From e968b52cb7de4e6c2fcc6e7d5ccb98d984745715 Mon Sep 17 00:00:00 2001 From: Sven Mika Date: Fri, 21 Aug 2020 12:35:16 +0200 Subject: [PATCH] [RLlib] Trajectory view API - 03 Fast LSTM + prev actions/rewards (#9950) --- rllib/BUILD | 34 +- rllib/agents/callbacks.py | 17 +- rllib/agents/ppo/ppo_torch_policy.py | 11 +- rllib/agents/trainer.py | 5 + rllib/env/policy_client.py | 2 +- .../multi_agent_sample_collector.py | 30 +- .../evaluation/per_policy_sample_collector.py | 64 +- rllib/evaluation/rollout_worker.py | 4 +- rllib/evaluation/sampler.py | 650 ++++++++++++++---- .../tests/test_trajectory_view_api.py | 277 ++++++++ rllib/examples/env/debug_counter_env.py | 43 ++ .../policy/episode_env_aware_policy.py | 66 ++ rllib/models/catalog.py | 3 + rllib/models/modelv2.py | 34 +- rllib/models/tf/recurrent_net.py | 5 +- rllib/models/torch/recurrent_net.py | 70 +- rllib/policy/policy.py | 24 +- rllib/policy/rnn_sequencing.py | 77 ++- rllib/policy/sample_batch.py | 60 +- .../policy/tests/test_trajectory_view_api.py | 84 --- rllib/policy/torch_policy.py | 39 +- rllib/policy/torch_policy_template.py | 11 +- rllib/policy/view_requirement.py | 6 +- rllib/utils/sgd.py | 24 +- rllib/utils/typing.py | 3 + 25 files changed, 1230 insertions(+), 413 deletions(-) create mode 100644 rllib/evaluation/tests/test_trajectory_view_api.py create mode 100644 rllib/examples/policy/episode_env_aware_policy.py delete mode 100644 rllib/policy/tests/test_trajectory_view_api.py diff --git a/rllib/BUILD b/rllib/BUILD index 222fa482a..1eab1ae5f 100644 --- a/rllib/BUILD +++ b/rllib/BUILD @@ -22,13 +22,14 @@ # (problems: 10min timeout, not respecting ray/ci/keep_alive.sh, or even # `travis_wait n`, etc..). -# Our travis.yml file executes all these tests in 6 different jobs, which are: +# Our travis.yml file executes all these tests in 7 different jobs, which are: # 1) everything in a) using tf2.x # 2) everything in a) using tf1.x -# 3) everything in b) c) d) and e) -# 4) everything in g) -# 5) f), BUT only those tagged `tests_dir_A` to `tests_dir_L` -# 6) f), BUT only those tagged `tests_dir_M` to `tests_dir_Z` +# 3) everything in a) using torch +# 4) everything in b) c) d) and e) +# 5) everything in g) +# 6) f), BUT only those tagged `tests_dir_A` to `tests_dir_[some letter]` +# 7) f), BUT only those tagged `tests_dir_[some letter]` to `tests_dir_Z` # -------------------------------------------------------------------- @@ -1024,6 +1025,22 @@ py_test( srcs = ["models/tests/test_attention_nets.py"] ) + +# -------------------------------------------------------------------- +# Evaluation components +# rllib/evaluation/ +# +# Tag: evaluation +# -------------------------------------------------------------------- +# mysteriously times out on travis. +#py_test( +# name = "evaluation/tests/test_trajectory_view_api", +# tags = ["evaluation"], +# size = "medium", +# srcs = ["evaluation/tests/test_trajectory_view_api.py"] +#) + + # -------------------------------------------------------------------- # Optimizers and Memories # rllib/execution/ @@ -1059,13 +1076,6 @@ py_test( srcs = ["policy/tests/test_compute_log_likelihoods.py"] ) -py_test( - name = "policy/tests/test_trajectory_view_api", - tags = ["policy"], - size = "small", - srcs = ["policy/tests/test_trajectory_view_api.py"] -) - # -------------------------------------------------------------------- # Utils: # rllib/utils/ diff --git a/rllib/agents/callbacks.py b/rllib/agents/callbacks.py index 2111032cb..53921a0b1 100644 --- a/rllib/agents/callbacks.py +++ b/rllib/agents/callbacks.py @@ -1,13 +1,16 @@ -from typing import Dict +from typing import Dict, TYPE_CHECKING from ray.rllib.env import BaseEnv from ray.rllib.policy import Policy from ray.rllib.policy.sample_batch import SampleBatch -from ray.rllib.evaluation import MultiAgentEpisode, RolloutWorker +from ray.rllib.evaluation import MultiAgentEpisode from ray.rllib.utils.annotations import PublicAPI from ray.rllib.utils.deprecation import deprecation_warning from ray.rllib.utils.typing import AgentID, PolicyID +if TYPE_CHECKING: + from ray.rllib.evaluation import RolloutWorker + @PublicAPI class DefaultCallbacks: @@ -27,7 +30,7 @@ class DefaultCallbacks: "a class extending rllib.agents.callbacks.DefaultCallbacks") self.legacy_callbacks = legacy_callbacks_dict or {} - def on_episode_start(self, worker: RolloutWorker, base_env: BaseEnv, + def on_episode_start(self, worker: "RolloutWorker", base_env: BaseEnv, policies: Dict[PolicyID, Policy], episode: MultiAgentEpisode, **kwargs): """Callback run on the rollout worker before each episode starts. @@ -52,7 +55,7 @@ class DefaultCallbacks: "episode": episode, }) - def on_episode_step(self, worker: RolloutWorker, base_env: BaseEnv, + def on_episode_step(self, worker: "RolloutWorker", base_env: BaseEnv, episode: MultiAgentEpisode, **kwargs): """Runs on each episode step. @@ -73,7 +76,7 @@ class DefaultCallbacks: "episode": episode }) - def on_episode_end(self, worker: RolloutWorker, base_env: BaseEnv, + def on_episode_end(self, worker: "RolloutWorker", base_env: BaseEnv, policies: Dict[PolicyID, Policy], episode: MultiAgentEpisode, **kwargs): """Runs when an episode is done. @@ -99,7 +102,7 @@ class DefaultCallbacks: }) def on_postprocess_trajectory( - self, worker: RolloutWorker, episode: MultiAgentEpisode, + self, worker: "RolloutWorker", episode: MultiAgentEpisode, agent_id: AgentID, policy_id: PolicyID, policies: Dict[PolicyID, Policy], postprocessed_batch: SampleBatch, original_batches: Dict[AgentID, SampleBatch], **kwargs): @@ -133,7 +136,7 @@ class DefaultCallbacks: "all_pre_batches": original_batches, }) - def on_sample_end(self, worker: RolloutWorker, samples: SampleBatch, + def on_sample_end(self, worker: "RolloutWorker", samples: SampleBatch, **kwargs): """Called at the end RolloutWorker.sample(). diff --git a/rllib/agents/ppo/ppo_torch_policy.py b/rllib/agents/ppo/ppo_torch_policy.py index c415f7b8a..b3ac1a6f3 100644 --- a/rllib/agents/ppo/ppo_torch_policy.py +++ b/rllib/agents/ppo/ppo_torch_policy.py @@ -117,7 +117,10 @@ def ppo_surrogate_loss(policy, model, dist_class, train_batch): mask = None if state: max_seq_len = torch.max(train_batch["seq_lens"]) - mask = sequence_mask(train_batch["seq_lens"], max_seq_len) + mask = sequence_mask( + train_batch["seq_lens"], + max_seq_len, + time_major=model.is_time_major()) mask = torch.reshape(mask, [-1]) policy.loss_obj = PPOLoss( @@ -221,6 +224,12 @@ def training_view_requirements_fn(policy): SampleBatch.NEXT_OBS: ViewRequirement(SampleBatch.OBS, shift=1), # VF preds are needed for the loss. SampleBatch.VF_PREDS: ViewRequirement(shift=0), + # Needed for postprocessing. + SampleBatch.ACTION_DIST_INPUTS: ViewRequirement(shift=0), + SampleBatch.ACTION_LOGP: ViewRequirement(shift=0), + # Created during postprocessing. + Postprocessing.ADVANTAGES: ViewRequirement(shift=0), + Postprocessing.VALUE_TARGETS: ViewRequirement(shift=0), } diff --git a/rllib/agents/trainer.py b/rllib/agents/trainer.py index 926802c5d..7904245c1 100644 --- a/rllib/agents/trainer.py +++ b/rllib/agents/trainer.py @@ -1082,6 +1082,11 @@ class Trainer(Trainable): raise ValueError( "`_use_trajectory_view_api` only supported for PyTorch so " "far!") + elif not config.get("_use_trajectory_view_api") and \ + config.get("model", {}).get("_time_major"): + raise ValueError("`model._time_major` only supported " + "iff `_use_trajectory_view_api` is True!") + if "policy_graphs" in config["multiagent"]: deprecation_warning("policy_graphs", "policies") # Backwards compatibility. diff --git a/rllib/env/policy_client.py b/rllib/env/policy_client.py index 5aa17fae0..b8efa2eb9 100644 --- a/rllib/env/policy_client.py +++ b/rllib/env/policy_client.py @@ -10,7 +10,6 @@ import time from typing import Union, Optional import ray.cloudpickle as pickle -from ray.rllib.evaluation.rollout_worker import RolloutWorker from ray.rllib.env import ExternalEnv, MultiAgentEnv, ExternalMultiAgentEnv from ray.rllib.policy.sample_batch import MultiAgentBatch from ray.rllib.utils.annotations import PublicAPI @@ -337,6 +336,7 @@ def _create_embedded_rollout_worker(kwargs, send_fn): real_env_creator = kwargs["env_creator"] kwargs["env_creator"] = _auto_wrap_external(real_env_creator) + from ray.rllib.evaluation.rollout_worker import RolloutWorker rollout_worker = RolloutWorker(**kwargs) inference_thread = _LocalInferenceThread(rollout_worker, send_fn) inference_thread.start() diff --git a/rllib/evaluation/multi_agent_sample_collector.py b/rllib/evaluation/multi_agent_sample_collector.py index dd81b9e38..7c21b0bec 100644 --- a/rllib/evaluation/multi_agent_sample_collector.py +++ b/rllib/evaluation/multi_agent_sample_collector.py @@ -1,7 +1,6 @@ import logging -from typing import Dict, Optional +from typing import Dict, Optional, TYPE_CHECKING -from ray.rllib.agents.callbacks import DefaultCallbacks from ray.rllib.env.base_env import _DUMMY_AGENT_ID from ray.rllib.evaluation.episode import MultiAgentEpisode from ray.rllib.evaluation.per_policy_sample_collector import \ @@ -16,6 +15,9 @@ from ray.rllib.utils.typing import AgentID, EnvID, EpisodeID, PolicyID, \ TensorType from ray.util.debug import log_once +if TYPE_CHECKING: + from ray.rllib.agents.callbacks import DefaultCallbacks + logger = logging.getLogger(__name__) @@ -38,7 +40,7 @@ class _MultiAgentSampleCollector(_SampleCollector): def __init__( self, policy_map: Dict[PolicyID, Policy], - callbacks: DefaultCallbacks, + callbacks: "DefaultCallbacks", # TODO: (sven) make `num_agents` flexibly grow in size. num_agents: int = 100, num_timesteps=None, @@ -64,8 +66,8 @@ class _MultiAgentSampleCollector(_SampleCollector): num_agents = 1000 self.num_agents = int(num_agents) - # Collect SampleBatches per-policy in PolicyTrajectories objects. - self.rollout_sample_collectors = {} + # Collect SampleBatches per-policy in _PerPolicySampleCollectors. + self.policy_sample_collectors = {} for pid, policy in policy_map.items(): # Figure out max-shifts (before and after). view_reqs = policy.training_view_requirements @@ -86,7 +88,7 @@ class _MultiAgentSampleCollector(_SampleCollector): elif num_timesteps is not None: kwargs["num_timesteps"] = num_timesteps - self.rollout_sample_collectors[pid] = _PerPolicySampleCollector( + self.policy_sample_collectors[pid] = _PerPolicySampleCollector( num_agents=self.num_agents, shift_before=-max_shift_before, shift_after=max_shift_after, @@ -109,7 +111,7 @@ class _MultiAgentSampleCollector(_SampleCollector): assert self.agent_to_policy[agent_id] == policy_id # Add initial obs to Trajectory. - self.rollout_sample_collectors[policy_id].add_init_obs( + self.policy_sample_collectors[policy_id].add_init_obs( episode_id, agent_id, env_id, chunk_num=0, init_obs=obs) @override(_SampleCollector) @@ -117,7 +119,7 @@ class _MultiAgentSampleCollector(_SampleCollector): agent_id: AgentID, env_id: EnvID, policy_id: PolicyID, agent_done: bool, values: Dict[str, TensorType]) -> None: - assert policy_id in self.rollout_sample_collectors + assert policy_id in self.policy_sample_collectors # Make sure our mappings are up to date. if agent_id not in self.agent_to_policy: @@ -130,13 +132,13 @@ class _MultiAgentSampleCollector(_SampleCollector): values["agent_id"] = agent_id # Add action/reward/next-obs (and other data) to Trajectory. - self.rollout_sample_collectors[policy_id].add_action_reward_next_obs( + self.policy_sample_collectors[policy_id].add_action_reward_next_obs( episode_id, agent_id, env_id, agent_done, values) @override(_SampleCollector) def total_env_steps(self) -> int: return sum(a.timesteps_since_last_reset - for a in self.rollout_sample_collectors.values()) + for a in self.policy_sample_collectors.values()) def total(self): # TODO: (sven) deprecate; use `self.total_env_steps`, instead. @@ -148,7 +150,7 @@ class _MultiAgentSampleCollector(_SampleCollector): Dict[str, TensorType]: policy = self.policy_map[policy_id] view_reqs = policy.model.inference_view_requirements - return self.rollout_sample_collectors[ + return self.policy_sample_collectors[ policy_id].get_inference_input_dict(view_reqs) @override(_SampleCollector) @@ -161,7 +163,7 @@ class _MultiAgentSampleCollector(_SampleCollector): # Loop through each per-policy collector and create a view (for each # agent as SampleBatch) from its buffers for post-processing all_agent_batches = {} - for pid, rc in self.rollout_sample_collectors.items(): + for pid, rc in self.policy_sample_collectors.items(): policy = self.policy_map[pid] view_reqs = policy.training_view_requirements agent_batches = rc.get_postprocessing_sample_batches( @@ -211,7 +213,7 @@ class _MultiAgentSampleCollector(_SampleCollector): @override(_SampleCollector) def check_missing_dones(self, episode_id: EpisodeID) -> None: - for pid, rc in self.rollout_sample_collectors.items(): + for pid, rc in self.policy_sample_collectors.items(): for agent_key in rc.agent_key_to_slot.keys(): # Only check for given episode and only for last chunk # (all previous chunks for that agent in the episode are @@ -235,7 +237,7 @@ class _MultiAgentSampleCollector(_SampleCollector): def get_multi_agent_batch_and_reset(self): self.postprocess_trajectories_so_far() policy_batches = {} - for pid, rc in self.rollout_sample_collectors.items(): + for pid, rc in self.policy_sample_collectors.items(): policy = self.policy_map[pid] view_reqs = policy.training_view_requirements policy_batches[pid] = rc.get_train_sample_batch_and_reset( diff --git a/rllib/evaluation/per_policy_sample_collector.py b/rllib/evaluation/per_policy_sample_collector.py index 58e95231e..3ef853ad5 100644 --- a/rllib/evaluation/per_policy_sample_collector.py +++ b/rllib/evaluation/per_policy_sample_collector.py @@ -103,7 +103,13 @@ class _PerPolicySampleCollector: self._next_agent_slot() if SampleBatch.OBS not in self.buffers: - self._build_buffers(single_row={SampleBatch.OBS: init_obs}) + self._build_buffers( + single_row={ + SampleBatch.OBS: init_obs, + SampleBatch.EPS_ID: episode_id, + SampleBatch.AGENT_INDEX: agent_id, + "env_id": env_id, + }) if self.time_major: self.buffers[SampleBatch.OBS][self.shift_before-1, agent_slot] = \ init_obs @@ -262,12 +268,12 @@ class _PerPolicySampleCollector: batch = sample_batch_data[agent_key] for view_col, view_req in view_reqs.items(): + data_col = view_req.data_col or view_col # Skip columns that will only get added through postprocessing # (these may not even exist yet). - if view_req.created_during_postprocessing: + if data_col not in self.buffers: continue - data_col = view_req.data_col or view_col shift = view_req.shift if data_col == SampleBatch.OBS: shift -= 1 @@ -289,20 +295,22 @@ class _PerPolicySampleCollector: SampleBatch: Returns the accumulated sample batch for this policy. """ - seq_lens = [ + seq_lens_w_0s = [ self.agent_key_to_timestep[k] - self.shift_before for k in self.slot_to_agent_key if k is not None ] - first_zero_len = len(seq_lens) - if seq_lens[-1] == 0: - first_zero_len = seq_lens.index(0) + # We have an agent-axis buffer "rollover" (new SampleBatch will be + # built from last n agent records plus first m agent records in + # buffer). + if self.agent_slot_cursor < self.sample_batch_offset: + rollover = -(self.num_agents - self.sample_batch_offset) + seq_lens_w_0s = seq_lens_w_0s[rollover:] + seq_lens_w_0s[:rollover] + first_zero_len = len(seq_lens_w_0s) + if seq_lens_w_0s[-1] == 0: + first_zero_len = seq_lens_w_0s.index(0) # Assert that all zeros lie at the end of the seq_lens array. - try: - assert all(seq_lens[i] == 0 - for i in range(first_zero_len, len(seq_lens))) - except AssertionError as e: - print() - raise e + assert all(seq_lens_w_0s[i] == 0 + for i in range(first_zero_len, len(seq_lens_w_0s))) t_start = self.shift_before t_end = t_start + self.num_timesteps @@ -311,8 +319,8 @@ class _PerPolicySampleCollector: # actually already has at least 1 timestep of data (thus it excludes # just-rolled over chunks (which only have the initial obs in them)). valid_agent_cursor = \ - (self.agent_slot_cursor - (len(seq_lens) - first_zero_len)) % \ - self.num_agents + (self.agent_slot_cursor - + (len(seq_lens_w_0s) - first_zero_len)) % self.num_agents # Construct the view dict. view = {} @@ -320,12 +328,13 @@ class _PerPolicySampleCollector: data_col = view_req.data_col or view_col assert data_col in self.buffers # For OBS, indices must be shifted by -1. - extra_shift = 0 if data_col != SampleBatch.OBS else -1 + shift = view_req.shift + shift += 0 if data_col != SampleBatch.OBS else -1 # If agent_slot has been rolled-over to beginning, we have to copy # here. if valid_agent_cursor < self.sample_batch_offset: - time_slice = self.buffers[data_col][t_start + extra_shift: - t_end + extra_shift] + time_slice = self.buffers[data_col][t_start + shift:t_end + + shift] one_ = time_slice[:, self.sample_batch_offset:] two_ = time_slice[:, :valid_agent_cursor] if torch and isinstance(time_slice, torch.Tensor): @@ -335,17 +344,15 @@ class _PerPolicySampleCollector: else: view[view_col] = \ self.buffers[data_col][ - t_start + extra_shift:t_end + extra_shift, + t_start + shift:t_end + shift, self.sample_batch_offset:valid_agent_cursor] # Copy all still ongoing trajectories to new agent slots # (including the ones that just started (are seq_len=0)). new_chunk_args = [] - for i, seq_len in enumerate(seq_lens): + for i, seq_len in enumerate(seq_lens_w_0s): if seq_len < self.num_timesteps: - agent_slot = self.sample_batch_offset + i - if agent_slot >= self.num_agents: - agent_slot = agent_slot % self.num_agents + agent_slot = (self.sample_batch_offset + i) % self.num_agents if not self.buffers[SampleBatch. DONES][seq_len - 1 + self.shift_before][agent_slot]: @@ -354,9 +361,9 @@ class _PerPolicySampleCollector: (agent_slot, agent_key, self.agent_key_to_timestep[agent_key])) # Cut out all 0 seq-lens. - seq_lens = seq_lens[:first_zero_len] + seq_lens = seq_lens_w_0s[:first_zero_len] batch = SampleBatch( - view, _seq_lens=np.array(seq_lens), _time_major=True) + view, _seq_lens=np.array(seq_lens), _time_major=self.time_major) # Reset everything for new data. self.postprocessed_agents = [False] * self.num_agents @@ -376,9 +383,14 @@ class _PerPolicySampleCollector: def _build_buffers(self, single_row: Dict[str, TensorType]) -> None: """Builds the internal data buffers based on a single given row. + This may be called several times in the lifetime of this instance + to add new columns to the buffer. Columns in `single_row` that already + exist in the buffer will be ignored. + Args: single_row (Dict[str, TensorType]): A single datarow with one or - more columns (str as key, np.ndarray|tensor as data). + more columns (str as key, np.ndarray|tensor as data) to be used + as template to build the pre-allocated buffer. """ time_size = self.num_timesteps + self.shift_before + self.shift_after for col, data in single_row.items(): diff --git a/rllib/evaluation/rollout_worker.py b/rllib/evaluation/rollout_worker.py index f4bf4ccd6..73e69bd55 100644 --- a/rllib/evaluation/rollout_worker.py +++ b/rllib/evaluation/rollout_worker.py @@ -515,7 +515,7 @@ class RolloutWorker(ParallelIteratorWorker): rollout_fragment_length=rollout_fragment_length, callbacks=self.callbacks, horizon=episode_horizon, - pack_multiple_episodes_in_batch=pack, + multiple_episodes_in_batch=pack, tf_sess=self.tf_sess, clip_actions=clip_actions, blackhole_outputs="simulation" in input_evaluation, @@ -538,7 +538,7 @@ class RolloutWorker(ParallelIteratorWorker): rollout_fragment_length=rollout_fragment_length, callbacks=self.callbacks, horizon=episode_horizon, - pack_multiple_episodes_in_batch=pack, + multiple_episodes_in_batch=pack, tf_sess=self.tf_sess, clip_actions=clip_actions, soft_horizon=soft_horizon, diff --git a/rllib/evaluation/sampler.py b/rllib/evaluation/sampler.py index bb6071062..eba66e450 100644 --- a/rllib/evaluation/sampler.py +++ b/rllib/evaluation/sampler.py @@ -5,14 +5,17 @@ import numpy as np import queue import threading import time -from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, \ +from typing import Any, Callable, Dict, List, Iterable, Optional, Set, Tuple,\ TYPE_CHECKING, Union from ray.util.debug import log_once from ray.rllib.evaluation.episode import MultiAgentEpisode +from ray.rllib.evaluation.multi_agent_sample_collector import \ + _MultiAgentSampleCollector from ray.rllib.evaluation.rollout_metrics import RolloutMetrics from ray.rllib.evaluation.sample_batch_builder import \ MultiAgentSampleBatchBuilder +from ray.rllib.evaluation.sample_collector import _SampleCollector from ray.rllib.policy.policy import clip_action, Policy from ray.rllib.policy.tf_policy import TFPolicy from ray.rllib.models.preprocessors import Preprocessor @@ -22,6 +25,7 @@ from ray.rllib.env.atari_wrappers import get_wrapper_by_cls, MonitorEnv from ray.rllib.offline import InputReader from ray.rllib.utils.annotations import override, DeveloperAPI from ray.rllib.utils.debug import summarize +from ray.rllib.utils.numpy import convert_to_numpy from ray.rllib.utils.spaces.space_utils import flatten_to_single_ndarray, \ unbatch from ray.rllib.utils.tf_run_builder import TFRunBuilder @@ -51,14 +55,23 @@ class _PerfStats: def __init__(self): self.iters = 0 self.env_wait_time = 0.0 - self.processing_time = 0.0 + self.raw_obs_processing_time = 0.0 self.inference_time = 0.0 + self.action_processing_time = 0.0 def get(self): + # Mean multiplicator (1000 = ms -> sec). + factor = 1000 / self.iters return { - "mean_env_wait_ms": self.env_wait_time * 1000 / self.iters, - "mean_processing_ms": self.processing_time * 1000 / self.iters, - "mean_inference_ms": self.inference_time * 1000 / self.iters + # Waiting for environment (during poll). + "mean_env_wait_ms": self.env_wait_time * factor, + # Raw observation preprocessing. + "mean_raw_obs_processing_ms": self.raw_obs_processing_time * + factor, + # Computing actions through policy. + "mean_inference_ms": self.inference_time * factor, + # Processing actions (to be sent to env, e.g. clipping). + "mean_action_processing_ms": self.action_processing_time * factor, } @@ -108,7 +121,7 @@ class SyncSampler(SamplerInput): rollout_fragment_length: int, callbacks: "DefaultCallbacks", horizon: int = None, - pack_multiple_episodes_in_batch: bool = False, + multiple_episodes_in_batch: bool = False, tf_sess=None, clip_actions: bool = True, soft_horizon: bool = False, @@ -136,7 +149,7 @@ class SyncSampler(SamplerInput): callbacks (Callbacks): The Callbacks object to use when episode events happen during rollout. horizon (Optional[int]): Hard-reset the Env - pack_multiple_episodes_in_batch (bool): Whether to pack multiple + multiple_episodes_in_batch (bool): Whether to pack multiple episodes into each batch. This guarantees batches will be exactly `rollout_fragment_length` in size. tf_sess (Optional[tf.Session]): A tf.Session object to use (only if @@ -165,14 +178,20 @@ class SyncSampler(SamplerInput): self.obs_filters = obs_filters self.extra_batches = queue.Queue() self.perf_stats = _PerfStats() + if _use_trajectory_view_api: + self.sample_collector = _MultiAgentSampleCollector( + policies, callbacks) + else: + self.sample_collector = None + # Create the rollout generator to use for calls to `get_data()`. self.rollout_provider = _env_runner( worker, self.base_env, self.extra_batches.put, self.policies, self.policy_mapping_fn, self.rollout_fragment_length, self.horizon, self.preprocessors, self.obs_filters, clip_rewards, clip_actions, - pack_multiple_episodes_in_batch, callbacks, tf_sess, - self.perf_stats, soft_horizon, no_done_at_end, observation_fn, - _use_trajectory_view_api) + multiple_episodes_in_batch, callbacks, tf_sess, self.perf_stats, + soft_horizon, no_done_at_end, observation_fn, + _use_trajectory_view_api, self.sample_collector) self.metrics_queue = queue.Queue() @override(SamplerInput) @@ -226,7 +245,7 @@ class AsyncSampler(threading.Thread, SamplerInput): rollout_fragment_length: int, callbacks: "DefaultCallbacks", horizon: int = None, - pack_multiple_episodes_in_batch: bool = False, + multiple_episodes_in_batch: bool = False, tf_sess=None, clip_actions: bool = True, blackhole_outputs: bool = False, @@ -255,7 +274,7 @@ class AsyncSampler(threading.Thread, SamplerInput): callbacks (Callbacks): The Callbacks object to use when episode events happen during rollout. horizon (Optional[int]): Hard-reset the Env - pack_multiple_episodes_in_batch (bool): Whether to pack multiple + multiple_episodes_in_batch (bool): Whether to pack multiple episodes into each batch. This guarantees batches will be exactly `rollout_fragment_length` in size. tf_sess (Optional[tf.Session]): A tf.Session object to use (only if @@ -293,7 +312,7 @@ class AsyncSampler(threading.Thread, SamplerInput): self.obs_filters = obs_filters self.clip_rewards = clip_rewards self.daemon = True - self.pack_multiple_episodes_in_batch = pack_multiple_episodes_in_batch + self.multiple_episodes_in_batch = multiple_episodes_in_batch self.tf_sess = tf_sess self.callbacks = callbacks self.clip_actions = clip_actions @@ -304,6 +323,11 @@ class AsyncSampler(threading.Thread, SamplerInput): self.shutdown = False self.observation_fn = observation_fn self._use_trajectory_view_api = _use_trajectory_view_api + if _use_trajectory_view_api: + self.sample_collector = _MultiAgentSampleCollector( + policies, callbacks) + else: + self.sample_collector = None @override(threading.Thread) def run(self): @@ -325,8 +349,8 @@ class AsyncSampler(threading.Thread, SamplerInput): self.worker, self.base_env, extra_batches_putter, self.policies, self.policy_mapping_fn, self.rollout_fragment_length, self.horizon, self.preprocessors, self.obs_filters, self.clip_rewards, - self.clip_actions, self.pack_multiple_episodes_in_batch, - self.callbacks, self.tf_sess, self.perf_stats, self.soft_horizon, + self.clip_actions, self.multiple_episodes_in_batch, self.callbacks, + self.tf_sess, self.perf_stats, self.soft_horizon, self.no_done_at_end, self.observation_fn, self._use_trajectory_view_api) while not self.shutdown: @@ -385,14 +409,16 @@ def _env_runner( obs_filters: Dict[PolicyID, Filter], clip_rewards: bool, clip_actions: bool, - pack_multiple_episodes_in_batch: bool, + multiple_episodes_in_batch: bool, callbacks: "DefaultCallbacks", tf_sess: Optional["tf.Session"], perf_stats: _PerfStats, soft_horizon: bool, no_done_at_end: bool, observation_fn: "ObservationFunction", - _use_trajectory_view_api: bool = False) -> Iterable[SampleBatchType]: + _use_trajectory_view_api: bool = False, + _sample_collector: Optional[_SampleCollector] = None, +) -> Iterable[SampleBatchType]: """This implements the common experience collection logic. Args: @@ -413,7 +439,7 @@ def _env_runner( obs_filters (dict): Map of policy id to filter used to process observations for the policy. clip_rewards (bool): Whether to clip rewards before postprocessing. - pack_multiple_episodes_in_batch (bool): Whether to pack multiple + multiple_episodes_in_batch (bool): Whether to pack multiple episodes into each batch. This guarantees batches will be exactly `rollout_fragment_length` in size. clip_actions (bool): Whether to clip actions to the space range. @@ -430,6 +456,8 @@ def _env_runner( _use_trajectory_view_api (bool): Whether to use the (experimental) `_use_trajectory_view_api` to make generic trajectory views available to Models. Default: False. + _sample_collector (Optional[_SampleCollector]): An optional + _SampleCollector object to use Yields: rollout (SampleBatch): Object containing state, action, reward, @@ -471,6 +499,8 @@ def _env_runner( def get_batch_builder(): if batch_builder_pool: return batch_builder_pool.pop() + elif _use_trajectory_view_api: + return None else: return MultiAgentSampleBatchBuilder(policies, clip_rewards, callbacks) @@ -495,6 +525,7 @@ def _env_runner( return episode active_episodes: Dict[str, MultiAgentEpisode] = defaultdict(new_episode) + eval_results = None while True: perf_stats.iters += 1 @@ -514,39 +545,73 @@ def _env_runner( t1 = time.time() # type: Set[EnvID], Dict[PolicyID, List[PolicyEvalData]], # List[Union[RolloutMetrics, SampleBatchType]] - active_envs, to_eval, outputs = _process_observations( - worker=worker, - base_env=base_env, - policies=policies, - batch_builder_pool=batch_builder_pool, - active_episodes=active_episodes, - unfiltered_obs=unfiltered_obs, - rewards=rewards, - dones=dones, - infos=infos, - horizon=horizon, - preprocessors=preprocessors, - obs_filters=obs_filters, - rollout_fragment_length=rollout_fragment_length, - pack_multiple_episodes_in_batch=pack_multiple_episodes_in_batch, - callbacks=callbacks, - soft_horizon=soft_horizon, - no_done_at_end=no_done_at_end, - observation_fn=observation_fn, - _use_trajectory_view_api=_use_trajectory_view_api) - perf_stats.processing_time += time.time() - t1 + if _use_trajectory_view_api: + active_envs, to_eval, outputs = \ + _process_observations_w_trajectory_view_api( + worker=worker, + base_env=base_env, + policies=policies, + active_episodes=active_episodes, + prev_policy_outputs=eval_results, + unfiltered_obs=unfiltered_obs, + rewards=rewards, + dones=dones, + infos=infos, + horizon=horizon, + preprocessors=preprocessors, + obs_filters=obs_filters, + rollout_fragment_length=rollout_fragment_length, + multiple_episodes_in_batch=multiple_episodes_in_batch, + callbacks=callbacks, + soft_horizon=soft_horizon, + no_done_at_end=no_done_at_end, + observation_fn=observation_fn, + perf_stats=perf_stats, + _sample_collector=_sample_collector, + ) + else: + active_envs, to_eval, outputs = _process_observations( + worker=worker, + base_env=base_env, + policies=policies, + batch_builder_pool=batch_builder_pool, + active_episodes=active_episodes, + unfiltered_obs=unfiltered_obs, + rewards=rewards, + dones=dones, + infos=infos, + horizon=horizon, + preprocessors=preprocessors, + obs_filters=obs_filters, + rollout_fragment_length=rollout_fragment_length, + multiple_episodes_in_batch=multiple_episodes_in_batch, + callbacks=callbacks, + soft_horizon=soft_horizon, + no_done_at_end=no_done_at_end, + observation_fn=observation_fn, + perf_stats=perf_stats, + ) + perf_stats.raw_obs_processing_time += time.time() - t1 for o in outputs: yield o # Do batched policy eval (accross vectorized envs). t2 = time.time() # type: Dict[PolicyID, Tuple[TensorStructType, StateBatch, dict]] - eval_results = _do_policy_eval( - to_eval=to_eval, - policies=policies, - active_episodes=active_episodes, - tf_sess=tf_sess, - _use_trajectory_view_api=_use_trajectory_view_api) + if _use_trajectory_view_api: + eval_results = _do_policy_eval_w_trajectory_view_api( + to_eval=to_eval, + policies=policies, + _sample_collector=_sample_collector, + tf_sess=tf_sess, + ) + else: + eval_results = _do_policy_eval( + to_eval=to_eval, + policies=policies, + active_episodes=active_episodes, + tf_sess=tf_sess, + ) perf_stats.inference_time += time.time() - t2 # Process results and update episode state. @@ -560,8 +625,10 @@ def _env_runner( off_policy_actions=off_policy_actions, policies=policies, clip_actions=clip_actions, - _use_trajectory_view_api=_use_trajectory_view_api) - perf_stats.processing_time += time.time() - t3 + _use_trajectory_view_api=_use_trajectory_view_api, + _sample_collector=_sample_collector, + ) + perf_stats.action_processing_time += time.time() - t3 # Return computed actions to ready envs. We also send to envs that have # taken off-policy actions; those envs are free to ignore the action. @@ -571,6 +638,7 @@ def _env_runner( def _process_observations( + *, worker: "RolloutWorker", base_env: BaseEnv, policies: Dict[PolicyID, Policy], @@ -584,12 +652,12 @@ def _process_observations( preprocessors: Dict[PolicyID, Preprocessor], obs_filters: Dict[PolicyID, Filter], rollout_fragment_length: int, - pack_multiple_episodes_in_batch: bool, + multiple_episodes_in_batch: bool, callbacks: "DefaultCallbacks", soft_horizon: bool, no_done_at_end: bool, observation_fn: "ObservationFunction", - _use_trajectory_view_api: bool = False + perf_stats: _PerfStats, ) -> Tuple[Set[EnvID], Dict[PolicyID, List[PolicyEvalData]], List[Union[ RolloutMetrics, SampleBatchType]]]: """Record new data from the environment and prepare for policy evaluation. @@ -602,8 +670,11 @@ def _process_observations( SampleBatchBuilder object for recycling. active_episodes (Dict[str, MultiAgentEpisode]): Mapping from episode ID to currently ongoing MultiAgentEpisode object. - unfiltered_obs (dict): Doubly keyed dict of env-ids -> agent ids -> - unfiltered observation tensor, returned by a `BaseEnv.poll()` call. + prev_policy_outputs (Dict[str,List]): The prev policy output dict + (by policy-id -> List[action, state outs, extra fetches]). + unfiltered_obs (dict): Doubly keyed dict of env-ids -> agent ids + -> unfiltered observation tensor, returned by a `BaseEnv.poll()` + call. rewards (dict): Doubly keyed dict of env-ids -> agent ids -> rewards tensor, returned by a `BaseEnv.poll()` call. dones (dict): Doubly keyed dict of env-ids -> agent ids -> @@ -618,7 +689,7 @@ def _process_observations( rollout_fragment_length (int): Number of episode steps before `SampleBatch` is yielded. Set to infinity to yield complete episodes. - pack_multiple_episodes_in_batch (bool): Whether to pack multiple + multiple_episodes_in_batch (bool): Whether to pack multiple episodes into each batch. This guarantees batches will be exactly `rollout_fragment_length` in size. callbacks (DefaultCallbacks): User callbacks to run on episode events. @@ -628,9 +699,6 @@ def _process_observations( and instead record done=False. observation_fn (ObservationFunction): Optional multi-agent observation func to use for preprocessing observations. - _use_trajectory_view_api (bool): Whether to use the (experimental) - `_use_trajectory_view_api` to make generic trajectory views - available to Models. Default: False. Returns: Tuple: @@ -652,20 +720,21 @@ def _process_observations( for env_id, agent_obs in unfiltered_obs.items(): is_new_episode: bool = env_id not in active_episodes episode: MultiAgentEpisode = active_episodes[env_id] + batch_builder = episode.batch_builder if not is_new_episode: episode.length += 1 - episode.batch_builder.count += 1 + batch_builder.count += 1 episode._add_agent_rewards(rewards[env_id]) - if (episode.batch_builder.total() > large_batch_threshold + if (batch_builder.total() > large_batch_threshold and log_once("large_batch_warning")): logger.warning( "More than {} observations for {} env steps ".format( - episode.batch_builder.total(), - episode.batch_builder.count) + "are buffered in " - "the sampler. If this is more than you expected, check " - "that you set a horizon on your environment correctly and " - "that it terminates at some point. " + batch_builder.total(), batch_builder.count) + + "are buffered in " + "the sampler. If this is more than you expected, check that " + "that you set a horizon on your environment correctly and that" + " it terminates at some point. " "Note: In multi-agent environments, `rollout_fragment_length` " "sets the batch size based on environment steps, not the " "steps of " @@ -725,12 +794,12 @@ def _process_observations( agent_done = bool(all_agents_done or dones[env_id].get(agent_id)) if not agent_done: - to_eval[policy_id].append( - PolicyEvalData(env_id, agent_id, filtered_obs, - infos[env_id].get(agent_id, {}), - episode.rnn_state_for(agent_id), - episode.last_action_for(agent_id), - rewards[env_id][agent_id] or 0.0)) + item = PolicyEvalData(env_id, agent_id, filtered_obs, + infos[env_id].get(agent_id, {}), + episode.rnn_state_for(agent_id), + episode.last_action_for(agent_id), + rewards[env_id][agent_id] or 0.0) + to_eval[policy_id].append(item) last_observation: EnvObsType = episode.last_observation_for( agent_id) @@ -741,7 +810,7 @@ def _process_observations( # Record transition info if applicable. if (last_observation is not None and infos[env_id].get( agent_id, {}).get("training_enabled", True)): - episode.batch_builder.add_values( + batch_builder.add_values( agent_id, policy_id, t=episode.length - 1, @@ -767,26 +836,26 @@ def _process_observations( # - all-agents-done and not packing multiple episodes into one # (batch_mode="complete_episodes") # - or if we've exceeded the rollout_fragment_length. - if episode.batch_builder.has_pending_agent_data(): + if batch_builder.has_pending_agent_data(): # Sanity check, whether all agents have done=True, if done[__all__] # is True. if dones[env_id]["__all__"] and not no_done_at_end: - episode.batch_builder.check_missing_dones() + batch_builder.check_missing_dones() - # Reached end of episode and we are not allowed to pack the - # next episode into the same SampleBatch -> Build the SampleBatch - # and add it to "outputs". - if (all_agents_done and not pack_multiple_episodes_in_batch) or \ - episode.batch_builder.count >= rollout_fragment_length: - outputs.append(episode.batch_builder.build_and_reset(episode)) - # Make sure postprocessor stays within one episode. - elif all_agents_done: - episode.batch_builder.postprocess_batch_so_far(episode) + # Reached end of episode and we are not allowed to pack the + # next episode into the same SampleBatch -> Build the SampleBatch + # and add it to "outputs". + if (all_agents_done and not multiple_episodes_in_batch) or \ + batch_builder.count >= rollout_fragment_length: + outputs.append(batch_builder.build_and_reset(episode)) + # Make sure postprocessor stays within one episode. + elif all_agents_done: + batch_builder.postprocess_batch_so_far(episode) # Episode is done. if all_agents_done: - # Handle episode termination. - batch_builder_pool.append(episode.batch_builder) + # We can pass the BatchBuilder to recycling. + batch_builder_pool.append(batch_builder) # Call each policy's Exploration.on_episode_end method. for p in policies.values(): if getattr(p, "exploration", None) is not None: @@ -834,14 +903,262 @@ def _process_observations( filtered_obs: EnvObsType = _get_or_raise( obs_filters, policy_id)(prep_obs) episode._set_last_observation(agent_id, filtered_obs) - to_eval[policy_id].append( - PolicyEvalData( - env_id, agent_id, filtered_obs, - episode.last_info_for(agent_id) or {}, - episode.rnn_state_for(agent_id), - np.zeros_like( - flatten_to_single_ndarray( - policy.action_space.sample())), 0.0)) + + item = PolicyEvalData( + env_id, agent_id, filtered_obs, + episode.last_info_for(agent_id) or {}, + episode.rnn_state_for(agent_id), + np.zeros_like( + flatten_to_single_ndarray( + policy.action_space.sample())), 0.0) + to_eval[policy_id].append(item) + + return active_envs, to_eval, outputs + + +def _process_observations_w_trajectory_view_api( + *, + worker: "RolloutWorker", + base_env: BaseEnv, + policies: Dict[PolicyID, Policy], + active_episodes: Dict[str, MultiAgentEpisode], + prev_policy_outputs: Dict[PolicyID, Tuple[TensorStructType, StateBatch, + dict]], + unfiltered_obs: Dict[EnvID, Dict[AgentID, EnvObsType]], + rewards: Dict[EnvID, Dict[AgentID, float]], + dones: Dict[EnvID, Dict[AgentID, bool]], + infos: Dict[EnvID, Dict[AgentID, EnvInfoDict]], + horizon: int, + preprocessors: Dict[PolicyID, Preprocessor], + obs_filters: Dict[PolicyID, Filter], + rollout_fragment_length: int, + multiple_episodes_in_batch: bool, + callbacks: "DefaultCallbacks", + soft_horizon: bool, + no_done_at_end: bool, + observation_fn: "ObservationFunction", + perf_stats: _PerfStats, + _sample_collector: _SampleCollector, +) -> Tuple[Set[EnvID], Dict[PolicyID, List[PolicyEvalData]], List[Union[ + RolloutMetrics, SampleBatchType]]]: + """Trajectory View API version of `_process_observations()`. + TODO: (sven) Move docstring here once original function is deprecated. + """ + + # Output objects. + active_envs: Set[EnvID] = set() + to_eval: Set[PolicyID] = set() + outputs: List[Union[RolloutMetrics, SampleBatchType]] = [] + + large_batch_threshold: int = max(1000, rollout_fragment_length * 10) if \ + rollout_fragment_length != float("inf") else 5000 + + # For each environment. + # type: EnvID, Dict[AgentID, EnvObsType] + for env_id, agent_obs in unfiltered_obs.items(): + is_new_episode: bool = env_id not in active_episodes + episode: MultiAgentEpisode = active_episodes[env_id] + + if not is_new_episode: + episode.length += 1 + _sample_collector.count += 1 + episode._add_agent_rewards(rewards[env_id]) + + if (_sample_collector.total_env_steps() > large_batch_threshold + and log_once("large_batch_warning")): + logger.warning( + "More than {} observations for {} env steps ".format( + _sample_collector.total_env_steps(), + _sample_collector.count) + "are buffered in " + "the sampler. If this is more than you expected, check that " + "that you set a horizon on your environment correctly and that" + " it terminates at some point. " + "Note: In multi-agent environments, `rollout_fragment_length` " + "sets the batch size based on environment steps, not the " + "steps of " + "individual agents, which can result in unexpectedly large " + "batches. Also, you may be in evaluation waiting for your Env " + "to terminate (batch_mode=`complete_episodes`). Make sure it " + "does at some point.") + + # Check episode termination conditions. + if dones[env_id]["__all__"] or episode.length >= horizon: + hit_horizon = (episode.length >= horizon + and not dones[env_id]["__all__"]) + all_agents_done = True + atari_metrics: List[RolloutMetrics] = _fetch_atari_metrics( + base_env) + if atari_metrics is not None: + for m in atari_metrics: + outputs.append( + m._replace(custom_metrics=episode.custom_metrics)) + else: + outputs.append( + RolloutMetrics(episode.length, episode.total_reward, + dict(episode.agent_rewards), + episode.custom_metrics, {}, + episode.hist_data)) + else: + hit_horizon = False + all_agents_done = False + active_envs.add(env_id) + + # Custom observation function is applied before preprocessing. + if observation_fn: + agent_obs: Dict[AgentID, EnvObsType] = observation_fn( + agent_obs=agent_obs, + worker=worker, + base_env=base_env, + policies=policies, + episode=episode) + if not isinstance(agent_obs, dict): + raise ValueError( + "observe() must return a dict of agent observations") + + # For each agent in the environment. + # type: AgentID, EnvObsType + for agent_id, raw_obs in agent_obs.items(): + assert agent_id != "__all__" + policy_id: PolicyID = episode.policy_for(agent_id) + prep_obs: EnvObsType = _get_or_raise(preprocessors, + policy_id).transform(raw_obs) + if log_once("prep_obs"): + logger.info("Preprocessed obs: {}".format(summarize(prep_obs))) + + filtered_obs: EnvObsType = _get_or_raise(obs_filters, + policy_id)(prep_obs) + if log_once("filtered_obs"): + logger.info("Filtered obs: {}".format(summarize(filtered_obs))) + + agent_done = bool(all_agents_done or dones[env_id].get(agent_id)) + + last_observation: EnvObsType = episode.last_observation_for( + agent_id) + episode._set_last_observation(agent_id, filtered_obs) + episode._set_last_raw_obs(agent_id, raw_obs) + episode._set_last_info(agent_id, infos[env_id].get(agent_id, {})) + + # Record transition info if applicable. + if last_observation is None: + _sample_collector.add_init_obs(episode.episode_id, agent_id, + env_id, policy_id, filtered_obs) + else: + rc = _sample_collector.policy_sample_collectors[policy_id] + eval_idx = rc.agent_key_to_forward_pass_index[( + agent_id, episode.episode_id)] + values_dict = { + "t": episode.length - 1, + "eps_id": episode.episode_id, + "agent_index": episode._agent_index(agent_id), + # Action (slot 0) taken at timestep t. + "actions": prev_policy_outputs[policy_id][0][eval_idx], + # Reward received after taking a at timestep t. + "rewards": rewards[env_id][agent_id], + # After taking a, did we reach terminal? + "dones": (False if (no_done_at_end + or (hit_horizon and soft_horizon)) else + agent_done), + # Next observation. + "new_obs": filtered_obs, + } + # TODO: (sven) add env infos to buffers as well. + for k, v in prev_policy_outputs[policy_id][2].items(): + values_dict[k] = v[eval_idx] + for i, v in enumerate(prev_policy_outputs[policy_id][1]): + values_dict["state_out_{}".format(i)] = v[eval_idx] + _sample_collector.add_action_reward_next_obs( + episode.episode_id, agent_id, env_id, policy_id, + agent_done, values_dict) + + if not agent_done: + to_eval.add(policy_id) + + # Invoke the step callback after the step is logged to the episode + callbacks.on_episode_step( + worker=worker, base_env=base_env, episode=episode) + + # Cut the batch if ... + # - all-agents-done and not packing multiple episodes into one + # (batch_mode="complete_episodes") + # - or if we've exceeded the rollout_fragment_length. + if _sample_collector.has_non_postprocessed_data(): + # Sanity check, whether all agents have done=True, if done[__all__] + # is True. + if dones[env_id]["__all__"] and not no_done_at_end: + _sample_collector.check_missing_dones( + episode_id=episode.episode_id) + + # Reached end of episode and we are not allowed to pack the + # next episode into the same SampleBatch -> Build the SampleBatch + # and add it to "outputs". + if (all_agents_done and not multiple_episodes_in_batch) or \ + _sample_collector.count >= rollout_fragment_length: + # TODO: (sven) Case: rollout_fragment_length reached: Do not + # store any data in `episode` anymore + # (useless for get_view_requirements when t<<-1, e.g. + # attention), but keep last episode data around in + # SampleBatchBuilder + # to be able to still reference into it + # should a model require this. + outputs.append(_sample_collector.get_multi_agent_batch_and_reset()) + # Make sure postprocessor stays within one episode. + elif all_agents_done: + _sample_collector.postprocess_trajectories_so_far(episode) + + # Episode is done. + if all_agents_done: + # Call each policy's Exploration.on_episode_end method. + for p in policies.values(): + if getattr(p, "exploration", None) is not None: + p.exploration.on_episode_end( + policy=p, + environment=base_env, + episode=episode, + tf_sess=getattr(p, "_sess", None)) + # Call custom on_episode_end callback. + callbacks.on_episode_end( + worker=worker, + base_env=base_env, + policies=policies, + episode=episode) + if hit_horizon and soft_horizon: + episode.soft_reset() + resetted_obs: Dict[AgentID, EnvObsType] = agent_obs + else: + del active_episodes[env_id] + resetted_obs: Dict[AgentID, EnvObsType] = base_env.try_reset( + env_id) + if resetted_obs is None: + # Reset not supported, drop this env from the ready list. + if horizon != float("inf"): + raise ValueError( + "Setting episode horizon requires reset() support " + "from the environment.") + elif resetted_obs != ASYNC_RESET_RETURN: + # Creates a new episode if this is not async return. + # If reset is async, we will get its result in some future poll + episode: MultiAgentEpisode = active_episodes[env_id] + if observation_fn: + resetted_obs: Dict[AgentID, EnvObsType] = observation_fn( + agent_obs=resetted_obs, + worker=worker, + base_env=base_env, + policies=policies, + episode=episode) + # type: AgentID, EnvObsType + for agent_id, raw_obs in resetted_obs.items(): + policy_id: PolicyID = episode.policy_for(agent_id) + prep_obs: EnvObsType = _get_or_raise( + preprocessors, policy_id).transform(raw_obs) + filtered_obs: EnvObsType = _get_or_raise( + obs_filters, policy_id)(prep_obs) + episode._set_last_observation(agent_id, filtered_obs) + + # Add initial obs to buffer. + _sample_collector.add_init_obs(episode.episode_id, + agent_id, env_id, policy_id, + filtered_obs) + to_eval.add(policy_id) return active_envs, to_eval, outputs @@ -852,7 +1169,6 @@ def _do_policy_eval( policies: Dict[PolicyID, Policy], active_episodes: Dict[str, MultiAgentEpisode], tf_sess=None, - _use_trajectory_view_api=False ) -> Dict[PolicyID, Tuple[TensorStructType, StateBatch, dict]]: """Call compute_actions on collected episode/model data to get next action. @@ -866,9 +1182,6 @@ def _do_policy_eval( episode ID to currently ongoing MultiAgentEpisode object. tf_sess (Optional[tf.Session]): Optional tensorflow session to use for batching TF policy evaluations. - _use_trajectory_view_api (bool): Whether to use the (experimental) - `_use_trajectory_view_api` procedure to collect samples. - Default: False. Returns: eval_results: dict of policy to compute_action() outputs. @@ -888,15 +1201,15 @@ def _do_policy_eval( # type: PolicyID, PolicyEvalData for policy_id, eval_data in to_eval.items(): - rnn_in: List[List[Any]] = [t.rnn_state for t in eval_data] policy: Policy = _get_or_raise(policies, policy_id) - # If tf (non eager) AND TFPolicy's compute_action method has not been - # overridden -> Use `policy._build_compute_actions()`. + # If tf (non eager) AND TFPolicy's compute_action method has not + # been overridden -> Use `policy._build_compute_actions()`. if builder and (policy.compute_actions.__code__ is TFPolicy.compute_actions.__code__): obs_batch: List[EnvObsType] = [t.obs for t in eval_data] - state_batches: StateBatch = _to_column_format(rnn_in) + state_batches: StateBatch = _to_column_format( + [t.rnn_state for t in eval_data]) # TODO(ekl): how can we make info batch available to TF code? prev_action_batch = [t.prev_action for t in eval_data] prev_reward_batch = [t.prev_reward for t in eval_data] @@ -909,6 +1222,7 @@ def _do_policy_eval( prev_reward_batch=prev_reward_batch, timestep=policy.global_timestep) else: + rnn_in = [t.rnn_state for t in eval_data] rnn_in_cols: StateBatch = [ np.stack([row[i] for row in rnn_in]) for i in range(len(rnn_in[0])) @@ -921,6 +1235,61 @@ def _do_policy_eval( info_batch=[t.info for t in eval_data], episodes=[active_episodes[t.env_id] for t in eval_data], timestep=policy.global_timestep) + + if builder: + # type: PolicyID, Tuple[TensorStructType, StateBatch, dict] + for pid, v in pending_fetches.items(): + eval_results[pid] = builder.get(v) + + if log_once("compute_actions_result"): + logger.info("Outputs of compute_actions():\n\n{}\n".format( + summarize(eval_results))) + + return eval_results + + +def _do_policy_eval_w_trajectory_view_api( + *, + to_eval: Dict[PolicyID, List[PolicyEvalData]], + policies: Dict[PolicyID, Policy], + _sample_collector, + tf_sess=None, +) -> Dict[PolicyID, Tuple[TensorStructType, StateBatch, dict]]: + """Call compute_actions on collected episode/model data to get next action. + + Args: + to_eval (Dict[PolicyID, List[PolicyEvalData]]): Mapping of policy + IDs to lists of PolicyEvalData objects (items in these lists will + be the batch's items for the model forward pass). + policies (Dict[PolicyID, Policy]): Mapping from policy ID to Policy + obj. + _sample_collector (SampleCollector): The SampleCollector object to use. + tf_sess (Optional[tf.Session]): Optional tensorflow session to use for + batching TF policy evaluations. + + Returns: + eval_results: dict of policy to compute_action() outputs. + """ + + eval_results: Dict[PolicyID, TensorStructType] = {} + + if tf_sess: + builder = TFRunBuilder(tf_sess, "policy_eval") + pending_fetches: Dict[PolicyID, Any] = {} + else: + builder = None + + if log_once("compute_actions_input"): + logger.info("Inputs to compute_actions():\n\n{}\n".format( + summarize(to_eval))) + + for policy_id in to_eval: + policy: Policy = _get_or_raise(policies, policy_id) + input_dict = _sample_collector.get_inference_input_dict(policy_id) + eval_results[policy_id] = \ + policy.compute_actions_from_input_dict( + input_dict, timestep=policy.global_timestep) + if builder: # type: PolicyID, Tuple[TensorStructType, StateBatch, dict] for pid, v in pending_fetches.items(): @@ -943,7 +1312,8 @@ def _process_policy_eval_results( off_policy_actions: MultiEnvDict, policies: Dict[PolicyID, Policy], clip_actions: bool, - _use_trajectory_view_api: bool = False + _use_trajectory_view_api: bool = False, + _sample_collector=None, ) -> Dict[EnvID, Dict[AgentID, EnvActionType]]: """Process the output of policy neural network evaluation. @@ -980,11 +1350,10 @@ def _process_policy_eval_results( actions_to_send[env_id] = {} # at minimum send empty dict # type: PolicyID, List[PolicyEvalData] - for policy_id, eval_data in to_eval.items(): - rnn_in_cols: StateBatch = _to_column_format( - [t.rnn_state for t in eval_data]) - + for policy_id in to_eval: actions: TensorStructType = eval_results[policy_id][0] + actions = convert_to_numpy(actions) + rnn_out_cols: StateBatch = eval_results[policy_id][1] pi_info_cols: dict = eval_results[policy_id][2] @@ -993,40 +1362,58 @@ def _process_policy_eval_results( if isinstance(actions, list): actions = np.array(actions) - if len(rnn_in_cols) != len(rnn_out_cols): - raise ValueError("Length of RNN in did not match RNN out, got: " - "{} vs {}".format(rnn_in_cols, rnn_out_cols)) - # Add RNN state info - for f_i, column in enumerate(rnn_in_cols): - pi_info_cols["state_in_{}".format(f_i)] = column - for f_i, column in enumerate(rnn_out_cols): - pi_info_cols["state_out_{}".format(f_i)] = column + # Add RNN state info. + eval_data = None + if not _use_trajectory_view_api: + eval_data = to_eval[policy_id] + rnn_in_cols: StateBatch = _to_column_format( + [t.rnn_state for t in eval_data]) + + if len(rnn_in_cols) != len(rnn_out_cols): + raise ValueError( + "Length of RNN in did not match RNN out, got: " + "{} vs {}".format(rnn_in_cols, rnn_out_cols)) + for f_i, column in enumerate(rnn_in_cols): + pi_info_cols["state_in_{}".format(f_i)] = column + for f_i, column in enumerate(rnn_out_cols): + pi_info_cols["state_out_{}".format(f_i)] = column policy: Policy = _get_or_raise(policies, policy_id) # Split action-component batches into single action rows. actions: List[EnvActionType] = unbatch(actions) # type: int, EnvActionType for i, action in enumerate(actions): - env_id: int = eval_data[i].env_id - agent_id: AgentID = eval_data[i].agent_id # Clip if necessary. if clip_actions: clipped_action = clip_action(action, policy.action_space_struct) else: clipped_action = action - actions_to_send[env_id][agent_id] = clipped_action - episode: MultiAgentEpisode = active_episodes[env_id] - episode._set_rnn_state(agent_id, [c[i] for c in rnn_out_cols]) - episode._set_last_pi_info( - agent_id, {k: v[i] - for k, v in pi_info_cols.items()}) - if env_id in off_policy_actions and \ - agent_id in off_policy_actions[env_id]: - episode._set_last_action(agent_id, - off_policy_actions[env_id][agent_id]) + + # Trajectory View API: Do not store data directly in episode + # (entire episode is stored in Trajectory and kept until + # end of episode). + if _use_trajectory_view_api: + agent_id, episode_id, env_id = \ + _sample_collector.policy_sample_collectors[ + policy_id].forward_pass_index_to_agent_info[i] else: - episode._set_last_action(agent_id, action) + env_id: int = eval_data[i].env_id + agent_id: AgentID = eval_data[i].agent_id + episode: MultiAgentEpisode = active_episodes[env_id] + episode._set_rnn_state(agent_id, [c[i] for c in rnn_out_cols]) + episode._set_last_pi_info( + agent_id, {k: v[i] + for k, v in pi_info_cols.items()}) + if env_id in off_policy_actions and \ + agent_id in off_policy_actions[env_id]: + episode._set_last_action( + agent_id, off_policy_actions[env_id][agent_id]) + else: + episode._set_last_action(agent_id, action) + + assert agent_id not in actions_to_send[env_id] + actions_to_send[env_id][agent_id] = clipped_action return actions_to_send @@ -1054,20 +1441,21 @@ def _to_column_format(rnn_state_rows: List[List[Any]]) -> StateBatch: return [[row[i] for row in rnn_state_rows] for i in range(num_cols)] -def _get_or_raise(mapping: Dict[PolicyID, Policy], - policy_id: PolicyID) -> Policy: - """Returns a Policy object under key `policy_id` in `mapping`. +def _get_or_raise(mapping: Dict[PolicyID, Union[Policy, Preprocessor, Filter]], + policy_id: PolicyID) -> Union[Policy, Preprocessor, Filter]: + """Returns an object under key `policy_id` in `mapping`. Args: - mapping (dict): The mapping dict from policy id (str) to - actual Policy object. + mapping (Dict[PolicyID, Union[Policy, Preprocessor, Filter]]): The + mapping dict from policy id (str) to actual object (Policy, + Preprocessor, etc.). policy_id (str): The policy ID to lookup. Returns: - Policy: The found Policy object. + Union[Policy, Preprocessor, Filter]: The found object. Throws: - ValueError: If `policy_id` cannot be found. + ValueError: If `policy_id` cannot be found in `mapping`. """ if policy_id not in mapping: raise ValueError( diff --git a/rllib/evaluation/tests/test_trajectory_view_api.py b/rllib/evaluation/tests/test_trajectory_view_api.py new file mode 100644 index 000000000..4a1822da3 --- /dev/null +++ b/rllib/evaluation/tests/test_trajectory_view_api.py @@ -0,0 +1,277 @@ +import copy +from gym.spaces import Box, Discrete +import time +import unittest + +import ray +import ray.rllib.agents.ppo as ppo +from ray.rllib.examples.env.debug_counter_env import MultiAgentDebugCounterEnv +from ray.rllib.evaluation.rollout_worker import RolloutWorker +from ray.rllib.examples.policy.episode_env_aware_policy import \ + EpisodeEnvAwarePolicy +from ray.rllib.policy.sample_batch import SampleBatch +from ray.rllib.utils.test_utils import framework_iterator + + +class TestTrajectoryViewAPI(unittest.TestCase): + @classmethod + def setUpClass(cls) -> None: + ray.init() + + @classmethod + def tearDownClass(cls) -> None: + ray.shutdown() + + def test_traj_view_normal_case(self): + """Tests, whether Model and Policy return the correct ViewRequirements. + """ + config = ppo.DEFAULT_CONFIG.copy() + for _ in framework_iterator(config, frameworks="torch"): + trainer = ppo.PPOTrainer(config, env="CartPole-v0") + policy = trainer.get_policy() + view_req_model = policy.model.inference_view_requirements + view_req_policy = policy.training_view_requirements + assert len(view_req_model) == 1 + assert len(view_req_policy) == 10 + for key in [ + SampleBatch.OBS, SampleBatch.ACTIONS, SampleBatch.REWARDS, + SampleBatch.DONES, SampleBatch.NEXT_OBS, + SampleBatch.VF_PREDS, "advantages", "value_targets", + SampleBatch.ACTION_DIST_INPUTS, SampleBatch.ACTION_LOGP + ]: + assert key in view_req_policy + # None of the view cols has a special underlying data_col, + # except next-obs. + if key != SampleBatch.NEXT_OBS: + assert view_req_policy[key].data_col is None + else: + assert view_req_policy[key].data_col == SampleBatch.OBS + assert view_req_policy[key].shift == 1 + trainer.stop() + + def test_traj_view_lstm_prev_actions_and_rewards(self): + """Tests, whether Policy/Model return correct LSTM ViewRequirements. + """ + config = ppo.DEFAULT_CONFIG.copy() + config["model"] = config["model"].copy() + # Activate LSTM + prev-action + rewards. + config["model"]["use_lstm"] = True + config["model"]["lstm_use_prev_action_reward"] = True + + for _ in framework_iterator(config, frameworks="torch"): + trainer = ppo.PPOTrainer(config, env="CartPole-v0") + policy = trainer.get_policy() + view_req_model = policy.model.inference_view_requirements + view_req_policy = policy.training_view_requirements + assert len(view_req_model) == 7 # obs, prev_a, prev_r, 4xstates + assert len(view_req_policy) == 16 + for key in [ + SampleBatch.OBS, SampleBatch.ACTIONS, SampleBatch.REWARDS, + SampleBatch.DONES, SampleBatch.NEXT_OBS, + SampleBatch.VF_PREDS, SampleBatch.PREV_ACTIONS, + SampleBatch.PREV_REWARDS, "advantages", "value_targets", + SampleBatch.ACTION_DIST_INPUTS, SampleBatch.ACTION_LOGP + ]: + assert key in view_req_policy + + if key == SampleBatch.PREV_ACTIONS: + assert view_req_policy[key].data_col == SampleBatch.ACTIONS + assert view_req_policy[key].shift == -1 + elif key == SampleBatch.PREV_REWARDS: + assert view_req_policy[key].data_col == SampleBatch.REWARDS + assert view_req_policy[key].shift == -1 + elif key not in [ + SampleBatch.NEXT_OBS, SampleBatch.PREV_ACTIONS, + SampleBatch.PREV_REWARDS + ]: + assert view_req_policy[key].data_col is None + else: + assert view_req_policy[key].data_col == SampleBatch.OBS + assert view_req_policy[key].shift == 1 + trainer.stop() + + def test_traj_view_lstm_performance(self): + """Test whether PPOTrainer runs faster w/ `_use_trajectory_view_api`. + """ + config = copy.deepcopy(ppo.DEFAULT_CONFIG) + action_space = Discrete(2) + obs_space = Box(-1.0, 1.0, shape=(700, )) + + from ray.rllib.examples.env.random_env import RandomMultiAgentEnv + + from ray.tune import register_env + register_env("ma_env", lambda c: RandomMultiAgentEnv({ + "num_agents": 2, + "p_done": 0.01, + "action_space": action_space, + "observation_space": obs_space + })) + + config["num_workers"] = 3 + config["num_envs_per_worker"] = 8 + config["num_sgd_iter"] = 6 + config["model"]["use_lstm"] = True + config["model"]["lstm_use_prev_action_reward"] = True + config["model"]["max_seq_len"] = 100 + + policies = { + "pol0": (None, obs_space, action_space, {}), + } + + def policy_fn(agent_id): + return "pol0" + + config["multiagent"] = { + "policies": policies, + "policy_mapping_fn": policy_fn, + } + num_iterations = 1 + # Only works in torch so far. + for _ in framework_iterator(config, frameworks="torch"): + print("w/ traj. view API (and time-major)") + config["_use_trajectory_view_api"] = True + config["model"]["_time_major"] = True + trainer = ppo.PPOTrainer(config=config, env="ma_env") + learn_time_w = 0.0 + sampler_perf = {} + start = time.time() + for i in range(num_iterations): + out = trainer.train() + sampler_perf_ = out["sampler_perf"] + sampler_perf = { + k: sampler_perf.get(k, 0.0) + sampler_perf_[k] + for k, v in sampler_perf_.items() + } + delta = out["timers"]["learn_time_ms"] / 1000 + learn_time_w += delta + print("{}={}s".format(i, delta)) + sampler_perf = { + k: sampler_perf[k] / (num_iterations if "mean_" in k else 1) + for k, v in sampler_perf.items() + } + duration_w = time.time() - start + print("Duration: {}s " + "sampler-perf.={} learn-time/iter={}s".format( + duration_w, sampler_perf, learn_time_w / num_iterations)) + trainer.stop() + + print("w/o traj. view API (and w/o time-major)") + config["_use_trajectory_view_api"] = False + config["model"]["_time_major"] = False + trainer = ppo.PPOTrainer(config=config, env="ma_env") + learn_time_wo = 0.0 + sampler_perf = {} + start = time.time() + for i in range(num_iterations): + out = trainer.train() + sampler_perf_ = out["sampler_perf"] + sampler_perf = { + k: sampler_perf.get(k, 0.0) + sampler_perf_[k] + for k, v in sampler_perf_.items() + } + delta = out["timers"]["learn_time_ms"] / 1000 + learn_time_wo += delta + print("{}={}s".format(i, delta)) + sampler_perf = { + k: sampler_perf[k] / (num_iterations if "mean_" in k else 1) + for k, v in sampler_perf.items() + } + duration_wo = time.time() - start + print("Duration: {}s " + "sampler-perf.={} learn-time/iter={}s".format( + duration_wo, sampler_perf, + learn_time_wo / num_iterations)) + trainer.stop() + + # Assert `_use_trajectory_view_api` is much faster. + self.assertLess(duration_w, duration_wo) + self.assertLess(learn_time_w, learn_time_wo * 0.6) + + def test_traj_view_lstm_functionality(self): + action_space = Box(-float("inf"), float("inf"), shape=(2, )) + obs_space = Box(float("-inf"), float("inf"), (4, )) + max_seq_len = 50 + policies = { + "pol0": (EpisodeEnvAwarePolicy, obs_space, action_space, {}), + } + + def policy_fn(agent_id): + return "pol0" + + rollout_worker = RolloutWorker( + env_creator=lambda _: MultiAgentDebugCounterEnv({"num_agents": 4}), + policy_config={ + "multiagent": { + "policies": policies, + "policy_mapping_fn": policy_fn, + }, + "_use_trajectory_view_api": True, + "model": { + "use_lstm": True, + "_time_major": True, + "max_seq_len": max_seq_len, + }, + }, + policy=policies, + policy_mapping_fn=policy_fn, + num_envs=1, + ) + for i in range(100): + pc = rollout_worker.sampler.sample_collector. \ + policy_sample_collectors["pol0"] + sample_batch_offset_before = pc.sample_batch_offset + buffers = pc.buffers + result = rollout_worker.sample() + pol_batch = result.policy_batches["pol0"] + + self.assertTrue(result.count == 100) + self.assertTrue(pol_batch.count >= 100) + self.assertFalse(0 in pol_batch.seq_lens) + # Check prev_reward/action, next_obs consistency. + for t in range(max_seq_len): + obs_t = pol_batch["obs"][t] + r_t = pol_batch["rewards"][t] + if t > 0: + next_obs_t_m_1 = pol_batch["new_obs"][t - 1] + self.assertTrue((obs_t == next_obs_t_m_1).all()) + if t < max_seq_len - 1: + prev_rewards_t_p_1 = pol_batch["prev_rewards"][t + 1] + self.assertTrue((r_t == prev_rewards_t_p_1).all()) + + # Check the sanity of all the buffers in the un underlying + # PerPolicy collector. + for sample_batch_slot, agent_slot in enumerate( + range(sample_batch_offset_before, pc.sample_batch_offset)): + t_buf = buffers["t"][:, agent_slot] + obs_buf = buffers["obs"][:, agent_slot] + # Skip empty seqs at end (these won't be part of the batch + # and have been copied to new agent-slots (even if seq-len=0)). + if sample_batch_slot < len(pol_batch.seq_lens): + seq_len = pol_batch.seq_lens[sample_batch_slot] + # Make sure timesteps are always increasing within the seq. + assert all(t_buf[1] + j == n + 1 + for j, n in enumerate(t_buf) + if j < seq_len and j != 0) + # Make sure all obs within seq are non-0.0. + assert all( + any(obs_buf[j] != 0.0) for j in range(1, seq_len + 1)) + + # Check seq-lens. + for agent_slot, seq_len in enumerate(pol_batch.seq_lens): + if seq_len < max_seq_len - 1: + # At least in the beginning, the next slots should always + # be empty (once all agent slots have been used once, these + # may be filled with "old" values (from longer sequences)). + if i < 10: + self.assertTrue( + (pol_batch["obs"][seq_len + + 1][agent_slot] == 0.0).all()) + print(end="") + self.assertFalse( + (pol_batch["obs"][seq_len][agent_slot] == 0.0).all()) + + +if __name__ == "__main__": + import pytest + import sys + sys.exit(pytest.main(["-v", __file__])) diff --git a/rllib/examples/env/debug_counter_env.py b/rllib/examples/env/debug_counter_env.py index 0fc3ad972..0d2adcba8 100644 --- a/rllib/examples/env/debug_counter_env.py +++ b/rllib/examples/env/debug_counter_env.py @@ -1,4 +1,7 @@ import gym +import numpy as np + +from ray.rllib.env.multi_agent_env import MultiAgentEnv class DebugCounterEnv(gym.Env): @@ -21,3 +24,43 @@ class DebugCounterEnv(gym.Env): def step(self, action): self.i += 1 return [self.i], self.i % 3, self.i >= 15, {} + + +class MultiAgentDebugCounterEnv(MultiAgentEnv): + def __init__(self, config): + self.num_agents = config["num_agents"] + self.p_done = config.get("p_done", 0.02) + # Actions are always: + # (episodeID, envID) as floats. + self.action_space = \ + gym.spaces.Box(-float("inf"), float("inf"), shape=(2, )) + # Observation dims: + # 0=agent ID. + # 1=episode ID (0.0 for obs after reset). + # 2=env ID (0.0 for obs after reset). + # 3=ts (of the agent). + self.observation_space = \ + gym.spaces.Box(float("-inf"), float("inf"), (4, )) + self.timesteps = [0] * self.num_agents + self.dones = set() + + def reset(self): + self.dones = set() + return { + i: np.array([i, 0.0, 0.0, 0.0], dtype=np.float32) + for i in range(self.num_agents) + } + + def step(self, action_dict): + obs, rew, done = {}, {}, {} + for i, action in action_dict.items(): + self.timesteps[i] += 1 + obs[i] = np.array([i, action[0], action[1], self.timesteps[i]]) + rew[i] = self.timesteps[i] % 3 + done[i] = bool( + np.random.choice( + [True, False], p=[self.p_done, 1.0 - self.p_done])) + if done[i]: + self.dones.add(i) + done["__all__"] = len(self.dones) == self.num_agents + return obs, rew, done, {} diff --git a/rllib/examples/policy/episode_env_aware_policy.py b/rllib/examples/policy/episode_env_aware_policy.py new file mode 100644 index 000000000..59018b856 --- /dev/null +++ b/rllib/examples/policy/episode_env_aware_policy.py @@ -0,0 +1,66 @@ +import numpy as np + +from ray.rllib.examples.policy.random_policy import RandomPolicy +from ray.rllib.policy.policy import Policy +from ray.rllib.policy.sample_batch import SampleBatch +from ray.rllib.policy.view_requirement import ViewRequirement +from ray.rllib.utils.annotations import override + + +class EpisodeEnvAwarePolicy(RandomPolicy): + """A Policy that always knows the current EpisodeID and EnvID and + returns these in its actions.""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.episode_id = None + self.env_id = None + + class _fake_model: + pass + + self.model = _fake_model() + self.model.time_major = True + self.model.inference_view_requirements = { + SampleBatch.EPS_ID: ViewRequirement(), + "env_id": ViewRequirement(), + SampleBatch.OBS: ViewRequirement(), + SampleBatch.PREV_ACTIONS: ViewRequirement( + SampleBatch.ACTIONS, space=self.action_space, shift=-1), + SampleBatch.PREV_REWARDS: ViewRequirement( + SampleBatch.REWARDS, shift=-1), + } + self.training_view_requirements = dict( + **{ + SampleBatch.NEXT_OBS: ViewRequirement( + SampleBatch.OBS, shift=1), + SampleBatch.ACTIONS: ViewRequirement(space=self.action_space), + SampleBatch.REWARDS: ViewRequirement(), + SampleBatch.DONES: ViewRequirement(), + }, + **self.model.inference_view_requirements) + + @override(Policy) + def is_recurrent(self): + return True + + @override(Policy) + def compute_actions_from_input_dict(self, + input_dict, + explore=None, + timestep=None, + **kwargs): + self.episode_id = input_dict[SampleBatch.EPS_ID][0] + self.env_id = input_dict["env_id"][0] + # Always return (episodeID, envID) + return [ + np.array([self.episode_id, self.env_id]) for _ in input_dict["obs"] + ], [], {} + + @override(Policy) + def postprocess_trajectory(self, + sample_batch, + other_agent_batches=None, + episode=None): + sample_batch["postprocessed_column"] = sample_batch["obs"] + 1.0 + return sample_batch diff --git a/rllib/models/catalog.py b/rllib/models/catalog.py index 37df032a0..cc2bfa6f4 100644 --- a/rllib/models/catalog.py +++ b/rllib/models/catalog.py @@ -65,6 +65,9 @@ MODEL_DEFAULTS: ModelConfigDict = { "lstm_cell_size": 256, # Whether to feed a_{t-1}, r_{t-1} to LSTM. "lstm_use_prev_action_reward": False, + # Experimental (only works with `_use_trajectory_view_api`=True): + # Whether the LSTM is time-major (TxBx..) or batch-major (BxTx..). + "_time_major": False, # When using modelv1 models with a modelv2 algorithm, you may have to # define the state shape here (e.g., [256, 256]). "state_shape": None, diff --git a/rllib/models/modelv2.py b/rllib/models/modelv2.py index 5ead72459..15dbfb8c1 100644 --- a/rllib/models/modelv2.py +++ b/rllib/models/modelv2.py @@ -58,6 +58,10 @@ class ModelV2: self.name: str = name or "default_model" self.framework: str = framework self._last_output = None + self.time_major = self.model_config.get("_time_major") + self.inference_view_requirements = { + SampleBatch.OBS: ViewRequirement(shift=0), + } @PublicAPI def get_initial_state(self) -> List[np.ndarray]: @@ -246,26 +250,6 @@ class ModelV2: i += 1 return self.__call__(input_dict, states, train_batch.get("seq_lens")) - def inference_view_requirements(self) -> Dict[str, ViewRequirement]: - """Returns a dict of ViewRequirements for this Model. - - Note: This is an experimental API method. - - The view requirements dict is used to generate input_dicts and - train batches for 1) action computations, 2) postprocessing, and 3) - generating training batches. - - Returns: - Dict[str, ViewRequirement]: The view requirements dict, mapping - each view key (which will be available in input_dicts) to - an underlying requirement (actual data, timestep shift, etc..). - """ - # Default implementation for simple RL model: - # Single requirement: Pass current obs as input. - return { - SampleBatch.OBS: ViewRequirement(shift=0), - } - def import_from_h5(self, h5_file: str) -> None: """Imports weights from an h5 file. @@ -322,6 +306,16 @@ class ModelV2: """ raise NotImplementedError + @PublicAPI + def is_time_major(self) -> bool: + """If True, data for calling this ModelV2 must be in time-major format. + + Returns + bool: Whether this ModelV2 requires a time-major (TxBx...) data + format. + """ + return self.time_major is True + class NullContextManager: """No-op context manager""" diff --git a/rllib/models/tf/recurrent_net.py b/rllib/models/tf/recurrent_net.py index 355213800..d6c9bfab8 100644 --- a/rllib/models/tf/recurrent_net.py +++ b/rllib/models/tf/recurrent_net.py @@ -54,10 +54,11 @@ class RecurrentNetwork(TFModelV2): You should implement forward_rnn() in your subclass.""" assert seq_lens is not None - + padded_inputs = input_dict["obs_flat"] + max_seq_len = tf.shape(padded_inputs)[0] // tf.shape(seq_lens)[0] output, new_state = self.forward_rnn( add_time_dimension( - input_dict["obs_flat"], seq_lens, framework="tf"), state, + padded_inputs, max_seq_len=max_seq_len, framework="tf"), state, seq_lens) return tf.reshape(output, [-1, self.num_outputs]), new_state diff --git a/rllib/models/torch/recurrent_net.py b/rllib/models/torch/recurrent_net.py index 1684d11c6..3e37d60bf 100644 --- a/rllib/models/torch/recurrent_net.py +++ b/rllib/models/torch/recurrent_net.py @@ -1,5 +1,5 @@ +from gym.spaces import Box import numpy as np -from typing import Dict from ray.rllib.models.modelv2 import ModelV2 from ray.rllib.models.torch.misc import SlimFC @@ -63,13 +63,20 @@ class RecurrentNetwork(TorchModelV2): """Adds time dimension to batch before sending inputs to forward_rnn(). You should implement forward_rnn() in your subclass.""" + flat_inputs = input_dict["obs_flat"].float() if isinstance(seq_lens, np.ndarray): seq_lens = torch.Tensor(seq_lens).int() - output, new_state = self.forward_rnn( - add_time_dimension( - input_dict["obs_flat"].float(), seq_lens, framework="torch"), - state, seq_lens) - return torch.reshape(output, [-1, self.num_outputs]), new_state + max_seq_len = flat_inputs.shape[0] // seq_lens.shape[0] + self.time_major = self.model_config.get("_time_major", False) + inputs = add_time_dimension( + flat_inputs, + max_seq_len=max_seq_len, + framework="torch", + time_major=self.time_major, + ) + output, new_state = self.forward_rnn(inputs, state, seq_lens) + output = torch.reshape(output, [-1, self.num_outputs]) + return output, new_state def forward_rnn(self, inputs, state, seq_lens): """Call the model with the given input tensors and state. @@ -104,13 +111,15 @@ class LSTMWrapper(RecurrentNetwork, nn.Module): super().__init__(obs_space, action_space, None, model_config, name) self.cell_size = model_config["lstm_cell_size"] + self.time_major = model_config.get("_time_major", False) self.use_prev_action_reward = model_config[ "lstm_use_prev_action_reward"] self.action_dim = int(np.product(action_space.shape)) # Add prev-action/reward nodes to input to LSTM. if self.use_prev_action_reward: self.num_outputs += 1 + self.action_dim - self.lstm = nn.LSTM(self.num_outputs, self.cell_size, batch_first=True) + self.lstm = nn.LSTM( + self.num_outputs, self.cell_size, batch_first=not self.time_major) self.num_outputs = num_outputs @@ -126,6 +135,26 @@ class LSTMWrapper(RecurrentNetwork, nn.Module): activation_fn=None, initializer=torch.nn.init.xavier_uniform_) + self.inference_view_requirements.update( + dict( + **{ + SampleBatch.OBS: ViewRequirement(shift=0), + SampleBatch.PREV_REWARDS: ViewRequirement( + SampleBatch.REWARDS, shift=-1), + SampleBatch.PREV_ACTIONS: ViewRequirement( + SampleBatch.ACTIONS, space=self.action_space, + shift=-1), + })) + for i in range(2): + self.inference_view_requirements["state_in_{}".format(i)] = \ + ViewRequirement( + "state_out_{}".format(i), + shift=-1, + space=Box(-1.0, 1.0, shape=(self.cell_size,))) + self.inference_view_requirements["state_out_{}".format(i)] = \ + ViewRequirement( + space=Box(-1.0, 1.0, shape=(self.cell_size,))) + @override(RecurrentNetwork) def forward(self, input_dict, state, seq_lens): assert seq_lens is not None @@ -150,10 +179,24 @@ class LSTMWrapper(RecurrentNetwork, nn.Module): @override(RecurrentNetwork) def forward_rnn(self, inputs, state, seq_lens): + # Don't show paddings to RNN(?) + # TODO: (sven) For now, only allow, iff time_major=True to not break + # anything retrospectively (time_major not supported previously). + # max_seq_len = inputs.shape[0] + # time_major = self.model_config["_time_major"] + # if time_major and max_seq_len > 1: + # inputs = torch.nn.utils.rnn.pack_padded_sequence( + # inputs, seq_lens, + # batch_first=not time_major, enforce_sorted=False) self._features, [h, c] = self.lstm( inputs, [torch.unsqueeze(state[0], 0), torch.unsqueeze(state[1], 0)]) + # Re-apply paddings. + # if time_major and max_seq_len > 1: + # self._features, _ = torch.nn.utils.rnn.pad_packed_sequence( + # self._features, + # batch_first=not time_major) model_out = self._logits_branch(self._features) return model_out, [torch.squeeze(h, 0), torch.squeeze(c, 0)] @@ -171,16 +214,3 @@ class LSTMWrapper(RecurrentNetwork, nn.Module): def value_function(self): assert self._features is not None, "must call forward() first" return torch.reshape(self._value_branch(self._features), [-1]) - - @override(ModelV2) - def inference_view_requirements(self) -> Dict[str, ViewRequirement]: - req = super().inference_view_requirements() - # Optional: prev-actions/rewards for forward pass. - if self.model_config["lstm_use_prev_action_reward"]: - req.update({ - SampleBatch.PREV_REWARDS: ViewRequirement( - SampleBatch.REWARDS, shift=-1), - SampleBatch.PREV_ACTIONS: ViewRequirement( - SampleBatch.ACTIONS, space=self.action_space, shift=-1), - }) - return req diff --git a/rllib/policy/policy.py b/rllib/policy/policy.py index a5ffb3168..0414216f7 100644 --- a/rllib/policy/policy.py +++ b/rllib/policy/policy.py @@ -70,6 +70,11 @@ class Policy(metaclass=ABCMeta): # The action distribution class to use for action sampling, if any. # Child classes may set this. self.dist_class = None + # View requirements dict for a `learn_on_batch()` call. + # Child classes need to add their specific requirements here (usually + # a combination of a Model's inference_view_- and the + # Policy's loss function-requirements. + self.training_view_requirements = {} @abstractmethod @DeveloperAPI @@ -283,25 +288,6 @@ class Policy(metaclass=ABCMeta): """ raise NotImplementedError - @DeveloperAPI - def training_view_requirements(self): - """Returns a dict of view requirements for operating on this Policy. - - Note: This is an experimental API method. - - The view requirements dict is used to generate input_dicts and - SampleBatches for 1) action computations, 2) postprocessing, and 3) - generating training batches. - The Policy may ask its Model(s) as well for possible additional - requirements (e.g. prev-action/reward in an LSTM). - - Returns: - Dict[str, ViewRequirement]: The view requirements dict, mapping - each view key (which will be available in input_dicts) to - an underlying requirement (actual data, timestep shift, etc..). - """ - return {} - @DeveloperAPI def postprocess_trajectory( self, diff --git a/rllib/policy/rnn_sequencing.py b/rllib/policy/rnn_sequencing.py index 8e9ad1205..c0b999974 100644 --- a/rllib/policy/rnn_sequencing.py +++ b/rllib/policy/rnn_sequencing.py @@ -13,12 +13,14 @@ current algorithms: https://github.com/ray-project/ray/issues/2992 import logging import numpy as np +from typing import List, Optional -from ray.util import log_once from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.utils.annotations import DeveloperAPI from ray.rllib.utils.debug import summarize from ray.rllib.utils.framework import try_import_tf, try_import_torch +from ray.rllib.utils.typing import TensorType +from ray.util import log_once tf1, tf, tfv = try_import_tf() torch, _ = try_import_torch() @@ -27,11 +29,14 @@ logger = logging.getLogger(__name__) @DeveloperAPI -def pad_batch_to_sequences_of_same_size(batch, - max_seq_len, - shuffle=False, - batch_divisibility_req=1, - feature_keys=None): +def pad_batch_to_sequences_of_same_size( + batch: SampleBatch, + max_seq_len: int, + shuffle: bool = False, + batch_divisibility_req: int = 1, + feature_keys: Optional[List[str]] = None, + _use_trajectory_view_api: bool = False, +): """Applies padding to `batch` so it's choppable into same-size sequences. Shuffles `batch` (if desired), makes sure divisibility requirement is met, @@ -51,7 +56,26 @@ def pad_batch_to_sequences_of_same_size(batch, feature_keys (Optional[List[str]]): An optional list of keys to apply sequence-chopping to. If None, use all keys in batch that are not "state_in/out_"-type keys. + _use_trajectory_view_api (bool): Whether we are using the Trajectory + View API to collect and process samples. """ + if _use_trajectory_view_api: + if batch.time_major is not None: + batch["seq_lens"] = torch.tensor(batch.seq_lens) + t = 0 if batch.time_major else 1 + for col in batch.data.keys(): + # Cut time-dim from states. + if "state_" in col[:6]: + batch[col] = batch[col][t] + # Flatten all other data. + else: + # Cut time-dim at `max_seq_len`. + if batch.time_major: + batch[col] = batch[col][:batch.max_seq_len] + batch[col] = batch[col].reshape((-1, ) + + batch[col].shape[2:]) + return + if batch_divisibility_req > 1: meets_divisibility_reqs = ( len(batch[SampleBatch.CUR_OBS]) % batch_divisibility_req == 0 @@ -61,7 +85,7 @@ def pad_batch_to_sequences_of_same_size(batch, meets_divisibility_reqs = True # RNN-case. - if "state_in_0" in batch: + if "state_in_0" in batch or "state_out_0" in batch: dynamic_max = True # Multi-agent case. elif not meets_divisibility_reqs: @@ -109,31 +133,32 @@ def pad_batch_to_sequences_of_same_size(batch, @DeveloperAPI -def add_time_dimension(padded_inputs, - seq_lens, - framework="tf", - time_major=False): +def add_time_dimension(padded_inputs: TensorType, + *, + max_seq_len: int, + framework: str = "tf", + time_major: bool = False): """Adds a time dimension to padded inputs. - Arguments: - padded_inputs (Tensor): a padded batch of sequences. That is, + Args: + padded_inputs (TensorType): a padded batch of sequences. That is, for seq_lens=[1, 2, 2], then inputs=[A, *, B, B, C, C], where A, B, C are sequence elements and * denotes padding. - seq_lens (Tensor): the sequence lengths within the input batch, - suitable for passing to tf.nn.dynamic_rnn(). + max_seq_len (int): The max. sequence length in padded_inputs. + framework (str): The framework string ("tf2", "tf", "tfe", "torch"). + time_major (bool): Whether data should be returned in time-major (TxB) + format or not (BxT). Returns: - Reshaped tensor of shape [NUM_SEQUENCES, MAX_SEQ_LEN, ...]. + TensorType: Reshaped tensor of shape [B, T, ...] or [T, B, ...]. """ # Sequence lengths have to be specified for LSTM batch inputs. The # input batch must be padded to the max seq length given here. That is, # batch_size == len(seq_lens) * max(seq_lens) - if framework == "tf": + if framework in ["tf2", "tf", "tfe"]: assert time_major is False, "time-major not supported yet for tf!" padded_batch_size = tf.shape(padded_inputs)[0] - max_seq_len = padded_batch_size // tf.shape(seq_lens)[0] - # Dynamically reshape the padded batch to introduce a time dimension. new_batch_size = padded_batch_size // max_seq_len new_shape = ([new_batch_size, max_seq_len] + @@ -142,7 +167,6 @@ def add_time_dimension(padded_inputs, else: assert framework == "torch", "`framework` must be either tf or torch!" padded_batch_size = padded_inputs.shape[0] - max_seq_len = padded_batch_size // seq_lens.shape[0] # Dynamically reshape the padded batch to introduce a time dimension. new_batch_size = padded_batch_size // max_seq_len @@ -153,6 +177,9 @@ def add_time_dimension(padded_inputs, return torch.reshape(padded_inputs, new_shape) +# NOTE: This function will be deprecated once chunks already come padded and +# correctly chopped from the _SampleCollector object (in time-major fashion +# or not). It is already no longer user iff `_use_trajectory_view_api` = True. @DeveloperAPI def chop_into_sequences(episode_ids, unroll_ids, @@ -166,11 +193,11 @@ def chop_into_sequences(episode_ids, """Truncate and pad experiences into fixed-length sequences. Args: - episode_ids (list): List of episode ids for each step. - unroll_ids (list): List of identifiers for the sample batch. This is - used to make sure sequences are cut between sample batches. - agent_indices (list): List of agent ids for each step. Note that this - has to be combined with episode_ids for uniqueness. + episode_ids (List[EpisodeID]): List of episode ids for each step. + unroll_ids (List[UnrollID]): List of identifiers for the sample batch. + This is used to make sure sequences are cut between sample batches. + agent_indices (List[AgentID]): List of agent ids for each step. Note + that this has to be combined with episode_ids for uniqueness. feature_columns (list): List of arrays containing features. state_columns (list): List of arrays containing LSTM state values. max_seq_len (int): Max length of sequences before truncation. diff --git a/rllib/policy/sample_batch.py b/rllib/policy/sample_batch.py index 32e9c84a6..3436922ef 100644 --- a/rllib/policy/sample_batch.py +++ b/rllib/policy/sample_batch.py @@ -59,19 +59,33 @@ class SampleBatch: def __init__(self, *args, **kwargs): """Constructs a sample batch (same params as dict constructor).""" - self._initial_inputs = kwargs.pop("_initial_inputs", {}) + # Possible seq_lens (TxB or BxT) setup. + self.time_major = kwargs.pop("_time_major", None) + self.seq_lens = kwargs.pop("_seq_lens", None) + self.max_seq_len = None + if self.seq_lens is not None and len(self.seq_lens) > 0: + self.max_seq_len = max(self.seq_lens) + # The actual data, accessible by column name (str). self.data = dict(*args, **kwargs) + lengths = [] for k, v in self.data.copy().items(): assert isinstance(k, str), self lengths.append(len(v)) - self.data[k] = np.array(v, copy=False) + if isinstance(v, list): + self.data[k] = np.array(v) if not lengths: raise ValueError("Empty sample batch") - assert len(set(lengths)) == 1, ("data columns must be same length", - self.data, lengths) - self.count = lengths[0] + assert len(set(lengths)) == 1, \ + "Data columns must be same length, but lens are {}".format(lengths) + if self.seq_lens is not None and len(self.seq_lens) > 0: + self.count = sum(self.seq_lens) + else: + self.count = len(self.data[k]) + + # Keeps track of new columns added after initial ones. + self.new_columns = [] @staticmethod @PublicAPI @@ -88,11 +102,21 @@ class SampleBatch: """ if isinstance(samples[0], MultiAgentBatch): return MultiAgentBatch.concat_samples(samples) + seq_lens = [] + concat_samples = [] + for s in samples: + if s.count > 0: + concat_samples.append(s) + if s.seq_lens is not None: + seq_lens.extend(s.seq_lens) + out = {} - samples = [s for s in samples if s.count > 0] - for k in samples[0].keys(): - out[k] = concat_aligned([s[k] for s in samples]) - return SampleBatch(out) + for k in concat_samples[0].keys(): + out[k] = concat_aligned( + [s[k] for s in concat_samples], + time_major=concat_samples[0].time_major) + return SampleBatch( + out, _seq_lens=seq_lens, _time_major=concat_samples[0].time_major) @PublicAPI def concat(self, other: "SampleBatch") -> "SampleBatch": @@ -222,8 +246,18 @@ class SampleBatch: SampleBatch: A new SampleBatch, which has a slice of this batch's data. """ - - return SampleBatch({k: v[start:end] for k, v in self.data.items()}) + if self.time_major is not None: + return SampleBatch( + {k: v[:, start:end] + for k, v in self.data.items()}, + _seq_lens=self.seq_lens[start:end], + _time_major=self.time_major) + else: + return SampleBatch( + {k: v[start:end] + for k, v in self.data.items()}, + _seq_lens=None, + _time_major=self.time_major) @PublicAPI def timeslices(self, k: int) -> List["SampleBatch"]: @@ -290,7 +324,7 @@ class SampleBatch: key (str): The key (column name) to return. Returns: - TensorType]: The data under the given key. + TensorType: The data under the given key. """ return self.data[key] @@ -302,6 +336,8 @@ class SampleBatch: key (str): The column name to set a value for. item (TensorType): The data to insert. """ + if key not in self.data: + self.new_columns.append(key) self.data[key] = item @DeveloperAPI diff --git a/rllib/policy/tests/test_trajectory_view_api.py b/rllib/policy/tests/test_trajectory_view_api.py deleted file mode 100644 index 91e10650e..000000000 --- a/rllib/policy/tests/test_trajectory_view_api.py +++ /dev/null @@ -1,84 +0,0 @@ -import unittest - -import ray -import ray.rllib.agents.ppo as ppo -from ray.rllib.policy.sample_batch import SampleBatch -from ray.rllib.utils.test_utils import framework_iterator - - -class TestTrajectoryViewAPI(unittest.TestCase): - @classmethod - def setUpClass(cls) -> None: - ray.init() - - @classmethod - def tearDownClass(cls) -> None: - ray.shutdown() - - def test_plain(self): - config = ppo.DEFAULT_CONFIG.copy() - for _ in framework_iterator(config, frameworks="torch"): - trainer = ppo.PPOTrainer(config, env="CartPole-v0") - policy = trainer.get_policy() - view_req_model = policy.model.inference_view_requirements() - view_req_policy = policy.training_view_requirements() - assert len(view_req_model) == 1 - assert len(view_req_policy) == 6 - for key in [ - SampleBatch.OBS, SampleBatch.ACTIONS, SampleBatch.REWARDS, - SampleBatch.DONES, SampleBatch.NEXT_OBS, - SampleBatch.VF_PREDS - ]: - assert key in view_req_policy - # None of the view cols has a special underlying data_col, - # except next-obs. - if key != SampleBatch.NEXT_OBS: - assert view_req_policy[key].data_col is None - else: - assert view_req_policy[key].data_col == SampleBatch.OBS - assert view_req_policy[key].shift == 1 - trainer.stop() - - def test_lstm_prev_actions_and_rewards(self): - config = ppo.DEFAULT_CONFIG.copy() - config["model"] = config["model"].copy() - # Activate LSTM + prev-action + rewards. - config["model"]["use_lstm"] = True - config["model"]["lstm_use_prev_action_reward"] = True - - for _ in framework_iterator(config, frameworks="torch"): - trainer = ppo.PPOTrainer(config, env="CartPole-v0") - policy = trainer.get_policy() - view_req_model = policy.model.inference_view_requirements() - view_req_policy = policy.training_view_requirements() - assert len(view_req_model) == 3 # obs, prev_a, prev_r - assert len(view_req_policy) == 8 - for key in [ - SampleBatch.OBS, SampleBatch.ACTIONS, SampleBatch.REWARDS, - SampleBatch.DONES, SampleBatch.NEXT_OBS, - SampleBatch.VF_PREDS, SampleBatch.PREV_ACTIONS, - SampleBatch.PREV_REWARDS - ]: - assert key in view_req_policy - - if key == SampleBatch.PREV_ACTIONS: - assert view_req_policy[key].data_col == SampleBatch.ACTIONS - assert view_req_policy[key].shift == -1 - elif key == SampleBatch.PREV_REWARDS: - assert view_req_policy[key].data_col == SampleBatch.REWARDS - assert view_req_policy[key].shift == -1 - elif key not in [ - SampleBatch.NEXT_OBS, SampleBatch.PREV_ACTIONS, - SampleBatch.PREV_REWARDS - ]: - assert view_req_policy[key].data_col is None - else: - assert view_req_policy[key].data_col == SampleBatch.OBS - assert view_req_policy[key].shift == 1 - trainer.stop() - - -if __name__ == "__main__": - import pytest - import sys - sys.exit(pytest.main(["-v", __file__])) diff --git a/rllib/policy/torch_policy.py b/rllib/policy/torch_policy.py index 280064a4e..1a908ec86 100644 --- a/rllib/policy/torch_policy.py +++ b/rllib/policy/torch_policy.py @@ -104,16 +104,21 @@ class TorchPolicy(Policy): """ self.framework = "torch" super().__init__(observation_space, action_space, config) - if torch.cuda.is_available() and ray.get_gpu_ids(as_str=True): + if torch.cuda.is_available() and ray.get_gpu_ids(): self.device = torch.device("cuda") else: self.device = torch.device("cpu") self.model = model.to(self.device) # Combine view_requirements for Model and Policy. - self.view_requirements = { - **self.model.inference_view_requirements(), - **self.training_view_requirements(), - } + self.training_view_requirements = dict( + **{ + SampleBatch.ACTIONS: ViewRequirement( + space=self.action_space, shift=0), + SampleBatch.REWARDS: ViewRequirement(shift=0), + SampleBatch.DONES: ViewRequirement(shift=0), + }, + **self.model.inference_view_requirements) + self.exploration = self._create_exploration() self.unwrapped_model = model # used to support DistributedDataParallel self._loss = loss @@ -131,17 +136,6 @@ class TorchPolicy(Policy): callable(get_batch_divisibility_req) else \ (get_batch_divisibility_req or 1) - @override(Policy) - def training_view_requirements(self): - if hasattr(self, "view_requirements"): - return self.view_requirements - return { - SampleBatch.ACTIONS: ViewRequirement( - space=self.action_space, shift=0), - SampleBatch.REWARDS: ViewRequirement(shift=0), - SampleBatch.DONES: ViewRequirement(shift=0), - } - @override(Policy) @DeveloperAPI def compute_actions( @@ -204,9 +198,11 @@ class TorchPolicy(Policy): with torch.no_grad(): # Pass lazy (torch) tensor dict to Model as `input_dict`. input_dict = self._lazy_tensor_dict(input_dict) - # TODO: (sven) support RNNs w/ fast sampling. - state_batches = [] - seq_lens = None + state_batches = [ + input_dict[k] for k in input_dict.keys() if "state_" in k[:6] + ] + seq_lens = np.array([1] * len(input_dict["obs"])) \ + if state_batches else None actions, state_out, extra_fetches, logp = \ self._compute_action_helper( @@ -340,7 +336,9 @@ class TorchPolicy(Policy): postprocessed_batch, max_seq_len=self.max_seq_len, shuffle=False, - batch_divisibility_req=self.batch_divisibility_req) + batch_divisibility_req=self.batch_divisibility_req, + _use_trajectory_view_api=self.config["_use_trajectory_view_api"], + ) train_batch = self._lazy_tensor_dict(postprocessed_batch) @@ -359,6 +357,7 @@ class TorchPolicy(Policy): loss_out, train_batch) assert len(loss_out) == len(self._optimizers) + # assert not any(torch.isnan(l) for l in loss_out) fetches = self.extra_compute_grad_fetches() diff --git a/rllib/policy/torch_policy_template.py b/rllib/policy/torch_policy_template.py index 11cee46bb..1e1fd4806 100644 --- a/rllib/policy/torch_policy_template.py +++ b/rllib/policy/torch_policy_template.py @@ -224,16 +224,13 @@ def build_torch_policy( get_batch_divisibility_req=get_batch_divisibility_req, ) + if callable(training_view_requirements_fn): + self.training_view_requirements.update( + training_view_requirements_fn(self)) + if after_init: after_init(self, obs_space, action_space, config) - @override(TorchPolicy) - def training_view_requirements(self): - req = super().training_view_requirements() - if callable(training_view_requirements_fn): - req.update(training_view_requirements_fn(self)) - return req - @override(Policy) def postprocess_trajectory(self, sample_batch, diff --git a/rllib/policy/view_requirement.py b/rllib/policy/view_requirement.py index 687d76024..df79b6340 100644 --- a/rllib/policy/view_requirement.py +++ b/rllib/policy/view_requirement.py @@ -29,8 +29,7 @@ class ViewRequirement: def __init__(self, data_col: Optional[str] = None, space: gym.Space = None, - shift: Union[int, List[int]] = 0, - created_during_postprocessing: bool = False): + shift: Union[int, List[int]] = 0): """Initializes a ViewRequirement object. Args: @@ -47,11 +46,8 @@ class ViewRequirement: Example: For a view column "obs" in an Atari framestacking fashion, you can set `data_col="obs"` and `shift=[-3, -2, -1, 0]`. - created_during_postprocessing (bool): Whether this column only gets - created during postprocessing. """ self.data_col = data_col self.space = space or gym.spaces.Box( float("-inf"), float("inf"), shape=()) self.shift = shift - self.created_during_postprocessing = created_during_postprocessing diff --git a/rllib/utils/sgd.py b/rllib/utils/sgd.py index bc2e920b0..3a2b711fc 100644 --- a/rllib/utils/sgd.py +++ b/rllib/utils/sgd.py @@ -63,7 +63,8 @@ def minibatches(samples, sgd_minibatch_size): raise NotImplementedError( "Minibatching not implemented for multi-agent in simple mode") - if "state_in_0" in samples.data: + # Replace with `if samples.seq_lens` check. + if "state_in_0" in samples.data or "state_out_0" in samples.data: if log_once("not_shuffling_rnn_data_in_simple_mode"): logger.warning("Not shuffling RNN data for SGD in simple mode") else: @@ -71,9 +72,22 @@ def minibatches(samples, sgd_minibatch_size): i = 0 slices = [] - while i < samples.count: - slices.append((i, i + sgd_minibatch_size)) - i += sgd_minibatch_size + if samples.seq_lens: + seq_no = 0 + while i < samples.count: + seq_no_end = seq_no + actual_count = 0 + while actual_count < sgd_minibatch_size and len( + samples.seq_lens) > seq_no_end: + actual_count += samples.seq_lens[seq_no_end] + seq_no_end += 1 + slices.append((seq_no, seq_no_end)) + i += actual_count + seq_no = seq_no_end + else: + while i < samples.count: + slices.append((i, i + sgd_minibatch_size)) + i += sgd_minibatch_size random.shuffle(slices) for i, j in slices: @@ -100,7 +114,7 @@ def do_minibatch_sgd(samples, policies, local_worker, num_sgd_iter, samples = MultiAgentBatch({DEFAULT_POLICY_ID: samples}, samples.count) fetches = {} - for policy_id, policy in policies.items(): + for policy_id in policies.keys(): if policy_id not in samples.policy_batches: continue diff --git a/rllib/utils/typing.py b/rllib/utils/typing.py index 96a8893ef..513b00c21 100644 --- a/rllib/utils/typing.py +++ b/rllib/utils/typing.py @@ -43,6 +43,9 @@ EnvID = int # Represents an episode id. EpisodeID = int +# Represents an "unroll" (maybe across different sub-envs in a vector env). +UnrollID = int + # A dict keyed by agent ids, e.g. {"agent-1": value}. MultiAgentDict = Dict[AgentID, Any]