mirror of
https://github.com/wassname/ray.git
synced 2026-06-27 20:06:31 +08:00
[rllib] Simplify sample batch size and num envs config, n_step adjustment (#2995)
* simplify vec batch requirements * Update rllib-training.rst * Update rllib-training.rst * Update rllib-training.rst * Update rllib-training.rst * Update rllib-training.rst * Update rllib-models.rst
This commit is contained in:
@@ -135,8 +135,8 @@ class DQNAgent(Agent):
|
||||
|
||||
def _init(self):
|
||||
# Update effective batch size to include n-step
|
||||
adjusted_batch_size = (
|
||||
self.config["sample_batch_size"] + self.config["n_step"] - 1)
|
||||
adjusted_batch_size = max(self.config["sample_batch_size"],
|
||||
self.config["n_step"])
|
||||
self.config["sample_batch_size"] = adjusted_batch_size
|
||||
|
||||
self.exploration0 = self._make_exploration_schedule(0)
|
||||
|
||||
@@ -126,8 +126,7 @@ class VTracePolicyGraph(LearningRateSchedule, TFPolicyGraph):
|
||||
else:
|
||||
# Important: chop the tensor into batches at known episode cut
|
||||
# boundaries. TODO(ekl) this is kind of a hack
|
||||
T = (self.config["sample_batch_size"] //
|
||||
self.config["num_envs_per_worker"])
|
||||
T = self.config["sample_batch_size"]
|
||||
B = tf.shape(tensor)[0] // T
|
||||
rs = tf.reshape(tensor,
|
||||
tf.concat([[B, T], tf.shape(tensor)[1:]], axis=0))
|
||||
|
||||
@@ -124,16 +124,14 @@ class PolicyEvaluator(EvaluatorInterface):
|
||||
in each sample batch returned from this evaluator.
|
||||
batch_mode (str): One of the following batch modes:
|
||||
"truncate_episodes": Each call to sample() will return a batch
|
||||
of at most `batch_steps` in size. The batch will be exactly
|
||||
`batch_steps` in size if postprocessing does not change
|
||||
batch sizes. Episodes may be truncated in order to meet
|
||||
this size requirement. When `num_envs > 1`, episodes will
|
||||
be truncated to sequences of `batch_size / num_envs` in
|
||||
length.
|
||||
of at most `batch_steps * num_envs` in size. The batch will
|
||||
be exactly `batch_steps * num_envs` in size if
|
||||
postprocessing does not change batch sizes. Episodes may be
|
||||
truncated in order to meet this size requirement.
|
||||
"complete_episodes": Each call to sample() will return a batch
|
||||
of at least `batch_steps in size. Episodes will not be
|
||||
truncated, but multiple episodes may be packed within one
|
||||
batch to meet the batch size. Note that when
|
||||
of at least `batch_steps * num_envs` in size. Episodes will
|
||||
not be truncated, but multiple episodes may be packed
|
||||
within one batch to meet the batch size. Note that when
|
||||
`num_envs > 1`, episode steps will be buffered until the
|
||||
episode completes, and hence batches may contain
|
||||
significant amounts of off-policy data.
|
||||
@@ -171,7 +169,7 @@ class PolicyEvaluator(EvaluatorInterface):
|
||||
policy_mapping_fn = (policy_mapping_fn
|
||||
or (lambda agent_id: DEFAULT_POLICY_ID))
|
||||
self.env_creator = env_creator
|
||||
self.batch_steps = batch_steps
|
||||
self.sample_batch_size = batch_steps * num_envs
|
||||
self.batch_mode = batch_mode
|
||||
self.compress_observations = compress_observations
|
||||
|
||||
@@ -246,15 +244,10 @@ class PolicyEvaluator(EvaluatorInterface):
|
||||
self.num_envs = num_envs
|
||||
|
||||
if self.batch_mode == "truncate_episodes":
|
||||
if batch_steps % num_envs != 0:
|
||||
raise ValueError(
|
||||
"In 'truncate_episodes' batch mode, `batch_steps` must be "
|
||||
"evenly divisible by `num_envs`. Got {} and {}.".format(
|
||||
batch_steps, num_envs))
|
||||
batch_steps = batch_steps // num_envs
|
||||
unroll_length = batch_steps
|
||||
pack_episodes = True
|
||||
elif self.batch_mode == "complete_episodes":
|
||||
batch_steps = float("inf") # never cut episodes
|
||||
unroll_length = float("inf") # never cut episodes
|
||||
pack_episodes = False # sampler will return 1 episode per poll
|
||||
else:
|
||||
raise ValueError("Unsupported batch mode: {}".format(
|
||||
@@ -266,7 +259,7 @@ class PolicyEvaluator(EvaluatorInterface):
|
||||
policy_mapping_fn,
|
||||
self.filters,
|
||||
clip_rewards,
|
||||
batch_steps,
|
||||
unroll_length,
|
||||
horizon=episode_horizon,
|
||||
pack=pack_episodes,
|
||||
tf_sess=self.tf_sess)
|
||||
@@ -278,7 +271,7 @@ class PolicyEvaluator(EvaluatorInterface):
|
||||
policy_mapping_fn,
|
||||
self.filters,
|
||||
clip_rewards,
|
||||
batch_steps,
|
||||
unroll_length,
|
||||
horizon=episode_horizon,
|
||||
pack=pack_episodes,
|
||||
tf_sess=self.tf_sess)
|
||||
@@ -310,7 +303,8 @@ class PolicyEvaluator(EvaluatorInterface):
|
||||
else:
|
||||
max_batches = float("inf")
|
||||
|
||||
while steps_so_far < self.batch_steps and len(batches) < max_batches:
|
||||
while steps_so_far < self.sample_batch_size and len(
|
||||
batches) < max_batches:
|
||||
batch = self.sampler.get_data()
|
||||
steps_so_far += batch.count
|
||||
batches.append(batch)
|
||||
|
||||
@@ -36,12 +36,12 @@ class SyncSampler(object):
|
||||
policy_mapping_fn,
|
||||
obs_filters,
|
||||
clip_rewards,
|
||||
num_local_steps,
|
||||
unroll_length,
|
||||
horizon=None,
|
||||
pack=False,
|
||||
tf_sess=None):
|
||||
self.async_vector_env = AsyncVectorEnv.wrap_async(env)
|
||||
self.num_local_steps = num_local_steps
|
||||
self.unroll_length = unroll_length
|
||||
self.horizon = horizon
|
||||
self.policies = policies
|
||||
self.policy_mapping_fn = policy_mapping_fn
|
||||
@@ -49,7 +49,7 @@ class SyncSampler(object):
|
||||
self.extra_batches = queue.Queue()
|
||||
self.rollout_provider = _env_runner(
|
||||
self.async_vector_env, self.extra_batches.put, self.policies,
|
||||
self.policy_mapping_fn, self.num_local_steps, self.horizon,
|
||||
self.policy_mapping_fn, self.unroll_length, self.horizon,
|
||||
self._obs_filters, clip_rewards, pack, tf_sess)
|
||||
self.metrics_queue = queue.Queue()
|
||||
|
||||
@@ -92,7 +92,7 @@ class AsyncSampler(threading.Thread):
|
||||
policy_mapping_fn,
|
||||
obs_filters,
|
||||
clip_rewards,
|
||||
num_local_steps,
|
||||
unroll_length,
|
||||
horizon=None,
|
||||
pack=False,
|
||||
tf_sess=None):
|
||||
@@ -104,7 +104,7 @@ class AsyncSampler(threading.Thread):
|
||||
self.queue = queue.Queue(5)
|
||||
self.extra_batches = queue.Queue()
|
||||
self.metrics_queue = queue.Queue()
|
||||
self.num_local_steps = num_local_steps
|
||||
self.unroll_length = unroll_length
|
||||
self.horizon = horizon
|
||||
self.policies = policies
|
||||
self.policy_mapping_fn = policy_mapping_fn
|
||||
@@ -124,7 +124,7 @@ class AsyncSampler(threading.Thread):
|
||||
def _run(self):
|
||||
rollout_provider = _env_runner(
|
||||
self.async_vector_env, self.extra_batches.put, self.policies,
|
||||
self.policy_mapping_fn, self.num_local_steps, self.horizon,
|
||||
self.policy_mapping_fn, self.unroll_length, self.horizon,
|
||||
self._obs_filters, self.clip_rewards, self.pack, self.tf_sess)
|
||||
while True:
|
||||
# The timeout variable exists because apparently, if one worker
|
||||
@@ -182,7 +182,7 @@ def _env_runner(async_vector_env,
|
||||
extra_batch_callback,
|
||||
policies,
|
||||
policy_mapping_fn,
|
||||
num_local_steps,
|
||||
unroll_length,
|
||||
horizon,
|
||||
obs_filters,
|
||||
clip_rewards,
|
||||
@@ -197,14 +197,14 @@ def _env_runner(async_vector_env,
|
||||
policy_mapping_fn (func): Function that maps agent ids to policy ids.
|
||||
This is called when an agent first enters the environment. The
|
||||
agent is then "bound" to the returned policy for the episode.
|
||||
num_local_steps (int): Number of episode steps before `SampleBatch` is
|
||||
unroll_length (int): Number of episode steps before `SampleBatch` is
|
||||
yielded. Set to infinity to yield complete episodes.
|
||||
horizon (int): Horizon of the episode.
|
||||
obs_filters (dict): Map of policy id to filter used to process
|
||||
observations for the policy.
|
||||
clip_rewards (bool): Whether to clip rewards before postprocessing.
|
||||
pack (bool): Whether to pack multiple episodes into each batch. This
|
||||
guarantees batches will be exactly `num_local_steps` in size.
|
||||
guarantees batches will be exactly `unroll_length` in size.
|
||||
tf_sess (Session|None): Optional tensorflow session to use for batching
|
||||
TF policy evaluations.
|
||||
|
||||
@@ -306,7 +306,7 @@ def _env_runner(async_vector_env,
|
||||
# or if we've exceeded the requested batch size.
|
||||
if episode.batch_builder.has_pending_data():
|
||||
if (all_done and not pack) or \
|
||||
episode.batch_builder.count >= num_local_steps:
|
||||
episode.batch_builder.count >= unroll_length:
|
||||
yield episode.batch_builder.build_and_reset()
|
||||
elif all_done:
|
||||
# Make sure postprocessor stays within one episode
|
||||
|
||||
@@ -129,9 +129,10 @@ class TestPolicyEvaluator(unittest.TestCase):
|
||||
"num_workers": 2,
|
||||
"sample_batch_size": 5
|
||||
})
|
||||
results = pg.optimizer.foreach_evaluator(lambda ev: ev.batch_steps)
|
||||
results = pg.optimizer.foreach_evaluator(
|
||||
lambda ev: ev.sample_batch_size)
|
||||
results2 = pg.optimizer.foreach_evaluator_with_index(
|
||||
lambda ev, i: (i, ev.batch_steps))
|
||||
lambda ev, i: (i, ev.sample_batch_size))
|
||||
self.assertEqual(results, [5, 5, 5])
|
||||
self.assertEqual(results2, [(0, 5), (1, 5), (2, 5)])
|
||||
|
||||
@@ -198,7 +199,7 @@ class TestPolicyEvaluator(unittest.TestCase):
|
||||
env_creator=lambda cfg: MockEnv(episode_length=20, config=cfg),
|
||||
policy_graph=MockPolicyGraph,
|
||||
batch_mode="truncate_episodes",
|
||||
batch_steps=16,
|
||||
batch_steps=2,
|
||||
num_envs=8)
|
||||
for _ in range(8):
|
||||
batch = ev.sample()
|
||||
@@ -216,21 +217,12 @@ class TestPolicyEvaluator(unittest.TestCase):
|
||||
indices.append(env.unwrapped.config.vector_index)
|
||||
self.assertEqual(indices, [0, 1, 2, 3, 4, 5, 6, 7])
|
||||
|
||||
def testBatchDivisibilityCheck(self):
|
||||
self.assertRaises(
|
||||
ValueError,
|
||||
lambda: PolicyEvaluator(
|
||||
env_creator=lambda _: MockEnv(episode_length=8),
|
||||
policy_graph=MockPolicyGraph,
|
||||
batch_mode="truncate_episodes",
|
||||
batch_steps=15, num_envs=4))
|
||||
|
||||
def testBatchesSmallerWhenVectorized(self):
|
||||
def testBatchesLargerWhenVectorized(self):
|
||||
ev = PolicyEvaluator(
|
||||
env_creator=lambda _: MockEnv(episode_length=8),
|
||||
policy_graph=MockPolicyGraph,
|
||||
batch_mode="truncate_episodes",
|
||||
batch_steps=16,
|
||||
batch_steps=4,
|
||||
num_envs=4)
|
||||
batch = ev.sample()
|
||||
self.assertEqual(batch.count, 16)
|
||||
|
||||
@@ -9,7 +9,7 @@ atari-a2c:
|
||||
- SpaceInvadersNoFrameskip-v4
|
||||
run: A2C
|
||||
config:
|
||||
sample_batch_size: 100
|
||||
sample_batch_size: 20
|
||||
clip_rewards: True
|
||||
num_workers: 5
|
||||
num_envs_per_worker: 5
|
||||
|
||||
@@ -28,7 +28,7 @@ apex:
|
||||
# APEX
|
||||
num_workers: 8
|
||||
num_envs_per_worker: 8
|
||||
sample_batch_size: 158
|
||||
sample_batch_size: 20
|
||||
train_batch_size: 512
|
||||
target_network_update_freq: 50000
|
||||
timesteps_per_iteration: 25000
|
||||
|
||||
@@ -9,7 +9,7 @@ atari-impala:
|
||||
- SpaceInvadersNoFrameskip-v4
|
||||
run: IMPALA
|
||||
config:
|
||||
sample_batch_size: 250 # 50 * num_envs_per_worker
|
||||
sample_batch_size: 50
|
||||
train_batch_size: 500
|
||||
num_workers: 32
|
||||
num_envs_per_worker: 5
|
||||
|
||||
@@ -16,7 +16,7 @@ atari-ppo:
|
||||
vf_clip_param: 10.0
|
||||
entropy_coeff: 0.01
|
||||
train_batch_size: 5000
|
||||
sample_batch_size: 500
|
||||
sample_batch_size: 100
|
||||
sgd_minibatch_size: 500
|
||||
num_sgd_iter: 10
|
||||
num_workers: 10
|
||||
|
||||
@@ -5,7 +5,7 @@ pong-impala-vectorized:
|
||||
env: PongNoFrameskip-v4
|
||||
run: IMPALA
|
||||
config:
|
||||
sample_batch_size: 500 # 50 * num_envs_per_worker
|
||||
sample_batch_size: 50
|
||||
train_batch_size: 500
|
||||
num_workers: 32
|
||||
num_envs_per_worker: 10
|
||||
|
||||
Reference in New Issue
Block a user