From 202e7bf19a34fc18839ce29515273e321ecd04af Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Wed, 1 Nov 2017 13:45:39 -0700 Subject: [PATCH] fix (#1174) --- python/ray/rllib/parallel.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/python/ray/rllib/parallel.py b/python/ray/rllib/parallel.py index e2281f663..a0e9a1aeb 100644 --- a/python/ray/rllib/parallel.py +++ b/python/ray/rllib/parallel.py @@ -65,8 +65,12 @@ class LocalSyncParallelOptimizer(object): # Then setup the per-device loss graphs that use the shared weights self._batch_index = tf.placeholder(tf.int32) - data_splits = zip( - *[tf.split(ph, len(devices)) for ph in input_placeholders]) + + # Split on the CPU in case the data doesn't fit in GPU memory. + with tf.device("/cpu:0"): + data_splits = zip( + *[tf.split(ph, len(devices)) for ph in input_placeholders]) + self._towers = [] for device, device_placeholders in zip(self.devices, data_splits): self._towers.append(self._setup_device(device,