mirror of
https://github.com/wassname/ray.git
synced 2026-06-27 16:31:25 +08:00
[RLlib] Make TFModelV2 behave more like TorchModelV2: Obsolete register_variables. Unify variable dicts. (#13339)
This commit is contained in:
@@ -1089,6 +1089,13 @@ py_test(
|
||||
srcs = ["models/tests/test_distributions.py"]
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "test_models",
|
||||
tags = ["models"],
|
||||
size = "small",
|
||||
srcs = ["models/tests/test_models.py"]
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "test_preprocessors",
|
||||
tags = ["models"],
|
||||
|
||||
@@ -89,7 +89,6 @@ class DDPGTFModel(TFModelV2):
|
||||
actor_out = tf.keras.layers.Lambda(lambda_)(actor_out)
|
||||
|
||||
self.policy_model = tf.keras.Model(self.model_out, actor_out)
|
||||
self.register_variables(self.policy_model.variables)
|
||||
|
||||
# Build the Q-model(s).
|
||||
self.actions_input = tf.keras.layers.Input(
|
||||
@@ -116,12 +115,10 @@ class DDPGTFModel(TFModelV2):
|
||||
return q_net
|
||||
|
||||
self.q_model = build_q_net("q", self.model_out, self.actions_input)
|
||||
self.register_variables(self.q_model.variables)
|
||||
|
||||
if twin_q:
|
||||
self.twin_q_model = build_q_net("twin_q", self.model_out,
|
||||
self.actions_input)
|
||||
self.register_variables(self.twin_q_model.variables)
|
||||
else:
|
||||
self.twin_q_model = None
|
||||
|
||||
|
||||
@@ -162,13 +162,11 @@ class DistributionalQTFModel(TFModelV2):
|
||||
|
||||
q_out = build_action_value(name + "/action_value/", self.model_out)
|
||||
self.q_value_head = tf.keras.Model(self.model_out, q_out)
|
||||
self.register_variables(self.q_value_head.variables)
|
||||
|
||||
if dueling:
|
||||
state_out = build_state_score(name + "/state_value/",
|
||||
self.model_out)
|
||||
self.state_value_head = tf.keras.Model(self.model_out, state_out)
|
||||
self.register_variables(self.state_value_head.variables)
|
||||
|
||||
def get_q_value_distributions(self,
|
||||
model_out: TensorType) -> List[TensorType]:
|
||||
|
||||
@@ -91,8 +91,6 @@ class SACTFModel(TFModelV2):
|
||||
])
|
||||
self.shift_and_log_scale_diag = self.action_model(self.model_out)
|
||||
|
||||
self.register_variables(self.action_model.variables)
|
||||
|
||||
self.actions_input = None
|
||||
if not self.discrete:
|
||||
self.actions_input = tf.keras.layers.Input(
|
||||
@@ -123,12 +121,10 @@ class SACTFModel(TFModelV2):
|
||||
return q_net
|
||||
|
||||
self.q_net = build_q_net("q", self.model_out, self.actions_input)
|
||||
self.register_variables(self.q_net.variables)
|
||||
|
||||
if twin_q:
|
||||
self.twin_q_net = build_q_net("twin_q", self.model_out,
|
||||
self.actions_input)
|
||||
self.register_variables(self.twin_q_net.variables)
|
||||
else:
|
||||
self.twin_q_net = None
|
||||
|
||||
@@ -147,8 +143,6 @@ class SACTFModel(TFModelV2):
|
||||
target_entropy = -np.prod(action_space.shape)
|
||||
self.target_entropy = target_entropy
|
||||
|
||||
self.register_variables([self.log_alpha])
|
||||
|
||||
def get_q_values(self,
|
||||
model_out: TensorType,
|
||||
actions: Optional[TensorType] = None) -> TensorType:
|
||||
|
||||
@@ -33,6 +33,7 @@ if __name__ == "__main__":
|
||||
"model": {
|
||||
"custom_model": "bn_model",
|
||||
},
|
||||
"lr": 0.0003,
|
||||
# Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
|
||||
"num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")),
|
||||
"num_workers": 0,
|
||||
@@ -45,7 +46,7 @@ if __name__ == "__main__":
|
||||
"episode_reward_mean": args.stop_reward,
|
||||
}
|
||||
|
||||
results = tune.run(args.run, stop=stop, config=config, verbose=1)
|
||||
results = tune.run(args.run, stop=stop, config=config, verbose=2)
|
||||
|
||||
if args.as_test:
|
||||
check_learning_achieved(results, args.stop_reward)
|
||||
|
||||
@@ -71,7 +71,6 @@ class CustomModel(TFModelV2):
|
||||
model_config, name)
|
||||
self.model = FullyConnectedNetwork(obs_space, action_space,
|
||||
num_outputs, model_config, name)
|
||||
self.register_variables(self.model.variables())
|
||||
|
||||
def forward(self, input_dict, state, seq_lens):
|
||||
return self.model.forward(input_dict, state, seq_lens)
|
||||
|
||||
@@ -48,7 +48,6 @@ class MyKerasModel(TFModelV2):
|
||||
activation=None,
|
||||
kernel_initializer=normc_initializer(0.01))(layer_1)
|
||||
self.base_model = tf.keras.Model(self.inputs, [layer_out, value_out])
|
||||
self.register_variables(self.base_model.variables)
|
||||
|
||||
def forward(self, input_dict, state, seq_lens):
|
||||
model_out, self._value_out = self.base_model(input_dict["obs"])
|
||||
@@ -84,7 +83,6 @@ class MyKerasQModel(DistributionalQTFModel):
|
||||
activation=tf.nn.relu,
|
||||
kernel_initializer=normc_initializer(1.0))(layer_1)
|
||||
self.base_model = tf.keras.Model(self.inputs, layer_out)
|
||||
self.register_variables(self.base_model.variables)
|
||||
|
||||
# Implement the core forward method
|
||||
def forward(self, input_dict, state, seq_lens):
|
||||
|
||||
@@ -69,14 +69,12 @@ class AutoregressiveActionModel(TFModelV2):
|
||||
|
||||
# Base layers
|
||||
self.base_model = tf.keras.Model(obs_input, [context, value_out])
|
||||
self.register_variables(self.base_model.variables)
|
||||
self.base_model.summary()
|
||||
|
||||
# Autoregressive action sampler
|
||||
self.action_model = tf.keras.Model([ctx_input, a1_input],
|
||||
[a1_logits, a2_logits])
|
||||
self.action_model.summary()
|
||||
self.register_variables(self.action_model.variables)
|
||||
|
||||
def forward(self, input_dict, state, seq_lens):
|
||||
context, self._value_out = self.base_model(input_dict["obs"])
|
||||
|
||||
@@ -123,7 +123,6 @@ class KerasBatchNormModel(TFModelV2):
|
||||
|
||||
self.base_model = tf.keras.models.Model(
|
||||
inputs=[inputs, is_training], outputs=[output, value_out])
|
||||
self.register_variables(self.base_model.variables)
|
||||
|
||||
@override(ModelV2)
|
||||
def forward(self, input_dict, state, seq_lens):
|
||||
|
||||
@@ -23,7 +23,6 @@ class CentralizedCriticModel(TFModelV2):
|
||||
# Base of the model
|
||||
self.model = FullyConnectedNetwork(obs_space, action_space,
|
||||
num_outputs, model_config, name)
|
||||
self.register_variables(self.model.variables())
|
||||
|
||||
# Central VF maps (obs, opp_obs, opp_act) -> vf_pred
|
||||
obs = tf.keras.layers.Input(shape=(6, ), name="obs")
|
||||
@@ -37,7 +36,6 @@ class CentralizedCriticModel(TFModelV2):
|
||||
1, activation=None, name="c_vf_out")(central_vf_dense)
|
||||
self.central_vf = tf.keras.Model(
|
||||
inputs=[obs, opp_obs, opp_act], outputs=central_vf_out)
|
||||
self.register_variables(self.central_vf.variables)
|
||||
|
||||
@override(ModelV2)
|
||||
def forward(self, input_dict, state, seq_lens):
|
||||
@@ -79,11 +77,9 @@ class YetAnotherCentralizedCriticModel(TFModelV2):
|
||||
num_outputs,
|
||||
model_config,
|
||||
name + "_action")
|
||||
self.register_variables(self.action_model.variables())
|
||||
|
||||
self.value_model = FullyConnectedNetwork(obs_space, action_space, 1,
|
||||
model_config, name + "_vf")
|
||||
self.register_variables(self.value_model.variables())
|
||||
|
||||
def forward(self, input_dict, state, seq_lens):
|
||||
self._value_out, _ = self.value_model({
|
||||
|
||||
@@ -53,7 +53,6 @@ class CNNPlusFCConcatModel(TFModelV2):
|
||||
name="cnn_{}".format(i))
|
||||
concat_size += cnn.num_outputs
|
||||
self.cnns[i] = cnn
|
||||
self.register_variables(cnn.variables())
|
||||
# Discrete inputs -> One-hot encode.
|
||||
elif isinstance(component, Discrete):
|
||||
concat_size += component.n
|
||||
@@ -82,7 +81,6 @@ class CNNPlusFCConcatModel(TFModelV2):
|
||||
kernel_initializer=normc_initializer(0.01))(concat_layer)
|
||||
self.logits_and_value_model = tf.keras.models.Model(
|
||||
concat_layer, [logits_layer, value_layer])
|
||||
self.register_variables(self.logits_and_value_model.variables)
|
||||
else:
|
||||
self.num_outputs = concat_size
|
||||
|
||||
|
||||
@@ -29,7 +29,6 @@ class CustomLossModel(TFModelV2):
|
||||
num_outputs,
|
||||
model_config,
|
||||
name="fcnet")
|
||||
self.register_variables(self.fcnet.variables())
|
||||
|
||||
@override(ModelV2)
|
||||
def forward(self, input_dict, state, seq_lens):
|
||||
|
||||
@@ -40,7 +40,6 @@ class EagerModel(TFModelV2):
|
||||
|
||||
out = tf.keras.layers.Lambda(lambda_)(out)
|
||||
self.base_model = tf.keras.models.Model(inputs, [out, value_out])
|
||||
self.register_variables(self.base_model.variables)
|
||||
|
||||
@override(ModelV2)
|
||||
def forward(self, input_dict, state, seq_lens):
|
||||
|
||||
@@ -68,7 +68,6 @@ class MobileV2PlusRNNModel(RecurrentNetwork):
|
||||
self.rnn_model = tf.keras.Model(
|
||||
inputs=[inputs, seq_in, state_in_h, state_in_c],
|
||||
outputs=[logits, values, state_h, state_c])
|
||||
self.register_variables(self.rnn_model.variables)
|
||||
self.rnn_model.summary()
|
||||
|
||||
@override(RecurrentNetwork)
|
||||
|
||||
@@ -36,7 +36,6 @@ class ParametricActionsModel(DistributionalQTFModel):
|
||||
self.action_embed_model = FullyConnectedNetwork(
|
||||
Box(-1, 1, shape=true_obs_shape), action_space, action_embed_size,
|
||||
model_config, name + "_action_embed")
|
||||
self.register_variables(self.action_embed_model.variables())
|
||||
|
||||
def forward(self, input_dict, state, seq_lens):
|
||||
# Extract the available actions tensor from the observation.
|
||||
|
||||
@@ -54,7 +54,6 @@ class RNNModel(RecurrentNetwork):
|
||||
self.rnn_model = tf.keras.Model(
|
||||
inputs=[input_layer, seq_in, state_in_h, state_in_c],
|
||||
outputs=[logits, values, state_h, state_c])
|
||||
self.register_variables(self.rnn_model.variables)
|
||||
self.rnn_model.summary()
|
||||
|
||||
@override(RecurrentNetwork)
|
||||
|
||||
@@ -107,7 +107,6 @@ class RNNSpyModel(RecurrentNetwork):
|
||||
[inputs, seq_lens, state_in_h, state_in_c],
|
||||
[logits, value_out, state_out_h, state_out_c])
|
||||
self.base_model.summary()
|
||||
self.register_variables(self.base_model.variables)
|
||||
|
||||
@override(RecurrentNetwork)
|
||||
def forward_rnn(self, inputs, state, seq_lens):
|
||||
|
||||
@@ -39,7 +39,6 @@ class TF2SharedWeightsModel(TFModelV2):
|
||||
vf = tf.keras.layers.Dense(
|
||||
units=1, activation=None, name="value_out")(last_layer)
|
||||
self.base_model = tf.keras.models.Model(inputs, [output, vf])
|
||||
self.register_variables(self.base_model.variables)
|
||||
|
||||
@override(ModelV2)
|
||||
def forward(self, input_dict, state, seq_lens):
|
||||
@@ -80,7 +79,6 @@ class SharedWeightsModel1(TFModelV2):
|
||||
vf = tf.keras.layers.Dense(
|
||||
units=1, activation=None, name="value_out")(last_layer)
|
||||
self.base_model = tf.keras.models.Model(inputs, [output, vf])
|
||||
self.register_variables(self.base_model.variables)
|
||||
|
||||
@override(ModelV2)
|
||||
def forward(self, input_dict, state, seq_lens):
|
||||
@@ -114,7 +112,6 @@ class SharedWeightsModel2(TFModelV2):
|
||||
vf = tf.keras.layers.Dense(
|
||||
units=1, activation=None, name="value_out")(last_layer)
|
||||
self.base_model = tf.keras.models.Model(inputs, [output, vf])
|
||||
self.register_variables(self.base_model.variables)
|
||||
|
||||
@override(ModelV2)
|
||||
def forward(self, input_dict, state, seq_lens):
|
||||
|
||||
@@ -47,7 +47,6 @@ class CustomTFRPGModel(TFModelV2):
|
||||
name)
|
||||
self.model = TFFCNet(obs_space, action_space, num_outputs,
|
||||
model_config, name)
|
||||
self.register_variables(self.model.variables())
|
||||
|
||||
def forward(self, input_dict, state, seq_lens):
|
||||
# The unpacked input tensors, where M=MAX_PLAYERS, N=MAX_ITEMS:
|
||||
|
||||
@@ -36,7 +36,6 @@ class FrameStackingCartPoleModel(TFModelV2):
|
||||
out = tf.keras.layers.Dense(self.num_outputs)(layer1)
|
||||
values = tf.keras.layers.Dense(1)(layer1)
|
||||
self.base_model = tf.keras.models.Model([input_], [out, values])
|
||||
self.register_variables(self.base_model.variables)
|
||||
|
||||
self._last_value = None
|
||||
|
||||
|
||||
+24
-14
@@ -378,7 +378,9 @@ class ModelCatalog:
|
||||
if model_config.get("use_lstm") else AttentionWrapper)
|
||||
model_cls._wrapped_forward = forward
|
||||
|
||||
# Track and warn if vars were created but not registered.
|
||||
# Obsolete: Track and warn if vars were created but not
|
||||
# registered. Only still do this, if users do register their
|
||||
# variables. If not (which they shouldn't), don't check here.
|
||||
created = set()
|
||||
|
||||
def track_var_creation(next_creator, **kw):
|
||||
@@ -407,19 +409,27 @@ class ModelCatalog:
|
||||
# Other error -> re-raise.
|
||||
else:
|
||||
raise e
|
||||
registered = set(instance.variables())
|
||||
not_registered = set()
|
||||
for var in created:
|
||||
if var not in registered:
|
||||
not_registered.add(var)
|
||||
if not_registered:
|
||||
raise ValueError(
|
||||
"It looks like variables {} were created as part "
|
||||
"of {} but does not appear in model.variables() "
|
||||
"({}). Did you forget to call "
|
||||
"model.register_variables() on the variables in "
|
||||
"question?".format(not_registered, instance,
|
||||
registered))
|
||||
|
||||
# User still registered TFModelV2's variables: Check, whether
|
||||
# ok.
|
||||
registered = set(instance.var_list)
|
||||
if len(registered) > 0:
|
||||
not_registered = set()
|
||||
for var in created:
|
||||
if var not in registered:
|
||||
not_registered.add(var)
|
||||
if not_registered:
|
||||
raise ValueError(
|
||||
"It looks like you are still using "
|
||||
"`{}.register_variables()` to register your "
|
||||
"model's weights. This is no longer required, but "
|
||||
"if you are still calling this method at least "
|
||||
"once, you must make sure to register all created "
|
||||
"variables properly. The missing variables are {},"
|
||||
" and you only registered {}. "
|
||||
"Did you forget to call `register_variables()` on "
|
||||
"some of the variables in question?".format(
|
||||
instance, not_registered, registered))
|
||||
elif framework == "torch":
|
||||
# Try wrapping custom model with LSTM/attention, if required.
|
||||
if model_config.get("use_lstm") or \
|
||||
|
||||
@@ -94,8 +94,7 @@ class TestDistributions(unittest.TestCase):
|
||||
|
||||
inputs = inputs_space.sample()
|
||||
|
||||
for fw, sess in framework_iterator(
|
||||
session=True, frameworks=("tf", "tf2", "torch")):
|
||||
for fw, sess in framework_iterator(session=True):
|
||||
# Create the correct distribution object.
|
||||
cls = JAXCategorical if fw == "jax" else Categorical if \
|
||||
fw != "torch" else TorchCategorical
|
||||
@@ -218,8 +217,7 @@ class TestDistributions(unittest.TestCase):
|
||||
input_space = Box(-2.0, 2.0, shape=(2000, 10))
|
||||
low, high = -2.0, 1.0
|
||||
|
||||
for fw, sess in framework_iterator(
|
||||
frameworks=("torch", "tf", "tfe"), session=True):
|
||||
for fw, sess in framework_iterator(session=True):
|
||||
cls = SquashedGaussian if fw != "torch" else TorchSquashedGaussian
|
||||
|
||||
# Do a stability test using extreme NN outputs to see whether
|
||||
@@ -310,8 +308,7 @@ class TestDistributions(unittest.TestCase):
|
||||
"""Tests the DiagGaussian ActionDistribution for all frameworks."""
|
||||
input_space = Box(-2.0, 1.0, shape=(2000, 10))
|
||||
|
||||
for fw, sess in framework_iterator(
|
||||
frameworks=("torch", "tf", "tfe"), session=True):
|
||||
for fw, sess in framework_iterator(session=True):
|
||||
cls = DiagGaussian if fw != "torch" else TorchDiagGaussian
|
||||
|
||||
# Do a stability test using extreme NN outputs to see whether
|
||||
|
||||
@@ -0,0 +1,59 @@
|
||||
from gym.spaces import Box
|
||||
import numpy as np
|
||||
import unittest
|
||||
|
||||
from ray.rllib.models.tf.tf_modelv2 import TFModelV2
|
||||
from ray.rllib.models.tf.fcnet import FullyConnectedNetwork
|
||||
from ray.rllib.utils.framework import try_import_tf
|
||||
|
||||
tf1, tf, tfv = try_import_tf()
|
||||
|
||||
|
||||
class TestTFModel(TFModelV2):
|
||||
def __init__(self, obs_space, action_space, num_outputs, model_config,
|
||||
name):
|
||||
super().__init__(obs_space, action_space, num_outputs, model_config,
|
||||
name)
|
||||
input_ = tf.keras.layers.Input(shape=(3, ))
|
||||
output = tf.keras.layers.Dense(2)(input_)
|
||||
# A keras model inside.
|
||||
self.keras_model = tf.keras.models.Model([input_], [output])
|
||||
# A RLlib FullyConnectedNetwork (tf) inside (which is also a keras
|
||||
# Model).
|
||||
self.fc_net = FullyConnectedNetwork(obs_space, action_space, 3, {},
|
||||
"fc1")
|
||||
|
||||
def forward(self, input_dict, state, seq_lens):
|
||||
obs = input_dict["obs_flat"]
|
||||
out1 = self.keras_model(obs)
|
||||
out2, _ = self.fc_net({"obs": obs})
|
||||
return tf.concat([out1, out2], axis=-1), []
|
||||
|
||||
|
||||
class TestModels(unittest.TestCase):
|
||||
"""Tests ModelV2 classes and their modularization capabilities."""
|
||||
|
||||
def test_tf_modelv2(self):
|
||||
obs_space = Box(-1.0, 1.0, (3, ))
|
||||
action_space = Box(-1.0, 1.0, (2, ))
|
||||
my_tf_model = TestTFModel(obs_space, action_space, 5, {},
|
||||
"my_tf_model")
|
||||
# Call the model.
|
||||
out, states = my_tf_model({"obs": np.array([obs_space.sample()])})
|
||||
self.assertTrue(out.shape == (1, 5))
|
||||
self.assertTrue(out.dtype == tf.float32)
|
||||
self.assertTrue(states == [])
|
||||
vars = my_tf_model.variables(as_dict=True)
|
||||
self.assertTrue(len(vars) == 6)
|
||||
self.assertTrue("keras_model.dense.kernel:0" in vars)
|
||||
self.assertTrue("keras_model.dense.bias:0" in vars)
|
||||
self.assertTrue("fc_net.base_model.fc_out.kernel:0" in vars)
|
||||
self.assertTrue("fc_net.base_model.fc_out.bias:0" in vars)
|
||||
self.assertTrue("fc_net.base_model.value_out.kernel:0" in vars)
|
||||
self.assertTrue("fc_net.base_model.value_out.bias:0" in vars)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import pytest
|
||||
import sys
|
||||
sys.exit(pytest.main(["-v", __file__]))
|
||||
@@ -117,7 +117,6 @@ class TrXLNet(RecurrentNetwork):
|
||||
name="logits")(E_out)
|
||||
|
||||
self.base_model = tf.keras.models.Model([inputs], [logits])
|
||||
self.register_variables(self.base_model.variables)
|
||||
|
||||
@override(RecurrentNetwork)
|
||||
def forward_rnn(self, inputs: TensorType, state: List[TensorType],
|
||||
@@ -287,7 +286,6 @@ class GTrXLNet(RecurrentNetwork):
|
||||
self.trxl_model = tf.keras.Model(
|
||||
inputs=[input_layer] + memory_ins, outputs=outs + memory_outs[:-1])
|
||||
|
||||
self.register_variables(self.trxl_model.variables)
|
||||
self.trxl_model.summary()
|
||||
|
||||
# __sphinx_doc_begin__
|
||||
@@ -386,7 +384,6 @@ class AttentionWrapper(TFModelV2):
|
||||
position_wise_mlp_dim=cfg["attention_position_wise_mlp_dim"],
|
||||
init_gru_gate_bias=cfg["attention_init_gru_gate_bias"],
|
||||
)
|
||||
self.register_variables(self.gtrxl.variables())
|
||||
|
||||
# `self.num_outputs` right now is the number of nodes coming from the
|
||||
# attention net.
|
||||
@@ -399,11 +396,9 @@ class AttentionWrapper(TFModelV2):
|
||||
# values.
|
||||
out = tf.keras.layers.Dense(self.num_outputs, activation=None)(input_)
|
||||
self._logits_branch = tf.keras.models.Model([input_], [out])
|
||||
self.register_variables(self._logits_branch.variables)
|
||||
|
||||
out = tf.keras.layers.Dense(1, activation=None)(input_)
|
||||
self._value_branch = tf.keras.models.Model([input_], [out])
|
||||
self.register_variables(self._value_branch.variables)
|
||||
|
||||
self.view_requirements = self.gtrxl.view_requirements
|
||||
|
||||
|
||||
@@ -33,7 +33,6 @@ class FullyConnectedNetwork(TFModelV2):
|
||||
num_outputs = num_outputs // 2
|
||||
self.log_std_var = tf.Variable(
|
||||
[0.0] * num_outputs, dtype=tf.float32, name="log_std")
|
||||
self.register_variables([self.log_std_var])
|
||||
|
||||
# We are using obs_flat, so take the flattened shape as input.
|
||||
inputs = tf.keras.layers.Input(
|
||||
@@ -115,7 +114,6 @@ class FullyConnectedNetwork(TFModelV2):
|
||||
self.base_model = tf.keras.Model(
|
||||
inputs, [(logits_out
|
||||
if logits_out is not None else last_layer), value_out])
|
||||
self.register_variables(self.base_model.variables)
|
||||
|
||||
def forward(self, input_dict: Dict[str, TensorType],
|
||||
state: List[TensorType],
|
||||
|
||||
@@ -50,7 +50,6 @@ class RecurrentNetwork(TFModelV2):
|
||||
self.rnn_model = tf.keras.Model(
|
||||
inputs=[input_layer, seq_in, state_in_h, state_in_c],
|
||||
outputs=[output_layer, state_h, state_c])
|
||||
self.register_variables(self.rnn_model.variables)
|
||||
self.rnn_model.summary()
|
||||
"""
|
||||
|
||||
@@ -179,7 +178,6 @@ class LSTMWrapper(RecurrentNetwork):
|
||||
self._rnn_model = tf.keras.Model(
|
||||
inputs=[input_layer, seq_in, state_in_h, state_in_c],
|
||||
outputs=[logits, values, state_h, state_c])
|
||||
self.register_variables(self._rnn_model.variables)
|
||||
self._rnn_model.summary()
|
||||
|
||||
# Add prev-a/r to this model's view, if required.
|
||||
|
||||
@@ -1,9 +1,12 @@
|
||||
import contextlib
|
||||
import gym
|
||||
import re
|
||||
from typing import List
|
||||
|
||||
from ray.util import log_once
|
||||
from ray.rllib.models.modelv2 import ModelV2
|
||||
from ray.rllib.utils.annotations import override, PublicAPI
|
||||
from ray.rllib.utils.deprecation import deprecation_warning
|
||||
from ray.rllib.utils.framework import try_import_tf
|
||||
from ray.rllib.utils.typing import ModelConfigDict, TensorType
|
||||
|
||||
@@ -12,7 +15,7 @@ tf1, tf, tfv = try_import_tf()
|
||||
|
||||
@PublicAPI
|
||||
class TFModelV2(ModelV2):
|
||||
"""TF version of ModelV2.
|
||||
"""TF version of ModelV2, which is always also a keras Model.
|
||||
|
||||
Note that this class by itself is not a valid model unless you
|
||||
implement forward() in a subclass."""
|
||||
@@ -33,18 +36,18 @@ class TFModelV2(ModelV2):
|
||||
value_layer = tf.keras.layers.Dense(...)(hidden_layer)
|
||||
self.base_model = tf.keras.Model(
|
||||
input_layer, [output_layer, value_layer])
|
||||
self.register_variables(self.base_model.variables)
|
||||
"""
|
||||
|
||||
ModelV2.__init__(
|
||||
self,
|
||||
super().__init__(
|
||||
obs_space,
|
||||
action_space,
|
||||
num_outputs,
|
||||
model_config,
|
||||
name,
|
||||
framework="tf")
|
||||
|
||||
# Deprecated: TFModelV2 now automatically track their variables.
|
||||
self.var_list = []
|
||||
|
||||
if tf1.executing_eagerly():
|
||||
self.graph = None
|
||||
else:
|
||||
@@ -65,13 +68,41 @@ class TFModelV2(ModelV2):
|
||||
|
||||
def register_variables(self, variables: List[TensorType]) -> None:
|
||||
"""Register the given list of variables with this model."""
|
||||
if log_once("deprecated_tfmodelv2_register_variables"):
|
||||
deprecation_warning(
|
||||
old="TFModelV2.register_variables", error=False)
|
||||
self.var_list.extend(variables)
|
||||
|
||||
@override(ModelV2)
|
||||
def variables(self, as_dict: bool = False) -> List[TensorType]:
|
||||
if as_dict:
|
||||
return {v.name: v for v in self.var_list}
|
||||
return list(self.var_list)
|
||||
# Old way using `register_variables`.
|
||||
if self.var_list:
|
||||
return {v.name: v for v in self.var_list}
|
||||
# New way: Automatically determine the var tree.
|
||||
else:
|
||||
ret = {}
|
||||
for prop, value in self.__dict__.items():
|
||||
# Keras Model: key=k + "." + var-name (replace '/' by '.').
|
||||
if isinstance(value, tf.keras.models.Model):
|
||||
for var in value.variables:
|
||||
key = prop + "." + re.sub("/", ".", var.name)
|
||||
ret[key] = var
|
||||
# Other TFModelV2: Include its vars into ours.
|
||||
elif isinstance(value, TFModelV2):
|
||||
for key, var in value.variables(as_dict=True).items():
|
||||
ret[prop + "." + key] = var
|
||||
# tf.Variable
|
||||
elif isinstance(value, tf.Variable):
|
||||
ret[prop] = value
|
||||
return ret
|
||||
|
||||
# Old way using `register_variables`.
|
||||
if self.var_list:
|
||||
return list(self.var_list)
|
||||
# New way: Automatically determine the var tree.
|
||||
else:
|
||||
return list(self.variables(as_dict=True).values())
|
||||
|
||||
@override(ModelV2)
|
||||
def trainable_variables(self, as_dict: bool = False) -> List[TensorType]:
|
||||
|
||||
@@ -140,7 +140,6 @@ class VisionNetwork(TFModelV2):
|
||||
lambda x: tf.squeeze(x, axis=[1, 2]))(last_layer)
|
||||
|
||||
self.base_model = tf.keras.Model(inputs, [conv_out, value_out])
|
||||
self.register_variables(self.base_model.variables)
|
||||
|
||||
def forward(self, input_dict: Dict[str, TensorType],
|
||||
state: List[TensorType],
|
||||
|
||||
@@ -11,7 +11,7 @@ _, nn = try_import_torch()
|
||||
|
||||
@PublicAPI
|
||||
class TorchModelV2(ModelV2):
|
||||
"""Torch version of ModelV2.
|
||||
"""Torch version of ModelV2, which is also always a torch.nn.Module.
|
||||
|
||||
Note that this class by itself is not a valid model unless you
|
||||
inherit from nn.Module and implement forward() in a subclass."""
|
||||
|
||||
@@ -49,8 +49,6 @@ class MyKerasModel(TFModelV2):
|
||||
else:
|
||||
self.base_model = tf.keras.Model(self.inputs, layer_out)
|
||||
|
||||
self.register_variables(self.base_model.variables)
|
||||
|
||||
def forward(self, input_dict, state, seq_lens):
|
||||
if self.model_config["vf_share_layers"]:
|
||||
model_out, self._value_out = self.base_model(input_dict["obs"])
|
||||
|
||||
@@ -240,7 +240,6 @@ class DictSpyModel(TFModelV2):
|
||||
self.num_outputs = num_outputs or 64
|
||||
out = tf.keras.layers.Dense(self.num_outputs)(input_)
|
||||
self._main_layer = tf.keras.models.Model([input_], [out])
|
||||
self.register_variables(self._main_layer.variables)
|
||||
|
||||
def forward(self, input_dict, state, seq_lens):
|
||||
def spy(pos, front_cam, task):
|
||||
@@ -282,7 +281,6 @@ class TupleSpyModel(TFModelV2):
|
||||
self.num_outputs = num_outputs or 64
|
||||
out = tf.keras.layers.Dense(self.num_outputs)(input_)
|
||||
self._main_layer = tf.keras.models.Model([input_], [out])
|
||||
self.register_variables(self._main_layer.variables)
|
||||
|
||||
def forward(self, input_dict, state, seq_lens):
|
||||
def spy(pos, cam, task):
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import logging
|
||||
from typing import Optional, Union
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -8,15 +9,18 @@ logger = logging.getLogger(__name__)
|
||||
DEPRECATED_VALUE = -1
|
||||
|
||||
|
||||
def deprecation_warning(old, new=None, error=None):
|
||||
"""
|
||||
Logs (via the `logger` object) or throws a deprecation warning/error.
|
||||
def deprecation_warning(
|
||||
old: str,
|
||||
new: Optional[str] = None,
|
||||
error: Optional[Union[bool, Exception]] = None) -> None:
|
||||
"""Warns (via the `logger` object) or throws a deprecation warning/error.
|
||||
|
||||
Args:
|
||||
old (str): A description of the "thing" that is to be deprecated.
|
||||
new (Optional[str]): A description of the new "thing" that replaces it.
|
||||
error (Optional[Union[bool,Exception]]): Whether or which exception to
|
||||
throw. If True, throw ValueError.
|
||||
error (Optional[Union[bool, Exception]]): Whether or which exception to
|
||||
throw. If True, throw ValueError. If False, just warn.
|
||||
If Exception, throw that Exception.
|
||||
"""
|
||||
msg = "`{}` has been deprecated.{}".format(
|
||||
old, (" Use `{}` instead.".format(new) if new else ""))
|
||||
|
||||
@@ -208,7 +208,6 @@ class Curiosity(Exploration):
|
||||
self._curiosity_feature_net.base_model.variables + \
|
||||
self._curiosity_inverse_fcnet.variables + \
|
||||
self._curiosity_forward_fcnet.variables
|
||||
self.model.register_variables(self._optimizer_var_list)
|
||||
self._optimizer = tf1.train.AdamOptimizer(learning_rate=self.lr)
|
||||
# Create placeholders and initialize the loss.
|
||||
if self.framework == "tf":
|
||||
|
||||
Reference in New Issue
Block a user