mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 03:34:48 +08:00
[rllib] Export policy model checkpoint (#3637)
* Export policy model checkpoint * update comment
This commit is contained in:
@@ -443,6 +443,26 @@ class Agent(Trainable):
|
||||
"""
|
||||
self.local_evaluator.export_policy_model(export_dir, policy_id)
|
||||
|
||||
def export_policy_checkpoint(self,
|
||||
export_dir,
|
||||
filename_prefix="model",
|
||||
policy_id=DEFAULT_POLICY_ID):
|
||||
"""Export tensorflow policy model checkpoint to local directory.
|
||||
|
||||
Arguments:
|
||||
export_dir (string): Writable local directory.
|
||||
filename_prefix (string): file name prefix of checkpoint files.
|
||||
policy_id (string): Optional policy id to export.
|
||||
|
||||
Example:
|
||||
>>> agent = MyAgent()
|
||||
>>> for _ in range(10):
|
||||
>>> agent.train()
|
||||
>>> agent.export_policy_checkpoint("/tmp/export_dir")
|
||||
"""
|
||||
self.local_evaluator.export_policy_checkpoint(
|
||||
export_dir, filename_prefix, policy_id)
|
||||
|
||||
@classmethod
|
||||
def resource_help(cls, config):
|
||||
return ("\n\nYou can adjust the resource requests of RLlib agents by "
|
||||
|
||||
@@ -556,6 +556,13 @@ class PolicyEvaluator(EvaluatorInterface):
|
||||
def export_policy_model(self, export_dir, policy_id=DEFAULT_POLICY_ID):
|
||||
self.policy_map[policy_id].export_model(export_dir)
|
||||
|
||||
def export_policy_checkpoint(self,
|
||||
export_dir,
|
||||
filename_prefix="model",
|
||||
policy_id=DEFAULT_POLICY_ID):
|
||||
self.policy_map[policy_id].export_checkpoint(export_dir,
|
||||
filename_prefix)
|
||||
|
||||
def _build_policy_map(self, policy_dict, policy_config):
|
||||
policy_map = {}
|
||||
preprocessors = {}
|
||||
|
||||
@@ -205,6 +205,14 @@ class PolicyGraph(object):
|
||||
"""Export PolicyGraph to local directory for serving.
|
||||
|
||||
Arguments:
|
||||
export_dir (str): Local writable directory
|
||||
export_dir (str): Local writable directory.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def export_checkpoint(self, export_dir):
|
||||
"""Export PolicyGraph checkpoint to local directory.
|
||||
|
||||
Argument:
|
||||
export_dir (str): Local writable directory.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -2,6 +2,7 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import logging
|
||||
import tensorflow as tf
|
||||
import numpy as np
|
||||
@@ -199,6 +200,14 @@ class TFPolicyGraph(PolicyGraph):
|
||||
signature_def_map=signature_def_map)
|
||||
builder.save()
|
||||
|
||||
@override(PolicyGraph)
|
||||
def export_checkpoint(self, export_dir, filename_prefix="model"):
|
||||
"""Export tensorflow checkpoint to export_dir."""
|
||||
save_path = os.path.join(export_dir, filename_prefix)
|
||||
with self._sess.graph.as_default():
|
||||
saver = tf.train.Saver()
|
||||
saver.save(self._sess, save_path)
|
||||
|
||||
def copy(self, existing_inputs):
|
||||
"""Creates a copy of self using existing input placeholders.
|
||||
|
||||
|
||||
@@ -4,24 +4,28 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import ray
|
||||
import tensorflow as tf
|
||||
|
||||
from ray.rllib.agents.agent import get_agent_class
|
||||
from ray.rllib.agents.registry import get_agent_class
|
||||
|
||||
ray.init(num_cpus=10)
|
||||
|
||||
|
||||
def train_and_export(algo_name, num_steps, export_dir):
|
||||
def train_and_export(algo_name, num_steps, model_dir, ckpt_dir, prefix):
|
||||
cls = get_agent_class(algo_name)
|
||||
alg = cls(config={}, env="CartPole-v0")
|
||||
for _ in range(3):
|
||||
for _ in range(num_steps):
|
||||
alg.train()
|
||||
|
||||
alg.export_policy_model(export_dir)
|
||||
# Export tensorflow checkpoint for fine-tuning
|
||||
alg.export_policy_checkpoint(ckpt_dir, filename_prefix=prefix)
|
||||
# Export tensorflow SavedModel for online serving
|
||||
alg.export_policy_model(model_dir)
|
||||
|
||||
|
||||
def restore(export_dir):
|
||||
def restore_saved_model(export_dir):
|
||||
signature_key = \
|
||||
tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
|
||||
g = tf.Graph()
|
||||
@@ -38,9 +42,24 @@ def restore(export_dir):
|
||||
print("https://www.tensorflow.org/guide/saved_model")
|
||||
|
||||
|
||||
def restore_checkpoint(export_dir, prefix):
|
||||
sess = tf.Session()
|
||||
meta_file = "%s.meta" % prefix
|
||||
saver = tf.train.import_meta_graph(os.path.join(export_dir, meta_file))
|
||||
saver.restore(sess, os.path.join(export_dir, prefix))
|
||||
print("Checkpoint restored!")
|
||||
print("Variables Information:")
|
||||
for v in tf.trainable_variables():
|
||||
value = sess.run(v)
|
||||
print(v.name, value)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
algo = "DQN"
|
||||
export_dir = "/tmp/export_dir"
|
||||
model_dir = "/tmp/model_export_dir"
|
||||
ckpt_dir = "/tmp/ckpt_export_dir"
|
||||
prefix = "model.ckpt"
|
||||
num_steps = 3
|
||||
train_and_export(algo, num_steps, export_dir)
|
||||
restore(export_dir)
|
||||
train_and_export(algo, num_steps, model_dir, ckpt_dir, prefix)
|
||||
restore_saved_model(model_dir)
|
||||
restore_checkpoint(ckpt_dir, prefix)
|
||||
|
||||
@@ -107,6 +107,14 @@ def test_export(algo_name, failures):
|
||||
failures.append(algo_name)
|
||||
shutil.rmtree(export_dir)
|
||||
|
||||
print("Exporting checkpoint", algo_name, export_dir)
|
||||
algo.export_policy_checkpoint(export_dir)
|
||||
if not os.path.exists(os.path.join(export_dir, "model.meta")) \
|
||||
or not os.path.exists(os.path.join(export_dir, "model.index")) \
|
||||
or not os.path.exists(os.path.join(export_dir, "checkpoint")):
|
||||
failures.append(algo_name)
|
||||
shutil.rmtree(export_dir)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
failures = []
|
||||
|
||||
Reference in New Issue
Block a user