mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 20:07:41 +08:00
[rllib] Workaround actor creation hang edge case for ape-X (#2661)
* apex hang * fix * move pyt to end
This commit is contained in:
@@ -137,14 +137,27 @@ class DQNAgent(Agent):
|
||||
|
||||
self.local_evaluator = self.make_local_evaluator(
|
||||
self.env_creator, self._policy_graph)
|
||||
self.remote_evaluators = self.make_remote_evaluators(
|
||||
self.env_creator, self._policy_graph, self.config["num_workers"], {
|
||||
"num_cpus": self.config["num_cpus_per_worker"],
|
||||
"num_gpus": self.config["num_gpus_per_worker"]
|
||||
})
|
||||
|
||||
def create_remote_evaluators():
|
||||
return self.make_remote_evaluators(
|
||||
self.env_creator, self._policy_graph,
|
||||
self.config["num_workers"], {
|
||||
"num_cpus": self.config["num_cpus_per_worker"],
|
||||
"num_gpus": self.config["num_gpus_per_worker"]
|
||||
})
|
||||
|
||||
if self.config["optimizer_class"] != "AsyncReplayOptimizer":
|
||||
self.remote_evaluators = create_remote_evaluators()
|
||||
else:
|
||||
# Hack to workaround https://github.com/ray-project/ray/issues/2541
|
||||
self.remote_evaluators = None
|
||||
self.optimizer = getattr(optimizers, self.config["optimizer_class"])(
|
||||
self.local_evaluator, self.remote_evaluators,
|
||||
self.config["optimizer"])
|
||||
# Create the remote evaluators *after* the replay actors
|
||||
if self.remote_evaluators is None:
|
||||
self.remote_evaluators = create_remote_evaluators()
|
||||
self.optimizer.set_evaluators(self.remote_evaluators)
|
||||
|
||||
self.last_target_update_ts = 0
|
||||
self.num_target_updates = 0
|
||||
|
||||
@@ -27,7 +27,7 @@ REPLAY_QUEUE_DEPTH = 4
|
||||
LEARNER_QUEUE_MAX_SIZE = 16
|
||||
|
||||
|
||||
@ray.remote
|
||||
@ray.remote(num_cpus=0)
|
||||
class ReplayActor(object):
|
||||
"""A replay buffer shard.
|
||||
|
||||
@@ -175,7 +175,6 @@ class AsyncReplayOptimizer(PolicyOptimizer):
|
||||
train_batch_size, prioritized_replay_alpha,
|
||||
prioritized_replay_beta, prioritized_replay_eps, clip_rewards
|
||||
], num_replay_buffer_shards)
|
||||
assert len(self.remote_evaluators) > 0
|
||||
|
||||
# Stats
|
||||
self.timers = {
|
||||
@@ -199,6 +198,12 @@ class AsyncReplayOptimizer(PolicyOptimizer):
|
||||
|
||||
# Kick off async background sampling
|
||||
self.sample_tasks = TaskPool()
|
||||
if self.remote_evaluators:
|
||||
self.set_evaluators(self.remote_evaluators)
|
||||
|
||||
# For https://github.com/ray-project/ray/issues/2541 only
|
||||
def set_evaluators(self, remote_evaluators):
|
||||
self.remote_evaluators = remote_evaluators
|
||||
weights = self.local_evaluator.get_weights()
|
||||
for ev in self.remote_evaluators:
|
||||
ev.set_weights.remote(weights)
|
||||
@@ -207,6 +212,7 @@ class AsyncReplayOptimizer(PolicyOptimizer):
|
||||
self.sample_tasks.add(ev, ev.sample_with_count.remote())
|
||||
|
||||
def step(self):
|
||||
assert len(self.remote_evaluators) > 0
|
||||
start = time.time()
|
||||
sample_timesteps, train_timesteps = self._step()
|
||||
time_delta = time.time() - start
|
||||
|
||||
Reference in New Issue
Block a user