mirror of
https://github.com/wassname/ray.git
synced 2026-06-29 02:30:34 +08:00
[RLlib] DDPG PyTorch version. (#7953)
The DDPG/TD3 algorithms currently do not have a PyTorch implementation. This PR adds PyTorch support for DDPG/TD3 to RLlib. This PR: - Depends on the re-factor PR for DDPG (Functional Algorithm API). - Adds learning regression tests for the PyTorch version of DDPG and a DDPG (torch) - Updates the documentation to reflect that DDPG and TD3 now support PyTorch. * Learning Pendulum-v0 on torch version (same config as tf). Wall time a little slower (~20% than tf). * Fix GPU target model problem.
This commit is contained in:
@@ -8,25 +8,27 @@ RLlib Algorithms
|
||||
Feature Compatibility Matrix
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
=================== ======================= ============== =========== =====================
|
||||
Algorithm Discrete Actions Continuous Multi-Agent Model Support
|
||||
=================== ======================= ============== =========== =====================
|
||||
`A2C, A3C`_ **Yes** `+parametric`_ **Yes** **Yes** `+RNN`_, `+autoreg`_
|
||||
`PPO`_, `APPO`_ **Yes** `+parametric`_ **Yes** **Yes** `+RNN`_, `+autoreg`_
|
||||
`PG`_ **Yes** `+parametric`_ **Yes** **Yes** `+RNN`_, `+autoreg`_
|
||||
`IMPALA`_ **Yes** `+parametric`_ **Yes** **Yes** `+RNN`_, `+autoreg`_
|
||||
`DQN`_, `Rainbow`_ **Yes** `+parametric`_ No **Yes**
|
||||
`DDPG`_, `TD3`_ No **Yes** **Yes**
|
||||
`APEX-DQN`_ **Yes** `+parametric`_ No **Yes**
|
||||
`APEX-DDPG`_ No **Yes** **Yes**
|
||||
`SAC`_ **Yes** **Yes** **Yes**
|
||||
`ES`_ **Yes** **Yes** No
|
||||
`ARS`_ **Yes** **Yes** No
|
||||
`QMIX`_ **Yes** No **Yes** `+RNN`_
|
||||
`MARWIL`_ **Yes** `+parametric`_ **Yes** **Yes** `+RNN`_
|
||||
`LinUCB`_, `LinTS`_ **Yes** `+parametric`_ No **Yes**
|
||||
`AlphaZero`_ **Yes** `+parametric`_ No No
|
||||
=================== ======================= ============== =========== =====================
|
||||
=================== ========== ======================= ================== =========== =====================
|
||||
Algorithm Frameworks Discrete Actions Continuous Actions Multi-Agent Model Support
|
||||
=================== ========== ======================= ================== =========== =====================
|
||||
`A2C, A3C`_ tf + torch **Yes** `+parametric`_ **Yes** **Yes** `+RNN`_, `+autoreg`_
|
||||
`ARS`_ tf **Yes** **Yes** No
|
||||
`ES`_ tf **Yes** **Yes** No
|
||||
`DDPG`_, `TD3`_ tf + torch No **Yes** **Yes**
|
||||
`APEX-DDPG`_ tf No **Yes** **Yes**
|
||||
`DQN`_, `Rainbow`_ tf + torch **Yes** `+parametric`_ No **Yes**
|
||||
`APEX-DQN`_ tf + torch **Yes** `+parametric`_ No **Yes**
|
||||
`IMPALA`_ tf **Yes** `+parametric`_ **Yes** **Yes** `+RNN`_, `+autoreg`_
|
||||
`MARWIL`_ tf + torch **Yes** `+parametric`_ **Yes** **Yes** `+RNN`_
|
||||
`PG`_ tf + torch **Yes** `+parametric`_ **Yes** **Yes** `+RNN`_, `+autoreg`_
|
||||
`PPO`_, `APPO`_ tf + torch **Yes** `+parametric`_ **Yes** **Yes** `+RNN`_, `+autoreg`_
|
||||
`QMIX`_ torch **Yes** No **Yes** `+RNN`_
|
||||
`SAC`_ tf + torch **Yes** **Yes** **Yes**
|
||||
------------------- ---------- ----------------------- ------------------ ----------- ---------------------
|
||||
`AlphaZero`_ torch **Yes** `+parametric`_ No No
|
||||
`LinUCB`_, `LinTS`_ torch **Yes** `+parametric`_ No **Yes**
|
||||
`MADDPG`_ tf No **Yes** **Yes**
|
||||
=================== ========== ======================= ================== =========== =====================
|
||||
|
||||
.. _`+parametric`: rllib-models.html#variable-length-parametric-action-spaces
|
||||
.. _`+RNN`: rllib-models.html#recurrent-models
|
||||
@@ -233,7 +235,7 @@ SpaceInvaders 692 ~600
|
||||
|
||||
Deep Deterministic Policy Gradients (DDPG, TD3)
|
||||
-----------------------------------------------
|
||||
|tensorflow|
|
||||
|pytorch| |tensorflow|
|
||||
`[paper] <https://arxiv.org/abs/1509.02971>`__ `[implementation] <https://github.com/ray-project/ray/blob/master/rllib/agents/ddpg/ddpg.py>`__
|
||||
DDPG is implemented similarly to DQN (below). The algorithm can be scaled by increasing the number of workers, switching to AsyncGradientsOptimizer, or using Ape-X. The improvements from `TD3 <https://spinningup.openai.com/en/latest/algorithms/td3.html>`__ are available as ``TD3``.
|
||||
|
||||
@@ -370,7 +372,7 @@ HalfCheetah 9664 ~7700
|
||||
|
||||
Soft Actor Critic (SAC)
|
||||
------------------------
|
||||
|tensorflow|
|
||||
|pytorch| |tensorflow|
|
||||
`[paper] <https://arxiv.org/pdf/1801.01290>`__ `[implementation] <https://github.com/ray-project/ray/blob/master/rllib/agents/sac/sac.py>`__
|
||||
|
||||
.. figure:: dqn-arch.svg
|
||||
@@ -476,7 +478,7 @@ Tuned examples: `Multi-Agent Particle Environment <https://github.com/wsjeon/mad
|
||||
|
||||
Advantage Re-Weighted Imitation Learning (MARWIL)
|
||||
-------------------------------------------------
|
||||
|tensorflow|
|
||||
|pytorch| |tensorflow|
|
||||
`[paper] <http://papers.nips.cc/paper/7866-exponentially-weighted-imitation-learning-for-batched-historical-data>`__ `[implementation] <https://github.com/ray-project/ray/blob/master/rllib/agents/marwil/marwil.py>`__ MARWIL is a hybrid imitation learning and policy gradient algorithm suitable for training on batched historical data. When the ``beta`` hyperparameter is set to zero, the MARWIL objective reduces to vanilla imitation learning. MARWIL requires the `offline datasets API <rllib-offline.html>`__ to be used.
|
||||
|
||||
Tuned examples: `CartPole-v0 <https://github.com/ray-project/ray/blob/master/rllib/tuned_examples/cartpole-marwil.yaml>`__
|
||||
|
||||
@@ -259,22 +259,6 @@ def make_ddpg_optimizers(policy, config):
|
||||
learning_rate=config["critic_lr"])
|
||||
return None
|
||||
|
||||
# TFPolicy.__init__(
|
||||
# self,
|
||||
# observation_space,
|
||||
# action_space,
|
||||
# self.config,
|
||||
# self.sess,
|
||||
# #obs_input=self.cur_observations,
|
||||
# sampled_action=self.output_actions,
|
||||
# loss=self.actor_loss + self.critic_loss,
|
||||
# loss_inputs=self.loss_inputs,
|
||||
# update_ops=q_batchnorm_update_ops + policy_batchnorm_update_ops,
|
||||
# explore=explore,
|
||||
# dist_inputs=self._distribution_inputs,
|
||||
# dist_class=Deterministic,
|
||||
# timestep=timestep)
|
||||
|
||||
|
||||
def build_apply_op(policy, optimizer, grads_and_vars):
|
||||
# For policy gradient, update policy net one time v.s.
|
||||
@@ -2,4 +2,9 @@ from ray.rllib.agents.ddpg.apex import ApexDDPGTrainer
|
||||
from ray.rllib.agents.ddpg.ddpg import DDPGTrainer, DEFAULT_CONFIG
|
||||
from ray.rllib.agents.ddpg.td3 import TD3Trainer
|
||||
|
||||
__all__ = ["ApexDDPGTrainer", "DDPGTrainer", "DEFAULT_CONFIG", "TD3Trainer"]
|
||||
__all__ = [
|
||||
"ApexDDPGTrainer",
|
||||
"DDPGTrainer",
|
||||
"DEFAULT_CONFIG",
|
||||
"TD3Trainer",
|
||||
]
|
||||
|
||||
@@ -2,7 +2,7 @@ import logging
|
||||
|
||||
from ray.rllib.agents.trainer import with_common_config
|
||||
from ray.rllib.agents.dqn.dqn import GenericOffPolicyTrainer
|
||||
from ray.rllib.agents.ddpg.ddpg_policy import DDPGTFPolicy
|
||||
from ray.rllib.agents.ddpg.ddpg_tf_policy import DDPGTFPolicy
|
||||
from ray.rllib.utils.deprecation import deprecation_warning, \
|
||||
DEPRECATED_VALUE
|
||||
from ray.rllib.utils.exploration.per_worker_ornstein_uhlenbeck_noise import \
|
||||
@@ -153,10 +153,6 @@ DEFAULT_CONFIG = with_common_config({
|
||||
|
||||
|
||||
def validate_config(config):
|
||||
# PyTorch check.
|
||||
if config["use_pytorch"]:
|
||||
raise ValueError("DDPG does not support PyTorch yet! Use tf instead.")
|
||||
|
||||
# TODO(sven): Remove at some point.
|
||||
# Backward compatibility of noise-based exploration config.
|
||||
schedule_max_timesteps = None
|
||||
@@ -202,10 +198,18 @@ def validate_config(config):
|
||||
config["batch_mode"] = "complete_episodes"
|
||||
|
||||
|
||||
def get_policy_class(config):
|
||||
if config["use_pytorch"]:
|
||||
from ray.rllib.agents.ddpg.ddpg_torch_policy import DDPGTorchPolicy
|
||||
return DDPGTorchPolicy
|
||||
else:
|
||||
return DDPGTFPolicy
|
||||
|
||||
|
||||
DDPGTrainer = GenericOffPolicyTrainer.with_updates(
|
||||
name="DDPG",
|
||||
default_config=DEFAULT_CONFIG,
|
||||
default_policy=DDPGTFPolicy,
|
||||
get_policy_class=None,
|
||||
get_policy_class=get_policy_class,
|
||||
validate_config=validate_config,
|
||||
)
|
||||
|
||||
@@ -4,7 +4,7 @@ from ray.rllib.utils import try_import_tf
|
||||
tf = try_import_tf()
|
||||
|
||||
|
||||
class DDPGModel(TFModelV2):
|
||||
class DDPGTFModel(TFModelV2):
|
||||
"""Extension of standard TFModel to provide DDPG action- and q-outputs.
|
||||
|
||||
Data flow:
|
||||
@@ -43,8 +43,8 @@ class DDPGModel(TFModelV2):
|
||||
should be defined in subclasses of DDPGActionModel.
|
||||
"""
|
||||
|
||||
super(DDPGModel, self).__init__(obs_space, action_space, num_outputs,
|
||||
model_config, name)
|
||||
super(DDPGTFModel, self).__init__(obs_space, action_space, num_outputs,
|
||||
model_config, name)
|
||||
|
||||
actor_hidden_activation = getattr(tf.nn, actor_hidden_activation,
|
||||
tf.nn.relu)
|
||||
@@ -84,8 +84,8 @@ class DDPGModel(TFModelV2):
|
||||
|
||||
actor_out = tf.keras.layers.Lambda(lambda_)(actor_out)
|
||||
|
||||
self.action_model = tf.keras.Model(self.model_out, actor_out)
|
||||
self.register_variables(self.action_model.variables)
|
||||
self.policy_model = tf.keras.Model(self.model_out, actor_out)
|
||||
self.register_variables(self.policy_model.variables)
|
||||
|
||||
# Build the Q-model(s).
|
||||
self.actions_input = tf.keras.layers.Input(
|
||||
@@ -111,15 +111,15 @@ class DDPGModel(TFModelV2):
|
||||
q_net([observations, actions]))
|
||||
return q_net
|
||||
|
||||
self.q_net = build_q_net("q", self.model_out, self.actions_input)
|
||||
self.register_variables(self.q_net.variables)
|
||||
self.q_model = build_q_net("q", self.model_out, self.actions_input)
|
||||
self.register_variables(self.q_model.variables)
|
||||
|
||||
if twin_q:
|
||||
self.twin_q_net = build_q_net("twin_q", self.model_out,
|
||||
self.actions_input)
|
||||
self.register_variables(self.twin_q_net.variables)
|
||||
self.twin_q_model = build_q_net("twin_q", self.model_out,
|
||||
self.actions_input)
|
||||
self.register_variables(self.twin_q_model.variables)
|
||||
else:
|
||||
self.twin_q_net = None
|
||||
self.twin_q_model = None
|
||||
|
||||
def get_q_values(self, model_out, actions):
|
||||
"""Return the Q estimates for the most recent forward pass.
|
||||
@@ -136,9 +136,9 @@ class DDPGModel(TFModelV2):
|
||||
tensor of shape [BATCH_SIZE].
|
||||
"""
|
||||
if actions is not None:
|
||||
return self.q_net([model_out, actions])
|
||||
return self.q_model([model_out, actions])
|
||||
else:
|
||||
return self.q_net(model_out)
|
||||
return self.q_model(model_out)
|
||||
|
||||
def get_twin_q_values(self, model_out, actions):
|
||||
"""Same as get_q_values but using the twin Q net.
|
||||
@@ -155,9 +155,9 @@ class DDPGModel(TFModelV2):
|
||||
tensor of shape [BATCH_SIZE].
|
||||
"""
|
||||
if actions is not None:
|
||||
return self.twin_q_net([model_out, actions])
|
||||
return self.twin_q_model([model_out, actions])
|
||||
else:
|
||||
return self.twin_q_net(model_out)
|
||||
return self.twin_q_model(model_out)
|
||||
|
||||
def get_policy_output(self, model_out):
|
||||
"""Return the action output for the most recent forward pass.
|
||||
@@ -172,14 +172,14 @@ class DDPGModel(TFModelV2):
|
||||
Returns:
|
||||
tensor of shape [BATCH_SIZE, action_out_size]
|
||||
"""
|
||||
return self.action_model(model_out)
|
||||
return self.policy_model(model_out)
|
||||
|
||||
def policy_variables(self):
|
||||
"""Return the list of variables for the policy net."""
|
||||
return list(self.action_model.variables)
|
||||
return list(self.policy_model.variables)
|
||||
|
||||
def q_variables(self):
|
||||
"""Return the list of variables for Q / twin Q nets."""
|
||||
|
||||
return self.q_net.variables + (self.twin_q_net.variables
|
||||
if self.twin_q_net else [])
|
||||
return self.q_model.variables + (self.twin_q_model.variables
|
||||
if self.twin_q_model else [])
|
||||
@@ -0,0 +1,438 @@
|
||||
from gym.spaces import Box
|
||||
import logging
|
||||
import numpy as np
|
||||
|
||||
import ray
|
||||
import ray.experimental.tf_utils
|
||||
from ray.util.debug import log_once
|
||||
from ray.rllib.agents.ddpg.ddpg_tf_model import DDPGTFModel
|
||||
from ray.rllib.agents.ddpg.ddpg_torch_model import DDPGTorchModel
|
||||
from ray.rllib.agents.ddpg.noop_model import NoopModel, TorchNoopModel
|
||||
from ray.rllib.agents.dqn.dqn_tf_policy import postprocess_nstep_and_prio, \
|
||||
PRIO_WEIGHTS
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.models import ModelCatalog
|
||||
from ray.rllib.models.tf.tf_action_dist import Deterministic
|
||||
from ray.rllib.models.torch.torch_action_dist import TorchDeterministic
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.error import UnsupportedSpaceException
|
||||
from ray.rllib.policy.tf_policy import TFPolicy
|
||||
from ray.rllib.policy.tf_policy_template import build_tf_policy
|
||||
from ray.rllib.utils import try_import_tf
|
||||
from ray.rllib.utils.tf_ops import huber_loss, minimize_and_clip, \
|
||||
make_tf_callable
|
||||
|
||||
tf = try_import_tf()
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
ACTION_SCOPE = "action"
|
||||
POLICY_SCOPE = "policy"
|
||||
POLICY_TARGET_SCOPE = "target_policy"
|
||||
Q_SCOPE = "critic"
|
||||
Q_TARGET_SCOPE = "target_critic"
|
||||
TWIN_Q_SCOPE = "twin_critic"
|
||||
TWIN_Q_TARGET_SCOPE = "twin_target_critic"
|
||||
|
||||
|
||||
def build_ddpg_models(policy, observation_space, action_space, config):
|
||||
if config["model"]["custom_model"]:
|
||||
logger.warning(
|
||||
"Setting use_state_preprocessor=True since a custom model "
|
||||
"was specified.")
|
||||
config["use_state_preprocessor"] = True
|
||||
|
||||
if not isinstance(action_space, Box):
|
||||
raise UnsupportedSpaceException(
|
||||
"Action space {} is not supported for DDPG.".format(action_space))
|
||||
elif len(action_space.shape) > 1:
|
||||
raise UnsupportedSpaceException(
|
||||
"Action space has multiple dimensions "
|
||||
"{}. ".format(action_space.shape) +
|
||||
"Consider reshaping this into a single dimension, "
|
||||
"using a Tuple action space, or the multi-agent API.")
|
||||
|
||||
if policy.config["use_state_preprocessor"]:
|
||||
default_model = None # catalog decides
|
||||
num_outputs = 256 # arbitrary
|
||||
config["model"]["no_final_linear"] = True
|
||||
else:
|
||||
default_model = TorchNoopModel if config["use_pytorch"] else NoopModel
|
||||
num_outputs = int(np.product(observation_space.shape))
|
||||
|
||||
policy.model = ModelCatalog.get_model_v2(
|
||||
obs_space=observation_space,
|
||||
action_space=action_space,
|
||||
num_outputs=num_outputs,
|
||||
model_config=config["model"],
|
||||
framework="torch" if config["use_pytorch"] else "tf",
|
||||
model_interface=DDPGTorchModel
|
||||
if config["use_pytorch"] else DDPGTFModel,
|
||||
default_model=default_model,
|
||||
name="ddpg_model",
|
||||
actor_hidden_activation=config["actor_hidden_activation"],
|
||||
actor_hiddens=config["actor_hiddens"],
|
||||
critic_hidden_activation=config["critic_hidden_activation"],
|
||||
critic_hiddens=config["critic_hiddens"],
|
||||
twin_q=config["twin_q"],
|
||||
add_layer_norm=(policy.config["exploration_config"].get("type") ==
|
||||
"ParameterNoise"),
|
||||
)
|
||||
|
||||
policy.target_model = ModelCatalog.get_model_v2(
|
||||
obs_space=observation_space,
|
||||
action_space=action_space,
|
||||
num_outputs=num_outputs,
|
||||
model_config=config["model"],
|
||||
framework="torch" if config["use_pytorch"] else "tf",
|
||||
model_interface=DDPGTorchModel
|
||||
if config["use_pytorch"] else DDPGTFModel,
|
||||
default_model=default_model,
|
||||
name="target_ddpg_model",
|
||||
actor_hidden_activation=config["actor_hidden_activation"],
|
||||
actor_hiddens=config["actor_hiddens"],
|
||||
critic_hidden_activation=config["critic_hidden_activation"],
|
||||
critic_hiddens=config["critic_hiddens"],
|
||||
twin_q=config["twin_q"],
|
||||
add_layer_norm=(policy.config["exploration_config"].get("type") ==
|
||||
"ParameterNoise"),
|
||||
)
|
||||
|
||||
return policy.model
|
||||
|
||||
|
||||
def get_distribution_inputs_and_class(policy,
|
||||
model,
|
||||
obs_batch,
|
||||
*,
|
||||
explore=True,
|
||||
is_training=False,
|
||||
**kwargs):
|
||||
model_out, _ = model({
|
||||
"obs": obs_batch,
|
||||
"is_training": is_training,
|
||||
}, [], None)
|
||||
dist_inputs = model.get_policy_output(model_out)
|
||||
|
||||
return dist_inputs,\
|
||||
TorchDeterministic if policy.config["use_pytorch"] else Deterministic,\
|
||||
[] # []=state out
|
||||
|
||||
|
||||
def ddpg_actor_critic_loss(policy, model, _, train_batch):
|
||||
twin_q = policy.config["twin_q"]
|
||||
gamma = policy.config["gamma"]
|
||||
n_step = policy.config["n_step"]
|
||||
use_huber = policy.config["use_huber"]
|
||||
huber_threshold = policy.config["huber_threshold"]
|
||||
l2_reg = policy.config["l2_reg"]
|
||||
|
||||
input_dict = {
|
||||
"obs": train_batch[SampleBatch.CUR_OBS],
|
||||
"is_training": True,
|
||||
}
|
||||
input_dict_next = {
|
||||
"obs": train_batch[SampleBatch.NEXT_OBS],
|
||||
"is_training": True,
|
||||
}
|
||||
|
||||
model_out_t, _ = model(input_dict, [], None)
|
||||
model_out_tp1, _ = model(input_dict_next, [], None)
|
||||
target_model_out_tp1, _ = policy.target_model(input_dict_next, [], None)
|
||||
|
||||
# Policy network evaluation.
|
||||
with tf.variable_scope(POLICY_SCOPE, reuse=True):
|
||||
# prev_update_ops = set(tf.get_collection(tf.GraphKeys.UPDATE_OPS))
|
||||
policy_t = model.get_policy_output(model_out_t)
|
||||
# policy_batchnorm_update_ops = list(
|
||||
# set(tf.get_collection(tf.GraphKeys.UPDATE_OPS)) - prev_update_ops)
|
||||
|
||||
with tf.variable_scope(POLICY_TARGET_SCOPE):
|
||||
policy_tp1 = \
|
||||
policy.target_model.get_policy_output(target_model_out_tp1)
|
||||
|
||||
# Action outputs.
|
||||
with tf.variable_scope(ACTION_SCOPE, reuse=True):
|
||||
if policy.config["smooth_target_policy"]:
|
||||
target_noise_clip = policy.config["target_noise_clip"]
|
||||
clipped_normal_sample = tf.clip_by_value(
|
||||
tf.random_normal(
|
||||
tf.shape(policy_tp1),
|
||||
stddev=policy.config["target_noise"]), -target_noise_clip,
|
||||
target_noise_clip)
|
||||
policy_tp1_smoothed = tf.clip_by_value(
|
||||
policy_tp1 + clipped_normal_sample,
|
||||
policy.action_space.low * tf.ones_like(policy_tp1),
|
||||
policy.action_space.high * tf.ones_like(policy_tp1))
|
||||
else:
|
||||
# No smoothing, just use deterministic actions.
|
||||
policy_tp1_smoothed = policy_tp1
|
||||
|
||||
# Q-net(s) evaluation.
|
||||
# prev_update_ops = set(tf.get_collection(tf.GraphKeys.UPDATE_OPS))
|
||||
with tf.variable_scope(Q_SCOPE):
|
||||
# Q-values for given actions & observations in given current
|
||||
q_t = model.get_q_values(model_out_t, train_batch[SampleBatch.ACTIONS])
|
||||
|
||||
with tf.variable_scope(Q_SCOPE, reuse=True):
|
||||
# Q-values for current policy (no noise) in given current state
|
||||
q_t_det_policy = model.get_q_values(model_out_t, policy_t)
|
||||
|
||||
if twin_q:
|
||||
with tf.variable_scope(TWIN_Q_SCOPE):
|
||||
twin_q_t = model.get_twin_q_values(
|
||||
model_out_t, train_batch[SampleBatch.ACTIONS])
|
||||
# q_batchnorm_update_ops = list(
|
||||
# set(tf.get_collection(tf.GraphKeys.UPDATE_OPS)) - prev_update_ops)
|
||||
|
||||
# Target q-net(s) evaluation.
|
||||
with tf.variable_scope(Q_TARGET_SCOPE):
|
||||
q_tp1 = policy.target_model.get_q_values(target_model_out_tp1,
|
||||
policy_tp1_smoothed)
|
||||
|
||||
if twin_q:
|
||||
with tf.variable_scope(TWIN_Q_TARGET_SCOPE):
|
||||
twin_q_tp1 = policy.target_model.get_twin_q_values(
|
||||
target_model_out_tp1, policy_tp1_smoothed)
|
||||
|
||||
q_t_selected = tf.squeeze(q_t, axis=len(q_t.shape) - 1)
|
||||
if twin_q:
|
||||
twin_q_t_selected = tf.squeeze(twin_q_t, axis=len(q_t.shape) - 1)
|
||||
q_tp1 = tf.minimum(q_tp1, twin_q_tp1)
|
||||
|
||||
q_tp1_best = tf.squeeze(input=q_tp1, axis=len(q_tp1.shape) - 1)
|
||||
q_tp1_best_masked = \
|
||||
(1.0 - tf.cast(train_batch[SampleBatch.DONES], tf.float32)) * \
|
||||
q_tp1_best
|
||||
|
||||
# Compute RHS of bellman equation.
|
||||
q_t_selected_target = tf.stop_gradient(train_batch[SampleBatch.REWARDS] +
|
||||
gamma**n_step * q_tp1_best_masked)
|
||||
|
||||
# Compute the error (potentially clipped).
|
||||
if twin_q:
|
||||
td_error = q_t_selected - q_t_selected_target
|
||||
twin_td_error = twin_q_t_selected - q_t_selected_target
|
||||
td_error = td_error + twin_td_error
|
||||
if use_huber:
|
||||
errors = huber_loss(td_error, huber_threshold) \
|
||||
+ huber_loss(twin_td_error, huber_threshold)
|
||||
else:
|
||||
errors = 0.5 * tf.square(td_error) + 0.5 * tf.square(twin_td_error)
|
||||
else:
|
||||
td_error = q_t_selected - q_t_selected_target
|
||||
if use_huber:
|
||||
errors = huber_loss(td_error, huber_threshold)
|
||||
else:
|
||||
errors = 0.5 * tf.square(td_error)
|
||||
|
||||
critic_loss = tf.reduce_mean(train_batch[PRIO_WEIGHTS] * errors)
|
||||
actor_loss = -tf.reduce_mean(q_t_det_policy)
|
||||
|
||||
# Add l2-regularization if required.
|
||||
if l2_reg is not None:
|
||||
for var in policy.model.policy_variables():
|
||||
if "bias" not in var.name:
|
||||
actor_loss += (l2_reg * tf.nn.l2_loss(var))
|
||||
for var in policy.model.q_variables():
|
||||
if "bias" not in var.name:
|
||||
critic_loss += (l2_reg * tf.nn.l2_loss(var))
|
||||
|
||||
# Model self-supervised losses.
|
||||
if policy.config["use_state_preprocessor"]:
|
||||
# Expand input_dict in case custom_loss' need them.
|
||||
input_dict[SampleBatch.ACTIONS] = train_batch[SampleBatch.ACTIONS]
|
||||
input_dict[SampleBatch.REWARDS] = train_batch[SampleBatch.REWARDS]
|
||||
input_dict[SampleBatch.DONES] = train_batch[SampleBatch.DONES]
|
||||
input_dict[SampleBatch.NEXT_OBS] = train_batch[SampleBatch.NEXT_OBS]
|
||||
if log_once("ddpg_custom_loss"):
|
||||
logger.warning(
|
||||
"You are using a state-preprocessor with DDPG and "
|
||||
"therefore, `custom_loss` will be called on your Model! "
|
||||
"Please be aware that DDPG now uses the ModelV2 API, which "
|
||||
"merges all previously separate sub-models (policy_model, "
|
||||
"q_model, and twin_q_model) into one ModelV2, on which "
|
||||
"`custom_loss` is called, passing it "
|
||||
"[actor_loss, critic_loss] as 1st argument. "
|
||||
"You may have to change your custom loss function to handle "
|
||||
"this.")
|
||||
[actor_loss, critic_loss] = model.custom_loss(
|
||||
[actor_loss, critic_loss], input_dict)
|
||||
|
||||
# Store values for stats function.
|
||||
policy.actor_loss = actor_loss
|
||||
policy.critic_loss = critic_loss
|
||||
policy.td_error = td_error
|
||||
policy.q_t = q_t
|
||||
|
||||
# Return one loss value (even though we treat them separately in our
|
||||
# 2 optimizers: actor and critic).
|
||||
return policy.critic_loss + policy.actor_loss
|
||||
|
||||
|
||||
def make_ddpg_optimizers(policy, config):
|
||||
# Create separate optimizers for actor & critic losses.
|
||||
policy._actor_optimizer = tf.train.AdamOptimizer(
|
||||
learning_rate=config["actor_lr"])
|
||||
policy._critic_optimizer = tf.train.AdamOptimizer(
|
||||
learning_rate=config["critic_lr"])
|
||||
return None
|
||||
|
||||
# TFPolicy.__init__(
|
||||
# self,
|
||||
# observation_space,
|
||||
# action_space,
|
||||
# self.config,
|
||||
# self.sess,
|
||||
# #obs_input=self.cur_observations,
|
||||
# sampled_action=self.output_actions,
|
||||
# loss=self.actor_loss + self.critic_loss,
|
||||
# loss_inputs=self.loss_inputs,
|
||||
# update_ops=q_batchnorm_update_ops + policy_batchnorm_update_ops,
|
||||
# explore=explore,
|
||||
# dist_inputs=self._distribution_inputs,
|
||||
# dist_class=Deterministic,
|
||||
# timestep=timestep)
|
||||
|
||||
|
||||
def build_apply_op(policy, optimizer, grads_and_vars):
|
||||
# For policy gradient, update policy net one time v.s.
|
||||
# update critic net `policy_delay` time(s).
|
||||
should_apply_actor_opt = tf.equal(
|
||||
tf.mod(policy.global_step, policy.config["policy_delay"]), 0)
|
||||
|
||||
def make_apply_op():
|
||||
return policy._actor_optimizer.apply_gradients(
|
||||
policy._actor_grads_and_vars)
|
||||
|
||||
actor_op = tf.cond(
|
||||
should_apply_actor_opt,
|
||||
true_fn=make_apply_op,
|
||||
false_fn=lambda: tf.no_op())
|
||||
critic_op = policy._critic_optimizer.apply_gradients(
|
||||
policy._critic_grads_and_vars)
|
||||
# Increment global step & apply ops.
|
||||
with tf.control_dependencies([tf.assign_add(policy.global_step, 1)]):
|
||||
return tf.group(actor_op, critic_op)
|
||||
|
||||
|
||||
def gradients_fn(policy, optimizer, loss):
|
||||
if policy.config["grad_norm_clipping"] is not None:
|
||||
actor_grads_and_vars = minimize_and_clip(
|
||||
policy._actor_optimizer,
|
||||
policy.actor_loss,
|
||||
var_list=policy.model.policy_variables(),
|
||||
clip_val=policy.config["grad_norm_clipping"])
|
||||
critic_grads_and_vars = minimize_and_clip(
|
||||
policy._critic_optimizer,
|
||||
policy.critic_loss,
|
||||
var_list=policy.model.q_variables(),
|
||||
clip_val=policy.config["grad_norm_clipping"])
|
||||
else:
|
||||
actor_grads_and_vars = policy._actor_optimizer.compute_gradients(
|
||||
policy.actor_loss, var_list=policy.model.policy_variables())
|
||||
critic_grads_and_vars = policy._critic_optimizer.compute_gradients(
|
||||
policy.critic_loss, var_list=policy.model.q_variables())
|
||||
# Save these for later use in build_apply_op.
|
||||
policy._actor_grads_and_vars = [(g, v) for (g, v) in actor_grads_and_vars
|
||||
if g is not None]
|
||||
policy._critic_grads_and_vars = [(g, v) for (g, v) in critic_grads_and_vars
|
||||
if g is not None]
|
||||
grads_and_vars = policy._actor_grads_and_vars + \
|
||||
policy._critic_grads_and_vars
|
||||
return grads_and_vars
|
||||
|
||||
|
||||
def build_ddpg_stats(policy, batch):
|
||||
stats = {
|
||||
"mean_q": tf.reduce_mean(policy.q_t),
|
||||
"max_q": tf.reduce_max(policy.q_t),
|
||||
"min_q": tf.reduce_min(policy.q_t),
|
||||
}
|
||||
return stats
|
||||
|
||||
|
||||
def before_init_fn(policy, obs_space, action_space, config):
|
||||
# Create global step for counting the number of update operations.
|
||||
policy.global_step = tf.train.get_or_create_global_step()
|
||||
|
||||
|
||||
class ComputeTDErrorMixin:
|
||||
def __init__(self, loss_fn):
|
||||
@make_tf_callable(self.get_session(), dynamic_shape=True)
|
||||
def compute_td_error(obs_t, act_t, rew_t, obs_tp1, done_mask,
|
||||
importance_weights):
|
||||
# Do forward pass on loss to update td errors attribute
|
||||
# (one TD-error value per item in batch to update PR weights).
|
||||
loss_fn(
|
||||
self, self.model, None, {
|
||||
SampleBatch.CUR_OBS: tf.convert_to_tensor(obs_t),
|
||||
SampleBatch.ACTIONS: tf.convert_to_tensor(act_t),
|
||||
SampleBatch.REWARDS: tf.convert_to_tensor(rew_t),
|
||||
SampleBatch.NEXT_OBS: tf.convert_to_tensor(obs_tp1),
|
||||
SampleBatch.DONES: tf.convert_to_tensor(done_mask),
|
||||
PRIO_WEIGHTS: tf.convert_to_tensor(importance_weights),
|
||||
})
|
||||
# `self.td_error` is set in loss_fn.
|
||||
return self.td_error
|
||||
|
||||
self.compute_td_error = compute_td_error
|
||||
|
||||
|
||||
def setup_mid_mixins(policy, obs_space, action_space, config):
|
||||
ComputeTDErrorMixin.__init__(policy, ddpg_actor_critic_loss)
|
||||
|
||||
|
||||
class TargetNetworkMixin:
|
||||
def __init__(self, config):
|
||||
@make_tf_callable(self.get_session())
|
||||
def update_target_fn(tau):
|
||||
tau = tf.convert_to_tensor(tau, dtype=tf.float32)
|
||||
update_target_expr = []
|
||||
model_vars = self.model.trainable_variables()
|
||||
target_model_vars = self.target_model.trainable_variables()
|
||||
assert len(model_vars) == len(target_model_vars), \
|
||||
(model_vars, target_model_vars)
|
||||
for var, var_target in zip(model_vars, target_model_vars):
|
||||
update_target_expr.append(
|
||||
var_target.assign(tau * var + (1.0 - tau) * var_target))
|
||||
logger.debug("Update target op {}".format(var_target))
|
||||
return tf.group(*update_target_expr)
|
||||
|
||||
# Hard initial update.
|
||||
self._do_update = update_target_fn
|
||||
self.update_target(tau=1.0)
|
||||
|
||||
# Support both hard and soft sync.
|
||||
def update_target(self, tau=None):
|
||||
self._do_update(np.float32(tau or self.config.get("tau")))
|
||||
|
||||
@override(TFPolicy)
|
||||
def variables(self):
|
||||
return self.model.variables() + self.target_model.variables()
|
||||
|
||||
|
||||
def setup_late_mixins(policy, obs_space, action_space, config):
|
||||
TargetNetworkMixin.__init__(policy, config)
|
||||
|
||||
|
||||
DDPGTFPolicy = build_tf_policy(
|
||||
name="DQNTFPolicy",
|
||||
get_default_config=lambda: ray.rllib.agents.ddpg.ddpg.DEFAULT_CONFIG,
|
||||
make_model=build_ddpg_models,
|
||||
action_distribution_fn=get_distribution_inputs_and_class,
|
||||
loss_fn=ddpg_actor_critic_loss,
|
||||
stats_fn=build_ddpg_stats,
|
||||
postprocess_fn=postprocess_nstep_and_prio,
|
||||
optimizer_fn=make_ddpg_optimizers,
|
||||
gradients_fn=gradients_fn,
|
||||
apply_gradients_fn=build_apply_op,
|
||||
extra_learn_fetches_fn=lambda policy: {"td_error": policy.td_error},
|
||||
before_init=before_init_fn,
|
||||
before_loss_init=setup_mid_mixins,
|
||||
after_init=setup_late_mixins,
|
||||
obs_include_prev_action_reward=False,
|
||||
mixins=[
|
||||
TargetNetworkMixin,
|
||||
ComputeTDErrorMixin,
|
||||
])
|
||||
@@ -0,0 +1,176 @@
|
||||
import numpy as np
|
||||
|
||||
from ray.rllib.models.torch.misc import SlimFC
|
||||
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
|
||||
from ray.rllib.utils.framework import try_import_torch, get_activation_fn
|
||||
|
||||
torch, nn = try_import_torch()
|
||||
|
||||
|
||||
class DDPGTorchModel(TorchModelV2, nn.Module):
|
||||
"""Extension of standard TorchModelV2 for DDPG.
|
||||
|
||||
Data flow:
|
||||
obs -> forward() -> model_out
|
||||
model_out -> get_policy_output() -> pi(s)
|
||||
model_out, actions -> get_q_values() -> Q(s, a)
|
||||
model_out, actions -> get_twin_q_values() -> Q_twin(s, a)
|
||||
|
||||
Note that this class by itself is not a valid model unless you
|
||||
implement forward() in a subclass."""
|
||||
|
||||
def __init__(self,
|
||||
obs_space,
|
||||
action_space,
|
||||
num_outputs,
|
||||
model_config,
|
||||
name,
|
||||
actor_hidden_activation="relu",
|
||||
actor_hiddens=(256, 256),
|
||||
critic_hidden_activation="relu",
|
||||
critic_hiddens=(256, 256),
|
||||
twin_q=False,
|
||||
add_layer_norm=False):
|
||||
"""Initialize variables of this model.
|
||||
|
||||
Extra model kwargs:
|
||||
actor_hidden_activation (str): activation for actor network
|
||||
actor_hiddens (list): hidden layers sizes for actor network
|
||||
critic_hidden_activation (str): activation for critic network
|
||||
critic_hiddens (list): hidden layers sizes for critic network
|
||||
twin_q (bool): build twin Q networks.
|
||||
add_layer_norm (bool): Enable layer norm (for param noise).
|
||||
|
||||
Note that the core layers for forward() are not defined here, this
|
||||
only defines the layers for the output heads. Those layers for
|
||||
forward() should be defined in subclasses of DDPGTorchModel.
|
||||
"""
|
||||
TorchModelV2.__init__(self, obs_space, action_space, num_outputs,
|
||||
model_config, name)
|
||||
nn.Module.__init__(self)
|
||||
|
||||
self.action_dim = np.product(action_space.shape)
|
||||
|
||||
# Build the policy network.
|
||||
self.policy_model = nn.Sequential()
|
||||
ins = obs_space.shape[-1]
|
||||
self.obs_ins = ins
|
||||
activation = get_activation_fn(
|
||||
actor_hidden_activation, framework="torch")
|
||||
for i, n in enumerate(actor_hiddens):
|
||||
self.policy_model.add_module(
|
||||
"action_{}".format(i),
|
||||
SlimFC(
|
||||
ins,
|
||||
n,
|
||||
initializer=torch.nn.init.xavier_uniform_,
|
||||
activation_fn=activation))
|
||||
# Add LayerNorm after each Dense.
|
||||
if add_layer_norm:
|
||||
self.policy_model.add_module("LayerNorm_A_{}".format(i),
|
||||
nn.LayerNorm(n))
|
||||
ins = n
|
||||
|
||||
self.policy_model.add_module(
|
||||
"action_out",
|
||||
SlimFC(
|
||||
ins,
|
||||
self.action_dim,
|
||||
initializer=torch.nn.init.xavier_uniform_,
|
||||
activation_fn=None))
|
||||
|
||||
# Build the Q-net(s), including target Q-net(s).
|
||||
def build_q_net(name_):
|
||||
activation = get_activation_fn(
|
||||
critic_hidden_activation, framework="torch")
|
||||
# For continuous actions: Feed obs and actions (concatenated)
|
||||
# through the NN. For discrete actions, only obs.
|
||||
q_net = nn.Sequential()
|
||||
ins = self.obs_ins + self.action_dim
|
||||
for i, n in enumerate(critic_hiddens):
|
||||
q_net.add_module(
|
||||
"{}_hidden_{}".format(name_, i),
|
||||
SlimFC(
|
||||
ins,
|
||||
n,
|
||||
initializer=torch.nn.init.xavier_uniform_,
|
||||
activation_fn=activation))
|
||||
ins = n
|
||||
|
||||
q_net.add_module(
|
||||
"{}_out".format(name_),
|
||||
SlimFC(
|
||||
ins,
|
||||
1,
|
||||
initializer=torch.nn.init.xavier_uniform_,
|
||||
activation_fn=None))
|
||||
return q_net
|
||||
|
||||
self.q_model = build_q_net("q")
|
||||
if twin_q:
|
||||
self.twin_q_model = build_q_net("twin_q")
|
||||
else:
|
||||
self.twin_q_model = None
|
||||
|
||||
def get_q_values(self, model_out, actions):
|
||||
"""Return the Q estimates for the most recent forward pass.
|
||||
|
||||
This implements Q(s, a).
|
||||
|
||||
Arguments:
|
||||
model_out (Tensor): obs embeddings from the model layers, of shape
|
||||
[BATCH_SIZE, num_outputs].
|
||||
actions (Tensor): Actions to return the Q-values for.
|
||||
Shape: [BATCH_SIZE, action_dim].
|
||||
|
||||
Returns:
|
||||
tensor of shape [BATCH_SIZE].
|
||||
"""
|
||||
return self.q_model(torch.cat([model_out, actions], -1))
|
||||
|
||||
def get_twin_q_values(self, model_out, actions):
|
||||
"""Same as get_q_values but using the twin Q net.
|
||||
|
||||
This implements the twin Q(s, a).
|
||||
|
||||
Arguments:
|
||||
model_out (Tensor): obs embeddings from the model layers, of shape
|
||||
[BATCH_SIZE, num_outputs].
|
||||
actions (Optional[Tensor]): Actions to return the Q-values for.
|
||||
Shape: [BATCH_SIZE, action_dim].
|
||||
|
||||
Returns:
|
||||
tensor of shape [BATCH_SIZE].
|
||||
"""
|
||||
return self.twin_q_model(torch.cat([model_out, actions], -1))
|
||||
|
||||
def get_policy_output(self, model_out):
|
||||
"""Return the action output for the most recent forward pass.
|
||||
|
||||
This outputs the support for pi(s). For continuous action spaces, this
|
||||
is the action directly. For discrete, is is the mean / std dev.
|
||||
|
||||
Arguments:
|
||||
model_out (Tensor): obs embeddings from the model layers, of shape
|
||||
[BATCH_SIZE, num_outputs].
|
||||
|
||||
Returns:
|
||||
tensor of shape [BATCH_SIZE, action_out_size]
|
||||
"""
|
||||
return self.policy_model(model_out)
|
||||
|
||||
def policy_variables(self, as_dict=False):
|
||||
"""Return the list of variables for the policy net."""
|
||||
if as_dict:
|
||||
return self.policy_model.state_dict()
|
||||
return list(self.policy_model.parameters())
|
||||
|
||||
def q_variables(self, as_dict=False):
|
||||
"""Return the list of variables for Q / twin Q nets."""
|
||||
if as_dict:
|
||||
return {
|
||||
**self.q_model.state_dict(),
|
||||
**(self.twin_q_model.state_dict() if self.twin_q_model else {})
|
||||
}
|
||||
return list(self.q_model.parameters()) + \
|
||||
(list(self.twin_q_model.parameters()) if self.twin_q_model else [])
|
||||
@@ -0,0 +1,273 @@
|
||||
import logging
|
||||
|
||||
import ray
|
||||
from ray.rllib.agents.ddpg.ddpg_tf_policy import build_ddpg_models, \
|
||||
get_distribution_inputs_and_class
|
||||
from ray.rllib.agents.dqn.dqn_tf_policy import postprocess_nstep_and_prio, \
|
||||
PRIO_WEIGHTS
|
||||
from ray.rllib.models.torch.torch_action_dist import TorchDeterministic
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.policy.torch_policy_template import build_torch_policy
|
||||
from ray.rllib.utils.framework import try_import_torch
|
||||
from ray.rllib.utils.torch_ops import huber_loss, minimize_and_clip, l2_loss
|
||||
|
||||
torch, nn = try_import_torch()
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def build_ddpg_models_and_action_dist(policy, obs_space, action_space, config):
|
||||
model = build_ddpg_models(policy, obs_space, action_space, config)
|
||||
# TODO(sven): Unify this once we generically support creating more than
|
||||
# one Model per policy. Note: Device placement is done automatically
|
||||
# already for `policy.model` (but not for the target model).
|
||||
device = (torch.device("cuda")
|
||||
if torch.cuda.is_available() else torch.device("cpu"))
|
||||
policy.target_model = policy.target_model.to(device)
|
||||
return model, TorchDeterministic
|
||||
|
||||
|
||||
def ddpg_actor_critic_loss(policy, model, _, train_batch):
|
||||
twin_q = policy.config["twin_q"]
|
||||
gamma = policy.config["gamma"]
|
||||
n_step = policy.config["n_step"]
|
||||
use_huber = policy.config["use_huber"]
|
||||
huber_threshold = policy.config["huber_threshold"]
|
||||
l2_reg = policy.config["l2_reg"]
|
||||
|
||||
input_dict = {
|
||||
"obs": train_batch[SampleBatch.CUR_OBS],
|
||||
"is_training": True,
|
||||
}
|
||||
input_dict_next = {
|
||||
"obs": train_batch[SampleBatch.NEXT_OBS],
|
||||
"is_training": True,
|
||||
}
|
||||
|
||||
model_out_t, _ = model(input_dict, [], None)
|
||||
model_out_tp1, _ = model(input_dict_next, [], None)
|
||||
target_model_out_tp1, _ = policy.target_model(input_dict_next, [], None)
|
||||
|
||||
# Policy network evaluation.
|
||||
# prev_update_ops = set(tf.get_collection(tf.GraphKeys.UPDATE_OPS))
|
||||
policy_t = model.get_policy_output(model_out_t)
|
||||
# policy_batchnorm_update_ops = list(
|
||||
# set(tf.get_collection(tf.GraphKeys.UPDATE_OPS)) - prev_update_ops)
|
||||
|
||||
policy_tp1 = \
|
||||
policy.target_model.get_policy_output(target_model_out_tp1)
|
||||
|
||||
# Action outputs.
|
||||
if policy.config["smooth_target_policy"]:
|
||||
target_noise_clip = policy.config["target_noise_clip"]
|
||||
clipped_normal_sample = torch.clamp(
|
||||
torch.normal(
|
||||
mean=torch.zeros(policy_tp1.size()),
|
||||
std=policy.config["target_noise"]), -target_noise_clip,
|
||||
target_noise_clip)
|
||||
policy_tp1_smoothed = torch.clamp(
|
||||
policy_tp1 + clipped_normal_sample,
|
||||
policy.action_space.low * torch.ones_like(policy_tp1),
|
||||
policy.action_space.high * torch.ones_like(policy_tp1))
|
||||
else:
|
||||
# No smoothing, just use deterministic actions.
|
||||
policy_tp1_smoothed = policy_tp1
|
||||
|
||||
# Q-net(s) evaluation.
|
||||
# prev_update_ops = set(tf.get_collection(tf.GraphKeys.UPDATE_OPS))
|
||||
# Q-values for given actions & observations in given current
|
||||
q_t = model.get_q_values(model_out_t, train_batch[SampleBatch.ACTIONS])
|
||||
|
||||
# Q-values for current policy (no noise) in given current state
|
||||
q_t_det_policy = model.get_q_values(model_out_t, policy_t)
|
||||
|
||||
if twin_q:
|
||||
twin_q_t = model.get_twin_q_values(model_out_t,
|
||||
train_batch[SampleBatch.ACTIONS])
|
||||
# q_batchnorm_update_ops = list(
|
||||
# set(tf.get_collection(tf.GraphKeys.UPDATE_OPS)) - prev_update_ops)
|
||||
|
||||
# Target q-net(s) evaluation.
|
||||
q_tp1 = policy.target_model.get_q_values(target_model_out_tp1,
|
||||
policy_tp1_smoothed)
|
||||
|
||||
if twin_q:
|
||||
twin_q_tp1 = policy.target_model.get_twin_q_values(
|
||||
target_model_out_tp1, policy_tp1_smoothed)
|
||||
|
||||
q_t_selected = torch.squeeze(q_t, axis=len(q_t.shape) - 1)
|
||||
if twin_q:
|
||||
twin_q_t_selected = torch.squeeze(twin_q_t, axis=len(q_t.shape) - 1)
|
||||
q_tp1 = torch.min(q_tp1, twin_q_tp1)
|
||||
|
||||
q_tp1_best = torch.squeeze(input=q_tp1, axis=len(q_tp1.shape) - 1)
|
||||
q_tp1_best_masked = \
|
||||
(1.0 - train_batch[SampleBatch.DONES].float()) * \
|
||||
q_tp1_best
|
||||
|
||||
# Compute RHS of bellman equation.
|
||||
q_t_selected_target = (train_batch[SampleBatch.REWARDS] +
|
||||
gamma**n_step * q_tp1_best_masked).detach()
|
||||
|
||||
# Compute the error (potentially clipped).
|
||||
if twin_q:
|
||||
td_error = q_t_selected - q_t_selected_target
|
||||
twin_td_error = twin_q_t_selected - q_t_selected_target
|
||||
td_error = td_error + twin_td_error
|
||||
if use_huber:
|
||||
errors = huber_loss(td_error, huber_threshold) \
|
||||
+ huber_loss(twin_td_error, huber_threshold)
|
||||
else:
|
||||
errors = 0.5 * \
|
||||
(torch.pow(td_error, 2.0) + torch.pow(twin_td_error, 2.0))
|
||||
else:
|
||||
td_error = q_t_selected - q_t_selected_target
|
||||
if use_huber:
|
||||
errors = huber_loss(td_error, huber_threshold)
|
||||
else:
|
||||
errors = 0.5 * torch.pow(td_error, 2.0)
|
||||
|
||||
critic_loss = torch.mean(train_batch[PRIO_WEIGHTS] * errors)
|
||||
actor_loss = -torch.mean(q_t_det_policy)
|
||||
|
||||
# Add l2-regularization if required.
|
||||
if l2_reg is not None:
|
||||
for name, var in policy.model.policy_variables(as_dict=True).items():
|
||||
if "bias" not in name:
|
||||
actor_loss += (l2_reg * l2_loss(var))
|
||||
for name, var in policy.model.q_variables(as_dict=True).items():
|
||||
if "bias" not in name:
|
||||
critic_loss += (l2_reg * l2_loss(var))
|
||||
|
||||
# Model self-supervised losses.
|
||||
if policy.config["use_state_preprocessor"]:
|
||||
# Expand input_dict in case custom_loss' need them.
|
||||
input_dict[SampleBatch.ACTIONS] = train_batch[SampleBatch.ACTIONS]
|
||||
input_dict[SampleBatch.REWARDS] = train_batch[SampleBatch.REWARDS]
|
||||
input_dict[SampleBatch.DONES] = train_batch[SampleBatch.DONES]
|
||||
input_dict[SampleBatch.NEXT_OBS] = train_batch[SampleBatch.NEXT_OBS]
|
||||
[actor_loss, critic_loss] = model.custom_loss(
|
||||
[actor_loss, critic_loss], input_dict)
|
||||
|
||||
# Store values for stats function.
|
||||
policy.actor_loss = actor_loss
|
||||
policy.critic_loss = critic_loss
|
||||
policy.td_error = td_error
|
||||
policy.q_t = q_t
|
||||
|
||||
# Return one loss value (even though we treat them separately in our
|
||||
# 2 optimizers: actor and critic).
|
||||
return policy.actor_loss, policy.critic_loss
|
||||
|
||||
|
||||
def make_ddpg_optimizers(policy, config):
|
||||
# Create separate optimizers for actor & critic losses.
|
||||
policy._actor_optimizer = torch.optim.Adam(
|
||||
params=policy.model.policy_variables(), lr=config["actor_lr"])
|
||||
policy._critic_optimizer = torch.optim.Adam(
|
||||
params=policy.model.q_variables(), lr=config["critic_lr"])
|
||||
return policy._actor_optimizer, policy._critic_optimizer
|
||||
|
||||
|
||||
def apply_gradients_fn(policy):
|
||||
# For policy gradient, update policy net one time v.s.
|
||||
# update critic net `policy_delay` time(s).
|
||||
if policy.global_step % policy.config["policy_delay"] == 0:
|
||||
policy._actor_optimizer.step()
|
||||
|
||||
policy._critic_optimizer.step()
|
||||
|
||||
# Increment global step & apply ops.
|
||||
policy.global_step += 1
|
||||
|
||||
|
||||
def gradients_fn(policy, optimizer, loss):
|
||||
if policy.config["grad_norm_clipping"] is not None:
|
||||
minimize_and_clip(optimizer, policy.config["grad_norm_clipping"])
|
||||
return {}
|
||||
|
||||
|
||||
def build_ddpg_stats(policy, batch):
|
||||
stats = {
|
||||
"actor_loss": policy.actor_loss,
|
||||
"critic_loss": policy.critic_loss,
|
||||
"mean_q": torch.mean(policy.q_t),
|
||||
"max_q": torch.max(policy.q_t),
|
||||
"min_q": torch.min(policy.q_t),
|
||||
"td_error": policy.td_error
|
||||
}
|
||||
return stats
|
||||
|
||||
|
||||
def before_init_fn(policy, obs_space, action_space, config):
|
||||
# Create global step for counting the number of update operations.
|
||||
policy.global_step = 0
|
||||
|
||||
|
||||
class ComputeTDErrorMixin:
|
||||
def __init__(self, loss_fn):
|
||||
def compute_td_error(obs_t, act_t, rew_t, obs_tp1, done_mask,
|
||||
importance_weights):
|
||||
input_dict = self._lazy_tensor_dict({
|
||||
SampleBatch.CUR_OBS: obs_t,
|
||||
SampleBatch.ACTIONS: act_t,
|
||||
SampleBatch.REWARDS: rew_t,
|
||||
SampleBatch.NEXT_OBS: obs_tp1,
|
||||
SampleBatch.DONES: done_mask,
|
||||
PRIO_WEIGHTS: importance_weights,
|
||||
})
|
||||
# Do forward pass on loss to update td errors attribute
|
||||
# (one TD-error value per item in batch to update PR weights).
|
||||
loss_fn(self, self.model, None, input_dict)
|
||||
|
||||
# Self.td_error is set within actor_critic_loss call.
|
||||
return self.td_error
|
||||
|
||||
self.compute_td_error = compute_td_error
|
||||
|
||||
|
||||
class TargetNetworkMixin:
|
||||
def __init__(self):
|
||||
# Hard initial update from Q-net(s) to target Q-net(s).
|
||||
self.update_target(tau=1.0)
|
||||
|
||||
def update_target(self, tau=None):
|
||||
tau = tau or self.config.get("tau")
|
||||
# Update_target_fn will be called periodically to copy Q network to
|
||||
# target Q network, using (soft) tau-synching.
|
||||
# Full sync from Q-model to target Q-model.
|
||||
if tau == 1.0:
|
||||
self.target_model.load_state_dict(self.model.state_dict())
|
||||
# Partial (soft) sync using tau-synching.
|
||||
else:
|
||||
model_vars = self.model.variables()
|
||||
target_model_vars = self.target_model.variables()
|
||||
assert len(model_vars) == len(target_model_vars), \
|
||||
(model_vars, target_model_vars)
|
||||
for var, var_target in zip(model_vars, target_model_vars):
|
||||
var_target.data = tau * var.data + \
|
||||
(1.0 - tau) * var_target.data
|
||||
|
||||
|
||||
def setup_late_mixins(policy, obs_space, action_space, config):
|
||||
ComputeTDErrorMixin.__init__(policy, ddpg_actor_critic_loss)
|
||||
TargetNetworkMixin.__init__(policy)
|
||||
|
||||
|
||||
DDPGTorchPolicy = build_torch_policy(
|
||||
name="DDPGTorchPolicy",
|
||||
loss_fn=ddpg_actor_critic_loss,
|
||||
get_default_config=lambda: ray.rllib.agents.ddpg.ddpg.DEFAULT_CONFIG,
|
||||
stats_fn=build_ddpg_stats,
|
||||
postprocess_fn=postprocess_nstep_and_prio,
|
||||
extra_grad_process_fn=gradients_fn,
|
||||
optimizer_fn=make_ddpg_optimizers,
|
||||
before_init=before_init_fn,
|
||||
after_init=setup_late_mixins,
|
||||
action_distribution_fn=get_distribution_inputs_and_class,
|
||||
make_model_and_action_dist=build_ddpg_models_and_action_dist,
|
||||
apply_gradients_fn=apply_gradients_fn,
|
||||
mixins=[
|
||||
TargetNetworkMixin,
|
||||
ComputeTDErrorMixin,
|
||||
])
|
||||
@@ -1,4 +1,5 @@
|
||||
from ray.rllib.models.tf.tf_modelv2 import TFModelV2
|
||||
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils import try_import_tf
|
||||
|
||||
@@ -13,3 +14,13 @@ class NoopModel(TFModelV2):
|
||||
@override(TFModelV2)
|
||||
def forward(self, input_dict, state, seq_lens):
|
||||
return tf.cast(input_dict["obs_flat"], tf.float32), state
|
||||
|
||||
|
||||
class TorchNoopModel(TorchModelV2):
|
||||
"""Trivial model that just returns the obs flattened.
|
||||
|
||||
This is the model used if use_state_preprocessor=False."""
|
||||
|
||||
@override(TorchModelV2)
|
||||
def forward(self, input_dict, state, seq_lens):
|
||||
return input_dict["obs_flat"].float(), state
|
||||
|
||||
@@ -60,5 +60,4 @@ TD3_DEFAULT_CONFIG = DDPGTrainer.merge_trainer_configs(
|
||||
TD3Trainer = DDPGTrainer.with_updates(
|
||||
name="TD3",
|
||||
default_config=TD3_DEFAULT_CONFIG,
|
||||
get_policy_class=None,
|
||||
)
|
||||
|
||||
@@ -17,7 +17,7 @@ class TestDDPG(unittest.TestCase):
|
||||
num_iterations = 2
|
||||
|
||||
# Test against all frameworks.
|
||||
for _ in framework_iterator(config, "tf"):
|
||||
for _ in framework_iterator(config, ("torch", "tf")):
|
||||
trainer = ddpg.DDPGTrainer(config=config, env="Pendulum-v0")
|
||||
for i in range(num_iterations):
|
||||
results = trainer.train()
|
||||
@@ -25,12 +25,13 @@ class TestDDPG(unittest.TestCase):
|
||||
|
||||
def test_ddpg_exploration_and_with_random_prerun(self):
|
||||
"""Tests DDPG's Exploration (w/ random actions for n timesteps)."""
|
||||
config = ddpg.DEFAULT_CONFIG.copy()
|
||||
config["num_workers"] = 0 # Run locally.
|
||||
core_config = ddpg.DEFAULT_CONFIG.copy()
|
||||
core_config["num_workers"] = 0 # Run locally.
|
||||
obs = np.array([0.0, 0.1, -0.1])
|
||||
|
||||
# Test against all frameworks.
|
||||
for _ in framework_iterator(config, "tf"):
|
||||
for _ in framework_iterator(core_config, ("torch", "tf")):
|
||||
config = core_config.copy()
|
||||
# Default OUNoise setup.
|
||||
trainer = ddpg.DDPGTrainer(config=config, env="Pendulum-v0")
|
||||
# Setting explore=False should always return the same action.
|
||||
|
||||
@@ -125,8 +125,10 @@ class DQNTorchModel(TorchModelV2):
|
||||
"""
|
||||
in_size = int(action_in.shape[1])
|
||||
|
||||
epsilon_in = torch.normal(size=[in_size])
|
||||
epsilon_out = torch.normal(size=[out_size])
|
||||
epsilon_in = torch.normal(
|
||||
mean=torch.zeros([in_size]), std=torch.ones([in_size]))
|
||||
epsilon_out = torch.normal(
|
||||
mean=torch.zeros([out_size]), std=torch.ones([out_size]))
|
||||
epsilon_in = self._f_epsilon(epsilon_in)
|
||||
epsilon_out = self._f_epsilon(epsilon_out)
|
||||
epsilon_w = torch.matmul(
|
||||
|
||||
@@ -156,7 +156,7 @@ def build_q_losses(policy, model, _, train_batch):
|
||||
policy.q_model,
|
||||
train_batch[SampleBatch.CUR_OBS],
|
||||
explore=False,
|
||||
is_training=False)
|
||||
is_training=True)
|
||||
|
||||
# target q network evalution
|
||||
q_tp1 = compute_q_values(
|
||||
@@ -164,7 +164,7 @@ def build_q_losses(policy, model, _, train_batch):
|
||||
policy.target_q_model,
|
||||
train_batch[SampleBatch.NEXT_OBS],
|
||||
explore=False,
|
||||
is_training=False)
|
||||
is_training=True)
|
||||
|
||||
# q scores for actions which we know were selected in the given state.
|
||||
one_hot_selection = F.one_hot(train_batch[SampleBatch.ACTIONS],
|
||||
@@ -178,7 +178,7 @@ def build_q_losses(policy, model, _, train_batch):
|
||||
policy.q_model,
|
||||
train_batch[SampleBatch.NEXT_OBS],
|
||||
explore=False,
|
||||
is_training=False)
|
||||
is_training=True)
|
||||
q_tp1_best_using_online_net = torch.argmax(q_tp1_using_online_net, 1)
|
||||
q_tp1_best_one_hot_selection = F.one_hot(q_tp1_best_using_online_net,
|
||||
policy.action_space.n)
|
||||
|
||||
@@ -3,7 +3,7 @@ import logging
|
||||
|
||||
import ray
|
||||
import ray.experimental.tf_utils
|
||||
from ray.rllib.agents.ddpg.ddpg_policy import ComputeTDErrorMixin, \
|
||||
from ray.rllib.agents.ddpg.ddpg_tf_policy import ComputeTDErrorMixin, \
|
||||
TargetNetworkMixin
|
||||
from ray.rllib.agents.dqn.dqn_tf_policy import postprocess_nstep_and_prio
|
||||
from ray.rllib.agents.sac.sac_tf_model import SACTFModel
|
||||
|
||||
+4
-3
@@ -1,9 +1,10 @@
|
||||
pendulum-ddpg:
|
||||
pendulum-ddpg-tf:
|
||||
env: Pendulum-v0
|
||||
run: DDPG
|
||||
stop:
|
||||
episode_reward_mean: -900
|
||||
timesteps_total: 100000
|
||||
config:
|
||||
use_huber: True
|
||||
clip_rewards: False
|
||||
use_pytorch: false
|
||||
use_huber: true
|
||||
clip_rewards: false
|
||||
@@ -0,0 +1,10 @@
|
||||
pendulum-ddpg-torch:
|
||||
env: Pendulum-v0
|
||||
run: DDPG
|
||||
stop:
|
||||
episode_reward_mean: -900
|
||||
timesteps_total: 100000
|
||||
config:
|
||||
use_pytorch: true
|
||||
use_huber: true
|
||||
clip_rewards: false
|
||||
@@ -153,10 +153,10 @@ class OrnsteinUhlenbeckNoise(GaussianNoise):
|
||||
ou_new = self.ou_theta * -self.ou_state + \
|
||||
self.ou_sigma * gaussian_sample
|
||||
self.ou_state += ou_new
|
||||
high_low = torch.from_numpy(self.action_space.high -
|
||||
self.action_space.low).to(
|
||||
self.device)
|
||||
noise = scale * self.ou_base_scale * self.ou_state * high_low
|
||||
high_m_low = torch.from_numpy(
|
||||
self.action_space.high - self.action_space.low). \
|
||||
to(self.device)
|
||||
noise = scale * self.ou_base_scale * self.ou_state * high_m_low
|
||||
action = torch.clamp(det_actions + noise,
|
||||
self.action_space.low[0],
|
||||
self.action_space.high[0])
|
||||
|
||||
@@ -24,15 +24,14 @@ def do_test_explorations(run,
|
||||
expected_mean_action=None):
|
||||
"""Calls an Agent's `compute_actions` with different `explore` options."""
|
||||
|
||||
config = config.copy()
|
||||
core_config = config.copy()
|
||||
if run not in [a3c.A3CTrainer]:
|
||||
config["num_workers"] = 0
|
||||
core_config["num_workers"] = 0
|
||||
|
||||
# Test all frameworks.
|
||||
for fw in framework_iterator(config):
|
||||
for fw in framework_iterator(core_config):
|
||||
if fw == "torch" and \
|
||||
run in [ddpg.DDPGTrainer, impala.ImpalaTrainer,
|
||||
sac.SACTrainer, td3.TD3Trainer]:
|
||||
run in [impala.ImpalaTrainer, sac.SACTrainer]:
|
||||
continue
|
||||
elif fw == "eager" and run in [
|
||||
ddpg.DDPGTrainer, sac.SACTrainer, td3.TD3Trainer
|
||||
@@ -44,14 +43,15 @@ def do_test_explorations(run,
|
||||
# Test for both the default Agent's exploration AND the `Random`
|
||||
# exploration class.
|
||||
for exploration in [None, "Random"]:
|
||||
local_config = core_config.copy()
|
||||
if exploration == "Random":
|
||||
# TODO(sven): Random doesn't work for IMPALA yet.
|
||||
if run is impala.ImpalaTrainer:
|
||||
continue
|
||||
config["exploration_config"] = {"type": "Random"}
|
||||
local_config["exploration_config"] = {"type": "Random"}
|
||||
print("exploration={}".format(exploration or "default"))
|
||||
|
||||
trainer = run(config=config, env=env)
|
||||
trainer = run(config=local_config, env=env)
|
||||
|
||||
# Make sure all actions drawn are the same, given same
|
||||
# observations.
|
||||
|
||||
@@ -40,7 +40,7 @@ class TestParameterNoise(unittest.TestCase):
|
||||
|
||||
config = core_config.copy()
|
||||
|
||||
# DQN with ParameterNoise exploration (config["explore"]=True).
|
||||
# Algo with ParameterNoise exploration (config["explore"]=True).
|
||||
# ----
|
||||
config["exploration_config"] = {"type": "ParameterNoise"}
|
||||
config["explore"] = True
|
||||
|
||||
@@ -192,7 +192,9 @@ def get_variable(value,
|
||||
tf_name, initializer=value, dtype=dtype, trainable=trainable)
|
||||
elif framework == "torch" and torch_tensor is True:
|
||||
torch, _ = try_import_torch()
|
||||
var_ = torch.from_numpy(value).to(device)
|
||||
var_ = torch.from_numpy(value)
|
||||
if device:
|
||||
var_ = var_.to(device)
|
||||
var_.requires_grad = trainable
|
||||
return var_
|
||||
# torch or None: Return python primitive.
|
||||
|
||||
+17
-10
@@ -21,6 +21,14 @@ def huber_loss(x, delta=1.0):
|
||||
torch.pow(x, 2.0) * 0.5, delta * (torch.abs(x) - 0.5 * delta))
|
||||
|
||||
|
||||
def l2_loss(x):
|
||||
"""Computes half the L2 norm of a tensor without the sqrt.
|
||||
|
||||
output = sum(x ** 2) / 2
|
||||
"""
|
||||
return torch.sum(torch.pow(x, 2.0)) / 2.0
|
||||
|
||||
|
||||
def reduce_mean_ignore_inf(x, axis):
|
||||
"""Same as torch.mean() but ignores -inf values."""
|
||||
mask = torch.ne(x, float("-inf"))
|
||||
@@ -28,17 +36,16 @@ def reduce_mean_ignore_inf(x, axis):
|
||||
return torch.sum(x_zeroed, axis) / torch.sum(mask.float(), axis)
|
||||
|
||||
|
||||
def minimize_and_clip(optimizer, objective, var_list, clip_val=10):
|
||||
"""Minimized `objective` using `optimizer` w.r.t. variables in
|
||||
`var_list` while ensure the norm of the gradients for each
|
||||
variable is clipped to `clip_val`
|
||||
def minimize_and_clip(optimizer, clip_val=10):
|
||||
"""Clips gradients found in `optimizer.param_groups` to given value.
|
||||
|
||||
Ensures the norm of the gradients for each variable is clipped to
|
||||
`clip_val`
|
||||
"""
|
||||
gradients = optimizer.compute_gradients(objective, var_list=var_list)
|
||||
for i, (grad, var) in enumerate(gradients):
|
||||
if grad is not None:
|
||||
gradients[i] = (torch.nn.utils.clip_grad_norm_(grad, clip_val),
|
||||
var)
|
||||
return gradients
|
||||
for param_group in optimizer.param_groups:
|
||||
for p in param_group["params"]:
|
||||
if p.grad is not None:
|
||||
torch.nn.utils.clip_grad_norm_(p.grad, clip_val)
|
||||
|
||||
|
||||
def sequence_mask(lengths, maxlen, dtype=None):
|
||||
|
||||
Reference in New Issue
Block a user