mirror of
https://github.com/wassname/ray.git
synced 2026-07-02 19:22:51 +08:00
[rllib] PPO and A3C unification (#1253)
This commit is contained in:
@@ -8,8 +8,8 @@ import os
|
||||
|
||||
import ray
|
||||
from ray.rllib.agent import Agent
|
||||
from ray.rllib.a3c.envs import create_and_wrap
|
||||
from ray.rllib.a3c.runner import RemoteRunner
|
||||
from ray.rllib.envs import create_and_wrap
|
||||
from ray.rllib.a3c.runner import RemoteA3CEvaluator
|
||||
from ray.rllib.a3c.common import get_policy_cls
|
||||
from ray.rllib.utils.filter import get_filter
|
||||
from ray.tune.result import TrainingResult
|
||||
@@ -49,7 +49,8 @@ class A3CAgent(Agent):
|
||||
self.env.observation_space.shape)
|
||||
self.rew_filter = get_filter(self.config["reward_filter"], ())
|
||||
self.agents = [
|
||||
RemoteRunner.remote(self.env_creator, self.config, self.logdir)
|
||||
RemoteA3CEvaluator.remote(
|
||||
self.env_creator, self.config, self.logdir)
|
||||
for i in range(self.config["num_workers"])]
|
||||
self.parameters = self.policy.get_weights()
|
||||
|
||||
@@ -105,6 +106,7 @@ class A3CAgent(Agent):
|
||||
return result
|
||||
|
||||
def _save(self):
|
||||
# TODO(rliaw): extend to also support saving worker state?
|
||||
checkpoint_path = os.path.join(
|
||||
self.logdir, "checkpoint-{}".format(self.iteration))
|
||||
objects = [self.parameters, self.obs_filter, self.rew_filter]
|
||||
@@ -118,6 +120,8 @@ class A3CAgent(Agent):
|
||||
self.rew_filter = objects[2]
|
||||
self.policy.set_weights(self.parameters)
|
||||
|
||||
# TODO(rliaw): augment to support LSTM
|
||||
def compute_action(self, observation):
|
||||
actions = self.policy.compute_action(observation)
|
||||
return actions[0]
|
||||
obs = self.obs_filter(observation, update=False)
|
||||
action, info = self.policy.compute(obs)
|
||||
return action
|
||||
|
||||
@@ -2,37 +2,6 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
import scipy.signal
|
||||
from collections import namedtuple
|
||||
|
||||
|
||||
def discount(x, gamma):
|
||||
return scipy.signal.lfilter([1], [1, -gamma], x[::-1], axis=0)[::-1]
|
||||
|
||||
|
||||
def process_rollout(rollout, reward_filter, gamma, lambda_=1.0):
|
||||
"""Given a rollout, compute its returns and the advantage.
|
||||
|
||||
TODO(rliaw): generalize this"""
|
||||
batch_si = np.asarray(rollout.data["state"])
|
||||
batch_a = np.asarray(rollout.data["action"])
|
||||
rewards = np.asarray(rollout.data["reward"])
|
||||
vpred_t = np.asarray(rollout.data["value"] + [rollout.last_r])
|
||||
|
||||
rewards_plus_v = np.asarray(rollout.data["reward"] + [rollout.last_r])
|
||||
batch_r = discount(rewards_plus_v, gamma)[:-1]
|
||||
delta_t = rewards + gamma * vpred_t[1:] - vpred_t[:-1]
|
||||
# This formula for the advantage comes "Generalized Advantage Estimation":
|
||||
# https://arxiv.org/abs/1506.02438
|
||||
batch_adv = discount(delta_t, gamma * lambda_)
|
||||
for i in range(batch_adv.shape[0]):
|
||||
batch_adv[i] = reward_filter(batch_adv[i])
|
||||
|
||||
features = rollout.data["features"][0]
|
||||
return Batch(batch_si, batch_a, batch_adv, batch_r, rollout.is_terminal(),
|
||||
features)
|
||||
|
||||
|
||||
def get_policy_cls(config):
|
||||
if config["use_lstm"]:
|
||||
@@ -45,7 +14,3 @@ def get_policy_cls(config):
|
||||
from ray.rllib.a3c.shared_model import SharedModel
|
||||
policy_cls = SharedModel
|
||||
return policy_cls
|
||||
|
||||
|
||||
Batch = namedtuple(
|
||||
"Batch", ["si", "a", "adv", "r", "terminal", "features"])
|
||||
|
||||
@@ -20,7 +20,7 @@ class Policy(object):
|
||||
def compute_gradients(self, batch):
|
||||
raise NotImplementedError
|
||||
|
||||
def compute_action(self, observations):
|
||||
def compute(self, observations):
|
||||
"""Compute action for a _single_ observation"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@@ -3,13 +3,15 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import ray
|
||||
from ray.rllib.a3c.envs import create_and_wrap
|
||||
from ray.rllib.a3c.common import process_rollout, get_policy_cls
|
||||
from ray.rllib.envs import create_and_wrap
|
||||
from ray.rllib.evaluator import Evaluator
|
||||
from ray.rllib.a3c.common import get_policy_cls
|
||||
from ray.rllib.utils.filter import get_filter
|
||||
from ray.rllib.utils.sampler import AsyncSampler
|
||||
from ray.rllib.utils.process_rollout import process_rollout
|
||||
|
||||
|
||||
class Runner(object):
|
||||
class A3CEvaluator(Evaluator):
|
||||
"""Actor object to start running simulation on workers.
|
||||
|
||||
The gradient computation is also executed from this object.
|
||||
@@ -29,19 +31,16 @@ class Runner(object):
|
||||
obs_filter = get_filter(
|
||||
config["observation_filter"], env.observation_space.shape)
|
||||
self.rew_filter = get_filter(config["reward_filter"], ())
|
||||
|
||||
self.sampler = AsyncSampler(env, self.policy, config["batch_size"],
|
||||
obs_filter)
|
||||
self.sampler = AsyncSampler(env, self.policy, obs_filter,
|
||||
config["batch_size"])
|
||||
self.logdir = logdir
|
||||
|
||||
def get_data(self):
|
||||
def sample(self):
|
||||
"""
|
||||
Returns:
|
||||
trajectory: trajectory information
|
||||
obs_filter: Current state of observation filter
|
||||
rew_filter: Current state of reward filter"""
|
||||
rollout, obs_filter = self.sampler.get_data()
|
||||
return rollout, obs_filter, self.rew_filter
|
||||
trajectory (PartialRollout): Experience Samples from evaluator"""
|
||||
rollout = self.sampler.get_data()
|
||||
return rollout
|
||||
|
||||
def get_completed_rollout_metrics(self):
|
||||
"""Returns metrics on previously completed rollouts.
|
||||
@@ -51,14 +50,19 @@ class Runner(object):
|
||||
return self.sampler.get_metrics()
|
||||
|
||||
def compute_gradient(self):
|
||||
rollout, obsf_snapshot = self.sampler.get_data()
|
||||
batch = process_rollout(
|
||||
rollout, self.rew_filter, gamma=0.99, lambda_=1.0)
|
||||
gradient, info = self.policy.compute_gradients(batch)
|
||||
info["obs_filter"] = obsf_snapshot
|
||||
rollout = self.sampler.get_data()
|
||||
obs_filter = self.sampler.get_obs_filter(flush=True)
|
||||
|
||||
traj = process_rollout(
|
||||
rollout, self.rew_filter, gamma=0.99, lambda_=1.0, use_gae=True)
|
||||
gradient, info = self.policy.compute_gradients(traj)
|
||||
info["obs_filter"] = obs_filter
|
||||
info["rew_filter"] = self.rew_filter
|
||||
return gradient, info
|
||||
|
||||
def apply_gradient(self, grads):
|
||||
self.policy.apply_gradients(grads)
|
||||
|
||||
def set_weights(self, params):
|
||||
self.policy.set_weights(params)
|
||||
|
||||
@@ -70,4 +74,4 @@ class Runner(object):
|
||||
self.sampler.update_obs_filter(obs_filter)
|
||||
|
||||
|
||||
RemoteRunner = ray.remote(Runner)
|
||||
RemoteA3CEvaluator = ray.remote(A3CEvaluator)
|
||||
|
||||
@@ -10,7 +10,7 @@ from ray.rllib.models.catalog import ModelCatalog
|
||||
|
||||
class SharedModel(TFPolicy):
|
||||
|
||||
other_output = ["value"]
|
||||
other_output = ["vf_preds"]
|
||||
is_recurrent = False
|
||||
|
||||
def __init__(self, ob_space, ac_space, **kwargs):
|
||||
@@ -35,13 +35,13 @@ class SharedModel(TFPolicy):
|
||||
initializer=tf.constant_initializer(0, dtype=tf.int32),
|
||||
trainable=False)
|
||||
|
||||
def compute_gradients(self, batch):
|
||||
def compute_gradients(self, trajectory):
|
||||
info = {}
|
||||
feed_dict = {
|
||||
self.x: batch.si,
|
||||
self.ac: batch.a,
|
||||
self.adv: batch.adv,
|
||||
self.r: batch.r,
|
||||
self.x: trajectory["observations"],
|
||||
self.ac: trajectory["actions"],
|
||||
self.adv: trajectory["advantages"],
|
||||
self.r: trajectory["value_targets"],
|
||||
}
|
||||
self.grads = [g for g in self.grads if g is not None]
|
||||
self.local_steps += 1
|
||||
@@ -53,14 +53,11 @@ class SharedModel(TFPolicy):
|
||||
grad = self.sess.run(self.grads, feed_dict=feed_dict)
|
||||
return grad, info
|
||||
|
||||
def compute_action(self, ob, *args):
|
||||
def compute(self, ob, *args):
|
||||
action, vf = self.sess.run([self.sample, self.vf],
|
||||
{self.x: [ob]})
|
||||
return action[0], {"value": vf[0]}
|
||||
return action[0], {"vf_preds": vf[0]}
|
||||
|
||||
def value(self, ob, *args):
|
||||
vf = self.sess.run(self.vf, {self.x: [ob]})
|
||||
return vf[0]
|
||||
|
||||
def get_initial_features(self):
|
||||
return []
|
||||
|
||||
@@ -18,7 +18,7 @@ class SharedModelLSTM(TFPolicy):
|
||||
to be tracked).
|
||||
"""
|
||||
|
||||
other_output = ["value", "features"]
|
||||
other_output = ["vf_preds", "features"]
|
||||
is_recurrent = True
|
||||
|
||||
def __init__(self, ob_space, ac_space, **kwargs):
|
||||
@@ -48,19 +48,20 @@ class SharedModelLSTM(TFPolicy):
|
||||
initializer=tf.constant_initializer(0, dtype=tf.int32),
|
||||
trainable=False)
|
||||
|
||||
def compute_gradients(self, batch):
|
||||
def compute_gradients(self, trajectory):
|
||||
"""Computing the gradient is actually model-dependent.
|
||||
|
||||
The LSTM needs its hidden states in order to compute the gradient
|
||||
accurately.
|
||||
"""
|
||||
features = trajectory["features"][0]
|
||||
feed_dict = {
|
||||
self.x: batch.si,
|
||||
self.ac: batch.a,
|
||||
self.adv: batch.adv,
|
||||
self.r: batch.r,
|
||||
self.state_in[0]: batch.features[0],
|
||||
self.state_in[1]: batch.features[1]
|
||||
self.x: trajectory["observations"],
|
||||
self.ac: trajectory["actions"],
|
||||
self.adv: trajectory["advantages"],
|
||||
self.r: trajectory["value_targets"],
|
||||
self.state_in[0]: features[0],
|
||||
self.state_in[1]: features[1]
|
||||
}
|
||||
info = {}
|
||||
self.local_steps += 1
|
||||
@@ -72,11 +73,11 @@ class SharedModelLSTM(TFPolicy):
|
||||
grad = self.sess.run(self.grads, feed_dict=feed_dict)
|
||||
return grad, info
|
||||
|
||||
def compute_action(self, ob, c, h):
|
||||
def compute(self, ob, c, h):
|
||||
action, vf, c, h = self.sess.run(
|
||||
[self.sample, self.vf] + self.state_out,
|
||||
{self.x: [ob], self.state_in[0]: c, self.state_in[1]: h})
|
||||
return action[0], {"value": vf[0], "features": (c, h)}
|
||||
return action[0], {"vf_preds": vf[0], "features": (c, h)}
|
||||
|
||||
def value(self, ob, c, h):
|
||||
vf = self.sess.run(self.vf, {self.x: [ob],
|
||||
|
||||
@@ -14,7 +14,7 @@ from ray.rllib.models.catalog import ModelCatalog
|
||||
class SharedTorchPolicy(TorchPolicy):
|
||||
"""Assumes nonrecurrent."""
|
||||
|
||||
other_output = ["value"]
|
||||
other_output = ["vf_preds"]
|
||||
is_recurrent = False
|
||||
|
||||
def __init__(self, ob_space, ac_space, **kwargs):
|
||||
@@ -26,14 +26,14 @@ class SharedTorchPolicy(TorchPolicy):
|
||||
self._model = ModelCatalog.get_torch_model(ob_space, self.logit_dim)
|
||||
self.optimizer = torch.optim.Adam(self._model.parameters(), lr=0.0001)
|
||||
|
||||
def compute_action(self, ob, *args):
|
||||
def compute(self, ob, *args):
|
||||
"""Should take in a SINGLE ob"""
|
||||
with self.lock:
|
||||
ob = Variable(torch.from_numpy(ob).float().unsqueeze(0))
|
||||
logits, values = self._model(ob)
|
||||
samples = self._model.probs(logits).multinomial().squeeze()
|
||||
values = values.squeeze(0)
|
||||
return var_to_np(samples), {"value": var_to_np(values)}
|
||||
return var_to_np(samples), {"vf_preds": var_to_np(values)}
|
||||
|
||||
def compute_logits(self, ob, *args):
|
||||
with self.lock:
|
||||
@@ -71,6 +71,3 @@ class SharedTorchPolicy(TorchPolicy):
|
||||
overall_err = 0.5 * value_err + pi_err - entropy * 0.01
|
||||
overall_err.backward()
|
||||
torch.nn.utils.clip_grad_norm(self._model.parameters(), 40)
|
||||
|
||||
def get_initial_features(self):
|
||||
return [None]
|
||||
|
||||
@@ -92,7 +92,7 @@ class TFPolicy(Policy):
|
||||
def compute_gradients(self, batch):
|
||||
raise NotImplementedError
|
||||
|
||||
def compute_action(self, observations):
|
||||
def compute(self, observation):
|
||||
raise NotImplementedError
|
||||
|
||||
def value(self, ob):
|
||||
|
||||
@@ -73,6 +73,3 @@ class TorchPolicy(Policy):
|
||||
This function regenerates the backward trace and
|
||||
caluclates the gradient."""
|
||||
raise NotImplementedError
|
||||
|
||||
def get_initial_features(self):
|
||||
return []
|
||||
|
||||
@@ -8,19 +8,23 @@ import torch
|
||||
from torch.autograd import Variable
|
||||
|
||||
|
||||
def convert_batch(batch, has_features=False):
|
||||
"""Convert batch from numpy to PT variable"""
|
||||
states = Variable(torch.from_numpy(batch.si).float())
|
||||
acs = Variable(torch.from_numpy(batch.a))
|
||||
advs = Variable(torch.from_numpy(batch.adv.copy()).float())
|
||||
def convert_batch(trajectory, has_features=False):
|
||||
"""Convert trajectory from numpy to PT variable"""
|
||||
states = Variable(torch.from_numpy(
|
||||
trajectory["observations"]).float())
|
||||
acs = Variable(torch.from_numpy(
|
||||
trajectory["actions"]))
|
||||
advs = Variable(torch.from_numpy(
|
||||
trajectory["advantages"].copy()).float())
|
||||
advs = advs.view(-1, 1)
|
||||
rs = Variable(torch.from_numpy(batch.r.copy()).float())
|
||||
rs = Variable(torch.from_numpy(
|
||||
trajectory["value_targets"]).float())
|
||||
rs = rs.view(-1, 1)
|
||||
if has_features:
|
||||
features = [Variable(torch.from_numpy(f))
|
||||
for f in batch.features]
|
||||
for f in trajectory["features"]]
|
||||
else:
|
||||
features = batch.features
|
||||
features = trajectory["features"]
|
||||
return states, acs, advs, rs, features
|
||||
|
||||
|
||||
|
||||
@@ -10,9 +10,12 @@ from ray.rllib.models import ModelCatalog
|
||||
|
||||
class ProximalPolicyLoss(object):
|
||||
|
||||
other_output = ["vf_preds", "logprobs"]
|
||||
is_recurrent = False
|
||||
|
||||
def __init__(
|
||||
self, observation_space, action_space,
|
||||
observations, returns, advantages, actions,
|
||||
observations, value_targets, advantages, actions,
|
||||
prev_logits, prev_vf_preds, logit_dim,
|
||||
kl_coeff, distribution_class, config, sess):
|
||||
assert (isinstance(action_space, gym.spaces.Discrete) or
|
||||
@@ -55,11 +58,11 @@ class ProximalPolicyLoss(object):
|
||||
# We use a huber loss here to be more robust against outliers,
|
||||
# which seem to occur when the rollouts get longer (the variance
|
||||
# scales superlinearly with the length of the rollout)
|
||||
self.vf_loss1 = tf.square(self.value_function - returns)
|
||||
self.vf_loss1 = tf.square(self.value_function - value_targets)
|
||||
vf_clipped = prev_vf_preds + tf.clip_by_value(
|
||||
self.value_function - prev_vf_preds,
|
||||
-config["clip_param"], config["clip_param"])
|
||||
self.vf_loss2 = tf.square(vf_clipped - returns)
|
||||
self.vf_loss2 = tf.square(vf_clipped - value_targets)
|
||||
self.vf_loss = tf.minimum(self.vf_loss1, self.vf_loss2)
|
||||
self.mean_vf_loss = tf.reduce_mean(self.vf_loss)
|
||||
self.loss = tf.reduce_mean(
|
||||
@@ -82,9 +85,11 @@ class ProximalPolicyLoss(object):
|
||||
self.policy_results = [
|
||||
self.sampler, self.curr_logits, tf.constant("NA")]
|
||||
|
||||
def compute(self, observations):
|
||||
return self.sess.run(self.policy_results,
|
||||
feed_dict={self.observations: observations})
|
||||
def compute(self, observation):
|
||||
action, logprobs, vf = self.sess.run(
|
||||
self.policy_results,
|
||||
feed_dict={self.observations: [observation]})
|
||||
return action[0], {"vf_preds": vf[0], "logprobs": logprobs[0]}
|
||||
|
||||
def loss(self):
|
||||
return self.loss
|
||||
|
||||
+16
-13
@@ -11,8 +11,9 @@ import tensorflow as tf
|
||||
from tensorflow.python import debug as tf_debug
|
||||
|
||||
import ray
|
||||
from ray.rllib.agent import Agent
|
||||
from ray.tune.result import TrainingResult
|
||||
from ray.rllib.agent import Agent
|
||||
from ray.rllib.utils.filter import get_filter
|
||||
from ray.rllib.ppo.runner import Runner, RemoteRunner
|
||||
from ray.rllib.ppo.rollout import collect_samples
|
||||
from ray.rllib.ppo.utils import shuffle
|
||||
@@ -90,10 +91,10 @@ class PPOAgent(Agent):
|
||||
self.global_step = 0
|
||||
self.kl_coeff = self.config["kl_coeff"]
|
||||
self.model = Runner(
|
||||
self.env_creator, 1, self.config, self.logdir, False)
|
||||
self.env_creator, self.config, self.logdir, False)
|
||||
self.agents = [
|
||||
RemoteRunner.remote(
|
||||
self.env_creator, 1, self.config, self.logdir, True)
|
||||
self.env_creator, self.config, self.logdir, True)
|
||||
for _ in range(self.config["num_workers"])]
|
||||
self.start_time = time.time()
|
||||
if self.config["write_logs"]:
|
||||
@@ -102,6 +103,9 @@ class PPOAgent(Agent):
|
||||
else:
|
||||
self.file_writer = None
|
||||
self.saver = tf.train.Saver(max_to_keep=None)
|
||||
self.obs_filter = get_filter(
|
||||
self.config["observation_filter"],
|
||||
self.model.env.observation_space.shape)
|
||||
|
||||
def _train(self):
|
||||
agents = self.agents
|
||||
@@ -114,11 +118,11 @@ class PPOAgent(Agent):
|
||||
weights = ray.put(model.get_weights())
|
||||
[a.load_weights.remote(weights) for a in agents]
|
||||
trajectory, total_reward, traj_len_mean = collect_samples(
|
||||
agents, config, self.model.observation_filter,
|
||||
agents, config, self.obs_filter,
|
||||
self.model.reward_filter)
|
||||
print("total reward is ", total_reward)
|
||||
print("trajectory length mean is ", traj_len_mean)
|
||||
print("timesteps:", trajectory["dones"].shape[0])
|
||||
print("timesteps:", trajectory["actions"].shape[0])
|
||||
if self.file_writer:
|
||||
traj_stats = tf.Summary(value=[
|
||||
tf.Summary.Value(
|
||||
@@ -135,10 +139,7 @@ class PPOAgent(Agent):
|
||||
# to guard against the case where all values are equal
|
||||
return (value - value.mean()) / max(1e-4, value.std())
|
||||
|
||||
if config["use_gae"]:
|
||||
trajectory["advantages"] = standardized(trajectory["advantages"])
|
||||
else:
|
||||
trajectory["returns"] = standardized(trajectory["returns"])
|
||||
trajectory["advantages"] = standardized(trajectory["advantages"])
|
||||
|
||||
rollouts_end = time.time()
|
||||
print("Computing policy (iterations=" + str(config["num_sgd_iter"]) +
|
||||
@@ -238,7 +239,7 @@ class PPOAgent(Agent):
|
||||
result = TrainingResult(
|
||||
episode_reward_mean=total_reward,
|
||||
episode_len_mean=traj_len_mean,
|
||||
timesteps_this_iter=trajectory["dones"].shape[0],
|
||||
timesteps_this_iter=trajectory["actions"].shape[0],
|
||||
info=info)
|
||||
|
||||
return result
|
||||
@@ -253,7 +254,8 @@ class PPOAgent(Agent):
|
||||
self.model.save(),
|
||||
self.global_step,
|
||||
self.kl_coeff,
|
||||
agent_state]
|
||||
agent_state,
|
||||
self.obs_filter]
|
||||
pickle.dump(extra_data, open(checkpoint_path + ".extra_data", "wb"))
|
||||
return checkpoint_path
|
||||
|
||||
@@ -266,7 +268,8 @@ class PPOAgent(Agent):
|
||||
ray.get([
|
||||
a.restore.remote(o)
|
||||
for (a, o) in zip(self.agents, extra_data[3])])
|
||||
self.obs_filter = extra_data[4]
|
||||
|
||||
def compute_action(self, observation):
|
||||
observation = self.model.observation_filter(observation, update=False)
|
||||
return self.model.common_policy.compute([observation])[0][0]
|
||||
observation = self.obs_filter(observation, update=False)
|
||||
return self.model.common_policy.compute(observation)[0]
|
||||
|
||||
@@ -8,92 +8,6 @@ import ray
|
||||
from ray.rllib.ppo.utils import concatenate
|
||||
|
||||
|
||||
def rollouts(policy, env, horizon, observation_filter, reward_filter):
|
||||
"""Perform a batch of rollouts of a policy in an environment.
|
||||
|
||||
Args:
|
||||
policy: The policy that will be rollout out. Can be an arbitrary object
|
||||
that supports a compute_actions(observation) function.
|
||||
env: The environment the rollout is computed in. Needs to support the
|
||||
OpenAI gym API and needs to support batches of data.
|
||||
horizon: Upper bound for the number of timesteps for each rollout in
|
||||
the batch.
|
||||
observation_filter: Function that is applied to each of the
|
||||
observations.
|
||||
reward_filter: Function that is applied to each of the rewards.
|
||||
|
||||
Returns:
|
||||
A trajectory, which is a dictionary with keys "observations",
|
||||
"rewards", "orig_rewards", "actions", "logprobs", "dones". Each
|
||||
value is an array of shape (num_timesteps, env.batchsize, shape).
|
||||
"""
|
||||
|
||||
observation = observation_filter(env.reset())
|
||||
done = np.array(env.batchsize * [False])
|
||||
t = 0
|
||||
observations = [] # Filtered observations
|
||||
raw_rewards = [] # Empirical rewards
|
||||
actions = [] # Actions sampled by the policy
|
||||
logprobs = [] # Last layer of the policy network
|
||||
vf_preds = [] # Value function predictions
|
||||
dones = [] # Has this rollout terminated?
|
||||
|
||||
while True:
|
||||
action, logprob, vfpred = policy.compute(observation)
|
||||
vf_preds.append(vfpred)
|
||||
observations.append(observation[None])
|
||||
actions.append(action[None])
|
||||
logprobs.append(logprob[None])
|
||||
observation, raw_reward, done = env.step(action)
|
||||
observation = observation_filter(observation)
|
||||
raw_rewards.append(raw_reward[None])
|
||||
dones.append(done[None])
|
||||
t += 1
|
||||
if done.all() or t >= horizon:
|
||||
break
|
||||
|
||||
return {"observations": np.vstack(observations),
|
||||
"raw_rewards": np.vstack(raw_rewards),
|
||||
"actions": np.vstack(actions),
|
||||
"logprobs": np.vstack(logprobs),
|
||||
"vf_preds": np.vstack(vf_preds),
|
||||
"dones": np.vstack(dones)}
|
||||
|
||||
|
||||
def add_return_values(trajectory, gamma, reward_filter):
|
||||
rewards = trajectory["raw_rewards"]
|
||||
dones = trajectory["dones"]
|
||||
returns = np.zeros_like(rewards)
|
||||
last_return = np.zeros(rewards.shape[1], dtype="float32")
|
||||
|
||||
for t in reversed(range(len(rewards) - 1)):
|
||||
last_return = rewards[t, :] * (1 - dones[t, :]) + gamma * last_return
|
||||
returns[t, :] = last_return
|
||||
reward_filter(returns[t, :])
|
||||
|
||||
trajectory["returns"] = returns
|
||||
|
||||
|
||||
def add_advantage_values(trajectory, gamma, lam, reward_filter):
|
||||
rewards = trajectory["raw_rewards"]
|
||||
vf_preds = trajectory["vf_preds"]
|
||||
dones = trajectory["dones"]
|
||||
advantages = np.zeros_like(rewards)
|
||||
last_advantage = np.zeros(rewards.shape[1], dtype="float32")
|
||||
|
||||
for t in reversed(range(len(rewards) - 1)):
|
||||
delta = rewards[t, :] * (1 - dones[t, :]) + \
|
||||
gamma * vf_preds[t+1, :] * (1 - dones[t+1, :]) - vf_preds[t, :]
|
||||
last_advantage = \
|
||||
delta + gamma * lam * last_advantage * (1 - dones[t+1, :])
|
||||
advantages[t, :] = last_advantage
|
||||
reward_filter(advantages[t, :])
|
||||
|
||||
trajectory["advantages"] = advantages
|
||||
trajectory["td_lambda_returns"] = \
|
||||
trajectory["advantages"] + trajectory["vf_preds"]
|
||||
|
||||
|
||||
def collect_samples(agents,
|
||||
config,
|
||||
observation_filter,
|
||||
@@ -106,26 +20,20 @@ def collect_samples(agents,
|
||||
# computed to the agent that they are computed on; we start some initial
|
||||
# tasks here.
|
||||
agent_dict = {agent.compute_steps.remote(
|
||||
config["gamma"], config["lambda"],
|
||||
config["horizon"], config["min_steps_per_task"],
|
||||
observation_filter, reward_filter):
|
||||
config, observation_filter, reward_filter):
|
||||
agent for agent in agents}
|
||||
while num_timesteps_so_far < config["timesteps_per_batch"]:
|
||||
# TODO(pcm): Make wait support arbitrary iterators and remove the
|
||||
# conversion to list here.
|
||||
[next_trajectory], waiting_trajectories = ray.wait(
|
||||
list(agent_dict.keys()))
|
||||
[next_trajectory], _ = ray.wait(list(agent_dict))
|
||||
agent = agent_dict.pop(next_trajectory)
|
||||
# Start task with next trajectory and record it in the dictionary.
|
||||
agent_dict[agent.compute_steps.remote(
|
||||
config["gamma"], config["lambda"],
|
||||
config["horizon"], config["min_steps_per_task"],
|
||||
observation_filter, reward_filter)] = (
|
||||
agent)
|
||||
config, observation_filter, reward_filter)] = agent
|
||||
trajectory, rewards, lengths, obs_f, rew_f = ray.get(next_trajectory)
|
||||
total_rewards.extend(rewards)
|
||||
trajectory_lengths.extend(lengths)
|
||||
num_timesteps_so_far += len(trajectory["dones"])
|
||||
num_timesteps_so_far += sum(lengths)
|
||||
trajectories.append(trajectory)
|
||||
observation_filter.update(obs_f)
|
||||
reward_filter.update(rew_f)
|
||||
|
||||
@@ -14,12 +14,13 @@ import ray
|
||||
|
||||
from ray.rllib.parallel import LocalSyncParallelOptimizer
|
||||
from ray.rllib.models import ModelCatalog
|
||||
from ray.rllib.envs import create_and_wrap
|
||||
from ray.rllib.utils.sampler import SyncSampler
|
||||
from ray.rllib.utils.filter import get_filter, MeanStdFilter
|
||||
from ray.rllib.ppo.env import BatchedEnv
|
||||
from ray.rllib.utils.process_rollout import process_rollout
|
||||
from ray.rllib.ppo.loss import ProximalPolicyLoss
|
||||
from ray.rllib.ppo.rollout import (
|
||||
rollouts, add_return_values, add_advantage_values)
|
||||
from ray.rllib.ppo.utils import flatten, concatenate
|
||||
from ray.rllib.ppo.utils import concatenate
|
||||
|
||||
|
||||
# TODO(pcm): Make sure that both observation_filter and reward_filter
|
||||
# are correctly handled, i.e. (a) the values are accumulated accross
|
||||
@@ -37,7 +38,8 @@ class Runner(object):
|
||||
network weights. When run as a remote agent, only this graph is used.
|
||||
"""
|
||||
|
||||
def __init__(self, env_creator, batchsize, config, logdir, is_remote):
|
||||
def __init__(self, env_creator, config, logdir, is_remote):
|
||||
self.is_remote = is_remote
|
||||
if is_remote:
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = ""
|
||||
devices = ["/cpu:0"]
|
||||
@@ -46,12 +48,11 @@ class Runner(object):
|
||||
self.devices = devices
|
||||
self.config = config
|
||||
self.logdir = logdir
|
||||
self.env = BatchedEnv(env_creator, batchsize, config)
|
||||
self.env = create_and_wrap(env_creator, config["model"])
|
||||
if is_remote:
|
||||
config_proto = tf.ConfigProto()
|
||||
else:
|
||||
config_proto = tf.ConfigProto(**config["tf_session_args"])
|
||||
self.preprocessor = self.env.preprocessor
|
||||
self.sess = tf.Session(config=config_proto)
|
||||
if config["tf_debug_inf_or_nan"] and not is_remote:
|
||||
self.sess = tf_debug.LocalCLIDebugWrapperSession(self.sess)
|
||||
@@ -65,13 +66,14 @@ class Runner(object):
|
||||
|
||||
# The input observations.
|
||||
self.observations = tf.placeholder(
|
||||
tf.float32, shape=(None,) + self.preprocessor.shape)
|
||||
tf.float32, shape=(None,) + self.env.observation_space.shape)
|
||||
# Targets of the value function.
|
||||
self.returns = tf.placeholder(tf.float32, shape=(None,))
|
||||
self.value_targets = tf.placeholder(tf.float32, shape=(None,))
|
||||
# Advantage values in the policy gradient estimator.
|
||||
self.advantages = tf.placeholder(tf.float32, shape=(None,))
|
||||
|
||||
action_space = self.env.action_space
|
||||
# TODO(rliaw): pull this into model_catalog
|
||||
if isinstance(action_space, gym.spaces.Box):
|
||||
self.actions = tf.placeholder(
|
||||
tf.float32, shape=(None, action_space.shape[0]))
|
||||
@@ -98,17 +100,17 @@ class Runner(object):
|
||||
self.batch_size = config["sgd_batchsize"]
|
||||
self.per_device_batch_size = int(self.batch_size / len(devices))
|
||||
|
||||
def build_loss(obs, rets, advs, acts, plog, pvf_preds):
|
||||
def build_loss(obs, vtargets, advs, acts, plog, pvf_preds):
|
||||
return ProximalPolicyLoss(
|
||||
self.env.observation_space, self.env.action_space,
|
||||
obs, rets, advs, acts, plog, pvf_preds, self.logit_dim,
|
||||
obs, vtargets, advs, acts, plog, pvf_preds, self.logit_dim,
|
||||
self.kl_coeff, self.distribution_class, self.config,
|
||||
self.sess)
|
||||
|
||||
self.par_opt = LocalSyncParallelOptimizer(
|
||||
tf.train.AdamOptimizer(self.config["sgd_stepsize"]),
|
||||
self.devices,
|
||||
[self.observations, self.returns, self.advantages,
|
||||
[self.observations, self.value_targets, self.advantages,
|
||||
self.actions, self.prev_logits, self.prev_vf_preds],
|
||||
self.per_device_batch_size,
|
||||
build_loss,
|
||||
@@ -137,33 +139,26 @@ class Runner(object):
|
||||
self.common_policy = self.par_opt.get_common_loss()
|
||||
self.variables = ray.experimental.TensorFlowVariables(
|
||||
self.common_policy.loss, self.sess)
|
||||
self.observation_filter = get_filter(
|
||||
config["observation_filter"], self.preprocessor.shape)
|
||||
obs_filter = get_filter(
|
||||
config["observation_filter"], self.env.observation_space.shape)
|
||||
self.sampler = SyncSampler(
|
||||
self.env, self.common_policy, obs_filter,
|
||||
self.config["horizon"], self.config["horizon"])
|
||||
self.reward_filter = MeanStdFilter((), clip=5.0)
|
||||
self.sess.run(tf.global_variables_initializer())
|
||||
|
||||
def load_data(self, trajectories, full_trace):
|
||||
if self.config["use_gae"]:
|
||||
return self.par_opt.load_data(
|
||||
self.sess,
|
||||
[trajectories["observations"],
|
||||
trajectories["td_lambda_returns"],
|
||||
trajectories["advantages"],
|
||||
trajectories["actions"].squeeze(),
|
||||
trajectories["logprobs"],
|
||||
trajectories["vf_preds"]],
|
||||
full_trace=full_trace)
|
||||
else:
|
||||
dummy = np.zeros((trajectories["observations"].shape[0],))
|
||||
return self.par_opt.load_data(
|
||||
self.sess,
|
||||
[trajectories["observations"],
|
||||
dummy,
|
||||
trajectories["returns"],
|
||||
trajectories["actions"].squeeze(),
|
||||
trajectories["logprobs"],
|
||||
dummy],
|
||||
full_trace=full_trace)
|
||||
use_gae = self.config["use_gae"]
|
||||
dummy = np.zeros_like(trajectories["advantages"])
|
||||
return self.par_opt.load_data(
|
||||
self.sess,
|
||||
[trajectories["observations"],
|
||||
trajectories["value_targets"] if use_gae else dummy,
|
||||
trajectories["advantages"],
|
||||
trajectories["actions"].squeeze(),
|
||||
trajectories["logprobs"],
|
||||
trajectories["vf_preds"] if use_gae else dummy],
|
||||
full_trace=full_trace)
|
||||
|
||||
def run_sgd_minibatch(
|
||||
self, batch_index, kl_coeff, full_trace, file_writer):
|
||||
@@ -177,12 +172,14 @@ class Runner(object):
|
||||
file_writer=file_writer if full_trace else None)
|
||||
|
||||
def save(self):
|
||||
return pickle.dumps([self.observation_filter, self.reward_filter])
|
||||
obs_filter = self.sampler.get_obs_filter()
|
||||
return pickle.dumps([obs_filter, self.reward_filter])
|
||||
|
||||
def restore(self, objs):
|
||||
objs = pickle.loads(objs)
|
||||
self.observation_filter = objs[0]
|
||||
self.reward_filter = objs[1]
|
||||
obs_filter = objs[0]
|
||||
rew_filter = objs[1]
|
||||
self.update_filters(obs_filter, rew_filter)
|
||||
|
||||
def get_weights(self):
|
||||
return self.variables.get_weights()
|
||||
@@ -190,29 +187,22 @@ class Runner(object):
|
||||
def load_weights(self, weights):
|
||||
self.variables.set_weights(weights)
|
||||
|
||||
def compute_trajectory(self, gamma, lam, horizon):
|
||||
"""Compute a single rollout on the agent and return."""
|
||||
trajectory = rollouts(
|
||||
self.common_policy,
|
||||
self.env, horizon, self.observation_filter, self.reward_filter)
|
||||
if self.config["use_gae"]:
|
||||
add_advantage_values(trajectory, gamma, lam, self.reward_filter)
|
||||
else:
|
||||
add_return_values(trajectory, gamma, self.reward_filter)
|
||||
return trajectory
|
||||
def update_filters(self, obs_filter=None, rew_filter=None):
|
||||
if rew_filter:
|
||||
# No special handling required since outside of threaded code
|
||||
self.reward_filter = rew_filter.copy()
|
||||
if obs_filter:
|
||||
self.sampler.update_obs_filter(obs_filter)
|
||||
|
||||
def compute_steps(
|
||||
self, gamma, lam, horizon, min_steps_per_task,
|
||||
observation_filter, reward_filter):
|
||||
def get_obs_filter(self):
|
||||
return self.sampler.get_obs_filter()
|
||||
|
||||
def compute_steps(self, config, obs_filter, rew_filter):
|
||||
"""Compute multiple rollouts and concatenate the results.
|
||||
|
||||
Args:
|
||||
gamma: MDP discount factor
|
||||
lam: GAE(lambda) parameter
|
||||
horizon: Number of steps after which a rollout gets cut
|
||||
min_steps_per_task: Lower bound on the number of states to be
|
||||
collected.
|
||||
observation_filter: Function that is applied to each of the
|
||||
config: Configuration parameters
|
||||
obs_filter: Function that is applied to each of the
|
||||
observations.
|
||||
reward_filter: Function that is applied to each of the rewards.
|
||||
|
||||
@@ -221,38 +211,26 @@ class Runner(object):
|
||||
total_rewards: Total rewards of the trajectories.
|
||||
trajectory_lengths: Lengths of the trajectories.
|
||||
"""
|
||||
|
||||
# Update our local filters
|
||||
self.observation_filter = observation_filter.copy()
|
||||
self.reward_filter = reward_filter.copy()
|
||||
|
||||
num_steps_so_far = 0
|
||||
trajectories = []
|
||||
total_rewards = []
|
||||
trajectory_lengths = []
|
||||
while True:
|
||||
trajectory = self.compute_trajectory(gamma, lam, horizon)
|
||||
total_rewards.append(
|
||||
trajectory["raw_rewards"].sum(axis=0).mean())
|
||||
trajectory_lengths.append(
|
||||
np.logical_not(trajectory["dones"]).sum(axis=0).mean())
|
||||
trajectory = flatten(trajectory)
|
||||
not_done = np.logical_not(trajectory["dones"])
|
||||
# Filtering out states that are done. We do this because
|
||||
# trajectories are batched and cut only if all the trajectories
|
||||
# in the batch terminated, so we can potentially get rid of
|
||||
# some of the states here.
|
||||
trajectory = {key: val[not_done]
|
||||
for key, val in trajectory.items()}
|
||||
num_steps_so_far += trajectory["raw_rewards"].shape[0]
|
||||
self.update_filters(obs_filter, rew_filter)
|
||||
|
||||
while num_steps_so_far < config["min_steps_per_task"]:
|
||||
rollout = self.sampler.get_data()
|
||||
trajectory = process_rollout(
|
||||
rollout, self.reward_filter, config["gamma"],
|
||||
config["lambda"], use_gae=config["use_gae"])
|
||||
num_steps_so_far += trajectory["rewards"].shape[0]
|
||||
trajectories.append(trajectory)
|
||||
if num_steps_so_far >= min_steps_per_task:
|
||||
break
|
||||
metrics = self.sampler.get_metrics()
|
||||
total_rewards, trajectory_lengths = zip(*[
|
||||
(c.episode_reward, c.episode_length) for c in metrics])
|
||||
updated_obs_filter = self.sampler.get_obs_filter(flush=True)
|
||||
return (
|
||||
concatenate(trajectories),
|
||||
total_rewards,
|
||||
trajectory_lengths,
|
||||
self.observation_filter,
|
||||
updated_obs_filter,
|
||||
self.reward_filter)
|
||||
|
||||
|
||||
|
||||
@@ -30,7 +30,7 @@ def concatenate(weights_list):
|
||||
|
||||
|
||||
def shuffle(trajectory):
|
||||
permutation = np.random.permutation(trajectory["dones"].shape[0])
|
||||
permutation = np.random.permutation(trajectory["actions"].shape[0])
|
||||
for key, val in trajectory.items():
|
||||
trajectory[key] = val[permutation]
|
||||
return trajectory
|
||||
|
||||
@@ -5,7 +5,7 @@ from __future__ import print_function
|
||||
import numpy as np
|
||||
|
||||
|
||||
class BaseFilter(object):
|
||||
class Filter(object):
|
||||
"""Processes input, possibly statefully."""
|
||||
|
||||
def update(self, other, *args, **kwargs):
|
||||
@@ -24,7 +24,7 @@ class BaseFilter(object):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class NoFilter(BaseFilter):
|
||||
class NoFilter(Filter):
|
||||
def __init__(self, *args):
|
||||
pass
|
||||
|
||||
@@ -107,7 +107,7 @@ class RunningStat(object):
|
||||
return self._M.shape
|
||||
|
||||
|
||||
class MeanStdFilter(object):
|
||||
class MeanStdFilter(Filter):
|
||||
"""Keeps track of a running mean for seen states"""
|
||||
|
||||
def __init__(self, shape, demean=True, destd=True, clip=10.0):
|
||||
|
||||
@@ -0,0 +1,40 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
import scipy.signal
|
||||
|
||||
|
||||
def discount(x, gamma):
|
||||
return scipy.signal.lfilter([1], [1, -gamma], x[::-1], axis=0)[::-1]
|
||||
|
||||
|
||||
def process_rollout(rollout, reward_filter, gamma, lambda_=1.0, use_gae=True):
|
||||
"""Given a rollout, compute its value targets and the advantage."""
|
||||
|
||||
traj = {}
|
||||
trajsize = len(rollout.data["actions"])
|
||||
for key in rollout.data:
|
||||
traj[key] = np.stack(rollout.data[key])
|
||||
|
||||
if use_gae:
|
||||
assert "vf_preds" in rollout.data, "Values not found!"
|
||||
vpred_t = np.stack(
|
||||
rollout.data["vf_preds"] + [np.array(rollout.last_r)]).squeeze()
|
||||
delta_t = traj["rewards"] + gamma * vpred_t[1:] - vpred_t[:-1]
|
||||
# This formula for the advantage comes
|
||||
# "Generalized Advantage Estimation": https://arxiv.org/abs/1506.02438
|
||||
traj["advantages"] = discount(delta_t, gamma * lambda_)
|
||||
traj["value_targets"] = traj["advantages"] + traj["vf_preds"]
|
||||
else:
|
||||
rewards_plus_v = np.stack(
|
||||
rollout.data["rewards"] + [np.array(rollout.last_r)]).squeeze()
|
||||
traj["advantages"] = discount(rewards_plus_v, gamma)[:-1]
|
||||
|
||||
for i in range(traj["advantages"].shape[0]):
|
||||
traj["advantages"][i] = reward_filter(traj["advantages"][i])
|
||||
|
||||
assert all(val.shape[0] == trajsize for val in traj.values()), \
|
||||
"Rollout stacked incorrectly!"
|
||||
return traj
|
||||
@@ -19,9 +19,14 @@ class PartialRollout(object):
|
||||
|
||||
We run our agent, and process its experience once it has processed enough
|
||||
steps.
|
||||
|
||||
Attributes:
|
||||
data (dict): Stores rollout data. All numpy arrays other than
|
||||
`observations` and `features` will be squeezed.
|
||||
last_r (float): Value of next state. Used for bootstrapping.
|
||||
"""
|
||||
|
||||
fields = ["state", "action", "reward", "terminal", "features"]
|
||||
fields = ["observations", "actions", "rewards", "terminal", "features"]
|
||||
|
||||
def __init__(self, extra_fields=None):
|
||||
"""Initializers internals. Maintains a `last_r` field
|
||||
@@ -72,23 +77,40 @@ class SyncSampler(object):
|
||||
thread."""
|
||||
async = False
|
||||
|
||||
def __init__(self, env, policy, num_local_steps, obs_filter):
|
||||
def __init__(self, env, policy, obs_filter,
|
||||
num_local_steps, horizon=None):
|
||||
self.num_local_steps = num_local_steps
|
||||
self.horizon = horizon
|
||||
self.env = env
|
||||
self.policy = policy
|
||||
self.obs_filter = obs_filter
|
||||
self._obs_filter = obs_filter
|
||||
self.rollout_provider = _env_runner(
|
||||
self.env, self.policy, self.num_local_steps, self.obs_filter)
|
||||
self.env, self.policy, self.num_local_steps, self.horizon,
|
||||
self._obs_filter)
|
||||
self.metrics_queue = queue.Queue()
|
||||
|
||||
def update_obs_filter(self, other_filter):
|
||||
"""Method to update observation filter with copy from driver.
|
||||
Since this class is synchronous, updating the observation
|
||||
filter should be a straightforward replacement
|
||||
def get_obs_filter(self, flush=False):
|
||||
"""Gets a snapshot of the current observation filter. The snapshot
|
||||
also by default does not clear the accumulated delta.
|
||||
|
||||
Args:
|
||||
other_filter: Another filter (of same type)."""
|
||||
self.obs_filter = other_filter.copy()
|
||||
flush (bool): If True, accumulated state in buffer is cleared.
|
||||
|
||||
Returns:
|
||||
snapshot (Filter): Copy of observation filter.
|
||||
"""
|
||||
snapshot = self._obs_filter.copy()
|
||||
if flush and hasattr(self._obs_filter, "clear_buffer"):
|
||||
self._obs_filter.clear_buffer()
|
||||
return snapshot
|
||||
|
||||
def update_obs_filter(self, other_filter):
|
||||
"""Updates observation filter with copy from driver.
|
||||
|
||||
Args:
|
||||
other_filter: Another filter (of same type).
|
||||
"""
|
||||
self._obs_filter.sync(other_filter)
|
||||
|
||||
def get_data(self):
|
||||
while True:
|
||||
@@ -96,10 +118,7 @@ class SyncSampler(object):
|
||||
if isinstance(item, CompletedRollout):
|
||||
self.metrics_queue.put(item)
|
||||
else:
|
||||
obsf_snapshot = self.obs_filter.copy()
|
||||
if hasattr(self.obs_filter, "clear_buffer"):
|
||||
self.obs_filter.clear_buffer()
|
||||
return item, obsf_snapshot
|
||||
return item
|
||||
|
||||
def get_metrics(self):
|
||||
completed = []
|
||||
@@ -118,15 +137,17 @@ class AsyncSampler(threading.Thread):
|
||||
accumulate and the gradient can be calculated on up to 5 batches."""
|
||||
async = True
|
||||
|
||||
def __init__(self, env, policy, num_local_steps, obs_filter):
|
||||
def __init__(self, env, policy, obs_filter,
|
||||
num_local_steps, horizon=None):
|
||||
threading.Thread.__init__(self)
|
||||
self.queue = queue.Queue(5)
|
||||
self.metrics_queue = queue.Queue()
|
||||
self.num_local_steps = num_local_steps
|
||||
self.horizon = horizon
|
||||
self.env = env
|
||||
self.policy = policy
|
||||
self.obs_filter = obs_filter
|
||||
self.obs_f_lock = threading.Lock()
|
||||
self._obs_filter = obs_filter
|
||||
self._obs_f_lock = threading.Lock()
|
||||
self.start()
|
||||
|
||||
def run(self):
|
||||
@@ -139,25 +160,26 @@ class AsyncSampler(threading.Thread):
|
||||
def update_obs_filter(self, other_filter):
|
||||
"""Method to update observation filter with copy from driver.
|
||||
Applies delta since last `clear_buffer` to given new filter,
|
||||
and syncs current filter to new filter. `self.obs_filter` is
|
||||
and syncs current filter to new filter. `self._obs_filter` is
|
||||
kept in place due to the `lock_wrap`.
|
||||
|
||||
Args:
|
||||
other_filter: Another filter (of same type)."""
|
||||
with self.obs_f_lock:
|
||||
with self._obs_f_lock:
|
||||
new_filter = other_filter.copy()
|
||||
# Applies delta to filter, including buffer
|
||||
new_filter.update(self.obs_filter, copy_buffer=True)
|
||||
new_filter.update(self._obs_filter, copy_buffer=True)
|
||||
# copies everything back into original filter - needed
|
||||
# due to `lock_wrap`
|
||||
self.obs_filter.sync(new_filter)
|
||||
self._obs_filter.sync(new_filter)
|
||||
|
||||
def _run(self):
|
||||
"""Sets observation filter into an atomic region and starts
|
||||
other thread for running."""
|
||||
safe_obs_filter = lock_wrap(self.obs_filter, self.obs_f_lock)
|
||||
safe_obs_filter = lock_wrap(self._obs_filter, self._obs_f_lock)
|
||||
rollout_provider = _env_runner(
|
||||
self.env, self.policy, self.num_local_steps, safe_obs_filter)
|
||||
self.env, self.policy, self.num_local_steps,
|
||||
self.horizon, safe_obs_filter)
|
||||
while True:
|
||||
# The timeout variable exists because apparently, if one worker
|
||||
# dies, the other workers won't die with it, unless the timeout is
|
||||
@@ -168,24 +190,32 @@ class AsyncSampler(threading.Thread):
|
||||
else:
|
||||
self.queue.put(item, timeout=600.0)
|
||||
|
||||
def get_data(self):
|
||||
"""Gets currently accumulated data and a snapshot of the current
|
||||
observation filter. The snapshot also clears the accumulated delta.
|
||||
Note that in between getting the rollout and acquiring the lock,
|
||||
def get_obs_filter(self, flush=False):
|
||||
"""Gets a snapshot of the current observation filter. The snapshot
|
||||
also clears the accumulated delta. Note that in between getting
|
||||
the rollout from self.queue and acquiring the lock here,
|
||||
the other thread can run, resulting in slight discrepamcies
|
||||
between data retrieved and filter statistics.
|
||||
|
||||
Returns:
|
||||
rollout: trajectory data (unprocessed)
|
||||
obsf_snapshot: snapshot of observation filter.
|
||||
snapshot (Filter): Copy of observation filter.
|
||||
"""
|
||||
|
||||
with self._obs_f_lock:
|
||||
snapshot = self._obs_filter.copy()
|
||||
if hasattr(self._obs_filter, "clear_buffer"):
|
||||
self._obs_filter.clear_buffer()
|
||||
return snapshot
|
||||
|
||||
def get_data(self):
|
||||
"""Gets currently accumulated data.
|
||||
|
||||
Returns:
|
||||
rollout (PartialRollout): trajectory data (unprocessed)
|
||||
"""
|
||||
|
||||
rollout = self._pull_batch_from_queue()
|
||||
with self.obs_f_lock:
|
||||
obsf_snapshot = self.obs_filter.copy()
|
||||
if hasattr(self.obs_filter, "clear_buffer"):
|
||||
self.obs_filter.clear_buffer()
|
||||
return rollout, obsf_snapshot
|
||||
return rollout
|
||||
|
||||
def _pull_batch_from_queue(self):
|
||||
"""Take a rollout from the queue of the thread runner."""
|
||||
@@ -212,7 +242,7 @@ class AsyncSampler(threading.Thread):
|
||||
return completed
|
||||
|
||||
|
||||
def _env_runner(env, policy, num_local_steps, obs_filter):
|
||||
def _env_runner(env, policy, num_local_steps, horizon, obs_filter):
|
||||
"""This implements the logic of the thread runner.
|
||||
|
||||
It continually runs the policy, and as long as the rollout exceeds a
|
||||
@@ -231,10 +261,15 @@ def _env_runner(env, policy, num_local_steps, obs_filter):
|
||||
rollout (PartialRollout): Object containing state, action, reward,
|
||||
terminal condition, and other fields as dictated by `policy`.
|
||||
"""
|
||||
last_state = obs_filter(env.reset())
|
||||
timestep_limit = env.spec.tags.get("wrapper_config.TimeLimit"
|
||||
".max_episode_steps")
|
||||
last_features = features = policy.get_initial_features()
|
||||
last_observation = obs_filter(env.reset())
|
||||
horizon = horizon if horizon else env.spec.tags.get(
|
||||
"wrapper_config.TimeLimit.max_episode_steps")
|
||||
assert horizon > 0
|
||||
if hasattr(policy, "get_initial_features"):
|
||||
last_features = policy.get_initial_features()
|
||||
else:
|
||||
last_features = []
|
||||
features = last_features
|
||||
length = 0
|
||||
rewards = 0
|
||||
rollout_number = 0
|
||||
@@ -244,44 +279,47 @@ def _env_runner(env, policy, num_local_steps, obs_filter):
|
||||
rollout = PartialRollout(extra_fields=policy.other_output)
|
||||
|
||||
for _ in range(num_local_steps):
|
||||
action, pi_info = policy.compute_action(last_state, *last_features)
|
||||
action, pi_info = policy.compute(last_observation, *last_features)
|
||||
if policy.is_recurrent:
|
||||
features = pi_info["features"]
|
||||
del pi_info["features"]
|
||||
state, reward, terminal, info = env.step(action)
|
||||
state = obs_filter(state)
|
||||
observation, reward, terminal, info = env.step(action)
|
||||
observation = obs_filter(observation)
|
||||
|
||||
length += 1
|
||||
rewards += reward
|
||||
if length >= timestep_limit:
|
||||
if length >= horizon:
|
||||
terminal = True
|
||||
|
||||
# Collect the experience.
|
||||
rollout.add(state=last_state,
|
||||
action=action,
|
||||
reward=reward,
|
||||
rollout.add(observations=last_observation,
|
||||
actions=action,
|
||||
rewards=reward,
|
||||
terminal=terminal,
|
||||
features=last_features,
|
||||
**pi_info)
|
||||
|
||||
last_state = state
|
||||
last_observation = observation
|
||||
last_features = features
|
||||
|
||||
if terminal:
|
||||
terminal_end = True
|
||||
yield CompletedRollout(length, rewards)
|
||||
|
||||
if (length >= timestep_limit or
|
||||
if (length >= horizon or
|
||||
not env.metadata.get("semantics.autoreset")):
|
||||
last_state = obs_filter(env.reset())
|
||||
last_features = policy.get_initial_features()
|
||||
last_observation = obs_filter(env.reset())
|
||||
if hasattr(policy, "get_initial_features"):
|
||||
last_features = policy.get_initial_features()
|
||||
else:
|
||||
last_features = []
|
||||
rollout_number += 1
|
||||
length = 0
|
||||
rewards = 0
|
||||
break
|
||||
|
||||
if not terminal_end:
|
||||
rollout.last_r = policy.value(last_state, *last_features)
|
||||
rollout.last_r = policy.value(last_observation, *last_features)
|
||||
|
||||
# Once we have enough experience, yield it, and have the ThreadRunner
|
||||
# place it on a queue.
|
||||
|
||||
Reference in New Issue
Block a user