From 0a505ca83d18a340e7e2dd270d82aeff620ebc77 Mon Sep 17 00:00:00 2001 From: "Siyuan (Ryans) Zhuang" Date: Thu, 26 Nov 2020 16:09:54 -0800 Subject: [PATCH] [Core] zero-copy serializer for pytorch (#12344) * zero-copy serializer for pytorch * address possible bottleneck * add tests & device support --- python/ray/serialization.py | 2 + python/ray/serialization_addons.py | 72 ++++++++++++++++++++++++++ python/ray/tests/test_serialization.py | 26 +++++++++- 3 files changed, 99 insertions(+), 1 deletion(-) create mode 100644 python/ray/serialization_addons.py diff --git a/python/ray/serialization.py b/python/ray/serialization.py index f85b07afa..ef1ec50f2 100644 --- a/python/ray/serialization.py +++ b/python/ray/serialization.py @@ -26,6 +26,7 @@ from ray._raylet import ( MessagePackSerializedObject, RawSerializedObject, ) +from ray import serialization_addons logger = logging.getLogger(__name__) @@ -155,6 +156,7 @@ class SerializationContext: # Because objects have default __reduce__ method, we only need to # treat ObjectRef specifically. self._register_cloudpickle_reducer(ray.ObjectRef, object_ref_reducer) + serialization_addons.apply(self) def _register_cloudpickle_reducer(self, cls, reducer): pickle.CloudPickler.dispatch[cls] = reducer diff --git a/python/ray/serialization_addons.py b/python/ray/serialization_addons.py new file mode 100644 index 000000000..3d57a9137 --- /dev/null +++ b/python/ray/serialization_addons.py @@ -0,0 +1,72 @@ +""" +This module is intended for implementing internal serializers for some +site packages. +""" + +import warnings + +try: + import torch + + _TORCH_WARNING_FILTER_ACTIVATE = True + + class _TorchTensorReducingHelper: + def __init__(self, tensor): + self.tensor = tensor + + @classmethod + def rebuild_tensor(cls, rebuild_func, device, ndarray, params): + global _TORCH_WARNING_FILTER_ACTIVATE + # filtering warning messages would be the bottleneck for + # deserializing torch tensors. Since the warning only prompts once, + # we would only deal with it for the first time. + if _TORCH_WARNING_FILTER_ACTIVATE: + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", + category=UserWarning, + message="The given NumPy array is not writeable") + _tensor = torch.from_numpy(ndarray) + _TORCH_WARNING_FILTER_ACTIVATE = False + else: + _tensor = torch.from_numpy(ndarray) + if device != torch.device("cpu"): + _tensor = _tensor.to(device) + tensor = rebuild_func(_tensor.storage(), *params) + return cls(tensor) + + @classmethod + def rebuild_sparse_tensor(cls, rebuild_func, content): + tensor = rebuild_func(*content) + return cls(tensor) + + def __reduce_ex__(self, protocol): + _rebuild_func, content = self.tensor.__reduce_ex__(protocol) + if self.tensor.is_sparse: + # Torch will help us reduce the sparse tensor into + # several continuous tensors. + return self.rebuild_sparse_tensor, (_rebuild_func, content) + # By only replacing the storage with a numpy array, we can reuse + # zero-copy serialization while keeping all other params of the + # torch tensor. + return self.rebuild_tensor, (_rebuild_func, self.tensor.device, + self.tensor.detach().cpu().numpy(), + content[1:]) + + def _unwrap_tensor(s): + return s.tensor + + def torch_tensor_reducer(tensor): + return _unwrap_tensor, (_TorchTensorReducingHelper(tensor), ) + +except ImportError: + pass + + +def apply(serialization_context): + try: + import torch + serialization_context._register_cloudpickle_reducer( + torch.Tensor, torch_tensor_reducer) + except ImportError: + pass diff --git a/python/ray/tests/test_serialization.py b/python/ray/tests/test_serialization.py index 35b6e09fe..500b1ed84 100644 --- a/python/ray/tests/test_serialization.py +++ b/python/ray/tests/test_serialization.py @@ -543,7 +543,7 @@ def test_reducer_override_no_reference_cycle(ray_start_shared_local_modes): assert new_obj() is None -def test_buffer_alignment(): +def test_buffer_alignment(ray_start_shared_local_modes): # Deserialized large numpy arrays should be 64-byte aligned. x = np.random.normal(size=(10, 20, 30)) y = ray.get(ray.put(x)) @@ -568,6 +568,30 @@ def test_buffer_alignment(): assert y.ctypes.data % 8 == 0 +def test_pytorch_tensor_zerocopy_serialization(ray_start_shared_local_modes): + import torch + # test dense tensor + tensor = torch.rand(32, 3, 64, 64) + ref = ray.put(tensor) + tensor_1, tensor_2 = ray.get([ref] * 2) + assert tensor_1.data_ptr() == tensor_2.data_ptr() + + # test sparse tensor + i = torch.arange(0, 1024 * 1024, 4).view(1, -1) + v = torch.rand(1024 * 1024 // 4) + k = torch.sparse_coo_tensor(i, v, size=(1024 * 1024, )) + ref = ray.put(k) + tensor_1, tensor_2 = ray.get([ref] * 2) + assert tensor_1._indices().data_ptr() == tensor_2._indices().data_ptr() + assert tensor_1._values().data_ptr() == tensor_2._values().data_ptr() + + # test attributes + tensor = torch.rand(4).requires_grad_(True) + ref = ray.put(tensor) + tensor = ray.get(ref) + assert tensor.requires_grad + + if __name__ == "__main__": import pytest sys.exit(pytest.main(["-v", __file__]))