[rllib] Copy data before passing to Ape-X learner thread (fixes transient plasma crashes) (#3484)

This commit is contained in:
Eric Liang
2018-12-06 18:01:11 -08:00
committed by GitHub
parent c2c501bbe6
commit 8395523f81
2 changed files with 14 additions and 5 deletions
@@ -200,6 +200,11 @@ class MultiAgentBatch(object):
out[policy_id] = SampleBatch.concat_samples(batches)
return MultiAgentBatch(out, total_count)
def copy(self):
return MultiAgentBatch(
{k: v.copy()
for (k, v) in self.policy_batches.items()}, self.count)
def total(self):
ct = 0
for batch in self.policy_batches.values():
@@ -261,6 +266,11 @@ class SampleBatch(object):
out[k] = np.concatenate([self[k], other[k]])
return SampleBatch(out)
def copy(self):
return SampleBatch(
{k: np.array(v, copy=True)
for (k, v) in self.data.items()})
def rows(self):
"""Returns an iterator over data rows, i.e. dicts with column values.
@@ -154,9 +154,7 @@ class LearnerThread(threading.Thread):
info["td_error"])
if "stats" in info:
self.stats[pid] = info["stats"]
# send `replay` back also so that it gets released by the original
# thread: https://github.com/ray-project/ray/issues/2610
self.outqueue.put((ra, replay, prio_dict, replay.count))
self.outqueue.put((ra, prio_dict, replay.count))
self.learner_queue_size.push(self.inqueue.qsize())
self.weights_updated = True
@@ -293,11 +291,12 @@ class AsyncReplayOptimizer(PolicyOptimizer):
else:
with self.timers["get_samples"]:
samples = ray.get(replay)
self.learner.inqueue.put((ra, samples))
# Defensive copy against plasma crashes, see #2610 #3452
self.learner.inqueue.put((ra, samples and samples.copy()))
with self.timers["update_priorities"]:
while not self.learner.outqueue.empty():
ra, _, prio_dict, count = self.learner.outqueue.get()
ra, prio_dict, count = self.learner.outqueue.get()
ra.update_priorities.remote(prio_dict)
train_timesteps += count