From deb26b954e56b17c5c4e4f61a37d20b9ce7352eb Mon Sep 17 00:00:00 2001 From: Tianming Xu Date: Sat, 22 Dec 2018 16:35:25 +0800 Subject: [PATCH] [rllib] Export tensorflow model of policy graph (#3585) * Export tensorflow model of policy graph * Add tests,examples,pydocs and infer extra signatures from existing methods * Add example usage in export_policy_model comment * Fix lint error * Fix lint error * Fix lint error --- python/ray/rllib/agents/agent.py | 16 +++++ .../ray/rllib/evaluation/policy_evaluator.py | 3 + python/ray/rllib/evaluation/policy_graph.py | 8 +++ .../ray/rllib/evaluation/tf_policy_graph.py | 71 +++++++++++++++++++ .../examples/export/cartpole_dqn_export.py | 46 ++++++++++++ .../ray/rllib/test/test_checkpoint_restore.py | 32 ++++++++- 6 files changed, 174 insertions(+), 2 deletions(-) create mode 100644 python/ray/rllib/examples/export/cartpole_dqn_export.py diff --git a/python/ray/rllib/agents/agent.py b/python/ray/rllib/agents/agent.py index 0dd93ef74..f6d00671d 100644 --- a/python/ray/rllib/agents/agent.py +++ b/python/ray/rllib/agents/agent.py @@ -16,6 +16,7 @@ import ray from ray.rllib.offline import NoopOutput, JsonReader, MixedInput, JsonWriter from ray.rllib.models import MODEL_DEFAULTS from ray.rllib.evaluation.policy_evaluator import PolicyEvaluator +from ray.rllib.evaluation.sample_batch import DEFAULT_POLICY_ID from ray.rllib.optimizers.policy_optimizer import PolicyOptimizer from ray.rllib.utils.annotations import override from ray.rllib.utils import FilterManager, deep_update, merge_dicts @@ -427,6 +428,21 @@ class Agent(Trainable): self.config) for i in range(count) ] + def export_policy_model(self, export_dir, policy_id=DEFAULT_POLICY_ID): + """Export policy model with given policy_id to local directory. + + Arguments: + export_dir (string): Writable local directory. + policy_id (string): Optional policy id to export. + + Example: + >>> agent = MyAgent() + >>> for _ in range(10): + >>> agent.train() + >>> agent.export_policy_model("/tmp/export_dir") + """ + self.local_evaluator.export_policy_model(export_dir, policy_id) + @classmethod def resource_help(cls, config): return ("\n\nYou can adjust the resource requests of RLlib agents by " diff --git a/python/ray/rllib/evaluation/policy_evaluator.py b/python/ray/rllib/evaluation/policy_evaluator.py index f2e78551f..0e52e1efe 100644 --- a/python/ray/rllib/evaluation/policy_evaluator.py +++ b/python/ray/rllib/evaluation/policy_evaluator.py @@ -553,6 +553,9 @@ class PolicyEvaluator(EvaluatorInterface): def set_global_vars(self, global_vars): self.foreach_policy(lambda p, _: p.on_global_var_update(global_vars)) + def export_policy_model(self, export_dir, policy_id=DEFAULT_POLICY_ID): + self.policy_map[policy_id].export_model(export_dir) + def _build_policy_map(self, policy_dict, policy_config): policy_map = {} preprocessors = {} diff --git a/python/ray/rllib/evaluation/policy_graph.py b/python/ray/rllib/evaluation/policy_graph.py index ef108a273..39a569fd7 100644 --- a/python/ray/rllib/evaluation/policy_graph.py +++ b/python/ray/rllib/evaluation/policy_graph.py @@ -200,3 +200,11 @@ class PolicyGraph(object): global_vars (dict): Global variables broadcast from the driver. """ pass + + def export_model(self, export_dir): + """Export PolicyGraph to local directory for serving. + + Arguments: + export_dir (str): Local writable directory + """ + raise NotImplementedError diff --git a/python/ray/rllib/evaluation/tf_policy_graph.py b/python/ray/rllib/evaluation/tf_policy_graph.py index a9064b7c0..3e1fcd0a3 100644 --- a/python/ray/rllib/evaluation/tf_policy_graph.py +++ b/python/ray/rllib/evaluation/tf_policy_graph.py @@ -188,6 +188,17 @@ class TFPolicyGraph(PolicyGraph): def set_weights(self, weights): return self._variables.set_flat(weights) + @override(PolicyGraph) + def export_model(self, export_dir): + """Export tensorflow graph to export_dir for serving.""" + with self._sess.graph.as_default(): + builder = tf.saved_model.builder.SavedModelBuilder(export_dir) + signature_def_map = self._build_signature_def() + builder.add_meta_graph_and_variables( + self._sess, [tf.saved_model.tag_constants.SERVING], + signature_def_map=signature_def_map) + builder.save() + def copy(self, existing_inputs): """Creates a copy of self using existing input placeholders. @@ -218,6 +229,26 @@ class TFPolicyGraph(PolicyGraph): """Extra values to fetch and return from apply_gradients().""" return {} # e.g., batch norm updates + def _extra_input_signature_def(self): + """Extra input signatures to add when exporting tf model. + Inferred from extra_compute_action_feed_dict() + """ + feed_dict = self.extra_compute_action_feed_dict() + return { + k.name: tf.saved_model.utils.build_tensor_info(k) + for k in feed_dict.keys() + } + + def _extra_output_signature_def(self): + """Extra output signatures to add when exporting tf model. + Inferred from extra_compute_action_fetches() + """ + fetches = self.extra_compute_action_fetches() + return { + k: tf.saved_model.utils.build_tensor_info(fetches[k]) + for k in fetches.keys() + } + def optimizer(self): """TF optimizer to use for policy optimization.""" return tf.train.AdamOptimizer() @@ -226,6 +257,46 @@ class TFPolicyGraph(PolicyGraph): """Override for custom gradient computation.""" return optimizer.compute_gradients(self._loss) + def _build_signature_def(self): + """Build signature def map for tensorflow SavedModelBuilder. + """ + # build input signatures + input_signature = self._extra_input_signature_def() + input_signature["observations"] = \ + tf.saved_model.utils.build_tensor_info(self._obs_input) + + if self._seq_lens is not None: + input_signature["seq_lens"] = \ + tf.saved_model.utils.build_tensor_info(self._seq_lens) + if self._prev_action_input is not None: + input_signature["prev_action"] = \ + tf.saved_model.utils.build_tensor_info(self._prev_action_input) + if self._prev_reward_input is not None: + input_signature["prev_reward"] = \ + tf.saved_model.utils.build_tensor_info(self._prev_reward_input) + input_signature["is_training"] = \ + tf.saved_model.utils.build_tensor_info(self._is_training) + + for state_input in self._state_inputs: + input_signature[state_input.name] = \ + tf.saved_model.utils.build_tensor_info(state_input) + + # build output signatures + output_signature = self._extra_output_signature_def() + output_signature["actions"] = \ + tf.saved_model.utils.build_tensor_info(self._sampler) + for state_output in self._state_outputs: + output_signature[state_output.name] = \ + tf.saved_model.utils.build_tensor_info(state_output) + signature_def = ( + tf.saved_model.signature_def_utils.build_signature_def( + input_signature, output_signature, + tf.saved_model.signature_constants.PREDICT_METHOD_NAME)) + signature_def_key = \ + tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY # noqa: E501 + signature_def_map = {signature_def_key: signature_def} + return signature_def_map + def _build_compute_actions(self, builder, obs_batch, diff --git a/python/ray/rllib/examples/export/cartpole_dqn_export.py b/python/ray/rllib/examples/export/cartpole_dqn_export.py new file mode 100644 index 000000000..a7125e431 --- /dev/null +++ b/python/ray/rllib/examples/export/cartpole_dqn_export.py @@ -0,0 +1,46 @@ +#!/usr/bin/env python + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import ray +import tensorflow as tf + +from ray.rllib.agents.agent import get_agent_class + +ray.init(num_cpus=10) + + +def train_and_export(algo_name, num_steps, export_dir): + cls = get_agent_class(algo_name) + alg = cls(config={}, env="CartPole-v0") + for _ in range(3): + alg.train() + + alg.export_policy_model(export_dir) + + +def restore(export_dir): + signature_key = \ + tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY + g = tf.Graph() + with g.as_default(): + with tf.Session(graph=g) as sess: + meta_graph_def = \ + tf.saved_model.load(sess, + [tf.saved_model.tag_constants.SERVING], + export_dir) + print("Model restored!") + print("Signature Def Information:") + print(meta_graph_def.signature_def[signature_key]) + print("You can inspect the model using TensorFlow SavedModel CLI.") + print("https://www.tensorflow.org/guide/saved_model") + + +if __name__ == "__main__": + algo = "DQN" + export_dir = "/tmp/export_dir" + num_steps = 3 + train_and_export(algo, num_steps, export_dir) + restore(export_dir) diff --git a/python/ray/rllib/test/test_checkpoint_restore.py b/python/ray/rllib/test/test_checkpoint_restore.py index 75d3a0136..f3b4aa623 100644 --- a/python/ray/rllib/test/test_checkpoint_restore.py +++ b/python/ray/rllib/test/test_checkpoint_restore.py @@ -4,6 +4,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import os +import shutil import numpy as np import ray @@ -55,7 +57,7 @@ CONFIGS = { } -def test(use_object_store, alg_name, failures): +def test_ckpt_restore(use_object_store, alg_name, failures): cls = get_agent_class(alg_name) if "DDPG" in alg_name: alg1 = cls(config=CONFIGS[name], env="Pendulum-v0") @@ -86,11 +88,37 @@ def test(use_object_store, alg_name, failures): failures.append((alg_name, [a1, a2])) +def test_export(algo_name, failures): + cls = get_agent_class(algo_name) + if "DDPG" in algo_name: + algo = cls(config=CONFIGS[name], env="Pendulum-v0") + else: + algo = cls(config=CONFIGS[name], env="CartPole-v0") + + for _ in range(3): + res = algo.train() + print("current status: " + str(res)) + + export_dir = "/tmp/export_dir_%s" % algo_name + print("Exporting model ", algo_name, export_dir) + algo.export_policy_model(export_dir) + if not os.path.exists(os.path.join(export_dir, "saved_model.pb")) \ + or not os.listdir(os.path.join(export_dir, "variables")): + failures.append(algo_name) + shutil.rmtree(export_dir) + + if __name__ == "__main__": failures = [] for use_object_store in [False, True]: for name in ["ES", "DQN", "DDPG", "PPO", "A3C", "APEX_DDPG", "ARS"]: - test(use_object_store, name, failures) + test_ckpt_restore(use_object_store, name, failures) assert not failures, failures print("All checkpoint restore tests passed!") + + failures = [] + for name in ["DQN", "DDPG", "PPO", "A3C"]: + test_export(name, failures) + assert not failures, failures + print("All export tests passed!")