[RLlib] Sample batch docs and cleanup. (#8778)

This commit is contained in:
Sven Mika
2020-06-04 22:47:32 +02:00
committed by GitHub
parent aee01133cd
commit 368088be85
12 changed files with 411 additions and 138 deletions
+5 -5
View File
@@ -162,7 +162,7 @@ matrix:
script:
- . ./ci/travis/ci.sh test_wheels
# RLlib: Learning tests (from rllib/tuned_examples/regression_tests/*.yaml).
# RLlib: Learning tests (from rllib/tuned_examples/*.yaml).
- os: linux
env:
- RLLIB_TESTING=1 RLLIB_REGRESSION_TESTS=1
@@ -178,7 +178,7 @@ matrix:
script:
- ./ci/keep_alive bazel test --config=ci --test_output=errors --build_tests_only --test_tag_filters=learning_tests_tf rllib/...
# RLlib: Learning tests with tf=1.x (from rllib/tuned_examples/regression_tests/*.yaml).
# RLlib: Learning tests with tf=1.x (from rllib/tuned_examples/*.yaml).
# Requested by Edi (MS): Test all learning capabilities with tf1.x
- os: linux
env:
@@ -195,7 +195,7 @@ matrix:
script:
- ./ci/keep_alive bazel test --config=ci --test_output=errors --build_tests_only --test_tag_filters=learning_tests_tf rllib/...
# RLlib: Learning tests with torch (from rllib/tuned_examples/regression_tests/*.yaml).
# RLlib: Learning tests with torch (from rllib/tuned_examples/*.yaml).
- os: linux
env:
- RLLIB_TESTING=1 RLLIB_REGRESSION_TESTS_TORCH=1
@@ -250,7 +250,7 @@ matrix:
- ./ci/keep_alive bazel test --config=ci --build_tests_only --test_tag_filters=examples_E,examples_F,examples_G,examples_H,examples_I,examples_J,examples_K,examples_L,examples_M,examples_N,examples_O,examples_P rllib/...
- ./ci/keep_alive bazel test --config=ci --build_tests_only --test_tag_filters=examples_Q,examples_R,examples_S,examples_T,examples_U,examples_V,examples_W,examples_X,examples_Y,examples_Z rllib/...
# RLlib: tests_dir: Everything in rllib/tests/ directory (A-I).
# RLlib: tests_dir: Everything in rllib/tests/ directory (A-L).
- os: linux
env:
- RLLIB_TESTING=1 RLLIB_TESTS_DIR_TESTS_A_TO_L=1
@@ -266,7 +266,7 @@ matrix:
script:
- ./ci/keep_alive bazel test --config=ci --build_tests_only --test_tag_filters=tests_dir_A,tests_dir_B,tests_dir_C,tests_dir_D,tests_dir_E,tests_dir_F,tests_dir_G,tests_dir_H,tests_dir_I,tests_dir_J,tests_dir_K,tests_dir_L rllib/...
# RLlib: tests_dir: Everything in rllib/tests/ directory (J-Z).
# RLlib: tests_dir: Everything in rllib/tests/ directory (M-Z).
- os: linux
env:
- RLLIB_TESTING=1 RLLIB_TESTS_DIR_TESTS_M_TO_Z=1
-5
View File
@@ -6,9 +6,7 @@ from ray.rllib.env.base_env import BaseEnv
from ray.rllib.env.external_env import ExternalEnv
from ray.rllib.env.multi_agent_env import MultiAgentEnv
from ray.rllib.env.vector_env import VectorEnv
from ray.rllib.evaluation.policy_graph import PolicyGraph
from ray.rllib.evaluation.rollout_worker import RolloutWorker
from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph
from ray.rllib.policy.policy import Policy
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.policy.tf_policy import TFPolicy
@@ -61,11 +59,8 @@ _register_all()
__all__ = [
"Policy",
"PolicyGraph",
"TFPolicy",
"TFPolicyGraph",
"RolloutWorker",
"PolicyEvaluator",
"SampleBatch",
"BaseEnv",
"MultiAgentEnv",
+1 -10
View File
@@ -1,22 +1,14 @@
from ray.rllib.evaluation.episode import MultiAgentEpisode
from ray.rllib.evaluation.rollout_worker import RolloutWorker
from ray.rllib.evaluation.policy_evaluator import PolicyEvaluator
from ray.rllib.evaluation.policy_graph import PolicyGraph
from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph
from ray.rllib.evaluation.torch_policy_graph import TorchPolicyGraph
from ray.rllib.evaluation.sample_batch import MultiAgentBatch
from ray.rllib.evaluation.sample_batch_builder import (
SampleBatchBuilder, MultiAgentSampleBatchBuilder)
from ray.rllib.evaluation.sampler import SyncSampler, AsyncSampler
from ray.rllib.evaluation.postprocessing import compute_advantages
from ray.rllib.evaluation.metrics import collect_metrics
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.policy.sample_batch import SampleBatch, MultiAgentBatch
__all__ = [
"RolloutWorker",
"PolicyGraph",
"TFPolicyGraph",
"TorchPolicyGraph",
"SampleBatch",
"MultiAgentBatch",
"SampleBatchBuilder",
@@ -26,5 +18,4 @@ __all__ = [
"compute_advantages",
"collect_metrics",
"MultiAgentEpisode",
"PolicyEvaluator",
]
-5
View File
@@ -1,5 +0,0 @@
from ray.rllib.utils import renamed_class
from ray.rllib.evaluation import RolloutWorker
PolicyEvaluator = renamed_class(
RolloutWorker, old_name="rllib.evaluation.PolicyEvaluator")
-4
View File
@@ -1,4 +0,0 @@
from ray.rllib.policy.policy import Policy
from ray.rllib.utils import renamed_class
PolicyGraph = renamed_class(Policy, old_name="PolicyGraph")
+30 -24
View File
@@ -420,11 +420,16 @@ class RolloutWorker(ParallelIteratorWorker):
remote_env_batch_wait_ms=remote_env_batch_wait_ms)
self.num_envs = num_envs
# `truncate_episodes`: Allow a batch to contain more than one episode
# (fragments) and always make the batch `rollout_fragment_length`
# long.
if self.batch_mode == "truncate_episodes":
pack_episodes = True
pack = True
# `complete_episodes`: Never cut episodes and sampler will return
# exactly one (complete) episode per poll.
elif self.batch_mode == "complete_episodes":
rollout_fragment_length = float("inf") # never cut episodes
pack_episodes = False # sampler will return 1 episode per poll
rollout_fragment_length = float("inf")
pack = False
else:
raise ValueError("Unsupported batch mode: {}".format(
self.batch_mode))
@@ -450,37 +455,38 @@ class RolloutWorker(ParallelIteratorWorker):
if sample_async:
self.sampler = AsyncSampler(
self,
self.async_env,
self.policy_map,
policy_mapping_fn,
self.preprocessors,
self.filters,
clip_rewards,
rollout_fragment_length,
self.callbacks,
worker=self,
env=self.async_env,
policies=self.policy_map,
policy_mapping_fn=policy_mapping_fn,
preprocessors=self.preprocessors,
obs_filters=self.filters,
clip_rewards=clip_rewards,
rollout_fragment_length=rollout_fragment_length,
callbacks=self.callbacks,
horizon=episode_horizon,
pack=pack_episodes,
pack_multiple_episodes_in_batch=pack,
tf_sess=self.tf_sess,
clip_actions=clip_actions,
blackhole_outputs="simulation" in input_evaluation,
soft_horizon=soft_horizon,
no_done_at_end=no_done_at_end,
observation_fn=observation_fn)
# Start the Sampler thread.
self.sampler.start()
else:
self.sampler = SyncSampler(
self,
self.async_env,
self.policy_map,
policy_mapping_fn,
self.preprocessors,
self.filters,
clip_rewards,
rollout_fragment_length,
self.callbacks,
worker=self,
env=self.async_env,
policies=self.policy_map,
policy_mapping_fn=policy_mapping_fn,
preprocessors=self.preprocessors,
obs_filters=self.filters,
clip_rewards=clip_rewards,
rollout_fragment_length=rollout_fragment_length,
callbacks=self.callbacks,
horizon=episode_horizon,
pack=pack_episodes,
pack_multiple_episodes_in_batch=pack,
tf_sess=self.tf_sess,
clip_actions=clip_actions,
soft_horizon=soft_horizon,
@@ -503,7 +509,7 @@ class RolloutWorker(ParallelIteratorWorker):
This method must be implemented by subclasses.
Returns:
SampleBatch|MultiAgentBatch: A columnar batch of experiences
Union[SampleBatch,MultiAgentBatch]: A columnar batch of experiences
(e.g., tensors), or a multi-agent batch.
Examples:
-7
View File
@@ -1,7 +0,0 @@
from ray.rllib.policy.sample_batch import SampleBatch, MultiAgentBatch
from ray.rllib.utils import renamed_class
SampleBatch = renamed_class(
SampleBatch, old_name="rllib.evaluation.SampleBatch")
MultiAgentBatch = renamed_class(
MultiAgentBatch, old_name="rllib.evaluation.MultiAgentBatch")
+43 -17
View File
@@ -75,30 +75,47 @@ class MultiAgentSampleBatchBuilder:
def __init__(self, policy_map, clip_rewards, callbacks):
"""Initialize a MultiAgentSampleBatchBuilder.
Arguments:
policy_map (dict): Maps policy ids to policy instances.
clip_rewards (bool): Whether to clip rewards before postprocessing.
Args:
policy_map (Dict[str,Policy]): Maps policy ids to policy instances.
clip_rewards (Union[bool,float]): Whether to clip rewards before
postprocessing (at +/-1.0) or the actual value to +/- clip.
callbacks (DefaultCallbacks): RLlib callbacks.
"""
self.policy_map = policy_map
self.clip_rewards = clip_rewards
# Build the Policies' SampleBatchBuilders.
self.policy_builders = {
k: SampleBatchBuilder()
for k in policy_map.keys()
}
# Whenever we observe a new agent, add a new SampleBatchBuilder for
# this agent.
self.agent_builders = {}
# Internal agent-to-policy map.
self.agent_to_policy = {}
self.callbacks = callbacks
self.count = 0 # increment this manually
# Number of "inference" steps taken in the environment.
# Regardless of the number of agents involved in each of these steps.
self.count = 0
def total(self):
"""Returns summed number of steps across all agent buffers."""
"""Returns the total number of steps taken in the env (all agents).
Returns:
int: The number of steps taken in total in the environment over all
agents.
"""
return sum(a.count for a in self.agent_builders.values())
def has_pending_agent_data(self):
"""Returns whether there is pending unprocessed data."""
"""Returns whether there is pending unprocessed data.
Returns:
bool: True if there is at least one per-agent builder (with data
in it).
"""
return len(self.agent_builders) > 0
@@ -115,32 +132,37 @@ class MultiAgentSampleBatchBuilder:
if agent_id not in self.agent_builders:
self.agent_builders[agent_id] = SampleBatchBuilder()
self.agent_to_policy[agent_id] = policy_id
builder = self.agent_builders[agent_id]
builder.add_values(**values)
self.agent_builders[agent_id].add_values(**values)
def postprocess_batch_so_far(self, episode):
def postprocess_batch_so_far(self, episode=None):
"""Apply policy postprocessors to any unprocessed rows.
This pushes the postprocessed per-agent batches onto the per-policy
builders, clearing per-agent state.
Args:
episode (Optional[MultiAgentEpisode]): Current MultiAgentEpisode
object.
episode (Optional[MultiAgentEpisode]): The Episode object that
holds this MultiAgentBatchBuilder object.
"""
# Materialize the batches so far
# Materialize the batches so far.
pre_batches = {}
for agent_id, builder in self.agent_builders.items():
pre_batches[agent_id] = (
self.policy_map[self.agent_to_policy[agent_id]],
builder.build_and_reset())
# Apply postprocessor
# Apply postprocessor.
post_batches = {}
if self.clip_rewards:
if self.clip_rewards is True:
for _, (_, pre_batch) in pre_batches.items():
pre_batch["rewards"] = np.sign(pre_batch["rewards"])
elif self.clip_rewards:
for _, (_, pre_batch) in pre_batches.items():
pre_batch["rewards"] = np.clip(
pre_batch["rewards"],
a_min=-self.clip_rewards,
a_max=self.clip_rewards)
for agent_id, (_, pre_batch) in pre_batches.items():
other_batches = pre_batches.copy()
del other_batches[agent_id]
@@ -193,15 +215,19 @@ class MultiAgentSampleBatchBuilder:
"Alternatively, set no_done_at_end=True to allow this.")
@DeveloperAPI
def build_and_reset(self, episode):
def build_and_reset(self, episode=None):
"""Returns the accumulated sample batches for each policy.
Any unprocessed rows will be first postprocessed with a policy
postprocessor. The internal state of this builder will be reset.
Args:
episode (Optional[MultiAgentEpisode]): Current MultiAgentEpisode
object.
episode (Optional[MultiAgentEpisode]): The Episode object that
holds this MultiAgentBatchBuilder object or None.
Returns:
MultiAgentBatch: Returns the accumulated sample batches for each
policy.
"""
self.postprocess_batch_so_far(episode)
+251 -41
View File
@@ -1,3 +1,4 @@
from abc import abstractmethod, ABCMeta
from collections import defaultdict, namedtuple
import logging
import numpy as np
@@ -16,7 +17,7 @@ from ray.rllib.env.base_env import BaseEnv, ASYNC_RESET_RETURN
from ray.rllib.env.atari_wrappers import get_wrapper_by_cls, MonitorEnv
from ray.rllib.offline import InputReader
from ray.rllib.utils import try_import_tree
from ray.rllib.utils.annotations import override
from ray.rllib.utils.annotations import override, DeveloperAPI
from ray.rllib.utils.debug import summarize
from ray.rllib.utils.spaces.space_utils import flatten_to_single_ndarray, \
unbatch
@@ -49,7 +50,8 @@ class PerfStats:
}
class SamplerInput(InputReader):
@DeveloperAPI
class SamplerInput(InputReader, metaclass=ABCMeta):
"""Reads input experiences from an existing sampler."""
@override(InputReader)
@@ -61,9 +63,29 @@ class SamplerInput(InputReader):
else:
return batches[0]
@abstractmethod
@DeveloperAPI
def get_data(self):
raise NotImplementedError
@abstractmethod
@DeveloperAPI
def get_metrics(self):
raise NotImplementedError
@abstractmethod
@DeveloperAPI
def get_extra_batches(self):
raise NotImplementedError
@DeveloperAPI
class SyncSampler(SamplerInput):
"""Sync SamplerInput that collects experiences when `get_data()` is called.
"""
def __init__(self,
*,
worker,
env,
policies,
@@ -74,12 +96,50 @@ class SyncSampler(SamplerInput):
rollout_fragment_length,
callbacks,
horizon=None,
pack=False,
pack_multiple_episodes_in_batch=False,
tf_sess=None,
clip_actions=True,
soft_horizon=False,
no_done_at_end=False,
observation_fn=None):
"""Initializes a SyncSampler object.
Args:
worker (RolloutWorker): The RolloutWorker that will use this
Sampler for sampling.
env (Env): Any Env object. Will be converted into an RLlib BaseEnv.
policies (Dict[str,Policy]): Mapping from policy ID to Policy obj.
policy_mapping_fn (callable): Callable that takes an agent ID and
returns a Policy object.
preprocessors (Dict[str,Preprocessor]): Mapping from policy ID to
Preprocessor object for the observations prior to filtering.
obs_filters (Dict[str,Filter]): Mapping from policy ID to
env Filter object.
clip_rewards (Union[bool,float]): True for +/-1.0 clipping, actual
float value for +/- value clipping. False for no clipping.
rollout_fragment_length (int): The length of a fragment to collect
before building a SampleBatch from the data and resetting
the SampleBatchBuilder object.
callbacks (Callbacks): The Callbacks object to use when episode
events happen during rollout.
horizon (Optional[int]): Hard-reset the Env
pack_multiple_episodes_in_batch (bool): Whether to pack multiple
episodes into each batch. This guarantees batches will be
exactly `rollout_fragment_length` in size.
tf_sess (Optional[tf.Session]): A tf.Session object to use (only if
framework=tf).
clip_actions (bool): Whether to clip actions according to the
given action_space's bounds.
soft_horizon (bool): If True, calculate bootstrapped values as if
episode had ended, but don't physically reset the environment
when the horizon is hit.
no_done_at_end (bool): Ignore the done=True at the end of the
episode and instead record done=False.
observation_fn (Optional[ObservationFunction]): Optional
multi-agent observation func to use for preprocessing
observations.
"""
self.base_env = BaseEnv.to_base_env(env)
self.rollout_fragment_length = rollout_fragment_length
self.horizon = horizon
@@ -89,14 +149,16 @@ class SyncSampler(SamplerInput):
self.obs_filters = obs_filters
self.extra_batches = queue.Queue()
self.perf_stats = PerfStats()
# Create the rollout generator to use for calls to `get_data()`.
self.rollout_provider = _env_runner(
worker, self.base_env, self.extra_batches.put, self.policies,
self.policy_mapping_fn, self.rollout_fragment_length, self.horizon,
self.preprocessors, self.obs_filters, clip_rewards, clip_actions,
pack, callbacks, tf_sess, self.perf_stats, soft_horizon,
no_done_at_end, observation_fn)
pack_multiple_episodes_in_batch, callbacks, tf_sess,
self.perf_stats, soft_horizon, no_done_at_end, observation_fn)
self.metrics_queue = queue.Queue()
@override(SamplerInput)
def get_data(self):
while True:
item = next(self.rollout_provider)
@@ -105,6 +167,7 @@ class SyncSampler(SamplerInput):
else:
return item
@override(SamplerInput)
def get_metrics(self):
completed = []
while True:
@@ -115,6 +178,7 @@ class SyncSampler(SamplerInput):
break
return completed
@override(SamplerInput)
def get_extra_batches(self):
extra = []
while True:
@@ -125,8 +189,16 @@ class SyncSampler(SamplerInput):
return extra
@DeveloperAPI
class AsyncSampler(threading.Thread, SamplerInput):
"""Async SamplerInput that collects experiences in thread and queues them.
Once started, experiences are continuously collected and put into a Queue,
from where they can be unqueued by the caller of `get_data()`.
"""
def __init__(self,
*,
worker,
env,
policies,
@@ -137,13 +209,52 @@ class AsyncSampler(threading.Thread, SamplerInput):
rollout_fragment_length,
callbacks,
horizon=None,
pack=False,
pack_multiple_episodes_in_batch=False,
tf_sess=None,
clip_actions=True,
blackhole_outputs=False,
soft_horizon=False,
no_done_at_end=False,
observation_fn=None):
"""Initializes a AsyncSampler object.
Args:
worker (RolloutWorker): The RolloutWorker that will use this
Sampler for sampling.
env (Env): Any Env object. Will be converted into an RLlib BaseEnv.
policies (Dict[str,Policy]): Mapping from policy ID to Policy obj.
policy_mapping_fn (callable): Callable that takes an agent ID and
returns a Policy object.
preprocessors (Dict[str,Preprocessor]): Mapping from policy ID to
Preprocessor object for the observations prior to filtering.
obs_filters (Dict[str,Filter]): Mapping from policy ID to
env Filter object.
clip_rewards (Union[bool,float]): True for +/-1.0 clipping, actual
float value for +/- value clipping. False for no clipping.
rollout_fragment_length (int): The length of a fragment to collect
before building a SampleBatch from the data and resetting
the SampleBatchBuilder object.
callbacks (Callbacks): The Callbacks object to use when episode
events happen during rollout.
horizon (Optional[int]): Hard-reset the Env
pack_multiple_episodes_in_batch (bool): Whether to pack multiple
episodes into each batch. This guarantees batches will be
exactly `rollout_fragment_length` in size.
tf_sess (Optional[tf.Session]): A tf.Session object to use (only if
framework=tf).
clip_actions (bool): Whether to clip actions according to the
given action_space's bounds.
blackhole_outputs (bool): Whether to collect samples, but then
not further process or store them (throw away all samples).
soft_horizon (bool): If True, calculate bootstrapped values as if
episode had ended, but don't physically reset the environment
when the horizon is hit.
no_done_at_end (bool): Ignore the done=True at the end of the
episode and instead record done=False.
observation_fn (Optional[ObservationFunction]): Optional
multi-agent observation func to use for preprocessing
observations.
"""
for _, f in obs_filters.items():
assert getattr(f, "is_concurrent", False), \
"Observation Filter must support concurrent updates."
@@ -161,7 +272,7 @@ class AsyncSampler(threading.Thread, SamplerInput):
self.obs_filters = obs_filters
self.clip_rewards = clip_rewards
self.daemon = True
self.pack = pack
self.pack_multiple_episodes_in_batch = pack_multiple_episodes_in_batch
self.tf_sess = tf_sess
self.callbacks = callbacks
self.clip_actions = clip_actions
@@ -172,6 +283,7 @@ class AsyncSampler(threading.Thread, SamplerInput):
self.shutdown = False
self.observation_fn = observation_fn
@override(threading.Thread)
def run(self):
try:
self._run()
@@ -191,9 +303,9 @@ class AsyncSampler(threading.Thread, SamplerInput):
self.worker, self.base_env, extra_batches_putter, self.policies,
self.policy_mapping_fn, self.rollout_fragment_length, self.horizon,
self.preprocessors, self.obs_filters, self.clip_rewards,
self.clip_actions, self.pack, self.callbacks, self.tf_sess,
self.perf_stats, self.soft_horizon, self.no_done_at_end,
self.observation_fn)
self.clip_actions, self.pack_multiple_episodes_in_batch,
self.callbacks, self.tf_sess, self.perf_stats, self.soft_horizon,
self.no_done_at_end, self.observation_fn)
while not self.shutdown:
# The timeout variable exists because apparently, if one worker
# dies, the other workers won't die with it, unless the timeout is
@@ -204,6 +316,7 @@ class AsyncSampler(threading.Thread, SamplerInput):
else:
queue_putter(item)
@override(SamplerInput)
def get_data(self):
if not self.is_alive():
raise RuntimeError("Sampling thread has died")
@@ -215,6 +328,7 @@ class AsyncSampler(threading.Thread, SamplerInput):
return rollout
@override(SamplerInput)
def get_metrics(self):
completed = []
while True:
@@ -225,6 +339,7 @@ class AsyncSampler(threading.Thread, SamplerInput):
break
return completed
@override(SamplerInput)
def get_extra_batches(self):
extra = []
while True:
@@ -237,14 +352,14 @@ class AsyncSampler(threading.Thread, SamplerInput):
def _env_runner(worker, base_env, extra_batch_callback, policies,
policy_mapping_fn, rollout_fragment_length, horizon,
preprocessors, obs_filters, clip_rewards, clip_actions, pack,
callbacks, tf_sess, perf_stats, soft_horizon, no_done_at_end,
observation_fn):
preprocessors, obs_filters, clip_rewards, clip_actions,
pack_multiple_episodes_in_batch, callbacks, tf_sess,
perf_stats, soft_horizon, no_done_at_end, observation_fn):
"""This implements the common experience collection logic.
Args:
worker (RolloutWorker): reference to the current rollout worker.
base_env (BaseEnv): env implementing BaseEnv.
worker (RolloutWorker): Reference to the current rollout worker.
base_env (BaseEnv): Env implementing BaseEnv.
extra_batch_callback (fn): function to send extra batch data to.
policies (dict): Map of policy ids to Policy instances.
policy_mapping_fn (func): Function that maps agent ids to policy ids.
@@ -259,9 +374,9 @@ def _env_runner(worker, base_env, extra_batch_callback, policies,
obs_filters (dict): Map of policy id to filter used to process
observations for the policy.
clip_rewards (bool): Whether to clip rewards before postprocessing.
pack (bool): Whether to pack multiple episodes into each batch. This
guarantees batches will be exactly `rollout_fragment_length` in
size.
pack_multiple_episodes_in_batch (bool): Whether to pack multiple
episodes into each batch. This guarantees batches will be exactly
`rollout_fragment_length` in size.
clip_actions (bool): Whether to clip actions to the space range.
callbacks (DefaultCallbacks): User callbacks to run on episode events.
tf_sess (Session|None): Optional tensorflow session to use for batching
@@ -354,25 +469,47 @@ def _env_runner(worker, base_env, extra_batch_callback, policies,
# Process observations and prepare for policy evaluation.
t1 = time.time()
active_envs, to_eval, outputs = _process_observations(
worker, base_env, policies, batch_builder_pool, active_episodes,
unfiltered_obs, rewards, dones, infos, off_policy_actions, horizon,
preprocessors, obs_filters, rollout_fragment_length, pack,
callbacks, soft_horizon, no_done_at_end, observation_fn)
worker=worker,
base_env=base_env,
policies=policies,
batch_builder_pool=batch_builder_pool,
active_episodes=active_episodes,
unfiltered_obs=unfiltered_obs,
rewards=rewards,
dones=dones,
infos=infos,
horizon=horizon,
preprocessors=preprocessors,
obs_filters=obs_filters,
rollout_fragment_length=rollout_fragment_length,
pack_multiple_episodes_in_batch=pack_multiple_episodes_in_batch,
callbacks=callbacks,
soft_horizon=soft_horizon,
no_done_at_end=no_done_at_end,
observation_fn=observation_fn)
perf_stats.processing_time += time.time() - t1
for o in outputs:
yield o
# Do batched policy eval (accross vectorized envs).
t2 = time.time()
eval_results = _do_policy_eval(tf_sess, to_eval, policies,
active_episodes)
eval_results = _do_policy_eval(
to_eval=to_eval,
policies=policies,
active_episodes=active_episodes,
tf_sess=tf_sess)
perf_stats.inference_time += time.time() - t2
# Process results and update episode state.
t3 = time.time()
actions_to_send = _process_policy_eval_results(
to_eval, eval_results, active_episodes, active_envs,
off_policy_actions, policies, clip_actions)
to_eval=to_eval,
eval_results=eval_results,
active_episodes=active_episodes,
active_envs=active_envs,
off_policy_actions=off_policy_actions,
policies=policies,
clip_actions=clip_actions)
perf_stats.processing_time += time.time() - t3
# Return computed actions to ready envs. We also send to envs that have
@@ -384,15 +521,51 @@ def _env_runner(worker, base_env, extra_batch_callback, policies,
def _process_observations(
worker, base_env, policies, batch_builder_pool, active_episodes,
unfiltered_obs, rewards, dones, infos, off_policy_actions, horizon,
preprocessors, obs_filters, rollout_fragment_length, pack, callbacks,
soft_horizon, no_done_at_end, observation_fn):
unfiltered_obs, rewards, dones, infos, horizon, preprocessors,
obs_filters, rollout_fragment_length, pack_multiple_episodes_in_batch,
callbacks, soft_horizon, no_done_at_end, observation_fn):
"""Record new data from the environment and prepare for policy evaluation.
Args:
worker (RolloutWorker): Reference to the current rollout worker.
base_env (BaseEnv): Env implementing BaseEnv.
policies (dict): Map of policy ids to Policy instances.
batch_builder_pool (List[SampleBatchBuilder]): List of pooled
SampleBatchBuilder object for recycling.
active_episodes (defaultdict[str,MultiAgentEpisode]): Mapping from
episode ID to currently ongoing MultiAgentEpisode object.
unfiltered_obs (dict): Doubly keyed dict of env-ids -> agent ids ->
unfiltered observation tensor, returned by a `BaseEnv.poll()` call.
rewards (dict): Doubly keyed dict of env-ids -> agent ids ->
rewards tensor, returned by a `BaseEnv.poll()` call.
dones (dict): Doubly keyed dict of env-ids -> agent ids ->
boolean done flags, returned by a `BaseEnv.poll()` call.
infos (dict): Doubly keyed dict of env-ids -> agent ids ->
info dicts, returned by a `BaseEnv.poll()` call.
horizon (int): Horizon of the episode.
preprocessors (dict): Map of policy id to preprocessor for the
observations prior to filtering.
obs_filters (dict): Map of policy id to filter used to process
observations for the policy.
rollout_fragment_length (int): Number of episode steps before
`SampleBatch` is yielded. Set to infinity to yield complete
episodes.
pack_multiple_episodes_in_batch (bool): Whether to pack multiple
episodes into each batch. This guarantees batches will be exactly
`rollout_fragment_length` in size.
callbacks (DefaultCallbacks): User callbacks to run on episode events.
soft_horizon (bool): Calculate rewards but don't reset the
environment when the horizon is hit.
no_done_at_end (bool): Ignore the done=True at the end of the episode
and instead record done=False.
observation_fn (ObservationFunction): Optional multi-agent
observation func to use for preprocessing observations.
Returns:
active_envs: set of non-terminated env ids
to_eval: map of policy_id to list of agent PolicyEvalData
outputs: list of metrics and samples to return from the sampler
Tuple:
- active_envs: Set of non-terminated env ids.
- to_eval: Map of policy_id to list of agent PolicyEvalData.
- outputs: List of metrics and samples to return from the sampler.
"""
active_envs = set()
@@ -487,7 +660,7 @@ def _process_observations(
episode._set_last_raw_obs(agent_id, raw_obs)
episode._set_last_info(agent_id, infos[env_id].get(agent_id, {}))
# Record transition info if applicable
# Record transition info if applicable.
if (last_observation is not None and infos[env_id].get(
agent_id, {}).get("training_enabled", True)):
episode.batch_builder.add_values(
@@ -515,13 +688,19 @@ def _process_observations(
# Cut the batch if we're not packing multiple episodes into one,
# or if we've exceeded the requested batch size.
if episode.batch_builder.has_pending_agent_data():
# Sanity check, whether all agents have done=True, if done[__all__]
# is True.
if dones[env_id]["__all__"] and not no_done_at_end:
episode.batch_builder.check_missing_dones()
if (all_agents_done and not pack) or \
# Reached end of episode and we are not allowed to pack the
# next episode into the same SampleBatch -> Build the SampleBatch
# and add it to "outputs".
if (all_agents_done and not pack_multiple_episodes_in_batch) or \
episode.batch_builder.count >= rollout_fragment_length:
outputs.append(episode.batch_builder.build_and_reset(episode))
# Make sure postprocessor stays within one episode.
elif all_agents_done:
# Make sure postprocessor stays within one episode
episode.batch_builder.postprocess_batch_so_far(episode)
if all_agents_done:
@@ -584,8 +763,17 @@ def _process_observations(
return active_envs, to_eval, outputs
def _do_policy_eval(tf_sess, to_eval, policies, active_episodes):
"""Call compute actions on observation batches to get next actions.
def _do_policy_eval(*, to_eval, policies, active_episodes, tf_sess=None):
"""Call compute_actions on collected episode/model data to get next action.
Args:
tf_sess (Optional[tf.Session]): Optional tensorflow session to use for
batching TF policy evaluations.
to_eval (Dict[str,List[PolicyEvalData]]): Mapping of policy IDs to
lists of PolicyEvalData objects.
policies (Dict[str,Policy]): Mapping from policy ID to Policy obj.
active_episodes (defaultdict[str,MultiAgentEpisode]): Mapping from
episode ID to currently ongoing MultiAgentEpisode object.
Returns:
eval_results: dict of policy to compute_action() outputs.
@@ -606,6 +794,8 @@ def _do_policy_eval(tf_sess, to_eval, policies, active_episodes):
for policy_id, eval_data in to_eval.items():
rnn_in = [t.rnn_state for t in eval_data]
policy = _get_or_raise(policies, policy_id)
# If tf (non eager) AND TFPolicy's compute_action method has not been
# overridden -> Use `policy._build_compute_actions()`.
if builder and (policy.compute_actions.__code__ is
TFPolicy.compute_actions.__code__):
@@ -646,7 +836,7 @@ def _do_policy_eval(tf_sess, to_eval, policies, active_episodes):
return eval_results
def _process_policy_eval_results(to_eval, eval_results, active_episodes,
def _process_policy_eval_results(*, to_eval, eval_results, active_episodes,
active_envs, off_policy_actions, policies,
clip_actions):
"""Process the output of policy neural network evaluation.
@@ -654,8 +844,22 @@ def _process_policy_eval_results(to_eval, eval_results, active_episodes,
Records policy evaluation results into the given episode objects and
returns replies to send back to agents in the env.
Args:
to_eval (Dict[str,List[PolicyEvalData]]): Mapping of policy IDs to
lists of PolicyEvalData objects.
eval_results (Dict[str,List]): Mapping of policy IDs to list of
actions, rnn-out states, extra-action-fetches dicts.
active_episodes (defaultdict[str,MultiAgentEpisode]): Mapping from
episode ID to currently ongoing MultiAgentEpisode object.
active_envs (Set[int]): Set of non-terminated env ids.
off_policy_actions (dict): Doubly keyed dict of env-ids -> agent ids ->
off-policy-action, returned by a `BaseEnv.poll()` call.
policies (Dict[str,Policy]): Mapping from policy ID to Policy obj.
clip_actions (bool): Whether to clip actions to the action space's
bounds.
Returns:
actions_to_send: nested dict of env id -> agent id -> agent replies.
actions_to_send: Nested dict of env id -> agent id -> agent replies.
"""
actions_to_send = defaultdict(dict)
@@ -711,7 +915,7 @@ def _process_policy_eval_results(to_eval, eval_results, active_episodes,
def _fetch_atari_metrics(base_env):
"""Atari games have multiple logical episodes, one per life.
However for metrics reporting we count full episodes all lives included.
However, for metrics reporting we count full episodes, all lives included.
"""
unwrapped = base_env.get_unwrapped()
if not unwrapped:
@@ -734,10 +938,16 @@ def _to_column_format(rnn_state_rows):
def _get_or_raise(mapping, policy_id):
"""Returns a Policy object under key `policy_id` in `mapping`.
Throws an error if `policy_id` cannot be found.
Args:
mapping (dict): The mapping dict from policy id (str) to
actual Policy object.
policy_id (str): The policy ID to lookup.
Returns:
Policy: The found Policy object.
Throws:
ValueError: If `policy_id` cannot be found.
"""
if policy_id not in mapping:
raise ValueError(
-4
View File
@@ -1,4 +0,0 @@
from ray.rllib.policy.tf_policy import TFPolicy
from ray.rllib.utils import renamed_class
TFPolicyGraph = renamed_class(TFPolicy, old_name="TFPolicyGraph")
-4
View File
@@ -1,4 +0,0 @@
from ray.rllib.policy.torch_policy import TorchPolicy
from ray.rllib.utils import renamed_class
TorchPolicyGraph = renamed_class(TorchPolicy, old_name="TorchPolicyGraph")
+81 -12
View File
@@ -65,6 +65,15 @@ class SampleBatch:
@staticmethod
@PublicAPI
def concat_samples(samples):
"""Concatenates n data dicts or MultiAgentBatches.
Args:
samples (List[Dict[np.ndarray]]]): List of dicts of data (numpy).
Returns:
Union[SampleBatch,MultiAgentBatch]: A new (compressed) SampleBatch/
MultiAgentBatch.
"""
if isinstance(samples[0], MultiAgentBatch):
return MultiAgentBatch.concat_samples(samples)
out = {}
@@ -84,7 +93,10 @@ class SampleBatch:
{"a": [1, 2, 3, 4, 5]}
"""
assert self.keys() == other.keys(), "must have same columns"
if self.keys() != other.keys():
raise ValueError(
"SampleBatches to concat must have same columns! {} vs {}".
format(list(self.keys()), list(other.keys())))
out = {}
for k in self.keys():
out[k] = concat_aligned([self[k], other[k]])
@@ -117,7 +129,14 @@ class SampleBatch:
@PublicAPI
def columns(self, keys):
"""Returns a list of just the specified columns.
"""Returns a list of the batch-data in the specified columns.
Args:
keys (List[str]): List of column names fo which to return the data.
Returns:
List[any]: The list of data items ordered by the order of column
names in `keys`.
Examples:
>>> batch = SampleBatch({"a": [1], "b": [2], "c": [3]})
@@ -143,7 +162,7 @@ class SampleBatch:
"""Splits this batch's data by `eps_id`.
Returns:
list of SampleBatch, one per distinct episode.
List[SampleBatch]: List of batches, one per distinct episode.
"""
slices = []
@@ -166,7 +185,7 @@ class SampleBatch:
def slice(self, start, end):
"""Returns a slice of the row data of this batch.
Arguments:
Args:
start (int): Starting index.
end (int): Ending index.
@@ -234,23 +253,37 @@ class SampleBatch:
@PublicAPI
class MultiAgentBatch:
"""A batch of experiences from multiple policies in the environment.
Attributes:
policy_batches (dict): Mapping from policy id to a normal SampleBatch
of experiences. Note that these batches may be of different length.
count (int): The number of timesteps in the environment this batch
contains. This will be less than the number of transitions this
batch contains across all policies in total.
"""
@PublicAPI
def __init__(self, policy_batches, count):
"""Initializes a MultiAgentBatch object.
Args:
policy_batches (Dict[str,SampleBatch]): Mapping from policy id
(str) to a SampleBatch of experiences. Note that these batches
may be of different length.
count (int): The number of timesteps in the environment this batch
contains. This will be less than the number of transitions this
batch contains across all policies in total.
"""
self.policy_batches = policy_batches
self.count = count
@staticmethod
@PublicAPI
def wrap_as_needed(batches, count):
"""Returns SampleBatch or MultiAgentBatch, depending on given policies.
Args:
batches (Dict[str,SampleBatch]): Mapping from policy ID to
SampleBatch.
count (int): A count to use, when returning a MultiAgentBatch.
Returns:
Union[SampleBatch,MultiAgentBatch]: The single default policy's
SampleBatch or a MultiAgentBatch (more than one policy).
"""
if len(batches) == 1 and DEFAULT_POLICY_ID in batches:
return batches[DEFAULT_POLICY_ID]
return MultiAgentBatch(batches, count)
@@ -258,10 +291,23 @@ class MultiAgentBatch:
@staticmethod
@PublicAPI
def concat_samples(samples):
"""Concatenates a list of MultiAgentBatches into a new MultiAgentBatch.
Args:
samples (List[MultiAgentBatch]): List of MultiagentBatch objects
to concatenate.
Returns:
MultiAgentBatch: A new MultiAgentBatch consisting of the
concatenated inputs.
"""
policy_batches = collections.defaultdict(list)
total_count = 0
for s in samples:
assert isinstance(s, MultiAgentBatch)
if not isinstance(s, MultiAgentBatch):
raise ValueError(
"`MultiAgentBatch.concat_samples()` can only concat "
"MultiAgentBatch types, not {}!".format(type(s).__name__))
for policy_id, batch in s.policy_batches.items():
policy_batches[policy_id].append(batch)
total_count += s.count
@@ -272,12 +318,22 @@ class MultiAgentBatch:
@PublicAPI
def copy(self):
"""Deep-copies self into a new MultiAgentBatch.
Returns:
MultiAgentBatch: The copy of self with deep-copied data.
"""
return MultiAgentBatch(
{k: v.copy()
for (k, v) in self.policy_batches.items()}, self.count)
@PublicAPI
def total(self):
"""Calculates the sum of all step-counts over all policy batches.
Returns:
int: The sum of counts over all policy batches.
"""
ct = 0
for batch in self.policy_batches.values():
ct += batch.count
@@ -285,11 +341,24 @@ class MultiAgentBatch:
@DeveloperAPI
def compress(self, bulk=False, columns=frozenset(["obs", "new_obs"])):
"""Compresses each policy batch.
Args:
bulk (bool): Whether to compress across the batch dimension (0)
as well. If False will compress n separate list items, where n
is the batch size.
columns (Set[str]): Set of column names to compress.
"""
for batch in self.policy_batches.values():
batch.compress(bulk=bulk, columns=columns)
@DeveloperAPI
def decompress_if_needed(self, columns=frozenset(["obs", "new_obs"])):
"""Decompresses each policy batch, if already compressed.
Args:
columns (Set[str]): Set of column names to decompress.
"""
for batch in self.policy_batches.values():
batch.decompress_if_needed(columns)
return self