From 8395523f8102d513a5585cfd1de968d9f7148bcf Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Thu, 6 Dec 2018 18:01:11 -0800 Subject: [PATCH] [rllib] Copy data before passing to Ape-X learner thread (fixes transient plasma crashes) (#3484) --- python/ray/rllib/evaluation/sample_batch.py | 10 ++++++++++ python/ray/rllib/optimizers/async_replay_optimizer.py | 9 ++++----- 2 files changed, 14 insertions(+), 5 deletions(-) diff --git a/python/ray/rllib/evaluation/sample_batch.py b/python/ray/rllib/evaluation/sample_batch.py index 83f2f5ca1..5a0099530 100644 --- a/python/ray/rllib/evaluation/sample_batch.py +++ b/python/ray/rllib/evaluation/sample_batch.py @@ -200,6 +200,11 @@ class MultiAgentBatch(object): out[policy_id] = SampleBatch.concat_samples(batches) return MultiAgentBatch(out, total_count) + def copy(self): + return MultiAgentBatch( + {k: v.copy() + for (k, v) in self.policy_batches.items()}, self.count) + def total(self): ct = 0 for batch in self.policy_batches.values(): @@ -261,6 +266,11 @@ class SampleBatch(object): out[k] = np.concatenate([self[k], other[k]]) return SampleBatch(out) + def copy(self): + return SampleBatch( + {k: np.array(v, copy=True) + for (k, v) in self.data.items()}) + def rows(self): """Returns an iterator over data rows, i.e. dicts with column values. diff --git a/python/ray/rllib/optimizers/async_replay_optimizer.py b/python/ray/rllib/optimizers/async_replay_optimizer.py index 1f2167408..932800001 100644 --- a/python/ray/rllib/optimizers/async_replay_optimizer.py +++ b/python/ray/rllib/optimizers/async_replay_optimizer.py @@ -154,9 +154,7 @@ class LearnerThread(threading.Thread): info["td_error"]) if "stats" in info: self.stats[pid] = info["stats"] - # send `replay` back also so that it gets released by the original - # thread: https://github.com/ray-project/ray/issues/2610 - self.outqueue.put((ra, replay, prio_dict, replay.count)) + self.outqueue.put((ra, prio_dict, replay.count)) self.learner_queue_size.push(self.inqueue.qsize()) self.weights_updated = True @@ -293,11 +291,12 @@ class AsyncReplayOptimizer(PolicyOptimizer): else: with self.timers["get_samples"]: samples = ray.get(replay) - self.learner.inqueue.put((ra, samples)) + # Defensive copy against plasma crashes, see #2610 #3452 + self.learner.inqueue.put((ra, samples and samples.copy())) with self.timers["update_priorities"]: while not self.learner.outqueue.empty(): - ra, _, prio_dict, count = self.learner.outqueue.get() + ra, prio_dict, count = self.learner.outqueue.get() ra.update_priorities.remote(prio_dict) train_timesteps += count