From b4f61dfd503e23df5ea08194037ea09bf5d99540 Mon Sep 17 00:00:00 2001 From: Tianming Xu Date: Thu, 27 Dec 2018 07:43:06 +0800 Subject: [PATCH] [rllib] Export policy model checkpoint (#3637) * Export policy model checkpoint * update comment --- python/ray/rllib/agents/agent.py | 20 +++++++++++ .../ray/rllib/evaluation/policy_evaluator.py | 7 ++++ python/ray/rllib/evaluation/policy_graph.py | 10 +++++- .../ray/rllib/evaluation/tf_policy_graph.py | 9 +++++ .../examples/export/cartpole_dqn_export.py | 35 ++++++++++++++----- .../ray/rllib/test/test_checkpoint_restore.py | 8 +++++ 6 files changed, 80 insertions(+), 9 deletions(-) diff --git a/python/ray/rllib/agents/agent.py b/python/ray/rllib/agents/agent.py index f6d00671d..6f151b8e0 100644 --- a/python/ray/rllib/agents/agent.py +++ b/python/ray/rllib/agents/agent.py @@ -443,6 +443,26 @@ class Agent(Trainable): """ self.local_evaluator.export_policy_model(export_dir, policy_id) + def export_policy_checkpoint(self, + export_dir, + filename_prefix="model", + policy_id=DEFAULT_POLICY_ID): + """Export tensorflow policy model checkpoint to local directory. + + Arguments: + export_dir (string): Writable local directory. + filename_prefix (string): file name prefix of checkpoint files. + policy_id (string): Optional policy id to export. + + Example: + >>> agent = MyAgent() + >>> for _ in range(10): + >>> agent.train() + >>> agent.export_policy_checkpoint("/tmp/export_dir") + """ + self.local_evaluator.export_policy_checkpoint( + export_dir, filename_prefix, policy_id) + @classmethod def resource_help(cls, config): return ("\n\nYou can adjust the resource requests of RLlib agents by " diff --git a/python/ray/rllib/evaluation/policy_evaluator.py b/python/ray/rllib/evaluation/policy_evaluator.py index 0e52e1efe..f5c250aa9 100644 --- a/python/ray/rllib/evaluation/policy_evaluator.py +++ b/python/ray/rllib/evaluation/policy_evaluator.py @@ -556,6 +556,13 @@ class PolicyEvaluator(EvaluatorInterface): def export_policy_model(self, export_dir, policy_id=DEFAULT_POLICY_ID): self.policy_map[policy_id].export_model(export_dir) + def export_policy_checkpoint(self, + export_dir, + filename_prefix="model", + policy_id=DEFAULT_POLICY_ID): + self.policy_map[policy_id].export_checkpoint(export_dir, + filename_prefix) + def _build_policy_map(self, policy_dict, policy_config): policy_map = {} preprocessors = {} diff --git a/python/ray/rllib/evaluation/policy_graph.py b/python/ray/rllib/evaluation/policy_graph.py index 39a569fd7..fc4be5706 100644 --- a/python/ray/rllib/evaluation/policy_graph.py +++ b/python/ray/rllib/evaluation/policy_graph.py @@ -205,6 +205,14 @@ class PolicyGraph(object): """Export PolicyGraph to local directory for serving. Arguments: - export_dir (str): Local writable directory + export_dir (str): Local writable directory. + """ + raise NotImplementedError + + def export_checkpoint(self, export_dir): + """Export PolicyGraph checkpoint to local directory. + + Argument: + export_dir (str): Local writable directory. """ raise NotImplementedError diff --git a/python/ray/rllib/evaluation/tf_policy_graph.py b/python/ray/rllib/evaluation/tf_policy_graph.py index 3e1fcd0a3..7574864c9 100644 --- a/python/ray/rllib/evaluation/tf_policy_graph.py +++ b/python/ray/rllib/evaluation/tf_policy_graph.py @@ -2,6 +2,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import os import logging import tensorflow as tf import numpy as np @@ -199,6 +200,14 @@ class TFPolicyGraph(PolicyGraph): signature_def_map=signature_def_map) builder.save() + @override(PolicyGraph) + def export_checkpoint(self, export_dir, filename_prefix="model"): + """Export tensorflow checkpoint to export_dir.""" + save_path = os.path.join(export_dir, filename_prefix) + with self._sess.graph.as_default(): + saver = tf.train.Saver() + saver.save(self._sess, save_path) + def copy(self, existing_inputs): """Creates a copy of self using existing input placeholders. diff --git a/python/ray/rllib/examples/export/cartpole_dqn_export.py b/python/ray/rllib/examples/export/cartpole_dqn_export.py index a7125e431..6bfcae060 100644 --- a/python/ray/rllib/examples/export/cartpole_dqn_export.py +++ b/python/ray/rllib/examples/export/cartpole_dqn_export.py @@ -4,24 +4,28 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import os import ray import tensorflow as tf -from ray.rllib.agents.agent import get_agent_class +from ray.rllib.agents.registry import get_agent_class ray.init(num_cpus=10) -def train_and_export(algo_name, num_steps, export_dir): +def train_and_export(algo_name, num_steps, model_dir, ckpt_dir, prefix): cls = get_agent_class(algo_name) alg = cls(config={}, env="CartPole-v0") - for _ in range(3): + for _ in range(num_steps): alg.train() - alg.export_policy_model(export_dir) + # Export tensorflow checkpoint for fine-tuning + alg.export_policy_checkpoint(ckpt_dir, filename_prefix=prefix) + # Export tensorflow SavedModel for online serving + alg.export_policy_model(model_dir) -def restore(export_dir): +def restore_saved_model(export_dir): signature_key = \ tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY g = tf.Graph() @@ -38,9 +42,24 @@ def restore(export_dir): print("https://www.tensorflow.org/guide/saved_model") +def restore_checkpoint(export_dir, prefix): + sess = tf.Session() + meta_file = "%s.meta" % prefix + saver = tf.train.import_meta_graph(os.path.join(export_dir, meta_file)) + saver.restore(sess, os.path.join(export_dir, prefix)) + print("Checkpoint restored!") + print("Variables Information:") + for v in tf.trainable_variables(): + value = sess.run(v) + print(v.name, value) + + if __name__ == "__main__": algo = "DQN" - export_dir = "/tmp/export_dir" + model_dir = "/tmp/model_export_dir" + ckpt_dir = "/tmp/ckpt_export_dir" + prefix = "model.ckpt" num_steps = 3 - train_and_export(algo, num_steps, export_dir) - restore(export_dir) + train_and_export(algo, num_steps, model_dir, ckpt_dir, prefix) + restore_saved_model(model_dir) + restore_checkpoint(ckpt_dir, prefix) diff --git a/python/ray/rllib/test/test_checkpoint_restore.py b/python/ray/rllib/test/test_checkpoint_restore.py index f3b4aa623..926c8573c 100644 --- a/python/ray/rllib/test/test_checkpoint_restore.py +++ b/python/ray/rllib/test/test_checkpoint_restore.py @@ -107,6 +107,14 @@ def test_export(algo_name, failures): failures.append(algo_name) shutil.rmtree(export_dir) + print("Exporting checkpoint", algo_name, export_dir) + algo.export_policy_checkpoint(export_dir) + if not os.path.exists(os.path.join(export_dir, "model.meta")) \ + or not os.path.exists(os.path.join(export_dir, "model.index")) \ + or not os.path.exists(os.path.join(export_dir, "checkpoint")): + failures.append(algo_name) + shutil.rmtree(export_dir) + if __name__ == "__main__": failures = []