[rllib] Feature/soft actor critic v2 (#5328)

* Add base for Soft Actor-Critic

* Pick changes from old SAC branch

* Update sac.py

* First implementation of sac model

* Remove unnecessary SAC imports

* Prune unnecessary noise and exploration code

* Implement SAC model and use that in SAC policy

* runs but doesn't learn

* clear state

* fix batch size

* Add missing alpha grads and vars

* -200 by 2k timesteps

* doc

* lazy squash

* one file

* ignore tfp

* revert done
This commit is contained in:
Kristian Hartikainen
2019-08-01 23:37:36 -07:00
committed by Eric Liang
parent 3ae54a2b20
commit 13fb9fe3db
21 changed files with 827 additions and 26 deletions
+1
View File
@@ -28,6 +28,7 @@ MOCK_MODULES = [
"scipy",
"scipy.signal",
"scipy.stats",
"tensorflow_probability",
"tensorflow",
"tensorflow.contrib",
"tensorflow.contrib.all_reduce",
+21 -6
View File
@@ -164,12 +164,12 @@ Tuned examples: `PongDeterministic-v4 <https://github.com/ray-project/ray/blob/m
**Atari results @10M steps**: `more details <https://github.com/ray-project/rl-experiments>`__
============= ======================== ============================= ============================== ===============================
Atari env RLlib DQN RLlib Dueling DDQN RLlib Dist. DQN Hessel et al. DQN
Atari env RLlib DQN RLlib Dueling DDQN RLlib Dist. DQN Hessel et al. DQN
============= ======================== ============================= ============================== ===============================
BeamRider 2869 1910 4447 ~2000
Breakout 287 312 410 ~150
Qbert 3921 7968 15780 ~4000
SpaceInvaders 650 1001 1025 ~500
BeamRider 2869 1910 4447 ~2000
Breakout 287 312 410 ~150
Qbert 3921 7968 15780 ~4000
SpaceInvaders 650 1001 1025 ~500
============= ======================== ============================= ============================== ===============================
**DQN-specific configs** (see also `common configs <rllib-training.html#common-parameters>`__):
@@ -217,7 +217,7 @@ SpaceInvaders 671 944 ~800
============= ========================= =============================
MuJoCo env RLlib PPO 16-workers @ 1h Fan et al PPO 16-workers @ 1h
============= ========================= =============================
HalfCheetah 9664 ~7700
HalfCheetah 9664 ~7700
============= ========================= =============================
.. figure:: ppo.png
@@ -232,6 +232,21 @@ HalfCheetah 9664 ~7700
:start-after: __sphinx_doc_begin__
:end-before: __sphinx_doc_end__
-Soft Actor Critic (SAC)
------------------------
`[paper] <https://arxiv.org/pdf/1801.01290>`__ `[implementation] <https://github.com/ray-project/ray/blob/master/python/ray/rllib/agents/sac/sac.py>`__
RLlib's soft-actor critic implementation is ported from the `official SAC repo <https://github.com/rail-berkeley/softlearning>`__ to better integrate with RLlib APIs. Note that SAC has two fields to configure for custom models: ``policy_model`` and ``Q_model``, and currently has no support for non-continuous action distributions. It is also currently *experimental*.
Tuned examples: `Pendulum-v0 <https://github.com/ray-project/ray/blob/master/python/ray/rllib/tuned_examples/regression_tests/pendulum-sac.yaml>`__
**SAC-specific configs** (see also `common configs <rllib-training.html#common-parameters>`__):
.. literalinclude:: ../../python/ray/rllib/agents/sac/sac.py
:language: python
:start-after: __sphinx_doc_begin__
:end-before: __sphinx_doc_end__
Derivative-free
~~~~~~~~~~~~~~~
+1
View File
@@ -18,6 +18,7 @@ DQN, Rainbow **Yes** `+parametric`_ No **Yes** No
DDPG, TD3 No **Yes** **Yes** No
APEX-DQN **Yes** `+parametric`_ No **Yes** No
APEX-DDPG No **Yes** **Yes** No
SAC (todo) **Yes** **Yes** No
ES **Yes** **Yes** No No
ARS **Yes** **Yes** No No
QMIX **Yes** No **Yes** **Yes**
+1 -1
View File
@@ -37,7 +37,7 @@ The ``rllib train`` command (same as the ``train.py`` script in the repo) has a
The most important options are for choosing the environment
with ``--env`` (any OpenAI gym environment including ones registered by the user
can be used) and for choosing the algorithm with ``--run``
(available options are ``PPO``, ``PG``, ``A2C``, ``A3C``, ``IMPALA``, ``ES``, ``DDPG``, ``DQN``, ``MARWIL``, ``APEX``, and ``APEX_DDPG``).
(available options are ``SAC``, ``PPO``, ``PG``, ``A2C``, ``A3C``, ``IMPALA``, ``ES``, ``DDPG``, ``DQN``, ``MARWIL``, ``APEX``, and ``APEX_DDPG``).
Evaluating Trained Policies
~~~~~~~~~~~~~~~~~~~~~~~~~~~
+2
View File
@@ -73,6 +73,8 @@ Algorithms
- `Proximal Policy Optimization (PPO) <rllib-algorithms.html#proximal-policy-optimization-ppo>`__
- `Soft Actor Critic (SAC) <rllib-algorithms.html#soft-actor-critic-sac>`__
* Derivative-free
- `Augmented Random Search (ARS) <rllib-algorithms.html#augmented-random-search-ars>`__
+1 -1
View File
@@ -7,7 +7,7 @@ RUN conda install -y numpy
RUN apt-get install -y zlib1g-dev
# The following is needed to support TensorFlow 1.14
RUN conda remove -y --force wrapt
RUN pip install gym[atari] opencv-python-headless tensorflow lz4 keras pytest-timeout smart_open
RUN pip install gym[atari] opencv-python-headless tensorflow lz4 keras pytest-timeout smart_open tensorflow_probability
RUN pip install -U h5py # Mutes FutureWarnings
RUN pip install --upgrade bayesian-optimization
RUN pip install --upgrade git+git://github.com/hyperopt/hyperopt.git
+6
View File
@@ -9,6 +9,11 @@ import traceback
from ray.rllib.contrib.registry import CONTRIBUTED_ALGORITHMS
def _import_sac():
from ray.rllib.agents import sac
return sac.SACTrainer
def _import_appo():
from ray.rllib.agents import ppo
return ppo.APPOTrainer
@@ -95,6 +100,7 @@ def _import_marwil():
ALGORITHMS = {
"SAC": _import_sac,
"DDPG": _import_ddpg,
"APEX_DDPG": _import_apex_ddpg,
"TD3": _import_td3,
+1
View File
@@ -0,0 +1 @@
Implementation of Soft Actor-Critic (https://arxiv.org/abs/1812.05905.pdf).
+13
View File
@@ -0,0 +1,13 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from ray.rllib.agents.sac.sac import SACTrainer, DEFAULT_CONFIG
from ray.rllib.utils import renamed_agent
SACAgent = renamed_agent(SACTrainer)
__all__ = [
"SACTrainer",
"DEFAULT_CONFIG",
]
+119
View File
@@ -0,0 +1,119 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from ray.rllib.agents.trainer import with_common_config
from ray.rllib.agents.dqn.dqn import GenericOffPolicyTrainer
from ray.rllib.agents.sac.sac_policy import SACTFPolicy
OPTIMIZER_SHARED_CONFIGS = [
"buffer_size", "prioritized_replay", "prioritized_replay_alpha",
"prioritized_replay_beta", "prioritized_replay_eps", "sample_batch_size",
"train_batch_size", "learning_starts"
]
# yapf: disable
# __sphinx_doc_begin__
DEFAULT_CONFIG = with_common_config({
# === Model ===
"twin_q": True,
"use_state_preprocessor": False,
"policy": "GaussianLatentSpacePolicy",
# RLlib model options for the Q function
"Q_model": {
"hidden_activation": "relu",
"hidden_layer_sizes": (256, 256),
},
# RLlib model options for the policy function
"policy_model": {
"hidden_activation": "relu",
"hidden_layer_sizes": (256, 256),
},
# === Learning ===
# Update the target by \tau * policy + (1-\tau) * target_policy
"tau": 5e-3,
# Target entropy lower bound. This is the inverse of reward scale,
# and will be optimized automatically.
"target_entropy": "auto",
# Disable setting done=True at end of episode.
"no_done_at_end": True,
# N-step target updates
"n_step": 1,
# === Evaluation ===
# The evaluation stats will be reported under the "evaluation" metric key.
"evaluation_interval": 1,
# Number of episodes to run per evaluation period.
"evaluation_num_episodes": 1,
# Extra configuration that disables exploration.
"evaluation_config": {
"exploration_enabled": False,
},
# === Exploration ===
# Number of env steps to optimize for before returning
"timesteps_per_iteration": 1000,
"exploration_enabled": True,
# === Replay buffer ===
# Size of the replay buffer. Note that if async_updates is set, then
# each worker will have a replay buffer of this size.
"buffer_size": int(1e6),
# If True prioritized replay buffer will be used.
# TODO(hartikainen): Make sure this works or remove the option.
"prioritized_replay": False,
"prioritized_replay_alpha": 0.6,
"prioritized_replay_beta": 0.4,
"prioritized_replay_eps": 1e-6,
"beta_annealing_fraction": 0.2,
"final_prioritized_replay_beta": 0.4,
"compress_observations": False,
# === Optimization ===
"optimization": {
"actor_learning_rate": 3e-4,
"critic_learning_rate": 3e-4,
"entropy_learning_rate": 3e-4,
},
# If not None, clip gradients during optimization at this value
"grad_norm_clipping": None,
# How many steps of the model to sample before learning starts.
"learning_starts": 1500,
# Update the replay buffer with this many samples at once. Note that this
# setting applies per-worker if num_workers > 1.
"sample_batch_size": 1,
# Size of a batched sampled from replay buffer for training. Note that
# if async_updates is set, then each worker returns gradients for a
# batch of this size.
"train_batch_size": 256,
# Update the target network every `target_network_update_freq` steps.
"target_network_update_freq": 0,
# === Parallelism ===
# Whether to use a GPU for local optimization.
"num_gpus": 0,
# Number of workers for collecting samples with. This only makes sense
# to increase if your environment is particularly slow to sample, or if
# you"re using the Async or Ape-X optimizers.
"num_workers": 0,
# Whether to allocate GPUs for workers (if > 0).
"num_gpus_per_worker": 0,
# Whether to allocate CPUs for workers (if > 0).
"num_cpus_per_worker": 1,
# Whether to compute priorities on workers.
"worker_side_prioritization": False,
# Prevent iterations from going lower than this time span
"min_iter_time_s": 1,
# TODO(ekl) these are unused; remove them from sac config
"per_worker_exploration": False,
"exploration_fraction": 0.1,
"schedule_max_timesteps": 100000,
"exploration_final_eps": 0.02,
})
# __sphinx_doc_end__
# yapf: enable
SACTrainer = GenericOffPolicyTrainer.with_updates(
name="SAC", default_config=DEFAULT_CONFIG, default_policy=SACTFPolicy)
+232
View File
@@ -0,0 +1,232 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from ray.rllib.models.tf.tf_modelv2 import TFModelV2
from ray.rllib.utils import try_import_tf, try_import_tfp
tf = try_import_tf()
tfp = try_import_tfp()
SCALE_DIAG_MIN_MAX = (-20, 2)
def SquashBijector():
# lazy def since it depends on tfp
class SquashBijector(tfp.bijectors.Bijector):
def __init__(self, validate_args=False, name="tanh"):
super(SquashBijector, self).__init__(
forward_min_event_ndims=0,
validate_args=validate_args,
name=name)
def _forward(self, x):
return tf.nn.tanh(x)
def _inverse(self, y):
return tf.atanh(y)
def _forward_log_det_jacobian(self, x):
return 2. * (np.log(2.) - x - tf.nn.softplus(-2. * x))
return SquashBijector()
class SACModel(TFModelV2):
"""Extension of standard TFModel for SAC.
Data flow:
obs -> forward() -> model_out
model_out -> get_policy_output() -> pi(s)
model_out, actions -> get_q_values() -> Q(s, a)
model_out, actions -> get_twin_q_values() -> Q_twin(s, a)
Note that this class by itself is not a valid model unless you
implement forward() in a subclass."""
def __init__(self,
obs_space,
action_space,
num_outputs,
model_config,
name,
actor_hidden_activation="relu",
actor_hiddens=(256, 256),
critic_hidden_activation="relu",
critic_hiddens=(256, 256),
twin_q=False):
"""Initialize variables of this model.
Extra model kwargs:
actor_hidden_activation (str): activation for actor network
actor_hiddens (list): hidden layers sizes for actor network
critic_hidden_activation (str): activation for critic network
critic_hiddens (list): hidden layers sizes for critic network
twin_q (bool): build twin Q networks
Note that the core layers for forward() are not defined here, this
only defines the layers for the output heads. Those layers for
forward() should be defined in subclasses of SACModel.
"""
if tfp is None:
raise ImportError("tensorflow-probability package not found")
super(SACModel, self).__init__(obs_space, action_space, num_outputs,
model_config, name)
self.action_dim = np.product(action_space.shape)
self.model_out = tf.keras.layers.Input(
shape=(num_outputs, ), name="model_out")
self.actions = tf.keras.layers.Input(
shape=(self.action_dim, ), name="actions")
shift_and_log_scale_diag = tf.keras.Sequential([
tf.keras.layers.Dense(
units=hidden,
activation=getattr(tf.nn, actor_hidden_activation),
name="action_hidden_{}".format(i))
for i, hidden in enumerate(actor_hiddens)
] + [
tf.keras.layers.Dense(
units=tfp.layers.MultivariateNormalTriL.params_size(
self.action_dim),
activation=None,
name="action_out")
])(self.model_out)
shift, log_scale_diag = tf.keras.layers.Lambda(
lambda shift_and_log_scale_diag: tf.split(
shift_and_log_scale_diag,
num_or_size_splits=2,
axis=-1)
)(shift_and_log_scale_diag)
log_scale_diag = tf.keras.layers.Lambda(
lambda log_sd: tf.clip_by_value(log_sd, *SCALE_DIAG_MIN_MAX))(
log_scale_diag)
shift_and_log_scale_diag = tf.keras.layers.Concatenate(axis=-1)(
[shift, log_scale_diag])
raw_action_distribution = tfp.layers.MultivariateNormalTriL(
self.action_dim)(shift_and_log_scale_diag)
action_distribution = tfp.layers.DistributionLambda(
make_distribution_fn=SquashBijector())(raw_action_distribution)
# TODO(hartikainen): Remove the unnecessary Model call here
self.action_distribution_model = tf.keras.Model(
self.model_out, action_distribution)
self.register_variables(self.action_distribution_model.variables)
def build_q_net(name, observations, actions):
q_net = tf.keras.Sequential([
tf.keras.layers.Concatenate(axis=1),
] + [
tf.keras.layers.Dense(
units=units,
activation=getattr(tf.nn, critic_hidden_activation),
name="{}_hidden_{}".format(name, i))
for i, units in enumerate(critic_hiddens)
] + [
tf.keras.layers.Dense(
units=1, activation=None, name="{}_out".format(name))
])
# TODO(hartikainen): Remove the unnecessary Model call here
q_net = tf.keras.Model([observations, actions],
q_net([observations, actions]))
return q_net
self.q_net = build_q_net("q", self.model_out, self.actions)
self.register_variables(self.q_net.variables)
if twin_q:
self.twin_q_net = build_q_net("twin_q", self.model_out,
self.actions)
self.register_variables(self.twin_q_net.variables)
else:
self.twin_q_net = None
self.log_alpha = tf.Variable(0.0, dtype=tf.float32, name="log_alpha")
self.alpha = tf.exp(self.log_alpha)
self.register_variables([self.log_alpha])
def forward(self, input_dict, state, seq_lens):
"""This generates the model_out tensor input.
You must implement this as documented in modelv2.py."""
raise NotImplementedError
def get_policy_output(self, model_out, deterministic=False):
"""Return the (unscaled) output of the policy network.
This returns the unscaled outputs of pi(s).
Arguments:
model_out (Tensor): obs embeddings from the model layers, of shape
[BATCH_SIZE, num_outputs].
Returns:
tensor of shape [BATCH_SIZE, action_dim] with range [-inf, inf].
"""
action_distribution = self.action_distribution_model(model_out)
if deterministic:
actions = action_distribution.bijector(
action_distribution.distribution.mean())
log_pis = None
else:
actions = action_distribution.sample()
log_pis = action_distribution.log_prob(actions)
return actions, log_pis
def get_q_values(self, model_out, actions):
"""Return the Q estimates for the most recent forward pass.
This implements Q(s, a).
Arguments:
model_out (Tensor): obs embeddings from the model layers, of shape
[BATCH_SIZE, num_outputs].
actions (Tensor): action values that correspond with the most
recent batch of observations passed through forward(), of shape
[BATCH_SIZE, action_dim].
Returns:
tensor of shape [BATCH_SIZE].
"""
return self.q_net([model_out, actions])
def get_twin_q_values(self, model_out, actions):
"""Same as get_q_values but using the twin Q net.
This implements the twin Q(s, a).
Arguments:
model_out (Tensor): obs embeddings from the model layers, of shape
[BATCH_SIZE, num_outputs].
actions (Tensor): action values that correspond with the most
recent batch of observations passed through forward(), of shape
[BATCH_SIZE, action_dim].
Returns:
tensor of shape [BATCH_SIZE].
"""
return self.twin_q_net([model_out, actions])
def policy_variables(self):
"""Return the list of variables for the policy net."""
return list(self.action_distribution_model.variables)
def q_variables(self):
"""Return the list of variables for Q / twin Q nets."""
return self.q_net.variables + (self.twin_q_net.variables
if self.twin_q_net else [])
+367
View File
@@ -0,0 +1,367 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from gym.spaces import Box
import numpy as np
import logging
import ray
import ray.experimental.tf_utils
from ray.rllib.agents.sac.sac_model import SACModel
from ray.rllib.agents.ddpg.noop_model import NoopModel
from ray.rllib.agents.dqn.dqn_policy import _postprocess_dqn, PRIO_WEIGHTS
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.policy.tf_policy_template import build_tf_policy
from ray.rllib.models import ModelCatalog
from ray.rllib.utils.error import UnsupportedSpaceException
from ray.rllib.utils import try_import_tf, try_import_tfp
from ray.rllib.utils.tf_ops import minimize_and_clip
tf = try_import_tf()
tfp = try_import_tfp()
logger = logging.getLogger(__name__)
def build_sac_model(policy, obs_space, action_space, config):
if config["model"]["custom_model"]:
logger.warning(
"Setting use_state_preprocessor=True since a custom model "
"was specified.")
config["use_state_preprocessor"] = True
if not isinstance(action_space, Box):
raise UnsupportedSpaceException(
"Action space {} is not supported for SAC.".format(action_space))
if len(action_space.shape) > 1:
raise UnsupportedSpaceException(
"Action space has multiple dimensions "
"{}. ".format(action_space.shape) +
"Consider reshaping this into a single dimension, "
"using a Tuple action space, or the multi-agent API.")
if config["use_state_preprocessor"]:
default_model = None # catalog decides
num_outputs = 256 # arbitrary
config["model"]["no_final_linear"] = True
else:
default_model = NoopModel
num_outputs = int(np.product(obs_space.shape))
policy.model = ModelCatalog.get_model_v2(
obs_space,
action_space,
num_outputs,
config["model"],
framework="tf",
model_interface=SACModel,
default_model=default_model,
name="sac_model",
actor_hidden_activation=config["policy_model"]["hidden_activation"],
actor_hiddens=config["policy_model"]["hidden_layer_sizes"],
critic_hidden_activation=config["Q_model"]["hidden_activation"],
critic_hiddens=config["Q_model"]["hidden_layer_sizes"],
twin_q=config["twin_q"])
policy.target_model = ModelCatalog.get_model_v2(
obs_space,
action_space,
num_outputs,
config["model"],
framework="tf",
model_interface=SACModel,
default_model=default_model,
name="target_sac_model",
actor_hidden_activation=config["policy_model"]["hidden_activation"],
actor_hiddens=config["policy_model"]["hidden_layer_sizes"],
critic_hidden_activation=config["Q_model"]["hidden_activation"],
critic_hiddens=config["Q_model"]["hidden_layer_sizes"],
twin_q=config["twin_q"])
return policy.model
def postprocess_trajectory(policy,
sample_batch,
other_agent_batches=None,
episode=None):
return _postprocess_dqn(policy, sample_batch)
def exploration_setting_inputs(policy):
return {
policy.stochastic: policy.config["exploration_enabled"],
}
def build_action_output(policy, model, input_dict, obs_space, action_space,
config):
model_out, _ = model({
"obs": input_dict[SampleBatch.CUR_OBS],
"is_training": policy._get_is_training_placeholder(),
}, [], None)
def unsquash_actions(actions):
# Use sigmoid to scale to [0,1], but also double magnitude of input to
# emulate behaviour of tanh activation used in SAC and TD3 papers.
sigmoid_out = tf.nn.sigmoid(2 * actions)
# Rescale to actual env policy scale
# (shape of sigmoid_out is [batch_size, dim_actions], so we reshape to
# get same dims)
action_range = (action_space.high - action_space.low)[None]
low_action = action_space.low[None]
unsquashed_actions = action_range * sigmoid_out + low_action
return unsquashed_actions
squashed_stochastic_actions, log_pis = policy.model.get_policy_output(
model_out, deterministic=False)
stochastic_actions = unsquash_actions(squashed_stochastic_actions)
squashed_deterministic_actions, _ = policy.model.get_policy_output(
model_out, deterministic=True)
deterministic_actions = unsquash_actions(squashed_deterministic_actions)
actions = tf.cond(policy.stochastic, lambda: stochastic_actions,
lambda: deterministic_actions)
action_probabilities = tf.cond(policy.stochastic, lambda: log_pis,
lambda: tf.zeros_like(log_pis))
policy.output_actions = actions
return actions, action_probabilities
def actor_critic_loss(policy, batch_tensors):
model_out_t, _ = policy.model({
"obs": batch_tensors[SampleBatch.CUR_OBS],
"is_training": policy._get_is_training_placeholder(),
}, [], None)
model_out_tp1, _ = policy.model({
"obs": batch_tensors[SampleBatch.NEXT_OBS],
"is_training": policy._get_is_training_placeholder(),
}, [], None)
target_model_out_tp1, _ = policy.target_model({
"obs": batch_tensors[SampleBatch.NEXT_OBS],
"is_training": policy._get_is_training_placeholder(),
}, [], None)
# TODO(hartikainen): figure actions and log pis
policy_t, log_pis_t = policy.model.get_policy_output(model_out_t)
policy_tp1, log_pis_tp1 = policy.model.get_policy_output(model_out_tp1)
log_alpha = policy.model.log_alpha
alpha = policy.model.alpha
# q network evaluation
q_t = policy.model.get_q_values(model_out_t,
batch_tensors[SampleBatch.ACTIONS])
if policy.config["twin_q"]:
twin_q_t = policy.model.get_twin_q_values(
model_out_t, batch_tensors[SampleBatch.ACTIONS])
# Q-values for current policy (no noise) in given current state
q_t_det_policy = policy.model.get_q_values(model_out_t, policy_t)
# target q network evaluation
q_tp1 = policy.target_model.get_q_values(target_model_out_tp1, policy_tp1)
if policy.config["twin_q"]:
twin_q_tp1 = policy.target_model.get_twin_q_values(
target_model_out_tp1, policy_tp1)
q_t_selected = tf.squeeze(q_t, axis=len(q_t.shape) - 1)
if policy.config["twin_q"]:
twin_q_t_selected = tf.squeeze(twin_q_t, axis=len(q_t.shape) - 1)
q_tp1 = tf.minimum(q_tp1, twin_q_tp1)
q_tp1 -= tf.expand_dims(alpha * log_pis_t, 1)
q_tp1_best = tf.squeeze(input=q_tp1, axis=len(q_tp1.shape) - 1)
q_tp1_best_masked = (1.0 - tf.cast(batch_tensors[SampleBatch.DONES],
tf.float32)) * q_tp1_best
assert policy.config["n_step"] == 1, "TODO(hartikainen) n_step > 1"
# compute RHS of bellman equation
q_t_selected_target = tf.stop_gradient(
batch_tensors[SampleBatch.REWARDS] +
policy.config["gamma"]**policy.config["n_step"] * q_tp1_best_masked)
# compute the error (potentially clipped)
if policy.config["twin_q"]:
td_error = q_t_selected - q_t_selected_target
twin_td_error = twin_q_t_selected - q_t_selected_target
td_error = td_error + twin_td_error
errors = 0.5 * (tf.square(td_error) + tf.square(twin_td_error))
else:
td_error = q_t_selected - q_t_selected_target
errors = 0.5 * tf.square(td_error)
critic_loss = policy.model.custom_loss(
tf.reduce_mean(batch_tensors[PRIO_WEIGHTS] * errors), batch_tensors)
actor_loss = tf.reduce_mean(alpha * log_pis_t - q_t_det_policy)
target_entropy = (-np.prod(policy.action_space.shape)
if policy.config["target_entropy"] == "auto" else
policy.config["target_entropy"])
alpha_loss = -tf.reduce_mean(
log_alpha * tf.stop_gradient(log_pis_t + target_entropy))
# save for stats function
policy.q_t = q_t
policy.td_error = td_error
policy.actor_loss = actor_loss
policy.critic_loss = critic_loss
policy.alpha_loss = alpha_loss
# in a custom apply op we handle the losses separately, but return them
# combined in one loss for now
return actor_loss + critic_loss + alpha_loss
def gradients(policy, optimizer, loss):
if policy.config["grad_norm_clipping"] is not None:
actor_grads_and_vars = minimize_and_clip(
policy._actor_optimizer,
policy.actor_loss,
var_list=policy.model.policy_variables(),
clip_val=policy.config["grad_norm_clipping"])
critic_grads_and_vars = minimize_and_clip(
policy._critic_optimizer,
policy.critic_loss,
var_list=policy.model.q_variables(),
clip_val=policy.config["grad_norm_clipping"])
alpha_grads_and_vars = minimize_and_clip(
policy._alpha_optimizer,
policy.alpha_loss,
var_list=policy.model.alpha,
clip_val=policy.config["grad_norm_clipping"])
else:
actor_grads_and_vars = policy._actor_optimizer.compute_gradients(
policy.actor_loss, var_list=policy.model.policy_variables())
critic_grads_and_vars = policy._critic_optimizer.compute_gradients(
policy.critic_loss, var_list=policy.model.q_variables())
alpha_grads_and_vars = policy._critic_optimizer.compute_gradients(
policy.alpha_loss, var_list=policy.model.alpha)
# save these for later use in build_apply_op
policy._actor_grads_and_vars = [(g, v) for (g, v) in actor_grads_and_vars
if g is not None]
policy._critic_grads_and_vars = [(g, v) for (g, v) in critic_grads_and_vars
if g is not None]
policy._alpha_grads_and_vars = [(g, v) for (g, v) in alpha_grads_and_vars
if g is not None]
grads_and_vars = (
policy._actor_grads_and_vars + policy._critic_grads_and_vars +
policy._alpha_grads_and_vars)
return grads_and_vars
def stats(policy, batch_tensors):
return {
"td_error": tf.reduce_mean(policy.td_error),
"actor_loss": tf.reduce_mean(policy.actor_loss),
"critic_loss": tf.reduce_mean(policy.critic_loss),
"mean_q": tf.reduce_mean(policy.q_t),
"max_q": tf.reduce_max(policy.q_t),
"min_q": tf.reduce_min(policy.q_t),
}
class ExplorationStateMixin(object):
def __init__(self, obs_space, action_space, config):
self.stochastic = tf.placeholder(tf.bool, (), name="stochastic")
def set_epsilon(self, epsilon):
pass
class TargetNetworkMixin(object):
def __init__(self, config):
# update_target_fn will be called periodically to copy Q network to
# target Q network
self.tau_value = config.get("tau")
self.tau = tf.placeholder(tf.float32, (), name="tau")
update_target_expr = []
model_vars = self.model.trainable_variables()
target_model_vars = self.target_model.trainable_variables()
assert len(model_vars) == len(target_model_vars), \
(model_vars, target_model_vars)
for var, var_target in zip(model_vars, target_model_vars):
update_target_expr.append(
var_target.assign(self.tau * var +
(1.0 - self.tau) * var_target))
logger.debug("Update target op {}".format(var_target))
self.update_target_expr = tf.group(*update_target_expr)
# Hard initial update
self.update_target(tau=1.0)
# support both hard and soft sync
def update_target(self, tau=None):
tau = tau or self.tau_value
return self.get_session().run(
self.update_target_expr, feed_dict={self.tau: tau})
class ActorCriticOptimizerMixin(object):
def __init__(self, config):
# create global step for counting the number of update operations
self.global_step = tf.train.get_or_create_global_step()
# use separate optimizers for actor & critic
self._actor_optimizer = tf.train.AdamOptimizer(
learning_rate=config["optimization"]["actor_learning_rate"])
self._critic_optimizer = tf.train.AdamOptimizer(
learning_rate=config["optimization"]["critic_learning_rate"])
self._alpha_optimizer = tf.train.AdamOptimizer(
learning_rate=config["optimization"]["entropy_learning_rate"])
class ComputeTDErrorMixin(object):
def compute_td_error(self, obs_t, act_t, rew_t, obs_tp1, done_mask,
importance_weights):
if not self.loss_initialized():
return np.zeros_like(rew_t)
td_err = self.get_session().run(
self.td_error,
feed_dict={
self.get_placeholder(SampleBatch.CUR_OBS): [
np.array(ob) for ob in obs_t
],
self.get_placeholder(SampleBatch.ACTIONS): act_t,
self.get_placeholder(SampleBatch.REWARDS): rew_t,
self.get_placeholder(SampleBatch.NEXT_OBS): [
np.array(ob) for ob in obs_tp1
],
self.get_placeholder(SampleBatch.DONES): done_mask,
self.get_placeholder(PRIO_WEIGHTS): importance_weights
})
return td_err
def setup_early_mixins(policy, obs_space, action_space, config):
ExplorationStateMixin.__init__(policy, obs_space, action_space, config)
ActorCriticOptimizerMixin.__init__(policy, config)
def setup_late_mixins(policy, obs_space, action_space, config):
TargetNetworkMixin.__init__(policy, config)
SACTFPolicy = build_tf_policy(
name="SACTFPolicy",
get_default_config=lambda: ray.rllib.agents.sac.sac.DEFAULT_CONFIG,
make_model=build_sac_model,
postprocess_fn=postprocess_trajectory,
extra_action_feed_fn=exploration_setting_inputs,
action_sampler_fn=build_action_output,
loss_fn=actor_critic_loss,
stats_fn=stats,
gradients_fn=gradients,
extra_learn_fetches_fn=lambda policy: {"td_error": policy.td_error},
mixins=[
TargetNetworkMixin, ExplorationStateMixin, ActorCriticOptimizerMixin,
ComputeTDErrorMixin
],
before_init=setup_early_mixins,
after_init=setup_late_mixins,
obs_include_prev_action_reward=False)
+4
View File
@@ -88,6 +88,10 @@ COMMON_CONFIG = {
# hit. This allows value estimation and RNN state to span across logical
# episodes denoted by horizon. This only has an effect if horizon != inf.
"soft_horizon": False,
# Don't set 'done' at the end of the episode. Note that you still need to
# set this if soft_horizon=True, unless your env is actually running
# forever without returning done=True.
"no_done_at_end": False,
# Arguments to pass to the env creator
"env_config": {},
# Environment name can also be passed via config
@@ -132,6 +132,7 @@ class RolloutWorker(EvaluatorInterface):
remote_worker_envs=False,
remote_env_batch_wait_ms=0,
soft_horizon=False,
no_done_at_end=False,
seed=None,
_fake_sampler=False):
"""Initialize a rollout worker.
@@ -218,6 +219,8 @@ class RolloutWorker(EvaluatorInterface):
step / reset and model inference perf.
soft_horizon (bool): Calculate rewards but don't reset the
environment when the horizon is hit.
no_done_at_end (bool): Ignore the done=True at the end of the
episode and instead record done=False.
seed (int): Set the seed of both np and tf to this value to
to ensure each remote worker has unique exploration behavior.
_fake_sampler (bool): Use a fake (inf speed) sampler for testing.
@@ -408,7 +411,8 @@ class RolloutWorker(EvaluatorInterface):
tf_sess=self.tf_sess,
clip_actions=clip_actions,
blackhole_outputs="simulation" in input_evaluation,
soft_horizon=soft_horizon)
soft_horizon=soft_horizon,
no_done_at_end=no_done_at_end)
self.sampler.start()
else:
self.sampler = SyncSampler(
@@ -424,7 +428,8 @@ class RolloutWorker(EvaluatorInterface):
pack=pack_episodes,
tf_sess=self.tf_sess,
clip_actions=clip_actions,
soft_horizon=soft_horizon)
soft_horizon=soft_horizon,
no_done_at_end=no_done_at_end)
self.input_reader = input_creator(self.io_context)
assert isinstance(self.input_reader, InputReader), self.input_reader
@@ -185,7 +185,8 @@ class MultiAgentSampleBatchBuilder(object):
"agent {} (policy {}). ".format(
agent_id, self.agent_to_policy[agent_id]) +
"Please ensure that you include the last observations "
"of all live agents when setting '__all__' done to True.")
"of all live agents when setting '__all__' done to True. "
"Alternatively, set no_done_at_end=True to allow this.")
@DeveloperAPI
def build_and_reset(self, episode):
+17 -10
View File
@@ -75,7 +75,8 @@ class SyncSampler(SamplerInput):
pack=False,
tf_sess=None,
clip_actions=True,
soft_horizon=False):
soft_horizon=False,
no_done_at_end=False):
self.base_env = BaseEnv.to_base_env(env)
self.unroll_length = unroll_length
self.horizon = horizon
@@ -89,7 +90,8 @@ class SyncSampler(SamplerInput):
self.base_env, self.extra_batches.put, self.policies,
self.policy_mapping_fn, self.unroll_length, self.horizon,
self.preprocessors, self.obs_filters, clip_rewards, clip_actions,
pack, callbacks, tf_sess, self.perf_stats, soft_horizon)
pack, callbacks, tf_sess, self.perf_stats, soft_horizon,
no_done_at_end)
self.metrics_queue = queue.Queue()
def get_data(self):
@@ -135,7 +137,8 @@ class AsyncSampler(threading.Thread, SamplerInput):
tf_sess=None,
clip_actions=True,
blackhole_outputs=False,
soft_horizon=False):
soft_horizon=False,
no_done_at_end=False):
for _, f in obs_filters.items():
assert getattr(f, "is_concurrent", False), \
"Observation Filter must support concurrent updates."
@@ -158,6 +161,7 @@ class AsyncSampler(threading.Thread, SamplerInput):
self.clip_actions = clip_actions
self.blackhole_outputs = blackhole_outputs
self.soft_horizon = soft_horizon
self.no_done_at_end = no_done_at_end
self.perf_stats = PerfStats()
self.shutdown = False
@@ -181,7 +185,7 @@ class AsyncSampler(threading.Thread, SamplerInput):
self.policy_mapping_fn, self.unroll_length, self.horizon,
self.preprocessors, self.obs_filters, self.clip_rewards,
self.clip_actions, self.pack, self.callbacks, self.tf_sess,
self.perf_stats, self.soft_horizon)
self.perf_stats, self.soft_horizon, self.no_done_at_end)
while not self.shutdown:
# The timeout variable exists because apparently, if one worker
# dies, the other workers won't die with it, unless the timeout is
@@ -226,7 +230,7 @@ class AsyncSampler(threading.Thread, SamplerInput):
def _env_runner(base_env, extra_batch_callback, policies, policy_mapping_fn,
unroll_length, horizon, preprocessors, obs_filters,
clip_rewards, clip_actions, pack, callbacks, tf_sess,
perf_stats, soft_horizon):
perf_stats, soft_horizon, no_done_at_end):
"""This implements the common experience collection logic.
Args:
@@ -253,6 +257,8 @@ def _env_runner(base_env, extra_batch_callback, policies, policy_mapping_fn,
perf_stats (PerfStats): Record perf stats into this object.
soft_horizon (bool): Calculate rewards but don't reset the
environment when the horizon is hit.
no_done_at_end (bool): Ignore the done=True at the end of the episode
and instead record done=False.
Yields:
rollout (SampleBatch): Object containing state, action, reward,
@@ -310,7 +316,7 @@ def _env_runner(base_env, extra_batch_callback, policies, policy_mapping_fn,
base_env, policies, batch_builder_pool, active_episodes,
unfiltered_obs, rewards, dones, infos, off_policy_actions, horizon,
preprocessors, obs_filters, unroll_length, pack, callbacks,
soft_horizon)
soft_horizon, no_done_at_end)
perf_stats.processing_time += time.time() - t1
for o in outputs:
yield o
@@ -339,7 +345,7 @@ def _process_observations(base_env, policies, batch_builder_pool,
active_episodes, unfiltered_obs, rewards, dones,
infos, off_policy_actions, horizon, preprocessors,
obs_filters, unroll_length, pack, callbacks,
soft_horizon):
soft_horizon, no_done_at_end):
"""Record new data from the environment and prepare for policy evaluation.
Returns:
@@ -434,8 +440,9 @@ def _process_observations(base_env, policies, batch_builder_pool,
rewards=rewards[env_id][agent_id],
prev_actions=episode.prev_action_for(agent_id),
prev_rewards=episode.prev_reward_for(agent_id),
dones=(False
if (hit_horizon and soft_horizon) else agent_done),
dones=(False if (no_done_at_end
or (hit_horizon and soft_horizon)) else
agent_done),
infos=infos[env_id].get(agent_id, {}),
new_obs=filtered_obs,
**episode.last_pi_info_for(agent_id))
@@ -447,7 +454,7 @@ def _process_observations(base_env, policies, batch_builder_pool,
# Cut the batch if we're not packing multiple episodes into one,
# or if we've exceeded the requested batch size.
if episode.batch_builder.has_pending_data():
if dones[env_id]["__all__"]:
if dones[env_id]["__all__"] and not no_done_at_end:
episode.batch_builder.check_missing_dones()
if (all_done and not pack) or \
episode.batch_builder.count >= unroll_length:
@@ -211,6 +211,7 @@ class WorkerSet(object):
remote_worker_envs=config["remote_worker_envs"],
remote_env_batch_wait_ms=config["remote_env_batch_wait_ms"],
soft_horizon=config["soft_horizon"],
no_done_at_end=config["no_done_at_end"],
seed=(config["seed"] + worker_index)
if config["seed"] is not None else None,
_fake_sampler=config.get("_fake_sampler", False))
@@ -24,6 +24,7 @@ def get_mean_action(alg, obs):
ray.init(num_cpus=10)
CONFIGS = {
"SAC": {},
"ES": {
"episodes_per_batch": 10,
"train_batch_size": 100,
@@ -62,7 +63,7 @@ CONFIGS = {
def test_ckpt_restore(use_object_store, alg_name, failures):
cls = get_agent_class(alg_name)
if "DDPG" in alg_name:
if "DDPG" in alg_name or "SAC" in alg_name:
alg1 = cls(config=CONFIGS[name], env="Pendulum-v0")
alg2 = cls(config=CONFIGS[name], env="Pendulum-v0")
env = gym.make("Pendulum-v0")
@@ -82,7 +83,7 @@ def test_ckpt_restore(use_object_store, alg_name, failures):
alg2.restore(alg1.save())
for _ in range(10):
if "DDPG" in alg_name:
if "DDPG" in alg_name or "SAC" in alg_name:
obs = np.clip(
np.random.uniform(size=3),
env.observation_space.low,
@@ -110,7 +111,7 @@ def test_export(algo_name, failures):
and os.path.exists(os.path.join(checkpoint_dir, "checkpoint"))
cls = get_agent_class(algo_name)
if "DDPG" in algo_name:
if "DDPG" in algo_name or "SAC" in algo_name:
algo = cls(config=CONFIGS[name], env="Pendulum-v0")
else:
algo = cls(config=CONFIGS[name], env="CartPole-v0")
@@ -145,14 +146,16 @@ def test_export(algo_name, failures):
if __name__ == "__main__":
failures = []
for use_object_store in [False, True]:
for name in ["ES", "DQN", "DDPG", "PPO", "A3C", "APEX_DDPG", "ARS"]:
for name in [
"SAC", "ES", "DQN", "DDPG", "PPO", "A3C", "APEX_DDPG", "ARS"
]:
test_ckpt_restore(use_object_store, name, failures)
assert not failures, failures
print("All checkpoint restore tests passed!")
failures = []
for name in ["DQN", "DDPG", "PPO", "A3C"]:
for name in ["SAC", "DQN", "DDPG", "PPO", "A3C"]:
test_export(name, failures)
assert not failures, failures
print("All export tests passed!")
@@ -0,0 +1,10 @@
pendulum-sac:
env: Pendulum-v0
run: SAC
stop:
episode_reward_mean: -300 # note that evaluation perf is higher
timesteps_total: 15000
config:
evaluation_interval: 1 # logged under evaluation/* metric keys
soft_horizon: true
metrics_smoothing_episodes: 10
+13
View File
@@ -76,6 +76,19 @@ def try_import_tf():
return None
def try_import_tfp():
if "RLLIB_TEST_NO_TF_IMPORT" in os.environ:
logger.warning(
"Not importing TensorFlow Probability for test purposes.")
return None
try:
import tensorflow_probability as tfp
return tfp
except ImportError:
return None
__all__ = [
"Filter",
"FilterManager",