[rllib] PPO and A3C unification (#1253)

This commit is contained in:
Richard Liaw
2017-12-14 01:08:23 -08:00
committed by GitHub
parent 2f750e9ba7
commit c5c83a4465
19 changed files with 291 additions and 350 deletions
+9 -5
View File
@@ -8,8 +8,8 @@ import os
import ray
from ray.rllib.agent import Agent
from ray.rllib.a3c.envs import create_and_wrap
from ray.rllib.a3c.runner import RemoteRunner
from ray.rllib.envs import create_and_wrap
from ray.rllib.a3c.runner import RemoteA3CEvaluator
from ray.rllib.a3c.common import get_policy_cls
from ray.rllib.utils.filter import get_filter
from ray.tune.result import TrainingResult
@@ -49,7 +49,8 @@ class A3CAgent(Agent):
self.env.observation_space.shape)
self.rew_filter = get_filter(self.config["reward_filter"], ())
self.agents = [
RemoteRunner.remote(self.env_creator, self.config, self.logdir)
RemoteA3CEvaluator.remote(
self.env_creator, self.config, self.logdir)
for i in range(self.config["num_workers"])]
self.parameters = self.policy.get_weights()
@@ -105,6 +106,7 @@ class A3CAgent(Agent):
return result
def _save(self):
# TODO(rliaw): extend to also support saving worker state?
checkpoint_path = os.path.join(
self.logdir, "checkpoint-{}".format(self.iteration))
objects = [self.parameters, self.obs_filter, self.rew_filter]
@@ -118,6 +120,8 @@ class A3CAgent(Agent):
self.rew_filter = objects[2]
self.policy.set_weights(self.parameters)
# TODO(rliaw): augment to support LSTM
def compute_action(self, observation):
actions = self.policy.compute_action(observation)
return actions[0]
obs = self.obs_filter(observation, update=False)
action, info = self.policy.compute(obs)
return action
-35
View File
@@ -2,37 +2,6 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import scipy.signal
from collections import namedtuple
def discount(x, gamma):
return scipy.signal.lfilter([1], [1, -gamma], x[::-1], axis=0)[::-1]
def process_rollout(rollout, reward_filter, gamma, lambda_=1.0):
"""Given a rollout, compute its returns and the advantage.
TODO(rliaw): generalize this"""
batch_si = np.asarray(rollout.data["state"])
batch_a = np.asarray(rollout.data["action"])
rewards = np.asarray(rollout.data["reward"])
vpred_t = np.asarray(rollout.data["value"] + [rollout.last_r])
rewards_plus_v = np.asarray(rollout.data["reward"] + [rollout.last_r])
batch_r = discount(rewards_plus_v, gamma)[:-1]
delta_t = rewards + gamma * vpred_t[1:] - vpred_t[:-1]
# This formula for the advantage comes "Generalized Advantage Estimation":
# https://arxiv.org/abs/1506.02438
batch_adv = discount(delta_t, gamma * lambda_)
for i in range(batch_adv.shape[0]):
batch_adv[i] = reward_filter(batch_adv[i])
features = rollout.data["features"][0]
return Batch(batch_si, batch_a, batch_adv, batch_r, rollout.is_terminal(),
features)
def get_policy_cls(config):
if config["use_lstm"]:
@@ -45,7 +14,3 @@ def get_policy_cls(config):
from ray.rllib.a3c.shared_model import SharedModel
policy_cls = SharedModel
return policy_cls
Batch = namedtuple(
"Batch", ["si", "a", "adv", "r", "terminal", "features"])
+1 -1
View File
@@ -20,7 +20,7 @@ class Policy(object):
def compute_gradients(self, batch):
raise NotImplementedError
def compute_action(self, observations):
def compute(self, observations):
"""Compute action for a _single_ observation"""
raise NotImplementedError
+22 -18
View File
@@ -3,13 +3,15 @@ from __future__ import division
from __future__ import print_function
import ray
from ray.rllib.a3c.envs import create_and_wrap
from ray.rllib.a3c.common import process_rollout, get_policy_cls
from ray.rllib.envs import create_and_wrap
from ray.rllib.evaluator import Evaluator
from ray.rllib.a3c.common import get_policy_cls
from ray.rllib.utils.filter import get_filter
from ray.rllib.utils.sampler import AsyncSampler
from ray.rllib.utils.process_rollout import process_rollout
class Runner(object):
class A3CEvaluator(Evaluator):
"""Actor object to start running simulation on workers.
The gradient computation is also executed from this object.
@@ -29,19 +31,16 @@ class Runner(object):
obs_filter = get_filter(
config["observation_filter"], env.observation_space.shape)
self.rew_filter = get_filter(config["reward_filter"], ())
self.sampler = AsyncSampler(env, self.policy, config["batch_size"],
obs_filter)
self.sampler = AsyncSampler(env, self.policy, obs_filter,
config["batch_size"])
self.logdir = logdir
def get_data(self):
def sample(self):
"""
Returns:
trajectory: trajectory information
obs_filter: Current state of observation filter
rew_filter: Current state of reward filter"""
rollout, obs_filter = self.sampler.get_data()
return rollout, obs_filter, self.rew_filter
trajectory (PartialRollout): Experience Samples from evaluator"""
rollout = self.sampler.get_data()
return rollout
def get_completed_rollout_metrics(self):
"""Returns metrics on previously completed rollouts.
@@ -51,14 +50,19 @@ class Runner(object):
return self.sampler.get_metrics()
def compute_gradient(self):
rollout, obsf_snapshot = self.sampler.get_data()
batch = process_rollout(
rollout, self.rew_filter, gamma=0.99, lambda_=1.0)
gradient, info = self.policy.compute_gradients(batch)
info["obs_filter"] = obsf_snapshot
rollout = self.sampler.get_data()
obs_filter = self.sampler.get_obs_filter(flush=True)
traj = process_rollout(
rollout, self.rew_filter, gamma=0.99, lambda_=1.0, use_gae=True)
gradient, info = self.policy.compute_gradients(traj)
info["obs_filter"] = obs_filter
info["rew_filter"] = self.rew_filter
return gradient, info
def apply_gradient(self, grads):
self.policy.apply_gradients(grads)
def set_weights(self, params):
self.policy.set_weights(params)
@@ -70,4 +74,4 @@ class Runner(object):
self.sampler.update_obs_filter(obs_filter)
RemoteRunner = ray.remote(Runner)
RemoteA3CEvaluator = ray.remote(A3CEvaluator)
+8 -11
View File
@@ -10,7 +10,7 @@ from ray.rllib.models.catalog import ModelCatalog
class SharedModel(TFPolicy):
other_output = ["value"]
other_output = ["vf_preds"]
is_recurrent = False
def __init__(self, ob_space, ac_space, **kwargs):
@@ -35,13 +35,13 @@ class SharedModel(TFPolicy):
initializer=tf.constant_initializer(0, dtype=tf.int32),
trainable=False)
def compute_gradients(self, batch):
def compute_gradients(self, trajectory):
info = {}
feed_dict = {
self.x: batch.si,
self.ac: batch.a,
self.adv: batch.adv,
self.r: batch.r,
self.x: trajectory["observations"],
self.ac: trajectory["actions"],
self.adv: trajectory["advantages"],
self.r: trajectory["value_targets"],
}
self.grads = [g for g in self.grads if g is not None]
self.local_steps += 1
@@ -53,14 +53,11 @@ class SharedModel(TFPolicy):
grad = self.sess.run(self.grads, feed_dict=feed_dict)
return grad, info
def compute_action(self, ob, *args):
def compute(self, ob, *args):
action, vf = self.sess.run([self.sample, self.vf],
{self.x: [ob]})
return action[0], {"value": vf[0]}
return action[0], {"vf_preds": vf[0]}
def value(self, ob, *args):
vf = self.sess.run(self.vf, {self.x: [ob]})
return vf[0]
def get_initial_features(self):
return []
+11 -10
View File
@@ -18,7 +18,7 @@ class SharedModelLSTM(TFPolicy):
to be tracked).
"""
other_output = ["value", "features"]
other_output = ["vf_preds", "features"]
is_recurrent = True
def __init__(self, ob_space, ac_space, **kwargs):
@@ -48,19 +48,20 @@ class SharedModelLSTM(TFPolicy):
initializer=tf.constant_initializer(0, dtype=tf.int32),
trainable=False)
def compute_gradients(self, batch):
def compute_gradients(self, trajectory):
"""Computing the gradient is actually model-dependent.
The LSTM needs its hidden states in order to compute the gradient
accurately.
"""
features = trajectory["features"][0]
feed_dict = {
self.x: batch.si,
self.ac: batch.a,
self.adv: batch.adv,
self.r: batch.r,
self.state_in[0]: batch.features[0],
self.state_in[1]: batch.features[1]
self.x: trajectory["observations"],
self.ac: trajectory["actions"],
self.adv: trajectory["advantages"],
self.r: trajectory["value_targets"],
self.state_in[0]: features[0],
self.state_in[1]: features[1]
}
info = {}
self.local_steps += 1
@@ -72,11 +73,11 @@ class SharedModelLSTM(TFPolicy):
grad = self.sess.run(self.grads, feed_dict=feed_dict)
return grad, info
def compute_action(self, ob, c, h):
def compute(self, ob, c, h):
action, vf, c, h = self.sess.run(
[self.sample, self.vf] + self.state_out,
{self.x: [ob], self.state_in[0]: c, self.state_in[1]: h})
return action[0], {"value": vf[0], "features": (c, h)}
return action[0], {"vf_preds": vf[0], "features": (c, h)}
def value(self, ob, c, h):
vf = self.sess.run(self.vf, {self.x: [ob],
+3 -6
View File
@@ -14,7 +14,7 @@ from ray.rllib.models.catalog import ModelCatalog
class SharedTorchPolicy(TorchPolicy):
"""Assumes nonrecurrent."""
other_output = ["value"]
other_output = ["vf_preds"]
is_recurrent = False
def __init__(self, ob_space, ac_space, **kwargs):
@@ -26,14 +26,14 @@ class SharedTorchPolicy(TorchPolicy):
self._model = ModelCatalog.get_torch_model(ob_space, self.logit_dim)
self.optimizer = torch.optim.Adam(self._model.parameters(), lr=0.0001)
def compute_action(self, ob, *args):
def compute(self, ob, *args):
"""Should take in a SINGLE ob"""
with self.lock:
ob = Variable(torch.from_numpy(ob).float().unsqueeze(0))
logits, values = self._model(ob)
samples = self._model.probs(logits).multinomial().squeeze()
values = values.squeeze(0)
return var_to_np(samples), {"value": var_to_np(values)}
return var_to_np(samples), {"vf_preds": var_to_np(values)}
def compute_logits(self, ob, *args):
with self.lock:
@@ -71,6 +71,3 @@ class SharedTorchPolicy(TorchPolicy):
overall_err = 0.5 * value_err + pi_err - entropy * 0.01
overall_err.backward()
torch.nn.utils.clip_grad_norm(self._model.parameters(), 40)
def get_initial_features(self):
return [None]
+1 -1
View File
@@ -92,7 +92,7 @@ class TFPolicy(Policy):
def compute_gradients(self, batch):
raise NotImplementedError
def compute_action(self, observations):
def compute(self, observation):
raise NotImplementedError
def value(self, ob):
-3
View File
@@ -73,6 +73,3 @@ class TorchPolicy(Policy):
This function regenerates the backward trace and
caluclates the gradient."""
raise NotImplementedError
def get_initial_features(self):
return []
+12 -8
View File
@@ -8,19 +8,23 @@ import torch
from torch.autograd import Variable
def convert_batch(batch, has_features=False):
"""Convert batch from numpy to PT variable"""
states = Variable(torch.from_numpy(batch.si).float())
acs = Variable(torch.from_numpy(batch.a))
advs = Variable(torch.from_numpy(batch.adv.copy()).float())
def convert_batch(trajectory, has_features=False):
"""Convert trajectory from numpy to PT variable"""
states = Variable(torch.from_numpy(
trajectory["observations"]).float())
acs = Variable(torch.from_numpy(
trajectory["actions"]))
advs = Variable(torch.from_numpy(
trajectory["advantages"].copy()).float())
advs = advs.view(-1, 1)
rs = Variable(torch.from_numpy(batch.r.copy()).float())
rs = Variable(torch.from_numpy(
trajectory["value_targets"]).float())
rs = rs.view(-1, 1)
if has_features:
features = [Variable(torch.from_numpy(f))
for f in batch.features]
for f in trajectory["features"]]
else:
features = batch.features
features = trajectory["features"]
return states, acs, advs, rs, features
+11 -6
View File
@@ -10,9 +10,12 @@ from ray.rllib.models import ModelCatalog
class ProximalPolicyLoss(object):
other_output = ["vf_preds", "logprobs"]
is_recurrent = False
def __init__(
self, observation_space, action_space,
observations, returns, advantages, actions,
observations, value_targets, advantages, actions,
prev_logits, prev_vf_preds, logit_dim,
kl_coeff, distribution_class, config, sess):
assert (isinstance(action_space, gym.spaces.Discrete) or
@@ -55,11 +58,11 @@ class ProximalPolicyLoss(object):
# We use a huber loss here to be more robust against outliers,
# which seem to occur when the rollouts get longer (the variance
# scales superlinearly with the length of the rollout)
self.vf_loss1 = tf.square(self.value_function - returns)
self.vf_loss1 = tf.square(self.value_function - value_targets)
vf_clipped = prev_vf_preds + tf.clip_by_value(
self.value_function - prev_vf_preds,
-config["clip_param"], config["clip_param"])
self.vf_loss2 = tf.square(vf_clipped - returns)
self.vf_loss2 = tf.square(vf_clipped - value_targets)
self.vf_loss = tf.minimum(self.vf_loss1, self.vf_loss2)
self.mean_vf_loss = tf.reduce_mean(self.vf_loss)
self.loss = tf.reduce_mean(
@@ -82,9 +85,11 @@ class ProximalPolicyLoss(object):
self.policy_results = [
self.sampler, self.curr_logits, tf.constant("NA")]
def compute(self, observations):
return self.sess.run(self.policy_results,
feed_dict={self.observations: observations})
def compute(self, observation):
action, logprobs, vf = self.sess.run(
self.policy_results,
feed_dict={self.observations: [observation]})
return action[0], {"vf_preds": vf[0], "logprobs": logprobs[0]}
def loss(self):
return self.loss
+16 -13
View File
@@ -11,8 +11,9 @@ import tensorflow as tf
from tensorflow.python import debug as tf_debug
import ray
from ray.rllib.agent import Agent
from ray.tune.result import TrainingResult
from ray.rllib.agent import Agent
from ray.rllib.utils.filter import get_filter
from ray.rllib.ppo.runner import Runner, RemoteRunner
from ray.rllib.ppo.rollout import collect_samples
from ray.rllib.ppo.utils import shuffle
@@ -90,10 +91,10 @@ class PPOAgent(Agent):
self.global_step = 0
self.kl_coeff = self.config["kl_coeff"]
self.model = Runner(
self.env_creator, 1, self.config, self.logdir, False)
self.env_creator, self.config, self.logdir, False)
self.agents = [
RemoteRunner.remote(
self.env_creator, 1, self.config, self.logdir, True)
self.env_creator, self.config, self.logdir, True)
for _ in range(self.config["num_workers"])]
self.start_time = time.time()
if self.config["write_logs"]:
@@ -102,6 +103,9 @@ class PPOAgent(Agent):
else:
self.file_writer = None
self.saver = tf.train.Saver(max_to_keep=None)
self.obs_filter = get_filter(
self.config["observation_filter"],
self.model.env.observation_space.shape)
def _train(self):
agents = self.agents
@@ -114,11 +118,11 @@ class PPOAgent(Agent):
weights = ray.put(model.get_weights())
[a.load_weights.remote(weights) for a in agents]
trajectory, total_reward, traj_len_mean = collect_samples(
agents, config, self.model.observation_filter,
agents, config, self.obs_filter,
self.model.reward_filter)
print("total reward is ", total_reward)
print("trajectory length mean is ", traj_len_mean)
print("timesteps:", trajectory["dones"].shape[0])
print("timesteps:", trajectory["actions"].shape[0])
if self.file_writer:
traj_stats = tf.Summary(value=[
tf.Summary.Value(
@@ -135,10 +139,7 @@ class PPOAgent(Agent):
# to guard against the case where all values are equal
return (value - value.mean()) / max(1e-4, value.std())
if config["use_gae"]:
trajectory["advantages"] = standardized(trajectory["advantages"])
else:
trajectory["returns"] = standardized(trajectory["returns"])
trajectory["advantages"] = standardized(trajectory["advantages"])
rollouts_end = time.time()
print("Computing policy (iterations=" + str(config["num_sgd_iter"]) +
@@ -238,7 +239,7 @@ class PPOAgent(Agent):
result = TrainingResult(
episode_reward_mean=total_reward,
episode_len_mean=traj_len_mean,
timesteps_this_iter=trajectory["dones"].shape[0],
timesteps_this_iter=trajectory["actions"].shape[0],
info=info)
return result
@@ -253,7 +254,8 @@ class PPOAgent(Agent):
self.model.save(),
self.global_step,
self.kl_coeff,
agent_state]
agent_state,
self.obs_filter]
pickle.dump(extra_data, open(checkpoint_path + ".extra_data", "wb"))
return checkpoint_path
@@ -266,7 +268,8 @@ class PPOAgent(Agent):
ray.get([
a.restore.remote(o)
for (a, o) in zip(self.agents, extra_data[3])])
self.obs_filter = extra_data[4]
def compute_action(self, observation):
observation = self.model.observation_filter(observation, update=False)
return self.model.common_policy.compute([observation])[0][0]
observation = self.obs_filter(observation, update=False)
return self.model.common_policy.compute(observation)[0]
+4 -96
View File
@@ -8,92 +8,6 @@ import ray
from ray.rllib.ppo.utils import concatenate
def rollouts(policy, env, horizon, observation_filter, reward_filter):
"""Perform a batch of rollouts of a policy in an environment.
Args:
policy: The policy that will be rollout out. Can be an arbitrary object
that supports a compute_actions(observation) function.
env: The environment the rollout is computed in. Needs to support the
OpenAI gym API and needs to support batches of data.
horizon: Upper bound for the number of timesteps for each rollout in
the batch.
observation_filter: Function that is applied to each of the
observations.
reward_filter: Function that is applied to each of the rewards.
Returns:
A trajectory, which is a dictionary with keys "observations",
"rewards", "orig_rewards", "actions", "logprobs", "dones". Each
value is an array of shape (num_timesteps, env.batchsize, shape).
"""
observation = observation_filter(env.reset())
done = np.array(env.batchsize * [False])
t = 0
observations = [] # Filtered observations
raw_rewards = [] # Empirical rewards
actions = [] # Actions sampled by the policy
logprobs = [] # Last layer of the policy network
vf_preds = [] # Value function predictions
dones = [] # Has this rollout terminated?
while True:
action, logprob, vfpred = policy.compute(observation)
vf_preds.append(vfpred)
observations.append(observation[None])
actions.append(action[None])
logprobs.append(logprob[None])
observation, raw_reward, done = env.step(action)
observation = observation_filter(observation)
raw_rewards.append(raw_reward[None])
dones.append(done[None])
t += 1
if done.all() or t >= horizon:
break
return {"observations": np.vstack(observations),
"raw_rewards": np.vstack(raw_rewards),
"actions": np.vstack(actions),
"logprobs": np.vstack(logprobs),
"vf_preds": np.vstack(vf_preds),
"dones": np.vstack(dones)}
def add_return_values(trajectory, gamma, reward_filter):
rewards = trajectory["raw_rewards"]
dones = trajectory["dones"]
returns = np.zeros_like(rewards)
last_return = np.zeros(rewards.shape[1], dtype="float32")
for t in reversed(range(len(rewards) - 1)):
last_return = rewards[t, :] * (1 - dones[t, :]) + gamma * last_return
returns[t, :] = last_return
reward_filter(returns[t, :])
trajectory["returns"] = returns
def add_advantage_values(trajectory, gamma, lam, reward_filter):
rewards = trajectory["raw_rewards"]
vf_preds = trajectory["vf_preds"]
dones = trajectory["dones"]
advantages = np.zeros_like(rewards)
last_advantage = np.zeros(rewards.shape[1], dtype="float32")
for t in reversed(range(len(rewards) - 1)):
delta = rewards[t, :] * (1 - dones[t, :]) + \
gamma * vf_preds[t+1, :] * (1 - dones[t+1, :]) - vf_preds[t, :]
last_advantage = \
delta + gamma * lam * last_advantage * (1 - dones[t+1, :])
advantages[t, :] = last_advantage
reward_filter(advantages[t, :])
trajectory["advantages"] = advantages
trajectory["td_lambda_returns"] = \
trajectory["advantages"] + trajectory["vf_preds"]
def collect_samples(agents,
config,
observation_filter,
@@ -106,26 +20,20 @@ def collect_samples(agents,
# computed to the agent that they are computed on; we start some initial
# tasks here.
agent_dict = {agent.compute_steps.remote(
config["gamma"], config["lambda"],
config["horizon"], config["min_steps_per_task"],
observation_filter, reward_filter):
config, observation_filter, reward_filter):
agent for agent in agents}
while num_timesteps_so_far < config["timesteps_per_batch"]:
# TODO(pcm): Make wait support arbitrary iterators and remove the
# conversion to list here.
[next_trajectory], waiting_trajectories = ray.wait(
list(agent_dict.keys()))
[next_trajectory], _ = ray.wait(list(agent_dict))
agent = agent_dict.pop(next_trajectory)
# Start task with next trajectory and record it in the dictionary.
agent_dict[agent.compute_steps.remote(
config["gamma"], config["lambda"],
config["horizon"], config["min_steps_per_task"],
observation_filter, reward_filter)] = (
agent)
config, observation_filter, reward_filter)] = agent
trajectory, rewards, lengths, obs_f, rew_f = ray.get(next_trajectory)
total_rewards.extend(rewards)
trajectory_lengths.extend(lengths)
num_timesteps_so_far += len(trajectory["dones"])
num_timesteps_so_far += sum(lengths)
trajectories.append(trajectory)
observation_filter.update(obs_f)
reward_filter.update(rew_f)
+60 -82
View File
@@ -14,12 +14,13 @@ import ray
from ray.rllib.parallel import LocalSyncParallelOptimizer
from ray.rllib.models import ModelCatalog
from ray.rllib.envs import create_and_wrap
from ray.rllib.utils.sampler import SyncSampler
from ray.rllib.utils.filter import get_filter, MeanStdFilter
from ray.rllib.ppo.env import BatchedEnv
from ray.rllib.utils.process_rollout import process_rollout
from ray.rllib.ppo.loss import ProximalPolicyLoss
from ray.rllib.ppo.rollout import (
rollouts, add_return_values, add_advantage_values)
from ray.rllib.ppo.utils import flatten, concatenate
from ray.rllib.ppo.utils import concatenate
# TODO(pcm): Make sure that both observation_filter and reward_filter
# are correctly handled, i.e. (a) the values are accumulated accross
@@ -37,7 +38,8 @@ class Runner(object):
network weights. When run as a remote agent, only this graph is used.
"""
def __init__(self, env_creator, batchsize, config, logdir, is_remote):
def __init__(self, env_creator, config, logdir, is_remote):
self.is_remote = is_remote
if is_remote:
os.environ["CUDA_VISIBLE_DEVICES"] = ""
devices = ["/cpu:0"]
@@ -46,12 +48,11 @@ class Runner(object):
self.devices = devices
self.config = config
self.logdir = logdir
self.env = BatchedEnv(env_creator, batchsize, config)
self.env = create_and_wrap(env_creator, config["model"])
if is_remote:
config_proto = tf.ConfigProto()
else:
config_proto = tf.ConfigProto(**config["tf_session_args"])
self.preprocessor = self.env.preprocessor
self.sess = tf.Session(config=config_proto)
if config["tf_debug_inf_or_nan"] and not is_remote:
self.sess = tf_debug.LocalCLIDebugWrapperSession(self.sess)
@@ -65,13 +66,14 @@ class Runner(object):
# The input observations.
self.observations = tf.placeholder(
tf.float32, shape=(None,) + self.preprocessor.shape)
tf.float32, shape=(None,) + self.env.observation_space.shape)
# Targets of the value function.
self.returns = tf.placeholder(tf.float32, shape=(None,))
self.value_targets = tf.placeholder(tf.float32, shape=(None,))
# Advantage values in the policy gradient estimator.
self.advantages = tf.placeholder(tf.float32, shape=(None,))
action_space = self.env.action_space
# TODO(rliaw): pull this into model_catalog
if isinstance(action_space, gym.spaces.Box):
self.actions = tf.placeholder(
tf.float32, shape=(None, action_space.shape[0]))
@@ -98,17 +100,17 @@ class Runner(object):
self.batch_size = config["sgd_batchsize"]
self.per_device_batch_size = int(self.batch_size / len(devices))
def build_loss(obs, rets, advs, acts, plog, pvf_preds):
def build_loss(obs, vtargets, advs, acts, plog, pvf_preds):
return ProximalPolicyLoss(
self.env.observation_space, self.env.action_space,
obs, rets, advs, acts, plog, pvf_preds, self.logit_dim,
obs, vtargets, advs, acts, plog, pvf_preds, self.logit_dim,
self.kl_coeff, self.distribution_class, self.config,
self.sess)
self.par_opt = LocalSyncParallelOptimizer(
tf.train.AdamOptimizer(self.config["sgd_stepsize"]),
self.devices,
[self.observations, self.returns, self.advantages,
[self.observations, self.value_targets, self.advantages,
self.actions, self.prev_logits, self.prev_vf_preds],
self.per_device_batch_size,
build_loss,
@@ -137,33 +139,26 @@ class Runner(object):
self.common_policy = self.par_opt.get_common_loss()
self.variables = ray.experimental.TensorFlowVariables(
self.common_policy.loss, self.sess)
self.observation_filter = get_filter(
config["observation_filter"], self.preprocessor.shape)
obs_filter = get_filter(
config["observation_filter"], self.env.observation_space.shape)
self.sampler = SyncSampler(
self.env, self.common_policy, obs_filter,
self.config["horizon"], self.config["horizon"])
self.reward_filter = MeanStdFilter((), clip=5.0)
self.sess.run(tf.global_variables_initializer())
def load_data(self, trajectories, full_trace):
if self.config["use_gae"]:
return self.par_opt.load_data(
self.sess,
[trajectories["observations"],
trajectories["td_lambda_returns"],
trajectories["advantages"],
trajectories["actions"].squeeze(),
trajectories["logprobs"],
trajectories["vf_preds"]],
full_trace=full_trace)
else:
dummy = np.zeros((trajectories["observations"].shape[0],))
return self.par_opt.load_data(
self.sess,
[trajectories["observations"],
dummy,
trajectories["returns"],
trajectories["actions"].squeeze(),
trajectories["logprobs"],
dummy],
full_trace=full_trace)
use_gae = self.config["use_gae"]
dummy = np.zeros_like(trajectories["advantages"])
return self.par_opt.load_data(
self.sess,
[trajectories["observations"],
trajectories["value_targets"] if use_gae else dummy,
trajectories["advantages"],
trajectories["actions"].squeeze(),
trajectories["logprobs"],
trajectories["vf_preds"] if use_gae else dummy],
full_trace=full_trace)
def run_sgd_minibatch(
self, batch_index, kl_coeff, full_trace, file_writer):
@@ -177,12 +172,14 @@ class Runner(object):
file_writer=file_writer if full_trace else None)
def save(self):
return pickle.dumps([self.observation_filter, self.reward_filter])
obs_filter = self.sampler.get_obs_filter()
return pickle.dumps([obs_filter, self.reward_filter])
def restore(self, objs):
objs = pickle.loads(objs)
self.observation_filter = objs[0]
self.reward_filter = objs[1]
obs_filter = objs[0]
rew_filter = objs[1]
self.update_filters(obs_filter, rew_filter)
def get_weights(self):
return self.variables.get_weights()
@@ -190,29 +187,22 @@ class Runner(object):
def load_weights(self, weights):
self.variables.set_weights(weights)
def compute_trajectory(self, gamma, lam, horizon):
"""Compute a single rollout on the agent and return."""
trajectory = rollouts(
self.common_policy,
self.env, horizon, self.observation_filter, self.reward_filter)
if self.config["use_gae"]:
add_advantage_values(trajectory, gamma, lam, self.reward_filter)
else:
add_return_values(trajectory, gamma, self.reward_filter)
return trajectory
def update_filters(self, obs_filter=None, rew_filter=None):
if rew_filter:
# No special handling required since outside of threaded code
self.reward_filter = rew_filter.copy()
if obs_filter:
self.sampler.update_obs_filter(obs_filter)
def compute_steps(
self, gamma, lam, horizon, min_steps_per_task,
observation_filter, reward_filter):
def get_obs_filter(self):
return self.sampler.get_obs_filter()
def compute_steps(self, config, obs_filter, rew_filter):
"""Compute multiple rollouts and concatenate the results.
Args:
gamma: MDP discount factor
lam: GAE(lambda) parameter
horizon: Number of steps after which a rollout gets cut
min_steps_per_task: Lower bound on the number of states to be
collected.
observation_filter: Function that is applied to each of the
config: Configuration parameters
obs_filter: Function that is applied to each of the
observations.
reward_filter: Function that is applied to each of the rewards.
@@ -221,38 +211,26 @@ class Runner(object):
total_rewards: Total rewards of the trajectories.
trajectory_lengths: Lengths of the trajectories.
"""
# Update our local filters
self.observation_filter = observation_filter.copy()
self.reward_filter = reward_filter.copy()
num_steps_so_far = 0
trajectories = []
total_rewards = []
trajectory_lengths = []
while True:
trajectory = self.compute_trajectory(gamma, lam, horizon)
total_rewards.append(
trajectory["raw_rewards"].sum(axis=0).mean())
trajectory_lengths.append(
np.logical_not(trajectory["dones"]).sum(axis=0).mean())
trajectory = flatten(trajectory)
not_done = np.logical_not(trajectory["dones"])
# Filtering out states that are done. We do this because
# trajectories are batched and cut only if all the trajectories
# in the batch terminated, so we can potentially get rid of
# some of the states here.
trajectory = {key: val[not_done]
for key, val in trajectory.items()}
num_steps_so_far += trajectory["raw_rewards"].shape[0]
self.update_filters(obs_filter, rew_filter)
while num_steps_so_far < config["min_steps_per_task"]:
rollout = self.sampler.get_data()
trajectory = process_rollout(
rollout, self.reward_filter, config["gamma"],
config["lambda"], use_gae=config["use_gae"])
num_steps_so_far += trajectory["rewards"].shape[0]
trajectories.append(trajectory)
if num_steps_so_far >= min_steps_per_task:
break
metrics = self.sampler.get_metrics()
total_rewards, trajectory_lengths = zip(*[
(c.episode_reward, c.episode_length) for c in metrics])
updated_obs_filter = self.sampler.get_obs_filter(flush=True)
return (
concatenate(trajectories),
total_rewards,
trajectory_lengths,
self.observation_filter,
updated_obs_filter,
self.reward_filter)
+1 -1
View File
@@ -30,7 +30,7 @@ def concatenate(weights_list):
def shuffle(trajectory):
permutation = np.random.permutation(trajectory["dones"].shape[0])
permutation = np.random.permutation(trajectory["actions"].shape[0])
for key, val in trajectory.items():
trajectory[key] = val[permutation]
return trajectory
+3 -3
View File
@@ -5,7 +5,7 @@ from __future__ import print_function
import numpy as np
class BaseFilter(object):
class Filter(object):
"""Processes input, possibly statefully."""
def update(self, other, *args, **kwargs):
@@ -24,7 +24,7 @@ class BaseFilter(object):
raise NotImplementedError
class NoFilter(BaseFilter):
class NoFilter(Filter):
def __init__(self, *args):
pass
@@ -107,7 +107,7 @@ class RunningStat(object):
return self._M.shape
class MeanStdFilter(object):
class MeanStdFilter(Filter):
"""Keeps track of a running mean for seen states"""
def __init__(self, shape, demean=True, destd=True, clip=10.0):
+40
View File
@@ -0,0 +1,40 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import scipy.signal
def discount(x, gamma):
return scipy.signal.lfilter([1], [1, -gamma], x[::-1], axis=0)[::-1]
def process_rollout(rollout, reward_filter, gamma, lambda_=1.0, use_gae=True):
"""Given a rollout, compute its value targets and the advantage."""
traj = {}
trajsize = len(rollout.data["actions"])
for key in rollout.data:
traj[key] = np.stack(rollout.data[key])
if use_gae:
assert "vf_preds" in rollout.data, "Values not found!"
vpred_t = np.stack(
rollout.data["vf_preds"] + [np.array(rollout.last_r)]).squeeze()
delta_t = traj["rewards"] + gamma * vpred_t[1:] - vpred_t[:-1]
# This formula for the advantage comes
# "Generalized Advantage Estimation": https://arxiv.org/abs/1506.02438
traj["advantages"] = discount(delta_t, gamma * lambda_)
traj["value_targets"] = traj["advantages"] + traj["vf_preds"]
else:
rewards_plus_v = np.stack(
rollout.data["rewards"] + [np.array(rollout.last_r)]).squeeze()
traj["advantages"] = discount(rewards_plus_v, gamma)[:-1]
for i in range(traj["advantages"].shape[0]):
traj["advantages"][i] = reward_filter(traj["advantages"][i])
assert all(val.shape[0] == trajsize for val in traj.values()), \
"Rollout stacked incorrectly!"
return traj
+89 -51
View File
@@ -19,9 +19,14 @@ class PartialRollout(object):
We run our agent, and process its experience once it has processed enough
steps.
Attributes:
data (dict): Stores rollout data. All numpy arrays other than
`observations` and `features` will be squeezed.
last_r (float): Value of next state. Used for bootstrapping.
"""
fields = ["state", "action", "reward", "terminal", "features"]
fields = ["observations", "actions", "rewards", "terminal", "features"]
def __init__(self, extra_fields=None):
"""Initializers internals. Maintains a `last_r` field
@@ -72,23 +77,40 @@ class SyncSampler(object):
thread."""
async = False
def __init__(self, env, policy, num_local_steps, obs_filter):
def __init__(self, env, policy, obs_filter,
num_local_steps, horizon=None):
self.num_local_steps = num_local_steps
self.horizon = horizon
self.env = env
self.policy = policy
self.obs_filter = obs_filter
self._obs_filter = obs_filter
self.rollout_provider = _env_runner(
self.env, self.policy, self.num_local_steps, self.obs_filter)
self.env, self.policy, self.num_local_steps, self.horizon,
self._obs_filter)
self.metrics_queue = queue.Queue()
def update_obs_filter(self, other_filter):
"""Method to update observation filter with copy from driver.
Since this class is synchronous, updating the observation
filter should be a straightforward replacement
def get_obs_filter(self, flush=False):
"""Gets a snapshot of the current observation filter. The snapshot
also by default does not clear the accumulated delta.
Args:
other_filter: Another filter (of same type)."""
self.obs_filter = other_filter.copy()
flush (bool): If True, accumulated state in buffer is cleared.
Returns:
snapshot (Filter): Copy of observation filter.
"""
snapshot = self._obs_filter.copy()
if flush and hasattr(self._obs_filter, "clear_buffer"):
self._obs_filter.clear_buffer()
return snapshot
def update_obs_filter(self, other_filter):
"""Updates observation filter with copy from driver.
Args:
other_filter: Another filter (of same type).
"""
self._obs_filter.sync(other_filter)
def get_data(self):
while True:
@@ -96,10 +118,7 @@ class SyncSampler(object):
if isinstance(item, CompletedRollout):
self.metrics_queue.put(item)
else:
obsf_snapshot = self.obs_filter.copy()
if hasattr(self.obs_filter, "clear_buffer"):
self.obs_filter.clear_buffer()
return item, obsf_snapshot
return item
def get_metrics(self):
completed = []
@@ -118,15 +137,17 @@ class AsyncSampler(threading.Thread):
accumulate and the gradient can be calculated on up to 5 batches."""
async = True
def __init__(self, env, policy, num_local_steps, obs_filter):
def __init__(self, env, policy, obs_filter,
num_local_steps, horizon=None):
threading.Thread.__init__(self)
self.queue = queue.Queue(5)
self.metrics_queue = queue.Queue()
self.num_local_steps = num_local_steps
self.horizon = horizon
self.env = env
self.policy = policy
self.obs_filter = obs_filter
self.obs_f_lock = threading.Lock()
self._obs_filter = obs_filter
self._obs_f_lock = threading.Lock()
self.start()
def run(self):
@@ -139,25 +160,26 @@ class AsyncSampler(threading.Thread):
def update_obs_filter(self, other_filter):
"""Method to update observation filter with copy from driver.
Applies delta since last `clear_buffer` to given new filter,
and syncs current filter to new filter. `self.obs_filter` is
and syncs current filter to new filter. `self._obs_filter` is
kept in place due to the `lock_wrap`.
Args:
other_filter: Another filter (of same type)."""
with self.obs_f_lock:
with self._obs_f_lock:
new_filter = other_filter.copy()
# Applies delta to filter, including buffer
new_filter.update(self.obs_filter, copy_buffer=True)
new_filter.update(self._obs_filter, copy_buffer=True)
# copies everything back into original filter - needed
# due to `lock_wrap`
self.obs_filter.sync(new_filter)
self._obs_filter.sync(new_filter)
def _run(self):
"""Sets observation filter into an atomic region and starts
other thread for running."""
safe_obs_filter = lock_wrap(self.obs_filter, self.obs_f_lock)
safe_obs_filter = lock_wrap(self._obs_filter, self._obs_f_lock)
rollout_provider = _env_runner(
self.env, self.policy, self.num_local_steps, safe_obs_filter)
self.env, self.policy, self.num_local_steps,
self.horizon, safe_obs_filter)
while True:
# The timeout variable exists because apparently, if one worker
# dies, the other workers won't die with it, unless the timeout is
@@ -168,24 +190,32 @@ class AsyncSampler(threading.Thread):
else:
self.queue.put(item, timeout=600.0)
def get_data(self):
"""Gets currently accumulated data and a snapshot of the current
observation filter. The snapshot also clears the accumulated delta.
Note that in between getting the rollout and acquiring the lock,
def get_obs_filter(self, flush=False):
"""Gets a snapshot of the current observation filter. The snapshot
also clears the accumulated delta. Note that in between getting
the rollout from self.queue and acquiring the lock here,
the other thread can run, resulting in slight discrepamcies
between data retrieved and filter statistics.
Returns:
rollout: trajectory data (unprocessed)
obsf_snapshot: snapshot of observation filter.
snapshot (Filter): Copy of observation filter.
"""
with self._obs_f_lock:
snapshot = self._obs_filter.copy()
if hasattr(self._obs_filter, "clear_buffer"):
self._obs_filter.clear_buffer()
return snapshot
def get_data(self):
"""Gets currently accumulated data.
Returns:
rollout (PartialRollout): trajectory data (unprocessed)
"""
rollout = self._pull_batch_from_queue()
with self.obs_f_lock:
obsf_snapshot = self.obs_filter.copy()
if hasattr(self.obs_filter, "clear_buffer"):
self.obs_filter.clear_buffer()
return rollout, obsf_snapshot
return rollout
def _pull_batch_from_queue(self):
"""Take a rollout from the queue of the thread runner."""
@@ -212,7 +242,7 @@ class AsyncSampler(threading.Thread):
return completed
def _env_runner(env, policy, num_local_steps, obs_filter):
def _env_runner(env, policy, num_local_steps, horizon, obs_filter):
"""This implements the logic of the thread runner.
It continually runs the policy, and as long as the rollout exceeds a
@@ -231,10 +261,15 @@ def _env_runner(env, policy, num_local_steps, obs_filter):
rollout (PartialRollout): Object containing state, action, reward,
terminal condition, and other fields as dictated by `policy`.
"""
last_state = obs_filter(env.reset())
timestep_limit = env.spec.tags.get("wrapper_config.TimeLimit"
".max_episode_steps")
last_features = features = policy.get_initial_features()
last_observation = obs_filter(env.reset())
horizon = horizon if horizon else env.spec.tags.get(
"wrapper_config.TimeLimit.max_episode_steps")
assert horizon > 0
if hasattr(policy, "get_initial_features"):
last_features = policy.get_initial_features()
else:
last_features = []
features = last_features
length = 0
rewards = 0
rollout_number = 0
@@ -244,44 +279,47 @@ def _env_runner(env, policy, num_local_steps, obs_filter):
rollout = PartialRollout(extra_fields=policy.other_output)
for _ in range(num_local_steps):
action, pi_info = policy.compute_action(last_state, *last_features)
action, pi_info = policy.compute(last_observation, *last_features)
if policy.is_recurrent:
features = pi_info["features"]
del pi_info["features"]
state, reward, terminal, info = env.step(action)
state = obs_filter(state)
observation, reward, terminal, info = env.step(action)
observation = obs_filter(observation)
length += 1
rewards += reward
if length >= timestep_limit:
if length >= horizon:
terminal = True
# Collect the experience.
rollout.add(state=last_state,
action=action,
reward=reward,
rollout.add(observations=last_observation,
actions=action,
rewards=reward,
terminal=terminal,
features=last_features,
**pi_info)
last_state = state
last_observation = observation
last_features = features
if terminal:
terminal_end = True
yield CompletedRollout(length, rewards)
if (length >= timestep_limit or
if (length >= horizon or
not env.metadata.get("semantics.autoreset")):
last_state = obs_filter(env.reset())
last_features = policy.get_initial_features()
last_observation = obs_filter(env.reset())
if hasattr(policy, "get_initial_features"):
last_features = policy.get_initial_features()
else:
last_features = []
rollout_number += 1
length = 0
rewards = 0
break
if not terminal_end:
rollout.last_r = policy.value(last_state, *last_features)
rollout.last_r = policy.value(last_observation, *last_features)
# Once we have enough experience, yield it, and have the ThreadRunner
# place it on a queue.