mirror of
https://github.com/wassname/ray.git
synced 2026-06-30 01:23:10 +08:00
[rllib] Revert "use make template" which seems to break DQN/Atari (#5134)
* Revert "use make template" This reverts commit 291e9e0031c6e315fe24e5b4973dea375fe73918. * debug vars
This commit is contained in:
@@ -4,6 +4,7 @@ from __future__ import print_function
|
||||
"""Basic example of a DQN policy without any optimizations."""
|
||||
|
||||
from gym.spaces import Discrete
|
||||
import logging
|
||||
|
||||
import ray
|
||||
from ray.rllib.agents.dqn.simple_q_model import SimpleQModel
|
||||
@@ -18,6 +19,7 @@ from ray.rllib.utils import try_import_tf
|
||||
from ray.rllib.utils.tf_ops import huber_loss
|
||||
|
||||
tf = try_import_tf()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
Q_SCOPE = "q_func"
|
||||
Q_TARGET_SCOPE = "target_q_func"
|
||||
@@ -55,6 +57,7 @@ class TargetNetworkMixin(object):
|
||||
(self.q_func_vars, self.target_q_func_vars)
|
||||
for var, var_target in zip(self.q_func_vars, self.target_q_func_vars):
|
||||
update_target_expr.append(var_target.assign(var))
|
||||
logger.debug("Update target op {}".format(var_target))
|
||||
self.update_target_expr = tf.group(*update_target_expr)
|
||||
|
||||
def update_target(self):
|
||||
|
||||
@@ -120,11 +120,11 @@ class ModelV2(object):
|
||||
|
||||
def variables(self):
|
||||
"""Returns the list of variables for this model."""
|
||||
return self.var_list
|
||||
return list(self.var_list)
|
||||
|
||||
def trainable_variables(self):
|
||||
"""Returns the list of trainable variables for this model."""
|
||||
return self.variables()
|
||||
return [v for v in self.variables() if v.trainable]
|
||||
|
||||
def __call__(self, input_dict, state, seq_lens):
|
||||
"""Call the model with the given input tensors and state.
|
||||
|
||||
@@ -27,52 +27,9 @@ def make_v1_wrapper(legacy_model_cls):
|
||||
model_config, name)
|
||||
self.legacy_model_cls = legacy_model_cls
|
||||
|
||||
def instance_template(input_dict, state, seq_lens):
|
||||
# create a new model instance
|
||||
with tf.variable_scope(self.name):
|
||||
new_instance = self.legacy_model_cls(
|
||||
input_dict, obs_space, action_space, num_outputs,
|
||||
model_config, state, seq_lens)
|
||||
return new_instance
|
||||
|
||||
self.instance_template = tf.make_template("instance_template",
|
||||
instance_template)
|
||||
# Tracks the last v1 model created by the call to forward
|
||||
self.cur_instance = None
|
||||
|
||||
def vf_template(last_layer, input_dict):
|
||||
with tf.variable_scope(self.variable_scope):
|
||||
with tf.variable_scope("value_function"):
|
||||
# Simple case: sharing the feature layer
|
||||
if model_config["vf_share_layers"]:
|
||||
return tf.reshape(
|
||||
linear(last_layer, 1, "value_function",
|
||||
normc_initializer(1.0)), [-1])
|
||||
|
||||
# Create a new separate model with no RNN state, etc.
|
||||
branch_model_config = model_config.copy()
|
||||
branch_model_config["free_log_std"] = False
|
||||
if branch_model_config["use_lstm"]:
|
||||
branch_model_config["use_lstm"] = False
|
||||
logger.warning(
|
||||
"It is not recommended to use a LSTM model "
|
||||
"with vf_share_layers=False (consider "
|
||||
"setting it to True). If you want to not "
|
||||
"share layers, you can implement a custom "
|
||||
"LSTM model that overrides the "
|
||||
"value_function() method.")
|
||||
branch_instance = legacy_model_cls(
|
||||
input_dict,
|
||||
obs_space,
|
||||
action_space,
|
||||
1,
|
||||
branch_model_config,
|
||||
state_in=None,
|
||||
seq_lens=None)
|
||||
return tf.reshape(branch_instance.outputs, [-1])
|
||||
|
||||
self.vf_template = tf.make_template("vf_template", vf_template)
|
||||
|
||||
# XXX: Try to guess the initial state size. Since the size of the
|
||||
# state is known only after forward() for V1 models, it might be
|
||||
# wrong.
|
||||
@@ -90,6 +47,9 @@ def make_v1_wrapper(legacy_model_cls):
|
||||
else:
|
||||
self.initial_state = []
|
||||
|
||||
# Tracks branches created so far
|
||||
self.branches_created = set()
|
||||
|
||||
with tf.variable_scope(self.name) as scope:
|
||||
self.variable_scope = scope
|
||||
|
||||
@@ -99,7 +59,18 @@ def make_v1_wrapper(legacy_model_cls):
|
||||
|
||||
@override(ModelV2)
|
||||
def __call__(self, input_dict, state, seq_lens):
|
||||
new_instance = self.instance_template(input_dict, state, seq_lens)
|
||||
if self.cur_instance:
|
||||
# create a weight-sharing model copy
|
||||
with tf.variable_scope(self.cur_instance.scope, reuse=True):
|
||||
new_instance = self.legacy_model_cls(
|
||||
input_dict, self.obs_space, self.action_space,
|
||||
self.num_outputs, self.model_config, state, seq_lens)
|
||||
else:
|
||||
# create a new model instance
|
||||
with tf.variable_scope(self.name):
|
||||
new_instance = self.legacy_model_cls(
|
||||
input_dict, self.obs_space, self.action_space,
|
||||
self.num_outputs, self.model_config, state, seq_lens)
|
||||
if len(new_instance.state_init) != len(self.get_initial_state()):
|
||||
raise ValueError(
|
||||
"When using a custom recurrent ModelV1 model, you should "
|
||||
@@ -114,8 +85,11 @@ def make_v1_wrapper(legacy_model_cls):
|
||||
|
||||
@override(ModelV2)
|
||||
def variables(self):
|
||||
return super(ModelV1Wrapper, self).variables() + scope_vars(
|
||||
self.variable_scope)
|
||||
var_list = super(ModelV1Wrapper, self).variables()
|
||||
for v in scope_vars(self.variable_scope):
|
||||
if v not in var_list:
|
||||
var_list.append(v)
|
||||
return var_list
|
||||
|
||||
@override(ModelV2)
|
||||
def custom_loss(self, policy_loss, loss_inputs):
|
||||
@@ -128,7 +102,43 @@ def make_v1_wrapper(legacy_model_cls):
|
||||
@override(ModelV2)
|
||||
def value_function(self):
|
||||
assert self.cur_instance, "must call forward first"
|
||||
return self.vf_template(self.cur_instance.last_layer,
|
||||
self.cur_instance.input_dict)
|
||||
|
||||
with self._branch_variable_scope("value_function"):
|
||||
# Simple case: sharing the feature layer
|
||||
if self.model_config["vf_share_layers"]:
|
||||
return tf.reshape(
|
||||
linear(self.cur_instance.last_layer, 1,
|
||||
"value_function", normc_initializer(1.0)), [-1])
|
||||
|
||||
# Create a new separate model with no RNN state, etc.
|
||||
branch_model_config = self.model_config.copy()
|
||||
branch_model_config["free_log_std"] = False
|
||||
if branch_model_config["use_lstm"]:
|
||||
branch_model_config["use_lstm"] = False
|
||||
logger.warning(
|
||||
"It is not recommended to use a LSTM model with "
|
||||
"vf_share_layers=False (consider setting it to True). "
|
||||
"If you want to not share layers, you can implement "
|
||||
"a custom LSTM model that overrides the "
|
||||
"value_function() method.")
|
||||
branch_instance = self.legacy_model_cls(
|
||||
self.cur_instance.input_dict,
|
||||
self.obs_space,
|
||||
self.action_space,
|
||||
1,
|
||||
branch_model_config,
|
||||
state_in=None,
|
||||
seq_lens=None)
|
||||
return tf.reshape(branch_instance.outputs, [-1])
|
||||
|
||||
def _branch_variable_scope(self, branch_type):
|
||||
if branch_type in self.branches_created:
|
||||
reuse = True
|
||||
else:
|
||||
self.branches_created.add(branch_type)
|
||||
reuse = tf.AUTO_REUSE
|
||||
|
||||
with tf.variable_scope(self.variable_scope):
|
||||
return tf.variable_scope(branch_type, reuse=reuse)
|
||||
|
||||
return ModelV1Wrapper
|
||||
|
||||
@@ -166,6 +166,7 @@ class LocalMultiGPUOptimizer(PolicyOptimizer):
|
||||
continue
|
||||
|
||||
policy = self.policies[policy_id]
|
||||
policy._debug_vars()
|
||||
tuples = policy._get_loss_inputs_dict(
|
||||
batch, shuffle=self.shuffle_sequences)
|
||||
data_keys = [ph for _, ph in policy._loss_inputs]
|
||||
|
||||
@@ -349,6 +349,11 @@ class TFPolicy(Policy):
|
||||
self._is_training = tf.placeholder_with_default(False, ())
|
||||
return self._is_training
|
||||
|
||||
def _debug_vars(self):
|
||||
if log_once("grad_vars"):
|
||||
for _, v in self._grads_and_vars:
|
||||
logger.info("Optimizing variable {}".format(v))
|
||||
|
||||
def _extra_input_signature_def(self):
|
||||
"""Extra input signatures to add when exporting tf model.
|
||||
Inferred from extra_compute_action_feed_dict()
|
||||
@@ -436,6 +441,7 @@ class TFPolicy(Policy):
|
||||
return fetches[0], fetches[1:-1], fetches[-1]
|
||||
|
||||
def _build_compute_gradients(self, builder, postprocessed_batch):
|
||||
self._debug_vars()
|
||||
builder.add_feed_dict(self.extra_compute_grad_feed_dict())
|
||||
builder.add_feed_dict({self._is_training: True})
|
||||
builder.add_feed_dict(
|
||||
@@ -455,6 +461,7 @@ class TFPolicy(Policy):
|
||||
return fetches[0]
|
||||
|
||||
def _build_learn_on_batch(self, builder, postprocessed_batch):
|
||||
self._debug_vars()
|
||||
builder.add_feed_dict(self.extra_compute_grad_feed_dict())
|
||||
builder.add_feed_dict(
|
||||
self._get_loss_inputs_dict(postprocessed_batch, shuffle=False))
|
||||
|
||||
Reference in New Issue
Block a user