From 483dee2ff37bb43b8e4639521abddd761500b2e7 Mon Sep 17 00:00:00 2001 From: Richard Liaw Date: Thu, 30 Nov 2017 00:22:25 -0800 Subject: [PATCH] [rllib] Generalizing A3C Sampling Classes (#1250) --- python/ray/rllib/a3c/a3c.py | 56 ++-- python/ray/rllib/a3c/common.py | 38 ++- python/ray/rllib/a3c/runner.py | 102 ++++--- python/ray/rllib/a3c/runner_thread.py | 151 ---------- python/ray/rllib/a3c/shared_model.py | 6 +- python/ray/rllib/a3c/shared_model_lstm.py | 13 +- python/ray/rllib/a3c/shared_torch_policy.py | 5 +- python/ray/rllib/dqn/dqn.py | 2 +- python/ray/rllib/es/policies.py | 15 +- python/ray/rllib/ppo/runner.py | 12 +- python/ray/rllib/utils/__init__.py | 0 python/ray/rllib/{ppo => utils}/filter.py | 73 ++++- python/ray/rllib/utils/sampler.py | 288 ++++++++++++++++++++ 13 files changed, 487 insertions(+), 274 deletions(-) delete mode 100644 python/ray/rllib/a3c/runner_thread.py create mode 100644 python/ray/rllib/utils/__init__.py rename python/ray/rllib/{ppo => utils}/filter.py (68%) create mode 100644 python/ray/rllib/utils/sampler.py diff --git a/python/ray/rllib/a3c/a3c.py b/python/ray/rllib/a3c/a3c.py index 41e363769..ed036af00 100644 --- a/python/ray/rllib/a3c/a3c.py +++ b/python/ray/rllib/a3c/a3c.py @@ -10,17 +10,24 @@ import ray from ray.rllib.agent import Agent from ray.rllib.a3c.envs import create_and_wrap from ray.rllib.a3c.runner import RemoteRunner -from ray.rllib.a3c.shared_model import SharedModel -from ray.rllib.a3c.shared_model_lstm import SharedModelLSTM +from ray.rllib.a3c.common import get_policy_cls +from ray.rllib.utils.filter import get_filter from ray.tune.result import TrainingResult DEFAULT_CONFIG = { "num_workers": 4, "num_batches_per_iteration": 100, + + # Size of rollout batch "batch_size": 10, "use_lstm": True, "use_pytorch": False, + # Which observation filter to apply to the observation + "observation_filter": "NoFilter", + # Which reward filter to apply to the reward + "reward_filter": "NoFilter", + "model": {"grayscale": True, "zero_mean": False, "dim": 42, @@ -34,38 +41,43 @@ class A3CAgent(Agent): def _init(self): self.env = create_and_wrap(self.env_creator, self.config["model"]) - if self.config["use_lstm"]: - policy_cls = SharedModelLSTM - elif self.config["use_pytorch"]: - from ray.rllib.a3c.shared_torch_policy import SharedTorchPolicy - policy_cls = SharedTorchPolicy - else: - policy_cls = SharedModel + policy_cls = get_policy_cls(self.config) self.policy = policy_cls( self.env.observation_space.shape, self.env.action_space) + self.obs_filter = get_filter( + self.config["observation_filter"], + self.env.observation_space.shape) + self.rew_filter = get_filter(self.config["reward_filter"], ()) self.agents = [ - RemoteRunner.remote(self.env_creator, policy_cls, i, - self.config["batch_size"], - self.config["model"], self.logdir) + RemoteRunner.remote(self.env_creator, self.config, self.logdir) for i in range(self.config["num_workers"])] self.parameters = self.policy.get_weights() def _train(self): - gradient_list = [ - agent.compute_gradient.remote(self.parameters) - for agent in self.agents] + remote_params = ray.put(self.parameters) + ray.get([agent.set_weights.remote(remote_params) + for agent in self.agents]) + + gradient_list = {agent.compute_gradient.remote(): agent + for agent in self.agents} max_batches = self.config["num_batches_per_iteration"] batches_so_far = len(gradient_list) while gradient_list: - done_id, gradient_list = ray.wait(gradient_list) - gradient, info = ray.get(done_id)[0] + [done_id], _ = ray.wait(list(gradient_list)) + gradient, info = ray.get(done_id) + agent = gradient_list.pop(done_id) + self.obs_filter.update(info["obs_filter"]) + self.rew_filter.update(info["rew_filter"]) self.policy.apply_gradients(gradient) self.parameters = self.policy.get_weights() + if batches_so_far < max_batches: batches_so_far += 1 - gradient_list.extend( - [self.agents[info["id"]].compute_gradient.remote( - self.parameters)]) + agent.update_filters.remote( + obs_filter=self.obs_filter, + rew_filter=self.rew_filter) + agent.set_weights.remote(self.parameters) + gradient_list[agent.compute_gradient.remote()] = agent res = self._fetch_metrics_from_workers() return res @@ -95,13 +107,15 @@ class A3CAgent(Agent): def _save(self): checkpoint_path = os.path.join( self.logdir, "checkpoint-{}".format(self.iteration)) - objects = [self.parameters] + objects = [self.parameters, self.obs_filter, self.rew_filter] pickle.dump(objects, open(checkpoint_path, "wb")) return checkpoint_path def _restore(self, checkpoint_path): objects = pickle.load(open(checkpoint_path, "rb")) self.parameters = objects[0] + self.obs_filter = objects[1] + self.rew_filter = objects[2] self.policy.set_weights(self.parameters) def compute_action(self, observation): diff --git a/python/ray/rllib/a3c/common.py b/python/ray/rllib/a3c/common.py index 17e6e7f9b..4c528a2da 100644 --- a/python/ray/rllib/a3c/common.py +++ b/python/ray/rllib/a3c/common.py @@ -11,27 +11,41 @@ def discount(x, gamma): return scipy.signal.lfilter([1], [1, -gamma], x[::-1], axis=0)[::-1] -def process_rollout(rollout, gamma, lambda_=1.0): - """Given a rollout, compute its returns and the advantage.""" - batch_si = np.asarray(rollout.states) - batch_a = np.asarray(rollout.actions) - rewards = np.asarray(rollout.rewards) - vpred_t = np.asarray(rollout.values + [rollout.r]) +def process_rollout(rollout, reward_filter, gamma, lambda_=1.0): + """Given a rollout, compute its returns and the advantage. - rewards_plus_v = np.asarray(rollout.rewards + [rollout.r]) + TODO(rliaw): generalize this""" + batch_si = np.asarray(rollout.data["state"]) + batch_a = np.asarray(rollout.data["action"]) + rewards = np.asarray(rollout.data["reward"]) + vpred_t = np.asarray(rollout.data["value"] + [rollout.last_r]) + + rewards_plus_v = np.asarray(rollout.data["reward"] + [rollout.last_r]) batch_r = discount(rewards_plus_v, gamma)[:-1] delta_t = rewards + gamma * vpred_t[1:] - vpred_t[:-1] # This formula for the advantage comes "Generalized Advantage Estimation": # https://arxiv.org/abs/1506.02438 batch_adv = discount(delta_t, gamma * lambda_) + for i in range(batch_adv.shape[0]): + batch_adv[i] = reward_filter(batch_adv[i]) - features = rollout.features[0] - return Batch(batch_si, batch_a, batch_adv, batch_r, rollout.terminal, + features = rollout.data["features"][0] + return Batch(batch_si, batch_a, batch_adv, batch_r, rollout.is_terminal(), features) +def get_policy_cls(config): + if config["use_lstm"]: + from ray.rllib.a3c.shared_model_lstm import SharedModelLSTM + policy_cls = SharedModelLSTM + elif config["use_pytorch"]: + from ray.rllib.a3c.shared_torch_policy import SharedTorchPolicy + policy_cls = SharedTorchPolicy + else: + from ray.rllib.a3c.shared_model import SharedModel + policy_cls = SharedModel + return policy_cls + + Batch = namedtuple( "Batch", ["si", "a", "adv", "r", "terminal", "features"]) - -CompletedRollout = namedtuple( - "CompletedRollout", ["episode_length", "episode_reward"]) diff --git a/python/ray/rllib/a3c/runner.py b/python/ray/rllib/a3c/runner.py index c490ad42e..1ddc1c9b1 100644 --- a/python/ray/rllib/a3c/runner.py +++ b/python/ray/rllib/a3c/runner.py @@ -2,82 +2,72 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from ray.rllib.a3c.envs import create_and_wrap -import tensorflow as tf -import six.moves.queue as queue -from ray.rllib.a3c.runner_thread import RunnerThread -from ray.rllib.a3c.common import process_rollout -from ray.rllib.a3c.tfpolicy import TFPolicy import ray -import os +from ray.rllib.a3c.envs import create_and_wrap +from ray.rllib.a3c.common import process_rollout, get_policy_cls +from ray.rllib.utils.filter import get_filter +from ray.rllib.utils.sampler import AsyncSampler class Runner(object): """Actor object to start running simulation on workers. The gradient computation is also executed from this object. + + Attributes: + policy: Copy of graph used for policy. Used by sampler and gradients. + rew_filter: Reward filter used in rollout post-processing. + sampler: Component for interacting with environment and generating + rollouts. + logdir: Directory for logging. """ - def __init__(self, env_creator, policy_cls, actor_id, batch_size, - preprocess_config, logdir): - env = create_and_wrap(env_creator, preprocess_config) - self.id = actor_id + def __init__(self, env_creator, config, logdir): + self.env = env = create_and_wrap(env_creator, config["model"]) + policy_cls = get_policy_cls(config) # TODO(rliaw): should change this to be just env.observation_space self.policy = policy_cls(env.observation_space.shape, env.action_space) - self.runner = RunnerThread(env, self.policy, batch_size) - self.env = env - self.logdir = logdir - self.start() + obs_filter = get_filter( + config["observation_filter"], env.observation_space.shape) + self.rew_filter = get_filter(config["reward_filter"], ()) - def pull_batch_from_queue(self): - """Take a rollout from the queue of the thread runner.""" - rollout = self.runner.queue.get(timeout=600.0) - if isinstance(rollout, BaseException): - raise rollout - while not rollout.terminal: - try: - part = self.runner.queue.get_nowait() - if isinstance(part, BaseException): - raise rollout - rollout.extend(part) - except queue.Empty: - break - return rollout + self.sampler = AsyncSampler(env, self.policy, config["batch_size"], + obs_filter) + self.logdir = logdir + + def get_data(self): + """ + Returns: + trajectory: trajectory information + obs_filter: Current state of observation filter + rew_filter: Current state of reward filter""" + rollout, obs_filter = self.sampler.get_data() + return rollout, obs_filter, self.rew_filter def get_completed_rollout_metrics(self): """Returns metrics on previously completed rollouts. Calling this clears the queue of completed rollout metrics. """ - completed = [] - while True: - try: - completed.append(self.runner.metrics_queue.get_nowait()) - except queue.Empty: - break - return completed + return self.sampler.get_metrics() - def start(self): - summary_writer = tf.summary.FileWriter( - os.path.join(self.logdir, "agent_%d" % self.id)) - self.summary_writer = summary_writer - if isinstance(self.policy, TFPolicy): - self.runner.start_runner(self.policy.sess, summary_writer) - else: - self.runner.start_runner(tf.Session(), summary_writer) - - def compute_gradient(self, params): - self.policy.set_weights(params) - rollout = self.pull_batch_from_queue() - batch = process_rollout(rollout, gamma=0.99, lambda_=1.0) + def compute_gradient(self): + rollout, obsf_snapshot = self.sampler.get_data() + batch = process_rollout( + rollout, self.rew_filter, gamma=0.99, lambda_=1.0) gradient, info = self.policy.compute_gradients(batch) - if "summary" in info: - self.summary_writer.add_summary( - tf.Summary.FromString(info['summary']), - self.policy.local_steps) - self.summary_writer.flush() - info = {"id": self.id, - "size": len(batch.a)} + info["obs_filter"] = obsf_snapshot + info["rew_filter"] = self.rew_filter return gradient, info + def set_weights(self, params): + self.policy.set_weights(params) + + def update_filters(self, obs_filter=None, rew_filter=None): + if rew_filter: + # No special handling required since outside of threaded code + self.rew_filter = rew_filter.copy() + if obs_filter: + self.sampler.update_obs_filter(obs_filter) + RemoteRunner = ray.remote(Runner) diff --git a/python/ray/rllib/a3c/runner_thread.py b/python/ray/rllib/a3c/runner_thread.py deleted file mode 100644 index c82125434..000000000 --- a/python/ray/rllib/a3c/runner_thread.py +++ /dev/null @@ -1,151 +0,0 @@ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import tensorflow as tf -import six.moves.queue as queue -import threading -from ray.rllib.a3c.common import CompletedRollout - - -class PartialRollout(object): - """A piece of a complete rollout. - - We run our agent, and process its experience once it has processed enough - steps. - """ - def __init__(self): - self.states = [] - self.actions = [] - self.rewards = [] - self.values = [] - self.r = 0.0 - self.terminal = False - self.features = [] - - def add(self, state, action, reward, value, terminal, features): - self.states += [state] - self.actions += [action] - self.rewards += [reward] - self.values += [value] - self.terminal = terminal - self.features += [features] - - def extend(self, other): - assert not self.terminal - self.states.extend(other.states) - self.actions.extend(other.actions) - self.rewards.extend(other.rewards) - self.values.extend(other.values) - self.r = other.r - self.terminal = other.terminal - self.features.extend(other.features) - - -class RunnerThread(threading.Thread): - """This thread interacts with the environment and tells it what to do.""" - def __init__(self, env, policy, num_local_steps, visualise=False): - threading.Thread.__init__(self) - self.queue = queue.Queue(5) - self.metrics_queue = queue.Queue() - self.num_local_steps = num_local_steps - self.env = env - self.last_features = None - self.policy = policy - self.daemon = True - self.sess = None - self.summary_writer = None - self.visualise = visualise - - def start_runner(self, sess, summary_writer): - self.sess = sess - self.summary_writer = summary_writer - self.start() - - def run(self): - try: - with self.sess.as_default(): - self._run() - except BaseException as e: - self.queue.put(e) - raise e - - def _run(self): - rollout_provider = env_runner( - self.env, self.policy, self.num_local_steps, - self.summary_writer, self.visualise) - while True: - # The timeout variable exists because apparently, if one worker - # dies, the other workers won't die with it, unless the timeout is - # set to some large number. This is an empirical observation. - item = next(rollout_provider) - if isinstance(item, CompletedRollout): - self.metrics_queue.put(item) - else: - self.queue.put(item, timeout=600.0) - - -def env_runner(env, policy, num_local_steps, summary_writer, render): - """This implements the logic of the thread runner. - - It continually runs the policy, and as long as the rollout exceeds a - certain length, the thread runner appends the policy to the queue. - """ - last_state = env.reset() - timestep_limit = env.spec.tags.get("wrapper_config.TimeLimit" - ".max_episode_steps") - last_features = policy.get_initial_features() - length = 0 - rewards = 0 - rollout_number = 0 - - while True: - terminal_end = False - rollout = PartialRollout() - - for _ in range(num_local_steps): - fetched = policy.compute_action(last_state, *last_features) - action, value_, features = fetched[0], fetched[1], fetched[2:] - # Argmax to convert from one-hot. - state, reward, terminal, info = env.step(action) - if render: - env.render() - - length += 1 - rewards += reward - if length >= timestep_limit: - terminal = True - - # Collect the experience. - rollout.add(last_state, action, reward, value_, terminal, - last_features) - - last_state = state - last_features = features - - if info: - summary = tf.Summary() - for k, v in info.items(): - summary.value.add(tag=k, simple_value=float(v)) - summary_writer.add_summary(summary, rollout_number) - summary_writer.flush() - - if terminal: - terminal_end = True - yield CompletedRollout(length, rewards) - - if (length >= timestep_limit or - not env.metadata.get("semantics.autoreset")): - last_state = env.reset() - last_features = policy.get_initial_features() - rollout_number += 1 - length = 0 - rewards = 0 - break - - if not terminal_end: - rollout.r = policy.value(last_state, *last_features) - - # Once we have enough experience, yield it, and have the ThreadRunner - # place it on a queue. - yield rollout diff --git a/python/ray/rllib/a3c/shared_model.py b/python/ray/rllib/a3c/shared_model.py index ac5fcae44..38612ee31 100644 --- a/python/ray/rllib/a3c/shared_model.py +++ b/python/ray/rllib/a3c/shared_model.py @@ -9,6 +9,10 @@ from ray.rllib.models.catalog import ModelCatalog class SharedModel(TFPolicy): + + other_output = ["value"] + is_recurrent = False + def __init__(self, ob_space, ac_space, **kwargs): super(SharedModel, self).__init__(ob_space, ac_space, **kwargs) @@ -52,7 +56,7 @@ class SharedModel(TFPolicy): def compute_action(self, ob, *args): action, vf = self.sess.run([self.sample, self.vf], {self.x: [ob]}) - return action[0], vf[0] + return action[0], {"value": vf[0]} def value(self, ob, *args): vf = self.sess.run(self.vf, {self.x: [ob]}) diff --git a/python/ray/rllib/a3c/shared_model_lstm.py b/python/ray/rllib/a3c/shared_model_lstm.py index f6b5b2619..b09a6ac88 100644 --- a/python/ray/rllib/a3c/shared_model_lstm.py +++ b/python/ray/rllib/a3c/shared_model_lstm.py @@ -10,6 +10,16 @@ from ray.rllib.models.lstm import LSTM class SharedModelLSTM(TFPolicy): + """ + Attributes: + other_output (list): Other than `action`, the other return values from + `compute_gradient`. + is_recurrent (bool): True if is a recurrent network (requires features + to be tracked). + """ + + other_output = ["value", "features"] + is_recurrent = True def __init__(self, ob_space, ac_space, **kwargs): super(SharedModelLSTM, self).__init__(ob_space, ac_space, **kwargs) @@ -66,10 +76,9 @@ class SharedModelLSTM(TFPolicy): action, vf, c, h = self.sess.run( [self.sample, self.vf] + self.state_out, {self.x: [ob], self.state_in[0]: c, self.state_in[1]: h}) - return action[0], vf[0], c, h + return action[0], {"value": vf[0], "features": (c, h)} def value(self, ob, c, h): - # process_rollout is very non-intuitive due to value being a float vf = self.sess.run(self.vf, {self.x: [ob], self.state_in[0]: c, self.state_in[1]: h}) diff --git a/python/ray/rllib/a3c/shared_torch_policy.py b/python/ray/rllib/a3c/shared_torch_policy.py index b29e5541b..b3d7da081 100644 --- a/python/ray/rllib/a3c/shared_torch_policy.py +++ b/python/ray/rllib/a3c/shared_torch_policy.py @@ -14,6 +14,9 @@ from ray.rllib.models.catalog import ModelCatalog class SharedTorchPolicy(TorchPolicy): """Assumes nonrecurrent.""" + other_output = ["value"] + is_recurrent = False + def __init__(self, ob_space, ac_space, **kwargs): super(SharedTorchPolicy, self).__init__( ob_space, ac_space, **kwargs) @@ -30,7 +33,7 @@ class SharedTorchPolicy(TorchPolicy): logits, values = self._model(ob) samples = self._model.probs(logits).multinomial().squeeze() values = values.squeeze(0) - return var_to_np(samples), var_to_np(values) + return var_to_np(samples), {"value": var_to_np(values)} def compute_logits(self, ob, *args): with self.lock: diff --git a/python/ray/rllib/dqn/dqn.py b/python/ray/rllib/dqn/dqn.py index 14fc9edaa..86d251b91 100644 --- a/python/ray/rllib/dqn/dqn.py +++ b/python/ray/rllib/dqn/dqn.py @@ -16,7 +16,7 @@ from ray.rllib.dqn import logger, models from ray.rllib.dqn.common.wrappers import wrap_dqn from ray.rllib.dqn.common.schedules import LinearSchedule from ray.rllib.dqn.replay_buffer import ReplayBuffer, PrioritizedReplayBuffer -from ray.rllib.ppo.filter import RunningStat +from ray.rllib.utils.filter import RunningStat from ray.tune.result import TrainingResult diff --git a/python/ray/rllib/es/policies.py b/python/ray/rllib/es/policies.py index c07f73248..41476f5f3 100644 --- a/python/ray/rllib/es/policies.py +++ b/python/ray/rllib/es/policies.py @@ -11,8 +11,7 @@ import tensorflow as tf import ray from ray.rllib.models import ModelCatalog -# TODO(rkn): Move these filters out of PPO to somewhere common. -from ray.rllib.ppo.filter import NoFilter, MeanStdFilter +from ray.rllib.utils.filter import get_filter def rollout(policy, env, timestep_limit=None, add_noise=False): @@ -46,16 +45,8 @@ class GenericPolicy(object): self.action_space = action_space self.action_noise_std = action_noise_std self.preprocessor = preprocessor - - if observation_filter == "MeanStdFilter": - self.observation_filter = MeanStdFilter( - self.preprocessor.shape, clip=None) - elif observation_filter == "NoFilter": - self.observation_filter = NoFilter() - else: - raise Exception("Unknown observation_filter: " + - str("observation_filter")) - + self.observation_filter = get_filter( + observation_filter, self.preprocessor.shape) self.inputs = tf.placeholder( tf.float32, [None] + list(self.preprocessor.shape)) diff --git a/python/ray/rllib/ppo/runner.py b/python/ray/rllib/ppo/runner.py index 5c4728986..27ae5119d 100644 --- a/python/ray/rllib/ppo/runner.py +++ b/python/ray/rllib/ppo/runner.py @@ -14,9 +14,9 @@ import ray from ray.rllib.parallel import LocalSyncParallelOptimizer from ray.rllib.models import ModelCatalog +from ray.rllib.utils.filter import get_filter, MeanStdFilter from ray.rllib.ppo.env import BatchedEnv from ray.rllib.ppo.loss import ProximalPolicyLoss -from ray.rllib.ppo.filter import NoFilter, MeanStdFilter from ray.rllib.ppo.rollout import ( rollouts, add_return_values, add_advantage_values) from ray.rllib.ppo.utils import flatten, concatenate @@ -137,14 +137,8 @@ class Runner(object): self.common_policy = self.par_opt.get_common_loss() self.variables = ray.experimental.TensorFlowVariables( self.common_policy.loss, self.sess) - if config["observation_filter"] == "MeanStdFilter": - self.observation_filter = MeanStdFilter( - self.preprocessor.shape, clip=None) - elif config["observation_filter"] == "NoFilter": - self.observation_filter = NoFilter() - else: - raise Exception("Unknown observation_filter: " + - str(config["observation_filter"])) + self.observation_filter = get_filter( + config["observation_filter"], self.preprocessor.shape) self.reward_filter = MeanStdFilter((), clip=5.0) self.sess.run(tf.global_variables_initializer()) diff --git a/python/ray/rllib/utils/__init__.py b/python/ray/rllib/utils/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/python/ray/rllib/ppo/filter.py b/python/ray/rllib/utils/filter.py similarity index 68% rename from python/ray/rllib/ppo/filter.py rename to python/ray/rllib/utils/filter.py index bd424ebd3..cf22a224e 100644 --- a/python/ray/rllib/ppo/filter.py +++ b/python/ray/rllib/utils/filter.py @@ -5,19 +5,41 @@ from __future__ import print_function import numpy as np -class NoFilter(object): - def __init__(self): +class BaseFilter(object): + """Processes input, possibly statefully.""" + + def update(self, other, *args, **kwargs): + """Updates self with "new state" from other filter.""" + raise NotImplementedError + + def copy(self): + """Creates a new object with same state as self. + + Returns: + copy (Filter): Copy of self""" + raise NotImplementedError + + def sync(self, other): + """Copies all state from other filter to self.""" + raise NotImplementedError + + +class NoFilter(BaseFilter): + def __init__(self, *args): pass def __call__(self, x, update=True): return np.asarray(x) - def update(self, other): + def update(self, other, *args, **kwargs): pass def copy(self): return self + def sync(self, other): + pass + # http://www.johndcook.com/blog/standard_deviation/ class RunningStat(object): @@ -103,12 +125,22 @@ class MeanStdFilter(object): def clear_buffer(self): self.buffer = RunningStat(self.shape) - def update(self, other): - # `update` takes another filter and - # only applies the information from the buffer. + def update(self, other, copy_buffer=False): + """Takes another filter and only applies the information from the + buffer. + + Using notation `F(state, buffer)` + Given `Filter1(x1, y1)` and `Filter2(x2, yt)`, + `update` modifies `Filter1` to `Filter1(x1 + yt, y1)` + If `copy_buffer`, then `Filter1` is modified to + `Filter1(x1 + yt, yt)`. + """ self.rs.update(other.buffer) + if copy_buffer: + self.buffer = other.buffer.copy() def copy(self): + """Returns a copy of Filter.""" other = MeanStdFilter(self.shape) other.demean = self.demean other.destd = self.destd @@ -117,6 +149,20 @@ class MeanStdFilter(object): other.buffer = self.buffer.copy() return other + def sync(self, other): + """Syncs all fields together from other filter. + + Using notation `F(state, buffer)` + Given `Filter1(x1, y1)` and `Filter2(x2, yt)`, + `sync` modifies `Filter1` to `Filter1(x2, yt)` + """ + assert other.shape == self.shape, "Shapes don't match!" + self.demean = other.demean + self.destd = other.destd + self.clip = other.clip + self.rs = other.rs.copy() + self.buffer = other.buffer.copy() + def __call__(self, x, update=True): x = np.asarray(x) if update: @@ -138,8 +184,19 @@ class MeanStdFilter(object): return x def __repr__(self): - return 'MeanStdFilter({}, {}, {}, {}, {})'.format( - self.shape, self.demean, self.destd, self.clip, self.rs) + return 'MeanStdFilter({}, {}, {}, {}, {}, {})'.format( + self.shape, self.demean, self.destd, + self.clip, self.rs, self.buffer) + + +def get_filter(filter_config, shape): + if filter_config == "MeanStdFilter": + return MeanStdFilter(shape, clip=None) + elif filter_config == "NoFilter": + return NoFilter() + else: + raise Exception("Unknown observation_filter: " + + str(filter_config)) def test_running_stat(): diff --git a/python/ray/rllib/utils/sampler.py b/python/ray/rllib/utils/sampler.py new file mode 100644 index 000000000..92fddfe53 --- /dev/null +++ b/python/ray/rllib/utils/sampler.py @@ -0,0 +1,288 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import six.moves.queue as queue +import threading +from collections import namedtuple + + +def lock_wrap(func, lock): + def wrapper(*args, **kwargs): + with lock: + return func(*args, **kwargs) + return wrapper + + +class PartialRollout(object): + """A piece of a complete rollout. + + We run our agent, and process its experience once it has processed enough + steps. + """ + + fields = ["state", "action", "reward", "terminal", "features"] + + def __init__(self, extra_fields=None): + """Initializers internals. Maintains a `last_r` field + in support of partial rollouts, used in bootstrapping advantage + estimation. + + Args: + extra_fields: Optional field for object to keep track. + """ + if extra_fields: + self.fields.extend(extra_fields) + self.data = {k: [] for k in self.fields} + self.last_r = 0.0 + + def add(self, **kwargs): + for k, v in kwargs.items(): + self.data[k] += [v] + + def extend(self, other_rollout): + """Extends internal data structure. Assumes other_rollout contains + data that occured afterwards.""" + + assert not self.is_terminal() + assert all(k in other_rollout.fields for k in self.fields) + for k, v in other_rollout.data.items(): + self.data[k].extend(v) + self.last_r = other_rollout.last_r + + def is_terminal(self): + """Check if terminal. + + Returns: + terminal (bool): if rollout has terminated.""" + return self.data["terminal"][-1] + + +CompletedRollout = namedtuple( + "CompletedRollout", ["episode_length", "episode_reward"]) + + +class SyncSampler(object): + """This class interacts with the environment and tells it what to do. + + Note that batch_size is only a unit of measure here. Batches can + accumulate and the gradient can be calculated on up to 5 batches. + + This class provides data on invocation, rather than on a separate + thread.""" + async = False + + def __init__(self, env, policy, num_local_steps, obs_filter): + self.num_local_steps = num_local_steps + self.env = env + self.policy = policy + self.obs_filter = obs_filter + self.rollout_provider = _env_runner( + self.env, self.policy, self.num_local_steps, self.obs_filter) + self.metrics_queue = queue.Queue() + + def update_obs_filter(self, other_filter): + """Method to update observation filter with copy from driver. + Since this class is synchronous, updating the observation + filter should be a straightforward replacement + + Args: + other_filter: Another filter (of same type).""" + self.obs_filter = other_filter.copy() + + def get_data(self): + while True: + item = next(self.rollout_provider) + if isinstance(item, CompletedRollout): + self.metrics_queue.put(item) + else: + obsf_snapshot = self.obs_filter.copy() + if hasattr(self.obs_filter, "clear_buffer"): + self.obs_filter.clear_buffer() + return item, obsf_snapshot + + def get_metrics(self): + completed = [] + while True: + try: + completed.append(self.metrics_queue.get_nowait()) + except queue.Empty: + break + return completed + + +class AsyncSampler(threading.Thread): + """This class interacts with the environment and tells it what to do. + + Note that batch_size is only a unit of measure here. Batches can + accumulate and the gradient can be calculated on up to 5 batches.""" + async = True + + def __init__(self, env, policy, num_local_steps, obs_filter): + threading.Thread.__init__(self) + self.queue = queue.Queue(5) + self.metrics_queue = queue.Queue() + self.num_local_steps = num_local_steps + self.env = env + self.policy = policy + self.obs_filter = obs_filter + self.obs_f_lock = threading.Lock() + self.start() + + def run(self): + try: + self._run() + except BaseException as e: + self.queue.put(e) + raise e + + def update_obs_filter(self, other_filter): + """Method to update observation filter with copy from driver. + Applies delta since last `clear_buffer` to given new filter, + and syncs current filter to new filter. `self.obs_filter` is + kept in place due to the `lock_wrap`. + + Args: + other_filter: Another filter (of same type).""" + with self.obs_f_lock: + new_filter = other_filter.copy() + # Applies delta to filter, including buffer + new_filter.update(self.obs_filter, copy_buffer=True) + # copies everything back into original filter - needed + # due to `lock_wrap` + self.obs_filter.sync(new_filter) + + def _run(self): + """Sets observation filter into an atomic region and starts + other thread for running.""" + safe_obs_filter = lock_wrap(self.obs_filter, self.obs_f_lock) + rollout_provider = _env_runner( + self.env, self.policy, self.num_local_steps, safe_obs_filter) + while True: + # The timeout variable exists because apparently, if one worker + # dies, the other workers won't die with it, unless the timeout is + # set to some large number. This is an empirical observation. + item = next(rollout_provider) + if isinstance(item, CompletedRollout): + self.metrics_queue.put(item) + else: + self.queue.put(item, timeout=600.0) + + def get_data(self): + """Gets currently accumulated data and a snapshot of the current + observation filter. The snapshot also clears the accumulated delta. + Note that in between getting the rollout and acquiring the lock, + the other thread can run, resulting in slight discrepamcies + between data retrieved and filter statistics. + + Returns: + rollout: trajectory data (unprocessed) + obsf_snapshot: snapshot of observation filter. + """ + + rollout = self._pull_batch_from_queue() + with self.obs_f_lock: + obsf_snapshot = self.obs_filter.copy() + if hasattr(self.obs_filter, "clear_buffer"): + self.obs_filter.clear_buffer() + return rollout, obsf_snapshot + + def _pull_batch_from_queue(self): + """Take a rollout from the queue of the thread runner.""" + rollout = self.queue.get(timeout=600.0) + if isinstance(rollout, BaseException): + raise rollout + while not rollout.is_terminal(): + try: + part = self.queue.get_nowait() + if isinstance(part, BaseException): + raise rollout + rollout.extend(part) + except queue.Empty: + break + return rollout + + def get_metrics(self): + completed = [] + while True: + try: + completed.append(self.metrics_queue.get_nowait()) + except queue.Empty: + break + return completed + + +def _env_runner(env, policy, num_local_steps, obs_filter): + """This implements the logic of the thread runner. + + It continually runs the policy, and as long as the rollout exceeds a + certain length, the thread runner appends the policy to the queue. Yields + when `timestep_limit` is surpassed, environment terminates, or + `num_local_steps` is reached. + + Args: + env: Environment generated by env_creator + policy: Policy used to interact with environment. Also sets fields + to be included in `PartialRollout` + num_local_steps: Number of steps before `PartialRollout` is yielded. + obs_filter: Filter used to process observations. + + Yields: + rollout (PartialRollout): Object containing state, action, reward, + terminal condition, and other fields as dictated by `policy`. + """ + last_state = obs_filter(env.reset()) + timestep_limit = env.spec.tags.get("wrapper_config.TimeLimit" + ".max_episode_steps") + last_features = features = policy.get_initial_features() + length = 0 + rewards = 0 + rollout_number = 0 + + while True: + terminal_end = False + rollout = PartialRollout(extra_fields=policy.other_output) + + for _ in range(num_local_steps): + action, pi_info = policy.compute_action(last_state, *last_features) + if policy.is_recurrent: + features = pi_info["features"] + del pi_info["features"] + state, reward, terminal, info = env.step(action) + state = obs_filter(state) + + length += 1 + rewards += reward + if length >= timestep_limit: + terminal = True + + # Collect the experience. + rollout.add(state=last_state, + action=action, + reward=reward, + terminal=terminal, + features=last_features, + **pi_info) + + last_state = state + last_features = features + + if terminal: + terminal_end = True + yield CompletedRollout(length, rewards) + + if (length >= timestep_limit or + not env.metadata.get("semantics.autoreset")): + last_state = obs_filter(env.reset()) + last_features = policy.get_initial_features() + rollout_number += 1 + length = 0 + rewards = 0 + break + + if not terminal_end: + rollout.last_r = policy.value(last_state, *last_features) + + # Once we have enough experience, yield it, and have the ThreadRunner + # place it on a queue. + yield rollout