diff --git a/.travis.yml b/.travis.yml index 9699f4dfa..d23cac22c 100644 --- a/.travis.yml +++ b/.travis.yml @@ -162,7 +162,7 @@ matrix: script: - . ./ci/travis/ci.sh test_wheels - # RLlib: Learning tests (from rllib/tuned_examples/regression_tests/*.yaml). + # RLlib: Learning tests (from rllib/tuned_examples/*.yaml). - os: linux env: - RLLIB_TESTING=1 RLLIB_REGRESSION_TESTS=1 @@ -178,7 +178,7 @@ matrix: script: - ./ci/keep_alive bazel test --config=ci --test_output=errors --build_tests_only --test_tag_filters=learning_tests_tf rllib/... - # RLlib: Learning tests with tf=1.x (from rllib/tuned_examples/regression_tests/*.yaml). + # RLlib: Learning tests with tf=1.x (from rllib/tuned_examples/*.yaml). # Requested by Edi (MS): Test all learning capabilities with tf1.x - os: linux env: @@ -195,7 +195,7 @@ matrix: script: - ./ci/keep_alive bazel test --config=ci --test_output=errors --build_tests_only --test_tag_filters=learning_tests_tf rllib/... - # RLlib: Learning tests with torch (from rllib/tuned_examples/regression_tests/*.yaml). + # RLlib: Learning tests with torch (from rllib/tuned_examples/*.yaml). - os: linux env: - RLLIB_TESTING=1 RLLIB_REGRESSION_TESTS_TORCH=1 @@ -250,7 +250,7 @@ matrix: - ./ci/keep_alive bazel test --config=ci --build_tests_only --test_tag_filters=examples_E,examples_F,examples_G,examples_H,examples_I,examples_J,examples_K,examples_L,examples_M,examples_N,examples_O,examples_P rllib/... - ./ci/keep_alive bazel test --config=ci --build_tests_only --test_tag_filters=examples_Q,examples_R,examples_S,examples_T,examples_U,examples_V,examples_W,examples_X,examples_Y,examples_Z rllib/... - # RLlib: tests_dir: Everything in rllib/tests/ directory (A-I). + # RLlib: tests_dir: Everything in rllib/tests/ directory (A-L). - os: linux env: - RLLIB_TESTING=1 RLLIB_TESTS_DIR_TESTS_A_TO_L=1 @@ -266,7 +266,7 @@ matrix: script: - ./ci/keep_alive bazel test --config=ci --build_tests_only --test_tag_filters=tests_dir_A,tests_dir_B,tests_dir_C,tests_dir_D,tests_dir_E,tests_dir_F,tests_dir_G,tests_dir_H,tests_dir_I,tests_dir_J,tests_dir_K,tests_dir_L rllib/... - # RLlib: tests_dir: Everything in rllib/tests/ directory (J-Z). + # RLlib: tests_dir: Everything in rllib/tests/ directory (M-Z). - os: linux env: - RLLIB_TESTING=1 RLLIB_TESTS_DIR_TESTS_M_TO_Z=1 diff --git a/rllib/__init__.py b/rllib/__init__.py index 5d6c9b90c..21207b600 100644 --- a/rllib/__init__.py +++ b/rllib/__init__.py @@ -6,9 +6,7 @@ from ray.rllib.env.base_env import BaseEnv from ray.rllib.env.external_env import ExternalEnv from ray.rllib.env.multi_agent_env import MultiAgentEnv from ray.rllib.env.vector_env import VectorEnv -from ray.rllib.evaluation.policy_graph import PolicyGraph from ray.rllib.evaluation.rollout_worker import RolloutWorker -from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph from ray.rllib.policy.policy import Policy from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.policy.tf_policy import TFPolicy @@ -61,11 +59,8 @@ _register_all() __all__ = [ "Policy", - "PolicyGraph", "TFPolicy", - "TFPolicyGraph", "RolloutWorker", - "PolicyEvaluator", "SampleBatch", "BaseEnv", "MultiAgentEnv", diff --git a/rllib/evaluation/__init__.py b/rllib/evaluation/__init__.py index d088ad19d..b67c0dcca 100644 --- a/rllib/evaluation/__init__.py +++ b/rllib/evaluation/__init__.py @@ -1,22 +1,14 @@ from ray.rllib.evaluation.episode import MultiAgentEpisode from ray.rllib.evaluation.rollout_worker import RolloutWorker -from ray.rllib.evaluation.policy_evaluator import PolicyEvaluator -from ray.rllib.evaluation.policy_graph import PolicyGraph -from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph -from ray.rllib.evaluation.torch_policy_graph import TorchPolicyGraph -from ray.rllib.evaluation.sample_batch import MultiAgentBatch from ray.rllib.evaluation.sample_batch_builder import ( SampleBatchBuilder, MultiAgentSampleBatchBuilder) from ray.rllib.evaluation.sampler import SyncSampler, AsyncSampler from ray.rllib.evaluation.postprocessing import compute_advantages from ray.rllib.evaluation.metrics import collect_metrics -from ray.rllib.policy.sample_batch import SampleBatch +from ray.rllib.policy.sample_batch import SampleBatch, MultiAgentBatch __all__ = [ "RolloutWorker", - "PolicyGraph", - "TFPolicyGraph", - "TorchPolicyGraph", "SampleBatch", "MultiAgentBatch", "SampleBatchBuilder", @@ -26,5 +18,4 @@ __all__ = [ "compute_advantages", "collect_metrics", "MultiAgentEpisode", - "PolicyEvaluator", ] diff --git a/rllib/evaluation/policy_evaluator.py b/rllib/evaluation/policy_evaluator.py deleted file mode 100644 index a61d4b254..000000000 --- a/rllib/evaluation/policy_evaluator.py +++ /dev/null @@ -1,5 +0,0 @@ -from ray.rllib.utils import renamed_class -from ray.rllib.evaluation import RolloutWorker - -PolicyEvaluator = renamed_class( - RolloutWorker, old_name="rllib.evaluation.PolicyEvaluator") diff --git a/rllib/evaluation/policy_graph.py b/rllib/evaluation/policy_graph.py deleted file mode 100644 index 68f6cc5a9..000000000 --- a/rllib/evaluation/policy_graph.py +++ /dev/null @@ -1,4 +0,0 @@ -from ray.rllib.policy.policy import Policy -from ray.rllib.utils import renamed_class - -PolicyGraph = renamed_class(Policy, old_name="PolicyGraph") diff --git a/rllib/evaluation/rollout_worker.py b/rllib/evaluation/rollout_worker.py index 30f7bb29d..66d6374fc 100644 --- a/rllib/evaluation/rollout_worker.py +++ b/rllib/evaluation/rollout_worker.py @@ -420,11 +420,16 @@ class RolloutWorker(ParallelIteratorWorker): remote_env_batch_wait_ms=remote_env_batch_wait_ms) self.num_envs = num_envs + # `truncate_episodes`: Allow a batch to contain more than one episode + # (fragments) and always make the batch `rollout_fragment_length` + # long. if self.batch_mode == "truncate_episodes": - pack_episodes = True + pack = True + # `complete_episodes`: Never cut episodes and sampler will return + # exactly one (complete) episode per poll. elif self.batch_mode == "complete_episodes": - rollout_fragment_length = float("inf") # never cut episodes - pack_episodes = False # sampler will return 1 episode per poll + rollout_fragment_length = float("inf") + pack = False else: raise ValueError("Unsupported batch mode: {}".format( self.batch_mode)) @@ -450,37 +455,38 @@ class RolloutWorker(ParallelIteratorWorker): if sample_async: self.sampler = AsyncSampler( - self, - self.async_env, - self.policy_map, - policy_mapping_fn, - self.preprocessors, - self.filters, - clip_rewards, - rollout_fragment_length, - self.callbacks, + worker=self, + env=self.async_env, + policies=self.policy_map, + policy_mapping_fn=policy_mapping_fn, + preprocessors=self.preprocessors, + obs_filters=self.filters, + clip_rewards=clip_rewards, + rollout_fragment_length=rollout_fragment_length, + callbacks=self.callbacks, horizon=episode_horizon, - pack=pack_episodes, + pack_multiple_episodes_in_batch=pack, tf_sess=self.tf_sess, clip_actions=clip_actions, blackhole_outputs="simulation" in input_evaluation, soft_horizon=soft_horizon, no_done_at_end=no_done_at_end, observation_fn=observation_fn) + # Start the Sampler thread. self.sampler.start() else: self.sampler = SyncSampler( - self, - self.async_env, - self.policy_map, - policy_mapping_fn, - self.preprocessors, - self.filters, - clip_rewards, - rollout_fragment_length, - self.callbacks, + worker=self, + env=self.async_env, + policies=self.policy_map, + policy_mapping_fn=policy_mapping_fn, + preprocessors=self.preprocessors, + obs_filters=self.filters, + clip_rewards=clip_rewards, + rollout_fragment_length=rollout_fragment_length, + callbacks=self.callbacks, horizon=episode_horizon, - pack=pack_episodes, + pack_multiple_episodes_in_batch=pack, tf_sess=self.tf_sess, clip_actions=clip_actions, soft_horizon=soft_horizon, @@ -503,7 +509,7 @@ class RolloutWorker(ParallelIteratorWorker): This method must be implemented by subclasses. Returns: - SampleBatch|MultiAgentBatch: A columnar batch of experiences + Union[SampleBatch,MultiAgentBatch]: A columnar batch of experiences (e.g., tensors), or a multi-agent batch. Examples: diff --git a/rllib/evaluation/sample_batch.py b/rllib/evaluation/sample_batch.py deleted file mode 100644 index 740fff70c..000000000 --- a/rllib/evaluation/sample_batch.py +++ /dev/null @@ -1,7 +0,0 @@ -from ray.rllib.policy.sample_batch import SampleBatch, MultiAgentBatch -from ray.rllib.utils import renamed_class - -SampleBatch = renamed_class( - SampleBatch, old_name="rllib.evaluation.SampleBatch") -MultiAgentBatch = renamed_class( - MultiAgentBatch, old_name="rllib.evaluation.MultiAgentBatch") diff --git a/rllib/evaluation/sample_batch_builder.py b/rllib/evaluation/sample_batch_builder.py index 845391983..cf1409da8 100644 --- a/rllib/evaluation/sample_batch_builder.py +++ b/rllib/evaluation/sample_batch_builder.py @@ -75,30 +75,47 @@ class MultiAgentSampleBatchBuilder: def __init__(self, policy_map, clip_rewards, callbacks): """Initialize a MultiAgentSampleBatchBuilder. - Arguments: - policy_map (dict): Maps policy ids to policy instances. - clip_rewards (bool): Whether to clip rewards before postprocessing. + Args: + policy_map (Dict[str,Policy]): Maps policy ids to policy instances. + clip_rewards (Union[bool,float]): Whether to clip rewards before + postprocessing (at +/-1.0) or the actual value to +/- clip. callbacks (DefaultCallbacks): RLlib callbacks. """ self.policy_map = policy_map self.clip_rewards = clip_rewards + # Build the Policies' SampleBatchBuilders. self.policy_builders = { k: SampleBatchBuilder() for k in policy_map.keys() } + # Whenever we observe a new agent, add a new SampleBatchBuilder for + # this agent. self.agent_builders = {} + # Internal agent-to-policy map. self.agent_to_policy = {} self.callbacks = callbacks - self.count = 0 # increment this manually + # Number of "inference" steps taken in the environment. + # Regardless of the number of agents involved in each of these steps. + self.count = 0 def total(self): - """Returns summed number of steps across all agent buffers.""" + """Returns the total number of steps taken in the env (all agents). + + Returns: + int: The number of steps taken in total in the environment over all + agents. + """ return sum(a.count for a in self.agent_builders.values()) def has_pending_agent_data(self): - """Returns whether there is pending unprocessed data.""" + """Returns whether there is pending unprocessed data. + + Returns: + bool: True if there is at least one per-agent builder (with data + in it). + """ return len(self.agent_builders) > 0 @@ -115,32 +132,37 @@ class MultiAgentSampleBatchBuilder: if agent_id not in self.agent_builders: self.agent_builders[agent_id] = SampleBatchBuilder() self.agent_to_policy[agent_id] = policy_id - builder = self.agent_builders[agent_id] - builder.add_values(**values) + self.agent_builders[agent_id].add_values(**values) - def postprocess_batch_so_far(self, episode): + def postprocess_batch_so_far(self, episode=None): """Apply policy postprocessors to any unprocessed rows. This pushes the postprocessed per-agent batches onto the per-policy builders, clearing per-agent state. Args: - episode (Optional[MultiAgentEpisode]): Current MultiAgentEpisode - object. + episode (Optional[MultiAgentEpisode]): The Episode object that + holds this MultiAgentBatchBuilder object. """ - # Materialize the batches so far + # Materialize the batches so far. pre_batches = {} for agent_id, builder in self.agent_builders.items(): pre_batches[agent_id] = ( self.policy_map[self.agent_to_policy[agent_id]], builder.build_and_reset()) - # Apply postprocessor + # Apply postprocessor. post_batches = {} - if self.clip_rewards: + if self.clip_rewards is True: for _, (_, pre_batch) in pre_batches.items(): pre_batch["rewards"] = np.sign(pre_batch["rewards"]) + elif self.clip_rewards: + for _, (_, pre_batch) in pre_batches.items(): + pre_batch["rewards"] = np.clip( + pre_batch["rewards"], + a_min=-self.clip_rewards, + a_max=self.clip_rewards) for agent_id, (_, pre_batch) in pre_batches.items(): other_batches = pre_batches.copy() del other_batches[agent_id] @@ -193,15 +215,19 @@ class MultiAgentSampleBatchBuilder: "Alternatively, set no_done_at_end=True to allow this.") @DeveloperAPI - def build_and_reset(self, episode): + def build_and_reset(self, episode=None): """Returns the accumulated sample batches for each policy. Any unprocessed rows will be first postprocessed with a policy postprocessor. The internal state of this builder will be reset. Args: - episode (Optional[MultiAgentEpisode]): Current MultiAgentEpisode - object. + episode (Optional[MultiAgentEpisode]): The Episode object that + holds this MultiAgentBatchBuilder object or None. + + Returns: + MultiAgentBatch: Returns the accumulated sample batches for each + policy. """ self.postprocess_batch_so_far(episode) diff --git a/rllib/evaluation/sampler.py b/rllib/evaluation/sampler.py index dd1b83c10..6bb9c9685 100644 --- a/rllib/evaluation/sampler.py +++ b/rllib/evaluation/sampler.py @@ -1,3 +1,4 @@ +from abc import abstractmethod, ABCMeta from collections import defaultdict, namedtuple import logging import numpy as np @@ -16,7 +17,7 @@ from ray.rllib.env.base_env import BaseEnv, ASYNC_RESET_RETURN from ray.rllib.env.atari_wrappers import get_wrapper_by_cls, MonitorEnv from ray.rllib.offline import InputReader from ray.rllib.utils import try_import_tree -from ray.rllib.utils.annotations import override +from ray.rllib.utils.annotations import override, DeveloperAPI from ray.rllib.utils.debug import summarize from ray.rllib.utils.spaces.space_utils import flatten_to_single_ndarray, \ unbatch @@ -49,7 +50,8 @@ class PerfStats: } -class SamplerInput(InputReader): +@DeveloperAPI +class SamplerInput(InputReader, metaclass=ABCMeta): """Reads input experiences from an existing sampler.""" @override(InputReader) @@ -61,9 +63,29 @@ class SamplerInput(InputReader): else: return batches[0] + @abstractmethod + @DeveloperAPI + def get_data(self): + raise NotImplementedError + @abstractmethod + @DeveloperAPI + def get_metrics(self): + raise NotImplementedError + + @abstractmethod + @DeveloperAPI + def get_extra_batches(self): + raise NotImplementedError + + +@DeveloperAPI class SyncSampler(SamplerInput): + """Sync SamplerInput that collects experiences when `get_data()` is called. + """ + def __init__(self, + *, worker, env, policies, @@ -74,12 +96,50 @@ class SyncSampler(SamplerInput): rollout_fragment_length, callbacks, horizon=None, - pack=False, + pack_multiple_episodes_in_batch=False, tf_sess=None, clip_actions=True, soft_horizon=False, no_done_at_end=False, observation_fn=None): + """Initializes a SyncSampler object. + + Args: + worker (RolloutWorker): The RolloutWorker that will use this + Sampler for sampling. + env (Env): Any Env object. Will be converted into an RLlib BaseEnv. + policies (Dict[str,Policy]): Mapping from policy ID to Policy obj. + policy_mapping_fn (callable): Callable that takes an agent ID and + returns a Policy object. + preprocessors (Dict[str,Preprocessor]): Mapping from policy ID to + Preprocessor object for the observations prior to filtering. + obs_filters (Dict[str,Filter]): Mapping from policy ID to + env Filter object. + clip_rewards (Union[bool,float]): True for +/-1.0 clipping, actual + float value for +/- value clipping. False for no clipping. + rollout_fragment_length (int): The length of a fragment to collect + before building a SampleBatch from the data and resetting + the SampleBatchBuilder object. + callbacks (Callbacks): The Callbacks object to use when episode + events happen during rollout. + horizon (Optional[int]): Hard-reset the Env + pack_multiple_episodes_in_batch (bool): Whether to pack multiple + episodes into each batch. This guarantees batches will be + exactly `rollout_fragment_length` in size. + tf_sess (Optional[tf.Session]): A tf.Session object to use (only if + framework=tf). + clip_actions (bool): Whether to clip actions according to the + given action_space's bounds. + soft_horizon (bool): If True, calculate bootstrapped values as if + episode had ended, but don't physically reset the environment + when the horizon is hit. + no_done_at_end (bool): Ignore the done=True at the end of the + episode and instead record done=False. + observation_fn (Optional[ObservationFunction]): Optional + multi-agent observation func to use for preprocessing + observations. + """ + self.base_env = BaseEnv.to_base_env(env) self.rollout_fragment_length = rollout_fragment_length self.horizon = horizon @@ -89,14 +149,16 @@ class SyncSampler(SamplerInput): self.obs_filters = obs_filters self.extra_batches = queue.Queue() self.perf_stats = PerfStats() + # Create the rollout generator to use for calls to `get_data()`. self.rollout_provider = _env_runner( worker, self.base_env, self.extra_batches.put, self.policies, self.policy_mapping_fn, self.rollout_fragment_length, self.horizon, self.preprocessors, self.obs_filters, clip_rewards, clip_actions, - pack, callbacks, tf_sess, self.perf_stats, soft_horizon, - no_done_at_end, observation_fn) + pack_multiple_episodes_in_batch, callbacks, tf_sess, + self.perf_stats, soft_horizon, no_done_at_end, observation_fn) self.metrics_queue = queue.Queue() + @override(SamplerInput) def get_data(self): while True: item = next(self.rollout_provider) @@ -105,6 +167,7 @@ class SyncSampler(SamplerInput): else: return item + @override(SamplerInput) def get_metrics(self): completed = [] while True: @@ -115,6 +178,7 @@ class SyncSampler(SamplerInput): break return completed + @override(SamplerInput) def get_extra_batches(self): extra = [] while True: @@ -125,8 +189,16 @@ class SyncSampler(SamplerInput): return extra +@DeveloperAPI class AsyncSampler(threading.Thread, SamplerInput): + """Async SamplerInput that collects experiences in thread and queues them. + + Once started, experiences are continuously collected and put into a Queue, + from where they can be unqueued by the caller of `get_data()`. + """ + def __init__(self, + *, worker, env, policies, @@ -137,13 +209,52 @@ class AsyncSampler(threading.Thread, SamplerInput): rollout_fragment_length, callbacks, horizon=None, - pack=False, + pack_multiple_episodes_in_batch=False, tf_sess=None, clip_actions=True, blackhole_outputs=False, soft_horizon=False, no_done_at_end=False, observation_fn=None): + """Initializes a AsyncSampler object. + + Args: + worker (RolloutWorker): The RolloutWorker that will use this + Sampler for sampling. + env (Env): Any Env object. Will be converted into an RLlib BaseEnv. + policies (Dict[str,Policy]): Mapping from policy ID to Policy obj. + policy_mapping_fn (callable): Callable that takes an agent ID and + returns a Policy object. + preprocessors (Dict[str,Preprocessor]): Mapping from policy ID to + Preprocessor object for the observations prior to filtering. + obs_filters (Dict[str,Filter]): Mapping from policy ID to + env Filter object. + clip_rewards (Union[bool,float]): True for +/-1.0 clipping, actual + float value for +/- value clipping. False for no clipping. + rollout_fragment_length (int): The length of a fragment to collect + before building a SampleBatch from the data and resetting + the SampleBatchBuilder object. + callbacks (Callbacks): The Callbacks object to use when episode + events happen during rollout. + horizon (Optional[int]): Hard-reset the Env + pack_multiple_episodes_in_batch (bool): Whether to pack multiple + episodes into each batch. This guarantees batches will be + exactly `rollout_fragment_length` in size. + tf_sess (Optional[tf.Session]): A tf.Session object to use (only if + framework=tf). + clip_actions (bool): Whether to clip actions according to the + given action_space's bounds. + blackhole_outputs (bool): Whether to collect samples, but then + not further process or store them (throw away all samples). + soft_horizon (bool): If True, calculate bootstrapped values as if + episode had ended, but don't physically reset the environment + when the horizon is hit. + no_done_at_end (bool): Ignore the done=True at the end of the + episode and instead record done=False. + observation_fn (Optional[ObservationFunction]): Optional + multi-agent observation func to use for preprocessing + observations. + """ for _, f in obs_filters.items(): assert getattr(f, "is_concurrent", False), \ "Observation Filter must support concurrent updates." @@ -161,7 +272,7 @@ class AsyncSampler(threading.Thread, SamplerInput): self.obs_filters = obs_filters self.clip_rewards = clip_rewards self.daemon = True - self.pack = pack + self.pack_multiple_episodes_in_batch = pack_multiple_episodes_in_batch self.tf_sess = tf_sess self.callbacks = callbacks self.clip_actions = clip_actions @@ -172,6 +283,7 @@ class AsyncSampler(threading.Thread, SamplerInput): self.shutdown = False self.observation_fn = observation_fn + @override(threading.Thread) def run(self): try: self._run() @@ -191,9 +303,9 @@ class AsyncSampler(threading.Thread, SamplerInput): self.worker, self.base_env, extra_batches_putter, self.policies, self.policy_mapping_fn, self.rollout_fragment_length, self.horizon, self.preprocessors, self.obs_filters, self.clip_rewards, - self.clip_actions, self.pack, self.callbacks, self.tf_sess, - self.perf_stats, self.soft_horizon, self.no_done_at_end, - self.observation_fn) + self.clip_actions, self.pack_multiple_episodes_in_batch, + self.callbacks, self.tf_sess, self.perf_stats, self.soft_horizon, + self.no_done_at_end, self.observation_fn) while not self.shutdown: # The timeout variable exists because apparently, if one worker # dies, the other workers won't die with it, unless the timeout is @@ -204,6 +316,7 @@ class AsyncSampler(threading.Thread, SamplerInput): else: queue_putter(item) + @override(SamplerInput) def get_data(self): if not self.is_alive(): raise RuntimeError("Sampling thread has died") @@ -215,6 +328,7 @@ class AsyncSampler(threading.Thread, SamplerInput): return rollout + @override(SamplerInput) def get_metrics(self): completed = [] while True: @@ -225,6 +339,7 @@ class AsyncSampler(threading.Thread, SamplerInput): break return completed + @override(SamplerInput) def get_extra_batches(self): extra = [] while True: @@ -237,14 +352,14 @@ class AsyncSampler(threading.Thread, SamplerInput): def _env_runner(worker, base_env, extra_batch_callback, policies, policy_mapping_fn, rollout_fragment_length, horizon, - preprocessors, obs_filters, clip_rewards, clip_actions, pack, - callbacks, tf_sess, perf_stats, soft_horizon, no_done_at_end, - observation_fn): + preprocessors, obs_filters, clip_rewards, clip_actions, + pack_multiple_episodes_in_batch, callbacks, tf_sess, + perf_stats, soft_horizon, no_done_at_end, observation_fn): """This implements the common experience collection logic. Args: - worker (RolloutWorker): reference to the current rollout worker. - base_env (BaseEnv): env implementing BaseEnv. + worker (RolloutWorker): Reference to the current rollout worker. + base_env (BaseEnv): Env implementing BaseEnv. extra_batch_callback (fn): function to send extra batch data to. policies (dict): Map of policy ids to Policy instances. policy_mapping_fn (func): Function that maps agent ids to policy ids. @@ -259,9 +374,9 @@ def _env_runner(worker, base_env, extra_batch_callback, policies, obs_filters (dict): Map of policy id to filter used to process observations for the policy. clip_rewards (bool): Whether to clip rewards before postprocessing. - pack (bool): Whether to pack multiple episodes into each batch. This - guarantees batches will be exactly `rollout_fragment_length` in - size. + pack_multiple_episodes_in_batch (bool): Whether to pack multiple + episodes into each batch. This guarantees batches will be exactly + `rollout_fragment_length` in size. clip_actions (bool): Whether to clip actions to the space range. callbacks (DefaultCallbacks): User callbacks to run on episode events. tf_sess (Session|None): Optional tensorflow session to use for batching @@ -354,25 +469,47 @@ def _env_runner(worker, base_env, extra_batch_callback, policies, # Process observations and prepare for policy evaluation. t1 = time.time() active_envs, to_eval, outputs = _process_observations( - worker, base_env, policies, batch_builder_pool, active_episodes, - unfiltered_obs, rewards, dones, infos, off_policy_actions, horizon, - preprocessors, obs_filters, rollout_fragment_length, pack, - callbacks, soft_horizon, no_done_at_end, observation_fn) + worker=worker, + base_env=base_env, + policies=policies, + batch_builder_pool=batch_builder_pool, + active_episodes=active_episodes, + unfiltered_obs=unfiltered_obs, + rewards=rewards, + dones=dones, + infos=infos, + horizon=horizon, + preprocessors=preprocessors, + obs_filters=obs_filters, + rollout_fragment_length=rollout_fragment_length, + pack_multiple_episodes_in_batch=pack_multiple_episodes_in_batch, + callbacks=callbacks, + soft_horizon=soft_horizon, + no_done_at_end=no_done_at_end, + observation_fn=observation_fn) perf_stats.processing_time += time.time() - t1 for o in outputs: yield o # Do batched policy eval (accross vectorized envs). t2 = time.time() - eval_results = _do_policy_eval(tf_sess, to_eval, policies, - active_episodes) + eval_results = _do_policy_eval( + to_eval=to_eval, + policies=policies, + active_episodes=active_episodes, + tf_sess=tf_sess) perf_stats.inference_time += time.time() - t2 # Process results and update episode state. t3 = time.time() actions_to_send = _process_policy_eval_results( - to_eval, eval_results, active_episodes, active_envs, - off_policy_actions, policies, clip_actions) + to_eval=to_eval, + eval_results=eval_results, + active_episodes=active_episodes, + active_envs=active_envs, + off_policy_actions=off_policy_actions, + policies=policies, + clip_actions=clip_actions) perf_stats.processing_time += time.time() - t3 # Return computed actions to ready envs. We also send to envs that have @@ -384,15 +521,51 @@ def _env_runner(worker, base_env, extra_batch_callback, policies, def _process_observations( worker, base_env, policies, batch_builder_pool, active_episodes, - unfiltered_obs, rewards, dones, infos, off_policy_actions, horizon, - preprocessors, obs_filters, rollout_fragment_length, pack, callbacks, - soft_horizon, no_done_at_end, observation_fn): + unfiltered_obs, rewards, dones, infos, horizon, preprocessors, + obs_filters, rollout_fragment_length, pack_multiple_episodes_in_batch, + callbacks, soft_horizon, no_done_at_end, observation_fn): """Record new data from the environment and prepare for policy evaluation. + Args: + worker (RolloutWorker): Reference to the current rollout worker. + base_env (BaseEnv): Env implementing BaseEnv. + policies (dict): Map of policy ids to Policy instances. + batch_builder_pool (List[SampleBatchBuilder]): List of pooled + SampleBatchBuilder object for recycling. + active_episodes (defaultdict[str,MultiAgentEpisode]): Mapping from + episode ID to currently ongoing MultiAgentEpisode object. + unfiltered_obs (dict): Doubly keyed dict of env-ids -> agent ids -> + unfiltered observation tensor, returned by a `BaseEnv.poll()` call. + rewards (dict): Doubly keyed dict of env-ids -> agent ids -> + rewards tensor, returned by a `BaseEnv.poll()` call. + dones (dict): Doubly keyed dict of env-ids -> agent ids -> + boolean done flags, returned by a `BaseEnv.poll()` call. + infos (dict): Doubly keyed dict of env-ids -> agent ids -> + info dicts, returned by a `BaseEnv.poll()` call. + horizon (int): Horizon of the episode. + preprocessors (dict): Map of policy id to preprocessor for the + observations prior to filtering. + obs_filters (dict): Map of policy id to filter used to process + observations for the policy. + rollout_fragment_length (int): Number of episode steps before + `SampleBatch` is yielded. Set to infinity to yield complete + episodes. + pack_multiple_episodes_in_batch (bool): Whether to pack multiple + episodes into each batch. This guarantees batches will be exactly + `rollout_fragment_length` in size. + callbacks (DefaultCallbacks): User callbacks to run on episode events. + soft_horizon (bool): Calculate rewards but don't reset the + environment when the horizon is hit. + no_done_at_end (bool): Ignore the done=True at the end of the episode + and instead record done=False. + observation_fn (ObservationFunction): Optional multi-agent + observation func to use for preprocessing observations. + Returns: - active_envs: set of non-terminated env ids - to_eval: map of policy_id to list of agent PolicyEvalData - outputs: list of metrics and samples to return from the sampler + Tuple: + - active_envs: Set of non-terminated env ids. + - to_eval: Map of policy_id to list of agent PolicyEvalData. + - outputs: List of metrics and samples to return from the sampler. """ active_envs = set() @@ -487,7 +660,7 @@ def _process_observations( episode._set_last_raw_obs(agent_id, raw_obs) episode._set_last_info(agent_id, infos[env_id].get(agent_id, {})) - # Record transition info if applicable + # Record transition info if applicable. if (last_observation is not None and infos[env_id].get( agent_id, {}).get("training_enabled", True)): episode.batch_builder.add_values( @@ -515,13 +688,19 @@ def _process_observations( # Cut the batch if we're not packing multiple episodes into one, # or if we've exceeded the requested batch size. if episode.batch_builder.has_pending_agent_data(): + # Sanity check, whether all agents have done=True, if done[__all__] + # is True. if dones[env_id]["__all__"] and not no_done_at_end: episode.batch_builder.check_missing_dones() - if (all_agents_done and not pack) or \ + + # Reached end of episode and we are not allowed to pack the + # next episode into the same SampleBatch -> Build the SampleBatch + # and add it to "outputs". + if (all_agents_done and not pack_multiple_episodes_in_batch) or \ episode.batch_builder.count >= rollout_fragment_length: outputs.append(episode.batch_builder.build_and_reset(episode)) + # Make sure postprocessor stays within one episode. elif all_agents_done: - # Make sure postprocessor stays within one episode episode.batch_builder.postprocess_batch_so_far(episode) if all_agents_done: @@ -584,8 +763,17 @@ def _process_observations( return active_envs, to_eval, outputs -def _do_policy_eval(tf_sess, to_eval, policies, active_episodes): - """Call compute actions on observation batches to get next actions. +def _do_policy_eval(*, to_eval, policies, active_episodes, tf_sess=None): + """Call compute_actions on collected episode/model data to get next action. + + Args: + tf_sess (Optional[tf.Session]): Optional tensorflow session to use for + batching TF policy evaluations. + to_eval (Dict[str,List[PolicyEvalData]]): Mapping of policy IDs to + lists of PolicyEvalData objects. + policies (Dict[str,Policy]): Mapping from policy ID to Policy obj. + active_episodes (defaultdict[str,MultiAgentEpisode]): Mapping from + episode ID to currently ongoing MultiAgentEpisode object. Returns: eval_results: dict of policy to compute_action() outputs. @@ -606,6 +794,8 @@ def _do_policy_eval(tf_sess, to_eval, policies, active_episodes): for policy_id, eval_data in to_eval.items(): rnn_in = [t.rnn_state for t in eval_data] policy = _get_or_raise(policies, policy_id) + # If tf (non eager) AND TFPolicy's compute_action method has not been + # overridden -> Use `policy._build_compute_actions()`. if builder and (policy.compute_actions.__code__ is TFPolicy.compute_actions.__code__): @@ -646,7 +836,7 @@ def _do_policy_eval(tf_sess, to_eval, policies, active_episodes): return eval_results -def _process_policy_eval_results(to_eval, eval_results, active_episodes, +def _process_policy_eval_results(*, to_eval, eval_results, active_episodes, active_envs, off_policy_actions, policies, clip_actions): """Process the output of policy neural network evaluation. @@ -654,8 +844,22 @@ def _process_policy_eval_results(to_eval, eval_results, active_episodes, Records policy evaluation results into the given episode objects and returns replies to send back to agents in the env. + Args: + to_eval (Dict[str,List[PolicyEvalData]]): Mapping of policy IDs to + lists of PolicyEvalData objects. + eval_results (Dict[str,List]): Mapping of policy IDs to list of + actions, rnn-out states, extra-action-fetches dicts. + active_episodes (defaultdict[str,MultiAgentEpisode]): Mapping from + episode ID to currently ongoing MultiAgentEpisode object. + active_envs (Set[int]): Set of non-terminated env ids. + off_policy_actions (dict): Doubly keyed dict of env-ids -> agent ids -> + off-policy-action, returned by a `BaseEnv.poll()` call. + policies (Dict[str,Policy]): Mapping from policy ID to Policy obj. + clip_actions (bool): Whether to clip actions to the action space's + bounds. + Returns: - actions_to_send: nested dict of env id -> agent id -> agent replies. + actions_to_send: Nested dict of env id -> agent id -> agent replies. """ actions_to_send = defaultdict(dict) @@ -711,7 +915,7 @@ def _process_policy_eval_results(to_eval, eval_results, active_episodes, def _fetch_atari_metrics(base_env): """Atari games have multiple logical episodes, one per life. - However for metrics reporting we count full episodes all lives included. + However, for metrics reporting we count full episodes, all lives included. """ unwrapped = base_env.get_unwrapped() if not unwrapped: @@ -734,10 +938,16 @@ def _to_column_format(rnn_state_rows): def _get_or_raise(mapping, policy_id): """Returns a Policy object under key `policy_id` in `mapping`. - Throws an error if `policy_id` cannot be found. + Args: + mapping (dict): The mapping dict from policy id (str) to + actual Policy object. + policy_id (str): The policy ID to lookup. Returns: Policy: The found Policy object. + + Throws: + ValueError: If `policy_id` cannot be found. """ if policy_id not in mapping: raise ValueError( diff --git a/rllib/evaluation/tf_policy_graph.py b/rllib/evaluation/tf_policy_graph.py deleted file mode 100644 index 8d4895473..000000000 --- a/rllib/evaluation/tf_policy_graph.py +++ /dev/null @@ -1,4 +0,0 @@ -from ray.rllib.policy.tf_policy import TFPolicy -from ray.rllib.utils import renamed_class - -TFPolicyGraph = renamed_class(TFPolicy, old_name="TFPolicyGraph") diff --git a/rllib/evaluation/torch_policy_graph.py b/rllib/evaluation/torch_policy_graph.py deleted file mode 100644 index 6eaa7d71b..000000000 --- a/rllib/evaluation/torch_policy_graph.py +++ /dev/null @@ -1,4 +0,0 @@ -from ray.rllib.policy.torch_policy import TorchPolicy -from ray.rllib.utils import renamed_class - -TorchPolicyGraph = renamed_class(TorchPolicy, old_name="TorchPolicyGraph") diff --git a/rllib/policy/sample_batch.py b/rllib/policy/sample_batch.py index 1fb8c2857..da5ec8334 100644 --- a/rllib/policy/sample_batch.py +++ b/rllib/policy/sample_batch.py @@ -65,6 +65,15 @@ class SampleBatch: @staticmethod @PublicAPI def concat_samples(samples): + """Concatenates n data dicts or MultiAgentBatches. + + Args: + samples (List[Dict[np.ndarray]]]): List of dicts of data (numpy). + + Returns: + Union[SampleBatch,MultiAgentBatch]: A new (compressed) SampleBatch/ + MultiAgentBatch. + """ if isinstance(samples[0], MultiAgentBatch): return MultiAgentBatch.concat_samples(samples) out = {} @@ -84,7 +93,10 @@ class SampleBatch: {"a": [1, 2, 3, 4, 5]} """ - assert self.keys() == other.keys(), "must have same columns" + if self.keys() != other.keys(): + raise ValueError( + "SampleBatches to concat must have same columns! {} vs {}". + format(list(self.keys()), list(other.keys()))) out = {} for k in self.keys(): out[k] = concat_aligned([self[k], other[k]]) @@ -117,7 +129,14 @@ class SampleBatch: @PublicAPI def columns(self, keys): - """Returns a list of just the specified columns. + """Returns a list of the batch-data in the specified columns. + + Args: + keys (List[str]): List of column names fo which to return the data. + + Returns: + List[any]: The list of data items ordered by the order of column + names in `keys`. Examples: >>> batch = SampleBatch({"a": [1], "b": [2], "c": [3]}) @@ -143,7 +162,7 @@ class SampleBatch: """Splits this batch's data by `eps_id`. Returns: - list of SampleBatch, one per distinct episode. + List[SampleBatch]: List of batches, one per distinct episode. """ slices = [] @@ -166,7 +185,7 @@ class SampleBatch: def slice(self, start, end): """Returns a slice of the row data of this batch. - Arguments: + Args: start (int): Starting index. end (int): Ending index. @@ -234,23 +253,37 @@ class SampleBatch: @PublicAPI class MultiAgentBatch: """A batch of experiences from multiple policies in the environment. - - Attributes: - policy_batches (dict): Mapping from policy id to a normal SampleBatch - of experiences. Note that these batches may be of different length. - count (int): The number of timesteps in the environment this batch - contains. This will be less than the number of transitions this - batch contains across all policies in total. """ @PublicAPI def __init__(self, policy_batches, count): + """Initializes a MultiAgentBatch object. + + Args: + policy_batches (Dict[str,SampleBatch]): Mapping from policy id + (str) to a SampleBatch of experiences. Note that these batches + may be of different length. + count (int): The number of timesteps in the environment this batch + contains. This will be less than the number of transitions this + batch contains across all policies in total. + """ self.policy_batches = policy_batches self.count = count @staticmethod @PublicAPI def wrap_as_needed(batches, count): + """Returns SampleBatch or MultiAgentBatch, depending on given policies. + + Args: + batches (Dict[str,SampleBatch]): Mapping from policy ID to + SampleBatch. + count (int): A count to use, when returning a MultiAgentBatch. + + Returns: + Union[SampleBatch,MultiAgentBatch]: The single default policy's + SampleBatch or a MultiAgentBatch (more than one policy). + """ if len(batches) == 1 and DEFAULT_POLICY_ID in batches: return batches[DEFAULT_POLICY_ID] return MultiAgentBatch(batches, count) @@ -258,10 +291,23 @@ class MultiAgentBatch: @staticmethod @PublicAPI def concat_samples(samples): + """Concatenates a list of MultiAgentBatches into a new MultiAgentBatch. + + Args: + samples (List[MultiAgentBatch]): List of MultiagentBatch objects + to concatenate. + + Returns: + MultiAgentBatch: A new MultiAgentBatch consisting of the + concatenated inputs. + """ policy_batches = collections.defaultdict(list) total_count = 0 for s in samples: - assert isinstance(s, MultiAgentBatch) + if not isinstance(s, MultiAgentBatch): + raise ValueError( + "`MultiAgentBatch.concat_samples()` can only concat " + "MultiAgentBatch types, not {}!".format(type(s).__name__)) for policy_id, batch in s.policy_batches.items(): policy_batches[policy_id].append(batch) total_count += s.count @@ -272,12 +318,22 @@ class MultiAgentBatch: @PublicAPI def copy(self): + """Deep-copies self into a new MultiAgentBatch. + + Returns: + MultiAgentBatch: The copy of self with deep-copied data. + """ return MultiAgentBatch( {k: v.copy() for (k, v) in self.policy_batches.items()}, self.count) @PublicAPI def total(self): + """Calculates the sum of all step-counts over all policy batches. + + Returns: + int: The sum of counts over all policy batches. + """ ct = 0 for batch in self.policy_batches.values(): ct += batch.count @@ -285,11 +341,24 @@ class MultiAgentBatch: @DeveloperAPI def compress(self, bulk=False, columns=frozenset(["obs", "new_obs"])): + """Compresses each policy batch. + + Args: + bulk (bool): Whether to compress across the batch dimension (0) + as well. If False will compress n separate list items, where n + is the batch size. + columns (Set[str]): Set of column names to compress. + """ for batch in self.policy_batches.values(): batch.compress(bulk=bulk, columns=columns) @DeveloperAPI def decompress_if_needed(self, columns=frozenset(["obs", "new_obs"])): + """Decompresses each policy batch, if already compressed. + + Args: + columns (Set[str]): Set of column names to decompress. + """ for batch in self.policy_batches.values(): batch.decompress_if_needed(columns) return self