mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 15:40:09 +08:00
[rllib] Move a3c implementation from examples/ to python/ray/rllib/ (#698)
* rllib v0 * fix imports * lint * comments * update docs * a3c wip * a3c wip * report stats * update doc * name is too long * fix small bug * propagate exception on error * fetch metrics * fix lint
This commit is contained in:
committed by
Philipp Moritz
parent
efce49cfbc
commit
2d81edfcdc
@@ -0,0 +1,112 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
import tensorflow.contrib.rnn as rnn
|
||||
import distutils.version
|
||||
|
||||
from ray.rllib.a3c.policy import (
|
||||
categorical_sample, conv2d, linear, flatten,
|
||||
normalized_columns_initializer, Policy)
|
||||
|
||||
use_tf100_api = (distutils.version.LooseVersion(tf.VERSION) >=
|
||||
distutils.version.LooseVersion("1.0.0"))
|
||||
|
||||
|
||||
class LSTMPolicy(Policy):
|
||||
def setup_graph(self, ob_space, ac_space):
|
||||
"""Setup model used for Policy.
|
||||
|
||||
In this A3C implementation, both the Critic and the Actor share the model.
|
||||
"""
|
||||
self.x = x = tf.placeholder(tf.float32, [None] + list(ob_space))
|
||||
|
||||
for i in range(4):
|
||||
x = tf.nn.elu(conv2d(x, 32, "l{}".format(i + 1), [3, 3], [2, 2]))
|
||||
# Introduce a "fake" batch dimension of 1 after flatten so that we can do
|
||||
# LSTM over the time dim.
|
||||
x = tf.expand_dims(flatten(x), [0])
|
||||
|
||||
size = 256
|
||||
if use_tf100_api:
|
||||
lstm = rnn.BasicLSTMCell(size, state_is_tuple=True)
|
||||
else:
|
||||
lstm = rnn.rnn_cell.BasicLSTMCell(size, state_is_tuple=True)
|
||||
self.state_size = lstm.state_size
|
||||
step_size = tf.shape(self.x)[:1]
|
||||
|
||||
c_init = np.zeros((1, lstm.state_size.c), np.float32)
|
||||
h_init = np.zeros((1, lstm.state_size.h), np.float32)
|
||||
self.state_init = [c_init, h_init]
|
||||
c_in = tf.placeholder(tf.float32, [1, lstm.state_size.c])
|
||||
h_in = tf.placeholder(tf.float32, [1, lstm.state_size.h])
|
||||
self.state_in = [c_in, h_in]
|
||||
|
||||
if use_tf100_api:
|
||||
state_in = rnn.LSTMStateTuple(c_in, h_in)
|
||||
else:
|
||||
state_in = rnn.rnn_cell.LSTMStateTuple(c_in, h_in)
|
||||
lstm_outputs, lstm_state = tf.nn.dynamic_rnn(
|
||||
lstm, x, initial_state=state_in, sequence_length=step_size,
|
||||
time_major=False)
|
||||
lstm_c, lstm_h = lstm_state
|
||||
x = tf.reshape(lstm_outputs, [-1, size])
|
||||
self.logits = linear(x, ac_space, "action",
|
||||
normalized_columns_initializer(0.01))
|
||||
self.vf = tf.reshape(linear(x, 1, "value",
|
||||
normalized_columns_initializer(1.0)), [-1])
|
||||
self.state_out = [lstm_c[:1, :], lstm_h[:1, :]]
|
||||
self.sample = categorical_sample(self.logits, ac_space)[0, :]
|
||||
self.var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
|
||||
tf.get_variable_scope().name)
|
||||
self.global_step = tf.get_variable(
|
||||
"global_step", [], tf.int32,
|
||||
initializer=tf.constant_initializer(0, dtype=tf.int32),
|
||||
trainable=False)
|
||||
|
||||
def get_gradients(self, batch):
|
||||
"""Computing the gradient is actually model-dependent.
|
||||
|
||||
The LSTM needs its hidden states in order to compute the gradient
|
||||
accurately.
|
||||
"""
|
||||
feed_dict = {
|
||||
self.x: batch.si,
|
||||
self.ac: batch.a,
|
||||
self.adv: batch.adv,
|
||||
self.r: batch.r,
|
||||
self.state_in[0]: batch.features[0],
|
||||
self.state_in[1]: batch.features[1]
|
||||
}
|
||||
self.local_steps += 1
|
||||
return self.sess.run(self.grads, feed_dict=feed_dict)
|
||||
|
||||
def act(self, ob, c, h):
|
||||
return self.sess.run([self.sample, self.vf] + self.state_out,
|
||||
{self.x: [ob],
|
||||
self.state_in[0]: c,
|
||||
self.state_in[1]: h})
|
||||
|
||||
def value(self, ob, c, h):
|
||||
return self.sess.run(self.vf, {self.x: [ob],
|
||||
self.state_in[0]: c,
|
||||
self.state_in[1]: h})[0]
|
||||
|
||||
def get_initial_features(self):
|
||||
return self.state_init
|
||||
|
||||
|
||||
class RawLSTMPolicy(LSTMPolicy):
|
||||
def get_weights(self):
|
||||
if not hasattr(self, "_weights"):
|
||||
self._weights = self.variables.get_weights()
|
||||
return self._weights
|
||||
|
||||
def set_weights(self, weights):
|
||||
self._weights = weights
|
||||
|
||||
def model_update(self, grads):
|
||||
for var, grad in zip(self.var_list, grads):
|
||||
self._weights[var.name[:-2]] -= 1e-4 * grad
|
||||
@@ -0,0 +1,3 @@
|
||||
from ray.rllib.a3c.a3c import A3C, DEFAULT_CONFIG
|
||||
|
||||
__all__ = ["A3C", "DEFAULT_CONFIG"]
|
||||
@@ -0,0 +1,126 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
import six.moves.queue as queue
|
||||
import os
|
||||
|
||||
import ray
|
||||
from ray.rllib.a3c.LSTM import LSTMPolicy
|
||||
from ray.rllib.a3c.runner import RunnerThread, process_rollout
|
||||
from ray.rllib.a3c.envs import create_env
|
||||
from ray.rllib.common import Algorithm, TrainingResult
|
||||
|
||||
|
||||
DEFAULT_CONFIG = {
|
||||
"num_workers": 4,
|
||||
"num_batches_per_iteration": 100,
|
||||
}
|
||||
|
||||
|
||||
@ray.remote
|
||||
class Runner(object):
|
||||
"""Actor object to start running simulation on workers.
|
||||
|
||||
The gradient computation is also executed from this object.
|
||||
"""
|
||||
def __init__(self, env_name, actor_id, logdir="/tmp/ray/a3c/", start=True):
|
||||
env = create_env(env_name)
|
||||
self.id = actor_id
|
||||
num_actions = env.action_space.n
|
||||
self.policy = LSTMPolicy(env.observation_space.shape, num_actions,
|
||||
actor_id)
|
||||
self.runner = RunnerThread(env, self.policy, 20)
|
||||
self.env = env
|
||||
self.logdir = logdir
|
||||
if start:
|
||||
self.start()
|
||||
|
||||
def pull_batch_from_queue(self):
|
||||
"""Take a rollout from the queue of the thread runner."""
|
||||
rollout = self.runner.queue.get(timeout=600.0)
|
||||
if isinstance(rollout, BaseException):
|
||||
raise rollout
|
||||
while not rollout.terminal:
|
||||
try:
|
||||
part = self.runner.queue.get_nowait()
|
||||
if isinstance(part, BaseException):
|
||||
raise rollout
|
||||
rollout.extend(part)
|
||||
except queue.Empty:
|
||||
break
|
||||
return rollout
|
||||
|
||||
def get_completed_rollout_metrics(self):
|
||||
"""Returns metrics on previously completed rollouts.
|
||||
|
||||
Calling this clears the queue of completed rollout metrics.
|
||||
"""
|
||||
completed = []
|
||||
while True:
|
||||
try:
|
||||
completed.append(self.runner.metrics_queue.get_nowait())
|
||||
except queue.Empty:
|
||||
break
|
||||
return completed
|
||||
|
||||
def start(self):
|
||||
summary_writer = tf.summary.FileWriter(
|
||||
os.path.join(self.logdir, "agent_%d" % self.id))
|
||||
self.summary_writer = summary_writer
|
||||
self.runner.start_runner(self.policy.sess, summary_writer)
|
||||
|
||||
def compute_gradient(self, params):
|
||||
self.policy.set_weights(params)
|
||||
rollout = self.pull_batch_from_queue()
|
||||
batch = process_rollout(rollout, gamma=0.99, lambda_=1.0)
|
||||
gradient = self.policy.get_gradients(batch)
|
||||
info = {"id": self.id,
|
||||
"size": len(batch.a)}
|
||||
return gradient, info
|
||||
|
||||
|
||||
class A3C(Algorithm):
|
||||
def __init__(self, env_name, config):
|
||||
Algorithm.__init__(self, env_name, config)
|
||||
self.env = create_env(env_name)
|
||||
self.policy = LSTMPolicy(
|
||||
self.env.observation_space.shape, self.env.action_space.n, 0)
|
||||
self.agents = [
|
||||
Runner.remote(env_name, i) for i in range(config["num_workers"])]
|
||||
self.parameters = self.policy.get_weights()
|
||||
self.iteration = 0
|
||||
|
||||
def train(self):
|
||||
gradient_list = [
|
||||
agent.compute_gradient.remote(self.parameters)
|
||||
for agent in self.agents]
|
||||
max_batches = self.config["num_batches_per_iteration"]
|
||||
batches_so_far = len(gradient_list)
|
||||
while gradient_list:
|
||||
done_id, gradient_list = ray.wait(gradient_list)
|
||||
gradient, info = ray.get(done_id)[0]
|
||||
self.policy.model_update(gradient)
|
||||
self.parameters = self.policy.get_weights()
|
||||
if batches_so_far < max_batches:
|
||||
batches_so_far += 1
|
||||
gradient_list.extend(
|
||||
[self.agents[info["id"]].compute_gradient.remote(self.parameters)])
|
||||
res = self.fetch_metrics_from_workers()
|
||||
self.iteration += 1
|
||||
return res
|
||||
|
||||
def fetch_metrics_from_workers(self):
|
||||
episode_rewards = []
|
||||
episode_lengths = []
|
||||
metric_lists = [
|
||||
a.get_completed_rollout_metrics.remote() for a in self.agents]
|
||||
for metrics in metric_lists:
|
||||
for episode in ray.get(metrics):
|
||||
episode_lengths.append(episode.episode_length)
|
||||
episode_rewards.append(episode.episode_reward)
|
||||
res = TrainingResult(
|
||||
self.iteration, np.mean(episode_rewards), np.mean(episode_lengths))
|
||||
return res
|
||||
@@ -0,0 +1,107 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import cv2
|
||||
import gym
|
||||
from gym.spaces.box import Box
|
||||
import logging
|
||||
import numpy as np
|
||||
import time
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
|
||||
def create_env(env_id):
|
||||
env = gym.make(env_id)
|
||||
env = AtariProcessing(env)
|
||||
env = Diagnostic(env)
|
||||
return env
|
||||
|
||||
|
||||
def _process_frame42(frame):
|
||||
frame = frame[34:(34 + 160), :160]
|
||||
# Resize by half, then down to 42x42 (essentially mipmapping). If we resize
|
||||
# directly we lose pixels that, when mapped to 42x42, aren't close enough to
|
||||
# the pixel boundary.
|
||||
frame = cv2.resize(frame, (80, 80))
|
||||
frame = cv2.resize(frame, (42, 42))
|
||||
frame = frame.mean(2)
|
||||
frame = frame.astype(np.float32)
|
||||
frame *= (1.0 / 255.0)
|
||||
frame = np.reshape(frame, [42, 42, 1])
|
||||
return frame
|
||||
|
||||
|
||||
class AtariProcessing(gym.ObservationWrapper):
|
||||
def __init__(self, env=None):
|
||||
super(AtariProcessing, self).__init__(env)
|
||||
self.observation_space = Box(0.0, 1.0, [42, 42, 1])
|
||||
|
||||
def _observation(self, observation):
|
||||
return _process_frame42(observation)
|
||||
|
||||
|
||||
class Diagnostic(gym.Wrapper):
|
||||
def __init__(self, env=None):
|
||||
super(Diagnostic, self).__init__(env)
|
||||
self.diagnostics = DiagnosticsLogger()
|
||||
|
||||
def _reset(self):
|
||||
observation = self.env.reset()
|
||||
return self.diagnostics._after_reset(observation)
|
||||
|
||||
def _step(self, action):
|
||||
results = self.env.step(action)
|
||||
return self.diagnostics._after_step(*results)
|
||||
|
||||
|
||||
class DiagnosticsLogger(object):
|
||||
def __init__(self, log_interval=503):
|
||||
self._episode_time = time.time()
|
||||
self._last_time = time.time()
|
||||
self._local_t = 0
|
||||
self._log_interval = log_interval
|
||||
self._episode_reward = 0
|
||||
self._episode_length = 0
|
||||
self._all_rewards = []
|
||||
self._last_episode_id = -1
|
||||
|
||||
def _after_reset(self, observation):
|
||||
logger.info("Resetting environment")
|
||||
self._episode_reward = 0
|
||||
self._episode_length = 0
|
||||
self._all_rewards = []
|
||||
return observation
|
||||
|
||||
def _after_step(self, observation, reward, done, info):
|
||||
to_log = {}
|
||||
if self._episode_length == 0:
|
||||
self._episode_time = time.time()
|
||||
|
||||
self._local_t += 1
|
||||
|
||||
if self._local_t % self._log_interval == 0:
|
||||
cur_time = time.time()
|
||||
self._last_time = cur_time
|
||||
|
||||
if reward is not None:
|
||||
self._episode_reward += reward
|
||||
if observation is not None:
|
||||
self._episode_length += 1
|
||||
self._all_rewards.append(reward)
|
||||
|
||||
if done:
|
||||
logger.info("Episode terminating: episode_reward=%s episode_length=%s",
|
||||
self._episode_reward, self._episode_length)
|
||||
total_time = time.time() - self._episode_time
|
||||
to_log["global/episode_reward"] = self._episode_reward
|
||||
to_log["global/episode_length"] = self._episode_length
|
||||
to_log["global/episode_time"] = total_time
|
||||
to_log["global/reward_per_time"] = self._episode_reward / total_time
|
||||
self._episode_reward = 0
|
||||
self._episode_length = 0
|
||||
self._all_rewards = []
|
||||
|
||||
return observation, reward, done, to_log
|
||||
Executable
+32
@@ -0,0 +1,32 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import argparse
|
||||
|
||||
import ray
|
||||
from ray.rllib.a3c import A3C, DEFAULT_CONFIG
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Run the A3C algorithm.")
|
||||
parser.add_argument("--environment", default="PongDeterministic-v3",
|
||||
type=str, help="The gym environment to use.")
|
||||
parser.add_argument("--redis-address", default=None, type=str,
|
||||
help="The Redis address of the cluster.")
|
||||
parser.add_argument("--num-workers", default=4, type=int,
|
||||
help="The number of A3C workers to use>")
|
||||
|
||||
args = parser.parse_args()
|
||||
ray.init(redis_address=args.redis_address, num_cpus=args.num_workers)
|
||||
|
||||
config = DEFAULT_CONFIG.copy()
|
||||
config["num_workers"] = args.num_workers
|
||||
|
||||
a3c = A3C(args.environment, config)
|
||||
|
||||
while True:
|
||||
res = a3c.train()
|
||||
print("current status: {}".format(res))
|
||||
@@ -0,0 +1,145 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
import ray
|
||||
|
||||
|
||||
class Policy(object):
|
||||
"""The policy base class."""
|
||||
def __init__(self, ob_space, ac_space, task, name="local"):
|
||||
self.local_steps = 0
|
||||
worker_device = "/job:localhost/replica:0/task:0/cpu:0"
|
||||
self.g = tf.Graph()
|
||||
with self.g.as_default(), tf.device(worker_device):
|
||||
with tf.variable_scope(name):
|
||||
self.setup_graph(ob_space, ac_space)
|
||||
assert all([hasattr(self, attr)
|
||||
for attr in ["vf", "logits", "x", "var_list"]])
|
||||
print("Setting up loss")
|
||||
self.setup_loss(ac_space)
|
||||
self.initialize()
|
||||
|
||||
def setup_graph(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def setup_loss(self, num_actions, summarize=True):
|
||||
self.ac = tf.placeholder(tf.float32, [None, num_actions], name="ac")
|
||||
self.adv = tf.placeholder(tf.float32, [None], name="adv")
|
||||
self.r = tf.placeholder(tf.float32, [None], name="r")
|
||||
|
||||
log_prob_tf = tf.nn.log_softmax(self.logits)
|
||||
prob_tf = tf.nn.softmax(self.logits)
|
||||
|
||||
# The "policy gradients" loss: its derivative is precisely the policy
|
||||
# gradient. Notice that self.ac is a placeholder that is provided
|
||||
# externally. adv will contain the advantages, as calculated in
|
||||
# process_rollout.
|
||||
pi_loss = - tf.reduce_sum(tf.reduce_sum(log_prob_tf * self.ac,
|
||||
[1]) * self.adv)
|
||||
|
||||
# loss of value function
|
||||
vf_loss = 0.5 * tf.reduce_sum(tf.square(self.vf - self.r))
|
||||
vf_loss = tf.Print(vf_loss, [vf_loss], "Value Fn Loss")
|
||||
entropy = - tf.reduce_sum(prob_tf * log_prob_tf)
|
||||
|
||||
bs = tf.to_float(tf.shape(self.x)[0])
|
||||
self.loss = pi_loss + 0.5 * vf_loss - entropy * 0.01
|
||||
|
||||
grads = tf.gradients(self.loss, self.var_list)
|
||||
self.grads, _ = tf.clip_by_global_norm(grads, 40.0)
|
||||
|
||||
grads_and_vars = list(zip(self.grads, self.var_list))
|
||||
opt = tf.train.AdamOptimizer(1e-4)
|
||||
self._apply_gradients = opt.apply_gradients(grads_and_vars)
|
||||
|
||||
if summarize:
|
||||
tf.summary.scalar("model/policy_loss", pi_loss / bs)
|
||||
tf.summary.scalar("model/value_loss", vf_loss / bs)
|
||||
tf.summary.scalar("model/entropy", entropy / bs)
|
||||
tf.summary.image("model/state", self.x)
|
||||
self.summary_op = tf.summary.merge_all()
|
||||
|
||||
def initialize(self):
|
||||
self.sess = tf.Session(graph=self.g, config=tf.ConfigProto(
|
||||
intra_op_parallelism_threads=1, inter_op_parallelism_threads=2))
|
||||
self.variables = ray.experimental.TensorFlowVariables(self.loss, self.sess)
|
||||
self.sess.run(tf.global_variables_initializer())
|
||||
|
||||
def model_update(self, grads):
|
||||
feed_dict = {self.grads[i]: grads[i]
|
||||
for i in range(len(grads))}
|
||||
self.sess.run(self._apply_gradients, feed_dict=feed_dict)
|
||||
|
||||
def get_weights(self):
|
||||
weights = self.variables.get_weights()
|
||||
return weights
|
||||
|
||||
def set_weights(self, weights):
|
||||
self.variables.set_weights(weights)
|
||||
|
||||
def get_gradients(self, batch):
|
||||
raise NotImplementedError
|
||||
|
||||
def get_vf_loss(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def act(self, ob):
|
||||
raise NotImplementedError
|
||||
|
||||
def value(self, ob):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def normalized_columns_initializer(std=1.0):
|
||||
def _initializer(shape, dtype=None, partition_info=None):
|
||||
out = np.random.randn(*shape).astype(np.float32)
|
||||
out *= std / np.sqrt(np.square(out).sum(axis=0, keepdims=True))
|
||||
return tf.constant(out)
|
||||
return _initializer
|
||||
|
||||
|
||||
def flatten(x):
|
||||
return tf.reshape(x, [-1, np.prod(x.get_shape().as_list()[1:])])
|
||||
|
||||
|
||||
def conv2d(x, num_filters, name, filter_size=(3, 3), stride=(1, 1), pad="SAME",
|
||||
dtype=tf.float32, collections=None):
|
||||
with tf.variable_scope(name):
|
||||
stride_shape = [1, stride[0], stride[1], 1]
|
||||
filter_shape = [filter_size[0], filter_size[1], int(x.get_shape()[3]),
|
||||
num_filters]
|
||||
|
||||
# There are "num input feature maps * filter height * filter width"
|
||||
# inputs to each hidden unit.
|
||||
fan_in = np.prod(filter_shape[:3])
|
||||
# Each unit in the lower layer receives a gradient from:
|
||||
# "num output feature maps * filter height * filter width" / pooling size.
|
||||
fan_out = np.prod(filter_shape[:2]) * num_filters
|
||||
# Initialize weights with random weights.
|
||||
w_bound = np.sqrt(6 / (fan_in + fan_out))
|
||||
|
||||
w = tf.get_variable("W", filter_shape, dtype,
|
||||
tf.random_uniform_initializer(-w_bound, w_bound),
|
||||
collections=collections)
|
||||
b = tf.get_variable("b", [1, 1, 1, num_filters],
|
||||
initializer=tf.constant_initializer(0.0),
|
||||
collections=collections)
|
||||
return tf.nn.conv2d(x, w, stride_shape, pad) + b
|
||||
|
||||
|
||||
def linear(x, size, name, initializer=None, bias_init=0):
|
||||
w = tf.get_variable(name + "/w", [x.get_shape()[1], size],
|
||||
initializer=initializer)
|
||||
b = tf.get_variable(name + "/b", [size],
|
||||
initializer=tf.constant_initializer(bias_init))
|
||||
return tf.matmul(x, w) + b
|
||||
|
||||
|
||||
def categorical_sample(logits, d):
|
||||
value = tf.squeeze(tf.multinomial(logits - tf.reduce_max(logits, [1],
|
||||
keep_dims=True),
|
||||
1), [1])
|
||||
return tf.one_hot(value, d)
|
||||
@@ -0,0 +1,182 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from collections import namedtuple
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
import six.moves.queue as queue
|
||||
import scipy.signal
|
||||
import threading
|
||||
|
||||
|
||||
def discount(x, gamma):
|
||||
return scipy.signal.lfilter([1], [1, -gamma], x[::-1], axis=0)[::-1]
|
||||
|
||||
|
||||
def process_rollout(rollout, gamma, lambda_=1.0):
|
||||
"""Given a rollout, compute its returns and the advantage."""
|
||||
batch_si = np.asarray(rollout.states)
|
||||
batch_a = np.asarray(rollout.actions)
|
||||
rewards = np.asarray(rollout.rewards)
|
||||
vpred_t = np.asarray(rollout.values + [rollout.r])
|
||||
|
||||
rewards_plus_v = np.asarray(rollout.rewards + [rollout.r])
|
||||
batch_r = discount(rewards_plus_v, gamma)[:-1]
|
||||
delta_t = rewards + gamma * vpred_t[1:] - vpred_t[:-1]
|
||||
# This formula for the advantage comes "Generalized Advantage Estimation":
|
||||
# https://arxiv.org/abs/1506.02438
|
||||
batch_adv = discount(delta_t, gamma * lambda_)
|
||||
|
||||
features = rollout.features[0]
|
||||
return Batch(batch_si, batch_a, batch_adv, batch_r, rollout.terminal,
|
||||
features)
|
||||
|
||||
|
||||
Batch = namedtuple(
|
||||
"Batch", ["si", "a", "adv", "r", "terminal", "features"])
|
||||
|
||||
CompletedRollout = namedtuple(
|
||||
"CompletedRollout", ["episode_length", "episode_reward"])
|
||||
|
||||
|
||||
class PartialRollout(object):
|
||||
"""A piece of a complete rollout.
|
||||
|
||||
We run our agent, and process its experience once it has processed enough
|
||||
steps.
|
||||
"""
|
||||
def __init__(self):
|
||||
self.states = []
|
||||
self.actions = []
|
||||
self.rewards = []
|
||||
self.values = []
|
||||
self.r = 0.0
|
||||
self.terminal = False
|
||||
self.features = []
|
||||
|
||||
def add(self, state, action, reward, value, terminal, features):
|
||||
self.states += [state]
|
||||
self.actions += [action]
|
||||
self.rewards += [reward]
|
||||
self.values += [value]
|
||||
self.terminal = terminal
|
||||
self.features += [features]
|
||||
|
||||
def extend(self, other):
|
||||
assert not self.terminal
|
||||
self.states.extend(other.states)
|
||||
self.actions.extend(other.actions)
|
||||
self.rewards.extend(other.rewards)
|
||||
self.values.extend(other.values)
|
||||
self.r = other.r
|
||||
self.terminal = other.terminal
|
||||
self.features.extend(other.features)
|
||||
|
||||
|
||||
class RunnerThread(threading.Thread):
|
||||
"""This thread interacts with the environment and tells it what to do."""
|
||||
def __init__(self, env, policy, num_local_steps, visualise=False):
|
||||
threading.Thread.__init__(self)
|
||||
self.queue = queue.Queue(5)
|
||||
self.metrics_queue = queue.Queue()
|
||||
self.num_local_steps = num_local_steps
|
||||
self.env = env
|
||||
self.last_features = None
|
||||
self.policy = policy
|
||||
self.daemon = True
|
||||
self.sess = None
|
||||
self.summary_writer = None
|
||||
self.visualise = visualise
|
||||
|
||||
def start_runner(self, sess, summary_writer):
|
||||
self.sess = sess
|
||||
self.summary_writer = summary_writer
|
||||
self.start()
|
||||
|
||||
def run(self):
|
||||
try:
|
||||
with self.sess.as_default():
|
||||
self._run()
|
||||
except BaseException as e:
|
||||
self.queue.put(e)
|
||||
raise e
|
||||
|
||||
def _run(self):
|
||||
rollout_provider = env_runner(
|
||||
self.env, self.policy, self.num_local_steps,
|
||||
self.summary_writer, self.visualise)
|
||||
while True:
|
||||
# The timeout variable exists because apparently, if one worker dies, the
|
||||
# other workers won't die with it, unless the timeout is set to some
|
||||
# large number. This is an empirical observation.
|
||||
item = next(rollout_provider)
|
||||
if isinstance(item, CompletedRollout):
|
||||
self.metrics_queue.put(item)
|
||||
else:
|
||||
self.queue.put(item, timeout=600.0)
|
||||
|
||||
|
||||
def env_runner(env, policy, num_local_steps, summary_writer, render):
|
||||
"""This implements the logic of the thread runner.
|
||||
|
||||
It continually runs the policy, and as long as the rollout exceeds a certain
|
||||
length, the thread runner appends the policy to the queue.
|
||||
"""
|
||||
last_state = env.reset()
|
||||
timestep_limit = env.spec.tags.get("wrapper_config.TimeLimit"
|
||||
".max_episode_steps")
|
||||
last_features = policy.get_initial_features()
|
||||
length = 0
|
||||
rewards = 0
|
||||
rollout_number = 0
|
||||
|
||||
while True:
|
||||
terminal_end = False
|
||||
rollout = PartialRollout()
|
||||
|
||||
for _ in range(num_local_steps):
|
||||
fetched = policy.act(last_state, *last_features)
|
||||
action, value_, features = fetched[0], fetched[1], fetched[2:]
|
||||
# Argmax to convert from one-hot.
|
||||
state, reward, terminal, info = env.step(action.argmax())
|
||||
if render:
|
||||
env.render()
|
||||
|
||||
length += 1
|
||||
rewards += reward
|
||||
if length >= timestep_limit:
|
||||
terminal = True
|
||||
|
||||
# Collect the experience.
|
||||
rollout.add(last_state, action, reward, value_, terminal, last_features)
|
||||
|
||||
last_state = state
|
||||
last_features = features
|
||||
|
||||
if info:
|
||||
summary = tf.Summary()
|
||||
for k, v in info.items():
|
||||
summary.value.add(tag=k, simple_value=float(v))
|
||||
summary_writer.add_summary(summary, rollout_number)
|
||||
summary_writer.flush()
|
||||
|
||||
if terminal:
|
||||
terminal_end = True
|
||||
yield CompletedRollout(length, rewards)
|
||||
|
||||
if length >= timestep_limit or not env.metadata.get("semantics"
|
||||
".autoreset"):
|
||||
last_state = env.reset()
|
||||
last_features = policy.get_initial_features()
|
||||
rollout_number += 1
|
||||
length = 0
|
||||
rewards = 0
|
||||
break
|
||||
|
||||
if not terminal_end:
|
||||
rollout.r = policy.value(last_state, *last_features)
|
||||
|
||||
# Once we have enough experience, yield it, and have the ThreadRunner
|
||||
# place it on a queue.
|
||||
yield rollout
|
||||
Reference in New Issue
Block a user