mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 01:46:10 +08:00
[Tune] Add export_formats option to export policy graphs (#3868)
In earlier PRs, PR#3585 and PR#3637, export_policy_model and export_policy_checkpoint were introduced for users to export TensorFlow model and checkpoint. For Ray Tune users, these APIs are not accessible through YAML configurations. In this pull request, export_formats option is provided to enable users to choose the desired export format.
This commit is contained in:
committed by
Richard Liaw
parent
b9eed2e86c
commit
1302fafc0b
@@ -22,7 +22,7 @@ from ray.rllib.utils.annotations import override, PublicAPI, DeveloperAPI
|
||||
from ray.rllib.utils import FilterManager, deep_update, merge_dicts
|
||||
from ray.tune.registry import ENV_CREATOR, register_env, _global_registry
|
||||
from ray.tune.trainable import Trainable
|
||||
from ray.tune.trial import Resources
|
||||
from ray.tune.trial import Resources, ExportFormat
|
||||
from ray.tune.logger import UnifiedLogger
|
||||
from ray.tune.result import DEFAULT_RESULTS_DIR
|
||||
|
||||
@@ -602,6 +602,20 @@ class Agent(Trainable):
|
||||
input_evaluation_method=config["input_evaluation"],
|
||||
output_creator=output_creator)
|
||||
|
||||
@override(Trainable)
|
||||
def _export_model(self, export_formats, export_dir):
|
||||
ExportFormat.validate(export_formats)
|
||||
exported = {}
|
||||
if ExportFormat.CHECKPOINT in export_formats:
|
||||
path = os.path.join(export_dir, ExportFormat.CHECKPOINT)
|
||||
self.export_policy_checkpoint(path)
|
||||
exported[ExportFormat.CHECKPOINT] = path
|
||||
if ExportFormat.MODEL in export_formats:
|
||||
path = os.path.join(export_dir, ExportFormat.MODEL)
|
||||
self.export_policy_model(path)
|
||||
exported[ExportFormat.MODEL] = path
|
||||
return exported
|
||||
|
||||
def __getstate__(self):
|
||||
state = {}
|
||||
if hasattr(self, "local_evaluator"):
|
||||
|
||||
@@ -3,6 +3,7 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import errno
|
||||
import logging
|
||||
import tensorflow as tf
|
||||
import numpy as np
|
||||
@@ -205,6 +206,12 @@ class TFPolicyGraph(PolicyGraph):
|
||||
@override(PolicyGraph)
|
||||
def export_checkpoint(self, export_dir, filename_prefix="model"):
|
||||
"""Export tensorflow checkpoint to export_dir."""
|
||||
try:
|
||||
os.makedirs(export_dir)
|
||||
except OSError as e:
|
||||
# ignore error if export dir already exists
|
||||
if e.errno != errno.EEXIST:
|
||||
raise
|
||||
save_path = os.path.join(export_dir, filename_prefix)
|
||||
with self._sess.graph.as_default():
|
||||
saver = tf.train.Saver()
|
||||
|
||||
@@ -10,6 +10,7 @@ import numpy as np
|
||||
import ray
|
||||
|
||||
from ray.rllib.agents.registry import get_agent_class
|
||||
from ray.tune.trial import ExportFormat
|
||||
|
||||
|
||||
def get_mean_action(alg, obs):
|
||||
@@ -89,6 +90,15 @@ def test_ckpt_restore(use_object_store, alg_name, failures):
|
||||
|
||||
|
||||
def test_export(algo_name, failures):
|
||||
def valid_tf_model(model_dir):
|
||||
return os.path.exists(os.path.join(model_dir, "saved_model.pb")) \
|
||||
and os.listdir(os.path.join(model_dir, "variables"))
|
||||
|
||||
def valid_tf_checkpoint(checkpoint_dir):
|
||||
return os.path.exists(os.path.join(checkpoint_dir, "model.meta")) \
|
||||
and os.path.exists(os.path.join(checkpoint_dir, "model.index")) \
|
||||
and os.path.exists(os.path.join(checkpoint_dir, "checkpoint"))
|
||||
|
||||
cls = get_agent_class(algo_name)
|
||||
if "DDPG" in algo_name:
|
||||
algo = cls(config=CONFIGS[name], env="Pendulum-v0")
|
||||
@@ -102,16 +112,22 @@ def test_export(algo_name, failures):
|
||||
export_dir = "/tmp/export_dir_%s" % algo_name
|
||||
print("Exporting model ", algo_name, export_dir)
|
||||
algo.export_policy_model(export_dir)
|
||||
if not os.path.exists(os.path.join(export_dir, "saved_model.pb")) \
|
||||
or not os.listdir(os.path.join(export_dir, "variables")):
|
||||
if not valid_tf_model(export_dir):
|
||||
failures.append(algo_name)
|
||||
shutil.rmtree(export_dir)
|
||||
|
||||
print("Exporting checkpoint", algo_name, export_dir)
|
||||
algo.export_policy_checkpoint(export_dir)
|
||||
if not os.path.exists(os.path.join(export_dir, "model.meta")) \
|
||||
or not os.path.exists(os.path.join(export_dir, "model.index")) \
|
||||
or not os.path.exists(os.path.join(export_dir, "checkpoint")):
|
||||
if not valid_tf_checkpoint(export_dir):
|
||||
failures.append(algo_name)
|
||||
shutil.rmtree(export_dir)
|
||||
|
||||
print("Exporting default policy", algo_name, export_dir)
|
||||
algo.export_model([ExportFormat.CHECKPOINT, ExportFormat.MODEL],
|
||||
export_dir)
|
||||
if not valid_tf_model(os.path.join(export_dir, ExportFormat.MODEL)) \
|
||||
or not valid_tf_checkpoint(os.path.join(export_dir,
|
||||
ExportFormat.CHECKPOINT)):
|
||||
failures.append(algo_name)
|
||||
shutil.rmtree(export_dir)
|
||||
|
||||
|
||||
@@ -102,6 +102,12 @@ def make_parser(parser_creator=None, **kwargs):
|
||||
action="store_true",
|
||||
help="Whether to checkpoint at the end of the experiment. "
|
||||
"Default is False.")
|
||||
parser.add_argument(
|
||||
"--export-formats",
|
||||
default=None,
|
||||
help="List of formats that exported at the end of the experiment. "
|
||||
"Default is None. For RLlib, 'checkpoint' and 'model' are "
|
||||
"supported for TensorFlow policy graphs.")
|
||||
parser.add_argument(
|
||||
"--max-failures",
|
||||
default=3,
|
||||
@@ -181,6 +187,7 @@ def create_trial_from_spec(spec, output_path, parser, **trial_kwargs):
|
||||
stopping_criterion=spec.get("stop", {}),
|
||||
checkpoint_freq=args.checkpoint_freq,
|
||||
checkpoint_at_end=args.checkpoint_at_end,
|
||||
export_formats=spec.get("export_formats", []),
|
||||
# str(None) doesn't create None
|
||||
restore_path=spec.get("restore"),
|
||||
upload_dir=args.upload_dir,
|
||||
|
||||
@@ -72,6 +72,8 @@ class Experiment(object):
|
||||
checkpoints. A value of 0 (default) disables checkpointing.
|
||||
checkpoint_at_end (bool): Whether to checkpoint at the end of the
|
||||
experiment regardless of the checkpoint_freq. Default is False.
|
||||
export_formats (list): List of formats that exported at the end of
|
||||
the experiment. Default is None.
|
||||
max_failures (int): Try to recover a trial from its last
|
||||
checkpoint at least this many times. Only applies if
|
||||
checkpointing is enabled. Setting to -1 will lead to infinite
|
||||
@@ -119,6 +121,7 @@ class Experiment(object):
|
||||
sync_function=None,
|
||||
checkpoint_freq=0,
|
||||
checkpoint_at_end=False,
|
||||
export_formats=None,
|
||||
max_failures=3,
|
||||
restore=None,
|
||||
repeat=None,
|
||||
@@ -146,6 +149,7 @@ class Experiment(object):
|
||||
"sync_function": sync_function,
|
||||
"checkpoint_freq": checkpoint_freq,
|
||||
"checkpoint_at_end": checkpoint_at_end,
|
||||
"export_formats": export_formats or [],
|
||||
"max_failures": max_failures,
|
||||
"restore": restore
|
||||
}
|
||||
|
||||
@@ -346,3 +346,14 @@ class RayTrialExecutor(TrialExecutor):
|
||||
logger.exception("Error restoring runner for Trial %s.", trial)
|
||||
self.set_status(trial, Trial.ERROR)
|
||||
return False
|
||||
|
||||
def export_trial_if_needed(self, trial):
|
||||
"""Exports model of this trial based on trial.export_formats.
|
||||
|
||||
Return:
|
||||
A dict that maps ExportFormats to successfully exported models.
|
||||
"""
|
||||
if trial.export_formats and len(trial.export_formats) > 0:
|
||||
return ray.get(
|
||||
trial.runner.export_model.remote(trial.export_formats))
|
||||
return {}
|
||||
|
||||
@@ -23,7 +23,7 @@ from ray.tune.result import (DEFAULT_RESULTS_DIR, TIMESTEPS_TOTAL, DONE,
|
||||
from ray.tune.logger import Logger
|
||||
from ray.tune.util import pin_in_object_store, get_pinned_object
|
||||
from ray.tune.experiment import Experiment
|
||||
from ray.tune.trial import Trial, Resources
|
||||
from ray.tune.trial import Trial, Resources, ExportFormat
|
||||
from ray.tune.trial_runner import TrialRunner
|
||||
from ray.tune.suggest import grid_search, BasicVariantGenerator
|
||||
from ray.tune.suggest.suggestion import (_MockSuggestionAlgorithm,
|
||||
@@ -679,6 +679,47 @@ class RunExperimentTest(unittest.TestCase):
|
||||
self.assertEqual(trial.status, Trial.TERMINATED)
|
||||
self.assertTrue(trial.has_checkpoint())
|
||||
|
||||
def testExportFormats(self):
|
||||
class train(Trainable):
|
||||
def _train(self):
|
||||
return {"timesteps_this_iter": 1, "done": True}
|
||||
|
||||
def _export_model(self, export_formats, export_dir):
|
||||
path = export_dir + "/exported"
|
||||
with open(path, "w") as f:
|
||||
f.write("OK")
|
||||
return {export_formats[0]: path}
|
||||
|
||||
trials = run_experiments({
|
||||
"foo": {
|
||||
"run": train,
|
||||
"export_formats": ["format"]
|
||||
}
|
||||
})
|
||||
for trial in trials:
|
||||
self.assertEqual(trial.status, Trial.TERMINATED)
|
||||
self.assertTrue(
|
||||
os.path.exists(os.path.join(trial.logdir, "exported")))
|
||||
|
||||
def testInvalidExportFormats(self):
|
||||
class train(Trainable):
|
||||
def _train(self):
|
||||
return {"timesteps_this_iter": 1, "done": True}
|
||||
|
||||
def _export_model(self, export_formats, export_dir):
|
||||
ExportFormat.validate(export_formats)
|
||||
return {}
|
||||
|
||||
def fail_trial():
|
||||
run_experiments({
|
||||
"foo": {
|
||||
"run": train,
|
||||
"export_formats": ["format"]
|
||||
}
|
||||
})
|
||||
|
||||
self.assertRaises(TuneError, fail_trial)
|
||||
|
||||
def testDeprecatedResources(self):
|
||||
class train(Trainable):
|
||||
def _train(self):
|
||||
|
||||
@@ -331,6 +331,23 @@ class Trainable(object):
|
||||
self.restore(checkpoint_path)
|
||||
shutil.rmtree(tmpdir)
|
||||
|
||||
def export_model(self, export_formats, export_dir=None):
|
||||
"""Exports model based on export_formats.
|
||||
|
||||
Subclasses should override _export_model() to actually
|
||||
export model to local directory.
|
||||
|
||||
Args:
|
||||
export_formats (list): List of formats that should be exported.
|
||||
export_dir (str): Optional dir to place the exported model.
|
||||
Defaults to self.logdir.
|
||||
|
||||
Return:
|
||||
A dict that maps ExportFormats to successfully exported models.
|
||||
"""
|
||||
export_dir = export_dir or self.logdir
|
||||
return self._export_model(export_formats, export_dir)
|
||||
|
||||
def reset_config(self, new_config):
|
||||
"""Resets configuration without restarting the trial.
|
||||
|
||||
@@ -402,6 +419,18 @@ class Trainable(object):
|
||||
"""Subclasses should override this for any cleanup on stop."""
|
||||
pass
|
||||
|
||||
def _export_model(self, export_formats, export_dir):
|
||||
"""Subclasses should override this to export model.
|
||||
|
||||
Args:
|
||||
export_formats (list): List of formats that should be exported.
|
||||
export_dir (str): Directory to place exported models.
|
||||
|
||||
Return:
|
||||
A dict that maps ExportFormats to successfully exported models.
|
||||
"""
|
||||
return {}
|
||||
|
||||
|
||||
def wrap_function(train_func):
|
||||
from ray.tune.function_runner import FunctionRunner
|
||||
|
||||
@@ -133,6 +133,29 @@ class Checkpoint(object):
|
||||
return Checkpoint(Checkpoint.MEMORY, value)
|
||||
|
||||
|
||||
class ExportFormat(object):
|
||||
"""Describes the format to export the trial Trainable.
|
||||
|
||||
This may correspond to different file formats based on the
|
||||
Trainable implementation.
|
||||
"""
|
||||
CHECKPOINT = "checkpoint"
|
||||
MODEL = "model"
|
||||
|
||||
@staticmethod
|
||||
def validate(export_formats):
|
||||
"""Validates export_formats.
|
||||
|
||||
Raises:
|
||||
ValueError if the format is unknown.
|
||||
"""
|
||||
for export_format in export_formats:
|
||||
if export_format not in [
|
||||
ExportFormat.CHECKPOINT, ExportFormat.MODEL
|
||||
]:
|
||||
raise TuneError("Unsupported export format: " + export_format)
|
||||
|
||||
|
||||
class Trial(object):
|
||||
"""A trial object holds the state for one model training run.
|
||||
|
||||
@@ -159,6 +182,7 @@ class Trial(object):
|
||||
stopping_criterion=None,
|
||||
checkpoint_freq=0,
|
||||
checkpoint_at_end=False,
|
||||
export_formats=None,
|
||||
restore_path=None,
|
||||
upload_dir=None,
|
||||
trial_name_creator=None,
|
||||
@@ -195,6 +219,7 @@ class Trial(object):
|
||||
self.checkpoint_at_end = checkpoint_at_end
|
||||
self._checkpoint = Checkpoint(
|
||||
storage=Checkpoint.DISK, value=restore_path)
|
||||
self.export_formats = export_formats
|
||||
self.status = Trial.PENDING
|
||||
self.logdir = None
|
||||
self.runner = None
|
||||
|
||||
@@ -205,3 +205,15 @@ class TrialExecutor(object):
|
||||
"""
|
||||
raise NotImplementedError("Subclasses of TrialExecutor must provide "
|
||||
"save() method")
|
||||
|
||||
def export_trial_if_needed(self, trial):
|
||||
"""Exports model of this trial based on trial.export_formats.
|
||||
|
||||
Args:
|
||||
trial (Trial): The state of this trial to be saved.
|
||||
|
||||
Return:
|
||||
A dict that maps ExportFormats to successfully exported models.
|
||||
"""
|
||||
raise NotImplementedError("Subclasses of TrialExecutor must provide "
|
||||
"export_trial_if_needed() method")
|
||||
|
||||
@@ -404,6 +404,7 @@ class TrialRunner(object):
|
||||
# Checkpoint before ending the trial
|
||||
# if checkpoint_at_end experiment option is set to True
|
||||
self._checkpoint_trial_if_needed(trial)
|
||||
self.trial_executor.export_trial_if_needed(trial)
|
||||
self.trial_executor.stop_trial(trial)
|
||||
else:
|
||||
assert False, "Invalid scheduling decision: {}".format(
|
||||
|
||||
Reference in New Issue
Block a user