[rllib] Cleanup RNN support and make it work with multi-GPU optimizer (#2394)

Cleanup: TFPolicyGraph now automatically adds loss input entries for state_in_*, so that graph sub-classes don't need to worry about it.

Multi-GPU support:

Allow setting up model tower replicas with existing state input tensors

Truncate the per-device minibatch slices so that they are always a multiple of max_seq_len.
This commit is contained in:
Eric Liang
2018-07-17 06:55:46 +02:00
committed by GitHub
parent 1b645fcc8b
commit 0cecf6b79c
14 changed files with 163 additions and 138 deletions
@@ -49,7 +49,6 @@ class A3CPolicyGraph(TFPolicyGraph):
[-1])
self.var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
tf.get_variable_scope().name)
is_training = tf.placeholder_with_default(True, ())
# Setup the policy loss
if isinstance(action_space, gym.spaces.Box):
@@ -74,16 +73,13 @@ class A3CPolicyGraph(TFPolicyGraph):
("advantages", advantages),
("value_targets", v_target),
]
for i, ph in enumerate(self.model.state_in):
loss_in.append(("state_in_{}".format(i), ph))
self.state_in = self.model.state_in
self.state_out = self.model.state_out
TFPolicyGraph.__init__(
self, observation_space, action_space, self.sess,
obs_input=self.observations, action_sampler=action_dist.sample(),
loss=self.loss.total_loss, loss_inputs=loss_in,
is_training=is_training, state_inputs=self.state_in,
state_outputs=self.state_out,
state_inputs=self.state_in, state_outputs=self.state_out,
seq_lens=self.model.seq_lens,
max_seq_len=self.config["model"]["max_seq_len"])
+2
View File
@@ -46,6 +46,8 @@ COMMON_CONFIG = {
"gpu_options": {
"allow_growth": True,
},
"log_device_placement": False,
"device_count": {"CPU": 1},
"allow_soft_placement": True, # required by PPO multi-gpu
},
# Whether to LZ4 compress observations
@@ -262,12 +262,11 @@ class DDPGPolicyGraph(TFPolicyGraph):
("dones", self.done_mask),
("weights", self.importance_weights),
]
self.is_training = tf.placeholder_with_default(True, ())
TFPolicyGraph.__init__(
self, observation_space, action_space, self.sess,
obs_input=self.cur_observations,
action_sampler=self.output_actions, loss=self.loss.total_loss,
loss_inputs=self.loss_inputs, is_training=self.is_training)
loss_inputs=self.loss_inputs)
self.sess.run(tf.global_variables_initializer())
# Note that this encompasses both the policy and Q-value networks and
@@ -171,12 +171,11 @@ class DQNPolicyGraph(TFPolicyGraph):
("dones", self.done_mask),
("weights", self.importance_weights),
]
self.is_training = tf.placeholder_with_default(True, ())
TFPolicyGraph.__init__(
self, observation_space, action_space, self.sess,
obs_input=self.cur_observations,
action_sampler=self.output_actions, loss=self.loss.loss,
loss_inputs=self.loss_inputs, is_training=self.is_training)
loss_inputs=self.loss_inputs)
self.sess.run(tf.global_variables_initializer())
def optimizer(self):
@@ -41,16 +41,10 @@ class PGPolicyGraph(TFPolicyGraph):
("advantages", advantages),
]
# LSTM support
for i, ph in enumerate(self.model.state_in):
loss_in.append(("state_in_{}".format(i), ph))
is_training = tf.placeholder_with_default(True, ())
TFPolicyGraph.__init__(
self, obs_space, action_space, sess, obs_input=obs,
action_sampler=action_dist.sample(), loss=loss,
loss_inputs=loss_in, is_training=is_training,
state_inputs=self.model.state_in,
loss_inputs=loss_in, state_inputs=self.model.state_in,
state_outputs=self.model.state_out,
seq_lens=self.model.seq_lens,
max_seq_len=config["model"]["max_seq_len"])
+1 -1
View File
@@ -50,7 +50,7 @@ DEFAULT_CONFIG = with_common_config({
"simple_optimizer": False,
# Override model config
"model": {
# Use LSTM model (note: requires simple optimizer for now).
# Whether to use LSTM model
"use_lstm": False,
# Max seq length for LSTM training.
"max_seq_len": 20,
+18 -21
View File
@@ -92,9 +92,10 @@ class PPOPolicyGraph(TFPolicyGraph):
dist_cls, logit_dim = ModelCatalog.get_action_dist(action_space)
if existing_inputs:
self.loss_in = existing_inputs
obs_ph, value_targets_ph, adv_ph, act_ph, \
logits_ph, vf_preds_ph = [ph for _, ph in existing_inputs]
logits_ph, vf_preds_ph = existing_inputs[:6]
existing_state_in = existing_inputs[6:-1]
existing_seq_lens = existing_inputs[-1]
else:
obs_ph = tf.placeholder(
tf.float32, name="obs", shape=(None,)+observation_space.shape)
@@ -107,23 +108,20 @@ class PPOPolicyGraph(TFPolicyGraph):
tf.float32, name="vf_preds", shape=(None,))
value_targets_ph = tf.placeholder(
tf.float32, name="value_targets", shape=(None,))
existing_state_in = None
existing_seq_lens = None
self.loss_in = [
("obs", obs_ph),
("value_targets", value_targets_ph),
("advantages", adv_ph),
("actions", act_ph),
("logits", logits_ph),
("vf_preds", vf_preds_ph),
]
self.loss_in = [
("obs", obs_ph),
("value_targets", value_targets_ph),
("advantages", adv_ph),
("actions", act_ph),
("logits", logits_ph),
("vf_preds", vf_preds_ph),
]
self.model = ModelCatalog.get_model(
obs_ph, logit_dim, self.config["model"])
# LSTM support
if not existing_inputs:
for i, ph in enumerate(self.model.state_in):
self.loss_in.append(("state_in_{}".format(i), ph))
obs_ph, logit_dim, self.config["model"],
state_in=existing_state_in, seq_lens=existing_seq_lens)
# KL Coefficient
self.kl_coeff = tf.get_variable(
@@ -155,15 +153,14 @@ class PPOPolicyGraph(TFPolicyGraph):
clip_param=self.config["clip_param"],
vf_loss_coeff=self.config["kl_target"],
use_gae=self.config["use_gae"])
self.is_training = tf.placeholder_with_default(True, ())
TFPolicyGraph.__init__(
self, observation_space, action_space,
self.sess, obs_input=obs_ph,
action_sampler=self.sampler, loss=self.loss_obj.loss,
loss_inputs=self.loss_in, is_training=self.is_training,
state_inputs=self.model.state_in,
state_outputs=self.model.state_out, seq_lens=self.model.seq_lens)
loss_inputs=self.loss_in, state_inputs=self.model.state_in,
state_outputs=self.model.state_out, seq_lens=self.model.seq_lens,
max_seq_len=config["model"]["max_seq_len"])
self.sess.run(tf.global_variables_initializer())
+12 -11
View File
@@ -3,6 +3,7 @@ from __future__ import division
from __future__ import print_function
import tensorflow as tf
import numpy as np
import ray
from ray.rllib.evaluation.policy_graph import PolicyGraph
@@ -36,9 +37,8 @@ class TFPolicyGraph(PolicyGraph):
def __init__(
self, observation_space, action_space, sess, obs_input,
action_sampler, loss, loss_inputs, is_training,
state_inputs=None, state_outputs=None, seq_lens=None,
max_seq_len=20):
action_sampler, loss, loss_inputs, state_inputs=None,
state_outputs=None, seq_lens=None, max_seq_len=20):
"""Initialize the policy graph.
Arguments:
@@ -54,10 +54,8 @@ class TFPolicyGraph(PolicyGraph):
input argument. Each placeholder name must correspond to a
SampleBatch column key returned by postprocess_trajectory(),
and has shape [BATCH_SIZE, data...].
is_training (Tensor): input placeholder for whether we are
currently training the policy.
state_inputs (list): list of RNN state output Tensors.
state_outputs (list): list of initial state values.
state_inputs (list): list of RNN state input Tensors.
state_outputs (list): list of RNN state output Tensors.
seq_lens (Tensor): placeholder for RNN sequence lengths, of shape
[NUM_SEQUENCES]. Note that NUM_SEQUENCES << BATCH_SIZE. See
models/lstm.py for more information.
@@ -72,9 +70,11 @@ class TFPolicyGraph(PolicyGraph):
self._loss = loss
self._loss_inputs = loss_inputs
self._loss_input_dict = dict(self._loss_inputs)
self._is_training = is_training
self._is_training = tf.placeholder_with_default(True, ())
self._state_inputs = state_inputs or []
self._state_outputs = state_outputs or []
for i, ph in enumerate(self._state_inputs):
self._loss_input_dict["state_in_{}".format(i)] = ph
self._seq_lens = seq_lens
self._max_seq_len = max_seq_len
self._optimizer = self.optimizer()
@@ -99,6 +99,8 @@ class TFPolicyGraph(PolicyGraph):
(self._state_inputs, state_batches)
builder.add_feed_dict(self.extra_compute_action_feed_dict())
builder.add_feed_dict({self._obs_input: obs_batch})
if state_batches:
builder.add_feed_dict({self._seq_lens: np.ones(len(obs_batch))})
builder.add_feed_dict({self._is_training: is_training})
builder.add_feed_dict(dict(zip(self._state_inputs, state_batches)))
fetches = builder.add_fetches(
@@ -123,10 +125,9 @@ class TFPolicyGraph(PolicyGraph):
return feed_dict
# RNN case
feature_keys = [
k for k, v in self._loss_inputs if not k.startswith("state_in_")]
feature_keys = [k for k, v in self._loss_inputs]
state_keys = [
k for k, v in self._loss_inputs if k.startswith("state_in_")]
"state_in_{}".format(i) for i in range(len(self._state_inputs))]
feature_sequences, initial_states, seq_lens = chop_into_sequences(
batch["t"],
[batch[k] for k in feature_keys],
+13 -7
View File
@@ -138,41 +138,47 @@ class ModelCatalog(object):
" not supported".format(action_space))
@staticmethod
def get_model(inputs, num_outputs, options=None):
def get_model(
inputs, num_outputs, options=None, state_in=None, seq_lens=None):
"""Returns a suitable model conforming to given input and output specs.
Args:
inputs (Tensor): The input tensor to the model.
num_outputs (int): The size of the output vector of the model.
options (dict): Optional args to pass to the model constructor.
state_in (list): Optional RNN state in tensors.
seq_in (Tensor): Optional RNN sequence length tensor.
Returns:
model (Model): Neural network model.
"""
options = options or {}
model = ModelCatalog._get_model(inputs, num_outputs, options)
model = ModelCatalog._get_model(
inputs, num_outputs, options, state_in, seq_lens)
if options.get("use_lstm"):
model = LSTM(model.last_layer, num_outputs, options)
model = LSTM(
model.last_layer, num_outputs, options, state_in, seq_lens)
return model
@staticmethod
def _get_model(inputs, num_outputs, options):
def _get_model(inputs, num_outputs, options, state_in, seq_lens):
if "custom_model" in options:
model = options["custom_model"]
print("Using custom model {}".format(model))
return _global_registry.get(RLLIB_MODEL, model)(
inputs, num_outputs, options)
inputs, num_outputs, options,
state_in=state_in, seq_lens=seq_lens)
obs_rank = len(inputs.shape) - 1
# num_outputs > 1 used to avoid hitting this with the value function
if isinstance(options.get("custom_options", {}).get(
"multiagent_fcnet_hiddens", 1), list) and num_outputs > 1:
return MultiAgentFullyConnectedNetwork(inputs,
num_outputs, options)
return MultiAgentFullyConnectedNetwork(
inputs, num_outputs, options)
if obs_rank > 1:
return VisionNetwork(inputs, num_outputs, options)
+9 -4
View File
@@ -41,8 +41,8 @@ def add_time_dimension(padded_inputs, seq_lens):
# Sequence lengths have to be specified for LSTM batch inputs. The
# input batch must be padded to the max seq length given here. That is,
# batch_size == len(seq_lens) * max(seq_lens)
max_seq_len = tf.reduce_max(seq_lens)
padded_batch_size = tf.shape(padded_inputs)[0]
max_seq_len = padded_batch_size // tf.shape(seq_lens)[0]
# Dynamically reshape the padded batch to introduce a time dimension.
new_batch_size = padded_batch_size // max_seq_len
@@ -155,9 +155,14 @@ class LSTM(Model):
np.zeros(lstm.state_size.h, np.float32)]
# Setup LSTM inputs
c_in = tf.placeholder(tf.float32, [None, lstm.state_size.c], name="c")
h_in = tf.placeholder(tf.float32, [None, lstm.state_size.h], name="h")
self.state_in = [c_in, h_in]
if self.state_in:
c_in, h_in = self.state_in
else:
c_in = tf.placeholder(
tf.float32, [None, lstm.state_size.c], name="c")
h_in = tf.placeholder(
tf.float32, [None, lstm.state_size.h], name="h")
self.state_in = [c_in, h_in]
# Setup LSTM outputs
if use_tf100_api:
+8 -6
View File
@@ -37,17 +37,19 @@ class Model(object):
a scale parameter (like a standard deviation).
"""
def __init__(self, inputs, num_outputs, options):
def __init__(
self, inputs, num_outputs, options, state_in=None, seq_lens=None):
self.inputs = inputs
# Default attribute values for the non-RNN case
self.state_init = []
self.state_in = []
self.state_in = state_in or []
self.state_out = []
self.seq_lens = tf.placeholder_with_default(
tf.ones( # reshape needed for older tf versions
tf.reshape(tf.shape(inputs)[0], [1]), dtype=tf.int32),
[None], name="seq_lens")
if seq_lens is not None:
self.seq_lens = seq_lens
else:
self.seq_lens = tf.placeholder(
dtype=tf.int32, shape=[None], name="seq_lens")
if options.get("free_log_std", False):
assert num_outputs % 2 == 0
+68 -65
View File
@@ -3,9 +3,7 @@ from __future__ import division
from __future__ import print_function
from collections import namedtuple
import os
from tensorflow.python.client import timeline
import tensorflow as tf
@@ -34,9 +32,11 @@ class LocalSyncParallelOptimizer(object):
Args:
optimizer: Delegate TensorFlow optimizer object.
devices: List of the names of TensorFlow devices to parallelize over.
input_placeholders: List of (name, input_placeholder)
for the loss function. Tensors of these shapes will be passed
to build_graph() in order to define the per-device loss ops.
input_placeholders: List of input_placeholders for the loss function.
Tensors of these shapes will be passed to build_graph() in order
to define the per-device loss ops.
rnn_inputs: Extra input placeholders for RNN inputs. These will have
shape [BATCH_SIZE // MAX_SEQ_LEN, ...].
per_device_batch_size: Number of tuples to optimize over at a time per
device. In each call to `optimize()`,
`len(devices) * per_device_batch_size` tuples of data will be
@@ -47,7 +47,7 @@ class LocalSyncParallelOptimizer(object):
grad_norm_clipping: None or int stdev to clip grad norms by
"""
def __init__(self, optimizer, devices, input_placeholders,
def __init__(self, optimizer, devices, input_placeholders, rnn_inputs,
per_device_batch_size, build_graph, logdir,
grad_norm_clipping=None):
# TODO(rliaw): remove logdir
@@ -55,27 +55,31 @@ class LocalSyncParallelOptimizer(object):
self.devices = devices
self.batch_size = per_device_batch_size * len(devices)
self.per_device_batch_size = per_device_batch_size
self.loss_inputs = input_placeholders
self.loss_inputs = input_placeholders + rnn_inputs
self.build_graph = build_graph
self.logdir = logdir
# First initialize the shared loss network
with tf.name_scope(TOWER_SCOPE_NAME):
self._shared_loss = build_graph(input_placeholders)
self._shared_loss = build_graph(self.loss_inputs)
# Then setup the per-device loss graphs that use the shared weights
self._batch_index = tf.placeholder(tf.int32)
self._batch_index = tf.placeholder(tf.int32, name="batch_index")
# When loading RNN input, we dynamically determine the max seq len
self._max_seq_len = tf.placeholder(tf.int32, name="max_seq_len")
self._loaded_max_seq_len = 1
# Split on the CPU in case the data doesn't fit in GPU memory.
with tf.device("/cpu:0"):
names, placeholders = zip(*input_placeholders)
data_splits = zip(
*[tf.split(ph, len(devices)) for ph in placeholders])
*[tf.split(ph, len(devices)) for ph in self.loss_inputs])
self._towers = []
for device, device_placeholders in zip(self.devices, data_splits):
self._towers.append(
self._setup_device(device, zip(names, device_placeholders)))
self._setup_device(
device, device_placeholders, len(input_placeholders)))
avg = average_gradients([t.grads for t in self._towers])
if grad_norm_clipping:
@@ -84,7 +88,7 @@ class LocalSyncParallelOptimizer(object):
avg[i] = (tf.clip_by_norm(grad, grad_norm_clipping), var)
self._train_op = self.optimizer.apply_gradients(avg)
def load_data(self, sess, inputs, full_trace=False):
def load_data(self, sess, inputs, state_inputs):
"""Bulk loads the specified inputs into device memory.
The shape of the inputs must conform to the shapes of the input
@@ -95,37 +99,47 @@ class LocalSyncParallelOptimizer(object):
Args:
sess: TensorFlow session.
inputs: List of Tensors matching the input placeholders specified
at construction time of this optimizer.
full_trace: Whether to profile data loading.
inputs: List of arrays matching the input placeholders, of shape
[BATCH_SIZE, ...].
state_inputs: List of RNN input arrays. These arrays have size
[BATCH_SIZE / MAX_SEQ_LEN, ...].
Returns:
The number of tuples loaded per device.
"""
feed_dict = {}
assert len(self.loss_inputs) == len(inputs)
for (name, ph), arr in zip(self.loss_inputs, inputs):
truncated_arr = make_divisible_by(arr, self.batch_size)
feed_dict[ph] = truncated_arr
truncated_len = len(truncated_arr)
assert len(self.loss_inputs) == len(inputs + state_inputs), \
(self.loss_inputs, inputs, state_inputs)
if full_trace:
run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
# The RNN truncation case is more complicated
if len(state_inputs) > 0:
seq_len = len(inputs[0]) // len(state_inputs[0])
self._loaded_max_seq_len = seq_len
assert len(state_inputs[0]) * seq_len == len(inputs[0])
# Make sure the shorter state inputs arrays are evenly divisible
state_inputs = [
make_divisible_by(arr, self.batch_size)
for arr in state_inputs
]
# Then truncate the data inputs to match
inputs = [
arr[:len(state_inputs[0]) * seq_len]
for arr in inputs
]
assert len(state_inputs[0]) * seq_len == len(inputs[0])
assert len(state_inputs[0]) % self.batch_size == 0
for ph, arr in zip(self.loss_inputs, inputs + state_inputs):
feed_dict[ph] = arr
truncated_len = len(inputs[0])
else:
run_options = tf.RunOptions(trace_level=tf.RunOptions.NO_TRACE)
run_metadata = tf.RunMetadata()
for ph, arr in zip(self.loss_inputs, inputs + state_inputs):
truncated_arr = make_divisible_by(arr, self.batch_size)
feed_dict[ph] = truncated_arr
truncated_len = len(truncated_arr)
sess.run(
[t.init_op for t in self._towers],
feed_dict=feed_dict,
options=run_options,
run_metadata=run_metadata)
if full_trace:
trace = timeline.Timeline(step_stats=run_metadata.step_stats)
trace_file = open(os.path.join(self.logdir, "timeline-load.json"),
"w")
trace_file.write(trace.generate_chrome_trace_format())
[t.init_op for t in self._towers], feed_dict=feed_dict)
tuples_per_device = truncated_len / len(self.devices)
assert tuples_per_device > 0, \
@@ -136,7 +150,7 @@ class LocalSyncParallelOptimizer(object):
assert tuples_per_device % self.per_device_batch_size == 0
return tuples_per_device
def optimize(self, sess, batch_index, file_writer=None):
def optimize(self, sess, batch_index):
"""Run a single step of SGD.
Runs a SGD step over a slice of the preloaded batch with size given by
@@ -151,19 +165,14 @@ class LocalSyncParallelOptimizer(object):
batch_index: Offset into the preloaded data. This value must be
between `0` and `tuples_per_device`. The amount of data to
process is always fixed to `per_device_batch_size`.
file_writer: If specified, tf metrics will be written out using
this.
Returns:
The outputs of extra_ops evaluated over the batch.
"""
if file_writer:
run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
else:
run_options = tf.RunOptions(trace_level=tf.RunOptions.NO_TRACE)
run_metadata = tf.RunMetadata()
feed_dict = {self._batch_index: batch_index}
feed_dict = {
self._batch_index: batch_index,
self._max_seq_len: self._loaded_max_seq_len,
}
for tower in self._towers:
feed_dict.update(tower.loss_graph.extra_compute_grad_feed_dict())
feed_dict.update(tower.loss_graph.extra_apply_grad_feed_dict())
@@ -173,21 +182,7 @@ class LocalSyncParallelOptimizer(object):
fetches.update(tower.loss_graph.extra_compute_grad_fetches())
fetches.update(tower.loss_graph.extra_apply_grad_fetches())
outs = sess.run(
fetches,
feed_dict=feed_dict,
options=run_options,
run_metadata=run_metadata)
if file_writer:
trace = timeline.Timeline(step_stats=run_metadata.step_stats)
trace_file = open(os.path.join(self.logdir, "timeline-sgd.json"),
"w")
trace_file.write(trace.generate_chrome_trace_format())
file_writer.add_run_metadata(
run_metadata, "sgd_train_{}".format(batch_index))
return outs
return sess.run(fetches, feed_dict=feed_dict)
def get_common_loss(self):
return self._shared_loss
@@ -195,23 +190,31 @@ class LocalSyncParallelOptimizer(object):
def get_device_losses(self):
return [t.loss_graph for t in self._towers]
def _setup_device(self, device, device_input_placeholders):
def _setup_device(self, device, device_input_placeholders, num_data_in):
assert num_data_in <= len(device_input_placeholders)
with tf.device(device):
with tf.name_scope(TOWER_SCOPE_NAME):
device_input_batches = []
device_input_slices = []
for name, ph in device_input_placeholders:
for i, ph in enumerate(device_input_placeholders):
current_batch = tf.Variable(
ph, trainable=False, validate_shape=False,
collections=[])
device_input_batches.append(current_batch)
if i < num_data_in:
scale = self._max_seq_len
granularity = self._max_seq_len
else:
scale = self._max_seq_len
granularity = 1
current_slice = tf.slice(
current_batch,
[self._batch_index] + [0] * len(ph.shape[1:]),
([self.per_device_batch_size] + [-1] *
len(ph.shape[1:])))
([self._batch_index // scale * granularity] +
[0] * len(ph.shape[1:])),
([self.per_device_batch_size // scale * granularity] +
[-1] * len(ph.shape[1:])))
current_slice.set_shape(ph.shape)
device_input_slices.append((name, current_slice))
device_input_slices.append(current_slice)
graph_obj = self.build_graph(device_input_slices)
device_grads = graph_obj.gradients(self.optimizer)
return Tower(
@@ -55,12 +55,12 @@ class LocalMultiGPUOptimizer(PolicyOptimizer):
print("LocalMultiGPUOptimizer devices", self.devices)
assert set(self.local_evaluator.policy_map.keys()) == {"default"}, \
"Multi-agent is not supported"
("Multi-agent is not supported with multi-GPU. Try using the "
"simple optimizer instead.")
self.policy = self.local_evaluator.policy_map["default"]
assert isinstance(self.policy, TFPolicyGraph), \
"Only TF policies are supported"
assert len(self.policy.get_initial_state()) == 0, \
"No RNN support yet for multi-gpu. Try the simple optimizer."
("Only TF policies are supported with multi-GPU. Try using the "
"simple optimizer instead.")
# per-GPU graph copies created below must share vars with the policy
# reuse is set to AUTO_REUSE because Adam nodes are created after
@@ -68,10 +68,16 @@ class LocalMultiGPUOptimizer(PolicyOptimizer):
with self.local_evaluator.tf_sess.graph.as_default():
with self.local_evaluator.tf_sess.as_default():
with tf.variable_scope("default", reuse=tf.AUTO_REUSE):
if self.policy._state_inputs:
rnn_inputs = self.policy._state_inputs + [
self.policy._seq_lens]
else:
rnn_inputs = []
self.par_opt = LocalSyncParallelOptimizer(
tf.train.AdamOptimizer(self.sgd_stepsize),
self.devices,
self.policy.loss_inputs(),
[v for _, v in self.policy.loss_inputs()],
rnn_inputs,
self.per_device_batch_size,
self.policy.copy,
os.getcwd())
@@ -103,9 +109,17 @@ class LocalMultiGPUOptimizer(PolicyOptimizer):
samples.shuffle()
with self.load_timer:
tuples = self.policy._get_loss_inputs_dict(samples)
data_keys = [ph for _, ph in self.policy.loss_inputs()]
if self.policy._state_inputs:
state_keys = (
self.policy._state_inputs + [self.policy._seq_lens])
else:
state_keys = []
tuples_per_device = self.par_opt.load_data(
self.sess,
samples.columns([key for key, _ in self.policy.loss_inputs()]))
[tuples[k] for k in data_keys],
[tuples[k] for k in state_keys])
with self.grad_timer:
num_batches = (