diff --git a/python/ray/rllib/evaluation/sample_batch.py b/python/ray/rllib/evaluation/sample_batch.py index 83f2f5ca1..5a0099530 100644 --- a/python/ray/rllib/evaluation/sample_batch.py +++ b/python/ray/rllib/evaluation/sample_batch.py @@ -200,6 +200,11 @@ class MultiAgentBatch(object): out[policy_id] = SampleBatch.concat_samples(batches) return MultiAgentBatch(out, total_count) + def copy(self): + return MultiAgentBatch( + {k: v.copy() + for (k, v) in self.policy_batches.items()}, self.count) + def total(self): ct = 0 for batch in self.policy_batches.values(): @@ -261,6 +266,11 @@ class SampleBatch(object): out[k] = np.concatenate([self[k], other[k]]) return SampleBatch(out) + def copy(self): + return SampleBatch( + {k: np.array(v, copy=True) + for (k, v) in self.data.items()}) + def rows(self): """Returns an iterator over data rows, i.e. dicts with column values. diff --git a/python/ray/rllib/optimizers/async_replay_optimizer.py b/python/ray/rllib/optimizers/async_replay_optimizer.py index 1f2167408..932800001 100644 --- a/python/ray/rllib/optimizers/async_replay_optimizer.py +++ b/python/ray/rllib/optimizers/async_replay_optimizer.py @@ -154,9 +154,7 @@ class LearnerThread(threading.Thread): info["td_error"]) if "stats" in info: self.stats[pid] = info["stats"] - # send `replay` back also so that it gets released by the original - # thread: https://github.com/ray-project/ray/issues/2610 - self.outqueue.put((ra, replay, prio_dict, replay.count)) + self.outqueue.put((ra, prio_dict, replay.count)) self.learner_queue_size.push(self.inqueue.qsize()) self.weights_updated = True @@ -293,11 +291,12 @@ class AsyncReplayOptimizer(PolicyOptimizer): else: with self.timers["get_samples"]: samples = ray.get(replay) - self.learner.inqueue.put((ra, samples)) + # Defensive copy against plasma crashes, see #2610 #3452 + self.learner.inqueue.put((ra, samples and samples.copy())) with self.timers["update_priorities"]: while not self.learner.outqueue.empty(): - ra, _, prio_dict, count = self.learner.outqueue.get() + ra, prio_dict, count = self.learner.outqueue.get() ra.update_priorities.remote(prio_dict) train_timesteps += count