mirror of
https://github.com/wassname/ray.git
synced 2026-06-30 20:18:33 +08:00
[rllib] Copy data before passing to Ape-X learner thread (fixes transient plasma crashes) (#3484)
This commit is contained in:
@@ -200,6 +200,11 @@ class MultiAgentBatch(object):
|
||||
out[policy_id] = SampleBatch.concat_samples(batches)
|
||||
return MultiAgentBatch(out, total_count)
|
||||
|
||||
def copy(self):
|
||||
return MultiAgentBatch(
|
||||
{k: v.copy()
|
||||
for (k, v) in self.policy_batches.items()}, self.count)
|
||||
|
||||
def total(self):
|
||||
ct = 0
|
||||
for batch in self.policy_batches.values():
|
||||
@@ -261,6 +266,11 @@ class SampleBatch(object):
|
||||
out[k] = np.concatenate([self[k], other[k]])
|
||||
return SampleBatch(out)
|
||||
|
||||
def copy(self):
|
||||
return SampleBatch(
|
||||
{k: np.array(v, copy=True)
|
||||
for (k, v) in self.data.items()})
|
||||
|
||||
def rows(self):
|
||||
"""Returns an iterator over data rows, i.e. dicts with column values.
|
||||
|
||||
|
||||
@@ -154,9 +154,7 @@ class LearnerThread(threading.Thread):
|
||||
info["td_error"])
|
||||
if "stats" in info:
|
||||
self.stats[pid] = info["stats"]
|
||||
# send `replay` back also so that it gets released by the original
|
||||
# thread: https://github.com/ray-project/ray/issues/2610
|
||||
self.outqueue.put((ra, replay, prio_dict, replay.count))
|
||||
self.outqueue.put((ra, prio_dict, replay.count))
|
||||
self.learner_queue_size.push(self.inqueue.qsize())
|
||||
self.weights_updated = True
|
||||
|
||||
@@ -293,11 +291,12 @@ class AsyncReplayOptimizer(PolicyOptimizer):
|
||||
else:
|
||||
with self.timers["get_samples"]:
|
||||
samples = ray.get(replay)
|
||||
self.learner.inqueue.put((ra, samples))
|
||||
# Defensive copy against plasma crashes, see #2610 #3452
|
||||
self.learner.inqueue.put((ra, samples and samples.copy()))
|
||||
|
||||
with self.timers["update_priorities"]:
|
||||
while not self.learner.outqueue.empty():
|
||||
ra, _, prio_dict, count = self.learner.outqueue.get()
|
||||
ra, prio_dict, count = self.learner.outqueue.get()
|
||||
ra.update_priorities.remote(prio_dict)
|
||||
train_timesteps += count
|
||||
|
||||
|
||||
Reference in New Issue
Block a user