diff --git a/python/ray/experimental/tfutils.py b/python/ray/experimental/tfutils.py index 5133a49b9..c6aa4ba9a 100644 --- a/python/ray/experimental/tfutils.py +++ b/python/ray/experimental/tfutils.py @@ -9,7 +9,7 @@ def unflatten(vector, shapes): i = 0 arrays = [] for shape in shapes: - size = np.prod(shape) + size = np.prod(shape, dtype=np.int) array = vector[i:(i + size)].reshape(shape) arrays.append(array) i += size