mirror of
https://github.com/wassname/ray.git
synced 2026-07-05 17:55:15 +08:00
[rllib] Basic IMPALA implementation (using deepmind's reference vtrace.py) (#2504)
Rename AsyncSamplesOptimizer -> AsyncReplayOptimizer Add AsyncSamplesOptimizer that implements the IMPALA architecture integrate V-trace with a3c policy graph audit V-trace integration benchmark compare vs A3C and with V-trace on/off PongNoFrameskip-v4 on IMPALA scaling from 16 to 128 workers, solving Pong in <10 min. For reference, solving this env takes ~40 minutes for Ape-X and several hours for A3C.
This commit is contained in:
@@ -19,7 +19,7 @@ from ray.rllib.evaluation.sample_batch import SampleBatch
|
||||
def _register_all():
|
||||
for key in [
|
||||
"PPO", "ES", "DQN", "APEX", "A3C", "BC", "PG", "DDPG", "APEX_DDPG",
|
||||
"__fake", "__sigmoid_fake_data", "__parameter_tuning"
|
||||
"IMPALA", "__fake", "__sigmoid_fake_data", "__parameter_tuning"
|
||||
]:
|
||||
from ray.rllib.agents.agent import get_agent_class
|
||||
register_trainable(key, get_agent_class(key))
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
"""Note: Keep in sync with changes to VTracePolicyGraph."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
@@ -77,8 +79,6 @@ class A3CPolicyGraph(TFPolicyGraph):
|
||||
("advantages", advantages),
|
||||
("value_targets", v_target),
|
||||
]
|
||||
self.state_in = self.model.state_in
|
||||
self.state_out = self.model.state_out
|
||||
TFPolicyGraph.__init__(
|
||||
self,
|
||||
observation_space,
|
||||
@@ -88,29 +88,21 @@ class A3CPolicyGraph(TFPolicyGraph):
|
||||
action_sampler=action_dist.sample(),
|
||||
loss=self.loss.total_loss,
|
||||
loss_inputs=loss_in,
|
||||
state_inputs=self.state_in,
|
||||
state_outputs=self.state_out,
|
||||
state_inputs=self.model.state_in,
|
||||
state_outputs=self.model.state_out,
|
||||
seq_lens=self.model.seq_lens,
|
||||
max_seq_len=self.config["model"]["max_seq_len"])
|
||||
|
||||
if self.config.get("summarize"):
|
||||
bs = tf.to_float(tf.shape(self.observations)[0])
|
||||
tf.summary.scalar("model/policy_graph", self.loss.pi_loss / bs)
|
||||
tf.summary.scalar("model/value_loss", self.loss.vf_loss / bs)
|
||||
tf.summary.scalar("model/entropy", self.loss.entropy / bs)
|
||||
tf.summary.scalar("model/grad_gnorm", tf.global_norm(self._grads))
|
||||
tf.summary.scalar("model/var_gnorm", tf.global_norm(self.var_list))
|
||||
self.summary_op = tf.summary.merge_all()
|
||||
|
||||
self.sess.run(tf.global_variables_initializer())
|
||||
|
||||
def extra_compute_action_fetches(self):
|
||||
return {"vf_preds": self.vf}
|
||||
|
||||
def value(self, ob, *args):
|
||||
feed_dict = {self.observations: [ob]}
|
||||
assert len(args) == len(self.state_in), (args, self.state_in)
|
||||
for k, v in zip(self.state_in, args):
|
||||
feed_dict = {self.observations: [ob], self.model.seq_lens: [1]}
|
||||
assert len(args) == len(self.model.state_in), \
|
||||
(args, self.model.state_in)
|
||||
for k, v in zip(self.model.state_in, args):
|
||||
feed_dict[k] = v
|
||||
vf = self.sess.run(self.vf, feed_dict)
|
||||
return vf[0]
|
||||
@@ -126,7 +118,15 @@ class A3CPolicyGraph(TFPolicyGraph):
|
||||
|
||||
def extra_compute_grad_fetches(self):
|
||||
if self.config.get("summarize"):
|
||||
return {"summary": self.summary_op}
|
||||
return {
|
||||
"stats": {
|
||||
"policy_loss": self.loss.pi_loss,
|
||||
"value_loss": self.loss.vf_loss,
|
||||
"entropy": self.loss.entropy,
|
||||
"grad_gnorm": tf.global_norm(self._grads),
|
||||
"var_gnorm": tf.global_norm(self.var_list),
|
||||
},
|
||||
}
|
||||
else:
|
||||
return {}
|
||||
|
||||
@@ -139,7 +139,7 @@ class A3CPolicyGraph(TFPolicyGraph):
|
||||
last_r = 0.0
|
||||
else:
|
||||
next_state = []
|
||||
for i in range(len(self.state_in)):
|
||||
for i in range(len(self.model.state_in)):
|
||||
next_state.append([sample_batch["state_out_{}".format(i)][-1]])
|
||||
last_r = self.value(sample_batch["new_obs"][-1], *next_state)
|
||||
return compute_advantages(sample_batch, last_r, self.config["gamma"],
|
||||
|
||||
@@ -360,6 +360,9 @@ def get_agent_class(alg):
|
||||
elif alg == "PG":
|
||||
from ray.rllib.agents import pg
|
||||
return pg.PGAgent
|
||||
elif alg == "IMPALA":
|
||||
from ray.rllib.agents import impala
|
||||
return impala.ImpalaAgent
|
||||
elif alg == "script":
|
||||
from ray.tune import script_runner
|
||||
return script_runner.ScriptRunner
|
||||
|
||||
@@ -9,7 +9,7 @@ from ray.tune.trial import Resources
|
||||
APEX_DDPG_DEFAULT_CONFIG = merge_dicts(
|
||||
DDPG_CONFIG,
|
||||
{
|
||||
"optimizer_class": "AsyncSamplesOptimizer",
|
||||
"optimizer_class": "AsyncReplayOptimizer",
|
||||
"optimizer": merge_dicts(
|
||||
DDPG_CONFIG["optimizer"], {
|
||||
"max_weight_sync_delay": 400,
|
||||
|
||||
@@ -9,7 +9,7 @@ from ray.tune.trial import Resources
|
||||
APEX_DEFAULT_CONFIG = merge_dicts(
|
||||
DQN_CONFIG,
|
||||
{
|
||||
"optimizer_class": "AsyncSamplesOptimizer",
|
||||
"optimizer_class": "AsyncReplayOptimizer",
|
||||
"optimizer": merge_dicts(
|
||||
DQN_CONFIG["optimizer"], {
|
||||
"max_weight_sync_delay": 400,
|
||||
|
||||
@@ -0,0 +1,3 @@
|
||||
from ray.rllib.agents.impala.impala import ImpalaAgent, DEFAULT_CONFIG
|
||||
|
||||
__all__ = ["ImpalaAgent", "DEFAULT_CONFIG"]
|
||||
@@ -0,0 +1,123 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import pickle
|
||||
import os
|
||||
import time
|
||||
|
||||
import ray
|
||||
from ray.rllib.agents.a3c.a3c_tf_policy_graph import A3CPolicyGraph
|
||||
from ray.rllib.agents.impala.vtrace_policy_graph import VTracePolicyGraph
|
||||
from ray.rllib.agents.agent import Agent, with_common_config
|
||||
from ray.rllib.optimizers import AsyncSamplesOptimizer
|
||||
from ray.rllib.utils import FilterManager
|
||||
from ray.tune.trial import Resources
|
||||
|
||||
OPTIMIZER_SHARED_CONFIGS = [
|
||||
"sample_batch_size",
|
||||
"train_batch_size",
|
||||
]
|
||||
|
||||
DEFAULT_CONFIG = with_common_config({
|
||||
# V-trace params (see vtrace.py).
|
||||
"vtrace": True,
|
||||
"vtrace_clip_rho_threshold": 1.0,
|
||||
"vtrace_clip_pg_rho_threshold": 1.0,
|
||||
|
||||
# System params.
|
||||
"sample_batch_size": 50,
|
||||
"train_batch_size": 500,
|
||||
"min_iter_time_s": 10,
|
||||
"summarize": False,
|
||||
"gpu": True,
|
||||
"num_workers": 2,
|
||||
"num_cpus_per_worker": 1,
|
||||
"num_gpus_per_worker": 0,
|
||||
|
||||
# Learning params.
|
||||
"grad_clip": 40.0,
|
||||
"lr": 0.0001,
|
||||
"vf_loss_coeff": 0.5,
|
||||
"entropy_coeff": -0.01,
|
||||
|
||||
# Model and preprocessor options.
|
||||
"clip_rewards": True,
|
||||
"preprocessor_pref": "deepmind",
|
||||
"model": {
|
||||
"use_lstm": False,
|
||||
"max_seq_len": 20,
|
||||
"dim": 80,
|
||||
},
|
||||
})
|
||||
|
||||
|
||||
class ImpalaAgent(Agent):
|
||||
"""IMPALA implementation using DeepMind's V-trace."""
|
||||
|
||||
_agent_name = "IMPALA"
|
||||
_default_config = DEFAULT_CONFIG
|
||||
|
||||
@classmethod
|
||||
def default_resource_request(cls, config):
|
||||
cf = dict(cls._default_config, **config)
|
||||
return Resources(
|
||||
cpu=1,
|
||||
gpu=cf["gpu"] and 1 or 0,
|
||||
extra_cpu=cf["num_cpus_per_worker"] * cf["num_workers"],
|
||||
extra_gpu=cf["num_gpus_per_worker"] * cf["num_workers"])
|
||||
|
||||
def _init(self):
|
||||
for k in OPTIMIZER_SHARED_CONFIGS:
|
||||
if k not in self.config["optimizer"]:
|
||||
self.config["optimizer"][k] = self.config[k]
|
||||
if self.config["vtrace"]:
|
||||
policy_cls = VTracePolicyGraph
|
||||
else:
|
||||
policy_cls = A3CPolicyGraph
|
||||
self.local_evaluator = self.make_local_evaluator(
|
||||
self.env_creator, policy_cls)
|
||||
self.remote_evaluators = self.make_remote_evaluators(
|
||||
self.env_creator, policy_cls, self.config["num_workers"],
|
||||
{"num_cpus": 1})
|
||||
self.optimizer = AsyncSamplesOptimizer(self.local_evaluator,
|
||||
self.remote_evaluators,
|
||||
self.config["optimizer"])
|
||||
|
||||
def _train(self):
|
||||
prev_steps = self.optimizer.num_steps_sampled
|
||||
start = time.time()
|
||||
self.optimizer.step()
|
||||
while time.time() - start < self.config["min_iter_time_s"]:
|
||||
self.optimizer.step()
|
||||
FilterManager.synchronize(self.local_evaluator.filters,
|
||||
self.remote_evaluators)
|
||||
result = self.optimizer.collect_metrics()
|
||||
result = result._replace(
|
||||
timesteps_this_iter=self.optimizer.num_steps_sampled - prev_steps)
|
||||
return result
|
||||
|
||||
def _stop(self):
|
||||
# workaround for https://github.com/ray-project/ray/issues/1516
|
||||
for ev in self.remote_evaluators:
|
||||
ev.__ray_terminate__.remote()
|
||||
|
||||
def _save(self, checkpoint_dir):
|
||||
checkpoint_path = os.path.join(checkpoint_dir,
|
||||
"checkpoint-{}".format(self.iteration))
|
||||
agent_state = ray.get(
|
||||
[a.save.remote() for a in self.remote_evaluators])
|
||||
extra_data = {
|
||||
"remote_state": agent_state,
|
||||
"local_state": self.local_evaluator.save()
|
||||
}
|
||||
pickle.dump(extra_data, open(checkpoint_path + ".extra_data", "wb"))
|
||||
return checkpoint_path
|
||||
|
||||
def _restore(self, checkpoint_path):
|
||||
extra_data = pickle.load(open(checkpoint_path + ".extra_data", "rb"))
|
||||
ray.get([
|
||||
a.restore.remote(o)
|
||||
for a, o in zip(self.remote_evaluators, extra_data["remote_state"])
|
||||
])
|
||||
self.local_evaluator.restore(extra_data["local_state"])
|
||||
@@ -0,0 +1,300 @@
|
||||
# Copyright 2018 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Functions to compute V-trace off-policy actor critic targets.
|
||||
|
||||
For details and theory see:
|
||||
|
||||
"IMPALA: Scalable Distributed Deep-RL with
|
||||
Importance Weighted Actor-Learner Architectures"
|
||||
by Espeholt, Soyer, Munos et al.
|
||||
|
||||
See https://arxiv.org/abs/1802.01561 for the full paper.
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
nest = tf.contrib.framework.nest
|
||||
|
||||
VTraceFromLogitsReturns = collections.namedtuple('VTraceFromLogitsReturns', [
|
||||
'vs', 'pg_advantages', 'log_rhos', 'behaviour_action_log_probs',
|
||||
'target_action_log_probs'
|
||||
])
|
||||
|
||||
VTraceReturns = collections.namedtuple('VTraceReturns', 'vs pg_advantages')
|
||||
|
||||
|
||||
def log_probs_from_logits_and_actions(policy_logits, actions):
|
||||
"""Computes action log-probs from policy logits and actions.
|
||||
|
||||
In the notation used throughout documentation and comments, T refers to the
|
||||
time dimension ranging from 0 to T-1. B refers to the batch size and
|
||||
NUM_ACTIONS refers to the number of actions.
|
||||
|
||||
Args:
|
||||
policy_logits: A float32 tensor of shape [T, B, NUM_ACTIONS] with
|
||||
un-normalized log-probabilities parameterizing a softmax policy.
|
||||
actions: An int32 tensor of shape [T, B] with actions.
|
||||
|
||||
Returns:
|
||||
A float32 tensor of shape [T, B] corresponding to the sampling log
|
||||
probability of the chosen action w.r.t. the policy.
|
||||
"""
|
||||
policy_logits = tf.convert_to_tensor(policy_logits, dtype=tf.float32)
|
||||
actions = tf.convert_to_tensor(actions, dtype=tf.int32)
|
||||
|
||||
policy_logits.shape.assert_has_rank(3)
|
||||
actions.shape.assert_has_rank(2)
|
||||
|
||||
return -tf.nn.sparse_softmax_cross_entropy_with_logits(
|
||||
logits=policy_logits, labels=actions)
|
||||
|
||||
|
||||
def from_logits(behaviour_policy_logits,
|
||||
target_policy_logits,
|
||||
actions,
|
||||
discounts,
|
||||
rewards,
|
||||
values,
|
||||
bootstrap_value,
|
||||
clip_rho_threshold=1.0,
|
||||
clip_pg_rho_threshold=1.0,
|
||||
name='vtrace_from_logits'):
|
||||
r"""V-trace for softmax policies.
|
||||
|
||||
Calculates V-trace actor critic targets for softmax polices as described in
|
||||
|
||||
"IMPALA: Scalable Distributed Deep-RL with
|
||||
Importance Weighted Actor-Learner Architectures"
|
||||
by Espeholt, Soyer, Munos et al.
|
||||
|
||||
Target policy refers to the policy we are interested in improving and
|
||||
behaviour policy refers to the policy that generated the given
|
||||
rewards and actions.
|
||||
|
||||
In the notation used throughout documentation and comments, T refers to the
|
||||
time dimension ranging from 0 to T-1. B refers to the batch size and
|
||||
NUM_ACTIONS refers to the number of actions.
|
||||
|
||||
Args:
|
||||
behaviour_policy_logits: A float32 tensor of shape [T, B, NUM_ACTIONS] with
|
||||
un-normalized log-probabilities parametrizing the softmax behaviour
|
||||
policy.
|
||||
target_policy_logits: A float32 tensor of shape [T, B, NUM_ACTIONS] with
|
||||
un-normalized log-probabilities parametrizing the softmax target policy.
|
||||
actions: An int32 tensor of shape [T, B] of actions sampled from the
|
||||
behaviour policy.
|
||||
discounts: A float32 tensor of shape [T, B] with the discount encountered
|
||||
when following the behaviour policy.
|
||||
rewards: A float32 tensor of shape [T, B] with the rewards generated by
|
||||
following the behaviour policy.
|
||||
values: A float32 tensor of shape [T, B] with the value function estimates
|
||||
wrt. the target policy.
|
||||
bootstrap_value: A float32 of shape [B] with the value function estimate at
|
||||
time T.
|
||||
clip_rho_threshold: A scalar float32 tensor with the clipping threshold for
|
||||
importance weights (rho) when calculating the baseline targets (vs).
|
||||
rho^bar in the paper.
|
||||
clip_pg_rho_threshold: A scalar float32 tensor with the clipping threshold
|
||||
on rho_s in \rho_s \delta log \pi(a|x) (r + \gamma v_{s+1} - V(x_s)).
|
||||
name: The name scope that all V-trace operations will be created in.
|
||||
|
||||
Returns:
|
||||
A `VTraceFromLogitsReturns` namedtuple with the following fields:
|
||||
vs: A float32 tensor of shape [T, B]. Can be used as target to train a
|
||||
baseline (V(x_t) - vs_t)^2.
|
||||
pg_advantages: A float 32 tensor of shape [T, B]. Can be used as an
|
||||
estimate of the advantage in the calculation of policy gradients.
|
||||
log_rhos: A float32 tensor of shape [T, B] containing the log importance
|
||||
sampling weights (log rhos).
|
||||
behaviour_action_log_probs: A float32 tensor of shape [T, B] containing
|
||||
behaviour policy action log probabilities (log \mu(a_t)).
|
||||
target_action_log_probs: A float32 tensor of shape [T, B] containing
|
||||
target policy action probabilities (log \pi(a_t)).
|
||||
"""
|
||||
behaviour_policy_logits = tf.convert_to_tensor(
|
||||
behaviour_policy_logits, dtype=tf.float32)
|
||||
target_policy_logits = tf.convert_to_tensor(
|
||||
target_policy_logits, dtype=tf.float32)
|
||||
actions = tf.convert_to_tensor(actions, dtype=tf.int32)
|
||||
|
||||
# Make sure tensor ranks are as expected.
|
||||
# The rest will be checked by from_action_log_probs.
|
||||
behaviour_policy_logits.shape.assert_has_rank(3)
|
||||
target_policy_logits.shape.assert_has_rank(3)
|
||||
actions.shape.assert_has_rank(2)
|
||||
|
||||
with tf.name_scope(
|
||||
name,
|
||||
values=[
|
||||
behaviour_policy_logits, target_policy_logits, actions,
|
||||
discounts, rewards, values, bootstrap_value
|
||||
]):
|
||||
target_action_log_probs = log_probs_from_logits_and_actions(
|
||||
target_policy_logits, actions)
|
||||
behaviour_action_log_probs = log_probs_from_logits_and_actions(
|
||||
behaviour_policy_logits, actions)
|
||||
log_rhos = target_action_log_probs - behaviour_action_log_probs
|
||||
vtrace_returns = from_importance_weights(
|
||||
log_rhos=log_rhos,
|
||||
discounts=discounts,
|
||||
rewards=rewards,
|
||||
values=values,
|
||||
bootstrap_value=bootstrap_value,
|
||||
clip_rho_threshold=clip_rho_threshold,
|
||||
clip_pg_rho_threshold=clip_pg_rho_threshold)
|
||||
return VTraceFromLogitsReturns(
|
||||
log_rhos=log_rhos,
|
||||
behaviour_action_log_probs=behaviour_action_log_probs,
|
||||
target_action_log_probs=target_action_log_probs,
|
||||
**vtrace_returns._asdict())
|
||||
|
||||
|
||||
def from_importance_weights(log_rhos,
|
||||
discounts,
|
||||
rewards,
|
||||
values,
|
||||
bootstrap_value,
|
||||
clip_rho_threshold=1.0,
|
||||
clip_pg_rho_threshold=1.0,
|
||||
name='vtrace_from_importance_weights'):
|
||||
r"""V-trace from log importance weights.
|
||||
|
||||
Calculates V-trace actor critic targets as described in
|
||||
|
||||
"IMPALA: Scalable Distributed Deep-RL with
|
||||
Importance Weighted Actor-Learner Architectures"
|
||||
by Espeholt, Soyer, Munos et al.
|
||||
|
||||
In the notation used throughout documentation and comments, T refers to the
|
||||
time dimension ranging from 0 to T-1. B refers to the batch size and
|
||||
NUM_ACTIONS refers to the number of actions. This code also supports the
|
||||
case where all tensors have the same number of additional dimensions, e.g.,
|
||||
`rewards` is [T, B, C], `values` is [T, B, C], `bootstrap_value` is [B, C].
|
||||
|
||||
Args:
|
||||
log_rhos: A float32 tensor of shape [T, B, NUM_ACTIONS] representing the
|
||||
log importance sampling weights, i.e.
|
||||
log(target_policy(a) / behaviour_policy(a)). V-trace performs operations
|
||||
on rhos in log-space for numerical stability.
|
||||
discounts: A float32 tensor of shape [T, B] with discounts encountered when
|
||||
following the behaviour policy.
|
||||
rewards: A float32 tensor of shape [T, B] containing rewards generated by
|
||||
following the behaviour policy.
|
||||
values: A float32 tensor of shape [T, B] with the value function estimates
|
||||
wrt. the target policy.
|
||||
bootstrap_value: A float32 of shape [B] with the value function estimate at
|
||||
time T.
|
||||
clip_rho_threshold: A scalar float32 tensor with the clipping threshold for
|
||||
importance weights (rho) when calculating the baseline targets (vs).
|
||||
rho^bar in the paper. If None, no clipping is applied.
|
||||
clip_pg_rho_threshold: A scalar float32 tensor with the clipping threshold
|
||||
on rho_s in \rho_s \delta log \pi(a|x) (r + \gamma v_{s+1} - V(x_s)). If
|
||||
None, no clipping is applied.
|
||||
name: The name scope that all V-trace operations will be created in.
|
||||
|
||||
Returns:
|
||||
A VTraceReturns namedtuple (vs, pg_advantages) where:
|
||||
vs: A float32 tensor of shape [T, B]. Can be used as target to
|
||||
train a baseline (V(x_t) - vs_t)^2.
|
||||
pg_advantages: A float32 tensor of shape [T, B]. Can be used as the
|
||||
advantage in the calculation of policy gradients.
|
||||
"""
|
||||
log_rhos = tf.convert_to_tensor(log_rhos, dtype=tf.float32)
|
||||
discounts = tf.convert_to_tensor(discounts, dtype=tf.float32)
|
||||
rewards = tf.convert_to_tensor(rewards, dtype=tf.float32)
|
||||
values = tf.convert_to_tensor(values, dtype=tf.float32)
|
||||
bootstrap_value = tf.convert_to_tensor(bootstrap_value, dtype=tf.float32)
|
||||
if clip_rho_threshold is not None:
|
||||
clip_rho_threshold = tf.convert_to_tensor(
|
||||
clip_rho_threshold, dtype=tf.float32)
|
||||
if clip_pg_rho_threshold is not None:
|
||||
clip_pg_rho_threshold = tf.convert_to_tensor(
|
||||
clip_pg_rho_threshold, dtype=tf.float32)
|
||||
|
||||
# Make sure tensor ranks are consistent.
|
||||
rho_rank = log_rhos.shape.ndims # Usually 2.
|
||||
values.shape.assert_has_rank(rho_rank)
|
||||
bootstrap_value.shape.assert_has_rank(rho_rank - 1)
|
||||
discounts.shape.assert_has_rank(rho_rank)
|
||||
rewards.shape.assert_has_rank(rho_rank)
|
||||
if clip_rho_threshold is not None:
|
||||
clip_rho_threshold.shape.assert_has_rank(0)
|
||||
if clip_pg_rho_threshold is not None:
|
||||
clip_pg_rho_threshold.shape.assert_has_rank(0)
|
||||
|
||||
with tf.name_scope(
|
||||
name,
|
||||
values=[log_rhos, discounts, rewards, values, bootstrap_value]):
|
||||
rhos = tf.exp(log_rhos)
|
||||
if clip_rho_threshold is not None:
|
||||
clipped_rhos = tf.minimum(
|
||||
clip_rho_threshold, rhos, name='clipped_rhos')
|
||||
else:
|
||||
clipped_rhos = rhos
|
||||
|
||||
cs = tf.minimum(1.0, rhos, name='cs')
|
||||
# Append bootstrapped value to get [v1, ..., v_t+1]
|
||||
values_t_plus_1 = tf.concat(
|
||||
[values[1:], tf.expand_dims(bootstrap_value, 0)], axis=0)
|
||||
deltas = clipped_rhos * (
|
||||
rewards + discounts * values_t_plus_1 - values)
|
||||
|
||||
# All sequences are reversed, computation starts from the back.
|
||||
sequences = (
|
||||
tf.reverse(discounts, axis=[0]),
|
||||
tf.reverse(cs, axis=[0]),
|
||||
tf.reverse(deltas, axis=[0]),
|
||||
)
|
||||
|
||||
# V-trace vs are calculated through a scan from the back to the
|
||||
# beginning of the given trajectory.
|
||||
def scanfunc(acc, sequence_item):
|
||||
discount_t, c_t, delta_t = sequence_item
|
||||
return delta_t + discount_t * c_t * acc
|
||||
|
||||
initial_values = tf.zeros_like(bootstrap_value)
|
||||
vs_minus_v_xs = tf.scan(
|
||||
fn=scanfunc,
|
||||
elems=sequences,
|
||||
initializer=initial_values,
|
||||
parallel_iterations=1,
|
||||
back_prop=False,
|
||||
name='scan')
|
||||
# Reverse the results back to original order.
|
||||
vs_minus_v_xs = tf.reverse(vs_minus_v_xs, [0], name='vs_minus_v_xs')
|
||||
|
||||
# Add V(x_s) to get v_s.
|
||||
vs = tf.add(vs_minus_v_xs, values, name='vs')
|
||||
|
||||
# Advantage for policy gradient.
|
||||
vs_t_plus_1 = tf.concat(
|
||||
[vs[1:], tf.expand_dims(bootstrap_value, 0)], axis=0)
|
||||
if clip_pg_rho_threshold is not None:
|
||||
clipped_pg_rhos = tf.minimum(
|
||||
clip_pg_rho_threshold, rhos, name='clipped_pg_rhos')
|
||||
else:
|
||||
clipped_pg_rhos = rhos
|
||||
pg_advantages = (
|
||||
clipped_pg_rhos * (rewards + discounts * vs_t_plus_1 - values))
|
||||
|
||||
# Make sure no gradients backpropagated through the returned values.
|
||||
return VTraceReturns(
|
||||
vs=tf.stop_gradient(vs),
|
||||
pg_advantages=tf.stop_gradient(pg_advantages))
|
||||
@@ -0,0 +1,217 @@
|
||||
"""Adapted from A3CPolicyGraph to add V-trace.
|
||||
|
||||
Keep in sync with changes to A3CPolicyGraph."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import tensorflow as tf
|
||||
import gym
|
||||
|
||||
import ray
|
||||
from ray.rllib.agents.impala import vtrace
|
||||
from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph
|
||||
from ray.rllib.models.catalog import ModelCatalog
|
||||
from ray.rllib.models.misc import linear, normc_initializer
|
||||
from ray.rllib.utils.error import UnsupportedSpaceException
|
||||
|
||||
|
||||
class VTraceLoss(object):
|
||||
def __init__(self,
|
||||
actions,
|
||||
actions_logp,
|
||||
actions_entropy,
|
||||
dones,
|
||||
behaviour_logits,
|
||||
target_logits,
|
||||
discount,
|
||||
rewards,
|
||||
values,
|
||||
bootstrap_value,
|
||||
vf_loss_coeff=0.5,
|
||||
entropy_coeff=-0.01,
|
||||
clip_rho_threshold=1.0,
|
||||
clip_pg_rho_threshold=1.0):
|
||||
"""Policy gradient loss with vtrace importance weighting.
|
||||
|
||||
VTraceLoss takes tensors of shape [T, B, ...], where `B` is the
|
||||
batch_size. The reason we need to know `B` is for V-trace to properly
|
||||
handle episode cut boundaries.
|
||||
|
||||
Args:
|
||||
actions: An int32 tensor of shape [T, B, NUM_ACTIONS].
|
||||
actions_logp: A float32 tensor of shape [T, B].
|
||||
actions_entropy: A float32 tensor of shape [T, B].
|
||||
dones: A bool tensor of shape [T, B].
|
||||
behaviour_logits: A float32 tensor of shape [T, B, NUM_ACTIONS].
|
||||
target_logits: A float32 tensor of shape [T, B, NUM_ACTIONS].
|
||||
discount: A float32 scalar.
|
||||
rewards: A float32 tensor of shape [T, B].
|
||||
values: A float32 tensor of shape [T, B].
|
||||
bootstrap_value: A float32 tensor of shape [B].
|
||||
"""
|
||||
|
||||
# Compute vtrace on the CPU for better perf.
|
||||
with tf.device("/cpu:0"):
|
||||
vtrace_returns = vtrace.from_logits(
|
||||
behaviour_policy_logits=behaviour_logits,
|
||||
target_policy_logits=target_logits,
|
||||
actions=tf.cast(actions, tf.int32),
|
||||
discounts=tf.to_float(~dones) * discount,
|
||||
rewards=rewards,
|
||||
values=values,
|
||||
bootstrap_value=bootstrap_value,
|
||||
clip_rho_threshold=tf.cast(clip_rho_threshold, tf.float32),
|
||||
clip_pg_rho_threshold=tf.cast(clip_pg_rho_threshold,
|
||||
tf.float32))
|
||||
|
||||
# The policy gradients loss
|
||||
self.pi_loss = -tf.reduce_sum(
|
||||
actions_logp * vtrace_returns.pg_advantages)
|
||||
|
||||
# The baseline loss
|
||||
delta = values - vtrace_returns.vs
|
||||
self.vf_loss = 0.5 * tf.reduce_sum(tf.square(delta))
|
||||
|
||||
# The entropy loss
|
||||
self.entropy = tf.reduce_sum(actions_entropy)
|
||||
|
||||
# The summed weighted loss
|
||||
self.total_loss = (self.pi_loss + self.vf_loss * vf_loss_coeff +
|
||||
self.entropy * entropy_coeff)
|
||||
|
||||
|
||||
class VTracePolicyGraph(TFPolicyGraph):
|
||||
def __init__(self, observation_space, action_space, config):
|
||||
config = dict(ray.rllib.agents.a3c.a3c.DEFAULT_CONFIG, **config)
|
||||
assert config["batch_mode"] == "truncate_episodes", \
|
||||
"Must use `truncate_episodes` batch mode with V-trace."
|
||||
self.config = config
|
||||
self.sess = tf.get_default_session()
|
||||
|
||||
# Setup the policy
|
||||
self.observations = tf.placeholder(
|
||||
tf.float32, [None] + list(observation_space.shape))
|
||||
dist_class, logit_dim = ModelCatalog.get_action_dist(
|
||||
action_space, self.config["model"])
|
||||
self.model = ModelCatalog.get_model(self.observations, logit_dim,
|
||||
self.config["model"])
|
||||
action_dist = dist_class(self.model.outputs)
|
||||
values = tf.reshape(
|
||||
linear(self.model.last_layer, 1, "value", normc_initializer(1.0)),
|
||||
[-1])
|
||||
self.var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
|
||||
tf.get_variable_scope().name)
|
||||
|
||||
# Setup the policy loss
|
||||
if isinstance(action_space, gym.spaces.Box):
|
||||
ac_size = action_space.shape[0]
|
||||
actions = tf.placeholder(tf.float32, [None, ac_size], name="ac")
|
||||
elif isinstance(action_space, gym.spaces.Discrete):
|
||||
ac_size = action_space.n
|
||||
actions = tf.placeholder(tf.int64, [None], name="ac")
|
||||
else:
|
||||
raise UnsupportedSpaceException(
|
||||
"Action space {} is not supported for IMPALA.".format(
|
||||
action_space))
|
||||
dones = tf.placeholder(tf.bool, [None], name="dones")
|
||||
rewards = tf.placeholder(tf.float32, [None], name="rewards")
|
||||
behaviour_logits = tf.placeholder(
|
||||
tf.float32, [None, ac_size], name="behaviour_logits")
|
||||
|
||||
def to_batches(tensor):
|
||||
if self.config["model"]["use_lstm"]:
|
||||
B = tf.shape(self.model.seq_lens)[0]
|
||||
T = tf.shape(tensor)[0] // B
|
||||
else:
|
||||
# Important: chop the tensor into batches at known episode cut
|
||||
# boundaries. TODO(ekl) this is kind of a hack
|
||||
T = (self.config["sample_batch_size"] //
|
||||
self.config["num_envs_per_worker"])
|
||||
B = tf.shape(tensor)[0] // T
|
||||
rs = tf.reshape(tensor,
|
||||
tf.concat([[B, T], tf.shape(tensor)[1:]], axis=0))
|
||||
# swap B and T axes
|
||||
return tf.transpose(
|
||||
rs,
|
||||
[1, 0] + list(range(2, 1 + int(tf.shape(tensor).shape[0]))))
|
||||
|
||||
if self.config["clip_rewards"]:
|
||||
clipped_rewards = tf.clip_by_value(rewards, -1, 1)
|
||||
else:
|
||||
clipped_rewards = rewards
|
||||
|
||||
# Inputs are reshaped from [B * T] => [T - 1, B] for V-trace calc.
|
||||
self.loss = VTraceLoss(
|
||||
actions=to_batches(actions)[:-1],
|
||||
actions_logp=to_batches(action_dist.logp(actions))[:-1],
|
||||
actions_entropy=to_batches(action_dist.entropy())[:-1],
|
||||
dones=to_batches(dones)[:-1],
|
||||
behaviour_logits=to_batches(behaviour_logits)[:-1],
|
||||
target_logits=to_batches(self.model.outputs)[:-1],
|
||||
discount=config["gamma"],
|
||||
rewards=to_batches(clipped_rewards)[:-1],
|
||||
values=to_batches(values)[:-1],
|
||||
bootstrap_value=to_batches(values)[-1],
|
||||
vf_loss_coeff=self.config["vf_loss_coeff"],
|
||||
entropy_coeff=self.config["entropy_coeff"],
|
||||
clip_rho_threshold=self.config["vtrace_clip_rho_threshold"],
|
||||
clip_pg_rho_threshold=self.config["vtrace_clip_pg_rho_threshold"])
|
||||
|
||||
# Initialize TFPolicyGraph
|
||||
loss_in = [
|
||||
("actions", actions),
|
||||
("dones", dones),
|
||||
("behaviour_logits", behaviour_logits),
|
||||
("rewards", rewards),
|
||||
("obs", self.observations),
|
||||
]
|
||||
TFPolicyGraph.__init__(
|
||||
self,
|
||||
observation_space,
|
||||
action_space,
|
||||
self.sess,
|
||||
obs_input=self.observations,
|
||||
action_sampler=action_dist.sample(),
|
||||
loss=self.loss.total_loss,
|
||||
loss_inputs=loss_in,
|
||||
state_inputs=self.model.state_in,
|
||||
state_outputs=self.model.state_out,
|
||||
seq_lens=self.model.seq_lens,
|
||||
max_seq_len=self.config["model"]["max_seq_len"])
|
||||
|
||||
self.sess.run(tf.global_variables_initializer())
|
||||
|
||||
def optimizer(self):
|
||||
return tf.train.AdamOptimizer(self.config["lr"])
|
||||
|
||||
def gradients(self, optimizer):
|
||||
grads = tf.gradients(self.loss.total_loss, self.var_list)
|
||||
self.grads, _ = tf.clip_by_global_norm(grads, self.config["grad_clip"])
|
||||
clipped_grads = list(zip(self.grads, self.var_list))
|
||||
return clipped_grads
|
||||
|
||||
def extra_compute_action_fetches(self):
|
||||
return {"behaviour_logits": self.model.outputs}
|
||||
|
||||
def extra_compute_grad_fetches(self):
|
||||
if self.config.get("summarize"):
|
||||
return {
|
||||
"stats": {
|
||||
"policy_loss": self.loss.pi_loss,
|
||||
"value_loss": self.loss.vf_loss,
|
||||
"entropy": self.loss.entropy,
|
||||
"grad_gnorm": tf.global_norm(self._grads),
|
||||
"var_gnorm": tf.global_norm(self.var_list),
|
||||
},
|
||||
}
|
||||
else:
|
||||
return {}
|
||||
|
||||
def postprocess_trajectory(self, sample_batch, other_agent_batches=None):
|
||||
del sample_batch.data["new_obs"] # not used, so save some bandwidth
|
||||
return sample_batch
|
||||
|
||||
def get_initial_state(self):
|
||||
return self.model.state_init
|
||||
@@ -1,4 +1,5 @@
|
||||
from ray.rllib.optimizers.policy_optimizer import PolicyOptimizer
|
||||
from ray.rllib.optimizers.async_replay_optimizer import AsyncReplayOptimizer
|
||||
from ray.rllib.optimizers.async_samples_optimizer import AsyncSamplesOptimizer
|
||||
from ray.rllib.optimizers.async_gradients_optimizer import \
|
||||
AsyncGradientsOptimizer
|
||||
@@ -7,6 +8,7 @@ from ray.rllib.optimizers.sync_replay_optimizer import SyncReplayOptimizer
|
||||
from ray.rllib.optimizers.multi_gpu_optimizer import LocalMultiGPUOptimizer
|
||||
|
||||
__all__ = [
|
||||
"PolicyOptimizer", "AsyncSamplesOptimizer", "AsyncGradientsOptimizer",
|
||||
"SyncSamplesOptimizer", "SyncReplayOptimizer", "LocalMultiGPUOptimizer"
|
||||
"PolicyOptimizer", "AsyncReplayOptimizer", "AsyncSamplesOptimizer",
|
||||
"AsyncGradientsOptimizer", "SyncSamplesOptimizer", "SyncReplayOptimizer",
|
||||
"LocalMultiGPUOptimizer"
|
||||
]
|
||||
|
||||
@@ -20,6 +20,7 @@ class AsyncGradientsOptimizer(PolicyOptimizer):
|
||||
self.wait_timer = TimerStat()
|
||||
self.dispatch_timer = TimerStat()
|
||||
self.grads_per_step = grads_per_step
|
||||
self.learner_stats = {}
|
||||
if not self.remote_evaluators:
|
||||
raise ValueError(
|
||||
"Async optimizer requires at least 1 remote evaluator")
|
||||
@@ -41,6 +42,8 @@ class AsyncGradientsOptimizer(PolicyOptimizer):
|
||||
with self.wait_timer:
|
||||
fut, e = gradient_queue.pop(0)
|
||||
gradient, info = ray.get(fut)
|
||||
if "stats" in info:
|
||||
self.learner_stats = info["stats"]
|
||||
|
||||
if gradient is not None:
|
||||
with self.apply_timer:
|
||||
@@ -61,4 +64,5 @@ class AsyncGradientsOptimizer(PolicyOptimizer):
|
||||
"wait_time_ms": round(1000 * self.wait_timer.mean, 3),
|
||||
"apply_time_ms": round(1000 * self.apply_timer.mean, 3),
|
||||
"dispatch_time_ms": round(1000 * self.dispatch_timer.mean, 3),
|
||||
"learner": self.learner_stats,
|
||||
})
|
||||
|
||||
@@ -0,0 +1,295 @@
|
||||
"""Implements Distributed Prioritized Experience Replay.
|
||||
|
||||
https://arxiv.org/abs/1803.00933"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import random
|
||||
import time
|
||||
import threading
|
||||
|
||||
import numpy as np
|
||||
from six.moves import queue
|
||||
|
||||
import ray
|
||||
from ray.rllib.optimizers.policy_optimizer import PolicyOptimizer
|
||||
from ray.rllib.optimizers.replay_buffer import PrioritizedReplayBuffer
|
||||
from ray.rllib.evaluation.sample_batch import SampleBatch
|
||||
from ray.rllib.utils.actors import TaskPool, create_colocated
|
||||
from ray.rllib.utils.timer import TimerStat
|
||||
from ray.rllib.utils.window_stat import WindowStat
|
||||
|
||||
SAMPLE_QUEUE_DEPTH = 2
|
||||
REPLAY_QUEUE_DEPTH = 4
|
||||
LEARNER_QUEUE_MAX_SIZE = 16
|
||||
|
||||
|
||||
@ray.remote
|
||||
class ReplayActor(object):
|
||||
"""A replay buffer shard.
|
||||
|
||||
Ray actors are single-threaded, so for scalability multiple replay actors
|
||||
may be created to increase parallelism."""
|
||||
|
||||
def __init__(self, num_shards, learning_starts, buffer_size,
|
||||
train_batch_size, prioritized_replay_alpha,
|
||||
prioritized_replay_beta, prioritized_replay_eps,
|
||||
clip_rewards):
|
||||
self.replay_starts = learning_starts // num_shards
|
||||
self.buffer_size = buffer_size // num_shards
|
||||
self.train_batch_size = train_batch_size
|
||||
self.prioritized_replay_beta = prioritized_replay_beta
|
||||
self.prioritized_replay_eps = prioritized_replay_eps
|
||||
|
||||
self.replay_buffer = PrioritizedReplayBuffer(
|
||||
self.buffer_size,
|
||||
alpha=prioritized_replay_alpha,
|
||||
clip_rewards=clip_rewards)
|
||||
|
||||
# Metrics
|
||||
self.add_batch_timer = TimerStat()
|
||||
self.replay_timer = TimerStat()
|
||||
self.update_priorities_timer = TimerStat()
|
||||
|
||||
def get_host(self):
|
||||
return os.uname()[1]
|
||||
|
||||
def add_batch(self, batch):
|
||||
PolicyOptimizer._check_not_multiagent(batch)
|
||||
with self.add_batch_timer:
|
||||
for row in batch.rows():
|
||||
self.replay_buffer.add(row["obs"], row["actions"],
|
||||
row["rewards"], row["new_obs"],
|
||||
row["dones"], row["weights"])
|
||||
|
||||
def replay(self):
|
||||
with self.replay_timer:
|
||||
if len(self.replay_buffer) < self.replay_starts:
|
||||
return None
|
||||
|
||||
(obses_t, actions, rewards, obses_tp1, dones, weights,
|
||||
batch_indexes) = self.replay_buffer.sample(
|
||||
self.train_batch_size, beta=self.prioritized_replay_beta)
|
||||
|
||||
batch = SampleBatch({
|
||||
"obs": obses_t,
|
||||
"actions": actions,
|
||||
"rewards": rewards,
|
||||
"new_obs": obses_tp1,
|
||||
"dones": dones,
|
||||
"weights": weights,
|
||||
"batch_indexes": batch_indexes
|
||||
})
|
||||
return batch
|
||||
|
||||
def update_priorities(self, batch_indexes, td_errors):
|
||||
with self.update_priorities_timer:
|
||||
new_priorities = (np.abs(td_errors) + self.prioritized_replay_eps)
|
||||
self.replay_buffer.update_priorities(batch_indexes, new_priorities)
|
||||
|
||||
def stats(self):
|
||||
stat = {
|
||||
"add_batch_time_ms": round(1000 * self.add_batch_timer.mean, 3),
|
||||
"replay_time_ms": round(1000 * self.replay_timer.mean, 3),
|
||||
"update_priorities_time_ms": round(
|
||||
1000 * self.update_priorities_timer.mean, 3),
|
||||
}
|
||||
stat.update(self.replay_buffer.stats())
|
||||
return stat
|
||||
|
||||
|
||||
class LearnerThread(threading.Thread):
|
||||
"""Background thread that updates the local model from replay data.
|
||||
|
||||
The learner thread communicates with the main thread through Queues. This
|
||||
is needed since Ray operations can only be run on the main thread. In
|
||||
addition, moving heavyweight gradient ops session runs off the main thread
|
||||
improves overall throughput.
|
||||
"""
|
||||
|
||||
def __init__(self, local_evaluator):
|
||||
threading.Thread.__init__(self)
|
||||
self.learner_queue_size = WindowStat("size", 50)
|
||||
self.local_evaluator = local_evaluator
|
||||
self.inqueue = queue.Queue(maxsize=LEARNER_QUEUE_MAX_SIZE)
|
||||
self.outqueue = queue.Queue()
|
||||
self.queue_timer = TimerStat()
|
||||
self.grad_timer = TimerStat()
|
||||
self.daemon = True
|
||||
self.weights_updated = False
|
||||
|
||||
def run(self):
|
||||
while True:
|
||||
self.step()
|
||||
|
||||
def step(self):
|
||||
with self.queue_timer:
|
||||
ra, replay = self.inqueue.get()
|
||||
if replay is not None:
|
||||
with self.grad_timer:
|
||||
td_error = self.local_evaluator.compute_apply(replay)[
|
||||
"td_error"]
|
||||
self.outqueue.put((ra, replay, td_error, replay.count))
|
||||
self.learner_queue_size.push(self.inqueue.qsize())
|
||||
self.weights_updated = True
|
||||
|
||||
|
||||
class AsyncReplayOptimizer(PolicyOptimizer):
|
||||
"""Main event loop of the Ape-X optimizer (async sampling with replay).
|
||||
|
||||
This class coordinates the data transfers between the learner thread,
|
||||
remote evaluators (Ape-X actors), and replay buffer actors.
|
||||
|
||||
This optimizer requires that policy evaluators return an additional
|
||||
"td_error" array in the info return of compute_gradients(). This error
|
||||
term will be used for sample prioritization."""
|
||||
|
||||
def _init(self,
|
||||
learning_starts=1000,
|
||||
buffer_size=10000,
|
||||
prioritized_replay=True,
|
||||
prioritized_replay_alpha=0.6,
|
||||
prioritized_replay_beta=0.4,
|
||||
prioritized_replay_eps=1e-6,
|
||||
train_batch_size=512,
|
||||
sample_batch_size=50,
|
||||
num_replay_buffer_shards=1,
|
||||
max_weight_sync_delay=400,
|
||||
clip_rewards=True,
|
||||
debug=False):
|
||||
|
||||
self.debug = debug
|
||||
self.replay_starts = learning_starts
|
||||
self.prioritized_replay_beta = prioritized_replay_beta
|
||||
self.prioritized_replay_eps = prioritized_replay_eps
|
||||
self.max_weight_sync_delay = max_weight_sync_delay
|
||||
|
||||
self.learner = LearnerThread(self.local_evaluator)
|
||||
self.learner.start()
|
||||
|
||||
self.replay_actors = create_colocated(ReplayActor, [
|
||||
num_replay_buffer_shards, learning_starts, buffer_size,
|
||||
train_batch_size, prioritized_replay_alpha,
|
||||
prioritized_replay_beta, prioritized_replay_eps, clip_rewards
|
||||
], num_replay_buffer_shards)
|
||||
assert len(self.remote_evaluators) > 0
|
||||
|
||||
# Stats
|
||||
self.timers = {
|
||||
k: TimerStat()
|
||||
for k in [
|
||||
"put_weights", "get_samples", "enqueue", "sample_processing",
|
||||
"replay_processing", "update_priorities", "train", "sample"
|
||||
]
|
||||
}
|
||||
self.num_weight_syncs = 0
|
||||
self.learning_started = False
|
||||
|
||||
# Number of worker steps since the last weight update
|
||||
self.steps_since_update = {}
|
||||
|
||||
# Otherwise kick of replay tasks for local gradient updates
|
||||
self.replay_tasks = TaskPool()
|
||||
for ra in self.replay_actors:
|
||||
for _ in range(REPLAY_QUEUE_DEPTH):
|
||||
self.replay_tasks.add(ra, ra.replay.remote())
|
||||
|
||||
# Kick off async background sampling
|
||||
self.sample_tasks = TaskPool()
|
||||
weights = self.local_evaluator.get_weights()
|
||||
for ev in self.remote_evaluators:
|
||||
ev.set_weights.remote(weights)
|
||||
self.steps_since_update[ev] = 0
|
||||
for _ in range(SAMPLE_QUEUE_DEPTH):
|
||||
self.sample_tasks.add(ev, ev.sample_with_count.remote())
|
||||
|
||||
def step(self):
|
||||
start = time.time()
|
||||
sample_timesteps, train_timesteps = self._step()
|
||||
time_delta = time.time() - start
|
||||
self.timers["sample"].push(time_delta)
|
||||
self.timers["sample"].push_units_processed(sample_timesteps)
|
||||
if train_timesteps > 0:
|
||||
self.learning_started = True
|
||||
if self.learning_started:
|
||||
self.timers["train"].push(time_delta)
|
||||
self.timers["train"].push_units_processed(train_timesteps)
|
||||
self.num_steps_sampled += sample_timesteps
|
||||
self.num_steps_trained += train_timesteps
|
||||
|
||||
def _step(self):
|
||||
sample_timesteps, train_timesteps = 0, 0
|
||||
weights = None
|
||||
|
||||
with self.timers["sample_processing"]:
|
||||
completed = list(self.sample_tasks.completed())
|
||||
counts = ray.get([c[1][1] for c in completed])
|
||||
for i, (ev, (sample_batch, count)) in enumerate(completed):
|
||||
sample_timesteps += counts[i]
|
||||
|
||||
# Send the data to the replay buffer
|
||||
random.choice(
|
||||
self.replay_actors).add_batch.remote(sample_batch)
|
||||
|
||||
# Update weights if needed
|
||||
self.steps_since_update[ev] += counts[i]
|
||||
if self.steps_since_update[ev] >= self.max_weight_sync_delay:
|
||||
# Note that it's important to pull new weights once
|
||||
# updated to avoid excessive correlation between actors
|
||||
if weights is None or self.learner.weights_updated:
|
||||
self.learner.weights_updated = False
|
||||
with self.timers["put_weights"]:
|
||||
weights = ray.put(
|
||||
self.local_evaluator.get_weights())
|
||||
ev.set_weights.remote(weights)
|
||||
self.num_weight_syncs += 1
|
||||
self.steps_since_update[ev] = 0
|
||||
|
||||
# Kick off another sample request
|
||||
self.sample_tasks.add(ev, ev.sample_with_count.remote())
|
||||
|
||||
with self.timers["replay_processing"]:
|
||||
for ra, replay in self.replay_tasks.completed():
|
||||
self.replay_tasks.add(ra, ra.replay.remote())
|
||||
with self.timers["get_samples"]:
|
||||
samples = ray.get(replay)
|
||||
with self.timers["enqueue"]:
|
||||
self.learner.inqueue.put((ra, samples))
|
||||
|
||||
with self.timers["update_priorities"]:
|
||||
while not self.learner.outqueue.empty():
|
||||
ra, replay, td_error, count = self.learner.outqueue.get()
|
||||
ra.update_priorities.remote(replay["batch_indexes"], td_error)
|
||||
train_timesteps += count
|
||||
|
||||
return sample_timesteps, train_timesteps
|
||||
|
||||
def stats(self):
|
||||
replay_stats = ray.get(self.replay_actors[0].stats.remote())
|
||||
timing = {
|
||||
"{}_time_ms".format(k): round(1000 * self.timers[k].mean, 3)
|
||||
for k in self.timers
|
||||
}
|
||||
timing["learner_grad_time_ms"] = round(
|
||||
1000 * self.learner.grad_timer.mean, 3)
|
||||
timing["learner_dequeue_time_ms"] = round(
|
||||
1000 * self.learner.queue_timer.mean, 3)
|
||||
stats = {
|
||||
"sample_throughput": round(self.timers["sample"].mean_throughput,
|
||||
3),
|
||||
"train_throughput": round(self.timers["train"].mean_throughput, 3),
|
||||
"num_weight_syncs": self.num_weight_syncs,
|
||||
}
|
||||
debug_stats = {
|
||||
"replay_shard_0": replay_stats,
|
||||
"timing_breakdown": timing,
|
||||
"pending_sample_tasks": self.sample_tasks.count,
|
||||
"pending_replay_tasks": self.replay_tasks.count,
|
||||
"learner_queue": self.learner.learner_queue_size.stats(),
|
||||
}
|
||||
if self.debug:
|
||||
stats.update(debug_stats)
|
||||
return dict(PolicyOptimizer.stats(self), **stats)
|
||||
@@ -1,108 +1,28 @@
|
||||
"""Implements Distributed Prioritized Experience Replay.
|
||||
"""Implements the IMPALA architecture.
|
||||
|
||||
https://arxiv.org/abs/1803.00933"""
|
||||
https://arxiv.org/abs/1802.01561"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import random
|
||||
import time
|
||||
import threading
|
||||
|
||||
import numpy as np
|
||||
from six.moves import queue
|
||||
|
||||
import ray
|
||||
from ray.rllib.optimizers.policy_optimizer import PolicyOptimizer
|
||||
from ray.rllib.optimizers.replay_buffer import PrioritizedReplayBuffer
|
||||
from ray.rllib.evaluation.sample_batch import SampleBatch
|
||||
from ray.rllib.utils.actors import TaskPool, create_colocated
|
||||
from ray.rllib.utils.actors import TaskPool
|
||||
from ray.rllib.utils.timer import TimerStat
|
||||
from ray.rllib.utils.window_stat import WindowStat
|
||||
|
||||
SAMPLE_QUEUE_DEPTH = 2
|
||||
REPLAY_QUEUE_DEPTH = 4
|
||||
LEARNER_QUEUE_MAX_SIZE = 16
|
||||
|
||||
|
||||
@ray.remote
|
||||
class ReplayActor(object):
|
||||
"""A replay buffer shard.
|
||||
|
||||
Ray actors are single-threaded, so for scalability multiple replay actors
|
||||
may be created to increase parallelism."""
|
||||
|
||||
def __init__(self, num_shards, learning_starts, buffer_size,
|
||||
train_batch_size, prioritized_replay_alpha,
|
||||
prioritized_replay_beta, prioritized_replay_eps,
|
||||
clip_rewards):
|
||||
self.replay_starts = learning_starts // num_shards
|
||||
self.buffer_size = buffer_size // num_shards
|
||||
self.train_batch_size = train_batch_size
|
||||
self.prioritized_replay_beta = prioritized_replay_beta
|
||||
self.prioritized_replay_eps = prioritized_replay_eps
|
||||
|
||||
self.replay_buffer = PrioritizedReplayBuffer(
|
||||
self.buffer_size,
|
||||
alpha=prioritized_replay_alpha,
|
||||
clip_rewards=clip_rewards)
|
||||
|
||||
# Metrics
|
||||
self.add_batch_timer = TimerStat()
|
||||
self.replay_timer = TimerStat()
|
||||
self.update_priorities_timer = TimerStat()
|
||||
|
||||
def get_host(self):
|
||||
return os.uname()[1]
|
||||
|
||||
def add_batch(self, batch):
|
||||
PolicyOptimizer._check_not_multiagent(batch)
|
||||
with self.add_batch_timer:
|
||||
for row in batch.rows():
|
||||
self.replay_buffer.add(row["obs"], row["actions"],
|
||||
row["rewards"], row["new_obs"],
|
||||
row["dones"], row["weights"])
|
||||
|
||||
def replay(self):
|
||||
with self.replay_timer:
|
||||
if len(self.replay_buffer) < self.replay_starts:
|
||||
return None
|
||||
|
||||
(obses_t, actions, rewards, obses_tp1, dones, weights,
|
||||
batch_indexes) = self.replay_buffer.sample(
|
||||
self.train_batch_size, beta=self.prioritized_replay_beta)
|
||||
|
||||
batch = SampleBatch({
|
||||
"obs": obses_t,
|
||||
"actions": actions,
|
||||
"rewards": rewards,
|
||||
"new_obs": obses_tp1,
|
||||
"dones": dones,
|
||||
"weights": weights,
|
||||
"batch_indexes": batch_indexes
|
||||
})
|
||||
return batch
|
||||
|
||||
def update_priorities(self, batch_indexes, td_errors):
|
||||
with self.update_priorities_timer:
|
||||
new_priorities = (np.abs(td_errors) + self.prioritized_replay_eps)
|
||||
self.replay_buffer.update_priorities(batch_indexes, new_priorities)
|
||||
|
||||
def stats(self):
|
||||
stat = {
|
||||
"add_batch_time_ms": round(1000 * self.add_batch_timer.mean, 3),
|
||||
"replay_time_ms": round(1000 * self.replay_timer.mean, 3),
|
||||
"update_priorities_time_ms": round(
|
||||
1000 * self.update_priorities_timer.mean, 3),
|
||||
}
|
||||
stat.update(self.replay_buffer.stats())
|
||||
return stat
|
||||
|
||||
|
||||
class LearnerThread(threading.Thread):
|
||||
"""Background thread that updates the local model from replay data.
|
||||
"""Background thread that updates the local model from sample trajectories.
|
||||
|
||||
The learner thread communicates with the main thread through Queues. This
|
||||
is needed since Ray operations can only be run on the main thread. In
|
||||
@@ -119,7 +39,8 @@ class LearnerThread(threading.Thread):
|
||||
self.queue_timer = TimerStat()
|
||||
self.grad_timer = TimerStat()
|
||||
self.daemon = True
|
||||
self.weights_updated = False
|
||||
self.weights_updated = 0
|
||||
self.stats = {}
|
||||
|
||||
def run(self):
|
||||
while True:
|
||||
@@ -127,86 +48,57 @@ class LearnerThread(threading.Thread):
|
||||
|
||||
def step(self):
|
||||
with self.queue_timer:
|
||||
ra, replay = self.inqueue.get()
|
||||
if replay is not None:
|
||||
ra, batch = self.inqueue.get()
|
||||
|
||||
if batch is not None:
|
||||
with self.grad_timer:
|
||||
td_error = self.local_evaluator.compute_apply(replay)[
|
||||
"td_error"]
|
||||
self.outqueue.put((ra, replay, td_error, replay.count))
|
||||
fetches = self.local_evaluator.compute_apply(batch)
|
||||
self.weights_updated += 1
|
||||
if "stats" in fetches:
|
||||
self.stats = fetches["stats"]
|
||||
self.outqueue.put(batch.count)
|
||||
self.learner_queue_size.push(self.inqueue.qsize())
|
||||
self.weights_updated = True
|
||||
|
||||
|
||||
class AsyncSamplesOptimizer(PolicyOptimizer):
|
||||
"""Main event loop of the Ape-X optimizer (async sampling with replay).
|
||||
"""Main event loop of the IMPALA architecture.
|
||||
|
||||
This class coordinates the data transfers between the learner thread,
|
||||
remote evaluators (Ape-X actors), and replay buffer actors.
|
||||
This class coordinates the data transfers between the learner thread
|
||||
and remote evaluators (IMPALA actors).
|
||||
"""
|
||||
|
||||
This optimizer requires that policy evaluators return an additional
|
||||
"td_error" array in the info return of compute_gradients(). This error
|
||||
term will be used for sample prioritization."""
|
||||
|
||||
def _init(self,
|
||||
learning_starts=1000,
|
||||
buffer_size=10000,
|
||||
prioritized_replay=True,
|
||||
prioritized_replay_alpha=0.6,
|
||||
prioritized_replay_beta=0.4,
|
||||
prioritized_replay_eps=1e-6,
|
||||
train_batch_size=512,
|
||||
sample_batch_size=50,
|
||||
num_replay_buffer_shards=1,
|
||||
max_weight_sync_delay=400,
|
||||
clip_rewards=True,
|
||||
debug=False):
|
||||
def _init(self, train_batch_size=512, sample_batch_size=50, debug=False):
|
||||
|
||||
self.debug = debug
|
||||
self.replay_starts = learning_starts
|
||||
self.prioritized_replay_beta = prioritized_replay_beta
|
||||
self.prioritized_replay_eps = prioritized_replay_eps
|
||||
self.max_weight_sync_delay = max_weight_sync_delay
|
||||
self.learning_started = False
|
||||
self.train_batch_size = train_batch_size
|
||||
|
||||
self.learner = LearnerThread(self.local_evaluator)
|
||||
self.learner.start()
|
||||
|
||||
self.replay_actors = create_colocated(ReplayActor, [
|
||||
num_replay_buffer_shards, learning_starts, buffer_size,
|
||||
train_batch_size, prioritized_replay_alpha,
|
||||
prioritized_replay_beta, prioritized_replay_eps, clip_rewards
|
||||
], num_replay_buffer_shards)
|
||||
assert len(self.remote_evaluators) > 0
|
||||
|
||||
# Stats
|
||||
self.timers = {
|
||||
k: TimerStat()
|
||||
for k in [
|
||||
"put_weights", "get_samples", "enqueue", "sample_processing",
|
||||
"replay_processing", "update_priorities", "train", "sample"
|
||||
]
|
||||
for k in
|
||||
["put_weights", "enqueue", "sample_processing", "train", "sample"]
|
||||
}
|
||||
self.num_weight_syncs = 0
|
||||
self.learning_started = False
|
||||
|
||||
# Number of worker steps since the last weight update
|
||||
self.steps_since_update = {}
|
||||
|
||||
# Otherwise kick of replay tasks for local gradient updates
|
||||
self.replay_tasks = TaskPool()
|
||||
for ra in self.replay_actors:
|
||||
for _ in range(REPLAY_QUEUE_DEPTH):
|
||||
self.replay_tasks.add(ra, ra.replay.remote())
|
||||
|
||||
# Kick off async background sampling
|
||||
self.sample_tasks = TaskPool()
|
||||
weights = self.local_evaluator.get_weights()
|
||||
for ev in self.remote_evaluators:
|
||||
ev.set_weights.remote(weights)
|
||||
self.steps_since_update[ev] = 0
|
||||
for _ in range(SAMPLE_QUEUE_DEPTH):
|
||||
self.sample_tasks.add(ev, ev.sample_with_count.remote())
|
||||
self.sample_tasks.add(ev, ev.sample.remote())
|
||||
|
||||
self.batch_buffer = []
|
||||
|
||||
def step(self):
|
||||
assert self.learner.is_alive()
|
||||
start = time.time()
|
||||
sample_timesteps, train_timesteps = self._step()
|
||||
time_delta = time.time() - start
|
||||
@@ -225,50 +117,37 @@ class AsyncSamplesOptimizer(PolicyOptimizer):
|
||||
weights = None
|
||||
|
||||
with self.timers["sample_processing"]:
|
||||
completed = list(self.sample_tasks.completed())
|
||||
counts = ray.get([c[1][1] for c in completed])
|
||||
for i, (ev, (sample_batch, count)) in enumerate(completed):
|
||||
sample_timesteps += counts[i]
|
||||
for ev, sample_batch in self.sample_tasks.completed_prefetch():
|
||||
sample_batch = ray.get(sample_batch)
|
||||
sample_timesteps += sample_batch.count
|
||||
self.batch_buffer.append(sample_batch)
|
||||
if sum(b.count
|
||||
for b in self.batch_buffer) >= self.train_batch_size:
|
||||
train_batch = self.batch_buffer[0].concat_samples(
|
||||
self.batch_buffer)
|
||||
with self.timers["enqueue"]:
|
||||
self.learner.inqueue.put((ev, train_batch))
|
||||
self.batch_buffer = []
|
||||
|
||||
# Send the data to the replay buffer
|
||||
random.choice(
|
||||
self.replay_actors).add_batch.remote(sample_batch)
|
||||
|
||||
# Update weights if needed
|
||||
self.steps_since_update[ev] += counts[i]
|
||||
if self.steps_since_update[ev] >= self.max_weight_sync_delay:
|
||||
# Note that it's important to pull new weights once
|
||||
# updated to avoid excessive correlation between actors
|
||||
if weights is None or self.learner.weights_updated:
|
||||
self.learner.weights_updated = False
|
||||
with self.timers["put_weights"]:
|
||||
weights = ray.put(
|
||||
self.local_evaluator.get_weights())
|
||||
ev.set_weights.remote(weights)
|
||||
self.num_weight_syncs += 1
|
||||
self.steps_since_update[ev] = 0
|
||||
# Note that it's important to pull new weights once
|
||||
# updated to avoid excessive correlation between actors
|
||||
if weights is None or self.learner.weights_updated:
|
||||
self.learner.weights_updated = False
|
||||
with self.timers["put_weights"]:
|
||||
weights = ray.put(self.local_evaluator.get_weights())
|
||||
ev.set_weights.remote(weights)
|
||||
self.num_weight_syncs += 1
|
||||
|
||||
# Kick off another sample request
|
||||
self.sample_tasks.add(ev, ev.sample_with_count.remote())
|
||||
self.sample_tasks.add(ev, ev.sample.remote())
|
||||
|
||||
with self.timers["replay_processing"]:
|
||||
for ra, replay in self.replay_tasks.completed():
|
||||
self.replay_tasks.add(ra, ra.replay.remote())
|
||||
with self.timers["get_samples"]:
|
||||
samples = ray.get(replay)
|
||||
with self.timers["enqueue"]:
|
||||
self.learner.inqueue.put((ra, samples))
|
||||
|
||||
with self.timers["update_priorities"]:
|
||||
while not self.learner.outqueue.empty():
|
||||
ra, replay, td_error, count = self.learner.outqueue.get()
|
||||
ra.update_priorities.remote(replay["batch_indexes"], td_error)
|
||||
train_timesteps += count
|
||||
while not self.learner.outqueue.empty():
|
||||
count = self.learner.outqueue.get()
|
||||
train_timesteps += count
|
||||
|
||||
return sample_timesteps, train_timesteps
|
||||
|
||||
def stats(self):
|
||||
replay_stats = ray.get(self.replay_actors[0].stats.remote())
|
||||
timing = {
|
||||
"{}_time_ms".format(k): round(1000 * self.timers[k].mean, 3)
|
||||
for k in self.timers
|
||||
@@ -284,12 +163,12 @@ class AsyncSamplesOptimizer(PolicyOptimizer):
|
||||
"num_weight_syncs": self.num_weight_syncs,
|
||||
}
|
||||
debug_stats = {
|
||||
"replay_shard_0": replay_stats,
|
||||
"timing_breakdown": timing,
|
||||
"pending_sample_tasks": self.sample_tasks.count,
|
||||
"pending_replay_tasks": self.replay_tasks.count,
|
||||
"learner_queue": self.learner.learner_queue_size.stats(),
|
||||
}
|
||||
if self.debug:
|
||||
stats.update(debug_stats)
|
||||
if self.learner.stats:
|
||||
stats["learner"] = self.learner.stats
|
||||
return dict(PolicyOptimizer.stats(self), **stats)
|
||||
|
||||
@@ -0,0 +1,11 @@
|
||||
# This can reach 18-19 reward within 10 minutes on a Tesla M60 GPU (e.g., G3 EC2 node)
|
||||
# with 32 workers and 10 envs per worker. This is more efficient than the non-vectorized
|
||||
# configuration which requires 128 workers to achieve the same performance.
|
||||
pong-impala-vectorized:
|
||||
env: PongNoFrameskip-v4
|
||||
run: IMPALA
|
||||
config:
|
||||
sample_batch_size: 500 # 50 * num_envs_per_worker
|
||||
train_batch_size: 500
|
||||
num_workers: 32
|
||||
num_envs_per_worker: 10
|
||||
@@ -0,0 +1,13 @@
|
||||
# This can reach 18-19 reward within 10 minutes on a Tesla M60 GPU (e.g., G3 EC2 node):
|
||||
# 128 workers -> 8 minutes
|
||||
# 32 workers -> 17 minutes
|
||||
# 16 workers -> 40 min+
|
||||
# See also: pong-impala-vectorized.yaml
|
||||
pong-impala:
|
||||
env: PongNoFrameskip-v4
|
||||
run: IMPALA
|
||||
config:
|
||||
sample_batch_size: 50
|
||||
train_batch_size: 500
|
||||
num_workers: 128
|
||||
num_envs_per_worker: 1
|
||||
@@ -12,6 +12,7 @@ class TaskPool(object):
|
||||
def __init__(self):
|
||||
self._tasks = {}
|
||||
self._objects = {}
|
||||
self._fetching = []
|
||||
|
||||
def add(self, worker, all_obj_ids):
|
||||
if isinstance(all_obj_ids, list):
|
||||
@@ -28,6 +29,25 @@ class TaskPool(object):
|
||||
for obj_id in ready:
|
||||
yield (self._tasks.pop(obj_id), self._objects.pop(obj_id))
|
||||
|
||||
def completed_prefetch(self):
|
||||
"""Similar to completed but only returns once the object is local.
|
||||
|
||||
Assumes obj_id only is one id."""
|
||||
|
||||
for worker, obj_id in self.completed():
|
||||
plasma_id = ray.pyarrow.plasma.ObjectID(obj_id.id())
|
||||
ray.worker.global_worker.plasma_client.fetch([plasma_id])
|
||||
self._fetching.append((worker, obj_id))
|
||||
|
||||
remaining = []
|
||||
for worker, obj_id in self._fetching:
|
||||
plasma_id = ray.pyarrow.plasma.ObjectID(obj_id.id())
|
||||
if ray.worker.global_worker.plasma_client.contains(plasma_id):
|
||||
yield (worker, obj_id)
|
||||
else:
|
||||
remaining.append((worker, obj_id))
|
||||
self._fetching = remaining
|
||||
|
||||
@property
|
||||
def count(self):
|
||||
return len(self._tasks)
|
||||
|
||||
Reference in New Issue
Block a user