diff --git a/python/requirements_rllib.txt b/python/requirements_rllib.txt index ac43812a5..0b7543587 100644 --- a/python/requirements_rllib.txt +++ b/python/requirements_rllib.txt @@ -5,6 +5,7 @@ torch>=1.6.0 # Version requirement to match Tune torchvision>=0.6.0 smart_open + # For testing in MuJoCo-like envs (in PyBullet). pybullet # For tests on PettingZoo's multi-agent envs. diff --git a/rllib/BUILD b/rllib/BUILD index 67747559c..8af609982 100644 --- a/rllib/BUILD +++ b/rllib/BUILD @@ -462,7 +462,7 @@ py_test( py_test( name = "test_ddpg", tags = ["agents_dir"], - size = "medium", + size = "large", srcs = ["agents/ddpg/tests/test_ddpg.py"] ) @@ -531,7 +531,6 @@ py_test( ) # MBMPOTrainer -# Removed due to Higher API conflicts with Pytorch-Import tests #py_test( # name = "test_mbmpo", # tags = ["agents_dir"], diff --git a/rllib/agents/a3c/a3c.py b/rllib/agents/a3c/a3c.py index eafa146f5..88e91bf82 100644 --- a/rllib/agents/a3c/a3c.py +++ b/rllib/agents/a3c/a3c.py @@ -37,9 +37,6 @@ DEFAULT_CONFIG = with_common_config({ # Workers sample async. Note that this increases the effective # rollout_fragment_length by up to 5x due to async buffering of batches. "sample_async": True, - # Use the new "trajectory view API" to collect samples and produce - # model- and policy inputs. - "_use_trajectory_view_api": True, }) # __sphinx_doc_end__ # yapf: enable diff --git a/rllib/agents/ars/ars.py b/rllib/agents/ars/ars.py index ef2195213..7c42d37b7 100644 --- a/rllib/agents/ars/ars.py +++ b/rllib/agents/ars/ars.py @@ -47,10 +47,6 @@ DEFAULT_CONFIG = with_common_config({ "num_envs_per_worker": 1, "observation_filter": "NoFilter" }, - - # Use the new "trajectory view API" to collect samples and produce - # model- and policy inputs. - "_use_trajectory_view_api": True, }) # __sphinx_doc_end__ # yapf: enable diff --git a/rllib/agents/ddpg/ddpg.py b/rllib/agents/ddpg/ddpg.py index 3729e3455..9e580c0f8 100644 --- a/rllib/agents/ddpg/ddpg.py +++ b/rllib/agents/ddpg/ddpg.py @@ -145,10 +145,6 @@ DEFAULT_CONFIG = with_common_config({ "worker_side_prioritization": False, # Prevent iterations from going lower than this time span "min_iter_time_s": 1, - - # Use the new "trajectory view API" to collect samples and produce - # model- and policy inputs. - "_use_trajectory_view_api": True, }) # __sphinx_doc_end__ # yapf: enable diff --git a/rllib/agents/dqn/dqn.py b/rllib/agents/dqn/dqn.py index 15e549a20..73d24e2bb 100644 --- a/rllib/agents/dqn/dqn.py +++ b/rllib/agents/dqn/dqn.py @@ -132,10 +132,6 @@ DEFAULT_CONFIG = with_common_config({ "worker_side_prioritization": False, # Prevent iterations from going lower than this time span "min_iter_time_s": 1, - - # Use the new "trajectory view API" to collect samples and produce - # model- and policy inputs. - "_use_trajectory_view_api": True, }) # __sphinx_doc_end__ # yapf: enable diff --git a/rllib/agents/dqn/simple_q.py b/rllib/agents/dqn/simple_q.py index da2fffd20..f2fbc59b8 100644 --- a/rllib/agents/dqn/simple_q.py +++ b/rllib/agents/dqn/simple_q.py @@ -90,10 +90,6 @@ DEFAULT_CONFIG = with_common_config({ "num_workers": 0, # Prevent iterations from going lower than this time span "min_iter_time_s": 1, - - # Use the new "trajectory view API" to collect samples and produce - # model- and policy inputs. - "_use_trajectory_view_api": True, }) # __sphinx_doc_end__ # yapf: enable diff --git a/rllib/agents/es/es.py b/rllib/agents/es/es.py index 065f141f5..28202e3bd 100644 --- a/rllib/agents/es/es.py +++ b/rllib/agents/es/es.py @@ -45,10 +45,6 @@ DEFAULT_CONFIG = with_common_config({ "num_envs_per_worker": 1, "observation_filter": "NoFilter" }, - - # Use the new "trajectory view API" to collect samples and produce - # model- and policy inputs. - "_use_trajectory_view_api": True, }) # __sphinx_doc_end__ # yapf: enable diff --git a/rllib/agents/impala/impala.py b/rllib/agents/impala/impala.py index a46d3c39e..7a09b1f9a 100644 --- a/rllib/agents/impala/impala.py +++ b/rllib/agents/impala/impala.py @@ -91,10 +91,6 @@ DEFAULT_CONFIG = with_common_config({ # Callback for APPO to use to update KL, target network periodically. # The input to the callback is the learner fetches dict. "after_train_step": None, - - # Use the new "trajectory view API" to collect samples and produce - # model- and policy inputs. - "_use_trajectory_view_api": True, }) # __sphinx_doc_end__ # yapf: enable diff --git a/rllib/agents/pg/pg.py b/rllib/agents/pg/pg.py index 08a4fa510..2a9522cd8 100644 --- a/rllib/agents/pg/pg.py +++ b/rllib/agents/pg/pg.py @@ -30,9 +30,6 @@ DEFAULT_CONFIG = with_common_config({ "num_workers": 0, # Learning rate. "lr": 0.0004, - # Use the new "trajectory view API" to collect samples and produce - # model- and policy inputs. - "_use_trajectory_view_api": True, }) # __sphinx_doc_end__ diff --git a/rllib/agents/ppo/ppo.py b/rllib/agents/ppo/ppo.py index cad2050a7..c8f6db43f 100644 --- a/rllib/agents/ppo/ppo.py +++ b/rllib/agents/ppo/ppo.py @@ -89,10 +89,6 @@ DEFAULT_CONFIG = with_common_config({ # Whether to fake GPUs (using CPUs). # Set this to True for debugging on non-GPU machines (set `num_gpus` > 0). "_fake_gpus": False, - - # Use the new "trajectory view API" to collect samples and produce - # model- and policy inputs. - "_use_trajectory_view_api": True, }) # __sphinx_doc_end__ diff --git a/rllib/agents/qmix/model.py b/rllib/agents/qmix/model.py index 0c7a6d117..42e55fe7b 100644 --- a/rllib/agents/qmix/model.py +++ b/rllib/agents/qmix/model.py @@ -1,6 +1,9 @@ +from gym.spaces import Box + from ray.rllib.models.modelv2 import ModelV2 from ray.rllib.models.preprocessors import get_preprocessor from ray.rllib.models.torch.torch_modelv2 import TorchModelV2 +from ray.rllib.policy.view_requirement import ViewRequirement from ray.rllib.utils.annotations import override from ray.rllib.utils.framework import try_import_torch @@ -20,6 +23,14 @@ class RNNModel(TorchModelV2, nn.Module): self.fc1 = nn.Linear(self.obs_size, self.rnn_hidden_dim) self.rnn = nn.GRUCell(self.rnn_hidden_dim, self.rnn_hidden_dim) self.fc2 = nn.Linear(self.rnn_hidden_dim, num_outputs) + self.n_agents = model_config["n_agents"] + + self.inference_view_requirements.update({ + "state_in_0": ViewRequirement( + "state_out_0", + data_rel_pos=-1, + space=Box(-1.0, 1.0, (self.n_agents, self.rnn_hidden_dim))) + }) @override(ModelV2) def get_initial_state(self): diff --git a/rllib/agents/qmix/qmix_policy.py b/rllib/agents/qmix/qmix_policy.py index 76974d269..a22518a04 100644 --- a/rllib/agents/qmix/qmix_policy.py +++ b/rllib/agents/qmix/qmix_policy.py @@ -162,6 +162,7 @@ class QMixTorchPolicy(Policy): self.framework = "torch" super().__init__(obs_space, action_space, config) self.n_agents = len(obs_space.original_space.spaces) + config["model"]["n_agents"] = self.n_agents self.n_actions = action_space.spaces[0].n self.h_size = config["model"]["lstm_cell_size"] self.has_env_global_state = False @@ -214,6 +215,9 @@ class QMixTorchPolicy(Policy): name="target_model", default_model=RNNModel).to(self.device) + # Combine view_requirements for Model and Policy. + self.view_requirements.update(self.model.inference_view_requirements) + self.exploration = self._create_exploration() # Setup the mixer network. diff --git a/rllib/agents/sac/sac.py b/rllib/agents/sac/sac.py index 78ad888b8..daf66f88a 100644 --- a/rllib/agents/sac/sac.py +++ b/rllib/agents/sac/sac.py @@ -134,10 +134,6 @@ DEFAULT_CONFIG = with_common_config({ # Use a Beta-distribution instead of a SquashedGaussian for bounded, # continuous action spaces (not recommended, for debugging only). "_use_beta_distribution": False, - - # Use the new "trajectory view API" to collect samples and produce - # model- and policy inputs. - "_use_trajectory_view_api": True, }) # __sphinx_doc_end__ # yapf: enable diff --git a/rllib/agents/trainer.py b/rllib/agents/trainer.py index 0c5c9cf94..b5751e264 100644 --- a/rllib/agents/trainer.py +++ b/rllib/agents/trainer.py @@ -222,8 +222,7 @@ COMMON_CONFIG: TrainerConfigDict = { # Experimental flag to speed up sampling and use "trajectory views" as # generic ModelV2 `input_dicts` that can be requested by the model to # contain different information on the ongoing episode. - # NOTE: Only supported for PyTorch so far. - "_use_trajectory_view_api": False, + "_use_trajectory_view_api": True, # Element-wise observation filter, either "NoFilter" or "MeanStdFilter". "observation_filter": "NoFilter", diff --git a/rllib/contrib/alpha_zero/core/alpha_zero_policy.py b/rllib/contrib/alpha_zero/core/alpha_zero_policy.py index 2ae81299f..4b96f3d77 100644 --- a/rllib/contrib/alpha_zero/core/alpha_zero_policy.py +++ b/rllib/contrib/alpha_zero/core/alpha_zero_policy.py @@ -38,15 +38,27 @@ class AlphaZeroPolicy(TorchPolicy): episodes=None, **kwargs): + input_dict = {"obs": obs_batch} + if prev_action_batch: + input_dict["prev_actions"] = prev_action_batch + if prev_reward_batch: + input_dict["prev_rewards"] = prev_reward_batch + + return self.compute_actions_from_input_dict( + input_dict=input_dict, + episodes=episodes, + state_batches=state_batches, + ) + + @override(Policy) + def compute_actions_from_input_dict(self, + input_dict, + explore=None, + timestep=None, + episodes=None, + **kwargs): with torch.no_grad(): - input_dict = {"obs": obs_batch} - if prev_action_batch: - input_dict["prev_actions"] = prev_action_batch - if prev_reward_batch: - input_dict["prev_rewards"] = prev_reward_batch - actions = [] - for i, episode in enumerate(episodes): if episode.length == 0: # if first time step of episode, get initial env state @@ -89,7 +101,7 @@ class AlphaZeroPolicy(TorchPolicy): episode.user_data["mcts_policies"].append(mcts_policy) return np.array(actions), [], self.extra_action_out( - input_dict, state_batches, self.model, None) + input_dict, kwargs.get("state_batches", []), self.model, None) @override(Policy) def postprocess_trajectory(self, diff --git a/rllib/evaluation/collectors/simple_list_collector.py b/rllib/evaluation/collectors/simple_list_collector.py index 7db62c532..a6be8c72f 100644 --- a/rllib/evaluation/collectors/simple_list_collector.py +++ b/rllib/evaluation/collectors/simple_list_collector.py @@ -251,10 +251,12 @@ class _PolicyCollector: training). """ for view_col, data in batch.items(): + # TODO(ekl) how do we handle this for policies that don't extend + # Torch / TF Policy template (no inference of view reqs)? # Skip columns that are not used for training. - if view_col not in view_requirements or \ - not view_requirements[view_col].used_for_training: - continue + # if view_col not in view_requirements or \ + # not view_requirements[view_col].used_for_training: + # continue self.buffers[view_col].extend(data) # Add the agent's trajectory length to our count. self.count += batch.count diff --git a/rllib/evaluation/sampler.py b/rllib/evaluation/sampler.py index 804c269aa..fd4219366 100644 --- a/rllib/evaluation/sampler.py +++ b/rllib/evaluation/sampler.py @@ -1063,7 +1063,8 @@ def _process_observations_w_trajectory_view_api( # Add extra-action-fetches to collectors. pol = policies[policy_id] for key, value in episode.last_pi_info_for(agent_id).items(): - values_dict[key] = value + if key in pol.view_requirements: + values_dict[key] = value # Env infos for this agent. if "infos" in pol.view_requirements: values_dict["infos"] = agent_infos diff --git a/rllib/models/tf/attention_net.py b/rllib/models/tf/attention_net.py index 48d25c132..2ddbaf33b 100644 --- a/rllib/models/tf/attention_net.py +++ b/rllib/models/tf/attention_net.py @@ -313,6 +313,7 @@ class GTrXLNet(RecurrentNetwork): return logits, [observations] + memory_outs + # TODO: (sven) Deprecate this once trajectory view API has fully matured. @override(RecurrentNetwork) def get_initial_state(self) -> List[np.ndarray]: # State is the T last observations concat'd together into one Tensor. diff --git a/rllib/policy/dynamic_tf_policy.py b/rllib/policy/dynamic_tf_policy.py index 96ee284e5..900d6cb57 100644 --- a/rllib/policy/dynamic_tf_policy.py +++ b/rllib/policy/dynamic_tf_policy.py @@ -80,6 +80,8 @@ class DynamicTFPolicy(TFPolicy): ], Tuple[TensorType, type, List[TensorType]]]] = None, existing_inputs: Optional[Dict[str, "tf1.placeholder"]] = None, existing_model: Optional[ModelV2] = None, + view_requirements_fn: Optional[Callable[[Policy], Dict[ + str, ViewRequirement]]] = None, get_batch_divisibility_req: Optional[Callable[[Policy], int]] = None, obs_include_prev_action_reward: bool = True): @@ -388,6 +390,7 @@ class DynamicTFPolicy(TFPolicy): instance._grad_stats_fn(instance, input_dict, instance._grads)) return instance + # TODO: (sven) deprecate once _use_trajectory_view_api is always True. @override(Policy) @DeveloperAPI def get_initial_state(self) -> List[TensorType]: @@ -545,7 +548,8 @@ class DynamicTFPolicy(TFPolicy): for i, si in enumerate(self._state_inputs): train_batch["state_in_{}".format(i)] = si else: - train_batch = UsageTrackingDict(self._input_dict) + train_batch = UsageTrackingDict( + dict(self._input_dict, **self._loss_input_dict)) if self._state_inputs: train_batch["seq_lens"] = self._seq_lens diff --git a/rllib/policy/eager_tf_policy.py b/rllib/policy/eager_tf_policy.py index dda13e563..62dbe3148 100644 --- a/rllib/policy/eager_tf_policy.py +++ b/rllib/policy/eager_tf_policy.py @@ -194,6 +194,7 @@ def build_eager_tf_policy(name, action_sampler_fn=None, action_distribution_fn=None, mixins=None, + view_requirements_fn=None, obs_include_prev_action_reward=True, get_batch_divisibility_req=None): """Build an eager TF policy. @@ -264,6 +265,9 @@ def build_eager_tf_policy(name, for s in self.model.get_initial_state() ] + # Update this Policy's ViewRequirements (if function given). + if callable(view_requirements_fn): + self.view_requirements.update(view_requirements_fn(self)) # Combine view_requirements for Model and Policy. self.view_requirements.update( self.model.inference_view_requirements) diff --git a/rllib/policy/tf_policy_template.py b/rllib/policy/tf_policy_template.py index 11f2680bc..2a738ddc9 100644 --- a/rllib/policy/tf_policy_template.py +++ b/rllib/policy/tf_policy_template.py @@ -8,6 +8,7 @@ from ray.rllib.policy import eager_tf_policy from ray.rllib.policy.policy import Policy, LEARNER_STATS_KEY from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.policy.tf_policy import TFPolicy +from ray.rllib.policy.view_requirement import ViewRequirement from ray.rllib.utils import add_mixins, force_list from ray.rllib.utils.annotations import override, DeveloperAPI from ray.rllib.utils.framework import try_import_tf @@ -65,6 +66,8 @@ def build_tf_policy( Policy, ModelV2, TensorType, TensorType, TensorType ], Tuple[TensorType, type, List[TensorType]]]] = None, mixins: Optional[List[type]] = None, + view_requirements_fn: Optional[Callable[[Policy], Dict[ + str, ViewRequirement]]] = None, get_batch_divisibility_req: Optional[Callable[[Policy], int]] = None, # TODO: (sven) deprecate once _use_trajectory_view_api is always True. obs_include_prev_action_reward: bool = True, @@ -170,6 +173,9 @@ def build_tf_policy( mixins (Optional[List[type]]): Optional list of any class mixins for the returned policy class. These mixins will be applied in order and will have higher precedence than the DynamicTFPolicy class. + view_requirements_fn (Callable[[Policy], + Dict[str, ViewRequirement]]): An optional callable to retrieve + additional train view requirements for this policy. get_batch_divisibility_req (Optional[Callable[[Policy], int]]): Optional callable that returns the divisibility requirement for sample batches. If None, will assume a value of 1. @@ -208,6 +214,8 @@ def build_tf_policy( else: policy._extra_action_fetches = extra_action_fetches_fn( policy) + policy._extra_action_fetches = extra_action_fetches_fn( + policy) DynamicTFPolicy.__init__( self, @@ -223,6 +231,7 @@ def build_tf_policy( action_distribution_fn=action_distribution_fn, existing_inputs=existing_inputs, existing_model=existing_model, + view_requirements_fn=view_requirements_fn, get_batch_divisibility_req=get_batch_divisibility_req, obs_include_prev_action_reward=obs_include_prev_action_reward) diff --git a/rllib/policy/torch_policy_template.py b/rllib/policy/torch_policy_template.py index c1320b9d4..dd2aadfa6 100644 --- a/rllib/policy/torch_policy_template.py +++ b/rllib/policy/torch_policy_template.py @@ -8,6 +8,7 @@ from ray.rllib.models.torch.torch_modelv2 import TorchModelV2 from ray.rllib.policy.policy import Policy, LEARNER_STATS_KEY from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.policy.torch_policy import TorchPolicy +from ray.rllib.policy.view_requirement import ViewRequirement from ray.rllib.utils import add_mixins, force_list from ray.rllib.utils.annotations import override, DeveloperAPI from ray.rllib.utils.framework import try_import_torch @@ -69,6 +70,8 @@ def build_torch_policy( apply_gradients_fn: Optional[Callable[ [Policy, "torch.optim.Optimizer"], None]] = None, mixins: Optional[List[type]] = None, + view_requirements_fn: Optional[Callable[[Policy], Dict[ + str, ViewRequirement]]] = None, get_batch_divisibility_req: Optional[Callable[[Policy], int]] = None ) -> Type[TorchPolicy]: """Helper function for creating a torch policy class at runtime. @@ -171,6 +174,9 @@ def build_torch_policy( mixins (Optional[List[type]]): Optional list of any class mixins for the returned policy class. These mixins will be applied in order and will have higher precedence than the TorchPolicy class. + view_requirements_fn (Optional[Callable[[Policy], + Dict[str, ViewRequirement]]]): An optional callable to retrieve + additional train view requirements for this policy. get_batch_divisibility_req (Optional[Callable[[Policy], int]]): Optional callable that returns the divisibility requirement for sample batches. If None, will assume a value of 1. @@ -236,6 +242,10 @@ def build_torch_policy( get_batch_divisibility_req=get_batch_divisibility_req, ) + # Update this Policy's ViewRequirements (if function given). + if callable(view_requirements_fn): + self.view_requirements.update(view_requirements_fn(self)) + # Merge Model's view requirements into Policy's. self.view_requirements.update( self.model.inference_view_requirements) @@ -244,6 +254,7 @@ def build_torch_policy( _before_loss_init(self, self.observation_space, self.action_space, config) + # Perform test runs through postprocessing- and loss functions. self._initialize_loss_from_dummy_batch( auto_remove_unneeded_view_reqs=True, stats_fn=stats_fn,