[rllib] Simplify sample batch size and num envs config, n_step adjustment (#2995)

* simplify vec batch requirements

* Update rllib-training.rst

* Update rllib-training.rst

* Update rllib-training.rst

* Update rllib-training.rst

* Update rllib-training.rst

* Update rllib-models.rst
This commit is contained in:
Eric Liang
2018-09-30 18:36:22 -07:00
committed by GitHub
parent 8aa736572b
commit 814c35b7d7
12 changed files with 68 additions and 57 deletions
+2 -2
View File
@@ -135,8 +135,8 @@ class DQNAgent(Agent):
def _init(self):
# Update effective batch size to include n-step
adjusted_batch_size = (
self.config["sample_batch_size"] + self.config["n_step"] - 1)
adjusted_batch_size = max(self.config["sample_batch_size"],
self.config["n_step"])
self.config["sample_batch_size"] = adjusted_batch_size
self.exploration0 = self._make_exploration_schedule(0)
@@ -126,8 +126,7 @@ class VTracePolicyGraph(LearningRateSchedule, TFPolicyGraph):
else:
# Important: chop the tensor into batches at known episode cut
# boundaries. TODO(ekl) this is kind of a hack
T = (self.config["sample_batch_size"] //
self.config["num_envs_per_worker"])
T = self.config["sample_batch_size"]
B = tf.shape(tensor)[0] // T
rs = tf.reshape(tensor,
tf.concat([[B, T], tf.shape(tensor)[1:]], axis=0))
+14 -20
View File
@@ -124,16 +124,14 @@ class PolicyEvaluator(EvaluatorInterface):
in each sample batch returned from this evaluator.
batch_mode (str): One of the following batch modes:
"truncate_episodes": Each call to sample() will return a batch
of at most `batch_steps` in size. The batch will be exactly
`batch_steps` in size if postprocessing does not change
batch sizes. Episodes may be truncated in order to meet
this size requirement. When `num_envs > 1`, episodes will
be truncated to sequences of `batch_size / num_envs` in
length.
of at most `batch_steps * num_envs` in size. The batch will
be exactly `batch_steps * num_envs` in size if
postprocessing does not change batch sizes. Episodes may be
truncated in order to meet this size requirement.
"complete_episodes": Each call to sample() will return a batch
of at least `batch_steps in size. Episodes will not be
truncated, but multiple episodes may be packed within one
batch to meet the batch size. Note that when
of at least `batch_steps * num_envs` in size. Episodes will
not be truncated, but multiple episodes may be packed
within one batch to meet the batch size. Note that when
`num_envs > 1`, episode steps will be buffered until the
episode completes, and hence batches may contain
significant amounts of off-policy data.
@@ -171,7 +169,7 @@ class PolicyEvaluator(EvaluatorInterface):
policy_mapping_fn = (policy_mapping_fn
or (lambda agent_id: DEFAULT_POLICY_ID))
self.env_creator = env_creator
self.batch_steps = batch_steps
self.sample_batch_size = batch_steps * num_envs
self.batch_mode = batch_mode
self.compress_observations = compress_observations
@@ -246,15 +244,10 @@ class PolicyEvaluator(EvaluatorInterface):
self.num_envs = num_envs
if self.batch_mode == "truncate_episodes":
if batch_steps % num_envs != 0:
raise ValueError(
"In 'truncate_episodes' batch mode, `batch_steps` must be "
"evenly divisible by `num_envs`. Got {} and {}.".format(
batch_steps, num_envs))
batch_steps = batch_steps // num_envs
unroll_length = batch_steps
pack_episodes = True
elif self.batch_mode == "complete_episodes":
batch_steps = float("inf") # never cut episodes
unroll_length = float("inf") # never cut episodes
pack_episodes = False # sampler will return 1 episode per poll
else:
raise ValueError("Unsupported batch mode: {}".format(
@@ -266,7 +259,7 @@ class PolicyEvaluator(EvaluatorInterface):
policy_mapping_fn,
self.filters,
clip_rewards,
batch_steps,
unroll_length,
horizon=episode_horizon,
pack=pack_episodes,
tf_sess=self.tf_sess)
@@ -278,7 +271,7 @@ class PolicyEvaluator(EvaluatorInterface):
policy_mapping_fn,
self.filters,
clip_rewards,
batch_steps,
unroll_length,
horizon=episode_horizon,
pack=pack_episodes,
tf_sess=self.tf_sess)
@@ -310,7 +303,8 @@ class PolicyEvaluator(EvaluatorInterface):
else:
max_batches = float("inf")
while steps_so_far < self.batch_steps and len(batches) < max_batches:
while steps_so_far < self.sample_batch_size and len(
batches) < max_batches:
batch = self.sampler.get_data()
steps_so_far += batch.count
batches.append(batch)
+10 -10
View File
@@ -36,12 +36,12 @@ class SyncSampler(object):
policy_mapping_fn,
obs_filters,
clip_rewards,
num_local_steps,
unroll_length,
horizon=None,
pack=False,
tf_sess=None):
self.async_vector_env = AsyncVectorEnv.wrap_async(env)
self.num_local_steps = num_local_steps
self.unroll_length = unroll_length
self.horizon = horizon
self.policies = policies
self.policy_mapping_fn = policy_mapping_fn
@@ -49,7 +49,7 @@ class SyncSampler(object):
self.extra_batches = queue.Queue()
self.rollout_provider = _env_runner(
self.async_vector_env, self.extra_batches.put, self.policies,
self.policy_mapping_fn, self.num_local_steps, self.horizon,
self.policy_mapping_fn, self.unroll_length, self.horizon,
self._obs_filters, clip_rewards, pack, tf_sess)
self.metrics_queue = queue.Queue()
@@ -92,7 +92,7 @@ class AsyncSampler(threading.Thread):
policy_mapping_fn,
obs_filters,
clip_rewards,
num_local_steps,
unroll_length,
horizon=None,
pack=False,
tf_sess=None):
@@ -104,7 +104,7 @@ class AsyncSampler(threading.Thread):
self.queue = queue.Queue(5)
self.extra_batches = queue.Queue()
self.metrics_queue = queue.Queue()
self.num_local_steps = num_local_steps
self.unroll_length = unroll_length
self.horizon = horizon
self.policies = policies
self.policy_mapping_fn = policy_mapping_fn
@@ -124,7 +124,7 @@ class AsyncSampler(threading.Thread):
def _run(self):
rollout_provider = _env_runner(
self.async_vector_env, self.extra_batches.put, self.policies,
self.policy_mapping_fn, self.num_local_steps, self.horizon,
self.policy_mapping_fn, self.unroll_length, self.horizon,
self._obs_filters, self.clip_rewards, self.pack, self.tf_sess)
while True:
# The timeout variable exists because apparently, if one worker
@@ -182,7 +182,7 @@ def _env_runner(async_vector_env,
extra_batch_callback,
policies,
policy_mapping_fn,
num_local_steps,
unroll_length,
horizon,
obs_filters,
clip_rewards,
@@ -197,14 +197,14 @@ def _env_runner(async_vector_env,
policy_mapping_fn (func): Function that maps agent ids to policy ids.
This is called when an agent first enters the environment. The
agent is then "bound" to the returned policy for the episode.
num_local_steps (int): Number of episode steps before `SampleBatch` is
unroll_length (int): Number of episode steps before `SampleBatch` is
yielded. Set to infinity to yield complete episodes.
horizon (int): Horizon of the episode.
obs_filters (dict): Map of policy id to filter used to process
observations for the policy.
clip_rewards (bool): Whether to clip rewards before postprocessing.
pack (bool): Whether to pack multiple episodes into each batch. This
guarantees batches will be exactly `num_local_steps` in size.
guarantees batches will be exactly `unroll_length` in size.
tf_sess (Session|None): Optional tensorflow session to use for batching
TF policy evaluations.
@@ -306,7 +306,7 @@ def _env_runner(async_vector_env,
# or if we've exceeded the requested batch size.
if episode.batch_builder.has_pending_data():
if (all_done and not pack) or \
episode.batch_builder.count >= num_local_steps:
episode.batch_builder.count >= unroll_length:
yield episode.batch_builder.build_and_reset()
elif all_done:
# Make sure postprocessor stays within one episode
+6 -14
View File
@@ -129,9 +129,10 @@ class TestPolicyEvaluator(unittest.TestCase):
"num_workers": 2,
"sample_batch_size": 5
})
results = pg.optimizer.foreach_evaluator(lambda ev: ev.batch_steps)
results = pg.optimizer.foreach_evaluator(
lambda ev: ev.sample_batch_size)
results2 = pg.optimizer.foreach_evaluator_with_index(
lambda ev, i: (i, ev.batch_steps))
lambda ev, i: (i, ev.sample_batch_size))
self.assertEqual(results, [5, 5, 5])
self.assertEqual(results2, [(0, 5), (1, 5), (2, 5)])
@@ -198,7 +199,7 @@ class TestPolicyEvaluator(unittest.TestCase):
env_creator=lambda cfg: MockEnv(episode_length=20, config=cfg),
policy_graph=MockPolicyGraph,
batch_mode="truncate_episodes",
batch_steps=16,
batch_steps=2,
num_envs=8)
for _ in range(8):
batch = ev.sample()
@@ -216,21 +217,12 @@ class TestPolicyEvaluator(unittest.TestCase):
indices.append(env.unwrapped.config.vector_index)
self.assertEqual(indices, [0, 1, 2, 3, 4, 5, 6, 7])
def testBatchDivisibilityCheck(self):
self.assertRaises(
ValueError,
lambda: PolicyEvaluator(
env_creator=lambda _: MockEnv(episode_length=8),
policy_graph=MockPolicyGraph,
batch_mode="truncate_episodes",
batch_steps=15, num_envs=4))
def testBatchesSmallerWhenVectorized(self):
def testBatchesLargerWhenVectorized(self):
ev = PolicyEvaluator(
env_creator=lambda _: MockEnv(episode_length=8),
policy_graph=MockPolicyGraph,
batch_mode="truncate_episodes",
batch_steps=16,
batch_steps=4,
num_envs=4)
batch = ev.sample()
self.assertEqual(batch.count, 16)
@@ -9,7 +9,7 @@ atari-a2c:
- SpaceInvadersNoFrameskip-v4
run: A2C
config:
sample_batch_size: 100
sample_batch_size: 20
clip_rewards: True
num_workers: 5
num_envs_per_worker: 5
@@ -28,7 +28,7 @@ apex:
# APEX
num_workers: 8
num_envs_per_worker: 8
sample_batch_size: 158
sample_batch_size: 20
train_batch_size: 512
target_network_update_freq: 50000
timesteps_per_iteration: 25000
@@ -9,7 +9,7 @@ atari-impala:
- SpaceInvadersNoFrameskip-v4
run: IMPALA
config:
sample_batch_size: 250 # 50 * num_envs_per_worker
sample_batch_size: 50
train_batch_size: 500
num_workers: 32
num_envs_per_worker: 5
@@ -16,7 +16,7 @@ atari-ppo:
vf_clip_param: 10.0
entropy_coeff: 0.01
train_batch_size: 5000
sample_batch_size: 500
sample_batch_size: 100
sgd_minibatch_size: 500
num_sgd_iter: 10
num_workers: 10
@@ -5,7 +5,7 @@ pong-impala-vectorized:
env: PongNoFrameskip-v4
run: IMPALA
config:
sample_batch_size: 500 # 50 * num_envs_per_worker
sample_batch_size: 50
train_batch_size: 500
num_workers: 32
num_envs_per_worker: 10