diff --git a/doc/source/rllib-algorithms.rst b/doc/source/rllib-algorithms.rst index 92de766aa..f027db6f8 100644 --- a/doc/source/rllib-algorithms.rst +++ b/doc/source/rllib-algorithms.rst @@ -35,7 +35,7 @@ Algorithm Frameworks Discrete Actions Continuous Actions Multi- .. _`+LSTM auto-wrapping`: rllib-models.html#built-in-models .. _`+parametric`: rllib-models.html#variable-length-parametric-action-spaces .. _`+RNN`: rllib-models.html#recurrent-models -.. _`+Transformer`: rllib-models.html#attention-networks +.. _`+Transformer`: rllib-models.html#attention-networks-transformers .. _`A2C, A3C`: rllib-algorithms.html#a3c .. _`APEX-DQN`: rllib-algorithms.html#apex .. _`APEX-DDPG`: rllib-algorithms.html#apex @@ -304,16 +304,22 @@ SpaceInvaders 650 1001 1025 Policy Gradients ---------------- -|pytorch| |tensorflow| -`[paper] `__ `[implementation] `__ We include a vanilla policy gradients implementation as an example algorithm. +|pytorch| |tensorflow| An `implementation `__ of a vanilla policy gradient algorithm for TensorFlow and PyTorch. + +**Papers**: +`[1] - Policy Gradient Methods for Reinforcement Learning with Function Approximation. `__ +and +`[2] - Simple Statistical Gradient-Following Algorithms for Connectionist Reinforcement Learning. `__ + .. figure:: a2c-arch.svg Policy gradients architecture (same as A2C) -Tuned examples: `CartPole-v0 `__ +**Tuned examples**: `CartPole-v0 `__ -**PG-specific configs** (see also `common configs `__): +**PG-specific configs**: The following updates will overwrite/be added to the +(base) Trainer config in `rllib/agents/trainer.py `__ (*COMMON_CONFIG* dict): .. literalinclude:: ../../rllib/agents/pg/pg.py :language: python diff --git a/rllib/__init__.py b/rllib/__init__.py index 9eff863f3..d27194f69 100644 --- a/rllib/__init__.py +++ b/rllib/__init__.py @@ -10,7 +10,7 @@ from ray.rllib.evaluation.rollout_worker import RolloutWorker from ray.rllib.policy.policy import Policy from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.policy.tf_policy import TFPolicy - +from ray.rllib.policy.torch_policy import TorchPolicy from ray.tune.registry import register_trainable @@ -60,6 +60,7 @@ _register_all() __all__ = [ "Policy", "TFPolicy", + "TorchPolicy", "RolloutWorker", "SampleBatch", "BaseEnv", diff --git a/rllib/agents/pg/README.md b/rllib/agents/pg/README.md new file mode 100644 index 000000000..407435ce5 --- /dev/null +++ b/rllib/agents/pg/README.md @@ -0,0 +1,8 @@ +Policy Gradient (PG) +==================== + +An implementation of a vanilla policy gradient algorithm for TensorFlow and PyTorch. + +**[Detailed Documentation](https://docs.ray.io/en/latest/rllib-algorithms.html#pg)** + +**[Implementation](https://github.com/ray-project/ray/blob/master/rllib/agents/pg/pg.py)** diff --git a/rllib/agents/pg/pg.py b/rllib/agents/pg/pg.py index 4f14aa48f..2206c26e3 100644 --- a/rllib/agents/pg/pg.py +++ b/rllib/agents/pg/pg.py @@ -1,29 +1,57 @@ +""" +Policy Gradient (PG) +==================== + +This file defines the distributed Trainer class for policy gradients. +See `pg_[tf|torch]_policy.py` for the definition of the policy loss. + +Detailed documentation: https://docs.ray.io/en/latest/rllib-algorithms.html#pg +""" + +from typing import Optional, Type + from ray.rllib.agents.trainer import with_common_config from ray.rllib.agents.trainer_template import build_trainer from ray.rllib.agents.pg.pg_tf_policy import PGTFPolicy +from ray.rllib.agents.pg.pg_torch_policy import PGTorchPolicy +from ray.rllib.policy.policy import Policy +from ray.rllib.utils.typing import TrainerConfigDict # yapf: disable # __sphinx_doc_begin__ + +# Adds the following updates to the (base) `Trainer` config in +# rllib/agents/trainer.py (`COMMON_CONFIG` dict). DEFAULT_CONFIG = with_common_config({ # No remote workers by default. "num_workers": 0, # Learning rate. "lr": 0.0004, }) + # __sphinx_doc_end__ # yapf: enable -def get_policy_class(config): +def get_policy_class(config: TrainerConfigDict) -> Optional[Type[Policy]]: + """Policy class picker function. Class is chosen based on DL-framework. + + Args: + config (TrainerConfigDict): The trainer's configuration dict. + + Returns: + Optional[Type[Policy]]: The Policy class to use with PGTrainer. + If None, use `default_policy` provided in build_trainer(). + """ if config["framework"] == "torch": - from ray.rllib.agents.pg.pg_torch_policy import PGTorchPolicy return PGTorchPolicy - else: - return PGTFPolicy +# Build a child class of `Trainer`, which uses the framework specific Policy +# determined in `get_policy_class()` above. PGTrainer = build_trainer( name="PG", default_config=DEFAULT_CONFIG, default_policy=PGTFPolicy, - get_policy_class=get_policy_class) + get_policy_class=get_policy_class, +) diff --git a/rllib/agents/pg/pg_tf_policy.py b/rllib/agents/pg/pg_tf_policy.py index 3b1aa8374..fa5e781ba 100644 --- a/rllib/agents/pg/pg_tf_policy.py +++ b/rllib/agents/pg/pg_tf_policy.py @@ -1,35 +1,54 @@ +""" +TensorFlow policy class used for PG. +""" + +from typing import List, Type, Union + import ray -from ray.rllib.evaluation.postprocessing import Postprocessing, \ - compute_advantages +from ray.rllib.agents.pg.utils import post_process_advantages +from ray.rllib.evaluation.postprocessing import Postprocessing +from ray.rllib.models.action_dist import ActionDistribution +from ray.rllib.models.modelv2 import ModelV2 +from ray.rllib.policy import Policy from ray.rllib.policy.tf_policy_template import build_tf_policy from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.utils.framework import try_import_tf +from ray.rllib.utils.typing import TensorType tf1, tf, tfv = try_import_tf() -def post_process_advantages(policy, - sample_batch, - other_agent_batches=None, - episode=None): - """This adds the "advantages" column to the sample train_batch.""" - return compute_advantages( - sample_batch, - 0.0, - policy.config["gamma"], - use_gae=False, - use_critic=False) +def pg_tf_loss( + policy: Policy, model: ModelV2, dist_class: Type[ActionDistribution], + train_batch: SampleBatch) -> Union[TensorType, List[TensorType]]: + """The basic policy gradients loss function. + Args: + policy (Policy): The Policy to calculate the loss for. + model (ModelV2): The Model to calculate the loss for. + dist_class (Type[ActionDistribution]: The action distr. class. + train_batch (SampleBatch): The training data. -def pg_tf_loss(policy, model, dist_class, train_batch): - """The basic policy gradients loss.""" - logits, _ = model.from_batch(train_batch) - action_dist = dist_class(logits, model) + Returns: + Union[TensorType, List[TensorType]]: A single loss tensor or a list + of loss tensors. + """ + # Pass the training data through our model to get distribution parameters. + dist_inputs, _ = model.from_batch(train_batch) + + # Create an action distribution object. + action_dist = dist_class(dist_inputs, model) + + # Calculate the vanilla PG loss based on: + # L = -E[ log(pi(a|s)) * A] return -tf.reduce_mean( action_dist.logp(train_batch[SampleBatch.ACTIONS]) * tf.cast( train_batch[Postprocessing.ADVANTAGES], dtype=tf.float32)) +# Build a child class of `TFPolicy`, given the extra options: +# - trajectory post-processing function (to calculate advantages) +# - PG loss function PGTFPolicy = build_tf_policy( name="PGTFPolicy", get_default_config=lambda: ray.rllib.agents.pg.pg.DEFAULT_CONFIG, diff --git a/rllib/agents/pg/pg_torch_policy.py b/rllib/agents/pg/pg_torch_policy.py index f63b76cb1..93cb2f4ac 100644 --- a/rllib/agents/pg/pg_torch_policy.py +++ b/rllib/agents/pg/pg_torch_policy.py @@ -1,31 +1,77 @@ +""" +PyTorch policy class used for PG. +""" + +from typing import Dict, List, Type, Union + import ray -from ray.rllib.agents.pg.pg_tf_policy import post_process_advantages +from ray.rllib.agents.pg.utils import post_process_advantages from ray.rllib.evaluation.postprocessing import Postprocessing +from ray.rllib.models.torch.torch_action_dist import TorchDistributionWrapper +from ray.rllib.models.modelv2 import ModelV2 +from ray.rllib.policy import Policy from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.policy.torch_policy_template import build_torch_policy from ray.rllib.utils.framework import try_import_torch +from ray.rllib.utils.typing import TensorType torch, _ = try_import_torch() -def pg_torch_loss(policy, model, dist_class, train_batch): - """The basic policy gradients loss.""" - logits, _ = model.from_batch(train_batch) - action_dist = dist_class(logits, model) +def pg_torch_loss( + policy: Policy, model: ModelV2, + dist_class: Type[TorchDistributionWrapper], + train_batch: SampleBatch) -> Union[TensorType, List[TensorType]]: + """The basic policy gradients loss function. + + Args: + policy (Policy): The Policy to calculate the loss for. + model (ModelV2): The Model to calculate the loss for. + dist_class (Type[ActionDistribution]: The action distr. class. + train_batch (SampleBatch): The training data. + + Returns: + Union[TensorType, List[TensorType]]: A single loss tensor or a list + of loss tensors. + """ + # Pass the training data through our model to get distribution parameters. + dist_inputs, _ = model.from_batch(train_batch) + + # Create an action distribution object. + action_dist = dist_class(dist_inputs, model) + + # Calculate the vanilla PG loss based on: + # L = -E[ log(pi(a|s)) * A] log_probs = action_dist.logp(train_batch[SampleBatch.ACTIONS]) - # Save the error in the policy object. - # policy.pi_err = -train_batch[Postprocessing.ADVANTAGES].dot( - # log_probs.reshape(-1)) / len(log_probs) + + # Save the loss in the policy object for the stats_fn below. policy.pi_err = -torch.mean( log_probs * train_batch[Postprocessing.ADVANTAGES]) + return policy.pi_err -def pg_loss_stats(policy, train_batch): - """ The error is recorded when computing the loss.""" - return {"policy_loss": policy.pi_err.item()} +def pg_loss_stats(policy: Policy, + train_batch: SampleBatch) -> Dict[str, TensorType]: + """Returns the calculated loss in a stats dict. + + Args: + policy (Policy): The Policy object. + train_batch (SampleBatch): The data used for training. + + Returns: + Dict[str, TensorType]: The stats dict. + """ + + return { + # `pi_err` (the loss) is stored inside `pg_torch_loss()`. + "policy_loss": policy.pi_err.item(), + } +# Build a child class of `TFPolicy`, given the extra options: +# - trajectory post-processing function (to calculate advantages) +# - PG loss function PGTorchPolicy = build_torch_policy( name="PGTorchPolicy", get_default_config=lambda: ray.rllib.agents.pg.pg.DEFAULT_CONFIG, diff --git a/rllib/agents/pg/utils.py b/rllib/agents/pg/utils.py new file mode 100644 index 000000000..309f8003d --- /dev/null +++ b/rllib/agents/pg/utils.py @@ -0,0 +1,36 @@ +from typing import List, Optional + +from ray.rllib.evaluation.episode import MultiAgentEpisode +from ray.rllib.evaluation.postprocessing import compute_advantages +from ray.rllib.policy import Policy +from ray.rllib.policy.sample_batch import SampleBatch + + +def post_process_advantages( + policy: Policy, + sample_batch: SampleBatch, + other_agent_batches: Optional[List[SampleBatch]] = None, + episode: Optional[MultiAgentEpisode] = None) -> SampleBatch: + """Adds the "advantages" column to `sample_batch`. + + Args: + policy (Policy): The Policy object to do post-processing for. + sample_batch (SampleBatch): The actual sample batch to post-process. + other_agent_batches (Optional[List[SampleBatch]]): Optional list of + other agents' SampleBatch objects. + episode (MultiAgentEpisode): The multi-agent episode object, from which + `sample_batch` was generated. + + Returns: + SampleBatch: The SampleBatch enhanced by the added ADVANTAGES field. + """ + + # Calculates advantage values based on the rewards in the sample batch. + # The value of the last observation is assumed to be 0.0 (no value function + # estimation at the end of the sampled chunk). + return compute_advantages( + rollout=sample_batch, + last_r=0.0, + gamma=policy.config["gamma"], + use_gae=False, + use_critic=False) diff --git a/rllib/agents/ppo/appo.py b/rllib/agents/ppo/appo.py index 5179836c7..4c014faa9 100644 --- a/rllib/agents/ppo/appo.py +++ b/rllib/agents/ppo/appo.py @@ -1,62 +1,65 @@ from ray.rllib.agents.impala.impala import validate_config from ray.rllib.agents.ppo.appo_tf_policy import AsyncPPOTFPolicy from ray.rllib.agents.ppo.ppo import UpdateKL -from ray.rllib.agents.trainer import with_base_config from ray.rllib.agents import impala from ray.rllib.execution.common import STEPS_SAMPLED_COUNTER, \ LAST_TARGET_UPDATE_TS, NUM_TARGET_UPDATES, _get_shared_metrics # yapf: disable # __sphinx_doc_begin__ -DEFAULT_CONFIG = with_base_config(impala.DEFAULT_CONFIG, { - # Whether to use V-trace weighted advantages. If false, PPO GAE advantages - # will be used instead. - "vtrace": False, +DEFAULT_CONFIG = impala.ImpalaTrainer.merge_trainer_configs( + impala.DEFAULT_CONFIG, # See keys in impala.py, which are also supported. + { + # Whether to use V-trace weighted advantages. If false, PPO GAE + # advantages will be used instead. + "vtrace": False, - # == These two options only apply if vtrace: False == - # Should use a critic as a baseline (otherwise don't use value baseline; - # required for using GAE). - "use_critic": True, - # If true, use the Generalized Advantage Estimator (GAE) - # with a value function, see https://arxiv.org/pdf/1506.02438.pdf. - "use_gae": True, - # GAE(lambda) parameter - "lambda": 1.0, + # == These two options only apply if vtrace: False == + # Should use a critic as a baseline (otherwise don't use value + # baseline; required for using GAE). + "use_critic": True, + # If true, use the Generalized Advantage Estimator (GAE) + # with a value function, see https://arxiv.org/pdf/1506.02438.pdf. + "use_gae": True, + # GAE(lambda) parameter + "lambda": 1.0, - # == PPO surrogate loss options == - "clip_param": 0.4, + # == PPO surrogate loss options == + "clip_param": 0.4, - # == PPO KL Loss options == - "use_kl_loss": False, - "kl_coeff": 1.0, - "kl_target": 0.01, + # == PPO KL Loss options == + "use_kl_loss": False, + "kl_coeff": 1.0, + "kl_target": 0.01, - # == IMPALA optimizer params (see documentation in impala.py) == - "rollout_fragment_length": 50, - "train_batch_size": 500, - "min_iter_time_s": 10, - "num_workers": 2, - "num_gpus": 0, - "num_data_loader_buffers": 1, - "minibatch_buffer_size": 1, - "num_sgd_iter": 1, - "replay_proportion": 0.0, - "replay_buffer_num_slots": 100, - "learner_queue_size": 16, - "learner_queue_timeout": 300, - "max_sample_requests_in_flight_per_worker": 2, - "broadcast_interval": 1, - "grad_clip": 40.0, - "opt_type": "adam", - "lr": 0.0005, - "lr_schedule": None, - "decay": 0.99, - "momentum": 0.0, - "epsilon": 0.1, - "vf_loss_coeff": 0.5, - "entropy_coeff": 0.01, - "entropy_coeff_schedule": None, -}) + # == IMPALA optimizer params (see documentation in impala.py) == + "rollout_fragment_length": 50, + "train_batch_size": 500, + "min_iter_time_s": 10, + "num_workers": 2, + "num_gpus": 0, + "num_data_loader_buffers": 1, + "minibatch_buffer_size": 1, + "num_sgd_iter": 1, + "replay_proportion": 0.0, + "replay_buffer_num_slots": 100, + "learner_queue_size": 16, + "learner_queue_timeout": 300, + "max_sample_requests_in_flight_per_worker": 2, + "broadcast_interval": 1, + "grad_clip": 40.0, + "opt_type": "adam", + "lr": 0.0005, + "lr_schedule": None, + "decay": 0.99, + "momentum": 0.0, + "epsilon": 0.1, + "vf_loss_coeff": 0.5, + "entropy_coeff": 0.01, + "entropy_coeff_schedule": None, + }, + _allow_unknown_configs=True, +) # __sphinx_doc_end__ # yapf: enable diff --git a/rllib/agents/ppo/ddppo.py b/rllib/agents/ppo/ddppo.py index cb16b9b3b..182c1234e 100644 --- a/rllib/agents/ppo/ddppo.py +++ b/rllib/agents/ppo/ddppo.py @@ -19,7 +19,6 @@ import time import ray from ray.rllib.agents.ppo import ppo -from ray.rllib.agents.trainer import with_base_config from ray.rllib.execution.rollout_ops import ParallelRollouts from ray.rllib.execution.metric_ops import StandardMetricsReporting from ray.rllib.execution.common import STEPS_SAMPLED_COUNTER, \ @@ -32,31 +31,42 @@ logger = logging.getLogger(__name__) # yapf: disable # __sphinx_doc_begin__ -DEFAULT_CONFIG = with_base_config(ppo.DEFAULT_CONFIG, { - # During the sampling phase, each rollout worker will collect a batch - # `rollout_fragment_length * num_envs_per_worker` steps in size. - "rollout_fragment_length": 100, - # Vectorize the env (should enable by default since each worker has a GPU). - "num_envs_per_worker": 5, - # During the SGD phase, workers iterate over minibatches of this size. - # The effective minibatch size will be `sgd_minibatch_size * num_workers`. - "sgd_minibatch_size": 50, - # Number of SGD epochs per optimization round. - "num_sgd_iter": 10, - # Download weights between each training step. This adds a bit of overhead - # but allows the user to access the weights from the trainer. - "keep_local_weights_in_sync": True, +DEFAULT_CONFIG = ppo.PPOTrainer.merge_trainer_configs( + ppo.DEFAULT_CONFIG, + { + # During the sampling phase, each rollout worker will collect a batch + # `rollout_fragment_length * num_envs_per_worker` steps in size. + "rollout_fragment_length": 100, + # Vectorize the env (should enable by default since each worker has + # a GPU). + "num_envs_per_worker": 5, + # During the SGD phase, workers iterate over minibatches of this size. + # The effective minibatch size will be: + # `sgd_minibatch_size * num_workers`. + "sgd_minibatch_size": 50, + # Number of SGD epochs per optimization round. + "num_sgd_iter": 10, + # Download weights between each training step. This adds a bit of + # overhead but allows the user to access the weights from the trainer. + "keep_local_weights_in_sync": True, - # *** WARNING: configs below are DDPPO overrides over PPO; you - # shouldn't need to adjust them. *** - "framework": "torch", # DDPPO requires PyTorch distributed. - "num_gpus": 0, # Learning is no longer done on the driver process, so - # giving GPUs to the driver does not make sense! - "num_gpus_per_worker": 1, # Each rollout worker gets a GPU. - "truncate_episodes": True, # Require evenly sized batches. Otherwise, - # collective allreduce could fail. - "train_batch_size": -1, # This is auto set based on sample batch size. -}) + # *** WARNING: configs below are DDPPO overrides over PPO; you + # shouldn't need to adjust them. *** + # DDPPO requires PyTorch distributed. + "framework": "torch", + # Learning is no longer done on the driver process, so + # giving GPUs to the driver does not make sense! + "num_gpus": 0, + # Each rollout worker gets a GPU. + "num_gpus_per_worker": 1, + # Require evenly sized batches. Otherwise, + # collective allreduce could fail. + "truncate_episodes": True, + # This is auto set based on sample batch size. + "train_batch_size": -1, + }, + _allow_unknown_configs=True, +) # __sphinx_doc_end__ # yapf: enable diff --git a/rllib/agents/trainer.py b/rllib/agents/trainer.py index dc7c73ce5..926802c5d 100644 --- a/rllib/agents/trainer.py +++ b/rllib/agents/trainer.py @@ -7,7 +7,7 @@ import os import pickle import time import tempfile -from typing import Callable, List, Dict, Union +from typing import Callable, Dict, List, Optional, Type, Union import ray from ray.exceptions import RayError @@ -390,19 +390,18 @@ COMMON_CONFIG: TrainerConfigDict = { @DeveloperAPI def with_common_config( extra_config: PartialTrainerConfigDict) -> TrainerConfigDict: - """Returns the given config dict merged with common agent confs.""" + """Returns the given config dict merged with common agent confs. - return with_base_config(COMMON_CONFIG, extra_config) + Args: + extra_config (PartialTrainerConfigDict): A user defined partial config + which will get merged with COMMON_CONFIG and returned. - -def with_base_config( - base_config: TrainerConfigDict, - extra_config: PartialTrainerConfigDict) -> TrainerConfigDict: - """Returns the given config dict merged with a base agent conf.""" - - config = copy.deepcopy(base_config) - config.update(extra_config) - return config + Returns: + TrainerConfigDict: The merged config dict resulting of COMMON_CONFIG + plus `extra_config`. + """ + return Trainer.merge_trainer_configs( + COMMON_CONFIG, extra_config, _allow_unknown_configs=True) @PublicAPI @@ -664,7 +663,7 @@ class Trainer(Trainable): self.evaluation_workers = self._make_workers( self.env_creator, - self._policy, + self._policy_class, merge_dicts(self.config, extra_config), num_workers=self.config["evaluation_num_workers"]) self.evaluation_metrics = {} @@ -691,7 +690,7 @@ class Trainer(Trainable): @DeveloperAPI def _make_workers(self, env_creator: Callable[[EnvContext], EnvType], - policy: type, config: TrainerConfigDict, + policy_class: Type[Policy], config: TrainerConfigDict, num_workers: int) -> WorkerSet: """Default factory method for a WorkerSet running under this Trainer. @@ -701,9 +700,9 @@ class Trainer(Trainable): Args: env_creator (callable): A function that return and Env given an env config. - policy (class): The Policy class to use for creating the policies - of the workers. - config (dict): The Trainer's config. + policy (Type[Policy]): The Policy class to use for creating the + policies of the workers. + config (TrainerConfigDict): The Trainer's config. num_workers (int): Number of remote rollout workers to create. 0 for local only. @@ -711,9 +710,9 @@ class Trainer(Trainable): WorkerSet: The created WorkerSet. """ return WorkerSet( - env_creator, - policy, - config, + env_creator=env_creator, + policy_class=policy_class, + trainer_config=config, num_workers=num_workers, logdir=self.logdir) @@ -1044,8 +1043,11 @@ class Trainer(Trainable): "The config of this agent is: {}".format(config)) @classmethod - def merge_trainer_configs(cls, config1: TrainerConfigDict, - config2: PartialTrainerConfigDict) -> dict: + def merge_trainer_configs(cls, + config1: TrainerConfigDict, + config2: PartialTrainerConfigDict, + _allow_unknown_configs: Optional[bool] = None + ) -> TrainerConfigDict: config1 = copy.deepcopy(config1) # Error if trainer default has deprecated value. if config1["sample_batch_size"] != DEPRECATED_VALUE: @@ -1067,7 +1069,9 @@ class Trainer(Trainable): legacy_callbacks_dict=legacy_callbacks_dict) config2["callbacks"] = make_callbacks - return deep_update(config1, config2, cls._allow_unknown_configs, + if _allow_unknown_configs is None: + _allow_unknown_configs = cls._allow_unknown_configs + return deep_update(config1, config2, _allow_unknown_configs, cls._allow_unknown_subkeys, cls._override_all_subkeys_if_type_changes) diff --git a/rllib/agents/trainer_template.py b/rllib/agents/trainer_template.py index c835b776f..84c6f5e9c 100644 --- a/rllib/agents/trainer_template.py +++ b/rllib/agents/trainer_template.py @@ -1,5 +1,5 @@ import logging -from typing import Callable, Optional, List, Iterable +from typing import Callable, Iterable, List, Optional, Type from ray.rllib.agents.trainer import Trainer, COMMON_CONFIG from ray.rllib.evaluation.worker_set import WorkerSet @@ -9,7 +9,8 @@ from ray.rllib.execution.metric_ops import StandardMetricsReporting from ray.rllib.policy import Policy from ray.rllib.utils import add_mixins from ray.rllib.utils.annotations import override, DeveloperAPI -from ray.rllib.utils.typing import TrainerConfigDict, ResultDict +from ray.rllib.utils.typing import EnvConfigDict, EnvType, ResultDict, \ + TrainerConfigDict logger = logging.getLogger(__name__) @@ -33,17 +34,19 @@ def default_execution_plan(workers: WorkerSet, config: TrainerConfigDict): @DeveloperAPI def build_trainer( name: str, - default_policy: Optional[Policy], *, default_config: TrainerConfigDict = None, validate_config: Callable[[TrainerConfigDict], None] = None, - get_policy_class: Callable[[TrainerConfigDict], Policy] = None, - before_init: Callable[[Trainer], None] = None, - after_init: Callable[[Trainer], None] = None, - before_evaluate_fn: Callable[[Trainer], None] = None, - mixins: List[type] = None, - execution_plan: Callable[[WorkerSet, TrainerConfigDict], Iterable[ - ResultDict]] = default_execution_plan): + default_policy: Optional[Type[Policy]] = None, + get_policy_class: Optional[Callable[[TrainerConfigDict], Optional[Type[ + Policy]]]] = None, + before_init: Optional[Callable[[Trainer], None]] = None, + after_init: Optional[Callable[[Trainer], None]] = None, + before_evaluate_fn: Optional[Callable[[Trainer], None]] = None, + mixins: Optional[List[type]] = None, + execution_plan: Optional[Callable[[ + WorkerSet, TrainerConfigDict + ], Iterable[ResultDict]]] = default_execution_plan): """Helper function for defining a custom trainer. Functions will be run in this order to initialize the trainer: @@ -51,22 +54,30 @@ def build_trainer( 2. Worker setup: before_init, execution_plan 3. Post setup: after_init - Arguments: + Args: name (str): name of the trainer (e.g., "PPO") - default_policy (cls): the default Policy class to use - default_config (dict): The default config dict of the algorithm, - otherwise uses the Trainer default config. + default_config (TrainerConfigDict): The default config dict + of the algorithm, otherwise uses the Trainer default config. validate_config (Optional[callable]): Optional callable that takes the config to check for correctness. It may mutate the config as needed. - get_policy_class (Optional[callable]): Optional callable that takes a - config and returns the policy class to override the default with. - before_init (Optional[callable]): Optional callable to run at the start - of trainer init that takes the trainer instance as argument. - after_init (Optional[callable]): Optional callable to run at the end of - trainer init that takes the trainer instance as argument. - before_evaluate_fn (Optional[callable]): callback to run before - evaluation. This takes the trainer instance as argument. + default_policy (Optional[Type[Policy]]): The default Policy class to + use. + get_policy_class (Optional[Callable[ + TrainerConfigDict, Optional[Type[Policy]]]]): Optional callable + that takes a config and returns the policy class or None. If None + is returned, will use `default_policy` (which must be provided + then). + before_init (Optional[Callable[[Trainer], None]]): Optional callable to + run before anything is constructed inside Trainer (Workers with + Policies, execution plan, etc..). Takes the Trainer instance as + argument. + after_init (Optional[Callable[[Trainer], None]]): Optional callable to + run at the end of trainer init (after all Workers and the exec. + plan have been constructed). Takes the Trainer instance as + argument. + before_evaluate_fn (Optional[Callable[[Trainer], None]]): Callback to + run before evaluation. This takes the trainer instance as argument. mixins (list): list of any class mixins for the returned trainer class. These mixins will be applied in order and will have higher precedence than the Trainer class. @@ -82,26 +93,37 @@ def build_trainer( class trainer_cls(base): _name = name _default_config = default_config or COMMON_CONFIG - _policy = default_policy + _policy_class = default_policy def __init__(self, config=None, env=None, logger_creator=None): Trainer.__init__(self, config, env, logger_creator) - def _init(self, config, env_creator): + def _init(self, config: TrainerConfigDict, + env_creator: Callable[[EnvConfigDict], EnvType]): + # Validate config via custom validation function. if validate_config: validate_config(config) if get_policy_class is None: - self._policy = default_policy + if not config["multiagent"]["policies"]: + assert default_policy is not None + self._policy_class = default_policy else: - self._policy = get_policy_class(config) + self._policy_class = get_policy_class(config) + if self._policy_class is None: + assert default_policy is not None + self._policy_class = default_policy + if before_init: before_init(self) + # Creating all workers (excluding evaluation workers). - self.workers = self._make_workers( - env_creator, self._policy, config, self.config["num_workers"]) + self.workers = self._make_workers(env_creator, self._policy_class, + config, + self.config["num_workers"]) self.execution_plan = execution_plan self.train_exec_impl = execution_plan(self.workers, config) + if after_init: after_init(self) diff --git a/rllib/evaluation/multi_agent_sample_collector.py b/rllib/evaluation/multi_agent_sample_collector.py index 6b546e810..dd81b9e38 100644 --- a/rllib/evaluation/multi_agent_sample_collector.py +++ b/rllib/evaluation/multi_agent_sample_collector.py @@ -12,7 +12,7 @@ from ray.rllib.policy.sample_batch import MultiAgentBatch from ray.rllib.utils import force_list from ray.rllib.utils.annotations import override from ray.rllib.utils.debug import summarize -from ray.rllib.utils.types import AgentID, EnvID, EpisodeID, PolicyID, \ +from ray.rllib.utils.typing import AgentID, EnvID, EpisodeID, PolicyID, \ TensorType from ray.util.debug import log_once diff --git a/rllib/evaluation/per_policy_sample_collector.py b/rllib/evaluation/per_policy_sample_collector.py index 0834c0fac..58e95231e 100644 --- a/rllib/evaluation/per_policy_sample_collector.py +++ b/rllib/evaluation/per_policy_sample_collector.py @@ -6,7 +6,7 @@ from ray.rllib.evaluation.episode import MultiAgentEpisode from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.policy.view_requirement import ViewRequirement from ray.rllib.utils.framework import try_import_tf, try_import_torch -from ray.rllib.utils.types import AgentID, EnvID, EpisodeID, TensorType +from ray.rllib.utils.typing import AgentID, EnvID, EpisodeID, TensorType tf1, tf, tfv = try_import_tf() torch, _ = try_import_torch() diff --git a/rllib/evaluation/postprocessing.py b/rllib/evaluation/postprocessing.py index 11a31562a..93008a313 100644 --- a/rllib/evaluation/postprocessing.py +++ b/rllib/evaluation/postprocessing.py @@ -23,16 +23,16 @@ def compute_advantages(rollout: SampleBatch, use_gae: bool = True, use_critic: bool = True): """ - Given a rollout, compute its value targets and the advantage. + Given a rollout, compute its value targets and the advantages. Args: - rollout (SampleBatch): SampleBatch of a single trajectory - last_r (float): Value estimation for last observation + rollout (SampleBatch): SampleBatch of a single trajectory. + last_r (float): Value estimation for last observation. gamma (float): Discount factor. - lambda_ (float): Parameter for GAE - use_gae (bool): Using Generalized Advantage Estimation + lambda_ (float): Parameter for GAE. + use_gae (bool): Using Generalized Advantage Estimation. use_critic (bool): Whether to use critic (value estimates). Setting - this to False will use 0 as baseline. + this to False will use 0 as baseline. Returns: SampleBatch (SampleBatch): Object with experience from rollout and diff --git a/rllib/evaluation/worker_set.py b/rllib/evaluation/worker_set.py index a1b269b4c..d79e0c4f1 100644 --- a/rllib/evaluation/worker_set.py +++ b/rllib/evaluation/worker_set.py @@ -1,6 +1,6 @@ import logging from types import FunctionType -from typing import TypeVar, Callable, List, Union +from typing import Callable, List, Optional, Type, TypeVar, Union import ray from ray.rllib.utils.annotations import DeveloperAPI @@ -30,21 +30,23 @@ class WorkerSet: """ def __init__(self, - env_creator: Callable[[EnvContext], EnvType], - policy: type, - trainer_config: TrainerConfigDict = None, + *, + env_creator: Optional[Callable[[EnvContext], EnvType]] = None, + policy_class: Optional[Type[Policy]] = None, + trainer_config: Optional[TrainerConfigDict] = None, num_workers: int = 0, - logdir: str = None, + logdir: Optional[str] = None, _setup: bool = True): """Create a new WorkerSet and initialize its workers. Arguments: - env_creator (func): Function that returns env given env config. - policy (cls): rllib.policy.Policy class. - trainer_config (dict): Optional dict that extends the common - config of the Trainer class. + env_creator (Optional[Callable[[EnvContext], EnvType]]): Function + that returns env given env config. + policy (Optional[Type[Policy]]): A rllib.policy.Policy class. + trainer_config (Optional[TrainerConfigDict]): Optional dict that + extends the common config of the Trainer class. num_workers (int): Number of remote rollout workers to create. - logdir (str): Optional logging directory for workers. + logdir (Optional[str]): Optional logging directory for workers. _setup (bool): Whether to setup workers. This is only for testing. """ @@ -53,7 +55,7 @@ class WorkerSet: trainer_config = COMMON_CONFIG self._env_creator = env_creator - self._policy = policy + self._policy_class = policy_class self._remote_config = trainer_config self._logdir = logdir @@ -63,8 +65,9 @@ class WorkerSet: {"tf_session_args": trainer_config["local_tf_session_args"]}) # Always create a local worker - self._local_worker = self._make_worker( - RolloutWorker, env_creator, policy, 0, self._local_config) + self._local_worker = self._make_worker(RolloutWorker, env_creator, + self._policy_class, 0, + self._local_config) # Create a number of remote workers self._remote_workers = [] @@ -102,8 +105,9 @@ class WorkerSet: } cls = RolloutWorker.as_remote(**remote_args).remote self._remote_workers.extend([ - self._make_worker(cls, self._env_creator, self._policy, i + 1, - self._remote_config) for i in range(num_workers) + self._make_worker(cls, self._env_creator, self._policy_class, + i + 1, self._remote_config) + for i in range(num_workers) ]) def reset(self, new_remote_workers: List["ActorHandle"]) -> None: @@ -190,14 +194,18 @@ class WorkerSet: @staticmethod def _from_existing(local_worker: RolloutWorker, remote_workers: List["ActorHandle"] = None): - workers = WorkerSet(None, None, {}, _setup=False) + workers = WorkerSet( + env_creator=None, + policy_class=None, + trainer_config={}, + _setup=False) workers._local_worker = local_worker workers._remote_workers = remote_workers or [] return workers def _make_worker( self, cls: Callable, env_creator: Callable[[EnvContext], EnvType], - policy: Policy, worker_index: int, + policy: Type[Policy], worker_index: int, config: TrainerConfigDict) -> Union[RolloutWorker, "ActorHandle"]: def session_creator(): logger.debug("Creating TF session {}".format( diff --git a/rllib/policy/tf_policy_template.py b/rllib/policy/tf_policy_template.py index 6242ed611..e27c30b7c 100644 --- a/rllib/policy/tf_policy_template.py +++ b/rllib/policy/tf_policy_template.py @@ -1,6 +1,7 @@ import gym -from typing import Callable, Dict, List, Optional, Tuple +from typing import Callable, Dict, List, Optional, Tuple, Type, Union +from ray.rllib.models.tf.tf_action_dist import TFActionDistribution from ray.rllib.models.modelv2 import ModelV2 from ray.rllib.policy.dynamic_tf_policy import DynamicTFPolicy from ray.rllib.policy import eager_tf_policy @@ -17,12 +18,15 @@ from ray.rllib.utils.typing import ModelGradients, TensorType, \ def build_tf_policy( name: str, *, - loss_fn: Callable[[Policy, ModelV2, type, SampleBatch], TensorType], + loss_fn: Callable[[ + Policy, ModelV2, Type[TFActionDistribution], SampleBatch + ], Union[TensorType, List[TensorType]]], get_default_config: Optional[Callable[[None], TrainerConfigDict]] = None, postprocess_fn: Optional[Callable[[ - Policy, SampleBatch, List[SampleBatch], "MultiAgentEpisode" - ], None]] = None, + Policy, SampleBatch, Optional[List[SampleBatch]], Optional[ + "MultiAgentEpisode"] + ], SampleBatch]] = None, stats_fn: Optional[Callable[[Policy, SampleBatch], Dict[ str, TensorType]]] = None, optimizer_fn: Optional[Callable[[ @@ -81,8 +85,10 @@ def build_tf_policy( Args: name (str): Name of the policy (e.g., "PPOTFPolicy"). - loss_fn (Callable[[Policy, ModelV2, type, SampleBatch], TensorType]): - Callable for calculating a loss tensor. + loss_fn (Callable[[ + Policy, ModelV2, Type[TFActionDistribution], SampleBatch], + Union[TensorType, List[TensorType]]]): Callable for calculating a + loss tensor. get_default_config (Optional[Callable[[None], TrainerConfigDict]]): Optional callable that returns the default config to merge with any overrides. If None, uses only(!) the user-provided diff --git a/rllib/policy/torch_policy_template.py b/rllib/policy/torch_policy_template.py index 13ed32024..11cee46bb 100644 --- a/rllib/policy/torch_policy_template.py +++ b/rllib/policy/torch_policy_template.py @@ -1,5 +1,5 @@ import gym -from typing import Callable, Dict, List, Optional, Tuple +from typing import Callable, Dict, List, Optional, Tuple, Type, Union from ray.rllib.models.catalog import ModelCatalog from ray.rllib.models.modelv2 import ModelV2 @@ -22,7 +22,9 @@ torch, _ = try_import_torch() def build_torch_policy( name: str, *, - loss_fn: Callable[[Policy, ModelV2, type, SampleBatch], TensorType], + loss_fn: Callable[[ + Policy, ModelV2, Type[TorchDistributionWrapper], SampleBatch + ], Union[TensorType, List[TensorType]]], get_default_config: Optional[Callable[[], TrainerConfigDict]] = None, stats_fn: Optional[Callable[[Policy, SampleBatch], Dict[ str, TensorType]]] = None, @@ -80,8 +82,9 @@ def build_torch_policy( super's `postprocess_trajectory` method). stats_fn (Optional[Callable[[Policy, SampleBatch], Dict[str, TensorType]]]): Optional callable that returns a dict of - values given the policy and batch input tensors. If None, - will use `TorchPolicy.extra_grad_info()` instead. + values given the policy and training batch. If None, + will use `TorchPolicy.extra_grad_info()` instead. The stats dict is + used for logging (e.g. in TensorBoard). extra_action_out_fn (Optional[Callable[[Policy, Dict[str, TensorType, List[TensorType], ModelV2, TorchDistributionWrapper]], Dict[str, TensorType]]]): Optional callable that returns a dict of extra