From 24649726dcbad87b4f72e6dd665fd5916f69fd26 Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Fri, 27 Jul 2018 16:44:21 -0700 Subject: [PATCH] [rllib] Use batch.count in async samples optimizer (#2488) Using the actual batch size reduces the risk of mis-accounting. Here, we under-counted samples since in truncate_episodes mode we were doubling the batch size by accident in policy_evaluator. --- python/ray/rllib/agents/dqn/apex.py | 1 - .../ray/rllib/evaluation/policy_evaluator.py | 27 +++++++++++++++---- .../optimizers/async_samples_optimizer.py | 22 +++++++-------- .../ray/rllib/optimizers/policy_optimizer.py | 3 ++- python/ray/rllib/utils/actors.py | 10 +++++-- 5 files changed, 43 insertions(+), 20 deletions(-) diff --git a/python/ray/rllib/agents/dqn/apex.py b/python/ray/rllib/agents/dqn/apex.py index 9321a70ff..138ad106c 100644 --- a/python/ray/rllib/agents/dqn/apex.py +++ b/python/ray/rllib/agents/dqn/apex.py @@ -23,7 +23,6 @@ APEX_DEFAULT_CONFIG = merge_dicts( "learning_starts": 50000, "train_batch_size": 512, "sample_batch_size": 50, - "max_weight_sync_delay": 400, "target_network_update_freq": 500000, "timesteps_per_iteration": 25000, "per_worker_exploration": True, diff --git a/python/ray/rllib/evaluation/policy_evaluator.py b/python/ray/rllib/evaluation/policy_evaluator.py index a3245c800..4cc13852d 100644 --- a/python/ray/rllib/evaluation/policy_evaluator.py +++ b/python/ray/rllib/evaluation/policy_evaluator.py @@ -122,10 +122,12 @@ class PolicyEvaluator(EvaluatorInterface): in each sample batch returned from this evaluator. batch_mode (str): One of the following batch modes: "truncate_episodes": Each call to sample() will return a batch - of exactly `batch_steps` in size. Episodes may be truncated - in order to meet this size requirement. When - `num_envs > 1`, episodes will be truncated to sequences of - `batch_size / num_envs` in length. + of at most `batch_steps` in size. The batch will be exactly + `batch_steps` in size if postprocessing does not change + batch sizes. Episodes may be truncated in order to meet + this size requirement. When `num_envs > 1`, episodes will + be truncated to sequences of `batch_size / num_envs` in + length. "complete_episodes": Each call to sample() will return a batch of at least `batch_steps in size. Episodes will not be truncated, but multiple episodes may be packed within one @@ -220,6 +222,7 @@ class PolicyEvaluator(EvaluatorInterface): # Always use vector env for consistency even if num_envs = 1 self.async_env = AsyncVectorEnv.wrap_async( self.env, make_env=make_env, num_envs=num_envs) + self.num_envs = num_envs if self.batch_mode == "truncate_episodes": if batch_steps % num_envs != 0: @@ -276,7 +279,15 @@ class PolicyEvaluator(EvaluatorInterface): batches = [self.sampler.get_data()] steps_so_far = batches[0].count - while steps_so_far < self.batch_steps: + + # In truncate_episodes mode, never pull more than 1 batch per env. + # This avoids over-running the target batch size. + if self.batch_mode == "truncate_episodes": + max_batches = self.num_envs + else: + max_batches = float("inf") + + while steps_so_far < self.batch_steps and len(batches) < max_batches: batch = self.sampler.get_data() steps_so_far += batch.count batches.append(batch) @@ -293,6 +304,12 @@ class PolicyEvaluator(EvaluatorInterface): return batch + @ray.method(num_return_vals=2) + def sample_with_count(self): + """Same as sample() but returns the count as a separate future.""" + batch = self.sample() + return batch, batch.count + def for_policy(self, func, policy_id=DEFAULT_POLICY_ID): """Apply the given function to the specified policy graph.""" diff --git a/python/ray/rllib/optimizers/async_samples_optimizer.py b/python/ray/rllib/optimizers/async_samples_optimizer.py index e37901c46..ebc8676cd 100644 --- a/python/ray/rllib/optimizers/async_samples_optimizer.py +++ b/python/ray/rllib/optimizers/async_samples_optimizer.py @@ -58,6 +58,7 @@ class ReplayActor(object): return os.uname()[1] def add_batch(self, batch): + PolicyOptimizer._check_not_multiagent(batch) with self.add_batch_timer: for row in batch.rows(): self.replay_buffer.add(row["obs"], row["actions"], @@ -131,7 +132,7 @@ class LearnerThread(threading.Thread): with self.grad_timer: td_error = self.local_evaluator.compute_apply(replay)[ "td_error"] - self.outqueue.put((ra, replay, td_error)) + self.outqueue.put((ra, replay, td_error, replay.count)) self.learner_queue_size.push(self.inqueue.qsize()) self.weights_updated = True @@ -164,8 +165,6 @@ class AsyncSamplesOptimizer(PolicyOptimizer): self.replay_starts = learning_starts self.prioritized_replay_beta = prioritized_replay_beta self.prioritized_replay_eps = prioritized_replay_eps - self.train_batch_size = train_batch_size - self.sample_batch_size = sample_batch_size self.max_weight_sync_delay = max_weight_sync_delay self.learner = LearnerThread(self.local_evaluator) @@ -205,7 +204,7 @@ class AsyncSamplesOptimizer(PolicyOptimizer): ev.set_weights.remote(weights) self.steps_since_update[ev] = 0 for _ in range(SAMPLE_QUEUE_DEPTH): - self.sample_tasks.add(ev, ev.sample.remote()) + self.sample_tasks.add(ev, ev.sample_with_count.remote()) def step(self): start = time.time() @@ -226,16 +225,17 @@ class AsyncSamplesOptimizer(PolicyOptimizer): weights = None with self.timers["sample_processing"]: - for ev, sample_batch in self.sample_tasks.completed(): - self._check_not_multiagent(sample_batch) - sample_timesteps += self.sample_batch_size + completed = list(self.sample_tasks.completed()) + counts = ray.get([c[1][1] for c in completed]) + for i, (ev, (sample_batch, count)) in enumerate(completed): + sample_timesteps += counts[i] # Send the data to the replay buffer random.choice( self.replay_actors).add_batch.remote(sample_batch) # Update weights if needed - self.steps_since_update[ev] += self.sample_batch_size + self.steps_since_update[ev] += counts[i] if self.steps_since_update[ev] >= self.max_weight_sync_delay: # Note that it's important to pull new weights once # updated to avoid excessive correlation between actors @@ -249,7 +249,7 @@ class AsyncSamplesOptimizer(PolicyOptimizer): self.steps_since_update[ev] = 0 # Kick off another sample request - self.sample_tasks.add(ev, ev.sample.remote()) + self.sample_tasks.add(ev, ev.sample_with_count.remote()) with self.timers["replay_processing"]: for ra, replay in self.replay_tasks.completed(): @@ -261,9 +261,9 @@ class AsyncSamplesOptimizer(PolicyOptimizer): with self.timers["update_priorities"]: while not self.learner.outqueue.empty(): - ra, replay, td_error = self.learner.outqueue.get() + ra, replay, td_error, count = self.learner.outqueue.get() ra.update_priorities.remote(replay["batch_indexes"], td_error) - train_timesteps += self.train_batch_size + train_timesteps += count return sample_timesteps, train_timesteps diff --git a/python/ray/rllib/optimizers/policy_optimizer.py b/python/ray/rllib/optimizers/policy_optimizer.py index d0e9720bb..2943102a4 100644 --- a/python/ray/rllib/optimizers/policy_optimizer.py +++ b/python/ray/rllib/optimizers/policy_optimizer.py @@ -121,7 +121,8 @@ class PolicyOptimizer(object): ]) return local_result + remote_results - def _check_not_multiagent(self, sample_batch): + @staticmethod + def _check_not_multiagent(sample_batch): if isinstance(sample_batch, MultiAgentBatch): raise NotImplementedError( "This optimizer does not support multi-agent yet.") diff --git a/python/ray/rllib/utils/actors.py b/python/ray/rllib/utils/actors.py index a7f604bc2..c663087eb 100644 --- a/python/ray/rllib/utils/actors.py +++ b/python/ray/rllib/utils/actors.py @@ -11,16 +11,22 @@ class TaskPool(object): def __init__(self): self._tasks = {} + self._objects = {} - def add(self, worker, obj_id): + def add(self, worker, all_obj_ids): + if isinstance(all_obj_ids, list): + obj_id = all_obj_ids[0] + else: + obj_id = all_obj_ids self._tasks[obj_id] = worker + self._objects[obj_id] = all_obj_ids def completed(self): pending = list(self._tasks) if pending: ready, _ = ray.wait(pending, num_returns=len(pending), timeout=10) for obj_id in ready: - yield (self._tasks.pop(obj_id), obj_id) + yield (self._tasks.pop(obj_id), self._objects.pop(obj_id)) @property def count(self):