diff --git a/python/ray/rllib/agents/agent.py b/python/ray/rllib/agents/agent.py index 8122e3491..45e2c1539 100644 --- a/python/ray/rllib/agents/agent.py +++ b/python/ray/rllib/agents/agent.py @@ -22,7 +22,7 @@ from ray.rllib.utils.annotations import override, PublicAPI, DeveloperAPI from ray.rllib.utils import FilterManager, deep_update, merge_dicts from ray.tune.registry import ENV_CREATOR, register_env, _global_registry from ray.tune.trainable import Trainable -from ray.tune.trial import Resources +from ray.tune.trial import Resources, ExportFormat from ray.tune.logger import UnifiedLogger from ray.tune.result import DEFAULT_RESULTS_DIR @@ -602,6 +602,20 @@ class Agent(Trainable): input_evaluation_method=config["input_evaluation"], output_creator=output_creator) + @override(Trainable) + def _export_model(self, export_formats, export_dir): + ExportFormat.validate(export_formats) + exported = {} + if ExportFormat.CHECKPOINT in export_formats: + path = os.path.join(export_dir, ExportFormat.CHECKPOINT) + self.export_policy_checkpoint(path) + exported[ExportFormat.CHECKPOINT] = path + if ExportFormat.MODEL in export_formats: + path = os.path.join(export_dir, ExportFormat.MODEL) + self.export_policy_model(path) + exported[ExportFormat.MODEL] = path + return exported + def __getstate__(self): state = {} if hasattr(self, "local_evaluator"): diff --git a/python/ray/rllib/evaluation/tf_policy_graph.py b/python/ray/rllib/evaluation/tf_policy_graph.py index fa5654161..fd1dc273e 100644 --- a/python/ray/rllib/evaluation/tf_policy_graph.py +++ b/python/ray/rllib/evaluation/tf_policy_graph.py @@ -3,6 +3,7 @@ from __future__ import division from __future__ import print_function import os +import errno import logging import tensorflow as tf import numpy as np @@ -205,6 +206,12 @@ class TFPolicyGraph(PolicyGraph): @override(PolicyGraph) def export_checkpoint(self, export_dir, filename_prefix="model"): """Export tensorflow checkpoint to export_dir.""" + try: + os.makedirs(export_dir) + except OSError as e: + # ignore error if export dir already exists + if e.errno != errno.EEXIST: + raise save_path = os.path.join(export_dir, filename_prefix) with self._sess.graph.as_default(): saver = tf.train.Saver() diff --git a/python/ray/rllib/test/test_checkpoint_restore.py b/python/ray/rllib/test/test_checkpoint_restore.py index 926c8573c..e1ad5a5a9 100644 --- a/python/ray/rllib/test/test_checkpoint_restore.py +++ b/python/ray/rllib/test/test_checkpoint_restore.py @@ -10,6 +10,7 @@ import numpy as np import ray from ray.rllib.agents.registry import get_agent_class +from ray.tune.trial import ExportFormat def get_mean_action(alg, obs): @@ -89,6 +90,15 @@ def test_ckpt_restore(use_object_store, alg_name, failures): def test_export(algo_name, failures): + def valid_tf_model(model_dir): + return os.path.exists(os.path.join(model_dir, "saved_model.pb")) \ + and os.listdir(os.path.join(model_dir, "variables")) + + def valid_tf_checkpoint(checkpoint_dir): + return os.path.exists(os.path.join(checkpoint_dir, "model.meta")) \ + and os.path.exists(os.path.join(checkpoint_dir, "model.index")) \ + and os.path.exists(os.path.join(checkpoint_dir, "checkpoint")) + cls = get_agent_class(algo_name) if "DDPG" in algo_name: algo = cls(config=CONFIGS[name], env="Pendulum-v0") @@ -102,16 +112,22 @@ def test_export(algo_name, failures): export_dir = "/tmp/export_dir_%s" % algo_name print("Exporting model ", algo_name, export_dir) algo.export_policy_model(export_dir) - if not os.path.exists(os.path.join(export_dir, "saved_model.pb")) \ - or not os.listdir(os.path.join(export_dir, "variables")): + if not valid_tf_model(export_dir): failures.append(algo_name) shutil.rmtree(export_dir) print("Exporting checkpoint", algo_name, export_dir) algo.export_policy_checkpoint(export_dir) - if not os.path.exists(os.path.join(export_dir, "model.meta")) \ - or not os.path.exists(os.path.join(export_dir, "model.index")) \ - or not os.path.exists(os.path.join(export_dir, "checkpoint")): + if not valid_tf_checkpoint(export_dir): + failures.append(algo_name) + shutil.rmtree(export_dir) + + print("Exporting default policy", algo_name, export_dir) + algo.export_model([ExportFormat.CHECKPOINT, ExportFormat.MODEL], + export_dir) + if not valid_tf_model(os.path.join(export_dir, ExportFormat.MODEL)) \ + or not valid_tf_checkpoint(os.path.join(export_dir, + ExportFormat.CHECKPOINT)): failures.append(algo_name) shutil.rmtree(export_dir) diff --git a/python/ray/tune/config_parser.py b/python/ray/tune/config_parser.py index aa0caa437..e44f5de81 100644 --- a/python/ray/tune/config_parser.py +++ b/python/ray/tune/config_parser.py @@ -102,6 +102,12 @@ def make_parser(parser_creator=None, **kwargs): action="store_true", help="Whether to checkpoint at the end of the experiment. " "Default is False.") + parser.add_argument( + "--export-formats", + default=None, + help="List of formats that exported at the end of the experiment. " + "Default is None. For RLlib, 'checkpoint' and 'model' are " + "supported for TensorFlow policy graphs.") parser.add_argument( "--max-failures", default=3, @@ -181,6 +187,7 @@ def create_trial_from_spec(spec, output_path, parser, **trial_kwargs): stopping_criterion=spec.get("stop", {}), checkpoint_freq=args.checkpoint_freq, checkpoint_at_end=args.checkpoint_at_end, + export_formats=spec.get("export_formats", []), # str(None) doesn't create None restore_path=spec.get("restore"), upload_dir=args.upload_dir, diff --git a/python/ray/tune/experiment.py b/python/ray/tune/experiment.py index 31a1ce7a8..ce52c3489 100644 --- a/python/ray/tune/experiment.py +++ b/python/ray/tune/experiment.py @@ -72,6 +72,8 @@ class Experiment(object): checkpoints. A value of 0 (default) disables checkpointing. checkpoint_at_end (bool): Whether to checkpoint at the end of the experiment regardless of the checkpoint_freq. Default is False. + export_formats (list): List of formats that exported at the end of + the experiment. Default is None. max_failures (int): Try to recover a trial from its last checkpoint at least this many times. Only applies if checkpointing is enabled. Setting to -1 will lead to infinite @@ -119,6 +121,7 @@ class Experiment(object): sync_function=None, checkpoint_freq=0, checkpoint_at_end=False, + export_formats=None, max_failures=3, restore=None, repeat=None, @@ -146,6 +149,7 @@ class Experiment(object): "sync_function": sync_function, "checkpoint_freq": checkpoint_freq, "checkpoint_at_end": checkpoint_at_end, + "export_formats": export_formats or [], "max_failures": max_failures, "restore": restore } diff --git a/python/ray/tune/ray_trial_executor.py b/python/ray/tune/ray_trial_executor.py index a6072d643..ff672a329 100644 --- a/python/ray/tune/ray_trial_executor.py +++ b/python/ray/tune/ray_trial_executor.py @@ -346,3 +346,14 @@ class RayTrialExecutor(TrialExecutor): logger.exception("Error restoring runner for Trial %s.", trial) self.set_status(trial, Trial.ERROR) return False + + def export_trial_if_needed(self, trial): + """Exports model of this trial based on trial.export_formats. + + Return: + A dict that maps ExportFormats to successfully exported models. + """ + if trial.export_formats and len(trial.export_formats) > 0: + return ray.get( + trial.runner.export_model.remote(trial.export_formats)) + return {} diff --git a/python/ray/tune/test/trial_runner_test.py b/python/ray/tune/test/trial_runner_test.py index 56e638a48..735b06594 100644 --- a/python/ray/tune/test/trial_runner_test.py +++ b/python/ray/tune/test/trial_runner_test.py @@ -23,7 +23,7 @@ from ray.tune.result import (DEFAULT_RESULTS_DIR, TIMESTEPS_TOTAL, DONE, from ray.tune.logger import Logger from ray.tune.util import pin_in_object_store, get_pinned_object from ray.tune.experiment import Experiment -from ray.tune.trial import Trial, Resources +from ray.tune.trial import Trial, Resources, ExportFormat from ray.tune.trial_runner import TrialRunner from ray.tune.suggest import grid_search, BasicVariantGenerator from ray.tune.suggest.suggestion import (_MockSuggestionAlgorithm, @@ -679,6 +679,47 @@ class RunExperimentTest(unittest.TestCase): self.assertEqual(trial.status, Trial.TERMINATED) self.assertTrue(trial.has_checkpoint()) + def testExportFormats(self): + class train(Trainable): + def _train(self): + return {"timesteps_this_iter": 1, "done": True} + + def _export_model(self, export_formats, export_dir): + path = export_dir + "/exported" + with open(path, "w") as f: + f.write("OK") + return {export_formats[0]: path} + + trials = run_experiments({ + "foo": { + "run": train, + "export_formats": ["format"] + } + }) + for trial in trials: + self.assertEqual(trial.status, Trial.TERMINATED) + self.assertTrue( + os.path.exists(os.path.join(trial.logdir, "exported"))) + + def testInvalidExportFormats(self): + class train(Trainable): + def _train(self): + return {"timesteps_this_iter": 1, "done": True} + + def _export_model(self, export_formats, export_dir): + ExportFormat.validate(export_formats) + return {} + + def fail_trial(): + run_experiments({ + "foo": { + "run": train, + "export_formats": ["format"] + } + }) + + self.assertRaises(TuneError, fail_trial) + def testDeprecatedResources(self): class train(Trainable): def _train(self): diff --git a/python/ray/tune/trainable.py b/python/ray/tune/trainable.py index 931fa5220..581085fd1 100644 --- a/python/ray/tune/trainable.py +++ b/python/ray/tune/trainable.py @@ -331,6 +331,23 @@ class Trainable(object): self.restore(checkpoint_path) shutil.rmtree(tmpdir) + def export_model(self, export_formats, export_dir=None): + """Exports model based on export_formats. + + Subclasses should override _export_model() to actually + export model to local directory. + + Args: + export_formats (list): List of formats that should be exported. + export_dir (str): Optional dir to place the exported model. + Defaults to self.logdir. + + Return: + A dict that maps ExportFormats to successfully exported models. + """ + export_dir = export_dir or self.logdir + return self._export_model(export_formats, export_dir) + def reset_config(self, new_config): """Resets configuration without restarting the trial. @@ -402,6 +419,18 @@ class Trainable(object): """Subclasses should override this for any cleanup on stop.""" pass + def _export_model(self, export_formats, export_dir): + """Subclasses should override this to export model. + + Args: + export_formats (list): List of formats that should be exported. + export_dir (str): Directory to place exported models. + + Return: + A dict that maps ExportFormats to successfully exported models. + """ + return {} + def wrap_function(train_func): from ray.tune.function_runner import FunctionRunner diff --git a/python/ray/tune/trial.py b/python/ray/tune/trial.py index 12693a6e3..36057ac2a 100644 --- a/python/ray/tune/trial.py +++ b/python/ray/tune/trial.py @@ -133,6 +133,29 @@ class Checkpoint(object): return Checkpoint(Checkpoint.MEMORY, value) +class ExportFormat(object): + """Describes the format to export the trial Trainable. + + This may correspond to different file formats based on the + Trainable implementation. + """ + CHECKPOINT = "checkpoint" + MODEL = "model" + + @staticmethod + def validate(export_formats): + """Validates export_formats. + + Raises: + ValueError if the format is unknown. + """ + for export_format in export_formats: + if export_format not in [ + ExportFormat.CHECKPOINT, ExportFormat.MODEL + ]: + raise TuneError("Unsupported export format: " + export_format) + + class Trial(object): """A trial object holds the state for one model training run. @@ -159,6 +182,7 @@ class Trial(object): stopping_criterion=None, checkpoint_freq=0, checkpoint_at_end=False, + export_formats=None, restore_path=None, upload_dir=None, trial_name_creator=None, @@ -195,6 +219,7 @@ class Trial(object): self.checkpoint_at_end = checkpoint_at_end self._checkpoint = Checkpoint( storage=Checkpoint.DISK, value=restore_path) + self.export_formats = export_formats self.status = Trial.PENDING self.logdir = None self.runner = None diff --git a/python/ray/tune/trial_executor.py b/python/ray/tune/trial_executor.py index 22d6d85eb..8bb944bec 100644 --- a/python/ray/tune/trial_executor.py +++ b/python/ray/tune/trial_executor.py @@ -205,3 +205,15 @@ class TrialExecutor(object): """ raise NotImplementedError("Subclasses of TrialExecutor must provide " "save() method") + + def export_trial_if_needed(self, trial): + """Exports model of this trial based on trial.export_formats. + + Args: + trial (Trial): The state of this trial to be saved. + + Return: + A dict that maps ExportFormats to successfully exported models. + """ + raise NotImplementedError("Subclasses of TrialExecutor must provide " + "export_trial_if_needed() method") diff --git a/python/ray/tune/trial_runner.py b/python/ray/tune/trial_runner.py index ce3c648d2..4f11d5226 100644 --- a/python/ray/tune/trial_runner.py +++ b/python/ray/tune/trial_runner.py @@ -404,6 +404,7 @@ class TrialRunner(object): # Checkpoint before ending the trial # if checkpoint_at_end experiment option is set to True self._checkpoint_trial_if_needed(trial) + self.trial_executor.export_trial_if_needed(trial) self.trial_executor.stop_trial(trial) else: assert False, "Invalid scheduling decision: {}".format(