Fixed calculation of num_steps_trained for multi_gpu_optimizer (#4364)

This commit is contained in:
Leon Sievers
2019-03-15 03:46:02 +01:00
committed by Eric Liang
parent 2c1131e8b2
commit 6b93ec3034
2 changed files with 2 additions and 2 deletions
@@ -188,7 +188,7 @@ class LocalSyncParallelOptimizer(object):
sess.run([t.init_op for t in self._towers], feed_dict=feed_dict)
tuples_per_device = truncated_len / len(self.devices)
tuples_per_device = truncated_len // len(self.devices)
assert tuples_per_device > 0, "No data loaded?"
assert tuples_per_device % self._loaded_per_device_batch_size == 0
return tuples_per_device
@@ -196,7 +196,7 @@ class LocalMultiGPUOptimizer(PolicyOptimizer):
fetches[policy_id] = _averaged(iter_extra_fetches)
self.num_steps_sampled += samples.count
self.num_steps_trained += samples.count
self.num_steps_trained += tuples_per_device * len(self.devices)
return fetches
@override(PolicyOptimizer)