[rllib] Workaround actor creation hang edge case for ape-X (#2661)

* apex hang

* fix

* move pyt to end
This commit is contained in:
Eric Liang
2018-08-16 18:03:50 -07:00
committed by GitHub
parent 5f430da180
commit 6670880f03
3 changed files with 41 additions and 22 deletions
+18 -5
View File
@@ -137,14 +137,27 @@ class DQNAgent(Agent):
self.local_evaluator = self.make_local_evaluator(
self.env_creator, self._policy_graph)
self.remote_evaluators = self.make_remote_evaluators(
self.env_creator, self._policy_graph, self.config["num_workers"], {
"num_cpus": self.config["num_cpus_per_worker"],
"num_gpus": self.config["num_gpus_per_worker"]
})
def create_remote_evaluators():
return self.make_remote_evaluators(
self.env_creator, self._policy_graph,
self.config["num_workers"], {
"num_cpus": self.config["num_cpus_per_worker"],
"num_gpus": self.config["num_gpus_per_worker"]
})
if self.config["optimizer_class"] != "AsyncReplayOptimizer":
self.remote_evaluators = create_remote_evaluators()
else:
# Hack to workaround https://github.com/ray-project/ray/issues/2541
self.remote_evaluators = None
self.optimizer = getattr(optimizers, self.config["optimizer_class"])(
self.local_evaluator, self.remote_evaluators,
self.config["optimizer"])
# Create the remote evaluators *after* the replay actors
if self.remote_evaluators is None:
self.remote_evaluators = create_remote_evaluators()
self.optimizer.set_evaluators(self.remote_evaluators)
self.last_target_update_ts = 0
self.num_target_updates = 0
@@ -27,7 +27,7 @@ REPLAY_QUEUE_DEPTH = 4
LEARNER_QUEUE_MAX_SIZE = 16
@ray.remote
@ray.remote(num_cpus=0)
class ReplayActor(object):
"""A replay buffer shard.
@@ -175,7 +175,6 @@ class AsyncReplayOptimizer(PolicyOptimizer):
train_batch_size, prioritized_replay_alpha,
prioritized_replay_beta, prioritized_replay_eps, clip_rewards
], num_replay_buffer_shards)
assert len(self.remote_evaluators) > 0
# Stats
self.timers = {
@@ -199,6 +198,12 @@ class AsyncReplayOptimizer(PolicyOptimizer):
# Kick off async background sampling
self.sample_tasks = TaskPool()
if self.remote_evaluators:
self.set_evaluators(self.remote_evaluators)
# For https://github.com/ray-project/ray/issues/2541 only
def set_evaluators(self, remote_evaluators):
self.remote_evaluators = remote_evaluators
weights = self.local_evaluator.get_weights()
for ev in self.remote_evaluators:
ev.set_weights.remote(weights)
@@ -207,6 +212,7 @@ class AsyncReplayOptimizer(PolicyOptimizer):
self.sample_tasks.add(ev, ev.sample_with_count.remote())
def step(self):
assert len(self.remote_evaluators) > 0
start = time.time()
sample_timesteps, train_timesteps = self._step()
time_delta = time.time() - start