diff --git a/doc/source/rllib-training.rst b/doc/source/rllib-training.rst index 4556f55c4..0b95ef61f 100644 --- a/doc/source/rllib-training.rst +++ b/doc/source/rllib-training.rst @@ -141,6 +141,18 @@ Here is an example of the basic usage (for a more complete example, see `custom_ checkpoint = trainer.save() print("checkpoint saved at", checkpoint) + # Also, in case you have trained a model outside of ray/RLlib and have created + # an h5-file with weight values in it, e.g. + # my_keras_model_trained_outside_rllib.save_weights("model.h5") + # (see: https://keras.io/models/about-keras-models/) + + # ... you can load the h5-weights into your Trainer's Policy's ModelV2 + # (tf or torch) by doing: + trainer.import_model("my_weights.h5") + # NOTE: In order for this to work, your (custom) model needs to implement + # the `import_from_h5` method. + # See https://github.com/ray-project/ray/blob/master/rllib/tests/test_model_imports.py + # for detailed examples for tf- and torch trainers/models. .. note:: diff --git a/python/ray/tune/trainable.py b/python/ray/tune/trainable.py index ebf5852b4..a1bab0fc6 100644 --- a/python/ray/tune/trainable.py +++ b/python/ray/tune/trainable.py @@ -473,13 +473,16 @@ class Trainable: export model to local directory. Args: - export_formats (list): List of formats that should be exported. + export_formats (Union[list,str]): Format or list of (str) formats + that should be exported. export_dir (str): Optional dir to place the exported model. Defaults to self.logdir. Returns: A dict that maps ExportFormats to successfully exported models. """ + if isinstance(export_formats, str): + export_formats = [export_formats] export_dir = export_dir or self.logdir return self._export_model(export_formats, export_dir) diff --git a/python/ray/tune/trial.py b/python/ray/tune/trial.py index d74126539..942bdef95 100644 --- a/python/ray/tune/trial.py +++ b/python/ray/tune/trial.py @@ -48,28 +48,30 @@ class Location: class ExportFormat: - """Describes the format to export the trial Trainable. + """Describes the format to import/export the trial Trainable. This may correspond to different file formats based on the Trainable implementation. """ CHECKPOINT = "checkpoint" MODEL = "model" + H5 = "h5" @staticmethod - def validate(export_formats): - """Validates export_formats. + def validate(formats): + """Validates formats. Raises: ValueError if the format is unknown. """ - for i in range(len(export_formats)): - export_formats[i] = export_formats[i].strip().lower() - if export_formats[i] not in [ - ExportFormat.CHECKPOINT, ExportFormat.MODEL + for i in range(len(formats)): + formats[i] = formats[i].strip().lower() + if formats[i] not in [ + ExportFormat.CHECKPOINT, ExportFormat.MODEL, + ExportFormat.H5 ]: - raise TuneError("Unsupported export format: " + - export_formats[i]) + raise TuneError("Unsupported import/export format: " + + formats[i]) def checkpoint_deleter(trial_id, runner): diff --git a/rllib/BUILD b/rllib/BUILD index a92b0e9e2..76bf3b3d8 100644 --- a/rllib/BUILD +++ b/rllib/BUILD @@ -1034,6 +1034,14 @@ py_test( srcs = ["tests/test_lstm.py"] ) +py_test( + name = "tests/test_model_imports", + tags = ["tests_dir", "tests_dir_M", "model_imports"], + size = "small", + data = glob(["tests/data/model_weights/**"]), + srcs = ["tests/test_model_imports.py"] +) + py_test( name = "tests/test_multi_agent_env", tags = ["tests_dir", "tests_dir_M"], diff --git a/rllib/agents/trainer.py b/rllib/agents/trainer.py index 1a74eacbd..61d70358c 100644 --- a/rllib/agents/trainer.py +++ b/rllib/agents/trainer.py @@ -628,6 +628,7 @@ class Trainer(Trainable): checkpoint_path = os.path.join(checkpoint_dir, "checkpoint-{}".format(self.iteration)) pickle.dump(self.__getstate__(), open(checkpoint_path, "wb")) + return checkpoint_path @override(Trainable) @@ -866,6 +867,25 @@ class Trainer(Trainable): self.workers.local_worker().export_policy_checkpoint( export_dir, filename_prefix, policy_id) + @DeveloperAPI + def import_policy_model_from_h5(self, + import_file, + policy_id=DEFAULT_POLICY_ID): + """Imports a policy's model with given policy_id from a local h5 file. + + Arguments: + import_file (str): The h5 file to import from. + policy_id (string): Optional policy id to import into. + + Example: + >>> trainer = MyTrainer() + >>> trainer.import_policy_model_from_h5("/tmp/weights.h5") + >>> for _ in range(10): + >>> trainer.train() + """ + self.workers.local_worker().import_policy_model_from_h5( + import_file, policy_id) + @DeveloperAPI def collect_metrics(self, selected_workers=None): """Collects metrics from the remote workers of this agent. @@ -1003,6 +1023,31 @@ class Trainer(Trainable): exported[ExportFormat.MODEL] = path return exported + def import_model(self, import_file): + """Imports a model from import_file. + + Note: Currently, only h5 files are supported. + + Args: + import_file (str): The file to import the model from. + + Returns: + A dict that maps ExportFormats to successfully exported models. + """ + # Check for existence. + if not os.path.exists(import_file): + raise FileNotFoundError( + "`import_file` '{}' does not exist! Can't import Model.". + format(import_file)) + # Get the format of the given file. + import_format = "h5" # TODO(sven): Support checkpoint loading. + + ExportFormat.validate([import_format]) + if import_format != ExportFormat.H5: + raise NotImplementedError + else: + return self.import_policy_model_from_h5(import_file) + def __getstate__(self): state = {} if hasattr(self, "workers"): diff --git a/rllib/evaluation/rollout_worker.py b/rllib/evaluation/rollout_worker.py index e527220a8..529586676 100644 --- a/rllib/evaluation/rollout_worker.py +++ b/rllib/evaluation/rollout_worker.py @@ -782,6 +782,12 @@ class RolloutWorker(EvaluatorInterface, ParallelIteratorWorker): def export_policy_model(self, export_dir, policy_id=DEFAULT_POLICY_ID): self.policy_map[policy_id].export_model(export_dir) + @DeveloperAPI + def import_policy_model_from_h5(self, + import_file, + policy_id=DEFAULT_POLICY_ID): + self.policy_map[policy_id].import_model_from_h5(import_file) + @DeveloperAPI def export_policy_checkpoint(self, export_dir, diff --git a/rllib/models/modelv2.py b/rllib/models/modelv2.py index 3435f269e..48e5f9d6e 100644 --- a/rllib/models/modelv2.py +++ b/rllib/models/modelv2.py @@ -192,6 +192,20 @@ class ModelV2: i += 1 return self.__call__(input_dict, states, train_batch.get("seq_lens")) + def import_from_h5(self, h5_file): + """Imports weights from an h5 file. + + Args: + h5_file (str): The h5 file name to import weights from. + + Example: + >>> trainer = MyTrainer() + >>> trainer.import_policy_model_from_h5("/tmp/weights.h5") + >>> for _ in range(10): + >>> trainer.train() + """ + raise NotImplementedError + def last_output(self): """Returns the last output returned from calling the model.""" return self._last_output diff --git a/rllib/models/tf/tf_action_dist.py b/rllib/models/tf/tf_action_dist.py index 3042e16b1..5edc941cd 100644 --- a/rllib/models/tf/tf_action_dist.py +++ b/rllib/models/tf/tf_action_dist.py @@ -46,7 +46,6 @@ class Categorical(TFActionDistribution): @DeveloperAPI def __init__(self, inputs, model=None, temperature=1.0): assert temperature > 0.0, "Categorical `temperature` must be > 0.0!" - self.n = inputs.shape[-1] # Allow softmax formula w/ temperature != 1.0: # Divide inputs by temperature. super().__init__(inputs / temperature, model) @@ -104,8 +103,7 @@ class MultiCategorical(TFActionDistribution): @override(ActionDistribution) def deterministic_sample(self): return tf.stack( - [cat.deterministic_sample() for cat in self.cats], - axis=1) + [cat.deterministic_sample() for cat in self.cats], axis=1) @override(ActionDistribution) def logp(self, actions): diff --git a/rllib/policy/eager_tf_policy.py b/rllib/policy/eager_tf_policy.py index 00e15a5ae..bed04b468 100644 --- a/rllib/policy/eager_tf_policy.py +++ b/rllib/policy/eager_tf_policy.py @@ -2,16 +2,16 @@ It supports both traced and non-traced eager execution modes.""" -import logging import functools +import logging import numpy as np from ray.util.debug import log_once from ray.rllib.evaluation.episode import _flatten_action from ray.rllib.models.catalog import ModelCatalog -from ray.rllib.policy.policy import Policy, LEARNER_STATS_KEY +from ray.rllib.policy.policy import Policy, LEARNER_STATS_KEY, ACTION_PROB, \ + ACTION_LOGP from ray.rllib.policy.sample_batch import SampleBatch -from ray.rllib.policy.policy import ACTION_PROB, ACTION_LOGP from ray.rllib.utils import add_mixins from ray.rllib.utils.annotations import override from ray.rllib.utils.framework import try_import_tf diff --git a/rllib/policy/policy.py b/rllib/policy/policy.py index 6d24b82c5..44f5446b3 100644 --- a/rllib/policy/policy.py +++ b/rllib/policy/policy.py @@ -375,6 +375,15 @@ class Policy(metaclass=ABCMeta): """ raise NotImplementedError + @DeveloperAPI + def import_model_from_h5(self, import_file): + """Imports Policy from local file. + + Arguments: + import_file (str): Local readable file. + """ + raise NotImplementedError + def _create_exploration(self, action_space, config): """Creates the Policy's Exploration object. diff --git a/rllib/policy/tf_policy.py b/rllib/policy/tf_policy.py index 2432f618d..779fa303b 100644 --- a/rllib/policy/tf_policy.py +++ b/rllib/policy/tf_policy.py @@ -364,6 +364,14 @@ class TFPolicy(Policy): saver = tf.train.Saver() saver.save(self._sess, save_path) + @override(Policy) + def import_model_from_h5(self, import_file): + """Imports weights into tf model.""" + # Make sure the session is the right one (see issue #7046). + with self._sess.graph.as_default(): + with self._sess.as_default(): + return self.model.import_from_h5(import_file) + @DeveloperAPI def copy(self, existing_inputs): """Creates a copy of self using existing input placeholders. diff --git a/rllib/policy/torch_policy.py b/rllib/policy/torch_policy.py index d9b6dacb4..eed61f42f 100644 --- a/rllib/policy/torch_policy.py +++ b/rllib/policy/torch_policy.py @@ -259,16 +259,21 @@ class TorchPolicy(Policy): @override(Policy) def export_model(self, export_dir): - """TODO: implement for torch. + """TODO(sven): implement for torch. """ raise NotImplementedError @override(Policy) def export_checkpoint(self, export_dir): - """TODO: implement for torch. + """TODO(sven): implement for torch. """ raise NotImplementedError + @override(Policy) + def import_model_from_h5(self, import_file): + """Imports weights into torch model.""" + return self.model.import_from_h5(import_file) + @DeveloperAPI class LearningRateSchedule: diff --git a/rllib/tests/data/model_weights/weights.h5 b/rllib/tests/data/model_weights/weights.h5 new file mode 100644 index 000000000..4d3ed8f7a Binary files /dev/null and b/rllib/tests/data/model_weights/weights.h5 differ diff --git a/rllib/tests/test_checkpoint_restore.py b/rllib/tests/test_checkpoint_restore.py index 99d8a6286..7a35a7163 100644 --- a/rllib/tests/test_checkpoint_restore.py +++ b/rllib/tests/test_checkpoint_restore.py @@ -19,19 +19,9 @@ def get_mean_action(alg, obs): CONFIGS = { - "SAC": { + "A3C": { "explore": False, - }, - "ES": { - "explore": False, - "episodes_per_batch": 10, - "train_batch_size": 100, - "num_workers": 2, - "noise_size": 2500000, - "observation_filter": "MeanStdFilter" - }, - "DQN": { - "explore": False + "num_workers": 1 }, "APEX_DDPG": { "explore": False, @@ -42,27 +32,37 @@ CONFIGS = { "num_replay_buffer_shards": 1, }, }, + "ARS": { + "explore": False, + "num_rollouts": 10, + "num_workers": 2, + "noise_size": 2500000, + "observation_filter": "MeanStdFilter" + }, "DDPG": { "explore": False, "timesteps_per_iteration": 100 }, + "DQN": { + "explore": False + }, + "ES": { + "explore": False, + "episodes_per_batch": 10, + "train_batch_size": 100, + "num_workers": 2, + "noise_size": 2500000, + "observation_filter": "MeanStdFilter" + }, "PPO": { "explore": False, "num_sgd_iter": 5, "train_batch_size": 1000, "num_workers": 2 }, - "A3C": { + "SAC": { "explore": False, - "num_workers": 1 }, - "ARS": { - "explore": False, - "num_rollouts": 10, - "num_workers": 2, - "noise_size": 2500000, - "observation_filter": "MeanStdFilter" - } } @@ -121,7 +121,7 @@ def export_test(alg_name, failures): else: algo = cls(config=CONFIGS[alg_name], env="CartPole-v0") - for _ in range(3): + for _ in range(2): res = algo.train() print("current status: " + str(res)) diff --git a/rllib/tests/test_model_imports.py b/rllib/tests/test_model_imports.py new file mode 100644 index 000000000..8d3846aab --- /dev/null +++ b/rllib/tests/test_model_imports.py @@ -0,0 +1,214 @@ +#!/usr/bin/env python + +import h5py +import numpy as np +from pathlib import Path +from tensorflow.python.eager.context import eager_mode +import unittest + +import ray +from ray.rllib.agents.registry import get_agent_class +from ray.rllib.models.catalog import ModelCatalog +from ray.rllib.models.tf.misc import normc_initializer +from ray.rllib.models.tf.tf_modelv2 import TFModelV2 +from ray.rllib.models.torch.torch_modelv2 import TorchModelV2 +from ray.rllib.utils.framework import try_import_tf, try_import_torch +from ray.rllib.utils.test_utils import check + +tf = try_import_tf() +torch, nn = try_import_torch() + + +class MyKerasModel(TFModelV2): + """Custom model for policy gradient algorithms.""" + + def __init__(self, obs_space, action_space, num_outputs, model_config, + name): + super(MyKerasModel, self).__init__(obs_space, action_space, + num_outputs, model_config, name) + self.inputs = tf.keras.layers.Input( + shape=obs_space.shape, name="observations") + layer_1 = tf.keras.layers.Dense( + 16, + name="layer1", + activation=tf.nn.relu, + kernel_initializer=normc_initializer(1.0))(self.inputs) + layer_out = tf.keras.layers.Dense( + num_outputs, + name="out", + activation=None, + kernel_initializer=normc_initializer(0.01))(layer_1) + if self.model_config["vf_share_layers"]: + value_out = tf.keras.layers.Dense( + 1, + name="value", + activation=None, + kernel_initializer=normc_initializer(0.01))(layer_1) + self.base_model = tf.keras.Model(self.inputs, + [layer_out, value_out]) + else: + self.base_model = tf.keras.Model(self.inputs, layer_out) + + self.register_variables(self.base_model.variables) + + def forward(self, input_dict, state, seq_lens): + if self.model_config["vf_share_layers"]: + model_out, self._value_out = self.base_model(input_dict["obs"]) + else: + model_out = self.base_model(input_dict["obs"]) + self._value_out = tf.zeros( + shape=(tf.shape(input_dict["obs"])[0], )) + return model_out, state + + def value_function(self): + return tf.reshape(self._value_out, [-1]) + + def import_from_h5(self, import_file): + # Override this to define custom weight loading behavior from h5 files. + self.base_model.load_weights(import_file) + + +class MyTorchModel(TorchModelV2, nn.Module): + """Generic vision network.""" + + def __init__(self, obs_space, action_space, num_outputs, model_config, + name): + TorchModelV2.__init__(self, obs_space, action_space, num_outputs, + model_config, name) + nn.Module.__init__(self) + + self.device = torch.device("cuda" + if torch.cuda.is_available() else "cpu") + + self.layer_1 = nn.Linear(obs_space.shape[0], 16).to(self.device) + self.layer_out = nn.Linear(16, num_outputs).to(self.device) + self.value_branch = nn.Linear(16, 1).to(self.device) + self.cur_value = None + + def forward(self, input_dict, state, seq_lens): + layer_1_out = self.layer_1(input_dict["obs"]) + logits = self.layer_out(layer_1_out) + self.cur_value = self.value_branch(layer_1_out).squeeze(1) + return logits, state + + def value_function(self): + assert self.cur_value is not None, "Must call `forward()` first!" + return self.cur_value + + def import_from_h5(self, import_file): + # Override this to define custom weight loading behavior from h5 files. + f = h5py.File(import_file) + self.layer_1.load_state_dict({ + "weight": torch.Tensor( + np.transpose(f["layer1"]["default_policy"]["layer1"][ + "kernel:0"].value)), + "bias": torch.Tensor( + np.transpose( + f["layer1"]["default_policy"]["layer1"]["bias:0"].value)), + }) + self.layer_out.load_state_dict({ + "weight": torch.Tensor( + np.transpose( + f["out"]["default_policy"]["out"]["kernel:0"].value)), + "bias": torch.Tensor( + np.transpose( + f["out"]["default_policy"]["out"]["bias:0"].value)), + }) + self.value_branch.load_state_dict({ + "weight": torch.Tensor( + np.transpose( + f["value"]["default_policy"]["value"]["kernel:0"].value)), + "bias": torch.Tensor( + np.transpose( + f["value"]["default_policy"]["value"]["bias:0"].value)), + }) + + +def model_import_test(algo, config, env): + # Get the abs-path to use (bazel-friendly). + rllib_dir = Path(__file__).parent.parent + import_file = str(rllib_dir) + "/tests/data/model_weights/weights.h5" + + agent_cls = get_agent_class(algo) + + for fw in ["tf", "torch"]: + print("framework={}".format(fw)) + + config["use_pytorch"] = fw == "torch" + config["eager"] = fw == "eager" + config["model"]["custom_model"] = "keras_model" if fw != "torch" else \ + "torch_model" + + eager_mode_ctx = None + if fw == "eager": + eager_mode_ctx = eager_mode() + eager_mode_ctx.__enter__() + assert tf.executing_eagerly() + elif fw == "tf": + assert not tf.executing_eagerly() + + agent = agent_cls(config, env) + + def current_weight(agent): + if fw == "tf": + return agent.get_weights()["default_policy"][ + "default_policy/value/kernel"][0] + elif fw == "torch": + return float(agent.get_weights()["default_policy"][ + "value_branch.weight"][0][0]) + else: + return agent.get_weights()["default_policy"][4][0] + + # Import weights for our custom model from an h5 file. + weight_before_import = current_weight(agent) + agent.import_model(import_file=import_file) + weight_after_import = current_weight(agent) + check(weight_before_import, weight_after_import, false=True) + + # Train for a while. + for _ in range(1): + agent.train() + weight_after_train = current_weight(agent) + # Weights should have changed. + check(weight_before_import, weight_after_train, false=True) + check(weight_after_import, weight_after_train, false=True) + + # We can save the entire Agent and restore, weights should remain the + # same. + file = agent.save("after_train") + check(weight_after_train, current_weight(agent)) + agent.restore(file) + check(weight_after_train, current_weight(agent)) + + # Import (untrained) weights again. + agent.import_model(import_file=import_file) + check(current_weight(agent), weight_after_import) + + if eager_mode_ctx: + eager_mode_ctx.__exit__(None, None, None) + + +class TestModelImport(unittest.TestCase): + def setUp(self): + ray.init() + ModelCatalog.register_custom_model("keras_model", MyKerasModel) + ModelCatalog.register_custom_model("torch_model", MyTorchModel) + + def tearDown(self): + ray.shutdown() + + def test_ppo(self): + model_import_test( + "PPO", + config={ + "num_workers": 0, + "vf_share_layers": True, + "model": {} + }, + env="CartPole-v0") + + +if __name__ == "__main__": + import pytest + import sys + sys.exit(pytest.main(["-v", __file__]))