diff --git a/rllib/agents/a3c/a3c_tf_policy.py b/rllib/agents/a3c/a3c_tf_policy.py index a38eff247..1ede5706c 100644 --- a/rllib/agents/a3c/a3c_tf_policy.py +++ b/rllib/agents/a3c/a3c_tf_policy.py @@ -1,17 +1,34 @@ """Note: Keep in sync with changes to VTraceTFPolicy.""" import ray +from ray.rllib.agents.ppo.ppo_tf_policy import ValueNetworkMixin from ray.rllib.policy.sample_batch import SampleBatch -from ray.rllib.evaluation.postprocessing import compute_advantages, \ +from ray.rllib.evaluation.postprocessing import compute_gae_for_sample_batch, \ Postprocessing from ray.rllib.policy.tf_policy_template import build_tf_policy from ray.rllib.policy.tf_policy import LearningRateSchedule +from ray.rllib.utils.deprecation import deprecation_warning from ray.rllib.utils.framework import try_import_tf -from ray.rllib.utils.tf_ops import explained_variance, make_tf_callable +from ray.rllib.utils.tf_ops import explained_variance tf1, tf, tfv = try_import_tf() +def postprocess_advantages(policy, + sample_batch, + other_agent_batches=None, + episode=None): + + # Stub serving backward compatibility. + deprecation_warning( + old="rllib.agents.a3c.a3c_tf_policy.postprocess_advantages", + new="rllib.evaluation.postprocessing.compute_gae_for_sample_batch", + error=False) + + return compute_gae_for_sample_batch(policy, sample_batch, + other_agent_batches, episode) + + class A3CLoss: def __init__(self, action_dist, @@ -45,46 +62,10 @@ def actor_critic_loss(policy, model, dist_class, train_batch): return policy.loss.total_loss -def postprocess_advantages(policy, - sample_batch, - other_agent_batches=None, - episode=None): - completed = sample_batch[SampleBatch.DONES][-1] - if completed: - last_r = 0.0 - else: - next_state = [] - for i in range(policy.num_state_tensors()): - next_state.append(sample_batch["state_out_{}".format(i)][-1]) - last_r = policy._value(sample_batch[SampleBatch.NEXT_OBS][-1], - sample_batch[SampleBatch.ACTIONS][-1], - sample_batch[SampleBatch.REWARDS][-1], - *next_state) - return compute_advantages( - sample_batch, last_r, policy.config["gamma"], policy.config["lambda"], - policy.config["use_gae"], policy.config["use_critic"]) - - def add_value_function_fetch(policy): return {SampleBatch.VF_PREDS: policy.model.value_function()} -class ValueNetworkMixin: - def __init__(self): - @make_tf_callable(self.get_session()) - def value(ob, prev_action, prev_reward, *state): - model_out, _ = self.model({ - SampleBatch.CUR_OBS: tf.convert_to_tensor([ob]), - SampleBatch.PREV_ACTIONS: tf.convert_to_tensor([prev_action]), - SampleBatch.PREV_REWARDS: tf.convert_to_tensor([prev_reward]), - "is_training": tf.convert_to_tensor(False), - }, [tf.convert_to_tensor([s]) for s in state], - tf.convert_to_tensor([1])) - return self.model.value_function()[0] - - self._value = value - - def stats(policy, train_batch): return { "cur_lr": tf.cast(policy.cur_lr, tf.float64), @@ -115,7 +96,7 @@ def clip_gradients(policy, optimizer, loss): def setup_mixins(policy, obs_space, action_space, config): - ValueNetworkMixin.__init__(policy) + ValueNetworkMixin.__init__(policy, obs_space, action_space, config) LearningRateSchedule.__init__(policy, config["lr"], config["lr_schedule"]) @@ -126,7 +107,7 @@ A3CTFPolicy = build_tf_policy( stats_fn=stats, grad_stats_fn=grad_stats, gradients_fn=clip_gradients, - postprocess_fn=postprocess_advantages, + postprocess_fn=compute_gae_for_sample_batch, extra_action_fetches_fn=add_value_function_fetch, before_loss_init=setup_mixins, mixins=[ValueNetworkMixin, LearningRateSchedule]) diff --git a/rllib/agents/a3c/a3c_torch_policy.py b/rllib/agents/a3c/a3c_torch_policy.py index 5eb83eb5f..603c37aab 100644 --- a/rllib/agents/a3c/a3c_torch_policy.py +++ b/rllib/agents/a3c/a3c_torch_policy.py @@ -1,13 +1,35 @@ +import gym + import ray -from ray.rllib.evaluation.postprocessing import compute_advantages, \ +from ray.rllib.agents.ppo.ppo_torch_policy import ValueNetworkMixin +from ray.rllib.evaluation.postprocessing import compute_gae_for_sample_batch, \ Postprocessing +from ray.rllib.policy.policy import Policy from ray.rllib.policy.policy_template import build_policy_class from ray.rllib.policy.sample_batch import SampleBatch +from ray.rllib.utils.deprecation import deprecation_warning from ray.rllib.utils.framework import try_import_torch +from ray.rllib.utils.torch_ops import apply_grad_clipping +from ray.rllib.utils.typing import TrainerConfigDict torch, nn = try_import_torch() +def add_advantages(policy, + sample_batch, + other_agent_batches=None, + episode=None): + + # Stub serving backward compatibility. + deprecation_warning( + old="rllib.agents.a3c.a3c_torch_policy.add_advantages", + new="rllib.evaluation.postprocessing.compute_gae_for_sample_batch", + error=False) + + return compute_gae_for_sample_batch(policy, sample_batch, + other_agent_batches, episode) + + def actor_critic_loss(policy, model, dist_class, train_batch): logits, _ = model.from_batch(train_batch) values = model.value_function() @@ -36,52 +58,27 @@ def loss_and_entropy_stats(policy, train_batch): } -def add_advantages(policy, - sample_batch, - other_agent_batches=None, - episode=None): - - completed = sample_batch[SampleBatch.DONES][-1] - if completed: - last_r = 0.0 - else: - last_r = policy._value(sample_batch[SampleBatch.NEXT_OBS][-1]) - - return compute_advantages( - sample_batch, last_r, policy.config["gamma"], policy.config["lambda"], - policy.config["use_gae"], policy.config["use_critic"]) - - def model_value_predictions(policy, input_dict, state_batches, model, action_dist): return {SampleBatch.VF_PREDS: model.value_function()} -def apply_grad_clipping(policy, optimizer, loss): - info = {} - if policy.config["grad_clip"]: - for param_group in optimizer.param_groups: - # Make sure we only pass params with grad != None into torch - # clip_grad_norm_. Would fail otherwise. - params = list( - filter(lambda p: p.grad is not None, param_group["params"])) - if params: - grad_gnorm = nn.utils.clip_grad_norm_( - params, policy.config["grad_clip"]) - if isinstance(grad_gnorm, torch.Tensor): - grad_gnorm = grad_gnorm.cpu().numpy() - info["grad_gnorm"] = grad_gnorm - return info - - def torch_optimizer(policy, config): return torch.optim.Adam(policy.model.parameters(), lr=config["lr"]) -class ValueNetworkMixin: - def _value(self, obs): - _ = self.model({"obs": torch.Tensor([obs]).to(self.device)}, [], [1]) - return self.model.value_function()[0] +def setup_mixins(policy: Policy, obs_space: gym.spaces.Space, + action_space: gym.spaces.Space, + config: TrainerConfigDict) -> None: + """Call all mixin classes' constructors before PPOPolicy initialization. + + Args: + policy (Policy): The Policy object. + obs_space (gym.spaces.Space): The Policy's observation space. + action_space (gym.spaces.Space): The Policy's action space. + config (TrainerConfigDict): The Policy's config. + """ + ValueNetworkMixin.__init__(policy, obs_space, action_space, config) A3CTorchPolicy = build_policy_class( @@ -90,9 +87,10 @@ A3CTorchPolicy = build_policy_class( get_default_config=lambda: ray.rllib.agents.a3c.a3c.DEFAULT_CONFIG, loss_fn=actor_critic_loss, stats_fn=loss_and_entropy_stats, - postprocess_fn=add_advantages, + postprocess_fn=compute_gae_for_sample_batch, extra_action_out_fn=model_value_predictions, extra_grad_process_fn=apply_grad_clipping, optimizer_fn=torch_optimizer, + before_loss_init=setup_mixins, mixins=[ValueNetworkMixin], ) diff --git a/rllib/agents/cql/cql_torch_policy.py b/rllib/agents/cql/cql_torch_policy.py index 14812230f..aace3e8b0 100644 --- a/rllib/agents/cql/cql_torch_policy.py +++ b/rllib/agents/cql/cql_torch_policy.py @@ -8,7 +8,6 @@ from typing import Dict, List, Tuple, Type, Union import ray import ray.experimental.tf_utils -from ray.rllib.agents.a3c.a3c_torch_policy import apply_grad_clipping from ray.rllib.agents.sac.sac_tf_policy import postprocess_trajectory, \ validate_spaces from ray.rllib.agents.sac.sac_torch_policy import _get_dist_class, stats, \ @@ -22,7 +21,8 @@ from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.utils.framework import try_import_torch from ray.rllib.utils.typing import LocalOptimizer, TensorType, \ TrainerConfigDict -from ray.rllib.utils.torch_ops import convert_to_torch_tensor +from ray.rllib.utils.torch_ops import apply_grad_clipping, \ + convert_to_torch_tensor torch, nn = try_import_torch() F = nn.functional diff --git a/rllib/agents/ddpg/ddpg_torch_policy.py b/rllib/agents/ddpg/ddpg_torch_policy.py index 79be4cce8..f6c73f912 100644 --- a/rllib/agents/ddpg/ddpg_torch_policy.py +++ b/rllib/agents/ddpg/ddpg_torch_policy.py @@ -1,7 +1,6 @@ import logging import ray -from ray.rllib.agents.a3c.a3c_torch_policy import apply_grad_clipping from ray.rllib.agents.ddpg.ddpg_tf_policy import build_ddpg_models, \ get_distribution_inputs_and_class, validate_spaces from ray.rllib.agents.dqn.dqn_tf_policy import postprocess_nstep_and_prio, \ @@ -10,7 +9,7 @@ from ray.rllib.models.torch.torch_action_dist import TorchDeterministic from ray.rllib.policy.policy_template import build_policy_class from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.utils.framework import try_import_torch -from ray.rllib.utils.torch_ops import huber_loss, l2_loss +from ray.rllib.utils.torch_ops import apply_grad_clipping, huber_loss, l2_loss torch, nn = try_import_torch() diff --git a/rllib/agents/dqn/dqn_tf_policy.py b/rllib/agents/dqn/dqn_tf_policy.py index d1d6d4570..874f34ba8 100644 --- a/rllib/agents/dqn/dqn_tf_policy.py +++ b/rllib/agents/dqn/dqn_tf_policy.py @@ -301,17 +301,11 @@ def adam_optimizer(policy: Policy, config: TrainerConfigDict def clip_gradients(policy: Policy, optimizer: "tf.keras.optimizers.Optimizer", loss: TensorType) -> ModelGradients: - if policy.config["grad_clip"] is not None: - grads_and_vars = minimize_and_clip( - optimizer, - loss, - var_list=policy.q_func_vars, - clip_val=policy.config["grad_clip"]) - else: - grads_and_vars = optimizer.compute_gradients( - loss, var_list=policy.q_func_vars) - grads_and_vars = [(g, v) for (g, v) in grads_and_vars if g is not None] - return grads_and_vars + return minimize_and_clip( + optimizer, + loss, + var_list=policy.q_func_vars, + clip_val=policy.config["grad_clip"]) def build_q_stats(policy: Policy, batch) -> Dict[str, TensorType]: diff --git a/rllib/agents/dqn/dqn_torch_policy.py b/rllib/agents/dqn/dqn_torch_policy.py index 1ed468e1d..b6800dafa 100644 --- a/rllib/agents/dqn/dqn_torch_policy.py +++ b/rllib/agents/dqn/dqn_torch_policy.py @@ -4,7 +4,6 @@ from typing import Dict, List, Tuple import gym import ray -from ray.rllib.agents.a3c.a3c_torch_policy import apply_grad_clipping from ray.rllib.agents.dqn.dqn_tf_policy import ( PRIO_WEIGHTS, Q_SCOPE, Q_TARGET_SCOPE, postprocess_nstep_and_prio) from ray.rllib.agents.dqn.dqn_torch_model import DQNTorchModel @@ -20,9 +19,8 @@ from ray.rllib.policy.torch_policy import LearningRateSchedule from ray.rllib.utils.error import UnsupportedSpaceException from ray.rllib.utils.exploration.parameter_noise import ParameterNoise from ray.rllib.utils.framework import try_import_torch -from ray.rllib.utils.torch_ops import (FLOAT_MIN, huber_loss, - reduce_mean_ignore_inf, - softmax_cross_entropy_with_logits) +from ray.rllib.utils.torch_ops import apply_grad_clipping, FLOAT_MIN, \ + huber_loss, reduce_mean_ignore_inf, softmax_cross_entropy_with_logits from ray.rllib.utils.typing import TensorType, TrainerConfigDict torch, nn = try_import_torch() diff --git a/rllib/agents/dreamer/dreamer_torch_policy.py b/rllib/agents/dreamer/dreamer_torch_policy.py index d23ad9c30..cc0c1e2a2 100644 --- a/rllib/agents/dreamer/dreamer_torch_policy.py +++ b/rllib/agents/dreamer/dreamer_torch_policy.py @@ -1,11 +1,11 @@ import logging import ray -from ray.rllib.agents.a3c.a3c_torch_policy import apply_grad_clipping from ray.rllib.agents.dreamer.utils import FreezeParameters from ray.rllib.models.catalog import ModelCatalog from ray.rllib.policy.policy_template import build_policy_class from ray.rllib.utils.framework import try_import_torch +from ray.rllib.utils.torch_ops import apply_grad_clipping torch, nn = try_import_torch() if torch: diff --git a/rllib/agents/impala/impala.py b/rllib/agents/impala/impala.py index 0bcddf4f3..580cf1ad3 100644 --- a/rllib/agents/impala/impala.py +++ b/rllib/agents/impala/impala.py @@ -1,7 +1,6 @@ import logging import ray -from ray.rllib.agents.a3c.a3c_tf_policy import A3CTFPolicy from ray.rllib.agents.impala.vtrace_tf_policy import VTraceTFPolicy from ray.rllib.agents.trainer import Trainer, with_common_config from ray.rllib.agents.trainer_template import build_trainer @@ -160,6 +159,7 @@ def get_policy_class(config): if config["vtrace"]: return VTraceTFPolicy else: + from ray.rllib.agents.a3c.a3c_tf_policy import A3CTFPolicy return A3CTFPolicy diff --git a/rllib/agents/impala/vtrace_torch_policy.py b/rllib/agents/impala/vtrace_torch_policy.py index c6b8c2634..7fc4398dd 100644 --- a/rllib/agents/impala/vtrace_torch_policy.py +++ b/rllib/agents/impala/vtrace_torch_policy.py @@ -3,7 +3,6 @@ import logging import numpy as np import ray -from ray.rllib.agents.a3c.a3c_torch_policy import apply_grad_clipping import ray.rllib.agents.impala.vtrace_torch as vtrace from ray.rllib.models.torch.torch_action_dist import TorchCategorical from ray.rllib.policy.policy_template import build_policy_class @@ -11,8 +10,8 @@ from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.policy.torch_policy import LearningRateSchedule, \ EntropyCoeffSchedule from ray.rllib.utils.framework import try_import_torch -from ray.rllib.utils.torch_ops import explained_variance, global_norm, \ - sequence_mask +from ray.rllib.utils.torch_ops import apply_grad_clipping, \ + explained_variance, global_norm, sequence_mask torch, nn = try_import_torch() diff --git a/rllib/agents/maml/maml_tf_policy.py b/rllib/agents/maml/maml_tf_policy.py index b9e4d0775..d5fa2568f 100644 --- a/rllib/agents/maml/maml_tf_policy.py +++ b/rllib/agents/maml/maml_tf_policy.py @@ -1,10 +1,10 @@ import logging import ray -from ray.rllib.agents.ppo.ppo_tf_policy import postprocess_ppo_gae, \ - vf_preds_fetches, compute_and_clip_gradients, setup_config, \ - ValueNetworkMixin -from ray.rllib.evaluation.postprocessing import Postprocessing +from ray.rllib.agents.ppo.ppo_tf_policy import vf_preds_fetches, \ + compute_and_clip_gradients, setup_config, ValueNetworkMixin +from ray.rllib.evaluation.postprocessing import compute_gae_for_sample_batch, \ + Postprocessing from ray.rllib.models.utils import get_activation_fn from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.policy.tf_policy_template import build_tf_policy @@ -422,7 +422,7 @@ MAMLTFPolicy = build_tf_policy( stats_fn=maml_stats, optimizer_fn=maml_optimizer_fn, extra_action_fetches_fn=vf_preds_fetches, - postprocess_fn=postprocess_ppo_gae, + postprocess_fn=compute_gae_for_sample_batch, gradients_fn=compute_and_clip_gradients, before_init=setup_config, before_loss_init=setup_mixins, diff --git a/rllib/agents/maml/maml_torch_policy.py b/rllib/agents/maml/maml_torch_policy.py index 478d95ba6..2e0e1e208 100644 --- a/rllib/agents/maml/maml_torch_policy.py +++ b/rllib/agents/maml/maml_torch_policy.py @@ -1,15 +1,15 @@ import logging import ray -from ray.rllib.evaluation.postprocessing import Postprocessing +from ray.rllib.evaluation.postprocessing import compute_gae_for_sample_batch, \ + Postprocessing from ray.rllib.policy.policy_template import build_policy_class from ray.rllib.policy.sample_batch import SampleBatch -from ray.rllib.agents.ppo.ppo_tf_policy import postprocess_ppo_gae, \ - setup_config +from ray.rllib.agents.ppo.ppo_tf_policy import setup_config from ray.rllib.agents.ppo.ppo_torch_policy import vf_preds_fetches, \ ValueNetworkMixin -from ray.rllib.agents.a3c.a3c_torch_policy import apply_grad_clipping from ray.rllib.utils.framework import try_import_torch +from ray.rllib.utils.torch_ops import apply_grad_clipping torch, nn = try_import_torch() @@ -355,7 +355,7 @@ MAMLTorchPolicy = build_policy_class( stats_fn=maml_stats, optimizer_fn=maml_optimizer_fn, extra_action_out_fn=vf_preds_fetches, - postprocess_fn=postprocess_ppo_gae, + postprocess_fn=compute_gae_for_sample_batch, extra_grad_process_fn=apply_grad_clipping, before_init=setup_config, after_init=setup_mixins, diff --git a/rllib/agents/mbmpo/mbmpo_torch_policy.py b/rllib/agents/mbmpo/mbmpo_torch_policy.py index f43d06ebe..06e65042e 100644 --- a/rllib/agents/mbmpo/mbmpo_torch_policy.py +++ b/rllib/agents/mbmpo/mbmpo_torch_policy.py @@ -3,18 +3,18 @@ import logging from typing import Tuple, Type import ray -from ray.rllib.agents.a3c.a3c_torch_policy import apply_grad_clipping from ray.rllib.agents.maml.maml_torch_policy import setup_mixins, \ maml_loss, maml_stats, maml_optimizer_fn, KLCoeffMixin -from ray.rllib.agents.ppo.ppo_tf_policy import postprocess_ppo_gae, \ - setup_config +from ray.rllib.agents.ppo.ppo_tf_policy import setup_config from ray.rllib.agents.ppo.ppo_torch_policy import vf_preds_fetches +from ray.rllib.evaluation.postprocessing import compute_gae_for_sample_batch from ray.rllib.models.catalog import ModelCatalog from ray.rllib.models.modelv2 import ModelV2 from ray.rllib.models.torch.torch_action_dist import TorchDistributionWrapper from ray.rllib.policy.policy import Policy from ray.rllib.policy.policy_template import build_policy_class from ray.rllib.utils.framework import try_import_torch +from ray.rllib.utils.torch_ops import apply_grad_clipping from ray.rllib.utils.typing import TrainerConfigDict torch, nn = try_import_torch() @@ -85,7 +85,7 @@ MBMPOTorchPolicy = build_policy_class( stats_fn=maml_stats, optimizer_fn=maml_optimizer_fn, extra_action_out_fn=vf_preds_fetches, - postprocess_fn=postprocess_ppo_gae, + postprocess_fn=compute_gae_for_sample_batch, extra_grad_process_fn=apply_grad_clipping, before_init=setup_config, after_init=setup_mixins, diff --git a/rllib/agents/ppo/appo_tf_policy.py b/rllib/agents/ppo/appo_tf_policy.py index c45cfb0ba..96d677e61 100644 --- a/rllib/agents/ppo/appo_tf_policy.py +++ b/rllib/agents/ppo/appo_tf_policy.py @@ -13,9 +13,9 @@ from typing import Dict, List, Optional, Type, Union from ray.rllib.agents.impala import vtrace_tf as vtrace from ray.rllib.agents.impala.vtrace_tf_policy import _make_time_major, \ clip_gradients, choose_optimizer -from ray.rllib.agents.ppo.ppo_tf_policy import postprocess_ppo_gae from ray.rllib.evaluation.episode import MultiAgentEpisode -from ray.rllib.evaluation.postprocessing import Postprocessing +from ray.rllib.evaluation.postprocessing import compute_gae_for_sample_batch, \ + Postprocessing from ray.rllib.models.tf.tf_action_dist import Categorical from ray.rllib.policy.policy import Policy from ray.rllib.policy.sample_batch import SampleBatch @@ -338,8 +338,8 @@ def postprocess_trajectory( SampleBatch: The postprocessed, modified SampleBatch (or a new one). """ if not policy.config["vtrace"]: - sample_batch = postprocess_ppo_gae(policy, sample_batch, - other_agent_batches, episode) + sample_batch = compute_gae_for_sample_batch( + policy, sample_batch, other_agent_batches, episode) # TODO: (sven) remove this del once we have trajectory view API fully in # place. diff --git a/rllib/agents/ppo/appo_torch_policy.py b/rllib/agents/ppo/appo_torch_policy.py index 461886dbe..81d748331 100644 --- a/rllib/agents/ppo/appo_torch_policy.py +++ b/rllib/agents/ppo/appo_torch_policy.py @@ -10,7 +10,6 @@ import numpy as np import logging from typing import Type -from ray.rllib.agents.a3c.a3c_torch_policy import apply_grad_clipping import ray.rllib.agents.impala.vtrace_torch as vtrace from ray.rllib.agents.impala.vtrace_torch_policy import make_time_major, \ choose_optimizer @@ -27,8 +26,8 @@ from ray.rllib.policy.policy_template import build_policy_class from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.policy.torch_policy import LearningRateSchedule from ray.rllib.utils.framework import try_import_torch -from ray.rllib.utils.torch_ops import explained_variance, global_norm, \ - sequence_mask +from ray.rllib.utils.torch_ops import apply_grad_clipping, explained_variance,\ + global_norm, sequence_mask from ray.rllib.utils.typing import TensorType, TrainerConfigDict torch, nn = try_import_torch() diff --git a/rllib/agents/ppo/ppo_tf_policy.py b/rllib/agents/ppo/ppo_tf_policy.py index 1a8f0be71..57874ba29 100644 --- a/rllib/agents/ppo/ppo_tf_policy.py +++ b/rllib/agents/ppo/ppo_tf_policy.py @@ -8,7 +8,7 @@ from typing import Dict, List, Optional, Type, Union import ray from ray.rllib.evaluation.episode import MultiAgentEpisode -from ray.rllib.evaluation.postprocessing import compute_advantages, \ +from ray.rllib.evaluation.postprocessing import compute_gae_for_sample_batch, \ Postprocessing from ray.rllib.models.modelv2 import ModelV2 from ray.rllib.models.tf.tf_action_dist import TFActionDistribution @@ -160,71 +160,6 @@ def vf_preds_fetches(policy: Policy) -> Dict[str, TensorType]: } -def postprocess_ppo_gae( - policy: Policy, - sample_batch: SampleBatch, - other_agent_batches: Optional[Dict[AgentID, SampleBatch]] = None, - episode: Optional[MultiAgentEpisode] = None) -> SampleBatch: - """Postprocesses a trajectory and returns the processed trajectory. - - The trajectory contains only data from one episode and from one agent. - - If `config.batch_mode=truncate_episodes` (default), sample_batch may - contain a truncated (at-the-end) episode, in case the - `config.rollout_fragment_length` was reached by the sampler. - - If `config.batch_mode=complete_episodes`, sample_batch will contain - exactly one episode (no matter how long). - New columns can be added to sample_batch and existing ones may be altered. - - Args: - policy (Policy): The Policy used to generate the trajectory - (`sample_batch`) - sample_batch (SampleBatch): The SampleBatch to postprocess. - other_agent_batches (Optional[Dict[PolicyID, SampleBatch]]): Optional - dict of AgentIDs mapping to other agents' trajectory data (from the - same episode). NOTE: The other agents use the same policy. - episode (Optional[MultiAgentEpisode]): Optional multi-agent episode - object in which the agents operated. - - Returns: - SampleBatch: The postprocessed, modified SampleBatch (or a new one). - """ - - # Trajectory is actually complete -> last r=0.0. - if sample_batch[SampleBatch.DONES][-1]: - last_r = 0.0 - # Trajectory has been truncated -> last r=VF estimate of last obs. - else: - # Input dict is provided to us automatically via the Model's - # requirements. It's a single-timestep (last one in trajectory) - # input_dict. - if policy.config["_use_trajectory_view_api"]: - # Create an input dict according to the Model's requirements. - input_dict = policy.model.get_input_dict( - sample_batch, index="last") - last_r = policy._value(**input_dict) - # TODO: (sven) Remove once trajectory view API is all-algo default. - else: - next_state = [] - for i in range(policy.num_state_tensors()): - next_state.append(sample_batch["state_out_{}".format(i)][-1]) - last_r = policy._value(sample_batch[SampleBatch.NEXT_OBS][-1], - sample_batch[SampleBatch.ACTIONS][-1], - sample_batch[SampleBatch.REWARDS][-1], - *next_state) - - # Adds the policy logits, VF preds, and advantages to the batch, - # using GAE ("generalized advantage estimation") or not. - batch = compute_advantages( - sample_batch, - last_r, - policy.config["gamma"], - policy.config["lambda"], - use_gae=policy.config["use_gae"], - use_critic=policy.config.get("use_critic", True)) - - return batch - - def compute_and_clip_gradients(policy: Policy, optimizer: LocalOptimizer, loss: TensorType) -> ModelGradients: """Gradients computing function (from loss tensor, using local optimizer). @@ -392,13 +327,29 @@ def setup_mixins(policy: Policy, obs_space: gym.spaces.Space, LearningRateSchedule.__init__(policy, config["lr"], config["lr_schedule"]) +def postprocess_ppo_gae( + policy: Policy, + sample_batch: SampleBatch, + other_agent_batches: Optional[Dict[AgentID, SampleBatch]] = None, + episode: Optional[MultiAgentEpisode] = None) -> SampleBatch: + + # Stub serving backward compatibility. + deprecation_warning( + old="rllib.agents.ppo.ppo_tf_policy.postprocess_ppo_gae", + new="rllib.evaluation.postprocessing.compute_gae_for_sample_batch", + error=False) + + return compute_gae_for_sample_batch(policy, sample_batch, + other_agent_batches, episode) + + # Build a child class of `DynamicTFPolicy`, given the custom functions defined # above. PPOTFPolicy = build_tf_policy( name="PPOTFPolicy", loss_fn=ppo_surrogate_loss, get_default_config=lambda: ray.rllib.agents.ppo.ppo.DEFAULT_CONFIG, - postprocess_fn=postprocess_ppo_gae, + postprocess_fn=compute_gae_for_sample_batch, stats_fn=kl_and_loss_stats, gradients_fn=compute_and_clip_gradients, extra_action_fetches_fn=vf_preds_fetches, diff --git a/rllib/agents/ppo/ppo_torch_policy.py b/rllib/agents/ppo/ppo_torch_policy.py index d73f53666..8bd07824b 100644 --- a/rllib/agents/ppo/ppo_torch_policy.py +++ b/rllib/agents/ppo/ppo_torch_policy.py @@ -7,10 +7,9 @@ import numpy as np from typing import Dict, List, Type, Union import ray -from ray.rllib.agents.a3c.a3c_torch_policy import apply_grad_clipping -from ray.rllib.agents.ppo.ppo_tf_policy import postprocess_ppo_gae, \ - setup_config -from ray.rllib.evaluation.postprocessing import Postprocessing +from ray.rllib.agents.ppo.ppo_tf_policy import setup_config +from ray.rllib.evaluation.postprocessing import compute_gae_for_sample_batch, \ + Postprocessing from ray.rllib.models.modelv2 import ModelV2 from ray.rllib.models.torch.torch_action_dist import TorchDistributionWrapper from ray.rllib.policy.policy import Policy @@ -19,8 +18,8 @@ from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.policy.torch_policy import EntropyCoeffSchedule, \ LearningRateSchedule from ray.rllib.utils.framework import try_import_torch -from ray.rllib.utils.torch_ops import convert_to_torch_tensor, \ - explained_variance, sequence_mask +from ray.rllib.utils.torch_ops import apply_grad_clipping, \ + convert_to_torch_tensor, explained_variance, sequence_mask from ray.rllib.utils.typing import TensorType, TrainerConfigDict torch, nn = try_import_torch() @@ -279,7 +278,7 @@ PPOTorchPolicy = build_policy_class( loss_fn=ppo_surrogate_loss, stats_fn=kl_and_loss_stats, extra_action_out_fn=vf_preds_fetches, - postprocess_fn=postprocess_ppo_gae, + postprocess_fn=compute_gae_for_sample_batch, extra_grad_process_fn=apply_grad_clipping, before_init=setup_config, before_loss_init=setup_mixins, diff --git a/rllib/agents/ppo/tests/test_ppo.py b/rllib/agents/ppo/tests/test_ppo.py index cb304c57e..de0bc90f6 100644 --- a/rllib/agents/ppo/tests/test_ppo.py +++ b/rllib/agents/ppo/tests/test_ppo.py @@ -5,11 +5,12 @@ import unittest import ray from ray.rllib.agents.callbacks import DefaultCallbacks import ray.rllib.agents.ppo as ppo -from ray.rllib.agents.ppo.ppo_tf_policy import postprocess_ppo_gae as \ - postprocess_ppo_gae_tf, ppo_surrogate_loss as ppo_surrogate_loss_tf -from ray.rllib.agents.ppo.ppo_torch_policy import postprocess_ppo_gae as \ - postprocess_ppo_gae_torch, ppo_surrogate_loss as ppo_surrogate_loss_torch -from ray.rllib.evaluation.postprocessing import Postprocessing +from ray.rllib.agents.ppo.ppo_tf_policy import ppo_surrogate_loss as \ + ppo_surrogate_loss_tf +from ray.rllib.agents.ppo.ppo_torch_policy import ppo_surrogate_loss as \ + ppo_surrogate_loss_torch +from ray.rllib.evaluation.postprocessing import compute_gae_for_sample_batch, \ + Postprocessing from ray.rllib.models.tf.tf_action_dist import Categorical from ray.rllib.models.torch.torch_modelv2 import TorchModelV2 from ray.rllib.models.torch.torch_action_dist import TorchCategorical @@ -212,11 +213,8 @@ class TestPPO(unittest.TestCase): # Check the variable is initially zero. init_std = get_value() assert init_std == 0.0, init_std - - if fw in ["tf2", "tf", "tfe"]: - batch = postprocess_ppo_gae_tf(policy, FAKE_BATCH.copy()) - else: - batch = postprocess_ppo_gae_torch(policy, FAKE_BATCH.copy()) + batch = compute_gae_for_sample_batch(policy, FAKE_BATCH.copy()) + if fw == "torch": batch = policy._lazy_tensor_dict(batch) policy.learn_on_batch(batch) @@ -255,11 +253,9 @@ class TestPPO(unittest.TestCase): # to train_batch dict. # A = [0.99^2 * 0.5 + 0.99 * -1.0 + 1.0, 0.99 * 0.5 - 1.0, 0.5] = # [0.50005, -0.505, 0.5] - if fw in ["tf2", "tf", "tfe"]: - train_batch = postprocess_ppo_gae_tf(policy, FAKE_BATCH.copy()) - else: - train_batch = postprocess_ppo_gae_torch( - policy, FAKE_BATCH.copy()) + train_batch = compute_gae_for_sample_batch(policy, + FAKE_BATCH.copy()) + if fw == "torch": train_batch = policy._lazy_tensor_dict(train_batch) # Check Advantage values. diff --git a/rllib/agents/sac/sac.py b/rllib/agents/sac/sac.py index daf66f88a..5c476248c 100644 --- a/rllib/agents/sac/sac.py +++ b/rllib/agents/sac/sac.py @@ -73,8 +73,7 @@ DEFAULT_CONFIG = with_common_config({ "timesteps_per_iteration": 100, # === Replay buffer === - # Size of the replay buffer. Note that if async_updates is set, then - # each worker will have a replay buffer of this size. + # Size of the replay buffer (in time steps). "buffer_size": int(1e6), # If True prioritized replay buffer will be used. "prioritized_replay": False, @@ -104,9 +103,7 @@ DEFAULT_CONFIG = with_common_config({ # Update the replay buffer with this many samples at once. Note that this # setting applies per-worker if num_workers > 1. "rollout_fragment_length": 1, - # Size of a batched sampled from replay buffer for training. Note that - # if async_updates is set, then each worker returns gradients for a - # batch of this size. + # Size of a batched sampled from replay buffer for training. "train_batch_size": 256, # Update the target network every `target_network_update_freq` steps. "target_network_update_freq": 0, diff --git a/rllib/agents/sac/sac_torch_policy.py b/rllib/agents/sac/sac_torch_policy.py index 49bc65557..d000e1839 100644 --- a/rllib/agents/sac/sac_torch_policy.py +++ b/rllib/agents/sac/sac_torch_policy.py @@ -9,7 +9,6 @@ from typing import Dict, List, Optional, Tuple, Type, Union import ray import ray.experimental.tf_utils -from ray.rllib.agents.a3c.a3c_torch_policy import apply_grad_clipping from ray.rllib.agents.sac.sac_tf_policy import build_sac_model, \ postprocess_trajectory, validate_spaces from ray.rllib.agents.dqn.dqn_tf_policy import PRIO_WEIGHTS @@ -23,7 +22,7 @@ from ray.rllib.models.torch.torch_action_dist import ( TorchCategorical, TorchSquashedGaussian, TorchDiagGaussian, TorchBeta) from ray.rllib.utils.framework import try_import_torch from ray.rllib.utils.spaces.simplex import Simplex -from ray.rllib.utils.torch_ops import huber_loss +from ray.rllib.utils.torch_ops import apply_grad_clipping, huber_loss from ray.rllib.utils.typing import LocalOptimizer, TensorType, \ TrainerConfigDict diff --git a/rllib/contrib/maddpg/maddpg_policy.py b/rllib/contrib/maddpg/maddpg_policy.py index 913c3f530..6615498ec 100644 --- a/rllib/contrib/maddpg/maddpg_policy.py +++ b/rllib/contrib/maddpg/maddpg_policy.py @@ -265,17 +265,11 @@ class MADDPGTFPolicy(MADDPGPostprocessing, TFPolicy): @override(TFPolicy) def gradients(self, optimizer, loss): - if self.config["grad_norm_clipping"] is not None: - self.gvs = { - k: minimize_and_clip(optimizer, self.losses[k], self.vars[k], - self.config["grad_norm_clipping"]) - for k, optimizer in self.optimizers.items() - } - else: - self.gvs = { - k: optimizer.compute_gradients(self.losses[k], self.vars[k]) - for k, optimizer in self.optimizers.items() - } + self.gvs = { + k: minimize_and_clip(optimizer, self.losses[k], self.vars[k], + self.config["grad_norm_clipping"]) + for k, optimizer in self.optimizers.items() + } return self.gvs["critic"] + self.gvs["actor"] @override(TFPolicy) diff --git a/rllib/evaluation/postprocessing.py b/rllib/evaluation/postprocessing.py index 0cb25d5c7..7d1801cf6 100644 --- a/rllib/evaluation/postprocessing.py +++ b/rllib/evaluation/postprocessing.py @@ -1,22 +1,12 @@ import numpy as np import scipy.signal +from typing import Dict, Optional + +from ray.rllib.evaluation.episode import MultiAgentEpisode +from ray.rllib.policy.policy import Policy from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.utils.annotations import DeveloperAPI - - -def discount_cumsum(x: np.ndarray, gamma: float) -> float: - """Calculates the discounted cumulative sum over a reward sequence `x`. - - y[t] - discount*y[t+1] = x[t] - reversed(y)[t] - discount*reversed(y)[t-1] = reversed(x)[t] - - Args: - gamma (float): The discount factor gamma. - - Returns: - float: The discounted cumulative sum over the reward sequence `x`. - """ - return scipy.signal.lfilter([1], [1, float(-gamma)], x[::-1], axis=0)[::-1] +from ray.rllib.utils.typing import AgentID class Postprocessing: @@ -89,3 +79,83 @@ def compute_advantages(rollout: SampleBatch, Postprocessing.ADVANTAGES].astype(np.float32) return rollout + + +def compute_gae_for_sample_batch( + policy: Policy, + sample_batch: SampleBatch, + other_agent_batches: Optional[Dict[AgentID, SampleBatch]] = None, + episode: Optional[MultiAgentEpisode] = None) -> SampleBatch: + """Adds GAE (generalized advantage estimations) to a trajectory. + + The trajectory contains only data from one episode and from one agent. + - If `config.batch_mode=truncate_episodes` (default), sample_batch may + contain a truncated (at-the-end) episode, in case the + `config.rollout_fragment_length` was reached by the sampler. + - If `config.batch_mode=complete_episodes`, sample_batch will contain + exactly one episode (no matter how long). + New columns can be added to sample_batch and existing ones may be altered. + + Args: + policy (Policy): The Policy used to generate the trajectory + (`sample_batch`) + sample_batch (SampleBatch): The SampleBatch to postprocess. + other_agent_batches (Optional[Dict[PolicyID, SampleBatch]]): Optional + dict of AgentIDs mapping to other agents' trajectory data (from the + same episode). NOTE: The other agents use the same policy. + episode (Optional[MultiAgentEpisode]): Optional multi-agent episode + object in which the agents operated. + + Returns: + SampleBatch: The postprocessed, modified SampleBatch (or a new one). + """ + + # Trajectory is actually complete -> last r=0.0. + if sample_batch[SampleBatch.DONES][-1]: + last_r = 0.0 + # Trajectory has been truncated -> last r=VF estimate of last obs. + else: + # Input dict is provided to us automatically via the Model's + # requirements. It's a single-timestep (last one in trajectory) + # input_dict. + if policy.config.get("_use_trajectory_view_api"): + # Create an input dict according to the Model's requirements. + input_dict = policy.model.get_input_dict( + sample_batch, index="last") + last_r = policy._value(**input_dict) + # TODO: (sven) Remove once trajectory view API is all-algo default. + else: + next_state = [] + for i in range(policy.num_state_tensors()): + next_state.append(sample_batch["state_out_{}".format(i)][-1]) + last_r = policy._value(sample_batch[SampleBatch.NEXT_OBS][-1], + sample_batch[SampleBatch.ACTIONS][-1], + sample_batch[SampleBatch.REWARDS][-1], + *next_state) + + # Adds the policy logits, VF preds, and advantages to the batch, + # using GAE ("generalized advantage estimation") or not. + batch = compute_advantages( + sample_batch, + last_r, + policy.config["gamma"], + policy.config["lambda"], + use_gae=policy.config["use_gae"], + use_critic=policy.config.get("use_critic", True)) + + return batch + + +def discount_cumsum(x: np.ndarray, gamma: float) -> float: + """Calculates the discounted cumulative sum over a reward sequence `x`. + + y[t] - discount*y[t+1] = x[t] + reversed(y)[t] - discount*reversed(y)[t-1] = reversed(x)[t] + + Args: + gamma (float): The discount factor gamma. + + Returns: + float: The discounted cumulative sum over the reward sequence `x`. + """ + return scipy.signal.lfilter([1], [1, float(-gamma)], x[::-1], axis=0)[::-1] diff --git a/rllib/tuned_examples/sac/mspacman-sac.yaml b/rllib/tuned_examples/sac/mspacman-sac.yaml index c47e384db..50883b114 100644 --- a/rllib/tuned_examples/sac/mspacman-sac.yaml +++ b/rllib/tuned_examples/sac/mspacman-sac.yaml @@ -14,11 +14,11 @@ mspacman-sac-tf: # state-preprocessor=Our default Atari Conv2D-net. use_state_preprocessor: true Q_model: - hidden_activation: relu - hidden_layer_sizes: [512] + fcnet_hiddens: [512] + fcnet_activation: relu policy_model: - hidden_activation: relu - hidden_layer_sizes: [512] + fcnet_hiddens: [512] + fcnet_activation: relu # Do hard syncs. # Soft-syncs seem to work less reliably for discrete action spaces. tau: 1.0 diff --git a/rllib/utils/tf_ops.py b/rllib/utils/tf_ops.py index f74926aaa..5b75bc5e1 100644 --- a/rllib/utils/tf_ops.py +++ b/rllib/utils/tf_ops.py @@ -92,7 +92,7 @@ def minimize_and_clip(optimizer, objective, var_list, clip_val=10.0): variable is clipped to `clip_val` """ # Accidentally passing values < 0.0 will break all gradients. - assert clip_val > 0.0, clip_val + assert clip_val is None or clip_val > 0.0, clip_val if tf.executing_eagerly(): tape = optimizer.tape @@ -102,10 +102,8 @@ def minimize_and_clip(optimizer, objective, var_list, clip_val=10.0): grads_and_vars = optimizer.compute_gradients( objective, var_list=var_list) - for i, (grad, var) in enumerate(grads_and_vars): - if grad is not None: - grads_and_vars[i] = (tf.clip_by_norm(grad, clip_val), var) - return grads_and_vars + return [(tf.clip_by_norm(g, clip_val) if clip_val is not None else g, v) + for (g, v) in grads_and_vars if g is not None] def make_tf_callable(session_or_none, dynamic_shape=False): diff --git a/rllib/utils/torch_ops.py b/rllib/utils/torch_ops.py index ce6c86a16..487b8c246 100644 --- a/rllib/utils/torch_ops.py +++ b/rllib/utils/torch_ops.py @@ -14,6 +14,30 @@ FLOAT_MIN = -3.4e38 FLOAT_MAX = 3.4e38 +def apply_grad_clipping(policy, optimizer, loss): + """Applies gradient clipping to already computed grads inside `optimizer`. + + Args: + policy (TorchPolicy): The TorchPolicy, which calculated `loss`. + optimizer (torch.optim.Optimizer): A local torch optimizer object. + loss (torch.Tensor): The torch loss tensor. + """ + info = {} + if policy.config["grad_clip"]: + for param_group in optimizer.param_groups: + # Make sure we only pass params with grad != None into torch + # clip_grad_norm_. Would fail otherwise. + params = list( + filter(lambda p: p.grad is not None, param_group["params"])) + if params: + grad_gnorm = nn.utils.clip_grad_norm_( + params, policy.config["grad_clip"]) + if isinstance(grad_gnorm, torch.Tensor): + grad_gnorm = grad_gnorm.cpu().numpy() + info["grad_gnorm"] = grad_gnorm + return info + + def atanh(x): return 0.5 * torch.log((1 + x) / (1 - x))