mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 17:02:43 +08:00
[rllib] Use batch.count in async samples optimizer (#2488)
Using the actual batch size reduces the risk of mis-accounting. Here, we under-counted samples since in truncate_episodes mode we were doubling the batch size by accident in policy_evaluator.
This commit is contained in:
@@ -23,7 +23,6 @@ APEX_DEFAULT_CONFIG = merge_dicts(
|
||||
"learning_starts": 50000,
|
||||
"train_batch_size": 512,
|
||||
"sample_batch_size": 50,
|
||||
"max_weight_sync_delay": 400,
|
||||
"target_network_update_freq": 500000,
|
||||
"timesteps_per_iteration": 25000,
|
||||
"per_worker_exploration": True,
|
||||
|
||||
@@ -122,10 +122,12 @@ class PolicyEvaluator(EvaluatorInterface):
|
||||
in each sample batch returned from this evaluator.
|
||||
batch_mode (str): One of the following batch modes:
|
||||
"truncate_episodes": Each call to sample() will return a batch
|
||||
of exactly `batch_steps` in size. Episodes may be truncated
|
||||
in order to meet this size requirement. When
|
||||
`num_envs > 1`, episodes will be truncated to sequences of
|
||||
`batch_size / num_envs` in length.
|
||||
of at most `batch_steps` in size. The batch will be exactly
|
||||
`batch_steps` in size if postprocessing does not change
|
||||
batch sizes. Episodes may be truncated in order to meet
|
||||
this size requirement. When `num_envs > 1`, episodes will
|
||||
be truncated to sequences of `batch_size / num_envs` in
|
||||
length.
|
||||
"complete_episodes": Each call to sample() will return a batch
|
||||
of at least `batch_steps in size. Episodes will not be
|
||||
truncated, but multiple episodes may be packed within one
|
||||
@@ -220,6 +222,7 @@ class PolicyEvaluator(EvaluatorInterface):
|
||||
# Always use vector env for consistency even if num_envs = 1
|
||||
self.async_env = AsyncVectorEnv.wrap_async(
|
||||
self.env, make_env=make_env, num_envs=num_envs)
|
||||
self.num_envs = num_envs
|
||||
|
||||
if self.batch_mode == "truncate_episodes":
|
||||
if batch_steps % num_envs != 0:
|
||||
@@ -276,7 +279,15 @@ class PolicyEvaluator(EvaluatorInterface):
|
||||
|
||||
batches = [self.sampler.get_data()]
|
||||
steps_so_far = batches[0].count
|
||||
while steps_so_far < self.batch_steps:
|
||||
|
||||
# In truncate_episodes mode, never pull more than 1 batch per env.
|
||||
# This avoids over-running the target batch size.
|
||||
if self.batch_mode == "truncate_episodes":
|
||||
max_batches = self.num_envs
|
||||
else:
|
||||
max_batches = float("inf")
|
||||
|
||||
while steps_so_far < self.batch_steps and len(batches) < max_batches:
|
||||
batch = self.sampler.get_data()
|
||||
steps_so_far += batch.count
|
||||
batches.append(batch)
|
||||
@@ -293,6 +304,12 @@ class PolicyEvaluator(EvaluatorInterface):
|
||||
|
||||
return batch
|
||||
|
||||
@ray.method(num_return_vals=2)
|
||||
def sample_with_count(self):
|
||||
"""Same as sample() but returns the count as a separate future."""
|
||||
batch = self.sample()
|
||||
return batch, batch.count
|
||||
|
||||
def for_policy(self, func, policy_id=DEFAULT_POLICY_ID):
|
||||
"""Apply the given function to the specified policy graph."""
|
||||
|
||||
|
||||
@@ -58,6 +58,7 @@ class ReplayActor(object):
|
||||
return os.uname()[1]
|
||||
|
||||
def add_batch(self, batch):
|
||||
PolicyOptimizer._check_not_multiagent(batch)
|
||||
with self.add_batch_timer:
|
||||
for row in batch.rows():
|
||||
self.replay_buffer.add(row["obs"], row["actions"],
|
||||
@@ -131,7 +132,7 @@ class LearnerThread(threading.Thread):
|
||||
with self.grad_timer:
|
||||
td_error = self.local_evaluator.compute_apply(replay)[
|
||||
"td_error"]
|
||||
self.outqueue.put((ra, replay, td_error))
|
||||
self.outqueue.put((ra, replay, td_error, replay.count))
|
||||
self.learner_queue_size.push(self.inqueue.qsize())
|
||||
self.weights_updated = True
|
||||
|
||||
@@ -164,8 +165,6 @@ class AsyncSamplesOptimizer(PolicyOptimizer):
|
||||
self.replay_starts = learning_starts
|
||||
self.prioritized_replay_beta = prioritized_replay_beta
|
||||
self.prioritized_replay_eps = prioritized_replay_eps
|
||||
self.train_batch_size = train_batch_size
|
||||
self.sample_batch_size = sample_batch_size
|
||||
self.max_weight_sync_delay = max_weight_sync_delay
|
||||
|
||||
self.learner = LearnerThread(self.local_evaluator)
|
||||
@@ -205,7 +204,7 @@ class AsyncSamplesOptimizer(PolicyOptimizer):
|
||||
ev.set_weights.remote(weights)
|
||||
self.steps_since_update[ev] = 0
|
||||
for _ in range(SAMPLE_QUEUE_DEPTH):
|
||||
self.sample_tasks.add(ev, ev.sample.remote())
|
||||
self.sample_tasks.add(ev, ev.sample_with_count.remote())
|
||||
|
||||
def step(self):
|
||||
start = time.time()
|
||||
@@ -226,16 +225,17 @@ class AsyncSamplesOptimizer(PolicyOptimizer):
|
||||
weights = None
|
||||
|
||||
with self.timers["sample_processing"]:
|
||||
for ev, sample_batch in self.sample_tasks.completed():
|
||||
self._check_not_multiagent(sample_batch)
|
||||
sample_timesteps += self.sample_batch_size
|
||||
completed = list(self.sample_tasks.completed())
|
||||
counts = ray.get([c[1][1] for c in completed])
|
||||
for i, (ev, (sample_batch, count)) in enumerate(completed):
|
||||
sample_timesteps += counts[i]
|
||||
|
||||
# Send the data to the replay buffer
|
||||
random.choice(
|
||||
self.replay_actors).add_batch.remote(sample_batch)
|
||||
|
||||
# Update weights if needed
|
||||
self.steps_since_update[ev] += self.sample_batch_size
|
||||
self.steps_since_update[ev] += counts[i]
|
||||
if self.steps_since_update[ev] >= self.max_weight_sync_delay:
|
||||
# Note that it's important to pull new weights once
|
||||
# updated to avoid excessive correlation between actors
|
||||
@@ -249,7 +249,7 @@ class AsyncSamplesOptimizer(PolicyOptimizer):
|
||||
self.steps_since_update[ev] = 0
|
||||
|
||||
# Kick off another sample request
|
||||
self.sample_tasks.add(ev, ev.sample.remote())
|
||||
self.sample_tasks.add(ev, ev.sample_with_count.remote())
|
||||
|
||||
with self.timers["replay_processing"]:
|
||||
for ra, replay in self.replay_tasks.completed():
|
||||
@@ -261,9 +261,9 @@ class AsyncSamplesOptimizer(PolicyOptimizer):
|
||||
|
||||
with self.timers["update_priorities"]:
|
||||
while not self.learner.outqueue.empty():
|
||||
ra, replay, td_error = self.learner.outqueue.get()
|
||||
ra, replay, td_error, count = self.learner.outqueue.get()
|
||||
ra.update_priorities.remote(replay["batch_indexes"], td_error)
|
||||
train_timesteps += self.train_batch_size
|
||||
train_timesteps += count
|
||||
|
||||
return sample_timesteps, train_timesteps
|
||||
|
||||
|
||||
@@ -121,7 +121,8 @@ class PolicyOptimizer(object):
|
||||
])
|
||||
return local_result + remote_results
|
||||
|
||||
def _check_not_multiagent(self, sample_batch):
|
||||
@staticmethod
|
||||
def _check_not_multiagent(sample_batch):
|
||||
if isinstance(sample_batch, MultiAgentBatch):
|
||||
raise NotImplementedError(
|
||||
"This optimizer does not support multi-agent yet.")
|
||||
|
||||
@@ -11,16 +11,22 @@ class TaskPool(object):
|
||||
|
||||
def __init__(self):
|
||||
self._tasks = {}
|
||||
self._objects = {}
|
||||
|
||||
def add(self, worker, obj_id):
|
||||
def add(self, worker, all_obj_ids):
|
||||
if isinstance(all_obj_ids, list):
|
||||
obj_id = all_obj_ids[0]
|
||||
else:
|
||||
obj_id = all_obj_ids
|
||||
self._tasks[obj_id] = worker
|
||||
self._objects[obj_id] = all_obj_ids
|
||||
|
||||
def completed(self):
|
||||
pending = list(self._tasks)
|
||||
if pending:
|
||||
ready, _ = ray.wait(pending, num_returns=len(pending), timeout=10)
|
||||
for obj_id in ready:
|
||||
yield (self._tasks.pop(obj_id), obj_id)
|
||||
yield (self._tasks.pop(obj_id), self._objects.pop(obj_id))
|
||||
|
||||
@property
|
||||
def count(self):
|
||||
|
||||
Reference in New Issue
Block a user