mirror of
https://github.com/wassname/ray.git
synced 2026-07-05 03:21:00 +08:00
[rllib] Support batch norm layers (#3369)
* batch norm * lint * fix dqn/ddpg update ops * bn model * Update tf_policy_graph.py * Update multi_gpu_impl.py * Apply suggestions from code review Co-Authored-By: ericl <ekhliang@gmail.com>
This commit is contained in:
@@ -53,7 +53,8 @@ class A3CPolicyGraph(LearningRateSchedule, TFPolicyGraph):
|
||||
self.model = ModelCatalog.get_model({
|
||||
"obs": self.observations,
|
||||
"prev_actions": prev_actions,
|
||||
"prev_rewards": prev_rewards
|
||||
"prev_rewards": prev_rewards,
|
||||
"is_training": self._get_is_training_placeholder(),
|
||||
}, observation_space, logit_dim, self.config["model"])
|
||||
action_dist = dist_class(self.model.outputs)
|
||||
self.vf = self.model.value_function()
|
||||
|
||||
@@ -385,13 +385,11 @@ class Agent(Trainable):
|
||||
observation, update=False)
|
||||
if state:
|
||||
return self.local_evaluator.for_policy(
|
||||
lambda p: p.compute_single_action(
|
||||
filtered_obs, state, is_training=False),
|
||||
lambda p: p.compute_single_action(filtered_obs, state),
|
||||
policy_id=policy_id)
|
||||
return self.local_evaluator.for_policy(
|
||||
lambda p: p.compute_single_action(
|
||||
filtered_obs, state, is_training=False)[0],
|
||||
policy_id=policy_id)
|
||||
lambda p: p.compute_single_action(filtered_obs, state)[0],
|
||||
policy_id=policy_id)
|
||||
|
||||
def get_weights(self, policies=None):
|
||||
"""Return a dictionary of policy ids to weights.
|
||||
|
||||
@@ -199,7 +199,9 @@ class DDPGPolicyGraph(TFPolicyGraph):
|
||||
self.stochastic = tf.placeholder(tf.bool, (), name="stochastic")
|
||||
self.eps = tf.placeholder(tf.float32, (), name="eps")
|
||||
self.cur_observations = tf.placeholder(
|
||||
tf.float32, shape=(None, ) + observation_space.shape)
|
||||
tf.float32,
|
||||
shape=(None, ) + observation_space.shape,
|
||||
name="cur_obs")
|
||||
|
||||
# Actor: P (policy) network
|
||||
with tf.variable_scope(P_SCOPE) as scope:
|
||||
@@ -236,7 +238,11 @@ class DDPGPolicyGraph(TFPolicyGraph):
|
||||
|
||||
# p network evaluation
|
||||
with tf.variable_scope(P_SCOPE, reuse=True) as scope:
|
||||
prev_update_ops = set(tf.get_collection(tf.GraphKeys.UPDATE_OPS))
|
||||
self.p_t = self._build_p_network(self.obs_t, observation_space)
|
||||
p_batchnorm_update_ops = list(
|
||||
set(tf.get_collection(tf.GraphKeys.UPDATE_OPS)) -
|
||||
prev_update_ops)
|
||||
|
||||
# target p network evaluation
|
||||
with tf.variable_scope(P_TARGET_SCOPE) as scope:
|
||||
@@ -257,6 +263,7 @@ class DDPGPolicyGraph(TFPolicyGraph):
|
||||
is_target=True)
|
||||
|
||||
# q network evaluation
|
||||
prev_update_ops = set(tf.get_collection(tf.GraphKeys.UPDATE_OPS))
|
||||
with tf.variable_scope(Q_SCOPE) as scope:
|
||||
q_t, model = self._build_q_network(self.obs_t, observation_space,
|
||||
self.act_t)
|
||||
@@ -269,6 +276,8 @@ class DDPGPolicyGraph(TFPolicyGraph):
|
||||
twin_q_t, twin_model = self._build_q_network(
|
||||
self.obs_t, observation_space, self.act_t)
|
||||
self.twin_q_func_vars = _scope_vars(scope.name)
|
||||
q_batchnorm_update_ops = list(
|
||||
set(tf.get_collection(tf.GraphKeys.UPDATE_OPS)) - prev_update_ops)
|
||||
|
||||
# target q network evalution
|
||||
with tf.variable_scope(Q_TARGET_SCOPE) as scope:
|
||||
@@ -345,7 +354,8 @@ class DDPGPolicyGraph(TFPolicyGraph):
|
||||
obs_input=self.cur_observations,
|
||||
action_sampler=self.output_actions,
|
||||
loss=model.loss() + self.loss.total_loss,
|
||||
loss_inputs=self.loss_inputs)
|
||||
loss_inputs=self.loss_inputs,
|
||||
update_ops=q_batchnorm_update_ops + p_batchnorm_update_ops)
|
||||
self.sess.run(tf.global_variables_initializer())
|
||||
|
||||
# Note that this encompasses both the policy and Q-value networks and
|
||||
@@ -359,7 +369,8 @@ class DDPGPolicyGraph(TFPolicyGraph):
|
||||
def _build_q_network(self, obs, obs_space, actions):
|
||||
q_net = QNetwork(
|
||||
ModelCatalog.get_model({
|
||||
"obs": obs
|
||||
"obs": obs,
|
||||
"is_training": self._get_is_training_placeholder(),
|
||||
}, obs_space, 1, self.config["model"]), actions,
|
||||
self.config["critic_hiddens"],
|
||||
self.config["critic_hidden_activation"])
|
||||
@@ -368,7 +379,8 @@ class DDPGPolicyGraph(TFPolicyGraph):
|
||||
def _build_p_network(self, obs, obs_space):
|
||||
return PNetwork(
|
||||
ModelCatalog.get_model({
|
||||
"obs": obs
|
||||
"obs": obs,
|
||||
"is_training": self._get_is_training_placeholder(),
|
||||
}, obs_space, 1, self.config["model"]), self.dim_actions,
|
||||
self.config["actor_hiddens"],
|
||||
self.config["actor_hidden_activation"]).action_scores
|
||||
|
||||
@@ -306,8 +306,12 @@ class DQNPolicyGraph(TFPolicyGraph):
|
||||
|
||||
# q network evaluation
|
||||
with tf.variable_scope(Q_SCOPE, reuse=True):
|
||||
prev_update_ops = set(tf.get_collection(tf.GraphKeys.UPDATE_OPS))
|
||||
q_t, q_logits_t, q_dist_t, model = self._build_q_network(
|
||||
self.obs_t, observation_space)
|
||||
q_batchnorm_update_ops = list(
|
||||
set(tf.get_collection(tf.GraphKeys.UPDATE_OPS)) -
|
||||
prev_update_ops)
|
||||
|
||||
# target q network evalution
|
||||
with tf.variable_scope(Q_TARGET_SCOPE) as scope:
|
||||
@@ -372,13 +376,15 @@ class DQNPolicyGraph(TFPolicyGraph):
|
||||
obs_input=self.cur_observations,
|
||||
action_sampler=self.output_actions,
|
||||
loss=model.loss() + self.loss.loss,
|
||||
loss_inputs=self.loss_inputs)
|
||||
loss_inputs=self.loss_inputs,
|
||||
update_ops=q_batchnorm_update_ops)
|
||||
self.sess.run(tf.global_variables_initializer())
|
||||
|
||||
def _build_q_network(self, obs, space):
|
||||
qnet = QNetwork(
|
||||
ModelCatalog.get_model({
|
||||
"obs": obs
|
||||
"obs": obs,
|
||||
"is_training": self._get_is_training_placeholder(),
|
||||
}, space, self.num_actions, self.config["model"]),
|
||||
self.num_actions, self.config["dueling"], self.config["hiddens"],
|
||||
self.config["noisy"], self.config["num_atoms"],
|
||||
|
||||
@@ -133,6 +133,7 @@ class VTracePolicyGraph(LearningRateSchedule, TFPolicyGraph):
|
||||
"obs": observations,
|
||||
"prev_actions": prev_actions,
|
||||
"prev_rewards": prev_rewards,
|
||||
"is_training": self._get_is_training_placeholder(),
|
||||
},
|
||||
observation_space,
|
||||
logit_dim,
|
||||
|
||||
@@ -35,7 +35,8 @@ class PGPolicyGraph(TFPolicyGraph):
|
||||
self.model = ModelCatalog.get_model({
|
||||
"obs": obs,
|
||||
"prev_actions": prev_actions,
|
||||
"prev_rewards": prev_rewards
|
||||
"prev_rewards": prev_rewards,
|
||||
"is_training": self._get_is_training_placeholder(),
|
||||
}, obs_space, self.logit_dim, self.config["model"])
|
||||
action_dist = dist_class(self.model.outputs) # logit for each action
|
||||
|
||||
|
||||
@@ -24,7 +24,7 @@ DEFAULT_CONFIG = with_common_config({
|
||||
"sample_batch_size": 200,
|
||||
# Number of timesteps collected for each SGD round
|
||||
"train_batch_size": 4000,
|
||||
# Total SGD batch size across all devices for SGD (multi-gpu only)
|
||||
# Total SGD batch size across all devices for SGD
|
||||
"sgd_minibatch_size": 128,
|
||||
# Number of SGD iterations in each outer loop
|
||||
"num_sgd_iter": 30,
|
||||
@@ -49,7 +49,8 @@ DEFAULT_CONFIG = with_common_config({
|
||||
"batch_mode": "truncate_episodes",
|
||||
# Which observation filter to apply to the observation
|
||||
"observation_filter": "MeanStdFilter",
|
||||
# Use the sync samples optimizer instead of the multi-gpu one
|
||||
# Uses the sync samples optimizer instead of the multi-gpu one. This does
|
||||
# not support minibatches.
|
||||
"simple_optimizer": False,
|
||||
})
|
||||
# __sphinx_doc_end__
|
||||
|
||||
@@ -158,7 +158,8 @@ class PPOPolicyGraph(LearningRateSchedule, TFPolicyGraph):
|
||||
{
|
||||
"obs": obs_ph,
|
||||
"prev_actions": prev_actions_ph,
|
||||
"prev_rewards": prev_rewards_ph
|
||||
"prev_rewards": prev_rewards_ph,
|
||||
"is_training": self._get_is_training_placeholder(),
|
||||
},
|
||||
observation_space,
|
||||
logit_dim,
|
||||
@@ -191,7 +192,8 @@ class PPOPolicyGraph(LearningRateSchedule, TFPolicyGraph):
|
||||
self.value_function = ModelCatalog.get_model({
|
||||
"obs": obs_ph,
|
||||
"prev_actions": prev_actions_ph,
|
||||
"prev_rewards": prev_rewards_ph
|
||||
"prev_rewards": prev_rewards_ph,
|
||||
"is_training": self._get_is_training_placeholder(),
|
||||
}, observation_space, 1, vf_config).outputs
|
||||
self.value_function = tf.reshape(self.value_function, [-1])
|
||||
else:
|
||||
|
||||
@@ -42,7 +42,6 @@ class PolicyGraph(object):
|
||||
state_batches,
|
||||
prev_action_batch=None,
|
||||
prev_reward_batch=None,
|
||||
is_training=False,
|
||||
episodes=None):
|
||||
"""Compute actions for the current policy.
|
||||
|
||||
@@ -51,7 +50,6 @@ class PolicyGraph(object):
|
||||
state_batches (list): list of RNN state input batches, if any
|
||||
prev_action_batch (np.ndarray): batch of previous action values
|
||||
prev_reward_batch (np.ndarray): batch of previous rewards
|
||||
is_training (bool): whether we are training the policy
|
||||
episodes (list): MultiAgentEpisode for each obs in obs_batch.
|
||||
This provides access to all of the internal episode state,
|
||||
which may be useful for model-based or multiagent algorithms.
|
||||
@@ -71,7 +69,6 @@ class PolicyGraph(object):
|
||||
state,
|
||||
prev_action_batch=None,
|
||||
prev_reward_batch=None,
|
||||
is_training=False,
|
||||
episode=None):
|
||||
"""Unbatched version of compute_actions.
|
||||
|
||||
@@ -80,7 +77,6 @@ class PolicyGraph(object):
|
||||
state_batches (list): list of RNN state inputs, if any
|
||||
prev_action_batch (np.ndarray): batch of previous action values
|
||||
prev_reward_batch (np.ndarray): batch of previous rewards
|
||||
is_training (bool): whether we are training the policy
|
||||
episode (MultiAgentEpisode): this provides access to all of the
|
||||
internal episode state, which may be useful for model-based or
|
||||
multi-agent algorithms.
|
||||
@@ -92,7 +88,7 @@ class PolicyGraph(object):
|
||||
"""
|
||||
|
||||
[action], state_out, info = self.compute_actions(
|
||||
[obs], [[s] for s in state], is_training, episodes=[episode])
|
||||
[obs], [[s] for s in state], episodes=[episode])
|
||||
return action, [s[0] for s in state_out], \
|
||||
{k: v[0] for k, v in info.items()}
|
||||
|
||||
|
||||
@@ -436,15 +436,13 @@ def _do_policy_eval(tf_sess, to_eval, policies, active_episodes):
|
||||
builder, [t.obs for t in eval_data],
|
||||
rnn_in_cols,
|
||||
prev_action_batch=[t.prev_action for t in eval_data],
|
||||
prev_reward_batch=[t.prev_reward for t in eval_data],
|
||||
is_training=True)
|
||||
prev_reward_batch=[t.prev_reward for t in eval_data])
|
||||
else:
|
||||
eval_results[policy_id] = policy.compute_actions(
|
||||
[t.obs for t in eval_data],
|
||||
rnn_in_cols,
|
||||
prev_action_batch=[t.prev_action for t in eval_data],
|
||||
prev_reward_batch=[t.prev_reward for t in eval_data],
|
||||
is_training=True,
|
||||
episodes=[active_episodes[t.env_id] for t in eval_data])
|
||||
if builder:
|
||||
for k, v in pending_fetches.items():
|
||||
|
||||
@@ -30,7 +30,7 @@ class TFPolicyGraph(PolicyGraph):
|
||||
|
||||
Examples:
|
||||
>>> policy = TFPolicyGraphSubclass(
|
||||
sess, obs_input, action_sampler, loss, loss_inputs, is_training)
|
||||
sess, obs_input, action_sampler, loss, loss_inputs)
|
||||
|
||||
>>> print(policy.compute_actions([1, 0, 2]))
|
||||
(array([0, 1, 1]), [], {})
|
||||
@@ -53,7 +53,8 @@ class TFPolicyGraph(PolicyGraph):
|
||||
prev_reward_input=None,
|
||||
seq_lens=None,
|
||||
max_seq_len=20,
|
||||
batch_divisibility_req=1):
|
||||
batch_divisibility_req=1,
|
||||
update_ops=None):
|
||||
"""Initialize the policy graph.
|
||||
|
||||
Arguments:
|
||||
@@ -82,6 +83,9 @@ class TFPolicyGraph(PolicyGraph):
|
||||
batch_divisibility_req (int): pad all agent experiences batches to
|
||||
multiples of this value. This only has an effect if not using
|
||||
a LSTM model.
|
||||
update_ops (list): override the batchnorm update ops to run when
|
||||
applying gradients. Otherwise we run all update ops found in
|
||||
the current variable scope.
|
||||
"""
|
||||
|
||||
self.observation_space = observation_space
|
||||
@@ -94,7 +98,7 @@ class TFPolicyGraph(PolicyGraph):
|
||||
self._loss = loss
|
||||
self._loss_inputs = loss_inputs
|
||||
self._loss_input_dict = dict(self._loss_inputs)
|
||||
self._is_training = tf.placeholder_with_default(True, ())
|
||||
self._is_training = self._get_is_training_placeholder()
|
||||
self._state_inputs = state_inputs or []
|
||||
self._state_outputs = state_outputs or []
|
||||
for i, ph in enumerate(self._state_inputs):
|
||||
@@ -108,14 +112,24 @@ class TFPolicyGraph(PolicyGraph):
|
||||
for (g, v) in self.gradients(self._optimizer)
|
||||
if g is not None]
|
||||
self._grads = [g for (g, v) in self._grads_and_vars]
|
||||
# specify global_step for TD3 which needs to count the num updates
|
||||
self._apply_op = self._optimizer.apply_gradients(
|
||||
self._grads_and_vars,
|
||||
global_step=tf.train.get_or_create_global_step())
|
||||
|
||||
self._variables = ray.experimental.TensorFlowVariables(
|
||||
self._loss, self._sess)
|
||||
|
||||
# gather update ops for any batch norm layers
|
||||
if update_ops:
|
||||
self._update_ops = update_ops
|
||||
else:
|
||||
self._update_ops = tf.get_collection(
|
||||
tf.GraphKeys.UPDATE_OPS, scope=tf.get_variable_scope().name)
|
||||
if self._update_ops:
|
||||
logger.debug("Update ops to run on apply gradient: {}".format(
|
||||
self._update_ops))
|
||||
with tf.control_dependencies(self._update_ops):
|
||||
# specify global_step for TD3 which needs to count the num updates
|
||||
self._apply_op = self._optimizer.apply_gradients(
|
||||
self._grads_and_vars,
|
||||
global_step=tf.train.get_or_create_global_step())
|
||||
|
||||
if len(self._state_inputs) != len(self._state_outputs):
|
||||
raise ValueError(
|
||||
"Number of state input and output tensors must match, got: "
|
||||
@@ -138,7 +152,6 @@ class TFPolicyGraph(PolicyGraph):
|
||||
state_batches=None,
|
||||
prev_action_batch=None,
|
||||
prev_reward_batch=None,
|
||||
is_training=False,
|
||||
episodes=None):
|
||||
state_batches = state_batches or []
|
||||
assert len(self._state_inputs) == len(state_batches), \
|
||||
@@ -151,7 +164,7 @@ class TFPolicyGraph(PolicyGraph):
|
||||
builder.add_feed_dict({self._prev_action_input: prev_action_batch})
|
||||
if self._prev_reward_input is not None and prev_reward_batch:
|
||||
builder.add_feed_dict({self._prev_reward_input: prev_reward_batch})
|
||||
builder.add_feed_dict({self._is_training: is_training})
|
||||
builder.add_feed_dict({self._is_training: False})
|
||||
builder.add_feed_dict(dict(zip(self._state_inputs, state_batches)))
|
||||
fetches = builder.add_fetches([self._sampler] + self._state_outputs +
|
||||
[self.extra_compute_action_fetches()])
|
||||
@@ -162,12 +175,11 @@ class TFPolicyGraph(PolicyGraph):
|
||||
state_batches=None,
|
||||
prev_action_batch=None,
|
||||
prev_reward_batch=None,
|
||||
is_training=False,
|
||||
episodes=None):
|
||||
builder = TFRunBuilder(self._sess, "compute_actions")
|
||||
fetches = self.build_compute_actions(builder, obs_batch, state_batches,
|
||||
prev_action_batch,
|
||||
prev_reward_batch, is_training)
|
||||
prev_reward_batch)
|
||||
return builder.get(fetches)
|
||||
|
||||
def _get_loss_inputs_dict(self, batch):
|
||||
@@ -287,6 +299,15 @@ class TFPolicyGraph(PolicyGraph):
|
||||
def loss_inputs(self):
|
||||
return self._loss_inputs
|
||||
|
||||
def _get_is_training_placeholder(self):
|
||||
"""Get the placeholder for _is_training, i.e., for batch norm layers.
|
||||
|
||||
This can be called safely before __init__ has run.
|
||||
"""
|
||||
if not hasattr(self, "_is_training"):
|
||||
self._is_training = tf.placeholder_with_default(False, ())
|
||||
return self._is_training
|
||||
|
||||
|
||||
class LearningRateSchedule(object):
|
||||
"""Mixin for TFPolicyGraph that adds a learning rate schedule."""
|
||||
|
||||
@@ -72,7 +72,6 @@ class TorchPolicyGraph(PolicyGraph):
|
||||
state_batches=None,
|
||||
prev_action_batch=None,
|
||||
prev_reward_batch=None,
|
||||
is_training=False,
|
||||
episodes=None):
|
||||
if state_batches:
|
||||
raise NotImplementedError("Torch RNN support")
|
||||
|
||||
@@ -0,0 +1,64 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
"""Example of using a custom model with batch norm."""
|
||||
|
||||
import argparse
|
||||
|
||||
import tensorflow as tf
|
||||
import tensorflow.contrib.slim as slim
|
||||
|
||||
import ray
|
||||
from ray.rllib.models import Model, ModelCatalog
|
||||
from ray.rllib.models.misc import normc_initializer
|
||||
from ray.tune import run_experiments
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--num-iters", type=int, default=200)
|
||||
parser.add_argument("--run", type=str, default="PPO")
|
||||
|
||||
|
||||
class BatchNormModel(Model):
|
||||
def _build_layers_v2(self, input_dict, num_outputs, options):
|
||||
last_layer = input_dict["obs"]
|
||||
hiddens = [256, 256]
|
||||
for i, size in enumerate(hiddens):
|
||||
label = "fc{}".format(i)
|
||||
last_layer = slim.fully_connected(
|
||||
last_layer,
|
||||
size,
|
||||
weights_initializer=normc_initializer(1.0),
|
||||
activation_fn=tf.nn.tanh,
|
||||
scope=label)
|
||||
# Add a batch norm layer
|
||||
last_layer = tf.layers.batch_normalization(
|
||||
last_layer, training=input_dict["is_training"])
|
||||
output = slim.fully_connected(
|
||||
last_layer,
|
||||
num_outputs,
|
||||
weights_initializer=normc_initializer(0.01),
|
||||
activation_fn=None,
|
||||
scope="fc_out")
|
||||
return output, last_layer
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
ray.init()
|
||||
|
||||
ModelCatalog.register_custom_model("bn_model", BatchNormModel)
|
||||
run_experiments({
|
||||
"batch_norm_demo": {
|
||||
"run": args.run,
|
||||
"env": "Pendulum-v0" if args.run == "DDPG" else "CartPole-v0",
|
||||
"stop": {
|
||||
"training_iteration": args.num_iters
|
||||
},
|
||||
"config": {
|
||||
"model": {
|
||||
"custom_model": "bn_model",
|
||||
},
|
||||
"num_workers": 0,
|
||||
},
|
||||
},
|
||||
})
|
||||
@@ -23,7 +23,7 @@ class Model(object):
|
||||
|
||||
Attributes:
|
||||
input_dict (dict): Dictionary of input tensors, including "obs",
|
||||
"prev_action", "prev_reward".
|
||||
"prev_action", "prev_reward", "is_training".
|
||||
outputs (Tensor): The output vector of this model, of shape
|
||||
[BATCH_SIZE, num_outputs].
|
||||
last_layer (Tensor): The feature layer right before the model output,
|
||||
@@ -108,7 +108,7 @@ class Model(object):
|
||||
|
||||
Arguments:
|
||||
input_dict (dict): Dictionary of input tensors, including "obs",
|
||||
"prev_action", "prev_reward".
|
||||
"prev_action", "prev_reward", "is_training".
|
||||
num_outputs (int): Output tensor must be of size
|
||||
[BATCH_SIZE, num_outputs].
|
||||
options (dict): Model options.
|
||||
@@ -124,6 +124,7 @@ class Model(object):
|
||||
>>> print(input_dict)
|
||||
{'prev_actions': <tf.Tensor shape=(?,) dtype=int64>,
|
||||
'prev_rewards': <tf.Tensor shape=(?,) dtype=float32>,
|
||||
'is_training': <tf.Tensor shape=(), dtype=bool>,
|
||||
'obs': OrderedDict([
|
||||
('sensors', OrderedDict([
|
||||
('front_cam', [
|
||||
|
||||
@@ -3,12 +3,15 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from collections import namedtuple
|
||||
import logging
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
# Variable scope in which created variables will be placed under
|
||||
TOWER_SCOPE_NAME = "tower"
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LocalSyncParallelOptimizer(object):
|
||||
"""Optimizer that runs in parallel across multiple local devices.
|
||||
@@ -63,6 +66,8 @@ class LocalSyncParallelOptimizer(object):
|
||||
# First initialize the shared loss network
|
||||
with tf.name_scope(TOWER_SCOPE_NAME):
|
||||
self._shared_loss = build_graph(self.loss_inputs)
|
||||
shared_ops = tf.get_collection(
|
||||
tf.GraphKeys.UPDATE_OPS, scope=tf.get_variable_scope().name)
|
||||
|
||||
# Then setup the per-device loss graphs that use the shared weights
|
||||
self._batch_index = tf.placeholder(tf.int32, name="batch_index")
|
||||
@@ -95,7 +100,20 @@ class LocalSyncParallelOptimizer(object):
|
||||
clipped, _ = tf.clip_by_global_norm(clipped, grad_norm_clipping)
|
||||
for i, (grad, var) in enumerate(avg):
|
||||
avg[i] = (clipped[i], var)
|
||||
self._train_op = self.optimizer.apply_gradients(avg)
|
||||
|
||||
# gather update ops for any batch norm layers. TODO(ekl) here we will
|
||||
# use all the ops found which won't work for DQN / DDPG, but those
|
||||
# aren't supported with multi-gpu right now anyways.
|
||||
self._update_ops = tf.get_collection(
|
||||
tf.GraphKeys.UPDATE_OPS, scope=tf.get_variable_scope().name)
|
||||
for op in shared_ops:
|
||||
self._update_ops.remove(op) # only care about tower update ops
|
||||
if self._update_ops:
|
||||
logger.debug("Update ops to run on apply gradient: {}".format(
|
||||
self._update_ops))
|
||||
|
||||
with tf.control_dependencies(self._update_ops):
|
||||
self._train_op = self.optimizer.apply_gradients(avg)
|
||||
|
||||
def load_data(self, sess, inputs, state_inputs):
|
||||
"""Bulk loads the specified inputs into device memory.
|
||||
|
||||
@@ -323,7 +323,6 @@ class TestMultiAgentEnv(unittest.TestCase):
|
||||
state_batches,
|
||||
prev_action_batch=None,
|
||||
prev_reward_batch=None,
|
||||
is_training=False,
|
||||
episodes=None):
|
||||
return [0] * len(obs_batch), [[h] * len(obs_batch)], {}
|
||||
|
||||
@@ -348,7 +347,6 @@ class TestMultiAgentEnv(unittest.TestCase):
|
||||
state_batches,
|
||||
prev_action_batch=None,
|
||||
prev_reward_batch=None,
|
||||
is_training=False,
|
||||
episodes=None):
|
||||
# Pretend we did a model-based rollout and want to return
|
||||
# the extra trajectory.
|
||||
|
||||
@@ -25,7 +25,6 @@ class MockPolicyGraph(PolicyGraph):
|
||||
state_batches,
|
||||
prev_action_batch=None,
|
||||
prev_reward_batch=None,
|
||||
is_training=False,
|
||||
episodes=None):
|
||||
return [0] * len(obs_batch), [], {}
|
||||
|
||||
@@ -43,7 +42,6 @@ class BadPolicyGraph(PolicyGraph):
|
||||
state_batches,
|
||||
prev_action_batch=None,
|
||||
prev_reward_batch=None,
|
||||
is_training=False,
|
||||
episodes=None):
|
||||
raise Exception("intentional error")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user