mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 13:19:38 +08:00
Fixed calculation of num_steps_trained for multi_gpu_optimizer (#4364)
This commit is contained in:
@@ -188,7 +188,7 @@ class LocalSyncParallelOptimizer(object):
|
||||
|
||||
sess.run([t.init_op for t in self._towers], feed_dict=feed_dict)
|
||||
|
||||
tuples_per_device = truncated_len / len(self.devices)
|
||||
tuples_per_device = truncated_len // len(self.devices)
|
||||
assert tuples_per_device > 0, "No data loaded?"
|
||||
assert tuples_per_device % self._loaded_per_device_batch_size == 0
|
||||
return tuples_per_device
|
||||
|
||||
@@ -196,7 +196,7 @@ class LocalMultiGPUOptimizer(PolicyOptimizer):
|
||||
fetches[policy_id] = _averaged(iter_extra_fetches)
|
||||
|
||||
self.num_steps_sampled += samples.count
|
||||
self.num_steps_trained += samples.count
|
||||
self.num_steps_trained += tuples_per_device * len(self.devices)
|
||||
return fetches
|
||||
|
||||
@override(PolicyOptimizer)
|
||||
|
||||
Reference in New Issue
Block a user