mirror of
https://github.com/wassname/ray.git
synced 2026-06-27 20:38:19 +08:00
[rllib] Feature/soft actor critic v2 (#5328)
* Add base for Soft Actor-Critic * Pick changes from old SAC branch * Update sac.py * First implementation of sac model * Remove unnecessary SAC imports * Prune unnecessary noise and exploration code * Implement SAC model and use that in SAC policy * runs but doesn't learn * clear state * fix batch size * Add missing alpha grads and vars * -200 by 2k timesteps * doc * lazy squash * one file * ignore tfp * revert done
This commit is contained in:
committed by
Eric Liang
parent
3ae54a2b20
commit
13fb9fe3db
@@ -28,6 +28,7 @@ MOCK_MODULES = [
|
||||
"scipy",
|
||||
"scipy.signal",
|
||||
"scipy.stats",
|
||||
"tensorflow_probability",
|
||||
"tensorflow",
|
||||
"tensorflow.contrib",
|
||||
"tensorflow.contrib.all_reduce",
|
||||
|
||||
@@ -164,12 +164,12 @@ Tuned examples: `PongDeterministic-v4 <https://github.com/ray-project/ray/blob/m
|
||||
**Atari results @10M steps**: `more details <https://github.com/ray-project/rl-experiments>`__
|
||||
|
||||
============= ======================== ============================= ============================== ===============================
|
||||
Atari env RLlib DQN RLlib Dueling DDQN RLlib Dist. DQN Hessel et al. DQN
|
||||
Atari env RLlib DQN RLlib Dueling DDQN RLlib Dist. DQN Hessel et al. DQN
|
||||
============= ======================== ============================= ============================== ===============================
|
||||
BeamRider 2869 1910 4447 ~2000
|
||||
Breakout 287 312 410 ~150
|
||||
Qbert 3921 7968 15780 ~4000
|
||||
SpaceInvaders 650 1001 1025 ~500
|
||||
BeamRider 2869 1910 4447 ~2000
|
||||
Breakout 287 312 410 ~150
|
||||
Qbert 3921 7968 15780 ~4000
|
||||
SpaceInvaders 650 1001 1025 ~500
|
||||
============= ======================== ============================= ============================== ===============================
|
||||
|
||||
**DQN-specific configs** (see also `common configs <rllib-training.html#common-parameters>`__):
|
||||
@@ -217,7 +217,7 @@ SpaceInvaders 671 944 ~800
|
||||
============= ========================= =============================
|
||||
MuJoCo env RLlib PPO 16-workers @ 1h Fan et al PPO 16-workers @ 1h
|
||||
============= ========================= =============================
|
||||
HalfCheetah 9664 ~7700
|
||||
HalfCheetah 9664 ~7700
|
||||
============= ========================= =============================
|
||||
|
||||
.. figure:: ppo.png
|
||||
@@ -232,6 +232,21 @@ HalfCheetah 9664 ~7700
|
||||
:start-after: __sphinx_doc_begin__
|
||||
:end-before: __sphinx_doc_end__
|
||||
|
||||
-Soft Actor Critic (SAC)
|
||||
------------------------
|
||||
`[paper] <https://arxiv.org/pdf/1801.01290>`__ `[implementation] <https://github.com/ray-project/ray/blob/master/python/ray/rllib/agents/sac/sac.py>`__
|
||||
|
||||
RLlib's soft-actor critic implementation is ported from the `official SAC repo <https://github.com/rail-berkeley/softlearning>`__ to better integrate with RLlib APIs. Note that SAC has two fields to configure for custom models: ``policy_model`` and ``Q_model``, and currently has no support for non-continuous action distributions. It is also currently *experimental*.
|
||||
|
||||
Tuned examples: `Pendulum-v0 <https://github.com/ray-project/ray/blob/master/python/ray/rllib/tuned_examples/regression_tests/pendulum-sac.yaml>`__
|
||||
|
||||
**SAC-specific configs** (see also `common configs <rllib-training.html#common-parameters>`__):
|
||||
|
||||
.. literalinclude:: ../../python/ray/rllib/agents/sac/sac.py
|
||||
:language: python
|
||||
:start-after: __sphinx_doc_begin__
|
||||
:end-before: __sphinx_doc_end__
|
||||
|
||||
Derivative-free
|
||||
~~~~~~~~~~~~~~~
|
||||
|
||||
|
||||
@@ -18,6 +18,7 @@ DQN, Rainbow **Yes** `+parametric`_ No **Yes** No
|
||||
DDPG, TD3 No **Yes** **Yes** No
|
||||
APEX-DQN **Yes** `+parametric`_ No **Yes** No
|
||||
APEX-DDPG No **Yes** **Yes** No
|
||||
SAC (todo) **Yes** **Yes** No
|
||||
ES **Yes** **Yes** No No
|
||||
ARS **Yes** **Yes** No No
|
||||
QMIX **Yes** No **Yes** **Yes**
|
||||
|
||||
@@ -37,7 +37,7 @@ The ``rllib train`` command (same as the ``train.py`` script in the repo) has a
|
||||
The most important options are for choosing the environment
|
||||
with ``--env`` (any OpenAI gym environment including ones registered by the user
|
||||
can be used) and for choosing the algorithm with ``--run``
|
||||
(available options are ``PPO``, ``PG``, ``A2C``, ``A3C``, ``IMPALA``, ``ES``, ``DDPG``, ``DQN``, ``MARWIL``, ``APEX``, and ``APEX_DDPG``).
|
||||
(available options are ``SAC``, ``PPO``, ``PG``, ``A2C``, ``A3C``, ``IMPALA``, ``ES``, ``DDPG``, ``DQN``, ``MARWIL``, ``APEX``, and ``APEX_DDPG``).
|
||||
|
||||
Evaluating Trained Policies
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
@@ -73,6 +73,8 @@ Algorithms
|
||||
|
||||
- `Proximal Policy Optimization (PPO) <rllib-algorithms.html#proximal-policy-optimization-ppo>`__
|
||||
|
||||
- `Soft Actor Critic (SAC) <rllib-algorithms.html#soft-actor-critic-sac>`__
|
||||
|
||||
* Derivative-free
|
||||
|
||||
- `Augmented Random Search (ARS) <rllib-algorithms.html#augmented-random-search-ars>`__
|
||||
|
||||
@@ -7,7 +7,7 @@ RUN conda install -y numpy
|
||||
RUN apt-get install -y zlib1g-dev
|
||||
# The following is needed to support TensorFlow 1.14
|
||||
RUN conda remove -y --force wrapt
|
||||
RUN pip install gym[atari] opencv-python-headless tensorflow lz4 keras pytest-timeout smart_open
|
||||
RUN pip install gym[atari] opencv-python-headless tensorflow lz4 keras pytest-timeout smart_open tensorflow_probability
|
||||
RUN pip install -U h5py # Mutes FutureWarnings
|
||||
RUN pip install --upgrade bayesian-optimization
|
||||
RUN pip install --upgrade git+git://github.com/hyperopt/hyperopt.git
|
||||
|
||||
@@ -9,6 +9,11 @@ import traceback
|
||||
from ray.rllib.contrib.registry import CONTRIBUTED_ALGORITHMS
|
||||
|
||||
|
||||
def _import_sac():
|
||||
from ray.rllib.agents import sac
|
||||
return sac.SACTrainer
|
||||
|
||||
|
||||
def _import_appo():
|
||||
from ray.rllib.agents import ppo
|
||||
return ppo.APPOTrainer
|
||||
@@ -95,6 +100,7 @@ def _import_marwil():
|
||||
|
||||
|
||||
ALGORITHMS = {
|
||||
"SAC": _import_sac,
|
||||
"DDPG": _import_ddpg,
|
||||
"APEX_DDPG": _import_apex_ddpg,
|
||||
"TD3": _import_td3,
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
Implementation of Soft Actor-Critic (https://arxiv.org/abs/1812.05905.pdf).
|
||||
@@ -0,0 +1,13 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from ray.rllib.agents.sac.sac import SACTrainer, DEFAULT_CONFIG
|
||||
from ray.rllib.utils import renamed_agent
|
||||
|
||||
SACAgent = renamed_agent(SACTrainer)
|
||||
|
||||
__all__ = [
|
||||
"SACTrainer",
|
||||
"DEFAULT_CONFIG",
|
||||
]
|
||||
@@ -0,0 +1,119 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from ray.rllib.agents.trainer import with_common_config
|
||||
from ray.rllib.agents.dqn.dqn import GenericOffPolicyTrainer
|
||||
from ray.rllib.agents.sac.sac_policy import SACTFPolicy
|
||||
|
||||
OPTIMIZER_SHARED_CONFIGS = [
|
||||
"buffer_size", "prioritized_replay", "prioritized_replay_alpha",
|
||||
"prioritized_replay_beta", "prioritized_replay_eps", "sample_batch_size",
|
||||
"train_batch_size", "learning_starts"
|
||||
]
|
||||
|
||||
# yapf: disable
|
||||
# __sphinx_doc_begin__
|
||||
DEFAULT_CONFIG = with_common_config({
|
||||
# === Model ===
|
||||
"twin_q": True,
|
||||
"use_state_preprocessor": False,
|
||||
"policy": "GaussianLatentSpacePolicy",
|
||||
# RLlib model options for the Q function
|
||||
"Q_model": {
|
||||
"hidden_activation": "relu",
|
||||
"hidden_layer_sizes": (256, 256),
|
||||
},
|
||||
# RLlib model options for the policy function
|
||||
"policy_model": {
|
||||
"hidden_activation": "relu",
|
||||
"hidden_layer_sizes": (256, 256),
|
||||
},
|
||||
|
||||
# === Learning ===
|
||||
# Update the target by \tau * policy + (1-\tau) * target_policy
|
||||
"tau": 5e-3,
|
||||
# Target entropy lower bound. This is the inverse of reward scale,
|
||||
# and will be optimized automatically.
|
||||
"target_entropy": "auto",
|
||||
# Disable setting done=True at end of episode.
|
||||
"no_done_at_end": True,
|
||||
# N-step target updates
|
||||
"n_step": 1,
|
||||
|
||||
# === Evaluation ===
|
||||
# The evaluation stats will be reported under the "evaluation" metric key.
|
||||
"evaluation_interval": 1,
|
||||
# Number of episodes to run per evaluation period.
|
||||
"evaluation_num_episodes": 1,
|
||||
# Extra configuration that disables exploration.
|
||||
"evaluation_config": {
|
||||
"exploration_enabled": False,
|
||||
},
|
||||
|
||||
# === Exploration ===
|
||||
# Number of env steps to optimize for before returning
|
||||
"timesteps_per_iteration": 1000,
|
||||
"exploration_enabled": True,
|
||||
|
||||
# === Replay buffer ===
|
||||
# Size of the replay buffer. Note that if async_updates is set, then
|
||||
# each worker will have a replay buffer of this size.
|
||||
"buffer_size": int(1e6),
|
||||
# If True prioritized replay buffer will be used.
|
||||
# TODO(hartikainen): Make sure this works or remove the option.
|
||||
"prioritized_replay": False,
|
||||
"prioritized_replay_alpha": 0.6,
|
||||
"prioritized_replay_beta": 0.4,
|
||||
"prioritized_replay_eps": 1e-6,
|
||||
"beta_annealing_fraction": 0.2,
|
||||
"final_prioritized_replay_beta": 0.4,
|
||||
"compress_observations": False,
|
||||
|
||||
# === Optimization ===
|
||||
"optimization": {
|
||||
"actor_learning_rate": 3e-4,
|
||||
"critic_learning_rate": 3e-4,
|
||||
"entropy_learning_rate": 3e-4,
|
||||
},
|
||||
# If not None, clip gradients during optimization at this value
|
||||
"grad_norm_clipping": None,
|
||||
# How many steps of the model to sample before learning starts.
|
||||
"learning_starts": 1500,
|
||||
# Update the replay buffer with this many samples at once. Note that this
|
||||
# setting applies per-worker if num_workers > 1.
|
||||
"sample_batch_size": 1,
|
||||
# Size of a batched sampled from replay buffer for training. Note that
|
||||
# if async_updates is set, then each worker returns gradients for a
|
||||
# batch of this size.
|
||||
"train_batch_size": 256,
|
||||
# Update the target network every `target_network_update_freq` steps.
|
||||
"target_network_update_freq": 0,
|
||||
|
||||
# === Parallelism ===
|
||||
# Whether to use a GPU for local optimization.
|
||||
"num_gpus": 0,
|
||||
# Number of workers for collecting samples with. This only makes sense
|
||||
# to increase if your environment is particularly slow to sample, or if
|
||||
# you"re using the Async or Ape-X optimizers.
|
||||
"num_workers": 0,
|
||||
# Whether to allocate GPUs for workers (if > 0).
|
||||
"num_gpus_per_worker": 0,
|
||||
# Whether to allocate CPUs for workers (if > 0).
|
||||
"num_cpus_per_worker": 1,
|
||||
# Whether to compute priorities on workers.
|
||||
"worker_side_prioritization": False,
|
||||
# Prevent iterations from going lower than this time span
|
||||
"min_iter_time_s": 1,
|
||||
|
||||
# TODO(ekl) these are unused; remove them from sac config
|
||||
"per_worker_exploration": False,
|
||||
"exploration_fraction": 0.1,
|
||||
"schedule_max_timesteps": 100000,
|
||||
"exploration_final_eps": 0.02,
|
||||
})
|
||||
# __sphinx_doc_end__
|
||||
# yapf: enable
|
||||
|
||||
SACTrainer = GenericOffPolicyTrainer.with_updates(
|
||||
name="SAC", default_config=DEFAULT_CONFIG, default_policy=SACTFPolicy)
|
||||
@@ -0,0 +1,232 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ray.rllib.models.tf.tf_modelv2 import TFModelV2
|
||||
from ray.rllib.utils import try_import_tf, try_import_tfp
|
||||
|
||||
tf = try_import_tf()
|
||||
tfp = try_import_tfp()
|
||||
|
||||
SCALE_DIAG_MIN_MAX = (-20, 2)
|
||||
|
||||
|
||||
def SquashBijector():
|
||||
# lazy def since it depends on tfp
|
||||
class SquashBijector(tfp.bijectors.Bijector):
|
||||
def __init__(self, validate_args=False, name="tanh"):
|
||||
super(SquashBijector, self).__init__(
|
||||
forward_min_event_ndims=0,
|
||||
validate_args=validate_args,
|
||||
name=name)
|
||||
|
||||
def _forward(self, x):
|
||||
return tf.nn.tanh(x)
|
||||
|
||||
def _inverse(self, y):
|
||||
return tf.atanh(y)
|
||||
|
||||
def _forward_log_det_jacobian(self, x):
|
||||
return 2. * (np.log(2.) - x - tf.nn.softplus(-2. * x))
|
||||
|
||||
return SquashBijector()
|
||||
|
||||
|
||||
class SACModel(TFModelV2):
|
||||
"""Extension of standard TFModel for SAC.
|
||||
|
||||
Data flow:
|
||||
obs -> forward() -> model_out
|
||||
model_out -> get_policy_output() -> pi(s)
|
||||
model_out, actions -> get_q_values() -> Q(s, a)
|
||||
model_out, actions -> get_twin_q_values() -> Q_twin(s, a)
|
||||
|
||||
Note that this class by itself is not a valid model unless you
|
||||
implement forward() in a subclass."""
|
||||
|
||||
def __init__(self,
|
||||
obs_space,
|
||||
action_space,
|
||||
num_outputs,
|
||||
model_config,
|
||||
name,
|
||||
actor_hidden_activation="relu",
|
||||
actor_hiddens=(256, 256),
|
||||
critic_hidden_activation="relu",
|
||||
critic_hiddens=(256, 256),
|
||||
twin_q=False):
|
||||
"""Initialize variables of this model.
|
||||
|
||||
Extra model kwargs:
|
||||
actor_hidden_activation (str): activation for actor network
|
||||
actor_hiddens (list): hidden layers sizes for actor network
|
||||
critic_hidden_activation (str): activation for critic network
|
||||
critic_hiddens (list): hidden layers sizes for critic network
|
||||
twin_q (bool): build twin Q networks
|
||||
|
||||
Note that the core layers for forward() are not defined here, this
|
||||
only defines the layers for the output heads. Those layers for
|
||||
forward() should be defined in subclasses of SACModel.
|
||||
"""
|
||||
|
||||
if tfp is None:
|
||||
raise ImportError("tensorflow-probability package not found")
|
||||
|
||||
super(SACModel, self).__init__(obs_space, action_space, num_outputs,
|
||||
model_config, name)
|
||||
|
||||
self.action_dim = np.product(action_space.shape)
|
||||
self.model_out = tf.keras.layers.Input(
|
||||
shape=(num_outputs, ), name="model_out")
|
||||
self.actions = tf.keras.layers.Input(
|
||||
shape=(self.action_dim, ), name="actions")
|
||||
|
||||
shift_and_log_scale_diag = tf.keras.Sequential([
|
||||
tf.keras.layers.Dense(
|
||||
units=hidden,
|
||||
activation=getattr(tf.nn, actor_hidden_activation),
|
||||
name="action_hidden_{}".format(i))
|
||||
for i, hidden in enumerate(actor_hiddens)
|
||||
] + [
|
||||
tf.keras.layers.Dense(
|
||||
units=tfp.layers.MultivariateNormalTriL.params_size(
|
||||
self.action_dim),
|
||||
activation=None,
|
||||
name="action_out")
|
||||
])(self.model_out)
|
||||
|
||||
shift, log_scale_diag = tf.keras.layers.Lambda(
|
||||
lambda shift_and_log_scale_diag: tf.split(
|
||||
shift_and_log_scale_diag,
|
||||
num_or_size_splits=2,
|
||||
axis=-1)
|
||||
)(shift_and_log_scale_diag)
|
||||
|
||||
log_scale_diag = tf.keras.layers.Lambda(
|
||||
lambda log_sd: tf.clip_by_value(log_sd, *SCALE_DIAG_MIN_MAX))(
|
||||
log_scale_diag)
|
||||
|
||||
shift_and_log_scale_diag = tf.keras.layers.Concatenate(axis=-1)(
|
||||
[shift, log_scale_diag])
|
||||
|
||||
raw_action_distribution = tfp.layers.MultivariateNormalTriL(
|
||||
self.action_dim)(shift_and_log_scale_diag)
|
||||
|
||||
action_distribution = tfp.layers.DistributionLambda(
|
||||
make_distribution_fn=SquashBijector())(raw_action_distribution)
|
||||
|
||||
# TODO(hartikainen): Remove the unnecessary Model call here
|
||||
self.action_distribution_model = tf.keras.Model(
|
||||
self.model_out, action_distribution)
|
||||
|
||||
self.register_variables(self.action_distribution_model.variables)
|
||||
|
||||
def build_q_net(name, observations, actions):
|
||||
q_net = tf.keras.Sequential([
|
||||
tf.keras.layers.Concatenate(axis=1),
|
||||
] + [
|
||||
tf.keras.layers.Dense(
|
||||
units=units,
|
||||
activation=getattr(tf.nn, critic_hidden_activation),
|
||||
name="{}_hidden_{}".format(name, i))
|
||||
for i, units in enumerate(critic_hiddens)
|
||||
] + [
|
||||
tf.keras.layers.Dense(
|
||||
units=1, activation=None, name="{}_out".format(name))
|
||||
])
|
||||
|
||||
# TODO(hartikainen): Remove the unnecessary Model call here
|
||||
q_net = tf.keras.Model([observations, actions],
|
||||
q_net([observations, actions]))
|
||||
return q_net
|
||||
|
||||
self.q_net = build_q_net("q", self.model_out, self.actions)
|
||||
self.register_variables(self.q_net.variables)
|
||||
|
||||
if twin_q:
|
||||
self.twin_q_net = build_q_net("twin_q", self.model_out,
|
||||
self.actions)
|
||||
self.register_variables(self.twin_q_net.variables)
|
||||
else:
|
||||
self.twin_q_net = None
|
||||
|
||||
self.log_alpha = tf.Variable(0.0, dtype=tf.float32, name="log_alpha")
|
||||
self.alpha = tf.exp(self.log_alpha)
|
||||
|
||||
self.register_variables([self.log_alpha])
|
||||
|
||||
def forward(self, input_dict, state, seq_lens):
|
||||
"""This generates the model_out tensor input.
|
||||
|
||||
You must implement this as documented in modelv2.py."""
|
||||
raise NotImplementedError
|
||||
|
||||
def get_policy_output(self, model_out, deterministic=False):
|
||||
"""Return the (unscaled) output of the policy network.
|
||||
|
||||
This returns the unscaled outputs of pi(s).
|
||||
|
||||
Arguments:
|
||||
model_out (Tensor): obs embeddings from the model layers, of shape
|
||||
[BATCH_SIZE, num_outputs].
|
||||
|
||||
Returns:
|
||||
tensor of shape [BATCH_SIZE, action_dim] with range [-inf, inf].
|
||||
"""
|
||||
action_distribution = self.action_distribution_model(model_out)
|
||||
if deterministic:
|
||||
actions = action_distribution.bijector(
|
||||
action_distribution.distribution.mean())
|
||||
log_pis = None
|
||||
else:
|
||||
actions = action_distribution.sample()
|
||||
log_pis = action_distribution.log_prob(actions)
|
||||
|
||||
return actions, log_pis
|
||||
|
||||
def get_q_values(self, model_out, actions):
|
||||
"""Return the Q estimates for the most recent forward pass.
|
||||
|
||||
This implements Q(s, a).
|
||||
|
||||
Arguments:
|
||||
model_out (Tensor): obs embeddings from the model layers, of shape
|
||||
[BATCH_SIZE, num_outputs].
|
||||
actions (Tensor): action values that correspond with the most
|
||||
recent batch of observations passed through forward(), of shape
|
||||
[BATCH_SIZE, action_dim].
|
||||
|
||||
Returns:
|
||||
tensor of shape [BATCH_SIZE].
|
||||
"""
|
||||
return self.q_net([model_out, actions])
|
||||
|
||||
def get_twin_q_values(self, model_out, actions):
|
||||
"""Same as get_q_values but using the twin Q net.
|
||||
|
||||
This implements the twin Q(s, a).
|
||||
|
||||
Arguments:
|
||||
model_out (Tensor): obs embeddings from the model layers, of shape
|
||||
[BATCH_SIZE, num_outputs].
|
||||
actions (Tensor): action values that correspond with the most
|
||||
recent batch of observations passed through forward(), of shape
|
||||
[BATCH_SIZE, action_dim].
|
||||
|
||||
Returns:
|
||||
tensor of shape [BATCH_SIZE].
|
||||
"""
|
||||
return self.twin_q_net([model_out, actions])
|
||||
|
||||
def policy_variables(self):
|
||||
"""Return the list of variables for the policy net."""
|
||||
|
||||
return list(self.action_distribution_model.variables)
|
||||
|
||||
def q_variables(self):
|
||||
"""Return the list of variables for Q / twin Q nets."""
|
||||
|
||||
return self.q_net.variables + (self.twin_q_net.variables
|
||||
if self.twin_q_net else [])
|
||||
@@ -0,0 +1,367 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from gym.spaces import Box
|
||||
import numpy as np
|
||||
import logging
|
||||
|
||||
import ray
|
||||
import ray.experimental.tf_utils
|
||||
from ray.rllib.agents.sac.sac_model import SACModel
|
||||
from ray.rllib.agents.ddpg.noop_model import NoopModel
|
||||
from ray.rllib.agents.dqn.dqn_policy import _postprocess_dqn, PRIO_WEIGHTS
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.policy.tf_policy_template import build_tf_policy
|
||||
from ray.rllib.models import ModelCatalog
|
||||
from ray.rllib.utils.error import UnsupportedSpaceException
|
||||
from ray.rllib.utils import try_import_tf, try_import_tfp
|
||||
from ray.rllib.utils.tf_ops import minimize_and_clip
|
||||
|
||||
tf = try_import_tf()
|
||||
tfp = try_import_tfp()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def build_sac_model(policy, obs_space, action_space, config):
|
||||
if config["model"]["custom_model"]:
|
||||
logger.warning(
|
||||
"Setting use_state_preprocessor=True since a custom model "
|
||||
"was specified.")
|
||||
config["use_state_preprocessor"] = True
|
||||
if not isinstance(action_space, Box):
|
||||
raise UnsupportedSpaceException(
|
||||
"Action space {} is not supported for SAC.".format(action_space))
|
||||
if len(action_space.shape) > 1:
|
||||
raise UnsupportedSpaceException(
|
||||
"Action space has multiple dimensions "
|
||||
"{}. ".format(action_space.shape) +
|
||||
"Consider reshaping this into a single dimension, "
|
||||
"using a Tuple action space, or the multi-agent API.")
|
||||
|
||||
if config["use_state_preprocessor"]:
|
||||
default_model = None # catalog decides
|
||||
num_outputs = 256 # arbitrary
|
||||
config["model"]["no_final_linear"] = True
|
||||
else:
|
||||
default_model = NoopModel
|
||||
num_outputs = int(np.product(obs_space.shape))
|
||||
|
||||
policy.model = ModelCatalog.get_model_v2(
|
||||
obs_space,
|
||||
action_space,
|
||||
num_outputs,
|
||||
config["model"],
|
||||
framework="tf",
|
||||
model_interface=SACModel,
|
||||
default_model=default_model,
|
||||
name="sac_model",
|
||||
actor_hidden_activation=config["policy_model"]["hidden_activation"],
|
||||
actor_hiddens=config["policy_model"]["hidden_layer_sizes"],
|
||||
critic_hidden_activation=config["Q_model"]["hidden_activation"],
|
||||
critic_hiddens=config["Q_model"]["hidden_layer_sizes"],
|
||||
twin_q=config["twin_q"])
|
||||
|
||||
policy.target_model = ModelCatalog.get_model_v2(
|
||||
obs_space,
|
||||
action_space,
|
||||
num_outputs,
|
||||
config["model"],
|
||||
framework="tf",
|
||||
model_interface=SACModel,
|
||||
default_model=default_model,
|
||||
name="target_sac_model",
|
||||
actor_hidden_activation=config["policy_model"]["hidden_activation"],
|
||||
actor_hiddens=config["policy_model"]["hidden_layer_sizes"],
|
||||
critic_hidden_activation=config["Q_model"]["hidden_activation"],
|
||||
critic_hiddens=config["Q_model"]["hidden_layer_sizes"],
|
||||
twin_q=config["twin_q"])
|
||||
|
||||
return policy.model
|
||||
|
||||
|
||||
def postprocess_trajectory(policy,
|
||||
sample_batch,
|
||||
other_agent_batches=None,
|
||||
episode=None):
|
||||
return _postprocess_dqn(policy, sample_batch)
|
||||
|
||||
|
||||
def exploration_setting_inputs(policy):
|
||||
return {
|
||||
policy.stochastic: policy.config["exploration_enabled"],
|
||||
}
|
||||
|
||||
|
||||
def build_action_output(policy, model, input_dict, obs_space, action_space,
|
||||
config):
|
||||
model_out, _ = model({
|
||||
"obs": input_dict[SampleBatch.CUR_OBS],
|
||||
"is_training": policy._get_is_training_placeholder(),
|
||||
}, [], None)
|
||||
|
||||
def unsquash_actions(actions):
|
||||
# Use sigmoid to scale to [0,1], but also double magnitude of input to
|
||||
# emulate behaviour of tanh activation used in SAC and TD3 papers.
|
||||
sigmoid_out = tf.nn.sigmoid(2 * actions)
|
||||
# Rescale to actual env policy scale
|
||||
# (shape of sigmoid_out is [batch_size, dim_actions], so we reshape to
|
||||
# get same dims)
|
||||
action_range = (action_space.high - action_space.low)[None]
|
||||
low_action = action_space.low[None]
|
||||
unsquashed_actions = action_range * sigmoid_out + low_action
|
||||
|
||||
return unsquashed_actions
|
||||
|
||||
squashed_stochastic_actions, log_pis = policy.model.get_policy_output(
|
||||
model_out, deterministic=False)
|
||||
stochastic_actions = unsquash_actions(squashed_stochastic_actions)
|
||||
squashed_deterministic_actions, _ = policy.model.get_policy_output(
|
||||
model_out, deterministic=True)
|
||||
deterministic_actions = unsquash_actions(squashed_deterministic_actions)
|
||||
|
||||
actions = tf.cond(policy.stochastic, lambda: stochastic_actions,
|
||||
lambda: deterministic_actions)
|
||||
|
||||
action_probabilities = tf.cond(policy.stochastic, lambda: log_pis,
|
||||
lambda: tf.zeros_like(log_pis))
|
||||
policy.output_actions = actions
|
||||
return actions, action_probabilities
|
||||
|
||||
|
||||
def actor_critic_loss(policy, batch_tensors):
|
||||
model_out_t, _ = policy.model({
|
||||
"obs": batch_tensors[SampleBatch.CUR_OBS],
|
||||
"is_training": policy._get_is_training_placeholder(),
|
||||
}, [], None)
|
||||
|
||||
model_out_tp1, _ = policy.model({
|
||||
"obs": batch_tensors[SampleBatch.NEXT_OBS],
|
||||
"is_training": policy._get_is_training_placeholder(),
|
||||
}, [], None)
|
||||
|
||||
target_model_out_tp1, _ = policy.target_model({
|
||||
"obs": batch_tensors[SampleBatch.NEXT_OBS],
|
||||
"is_training": policy._get_is_training_placeholder(),
|
||||
}, [], None)
|
||||
# TODO(hartikainen): figure actions and log pis
|
||||
policy_t, log_pis_t = policy.model.get_policy_output(model_out_t)
|
||||
policy_tp1, log_pis_tp1 = policy.model.get_policy_output(model_out_tp1)
|
||||
|
||||
log_alpha = policy.model.log_alpha
|
||||
alpha = policy.model.alpha
|
||||
|
||||
# q network evaluation
|
||||
q_t = policy.model.get_q_values(model_out_t,
|
||||
batch_tensors[SampleBatch.ACTIONS])
|
||||
if policy.config["twin_q"]:
|
||||
twin_q_t = policy.model.get_twin_q_values(
|
||||
model_out_t, batch_tensors[SampleBatch.ACTIONS])
|
||||
|
||||
# Q-values for current policy (no noise) in given current state
|
||||
q_t_det_policy = policy.model.get_q_values(model_out_t, policy_t)
|
||||
|
||||
# target q network evaluation
|
||||
q_tp1 = policy.target_model.get_q_values(target_model_out_tp1, policy_tp1)
|
||||
if policy.config["twin_q"]:
|
||||
twin_q_tp1 = policy.target_model.get_twin_q_values(
|
||||
target_model_out_tp1, policy_tp1)
|
||||
|
||||
q_t_selected = tf.squeeze(q_t, axis=len(q_t.shape) - 1)
|
||||
if policy.config["twin_q"]:
|
||||
twin_q_t_selected = tf.squeeze(twin_q_t, axis=len(q_t.shape) - 1)
|
||||
q_tp1 = tf.minimum(q_tp1, twin_q_tp1)
|
||||
|
||||
q_tp1 -= tf.expand_dims(alpha * log_pis_t, 1)
|
||||
|
||||
q_tp1_best = tf.squeeze(input=q_tp1, axis=len(q_tp1.shape) - 1)
|
||||
q_tp1_best_masked = (1.0 - tf.cast(batch_tensors[SampleBatch.DONES],
|
||||
tf.float32)) * q_tp1_best
|
||||
|
||||
assert policy.config["n_step"] == 1, "TODO(hartikainen) n_step > 1"
|
||||
|
||||
# compute RHS of bellman equation
|
||||
q_t_selected_target = tf.stop_gradient(
|
||||
batch_tensors[SampleBatch.REWARDS] +
|
||||
policy.config["gamma"]**policy.config["n_step"] * q_tp1_best_masked)
|
||||
|
||||
# compute the error (potentially clipped)
|
||||
if policy.config["twin_q"]:
|
||||
td_error = q_t_selected - q_t_selected_target
|
||||
twin_td_error = twin_q_t_selected - q_t_selected_target
|
||||
td_error = td_error + twin_td_error
|
||||
errors = 0.5 * (tf.square(td_error) + tf.square(twin_td_error))
|
||||
else:
|
||||
td_error = q_t_selected - q_t_selected_target
|
||||
errors = 0.5 * tf.square(td_error)
|
||||
|
||||
critic_loss = policy.model.custom_loss(
|
||||
tf.reduce_mean(batch_tensors[PRIO_WEIGHTS] * errors), batch_tensors)
|
||||
actor_loss = tf.reduce_mean(alpha * log_pis_t - q_t_det_policy)
|
||||
|
||||
target_entropy = (-np.prod(policy.action_space.shape)
|
||||
if policy.config["target_entropy"] == "auto" else
|
||||
policy.config["target_entropy"])
|
||||
alpha_loss = -tf.reduce_mean(
|
||||
log_alpha * tf.stop_gradient(log_pis_t + target_entropy))
|
||||
|
||||
# save for stats function
|
||||
policy.q_t = q_t
|
||||
policy.td_error = td_error
|
||||
policy.actor_loss = actor_loss
|
||||
policy.critic_loss = critic_loss
|
||||
policy.alpha_loss = alpha_loss
|
||||
|
||||
# in a custom apply op we handle the losses separately, but return them
|
||||
# combined in one loss for now
|
||||
return actor_loss + critic_loss + alpha_loss
|
||||
|
||||
|
||||
def gradients(policy, optimizer, loss):
|
||||
if policy.config["grad_norm_clipping"] is not None:
|
||||
actor_grads_and_vars = minimize_and_clip(
|
||||
policy._actor_optimizer,
|
||||
policy.actor_loss,
|
||||
var_list=policy.model.policy_variables(),
|
||||
clip_val=policy.config["grad_norm_clipping"])
|
||||
critic_grads_and_vars = minimize_and_clip(
|
||||
policy._critic_optimizer,
|
||||
policy.critic_loss,
|
||||
var_list=policy.model.q_variables(),
|
||||
clip_val=policy.config["grad_norm_clipping"])
|
||||
alpha_grads_and_vars = minimize_and_clip(
|
||||
policy._alpha_optimizer,
|
||||
policy.alpha_loss,
|
||||
var_list=policy.model.alpha,
|
||||
clip_val=policy.config["grad_norm_clipping"])
|
||||
else:
|
||||
actor_grads_and_vars = policy._actor_optimizer.compute_gradients(
|
||||
policy.actor_loss, var_list=policy.model.policy_variables())
|
||||
critic_grads_and_vars = policy._critic_optimizer.compute_gradients(
|
||||
policy.critic_loss, var_list=policy.model.q_variables())
|
||||
alpha_grads_and_vars = policy._critic_optimizer.compute_gradients(
|
||||
policy.alpha_loss, var_list=policy.model.alpha)
|
||||
# save these for later use in build_apply_op
|
||||
policy._actor_grads_and_vars = [(g, v) for (g, v) in actor_grads_and_vars
|
||||
if g is not None]
|
||||
policy._critic_grads_and_vars = [(g, v) for (g, v) in critic_grads_and_vars
|
||||
if g is not None]
|
||||
policy._alpha_grads_and_vars = [(g, v) for (g, v) in alpha_grads_and_vars
|
||||
if g is not None]
|
||||
grads_and_vars = (
|
||||
policy._actor_grads_and_vars + policy._critic_grads_and_vars +
|
||||
policy._alpha_grads_and_vars)
|
||||
return grads_and_vars
|
||||
|
||||
|
||||
def stats(policy, batch_tensors):
|
||||
return {
|
||||
"td_error": tf.reduce_mean(policy.td_error),
|
||||
"actor_loss": tf.reduce_mean(policy.actor_loss),
|
||||
"critic_loss": tf.reduce_mean(policy.critic_loss),
|
||||
"mean_q": tf.reduce_mean(policy.q_t),
|
||||
"max_q": tf.reduce_max(policy.q_t),
|
||||
"min_q": tf.reduce_min(policy.q_t),
|
||||
}
|
||||
|
||||
|
||||
class ExplorationStateMixin(object):
|
||||
def __init__(self, obs_space, action_space, config):
|
||||
self.stochastic = tf.placeholder(tf.bool, (), name="stochastic")
|
||||
|
||||
def set_epsilon(self, epsilon):
|
||||
pass
|
||||
|
||||
|
||||
class TargetNetworkMixin(object):
|
||||
def __init__(self, config):
|
||||
# update_target_fn will be called periodically to copy Q network to
|
||||
# target Q network
|
||||
self.tau_value = config.get("tau")
|
||||
self.tau = tf.placeholder(tf.float32, (), name="tau")
|
||||
update_target_expr = []
|
||||
model_vars = self.model.trainable_variables()
|
||||
target_model_vars = self.target_model.trainable_variables()
|
||||
assert len(model_vars) == len(target_model_vars), \
|
||||
(model_vars, target_model_vars)
|
||||
for var, var_target in zip(model_vars, target_model_vars):
|
||||
update_target_expr.append(
|
||||
var_target.assign(self.tau * var +
|
||||
(1.0 - self.tau) * var_target))
|
||||
logger.debug("Update target op {}".format(var_target))
|
||||
self.update_target_expr = tf.group(*update_target_expr)
|
||||
|
||||
# Hard initial update
|
||||
self.update_target(tau=1.0)
|
||||
|
||||
# support both hard and soft sync
|
||||
def update_target(self, tau=None):
|
||||
tau = tau or self.tau_value
|
||||
return self.get_session().run(
|
||||
self.update_target_expr, feed_dict={self.tau: tau})
|
||||
|
||||
|
||||
class ActorCriticOptimizerMixin(object):
|
||||
def __init__(self, config):
|
||||
# create global step for counting the number of update operations
|
||||
self.global_step = tf.train.get_or_create_global_step()
|
||||
|
||||
# use separate optimizers for actor & critic
|
||||
self._actor_optimizer = tf.train.AdamOptimizer(
|
||||
learning_rate=config["optimization"]["actor_learning_rate"])
|
||||
self._critic_optimizer = tf.train.AdamOptimizer(
|
||||
learning_rate=config["optimization"]["critic_learning_rate"])
|
||||
self._alpha_optimizer = tf.train.AdamOptimizer(
|
||||
learning_rate=config["optimization"]["entropy_learning_rate"])
|
||||
|
||||
|
||||
class ComputeTDErrorMixin(object):
|
||||
def compute_td_error(self, obs_t, act_t, rew_t, obs_tp1, done_mask,
|
||||
importance_weights):
|
||||
if not self.loss_initialized():
|
||||
return np.zeros_like(rew_t)
|
||||
|
||||
td_err = self.get_session().run(
|
||||
self.td_error,
|
||||
feed_dict={
|
||||
self.get_placeholder(SampleBatch.CUR_OBS): [
|
||||
np.array(ob) for ob in obs_t
|
||||
],
|
||||
self.get_placeholder(SampleBatch.ACTIONS): act_t,
|
||||
self.get_placeholder(SampleBatch.REWARDS): rew_t,
|
||||
self.get_placeholder(SampleBatch.NEXT_OBS): [
|
||||
np.array(ob) for ob in obs_tp1
|
||||
],
|
||||
self.get_placeholder(SampleBatch.DONES): done_mask,
|
||||
self.get_placeholder(PRIO_WEIGHTS): importance_weights
|
||||
})
|
||||
return td_err
|
||||
|
||||
|
||||
def setup_early_mixins(policy, obs_space, action_space, config):
|
||||
ExplorationStateMixin.__init__(policy, obs_space, action_space, config)
|
||||
ActorCriticOptimizerMixin.__init__(policy, config)
|
||||
|
||||
|
||||
def setup_late_mixins(policy, obs_space, action_space, config):
|
||||
TargetNetworkMixin.__init__(policy, config)
|
||||
|
||||
|
||||
SACTFPolicy = build_tf_policy(
|
||||
name="SACTFPolicy",
|
||||
get_default_config=lambda: ray.rllib.agents.sac.sac.DEFAULT_CONFIG,
|
||||
make_model=build_sac_model,
|
||||
postprocess_fn=postprocess_trajectory,
|
||||
extra_action_feed_fn=exploration_setting_inputs,
|
||||
action_sampler_fn=build_action_output,
|
||||
loss_fn=actor_critic_loss,
|
||||
stats_fn=stats,
|
||||
gradients_fn=gradients,
|
||||
extra_learn_fetches_fn=lambda policy: {"td_error": policy.td_error},
|
||||
mixins=[
|
||||
TargetNetworkMixin, ExplorationStateMixin, ActorCriticOptimizerMixin,
|
||||
ComputeTDErrorMixin
|
||||
],
|
||||
before_init=setup_early_mixins,
|
||||
after_init=setup_late_mixins,
|
||||
obs_include_prev_action_reward=False)
|
||||
@@ -88,6 +88,10 @@ COMMON_CONFIG = {
|
||||
# hit. This allows value estimation and RNN state to span across logical
|
||||
# episodes denoted by horizon. This only has an effect if horizon != inf.
|
||||
"soft_horizon": False,
|
||||
# Don't set 'done' at the end of the episode. Note that you still need to
|
||||
# set this if soft_horizon=True, unless your env is actually running
|
||||
# forever without returning done=True.
|
||||
"no_done_at_end": False,
|
||||
# Arguments to pass to the env creator
|
||||
"env_config": {},
|
||||
# Environment name can also be passed via config
|
||||
|
||||
@@ -132,6 +132,7 @@ class RolloutWorker(EvaluatorInterface):
|
||||
remote_worker_envs=False,
|
||||
remote_env_batch_wait_ms=0,
|
||||
soft_horizon=False,
|
||||
no_done_at_end=False,
|
||||
seed=None,
|
||||
_fake_sampler=False):
|
||||
"""Initialize a rollout worker.
|
||||
@@ -218,6 +219,8 @@ class RolloutWorker(EvaluatorInterface):
|
||||
step / reset and model inference perf.
|
||||
soft_horizon (bool): Calculate rewards but don't reset the
|
||||
environment when the horizon is hit.
|
||||
no_done_at_end (bool): Ignore the done=True at the end of the
|
||||
episode and instead record done=False.
|
||||
seed (int): Set the seed of both np and tf to this value to
|
||||
to ensure each remote worker has unique exploration behavior.
|
||||
_fake_sampler (bool): Use a fake (inf speed) sampler for testing.
|
||||
@@ -408,7 +411,8 @@ class RolloutWorker(EvaluatorInterface):
|
||||
tf_sess=self.tf_sess,
|
||||
clip_actions=clip_actions,
|
||||
blackhole_outputs="simulation" in input_evaluation,
|
||||
soft_horizon=soft_horizon)
|
||||
soft_horizon=soft_horizon,
|
||||
no_done_at_end=no_done_at_end)
|
||||
self.sampler.start()
|
||||
else:
|
||||
self.sampler = SyncSampler(
|
||||
@@ -424,7 +428,8 @@ class RolloutWorker(EvaluatorInterface):
|
||||
pack=pack_episodes,
|
||||
tf_sess=self.tf_sess,
|
||||
clip_actions=clip_actions,
|
||||
soft_horizon=soft_horizon)
|
||||
soft_horizon=soft_horizon,
|
||||
no_done_at_end=no_done_at_end)
|
||||
|
||||
self.input_reader = input_creator(self.io_context)
|
||||
assert isinstance(self.input_reader, InputReader), self.input_reader
|
||||
|
||||
@@ -185,7 +185,8 @@ class MultiAgentSampleBatchBuilder(object):
|
||||
"agent {} (policy {}). ".format(
|
||||
agent_id, self.agent_to_policy[agent_id]) +
|
||||
"Please ensure that you include the last observations "
|
||||
"of all live agents when setting '__all__' done to True.")
|
||||
"of all live agents when setting '__all__' done to True. "
|
||||
"Alternatively, set no_done_at_end=True to allow this.")
|
||||
|
||||
@DeveloperAPI
|
||||
def build_and_reset(self, episode):
|
||||
|
||||
@@ -75,7 +75,8 @@ class SyncSampler(SamplerInput):
|
||||
pack=False,
|
||||
tf_sess=None,
|
||||
clip_actions=True,
|
||||
soft_horizon=False):
|
||||
soft_horizon=False,
|
||||
no_done_at_end=False):
|
||||
self.base_env = BaseEnv.to_base_env(env)
|
||||
self.unroll_length = unroll_length
|
||||
self.horizon = horizon
|
||||
@@ -89,7 +90,8 @@ class SyncSampler(SamplerInput):
|
||||
self.base_env, self.extra_batches.put, self.policies,
|
||||
self.policy_mapping_fn, self.unroll_length, self.horizon,
|
||||
self.preprocessors, self.obs_filters, clip_rewards, clip_actions,
|
||||
pack, callbacks, tf_sess, self.perf_stats, soft_horizon)
|
||||
pack, callbacks, tf_sess, self.perf_stats, soft_horizon,
|
||||
no_done_at_end)
|
||||
self.metrics_queue = queue.Queue()
|
||||
|
||||
def get_data(self):
|
||||
@@ -135,7 +137,8 @@ class AsyncSampler(threading.Thread, SamplerInput):
|
||||
tf_sess=None,
|
||||
clip_actions=True,
|
||||
blackhole_outputs=False,
|
||||
soft_horizon=False):
|
||||
soft_horizon=False,
|
||||
no_done_at_end=False):
|
||||
for _, f in obs_filters.items():
|
||||
assert getattr(f, "is_concurrent", False), \
|
||||
"Observation Filter must support concurrent updates."
|
||||
@@ -158,6 +161,7 @@ class AsyncSampler(threading.Thread, SamplerInput):
|
||||
self.clip_actions = clip_actions
|
||||
self.blackhole_outputs = blackhole_outputs
|
||||
self.soft_horizon = soft_horizon
|
||||
self.no_done_at_end = no_done_at_end
|
||||
self.perf_stats = PerfStats()
|
||||
self.shutdown = False
|
||||
|
||||
@@ -181,7 +185,7 @@ class AsyncSampler(threading.Thread, SamplerInput):
|
||||
self.policy_mapping_fn, self.unroll_length, self.horizon,
|
||||
self.preprocessors, self.obs_filters, self.clip_rewards,
|
||||
self.clip_actions, self.pack, self.callbacks, self.tf_sess,
|
||||
self.perf_stats, self.soft_horizon)
|
||||
self.perf_stats, self.soft_horizon, self.no_done_at_end)
|
||||
while not self.shutdown:
|
||||
# The timeout variable exists because apparently, if one worker
|
||||
# dies, the other workers won't die with it, unless the timeout is
|
||||
@@ -226,7 +230,7 @@ class AsyncSampler(threading.Thread, SamplerInput):
|
||||
def _env_runner(base_env, extra_batch_callback, policies, policy_mapping_fn,
|
||||
unroll_length, horizon, preprocessors, obs_filters,
|
||||
clip_rewards, clip_actions, pack, callbacks, tf_sess,
|
||||
perf_stats, soft_horizon):
|
||||
perf_stats, soft_horizon, no_done_at_end):
|
||||
"""This implements the common experience collection logic.
|
||||
|
||||
Args:
|
||||
@@ -253,6 +257,8 @@ def _env_runner(base_env, extra_batch_callback, policies, policy_mapping_fn,
|
||||
perf_stats (PerfStats): Record perf stats into this object.
|
||||
soft_horizon (bool): Calculate rewards but don't reset the
|
||||
environment when the horizon is hit.
|
||||
no_done_at_end (bool): Ignore the done=True at the end of the episode
|
||||
and instead record done=False.
|
||||
|
||||
Yields:
|
||||
rollout (SampleBatch): Object containing state, action, reward,
|
||||
@@ -310,7 +316,7 @@ def _env_runner(base_env, extra_batch_callback, policies, policy_mapping_fn,
|
||||
base_env, policies, batch_builder_pool, active_episodes,
|
||||
unfiltered_obs, rewards, dones, infos, off_policy_actions, horizon,
|
||||
preprocessors, obs_filters, unroll_length, pack, callbacks,
|
||||
soft_horizon)
|
||||
soft_horizon, no_done_at_end)
|
||||
perf_stats.processing_time += time.time() - t1
|
||||
for o in outputs:
|
||||
yield o
|
||||
@@ -339,7 +345,7 @@ def _process_observations(base_env, policies, batch_builder_pool,
|
||||
active_episodes, unfiltered_obs, rewards, dones,
|
||||
infos, off_policy_actions, horizon, preprocessors,
|
||||
obs_filters, unroll_length, pack, callbacks,
|
||||
soft_horizon):
|
||||
soft_horizon, no_done_at_end):
|
||||
"""Record new data from the environment and prepare for policy evaluation.
|
||||
|
||||
Returns:
|
||||
@@ -434,8 +440,9 @@ def _process_observations(base_env, policies, batch_builder_pool,
|
||||
rewards=rewards[env_id][agent_id],
|
||||
prev_actions=episode.prev_action_for(agent_id),
|
||||
prev_rewards=episode.prev_reward_for(agent_id),
|
||||
dones=(False
|
||||
if (hit_horizon and soft_horizon) else agent_done),
|
||||
dones=(False if (no_done_at_end
|
||||
or (hit_horizon and soft_horizon)) else
|
||||
agent_done),
|
||||
infos=infos[env_id].get(agent_id, {}),
|
||||
new_obs=filtered_obs,
|
||||
**episode.last_pi_info_for(agent_id))
|
||||
@@ -447,7 +454,7 @@ def _process_observations(base_env, policies, batch_builder_pool,
|
||||
# Cut the batch if we're not packing multiple episodes into one,
|
||||
# or if we've exceeded the requested batch size.
|
||||
if episode.batch_builder.has_pending_data():
|
||||
if dones[env_id]["__all__"]:
|
||||
if dones[env_id]["__all__"] and not no_done_at_end:
|
||||
episode.batch_builder.check_missing_dones()
|
||||
if (all_done and not pack) or \
|
||||
episode.batch_builder.count >= unroll_length:
|
||||
|
||||
@@ -211,6 +211,7 @@ class WorkerSet(object):
|
||||
remote_worker_envs=config["remote_worker_envs"],
|
||||
remote_env_batch_wait_ms=config["remote_env_batch_wait_ms"],
|
||||
soft_horizon=config["soft_horizon"],
|
||||
no_done_at_end=config["no_done_at_end"],
|
||||
seed=(config["seed"] + worker_index)
|
||||
if config["seed"] is not None else None,
|
||||
_fake_sampler=config.get("_fake_sampler", False))
|
||||
|
||||
@@ -24,6 +24,7 @@ def get_mean_action(alg, obs):
|
||||
ray.init(num_cpus=10)
|
||||
|
||||
CONFIGS = {
|
||||
"SAC": {},
|
||||
"ES": {
|
||||
"episodes_per_batch": 10,
|
||||
"train_batch_size": 100,
|
||||
@@ -62,7 +63,7 @@ CONFIGS = {
|
||||
|
||||
def test_ckpt_restore(use_object_store, alg_name, failures):
|
||||
cls = get_agent_class(alg_name)
|
||||
if "DDPG" in alg_name:
|
||||
if "DDPG" in alg_name or "SAC" in alg_name:
|
||||
alg1 = cls(config=CONFIGS[name], env="Pendulum-v0")
|
||||
alg2 = cls(config=CONFIGS[name], env="Pendulum-v0")
|
||||
env = gym.make("Pendulum-v0")
|
||||
@@ -82,7 +83,7 @@ def test_ckpt_restore(use_object_store, alg_name, failures):
|
||||
alg2.restore(alg1.save())
|
||||
|
||||
for _ in range(10):
|
||||
if "DDPG" in alg_name:
|
||||
if "DDPG" in alg_name or "SAC" in alg_name:
|
||||
obs = np.clip(
|
||||
np.random.uniform(size=3),
|
||||
env.observation_space.low,
|
||||
@@ -110,7 +111,7 @@ def test_export(algo_name, failures):
|
||||
and os.path.exists(os.path.join(checkpoint_dir, "checkpoint"))
|
||||
|
||||
cls = get_agent_class(algo_name)
|
||||
if "DDPG" in algo_name:
|
||||
if "DDPG" in algo_name or "SAC" in algo_name:
|
||||
algo = cls(config=CONFIGS[name], env="Pendulum-v0")
|
||||
else:
|
||||
algo = cls(config=CONFIGS[name], env="CartPole-v0")
|
||||
@@ -145,14 +146,16 @@ def test_export(algo_name, failures):
|
||||
if __name__ == "__main__":
|
||||
failures = []
|
||||
for use_object_store in [False, True]:
|
||||
for name in ["ES", "DQN", "DDPG", "PPO", "A3C", "APEX_DDPG", "ARS"]:
|
||||
for name in [
|
||||
"SAC", "ES", "DQN", "DDPG", "PPO", "A3C", "APEX_DDPG", "ARS"
|
||||
]:
|
||||
test_ckpt_restore(use_object_store, name, failures)
|
||||
|
||||
assert not failures, failures
|
||||
print("All checkpoint restore tests passed!")
|
||||
|
||||
failures = []
|
||||
for name in ["DQN", "DDPG", "PPO", "A3C"]:
|
||||
for name in ["SAC", "DQN", "DDPG", "PPO", "A3C"]:
|
||||
test_export(name, failures)
|
||||
assert not failures, failures
|
||||
print("All export tests passed!")
|
||||
|
||||
@@ -0,0 +1,10 @@
|
||||
pendulum-sac:
|
||||
env: Pendulum-v0
|
||||
run: SAC
|
||||
stop:
|
||||
episode_reward_mean: -300 # note that evaluation perf is higher
|
||||
timesteps_total: 15000
|
||||
config:
|
||||
evaluation_interval: 1 # logged under evaluation/* metric keys
|
||||
soft_horizon: true
|
||||
metrics_smoothing_episodes: 10
|
||||
@@ -76,6 +76,19 @@ def try_import_tf():
|
||||
return None
|
||||
|
||||
|
||||
def try_import_tfp():
|
||||
if "RLLIB_TEST_NO_TF_IMPORT" in os.environ:
|
||||
logger.warning(
|
||||
"Not importing TensorFlow Probability for test purposes.")
|
||||
return None
|
||||
|
||||
try:
|
||||
import tensorflow_probability as tfp
|
||||
return tfp
|
||||
except ImportError:
|
||||
return None
|
||||
|
||||
|
||||
__all__ = [
|
||||
"Filter",
|
||||
"FilterManager",
|
||||
|
||||
Reference in New Issue
Block a user