diff --git a/python/ray/rllib/agents/dqn/simple_q_policy.py b/python/ray/rllib/agents/dqn/simple_q_policy.py index 15ae8b976..0212fdef6 100644 --- a/python/ray/rllib/agents/dqn/simple_q_policy.py +++ b/python/ray/rllib/agents/dqn/simple_q_policy.py @@ -4,6 +4,7 @@ from __future__ import print_function """Basic example of a DQN policy without any optimizations.""" from gym.spaces import Discrete +import logging import ray from ray.rllib.agents.dqn.simple_q_model import SimpleQModel @@ -18,6 +19,7 @@ from ray.rllib.utils import try_import_tf from ray.rllib.utils.tf_ops import huber_loss tf = try_import_tf() +logger = logging.getLogger(__name__) Q_SCOPE = "q_func" Q_TARGET_SCOPE = "target_q_func" @@ -55,6 +57,7 @@ class TargetNetworkMixin(object): (self.q_func_vars, self.target_q_func_vars) for var, var_target in zip(self.q_func_vars, self.target_q_func_vars): update_target_expr.append(var_target.assign(var)) + logger.debug("Update target op {}".format(var_target)) self.update_target_expr = tf.group(*update_target_expr) def update_target(self): diff --git a/python/ray/rllib/models/modelv2.py b/python/ray/rllib/models/modelv2.py index 2fca20bfb..4705532b1 100644 --- a/python/ray/rllib/models/modelv2.py +++ b/python/ray/rllib/models/modelv2.py @@ -120,11 +120,11 @@ class ModelV2(object): def variables(self): """Returns the list of variables for this model.""" - return self.var_list + return list(self.var_list) def trainable_variables(self): """Returns the list of trainable variables for this model.""" - return self.variables() + return [v for v in self.variables() if v.trainable] def __call__(self, input_dict, state, seq_lens): """Call the model with the given input tensors and state. diff --git a/python/ray/rllib/models/tf/modelv1_compat.py b/python/ray/rllib/models/tf/modelv1_compat.py index 6bc0ac65e..eb3dbe5fa 100644 --- a/python/ray/rllib/models/tf/modelv1_compat.py +++ b/python/ray/rllib/models/tf/modelv1_compat.py @@ -27,52 +27,9 @@ def make_v1_wrapper(legacy_model_cls): model_config, name) self.legacy_model_cls = legacy_model_cls - def instance_template(input_dict, state, seq_lens): - # create a new model instance - with tf.variable_scope(self.name): - new_instance = self.legacy_model_cls( - input_dict, obs_space, action_space, num_outputs, - model_config, state, seq_lens) - return new_instance - - self.instance_template = tf.make_template("instance_template", - instance_template) # Tracks the last v1 model created by the call to forward self.cur_instance = None - def vf_template(last_layer, input_dict): - with tf.variable_scope(self.variable_scope): - with tf.variable_scope("value_function"): - # Simple case: sharing the feature layer - if model_config["vf_share_layers"]: - return tf.reshape( - linear(last_layer, 1, "value_function", - normc_initializer(1.0)), [-1]) - - # Create a new separate model with no RNN state, etc. - branch_model_config = model_config.copy() - branch_model_config["free_log_std"] = False - if branch_model_config["use_lstm"]: - branch_model_config["use_lstm"] = False - logger.warning( - "It is not recommended to use a LSTM model " - "with vf_share_layers=False (consider " - "setting it to True). If you want to not " - "share layers, you can implement a custom " - "LSTM model that overrides the " - "value_function() method.") - branch_instance = legacy_model_cls( - input_dict, - obs_space, - action_space, - 1, - branch_model_config, - state_in=None, - seq_lens=None) - return tf.reshape(branch_instance.outputs, [-1]) - - self.vf_template = tf.make_template("vf_template", vf_template) - # XXX: Try to guess the initial state size. Since the size of the # state is known only after forward() for V1 models, it might be # wrong. @@ -90,6 +47,9 @@ def make_v1_wrapper(legacy_model_cls): else: self.initial_state = [] + # Tracks branches created so far + self.branches_created = set() + with tf.variable_scope(self.name) as scope: self.variable_scope = scope @@ -99,7 +59,18 @@ def make_v1_wrapper(legacy_model_cls): @override(ModelV2) def __call__(self, input_dict, state, seq_lens): - new_instance = self.instance_template(input_dict, state, seq_lens) + if self.cur_instance: + # create a weight-sharing model copy + with tf.variable_scope(self.cur_instance.scope, reuse=True): + new_instance = self.legacy_model_cls( + input_dict, self.obs_space, self.action_space, + self.num_outputs, self.model_config, state, seq_lens) + else: + # create a new model instance + with tf.variable_scope(self.name): + new_instance = self.legacy_model_cls( + input_dict, self.obs_space, self.action_space, + self.num_outputs, self.model_config, state, seq_lens) if len(new_instance.state_init) != len(self.get_initial_state()): raise ValueError( "When using a custom recurrent ModelV1 model, you should " @@ -114,8 +85,11 @@ def make_v1_wrapper(legacy_model_cls): @override(ModelV2) def variables(self): - return super(ModelV1Wrapper, self).variables() + scope_vars( - self.variable_scope) + var_list = super(ModelV1Wrapper, self).variables() + for v in scope_vars(self.variable_scope): + if v not in var_list: + var_list.append(v) + return var_list @override(ModelV2) def custom_loss(self, policy_loss, loss_inputs): @@ -128,7 +102,43 @@ def make_v1_wrapper(legacy_model_cls): @override(ModelV2) def value_function(self): assert self.cur_instance, "must call forward first" - return self.vf_template(self.cur_instance.last_layer, - self.cur_instance.input_dict) + + with self._branch_variable_scope("value_function"): + # Simple case: sharing the feature layer + if self.model_config["vf_share_layers"]: + return tf.reshape( + linear(self.cur_instance.last_layer, 1, + "value_function", normc_initializer(1.0)), [-1]) + + # Create a new separate model with no RNN state, etc. + branch_model_config = self.model_config.copy() + branch_model_config["free_log_std"] = False + if branch_model_config["use_lstm"]: + branch_model_config["use_lstm"] = False + logger.warning( + "It is not recommended to use a LSTM model with " + "vf_share_layers=False (consider setting it to True). " + "If you want to not share layers, you can implement " + "a custom LSTM model that overrides the " + "value_function() method.") + branch_instance = self.legacy_model_cls( + self.cur_instance.input_dict, + self.obs_space, + self.action_space, + 1, + branch_model_config, + state_in=None, + seq_lens=None) + return tf.reshape(branch_instance.outputs, [-1]) + + def _branch_variable_scope(self, branch_type): + if branch_type in self.branches_created: + reuse = True + else: + self.branches_created.add(branch_type) + reuse = tf.AUTO_REUSE + + with tf.variable_scope(self.variable_scope): + return tf.variable_scope(branch_type, reuse=reuse) return ModelV1Wrapper diff --git a/python/ray/rllib/optimizers/multi_gpu_optimizer.py b/python/ray/rllib/optimizers/multi_gpu_optimizer.py index 11fd80e01..b6329f038 100644 --- a/python/ray/rllib/optimizers/multi_gpu_optimizer.py +++ b/python/ray/rllib/optimizers/multi_gpu_optimizer.py @@ -166,6 +166,7 @@ class LocalMultiGPUOptimizer(PolicyOptimizer): continue policy = self.policies[policy_id] + policy._debug_vars() tuples = policy._get_loss_inputs_dict( batch, shuffle=self.shuffle_sequences) data_keys = [ph for _, ph in policy._loss_inputs] diff --git a/python/ray/rllib/policy/tf_policy.py b/python/ray/rllib/policy/tf_policy.py index 7348c7b04..a4d456f83 100644 --- a/python/ray/rllib/policy/tf_policy.py +++ b/python/ray/rllib/policy/tf_policy.py @@ -349,6 +349,11 @@ class TFPolicy(Policy): self._is_training = tf.placeholder_with_default(False, ()) return self._is_training + def _debug_vars(self): + if log_once("grad_vars"): + for _, v in self._grads_and_vars: + logger.info("Optimizing variable {}".format(v)) + def _extra_input_signature_def(self): """Extra input signatures to add when exporting tf model. Inferred from extra_compute_action_feed_dict() @@ -436,6 +441,7 @@ class TFPolicy(Policy): return fetches[0], fetches[1:-1], fetches[-1] def _build_compute_gradients(self, builder, postprocessed_batch): + self._debug_vars() builder.add_feed_dict(self.extra_compute_grad_feed_dict()) builder.add_feed_dict({self._is_training: True}) builder.add_feed_dict( @@ -455,6 +461,7 @@ class TFPolicy(Policy): return fetches[0] def _build_learn_on_batch(self, builder, postprocessed_batch): + self._debug_vars() builder.add_feed_dict(self.extra_compute_grad_feed_dict()) builder.add_feed_dict( self._get_loss_inputs_dict(postprocessed_batch, shuffle=False))