[RLlib] Make TFModelV2 behave more like TorchModelV2: Obsolete register_variables. Unify variable dicts. (#13339)

This commit is contained in:
Sven Mika
2021-01-11 22:42:30 +01:00
committed by GitHub
parent c43fa12e73
commit e2b2abb88b
33 changed files with 143 additions and 83 deletions
+7
View File
@@ -1089,6 +1089,13 @@ py_test(
srcs = ["models/tests/test_distributions.py"]
)
py_test(
name = "test_models",
tags = ["models"],
size = "small",
srcs = ["models/tests/test_models.py"]
)
py_test(
name = "test_preprocessors",
tags = ["models"],
-3
View File
@@ -89,7 +89,6 @@ class DDPGTFModel(TFModelV2):
actor_out = tf.keras.layers.Lambda(lambda_)(actor_out)
self.policy_model = tf.keras.Model(self.model_out, actor_out)
self.register_variables(self.policy_model.variables)
# Build the Q-model(s).
self.actions_input = tf.keras.layers.Input(
@@ -116,12 +115,10 @@ class DDPGTFModel(TFModelV2):
return q_net
self.q_model = build_q_net("q", self.model_out, self.actions_input)
self.register_variables(self.q_model.variables)
if twin_q:
self.twin_q_model = build_q_net("twin_q", self.model_out,
self.actions_input)
self.register_variables(self.twin_q_model.variables)
else:
self.twin_q_model = None
@@ -162,13 +162,11 @@ class DistributionalQTFModel(TFModelV2):
q_out = build_action_value(name + "/action_value/", self.model_out)
self.q_value_head = tf.keras.Model(self.model_out, q_out)
self.register_variables(self.q_value_head.variables)
if dueling:
state_out = build_state_score(name + "/state_value/",
self.model_out)
self.state_value_head = tf.keras.Model(self.model_out, state_out)
self.register_variables(self.state_value_head.variables)
def get_q_value_distributions(self,
model_out: TensorType) -> List[TensorType]:
-6
View File
@@ -91,8 +91,6 @@ class SACTFModel(TFModelV2):
])
self.shift_and_log_scale_diag = self.action_model(self.model_out)
self.register_variables(self.action_model.variables)
self.actions_input = None
if not self.discrete:
self.actions_input = tf.keras.layers.Input(
@@ -123,12 +121,10 @@ class SACTFModel(TFModelV2):
return q_net
self.q_net = build_q_net("q", self.model_out, self.actions_input)
self.register_variables(self.q_net.variables)
if twin_q:
self.twin_q_net = build_q_net("twin_q", self.model_out,
self.actions_input)
self.register_variables(self.twin_q_net.variables)
else:
self.twin_q_net = None
@@ -147,8 +143,6 @@ class SACTFModel(TFModelV2):
target_entropy = -np.prod(action_space.shape)
self.target_entropy = target_entropy
self.register_variables([self.log_alpha])
def get_q_values(self,
model_out: TensorType,
actions: Optional[TensorType] = None) -> TensorType:
+2 -1
View File
@@ -33,6 +33,7 @@ if __name__ == "__main__":
"model": {
"custom_model": "bn_model",
},
"lr": 0.0003,
# Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
"num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")),
"num_workers": 0,
@@ -45,7 +46,7 @@ if __name__ == "__main__":
"episode_reward_mean": args.stop_reward,
}
results = tune.run(args.run, stop=stop, config=config, verbose=1)
results = tune.run(args.run, stop=stop, config=config, verbose=2)
if args.as_test:
check_learning_achieved(results, args.stop_reward)
-1
View File
@@ -71,7 +71,6 @@ class CustomModel(TFModelV2):
model_config, name)
self.model = FullyConnectedNetwork(obs_space, action_space,
num_outputs, model_config, name)
self.register_variables(self.model.variables())
def forward(self, input_dict, state, seq_lens):
return self.model.forward(input_dict, state, seq_lens)
-2
View File
@@ -48,7 +48,6 @@ class MyKerasModel(TFModelV2):
activation=None,
kernel_initializer=normc_initializer(0.01))(layer_1)
self.base_model = tf.keras.Model(self.inputs, [layer_out, value_out])
self.register_variables(self.base_model.variables)
def forward(self, input_dict, state, seq_lens):
model_out, self._value_out = self.base_model(input_dict["obs"])
@@ -84,7 +83,6 @@ class MyKerasQModel(DistributionalQTFModel):
activation=tf.nn.relu,
kernel_initializer=normc_initializer(1.0))(layer_1)
self.base_model = tf.keras.Model(self.inputs, layer_out)
self.register_variables(self.base_model.variables)
# Implement the core forward method
def forward(self, input_dict, state, seq_lens):
@@ -69,14 +69,12 @@ class AutoregressiveActionModel(TFModelV2):
# Base layers
self.base_model = tf.keras.Model(obs_input, [context, value_out])
self.register_variables(self.base_model.variables)
self.base_model.summary()
# Autoregressive action sampler
self.action_model = tf.keras.Model([ctx_input, a1_input],
[a1_logits, a2_logits])
self.action_model.summary()
self.register_variables(self.action_model.variables)
def forward(self, input_dict, state, seq_lens):
context, self._value_out = self.base_model(input_dict["obs"])
@@ -123,7 +123,6 @@ class KerasBatchNormModel(TFModelV2):
self.base_model = tf.keras.models.Model(
inputs=[inputs, is_training], outputs=[output, value_out])
self.register_variables(self.base_model.variables)
@override(ModelV2)
def forward(self, input_dict, state, seq_lens):
@@ -23,7 +23,6 @@ class CentralizedCriticModel(TFModelV2):
# Base of the model
self.model = FullyConnectedNetwork(obs_space, action_space,
num_outputs, model_config, name)
self.register_variables(self.model.variables())
# Central VF maps (obs, opp_obs, opp_act) -> vf_pred
obs = tf.keras.layers.Input(shape=(6, ), name="obs")
@@ -37,7 +36,6 @@ class CentralizedCriticModel(TFModelV2):
1, activation=None, name="c_vf_out")(central_vf_dense)
self.central_vf = tf.keras.Model(
inputs=[obs, opp_obs, opp_act], outputs=central_vf_out)
self.register_variables(self.central_vf.variables)
@override(ModelV2)
def forward(self, input_dict, state, seq_lens):
@@ -79,11 +77,9 @@ class YetAnotherCentralizedCriticModel(TFModelV2):
num_outputs,
model_config,
name + "_action")
self.register_variables(self.action_model.variables())
self.value_model = FullyConnectedNetwork(obs_space, action_space, 1,
model_config, name + "_vf")
self.register_variables(self.value_model.variables())
def forward(self, input_dict, state, seq_lens):
self._value_out, _ = self.value_model({
@@ -53,7 +53,6 @@ class CNNPlusFCConcatModel(TFModelV2):
name="cnn_{}".format(i))
concat_size += cnn.num_outputs
self.cnns[i] = cnn
self.register_variables(cnn.variables())
# Discrete inputs -> One-hot encode.
elif isinstance(component, Discrete):
concat_size += component.n
@@ -82,7 +81,6 @@ class CNNPlusFCConcatModel(TFModelV2):
kernel_initializer=normc_initializer(0.01))(concat_layer)
self.logits_and_value_model = tf.keras.models.Model(
concat_layer, [logits_layer, value_layer])
self.register_variables(self.logits_and_value_model.variables)
else:
self.num_outputs = concat_size
@@ -29,7 +29,6 @@ class CustomLossModel(TFModelV2):
num_outputs,
model_config,
name="fcnet")
self.register_variables(self.fcnet.variables())
@override(ModelV2)
def forward(self, input_dict, state, seq_lens):
-1
View File
@@ -40,7 +40,6 @@ class EagerModel(TFModelV2):
out = tf.keras.layers.Lambda(lambda_)(out)
self.base_model = tf.keras.models.Model(inputs, [out, value_out])
self.register_variables(self.base_model.variables)
@override(ModelV2)
def forward(self, input_dict, state, seq_lens):
@@ -68,7 +68,6 @@ class MobileV2PlusRNNModel(RecurrentNetwork):
self.rnn_model = tf.keras.Model(
inputs=[inputs, seq_in, state_in_h, state_in_c],
outputs=[logits, values, state_h, state_c])
self.register_variables(self.rnn_model.variables)
self.rnn_model.summary()
@override(RecurrentNetwork)
@@ -36,7 +36,6 @@ class ParametricActionsModel(DistributionalQTFModel):
self.action_embed_model = FullyConnectedNetwork(
Box(-1, 1, shape=true_obs_shape), action_space, action_embed_size,
model_config, name + "_action_embed")
self.register_variables(self.action_embed_model.variables())
def forward(self, input_dict, state, seq_lens):
# Extract the available actions tensor from the observation.
-1
View File
@@ -54,7 +54,6 @@ class RNNModel(RecurrentNetwork):
self.rnn_model = tf.keras.Model(
inputs=[input_layer, seq_in, state_in_h, state_in_c],
outputs=[logits, values, state_h, state_c])
self.register_variables(self.rnn_model.variables)
self.rnn_model.summary()
@override(RecurrentNetwork)
-1
View File
@@ -107,7 +107,6 @@ class RNNSpyModel(RecurrentNetwork):
[inputs, seq_lens, state_in_h, state_in_c],
[logits, value_out, state_out_h, state_out_c])
self.base_model.summary()
self.register_variables(self.base_model.variables)
@override(RecurrentNetwork)
def forward_rnn(self, inputs, state, seq_lens):
@@ -39,7 +39,6 @@ class TF2SharedWeightsModel(TFModelV2):
vf = tf.keras.layers.Dense(
units=1, activation=None, name="value_out")(last_layer)
self.base_model = tf.keras.models.Model(inputs, [output, vf])
self.register_variables(self.base_model.variables)
@override(ModelV2)
def forward(self, input_dict, state, seq_lens):
@@ -80,7 +79,6 @@ class SharedWeightsModel1(TFModelV2):
vf = tf.keras.layers.Dense(
units=1, activation=None, name="value_out")(last_layer)
self.base_model = tf.keras.models.Model(inputs, [output, vf])
self.register_variables(self.base_model.variables)
@override(ModelV2)
def forward(self, input_dict, state, seq_lens):
@@ -114,7 +112,6 @@ class SharedWeightsModel2(TFModelV2):
vf = tf.keras.layers.Dense(
units=1, activation=None, name="value_out")(last_layer)
self.base_model = tf.keras.models.Model(inputs, [output, vf])
self.register_variables(self.base_model.variables)
@override(ModelV2)
def forward(self, input_dict, state, seq_lens):
@@ -47,7 +47,6 @@ class CustomTFRPGModel(TFModelV2):
name)
self.model = TFFCNet(obs_space, action_space, num_outputs,
model_config, name)
self.register_variables(self.model.variables())
def forward(self, input_dict, state, seq_lens):
# The unpacked input tensors, where M=MAX_PLAYERS, N=MAX_ITEMS:
@@ -36,7 +36,6 @@ class FrameStackingCartPoleModel(TFModelV2):
out = tf.keras.layers.Dense(self.num_outputs)(layer1)
values = tf.keras.layers.Dense(1)(layer1)
self.base_model = tf.keras.models.Model([input_], [out, values])
self.register_variables(self.base_model.variables)
self._last_value = None
+24 -14
View File
@@ -378,7 +378,9 @@ class ModelCatalog:
if model_config.get("use_lstm") else AttentionWrapper)
model_cls._wrapped_forward = forward
# Track and warn if vars were created but not registered.
# Obsolete: Track and warn if vars were created but not
# registered. Only still do this, if users do register their
# variables. If not (which they shouldn't), don't check here.
created = set()
def track_var_creation(next_creator, **kw):
@@ -407,19 +409,27 @@ class ModelCatalog:
# Other error -> re-raise.
else:
raise e
registered = set(instance.variables())
not_registered = set()
for var in created:
if var not in registered:
not_registered.add(var)
if not_registered:
raise ValueError(
"It looks like variables {} were created as part "
"of {} but does not appear in model.variables() "
"({}). Did you forget to call "
"model.register_variables() on the variables in "
"question?".format(not_registered, instance,
registered))
# User still registered TFModelV2's variables: Check, whether
# ok.
registered = set(instance.var_list)
if len(registered) > 0:
not_registered = set()
for var in created:
if var not in registered:
not_registered.add(var)
if not_registered:
raise ValueError(
"It looks like you are still using "
"`{}.register_variables()` to register your "
"model's weights. This is no longer required, but "
"if you are still calling this method at least "
"once, you must make sure to register all created "
"variables properly. The missing variables are {},"
" and you only registered {}. "
"Did you forget to call `register_variables()` on "
"some of the variables in question?".format(
instance, not_registered, registered))
elif framework == "torch":
# Try wrapping custom model with LSTM/attention, if required.
if model_config.get("use_lstm") or \
+3 -6
View File
@@ -94,8 +94,7 @@ class TestDistributions(unittest.TestCase):
inputs = inputs_space.sample()
for fw, sess in framework_iterator(
session=True, frameworks=("tf", "tf2", "torch")):
for fw, sess in framework_iterator(session=True):
# Create the correct distribution object.
cls = JAXCategorical if fw == "jax" else Categorical if \
fw != "torch" else TorchCategorical
@@ -218,8 +217,7 @@ class TestDistributions(unittest.TestCase):
input_space = Box(-2.0, 2.0, shape=(2000, 10))
low, high = -2.0, 1.0
for fw, sess in framework_iterator(
frameworks=("torch", "tf", "tfe"), session=True):
for fw, sess in framework_iterator(session=True):
cls = SquashedGaussian if fw != "torch" else TorchSquashedGaussian
# Do a stability test using extreme NN outputs to see whether
@@ -310,8 +308,7 @@ class TestDistributions(unittest.TestCase):
"""Tests the DiagGaussian ActionDistribution for all frameworks."""
input_space = Box(-2.0, 1.0, shape=(2000, 10))
for fw, sess in framework_iterator(
frameworks=("torch", "tf", "tfe"), session=True):
for fw, sess in framework_iterator(session=True):
cls = DiagGaussian if fw != "torch" else TorchDiagGaussian
# Do a stability test using extreme NN outputs to see whether
+59
View File
@@ -0,0 +1,59 @@
from gym.spaces import Box
import numpy as np
import unittest
from ray.rllib.models.tf.tf_modelv2 import TFModelV2
from ray.rllib.models.tf.fcnet import FullyConnectedNetwork
from ray.rllib.utils.framework import try_import_tf
tf1, tf, tfv = try_import_tf()
class TestTFModel(TFModelV2):
def __init__(self, obs_space, action_space, num_outputs, model_config,
name):
super().__init__(obs_space, action_space, num_outputs, model_config,
name)
input_ = tf.keras.layers.Input(shape=(3, ))
output = tf.keras.layers.Dense(2)(input_)
# A keras model inside.
self.keras_model = tf.keras.models.Model([input_], [output])
# A RLlib FullyConnectedNetwork (tf) inside (which is also a keras
# Model).
self.fc_net = FullyConnectedNetwork(obs_space, action_space, 3, {},
"fc1")
def forward(self, input_dict, state, seq_lens):
obs = input_dict["obs_flat"]
out1 = self.keras_model(obs)
out2, _ = self.fc_net({"obs": obs})
return tf.concat([out1, out2], axis=-1), []
class TestModels(unittest.TestCase):
"""Tests ModelV2 classes and their modularization capabilities."""
def test_tf_modelv2(self):
obs_space = Box(-1.0, 1.0, (3, ))
action_space = Box(-1.0, 1.0, (2, ))
my_tf_model = TestTFModel(obs_space, action_space, 5, {},
"my_tf_model")
# Call the model.
out, states = my_tf_model({"obs": np.array([obs_space.sample()])})
self.assertTrue(out.shape == (1, 5))
self.assertTrue(out.dtype == tf.float32)
self.assertTrue(states == [])
vars = my_tf_model.variables(as_dict=True)
self.assertTrue(len(vars) == 6)
self.assertTrue("keras_model.dense.kernel:0" in vars)
self.assertTrue("keras_model.dense.bias:0" in vars)
self.assertTrue("fc_net.base_model.fc_out.kernel:0" in vars)
self.assertTrue("fc_net.base_model.fc_out.bias:0" in vars)
self.assertTrue("fc_net.base_model.value_out.kernel:0" in vars)
self.assertTrue("fc_net.base_model.value_out.bias:0" in vars)
if __name__ == "__main__":
import pytest
import sys
sys.exit(pytest.main(["-v", __file__]))
-5
View File
@@ -117,7 +117,6 @@ class TrXLNet(RecurrentNetwork):
name="logits")(E_out)
self.base_model = tf.keras.models.Model([inputs], [logits])
self.register_variables(self.base_model.variables)
@override(RecurrentNetwork)
def forward_rnn(self, inputs: TensorType, state: List[TensorType],
@@ -287,7 +286,6 @@ class GTrXLNet(RecurrentNetwork):
self.trxl_model = tf.keras.Model(
inputs=[input_layer] + memory_ins, outputs=outs + memory_outs[:-1])
self.register_variables(self.trxl_model.variables)
self.trxl_model.summary()
# __sphinx_doc_begin__
@@ -386,7 +384,6 @@ class AttentionWrapper(TFModelV2):
position_wise_mlp_dim=cfg["attention_position_wise_mlp_dim"],
init_gru_gate_bias=cfg["attention_init_gru_gate_bias"],
)
self.register_variables(self.gtrxl.variables())
# `self.num_outputs` right now is the number of nodes coming from the
# attention net.
@@ -399,11 +396,9 @@ class AttentionWrapper(TFModelV2):
# values.
out = tf.keras.layers.Dense(self.num_outputs, activation=None)(input_)
self._logits_branch = tf.keras.models.Model([input_], [out])
self.register_variables(self._logits_branch.variables)
out = tf.keras.layers.Dense(1, activation=None)(input_)
self._value_branch = tf.keras.models.Model([input_], [out])
self.register_variables(self._value_branch.variables)
self.view_requirements = self.gtrxl.view_requirements
-2
View File
@@ -33,7 +33,6 @@ class FullyConnectedNetwork(TFModelV2):
num_outputs = num_outputs // 2
self.log_std_var = tf.Variable(
[0.0] * num_outputs, dtype=tf.float32, name="log_std")
self.register_variables([self.log_std_var])
# We are using obs_flat, so take the flattened shape as input.
inputs = tf.keras.layers.Input(
@@ -115,7 +114,6 @@ class FullyConnectedNetwork(TFModelV2):
self.base_model = tf.keras.Model(
inputs, [(logits_out
if logits_out is not None else last_layer), value_out])
self.register_variables(self.base_model.variables)
def forward(self, input_dict: Dict[str, TensorType],
state: List[TensorType],
-2
View File
@@ -50,7 +50,6 @@ class RecurrentNetwork(TFModelV2):
self.rnn_model = tf.keras.Model(
inputs=[input_layer, seq_in, state_in_h, state_in_c],
outputs=[output_layer, state_h, state_c])
self.register_variables(self.rnn_model.variables)
self.rnn_model.summary()
"""
@@ -179,7 +178,6 @@ class LSTMWrapper(RecurrentNetwork):
self._rnn_model = tf.keras.Model(
inputs=[input_layer, seq_in, state_in_h, state_in_c],
outputs=[logits, values, state_h, state_c])
self.register_variables(self._rnn_model.variables)
self._rnn_model.summary()
# Add prev-a/r to this model's view, if required.
+38 -7
View File
@@ -1,9 +1,12 @@
import contextlib
import gym
import re
from typing import List
from ray.util import log_once
from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.utils.annotations import override, PublicAPI
from ray.rllib.utils.deprecation import deprecation_warning
from ray.rllib.utils.framework import try_import_tf
from ray.rllib.utils.typing import ModelConfigDict, TensorType
@@ -12,7 +15,7 @@ tf1, tf, tfv = try_import_tf()
@PublicAPI
class TFModelV2(ModelV2):
"""TF version of ModelV2.
"""TF version of ModelV2, which is always also a keras Model.
Note that this class by itself is not a valid model unless you
implement forward() in a subclass."""
@@ -33,18 +36,18 @@ class TFModelV2(ModelV2):
value_layer = tf.keras.layers.Dense(...)(hidden_layer)
self.base_model = tf.keras.Model(
input_layer, [output_layer, value_layer])
self.register_variables(self.base_model.variables)
"""
ModelV2.__init__(
self,
super().__init__(
obs_space,
action_space,
num_outputs,
model_config,
name,
framework="tf")
# Deprecated: TFModelV2 now automatically track their variables.
self.var_list = []
if tf1.executing_eagerly():
self.graph = None
else:
@@ -65,13 +68,41 @@ class TFModelV2(ModelV2):
def register_variables(self, variables: List[TensorType]) -> None:
"""Register the given list of variables with this model."""
if log_once("deprecated_tfmodelv2_register_variables"):
deprecation_warning(
old="TFModelV2.register_variables", error=False)
self.var_list.extend(variables)
@override(ModelV2)
def variables(self, as_dict: bool = False) -> List[TensorType]:
if as_dict:
return {v.name: v for v in self.var_list}
return list(self.var_list)
# Old way using `register_variables`.
if self.var_list:
return {v.name: v for v in self.var_list}
# New way: Automatically determine the var tree.
else:
ret = {}
for prop, value in self.__dict__.items():
# Keras Model: key=k + "." + var-name (replace '/' by '.').
if isinstance(value, tf.keras.models.Model):
for var in value.variables:
key = prop + "." + re.sub("/", ".", var.name)
ret[key] = var
# Other TFModelV2: Include its vars into ours.
elif isinstance(value, TFModelV2):
for key, var in value.variables(as_dict=True).items():
ret[prop + "." + key] = var
# tf.Variable
elif isinstance(value, tf.Variable):
ret[prop] = value
return ret
# Old way using `register_variables`.
if self.var_list:
return list(self.var_list)
# New way: Automatically determine the var tree.
else:
return list(self.variables(as_dict=True).values())
@override(ModelV2)
def trainable_variables(self, as_dict: bool = False) -> List[TensorType]:
-1
View File
@@ -140,7 +140,6 @@ class VisionNetwork(TFModelV2):
lambda x: tf.squeeze(x, axis=[1, 2]))(last_layer)
self.base_model = tf.keras.Model(inputs, [conv_out, value_out])
self.register_variables(self.base_model.variables)
def forward(self, input_dict: Dict[str, TensorType],
state: List[TensorType],
+1 -1
View File
@@ -11,7 +11,7 @@ _, nn = try_import_torch()
@PublicAPI
class TorchModelV2(ModelV2):
"""Torch version of ModelV2.
"""Torch version of ModelV2, which is also always a torch.nn.Module.
Note that this class by itself is not a valid model unless you
inherit from nn.Module and implement forward() in a subclass."""
-2
View File
@@ -49,8 +49,6 @@ class MyKerasModel(TFModelV2):
else:
self.base_model = tf.keras.Model(self.inputs, layer_out)
self.register_variables(self.base_model.variables)
def forward(self, input_dict, state, seq_lens):
if self.model_config["vf_share_layers"]:
model_out, self._value_out = self.base_model(input_dict["obs"])
@@ -240,7 +240,6 @@ class DictSpyModel(TFModelV2):
self.num_outputs = num_outputs or 64
out = tf.keras.layers.Dense(self.num_outputs)(input_)
self._main_layer = tf.keras.models.Model([input_], [out])
self.register_variables(self._main_layer.variables)
def forward(self, input_dict, state, seq_lens):
def spy(pos, front_cam, task):
@@ -282,7 +281,6 @@ class TupleSpyModel(TFModelV2):
self.num_outputs = num_outputs or 64
out = tf.keras.layers.Dense(self.num_outputs)(input_)
self._main_layer = tf.keras.models.Model([input_], [out])
self.register_variables(self._main_layer.variables)
def forward(self, input_dict, state, seq_lens):
def spy(pos, cam, task):
+9 -5
View File
@@ -1,4 +1,5 @@
import logging
from typing import Optional, Union
logger = logging.getLogger(__name__)
@@ -8,15 +9,18 @@ logger = logging.getLogger(__name__)
DEPRECATED_VALUE = -1
def deprecation_warning(old, new=None, error=None):
"""
Logs (via the `logger` object) or throws a deprecation warning/error.
def deprecation_warning(
old: str,
new: Optional[str] = None,
error: Optional[Union[bool, Exception]] = None) -> None:
"""Warns (via the `logger` object) or throws a deprecation warning/error.
Args:
old (str): A description of the "thing" that is to be deprecated.
new (Optional[str]): A description of the new "thing" that replaces it.
error (Optional[Union[bool,Exception]]): Whether or which exception to
throw. If True, throw ValueError.
error (Optional[Union[bool, Exception]]): Whether or which exception to
throw. If True, throw ValueError. If False, just warn.
If Exception, throw that Exception.
"""
msg = "`{}` has been deprecated.{}".format(
old, (" Use `{}` instead.".format(new) if new else ""))
-1
View File
@@ -208,7 +208,6 @@ class Curiosity(Exploration):
self._curiosity_feature_net.base_model.variables + \
self._curiosity_inverse_fcnet.variables + \
self._curiosity_forward_fcnet.variables
self.model.register_variables(self._optimizer_var_list)
self._optimizer = tf1.train.AdamOptimizer(learning_rate=self.lr)
# Create placeholders and initialize the loss.
if self.framework == "tf":