This commit is contained in:
Eric Liang
2017-11-01 13:45:39 -07:00
committed by Philipp Moritz
parent dc66a2d7d5
commit 202e7bf19a
+6 -2
View File
@@ -65,8 +65,12 @@ class LocalSyncParallelOptimizer(object):
# Then setup the per-device loss graphs that use the shared weights
self._batch_index = tf.placeholder(tf.int32)
data_splits = zip(
*[tf.split(ph, len(devices)) for ph in input_placeholders])
# Split on the CPU in case the data doesn't fit in GPU memory.
with tf.device("/cpu:0"):
data_splits = zip(
*[tf.split(ph, len(devices)) for ph in input_placeholders])
self._towers = []
for device, device_placeholders in zip(self.devices, data_splits):
self._towers.append(self._setup_device(device,