mirror of
https://github.com/wassname/ray.git
synced 2026-06-27 19:32:11 +08:00
[RLlib] Sample batch docs and cleanup. (#8778)
This commit is contained in:
+5
-5
@@ -162,7 +162,7 @@ matrix:
|
||||
script:
|
||||
- . ./ci/travis/ci.sh test_wheels
|
||||
|
||||
# RLlib: Learning tests (from rllib/tuned_examples/regression_tests/*.yaml).
|
||||
# RLlib: Learning tests (from rllib/tuned_examples/*.yaml).
|
||||
- os: linux
|
||||
env:
|
||||
- RLLIB_TESTING=1 RLLIB_REGRESSION_TESTS=1
|
||||
@@ -178,7 +178,7 @@ matrix:
|
||||
script:
|
||||
- ./ci/keep_alive bazel test --config=ci --test_output=errors --build_tests_only --test_tag_filters=learning_tests_tf rllib/...
|
||||
|
||||
# RLlib: Learning tests with tf=1.x (from rllib/tuned_examples/regression_tests/*.yaml).
|
||||
# RLlib: Learning tests with tf=1.x (from rllib/tuned_examples/*.yaml).
|
||||
# Requested by Edi (MS): Test all learning capabilities with tf1.x
|
||||
- os: linux
|
||||
env:
|
||||
@@ -195,7 +195,7 @@ matrix:
|
||||
script:
|
||||
- ./ci/keep_alive bazel test --config=ci --test_output=errors --build_tests_only --test_tag_filters=learning_tests_tf rllib/...
|
||||
|
||||
# RLlib: Learning tests with torch (from rllib/tuned_examples/regression_tests/*.yaml).
|
||||
# RLlib: Learning tests with torch (from rllib/tuned_examples/*.yaml).
|
||||
- os: linux
|
||||
env:
|
||||
- RLLIB_TESTING=1 RLLIB_REGRESSION_TESTS_TORCH=1
|
||||
@@ -250,7 +250,7 @@ matrix:
|
||||
- ./ci/keep_alive bazel test --config=ci --build_tests_only --test_tag_filters=examples_E,examples_F,examples_G,examples_H,examples_I,examples_J,examples_K,examples_L,examples_M,examples_N,examples_O,examples_P rllib/...
|
||||
- ./ci/keep_alive bazel test --config=ci --build_tests_only --test_tag_filters=examples_Q,examples_R,examples_S,examples_T,examples_U,examples_V,examples_W,examples_X,examples_Y,examples_Z rllib/...
|
||||
|
||||
# RLlib: tests_dir: Everything in rllib/tests/ directory (A-I).
|
||||
# RLlib: tests_dir: Everything in rllib/tests/ directory (A-L).
|
||||
- os: linux
|
||||
env:
|
||||
- RLLIB_TESTING=1 RLLIB_TESTS_DIR_TESTS_A_TO_L=1
|
||||
@@ -266,7 +266,7 @@ matrix:
|
||||
script:
|
||||
- ./ci/keep_alive bazel test --config=ci --build_tests_only --test_tag_filters=tests_dir_A,tests_dir_B,tests_dir_C,tests_dir_D,tests_dir_E,tests_dir_F,tests_dir_G,tests_dir_H,tests_dir_I,tests_dir_J,tests_dir_K,tests_dir_L rllib/...
|
||||
|
||||
# RLlib: tests_dir: Everything in rllib/tests/ directory (J-Z).
|
||||
# RLlib: tests_dir: Everything in rllib/tests/ directory (M-Z).
|
||||
- os: linux
|
||||
env:
|
||||
- RLLIB_TESTING=1 RLLIB_TESTS_DIR_TESTS_M_TO_Z=1
|
||||
|
||||
@@ -6,9 +6,7 @@ from ray.rllib.env.base_env import BaseEnv
|
||||
from ray.rllib.env.external_env import ExternalEnv
|
||||
from ray.rllib.env.multi_agent_env import MultiAgentEnv
|
||||
from ray.rllib.env.vector_env import VectorEnv
|
||||
from ray.rllib.evaluation.policy_graph import PolicyGraph
|
||||
from ray.rllib.evaluation.rollout_worker import RolloutWorker
|
||||
from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph
|
||||
from ray.rllib.policy.policy import Policy
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.policy.tf_policy import TFPolicy
|
||||
@@ -61,11 +59,8 @@ _register_all()
|
||||
|
||||
__all__ = [
|
||||
"Policy",
|
||||
"PolicyGraph",
|
||||
"TFPolicy",
|
||||
"TFPolicyGraph",
|
||||
"RolloutWorker",
|
||||
"PolicyEvaluator",
|
||||
"SampleBatch",
|
||||
"BaseEnv",
|
||||
"MultiAgentEnv",
|
||||
|
||||
@@ -1,22 +1,14 @@
|
||||
from ray.rllib.evaluation.episode import MultiAgentEpisode
|
||||
from ray.rllib.evaluation.rollout_worker import RolloutWorker
|
||||
from ray.rllib.evaluation.policy_evaluator import PolicyEvaluator
|
||||
from ray.rllib.evaluation.policy_graph import PolicyGraph
|
||||
from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph
|
||||
from ray.rllib.evaluation.torch_policy_graph import TorchPolicyGraph
|
||||
from ray.rllib.evaluation.sample_batch import MultiAgentBatch
|
||||
from ray.rllib.evaluation.sample_batch_builder import (
|
||||
SampleBatchBuilder, MultiAgentSampleBatchBuilder)
|
||||
from ray.rllib.evaluation.sampler import SyncSampler, AsyncSampler
|
||||
from ray.rllib.evaluation.postprocessing import compute_advantages
|
||||
from ray.rllib.evaluation.metrics import collect_metrics
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.policy.sample_batch import SampleBatch, MultiAgentBatch
|
||||
|
||||
__all__ = [
|
||||
"RolloutWorker",
|
||||
"PolicyGraph",
|
||||
"TFPolicyGraph",
|
||||
"TorchPolicyGraph",
|
||||
"SampleBatch",
|
||||
"MultiAgentBatch",
|
||||
"SampleBatchBuilder",
|
||||
@@ -26,5 +18,4 @@ __all__ = [
|
||||
"compute_advantages",
|
||||
"collect_metrics",
|
||||
"MultiAgentEpisode",
|
||||
"PolicyEvaluator",
|
||||
]
|
||||
|
||||
@@ -1,5 +0,0 @@
|
||||
from ray.rllib.utils import renamed_class
|
||||
from ray.rllib.evaluation import RolloutWorker
|
||||
|
||||
PolicyEvaluator = renamed_class(
|
||||
RolloutWorker, old_name="rllib.evaluation.PolicyEvaluator")
|
||||
@@ -1,4 +0,0 @@
|
||||
from ray.rllib.policy.policy import Policy
|
||||
from ray.rllib.utils import renamed_class
|
||||
|
||||
PolicyGraph = renamed_class(Policy, old_name="PolicyGraph")
|
||||
@@ -420,11 +420,16 @@ class RolloutWorker(ParallelIteratorWorker):
|
||||
remote_env_batch_wait_ms=remote_env_batch_wait_ms)
|
||||
self.num_envs = num_envs
|
||||
|
||||
# `truncate_episodes`: Allow a batch to contain more than one episode
|
||||
# (fragments) and always make the batch `rollout_fragment_length`
|
||||
# long.
|
||||
if self.batch_mode == "truncate_episodes":
|
||||
pack_episodes = True
|
||||
pack = True
|
||||
# `complete_episodes`: Never cut episodes and sampler will return
|
||||
# exactly one (complete) episode per poll.
|
||||
elif self.batch_mode == "complete_episodes":
|
||||
rollout_fragment_length = float("inf") # never cut episodes
|
||||
pack_episodes = False # sampler will return 1 episode per poll
|
||||
rollout_fragment_length = float("inf")
|
||||
pack = False
|
||||
else:
|
||||
raise ValueError("Unsupported batch mode: {}".format(
|
||||
self.batch_mode))
|
||||
@@ -450,37 +455,38 @@ class RolloutWorker(ParallelIteratorWorker):
|
||||
|
||||
if sample_async:
|
||||
self.sampler = AsyncSampler(
|
||||
self,
|
||||
self.async_env,
|
||||
self.policy_map,
|
||||
policy_mapping_fn,
|
||||
self.preprocessors,
|
||||
self.filters,
|
||||
clip_rewards,
|
||||
rollout_fragment_length,
|
||||
self.callbacks,
|
||||
worker=self,
|
||||
env=self.async_env,
|
||||
policies=self.policy_map,
|
||||
policy_mapping_fn=policy_mapping_fn,
|
||||
preprocessors=self.preprocessors,
|
||||
obs_filters=self.filters,
|
||||
clip_rewards=clip_rewards,
|
||||
rollout_fragment_length=rollout_fragment_length,
|
||||
callbacks=self.callbacks,
|
||||
horizon=episode_horizon,
|
||||
pack=pack_episodes,
|
||||
pack_multiple_episodes_in_batch=pack,
|
||||
tf_sess=self.tf_sess,
|
||||
clip_actions=clip_actions,
|
||||
blackhole_outputs="simulation" in input_evaluation,
|
||||
soft_horizon=soft_horizon,
|
||||
no_done_at_end=no_done_at_end,
|
||||
observation_fn=observation_fn)
|
||||
# Start the Sampler thread.
|
||||
self.sampler.start()
|
||||
else:
|
||||
self.sampler = SyncSampler(
|
||||
self,
|
||||
self.async_env,
|
||||
self.policy_map,
|
||||
policy_mapping_fn,
|
||||
self.preprocessors,
|
||||
self.filters,
|
||||
clip_rewards,
|
||||
rollout_fragment_length,
|
||||
self.callbacks,
|
||||
worker=self,
|
||||
env=self.async_env,
|
||||
policies=self.policy_map,
|
||||
policy_mapping_fn=policy_mapping_fn,
|
||||
preprocessors=self.preprocessors,
|
||||
obs_filters=self.filters,
|
||||
clip_rewards=clip_rewards,
|
||||
rollout_fragment_length=rollout_fragment_length,
|
||||
callbacks=self.callbacks,
|
||||
horizon=episode_horizon,
|
||||
pack=pack_episodes,
|
||||
pack_multiple_episodes_in_batch=pack,
|
||||
tf_sess=self.tf_sess,
|
||||
clip_actions=clip_actions,
|
||||
soft_horizon=soft_horizon,
|
||||
@@ -503,7 +509,7 @@ class RolloutWorker(ParallelIteratorWorker):
|
||||
This method must be implemented by subclasses.
|
||||
|
||||
Returns:
|
||||
SampleBatch|MultiAgentBatch: A columnar batch of experiences
|
||||
Union[SampleBatch,MultiAgentBatch]: A columnar batch of experiences
|
||||
(e.g., tensors), or a multi-agent batch.
|
||||
|
||||
Examples:
|
||||
|
||||
@@ -1,7 +0,0 @@
|
||||
from ray.rllib.policy.sample_batch import SampleBatch, MultiAgentBatch
|
||||
from ray.rllib.utils import renamed_class
|
||||
|
||||
SampleBatch = renamed_class(
|
||||
SampleBatch, old_name="rllib.evaluation.SampleBatch")
|
||||
MultiAgentBatch = renamed_class(
|
||||
MultiAgentBatch, old_name="rllib.evaluation.MultiAgentBatch")
|
||||
@@ -75,30 +75,47 @@ class MultiAgentSampleBatchBuilder:
|
||||
def __init__(self, policy_map, clip_rewards, callbacks):
|
||||
"""Initialize a MultiAgentSampleBatchBuilder.
|
||||
|
||||
Arguments:
|
||||
policy_map (dict): Maps policy ids to policy instances.
|
||||
clip_rewards (bool): Whether to clip rewards before postprocessing.
|
||||
Args:
|
||||
policy_map (Dict[str,Policy]): Maps policy ids to policy instances.
|
||||
clip_rewards (Union[bool,float]): Whether to clip rewards before
|
||||
postprocessing (at +/-1.0) or the actual value to +/- clip.
|
||||
callbacks (DefaultCallbacks): RLlib callbacks.
|
||||
"""
|
||||
|
||||
self.policy_map = policy_map
|
||||
self.clip_rewards = clip_rewards
|
||||
# Build the Policies' SampleBatchBuilders.
|
||||
self.policy_builders = {
|
||||
k: SampleBatchBuilder()
|
||||
for k in policy_map.keys()
|
||||
}
|
||||
# Whenever we observe a new agent, add a new SampleBatchBuilder for
|
||||
# this agent.
|
||||
self.agent_builders = {}
|
||||
# Internal agent-to-policy map.
|
||||
self.agent_to_policy = {}
|
||||
self.callbacks = callbacks
|
||||
self.count = 0 # increment this manually
|
||||
# Number of "inference" steps taken in the environment.
|
||||
# Regardless of the number of agents involved in each of these steps.
|
||||
self.count = 0
|
||||
|
||||
def total(self):
|
||||
"""Returns summed number of steps across all agent buffers."""
|
||||
"""Returns the total number of steps taken in the env (all agents).
|
||||
|
||||
Returns:
|
||||
int: The number of steps taken in total in the environment over all
|
||||
agents.
|
||||
"""
|
||||
|
||||
return sum(a.count for a in self.agent_builders.values())
|
||||
|
||||
def has_pending_agent_data(self):
|
||||
"""Returns whether there is pending unprocessed data."""
|
||||
"""Returns whether there is pending unprocessed data.
|
||||
|
||||
Returns:
|
||||
bool: True if there is at least one per-agent builder (with data
|
||||
in it).
|
||||
"""
|
||||
|
||||
return len(self.agent_builders) > 0
|
||||
|
||||
@@ -115,32 +132,37 @@ class MultiAgentSampleBatchBuilder:
|
||||
if agent_id not in self.agent_builders:
|
||||
self.agent_builders[agent_id] = SampleBatchBuilder()
|
||||
self.agent_to_policy[agent_id] = policy_id
|
||||
builder = self.agent_builders[agent_id]
|
||||
builder.add_values(**values)
|
||||
self.agent_builders[agent_id].add_values(**values)
|
||||
|
||||
def postprocess_batch_so_far(self, episode):
|
||||
def postprocess_batch_so_far(self, episode=None):
|
||||
"""Apply policy postprocessors to any unprocessed rows.
|
||||
|
||||
This pushes the postprocessed per-agent batches onto the per-policy
|
||||
builders, clearing per-agent state.
|
||||
|
||||
Args:
|
||||
episode (Optional[MultiAgentEpisode]): Current MultiAgentEpisode
|
||||
object.
|
||||
episode (Optional[MultiAgentEpisode]): The Episode object that
|
||||
holds this MultiAgentBatchBuilder object.
|
||||
"""
|
||||
|
||||
# Materialize the batches so far
|
||||
# Materialize the batches so far.
|
||||
pre_batches = {}
|
||||
for agent_id, builder in self.agent_builders.items():
|
||||
pre_batches[agent_id] = (
|
||||
self.policy_map[self.agent_to_policy[agent_id]],
|
||||
builder.build_and_reset())
|
||||
|
||||
# Apply postprocessor
|
||||
# Apply postprocessor.
|
||||
post_batches = {}
|
||||
if self.clip_rewards:
|
||||
if self.clip_rewards is True:
|
||||
for _, (_, pre_batch) in pre_batches.items():
|
||||
pre_batch["rewards"] = np.sign(pre_batch["rewards"])
|
||||
elif self.clip_rewards:
|
||||
for _, (_, pre_batch) in pre_batches.items():
|
||||
pre_batch["rewards"] = np.clip(
|
||||
pre_batch["rewards"],
|
||||
a_min=-self.clip_rewards,
|
||||
a_max=self.clip_rewards)
|
||||
for agent_id, (_, pre_batch) in pre_batches.items():
|
||||
other_batches = pre_batches.copy()
|
||||
del other_batches[agent_id]
|
||||
@@ -193,15 +215,19 @@ class MultiAgentSampleBatchBuilder:
|
||||
"Alternatively, set no_done_at_end=True to allow this.")
|
||||
|
||||
@DeveloperAPI
|
||||
def build_and_reset(self, episode):
|
||||
def build_and_reset(self, episode=None):
|
||||
"""Returns the accumulated sample batches for each policy.
|
||||
|
||||
Any unprocessed rows will be first postprocessed with a policy
|
||||
postprocessor. The internal state of this builder will be reset.
|
||||
|
||||
Args:
|
||||
episode (Optional[MultiAgentEpisode]): Current MultiAgentEpisode
|
||||
object.
|
||||
episode (Optional[MultiAgentEpisode]): The Episode object that
|
||||
holds this MultiAgentBatchBuilder object or None.
|
||||
|
||||
Returns:
|
||||
MultiAgentBatch: Returns the accumulated sample batches for each
|
||||
policy.
|
||||
"""
|
||||
|
||||
self.postprocess_batch_so_far(episode)
|
||||
|
||||
+251
-41
@@ -1,3 +1,4 @@
|
||||
from abc import abstractmethod, ABCMeta
|
||||
from collections import defaultdict, namedtuple
|
||||
import logging
|
||||
import numpy as np
|
||||
@@ -16,7 +17,7 @@ from ray.rllib.env.base_env import BaseEnv, ASYNC_RESET_RETURN
|
||||
from ray.rllib.env.atari_wrappers import get_wrapper_by_cls, MonitorEnv
|
||||
from ray.rllib.offline import InputReader
|
||||
from ray.rllib.utils import try_import_tree
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.annotations import override, DeveloperAPI
|
||||
from ray.rllib.utils.debug import summarize
|
||||
from ray.rllib.utils.spaces.space_utils import flatten_to_single_ndarray, \
|
||||
unbatch
|
||||
@@ -49,7 +50,8 @@ class PerfStats:
|
||||
}
|
||||
|
||||
|
||||
class SamplerInput(InputReader):
|
||||
@DeveloperAPI
|
||||
class SamplerInput(InputReader, metaclass=ABCMeta):
|
||||
"""Reads input experiences from an existing sampler."""
|
||||
|
||||
@override(InputReader)
|
||||
@@ -61,9 +63,29 @@ class SamplerInput(InputReader):
|
||||
else:
|
||||
return batches[0]
|
||||
|
||||
@abstractmethod
|
||||
@DeveloperAPI
|
||||
def get_data(self):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
@DeveloperAPI
|
||||
def get_metrics(self):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
@DeveloperAPI
|
||||
def get_extra_batches(self):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@DeveloperAPI
|
||||
class SyncSampler(SamplerInput):
|
||||
"""Sync SamplerInput that collects experiences when `get_data()` is called.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
*,
|
||||
worker,
|
||||
env,
|
||||
policies,
|
||||
@@ -74,12 +96,50 @@ class SyncSampler(SamplerInput):
|
||||
rollout_fragment_length,
|
||||
callbacks,
|
||||
horizon=None,
|
||||
pack=False,
|
||||
pack_multiple_episodes_in_batch=False,
|
||||
tf_sess=None,
|
||||
clip_actions=True,
|
||||
soft_horizon=False,
|
||||
no_done_at_end=False,
|
||||
observation_fn=None):
|
||||
"""Initializes a SyncSampler object.
|
||||
|
||||
Args:
|
||||
worker (RolloutWorker): The RolloutWorker that will use this
|
||||
Sampler for sampling.
|
||||
env (Env): Any Env object. Will be converted into an RLlib BaseEnv.
|
||||
policies (Dict[str,Policy]): Mapping from policy ID to Policy obj.
|
||||
policy_mapping_fn (callable): Callable that takes an agent ID and
|
||||
returns a Policy object.
|
||||
preprocessors (Dict[str,Preprocessor]): Mapping from policy ID to
|
||||
Preprocessor object for the observations prior to filtering.
|
||||
obs_filters (Dict[str,Filter]): Mapping from policy ID to
|
||||
env Filter object.
|
||||
clip_rewards (Union[bool,float]): True for +/-1.0 clipping, actual
|
||||
float value for +/- value clipping. False for no clipping.
|
||||
rollout_fragment_length (int): The length of a fragment to collect
|
||||
before building a SampleBatch from the data and resetting
|
||||
the SampleBatchBuilder object.
|
||||
callbacks (Callbacks): The Callbacks object to use when episode
|
||||
events happen during rollout.
|
||||
horizon (Optional[int]): Hard-reset the Env
|
||||
pack_multiple_episodes_in_batch (bool): Whether to pack multiple
|
||||
episodes into each batch. This guarantees batches will be
|
||||
exactly `rollout_fragment_length` in size.
|
||||
tf_sess (Optional[tf.Session]): A tf.Session object to use (only if
|
||||
framework=tf).
|
||||
clip_actions (bool): Whether to clip actions according to the
|
||||
given action_space's bounds.
|
||||
soft_horizon (bool): If True, calculate bootstrapped values as if
|
||||
episode had ended, but don't physically reset the environment
|
||||
when the horizon is hit.
|
||||
no_done_at_end (bool): Ignore the done=True at the end of the
|
||||
episode and instead record done=False.
|
||||
observation_fn (Optional[ObservationFunction]): Optional
|
||||
multi-agent observation func to use for preprocessing
|
||||
observations.
|
||||
"""
|
||||
|
||||
self.base_env = BaseEnv.to_base_env(env)
|
||||
self.rollout_fragment_length = rollout_fragment_length
|
||||
self.horizon = horizon
|
||||
@@ -89,14 +149,16 @@ class SyncSampler(SamplerInput):
|
||||
self.obs_filters = obs_filters
|
||||
self.extra_batches = queue.Queue()
|
||||
self.perf_stats = PerfStats()
|
||||
# Create the rollout generator to use for calls to `get_data()`.
|
||||
self.rollout_provider = _env_runner(
|
||||
worker, self.base_env, self.extra_batches.put, self.policies,
|
||||
self.policy_mapping_fn, self.rollout_fragment_length, self.horizon,
|
||||
self.preprocessors, self.obs_filters, clip_rewards, clip_actions,
|
||||
pack, callbacks, tf_sess, self.perf_stats, soft_horizon,
|
||||
no_done_at_end, observation_fn)
|
||||
pack_multiple_episodes_in_batch, callbacks, tf_sess,
|
||||
self.perf_stats, soft_horizon, no_done_at_end, observation_fn)
|
||||
self.metrics_queue = queue.Queue()
|
||||
|
||||
@override(SamplerInput)
|
||||
def get_data(self):
|
||||
while True:
|
||||
item = next(self.rollout_provider)
|
||||
@@ -105,6 +167,7 @@ class SyncSampler(SamplerInput):
|
||||
else:
|
||||
return item
|
||||
|
||||
@override(SamplerInput)
|
||||
def get_metrics(self):
|
||||
completed = []
|
||||
while True:
|
||||
@@ -115,6 +178,7 @@ class SyncSampler(SamplerInput):
|
||||
break
|
||||
return completed
|
||||
|
||||
@override(SamplerInput)
|
||||
def get_extra_batches(self):
|
||||
extra = []
|
||||
while True:
|
||||
@@ -125,8 +189,16 @@ class SyncSampler(SamplerInput):
|
||||
return extra
|
||||
|
||||
|
||||
@DeveloperAPI
|
||||
class AsyncSampler(threading.Thread, SamplerInput):
|
||||
"""Async SamplerInput that collects experiences in thread and queues them.
|
||||
|
||||
Once started, experiences are continuously collected and put into a Queue,
|
||||
from where they can be unqueued by the caller of `get_data()`.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
*,
|
||||
worker,
|
||||
env,
|
||||
policies,
|
||||
@@ -137,13 +209,52 @@ class AsyncSampler(threading.Thread, SamplerInput):
|
||||
rollout_fragment_length,
|
||||
callbacks,
|
||||
horizon=None,
|
||||
pack=False,
|
||||
pack_multiple_episodes_in_batch=False,
|
||||
tf_sess=None,
|
||||
clip_actions=True,
|
||||
blackhole_outputs=False,
|
||||
soft_horizon=False,
|
||||
no_done_at_end=False,
|
||||
observation_fn=None):
|
||||
"""Initializes a AsyncSampler object.
|
||||
|
||||
Args:
|
||||
worker (RolloutWorker): The RolloutWorker that will use this
|
||||
Sampler for sampling.
|
||||
env (Env): Any Env object. Will be converted into an RLlib BaseEnv.
|
||||
policies (Dict[str,Policy]): Mapping from policy ID to Policy obj.
|
||||
policy_mapping_fn (callable): Callable that takes an agent ID and
|
||||
returns a Policy object.
|
||||
preprocessors (Dict[str,Preprocessor]): Mapping from policy ID to
|
||||
Preprocessor object for the observations prior to filtering.
|
||||
obs_filters (Dict[str,Filter]): Mapping from policy ID to
|
||||
env Filter object.
|
||||
clip_rewards (Union[bool,float]): True for +/-1.0 clipping, actual
|
||||
float value for +/- value clipping. False for no clipping.
|
||||
rollout_fragment_length (int): The length of a fragment to collect
|
||||
before building a SampleBatch from the data and resetting
|
||||
the SampleBatchBuilder object.
|
||||
callbacks (Callbacks): The Callbacks object to use when episode
|
||||
events happen during rollout.
|
||||
horizon (Optional[int]): Hard-reset the Env
|
||||
pack_multiple_episodes_in_batch (bool): Whether to pack multiple
|
||||
episodes into each batch. This guarantees batches will be
|
||||
exactly `rollout_fragment_length` in size.
|
||||
tf_sess (Optional[tf.Session]): A tf.Session object to use (only if
|
||||
framework=tf).
|
||||
clip_actions (bool): Whether to clip actions according to the
|
||||
given action_space's bounds.
|
||||
blackhole_outputs (bool): Whether to collect samples, but then
|
||||
not further process or store them (throw away all samples).
|
||||
soft_horizon (bool): If True, calculate bootstrapped values as if
|
||||
episode had ended, but don't physically reset the environment
|
||||
when the horizon is hit.
|
||||
no_done_at_end (bool): Ignore the done=True at the end of the
|
||||
episode and instead record done=False.
|
||||
observation_fn (Optional[ObservationFunction]): Optional
|
||||
multi-agent observation func to use for preprocessing
|
||||
observations.
|
||||
"""
|
||||
for _, f in obs_filters.items():
|
||||
assert getattr(f, "is_concurrent", False), \
|
||||
"Observation Filter must support concurrent updates."
|
||||
@@ -161,7 +272,7 @@ class AsyncSampler(threading.Thread, SamplerInput):
|
||||
self.obs_filters = obs_filters
|
||||
self.clip_rewards = clip_rewards
|
||||
self.daemon = True
|
||||
self.pack = pack
|
||||
self.pack_multiple_episodes_in_batch = pack_multiple_episodes_in_batch
|
||||
self.tf_sess = tf_sess
|
||||
self.callbacks = callbacks
|
||||
self.clip_actions = clip_actions
|
||||
@@ -172,6 +283,7 @@ class AsyncSampler(threading.Thread, SamplerInput):
|
||||
self.shutdown = False
|
||||
self.observation_fn = observation_fn
|
||||
|
||||
@override(threading.Thread)
|
||||
def run(self):
|
||||
try:
|
||||
self._run()
|
||||
@@ -191,9 +303,9 @@ class AsyncSampler(threading.Thread, SamplerInput):
|
||||
self.worker, self.base_env, extra_batches_putter, self.policies,
|
||||
self.policy_mapping_fn, self.rollout_fragment_length, self.horizon,
|
||||
self.preprocessors, self.obs_filters, self.clip_rewards,
|
||||
self.clip_actions, self.pack, self.callbacks, self.tf_sess,
|
||||
self.perf_stats, self.soft_horizon, self.no_done_at_end,
|
||||
self.observation_fn)
|
||||
self.clip_actions, self.pack_multiple_episodes_in_batch,
|
||||
self.callbacks, self.tf_sess, self.perf_stats, self.soft_horizon,
|
||||
self.no_done_at_end, self.observation_fn)
|
||||
while not self.shutdown:
|
||||
# The timeout variable exists because apparently, if one worker
|
||||
# dies, the other workers won't die with it, unless the timeout is
|
||||
@@ -204,6 +316,7 @@ class AsyncSampler(threading.Thread, SamplerInput):
|
||||
else:
|
||||
queue_putter(item)
|
||||
|
||||
@override(SamplerInput)
|
||||
def get_data(self):
|
||||
if not self.is_alive():
|
||||
raise RuntimeError("Sampling thread has died")
|
||||
@@ -215,6 +328,7 @@ class AsyncSampler(threading.Thread, SamplerInput):
|
||||
|
||||
return rollout
|
||||
|
||||
@override(SamplerInput)
|
||||
def get_metrics(self):
|
||||
completed = []
|
||||
while True:
|
||||
@@ -225,6 +339,7 @@ class AsyncSampler(threading.Thread, SamplerInput):
|
||||
break
|
||||
return completed
|
||||
|
||||
@override(SamplerInput)
|
||||
def get_extra_batches(self):
|
||||
extra = []
|
||||
while True:
|
||||
@@ -237,14 +352,14 @@ class AsyncSampler(threading.Thread, SamplerInput):
|
||||
|
||||
def _env_runner(worker, base_env, extra_batch_callback, policies,
|
||||
policy_mapping_fn, rollout_fragment_length, horizon,
|
||||
preprocessors, obs_filters, clip_rewards, clip_actions, pack,
|
||||
callbacks, tf_sess, perf_stats, soft_horizon, no_done_at_end,
|
||||
observation_fn):
|
||||
preprocessors, obs_filters, clip_rewards, clip_actions,
|
||||
pack_multiple_episodes_in_batch, callbacks, tf_sess,
|
||||
perf_stats, soft_horizon, no_done_at_end, observation_fn):
|
||||
"""This implements the common experience collection logic.
|
||||
|
||||
Args:
|
||||
worker (RolloutWorker): reference to the current rollout worker.
|
||||
base_env (BaseEnv): env implementing BaseEnv.
|
||||
worker (RolloutWorker): Reference to the current rollout worker.
|
||||
base_env (BaseEnv): Env implementing BaseEnv.
|
||||
extra_batch_callback (fn): function to send extra batch data to.
|
||||
policies (dict): Map of policy ids to Policy instances.
|
||||
policy_mapping_fn (func): Function that maps agent ids to policy ids.
|
||||
@@ -259,9 +374,9 @@ def _env_runner(worker, base_env, extra_batch_callback, policies,
|
||||
obs_filters (dict): Map of policy id to filter used to process
|
||||
observations for the policy.
|
||||
clip_rewards (bool): Whether to clip rewards before postprocessing.
|
||||
pack (bool): Whether to pack multiple episodes into each batch. This
|
||||
guarantees batches will be exactly `rollout_fragment_length` in
|
||||
size.
|
||||
pack_multiple_episodes_in_batch (bool): Whether to pack multiple
|
||||
episodes into each batch. This guarantees batches will be exactly
|
||||
`rollout_fragment_length` in size.
|
||||
clip_actions (bool): Whether to clip actions to the space range.
|
||||
callbacks (DefaultCallbacks): User callbacks to run on episode events.
|
||||
tf_sess (Session|None): Optional tensorflow session to use for batching
|
||||
@@ -354,25 +469,47 @@ def _env_runner(worker, base_env, extra_batch_callback, policies,
|
||||
# Process observations and prepare for policy evaluation.
|
||||
t1 = time.time()
|
||||
active_envs, to_eval, outputs = _process_observations(
|
||||
worker, base_env, policies, batch_builder_pool, active_episodes,
|
||||
unfiltered_obs, rewards, dones, infos, off_policy_actions, horizon,
|
||||
preprocessors, obs_filters, rollout_fragment_length, pack,
|
||||
callbacks, soft_horizon, no_done_at_end, observation_fn)
|
||||
worker=worker,
|
||||
base_env=base_env,
|
||||
policies=policies,
|
||||
batch_builder_pool=batch_builder_pool,
|
||||
active_episodes=active_episodes,
|
||||
unfiltered_obs=unfiltered_obs,
|
||||
rewards=rewards,
|
||||
dones=dones,
|
||||
infos=infos,
|
||||
horizon=horizon,
|
||||
preprocessors=preprocessors,
|
||||
obs_filters=obs_filters,
|
||||
rollout_fragment_length=rollout_fragment_length,
|
||||
pack_multiple_episodes_in_batch=pack_multiple_episodes_in_batch,
|
||||
callbacks=callbacks,
|
||||
soft_horizon=soft_horizon,
|
||||
no_done_at_end=no_done_at_end,
|
||||
observation_fn=observation_fn)
|
||||
perf_stats.processing_time += time.time() - t1
|
||||
for o in outputs:
|
||||
yield o
|
||||
|
||||
# Do batched policy eval (accross vectorized envs).
|
||||
t2 = time.time()
|
||||
eval_results = _do_policy_eval(tf_sess, to_eval, policies,
|
||||
active_episodes)
|
||||
eval_results = _do_policy_eval(
|
||||
to_eval=to_eval,
|
||||
policies=policies,
|
||||
active_episodes=active_episodes,
|
||||
tf_sess=tf_sess)
|
||||
perf_stats.inference_time += time.time() - t2
|
||||
|
||||
# Process results and update episode state.
|
||||
t3 = time.time()
|
||||
actions_to_send = _process_policy_eval_results(
|
||||
to_eval, eval_results, active_episodes, active_envs,
|
||||
off_policy_actions, policies, clip_actions)
|
||||
to_eval=to_eval,
|
||||
eval_results=eval_results,
|
||||
active_episodes=active_episodes,
|
||||
active_envs=active_envs,
|
||||
off_policy_actions=off_policy_actions,
|
||||
policies=policies,
|
||||
clip_actions=clip_actions)
|
||||
perf_stats.processing_time += time.time() - t3
|
||||
|
||||
# Return computed actions to ready envs. We also send to envs that have
|
||||
@@ -384,15 +521,51 @@ def _env_runner(worker, base_env, extra_batch_callback, policies,
|
||||
|
||||
def _process_observations(
|
||||
worker, base_env, policies, batch_builder_pool, active_episodes,
|
||||
unfiltered_obs, rewards, dones, infos, off_policy_actions, horizon,
|
||||
preprocessors, obs_filters, rollout_fragment_length, pack, callbacks,
|
||||
soft_horizon, no_done_at_end, observation_fn):
|
||||
unfiltered_obs, rewards, dones, infos, horizon, preprocessors,
|
||||
obs_filters, rollout_fragment_length, pack_multiple_episodes_in_batch,
|
||||
callbacks, soft_horizon, no_done_at_end, observation_fn):
|
||||
"""Record new data from the environment and prepare for policy evaluation.
|
||||
|
||||
Args:
|
||||
worker (RolloutWorker): Reference to the current rollout worker.
|
||||
base_env (BaseEnv): Env implementing BaseEnv.
|
||||
policies (dict): Map of policy ids to Policy instances.
|
||||
batch_builder_pool (List[SampleBatchBuilder]): List of pooled
|
||||
SampleBatchBuilder object for recycling.
|
||||
active_episodes (defaultdict[str,MultiAgentEpisode]): Mapping from
|
||||
episode ID to currently ongoing MultiAgentEpisode object.
|
||||
unfiltered_obs (dict): Doubly keyed dict of env-ids -> agent ids ->
|
||||
unfiltered observation tensor, returned by a `BaseEnv.poll()` call.
|
||||
rewards (dict): Doubly keyed dict of env-ids -> agent ids ->
|
||||
rewards tensor, returned by a `BaseEnv.poll()` call.
|
||||
dones (dict): Doubly keyed dict of env-ids -> agent ids ->
|
||||
boolean done flags, returned by a `BaseEnv.poll()` call.
|
||||
infos (dict): Doubly keyed dict of env-ids -> agent ids ->
|
||||
info dicts, returned by a `BaseEnv.poll()` call.
|
||||
horizon (int): Horizon of the episode.
|
||||
preprocessors (dict): Map of policy id to preprocessor for the
|
||||
observations prior to filtering.
|
||||
obs_filters (dict): Map of policy id to filter used to process
|
||||
observations for the policy.
|
||||
rollout_fragment_length (int): Number of episode steps before
|
||||
`SampleBatch` is yielded. Set to infinity to yield complete
|
||||
episodes.
|
||||
pack_multiple_episodes_in_batch (bool): Whether to pack multiple
|
||||
episodes into each batch. This guarantees batches will be exactly
|
||||
`rollout_fragment_length` in size.
|
||||
callbacks (DefaultCallbacks): User callbacks to run on episode events.
|
||||
soft_horizon (bool): Calculate rewards but don't reset the
|
||||
environment when the horizon is hit.
|
||||
no_done_at_end (bool): Ignore the done=True at the end of the episode
|
||||
and instead record done=False.
|
||||
observation_fn (ObservationFunction): Optional multi-agent
|
||||
observation func to use for preprocessing observations.
|
||||
|
||||
Returns:
|
||||
active_envs: set of non-terminated env ids
|
||||
to_eval: map of policy_id to list of agent PolicyEvalData
|
||||
outputs: list of metrics and samples to return from the sampler
|
||||
Tuple:
|
||||
- active_envs: Set of non-terminated env ids.
|
||||
- to_eval: Map of policy_id to list of agent PolicyEvalData.
|
||||
- outputs: List of metrics and samples to return from the sampler.
|
||||
"""
|
||||
|
||||
active_envs = set()
|
||||
@@ -487,7 +660,7 @@ def _process_observations(
|
||||
episode._set_last_raw_obs(agent_id, raw_obs)
|
||||
episode._set_last_info(agent_id, infos[env_id].get(agent_id, {}))
|
||||
|
||||
# Record transition info if applicable
|
||||
# Record transition info if applicable.
|
||||
if (last_observation is not None and infos[env_id].get(
|
||||
agent_id, {}).get("training_enabled", True)):
|
||||
episode.batch_builder.add_values(
|
||||
@@ -515,13 +688,19 @@ def _process_observations(
|
||||
# Cut the batch if we're not packing multiple episodes into one,
|
||||
# or if we've exceeded the requested batch size.
|
||||
if episode.batch_builder.has_pending_agent_data():
|
||||
# Sanity check, whether all agents have done=True, if done[__all__]
|
||||
# is True.
|
||||
if dones[env_id]["__all__"] and not no_done_at_end:
|
||||
episode.batch_builder.check_missing_dones()
|
||||
if (all_agents_done and not pack) or \
|
||||
|
||||
# Reached end of episode and we are not allowed to pack the
|
||||
# next episode into the same SampleBatch -> Build the SampleBatch
|
||||
# and add it to "outputs".
|
||||
if (all_agents_done and not pack_multiple_episodes_in_batch) or \
|
||||
episode.batch_builder.count >= rollout_fragment_length:
|
||||
outputs.append(episode.batch_builder.build_and_reset(episode))
|
||||
# Make sure postprocessor stays within one episode.
|
||||
elif all_agents_done:
|
||||
# Make sure postprocessor stays within one episode
|
||||
episode.batch_builder.postprocess_batch_so_far(episode)
|
||||
|
||||
if all_agents_done:
|
||||
@@ -584,8 +763,17 @@ def _process_observations(
|
||||
return active_envs, to_eval, outputs
|
||||
|
||||
|
||||
def _do_policy_eval(tf_sess, to_eval, policies, active_episodes):
|
||||
"""Call compute actions on observation batches to get next actions.
|
||||
def _do_policy_eval(*, to_eval, policies, active_episodes, tf_sess=None):
|
||||
"""Call compute_actions on collected episode/model data to get next action.
|
||||
|
||||
Args:
|
||||
tf_sess (Optional[tf.Session]): Optional tensorflow session to use for
|
||||
batching TF policy evaluations.
|
||||
to_eval (Dict[str,List[PolicyEvalData]]): Mapping of policy IDs to
|
||||
lists of PolicyEvalData objects.
|
||||
policies (Dict[str,Policy]): Mapping from policy ID to Policy obj.
|
||||
active_episodes (defaultdict[str,MultiAgentEpisode]): Mapping from
|
||||
episode ID to currently ongoing MultiAgentEpisode object.
|
||||
|
||||
Returns:
|
||||
eval_results: dict of policy to compute_action() outputs.
|
||||
@@ -606,6 +794,8 @@ def _do_policy_eval(tf_sess, to_eval, policies, active_episodes):
|
||||
for policy_id, eval_data in to_eval.items():
|
||||
rnn_in = [t.rnn_state for t in eval_data]
|
||||
policy = _get_or_raise(policies, policy_id)
|
||||
# If tf (non eager) AND TFPolicy's compute_action method has not been
|
||||
# overridden -> Use `policy._build_compute_actions()`.
|
||||
if builder and (policy.compute_actions.__code__ is
|
||||
TFPolicy.compute_actions.__code__):
|
||||
|
||||
@@ -646,7 +836,7 @@ def _do_policy_eval(tf_sess, to_eval, policies, active_episodes):
|
||||
return eval_results
|
||||
|
||||
|
||||
def _process_policy_eval_results(to_eval, eval_results, active_episodes,
|
||||
def _process_policy_eval_results(*, to_eval, eval_results, active_episodes,
|
||||
active_envs, off_policy_actions, policies,
|
||||
clip_actions):
|
||||
"""Process the output of policy neural network evaluation.
|
||||
@@ -654,8 +844,22 @@ def _process_policy_eval_results(to_eval, eval_results, active_episodes,
|
||||
Records policy evaluation results into the given episode objects and
|
||||
returns replies to send back to agents in the env.
|
||||
|
||||
Args:
|
||||
to_eval (Dict[str,List[PolicyEvalData]]): Mapping of policy IDs to
|
||||
lists of PolicyEvalData objects.
|
||||
eval_results (Dict[str,List]): Mapping of policy IDs to list of
|
||||
actions, rnn-out states, extra-action-fetches dicts.
|
||||
active_episodes (defaultdict[str,MultiAgentEpisode]): Mapping from
|
||||
episode ID to currently ongoing MultiAgentEpisode object.
|
||||
active_envs (Set[int]): Set of non-terminated env ids.
|
||||
off_policy_actions (dict): Doubly keyed dict of env-ids -> agent ids ->
|
||||
off-policy-action, returned by a `BaseEnv.poll()` call.
|
||||
policies (Dict[str,Policy]): Mapping from policy ID to Policy obj.
|
||||
clip_actions (bool): Whether to clip actions to the action space's
|
||||
bounds.
|
||||
|
||||
Returns:
|
||||
actions_to_send: nested dict of env id -> agent id -> agent replies.
|
||||
actions_to_send: Nested dict of env id -> agent id -> agent replies.
|
||||
"""
|
||||
|
||||
actions_to_send = defaultdict(dict)
|
||||
@@ -711,7 +915,7 @@ def _process_policy_eval_results(to_eval, eval_results, active_episodes,
|
||||
def _fetch_atari_metrics(base_env):
|
||||
"""Atari games have multiple logical episodes, one per life.
|
||||
|
||||
However for metrics reporting we count full episodes all lives included.
|
||||
However, for metrics reporting we count full episodes, all lives included.
|
||||
"""
|
||||
unwrapped = base_env.get_unwrapped()
|
||||
if not unwrapped:
|
||||
@@ -734,10 +938,16 @@ def _to_column_format(rnn_state_rows):
|
||||
def _get_or_raise(mapping, policy_id):
|
||||
"""Returns a Policy object under key `policy_id` in `mapping`.
|
||||
|
||||
Throws an error if `policy_id` cannot be found.
|
||||
Args:
|
||||
mapping (dict): The mapping dict from policy id (str) to
|
||||
actual Policy object.
|
||||
policy_id (str): The policy ID to lookup.
|
||||
|
||||
Returns:
|
||||
Policy: The found Policy object.
|
||||
|
||||
Throws:
|
||||
ValueError: If `policy_id` cannot be found.
|
||||
"""
|
||||
if policy_id not in mapping:
|
||||
raise ValueError(
|
||||
|
||||
@@ -1,4 +0,0 @@
|
||||
from ray.rllib.policy.tf_policy import TFPolicy
|
||||
from ray.rllib.utils import renamed_class
|
||||
|
||||
TFPolicyGraph = renamed_class(TFPolicy, old_name="TFPolicyGraph")
|
||||
@@ -1,4 +0,0 @@
|
||||
from ray.rllib.policy.torch_policy import TorchPolicy
|
||||
from ray.rllib.utils import renamed_class
|
||||
|
||||
TorchPolicyGraph = renamed_class(TorchPolicy, old_name="TorchPolicyGraph")
|
||||
@@ -65,6 +65,15 @@ class SampleBatch:
|
||||
@staticmethod
|
||||
@PublicAPI
|
||||
def concat_samples(samples):
|
||||
"""Concatenates n data dicts or MultiAgentBatches.
|
||||
|
||||
Args:
|
||||
samples (List[Dict[np.ndarray]]]): List of dicts of data (numpy).
|
||||
|
||||
Returns:
|
||||
Union[SampleBatch,MultiAgentBatch]: A new (compressed) SampleBatch/
|
||||
MultiAgentBatch.
|
||||
"""
|
||||
if isinstance(samples[0], MultiAgentBatch):
|
||||
return MultiAgentBatch.concat_samples(samples)
|
||||
out = {}
|
||||
@@ -84,7 +93,10 @@ class SampleBatch:
|
||||
{"a": [1, 2, 3, 4, 5]}
|
||||
"""
|
||||
|
||||
assert self.keys() == other.keys(), "must have same columns"
|
||||
if self.keys() != other.keys():
|
||||
raise ValueError(
|
||||
"SampleBatches to concat must have same columns! {} vs {}".
|
||||
format(list(self.keys()), list(other.keys())))
|
||||
out = {}
|
||||
for k in self.keys():
|
||||
out[k] = concat_aligned([self[k], other[k]])
|
||||
@@ -117,7 +129,14 @@ class SampleBatch:
|
||||
|
||||
@PublicAPI
|
||||
def columns(self, keys):
|
||||
"""Returns a list of just the specified columns.
|
||||
"""Returns a list of the batch-data in the specified columns.
|
||||
|
||||
Args:
|
||||
keys (List[str]): List of column names fo which to return the data.
|
||||
|
||||
Returns:
|
||||
List[any]: The list of data items ordered by the order of column
|
||||
names in `keys`.
|
||||
|
||||
Examples:
|
||||
>>> batch = SampleBatch({"a": [1], "b": [2], "c": [3]})
|
||||
@@ -143,7 +162,7 @@ class SampleBatch:
|
||||
"""Splits this batch's data by `eps_id`.
|
||||
|
||||
Returns:
|
||||
list of SampleBatch, one per distinct episode.
|
||||
List[SampleBatch]: List of batches, one per distinct episode.
|
||||
"""
|
||||
|
||||
slices = []
|
||||
@@ -166,7 +185,7 @@ class SampleBatch:
|
||||
def slice(self, start, end):
|
||||
"""Returns a slice of the row data of this batch.
|
||||
|
||||
Arguments:
|
||||
Args:
|
||||
start (int): Starting index.
|
||||
end (int): Ending index.
|
||||
|
||||
@@ -234,23 +253,37 @@ class SampleBatch:
|
||||
@PublicAPI
|
||||
class MultiAgentBatch:
|
||||
"""A batch of experiences from multiple policies in the environment.
|
||||
|
||||
Attributes:
|
||||
policy_batches (dict): Mapping from policy id to a normal SampleBatch
|
||||
of experiences. Note that these batches may be of different length.
|
||||
count (int): The number of timesteps in the environment this batch
|
||||
contains. This will be less than the number of transitions this
|
||||
batch contains across all policies in total.
|
||||
"""
|
||||
|
||||
@PublicAPI
|
||||
def __init__(self, policy_batches, count):
|
||||
"""Initializes a MultiAgentBatch object.
|
||||
|
||||
Args:
|
||||
policy_batches (Dict[str,SampleBatch]): Mapping from policy id
|
||||
(str) to a SampleBatch of experiences. Note that these batches
|
||||
may be of different length.
|
||||
count (int): The number of timesteps in the environment this batch
|
||||
contains. This will be less than the number of transitions this
|
||||
batch contains across all policies in total.
|
||||
"""
|
||||
self.policy_batches = policy_batches
|
||||
self.count = count
|
||||
|
||||
@staticmethod
|
||||
@PublicAPI
|
||||
def wrap_as_needed(batches, count):
|
||||
"""Returns SampleBatch or MultiAgentBatch, depending on given policies.
|
||||
|
||||
Args:
|
||||
batches (Dict[str,SampleBatch]): Mapping from policy ID to
|
||||
SampleBatch.
|
||||
count (int): A count to use, when returning a MultiAgentBatch.
|
||||
|
||||
Returns:
|
||||
Union[SampleBatch,MultiAgentBatch]: The single default policy's
|
||||
SampleBatch or a MultiAgentBatch (more than one policy).
|
||||
"""
|
||||
if len(batches) == 1 and DEFAULT_POLICY_ID in batches:
|
||||
return batches[DEFAULT_POLICY_ID]
|
||||
return MultiAgentBatch(batches, count)
|
||||
@@ -258,10 +291,23 @@ class MultiAgentBatch:
|
||||
@staticmethod
|
||||
@PublicAPI
|
||||
def concat_samples(samples):
|
||||
"""Concatenates a list of MultiAgentBatches into a new MultiAgentBatch.
|
||||
|
||||
Args:
|
||||
samples (List[MultiAgentBatch]): List of MultiagentBatch objects
|
||||
to concatenate.
|
||||
|
||||
Returns:
|
||||
MultiAgentBatch: A new MultiAgentBatch consisting of the
|
||||
concatenated inputs.
|
||||
"""
|
||||
policy_batches = collections.defaultdict(list)
|
||||
total_count = 0
|
||||
for s in samples:
|
||||
assert isinstance(s, MultiAgentBatch)
|
||||
if not isinstance(s, MultiAgentBatch):
|
||||
raise ValueError(
|
||||
"`MultiAgentBatch.concat_samples()` can only concat "
|
||||
"MultiAgentBatch types, not {}!".format(type(s).__name__))
|
||||
for policy_id, batch in s.policy_batches.items():
|
||||
policy_batches[policy_id].append(batch)
|
||||
total_count += s.count
|
||||
@@ -272,12 +318,22 @@ class MultiAgentBatch:
|
||||
|
||||
@PublicAPI
|
||||
def copy(self):
|
||||
"""Deep-copies self into a new MultiAgentBatch.
|
||||
|
||||
Returns:
|
||||
MultiAgentBatch: The copy of self with deep-copied data.
|
||||
"""
|
||||
return MultiAgentBatch(
|
||||
{k: v.copy()
|
||||
for (k, v) in self.policy_batches.items()}, self.count)
|
||||
|
||||
@PublicAPI
|
||||
def total(self):
|
||||
"""Calculates the sum of all step-counts over all policy batches.
|
||||
|
||||
Returns:
|
||||
int: The sum of counts over all policy batches.
|
||||
"""
|
||||
ct = 0
|
||||
for batch in self.policy_batches.values():
|
||||
ct += batch.count
|
||||
@@ -285,11 +341,24 @@ class MultiAgentBatch:
|
||||
|
||||
@DeveloperAPI
|
||||
def compress(self, bulk=False, columns=frozenset(["obs", "new_obs"])):
|
||||
"""Compresses each policy batch.
|
||||
|
||||
Args:
|
||||
bulk (bool): Whether to compress across the batch dimension (0)
|
||||
as well. If False will compress n separate list items, where n
|
||||
is the batch size.
|
||||
columns (Set[str]): Set of column names to compress.
|
||||
"""
|
||||
for batch in self.policy_batches.values():
|
||||
batch.compress(bulk=bulk, columns=columns)
|
||||
|
||||
@DeveloperAPI
|
||||
def decompress_if_needed(self, columns=frozenset(["obs", "new_obs"])):
|
||||
"""Decompresses each policy batch, if already compressed.
|
||||
|
||||
Args:
|
||||
columns (Set[str]): Set of column names to decompress.
|
||||
"""
|
||||
for batch in self.policy_batches.values():
|
||||
batch.decompress_if_needed(columns)
|
||||
return self
|
||||
|
||||
Reference in New Issue
Block a user