mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 11:37:28 +08:00
[rllib] Generalizing A3C Sampling Classes (#1250)
This commit is contained in:
+35
-21
@@ -10,17 +10,24 @@ import ray
|
||||
from ray.rllib.agent import Agent
|
||||
from ray.rllib.a3c.envs import create_and_wrap
|
||||
from ray.rllib.a3c.runner import RemoteRunner
|
||||
from ray.rllib.a3c.shared_model import SharedModel
|
||||
from ray.rllib.a3c.shared_model_lstm import SharedModelLSTM
|
||||
from ray.rllib.a3c.common import get_policy_cls
|
||||
from ray.rllib.utils.filter import get_filter
|
||||
from ray.tune.result import TrainingResult
|
||||
|
||||
|
||||
DEFAULT_CONFIG = {
|
||||
"num_workers": 4,
|
||||
"num_batches_per_iteration": 100,
|
||||
|
||||
# Size of rollout batch
|
||||
"batch_size": 10,
|
||||
"use_lstm": True,
|
||||
"use_pytorch": False,
|
||||
# Which observation filter to apply to the observation
|
||||
"observation_filter": "NoFilter",
|
||||
# Which reward filter to apply to the reward
|
||||
"reward_filter": "NoFilter",
|
||||
|
||||
"model": {"grayscale": True,
|
||||
"zero_mean": False,
|
||||
"dim": 42,
|
||||
@@ -34,38 +41,43 @@ class A3CAgent(Agent):
|
||||
|
||||
def _init(self):
|
||||
self.env = create_and_wrap(self.env_creator, self.config["model"])
|
||||
if self.config["use_lstm"]:
|
||||
policy_cls = SharedModelLSTM
|
||||
elif self.config["use_pytorch"]:
|
||||
from ray.rllib.a3c.shared_torch_policy import SharedTorchPolicy
|
||||
policy_cls = SharedTorchPolicy
|
||||
else:
|
||||
policy_cls = SharedModel
|
||||
policy_cls = get_policy_cls(self.config)
|
||||
self.policy = policy_cls(
|
||||
self.env.observation_space.shape, self.env.action_space)
|
||||
self.obs_filter = get_filter(
|
||||
self.config["observation_filter"],
|
||||
self.env.observation_space.shape)
|
||||
self.rew_filter = get_filter(self.config["reward_filter"], ())
|
||||
self.agents = [
|
||||
RemoteRunner.remote(self.env_creator, policy_cls, i,
|
||||
self.config["batch_size"],
|
||||
self.config["model"], self.logdir)
|
||||
RemoteRunner.remote(self.env_creator, self.config, self.logdir)
|
||||
for i in range(self.config["num_workers"])]
|
||||
self.parameters = self.policy.get_weights()
|
||||
|
||||
def _train(self):
|
||||
gradient_list = [
|
||||
agent.compute_gradient.remote(self.parameters)
|
||||
for agent in self.agents]
|
||||
remote_params = ray.put(self.parameters)
|
||||
ray.get([agent.set_weights.remote(remote_params)
|
||||
for agent in self.agents])
|
||||
|
||||
gradient_list = {agent.compute_gradient.remote(): agent
|
||||
for agent in self.agents}
|
||||
max_batches = self.config["num_batches_per_iteration"]
|
||||
batches_so_far = len(gradient_list)
|
||||
while gradient_list:
|
||||
done_id, gradient_list = ray.wait(gradient_list)
|
||||
gradient, info = ray.get(done_id)[0]
|
||||
[done_id], _ = ray.wait(list(gradient_list))
|
||||
gradient, info = ray.get(done_id)
|
||||
agent = gradient_list.pop(done_id)
|
||||
self.obs_filter.update(info["obs_filter"])
|
||||
self.rew_filter.update(info["rew_filter"])
|
||||
self.policy.apply_gradients(gradient)
|
||||
self.parameters = self.policy.get_weights()
|
||||
|
||||
if batches_so_far < max_batches:
|
||||
batches_so_far += 1
|
||||
gradient_list.extend(
|
||||
[self.agents[info["id"]].compute_gradient.remote(
|
||||
self.parameters)])
|
||||
agent.update_filters.remote(
|
||||
obs_filter=self.obs_filter,
|
||||
rew_filter=self.rew_filter)
|
||||
agent.set_weights.remote(self.parameters)
|
||||
gradient_list[agent.compute_gradient.remote()] = agent
|
||||
res = self._fetch_metrics_from_workers()
|
||||
return res
|
||||
|
||||
@@ -95,13 +107,15 @@ class A3CAgent(Agent):
|
||||
def _save(self):
|
||||
checkpoint_path = os.path.join(
|
||||
self.logdir, "checkpoint-{}".format(self.iteration))
|
||||
objects = [self.parameters]
|
||||
objects = [self.parameters, self.obs_filter, self.rew_filter]
|
||||
pickle.dump(objects, open(checkpoint_path, "wb"))
|
||||
return checkpoint_path
|
||||
|
||||
def _restore(self, checkpoint_path):
|
||||
objects = pickle.load(open(checkpoint_path, "rb"))
|
||||
self.parameters = objects[0]
|
||||
self.obs_filter = objects[1]
|
||||
self.rew_filter = objects[2]
|
||||
self.policy.set_weights(self.parameters)
|
||||
|
||||
def compute_action(self, observation):
|
||||
|
||||
@@ -11,27 +11,41 @@ def discount(x, gamma):
|
||||
return scipy.signal.lfilter([1], [1, -gamma], x[::-1], axis=0)[::-1]
|
||||
|
||||
|
||||
def process_rollout(rollout, gamma, lambda_=1.0):
|
||||
"""Given a rollout, compute its returns and the advantage."""
|
||||
batch_si = np.asarray(rollout.states)
|
||||
batch_a = np.asarray(rollout.actions)
|
||||
rewards = np.asarray(rollout.rewards)
|
||||
vpred_t = np.asarray(rollout.values + [rollout.r])
|
||||
def process_rollout(rollout, reward_filter, gamma, lambda_=1.0):
|
||||
"""Given a rollout, compute its returns and the advantage.
|
||||
|
||||
rewards_plus_v = np.asarray(rollout.rewards + [rollout.r])
|
||||
TODO(rliaw): generalize this"""
|
||||
batch_si = np.asarray(rollout.data["state"])
|
||||
batch_a = np.asarray(rollout.data["action"])
|
||||
rewards = np.asarray(rollout.data["reward"])
|
||||
vpred_t = np.asarray(rollout.data["value"] + [rollout.last_r])
|
||||
|
||||
rewards_plus_v = np.asarray(rollout.data["reward"] + [rollout.last_r])
|
||||
batch_r = discount(rewards_plus_v, gamma)[:-1]
|
||||
delta_t = rewards + gamma * vpred_t[1:] - vpred_t[:-1]
|
||||
# This formula for the advantage comes "Generalized Advantage Estimation":
|
||||
# https://arxiv.org/abs/1506.02438
|
||||
batch_adv = discount(delta_t, gamma * lambda_)
|
||||
for i in range(batch_adv.shape[0]):
|
||||
batch_adv[i] = reward_filter(batch_adv[i])
|
||||
|
||||
features = rollout.features[0]
|
||||
return Batch(batch_si, batch_a, batch_adv, batch_r, rollout.terminal,
|
||||
features = rollout.data["features"][0]
|
||||
return Batch(batch_si, batch_a, batch_adv, batch_r, rollout.is_terminal(),
|
||||
features)
|
||||
|
||||
|
||||
def get_policy_cls(config):
|
||||
if config["use_lstm"]:
|
||||
from ray.rllib.a3c.shared_model_lstm import SharedModelLSTM
|
||||
policy_cls = SharedModelLSTM
|
||||
elif config["use_pytorch"]:
|
||||
from ray.rllib.a3c.shared_torch_policy import SharedTorchPolicy
|
||||
policy_cls = SharedTorchPolicy
|
||||
else:
|
||||
from ray.rllib.a3c.shared_model import SharedModel
|
||||
policy_cls = SharedModel
|
||||
return policy_cls
|
||||
|
||||
|
||||
Batch = namedtuple(
|
||||
"Batch", ["si", "a", "adv", "r", "terminal", "features"])
|
||||
|
||||
CompletedRollout = namedtuple(
|
||||
"CompletedRollout", ["episode_length", "episode_reward"])
|
||||
|
||||
@@ -2,82 +2,72 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from ray.rllib.a3c.envs import create_and_wrap
|
||||
import tensorflow as tf
|
||||
import six.moves.queue as queue
|
||||
from ray.rllib.a3c.runner_thread import RunnerThread
|
||||
from ray.rllib.a3c.common import process_rollout
|
||||
from ray.rllib.a3c.tfpolicy import TFPolicy
|
||||
import ray
|
||||
import os
|
||||
from ray.rllib.a3c.envs import create_and_wrap
|
||||
from ray.rllib.a3c.common import process_rollout, get_policy_cls
|
||||
from ray.rllib.utils.filter import get_filter
|
||||
from ray.rllib.utils.sampler import AsyncSampler
|
||||
|
||||
|
||||
class Runner(object):
|
||||
"""Actor object to start running simulation on workers.
|
||||
|
||||
The gradient computation is also executed from this object.
|
||||
|
||||
Attributes:
|
||||
policy: Copy of graph used for policy. Used by sampler and gradients.
|
||||
rew_filter: Reward filter used in rollout post-processing.
|
||||
sampler: Component for interacting with environment and generating
|
||||
rollouts.
|
||||
logdir: Directory for logging.
|
||||
"""
|
||||
def __init__(self, env_creator, policy_cls, actor_id, batch_size,
|
||||
preprocess_config, logdir):
|
||||
env = create_and_wrap(env_creator, preprocess_config)
|
||||
self.id = actor_id
|
||||
def __init__(self, env_creator, config, logdir):
|
||||
self.env = env = create_and_wrap(env_creator, config["model"])
|
||||
policy_cls = get_policy_cls(config)
|
||||
# TODO(rliaw): should change this to be just env.observation_space
|
||||
self.policy = policy_cls(env.observation_space.shape, env.action_space)
|
||||
self.runner = RunnerThread(env, self.policy, batch_size)
|
||||
self.env = env
|
||||
self.logdir = logdir
|
||||
self.start()
|
||||
obs_filter = get_filter(
|
||||
config["observation_filter"], env.observation_space.shape)
|
||||
self.rew_filter = get_filter(config["reward_filter"], ())
|
||||
|
||||
def pull_batch_from_queue(self):
|
||||
"""Take a rollout from the queue of the thread runner."""
|
||||
rollout = self.runner.queue.get(timeout=600.0)
|
||||
if isinstance(rollout, BaseException):
|
||||
raise rollout
|
||||
while not rollout.terminal:
|
||||
try:
|
||||
part = self.runner.queue.get_nowait()
|
||||
if isinstance(part, BaseException):
|
||||
raise rollout
|
||||
rollout.extend(part)
|
||||
except queue.Empty:
|
||||
break
|
||||
return rollout
|
||||
self.sampler = AsyncSampler(env, self.policy, config["batch_size"],
|
||||
obs_filter)
|
||||
self.logdir = logdir
|
||||
|
||||
def get_data(self):
|
||||
"""
|
||||
Returns:
|
||||
trajectory: trajectory information
|
||||
obs_filter: Current state of observation filter
|
||||
rew_filter: Current state of reward filter"""
|
||||
rollout, obs_filter = self.sampler.get_data()
|
||||
return rollout, obs_filter, self.rew_filter
|
||||
|
||||
def get_completed_rollout_metrics(self):
|
||||
"""Returns metrics on previously completed rollouts.
|
||||
|
||||
Calling this clears the queue of completed rollout metrics.
|
||||
"""
|
||||
completed = []
|
||||
while True:
|
||||
try:
|
||||
completed.append(self.runner.metrics_queue.get_nowait())
|
||||
except queue.Empty:
|
||||
break
|
||||
return completed
|
||||
return self.sampler.get_metrics()
|
||||
|
||||
def start(self):
|
||||
summary_writer = tf.summary.FileWriter(
|
||||
os.path.join(self.logdir, "agent_%d" % self.id))
|
||||
self.summary_writer = summary_writer
|
||||
if isinstance(self.policy, TFPolicy):
|
||||
self.runner.start_runner(self.policy.sess, summary_writer)
|
||||
else:
|
||||
self.runner.start_runner(tf.Session(), summary_writer)
|
||||
|
||||
def compute_gradient(self, params):
|
||||
self.policy.set_weights(params)
|
||||
rollout = self.pull_batch_from_queue()
|
||||
batch = process_rollout(rollout, gamma=0.99, lambda_=1.0)
|
||||
def compute_gradient(self):
|
||||
rollout, obsf_snapshot = self.sampler.get_data()
|
||||
batch = process_rollout(
|
||||
rollout, self.rew_filter, gamma=0.99, lambda_=1.0)
|
||||
gradient, info = self.policy.compute_gradients(batch)
|
||||
if "summary" in info:
|
||||
self.summary_writer.add_summary(
|
||||
tf.Summary.FromString(info['summary']),
|
||||
self.policy.local_steps)
|
||||
self.summary_writer.flush()
|
||||
info = {"id": self.id,
|
||||
"size": len(batch.a)}
|
||||
info["obs_filter"] = obsf_snapshot
|
||||
info["rew_filter"] = self.rew_filter
|
||||
return gradient, info
|
||||
|
||||
def set_weights(self, params):
|
||||
self.policy.set_weights(params)
|
||||
|
||||
def update_filters(self, obs_filter=None, rew_filter=None):
|
||||
if rew_filter:
|
||||
# No special handling required since outside of threaded code
|
||||
self.rew_filter = rew_filter.copy()
|
||||
if obs_filter:
|
||||
self.sampler.update_obs_filter(obs_filter)
|
||||
|
||||
|
||||
RemoteRunner = ray.remote(Runner)
|
||||
|
||||
@@ -1,151 +0,0 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import tensorflow as tf
|
||||
import six.moves.queue as queue
|
||||
import threading
|
||||
from ray.rllib.a3c.common import CompletedRollout
|
||||
|
||||
|
||||
class PartialRollout(object):
|
||||
"""A piece of a complete rollout.
|
||||
|
||||
We run our agent, and process its experience once it has processed enough
|
||||
steps.
|
||||
"""
|
||||
def __init__(self):
|
||||
self.states = []
|
||||
self.actions = []
|
||||
self.rewards = []
|
||||
self.values = []
|
||||
self.r = 0.0
|
||||
self.terminal = False
|
||||
self.features = []
|
||||
|
||||
def add(self, state, action, reward, value, terminal, features):
|
||||
self.states += [state]
|
||||
self.actions += [action]
|
||||
self.rewards += [reward]
|
||||
self.values += [value]
|
||||
self.terminal = terminal
|
||||
self.features += [features]
|
||||
|
||||
def extend(self, other):
|
||||
assert not self.terminal
|
||||
self.states.extend(other.states)
|
||||
self.actions.extend(other.actions)
|
||||
self.rewards.extend(other.rewards)
|
||||
self.values.extend(other.values)
|
||||
self.r = other.r
|
||||
self.terminal = other.terminal
|
||||
self.features.extend(other.features)
|
||||
|
||||
|
||||
class RunnerThread(threading.Thread):
|
||||
"""This thread interacts with the environment and tells it what to do."""
|
||||
def __init__(self, env, policy, num_local_steps, visualise=False):
|
||||
threading.Thread.__init__(self)
|
||||
self.queue = queue.Queue(5)
|
||||
self.metrics_queue = queue.Queue()
|
||||
self.num_local_steps = num_local_steps
|
||||
self.env = env
|
||||
self.last_features = None
|
||||
self.policy = policy
|
||||
self.daemon = True
|
||||
self.sess = None
|
||||
self.summary_writer = None
|
||||
self.visualise = visualise
|
||||
|
||||
def start_runner(self, sess, summary_writer):
|
||||
self.sess = sess
|
||||
self.summary_writer = summary_writer
|
||||
self.start()
|
||||
|
||||
def run(self):
|
||||
try:
|
||||
with self.sess.as_default():
|
||||
self._run()
|
||||
except BaseException as e:
|
||||
self.queue.put(e)
|
||||
raise e
|
||||
|
||||
def _run(self):
|
||||
rollout_provider = env_runner(
|
||||
self.env, self.policy, self.num_local_steps,
|
||||
self.summary_writer, self.visualise)
|
||||
while True:
|
||||
# The timeout variable exists because apparently, if one worker
|
||||
# dies, the other workers won't die with it, unless the timeout is
|
||||
# set to some large number. This is an empirical observation.
|
||||
item = next(rollout_provider)
|
||||
if isinstance(item, CompletedRollout):
|
||||
self.metrics_queue.put(item)
|
||||
else:
|
||||
self.queue.put(item, timeout=600.0)
|
||||
|
||||
|
||||
def env_runner(env, policy, num_local_steps, summary_writer, render):
|
||||
"""This implements the logic of the thread runner.
|
||||
|
||||
It continually runs the policy, and as long as the rollout exceeds a
|
||||
certain length, the thread runner appends the policy to the queue.
|
||||
"""
|
||||
last_state = env.reset()
|
||||
timestep_limit = env.spec.tags.get("wrapper_config.TimeLimit"
|
||||
".max_episode_steps")
|
||||
last_features = policy.get_initial_features()
|
||||
length = 0
|
||||
rewards = 0
|
||||
rollout_number = 0
|
||||
|
||||
while True:
|
||||
terminal_end = False
|
||||
rollout = PartialRollout()
|
||||
|
||||
for _ in range(num_local_steps):
|
||||
fetched = policy.compute_action(last_state, *last_features)
|
||||
action, value_, features = fetched[0], fetched[1], fetched[2:]
|
||||
# Argmax to convert from one-hot.
|
||||
state, reward, terminal, info = env.step(action)
|
||||
if render:
|
||||
env.render()
|
||||
|
||||
length += 1
|
||||
rewards += reward
|
||||
if length >= timestep_limit:
|
||||
terminal = True
|
||||
|
||||
# Collect the experience.
|
||||
rollout.add(last_state, action, reward, value_, terminal,
|
||||
last_features)
|
||||
|
||||
last_state = state
|
||||
last_features = features
|
||||
|
||||
if info:
|
||||
summary = tf.Summary()
|
||||
for k, v in info.items():
|
||||
summary.value.add(tag=k, simple_value=float(v))
|
||||
summary_writer.add_summary(summary, rollout_number)
|
||||
summary_writer.flush()
|
||||
|
||||
if terminal:
|
||||
terminal_end = True
|
||||
yield CompletedRollout(length, rewards)
|
||||
|
||||
if (length >= timestep_limit or
|
||||
not env.metadata.get("semantics.autoreset")):
|
||||
last_state = env.reset()
|
||||
last_features = policy.get_initial_features()
|
||||
rollout_number += 1
|
||||
length = 0
|
||||
rewards = 0
|
||||
break
|
||||
|
||||
if not terminal_end:
|
||||
rollout.r = policy.value(last_state, *last_features)
|
||||
|
||||
# Once we have enough experience, yield it, and have the ThreadRunner
|
||||
# place it on a queue.
|
||||
yield rollout
|
||||
@@ -9,6 +9,10 @@ from ray.rllib.models.catalog import ModelCatalog
|
||||
|
||||
|
||||
class SharedModel(TFPolicy):
|
||||
|
||||
other_output = ["value"]
|
||||
is_recurrent = False
|
||||
|
||||
def __init__(self, ob_space, ac_space, **kwargs):
|
||||
super(SharedModel, self).__init__(ob_space, ac_space, **kwargs)
|
||||
|
||||
@@ -52,7 +56,7 @@ class SharedModel(TFPolicy):
|
||||
def compute_action(self, ob, *args):
|
||||
action, vf = self.sess.run([self.sample, self.vf],
|
||||
{self.x: [ob]})
|
||||
return action[0], vf[0]
|
||||
return action[0], {"value": vf[0]}
|
||||
|
||||
def value(self, ob, *args):
|
||||
vf = self.sess.run(self.vf, {self.x: [ob]})
|
||||
|
||||
@@ -10,6 +10,16 @@ from ray.rllib.models.lstm import LSTM
|
||||
|
||||
|
||||
class SharedModelLSTM(TFPolicy):
|
||||
"""
|
||||
Attributes:
|
||||
other_output (list): Other than `action`, the other return values from
|
||||
`compute_gradient`.
|
||||
is_recurrent (bool): True if is a recurrent network (requires features
|
||||
to be tracked).
|
||||
"""
|
||||
|
||||
other_output = ["value", "features"]
|
||||
is_recurrent = True
|
||||
|
||||
def __init__(self, ob_space, ac_space, **kwargs):
|
||||
super(SharedModelLSTM, self).__init__(ob_space, ac_space, **kwargs)
|
||||
@@ -66,10 +76,9 @@ class SharedModelLSTM(TFPolicy):
|
||||
action, vf, c, h = self.sess.run(
|
||||
[self.sample, self.vf] + self.state_out,
|
||||
{self.x: [ob], self.state_in[0]: c, self.state_in[1]: h})
|
||||
return action[0], vf[0], c, h
|
||||
return action[0], {"value": vf[0], "features": (c, h)}
|
||||
|
||||
def value(self, ob, c, h):
|
||||
# process_rollout is very non-intuitive due to value being a float
|
||||
vf = self.sess.run(self.vf, {self.x: [ob],
|
||||
self.state_in[0]: c,
|
||||
self.state_in[1]: h})
|
||||
|
||||
@@ -14,6 +14,9 @@ from ray.rllib.models.catalog import ModelCatalog
|
||||
class SharedTorchPolicy(TorchPolicy):
|
||||
"""Assumes nonrecurrent."""
|
||||
|
||||
other_output = ["value"]
|
||||
is_recurrent = False
|
||||
|
||||
def __init__(self, ob_space, ac_space, **kwargs):
|
||||
super(SharedTorchPolicy, self).__init__(
|
||||
ob_space, ac_space, **kwargs)
|
||||
@@ -30,7 +33,7 @@ class SharedTorchPolicy(TorchPolicy):
|
||||
logits, values = self._model(ob)
|
||||
samples = self._model.probs(logits).multinomial().squeeze()
|
||||
values = values.squeeze(0)
|
||||
return var_to_np(samples), var_to_np(values)
|
||||
return var_to_np(samples), {"value": var_to_np(values)}
|
||||
|
||||
def compute_logits(self, ob, *args):
|
||||
with self.lock:
|
||||
|
||||
@@ -16,7 +16,7 @@ from ray.rllib.dqn import logger, models
|
||||
from ray.rllib.dqn.common.wrappers import wrap_dqn
|
||||
from ray.rllib.dqn.common.schedules import LinearSchedule
|
||||
from ray.rllib.dqn.replay_buffer import ReplayBuffer, PrioritizedReplayBuffer
|
||||
from ray.rllib.ppo.filter import RunningStat
|
||||
from ray.rllib.utils.filter import RunningStat
|
||||
from ray.tune.result import TrainingResult
|
||||
|
||||
|
||||
|
||||
@@ -11,8 +11,7 @@ import tensorflow as tf
|
||||
|
||||
import ray
|
||||
from ray.rllib.models import ModelCatalog
|
||||
# TODO(rkn): Move these filters out of PPO to somewhere common.
|
||||
from ray.rllib.ppo.filter import NoFilter, MeanStdFilter
|
||||
from ray.rllib.utils.filter import get_filter
|
||||
|
||||
|
||||
def rollout(policy, env, timestep_limit=None, add_noise=False):
|
||||
@@ -46,16 +45,8 @@ class GenericPolicy(object):
|
||||
self.action_space = action_space
|
||||
self.action_noise_std = action_noise_std
|
||||
self.preprocessor = preprocessor
|
||||
|
||||
if observation_filter == "MeanStdFilter":
|
||||
self.observation_filter = MeanStdFilter(
|
||||
self.preprocessor.shape, clip=None)
|
||||
elif observation_filter == "NoFilter":
|
||||
self.observation_filter = NoFilter()
|
||||
else:
|
||||
raise Exception("Unknown observation_filter: " +
|
||||
str("observation_filter"))
|
||||
|
||||
self.observation_filter = get_filter(
|
||||
observation_filter, self.preprocessor.shape)
|
||||
self.inputs = tf.placeholder(
|
||||
tf.float32, [None] + list(self.preprocessor.shape))
|
||||
|
||||
|
||||
@@ -14,9 +14,9 @@ import ray
|
||||
|
||||
from ray.rllib.parallel import LocalSyncParallelOptimizer
|
||||
from ray.rllib.models import ModelCatalog
|
||||
from ray.rllib.utils.filter import get_filter, MeanStdFilter
|
||||
from ray.rllib.ppo.env import BatchedEnv
|
||||
from ray.rllib.ppo.loss import ProximalPolicyLoss
|
||||
from ray.rllib.ppo.filter import NoFilter, MeanStdFilter
|
||||
from ray.rllib.ppo.rollout import (
|
||||
rollouts, add_return_values, add_advantage_values)
|
||||
from ray.rllib.ppo.utils import flatten, concatenate
|
||||
@@ -137,14 +137,8 @@ class Runner(object):
|
||||
self.common_policy = self.par_opt.get_common_loss()
|
||||
self.variables = ray.experimental.TensorFlowVariables(
|
||||
self.common_policy.loss, self.sess)
|
||||
if config["observation_filter"] == "MeanStdFilter":
|
||||
self.observation_filter = MeanStdFilter(
|
||||
self.preprocessor.shape, clip=None)
|
||||
elif config["observation_filter"] == "NoFilter":
|
||||
self.observation_filter = NoFilter()
|
||||
else:
|
||||
raise Exception("Unknown observation_filter: " +
|
||||
str(config["observation_filter"]))
|
||||
self.observation_filter = get_filter(
|
||||
config["observation_filter"], self.preprocessor.shape)
|
||||
self.reward_filter = MeanStdFilter((), clip=5.0)
|
||||
self.sess.run(tf.global_variables_initializer())
|
||||
|
||||
|
||||
@@ -5,19 +5,41 @@ from __future__ import print_function
|
||||
import numpy as np
|
||||
|
||||
|
||||
class NoFilter(object):
|
||||
def __init__(self):
|
||||
class BaseFilter(object):
|
||||
"""Processes input, possibly statefully."""
|
||||
|
||||
def update(self, other, *args, **kwargs):
|
||||
"""Updates self with "new state" from other filter."""
|
||||
raise NotImplementedError
|
||||
|
||||
def copy(self):
|
||||
"""Creates a new object with same state as self.
|
||||
|
||||
Returns:
|
||||
copy (Filter): Copy of self"""
|
||||
raise NotImplementedError
|
||||
|
||||
def sync(self, other):
|
||||
"""Copies all state from other filter to self."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class NoFilter(BaseFilter):
|
||||
def __init__(self, *args):
|
||||
pass
|
||||
|
||||
def __call__(self, x, update=True):
|
||||
return np.asarray(x)
|
||||
|
||||
def update(self, other):
|
||||
def update(self, other, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def copy(self):
|
||||
return self
|
||||
|
||||
def sync(self, other):
|
||||
pass
|
||||
|
||||
|
||||
# http://www.johndcook.com/blog/standard_deviation/
|
||||
class RunningStat(object):
|
||||
@@ -103,12 +125,22 @@ class MeanStdFilter(object):
|
||||
def clear_buffer(self):
|
||||
self.buffer = RunningStat(self.shape)
|
||||
|
||||
def update(self, other):
|
||||
# `update` takes another filter and
|
||||
# only applies the information from the buffer.
|
||||
def update(self, other, copy_buffer=False):
|
||||
"""Takes another filter and only applies the information from the
|
||||
buffer.
|
||||
|
||||
Using notation `F(state, buffer)`
|
||||
Given `Filter1(x1, y1)` and `Filter2(x2, yt)`,
|
||||
`update` modifies `Filter1` to `Filter1(x1 + yt, y1)`
|
||||
If `copy_buffer`, then `Filter1` is modified to
|
||||
`Filter1(x1 + yt, yt)`.
|
||||
"""
|
||||
self.rs.update(other.buffer)
|
||||
if copy_buffer:
|
||||
self.buffer = other.buffer.copy()
|
||||
|
||||
def copy(self):
|
||||
"""Returns a copy of Filter."""
|
||||
other = MeanStdFilter(self.shape)
|
||||
other.demean = self.demean
|
||||
other.destd = self.destd
|
||||
@@ -117,6 +149,20 @@ class MeanStdFilter(object):
|
||||
other.buffer = self.buffer.copy()
|
||||
return other
|
||||
|
||||
def sync(self, other):
|
||||
"""Syncs all fields together from other filter.
|
||||
|
||||
Using notation `F(state, buffer)`
|
||||
Given `Filter1(x1, y1)` and `Filter2(x2, yt)`,
|
||||
`sync` modifies `Filter1` to `Filter1(x2, yt)`
|
||||
"""
|
||||
assert other.shape == self.shape, "Shapes don't match!"
|
||||
self.demean = other.demean
|
||||
self.destd = other.destd
|
||||
self.clip = other.clip
|
||||
self.rs = other.rs.copy()
|
||||
self.buffer = other.buffer.copy()
|
||||
|
||||
def __call__(self, x, update=True):
|
||||
x = np.asarray(x)
|
||||
if update:
|
||||
@@ -138,8 +184,19 @@ class MeanStdFilter(object):
|
||||
return x
|
||||
|
||||
def __repr__(self):
|
||||
return 'MeanStdFilter({}, {}, {}, {}, {})'.format(
|
||||
self.shape, self.demean, self.destd, self.clip, self.rs)
|
||||
return 'MeanStdFilter({}, {}, {}, {}, {}, {})'.format(
|
||||
self.shape, self.demean, self.destd,
|
||||
self.clip, self.rs, self.buffer)
|
||||
|
||||
|
||||
def get_filter(filter_config, shape):
|
||||
if filter_config == "MeanStdFilter":
|
||||
return MeanStdFilter(shape, clip=None)
|
||||
elif filter_config == "NoFilter":
|
||||
return NoFilter()
|
||||
else:
|
||||
raise Exception("Unknown observation_filter: " +
|
||||
str(filter_config))
|
||||
|
||||
|
||||
def test_running_stat():
|
||||
@@ -0,0 +1,288 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import six.moves.queue as queue
|
||||
import threading
|
||||
from collections import namedtuple
|
||||
|
||||
|
||||
def lock_wrap(func, lock):
|
||||
def wrapper(*args, **kwargs):
|
||||
with lock:
|
||||
return func(*args, **kwargs)
|
||||
return wrapper
|
||||
|
||||
|
||||
class PartialRollout(object):
|
||||
"""A piece of a complete rollout.
|
||||
|
||||
We run our agent, and process its experience once it has processed enough
|
||||
steps.
|
||||
"""
|
||||
|
||||
fields = ["state", "action", "reward", "terminal", "features"]
|
||||
|
||||
def __init__(self, extra_fields=None):
|
||||
"""Initializers internals. Maintains a `last_r` field
|
||||
in support of partial rollouts, used in bootstrapping advantage
|
||||
estimation.
|
||||
|
||||
Args:
|
||||
extra_fields: Optional field for object to keep track.
|
||||
"""
|
||||
if extra_fields:
|
||||
self.fields.extend(extra_fields)
|
||||
self.data = {k: [] for k in self.fields}
|
||||
self.last_r = 0.0
|
||||
|
||||
def add(self, **kwargs):
|
||||
for k, v in kwargs.items():
|
||||
self.data[k] += [v]
|
||||
|
||||
def extend(self, other_rollout):
|
||||
"""Extends internal data structure. Assumes other_rollout contains
|
||||
data that occured afterwards."""
|
||||
|
||||
assert not self.is_terminal()
|
||||
assert all(k in other_rollout.fields for k in self.fields)
|
||||
for k, v in other_rollout.data.items():
|
||||
self.data[k].extend(v)
|
||||
self.last_r = other_rollout.last_r
|
||||
|
||||
def is_terminal(self):
|
||||
"""Check if terminal.
|
||||
|
||||
Returns:
|
||||
terminal (bool): if rollout has terminated."""
|
||||
return self.data["terminal"][-1]
|
||||
|
||||
|
||||
CompletedRollout = namedtuple(
|
||||
"CompletedRollout", ["episode_length", "episode_reward"])
|
||||
|
||||
|
||||
class SyncSampler(object):
|
||||
"""This class interacts with the environment and tells it what to do.
|
||||
|
||||
Note that batch_size is only a unit of measure here. Batches can
|
||||
accumulate and the gradient can be calculated on up to 5 batches.
|
||||
|
||||
This class provides data on invocation, rather than on a separate
|
||||
thread."""
|
||||
async = False
|
||||
|
||||
def __init__(self, env, policy, num_local_steps, obs_filter):
|
||||
self.num_local_steps = num_local_steps
|
||||
self.env = env
|
||||
self.policy = policy
|
||||
self.obs_filter = obs_filter
|
||||
self.rollout_provider = _env_runner(
|
||||
self.env, self.policy, self.num_local_steps, self.obs_filter)
|
||||
self.metrics_queue = queue.Queue()
|
||||
|
||||
def update_obs_filter(self, other_filter):
|
||||
"""Method to update observation filter with copy from driver.
|
||||
Since this class is synchronous, updating the observation
|
||||
filter should be a straightforward replacement
|
||||
|
||||
Args:
|
||||
other_filter: Another filter (of same type)."""
|
||||
self.obs_filter = other_filter.copy()
|
||||
|
||||
def get_data(self):
|
||||
while True:
|
||||
item = next(self.rollout_provider)
|
||||
if isinstance(item, CompletedRollout):
|
||||
self.metrics_queue.put(item)
|
||||
else:
|
||||
obsf_snapshot = self.obs_filter.copy()
|
||||
if hasattr(self.obs_filter, "clear_buffer"):
|
||||
self.obs_filter.clear_buffer()
|
||||
return item, obsf_snapshot
|
||||
|
||||
def get_metrics(self):
|
||||
completed = []
|
||||
while True:
|
||||
try:
|
||||
completed.append(self.metrics_queue.get_nowait())
|
||||
except queue.Empty:
|
||||
break
|
||||
return completed
|
||||
|
||||
|
||||
class AsyncSampler(threading.Thread):
|
||||
"""This class interacts with the environment and tells it what to do.
|
||||
|
||||
Note that batch_size is only a unit of measure here. Batches can
|
||||
accumulate and the gradient can be calculated on up to 5 batches."""
|
||||
async = True
|
||||
|
||||
def __init__(self, env, policy, num_local_steps, obs_filter):
|
||||
threading.Thread.__init__(self)
|
||||
self.queue = queue.Queue(5)
|
||||
self.metrics_queue = queue.Queue()
|
||||
self.num_local_steps = num_local_steps
|
||||
self.env = env
|
||||
self.policy = policy
|
||||
self.obs_filter = obs_filter
|
||||
self.obs_f_lock = threading.Lock()
|
||||
self.start()
|
||||
|
||||
def run(self):
|
||||
try:
|
||||
self._run()
|
||||
except BaseException as e:
|
||||
self.queue.put(e)
|
||||
raise e
|
||||
|
||||
def update_obs_filter(self, other_filter):
|
||||
"""Method to update observation filter with copy from driver.
|
||||
Applies delta since last `clear_buffer` to given new filter,
|
||||
and syncs current filter to new filter. `self.obs_filter` is
|
||||
kept in place due to the `lock_wrap`.
|
||||
|
||||
Args:
|
||||
other_filter: Another filter (of same type)."""
|
||||
with self.obs_f_lock:
|
||||
new_filter = other_filter.copy()
|
||||
# Applies delta to filter, including buffer
|
||||
new_filter.update(self.obs_filter, copy_buffer=True)
|
||||
# copies everything back into original filter - needed
|
||||
# due to `lock_wrap`
|
||||
self.obs_filter.sync(new_filter)
|
||||
|
||||
def _run(self):
|
||||
"""Sets observation filter into an atomic region and starts
|
||||
other thread for running."""
|
||||
safe_obs_filter = lock_wrap(self.obs_filter, self.obs_f_lock)
|
||||
rollout_provider = _env_runner(
|
||||
self.env, self.policy, self.num_local_steps, safe_obs_filter)
|
||||
while True:
|
||||
# The timeout variable exists because apparently, if one worker
|
||||
# dies, the other workers won't die with it, unless the timeout is
|
||||
# set to some large number. This is an empirical observation.
|
||||
item = next(rollout_provider)
|
||||
if isinstance(item, CompletedRollout):
|
||||
self.metrics_queue.put(item)
|
||||
else:
|
||||
self.queue.put(item, timeout=600.0)
|
||||
|
||||
def get_data(self):
|
||||
"""Gets currently accumulated data and a snapshot of the current
|
||||
observation filter. The snapshot also clears the accumulated delta.
|
||||
Note that in between getting the rollout and acquiring the lock,
|
||||
the other thread can run, resulting in slight discrepamcies
|
||||
between data retrieved and filter statistics.
|
||||
|
||||
Returns:
|
||||
rollout: trajectory data (unprocessed)
|
||||
obsf_snapshot: snapshot of observation filter.
|
||||
"""
|
||||
|
||||
rollout = self._pull_batch_from_queue()
|
||||
with self.obs_f_lock:
|
||||
obsf_snapshot = self.obs_filter.copy()
|
||||
if hasattr(self.obs_filter, "clear_buffer"):
|
||||
self.obs_filter.clear_buffer()
|
||||
return rollout, obsf_snapshot
|
||||
|
||||
def _pull_batch_from_queue(self):
|
||||
"""Take a rollout from the queue of the thread runner."""
|
||||
rollout = self.queue.get(timeout=600.0)
|
||||
if isinstance(rollout, BaseException):
|
||||
raise rollout
|
||||
while not rollout.is_terminal():
|
||||
try:
|
||||
part = self.queue.get_nowait()
|
||||
if isinstance(part, BaseException):
|
||||
raise rollout
|
||||
rollout.extend(part)
|
||||
except queue.Empty:
|
||||
break
|
||||
return rollout
|
||||
|
||||
def get_metrics(self):
|
||||
completed = []
|
||||
while True:
|
||||
try:
|
||||
completed.append(self.metrics_queue.get_nowait())
|
||||
except queue.Empty:
|
||||
break
|
||||
return completed
|
||||
|
||||
|
||||
def _env_runner(env, policy, num_local_steps, obs_filter):
|
||||
"""This implements the logic of the thread runner.
|
||||
|
||||
It continually runs the policy, and as long as the rollout exceeds a
|
||||
certain length, the thread runner appends the policy to the queue. Yields
|
||||
when `timestep_limit` is surpassed, environment terminates, or
|
||||
`num_local_steps` is reached.
|
||||
|
||||
Args:
|
||||
env: Environment generated by env_creator
|
||||
policy: Policy used to interact with environment. Also sets fields
|
||||
to be included in `PartialRollout`
|
||||
num_local_steps: Number of steps before `PartialRollout` is yielded.
|
||||
obs_filter: Filter used to process observations.
|
||||
|
||||
Yields:
|
||||
rollout (PartialRollout): Object containing state, action, reward,
|
||||
terminal condition, and other fields as dictated by `policy`.
|
||||
"""
|
||||
last_state = obs_filter(env.reset())
|
||||
timestep_limit = env.spec.tags.get("wrapper_config.TimeLimit"
|
||||
".max_episode_steps")
|
||||
last_features = features = policy.get_initial_features()
|
||||
length = 0
|
||||
rewards = 0
|
||||
rollout_number = 0
|
||||
|
||||
while True:
|
||||
terminal_end = False
|
||||
rollout = PartialRollout(extra_fields=policy.other_output)
|
||||
|
||||
for _ in range(num_local_steps):
|
||||
action, pi_info = policy.compute_action(last_state, *last_features)
|
||||
if policy.is_recurrent:
|
||||
features = pi_info["features"]
|
||||
del pi_info["features"]
|
||||
state, reward, terminal, info = env.step(action)
|
||||
state = obs_filter(state)
|
||||
|
||||
length += 1
|
||||
rewards += reward
|
||||
if length >= timestep_limit:
|
||||
terminal = True
|
||||
|
||||
# Collect the experience.
|
||||
rollout.add(state=last_state,
|
||||
action=action,
|
||||
reward=reward,
|
||||
terminal=terminal,
|
||||
features=last_features,
|
||||
**pi_info)
|
||||
|
||||
last_state = state
|
||||
last_features = features
|
||||
|
||||
if terminal:
|
||||
terminal_end = True
|
||||
yield CompletedRollout(length, rewards)
|
||||
|
||||
if (length >= timestep_limit or
|
||||
not env.metadata.get("semantics.autoreset")):
|
||||
last_state = obs_filter(env.reset())
|
||||
last_features = policy.get_initial_features()
|
||||
rollout_number += 1
|
||||
length = 0
|
||||
rewards = 0
|
||||
break
|
||||
|
||||
if not terminal_end:
|
||||
rollout.last_r = policy.value(last_state, *last_features)
|
||||
|
||||
# Once we have enough experience, yield it, and have the ThreadRunner
|
||||
# place it on a queue.
|
||||
yield rollout
|
||||
Reference in New Issue
Block a user