From 1138f2ebedc8c69db334327a8dfeeac54cff4805 Mon Sep 17 00:00:00 2001 From: Sven Mika Date: Mon, 23 Mar 2020 20:19:30 +0100 Subject: [PATCH] [RLlib] Issue 7046 cannot restore keras model from h5 file. (#7482) --- doc/source/rllib-training.rst | 12 ++ python/ray/tune/trainable.py | 5 +- python/ray/tune/trial.py | 20 +- rllib/BUILD | 8 + rllib/agents/trainer.py | 45 +++++ rllib/evaluation/rollout_worker.py | 6 + rllib/models/modelv2.py | 14 ++ rllib/models/tf/tf_action_dist.py | 4 +- rllib/policy/eager_tf_policy.py | 6 +- rllib/policy/policy.py | 9 + rllib/policy/tf_policy.py | 8 + rllib/policy/torch_policy.py | 9 +- rllib/tests/data/model_weights/weights.h5 | Bin 0 -> 19880 bytes rllib/tests/test_checkpoint_restore.py | 44 ++--- rllib/tests/test_model_imports.py | 214 ++++++++++++++++++++++ 15 files changed, 364 insertions(+), 40 deletions(-) create mode 100644 rllib/tests/data/model_weights/weights.h5 create mode 100644 rllib/tests/test_model_imports.py diff --git a/doc/source/rllib-training.rst b/doc/source/rllib-training.rst index 4556f55c4..0b95ef61f 100644 --- a/doc/source/rllib-training.rst +++ b/doc/source/rllib-training.rst @@ -141,6 +141,18 @@ Here is an example of the basic usage (for a more complete example, see `custom_ checkpoint = trainer.save() print("checkpoint saved at", checkpoint) + # Also, in case you have trained a model outside of ray/RLlib and have created + # an h5-file with weight values in it, e.g. + # my_keras_model_trained_outside_rllib.save_weights("model.h5") + # (see: https://keras.io/models/about-keras-models/) + + # ... you can load the h5-weights into your Trainer's Policy's ModelV2 + # (tf or torch) by doing: + trainer.import_model("my_weights.h5") + # NOTE: In order for this to work, your (custom) model needs to implement + # the `import_from_h5` method. + # See https://github.com/ray-project/ray/blob/master/rllib/tests/test_model_imports.py + # for detailed examples for tf- and torch trainers/models. .. note:: diff --git a/python/ray/tune/trainable.py b/python/ray/tune/trainable.py index ebf5852b4..a1bab0fc6 100644 --- a/python/ray/tune/trainable.py +++ b/python/ray/tune/trainable.py @@ -473,13 +473,16 @@ class Trainable: export model to local directory. Args: - export_formats (list): List of formats that should be exported. + export_formats (Union[list,str]): Format or list of (str) formats + that should be exported. export_dir (str): Optional dir to place the exported model. Defaults to self.logdir. Returns: A dict that maps ExportFormats to successfully exported models. """ + if isinstance(export_formats, str): + export_formats = [export_formats] export_dir = export_dir or self.logdir return self._export_model(export_formats, export_dir) diff --git a/python/ray/tune/trial.py b/python/ray/tune/trial.py index d74126539..942bdef95 100644 --- a/python/ray/tune/trial.py +++ b/python/ray/tune/trial.py @@ -48,28 +48,30 @@ class Location: class ExportFormat: - """Describes the format to export the trial Trainable. + """Describes the format to import/export the trial Trainable. This may correspond to different file formats based on the Trainable implementation. """ CHECKPOINT = "checkpoint" MODEL = "model" + H5 = "h5" @staticmethod - def validate(export_formats): - """Validates export_formats. + def validate(formats): + """Validates formats. Raises: ValueError if the format is unknown. """ - for i in range(len(export_formats)): - export_formats[i] = export_formats[i].strip().lower() - if export_formats[i] not in [ - ExportFormat.CHECKPOINT, ExportFormat.MODEL + for i in range(len(formats)): + formats[i] = formats[i].strip().lower() + if formats[i] not in [ + ExportFormat.CHECKPOINT, ExportFormat.MODEL, + ExportFormat.H5 ]: - raise TuneError("Unsupported export format: " + - export_formats[i]) + raise TuneError("Unsupported import/export format: " + + formats[i]) def checkpoint_deleter(trial_id, runner): diff --git a/rllib/BUILD b/rllib/BUILD index a92b0e9e2..76bf3b3d8 100644 --- a/rllib/BUILD +++ b/rllib/BUILD @@ -1034,6 +1034,14 @@ py_test( srcs = ["tests/test_lstm.py"] ) +py_test( + name = "tests/test_model_imports", + tags = ["tests_dir", "tests_dir_M", "model_imports"], + size = "small", + data = glob(["tests/data/model_weights/**"]), + srcs = ["tests/test_model_imports.py"] +) + py_test( name = "tests/test_multi_agent_env", tags = ["tests_dir", "tests_dir_M"], diff --git a/rllib/agents/trainer.py b/rllib/agents/trainer.py index 1a74eacbd..61d70358c 100644 --- a/rllib/agents/trainer.py +++ b/rllib/agents/trainer.py @@ -628,6 +628,7 @@ class Trainer(Trainable): checkpoint_path = os.path.join(checkpoint_dir, "checkpoint-{}".format(self.iteration)) pickle.dump(self.__getstate__(), open(checkpoint_path, "wb")) + return checkpoint_path @override(Trainable) @@ -866,6 +867,25 @@ class Trainer(Trainable): self.workers.local_worker().export_policy_checkpoint( export_dir, filename_prefix, policy_id) + @DeveloperAPI + def import_policy_model_from_h5(self, + import_file, + policy_id=DEFAULT_POLICY_ID): + """Imports a policy's model with given policy_id from a local h5 file. + + Arguments: + import_file (str): The h5 file to import from. + policy_id (string): Optional policy id to import into. + + Example: + >>> trainer = MyTrainer() + >>> trainer.import_policy_model_from_h5("/tmp/weights.h5") + >>> for _ in range(10): + >>> trainer.train() + """ + self.workers.local_worker().import_policy_model_from_h5( + import_file, policy_id) + @DeveloperAPI def collect_metrics(self, selected_workers=None): """Collects metrics from the remote workers of this agent. @@ -1003,6 +1023,31 @@ class Trainer(Trainable): exported[ExportFormat.MODEL] = path return exported + def import_model(self, import_file): + """Imports a model from import_file. + + Note: Currently, only h5 files are supported. + + Args: + import_file (str): The file to import the model from. + + Returns: + A dict that maps ExportFormats to successfully exported models. + """ + # Check for existence. + if not os.path.exists(import_file): + raise FileNotFoundError( + "`import_file` '{}' does not exist! Can't import Model.". + format(import_file)) + # Get the format of the given file. + import_format = "h5" # TODO(sven): Support checkpoint loading. + + ExportFormat.validate([import_format]) + if import_format != ExportFormat.H5: + raise NotImplementedError + else: + return self.import_policy_model_from_h5(import_file) + def __getstate__(self): state = {} if hasattr(self, "workers"): diff --git a/rllib/evaluation/rollout_worker.py b/rllib/evaluation/rollout_worker.py index e527220a8..529586676 100644 --- a/rllib/evaluation/rollout_worker.py +++ b/rllib/evaluation/rollout_worker.py @@ -782,6 +782,12 @@ class RolloutWorker(EvaluatorInterface, ParallelIteratorWorker): def export_policy_model(self, export_dir, policy_id=DEFAULT_POLICY_ID): self.policy_map[policy_id].export_model(export_dir) + @DeveloperAPI + def import_policy_model_from_h5(self, + import_file, + policy_id=DEFAULT_POLICY_ID): + self.policy_map[policy_id].import_model_from_h5(import_file) + @DeveloperAPI def export_policy_checkpoint(self, export_dir, diff --git a/rllib/models/modelv2.py b/rllib/models/modelv2.py index 3435f269e..48e5f9d6e 100644 --- a/rllib/models/modelv2.py +++ b/rllib/models/modelv2.py @@ -192,6 +192,20 @@ class ModelV2: i += 1 return self.__call__(input_dict, states, train_batch.get("seq_lens")) + def import_from_h5(self, h5_file): + """Imports weights from an h5 file. + + Args: + h5_file (str): The h5 file name to import weights from. + + Example: + >>> trainer = MyTrainer() + >>> trainer.import_policy_model_from_h5("/tmp/weights.h5") + >>> for _ in range(10): + >>> trainer.train() + """ + raise NotImplementedError + def last_output(self): """Returns the last output returned from calling the model.""" return self._last_output diff --git a/rllib/models/tf/tf_action_dist.py b/rllib/models/tf/tf_action_dist.py index 3042e16b1..5edc941cd 100644 --- a/rllib/models/tf/tf_action_dist.py +++ b/rllib/models/tf/tf_action_dist.py @@ -46,7 +46,6 @@ class Categorical(TFActionDistribution): @DeveloperAPI def __init__(self, inputs, model=None, temperature=1.0): assert temperature > 0.0, "Categorical `temperature` must be > 0.0!" - self.n = inputs.shape[-1] # Allow softmax formula w/ temperature != 1.0: # Divide inputs by temperature. super().__init__(inputs / temperature, model) @@ -104,8 +103,7 @@ class MultiCategorical(TFActionDistribution): @override(ActionDistribution) def deterministic_sample(self): return tf.stack( - [cat.deterministic_sample() for cat in self.cats], - axis=1) + [cat.deterministic_sample() for cat in self.cats], axis=1) @override(ActionDistribution) def logp(self, actions): diff --git a/rllib/policy/eager_tf_policy.py b/rllib/policy/eager_tf_policy.py index 00e15a5ae..bed04b468 100644 --- a/rllib/policy/eager_tf_policy.py +++ b/rllib/policy/eager_tf_policy.py @@ -2,16 +2,16 @@ It supports both traced and non-traced eager execution modes.""" -import logging import functools +import logging import numpy as np from ray.util.debug import log_once from ray.rllib.evaluation.episode import _flatten_action from ray.rllib.models.catalog import ModelCatalog -from ray.rllib.policy.policy import Policy, LEARNER_STATS_KEY +from ray.rllib.policy.policy import Policy, LEARNER_STATS_KEY, ACTION_PROB, \ + ACTION_LOGP from ray.rllib.policy.sample_batch import SampleBatch -from ray.rllib.policy.policy import ACTION_PROB, ACTION_LOGP from ray.rllib.utils import add_mixins from ray.rllib.utils.annotations import override from ray.rllib.utils.framework import try_import_tf diff --git a/rllib/policy/policy.py b/rllib/policy/policy.py index 6d24b82c5..44f5446b3 100644 --- a/rllib/policy/policy.py +++ b/rllib/policy/policy.py @@ -375,6 +375,15 @@ class Policy(metaclass=ABCMeta): """ raise NotImplementedError + @DeveloperAPI + def import_model_from_h5(self, import_file): + """Imports Policy from local file. + + Arguments: + import_file (str): Local readable file. + """ + raise NotImplementedError + def _create_exploration(self, action_space, config): """Creates the Policy's Exploration object. diff --git a/rllib/policy/tf_policy.py b/rllib/policy/tf_policy.py index 2432f618d..779fa303b 100644 --- a/rllib/policy/tf_policy.py +++ b/rllib/policy/tf_policy.py @@ -364,6 +364,14 @@ class TFPolicy(Policy): saver = tf.train.Saver() saver.save(self._sess, save_path) + @override(Policy) + def import_model_from_h5(self, import_file): + """Imports weights into tf model.""" + # Make sure the session is the right one (see issue #7046). + with self._sess.graph.as_default(): + with self._sess.as_default(): + return self.model.import_from_h5(import_file) + @DeveloperAPI def copy(self, existing_inputs): """Creates a copy of self using existing input placeholders. diff --git a/rllib/policy/torch_policy.py b/rllib/policy/torch_policy.py index d9b6dacb4..eed61f42f 100644 --- a/rllib/policy/torch_policy.py +++ b/rllib/policy/torch_policy.py @@ -259,16 +259,21 @@ class TorchPolicy(Policy): @override(Policy) def export_model(self, export_dir): - """TODO: implement for torch. + """TODO(sven): implement for torch. """ raise NotImplementedError @override(Policy) def export_checkpoint(self, export_dir): - """TODO: implement for torch. + """TODO(sven): implement for torch. """ raise NotImplementedError + @override(Policy) + def import_model_from_h5(self, import_file): + """Imports weights into torch model.""" + return self.model.import_from_h5(import_file) + @DeveloperAPI class LearningRateSchedule: diff --git a/rllib/tests/data/model_weights/weights.h5 b/rllib/tests/data/model_weights/weights.h5 new file mode 100644 index 0000000000000000000000000000000000000000..4d3ed8f7a4677d1d9d69a3dbba59b7cf3563d678 GIT binary patch literal 19880 zcmeHP4@^@>7{3By?TWx=MxAq(=>}P}AY&V}^u3}Yol`No7&4<+qzZ_XK}E>a5k=E! zX2_fiaV#Tq3C!riGTdA%g)G^|rgLUdXRIY-F?8jT8l;g2=@~B z?%sX(-FLtDec!wLuI(Gv1q&lIk7*FT)M^xlqU0j{>!DZU5^@LX(7u)yRFt|XZ44m= z6$(eVO-JuXQTh5%awCOye0h0B29h{IK2X}NS3rF44PF>e2v{@Hm*M$l(@7^%+D`3^ zR*@sO%wAen;G z|A~_Z-9lUxn_?e=bcU5}K$WmE>jp{S5T3 zpH3LsX84ZS|4!!oti>{RR1Lt9EC>*C+KbB^rR$3vTSzYmx9}X23`vIMInMPU^LQkb zIOPa%1ULd50geDifFr;W;0SO8I0762jsQnsv>>ozY1RUqtAjjUEzQqC9>6-4Xqr#+ z4zTafrg;(IWAh|>I>Ac`qyT*CL{g7WqLlKr2}-^WT34r2hEYxeonNHr$piquQ9d26 z-hD=Z&0~VUxu{>kaC!Fiw#`M(oDxS-!Me)6JiY}O?vX-1ko99Ns*t=M%a4%HFW5XL zKEHSZx6|_r$om0R>X|@}FcPHGgkpP zTz!~kb8(-v$+>HgbnY{pSD?6JnMJhWab-B4%IF!3jVs^>mzcaioG}tFLeeSV;4ZRm zP*LwdsEt`S4D|eAa9TY0z5SivG{%y-wa+9pP#kbQMfi9ec<^l^m>IBfKu?8qQ2POy z@P>vuyptg=q(H*}DH{i>sNI4*){X}#cHFTM97B5dVP_0?hXDM%=JC7pjknh6)t%l3 zdzKg%l_}1;(q-DCs`S40=fSjtnmX@}$qunKN+&*f@^9ft-XgQQ{94+!a6@ZbL9=v$W`dFwfj}p$O#ER{CGllId zGg_~_Q)vEeYjo?%?~TI#+`oiVvF4V#i&Mm|_dXDI6kTrly|GL9sK%U@o*{^uImb*F znnm%}p=*NU>T}|5$1h&t+mkJk=j*+@GWU5qOMmqqJ@kw?En(e>x!yO0X-hVUyE@Ja z8HbjNuIrz8y{j&^)T*k5Q#s#w&HouhbSig-Pf6^I}c}hv^yi+6@@O3p<|6_ zZ_X~`>LXd7ulU#UMc8L2Q)`_Op6c*-Q}ZIb-M3DgJwKnHX!=1{Z;U$^Vsf9`o~pTC zW~y_ZFzt72^4$LJQ{%qnON~91drYAJ#hb5sUdh&Zstap9J-PEuH=;D2d9i1Wu7Ivr;4fE-Ra0vrL307qbaAmE!n zBKaN2W5NCy=pgb%aRgC|0$AUCQH7#?xGVlXhg(nc?~PQUK>JJG3xpBYox?p3KHkV8 z!0s=JS4nv!YDb`*GV|M)Otk>RJyOVDv*$)yR1z?MjZ^F3_mH;+s?;+<`q;g{5hQMouPhWqfvIeAbSG9A?aS^A`tPG||{- zBoG9O1Nmu$kH>*AZ!&kBAR7ntl&hWE4|E#p{PtREw-5(dI|e;CRU*LVx0}q$cD4q7 zZhe+*8Q*o^btYMf`&;J*zpku_CrHDSAT{50WyOd-5P$pIs&t}O6V)!z`vol|jLO-L z=+yY&!-xR;ej#3kP&~CGP(`_Kynp*!_VbPd`Fs}sil;xnSL&G{f0f+d4wp3c-{a!h z!b%%vPN;AEAz^~u=fe9Dq(D9RcvMT{kwu}9Z#>FZv=4Q~-}*LxJ{j31{i+4N0>vNo zBErYx55Zx){b%D3_)R0VAHV~C69W9n3JpP0cAsmZb_?=ZI}UP~G2hK!94N^oQ?GfE z&_Hou*-M0v$AK|#GIyLH8wd21tBKkVbQ)S4f;4=yfYf}~l_R^){U5}n%dP+b literal 0 HcmV?d00001 diff --git a/rllib/tests/test_checkpoint_restore.py b/rllib/tests/test_checkpoint_restore.py index 99d8a6286..7a35a7163 100644 --- a/rllib/tests/test_checkpoint_restore.py +++ b/rllib/tests/test_checkpoint_restore.py @@ -19,19 +19,9 @@ def get_mean_action(alg, obs): CONFIGS = { - "SAC": { + "A3C": { "explore": False, - }, - "ES": { - "explore": False, - "episodes_per_batch": 10, - "train_batch_size": 100, - "num_workers": 2, - "noise_size": 2500000, - "observation_filter": "MeanStdFilter" - }, - "DQN": { - "explore": False + "num_workers": 1 }, "APEX_DDPG": { "explore": False, @@ -42,27 +32,37 @@ CONFIGS = { "num_replay_buffer_shards": 1, }, }, + "ARS": { + "explore": False, + "num_rollouts": 10, + "num_workers": 2, + "noise_size": 2500000, + "observation_filter": "MeanStdFilter" + }, "DDPG": { "explore": False, "timesteps_per_iteration": 100 }, + "DQN": { + "explore": False + }, + "ES": { + "explore": False, + "episodes_per_batch": 10, + "train_batch_size": 100, + "num_workers": 2, + "noise_size": 2500000, + "observation_filter": "MeanStdFilter" + }, "PPO": { "explore": False, "num_sgd_iter": 5, "train_batch_size": 1000, "num_workers": 2 }, - "A3C": { + "SAC": { "explore": False, - "num_workers": 1 }, - "ARS": { - "explore": False, - "num_rollouts": 10, - "num_workers": 2, - "noise_size": 2500000, - "observation_filter": "MeanStdFilter" - } } @@ -121,7 +121,7 @@ def export_test(alg_name, failures): else: algo = cls(config=CONFIGS[alg_name], env="CartPole-v0") - for _ in range(3): + for _ in range(2): res = algo.train() print("current status: " + str(res)) diff --git a/rllib/tests/test_model_imports.py b/rllib/tests/test_model_imports.py new file mode 100644 index 000000000..8d3846aab --- /dev/null +++ b/rllib/tests/test_model_imports.py @@ -0,0 +1,214 @@ +#!/usr/bin/env python + +import h5py +import numpy as np +from pathlib import Path +from tensorflow.python.eager.context import eager_mode +import unittest + +import ray +from ray.rllib.agents.registry import get_agent_class +from ray.rllib.models.catalog import ModelCatalog +from ray.rllib.models.tf.misc import normc_initializer +from ray.rllib.models.tf.tf_modelv2 import TFModelV2 +from ray.rllib.models.torch.torch_modelv2 import TorchModelV2 +from ray.rllib.utils.framework import try_import_tf, try_import_torch +from ray.rllib.utils.test_utils import check + +tf = try_import_tf() +torch, nn = try_import_torch() + + +class MyKerasModel(TFModelV2): + """Custom model for policy gradient algorithms.""" + + def __init__(self, obs_space, action_space, num_outputs, model_config, + name): + super(MyKerasModel, self).__init__(obs_space, action_space, + num_outputs, model_config, name) + self.inputs = tf.keras.layers.Input( + shape=obs_space.shape, name="observations") + layer_1 = tf.keras.layers.Dense( + 16, + name="layer1", + activation=tf.nn.relu, + kernel_initializer=normc_initializer(1.0))(self.inputs) + layer_out = tf.keras.layers.Dense( + num_outputs, + name="out", + activation=None, + kernel_initializer=normc_initializer(0.01))(layer_1) + if self.model_config["vf_share_layers"]: + value_out = tf.keras.layers.Dense( + 1, + name="value", + activation=None, + kernel_initializer=normc_initializer(0.01))(layer_1) + self.base_model = tf.keras.Model(self.inputs, + [layer_out, value_out]) + else: + self.base_model = tf.keras.Model(self.inputs, layer_out) + + self.register_variables(self.base_model.variables) + + def forward(self, input_dict, state, seq_lens): + if self.model_config["vf_share_layers"]: + model_out, self._value_out = self.base_model(input_dict["obs"]) + else: + model_out = self.base_model(input_dict["obs"]) + self._value_out = tf.zeros( + shape=(tf.shape(input_dict["obs"])[0], )) + return model_out, state + + def value_function(self): + return tf.reshape(self._value_out, [-1]) + + def import_from_h5(self, import_file): + # Override this to define custom weight loading behavior from h5 files. + self.base_model.load_weights(import_file) + + +class MyTorchModel(TorchModelV2, nn.Module): + """Generic vision network.""" + + def __init__(self, obs_space, action_space, num_outputs, model_config, + name): + TorchModelV2.__init__(self, obs_space, action_space, num_outputs, + model_config, name) + nn.Module.__init__(self) + + self.device = torch.device("cuda" + if torch.cuda.is_available() else "cpu") + + self.layer_1 = nn.Linear(obs_space.shape[0], 16).to(self.device) + self.layer_out = nn.Linear(16, num_outputs).to(self.device) + self.value_branch = nn.Linear(16, 1).to(self.device) + self.cur_value = None + + def forward(self, input_dict, state, seq_lens): + layer_1_out = self.layer_1(input_dict["obs"]) + logits = self.layer_out(layer_1_out) + self.cur_value = self.value_branch(layer_1_out).squeeze(1) + return logits, state + + def value_function(self): + assert self.cur_value is not None, "Must call `forward()` first!" + return self.cur_value + + def import_from_h5(self, import_file): + # Override this to define custom weight loading behavior from h5 files. + f = h5py.File(import_file) + self.layer_1.load_state_dict({ + "weight": torch.Tensor( + np.transpose(f["layer1"]["default_policy"]["layer1"][ + "kernel:0"].value)), + "bias": torch.Tensor( + np.transpose( + f["layer1"]["default_policy"]["layer1"]["bias:0"].value)), + }) + self.layer_out.load_state_dict({ + "weight": torch.Tensor( + np.transpose( + f["out"]["default_policy"]["out"]["kernel:0"].value)), + "bias": torch.Tensor( + np.transpose( + f["out"]["default_policy"]["out"]["bias:0"].value)), + }) + self.value_branch.load_state_dict({ + "weight": torch.Tensor( + np.transpose( + f["value"]["default_policy"]["value"]["kernel:0"].value)), + "bias": torch.Tensor( + np.transpose( + f["value"]["default_policy"]["value"]["bias:0"].value)), + }) + + +def model_import_test(algo, config, env): + # Get the abs-path to use (bazel-friendly). + rllib_dir = Path(__file__).parent.parent + import_file = str(rllib_dir) + "/tests/data/model_weights/weights.h5" + + agent_cls = get_agent_class(algo) + + for fw in ["tf", "torch"]: + print("framework={}".format(fw)) + + config["use_pytorch"] = fw == "torch" + config["eager"] = fw == "eager" + config["model"]["custom_model"] = "keras_model" if fw != "torch" else \ + "torch_model" + + eager_mode_ctx = None + if fw == "eager": + eager_mode_ctx = eager_mode() + eager_mode_ctx.__enter__() + assert tf.executing_eagerly() + elif fw == "tf": + assert not tf.executing_eagerly() + + agent = agent_cls(config, env) + + def current_weight(agent): + if fw == "tf": + return agent.get_weights()["default_policy"][ + "default_policy/value/kernel"][0] + elif fw == "torch": + return float(agent.get_weights()["default_policy"][ + "value_branch.weight"][0][0]) + else: + return agent.get_weights()["default_policy"][4][0] + + # Import weights for our custom model from an h5 file. + weight_before_import = current_weight(agent) + agent.import_model(import_file=import_file) + weight_after_import = current_weight(agent) + check(weight_before_import, weight_after_import, false=True) + + # Train for a while. + for _ in range(1): + agent.train() + weight_after_train = current_weight(agent) + # Weights should have changed. + check(weight_before_import, weight_after_train, false=True) + check(weight_after_import, weight_after_train, false=True) + + # We can save the entire Agent and restore, weights should remain the + # same. + file = agent.save("after_train") + check(weight_after_train, current_weight(agent)) + agent.restore(file) + check(weight_after_train, current_weight(agent)) + + # Import (untrained) weights again. + agent.import_model(import_file=import_file) + check(current_weight(agent), weight_after_import) + + if eager_mode_ctx: + eager_mode_ctx.__exit__(None, None, None) + + +class TestModelImport(unittest.TestCase): + def setUp(self): + ray.init() + ModelCatalog.register_custom_model("keras_model", MyKerasModel) + ModelCatalog.register_custom_model("torch_model", MyTorchModel) + + def tearDown(self): + ray.shutdown() + + def test_ppo(self): + model_import_test( + "PPO", + config={ + "num_workers": 0, + "vf_share_layers": True, + "model": {} + }, + env="CartPole-v0") + + +if __name__ == "__main__": + import pytest + import sys + sys.exit(pytest.main(["-v", __file__]))