mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 01:00:10 +08:00
[rllib] Cleanup RNN support and make it work with multi-GPU optimizer (#2394)
Cleanup: TFPolicyGraph now automatically adds loss input entries for state_in_*, so that graph sub-classes don't need to worry about it. Multi-GPU support: Allow setting up model tower replicas with existing state input tensors Truncate the per-device minibatch slices so that they are always a multiple of max_seq_len.
This commit is contained in:
@@ -49,7 +49,6 @@ class A3CPolicyGraph(TFPolicyGraph):
|
||||
[-1])
|
||||
self.var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
|
||||
tf.get_variable_scope().name)
|
||||
is_training = tf.placeholder_with_default(True, ())
|
||||
|
||||
# Setup the policy loss
|
||||
if isinstance(action_space, gym.spaces.Box):
|
||||
@@ -74,16 +73,13 @@ class A3CPolicyGraph(TFPolicyGraph):
|
||||
("advantages", advantages),
|
||||
("value_targets", v_target),
|
||||
]
|
||||
for i, ph in enumerate(self.model.state_in):
|
||||
loss_in.append(("state_in_{}".format(i), ph))
|
||||
self.state_in = self.model.state_in
|
||||
self.state_out = self.model.state_out
|
||||
TFPolicyGraph.__init__(
|
||||
self, observation_space, action_space, self.sess,
|
||||
obs_input=self.observations, action_sampler=action_dist.sample(),
|
||||
loss=self.loss.total_loss, loss_inputs=loss_in,
|
||||
is_training=is_training, state_inputs=self.state_in,
|
||||
state_outputs=self.state_out,
|
||||
state_inputs=self.state_in, state_outputs=self.state_out,
|
||||
seq_lens=self.model.seq_lens,
|
||||
max_seq_len=self.config["model"]["max_seq_len"])
|
||||
|
||||
|
||||
@@ -46,6 +46,8 @@ COMMON_CONFIG = {
|
||||
"gpu_options": {
|
||||
"allow_growth": True,
|
||||
},
|
||||
"log_device_placement": False,
|
||||
"device_count": {"CPU": 1},
|
||||
"allow_soft_placement": True, # required by PPO multi-gpu
|
||||
},
|
||||
# Whether to LZ4 compress observations
|
||||
|
||||
@@ -262,12 +262,11 @@ class DDPGPolicyGraph(TFPolicyGraph):
|
||||
("dones", self.done_mask),
|
||||
("weights", self.importance_weights),
|
||||
]
|
||||
self.is_training = tf.placeholder_with_default(True, ())
|
||||
TFPolicyGraph.__init__(
|
||||
self, observation_space, action_space, self.sess,
|
||||
obs_input=self.cur_observations,
|
||||
action_sampler=self.output_actions, loss=self.loss.total_loss,
|
||||
loss_inputs=self.loss_inputs, is_training=self.is_training)
|
||||
loss_inputs=self.loss_inputs)
|
||||
self.sess.run(tf.global_variables_initializer())
|
||||
|
||||
# Note that this encompasses both the policy and Q-value networks and
|
||||
|
||||
@@ -171,12 +171,11 @@ class DQNPolicyGraph(TFPolicyGraph):
|
||||
("dones", self.done_mask),
|
||||
("weights", self.importance_weights),
|
||||
]
|
||||
self.is_training = tf.placeholder_with_default(True, ())
|
||||
TFPolicyGraph.__init__(
|
||||
self, observation_space, action_space, self.sess,
|
||||
obs_input=self.cur_observations,
|
||||
action_sampler=self.output_actions, loss=self.loss.loss,
|
||||
loss_inputs=self.loss_inputs, is_training=self.is_training)
|
||||
loss_inputs=self.loss_inputs)
|
||||
self.sess.run(tf.global_variables_initializer())
|
||||
|
||||
def optimizer(self):
|
||||
|
||||
@@ -41,16 +41,10 @@ class PGPolicyGraph(TFPolicyGraph):
|
||||
("advantages", advantages),
|
||||
]
|
||||
|
||||
# LSTM support
|
||||
for i, ph in enumerate(self.model.state_in):
|
||||
loss_in.append(("state_in_{}".format(i), ph))
|
||||
|
||||
is_training = tf.placeholder_with_default(True, ())
|
||||
TFPolicyGraph.__init__(
|
||||
self, obs_space, action_space, sess, obs_input=obs,
|
||||
action_sampler=action_dist.sample(), loss=loss,
|
||||
loss_inputs=loss_in, is_training=is_training,
|
||||
state_inputs=self.model.state_in,
|
||||
loss_inputs=loss_in, state_inputs=self.model.state_in,
|
||||
state_outputs=self.model.state_out,
|
||||
seq_lens=self.model.seq_lens,
|
||||
max_seq_len=config["model"]["max_seq_len"])
|
||||
|
||||
@@ -50,7 +50,7 @@ DEFAULT_CONFIG = with_common_config({
|
||||
"simple_optimizer": False,
|
||||
# Override model config
|
||||
"model": {
|
||||
# Use LSTM model (note: requires simple optimizer for now).
|
||||
# Whether to use LSTM model
|
||||
"use_lstm": False,
|
||||
# Max seq length for LSTM training.
|
||||
"max_seq_len": 20,
|
||||
|
||||
@@ -92,9 +92,10 @@ class PPOPolicyGraph(TFPolicyGraph):
|
||||
dist_cls, logit_dim = ModelCatalog.get_action_dist(action_space)
|
||||
|
||||
if existing_inputs:
|
||||
self.loss_in = existing_inputs
|
||||
obs_ph, value_targets_ph, adv_ph, act_ph, \
|
||||
logits_ph, vf_preds_ph = [ph for _, ph in existing_inputs]
|
||||
logits_ph, vf_preds_ph = existing_inputs[:6]
|
||||
existing_state_in = existing_inputs[6:-1]
|
||||
existing_seq_lens = existing_inputs[-1]
|
||||
else:
|
||||
obs_ph = tf.placeholder(
|
||||
tf.float32, name="obs", shape=(None,)+observation_space.shape)
|
||||
@@ -107,23 +108,20 @@ class PPOPolicyGraph(TFPolicyGraph):
|
||||
tf.float32, name="vf_preds", shape=(None,))
|
||||
value_targets_ph = tf.placeholder(
|
||||
tf.float32, name="value_targets", shape=(None,))
|
||||
existing_state_in = None
|
||||
existing_seq_lens = None
|
||||
|
||||
self.loss_in = [
|
||||
("obs", obs_ph),
|
||||
("value_targets", value_targets_ph),
|
||||
("advantages", adv_ph),
|
||||
("actions", act_ph),
|
||||
("logits", logits_ph),
|
||||
("vf_preds", vf_preds_ph),
|
||||
]
|
||||
|
||||
self.loss_in = [
|
||||
("obs", obs_ph),
|
||||
("value_targets", value_targets_ph),
|
||||
("advantages", adv_ph),
|
||||
("actions", act_ph),
|
||||
("logits", logits_ph),
|
||||
("vf_preds", vf_preds_ph),
|
||||
]
|
||||
self.model = ModelCatalog.get_model(
|
||||
obs_ph, logit_dim, self.config["model"])
|
||||
|
||||
# LSTM support
|
||||
if not existing_inputs:
|
||||
for i, ph in enumerate(self.model.state_in):
|
||||
self.loss_in.append(("state_in_{}".format(i), ph))
|
||||
obs_ph, logit_dim, self.config["model"],
|
||||
state_in=existing_state_in, seq_lens=existing_seq_lens)
|
||||
|
||||
# KL Coefficient
|
||||
self.kl_coeff = tf.get_variable(
|
||||
@@ -155,15 +153,14 @@ class PPOPolicyGraph(TFPolicyGraph):
|
||||
clip_param=self.config["clip_param"],
|
||||
vf_loss_coeff=self.config["kl_target"],
|
||||
use_gae=self.config["use_gae"])
|
||||
self.is_training = tf.placeholder_with_default(True, ())
|
||||
|
||||
TFPolicyGraph.__init__(
|
||||
self, observation_space, action_space,
|
||||
self.sess, obs_input=obs_ph,
|
||||
action_sampler=self.sampler, loss=self.loss_obj.loss,
|
||||
loss_inputs=self.loss_in, is_training=self.is_training,
|
||||
state_inputs=self.model.state_in,
|
||||
state_outputs=self.model.state_out, seq_lens=self.model.seq_lens)
|
||||
loss_inputs=self.loss_in, state_inputs=self.model.state_in,
|
||||
state_outputs=self.model.state_out, seq_lens=self.model.seq_lens,
|
||||
max_seq_len=config["model"]["max_seq_len"])
|
||||
|
||||
self.sess.run(tf.global_variables_initializer())
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import tensorflow as tf
|
||||
import numpy as np
|
||||
|
||||
import ray
|
||||
from ray.rllib.evaluation.policy_graph import PolicyGraph
|
||||
@@ -36,9 +37,8 @@ class TFPolicyGraph(PolicyGraph):
|
||||
|
||||
def __init__(
|
||||
self, observation_space, action_space, sess, obs_input,
|
||||
action_sampler, loss, loss_inputs, is_training,
|
||||
state_inputs=None, state_outputs=None, seq_lens=None,
|
||||
max_seq_len=20):
|
||||
action_sampler, loss, loss_inputs, state_inputs=None,
|
||||
state_outputs=None, seq_lens=None, max_seq_len=20):
|
||||
"""Initialize the policy graph.
|
||||
|
||||
Arguments:
|
||||
@@ -54,10 +54,8 @@ class TFPolicyGraph(PolicyGraph):
|
||||
input argument. Each placeholder name must correspond to a
|
||||
SampleBatch column key returned by postprocess_trajectory(),
|
||||
and has shape [BATCH_SIZE, data...].
|
||||
is_training (Tensor): input placeholder for whether we are
|
||||
currently training the policy.
|
||||
state_inputs (list): list of RNN state output Tensors.
|
||||
state_outputs (list): list of initial state values.
|
||||
state_inputs (list): list of RNN state input Tensors.
|
||||
state_outputs (list): list of RNN state output Tensors.
|
||||
seq_lens (Tensor): placeholder for RNN sequence lengths, of shape
|
||||
[NUM_SEQUENCES]. Note that NUM_SEQUENCES << BATCH_SIZE. See
|
||||
models/lstm.py for more information.
|
||||
@@ -72,9 +70,11 @@ class TFPolicyGraph(PolicyGraph):
|
||||
self._loss = loss
|
||||
self._loss_inputs = loss_inputs
|
||||
self._loss_input_dict = dict(self._loss_inputs)
|
||||
self._is_training = is_training
|
||||
self._is_training = tf.placeholder_with_default(True, ())
|
||||
self._state_inputs = state_inputs or []
|
||||
self._state_outputs = state_outputs or []
|
||||
for i, ph in enumerate(self._state_inputs):
|
||||
self._loss_input_dict["state_in_{}".format(i)] = ph
|
||||
self._seq_lens = seq_lens
|
||||
self._max_seq_len = max_seq_len
|
||||
self._optimizer = self.optimizer()
|
||||
@@ -99,6 +99,8 @@ class TFPolicyGraph(PolicyGraph):
|
||||
(self._state_inputs, state_batches)
|
||||
builder.add_feed_dict(self.extra_compute_action_feed_dict())
|
||||
builder.add_feed_dict({self._obs_input: obs_batch})
|
||||
if state_batches:
|
||||
builder.add_feed_dict({self._seq_lens: np.ones(len(obs_batch))})
|
||||
builder.add_feed_dict({self._is_training: is_training})
|
||||
builder.add_feed_dict(dict(zip(self._state_inputs, state_batches)))
|
||||
fetches = builder.add_fetches(
|
||||
@@ -123,10 +125,9 @@ class TFPolicyGraph(PolicyGraph):
|
||||
return feed_dict
|
||||
|
||||
# RNN case
|
||||
feature_keys = [
|
||||
k for k, v in self._loss_inputs if not k.startswith("state_in_")]
|
||||
feature_keys = [k for k, v in self._loss_inputs]
|
||||
state_keys = [
|
||||
k for k, v in self._loss_inputs if k.startswith("state_in_")]
|
||||
"state_in_{}".format(i) for i in range(len(self._state_inputs))]
|
||||
feature_sequences, initial_states, seq_lens = chop_into_sequences(
|
||||
batch["t"],
|
||||
[batch[k] for k in feature_keys],
|
||||
|
||||
@@ -138,41 +138,47 @@ class ModelCatalog(object):
|
||||
" not supported".format(action_space))
|
||||
|
||||
@staticmethod
|
||||
def get_model(inputs, num_outputs, options=None):
|
||||
def get_model(
|
||||
inputs, num_outputs, options=None, state_in=None, seq_lens=None):
|
||||
"""Returns a suitable model conforming to given input and output specs.
|
||||
|
||||
Args:
|
||||
inputs (Tensor): The input tensor to the model.
|
||||
num_outputs (int): The size of the output vector of the model.
|
||||
options (dict): Optional args to pass to the model constructor.
|
||||
state_in (list): Optional RNN state in tensors.
|
||||
seq_in (Tensor): Optional RNN sequence length tensor.
|
||||
|
||||
Returns:
|
||||
model (Model): Neural network model.
|
||||
"""
|
||||
|
||||
options = options or {}
|
||||
model = ModelCatalog._get_model(inputs, num_outputs, options)
|
||||
model = ModelCatalog._get_model(
|
||||
inputs, num_outputs, options, state_in, seq_lens)
|
||||
|
||||
if options.get("use_lstm"):
|
||||
model = LSTM(model.last_layer, num_outputs, options)
|
||||
model = LSTM(
|
||||
model.last_layer, num_outputs, options, state_in, seq_lens)
|
||||
|
||||
return model
|
||||
|
||||
@staticmethod
|
||||
def _get_model(inputs, num_outputs, options):
|
||||
def _get_model(inputs, num_outputs, options, state_in, seq_lens):
|
||||
if "custom_model" in options:
|
||||
model = options["custom_model"]
|
||||
print("Using custom model {}".format(model))
|
||||
return _global_registry.get(RLLIB_MODEL, model)(
|
||||
inputs, num_outputs, options)
|
||||
inputs, num_outputs, options,
|
||||
state_in=state_in, seq_lens=seq_lens)
|
||||
|
||||
obs_rank = len(inputs.shape) - 1
|
||||
|
||||
# num_outputs > 1 used to avoid hitting this with the value function
|
||||
if isinstance(options.get("custom_options", {}).get(
|
||||
"multiagent_fcnet_hiddens", 1), list) and num_outputs > 1:
|
||||
return MultiAgentFullyConnectedNetwork(inputs,
|
||||
num_outputs, options)
|
||||
return MultiAgentFullyConnectedNetwork(
|
||||
inputs, num_outputs, options)
|
||||
|
||||
if obs_rank > 1:
|
||||
return VisionNetwork(inputs, num_outputs, options)
|
||||
|
||||
@@ -41,8 +41,8 @@ def add_time_dimension(padded_inputs, seq_lens):
|
||||
# Sequence lengths have to be specified for LSTM batch inputs. The
|
||||
# input batch must be padded to the max seq length given here. That is,
|
||||
# batch_size == len(seq_lens) * max(seq_lens)
|
||||
max_seq_len = tf.reduce_max(seq_lens)
|
||||
padded_batch_size = tf.shape(padded_inputs)[0]
|
||||
max_seq_len = padded_batch_size // tf.shape(seq_lens)[0]
|
||||
|
||||
# Dynamically reshape the padded batch to introduce a time dimension.
|
||||
new_batch_size = padded_batch_size // max_seq_len
|
||||
@@ -155,9 +155,14 @@ class LSTM(Model):
|
||||
np.zeros(lstm.state_size.h, np.float32)]
|
||||
|
||||
# Setup LSTM inputs
|
||||
c_in = tf.placeholder(tf.float32, [None, lstm.state_size.c], name="c")
|
||||
h_in = tf.placeholder(tf.float32, [None, lstm.state_size.h], name="h")
|
||||
self.state_in = [c_in, h_in]
|
||||
if self.state_in:
|
||||
c_in, h_in = self.state_in
|
||||
else:
|
||||
c_in = tf.placeholder(
|
||||
tf.float32, [None, lstm.state_size.c], name="c")
|
||||
h_in = tf.placeholder(
|
||||
tf.float32, [None, lstm.state_size.h], name="h")
|
||||
self.state_in = [c_in, h_in]
|
||||
|
||||
# Setup LSTM outputs
|
||||
if use_tf100_api:
|
||||
|
||||
@@ -37,17 +37,19 @@ class Model(object):
|
||||
a scale parameter (like a standard deviation).
|
||||
"""
|
||||
|
||||
def __init__(self, inputs, num_outputs, options):
|
||||
def __init__(
|
||||
self, inputs, num_outputs, options, state_in=None, seq_lens=None):
|
||||
self.inputs = inputs
|
||||
|
||||
# Default attribute values for the non-RNN case
|
||||
self.state_init = []
|
||||
self.state_in = []
|
||||
self.state_in = state_in or []
|
||||
self.state_out = []
|
||||
self.seq_lens = tf.placeholder_with_default(
|
||||
tf.ones( # reshape needed for older tf versions
|
||||
tf.reshape(tf.shape(inputs)[0], [1]), dtype=tf.int32),
|
||||
[None], name="seq_lens")
|
||||
if seq_lens is not None:
|
||||
self.seq_lens = seq_lens
|
||||
else:
|
||||
self.seq_lens = tf.placeholder(
|
||||
dtype=tf.int32, shape=[None], name="seq_lens")
|
||||
|
||||
if options.get("free_log_std", False):
|
||||
assert num_outputs % 2 == 0
|
||||
|
||||
@@ -3,9 +3,7 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from collections import namedtuple
|
||||
import os
|
||||
|
||||
from tensorflow.python.client import timeline
|
||||
import tensorflow as tf
|
||||
|
||||
|
||||
@@ -34,9 +32,11 @@ class LocalSyncParallelOptimizer(object):
|
||||
Args:
|
||||
optimizer: Delegate TensorFlow optimizer object.
|
||||
devices: List of the names of TensorFlow devices to parallelize over.
|
||||
input_placeholders: List of (name, input_placeholder)
|
||||
for the loss function. Tensors of these shapes will be passed
|
||||
to build_graph() in order to define the per-device loss ops.
|
||||
input_placeholders: List of input_placeholders for the loss function.
|
||||
Tensors of these shapes will be passed to build_graph() in order
|
||||
to define the per-device loss ops.
|
||||
rnn_inputs: Extra input placeholders for RNN inputs. These will have
|
||||
shape [BATCH_SIZE // MAX_SEQ_LEN, ...].
|
||||
per_device_batch_size: Number of tuples to optimize over at a time per
|
||||
device. In each call to `optimize()`,
|
||||
`len(devices) * per_device_batch_size` tuples of data will be
|
||||
@@ -47,7 +47,7 @@ class LocalSyncParallelOptimizer(object):
|
||||
grad_norm_clipping: None or int stdev to clip grad norms by
|
||||
"""
|
||||
|
||||
def __init__(self, optimizer, devices, input_placeholders,
|
||||
def __init__(self, optimizer, devices, input_placeholders, rnn_inputs,
|
||||
per_device_batch_size, build_graph, logdir,
|
||||
grad_norm_clipping=None):
|
||||
# TODO(rliaw): remove logdir
|
||||
@@ -55,27 +55,31 @@ class LocalSyncParallelOptimizer(object):
|
||||
self.devices = devices
|
||||
self.batch_size = per_device_batch_size * len(devices)
|
||||
self.per_device_batch_size = per_device_batch_size
|
||||
self.loss_inputs = input_placeholders
|
||||
self.loss_inputs = input_placeholders + rnn_inputs
|
||||
self.build_graph = build_graph
|
||||
self.logdir = logdir
|
||||
|
||||
# First initialize the shared loss network
|
||||
with tf.name_scope(TOWER_SCOPE_NAME):
|
||||
self._shared_loss = build_graph(input_placeholders)
|
||||
self._shared_loss = build_graph(self.loss_inputs)
|
||||
|
||||
# Then setup the per-device loss graphs that use the shared weights
|
||||
self._batch_index = tf.placeholder(tf.int32)
|
||||
self._batch_index = tf.placeholder(tf.int32, name="batch_index")
|
||||
|
||||
# When loading RNN input, we dynamically determine the max seq len
|
||||
self._max_seq_len = tf.placeholder(tf.int32, name="max_seq_len")
|
||||
self._loaded_max_seq_len = 1
|
||||
|
||||
# Split on the CPU in case the data doesn't fit in GPU memory.
|
||||
with tf.device("/cpu:0"):
|
||||
names, placeholders = zip(*input_placeholders)
|
||||
data_splits = zip(
|
||||
*[tf.split(ph, len(devices)) for ph in placeholders])
|
||||
*[tf.split(ph, len(devices)) for ph in self.loss_inputs])
|
||||
|
||||
self._towers = []
|
||||
for device, device_placeholders in zip(self.devices, data_splits):
|
||||
self._towers.append(
|
||||
self._setup_device(device, zip(names, device_placeholders)))
|
||||
self._setup_device(
|
||||
device, device_placeholders, len(input_placeholders)))
|
||||
|
||||
avg = average_gradients([t.grads for t in self._towers])
|
||||
if grad_norm_clipping:
|
||||
@@ -84,7 +88,7 @@ class LocalSyncParallelOptimizer(object):
|
||||
avg[i] = (tf.clip_by_norm(grad, grad_norm_clipping), var)
|
||||
self._train_op = self.optimizer.apply_gradients(avg)
|
||||
|
||||
def load_data(self, sess, inputs, full_trace=False):
|
||||
def load_data(self, sess, inputs, state_inputs):
|
||||
"""Bulk loads the specified inputs into device memory.
|
||||
|
||||
The shape of the inputs must conform to the shapes of the input
|
||||
@@ -95,37 +99,47 @@ class LocalSyncParallelOptimizer(object):
|
||||
|
||||
Args:
|
||||
sess: TensorFlow session.
|
||||
inputs: List of Tensors matching the input placeholders specified
|
||||
at construction time of this optimizer.
|
||||
full_trace: Whether to profile data loading.
|
||||
inputs: List of arrays matching the input placeholders, of shape
|
||||
[BATCH_SIZE, ...].
|
||||
state_inputs: List of RNN input arrays. These arrays have size
|
||||
[BATCH_SIZE / MAX_SEQ_LEN, ...].
|
||||
|
||||
Returns:
|
||||
The number of tuples loaded per device.
|
||||
"""
|
||||
|
||||
feed_dict = {}
|
||||
assert len(self.loss_inputs) == len(inputs)
|
||||
for (name, ph), arr in zip(self.loss_inputs, inputs):
|
||||
truncated_arr = make_divisible_by(arr, self.batch_size)
|
||||
feed_dict[ph] = truncated_arr
|
||||
truncated_len = len(truncated_arr)
|
||||
assert len(self.loss_inputs) == len(inputs + state_inputs), \
|
||||
(self.loss_inputs, inputs, state_inputs)
|
||||
|
||||
if full_trace:
|
||||
run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
|
||||
# The RNN truncation case is more complicated
|
||||
if len(state_inputs) > 0:
|
||||
seq_len = len(inputs[0]) // len(state_inputs[0])
|
||||
self._loaded_max_seq_len = seq_len
|
||||
assert len(state_inputs[0]) * seq_len == len(inputs[0])
|
||||
# Make sure the shorter state inputs arrays are evenly divisible
|
||||
state_inputs = [
|
||||
make_divisible_by(arr, self.batch_size)
|
||||
for arr in state_inputs
|
||||
]
|
||||
# Then truncate the data inputs to match
|
||||
inputs = [
|
||||
arr[:len(state_inputs[0]) * seq_len]
|
||||
for arr in inputs
|
||||
]
|
||||
assert len(state_inputs[0]) * seq_len == len(inputs[0])
|
||||
assert len(state_inputs[0]) % self.batch_size == 0
|
||||
for ph, arr in zip(self.loss_inputs, inputs + state_inputs):
|
||||
feed_dict[ph] = arr
|
||||
truncated_len = len(inputs[0])
|
||||
else:
|
||||
run_options = tf.RunOptions(trace_level=tf.RunOptions.NO_TRACE)
|
||||
run_metadata = tf.RunMetadata()
|
||||
for ph, arr in zip(self.loss_inputs, inputs + state_inputs):
|
||||
truncated_arr = make_divisible_by(arr, self.batch_size)
|
||||
feed_dict[ph] = truncated_arr
|
||||
truncated_len = len(truncated_arr)
|
||||
|
||||
sess.run(
|
||||
[t.init_op for t in self._towers],
|
||||
feed_dict=feed_dict,
|
||||
options=run_options,
|
||||
run_metadata=run_metadata)
|
||||
if full_trace:
|
||||
trace = timeline.Timeline(step_stats=run_metadata.step_stats)
|
||||
trace_file = open(os.path.join(self.logdir, "timeline-load.json"),
|
||||
"w")
|
||||
trace_file.write(trace.generate_chrome_trace_format())
|
||||
[t.init_op for t in self._towers], feed_dict=feed_dict)
|
||||
|
||||
tuples_per_device = truncated_len / len(self.devices)
|
||||
assert tuples_per_device > 0, \
|
||||
@@ -136,7 +150,7 @@ class LocalSyncParallelOptimizer(object):
|
||||
assert tuples_per_device % self.per_device_batch_size == 0
|
||||
return tuples_per_device
|
||||
|
||||
def optimize(self, sess, batch_index, file_writer=None):
|
||||
def optimize(self, sess, batch_index):
|
||||
"""Run a single step of SGD.
|
||||
|
||||
Runs a SGD step over a slice of the preloaded batch with size given by
|
||||
@@ -151,19 +165,14 @@ class LocalSyncParallelOptimizer(object):
|
||||
batch_index: Offset into the preloaded data. This value must be
|
||||
between `0` and `tuples_per_device`. The amount of data to
|
||||
process is always fixed to `per_device_batch_size`.
|
||||
file_writer: If specified, tf metrics will be written out using
|
||||
this.
|
||||
|
||||
Returns:
|
||||
The outputs of extra_ops evaluated over the batch.
|
||||
"""
|
||||
if file_writer:
|
||||
run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
|
||||
else:
|
||||
run_options = tf.RunOptions(trace_level=tf.RunOptions.NO_TRACE)
|
||||
run_metadata = tf.RunMetadata()
|
||||
|
||||
feed_dict = {self._batch_index: batch_index}
|
||||
feed_dict = {
|
||||
self._batch_index: batch_index,
|
||||
self._max_seq_len: self._loaded_max_seq_len,
|
||||
}
|
||||
for tower in self._towers:
|
||||
feed_dict.update(tower.loss_graph.extra_compute_grad_feed_dict())
|
||||
feed_dict.update(tower.loss_graph.extra_apply_grad_feed_dict())
|
||||
@@ -173,21 +182,7 @@ class LocalSyncParallelOptimizer(object):
|
||||
fetches.update(tower.loss_graph.extra_compute_grad_fetches())
|
||||
fetches.update(tower.loss_graph.extra_apply_grad_fetches())
|
||||
|
||||
outs = sess.run(
|
||||
fetches,
|
||||
feed_dict=feed_dict,
|
||||
options=run_options,
|
||||
run_metadata=run_metadata)
|
||||
|
||||
if file_writer:
|
||||
trace = timeline.Timeline(step_stats=run_metadata.step_stats)
|
||||
trace_file = open(os.path.join(self.logdir, "timeline-sgd.json"),
|
||||
"w")
|
||||
trace_file.write(trace.generate_chrome_trace_format())
|
||||
file_writer.add_run_metadata(
|
||||
run_metadata, "sgd_train_{}".format(batch_index))
|
||||
|
||||
return outs
|
||||
return sess.run(fetches, feed_dict=feed_dict)
|
||||
|
||||
def get_common_loss(self):
|
||||
return self._shared_loss
|
||||
@@ -195,23 +190,31 @@ class LocalSyncParallelOptimizer(object):
|
||||
def get_device_losses(self):
|
||||
return [t.loss_graph for t in self._towers]
|
||||
|
||||
def _setup_device(self, device, device_input_placeholders):
|
||||
def _setup_device(self, device, device_input_placeholders, num_data_in):
|
||||
assert num_data_in <= len(device_input_placeholders)
|
||||
with tf.device(device):
|
||||
with tf.name_scope(TOWER_SCOPE_NAME):
|
||||
device_input_batches = []
|
||||
device_input_slices = []
|
||||
for name, ph in device_input_placeholders:
|
||||
for i, ph in enumerate(device_input_placeholders):
|
||||
current_batch = tf.Variable(
|
||||
ph, trainable=False, validate_shape=False,
|
||||
collections=[])
|
||||
device_input_batches.append(current_batch)
|
||||
if i < num_data_in:
|
||||
scale = self._max_seq_len
|
||||
granularity = self._max_seq_len
|
||||
else:
|
||||
scale = self._max_seq_len
|
||||
granularity = 1
|
||||
current_slice = tf.slice(
|
||||
current_batch,
|
||||
[self._batch_index] + [0] * len(ph.shape[1:]),
|
||||
([self.per_device_batch_size] + [-1] *
|
||||
len(ph.shape[1:])))
|
||||
([self._batch_index // scale * granularity] +
|
||||
[0] * len(ph.shape[1:])),
|
||||
([self.per_device_batch_size // scale * granularity] +
|
||||
[-1] * len(ph.shape[1:])))
|
||||
current_slice.set_shape(ph.shape)
|
||||
device_input_slices.append((name, current_slice))
|
||||
device_input_slices.append(current_slice)
|
||||
graph_obj = self.build_graph(device_input_slices)
|
||||
device_grads = graph_obj.gradients(self.optimizer)
|
||||
return Tower(
|
||||
|
||||
@@ -55,12 +55,12 @@ class LocalMultiGPUOptimizer(PolicyOptimizer):
|
||||
print("LocalMultiGPUOptimizer devices", self.devices)
|
||||
|
||||
assert set(self.local_evaluator.policy_map.keys()) == {"default"}, \
|
||||
"Multi-agent is not supported"
|
||||
("Multi-agent is not supported with multi-GPU. Try using the "
|
||||
"simple optimizer instead.")
|
||||
self.policy = self.local_evaluator.policy_map["default"]
|
||||
assert isinstance(self.policy, TFPolicyGraph), \
|
||||
"Only TF policies are supported"
|
||||
assert len(self.policy.get_initial_state()) == 0, \
|
||||
"No RNN support yet for multi-gpu. Try the simple optimizer."
|
||||
("Only TF policies are supported with multi-GPU. Try using the "
|
||||
"simple optimizer instead.")
|
||||
|
||||
# per-GPU graph copies created below must share vars with the policy
|
||||
# reuse is set to AUTO_REUSE because Adam nodes are created after
|
||||
@@ -68,10 +68,16 @@ class LocalMultiGPUOptimizer(PolicyOptimizer):
|
||||
with self.local_evaluator.tf_sess.graph.as_default():
|
||||
with self.local_evaluator.tf_sess.as_default():
|
||||
with tf.variable_scope("default", reuse=tf.AUTO_REUSE):
|
||||
if self.policy._state_inputs:
|
||||
rnn_inputs = self.policy._state_inputs + [
|
||||
self.policy._seq_lens]
|
||||
else:
|
||||
rnn_inputs = []
|
||||
self.par_opt = LocalSyncParallelOptimizer(
|
||||
tf.train.AdamOptimizer(self.sgd_stepsize),
|
||||
self.devices,
|
||||
self.policy.loss_inputs(),
|
||||
[v for _, v in self.policy.loss_inputs()],
|
||||
rnn_inputs,
|
||||
self.per_device_batch_size,
|
||||
self.policy.copy,
|
||||
os.getcwd())
|
||||
@@ -103,9 +109,17 @@ class LocalMultiGPUOptimizer(PolicyOptimizer):
|
||||
samples.shuffle()
|
||||
|
||||
with self.load_timer:
|
||||
tuples = self.policy._get_loss_inputs_dict(samples)
|
||||
data_keys = [ph for _, ph in self.policy.loss_inputs()]
|
||||
if self.policy._state_inputs:
|
||||
state_keys = (
|
||||
self.policy._state_inputs + [self.policy._seq_lens])
|
||||
else:
|
||||
state_keys = []
|
||||
tuples_per_device = self.par_opt.load_data(
|
||||
self.sess,
|
||||
samples.columns([key for key, _ in self.policy.loss_inputs()]))
|
||||
[tuples[k] for k in data_keys],
|
||||
[tuples[k] for k in state_keys])
|
||||
|
||||
with self.grad_timer:
|
||||
num_batches = (
|
||||
|
||||
Reference in New Issue
Block a user