mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 01:00:10 +08:00
[rllib] Export tensorflow model of policy graph (#3585)
* Export tensorflow model of policy graph * Add tests,examples,pydocs and infer extra signatures from existing methods * Add example usage in export_policy_model comment * Fix lint error * Fix lint error * Fix lint error
This commit is contained in:
@@ -16,6 +16,7 @@ import ray
|
||||
from ray.rllib.offline import NoopOutput, JsonReader, MixedInput, JsonWriter
|
||||
from ray.rllib.models import MODEL_DEFAULTS
|
||||
from ray.rllib.evaluation.policy_evaluator import PolicyEvaluator
|
||||
from ray.rllib.evaluation.sample_batch import DEFAULT_POLICY_ID
|
||||
from ray.rllib.optimizers.policy_optimizer import PolicyOptimizer
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils import FilterManager, deep_update, merge_dicts
|
||||
@@ -427,6 +428,21 @@ class Agent(Trainable):
|
||||
self.config) for i in range(count)
|
||||
]
|
||||
|
||||
def export_policy_model(self, export_dir, policy_id=DEFAULT_POLICY_ID):
|
||||
"""Export policy model with given policy_id to local directory.
|
||||
|
||||
Arguments:
|
||||
export_dir (string): Writable local directory.
|
||||
policy_id (string): Optional policy id to export.
|
||||
|
||||
Example:
|
||||
>>> agent = MyAgent()
|
||||
>>> for _ in range(10):
|
||||
>>> agent.train()
|
||||
>>> agent.export_policy_model("/tmp/export_dir")
|
||||
"""
|
||||
self.local_evaluator.export_policy_model(export_dir, policy_id)
|
||||
|
||||
@classmethod
|
||||
def resource_help(cls, config):
|
||||
return ("\n\nYou can adjust the resource requests of RLlib agents by "
|
||||
|
||||
@@ -553,6 +553,9 @@ class PolicyEvaluator(EvaluatorInterface):
|
||||
def set_global_vars(self, global_vars):
|
||||
self.foreach_policy(lambda p, _: p.on_global_var_update(global_vars))
|
||||
|
||||
def export_policy_model(self, export_dir, policy_id=DEFAULT_POLICY_ID):
|
||||
self.policy_map[policy_id].export_model(export_dir)
|
||||
|
||||
def _build_policy_map(self, policy_dict, policy_config):
|
||||
policy_map = {}
|
||||
preprocessors = {}
|
||||
|
||||
@@ -200,3 +200,11 @@ class PolicyGraph(object):
|
||||
global_vars (dict): Global variables broadcast from the driver.
|
||||
"""
|
||||
pass
|
||||
|
||||
def export_model(self, export_dir):
|
||||
"""Export PolicyGraph to local directory for serving.
|
||||
|
||||
Arguments:
|
||||
export_dir (str): Local writable directory
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -188,6 +188,17 @@ class TFPolicyGraph(PolicyGraph):
|
||||
def set_weights(self, weights):
|
||||
return self._variables.set_flat(weights)
|
||||
|
||||
@override(PolicyGraph)
|
||||
def export_model(self, export_dir):
|
||||
"""Export tensorflow graph to export_dir for serving."""
|
||||
with self._sess.graph.as_default():
|
||||
builder = tf.saved_model.builder.SavedModelBuilder(export_dir)
|
||||
signature_def_map = self._build_signature_def()
|
||||
builder.add_meta_graph_and_variables(
|
||||
self._sess, [tf.saved_model.tag_constants.SERVING],
|
||||
signature_def_map=signature_def_map)
|
||||
builder.save()
|
||||
|
||||
def copy(self, existing_inputs):
|
||||
"""Creates a copy of self using existing input placeholders.
|
||||
|
||||
@@ -218,6 +229,26 @@ class TFPolicyGraph(PolicyGraph):
|
||||
"""Extra values to fetch and return from apply_gradients()."""
|
||||
return {} # e.g., batch norm updates
|
||||
|
||||
def _extra_input_signature_def(self):
|
||||
"""Extra input signatures to add when exporting tf model.
|
||||
Inferred from extra_compute_action_feed_dict()
|
||||
"""
|
||||
feed_dict = self.extra_compute_action_feed_dict()
|
||||
return {
|
||||
k.name: tf.saved_model.utils.build_tensor_info(k)
|
||||
for k in feed_dict.keys()
|
||||
}
|
||||
|
||||
def _extra_output_signature_def(self):
|
||||
"""Extra output signatures to add when exporting tf model.
|
||||
Inferred from extra_compute_action_fetches()
|
||||
"""
|
||||
fetches = self.extra_compute_action_fetches()
|
||||
return {
|
||||
k: tf.saved_model.utils.build_tensor_info(fetches[k])
|
||||
for k in fetches.keys()
|
||||
}
|
||||
|
||||
def optimizer(self):
|
||||
"""TF optimizer to use for policy optimization."""
|
||||
return tf.train.AdamOptimizer()
|
||||
@@ -226,6 +257,46 @@ class TFPolicyGraph(PolicyGraph):
|
||||
"""Override for custom gradient computation."""
|
||||
return optimizer.compute_gradients(self._loss)
|
||||
|
||||
def _build_signature_def(self):
|
||||
"""Build signature def map for tensorflow SavedModelBuilder.
|
||||
"""
|
||||
# build input signatures
|
||||
input_signature = self._extra_input_signature_def()
|
||||
input_signature["observations"] = \
|
||||
tf.saved_model.utils.build_tensor_info(self._obs_input)
|
||||
|
||||
if self._seq_lens is not None:
|
||||
input_signature["seq_lens"] = \
|
||||
tf.saved_model.utils.build_tensor_info(self._seq_lens)
|
||||
if self._prev_action_input is not None:
|
||||
input_signature["prev_action"] = \
|
||||
tf.saved_model.utils.build_tensor_info(self._prev_action_input)
|
||||
if self._prev_reward_input is not None:
|
||||
input_signature["prev_reward"] = \
|
||||
tf.saved_model.utils.build_tensor_info(self._prev_reward_input)
|
||||
input_signature["is_training"] = \
|
||||
tf.saved_model.utils.build_tensor_info(self._is_training)
|
||||
|
||||
for state_input in self._state_inputs:
|
||||
input_signature[state_input.name] = \
|
||||
tf.saved_model.utils.build_tensor_info(state_input)
|
||||
|
||||
# build output signatures
|
||||
output_signature = self._extra_output_signature_def()
|
||||
output_signature["actions"] = \
|
||||
tf.saved_model.utils.build_tensor_info(self._sampler)
|
||||
for state_output in self._state_outputs:
|
||||
output_signature[state_output.name] = \
|
||||
tf.saved_model.utils.build_tensor_info(state_output)
|
||||
signature_def = (
|
||||
tf.saved_model.signature_def_utils.build_signature_def(
|
||||
input_signature, output_signature,
|
||||
tf.saved_model.signature_constants.PREDICT_METHOD_NAME))
|
||||
signature_def_key = \
|
||||
tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY # noqa: E501
|
||||
signature_def_map = {signature_def_key: signature_def}
|
||||
return signature_def_map
|
||||
|
||||
def _build_compute_actions(self,
|
||||
builder,
|
||||
obs_batch,
|
||||
|
||||
@@ -0,0 +1,46 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import ray
|
||||
import tensorflow as tf
|
||||
|
||||
from ray.rllib.agents.agent import get_agent_class
|
||||
|
||||
ray.init(num_cpus=10)
|
||||
|
||||
|
||||
def train_and_export(algo_name, num_steps, export_dir):
|
||||
cls = get_agent_class(algo_name)
|
||||
alg = cls(config={}, env="CartPole-v0")
|
||||
for _ in range(3):
|
||||
alg.train()
|
||||
|
||||
alg.export_policy_model(export_dir)
|
||||
|
||||
|
||||
def restore(export_dir):
|
||||
signature_key = \
|
||||
tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
|
||||
g = tf.Graph()
|
||||
with g.as_default():
|
||||
with tf.Session(graph=g) as sess:
|
||||
meta_graph_def = \
|
||||
tf.saved_model.load(sess,
|
||||
[tf.saved_model.tag_constants.SERVING],
|
||||
export_dir)
|
||||
print("Model restored!")
|
||||
print("Signature Def Information:")
|
||||
print(meta_graph_def.signature_def[signature_key])
|
||||
print("You can inspect the model using TensorFlow SavedModel CLI.")
|
||||
print("https://www.tensorflow.org/guide/saved_model")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
algo = "DQN"
|
||||
export_dir = "/tmp/export_dir"
|
||||
num_steps = 3
|
||||
train_and_export(algo, num_steps, export_dir)
|
||||
restore(export_dir)
|
||||
@@ -4,6 +4,8 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import shutil
|
||||
import numpy as np
|
||||
import ray
|
||||
|
||||
@@ -55,7 +57,7 @@ CONFIGS = {
|
||||
}
|
||||
|
||||
|
||||
def test(use_object_store, alg_name, failures):
|
||||
def test_ckpt_restore(use_object_store, alg_name, failures):
|
||||
cls = get_agent_class(alg_name)
|
||||
if "DDPG" in alg_name:
|
||||
alg1 = cls(config=CONFIGS[name], env="Pendulum-v0")
|
||||
@@ -86,11 +88,37 @@ def test(use_object_store, alg_name, failures):
|
||||
failures.append((alg_name, [a1, a2]))
|
||||
|
||||
|
||||
def test_export(algo_name, failures):
|
||||
cls = get_agent_class(algo_name)
|
||||
if "DDPG" in algo_name:
|
||||
algo = cls(config=CONFIGS[name], env="Pendulum-v0")
|
||||
else:
|
||||
algo = cls(config=CONFIGS[name], env="CartPole-v0")
|
||||
|
||||
for _ in range(3):
|
||||
res = algo.train()
|
||||
print("current status: " + str(res))
|
||||
|
||||
export_dir = "/tmp/export_dir_%s" % algo_name
|
||||
print("Exporting model ", algo_name, export_dir)
|
||||
algo.export_policy_model(export_dir)
|
||||
if not os.path.exists(os.path.join(export_dir, "saved_model.pb")) \
|
||||
or not os.listdir(os.path.join(export_dir, "variables")):
|
||||
failures.append(algo_name)
|
||||
shutil.rmtree(export_dir)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
failures = []
|
||||
for use_object_store in [False, True]:
|
||||
for name in ["ES", "DQN", "DDPG", "PPO", "A3C", "APEX_DDPG", "ARS"]:
|
||||
test(use_object_store, name, failures)
|
||||
test_ckpt_restore(use_object_store, name, failures)
|
||||
|
||||
assert not failures, failures
|
||||
print("All checkpoint restore tests passed!")
|
||||
|
||||
failures = []
|
||||
for name in ["DQN", "DDPG", "PPO", "A3C"]:
|
||||
test_export(name, failures)
|
||||
assert not failures, failures
|
||||
print("All export tests passed!")
|
||||
|
||||
Reference in New Issue
Block a user