diff --git a/rllib/BUILD b/rllib/BUILD index 199cc5ad9..cc048db33 100644 --- a/rllib/BUILD +++ b/rllib/BUILD @@ -1089,6 +1089,13 @@ py_test( srcs = ["models/tests/test_distributions.py"] ) +py_test( + name = "test_models", + tags = ["models"], + size = "small", + srcs = ["models/tests/test_models.py"] +) + py_test( name = "test_preprocessors", tags = ["models"], diff --git a/rllib/agents/ddpg/ddpg_tf_model.py b/rllib/agents/ddpg/ddpg_tf_model.py index cddc38006..a5b3a0ecb 100644 --- a/rllib/agents/ddpg/ddpg_tf_model.py +++ b/rllib/agents/ddpg/ddpg_tf_model.py @@ -89,7 +89,6 @@ class DDPGTFModel(TFModelV2): actor_out = tf.keras.layers.Lambda(lambda_)(actor_out) self.policy_model = tf.keras.Model(self.model_out, actor_out) - self.register_variables(self.policy_model.variables) # Build the Q-model(s). self.actions_input = tf.keras.layers.Input( @@ -116,12 +115,10 @@ class DDPGTFModel(TFModelV2): return q_net self.q_model = build_q_net("q", self.model_out, self.actions_input) - self.register_variables(self.q_model.variables) if twin_q: self.twin_q_model = build_q_net("twin_q", self.model_out, self.actions_input) - self.register_variables(self.twin_q_model.variables) else: self.twin_q_model = None diff --git a/rllib/agents/dqn/distributional_q_tf_model.py b/rllib/agents/dqn/distributional_q_tf_model.py index ed8064999..d921f7b78 100644 --- a/rllib/agents/dqn/distributional_q_tf_model.py +++ b/rllib/agents/dqn/distributional_q_tf_model.py @@ -162,13 +162,11 @@ class DistributionalQTFModel(TFModelV2): q_out = build_action_value(name + "/action_value/", self.model_out) self.q_value_head = tf.keras.Model(self.model_out, q_out) - self.register_variables(self.q_value_head.variables) if dueling: state_out = build_state_score(name + "/state_value/", self.model_out) self.state_value_head = tf.keras.Model(self.model_out, state_out) - self.register_variables(self.state_value_head.variables) def get_q_value_distributions(self, model_out: TensorType) -> List[TensorType]: diff --git a/rllib/agents/sac/sac_tf_model.py b/rllib/agents/sac/sac_tf_model.py index af6d1b539..4c890385f 100644 --- a/rllib/agents/sac/sac_tf_model.py +++ b/rllib/agents/sac/sac_tf_model.py @@ -91,8 +91,6 @@ class SACTFModel(TFModelV2): ]) self.shift_and_log_scale_diag = self.action_model(self.model_out) - self.register_variables(self.action_model.variables) - self.actions_input = None if not self.discrete: self.actions_input = tf.keras.layers.Input( @@ -123,12 +121,10 @@ class SACTFModel(TFModelV2): return q_net self.q_net = build_q_net("q", self.model_out, self.actions_input) - self.register_variables(self.q_net.variables) if twin_q: self.twin_q_net = build_q_net("twin_q", self.model_out, self.actions_input) - self.register_variables(self.twin_q_net.variables) else: self.twin_q_net = None @@ -147,8 +143,6 @@ class SACTFModel(TFModelV2): target_entropy = -np.prod(action_space.shape) self.target_entropy = target_entropy - self.register_variables([self.log_alpha]) - def get_q_values(self, model_out: TensorType, actions: Optional[TensorType] = None) -> TensorType: diff --git a/rllib/examples/batch_norm_model.py b/rllib/examples/batch_norm_model.py index 29f6dbe73..1a2a604ff 100644 --- a/rllib/examples/batch_norm_model.py +++ b/rllib/examples/batch_norm_model.py @@ -33,6 +33,7 @@ if __name__ == "__main__": "model": { "custom_model": "bn_model", }, + "lr": 0.0003, # Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0. "num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")), "num_workers": 0, @@ -45,7 +46,7 @@ if __name__ == "__main__": "episode_reward_mean": args.stop_reward, } - results = tune.run(args.run, stop=stop, config=config, verbose=1) + results = tune.run(args.run, stop=stop, config=config, verbose=2) if args.as_test: check_learning_achieved(results, args.stop_reward) diff --git a/rllib/examples/custom_env.py b/rllib/examples/custom_env.py index 25f5e61e6..b9b060e43 100644 --- a/rllib/examples/custom_env.py +++ b/rllib/examples/custom_env.py @@ -71,7 +71,6 @@ class CustomModel(TFModelV2): model_config, name) self.model = FullyConnectedNetwork(obs_space, action_space, num_outputs, model_config, name) - self.register_variables(self.model.variables()) def forward(self, input_dict, state, seq_lens): return self.model.forward(input_dict, state, seq_lens) diff --git a/rllib/examples/custom_keras_model.py b/rllib/examples/custom_keras_model.py index a277ccd94..03ccb0859 100644 --- a/rllib/examples/custom_keras_model.py +++ b/rllib/examples/custom_keras_model.py @@ -48,7 +48,6 @@ class MyKerasModel(TFModelV2): activation=None, kernel_initializer=normc_initializer(0.01))(layer_1) self.base_model = tf.keras.Model(self.inputs, [layer_out, value_out]) - self.register_variables(self.base_model.variables) def forward(self, input_dict, state, seq_lens): model_out, self._value_out = self.base_model(input_dict["obs"]) @@ -84,7 +83,6 @@ class MyKerasQModel(DistributionalQTFModel): activation=tf.nn.relu, kernel_initializer=normc_initializer(1.0))(layer_1) self.base_model = tf.keras.Model(self.inputs, layer_out) - self.register_variables(self.base_model.variables) # Implement the core forward method def forward(self, input_dict, state, seq_lens): diff --git a/rllib/examples/models/autoregressive_action_model.py b/rllib/examples/models/autoregressive_action_model.py index 5602f9b52..cb8af9ddc 100644 --- a/rllib/examples/models/autoregressive_action_model.py +++ b/rllib/examples/models/autoregressive_action_model.py @@ -69,14 +69,12 @@ class AutoregressiveActionModel(TFModelV2): # Base layers self.base_model = tf.keras.Model(obs_input, [context, value_out]) - self.register_variables(self.base_model.variables) self.base_model.summary() # Autoregressive action sampler self.action_model = tf.keras.Model([ctx_input, a1_input], [a1_logits, a2_logits]) self.action_model.summary() - self.register_variables(self.action_model.variables) def forward(self, input_dict, state, seq_lens): context, self._value_out = self.base_model(input_dict["obs"]) diff --git a/rllib/examples/models/batch_norm_model.py b/rllib/examples/models/batch_norm_model.py index 7d77ebc07..0ad833c89 100644 --- a/rllib/examples/models/batch_norm_model.py +++ b/rllib/examples/models/batch_norm_model.py @@ -123,7 +123,6 @@ class KerasBatchNormModel(TFModelV2): self.base_model = tf.keras.models.Model( inputs=[inputs, is_training], outputs=[output, value_out]) - self.register_variables(self.base_model.variables) @override(ModelV2) def forward(self, input_dict, state, seq_lens): diff --git a/rllib/examples/models/centralized_critic_models.py b/rllib/examples/models/centralized_critic_models.py index 23f1e8b92..7f6e370bf 100644 --- a/rllib/examples/models/centralized_critic_models.py +++ b/rllib/examples/models/centralized_critic_models.py @@ -23,7 +23,6 @@ class CentralizedCriticModel(TFModelV2): # Base of the model self.model = FullyConnectedNetwork(obs_space, action_space, num_outputs, model_config, name) - self.register_variables(self.model.variables()) # Central VF maps (obs, opp_obs, opp_act) -> vf_pred obs = tf.keras.layers.Input(shape=(6, ), name="obs") @@ -37,7 +36,6 @@ class CentralizedCriticModel(TFModelV2): 1, activation=None, name="c_vf_out")(central_vf_dense) self.central_vf = tf.keras.Model( inputs=[obs, opp_obs, opp_act], outputs=central_vf_out) - self.register_variables(self.central_vf.variables) @override(ModelV2) def forward(self, input_dict, state, seq_lens): @@ -79,11 +77,9 @@ class YetAnotherCentralizedCriticModel(TFModelV2): num_outputs, model_config, name + "_action") - self.register_variables(self.action_model.variables()) self.value_model = FullyConnectedNetwork(obs_space, action_space, 1, model_config, name + "_vf") - self.register_variables(self.value_model.variables()) def forward(self, input_dict, state, seq_lens): self._value_out, _ = self.value_model({ diff --git a/rllib/examples/models/cnn_plus_fc_concat_model.py b/rllib/examples/models/cnn_plus_fc_concat_model.py index e8cae19dd..6f8e3d85e 100644 --- a/rllib/examples/models/cnn_plus_fc_concat_model.py +++ b/rllib/examples/models/cnn_plus_fc_concat_model.py @@ -53,7 +53,6 @@ class CNNPlusFCConcatModel(TFModelV2): name="cnn_{}".format(i)) concat_size += cnn.num_outputs self.cnns[i] = cnn - self.register_variables(cnn.variables()) # Discrete inputs -> One-hot encode. elif isinstance(component, Discrete): concat_size += component.n @@ -82,7 +81,6 @@ class CNNPlusFCConcatModel(TFModelV2): kernel_initializer=normc_initializer(0.01))(concat_layer) self.logits_and_value_model = tf.keras.models.Model( concat_layer, [logits_layer, value_layer]) - self.register_variables(self.logits_and_value_model.variables) else: self.num_outputs = concat_size diff --git a/rllib/examples/models/custom_loss_model.py b/rllib/examples/models/custom_loss_model.py index e7c19cb9e..5e0d7b7c2 100644 --- a/rllib/examples/models/custom_loss_model.py +++ b/rllib/examples/models/custom_loss_model.py @@ -29,7 +29,6 @@ class CustomLossModel(TFModelV2): num_outputs, model_config, name="fcnet") - self.register_variables(self.fcnet.variables()) @override(ModelV2) def forward(self, input_dict, state, seq_lens): diff --git a/rllib/examples/models/eager_model.py b/rllib/examples/models/eager_model.py index a20236711..3b3b190b9 100644 --- a/rllib/examples/models/eager_model.py +++ b/rllib/examples/models/eager_model.py @@ -40,7 +40,6 @@ class EagerModel(TFModelV2): out = tf.keras.layers.Lambda(lambda_)(out) self.base_model = tf.keras.models.Model(inputs, [out, value_out]) - self.register_variables(self.base_model.variables) @override(ModelV2) def forward(self, input_dict, state, seq_lens): diff --git a/rllib/examples/models/mobilenet_v2_with_lstm_models.py b/rllib/examples/models/mobilenet_v2_with_lstm_models.py index 8afdbf188..c97778570 100644 --- a/rllib/examples/models/mobilenet_v2_with_lstm_models.py +++ b/rllib/examples/models/mobilenet_v2_with_lstm_models.py @@ -68,7 +68,6 @@ class MobileV2PlusRNNModel(RecurrentNetwork): self.rnn_model = tf.keras.Model( inputs=[inputs, seq_in, state_in_h, state_in_c], outputs=[logits, values, state_h, state_c]) - self.register_variables(self.rnn_model.variables) self.rnn_model.summary() @override(RecurrentNetwork) diff --git a/rllib/examples/models/parametric_actions_model.py b/rllib/examples/models/parametric_actions_model.py index abffbcafd..cbeda9645 100644 --- a/rllib/examples/models/parametric_actions_model.py +++ b/rllib/examples/models/parametric_actions_model.py @@ -36,7 +36,6 @@ class ParametricActionsModel(DistributionalQTFModel): self.action_embed_model = FullyConnectedNetwork( Box(-1, 1, shape=true_obs_shape), action_space, action_embed_size, model_config, name + "_action_embed") - self.register_variables(self.action_embed_model.variables()) def forward(self, input_dict, state, seq_lens): # Extract the available actions tensor from the observation. diff --git a/rllib/examples/models/rnn_model.py b/rllib/examples/models/rnn_model.py index 84729ca3d..5a661f8a6 100644 --- a/rllib/examples/models/rnn_model.py +++ b/rllib/examples/models/rnn_model.py @@ -54,7 +54,6 @@ class RNNModel(RecurrentNetwork): self.rnn_model = tf.keras.Model( inputs=[input_layer, seq_in, state_in_h, state_in_c], outputs=[logits, values, state_h, state_c]) - self.register_variables(self.rnn_model.variables) self.rnn_model.summary() @override(RecurrentNetwork) diff --git a/rllib/examples/models/rnn_spy_model.py b/rllib/examples/models/rnn_spy_model.py index 1b1d95f1e..19009d7ea 100644 --- a/rllib/examples/models/rnn_spy_model.py +++ b/rllib/examples/models/rnn_spy_model.py @@ -107,7 +107,6 @@ class RNNSpyModel(RecurrentNetwork): [inputs, seq_lens, state_in_h, state_in_c], [logits, value_out, state_out_h, state_out_c]) self.base_model.summary() - self.register_variables(self.base_model.variables) @override(RecurrentNetwork) def forward_rnn(self, inputs, state, seq_lens): diff --git a/rllib/examples/models/shared_weights_model.py b/rllib/examples/models/shared_weights_model.py index 4e4c6c32a..8f3bf7ea4 100644 --- a/rllib/examples/models/shared_weights_model.py +++ b/rllib/examples/models/shared_weights_model.py @@ -39,7 +39,6 @@ class TF2SharedWeightsModel(TFModelV2): vf = tf.keras.layers.Dense( units=1, activation=None, name="value_out")(last_layer) self.base_model = tf.keras.models.Model(inputs, [output, vf]) - self.register_variables(self.base_model.variables) @override(ModelV2) def forward(self, input_dict, state, seq_lens): @@ -80,7 +79,6 @@ class SharedWeightsModel1(TFModelV2): vf = tf.keras.layers.Dense( units=1, activation=None, name="value_out")(last_layer) self.base_model = tf.keras.models.Model(inputs, [output, vf]) - self.register_variables(self.base_model.variables) @override(ModelV2) def forward(self, input_dict, state, seq_lens): @@ -114,7 +112,6 @@ class SharedWeightsModel2(TFModelV2): vf = tf.keras.layers.Dense( units=1, activation=None, name="value_out")(last_layer) self.base_model = tf.keras.models.Model(inputs, [output, vf]) - self.register_variables(self.base_model.variables) @override(ModelV2) def forward(self, input_dict, state, seq_lens): diff --git a/rllib/examples/models/simple_rpg_model.py b/rllib/examples/models/simple_rpg_model.py index 072437e5a..8615d8c30 100644 --- a/rllib/examples/models/simple_rpg_model.py +++ b/rllib/examples/models/simple_rpg_model.py @@ -47,7 +47,6 @@ class CustomTFRPGModel(TFModelV2): name) self.model = TFFCNet(obs_space, action_space, num_outputs, model_config, name) - self.register_variables(self.model.variables()) def forward(self, input_dict, state, seq_lens): # The unpacked input tensors, where M=MAX_PLAYERS, N=MAX_ITEMS: diff --git a/rllib/examples/models/trajectory_view_utilizing_models.py b/rllib/examples/models/trajectory_view_utilizing_models.py index 2360be025..41f53d872 100644 --- a/rllib/examples/models/trajectory_view_utilizing_models.py +++ b/rllib/examples/models/trajectory_view_utilizing_models.py @@ -36,7 +36,6 @@ class FrameStackingCartPoleModel(TFModelV2): out = tf.keras.layers.Dense(self.num_outputs)(layer1) values = tf.keras.layers.Dense(1)(layer1) self.base_model = tf.keras.models.Model([input_], [out, values]) - self.register_variables(self.base_model.variables) self._last_value = None diff --git a/rllib/models/catalog.py b/rllib/models/catalog.py index 9638ed44b..a6e7415d4 100644 --- a/rllib/models/catalog.py +++ b/rllib/models/catalog.py @@ -378,7 +378,9 @@ class ModelCatalog: if model_config.get("use_lstm") else AttentionWrapper) model_cls._wrapped_forward = forward - # Track and warn if vars were created but not registered. + # Obsolete: Track and warn if vars were created but not + # registered. Only still do this, if users do register their + # variables. If not (which they shouldn't), don't check here. created = set() def track_var_creation(next_creator, **kw): @@ -407,19 +409,27 @@ class ModelCatalog: # Other error -> re-raise. else: raise e - registered = set(instance.variables()) - not_registered = set() - for var in created: - if var not in registered: - not_registered.add(var) - if not_registered: - raise ValueError( - "It looks like variables {} were created as part " - "of {} but does not appear in model.variables() " - "({}). Did you forget to call " - "model.register_variables() on the variables in " - "question?".format(not_registered, instance, - registered)) + + # User still registered TFModelV2's variables: Check, whether + # ok. + registered = set(instance.var_list) + if len(registered) > 0: + not_registered = set() + for var in created: + if var not in registered: + not_registered.add(var) + if not_registered: + raise ValueError( + "It looks like you are still using " + "`{}.register_variables()` to register your " + "model's weights. This is no longer required, but " + "if you are still calling this method at least " + "once, you must make sure to register all created " + "variables properly. The missing variables are {}," + " and you only registered {}. " + "Did you forget to call `register_variables()` on " + "some of the variables in question?".format( + instance, not_registered, registered)) elif framework == "torch": # Try wrapping custom model with LSTM/attention, if required. if model_config.get("use_lstm") or \ diff --git a/rllib/models/tests/test_distributions.py b/rllib/models/tests/test_distributions.py index 987f76a56..3dd14d0ae 100644 --- a/rllib/models/tests/test_distributions.py +++ b/rllib/models/tests/test_distributions.py @@ -94,8 +94,7 @@ class TestDistributions(unittest.TestCase): inputs = inputs_space.sample() - for fw, sess in framework_iterator( - session=True, frameworks=("tf", "tf2", "torch")): + for fw, sess in framework_iterator(session=True): # Create the correct distribution object. cls = JAXCategorical if fw == "jax" else Categorical if \ fw != "torch" else TorchCategorical @@ -218,8 +217,7 @@ class TestDistributions(unittest.TestCase): input_space = Box(-2.0, 2.0, shape=(2000, 10)) low, high = -2.0, 1.0 - for fw, sess in framework_iterator( - frameworks=("torch", "tf", "tfe"), session=True): + for fw, sess in framework_iterator(session=True): cls = SquashedGaussian if fw != "torch" else TorchSquashedGaussian # Do a stability test using extreme NN outputs to see whether @@ -310,8 +308,7 @@ class TestDistributions(unittest.TestCase): """Tests the DiagGaussian ActionDistribution for all frameworks.""" input_space = Box(-2.0, 1.0, shape=(2000, 10)) - for fw, sess in framework_iterator( - frameworks=("torch", "tf", "tfe"), session=True): + for fw, sess in framework_iterator(session=True): cls = DiagGaussian if fw != "torch" else TorchDiagGaussian # Do a stability test using extreme NN outputs to see whether diff --git a/rllib/models/tests/test_models.py b/rllib/models/tests/test_models.py new file mode 100644 index 000000000..424dea16c --- /dev/null +++ b/rllib/models/tests/test_models.py @@ -0,0 +1,59 @@ +from gym.spaces import Box +import numpy as np +import unittest + +from ray.rllib.models.tf.tf_modelv2 import TFModelV2 +from ray.rllib.models.tf.fcnet import FullyConnectedNetwork +from ray.rllib.utils.framework import try_import_tf + +tf1, tf, tfv = try_import_tf() + + +class TestTFModel(TFModelV2): + def __init__(self, obs_space, action_space, num_outputs, model_config, + name): + super().__init__(obs_space, action_space, num_outputs, model_config, + name) + input_ = tf.keras.layers.Input(shape=(3, )) + output = tf.keras.layers.Dense(2)(input_) + # A keras model inside. + self.keras_model = tf.keras.models.Model([input_], [output]) + # A RLlib FullyConnectedNetwork (tf) inside (which is also a keras + # Model). + self.fc_net = FullyConnectedNetwork(obs_space, action_space, 3, {}, + "fc1") + + def forward(self, input_dict, state, seq_lens): + obs = input_dict["obs_flat"] + out1 = self.keras_model(obs) + out2, _ = self.fc_net({"obs": obs}) + return tf.concat([out1, out2], axis=-1), [] + + +class TestModels(unittest.TestCase): + """Tests ModelV2 classes and their modularization capabilities.""" + + def test_tf_modelv2(self): + obs_space = Box(-1.0, 1.0, (3, )) + action_space = Box(-1.0, 1.0, (2, )) + my_tf_model = TestTFModel(obs_space, action_space, 5, {}, + "my_tf_model") + # Call the model. + out, states = my_tf_model({"obs": np.array([obs_space.sample()])}) + self.assertTrue(out.shape == (1, 5)) + self.assertTrue(out.dtype == tf.float32) + self.assertTrue(states == []) + vars = my_tf_model.variables(as_dict=True) + self.assertTrue(len(vars) == 6) + self.assertTrue("keras_model.dense.kernel:0" in vars) + self.assertTrue("keras_model.dense.bias:0" in vars) + self.assertTrue("fc_net.base_model.fc_out.kernel:0" in vars) + self.assertTrue("fc_net.base_model.fc_out.bias:0" in vars) + self.assertTrue("fc_net.base_model.value_out.kernel:0" in vars) + self.assertTrue("fc_net.base_model.value_out.bias:0" in vars) + + +if __name__ == "__main__": + import pytest + import sys + sys.exit(pytest.main(["-v", __file__])) diff --git a/rllib/models/tf/attention_net.py b/rllib/models/tf/attention_net.py index 4e79eb51a..fadd5ed89 100644 --- a/rllib/models/tf/attention_net.py +++ b/rllib/models/tf/attention_net.py @@ -117,7 +117,6 @@ class TrXLNet(RecurrentNetwork): name="logits")(E_out) self.base_model = tf.keras.models.Model([inputs], [logits]) - self.register_variables(self.base_model.variables) @override(RecurrentNetwork) def forward_rnn(self, inputs: TensorType, state: List[TensorType], @@ -287,7 +286,6 @@ class GTrXLNet(RecurrentNetwork): self.trxl_model = tf.keras.Model( inputs=[input_layer] + memory_ins, outputs=outs + memory_outs[:-1]) - self.register_variables(self.trxl_model.variables) self.trxl_model.summary() # __sphinx_doc_begin__ @@ -386,7 +384,6 @@ class AttentionWrapper(TFModelV2): position_wise_mlp_dim=cfg["attention_position_wise_mlp_dim"], init_gru_gate_bias=cfg["attention_init_gru_gate_bias"], ) - self.register_variables(self.gtrxl.variables()) # `self.num_outputs` right now is the number of nodes coming from the # attention net. @@ -399,11 +396,9 @@ class AttentionWrapper(TFModelV2): # values. out = tf.keras.layers.Dense(self.num_outputs, activation=None)(input_) self._logits_branch = tf.keras.models.Model([input_], [out]) - self.register_variables(self._logits_branch.variables) out = tf.keras.layers.Dense(1, activation=None)(input_) self._value_branch = tf.keras.models.Model([input_], [out]) - self.register_variables(self._value_branch.variables) self.view_requirements = self.gtrxl.view_requirements diff --git a/rllib/models/tf/fcnet.py b/rllib/models/tf/fcnet.py index e556741dd..eea01014d 100644 --- a/rllib/models/tf/fcnet.py +++ b/rllib/models/tf/fcnet.py @@ -33,7 +33,6 @@ class FullyConnectedNetwork(TFModelV2): num_outputs = num_outputs // 2 self.log_std_var = tf.Variable( [0.0] * num_outputs, dtype=tf.float32, name="log_std") - self.register_variables([self.log_std_var]) # We are using obs_flat, so take the flattened shape as input. inputs = tf.keras.layers.Input( @@ -115,7 +114,6 @@ class FullyConnectedNetwork(TFModelV2): self.base_model = tf.keras.Model( inputs, [(logits_out if logits_out is not None else last_layer), value_out]) - self.register_variables(self.base_model.variables) def forward(self, input_dict: Dict[str, TensorType], state: List[TensorType], diff --git a/rllib/models/tf/recurrent_net.py b/rllib/models/tf/recurrent_net.py index fa51b54d0..8618cecc5 100644 --- a/rllib/models/tf/recurrent_net.py +++ b/rllib/models/tf/recurrent_net.py @@ -50,7 +50,6 @@ class RecurrentNetwork(TFModelV2): self.rnn_model = tf.keras.Model( inputs=[input_layer, seq_in, state_in_h, state_in_c], outputs=[output_layer, state_h, state_c]) - self.register_variables(self.rnn_model.variables) self.rnn_model.summary() """ @@ -179,7 +178,6 @@ class LSTMWrapper(RecurrentNetwork): self._rnn_model = tf.keras.Model( inputs=[input_layer, seq_in, state_in_h, state_in_c], outputs=[logits, values, state_h, state_c]) - self.register_variables(self._rnn_model.variables) self._rnn_model.summary() # Add prev-a/r to this model's view, if required. diff --git a/rllib/models/tf/tf_modelv2.py b/rllib/models/tf/tf_modelv2.py index 09625781b..78e1e0276 100644 --- a/rllib/models/tf/tf_modelv2.py +++ b/rllib/models/tf/tf_modelv2.py @@ -1,9 +1,12 @@ import contextlib import gym +import re from typing import List +from ray.util import log_once from ray.rllib.models.modelv2 import ModelV2 from ray.rllib.utils.annotations import override, PublicAPI +from ray.rllib.utils.deprecation import deprecation_warning from ray.rllib.utils.framework import try_import_tf from ray.rllib.utils.typing import ModelConfigDict, TensorType @@ -12,7 +15,7 @@ tf1, tf, tfv = try_import_tf() @PublicAPI class TFModelV2(ModelV2): - """TF version of ModelV2. + """TF version of ModelV2, which is always also a keras Model. Note that this class by itself is not a valid model unless you implement forward() in a subclass.""" @@ -33,18 +36,18 @@ class TFModelV2(ModelV2): value_layer = tf.keras.layers.Dense(...)(hidden_layer) self.base_model = tf.keras.Model( input_layer, [output_layer, value_layer]) - self.register_variables(self.base_model.variables) """ - - ModelV2.__init__( - self, + super().__init__( obs_space, action_space, num_outputs, model_config, name, framework="tf") + + # Deprecated: TFModelV2 now automatically track their variables. self.var_list = [] + if tf1.executing_eagerly(): self.graph = None else: @@ -65,13 +68,41 @@ class TFModelV2(ModelV2): def register_variables(self, variables: List[TensorType]) -> None: """Register the given list of variables with this model.""" + if log_once("deprecated_tfmodelv2_register_variables"): + deprecation_warning( + old="TFModelV2.register_variables", error=False) self.var_list.extend(variables) @override(ModelV2) def variables(self, as_dict: bool = False) -> List[TensorType]: if as_dict: - return {v.name: v for v in self.var_list} - return list(self.var_list) + # Old way using `register_variables`. + if self.var_list: + return {v.name: v for v in self.var_list} + # New way: Automatically determine the var tree. + else: + ret = {} + for prop, value in self.__dict__.items(): + # Keras Model: key=k + "." + var-name (replace '/' by '.'). + if isinstance(value, tf.keras.models.Model): + for var in value.variables: + key = prop + "." + re.sub("/", ".", var.name) + ret[key] = var + # Other TFModelV2: Include its vars into ours. + elif isinstance(value, TFModelV2): + for key, var in value.variables(as_dict=True).items(): + ret[prop + "." + key] = var + # tf.Variable + elif isinstance(value, tf.Variable): + ret[prop] = value + return ret + + # Old way using `register_variables`. + if self.var_list: + return list(self.var_list) + # New way: Automatically determine the var tree. + else: + return list(self.variables(as_dict=True).values()) @override(ModelV2) def trainable_variables(self, as_dict: bool = False) -> List[TensorType]: diff --git a/rllib/models/tf/visionnet.py b/rllib/models/tf/visionnet.py index c2a8de5d2..039ad4389 100644 --- a/rllib/models/tf/visionnet.py +++ b/rllib/models/tf/visionnet.py @@ -140,7 +140,6 @@ class VisionNetwork(TFModelV2): lambda x: tf.squeeze(x, axis=[1, 2]))(last_layer) self.base_model = tf.keras.Model(inputs, [conv_out, value_out]) - self.register_variables(self.base_model.variables) def forward(self, input_dict: Dict[str, TensorType], state: List[TensorType], diff --git a/rllib/models/torch/torch_modelv2.py b/rllib/models/torch/torch_modelv2.py index 39bc336ed..f56cf9978 100644 --- a/rllib/models/torch/torch_modelv2.py +++ b/rllib/models/torch/torch_modelv2.py @@ -11,7 +11,7 @@ _, nn = try_import_torch() @PublicAPI class TorchModelV2(ModelV2): - """Torch version of ModelV2. + """Torch version of ModelV2, which is also always a torch.nn.Module. Note that this class by itself is not a valid model unless you inherit from nn.Module and implement forward() in a subclass.""" diff --git a/rllib/tests/test_model_imports.py b/rllib/tests/test_model_imports.py index b92f5d3a6..405b96b90 100644 --- a/rllib/tests/test_model_imports.py +++ b/rllib/tests/test_model_imports.py @@ -49,8 +49,6 @@ class MyKerasModel(TFModelV2): else: self.base_model = tf.keras.Model(self.inputs, layer_out) - self.register_variables(self.base_model.variables) - def forward(self, input_dict, state, seq_lens): if self.model_config["vf_share_layers"]: model_out, self._value_out = self.base_model(input_dict["obs"]) diff --git a/rllib/tests/test_nested_observation_spaces.py b/rllib/tests/test_nested_observation_spaces.py index 736e69780..1a10e8c71 100644 --- a/rllib/tests/test_nested_observation_spaces.py +++ b/rllib/tests/test_nested_observation_spaces.py @@ -240,7 +240,6 @@ class DictSpyModel(TFModelV2): self.num_outputs = num_outputs or 64 out = tf.keras.layers.Dense(self.num_outputs)(input_) self._main_layer = tf.keras.models.Model([input_], [out]) - self.register_variables(self._main_layer.variables) def forward(self, input_dict, state, seq_lens): def spy(pos, front_cam, task): @@ -282,7 +281,6 @@ class TupleSpyModel(TFModelV2): self.num_outputs = num_outputs or 64 out = tf.keras.layers.Dense(self.num_outputs)(input_) self._main_layer = tf.keras.models.Model([input_], [out]) - self.register_variables(self._main_layer.variables) def forward(self, input_dict, state, seq_lens): def spy(pos, cam, task): diff --git a/rllib/utils/deprecation.py b/rllib/utils/deprecation.py index 05788059b..430d0a66a 100644 --- a/rllib/utils/deprecation.py +++ b/rllib/utils/deprecation.py @@ -1,4 +1,5 @@ import logging +from typing import Optional, Union logger = logging.getLogger(__name__) @@ -8,15 +9,18 @@ logger = logging.getLogger(__name__) DEPRECATED_VALUE = -1 -def deprecation_warning(old, new=None, error=None): - """ - Logs (via the `logger` object) or throws a deprecation warning/error. +def deprecation_warning( + old: str, + new: Optional[str] = None, + error: Optional[Union[bool, Exception]] = None) -> None: + """Warns (via the `logger` object) or throws a deprecation warning/error. Args: old (str): A description of the "thing" that is to be deprecated. new (Optional[str]): A description of the new "thing" that replaces it. - error (Optional[Union[bool,Exception]]): Whether or which exception to - throw. If True, throw ValueError. + error (Optional[Union[bool, Exception]]): Whether or which exception to + throw. If True, throw ValueError. If False, just warn. + If Exception, throw that Exception. """ msg = "`{}` has been deprecated.{}".format( old, (" Use `{}` instead.".format(new) if new else "")) diff --git a/rllib/utils/exploration/curiosity.py b/rllib/utils/exploration/curiosity.py index ec91c53d3..45b75ac22 100644 --- a/rllib/utils/exploration/curiosity.py +++ b/rllib/utils/exploration/curiosity.py @@ -208,7 +208,6 @@ class Curiosity(Exploration): self._curiosity_feature_net.base_model.variables + \ self._curiosity_inverse_fcnet.variables + \ self._curiosity_forward_fcnet.variables - self.model.register_variables(self._optimizer_var_list) self._optimizer = tf1.train.AdamOptimizer(learning_rate=self.lr) # Create placeholders and initialize the loss. if self.framework == "tf":