Files
ray/python/ray/rllib/evaluation/tf_policy_graph.py
T
Eric Liang d01dc9e22d [rllib] format with yapf (#2427)
* initial yapf

* manual fix yapf bugs
2018-07-19 15:30:36 -07:00

228 lines
9.1 KiB
Python

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
import numpy as np
import ray
from ray.rllib.evaluation.policy_graph import PolicyGraph
from ray.rllib.models.lstm import chop_into_sequences
from ray.rllib.utils.tf_run_builder import TFRunBuilder
class TFPolicyGraph(PolicyGraph):
"""An agent policy and loss implemented in TensorFlow.
Extending this class enables RLlib to perform TensorFlow specific
optimizations on the policy graph, e.g., parallelization across gpus or
fusing multiple graphs together in the multi-agent setting.
Input tensors are typically shaped like [BATCH_SIZE, ...].
Attributes:
observation_space (gym.Space): observation space of the policy.
action_space (gym.Space): action space of the policy.
Examples:
>>> policy = TFPolicyGraphSubclass(
sess, obs_input, action_sampler, loss, loss_inputs, is_training)
>>> print(policy.compute_actions([1, 0, 2]))
(array([0, 1, 1]), [], {})
>>> print(policy.postprocess_trajectory(SampleBatch({...})))
SampleBatch({"action": ..., "advantages": ..., ...})
"""
def __init__(self,
observation_space,
action_space,
sess,
obs_input,
action_sampler,
loss,
loss_inputs,
state_inputs=None,
state_outputs=None,
seq_lens=None,
max_seq_len=20):
"""Initialize the policy graph.
Arguments:
observation_space (gym.Space): Observation space of the env.
action_space (gym.Space): Action space of the env.
sess (Session): TensorFlow session to use.
obs_input (Tensor): input placeholder for observations, of shape
[BATCH_SIZE, obs...].
action_sampler (Tensor): Tensor for sampling an action, of shape
[BATCH_SIZE, action...]
loss (Tensor): scalar policy loss output tensor.
loss_inputs (list): a (name, placeholder) tuple for each loss
input argument. Each placeholder name must correspond to a
SampleBatch column key returned by postprocess_trajectory(),
and has shape [BATCH_SIZE, data...].
state_inputs (list): list of RNN state input Tensors.
state_outputs (list): list of RNN state output Tensors.
seq_lens (Tensor): placeholder for RNN sequence lengths, of shape
[NUM_SEQUENCES]. Note that NUM_SEQUENCES << BATCH_SIZE. See
models/lstm.py for more information.
max_seq_len (int): max sequence length for LSTM training.
"""
self.observation_space = observation_space
self.action_space = action_space
self._sess = sess
self._obs_input = obs_input
self._sampler = action_sampler
self._loss = loss
self._loss_inputs = loss_inputs
self._loss_input_dict = dict(self._loss_inputs)
self._is_training = tf.placeholder_with_default(True, ())
self._state_inputs = state_inputs or []
self._state_outputs = state_outputs or []
for i, ph in enumerate(self._state_inputs):
self._loss_input_dict["state_in_{}".format(i)] = ph
self._seq_lens = seq_lens
self._max_seq_len = max_seq_len
self._optimizer = self.optimizer()
self._grads_and_vars = [(g, v)
for (g, v) in self.gradients(self._optimizer)
if g is not None]
self._grads = [g for (g, v) in self._grads_and_vars]
self._apply_op = self._optimizer.apply_gradients(self._grads_and_vars)
self._variables = ray.experimental.TensorFlowVariables(
self._loss, self._sess)
assert len(self._state_inputs) == len(self._state_outputs) == \
len(self.get_initial_state()), \
(self._state_inputs, self._state_outputs, self.get_initial_state())
if self._state_inputs:
assert self._seq_lens is not None
def build_compute_actions(self,
builder,
obs_batch,
state_batches=None,
is_training=False):
state_batches = state_batches or []
assert len(self._state_inputs) == len(state_batches), \
(self._state_inputs, state_batches)
builder.add_feed_dict(self.extra_compute_action_feed_dict())
builder.add_feed_dict({self._obs_input: obs_batch})
if state_batches:
builder.add_feed_dict({self._seq_lens: np.ones(len(obs_batch))})
builder.add_feed_dict({self._is_training: is_training})
builder.add_feed_dict(dict(zip(self._state_inputs, state_batches)))
fetches = builder.add_fetches([self._sampler] + self._state_outputs +
[self.extra_compute_action_fetches()])
return fetches[0], fetches[1:-1], fetches[-1]
def compute_actions(self, obs_batch, state_batches=None,
is_training=False):
builder = TFRunBuilder(self._sess, "compute_actions")
fetches = self.build_compute_actions(builder, obs_batch, state_batches,
is_training)
return builder.get(fetches)
def _get_loss_inputs_dict(self, batch):
feed_dict = {}
# Simple case
if not self._state_inputs:
for k, ph in self._loss_inputs:
feed_dict[ph] = batch[k]
return feed_dict
# RNN case
feature_keys = [k for k, v in self._loss_inputs]
state_keys = [
"state_in_{}".format(i) for i in range(len(self._state_inputs))
]
feature_sequences, initial_states, seq_lens = chop_into_sequences(
batch["t"], [batch[k] for k in feature_keys],
[batch[k] for k in state_keys], self._max_seq_len)
for k, v in zip(feature_keys, feature_sequences):
feed_dict[self._loss_input_dict[k]] = v
for k, v in zip(state_keys, initial_states):
feed_dict[self._loss_input_dict[k]] = v
feed_dict[self._seq_lens] = seq_lens
return feed_dict
def build_compute_gradients(self, builder, postprocessed_batch):
builder.add_feed_dict(self.extra_compute_grad_feed_dict())
builder.add_feed_dict({self._is_training: True})
builder.add_feed_dict(self._get_loss_inputs_dict(postprocessed_batch))
fetches = builder.add_fetches(
[self._grads, self.extra_compute_grad_fetches()])
return fetches[0], fetches[1]
def compute_gradients(self, postprocessed_batch):
builder = TFRunBuilder(self._sess, "compute_gradients")
fetches = self.build_compute_gradients(builder, postprocessed_batch)
return builder.get(fetches)
def build_apply_gradients(self, builder, gradients):
assert len(gradients) == len(self._grads), (gradients, self._grads)
builder.add_feed_dict(self.extra_apply_grad_feed_dict())
builder.add_feed_dict({self._is_training: True})
builder.add_feed_dict(dict(zip(self._grads, gradients)))
fetches = builder.add_fetches(
[self._apply_op, self.extra_apply_grad_fetches()])
return fetches[1]
def apply_gradients(self, gradients):
builder = TFRunBuilder(self._sess, "apply_gradients")
fetches = self.build_apply_gradients(builder, gradients)
return builder.get(fetches)
def build_compute_apply(self, builder, postprocessed_batch):
builder.add_feed_dict(self.extra_compute_grad_feed_dict())
builder.add_feed_dict(self.extra_apply_grad_feed_dict())
builder.add_feed_dict(self._get_loss_inputs_dict(postprocessed_batch))
builder.add_feed_dict({self._is_training: True})
fetches = builder.add_fetches([
self._apply_op,
self.extra_compute_grad_fetches(),
self.extra_apply_grad_fetches()
])
return fetches[1], fetches[2]
def compute_apply(self, postprocessed_batch):
builder = TFRunBuilder(self._sess, "compute_apply")
fetches = self.build_compute_apply(builder, postprocessed_batch)
return builder.get(fetches)
def get_weights(self):
return self._variables.get_flat()
def set_weights(self, weights):
return self._variables.set_flat(weights)
def extra_compute_action_feed_dict(self):
return {}
def extra_compute_action_fetches(self):
return {} # e.g, value function
def extra_compute_grad_feed_dict(self):
return {} # e.g, kl_coeff
def extra_compute_grad_fetches(self):
return {} # e.g, td error
def extra_apply_grad_feed_dict(self):
return {}
def extra_apply_grad_fetches(self):
return {} # e.g., batch norm updates
def optimizer(self):
return tf.train.AdamOptimizer()
def gradients(self, optimizer):
return optimizer.compute_gradients(self._loss)
def loss_inputs(self):
return self._loss_inputs