[rllib] Generalizing A3C Sampling Classes (#1250)

This commit is contained in:
Richard Liaw
2017-11-30 00:22:25 -08:00
committed by GitHub
parent 7db07acc4f
commit 483dee2ff3
13 changed files with 487 additions and 274 deletions
+35 -21
View File
@@ -10,17 +10,24 @@ import ray
from ray.rllib.agent import Agent
from ray.rllib.a3c.envs import create_and_wrap
from ray.rllib.a3c.runner import RemoteRunner
from ray.rllib.a3c.shared_model import SharedModel
from ray.rllib.a3c.shared_model_lstm import SharedModelLSTM
from ray.rllib.a3c.common import get_policy_cls
from ray.rllib.utils.filter import get_filter
from ray.tune.result import TrainingResult
DEFAULT_CONFIG = {
"num_workers": 4,
"num_batches_per_iteration": 100,
# Size of rollout batch
"batch_size": 10,
"use_lstm": True,
"use_pytorch": False,
# Which observation filter to apply to the observation
"observation_filter": "NoFilter",
# Which reward filter to apply to the reward
"reward_filter": "NoFilter",
"model": {"grayscale": True,
"zero_mean": False,
"dim": 42,
@@ -34,38 +41,43 @@ class A3CAgent(Agent):
def _init(self):
self.env = create_and_wrap(self.env_creator, self.config["model"])
if self.config["use_lstm"]:
policy_cls = SharedModelLSTM
elif self.config["use_pytorch"]:
from ray.rllib.a3c.shared_torch_policy import SharedTorchPolicy
policy_cls = SharedTorchPolicy
else:
policy_cls = SharedModel
policy_cls = get_policy_cls(self.config)
self.policy = policy_cls(
self.env.observation_space.shape, self.env.action_space)
self.obs_filter = get_filter(
self.config["observation_filter"],
self.env.observation_space.shape)
self.rew_filter = get_filter(self.config["reward_filter"], ())
self.agents = [
RemoteRunner.remote(self.env_creator, policy_cls, i,
self.config["batch_size"],
self.config["model"], self.logdir)
RemoteRunner.remote(self.env_creator, self.config, self.logdir)
for i in range(self.config["num_workers"])]
self.parameters = self.policy.get_weights()
def _train(self):
gradient_list = [
agent.compute_gradient.remote(self.parameters)
for agent in self.agents]
remote_params = ray.put(self.parameters)
ray.get([agent.set_weights.remote(remote_params)
for agent in self.agents])
gradient_list = {agent.compute_gradient.remote(): agent
for agent in self.agents}
max_batches = self.config["num_batches_per_iteration"]
batches_so_far = len(gradient_list)
while gradient_list:
done_id, gradient_list = ray.wait(gradient_list)
gradient, info = ray.get(done_id)[0]
[done_id], _ = ray.wait(list(gradient_list))
gradient, info = ray.get(done_id)
agent = gradient_list.pop(done_id)
self.obs_filter.update(info["obs_filter"])
self.rew_filter.update(info["rew_filter"])
self.policy.apply_gradients(gradient)
self.parameters = self.policy.get_weights()
if batches_so_far < max_batches:
batches_so_far += 1
gradient_list.extend(
[self.agents[info["id"]].compute_gradient.remote(
self.parameters)])
agent.update_filters.remote(
obs_filter=self.obs_filter,
rew_filter=self.rew_filter)
agent.set_weights.remote(self.parameters)
gradient_list[agent.compute_gradient.remote()] = agent
res = self._fetch_metrics_from_workers()
return res
@@ -95,13 +107,15 @@ class A3CAgent(Agent):
def _save(self):
checkpoint_path = os.path.join(
self.logdir, "checkpoint-{}".format(self.iteration))
objects = [self.parameters]
objects = [self.parameters, self.obs_filter, self.rew_filter]
pickle.dump(objects, open(checkpoint_path, "wb"))
return checkpoint_path
def _restore(self, checkpoint_path):
objects = pickle.load(open(checkpoint_path, "rb"))
self.parameters = objects[0]
self.obs_filter = objects[1]
self.rew_filter = objects[2]
self.policy.set_weights(self.parameters)
def compute_action(self, observation):
+26 -12
View File
@@ -11,27 +11,41 @@ def discount(x, gamma):
return scipy.signal.lfilter([1], [1, -gamma], x[::-1], axis=0)[::-1]
def process_rollout(rollout, gamma, lambda_=1.0):
"""Given a rollout, compute its returns and the advantage."""
batch_si = np.asarray(rollout.states)
batch_a = np.asarray(rollout.actions)
rewards = np.asarray(rollout.rewards)
vpred_t = np.asarray(rollout.values + [rollout.r])
def process_rollout(rollout, reward_filter, gamma, lambda_=1.0):
"""Given a rollout, compute its returns and the advantage.
rewards_plus_v = np.asarray(rollout.rewards + [rollout.r])
TODO(rliaw): generalize this"""
batch_si = np.asarray(rollout.data["state"])
batch_a = np.asarray(rollout.data["action"])
rewards = np.asarray(rollout.data["reward"])
vpred_t = np.asarray(rollout.data["value"] + [rollout.last_r])
rewards_plus_v = np.asarray(rollout.data["reward"] + [rollout.last_r])
batch_r = discount(rewards_plus_v, gamma)[:-1]
delta_t = rewards + gamma * vpred_t[1:] - vpred_t[:-1]
# This formula for the advantage comes "Generalized Advantage Estimation":
# https://arxiv.org/abs/1506.02438
batch_adv = discount(delta_t, gamma * lambda_)
for i in range(batch_adv.shape[0]):
batch_adv[i] = reward_filter(batch_adv[i])
features = rollout.features[0]
return Batch(batch_si, batch_a, batch_adv, batch_r, rollout.terminal,
features = rollout.data["features"][0]
return Batch(batch_si, batch_a, batch_adv, batch_r, rollout.is_terminal(),
features)
def get_policy_cls(config):
if config["use_lstm"]:
from ray.rllib.a3c.shared_model_lstm import SharedModelLSTM
policy_cls = SharedModelLSTM
elif config["use_pytorch"]:
from ray.rllib.a3c.shared_torch_policy import SharedTorchPolicy
policy_cls = SharedTorchPolicy
else:
from ray.rllib.a3c.shared_model import SharedModel
policy_cls = SharedModel
return policy_cls
Batch = namedtuple(
"Batch", ["si", "a", "adv", "r", "terminal", "features"])
CompletedRollout = namedtuple(
"CompletedRollout", ["episode_length", "episode_reward"])
+46 -56
View File
@@ -2,82 +2,72 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from ray.rllib.a3c.envs import create_and_wrap
import tensorflow as tf
import six.moves.queue as queue
from ray.rllib.a3c.runner_thread import RunnerThread
from ray.rllib.a3c.common import process_rollout
from ray.rllib.a3c.tfpolicy import TFPolicy
import ray
import os
from ray.rllib.a3c.envs import create_and_wrap
from ray.rllib.a3c.common import process_rollout, get_policy_cls
from ray.rllib.utils.filter import get_filter
from ray.rllib.utils.sampler import AsyncSampler
class Runner(object):
"""Actor object to start running simulation on workers.
The gradient computation is also executed from this object.
Attributes:
policy: Copy of graph used for policy. Used by sampler and gradients.
rew_filter: Reward filter used in rollout post-processing.
sampler: Component for interacting with environment and generating
rollouts.
logdir: Directory for logging.
"""
def __init__(self, env_creator, policy_cls, actor_id, batch_size,
preprocess_config, logdir):
env = create_and_wrap(env_creator, preprocess_config)
self.id = actor_id
def __init__(self, env_creator, config, logdir):
self.env = env = create_and_wrap(env_creator, config["model"])
policy_cls = get_policy_cls(config)
# TODO(rliaw): should change this to be just env.observation_space
self.policy = policy_cls(env.observation_space.shape, env.action_space)
self.runner = RunnerThread(env, self.policy, batch_size)
self.env = env
self.logdir = logdir
self.start()
obs_filter = get_filter(
config["observation_filter"], env.observation_space.shape)
self.rew_filter = get_filter(config["reward_filter"], ())
def pull_batch_from_queue(self):
"""Take a rollout from the queue of the thread runner."""
rollout = self.runner.queue.get(timeout=600.0)
if isinstance(rollout, BaseException):
raise rollout
while not rollout.terminal:
try:
part = self.runner.queue.get_nowait()
if isinstance(part, BaseException):
raise rollout
rollout.extend(part)
except queue.Empty:
break
return rollout
self.sampler = AsyncSampler(env, self.policy, config["batch_size"],
obs_filter)
self.logdir = logdir
def get_data(self):
"""
Returns:
trajectory: trajectory information
obs_filter: Current state of observation filter
rew_filter: Current state of reward filter"""
rollout, obs_filter = self.sampler.get_data()
return rollout, obs_filter, self.rew_filter
def get_completed_rollout_metrics(self):
"""Returns metrics on previously completed rollouts.
Calling this clears the queue of completed rollout metrics.
"""
completed = []
while True:
try:
completed.append(self.runner.metrics_queue.get_nowait())
except queue.Empty:
break
return completed
return self.sampler.get_metrics()
def start(self):
summary_writer = tf.summary.FileWriter(
os.path.join(self.logdir, "agent_%d" % self.id))
self.summary_writer = summary_writer
if isinstance(self.policy, TFPolicy):
self.runner.start_runner(self.policy.sess, summary_writer)
else:
self.runner.start_runner(tf.Session(), summary_writer)
def compute_gradient(self, params):
self.policy.set_weights(params)
rollout = self.pull_batch_from_queue()
batch = process_rollout(rollout, gamma=0.99, lambda_=1.0)
def compute_gradient(self):
rollout, obsf_snapshot = self.sampler.get_data()
batch = process_rollout(
rollout, self.rew_filter, gamma=0.99, lambda_=1.0)
gradient, info = self.policy.compute_gradients(batch)
if "summary" in info:
self.summary_writer.add_summary(
tf.Summary.FromString(info['summary']),
self.policy.local_steps)
self.summary_writer.flush()
info = {"id": self.id,
"size": len(batch.a)}
info["obs_filter"] = obsf_snapshot
info["rew_filter"] = self.rew_filter
return gradient, info
def set_weights(self, params):
self.policy.set_weights(params)
def update_filters(self, obs_filter=None, rew_filter=None):
if rew_filter:
# No special handling required since outside of threaded code
self.rew_filter = rew_filter.copy()
if obs_filter:
self.sampler.update_obs_filter(obs_filter)
RemoteRunner = ray.remote(Runner)
-151
View File
@@ -1,151 +0,0 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
import six.moves.queue as queue
import threading
from ray.rllib.a3c.common import CompletedRollout
class PartialRollout(object):
"""A piece of a complete rollout.
We run our agent, and process its experience once it has processed enough
steps.
"""
def __init__(self):
self.states = []
self.actions = []
self.rewards = []
self.values = []
self.r = 0.0
self.terminal = False
self.features = []
def add(self, state, action, reward, value, terminal, features):
self.states += [state]
self.actions += [action]
self.rewards += [reward]
self.values += [value]
self.terminal = terminal
self.features += [features]
def extend(self, other):
assert not self.terminal
self.states.extend(other.states)
self.actions.extend(other.actions)
self.rewards.extend(other.rewards)
self.values.extend(other.values)
self.r = other.r
self.terminal = other.terminal
self.features.extend(other.features)
class RunnerThread(threading.Thread):
"""This thread interacts with the environment and tells it what to do."""
def __init__(self, env, policy, num_local_steps, visualise=False):
threading.Thread.__init__(self)
self.queue = queue.Queue(5)
self.metrics_queue = queue.Queue()
self.num_local_steps = num_local_steps
self.env = env
self.last_features = None
self.policy = policy
self.daemon = True
self.sess = None
self.summary_writer = None
self.visualise = visualise
def start_runner(self, sess, summary_writer):
self.sess = sess
self.summary_writer = summary_writer
self.start()
def run(self):
try:
with self.sess.as_default():
self._run()
except BaseException as e:
self.queue.put(e)
raise e
def _run(self):
rollout_provider = env_runner(
self.env, self.policy, self.num_local_steps,
self.summary_writer, self.visualise)
while True:
# The timeout variable exists because apparently, if one worker
# dies, the other workers won't die with it, unless the timeout is
# set to some large number. This is an empirical observation.
item = next(rollout_provider)
if isinstance(item, CompletedRollout):
self.metrics_queue.put(item)
else:
self.queue.put(item, timeout=600.0)
def env_runner(env, policy, num_local_steps, summary_writer, render):
"""This implements the logic of the thread runner.
It continually runs the policy, and as long as the rollout exceeds a
certain length, the thread runner appends the policy to the queue.
"""
last_state = env.reset()
timestep_limit = env.spec.tags.get("wrapper_config.TimeLimit"
".max_episode_steps")
last_features = policy.get_initial_features()
length = 0
rewards = 0
rollout_number = 0
while True:
terminal_end = False
rollout = PartialRollout()
for _ in range(num_local_steps):
fetched = policy.compute_action(last_state, *last_features)
action, value_, features = fetched[0], fetched[1], fetched[2:]
# Argmax to convert from one-hot.
state, reward, terminal, info = env.step(action)
if render:
env.render()
length += 1
rewards += reward
if length >= timestep_limit:
terminal = True
# Collect the experience.
rollout.add(last_state, action, reward, value_, terminal,
last_features)
last_state = state
last_features = features
if info:
summary = tf.Summary()
for k, v in info.items():
summary.value.add(tag=k, simple_value=float(v))
summary_writer.add_summary(summary, rollout_number)
summary_writer.flush()
if terminal:
terminal_end = True
yield CompletedRollout(length, rewards)
if (length >= timestep_limit or
not env.metadata.get("semantics.autoreset")):
last_state = env.reset()
last_features = policy.get_initial_features()
rollout_number += 1
length = 0
rewards = 0
break
if not terminal_end:
rollout.r = policy.value(last_state, *last_features)
# Once we have enough experience, yield it, and have the ThreadRunner
# place it on a queue.
yield rollout
+5 -1
View File
@@ -9,6 +9,10 @@ from ray.rllib.models.catalog import ModelCatalog
class SharedModel(TFPolicy):
other_output = ["value"]
is_recurrent = False
def __init__(self, ob_space, ac_space, **kwargs):
super(SharedModel, self).__init__(ob_space, ac_space, **kwargs)
@@ -52,7 +56,7 @@ class SharedModel(TFPolicy):
def compute_action(self, ob, *args):
action, vf = self.sess.run([self.sample, self.vf],
{self.x: [ob]})
return action[0], vf[0]
return action[0], {"value": vf[0]}
def value(self, ob, *args):
vf = self.sess.run(self.vf, {self.x: [ob]})
+11 -2
View File
@@ -10,6 +10,16 @@ from ray.rllib.models.lstm import LSTM
class SharedModelLSTM(TFPolicy):
"""
Attributes:
other_output (list): Other than `action`, the other return values from
`compute_gradient`.
is_recurrent (bool): True if is a recurrent network (requires features
to be tracked).
"""
other_output = ["value", "features"]
is_recurrent = True
def __init__(self, ob_space, ac_space, **kwargs):
super(SharedModelLSTM, self).__init__(ob_space, ac_space, **kwargs)
@@ -66,10 +76,9 @@ class SharedModelLSTM(TFPolicy):
action, vf, c, h = self.sess.run(
[self.sample, self.vf] + self.state_out,
{self.x: [ob], self.state_in[0]: c, self.state_in[1]: h})
return action[0], vf[0], c, h
return action[0], {"value": vf[0], "features": (c, h)}
def value(self, ob, c, h):
# process_rollout is very non-intuitive due to value being a float
vf = self.sess.run(self.vf, {self.x: [ob],
self.state_in[0]: c,
self.state_in[1]: h})
+4 -1
View File
@@ -14,6 +14,9 @@ from ray.rllib.models.catalog import ModelCatalog
class SharedTorchPolicy(TorchPolicy):
"""Assumes nonrecurrent."""
other_output = ["value"]
is_recurrent = False
def __init__(self, ob_space, ac_space, **kwargs):
super(SharedTorchPolicy, self).__init__(
ob_space, ac_space, **kwargs)
@@ -30,7 +33,7 @@ class SharedTorchPolicy(TorchPolicy):
logits, values = self._model(ob)
samples = self._model.probs(logits).multinomial().squeeze()
values = values.squeeze(0)
return var_to_np(samples), var_to_np(values)
return var_to_np(samples), {"value": var_to_np(values)}
def compute_logits(self, ob, *args):
with self.lock:
+1 -1
View File
@@ -16,7 +16,7 @@ from ray.rllib.dqn import logger, models
from ray.rllib.dqn.common.wrappers import wrap_dqn
from ray.rllib.dqn.common.schedules import LinearSchedule
from ray.rllib.dqn.replay_buffer import ReplayBuffer, PrioritizedReplayBuffer
from ray.rllib.ppo.filter import RunningStat
from ray.rllib.utils.filter import RunningStat
from ray.tune.result import TrainingResult
+3 -12
View File
@@ -11,8 +11,7 @@ import tensorflow as tf
import ray
from ray.rllib.models import ModelCatalog
# TODO(rkn): Move these filters out of PPO to somewhere common.
from ray.rllib.ppo.filter import NoFilter, MeanStdFilter
from ray.rllib.utils.filter import get_filter
def rollout(policy, env, timestep_limit=None, add_noise=False):
@@ -46,16 +45,8 @@ class GenericPolicy(object):
self.action_space = action_space
self.action_noise_std = action_noise_std
self.preprocessor = preprocessor
if observation_filter == "MeanStdFilter":
self.observation_filter = MeanStdFilter(
self.preprocessor.shape, clip=None)
elif observation_filter == "NoFilter":
self.observation_filter = NoFilter()
else:
raise Exception("Unknown observation_filter: " +
str("observation_filter"))
self.observation_filter = get_filter(
observation_filter, self.preprocessor.shape)
self.inputs = tf.placeholder(
tf.float32, [None] + list(self.preprocessor.shape))
+3 -9
View File
@@ -14,9 +14,9 @@ import ray
from ray.rllib.parallel import LocalSyncParallelOptimizer
from ray.rllib.models import ModelCatalog
from ray.rllib.utils.filter import get_filter, MeanStdFilter
from ray.rllib.ppo.env import BatchedEnv
from ray.rllib.ppo.loss import ProximalPolicyLoss
from ray.rllib.ppo.filter import NoFilter, MeanStdFilter
from ray.rllib.ppo.rollout import (
rollouts, add_return_values, add_advantage_values)
from ray.rllib.ppo.utils import flatten, concatenate
@@ -137,14 +137,8 @@ class Runner(object):
self.common_policy = self.par_opt.get_common_loss()
self.variables = ray.experimental.TensorFlowVariables(
self.common_policy.loss, self.sess)
if config["observation_filter"] == "MeanStdFilter":
self.observation_filter = MeanStdFilter(
self.preprocessor.shape, clip=None)
elif config["observation_filter"] == "NoFilter":
self.observation_filter = NoFilter()
else:
raise Exception("Unknown observation_filter: " +
str(config["observation_filter"]))
self.observation_filter = get_filter(
config["observation_filter"], self.preprocessor.shape)
self.reward_filter = MeanStdFilter((), clip=5.0)
self.sess.run(tf.global_variables_initializer())
View File
@@ -5,19 +5,41 @@ from __future__ import print_function
import numpy as np
class NoFilter(object):
def __init__(self):
class BaseFilter(object):
"""Processes input, possibly statefully."""
def update(self, other, *args, **kwargs):
"""Updates self with "new state" from other filter."""
raise NotImplementedError
def copy(self):
"""Creates a new object with same state as self.
Returns:
copy (Filter): Copy of self"""
raise NotImplementedError
def sync(self, other):
"""Copies all state from other filter to self."""
raise NotImplementedError
class NoFilter(BaseFilter):
def __init__(self, *args):
pass
def __call__(self, x, update=True):
return np.asarray(x)
def update(self, other):
def update(self, other, *args, **kwargs):
pass
def copy(self):
return self
def sync(self, other):
pass
# http://www.johndcook.com/blog/standard_deviation/
class RunningStat(object):
@@ -103,12 +125,22 @@ class MeanStdFilter(object):
def clear_buffer(self):
self.buffer = RunningStat(self.shape)
def update(self, other):
# `update` takes another filter and
# only applies the information from the buffer.
def update(self, other, copy_buffer=False):
"""Takes another filter and only applies the information from the
buffer.
Using notation `F(state, buffer)`
Given `Filter1(x1, y1)` and `Filter2(x2, yt)`,
`update` modifies `Filter1` to `Filter1(x1 + yt, y1)`
If `copy_buffer`, then `Filter1` is modified to
`Filter1(x1 + yt, yt)`.
"""
self.rs.update(other.buffer)
if copy_buffer:
self.buffer = other.buffer.copy()
def copy(self):
"""Returns a copy of Filter."""
other = MeanStdFilter(self.shape)
other.demean = self.demean
other.destd = self.destd
@@ -117,6 +149,20 @@ class MeanStdFilter(object):
other.buffer = self.buffer.copy()
return other
def sync(self, other):
"""Syncs all fields together from other filter.
Using notation `F(state, buffer)`
Given `Filter1(x1, y1)` and `Filter2(x2, yt)`,
`sync` modifies `Filter1` to `Filter1(x2, yt)`
"""
assert other.shape == self.shape, "Shapes don't match!"
self.demean = other.demean
self.destd = other.destd
self.clip = other.clip
self.rs = other.rs.copy()
self.buffer = other.buffer.copy()
def __call__(self, x, update=True):
x = np.asarray(x)
if update:
@@ -138,8 +184,19 @@ class MeanStdFilter(object):
return x
def __repr__(self):
return 'MeanStdFilter({}, {}, {}, {}, {})'.format(
self.shape, self.demean, self.destd, self.clip, self.rs)
return 'MeanStdFilter({}, {}, {}, {}, {}, {})'.format(
self.shape, self.demean, self.destd,
self.clip, self.rs, self.buffer)
def get_filter(filter_config, shape):
if filter_config == "MeanStdFilter":
return MeanStdFilter(shape, clip=None)
elif filter_config == "NoFilter":
return NoFilter()
else:
raise Exception("Unknown observation_filter: " +
str(filter_config))
def test_running_stat():
+288
View File
@@ -0,0 +1,288 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import six.moves.queue as queue
import threading
from collections import namedtuple
def lock_wrap(func, lock):
def wrapper(*args, **kwargs):
with lock:
return func(*args, **kwargs)
return wrapper
class PartialRollout(object):
"""A piece of a complete rollout.
We run our agent, and process its experience once it has processed enough
steps.
"""
fields = ["state", "action", "reward", "terminal", "features"]
def __init__(self, extra_fields=None):
"""Initializers internals. Maintains a `last_r` field
in support of partial rollouts, used in bootstrapping advantage
estimation.
Args:
extra_fields: Optional field for object to keep track.
"""
if extra_fields:
self.fields.extend(extra_fields)
self.data = {k: [] for k in self.fields}
self.last_r = 0.0
def add(self, **kwargs):
for k, v in kwargs.items():
self.data[k] += [v]
def extend(self, other_rollout):
"""Extends internal data structure. Assumes other_rollout contains
data that occured afterwards."""
assert not self.is_terminal()
assert all(k in other_rollout.fields for k in self.fields)
for k, v in other_rollout.data.items():
self.data[k].extend(v)
self.last_r = other_rollout.last_r
def is_terminal(self):
"""Check if terminal.
Returns:
terminal (bool): if rollout has terminated."""
return self.data["terminal"][-1]
CompletedRollout = namedtuple(
"CompletedRollout", ["episode_length", "episode_reward"])
class SyncSampler(object):
"""This class interacts with the environment and tells it what to do.
Note that batch_size is only a unit of measure here. Batches can
accumulate and the gradient can be calculated on up to 5 batches.
This class provides data on invocation, rather than on a separate
thread."""
async = False
def __init__(self, env, policy, num_local_steps, obs_filter):
self.num_local_steps = num_local_steps
self.env = env
self.policy = policy
self.obs_filter = obs_filter
self.rollout_provider = _env_runner(
self.env, self.policy, self.num_local_steps, self.obs_filter)
self.metrics_queue = queue.Queue()
def update_obs_filter(self, other_filter):
"""Method to update observation filter with copy from driver.
Since this class is synchronous, updating the observation
filter should be a straightforward replacement
Args:
other_filter: Another filter (of same type)."""
self.obs_filter = other_filter.copy()
def get_data(self):
while True:
item = next(self.rollout_provider)
if isinstance(item, CompletedRollout):
self.metrics_queue.put(item)
else:
obsf_snapshot = self.obs_filter.copy()
if hasattr(self.obs_filter, "clear_buffer"):
self.obs_filter.clear_buffer()
return item, obsf_snapshot
def get_metrics(self):
completed = []
while True:
try:
completed.append(self.metrics_queue.get_nowait())
except queue.Empty:
break
return completed
class AsyncSampler(threading.Thread):
"""This class interacts with the environment and tells it what to do.
Note that batch_size is only a unit of measure here. Batches can
accumulate and the gradient can be calculated on up to 5 batches."""
async = True
def __init__(self, env, policy, num_local_steps, obs_filter):
threading.Thread.__init__(self)
self.queue = queue.Queue(5)
self.metrics_queue = queue.Queue()
self.num_local_steps = num_local_steps
self.env = env
self.policy = policy
self.obs_filter = obs_filter
self.obs_f_lock = threading.Lock()
self.start()
def run(self):
try:
self._run()
except BaseException as e:
self.queue.put(e)
raise e
def update_obs_filter(self, other_filter):
"""Method to update observation filter with copy from driver.
Applies delta since last `clear_buffer` to given new filter,
and syncs current filter to new filter. `self.obs_filter` is
kept in place due to the `lock_wrap`.
Args:
other_filter: Another filter (of same type)."""
with self.obs_f_lock:
new_filter = other_filter.copy()
# Applies delta to filter, including buffer
new_filter.update(self.obs_filter, copy_buffer=True)
# copies everything back into original filter - needed
# due to `lock_wrap`
self.obs_filter.sync(new_filter)
def _run(self):
"""Sets observation filter into an atomic region and starts
other thread for running."""
safe_obs_filter = lock_wrap(self.obs_filter, self.obs_f_lock)
rollout_provider = _env_runner(
self.env, self.policy, self.num_local_steps, safe_obs_filter)
while True:
# The timeout variable exists because apparently, if one worker
# dies, the other workers won't die with it, unless the timeout is
# set to some large number. This is an empirical observation.
item = next(rollout_provider)
if isinstance(item, CompletedRollout):
self.metrics_queue.put(item)
else:
self.queue.put(item, timeout=600.0)
def get_data(self):
"""Gets currently accumulated data and a snapshot of the current
observation filter. The snapshot also clears the accumulated delta.
Note that in between getting the rollout and acquiring the lock,
the other thread can run, resulting in slight discrepamcies
between data retrieved and filter statistics.
Returns:
rollout: trajectory data (unprocessed)
obsf_snapshot: snapshot of observation filter.
"""
rollout = self._pull_batch_from_queue()
with self.obs_f_lock:
obsf_snapshot = self.obs_filter.copy()
if hasattr(self.obs_filter, "clear_buffer"):
self.obs_filter.clear_buffer()
return rollout, obsf_snapshot
def _pull_batch_from_queue(self):
"""Take a rollout from the queue of the thread runner."""
rollout = self.queue.get(timeout=600.0)
if isinstance(rollout, BaseException):
raise rollout
while not rollout.is_terminal():
try:
part = self.queue.get_nowait()
if isinstance(part, BaseException):
raise rollout
rollout.extend(part)
except queue.Empty:
break
return rollout
def get_metrics(self):
completed = []
while True:
try:
completed.append(self.metrics_queue.get_nowait())
except queue.Empty:
break
return completed
def _env_runner(env, policy, num_local_steps, obs_filter):
"""This implements the logic of the thread runner.
It continually runs the policy, and as long as the rollout exceeds a
certain length, the thread runner appends the policy to the queue. Yields
when `timestep_limit` is surpassed, environment terminates, or
`num_local_steps` is reached.
Args:
env: Environment generated by env_creator
policy: Policy used to interact with environment. Also sets fields
to be included in `PartialRollout`
num_local_steps: Number of steps before `PartialRollout` is yielded.
obs_filter: Filter used to process observations.
Yields:
rollout (PartialRollout): Object containing state, action, reward,
terminal condition, and other fields as dictated by `policy`.
"""
last_state = obs_filter(env.reset())
timestep_limit = env.spec.tags.get("wrapper_config.TimeLimit"
".max_episode_steps")
last_features = features = policy.get_initial_features()
length = 0
rewards = 0
rollout_number = 0
while True:
terminal_end = False
rollout = PartialRollout(extra_fields=policy.other_output)
for _ in range(num_local_steps):
action, pi_info = policy.compute_action(last_state, *last_features)
if policy.is_recurrent:
features = pi_info["features"]
del pi_info["features"]
state, reward, terminal, info = env.step(action)
state = obs_filter(state)
length += 1
rewards += reward
if length >= timestep_limit:
terminal = True
# Collect the experience.
rollout.add(state=last_state,
action=action,
reward=reward,
terminal=terminal,
features=last_features,
**pi_info)
last_state = state
last_features = features
if terminal:
terminal_end = True
yield CompletedRollout(length, rewards)
if (length >= timestep_limit or
not env.metadata.get("semantics.autoreset")):
last_state = obs_filter(env.reset())
last_features = policy.get_initial_features()
rollout_number += 1
length = 0
rewards = 0
break
if not terminal_end:
rollout.last_r = policy.value(last_state, *last_features)
# Once we have enough experience, yield it, and have the ThreadRunner
# place it on a queue.
yield rollout