mirror of
https://github.com/wassname/ray.git
synced 2026-07-05 21:07:06 +08:00
[RLlib] Issue 7046 cannot restore keras model from h5 file. (#7482)
This commit is contained in:
@@ -141,6 +141,18 @@ Here is an example of the basic usage (for a more complete example, see `custom_
|
||||
checkpoint = trainer.save()
|
||||
print("checkpoint saved at", checkpoint)
|
||||
|
||||
# Also, in case you have trained a model outside of ray/RLlib and have created
|
||||
# an h5-file with weight values in it, e.g.
|
||||
# my_keras_model_trained_outside_rllib.save_weights("model.h5")
|
||||
# (see: https://keras.io/models/about-keras-models/)
|
||||
|
||||
# ... you can load the h5-weights into your Trainer's Policy's ModelV2
|
||||
# (tf or torch) by doing:
|
||||
trainer.import_model("my_weights.h5")
|
||||
# NOTE: In order for this to work, your (custom) model needs to implement
|
||||
# the `import_from_h5` method.
|
||||
# See https://github.com/ray-project/ray/blob/master/rllib/tests/test_model_imports.py
|
||||
# for detailed examples for tf- and torch trainers/models.
|
||||
|
||||
.. note::
|
||||
|
||||
|
||||
@@ -473,13 +473,16 @@ class Trainable:
|
||||
export model to local directory.
|
||||
|
||||
Args:
|
||||
export_formats (list): List of formats that should be exported.
|
||||
export_formats (Union[list,str]): Format or list of (str) formats
|
||||
that should be exported.
|
||||
export_dir (str): Optional dir to place the exported model.
|
||||
Defaults to self.logdir.
|
||||
|
||||
Returns:
|
||||
A dict that maps ExportFormats to successfully exported models.
|
||||
"""
|
||||
if isinstance(export_formats, str):
|
||||
export_formats = [export_formats]
|
||||
export_dir = export_dir or self.logdir
|
||||
return self._export_model(export_formats, export_dir)
|
||||
|
||||
|
||||
@@ -48,28 +48,30 @@ class Location:
|
||||
|
||||
|
||||
class ExportFormat:
|
||||
"""Describes the format to export the trial Trainable.
|
||||
"""Describes the format to import/export the trial Trainable.
|
||||
|
||||
This may correspond to different file formats based on the
|
||||
Trainable implementation.
|
||||
"""
|
||||
CHECKPOINT = "checkpoint"
|
||||
MODEL = "model"
|
||||
H5 = "h5"
|
||||
|
||||
@staticmethod
|
||||
def validate(export_formats):
|
||||
"""Validates export_formats.
|
||||
def validate(formats):
|
||||
"""Validates formats.
|
||||
|
||||
Raises:
|
||||
ValueError if the format is unknown.
|
||||
"""
|
||||
for i in range(len(export_formats)):
|
||||
export_formats[i] = export_formats[i].strip().lower()
|
||||
if export_formats[i] not in [
|
||||
ExportFormat.CHECKPOINT, ExportFormat.MODEL
|
||||
for i in range(len(formats)):
|
||||
formats[i] = formats[i].strip().lower()
|
||||
if formats[i] not in [
|
||||
ExportFormat.CHECKPOINT, ExportFormat.MODEL,
|
||||
ExportFormat.H5
|
||||
]:
|
||||
raise TuneError("Unsupported export format: " +
|
||||
export_formats[i])
|
||||
raise TuneError("Unsupported import/export format: " +
|
||||
formats[i])
|
||||
|
||||
|
||||
def checkpoint_deleter(trial_id, runner):
|
||||
|
||||
@@ -1034,6 +1034,14 @@ py_test(
|
||||
srcs = ["tests/test_lstm.py"]
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "tests/test_model_imports",
|
||||
tags = ["tests_dir", "tests_dir_M", "model_imports"],
|
||||
size = "small",
|
||||
data = glob(["tests/data/model_weights/**"]),
|
||||
srcs = ["tests/test_model_imports.py"]
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "tests/test_multi_agent_env",
|
||||
tags = ["tests_dir", "tests_dir_M"],
|
||||
|
||||
@@ -628,6 +628,7 @@ class Trainer(Trainable):
|
||||
checkpoint_path = os.path.join(checkpoint_dir,
|
||||
"checkpoint-{}".format(self.iteration))
|
||||
pickle.dump(self.__getstate__(), open(checkpoint_path, "wb"))
|
||||
|
||||
return checkpoint_path
|
||||
|
||||
@override(Trainable)
|
||||
@@ -866,6 +867,25 @@ class Trainer(Trainable):
|
||||
self.workers.local_worker().export_policy_checkpoint(
|
||||
export_dir, filename_prefix, policy_id)
|
||||
|
||||
@DeveloperAPI
|
||||
def import_policy_model_from_h5(self,
|
||||
import_file,
|
||||
policy_id=DEFAULT_POLICY_ID):
|
||||
"""Imports a policy's model with given policy_id from a local h5 file.
|
||||
|
||||
Arguments:
|
||||
import_file (str): The h5 file to import from.
|
||||
policy_id (string): Optional policy id to import into.
|
||||
|
||||
Example:
|
||||
>>> trainer = MyTrainer()
|
||||
>>> trainer.import_policy_model_from_h5("/tmp/weights.h5")
|
||||
>>> for _ in range(10):
|
||||
>>> trainer.train()
|
||||
"""
|
||||
self.workers.local_worker().import_policy_model_from_h5(
|
||||
import_file, policy_id)
|
||||
|
||||
@DeveloperAPI
|
||||
def collect_metrics(self, selected_workers=None):
|
||||
"""Collects metrics from the remote workers of this agent.
|
||||
@@ -1003,6 +1023,31 @@ class Trainer(Trainable):
|
||||
exported[ExportFormat.MODEL] = path
|
||||
return exported
|
||||
|
||||
def import_model(self, import_file):
|
||||
"""Imports a model from import_file.
|
||||
|
||||
Note: Currently, only h5 files are supported.
|
||||
|
||||
Args:
|
||||
import_file (str): The file to import the model from.
|
||||
|
||||
Returns:
|
||||
A dict that maps ExportFormats to successfully exported models.
|
||||
"""
|
||||
# Check for existence.
|
||||
if not os.path.exists(import_file):
|
||||
raise FileNotFoundError(
|
||||
"`import_file` '{}' does not exist! Can't import Model.".
|
||||
format(import_file))
|
||||
# Get the format of the given file.
|
||||
import_format = "h5" # TODO(sven): Support checkpoint loading.
|
||||
|
||||
ExportFormat.validate([import_format])
|
||||
if import_format != ExportFormat.H5:
|
||||
raise NotImplementedError
|
||||
else:
|
||||
return self.import_policy_model_from_h5(import_file)
|
||||
|
||||
def __getstate__(self):
|
||||
state = {}
|
||||
if hasattr(self, "workers"):
|
||||
|
||||
@@ -782,6 +782,12 @@ class RolloutWorker(EvaluatorInterface, ParallelIteratorWorker):
|
||||
def export_policy_model(self, export_dir, policy_id=DEFAULT_POLICY_ID):
|
||||
self.policy_map[policy_id].export_model(export_dir)
|
||||
|
||||
@DeveloperAPI
|
||||
def import_policy_model_from_h5(self,
|
||||
import_file,
|
||||
policy_id=DEFAULT_POLICY_ID):
|
||||
self.policy_map[policy_id].import_model_from_h5(import_file)
|
||||
|
||||
@DeveloperAPI
|
||||
def export_policy_checkpoint(self,
|
||||
export_dir,
|
||||
|
||||
@@ -192,6 +192,20 @@ class ModelV2:
|
||||
i += 1
|
||||
return self.__call__(input_dict, states, train_batch.get("seq_lens"))
|
||||
|
||||
def import_from_h5(self, h5_file):
|
||||
"""Imports weights from an h5 file.
|
||||
|
||||
Args:
|
||||
h5_file (str): The h5 file name to import weights from.
|
||||
|
||||
Example:
|
||||
>>> trainer = MyTrainer()
|
||||
>>> trainer.import_policy_model_from_h5("/tmp/weights.h5")
|
||||
>>> for _ in range(10):
|
||||
>>> trainer.train()
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def last_output(self):
|
||||
"""Returns the last output returned from calling the model."""
|
||||
return self._last_output
|
||||
|
||||
@@ -46,7 +46,6 @@ class Categorical(TFActionDistribution):
|
||||
@DeveloperAPI
|
||||
def __init__(self, inputs, model=None, temperature=1.0):
|
||||
assert temperature > 0.0, "Categorical `temperature` must be > 0.0!"
|
||||
self.n = inputs.shape[-1]
|
||||
# Allow softmax formula w/ temperature != 1.0:
|
||||
# Divide inputs by temperature.
|
||||
super().__init__(inputs / temperature, model)
|
||||
@@ -104,8 +103,7 @@ class MultiCategorical(TFActionDistribution):
|
||||
@override(ActionDistribution)
|
||||
def deterministic_sample(self):
|
||||
return tf.stack(
|
||||
[cat.deterministic_sample() for cat in self.cats],
|
||||
axis=1)
|
||||
[cat.deterministic_sample() for cat in self.cats], axis=1)
|
||||
|
||||
@override(ActionDistribution)
|
||||
def logp(self, actions):
|
||||
|
||||
@@ -2,16 +2,16 @@
|
||||
|
||||
It supports both traced and non-traced eager execution modes."""
|
||||
|
||||
import logging
|
||||
import functools
|
||||
import logging
|
||||
import numpy as np
|
||||
|
||||
from ray.util.debug import log_once
|
||||
from ray.rllib.evaluation.episode import _flatten_action
|
||||
from ray.rllib.models.catalog import ModelCatalog
|
||||
from ray.rllib.policy.policy import Policy, LEARNER_STATS_KEY
|
||||
from ray.rllib.policy.policy import Policy, LEARNER_STATS_KEY, ACTION_PROB, \
|
||||
ACTION_LOGP
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.policy.policy import ACTION_PROB, ACTION_LOGP
|
||||
from ray.rllib.utils import add_mixins
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.framework import try_import_tf
|
||||
|
||||
@@ -375,6 +375,15 @@ class Policy(metaclass=ABCMeta):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@DeveloperAPI
|
||||
def import_model_from_h5(self, import_file):
|
||||
"""Imports Policy from local file.
|
||||
|
||||
Arguments:
|
||||
import_file (str): Local readable file.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def _create_exploration(self, action_space, config):
|
||||
"""Creates the Policy's Exploration object.
|
||||
|
||||
|
||||
@@ -364,6 +364,14 @@ class TFPolicy(Policy):
|
||||
saver = tf.train.Saver()
|
||||
saver.save(self._sess, save_path)
|
||||
|
||||
@override(Policy)
|
||||
def import_model_from_h5(self, import_file):
|
||||
"""Imports weights into tf model."""
|
||||
# Make sure the session is the right one (see issue #7046).
|
||||
with self._sess.graph.as_default():
|
||||
with self._sess.as_default():
|
||||
return self.model.import_from_h5(import_file)
|
||||
|
||||
@DeveloperAPI
|
||||
def copy(self, existing_inputs):
|
||||
"""Creates a copy of self using existing input placeholders.
|
||||
|
||||
@@ -259,16 +259,21 @@ class TorchPolicy(Policy):
|
||||
|
||||
@override(Policy)
|
||||
def export_model(self, export_dir):
|
||||
"""TODO: implement for torch.
|
||||
"""TODO(sven): implement for torch.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@override(Policy)
|
||||
def export_checkpoint(self, export_dir):
|
||||
"""TODO: implement for torch.
|
||||
"""TODO(sven): implement for torch.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@override(Policy)
|
||||
def import_model_from_h5(self, import_file):
|
||||
"""Imports weights into torch model."""
|
||||
return self.model.import_from_h5(import_file)
|
||||
|
||||
|
||||
@DeveloperAPI
|
||||
class LearningRateSchedule:
|
||||
|
||||
Binary file not shown.
@@ -19,19 +19,9 @@ def get_mean_action(alg, obs):
|
||||
|
||||
|
||||
CONFIGS = {
|
||||
"SAC": {
|
||||
"A3C": {
|
||||
"explore": False,
|
||||
},
|
||||
"ES": {
|
||||
"explore": False,
|
||||
"episodes_per_batch": 10,
|
||||
"train_batch_size": 100,
|
||||
"num_workers": 2,
|
||||
"noise_size": 2500000,
|
||||
"observation_filter": "MeanStdFilter"
|
||||
},
|
||||
"DQN": {
|
||||
"explore": False
|
||||
"num_workers": 1
|
||||
},
|
||||
"APEX_DDPG": {
|
||||
"explore": False,
|
||||
@@ -42,27 +32,37 @@ CONFIGS = {
|
||||
"num_replay_buffer_shards": 1,
|
||||
},
|
||||
},
|
||||
"ARS": {
|
||||
"explore": False,
|
||||
"num_rollouts": 10,
|
||||
"num_workers": 2,
|
||||
"noise_size": 2500000,
|
||||
"observation_filter": "MeanStdFilter"
|
||||
},
|
||||
"DDPG": {
|
||||
"explore": False,
|
||||
"timesteps_per_iteration": 100
|
||||
},
|
||||
"DQN": {
|
||||
"explore": False
|
||||
},
|
||||
"ES": {
|
||||
"explore": False,
|
||||
"episodes_per_batch": 10,
|
||||
"train_batch_size": 100,
|
||||
"num_workers": 2,
|
||||
"noise_size": 2500000,
|
||||
"observation_filter": "MeanStdFilter"
|
||||
},
|
||||
"PPO": {
|
||||
"explore": False,
|
||||
"num_sgd_iter": 5,
|
||||
"train_batch_size": 1000,
|
||||
"num_workers": 2
|
||||
},
|
||||
"A3C": {
|
||||
"SAC": {
|
||||
"explore": False,
|
||||
"num_workers": 1
|
||||
},
|
||||
"ARS": {
|
||||
"explore": False,
|
||||
"num_rollouts": 10,
|
||||
"num_workers": 2,
|
||||
"noise_size": 2500000,
|
||||
"observation_filter": "MeanStdFilter"
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -121,7 +121,7 @@ def export_test(alg_name, failures):
|
||||
else:
|
||||
algo = cls(config=CONFIGS[alg_name], env="CartPole-v0")
|
||||
|
||||
for _ in range(3):
|
||||
for _ in range(2):
|
||||
res = algo.train()
|
||||
print("current status: " + str(res))
|
||||
|
||||
|
||||
@@ -0,0 +1,214 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
import h5py
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
from tensorflow.python.eager.context import eager_mode
|
||||
import unittest
|
||||
|
||||
import ray
|
||||
from ray.rllib.agents.registry import get_agent_class
|
||||
from ray.rllib.models.catalog import ModelCatalog
|
||||
from ray.rllib.models.tf.misc import normc_initializer
|
||||
from ray.rllib.models.tf.tf_modelv2 import TFModelV2
|
||||
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
|
||||
from ray.rllib.utils.framework import try_import_tf, try_import_torch
|
||||
from ray.rllib.utils.test_utils import check
|
||||
|
||||
tf = try_import_tf()
|
||||
torch, nn = try_import_torch()
|
||||
|
||||
|
||||
class MyKerasModel(TFModelV2):
|
||||
"""Custom model for policy gradient algorithms."""
|
||||
|
||||
def __init__(self, obs_space, action_space, num_outputs, model_config,
|
||||
name):
|
||||
super(MyKerasModel, self).__init__(obs_space, action_space,
|
||||
num_outputs, model_config, name)
|
||||
self.inputs = tf.keras.layers.Input(
|
||||
shape=obs_space.shape, name="observations")
|
||||
layer_1 = tf.keras.layers.Dense(
|
||||
16,
|
||||
name="layer1",
|
||||
activation=tf.nn.relu,
|
||||
kernel_initializer=normc_initializer(1.0))(self.inputs)
|
||||
layer_out = tf.keras.layers.Dense(
|
||||
num_outputs,
|
||||
name="out",
|
||||
activation=None,
|
||||
kernel_initializer=normc_initializer(0.01))(layer_1)
|
||||
if self.model_config["vf_share_layers"]:
|
||||
value_out = tf.keras.layers.Dense(
|
||||
1,
|
||||
name="value",
|
||||
activation=None,
|
||||
kernel_initializer=normc_initializer(0.01))(layer_1)
|
||||
self.base_model = tf.keras.Model(self.inputs,
|
||||
[layer_out, value_out])
|
||||
else:
|
||||
self.base_model = tf.keras.Model(self.inputs, layer_out)
|
||||
|
||||
self.register_variables(self.base_model.variables)
|
||||
|
||||
def forward(self, input_dict, state, seq_lens):
|
||||
if self.model_config["vf_share_layers"]:
|
||||
model_out, self._value_out = self.base_model(input_dict["obs"])
|
||||
else:
|
||||
model_out = self.base_model(input_dict["obs"])
|
||||
self._value_out = tf.zeros(
|
||||
shape=(tf.shape(input_dict["obs"])[0], ))
|
||||
return model_out, state
|
||||
|
||||
def value_function(self):
|
||||
return tf.reshape(self._value_out, [-1])
|
||||
|
||||
def import_from_h5(self, import_file):
|
||||
# Override this to define custom weight loading behavior from h5 files.
|
||||
self.base_model.load_weights(import_file)
|
||||
|
||||
|
||||
class MyTorchModel(TorchModelV2, nn.Module):
|
||||
"""Generic vision network."""
|
||||
|
||||
def __init__(self, obs_space, action_space, num_outputs, model_config,
|
||||
name):
|
||||
TorchModelV2.__init__(self, obs_space, action_space, num_outputs,
|
||||
model_config, name)
|
||||
nn.Module.__init__(self)
|
||||
|
||||
self.device = torch.device("cuda"
|
||||
if torch.cuda.is_available() else "cpu")
|
||||
|
||||
self.layer_1 = nn.Linear(obs_space.shape[0], 16).to(self.device)
|
||||
self.layer_out = nn.Linear(16, num_outputs).to(self.device)
|
||||
self.value_branch = nn.Linear(16, 1).to(self.device)
|
||||
self.cur_value = None
|
||||
|
||||
def forward(self, input_dict, state, seq_lens):
|
||||
layer_1_out = self.layer_1(input_dict["obs"])
|
||||
logits = self.layer_out(layer_1_out)
|
||||
self.cur_value = self.value_branch(layer_1_out).squeeze(1)
|
||||
return logits, state
|
||||
|
||||
def value_function(self):
|
||||
assert self.cur_value is not None, "Must call `forward()` first!"
|
||||
return self.cur_value
|
||||
|
||||
def import_from_h5(self, import_file):
|
||||
# Override this to define custom weight loading behavior from h5 files.
|
||||
f = h5py.File(import_file)
|
||||
self.layer_1.load_state_dict({
|
||||
"weight": torch.Tensor(
|
||||
np.transpose(f["layer1"]["default_policy"]["layer1"][
|
||||
"kernel:0"].value)),
|
||||
"bias": torch.Tensor(
|
||||
np.transpose(
|
||||
f["layer1"]["default_policy"]["layer1"]["bias:0"].value)),
|
||||
})
|
||||
self.layer_out.load_state_dict({
|
||||
"weight": torch.Tensor(
|
||||
np.transpose(
|
||||
f["out"]["default_policy"]["out"]["kernel:0"].value)),
|
||||
"bias": torch.Tensor(
|
||||
np.transpose(
|
||||
f["out"]["default_policy"]["out"]["bias:0"].value)),
|
||||
})
|
||||
self.value_branch.load_state_dict({
|
||||
"weight": torch.Tensor(
|
||||
np.transpose(
|
||||
f["value"]["default_policy"]["value"]["kernel:0"].value)),
|
||||
"bias": torch.Tensor(
|
||||
np.transpose(
|
||||
f["value"]["default_policy"]["value"]["bias:0"].value)),
|
||||
})
|
||||
|
||||
|
||||
def model_import_test(algo, config, env):
|
||||
# Get the abs-path to use (bazel-friendly).
|
||||
rllib_dir = Path(__file__).parent.parent
|
||||
import_file = str(rllib_dir) + "/tests/data/model_weights/weights.h5"
|
||||
|
||||
agent_cls = get_agent_class(algo)
|
||||
|
||||
for fw in ["tf", "torch"]:
|
||||
print("framework={}".format(fw))
|
||||
|
||||
config["use_pytorch"] = fw == "torch"
|
||||
config["eager"] = fw == "eager"
|
||||
config["model"]["custom_model"] = "keras_model" if fw != "torch" else \
|
||||
"torch_model"
|
||||
|
||||
eager_mode_ctx = None
|
||||
if fw == "eager":
|
||||
eager_mode_ctx = eager_mode()
|
||||
eager_mode_ctx.__enter__()
|
||||
assert tf.executing_eagerly()
|
||||
elif fw == "tf":
|
||||
assert not tf.executing_eagerly()
|
||||
|
||||
agent = agent_cls(config, env)
|
||||
|
||||
def current_weight(agent):
|
||||
if fw == "tf":
|
||||
return agent.get_weights()["default_policy"][
|
||||
"default_policy/value/kernel"][0]
|
||||
elif fw == "torch":
|
||||
return float(agent.get_weights()["default_policy"][
|
||||
"value_branch.weight"][0][0])
|
||||
else:
|
||||
return agent.get_weights()["default_policy"][4][0]
|
||||
|
||||
# Import weights for our custom model from an h5 file.
|
||||
weight_before_import = current_weight(agent)
|
||||
agent.import_model(import_file=import_file)
|
||||
weight_after_import = current_weight(agent)
|
||||
check(weight_before_import, weight_after_import, false=True)
|
||||
|
||||
# Train for a while.
|
||||
for _ in range(1):
|
||||
agent.train()
|
||||
weight_after_train = current_weight(agent)
|
||||
# Weights should have changed.
|
||||
check(weight_before_import, weight_after_train, false=True)
|
||||
check(weight_after_import, weight_after_train, false=True)
|
||||
|
||||
# We can save the entire Agent and restore, weights should remain the
|
||||
# same.
|
||||
file = agent.save("after_train")
|
||||
check(weight_after_train, current_weight(agent))
|
||||
agent.restore(file)
|
||||
check(weight_after_train, current_weight(agent))
|
||||
|
||||
# Import (untrained) weights again.
|
||||
agent.import_model(import_file=import_file)
|
||||
check(current_weight(agent), weight_after_import)
|
||||
|
||||
if eager_mode_ctx:
|
||||
eager_mode_ctx.__exit__(None, None, None)
|
||||
|
||||
|
||||
class TestModelImport(unittest.TestCase):
|
||||
def setUp(self):
|
||||
ray.init()
|
||||
ModelCatalog.register_custom_model("keras_model", MyKerasModel)
|
||||
ModelCatalog.register_custom_model("torch_model", MyTorchModel)
|
||||
|
||||
def tearDown(self):
|
||||
ray.shutdown()
|
||||
|
||||
def test_ppo(self):
|
||||
model_import_test(
|
||||
"PPO",
|
||||
config={
|
||||
"num_workers": 0,
|
||||
"vf_share_layers": True,
|
||||
"model": {}
|
||||
},
|
||||
env="CartPole-v0")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import pytest
|
||||
import sys
|
||||
sys.exit(pytest.main(["-v", __file__]))
|
||||
Reference in New Issue
Block a user