[RLlib] Issue 7046 cannot restore keras model from h5 file. (#7482)

This commit is contained in:
Sven Mika
2020-03-23 20:19:30 +01:00
committed by GitHub
parent ee8c9ff732
commit 1138f2ebed
15 changed files with 364 additions and 40 deletions
+12
View File
@@ -141,6 +141,18 @@ Here is an example of the basic usage (for a more complete example, see `custom_
checkpoint = trainer.save()
print("checkpoint saved at", checkpoint)
# Also, in case you have trained a model outside of ray/RLlib and have created
# an h5-file with weight values in it, e.g.
# my_keras_model_trained_outside_rllib.save_weights("model.h5")
# (see: https://keras.io/models/about-keras-models/)
# ... you can load the h5-weights into your Trainer's Policy's ModelV2
# (tf or torch) by doing:
trainer.import_model("my_weights.h5")
# NOTE: In order for this to work, your (custom) model needs to implement
# the `import_from_h5` method.
# See https://github.com/ray-project/ray/blob/master/rllib/tests/test_model_imports.py
# for detailed examples for tf- and torch trainers/models.
.. note::
+4 -1
View File
@@ -473,13 +473,16 @@ class Trainable:
export model to local directory.
Args:
export_formats (list): List of formats that should be exported.
export_formats (Union[list,str]): Format or list of (str) formats
that should be exported.
export_dir (str): Optional dir to place the exported model.
Defaults to self.logdir.
Returns:
A dict that maps ExportFormats to successfully exported models.
"""
if isinstance(export_formats, str):
export_formats = [export_formats]
export_dir = export_dir or self.logdir
return self._export_model(export_formats, export_dir)
+11 -9
View File
@@ -48,28 +48,30 @@ class Location:
class ExportFormat:
"""Describes the format to export the trial Trainable.
"""Describes the format to import/export the trial Trainable.
This may correspond to different file formats based on the
Trainable implementation.
"""
CHECKPOINT = "checkpoint"
MODEL = "model"
H5 = "h5"
@staticmethod
def validate(export_formats):
"""Validates export_formats.
def validate(formats):
"""Validates formats.
Raises:
ValueError if the format is unknown.
"""
for i in range(len(export_formats)):
export_formats[i] = export_formats[i].strip().lower()
if export_formats[i] not in [
ExportFormat.CHECKPOINT, ExportFormat.MODEL
for i in range(len(formats)):
formats[i] = formats[i].strip().lower()
if formats[i] not in [
ExportFormat.CHECKPOINT, ExportFormat.MODEL,
ExportFormat.H5
]:
raise TuneError("Unsupported export format: " +
export_formats[i])
raise TuneError("Unsupported import/export format: " +
formats[i])
def checkpoint_deleter(trial_id, runner):
+8
View File
@@ -1034,6 +1034,14 @@ py_test(
srcs = ["tests/test_lstm.py"]
)
py_test(
name = "tests/test_model_imports",
tags = ["tests_dir", "tests_dir_M", "model_imports"],
size = "small",
data = glob(["tests/data/model_weights/**"]),
srcs = ["tests/test_model_imports.py"]
)
py_test(
name = "tests/test_multi_agent_env",
tags = ["tests_dir", "tests_dir_M"],
+45
View File
@@ -628,6 +628,7 @@ class Trainer(Trainable):
checkpoint_path = os.path.join(checkpoint_dir,
"checkpoint-{}".format(self.iteration))
pickle.dump(self.__getstate__(), open(checkpoint_path, "wb"))
return checkpoint_path
@override(Trainable)
@@ -866,6 +867,25 @@ class Trainer(Trainable):
self.workers.local_worker().export_policy_checkpoint(
export_dir, filename_prefix, policy_id)
@DeveloperAPI
def import_policy_model_from_h5(self,
import_file,
policy_id=DEFAULT_POLICY_ID):
"""Imports a policy's model with given policy_id from a local h5 file.
Arguments:
import_file (str): The h5 file to import from.
policy_id (string): Optional policy id to import into.
Example:
>>> trainer = MyTrainer()
>>> trainer.import_policy_model_from_h5("/tmp/weights.h5")
>>> for _ in range(10):
>>> trainer.train()
"""
self.workers.local_worker().import_policy_model_from_h5(
import_file, policy_id)
@DeveloperAPI
def collect_metrics(self, selected_workers=None):
"""Collects metrics from the remote workers of this agent.
@@ -1003,6 +1023,31 @@ class Trainer(Trainable):
exported[ExportFormat.MODEL] = path
return exported
def import_model(self, import_file):
"""Imports a model from import_file.
Note: Currently, only h5 files are supported.
Args:
import_file (str): The file to import the model from.
Returns:
A dict that maps ExportFormats to successfully exported models.
"""
# Check for existence.
if not os.path.exists(import_file):
raise FileNotFoundError(
"`import_file` '{}' does not exist! Can't import Model.".
format(import_file))
# Get the format of the given file.
import_format = "h5" # TODO(sven): Support checkpoint loading.
ExportFormat.validate([import_format])
if import_format != ExportFormat.H5:
raise NotImplementedError
else:
return self.import_policy_model_from_h5(import_file)
def __getstate__(self):
state = {}
if hasattr(self, "workers"):
+6
View File
@@ -782,6 +782,12 @@ class RolloutWorker(EvaluatorInterface, ParallelIteratorWorker):
def export_policy_model(self, export_dir, policy_id=DEFAULT_POLICY_ID):
self.policy_map[policy_id].export_model(export_dir)
@DeveloperAPI
def import_policy_model_from_h5(self,
import_file,
policy_id=DEFAULT_POLICY_ID):
self.policy_map[policy_id].import_model_from_h5(import_file)
@DeveloperAPI
def export_policy_checkpoint(self,
export_dir,
+14
View File
@@ -192,6 +192,20 @@ class ModelV2:
i += 1
return self.__call__(input_dict, states, train_batch.get("seq_lens"))
def import_from_h5(self, h5_file):
"""Imports weights from an h5 file.
Args:
h5_file (str): The h5 file name to import weights from.
Example:
>>> trainer = MyTrainer()
>>> trainer.import_policy_model_from_h5("/tmp/weights.h5")
>>> for _ in range(10):
>>> trainer.train()
"""
raise NotImplementedError
def last_output(self):
"""Returns the last output returned from calling the model."""
return self._last_output
+1 -3
View File
@@ -46,7 +46,6 @@ class Categorical(TFActionDistribution):
@DeveloperAPI
def __init__(self, inputs, model=None, temperature=1.0):
assert temperature > 0.0, "Categorical `temperature` must be > 0.0!"
self.n = inputs.shape[-1]
# Allow softmax formula w/ temperature != 1.0:
# Divide inputs by temperature.
super().__init__(inputs / temperature, model)
@@ -104,8 +103,7 @@ class MultiCategorical(TFActionDistribution):
@override(ActionDistribution)
def deterministic_sample(self):
return tf.stack(
[cat.deterministic_sample() for cat in self.cats],
axis=1)
[cat.deterministic_sample() for cat in self.cats], axis=1)
@override(ActionDistribution)
def logp(self, actions):
+3 -3
View File
@@ -2,16 +2,16 @@
It supports both traced and non-traced eager execution modes."""
import logging
import functools
import logging
import numpy as np
from ray.util.debug import log_once
from ray.rllib.evaluation.episode import _flatten_action
from ray.rllib.models.catalog import ModelCatalog
from ray.rllib.policy.policy import Policy, LEARNER_STATS_KEY
from ray.rllib.policy.policy import Policy, LEARNER_STATS_KEY, ACTION_PROB, \
ACTION_LOGP
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.policy.policy import ACTION_PROB, ACTION_LOGP
from ray.rllib.utils import add_mixins
from ray.rllib.utils.annotations import override
from ray.rllib.utils.framework import try_import_tf
+9
View File
@@ -375,6 +375,15 @@ class Policy(metaclass=ABCMeta):
"""
raise NotImplementedError
@DeveloperAPI
def import_model_from_h5(self, import_file):
"""Imports Policy from local file.
Arguments:
import_file (str): Local readable file.
"""
raise NotImplementedError
def _create_exploration(self, action_space, config):
"""Creates the Policy's Exploration object.
+8
View File
@@ -364,6 +364,14 @@ class TFPolicy(Policy):
saver = tf.train.Saver()
saver.save(self._sess, save_path)
@override(Policy)
def import_model_from_h5(self, import_file):
"""Imports weights into tf model."""
# Make sure the session is the right one (see issue #7046).
with self._sess.graph.as_default():
with self._sess.as_default():
return self.model.import_from_h5(import_file)
@DeveloperAPI
def copy(self, existing_inputs):
"""Creates a copy of self using existing input placeholders.
+7 -2
View File
@@ -259,16 +259,21 @@ class TorchPolicy(Policy):
@override(Policy)
def export_model(self, export_dir):
"""TODO: implement for torch.
"""TODO(sven): implement for torch.
"""
raise NotImplementedError
@override(Policy)
def export_checkpoint(self, export_dir):
"""TODO: implement for torch.
"""TODO(sven): implement for torch.
"""
raise NotImplementedError
@override(Policy)
def import_model_from_h5(self, import_file):
"""Imports weights into torch model."""
return self.model.import_from_h5(import_file)
@DeveloperAPI
class LearningRateSchedule:
Binary file not shown.
+22 -22
View File
@@ -19,19 +19,9 @@ def get_mean_action(alg, obs):
CONFIGS = {
"SAC": {
"A3C": {
"explore": False,
},
"ES": {
"explore": False,
"episodes_per_batch": 10,
"train_batch_size": 100,
"num_workers": 2,
"noise_size": 2500000,
"observation_filter": "MeanStdFilter"
},
"DQN": {
"explore": False
"num_workers": 1
},
"APEX_DDPG": {
"explore": False,
@@ -42,27 +32,37 @@ CONFIGS = {
"num_replay_buffer_shards": 1,
},
},
"ARS": {
"explore": False,
"num_rollouts": 10,
"num_workers": 2,
"noise_size": 2500000,
"observation_filter": "MeanStdFilter"
},
"DDPG": {
"explore": False,
"timesteps_per_iteration": 100
},
"DQN": {
"explore": False
},
"ES": {
"explore": False,
"episodes_per_batch": 10,
"train_batch_size": 100,
"num_workers": 2,
"noise_size": 2500000,
"observation_filter": "MeanStdFilter"
},
"PPO": {
"explore": False,
"num_sgd_iter": 5,
"train_batch_size": 1000,
"num_workers": 2
},
"A3C": {
"SAC": {
"explore": False,
"num_workers": 1
},
"ARS": {
"explore": False,
"num_rollouts": 10,
"num_workers": 2,
"noise_size": 2500000,
"observation_filter": "MeanStdFilter"
}
}
@@ -121,7 +121,7 @@ def export_test(alg_name, failures):
else:
algo = cls(config=CONFIGS[alg_name], env="CartPole-v0")
for _ in range(3):
for _ in range(2):
res = algo.train()
print("current status: " + str(res))
+214
View File
@@ -0,0 +1,214 @@
#!/usr/bin/env python
import h5py
import numpy as np
from pathlib import Path
from tensorflow.python.eager.context import eager_mode
import unittest
import ray
from ray.rllib.agents.registry import get_agent_class
from ray.rllib.models.catalog import ModelCatalog
from ray.rllib.models.tf.misc import normc_initializer
from ray.rllib.models.tf.tf_modelv2 import TFModelV2
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.rllib.utils.framework import try_import_tf, try_import_torch
from ray.rllib.utils.test_utils import check
tf = try_import_tf()
torch, nn = try_import_torch()
class MyKerasModel(TFModelV2):
"""Custom model for policy gradient algorithms."""
def __init__(self, obs_space, action_space, num_outputs, model_config,
name):
super(MyKerasModel, self).__init__(obs_space, action_space,
num_outputs, model_config, name)
self.inputs = tf.keras.layers.Input(
shape=obs_space.shape, name="observations")
layer_1 = tf.keras.layers.Dense(
16,
name="layer1",
activation=tf.nn.relu,
kernel_initializer=normc_initializer(1.0))(self.inputs)
layer_out = tf.keras.layers.Dense(
num_outputs,
name="out",
activation=None,
kernel_initializer=normc_initializer(0.01))(layer_1)
if self.model_config["vf_share_layers"]:
value_out = tf.keras.layers.Dense(
1,
name="value",
activation=None,
kernel_initializer=normc_initializer(0.01))(layer_1)
self.base_model = tf.keras.Model(self.inputs,
[layer_out, value_out])
else:
self.base_model = tf.keras.Model(self.inputs, layer_out)
self.register_variables(self.base_model.variables)
def forward(self, input_dict, state, seq_lens):
if self.model_config["vf_share_layers"]:
model_out, self._value_out = self.base_model(input_dict["obs"])
else:
model_out = self.base_model(input_dict["obs"])
self._value_out = tf.zeros(
shape=(tf.shape(input_dict["obs"])[0], ))
return model_out, state
def value_function(self):
return tf.reshape(self._value_out, [-1])
def import_from_h5(self, import_file):
# Override this to define custom weight loading behavior from h5 files.
self.base_model.load_weights(import_file)
class MyTorchModel(TorchModelV2, nn.Module):
"""Generic vision network."""
def __init__(self, obs_space, action_space, num_outputs, model_config,
name):
TorchModelV2.__init__(self, obs_space, action_space, num_outputs,
model_config, name)
nn.Module.__init__(self)
self.device = torch.device("cuda"
if torch.cuda.is_available() else "cpu")
self.layer_1 = nn.Linear(obs_space.shape[0], 16).to(self.device)
self.layer_out = nn.Linear(16, num_outputs).to(self.device)
self.value_branch = nn.Linear(16, 1).to(self.device)
self.cur_value = None
def forward(self, input_dict, state, seq_lens):
layer_1_out = self.layer_1(input_dict["obs"])
logits = self.layer_out(layer_1_out)
self.cur_value = self.value_branch(layer_1_out).squeeze(1)
return logits, state
def value_function(self):
assert self.cur_value is not None, "Must call `forward()` first!"
return self.cur_value
def import_from_h5(self, import_file):
# Override this to define custom weight loading behavior from h5 files.
f = h5py.File(import_file)
self.layer_1.load_state_dict({
"weight": torch.Tensor(
np.transpose(f["layer1"]["default_policy"]["layer1"][
"kernel:0"].value)),
"bias": torch.Tensor(
np.transpose(
f["layer1"]["default_policy"]["layer1"]["bias:0"].value)),
})
self.layer_out.load_state_dict({
"weight": torch.Tensor(
np.transpose(
f["out"]["default_policy"]["out"]["kernel:0"].value)),
"bias": torch.Tensor(
np.transpose(
f["out"]["default_policy"]["out"]["bias:0"].value)),
})
self.value_branch.load_state_dict({
"weight": torch.Tensor(
np.transpose(
f["value"]["default_policy"]["value"]["kernel:0"].value)),
"bias": torch.Tensor(
np.transpose(
f["value"]["default_policy"]["value"]["bias:0"].value)),
})
def model_import_test(algo, config, env):
# Get the abs-path to use (bazel-friendly).
rllib_dir = Path(__file__).parent.parent
import_file = str(rllib_dir) + "/tests/data/model_weights/weights.h5"
agent_cls = get_agent_class(algo)
for fw in ["tf", "torch"]:
print("framework={}".format(fw))
config["use_pytorch"] = fw == "torch"
config["eager"] = fw == "eager"
config["model"]["custom_model"] = "keras_model" if fw != "torch" else \
"torch_model"
eager_mode_ctx = None
if fw == "eager":
eager_mode_ctx = eager_mode()
eager_mode_ctx.__enter__()
assert tf.executing_eagerly()
elif fw == "tf":
assert not tf.executing_eagerly()
agent = agent_cls(config, env)
def current_weight(agent):
if fw == "tf":
return agent.get_weights()["default_policy"][
"default_policy/value/kernel"][0]
elif fw == "torch":
return float(agent.get_weights()["default_policy"][
"value_branch.weight"][0][0])
else:
return agent.get_weights()["default_policy"][4][0]
# Import weights for our custom model from an h5 file.
weight_before_import = current_weight(agent)
agent.import_model(import_file=import_file)
weight_after_import = current_weight(agent)
check(weight_before_import, weight_after_import, false=True)
# Train for a while.
for _ in range(1):
agent.train()
weight_after_train = current_weight(agent)
# Weights should have changed.
check(weight_before_import, weight_after_train, false=True)
check(weight_after_import, weight_after_train, false=True)
# We can save the entire Agent and restore, weights should remain the
# same.
file = agent.save("after_train")
check(weight_after_train, current_weight(agent))
agent.restore(file)
check(weight_after_train, current_weight(agent))
# Import (untrained) weights again.
agent.import_model(import_file=import_file)
check(current_weight(agent), weight_after_import)
if eager_mode_ctx:
eager_mode_ctx.__exit__(None, None, None)
class TestModelImport(unittest.TestCase):
def setUp(self):
ray.init()
ModelCatalog.register_custom_model("keras_model", MyKerasModel)
ModelCatalog.register_custom_model("torch_model", MyTorchModel)
def tearDown(self):
ray.shutdown()
def test_ppo(self):
model_import_test(
"PPO",
config={
"num_workers": 0,
"vf_share_layers": True,
"model": {}
},
env="CartPole-v0")
if __name__ == "__main__":
import pytest
import sys
sys.exit(pytest.main(["-v", __file__]))