[rllib] Use batch.count in async samples optimizer (#2488)

Using the actual batch size reduces the risk of mis-accounting. Here, we under-counted samples since in truncate_episodes mode we were doubling the batch size by accident in policy_evaluator.
This commit is contained in:
Eric Liang
2018-07-27 16:44:21 -07:00
committed by GitHub
parent 1e6b130b90
commit 24649726dc
5 changed files with 43 additions and 20 deletions
-1
View File
@@ -23,7 +23,6 @@ APEX_DEFAULT_CONFIG = merge_dicts(
"learning_starts": 50000,
"train_batch_size": 512,
"sample_batch_size": 50,
"max_weight_sync_delay": 400,
"target_network_update_freq": 500000,
"timesteps_per_iteration": 25000,
"per_worker_exploration": True,
@@ -122,10 +122,12 @@ class PolicyEvaluator(EvaluatorInterface):
in each sample batch returned from this evaluator.
batch_mode (str): One of the following batch modes:
"truncate_episodes": Each call to sample() will return a batch
of exactly `batch_steps` in size. Episodes may be truncated
in order to meet this size requirement. When
`num_envs > 1`, episodes will be truncated to sequences of
`batch_size / num_envs` in length.
of at most `batch_steps` in size. The batch will be exactly
`batch_steps` in size if postprocessing does not change
batch sizes. Episodes may be truncated in order to meet
this size requirement. When `num_envs > 1`, episodes will
be truncated to sequences of `batch_size / num_envs` in
length.
"complete_episodes": Each call to sample() will return a batch
of at least `batch_steps in size. Episodes will not be
truncated, but multiple episodes may be packed within one
@@ -220,6 +222,7 @@ class PolicyEvaluator(EvaluatorInterface):
# Always use vector env for consistency even if num_envs = 1
self.async_env = AsyncVectorEnv.wrap_async(
self.env, make_env=make_env, num_envs=num_envs)
self.num_envs = num_envs
if self.batch_mode == "truncate_episodes":
if batch_steps % num_envs != 0:
@@ -276,7 +279,15 @@ class PolicyEvaluator(EvaluatorInterface):
batches = [self.sampler.get_data()]
steps_so_far = batches[0].count
while steps_so_far < self.batch_steps:
# In truncate_episodes mode, never pull more than 1 batch per env.
# This avoids over-running the target batch size.
if self.batch_mode == "truncate_episodes":
max_batches = self.num_envs
else:
max_batches = float("inf")
while steps_so_far < self.batch_steps and len(batches) < max_batches:
batch = self.sampler.get_data()
steps_so_far += batch.count
batches.append(batch)
@@ -293,6 +304,12 @@ class PolicyEvaluator(EvaluatorInterface):
return batch
@ray.method(num_return_vals=2)
def sample_with_count(self):
"""Same as sample() but returns the count as a separate future."""
batch = self.sample()
return batch, batch.count
def for_policy(self, func, policy_id=DEFAULT_POLICY_ID):
"""Apply the given function to the specified policy graph."""
@@ -58,6 +58,7 @@ class ReplayActor(object):
return os.uname()[1]
def add_batch(self, batch):
PolicyOptimizer._check_not_multiagent(batch)
with self.add_batch_timer:
for row in batch.rows():
self.replay_buffer.add(row["obs"], row["actions"],
@@ -131,7 +132,7 @@ class LearnerThread(threading.Thread):
with self.grad_timer:
td_error = self.local_evaluator.compute_apply(replay)[
"td_error"]
self.outqueue.put((ra, replay, td_error))
self.outqueue.put((ra, replay, td_error, replay.count))
self.learner_queue_size.push(self.inqueue.qsize())
self.weights_updated = True
@@ -164,8 +165,6 @@ class AsyncSamplesOptimizer(PolicyOptimizer):
self.replay_starts = learning_starts
self.prioritized_replay_beta = prioritized_replay_beta
self.prioritized_replay_eps = prioritized_replay_eps
self.train_batch_size = train_batch_size
self.sample_batch_size = sample_batch_size
self.max_weight_sync_delay = max_weight_sync_delay
self.learner = LearnerThread(self.local_evaluator)
@@ -205,7 +204,7 @@ class AsyncSamplesOptimizer(PolicyOptimizer):
ev.set_weights.remote(weights)
self.steps_since_update[ev] = 0
for _ in range(SAMPLE_QUEUE_DEPTH):
self.sample_tasks.add(ev, ev.sample.remote())
self.sample_tasks.add(ev, ev.sample_with_count.remote())
def step(self):
start = time.time()
@@ -226,16 +225,17 @@ class AsyncSamplesOptimizer(PolicyOptimizer):
weights = None
with self.timers["sample_processing"]:
for ev, sample_batch in self.sample_tasks.completed():
self._check_not_multiagent(sample_batch)
sample_timesteps += self.sample_batch_size
completed = list(self.sample_tasks.completed())
counts = ray.get([c[1][1] for c in completed])
for i, (ev, (sample_batch, count)) in enumerate(completed):
sample_timesteps += counts[i]
# Send the data to the replay buffer
random.choice(
self.replay_actors).add_batch.remote(sample_batch)
# Update weights if needed
self.steps_since_update[ev] += self.sample_batch_size
self.steps_since_update[ev] += counts[i]
if self.steps_since_update[ev] >= self.max_weight_sync_delay:
# Note that it's important to pull new weights once
# updated to avoid excessive correlation between actors
@@ -249,7 +249,7 @@ class AsyncSamplesOptimizer(PolicyOptimizer):
self.steps_since_update[ev] = 0
# Kick off another sample request
self.sample_tasks.add(ev, ev.sample.remote())
self.sample_tasks.add(ev, ev.sample_with_count.remote())
with self.timers["replay_processing"]:
for ra, replay in self.replay_tasks.completed():
@@ -261,9 +261,9 @@ class AsyncSamplesOptimizer(PolicyOptimizer):
with self.timers["update_priorities"]:
while not self.learner.outqueue.empty():
ra, replay, td_error = self.learner.outqueue.get()
ra, replay, td_error, count = self.learner.outqueue.get()
ra.update_priorities.remote(replay["batch_indexes"], td_error)
train_timesteps += self.train_batch_size
train_timesteps += count
return sample_timesteps, train_timesteps
@@ -121,7 +121,8 @@ class PolicyOptimizer(object):
])
return local_result + remote_results
def _check_not_multiagent(self, sample_batch):
@staticmethod
def _check_not_multiagent(sample_batch):
if isinstance(sample_batch, MultiAgentBatch):
raise NotImplementedError(
"This optimizer does not support multi-agent yet.")
+8 -2
View File
@@ -11,16 +11,22 @@ class TaskPool(object):
def __init__(self):
self._tasks = {}
self._objects = {}
def add(self, worker, obj_id):
def add(self, worker, all_obj_ids):
if isinstance(all_obj_ids, list):
obj_id = all_obj_ids[0]
else:
obj_id = all_obj_ids
self._tasks[obj_id] = worker
self._objects[obj_id] = all_obj_ids
def completed(self):
pending = list(self._tasks)
if pending:
ready, _ = ray.wait(pending, num_returns=len(pending), timeout=10)
for obj_id in ready:
yield (self._tasks.pop(obj_id), obj_id)
yield (self._tasks.pop(obj_id), self._objects.pop(obj_id))
@property
def count(self):