[rllib] Revert "use make template" which seems to break DQN/Atari (#5134)

* Revert "use make template"

This reverts commit 291e9e0031c6e315fe24e5b4973dea375fe73918.

* debug vars
This commit is contained in:
Eric Liang
2019-07-07 19:51:26 -07:00
committed by GitHub
parent 7e020e7183
commit 893744b3be
5 changed files with 71 additions and 50 deletions
@@ -4,6 +4,7 @@ from __future__ import print_function
"""Basic example of a DQN policy without any optimizations."""
from gym.spaces import Discrete
import logging
import ray
from ray.rllib.agents.dqn.simple_q_model import SimpleQModel
@@ -18,6 +19,7 @@ from ray.rllib.utils import try_import_tf
from ray.rllib.utils.tf_ops import huber_loss
tf = try_import_tf()
logger = logging.getLogger(__name__)
Q_SCOPE = "q_func"
Q_TARGET_SCOPE = "target_q_func"
@@ -55,6 +57,7 @@ class TargetNetworkMixin(object):
(self.q_func_vars, self.target_q_func_vars)
for var, var_target in zip(self.q_func_vars, self.target_q_func_vars):
update_target_expr.append(var_target.assign(var))
logger.debug("Update target op {}".format(var_target))
self.update_target_expr = tf.group(*update_target_expr)
def update_target(self):
+2 -2
View File
@@ -120,11 +120,11 @@ class ModelV2(object):
def variables(self):
"""Returns the list of variables for this model."""
return self.var_list
return list(self.var_list)
def trainable_variables(self):
"""Returns the list of trainable variables for this model."""
return self.variables()
return [v for v in self.variables() if v.trainable]
def __call__(self, input_dict, state, seq_lens):
"""Call the model with the given input tensors and state.
+58 -48
View File
@@ -27,52 +27,9 @@ def make_v1_wrapper(legacy_model_cls):
model_config, name)
self.legacy_model_cls = legacy_model_cls
def instance_template(input_dict, state, seq_lens):
# create a new model instance
with tf.variable_scope(self.name):
new_instance = self.legacy_model_cls(
input_dict, obs_space, action_space, num_outputs,
model_config, state, seq_lens)
return new_instance
self.instance_template = tf.make_template("instance_template",
instance_template)
# Tracks the last v1 model created by the call to forward
self.cur_instance = None
def vf_template(last_layer, input_dict):
with tf.variable_scope(self.variable_scope):
with tf.variable_scope("value_function"):
# Simple case: sharing the feature layer
if model_config["vf_share_layers"]:
return tf.reshape(
linear(last_layer, 1, "value_function",
normc_initializer(1.0)), [-1])
# Create a new separate model with no RNN state, etc.
branch_model_config = model_config.copy()
branch_model_config["free_log_std"] = False
if branch_model_config["use_lstm"]:
branch_model_config["use_lstm"] = False
logger.warning(
"It is not recommended to use a LSTM model "
"with vf_share_layers=False (consider "
"setting it to True). If you want to not "
"share layers, you can implement a custom "
"LSTM model that overrides the "
"value_function() method.")
branch_instance = legacy_model_cls(
input_dict,
obs_space,
action_space,
1,
branch_model_config,
state_in=None,
seq_lens=None)
return tf.reshape(branch_instance.outputs, [-1])
self.vf_template = tf.make_template("vf_template", vf_template)
# XXX: Try to guess the initial state size. Since the size of the
# state is known only after forward() for V1 models, it might be
# wrong.
@@ -90,6 +47,9 @@ def make_v1_wrapper(legacy_model_cls):
else:
self.initial_state = []
# Tracks branches created so far
self.branches_created = set()
with tf.variable_scope(self.name) as scope:
self.variable_scope = scope
@@ -99,7 +59,18 @@ def make_v1_wrapper(legacy_model_cls):
@override(ModelV2)
def __call__(self, input_dict, state, seq_lens):
new_instance = self.instance_template(input_dict, state, seq_lens)
if self.cur_instance:
# create a weight-sharing model copy
with tf.variable_scope(self.cur_instance.scope, reuse=True):
new_instance = self.legacy_model_cls(
input_dict, self.obs_space, self.action_space,
self.num_outputs, self.model_config, state, seq_lens)
else:
# create a new model instance
with tf.variable_scope(self.name):
new_instance = self.legacy_model_cls(
input_dict, self.obs_space, self.action_space,
self.num_outputs, self.model_config, state, seq_lens)
if len(new_instance.state_init) != len(self.get_initial_state()):
raise ValueError(
"When using a custom recurrent ModelV1 model, you should "
@@ -114,8 +85,11 @@ def make_v1_wrapper(legacy_model_cls):
@override(ModelV2)
def variables(self):
return super(ModelV1Wrapper, self).variables() + scope_vars(
self.variable_scope)
var_list = super(ModelV1Wrapper, self).variables()
for v in scope_vars(self.variable_scope):
if v not in var_list:
var_list.append(v)
return var_list
@override(ModelV2)
def custom_loss(self, policy_loss, loss_inputs):
@@ -128,7 +102,43 @@ def make_v1_wrapper(legacy_model_cls):
@override(ModelV2)
def value_function(self):
assert self.cur_instance, "must call forward first"
return self.vf_template(self.cur_instance.last_layer,
self.cur_instance.input_dict)
with self._branch_variable_scope("value_function"):
# Simple case: sharing the feature layer
if self.model_config["vf_share_layers"]:
return tf.reshape(
linear(self.cur_instance.last_layer, 1,
"value_function", normc_initializer(1.0)), [-1])
# Create a new separate model with no RNN state, etc.
branch_model_config = self.model_config.copy()
branch_model_config["free_log_std"] = False
if branch_model_config["use_lstm"]:
branch_model_config["use_lstm"] = False
logger.warning(
"It is not recommended to use a LSTM model with "
"vf_share_layers=False (consider setting it to True). "
"If you want to not share layers, you can implement "
"a custom LSTM model that overrides the "
"value_function() method.")
branch_instance = self.legacy_model_cls(
self.cur_instance.input_dict,
self.obs_space,
self.action_space,
1,
branch_model_config,
state_in=None,
seq_lens=None)
return tf.reshape(branch_instance.outputs, [-1])
def _branch_variable_scope(self, branch_type):
if branch_type in self.branches_created:
reuse = True
else:
self.branches_created.add(branch_type)
reuse = tf.AUTO_REUSE
with tf.variable_scope(self.variable_scope):
return tf.variable_scope(branch_type, reuse=reuse)
return ModelV1Wrapper
@@ -166,6 +166,7 @@ class LocalMultiGPUOptimizer(PolicyOptimizer):
continue
policy = self.policies[policy_id]
policy._debug_vars()
tuples = policy._get_loss_inputs_dict(
batch, shuffle=self.shuffle_sequences)
data_keys = [ph for _, ph in policy._loss_inputs]
+7
View File
@@ -349,6 +349,11 @@ class TFPolicy(Policy):
self._is_training = tf.placeholder_with_default(False, ())
return self._is_training
def _debug_vars(self):
if log_once("grad_vars"):
for _, v in self._grads_and_vars:
logger.info("Optimizing variable {}".format(v))
def _extra_input_signature_def(self):
"""Extra input signatures to add when exporting tf model.
Inferred from extra_compute_action_feed_dict()
@@ -436,6 +441,7 @@ class TFPolicy(Policy):
return fetches[0], fetches[1:-1], fetches[-1]
def _build_compute_gradients(self, builder, postprocessed_batch):
self._debug_vars()
builder.add_feed_dict(self.extra_compute_grad_feed_dict())
builder.add_feed_dict({self._is_training: True})
builder.add_feed_dict(
@@ -455,6 +461,7 @@ class TFPolicy(Policy):
return fetches[0]
def _build_learn_on_batch(self, builder, postprocessed_batch):
self._debug_vars()
builder.add_feed_dict(self.extra_compute_grad_feed_dict())
builder.add_feed_dict(
self._get_loss_inputs_dict(postprocessed_batch, shuffle=False))