diff --git a/python/ray/rllib/parallel.py b/python/ray/rllib/parallel.py index e2281f663..a0e9a1aeb 100644 --- a/python/ray/rllib/parallel.py +++ b/python/ray/rllib/parallel.py @@ -65,8 +65,12 @@ class LocalSyncParallelOptimizer(object): # Then setup the per-device loss graphs that use the shared weights self._batch_index = tf.placeholder(tf.int32) - data_splits = zip( - *[tf.split(ph, len(devices)) for ph in input_placeholders]) + + # Split on the CPU in case the data doesn't fit in GPU memory. + with tf.device("/cpu:0"): + data_splits = zip( + *[tf.split(ph, len(devices)) for ph in input_placeholders]) + self._towers = [] for device, device_placeholders in zip(self.devices, data_splits): self._towers.append(self._setup_device(device,