mirror of
https://github.com/wassname/ray.git
synced 2026-06-27 23:23:15 +08:00
fix (#1174)
This commit is contained in:
committed by
Philipp Moritz
parent
dc66a2d7d5
commit
202e7bf19a
@@ -65,8 +65,12 @@ class LocalSyncParallelOptimizer(object):
|
||||
|
||||
# Then setup the per-device loss graphs that use the shared weights
|
||||
self._batch_index = tf.placeholder(tf.int32)
|
||||
data_splits = zip(
|
||||
*[tf.split(ph, len(devices)) for ph in input_placeholders])
|
||||
|
||||
# Split on the CPU in case the data doesn't fit in GPU memory.
|
||||
with tf.device("/cpu:0"):
|
||||
data_splits = zip(
|
||||
*[tf.split(ph, len(devices)) for ph in input_placeholders])
|
||||
|
||||
self._towers = []
|
||||
for device, device_placeholders in zip(self.devices, data_splits):
|
||||
self._towers.append(self._setup_device(device,
|
||||
|
||||
Reference in New Issue
Block a user