[rllib] Support batch norm layers (#3369)

* batch norm

* lint

* fix dqn/ddpg update ops

* bn model

* Update tf_policy_graph.py

* Update multi_gpu_impl.py

* Apply suggestions from code review

Co-Authored-By: ericl <ekhliang@gmail.com>
This commit is contained in:
Eric Liang
2018-11-29 13:33:39 -08:00
committed by GitHub
parent 4d2010a852
commit 07d8cbf414
19 changed files with 182 additions and 49 deletions
@@ -53,7 +53,8 @@ class A3CPolicyGraph(LearningRateSchedule, TFPolicyGraph):
self.model = ModelCatalog.get_model({
"obs": self.observations,
"prev_actions": prev_actions,
"prev_rewards": prev_rewards
"prev_rewards": prev_rewards,
"is_training": self._get_is_training_placeholder(),
}, observation_space, logit_dim, self.config["model"])
action_dist = dist_class(self.model.outputs)
self.vf = self.model.value_function()
+3 -5
View File
@@ -385,13 +385,11 @@ class Agent(Trainable):
observation, update=False)
if state:
return self.local_evaluator.for_policy(
lambda p: p.compute_single_action(
filtered_obs, state, is_training=False),
lambda p: p.compute_single_action(filtered_obs, state),
policy_id=policy_id)
return self.local_evaluator.for_policy(
lambda p: p.compute_single_action(
filtered_obs, state, is_training=False)[0],
policy_id=policy_id)
lambda p: p.compute_single_action(filtered_obs, state)[0],
policy_id=policy_id)
def get_weights(self, policies=None):
"""Return a dictionary of policy ids to weights.
@@ -199,7 +199,9 @@ class DDPGPolicyGraph(TFPolicyGraph):
self.stochastic = tf.placeholder(tf.bool, (), name="stochastic")
self.eps = tf.placeholder(tf.float32, (), name="eps")
self.cur_observations = tf.placeholder(
tf.float32, shape=(None, ) + observation_space.shape)
tf.float32,
shape=(None, ) + observation_space.shape,
name="cur_obs")
# Actor: P (policy) network
with tf.variable_scope(P_SCOPE) as scope:
@@ -236,7 +238,11 @@ class DDPGPolicyGraph(TFPolicyGraph):
# p network evaluation
with tf.variable_scope(P_SCOPE, reuse=True) as scope:
prev_update_ops = set(tf.get_collection(tf.GraphKeys.UPDATE_OPS))
self.p_t = self._build_p_network(self.obs_t, observation_space)
p_batchnorm_update_ops = list(
set(tf.get_collection(tf.GraphKeys.UPDATE_OPS)) -
prev_update_ops)
# target p network evaluation
with tf.variable_scope(P_TARGET_SCOPE) as scope:
@@ -257,6 +263,7 @@ class DDPGPolicyGraph(TFPolicyGraph):
is_target=True)
# q network evaluation
prev_update_ops = set(tf.get_collection(tf.GraphKeys.UPDATE_OPS))
with tf.variable_scope(Q_SCOPE) as scope:
q_t, model = self._build_q_network(self.obs_t, observation_space,
self.act_t)
@@ -269,6 +276,8 @@ class DDPGPolicyGraph(TFPolicyGraph):
twin_q_t, twin_model = self._build_q_network(
self.obs_t, observation_space, self.act_t)
self.twin_q_func_vars = _scope_vars(scope.name)
q_batchnorm_update_ops = list(
set(tf.get_collection(tf.GraphKeys.UPDATE_OPS)) - prev_update_ops)
# target q network evalution
with tf.variable_scope(Q_TARGET_SCOPE) as scope:
@@ -345,7 +354,8 @@ class DDPGPolicyGraph(TFPolicyGraph):
obs_input=self.cur_observations,
action_sampler=self.output_actions,
loss=model.loss() + self.loss.total_loss,
loss_inputs=self.loss_inputs)
loss_inputs=self.loss_inputs,
update_ops=q_batchnorm_update_ops + p_batchnorm_update_ops)
self.sess.run(tf.global_variables_initializer())
# Note that this encompasses both the policy and Q-value networks and
@@ -359,7 +369,8 @@ class DDPGPolicyGraph(TFPolicyGraph):
def _build_q_network(self, obs, obs_space, actions):
q_net = QNetwork(
ModelCatalog.get_model({
"obs": obs
"obs": obs,
"is_training": self._get_is_training_placeholder(),
}, obs_space, 1, self.config["model"]), actions,
self.config["critic_hiddens"],
self.config["critic_hidden_activation"])
@@ -368,7 +379,8 @@ class DDPGPolicyGraph(TFPolicyGraph):
def _build_p_network(self, obs, obs_space):
return PNetwork(
ModelCatalog.get_model({
"obs": obs
"obs": obs,
"is_training": self._get_is_training_placeholder(),
}, obs_space, 1, self.config["model"]), self.dim_actions,
self.config["actor_hiddens"],
self.config["actor_hidden_activation"]).action_scores
@@ -306,8 +306,12 @@ class DQNPolicyGraph(TFPolicyGraph):
# q network evaluation
with tf.variable_scope(Q_SCOPE, reuse=True):
prev_update_ops = set(tf.get_collection(tf.GraphKeys.UPDATE_OPS))
q_t, q_logits_t, q_dist_t, model = self._build_q_network(
self.obs_t, observation_space)
q_batchnorm_update_ops = list(
set(tf.get_collection(tf.GraphKeys.UPDATE_OPS)) -
prev_update_ops)
# target q network evalution
with tf.variable_scope(Q_TARGET_SCOPE) as scope:
@@ -372,13 +376,15 @@ class DQNPolicyGraph(TFPolicyGraph):
obs_input=self.cur_observations,
action_sampler=self.output_actions,
loss=model.loss() + self.loss.loss,
loss_inputs=self.loss_inputs)
loss_inputs=self.loss_inputs,
update_ops=q_batchnorm_update_ops)
self.sess.run(tf.global_variables_initializer())
def _build_q_network(self, obs, space):
qnet = QNetwork(
ModelCatalog.get_model({
"obs": obs
"obs": obs,
"is_training": self._get_is_training_placeholder(),
}, space, self.num_actions, self.config["model"]),
self.num_actions, self.config["dueling"], self.config["hiddens"],
self.config["noisy"], self.config["num_atoms"],
@@ -133,6 +133,7 @@ class VTracePolicyGraph(LearningRateSchedule, TFPolicyGraph):
"obs": observations,
"prev_actions": prev_actions,
"prev_rewards": prev_rewards,
"is_training": self._get_is_training_placeholder(),
},
observation_space,
logit_dim,
@@ -35,7 +35,8 @@ class PGPolicyGraph(TFPolicyGraph):
self.model = ModelCatalog.get_model({
"obs": obs,
"prev_actions": prev_actions,
"prev_rewards": prev_rewards
"prev_rewards": prev_rewards,
"is_training": self._get_is_training_placeholder(),
}, obs_space, self.logit_dim, self.config["model"])
action_dist = dist_class(self.model.outputs) # logit for each action
+3 -2
View File
@@ -24,7 +24,7 @@ DEFAULT_CONFIG = with_common_config({
"sample_batch_size": 200,
# Number of timesteps collected for each SGD round
"train_batch_size": 4000,
# Total SGD batch size across all devices for SGD (multi-gpu only)
# Total SGD batch size across all devices for SGD
"sgd_minibatch_size": 128,
# Number of SGD iterations in each outer loop
"num_sgd_iter": 30,
@@ -49,7 +49,8 @@ DEFAULT_CONFIG = with_common_config({
"batch_mode": "truncate_episodes",
# Which observation filter to apply to the observation
"observation_filter": "MeanStdFilter",
# Use the sync samples optimizer instead of the multi-gpu one
# Uses the sync samples optimizer instead of the multi-gpu one. This does
# not support minibatches.
"simple_optimizer": False,
})
# __sphinx_doc_end__
@@ -158,7 +158,8 @@ class PPOPolicyGraph(LearningRateSchedule, TFPolicyGraph):
{
"obs": obs_ph,
"prev_actions": prev_actions_ph,
"prev_rewards": prev_rewards_ph
"prev_rewards": prev_rewards_ph,
"is_training": self._get_is_training_placeholder(),
},
observation_space,
logit_dim,
@@ -191,7 +192,8 @@ class PPOPolicyGraph(LearningRateSchedule, TFPolicyGraph):
self.value_function = ModelCatalog.get_model({
"obs": obs_ph,
"prev_actions": prev_actions_ph,
"prev_rewards": prev_rewards_ph
"prev_rewards": prev_rewards_ph,
"is_training": self._get_is_training_placeholder(),
}, observation_space, 1, vf_config).outputs
self.value_function = tf.reshape(self.value_function, [-1])
else:
+1 -5
View File
@@ -42,7 +42,6 @@ class PolicyGraph(object):
state_batches,
prev_action_batch=None,
prev_reward_batch=None,
is_training=False,
episodes=None):
"""Compute actions for the current policy.
@@ -51,7 +50,6 @@ class PolicyGraph(object):
state_batches (list): list of RNN state input batches, if any
prev_action_batch (np.ndarray): batch of previous action values
prev_reward_batch (np.ndarray): batch of previous rewards
is_training (bool): whether we are training the policy
episodes (list): MultiAgentEpisode for each obs in obs_batch.
This provides access to all of the internal episode state,
which may be useful for model-based or multiagent algorithms.
@@ -71,7 +69,6 @@ class PolicyGraph(object):
state,
prev_action_batch=None,
prev_reward_batch=None,
is_training=False,
episode=None):
"""Unbatched version of compute_actions.
@@ -80,7 +77,6 @@ class PolicyGraph(object):
state_batches (list): list of RNN state inputs, if any
prev_action_batch (np.ndarray): batch of previous action values
prev_reward_batch (np.ndarray): batch of previous rewards
is_training (bool): whether we are training the policy
episode (MultiAgentEpisode): this provides access to all of the
internal episode state, which may be useful for model-based or
multi-agent algorithms.
@@ -92,7 +88,7 @@ class PolicyGraph(object):
"""
[action], state_out, info = self.compute_actions(
[obs], [[s] for s in state], is_training, episodes=[episode])
[obs], [[s] for s in state], episodes=[episode])
return action, [s[0] for s in state_out], \
{k: v[0] for k, v in info.items()}
+1 -3
View File
@@ -436,15 +436,13 @@ def _do_policy_eval(tf_sess, to_eval, policies, active_episodes):
builder, [t.obs for t in eval_data],
rnn_in_cols,
prev_action_batch=[t.prev_action for t in eval_data],
prev_reward_batch=[t.prev_reward for t in eval_data],
is_training=True)
prev_reward_batch=[t.prev_reward for t in eval_data])
else:
eval_results[policy_id] = policy.compute_actions(
[t.obs for t in eval_data],
rnn_in_cols,
prev_action_batch=[t.prev_action for t in eval_data],
prev_reward_batch=[t.prev_reward for t in eval_data],
is_training=True,
episodes=[active_episodes[t.env_id] for t in eval_data])
if builder:
for k, v in pending_fetches.items():
+33 -12
View File
@@ -30,7 +30,7 @@ class TFPolicyGraph(PolicyGraph):
Examples:
>>> policy = TFPolicyGraphSubclass(
sess, obs_input, action_sampler, loss, loss_inputs, is_training)
sess, obs_input, action_sampler, loss, loss_inputs)
>>> print(policy.compute_actions([1, 0, 2]))
(array([0, 1, 1]), [], {})
@@ -53,7 +53,8 @@ class TFPolicyGraph(PolicyGraph):
prev_reward_input=None,
seq_lens=None,
max_seq_len=20,
batch_divisibility_req=1):
batch_divisibility_req=1,
update_ops=None):
"""Initialize the policy graph.
Arguments:
@@ -82,6 +83,9 @@ class TFPolicyGraph(PolicyGraph):
batch_divisibility_req (int): pad all agent experiences batches to
multiples of this value. This only has an effect if not using
a LSTM model.
update_ops (list): override the batchnorm update ops to run when
applying gradients. Otherwise we run all update ops found in
the current variable scope.
"""
self.observation_space = observation_space
@@ -94,7 +98,7 @@ class TFPolicyGraph(PolicyGraph):
self._loss = loss
self._loss_inputs = loss_inputs
self._loss_input_dict = dict(self._loss_inputs)
self._is_training = tf.placeholder_with_default(True, ())
self._is_training = self._get_is_training_placeholder()
self._state_inputs = state_inputs or []
self._state_outputs = state_outputs or []
for i, ph in enumerate(self._state_inputs):
@@ -108,14 +112,24 @@ class TFPolicyGraph(PolicyGraph):
for (g, v) in self.gradients(self._optimizer)
if g is not None]
self._grads = [g for (g, v) in self._grads_and_vars]
# specify global_step for TD3 which needs to count the num updates
self._apply_op = self._optimizer.apply_gradients(
self._grads_and_vars,
global_step=tf.train.get_or_create_global_step())
self._variables = ray.experimental.TensorFlowVariables(
self._loss, self._sess)
# gather update ops for any batch norm layers
if update_ops:
self._update_ops = update_ops
else:
self._update_ops = tf.get_collection(
tf.GraphKeys.UPDATE_OPS, scope=tf.get_variable_scope().name)
if self._update_ops:
logger.debug("Update ops to run on apply gradient: {}".format(
self._update_ops))
with tf.control_dependencies(self._update_ops):
# specify global_step for TD3 which needs to count the num updates
self._apply_op = self._optimizer.apply_gradients(
self._grads_and_vars,
global_step=tf.train.get_or_create_global_step())
if len(self._state_inputs) != len(self._state_outputs):
raise ValueError(
"Number of state input and output tensors must match, got: "
@@ -138,7 +152,6 @@ class TFPolicyGraph(PolicyGraph):
state_batches=None,
prev_action_batch=None,
prev_reward_batch=None,
is_training=False,
episodes=None):
state_batches = state_batches or []
assert len(self._state_inputs) == len(state_batches), \
@@ -151,7 +164,7 @@ class TFPolicyGraph(PolicyGraph):
builder.add_feed_dict({self._prev_action_input: prev_action_batch})
if self._prev_reward_input is not None and prev_reward_batch:
builder.add_feed_dict({self._prev_reward_input: prev_reward_batch})
builder.add_feed_dict({self._is_training: is_training})
builder.add_feed_dict({self._is_training: False})
builder.add_feed_dict(dict(zip(self._state_inputs, state_batches)))
fetches = builder.add_fetches([self._sampler] + self._state_outputs +
[self.extra_compute_action_fetches()])
@@ -162,12 +175,11 @@ class TFPolicyGraph(PolicyGraph):
state_batches=None,
prev_action_batch=None,
prev_reward_batch=None,
is_training=False,
episodes=None):
builder = TFRunBuilder(self._sess, "compute_actions")
fetches = self.build_compute_actions(builder, obs_batch, state_batches,
prev_action_batch,
prev_reward_batch, is_training)
prev_reward_batch)
return builder.get(fetches)
def _get_loss_inputs_dict(self, batch):
@@ -287,6 +299,15 @@ class TFPolicyGraph(PolicyGraph):
def loss_inputs(self):
return self._loss_inputs
def _get_is_training_placeholder(self):
"""Get the placeholder for _is_training, i.e., for batch norm layers.
This can be called safely before __init__ has run.
"""
if not hasattr(self, "_is_training"):
self._is_training = tf.placeholder_with_default(False, ())
return self._is_training
class LearningRateSchedule(object):
"""Mixin for TFPolicyGraph that adds a learning rate schedule."""
@@ -72,7 +72,6 @@ class TorchPolicyGraph(PolicyGraph):
state_batches=None,
prev_action_batch=None,
prev_reward_batch=None,
is_training=False,
episodes=None):
if state_batches:
raise NotImplementedError("Torch RNN support")
@@ -0,0 +1,64 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
"""Example of using a custom model with batch norm."""
import argparse
import tensorflow as tf
import tensorflow.contrib.slim as slim
import ray
from ray.rllib.models import Model, ModelCatalog
from ray.rllib.models.misc import normc_initializer
from ray.tune import run_experiments
parser = argparse.ArgumentParser()
parser.add_argument("--num-iters", type=int, default=200)
parser.add_argument("--run", type=str, default="PPO")
class BatchNormModel(Model):
def _build_layers_v2(self, input_dict, num_outputs, options):
last_layer = input_dict["obs"]
hiddens = [256, 256]
for i, size in enumerate(hiddens):
label = "fc{}".format(i)
last_layer = slim.fully_connected(
last_layer,
size,
weights_initializer=normc_initializer(1.0),
activation_fn=tf.nn.tanh,
scope=label)
# Add a batch norm layer
last_layer = tf.layers.batch_normalization(
last_layer, training=input_dict["is_training"])
output = slim.fully_connected(
last_layer,
num_outputs,
weights_initializer=normc_initializer(0.01),
activation_fn=None,
scope="fc_out")
return output, last_layer
if __name__ == "__main__":
args = parser.parse_args()
ray.init()
ModelCatalog.register_custom_model("bn_model", BatchNormModel)
run_experiments({
"batch_norm_demo": {
"run": args.run,
"env": "Pendulum-v0" if args.run == "DDPG" else "CartPole-v0",
"stop": {
"training_iteration": args.num_iters
},
"config": {
"model": {
"custom_model": "bn_model",
},
"num_workers": 0,
},
},
})
+3 -2
View File
@@ -23,7 +23,7 @@ class Model(object):
Attributes:
input_dict (dict): Dictionary of input tensors, including "obs",
"prev_action", "prev_reward".
"prev_action", "prev_reward", "is_training".
outputs (Tensor): The output vector of this model, of shape
[BATCH_SIZE, num_outputs].
last_layer (Tensor): The feature layer right before the model output,
@@ -108,7 +108,7 @@ class Model(object):
Arguments:
input_dict (dict): Dictionary of input tensors, including "obs",
"prev_action", "prev_reward".
"prev_action", "prev_reward", "is_training".
num_outputs (int): Output tensor must be of size
[BATCH_SIZE, num_outputs].
options (dict): Model options.
@@ -124,6 +124,7 @@ class Model(object):
>>> print(input_dict)
{'prev_actions': <tf.Tensor shape=(?,) dtype=int64>,
'prev_rewards': <tf.Tensor shape=(?,) dtype=float32>,
'is_training': <tf.Tensor shape=(), dtype=bool>,
'obs': OrderedDict([
('sensors', OrderedDict([
('front_cam', [
+19 -1
View File
@@ -3,12 +3,15 @@ from __future__ import division
from __future__ import print_function
from collections import namedtuple
import logging
import tensorflow as tf
# Variable scope in which created variables will be placed under
TOWER_SCOPE_NAME = "tower"
logger = logging.getLogger(__name__)
class LocalSyncParallelOptimizer(object):
"""Optimizer that runs in parallel across multiple local devices.
@@ -63,6 +66,8 @@ class LocalSyncParallelOptimizer(object):
# First initialize the shared loss network
with tf.name_scope(TOWER_SCOPE_NAME):
self._shared_loss = build_graph(self.loss_inputs)
shared_ops = tf.get_collection(
tf.GraphKeys.UPDATE_OPS, scope=tf.get_variable_scope().name)
# Then setup the per-device loss graphs that use the shared weights
self._batch_index = tf.placeholder(tf.int32, name="batch_index")
@@ -95,7 +100,20 @@ class LocalSyncParallelOptimizer(object):
clipped, _ = tf.clip_by_global_norm(clipped, grad_norm_clipping)
for i, (grad, var) in enumerate(avg):
avg[i] = (clipped[i], var)
self._train_op = self.optimizer.apply_gradients(avg)
# gather update ops for any batch norm layers. TODO(ekl) here we will
# use all the ops found which won't work for DQN / DDPG, but those
# aren't supported with multi-gpu right now anyways.
self._update_ops = tf.get_collection(
tf.GraphKeys.UPDATE_OPS, scope=tf.get_variable_scope().name)
for op in shared_ops:
self._update_ops.remove(op) # only care about tower update ops
if self._update_ops:
logger.debug("Update ops to run on apply gradient: {}".format(
self._update_ops))
with tf.control_dependencies(self._update_ops):
self._train_op = self.optimizer.apply_gradients(avg)
def load_data(self, sess, inputs, state_inputs):
"""Bulk loads the specified inputs into device memory.
@@ -323,7 +323,6 @@ class TestMultiAgentEnv(unittest.TestCase):
state_batches,
prev_action_batch=None,
prev_reward_batch=None,
is_training=False,
episodes=None):
return [0] * len(obs_batch), [[h] * len(obs_batch)], {}
@@ -348,7 +347,6 @@ class TestMultiAgentEnv(unittest.TestCase):
state_batches,
prev_action_batch=None,
prev_reward_batch=None,
is_training=False,
episodes=None):
# Pretend we did a model-based rollout and want to return
# the extra trajectory.
@@ -25,7 +25,6 @@ class MockPolicyGraph(PolicyGraph):
state_batches,
prev_action_batch=None,
prev_reward_batch=None,
is_training=False,
episodes=None):
return [0] * len(obs_batch), [], {}
@@ -43,7 +42,6 @@ class BadPolicyGraph(PolicyGraph):
state_batches,
prev_action_batch=None,
prev_reward_batch=None,
is_training=False,
episodes=None):
raise Exception("intentional error")