mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 11:53:32 +08:00
[rllib] Fix race condition with multiple data loaders, fix stats
This commit is contained in:
@@ -368,15 +368,16 @@ class TFMultiGPULearner(LearnerThread):
|
||||
assert self.loader_thread.is_alive()
|
||||
with self.load_wait_timer:
|
||||
opt, released = self.minibatch_buffer.get()
|
||||
if released:
|
||||
self.idle_optimizers.put(opt)
|
||||
|
||||
with self.grad_timer:
|
||||
fetches = opt.optimize(self.sess, 0)
|
||||
self.weights_updated = True
|
||||
self.stats = fetches.get("stats", {})
|
||||
|
||||
self.outqueue.put(self.train_batch_size)
|
||||
if released:
|
||||
self.idle_optimizers.put(opt)
|
||||
|
||||
self.outqueue.put(opt.num_tuples_loaded)
|
||||
self.learner_queue_size.push(self.inqueue.qsize())
|
||||
|
||||
|
||||
|
||||
@@ -188,6 +188,7 @@ class LocalSyncParallelOptimizer(object):
|
||||
|
||||
sess.run([t.init_op for t in self._towers], feed_dict=feed_dict)
|
||||
|
||||
self.num_tuples_loaded = truncated_len
|
||||
tuples_per_device = truncated_len // len(self.devices)
|
||||
assert tuples_per_device > 0, "No data loaded?"
|
||||
assert tuples_per_device % self._loaded_per_device_batch_size == 0
|
||||
|
||||
Reference in New Issue
Block a user