[Tune] Add export_formats option to export policy graphs (#3868)

In earlier PRs, PR#3585 and PR#3637, export_policy_model and export_policy_checkpoint were introduced for users to export TensorFlow model and checkpoint.

For Ray Tune users, these APIs are not accessible through YAML configurations.

In this pull request, export_formats option is provided to enable users to choose the desired export format.
This commit is contained in:
Tianming Xu
2019-02-01 09:07:27 +08:00
committed by Richard Liaw
parent b9eed2e86c
commit 1302fafc0b
11 changed files with 174 additions and 7 deletions
+15 -1
View File
@@ -22,7 +22,7 @@ from ray.rllib.utils.annotations import override, PublicAPI, DeveloperAPI
from ray.rllib.utils import FilterManager, deep_update, merge_dicts
from ray.tune.registry import ENV_CREATOR, register_env, _global_registry
from ray.tune.trainable import Trainable
from ray.tune.trial import Resources
from ray.tune.trial import Resources, ExportFormat
from ray.tune.logger import UnifiedLogger
from ray.tune.result import DEFAULT_RESULTS_DIR
@@ -602,6 +602,20 @@ class Agent(Trainable):
input_evaluation_method=config["input_evaluation"],
output_creator=output_creator)
@override(Trainable)
def _export_model(self, export_formats, export_dir):
ExportFormat.validate(export_formats)
exported = {}
if ExportFormat.CHECKPOINT in export_formats:
path = os.path.join(export_dir, ExportFormat.CHECKPOINT)
self.export_policy_checkpoint(path)
exported[ExportFormat.CHECKPOINT] = path
if ExportFormat.MODEL in export_formats:
path = os.path.join(export_dir, ExportFormat.MODEL)
self.export_policy_model(path)
exported[ExportFormat.MODEL] = path
return exported
def __getstate__(self):
state = {}
if hasattr(self, "local_evaluator"):
@@ -3,6 +3,7 @@ from __future__ import division
from __future__ import print_function
import os
import errno
import logging
import tensorflow as tf
import numpy as np
@@ -205,6 +206,12 @@ class TFPolicyGraph(PolicyGraph):
@override(PolicyGraph)
def export_checkpoint(self, export_dir, filename_prefix="model"):
"""Export tensorflow checkpoint to export_dir."""
try:
os.makedirs(export_dir)
except OSError as e:
# ignore error if export dir already exists
if e.errno != errno.EEXIST:
raise
save_path = os.path.join(export_dir, filename_prefix)
with self._sess.graph.as_default():
saver = tf.train.Saver()
@@ -10,6 +10,7 @@ import numpy as np
import ray
from ray.rllib.agents.registry import get_agent_class
from ray.tune.trial import ExportFormat
def get_mean_action(alg, obs):
@@ -89,6 +90,15 @@ def test_ckpt_restore(use_object_store, alg_name, failures):
def test_export(algo_name, failures):
def valid_tf_model(model_dir):
return os.path.exists(os.path.join(model_dir, "saved_model.pb")) \
and os.listdir(os.path.join(model_dir, "variables"))
def valid_tf_checkpoint(checkpoint_dir):
return os.path.exists(os.path.join(checkpoint_dir, "model.meta")) \
and os.path.exists(os.path.join(checkpoint_dir, "model.index")) \
and os.path.exists(os.path.join(checkpoint_dir, "checkpoint"))
cls = get_agent_class(algo_name)
if "DDPG" in algo_name:
algo = cls(config=CONFIGS[name], env="Pendulum-v0")
@@ -102,16 +112,22 @@ def test_export(algo_name, failures):
export_dir = "/tmp/export_dir_%s" % algo_name
print("Exporting model ", algo_name, export_dir)
algo.export_policy_model(export_dir)
if not os.path.exists(os.path.join(export_dir, "saved_model.pb")) \
or not os.listdir(os.path.join(export_dir, "variables")):
if not valid_tf_model(export_dir):
failures.append(algo_name)
shutil.rmtree(export_dir)
print("Exporting checkpoint", algo_name, export_dir)
algo.export_policy_checkpoint(export_dir)
if not os.path.exists(os.path.join(export_dir, "model.meta")) \
or not os.path.exists(os.path.join(export_dir, "model.index")) \
or not os.path.exists(os.path.join(export_dir, "checkpoint")):
if not valid_tf_checkpoint(export_dir):
failures.append(algo_name)
shutil.rmtree(export_dir)
print("Exporting default policy", algo_name, export_dir)
algo.export_model([ExportFormat.CHECKPOINT, ExportFormat.MODEL],
export_dir)
if not valid_tf_model(os.path.join(export_dir, ExportFormat.MODEL)) \
or not valid_tf_checkpoint(os.path.join(export_dir,
ExportFormat.CHECKPOINT)):
failures.append(algo_name)
shutil.rmtree(export_dir)
+7
View File
@@ -102,6 +102,12 @@ def make_parser(parser_creator=None, **kwargs):
action="store_true",
help="Whether to checkpoint at the end of the experiment. "
"Default is False.")
parser.add_argument(
"--export-formats",
default=None,
help="List of formats that exported at the end of the experiment. "
"Default is None. For RLlib, 'checkpoint' and 'model' are "
"supported for TensorFlow policy graphs.")
parser.add_argument(
"--max-failures",
default=3,
@@ -181,6 +187,7 @@ def create_trial_from_spec(spec, output_path, parser, **trial_kwargs):
stopping_criterion=spec.get("stop", {}),
checkpoint_freq=args.checkpoint_freq,
checkpoint_at_end=args.checkpoint_at_end,
export_formats=spec.get("export_formats", []),
# str(None) doesn't create None
restore_path=spec.get("restore"),
upload_dir=args.upload_dir,
+4
View File
@@ -72,6 +72,8 @@ class Experiment(object):
checkpoints. A value of 0 (default) disables checkpointing.
checkpoint_at_end (bool): Whether to checkpoint at the end of the
experiment regardless of the checkpoint_freq. Default is False.
export_formats (list): List of formats that exported at the end of
the experiment. Default is None.
max_failures (int): Try to recover a trial from its last
checkpoint at least this many times. Only applies if
checkpointing is enabled. Setting to -1 will lead to infinite
@@ -119,6 +121,7 @@ class Experiment(object):
sync_function=None,
checkpoint_freq=0,
checkpoint_at_end=False,
export_formats=None,
max_failures=3,
restore=None,
repeat=None,
@@ -146,6 +149,7 @@ class Experiment(object):
"sync_function": sync_function,
"checkpoint_freq": checkpoint_freq,
"checkpoint_at_end": checkpoint_at_end,
"export_formats": export_formats or [],
"max_failures": max_failures,
"restore": restore
}
+11
View File
@@ -346,3 +346,14 @@ class RayTrialExecutor(TrialExecutor):
logger.exception("Error restoring runner for Trial %s.", trial)
self.set_status(trial, Trial.ERROR)
return False
def export_trial_if_needed(self, trial):
"""Exports model of this trial based on trial.export_formats.
Return:
A dict that maps ExportFormats to successfully exported models.
"""
if trial.export_formats and len(trial.export_formats) > 0:
return ray.get(
trial.runner.export_model.remote(trial.export_formats))
return {}
+42 -1
View File
@@ -23,7 +23,7 @@ from ray.tune.result import (DEFAULT_RESULTS_DIR, TIMESTEPS_TOTAL, DONE,
from ray.tune.logger import Logger
from ray.tune.util import pin_in_object_store, get_pinned_object
from ray.tune.experiment import Experiment
from ray.tune.trial import Trial, Resources
from ray.tune.trial import Trial, Resources, ExportFormat
from ray.tune.trial_runner import TrialRunner
from ray.tune.suggest import grid_search, BasicVariantGenerator
from ray.tune.suggest.suggestion import (_MockSuggestionAlgorithm,
@@ -679,6 +679,47 @@ class RunExperimentTest(unittest.TestCase):
self.assertEqual(trial.status, Trial.TERMINATED)
self.assertTrue(trial.has_checkpoint())
def testExportFormats(self):
class train(Trainable):
def _train(self):
return {"timesteps_this_iter": 1, "done": True}
def _export_model(self, export_formats, export_dir):
path = export_dir + "/exported"
with open(path, "w") as f:
f.write("OK")
return {export_formats[0]: path}
trials = run_experiments({
"foo": {
"run": train,
"export_formats": ["format"]
}
})
for trial in trials:
self.assertEqual(trial.status, Trial.TERMINATED)
self.assertTrue(
os.path.exists(os.path.join(trial.logdir, "exported")))
def testInvalidExportFormats(self):
class train(Trainable):
def _train(self):
return {"timesteps_this_iter": 1, "done": True}
def _export_model(self, export_formats, export_dir):
ExportFormat.validate(export_formats)
return {}
def fail_trial():
run_experiments({
"foo": {
"run": train,
"export_formats": ["format"]
}
})
self.assertRaises(TuneError, fail_trial)
def testDeprecatedResources(self):
class train(Trainable):
def _train(self):
+29
View File
@@ -331,6 +331,23 @@ class Trainable(object):
self.restore(checkpoint_path)
shutil.rmtree(tmpdir)
def export_model(self, export_formats, export_dir=None):
"""Exports model based on export_formats.
Subclasses should override _export_model() to actually
export model to local directory.
Args:
export_formats (list): List of formats that should be exported.
export_dir (str): Optional dir to place the exported model.
Defaults to self.logdir.
Return:
A dict that maps ExportFormats to successfully exported models.
"""
export_dir = export_dir or self.logdir
return self._export_model(export_formats, export_dir)
def reset_config(self, new_config):
"""Resets configuration without restarting the trial.
@@ -402,6 +419,18 @@ class Trainable(object):
"""Subclasses should override this for any cleanup on stop."""
pass
def _export_model(self, export_formats, export_dir):
"""Subclasses should override this to export model.
Args:
export_formats (list): List of formats that should be exported.
export_dir (str): Directory to place exported models.
Return:
A dict that maps ExportFormats to successfully exported models.
"""
return {}
def wrap_function(train_func):
from ray.tune.function_runner import FunctionRunner
+25
View File
@@ -133,6 +133,29 @@ class Checkpoint(object):
return Checkpoint(Checkpoint.MEMORY, value)
class ExportFormat(object):
"""Describes the format to export the trial Trainable.
This may correspond to different file formats based on the
Trainable implementation.
"""
CHECKPOINT = "checkpoint"
MODEL = "model"
@staticmethod
def validate(export_formats):
"""Validates export_formats.
Raises:
ValueError if the format is unknown.
"""
for export_format in export_formats:
if export_format not in [
ExportFormat.CHECKPOINT, ExportFormat.MODEL
]:
raise TuneError("Unsupported export format: " + export_format)
class Trial(object):
"""A trial object holds the state for one model training run.
@@ -159,6 +182,7 @@ class Trial(object):
stopping_criterion=None,
checkpoint_freq=0,
checkpoint_at_end=False,
export_formats=None,
restore_path=None,
upload_dir=None,
trial_name_creator=None,
@@ -195,6 +219,7 @@ class Trial(object):
self.checkpoint_at_end = checkpoint_at_end
self._checkpoint = Checkpoint(
storage=Checkpoint.DISK, value=restore_path)
self.export_formats = export_formats
self.status = Trial.PENDING
self.logdir = None
self.runner = None
+12
View File
@@ -205,3 +205,15 @@ class TrialExecutor(object):
"""
raise NotImplementedError("Subclasses of TrialExecutor must provide "
"save() method")
def export_trial_if_needed(self, trial):
"""Exports model of this trial based on trial.export_formats.
Args:
trial (Trial): The state of this trial to be saved.
Return:
A dict that maps ExportFormats to successfully exported models.
"""
raise NotImplementedError("Subclasses of TrialExecutor must provide "
"export_trial_if_needed() method")
+1
View File
@@ -404,6 +404,7 @@ class TrialRunner(object):
# Checkpoint before ending the trial
# if checkpoint_at_end experiment option is set to True
self._checkpoint_trial_if_needed(trial)
self.trial_executor.export_trial_if_needed(trial)
self.trial_executor.stop_trial(trial)
else:
assert False, "Invalid scheduling decision: {}".format(