[rllib] Move a3c implementation from examples/ to python/ray/rllib/ (#698)

* rllib v0

* fix imports

* lint

* comments

* update docs

* a3c wip

* a3c wip

* report stats

* update doc

* name is too long

* fix small bug

* propagate exception on error

* fetch metrics

* fix lint
This commit is contained in:
Eric Liang
2017-06-29 08:49:56 -07:00
committed by Philipp Moritz
parent efce49cfbc
commit 2d81edfcdc
10 changed files with 199 additions and 133 deletions
+112
View File
@@ -0,0 +1,112 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import tensorflow as tf
import tensorflow.contrib.rnn as rnn
import distutils.version
from ray.rllib.a3c.policy import (
categorical_sample, conv2d, linear, flatten,
normalized_columns_initializer, Policy)
use_tf100_api = (distutils.version.LooseVersion(tf.VERSION) >=
distutils.version.LooseVersion("1.0.0"))
class LSTMPolicy(Policy):
def setup_graph(self, ob_space, ac_space):
"""Setup model used for Policy.
In this A3C implementation, both the Critic and the Actor share the model.
"""
self.x = x = tf.placeholder(tf.float32, [None] + list(ob_space))
for i in range(4):
x = tf.nn.elu(conv2d(x, 32, "l{}".format(i + 1), [3, 3], [2, 2]))
# Introduce a "fake" batch dimension of 1 after flatten so that we can do
# LSTM over the time dim.
x = tf.expand_dims(flatten(x), [0])
size = 256
if use_tf100_api:
lstm = rnn.BasicLSTMCell(size, state_is_tuple=True)
else:
lstm = rnn.rnn_cell.BasicLSTMCell(size, state_is_tuple=True)
self.state_size = lstm.state_size
step_size = tf.shape(self.x)[:1]
c_init = np.zeros((1, lstm.state_size.c), np.float32)
h_init = np.zeros((1, lstm.state_size.h), np.float32)
self.state_init = [c_init, h_init]
c_in = tf.placeholder(tf.float32, [1, lstm.state_size.c])
h_in = tf.placeholder(tf.float32, [1, lstm.state_size.h])
self.state_in = [c_in, h_in]
if use_tf100_api:
state_in = rnn.LSTMStateTuple(c_in, h_in)
else:
state_in = rnn.rnn_cell.LSTMStateTuple(c_in, h_in)
lstm_outputs, lstm_state = tf.nn.dynamic_rnn(
lstm, x, initial_state=state_in, sequence_length=step_size,
time_major=False)
lstm_c, lstm_h = lstm_state
x = tf.reshape(lstm_outputs, [-1, size])
self.logits = linear(x, ac_space, "action",
normalized_columns_initializer(0.01))
self.vf = tf.reshape(linear(x, 1, "value",
normalized_columns_initializer(1.0)), [-1])
self.state_out = [lstm_c[:1, :], lstm_h[:1, :]]
self.sample = categorical_sample(self.logits, ac_space)[0, :]
self.var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
tf.get_variable_scope().name)
self.global_step = tf.get_variable(
"global_step", [], tf.int32,
initializer=tf.constant_initializer(0, dtype=tf.int32),
trainable=False)
def get_gradients(self, batch):
"""Computing the gradient is actually model-dependent.
The LSTM needs its hidden states in order to compute the gradient
accurately.
"""
feed_dict = {
self.x: batch.si,
self.ac: batch.a,
self.adv: batch.adv,
self.r: batch.r,
self.state_in[0]: batch.features[0],
self.state_in[1]: batch.features[1]
}
self.local_steps += 1
return self.sess.run(self.grads, feed_dict=feed_dict)
def act(self, ob, c, h):
return self.sess.run([self.sample, self.vf] + self.state_out,
{self.x: [ob],
self.state_in[0]: c,
self.state_in[1]: h})
def value(self, ob, c, h):
return self.sess.run(self.vf, {self.x: [ob],
self.state_in[0]: c,
self.state_in[1]: h})[0]
def get_initial_features(self):
return self.state_init
class RawLSTMPolicy(LSTMPolicy):
def get_weights(self):
if not hasattr(self, "_weights"):
self._weights = self.variables.get_weights()
return self._weights
def set_weights(self, weights):
self._weights = weights
def model_update(self, grads):
for var, grad in zip(self.var_list, grads):
self._weights[var.name[:-2]] -= 1e-4 * grad
+3
View File
@@ -0,0 +1,3 @@
from ray.rllib.a3c.a3c import A3C, DEFAULT_CONFIG
__all__ = ["A3C", "DEFAULT_CONFIG"]
+126
View File
@@ -0,0 +1,126 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import tensorflow as tf
import six.moves.queue as queue
import os
import ray
from ray.rllib.a3c.LSTM import LSTMPolicy
from ray.rllib.a3c.runner import RunnerThread, process_rollout
from ray.rllib.a3c.envs import create_env
from ray.rllib.common import Algorithm, TrainingResult
DEFAULT_CONFIG = {
"num_workers": 4,
"num_batches_per_iteration": 100,
}
@ray.remote
class Runner(object):
"""Actor object to start running simulation on workers.
The gradient computation is also executed from this object.
"""
def __init__(self, env_name, actor_id, logdir="/tmp/ray/a3c/", start=True):
env = create_env(env_name)
self.id = actor_id
num_actions = env.action_space.n
self.policy = LSTMPolicy(env.observation_space.shape, num_actions,
actor_id)
self.runner = RunnerThread(env, self.policy, 20)
self.env = env
self.logdir = logdir
if start:
self.start()
def pull_batch_from_queue(self):
"""Take a rollout from the queue of the thread runner."""
rollout = self.runner.queue.get(timeout=600.0)
if isinstance(rollout, BaseException):
raise rollout
while not rollout.terminal:
try:
part = self.runner.queue.get_nowait()
if isinstance(part, BaseException):
raise rollout
rollout.extend(part)
except queue.Empty:
break
return rollout
def get_completed_rollout_metrics(self):
"""Returns metrics on previously completed rollouts.
Calling this clears the queue of completed rollout metrics.
"""
completed = []
while True:
try:
completed.append(self.runner.metrics_queue.get_nowait())
except queue.Empty:
break
return completed
def start(self):
summary_writer = tf.summary.FileWriter(
os.path.join(self.logdir, "agent_%d" % self.id))
self.summary_writer = summary_writer
self.runner.start_runner(self.policy.sess, summary_writer)
def compute_gradient(self, params):
self.policy.set_weights(params)
rollout = self.pull_batch_from_queue()
batch = process_rollout(rollout, gamma=0.99, lambda_=1.0)
gradient = self.policy.get_gradients(batch)
info = {"id": self.id,
"size": len(batch.a)}
return gradient, info
class A3C(Algorithm):
def __init__(self, env_name, config):
Algorithm.__init__(self, env_name, config)
self.env = create_env(env_name)
self.policy = LSTMPolicy(
self.env.observation_space.shape, self.env.action_space.n, 0)
self.agents = [
Runner.remote(env_name, i) for i in range(config["num_workers"])]
self.parameters = self.policy.get_weights()
self.iteration = 0
def train(self):
gradient_list = [
agent.compute_gradient.remote(self.parameters)
for agent in self.agents]
max_batches = self.config["num_batches_per_iteration"]
batches_so_far = len(gradient_list)
while gradient_list:
done_id, gradient_list = ray.wait(gradient_list)
gradient, info = ray.get(done_id)[0]
self.policy.model_update(gradient)
self.parameters = self.policy.get_weights()
if batches_so_far < max_batches:
batches_so_far += 1
gradient_list.extend(
[self.agents[info["id"]].compute_gradient.remote(self.parameters)])
res = self.fetch_metrics_from_workers()
self.iteration += 1
return res
def fetch_metrics_from_workers(self):
episode_rewards = []
episode_lengths = []
metric_lists = [
a.get_completed_rollout_metrics.remote() for a in self.agents]
for metrics in metric_lists:
for episode in ray.get(metrics):
episode_lengths.append(episode.episode_length)
episode_rewards.append(episode.episode_reward)
res = TrainingResult(
self.iteration, np.mean(episode_rewards), np.mean(episode_lengths))
return res
+107
View File
@@ -0,0 +1,107 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import cv2
import gym
from gym.spaces.box import Box
import logging
import numpy as np
import time
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
def create_env(env_id):
env = gym.make(env_id)
env = AtariProcessing(env)
env = Diagnostic(env)
return env
def _process_frame42(frame):
frame = frame[34:(34 + 160), :160]
# Resize by half, then down to 42x42 (essentially mipmapping). If we resize
# directly we lose pixels that, when mapped to 42x42, aren't close enough to
# the pixel boundary.
frame = cv2.resize(frame, (80, 80))
frame = cv2.resize(frame, (42, 42))
frame = frame.mean(2)
frame = frame.astype(np.float32)
frame *= (1.0 / 255.0)
frame = np.reshape(frame, [42, 42, 1])
return frame
class AtariProcessing(gym.ObservationWrapper):
def __init__(self, env=None):
super(AtariProcessing, self).__init__(env)
self.observation_space = Box(0.0, 1.0, [42, 42, 1])
def _observation(self, observation):
return _process_frame42(observation)
class Diagnostic(gym.Wrapper):
def __init__(self, env=None):
super(Diagnostic, self).__init__(env)
self.diagnostics = DiagnosticsLogger()
def _reset(self):
observation = self.env.reset()
return self.diagnostics._after_reset(observation)
def _step(self, action):
results = self.env.step(action)
return self.diagnostics._after_step(*results)
class DiagnosticsLogger(object):
def __init__(self, log_interval=503):
self._episode_time = time.time()
self._last_time = time.time()
self._local_t = 0
self._log_interval = log_interval
self._episode_reward = 0
self._episode_length = 0
self._all_rewards = []
self._last_episode_id = -1
def _after_reset(self, observation):
logger.info("Resetting environment")
self._episode_reward = 0
self._episode_length = 0
self._all_rewards = []
return observation
def _after_step(self, observation, reward, done, info):
to_log = {}
if self._episode_length == 0:
self._episode_time = time.time()
self._local_t += 1
if self._local_t % self._log_interval == 0:
cur_time = time.time()
self._last_time = cur_time
if reward is not None:
self._episode_reward += reward
if observation is not None:
self._episode_length += 1
self._all_rewards.append(reward)
if done:
logger.info("Episode terminating: episode_reward=%s episode_length=%s",
self._episode_reward, self._episode_length)
total_time = time.time() - self._episode_time
to_log["global/episode_reward"] = self._episode_reward
to_log["global/episode_length"] = self._episode_length
to_log["global/episode_time"] = total_time
to_log["global/reward_per_time"] = self._episode_reward / total_time
self._episode_reward = 0
self._episode_length = 0
self._all_rewards = []
return observation, reward, done, to_log
+32
View File
@@ -0,0 +1,32 @@
#!/usr/bin/env python
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import ray
from ray.rllib.a3c import A3C, DEFAULT_CONFIG
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run the A3C algorithm.")
parser.add_argument("--environment", default="PongDeterministic-v3",
type=str, help="The gym environment to use.")
parser.add_argument("--redis-address", default=None, type=str,
help="The Redis address of the cluster.")
parser.add_argument("--num-workers", default=4, type=int,
help="The number of A3C workers to use>")
args = parser.parse_args()
ray.init(redis_address=args.redis_address, num_cpus=args.num_workers)
config = DEFAULT_CONFIG.copy()
config["num_workers"] = args.num_workers
a3c = A3C(args.environment, config)
while True:
res = a3c.train()
print("current status: {}".format(res))
+145
View File
@@ -0,0 +1,145 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import tensorflow as tf
import ray
class Policy(object):
"""The policy base class."""
def __init__(self, ob_space, ac_space, task, name="local"):
self.local_steps = 0
worker_device = "/job:localhost/replica:0/task:0/cpu:0"
self.g = tf.Graph()
with self.g.as_default(), tf.device(worker_device):
with tf.variable_scope(name):
self.setup_graph(ob_space, ac_space)
assert all([hasattr(self, attr)
for attr in ["vf", "logits", "x", "var_list"]])
print("Setting up loss")
self.setup_loss(ac_space)
self.initialize()
def setup_graph(self):
raise NotImplementedError
def setup_loss(self, num_actions, summarize=True):
self.ac = tf.placeholder(tf.float32, [None, num_actions], name="ac")
self.adv = tf.placeholder(tf.float32, [None], name="adv")
self.r = tf.placeholder(tf.float32, [None], name="r")
log_prob_tf = tf.nn.log_softmax(self.logits)
prob_tf = tf.nn.softmax(self.logits)
# The "policy gradients" loss: its derivative is precisely the policy
# gradient. Notice that self.ac is a placeholder that is provided
# externally. adv will contain the advantages, as calculated in
# process_rollout.
pi_loss = - tf.reduce_sum(tf.reduce_sum(log_prob_tf * self.ac,
[1]) * self.adv)
# loss of value function
vf_loss = 0.5 * tf.reduce_sum(tf.square(self.vf - self.r))
vf_loss = tf.Print(vf_loss, [vf_loss], "Value Fn Loss")
entropy = - tf.reduce_sum(prob_tf * log_prob_tf)
bs = tf.to_float(tf.shape(self.x)[0])
self.loss = pi_loss + 0.5 * vf_loss - entropy * 0.01
grads = tf.gradients(self.loss, self.var_list)
self.grads, _ = tf.clip_by_global_norm(grads, 40.0)
grads_and_vars = list(zip(self.grads, self.var_list))
opt = tf.train.AdamOptimizer(1e-4)
self._apply_gradients = opt.apply_gradients(grads_and_vars)
if summarize:
tf.summary.scalar("model/policy_loss", pi_loss / bs)
tf.summary.scalar("model/value_loss", vf_loss / bs)
tf.summary.scalar("model/entropy", entropy / bs)
tf.summary.image("model/state", self.x)
self.summary_op = tf.summary.merge_all()
def initialize(self):
self.sess = tf.Session(graph=self.g, config=tf.ConfigProto(
intra_op_parallelism_threads=1, inter_op_parallelism_threads=2))
self.variables = ray.experimental.TensorFlowVariables(self.loss, self.sess)
self.sess.run(tf.global_variables_initializer())
def model_update(self, grads):
feed_dict = {self.grads[i]: grads[i]
for i in range(len(grads))}
self.sess.run(self._apply_gradients, feed_dict=feed_dict)
def get_weights(self):
weights = self.variables.get_weights()
return weights
def set_weights(self, weights):
self.variables.set_weights(weights)
def get_gradients(self, batch):
raise NotImplementedError
def get_vf_loss(self):
raise NotImplementedError
def act(self, ob):
raise NotImplementedError
def value(self, ob):
raise NotImplementedError
def normalized_columns_initializer(std=1.0):
def _initializer(shape, dtype=None, partition_info=None):
out = np.random.randn(*shape).astype(np.float32)
out *= std / np.sqrt(np.square(out).sum(axis=0, keepdims=True))
return tf.constant(out)
return _initializer
def flatten(x):
return tf.reshape(x, [-1, np.prod(x.get_shape().as_list()[1:])])
def conv2d(x, num_filters, name, filter_size=(3, 3), stride=(1, 1), pad="SAME",
dtype=tf.float32, collections=None):
with tf.variable_scope(name):
stride_shape = [1, stride[0], stride[1], 1]
filter_shape = [filter_size[0], filter_size[1], int(x.get_shape()[3]),
num_filters]
# There are "num input feature maps * filter height * filter width"
# inputs to each hidden unit.
fan_in = np.prod(filter_shape[:3])
# Each unit in the lower layer receives a gradient from:
# "num output feature maps * filter height * filter width" / pooling size.
fan_out = np.prod(filter_shape[:2]) * num_filters
# Initialize weights with random weights.
w_bound = np.sqrt(6 / (fan_in + fan_out))
w = tf.get_variable("W", filter_shape, dtype,
tf.random_uniform_initializer(-w_bound, w_bound),
collections=collections)
b = tf.get_variable("b", [1, 1, 1, num_filters],
initializer=tf.constant_initializer(0.0),
collections=collections)
return tf.nn.conv2d(x, w, stride_shape, pad) + b
def linear(x, size, name, initializer=None, bias_init=0):
w = tf.get_variable(name + "/w", [x.get_shape()[1], size],
initializer=initializer)
b = tf.get_variable(name + "/b", [size],
initializer=tf.constant_initializer(bias_init))
return tf.matmul(x, w) + b
def categorical_sample(logits, d):
value = tf.squeeze(tf.multinomial(logits - tf.reduce_max(logits, [1],
keep_dims=True),
1), [1])
return tf.one_hot(value, d)
+182
View File
@@ -0,0 +1,182 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from collections import namedtuple
import numpy as np
import tensorflow as tf
import six.moves.queue as queue
import scipy.signal
import threading
def discount(x, gamma):
return scipy.signal.lfilter([1], [1, -gamma], x[::-1], axis=0)[::-1]
def process_rollout(rollout, gamma, lambda_=1.0):
"""Given a rollout, compute its returns and the advantage."""
batch_si = np.asarray(rollout.states)
batch_a = np.asarray(rollout.actions)
rewards = np.asarray(rollout.rewards)
vpred_t = np.asarray(rollout.values + [rollout.r])
rewards_plus_v = np.asarray(rollout.rewards + [rollout.r])
batch_r = discount(rewards_plus_v, gamma)[:-1]
delta_t = rewards + gamma * vpred_t[1:] - vpred_t[:-1]
# This formula for the advantage comes "Generalized Advantage Estimation":
# https://arxiv.org/abs/1506.02438
batch_adv = discount(delta_t, gamma * lambda_)
features = rollout.features[0]
return Batch(batch_si, batch_a, batch_adv, batch_r, rollout.terminal,
features)
Batch = namedtuple(
"Batch", ["si", "a", "adv", "r", "terminal", "features"])
CompletedRollout = namedtuple(
"CompletedRollout", ["episode_length", "episode_reward"])
class PartialRollout(object):
"""A piece of a complete rollout.
We run our agent, and process its experience once it has processed enough
steps.
"""
def __init__(self):
self.states = []
self.actions = []
self.rewards = []
self.values = []
self.r = 0.0
self.terminal = False
self.features = []
def add(self, state, action, reward, value, terminal, features):
self.states += [state]
self.actions += [action]
self.rewards += [reward]
self.values += [value]
self.terminal = terminal
self.features += [features]
def extend(self, other):
assert not self.terminal
self.states.extend(other.states)
self.actions.extend(other.actions)
self.rewards.extend(other.rewards)
self.values.extend(other.values)
self.r = other.r
self.terminal = other.terminal
self.features.extend(other.features)
class RunnerThread(threading.Thread):
"""This thread interacts with the environment and tells it what to do."""
def __init__(self, env, policy, num_local_steps, visualise=False):
threading.Thread.__init__(self)
self.queue = queue.Queue(5)
self.metrics_queue = queue.Queue()
self.num_local_steps = num_local_steps
self.env = env
self.last_features = None
self.policy = policy
self.daemon = True
self.sess = None
self.summary_writer = None
self.visualise = visualise
def start_runner(self, sess, summary_writer):
self.sess = sess
self.summary_writer = summary_writer
self.start()
def run(self):
try:
with self.sess.as_default():
self._run()
except BaseException as e:
self.queue.put(e)
raise e
def _run(self):
rollout_provider = env_runner(
self.env, self.policy, self.num_local_steps,
self.summary_writer, self.visualise)
while True:
# The timeout variable exists because apparently, if one worker dies, the
# other workers won't die with it, unless the timeout is set to some
# large number. This is an empirical observation.
item = next(rollout_provider)
if isinstance(item, CompletedRollout):
self.metrics_queue.put(item)
else:
self.queue.put(item, timeout=600.0)
def env_runner(env, policy, num_local_steps, summary_writer, render):
"""This implements the logic of the thread runner.
It continually runs the policy, and as long as the rollout exceeds a certain
length, the thread runner appends the policy to the queue.
"""
last_state = env.reset()
timestep_limit = env.spec.tags.get("wrapper_config.TimeLimit"
".max_episode_steps")
last_features = policy.get_initial_features()
length = 0
rewards = 0
rollout_number = 0
while True:
terminal_end = False
rollout = PartialRollout()
for _ in range(num_local_steps):
fetched = policy.act(last_state, *last_features)
action, value_, features = fetched[0], fetched[1], fetched[2:]
# Argmax to convert from one-hot.
state, reward, terminal, info = env.step(action.argmax())
if render:
env.render()
length += 1
rewards += reward
if length >= timestep_limit:
terminal = True
# Collect the experience.
rollout.add(last_state, action, reward, value_, terminal, last_features)
last_state = state
last_features = features
if info:
summary = tf.Summary()
for k, v in info.items():
summary.value.add(tag=k, simple_value=float(v))
summary_writer.add_summary(summary, rollout_number)
summary_writer.flush()
if terminal:
terminal_end = True
yield CompletedRollout(length, rewards)
if length >= timestep_limit or not env.metadata.get("semantics"
".autoreset"):
last_state = env.reset()
last_features = policy.get_initial_features()
rollout_number += 1
length = 0
rewards = 0
break
if not terminal_end:
rollout.r = policy.value(last_state, *last_features)
# Once we have enough experience, yield it, and have the ThreadRunner
# place it on a queue.
yield rollout