diff --git a/python/ray/cloudpickle/__init__.py b/python/ray/cloudpickle/__init__.py index 73e0054da..a2f166ac5 100644 --- a/python/ray/cloudpickle/__init__.py +++ b/python/ray/cloudpickle/__init__.py @@ -1,12 +1,7 @@ from __future__ import absolute_import import sys -# TODO(suquark): This is a temporary flag for -# the new serialization implementation. -# Remove it when the old one is deprecated. -USE_NEW_SERIALIZER = False - -if USE_NEW_SERIALIZER and sys.version_info[:2] >= (3, 8): +if sys.version_info[:2] >= (3, 8): from ray.cloudpickle.cloudpickle_fast import * FAST_CLOUDPICKLE_USED = True else: diff --git a/python/ray/node.py b/python/ray/node.py index 28e01f586..e16e24453 100644 --- a/python/ray/node.py +++ b/python/ray/node.py @@ -264,6 +264,10 @@ class Node(object): def load_code_from_local(self): return self._ray_params.load_code_from_local + @property + def use_pickle(self): + return self._ray_params.use_pickle + @property def object_id_seed(self): """Get the seed for deterministic generation of object IDs""" @@ -520,7 +524,7 @@ class Node(object): include_java=self._ray_params.include_java, java_worker_options=self._ray_params.java_worker_options, load_code_from_local=self._ray_params.load_code_from_local, - ) + use_pickle=self._ray_params.use_pickle) assert ray_constants.PROCESS_TYPE_RAYLET not in self.all_processes self.all_processes[ray_constants.PROCESS_TYPE_RAYLET] = [process_info] diff --git a/python/ray/parameter.py b/python/ray/parameter.py index 929038c5d..5b4c7f451 100644 --- a/python/ray/parameter.py +++ b/python/ray/parameter.py @@ -74,6 +74,7 @@ class RayParams(object): Java worker. java_worker_options (str): The command options for Java worker. load_code_from_local: Whether load code from local file or from GCS. + use_pickle: Whether data objects should be serialized with cloudpickle. _internal_config (str): JSON configuration for overriding RayConfig defaults. For testing purposes ONLY. """ @@ -113,6 +114,7 @@ class RayParams(object): include_java=False, java_worker_options=None, load_code_from_local=False, + use_pickle=False, _internal_config=None): self.object_id_seed = object_id_seed self.redis_address = redis_address @@ -146,6 +148,7 @@ class RayParams(object): self.include_java = include_java self.java_worker_options = java_worker_options self.load_code_from_local = load_code_from_local + self.use_pickle = use_pickle self._internal_config = _internal_config self._check_usage() diff --git a/python/ray/scripts/scripts.py b/python/ray/scripts/scripts.py index 0cee431c7..062762dc2 100644 --- a/python/ray/scripts/scripts.py +++ b/python/ray/scripts/scripts.py @@ -225,6 +225,11 @@ def cli(logging_level, logging_format): is_flag=True, default=False, help="Specify whether load code from local file or GCS serialization.") +@click.option( + "--use-pickle", + is_flag=True, + default=False, + help="Use pickle for serialization.") def start(node_ip_address, redis_address, address, redis_port, num_redis_shards, redis_max_clients, redis_password, redis_shard_ports, object_manager_port, node_manager_port, memory, @@ -232,7 +237,8 @@ def start(node_ip_address, redis_address, address, redis_port, head, include_webui, block, plasma_directory, huge_pages, autoscaling_config, no_redirect_worker_output, no_redirect_output, plasma_store_socket_name, raylet_socket_name, temp_dir, include_java, - java_worker_options, load_code_from_local, internal_config): + java_worker_options, load_code_from_local, use_pickle, + internal_config): # Convert hostnames to numerical IP address. if node_ip_address is not None: node_ip_address = services.address_to_ip(node_ip_address) @@ -273,6 +279,7 @@ def start(node_ip_address, redis_address, address, redis_port, include_webui=include_webui, java_worker_options=java_worker_options, load_code_from_local=load_code_from_local, + use_pickle=use_pickle, _internal_config=internal_config) if head: diff --git a/python/ray/services.py b/python/ray/services.py index 69ee0dc70..3d6df5bf2 100644 --- a/python/ray/services.py +++ b/python/ray/services.py @@ -1060,7 +1060,8 @@ def start_raylet(redis_address, config=None, include_java=False, java_worker_options=None, - load_code_from_local=False): + load_code_from_local=False, + use_pickle=False): """Start a raylet, which is a combined local scheduler and object manager. Args: @@ -1092,6 +1093,7 @@ def start_raylet(redis_address, include_java (bool): If True, the raylet backend can also support Java worker. java_worker_options (str): The command options for Java worker. + use_pickle (bool): If True, use cloudpickle for serialization. Returns: ProcessInfo for the process that was started. """ @@ -1155,6 +1157,8 @@ def start_raylet(redis_address, if load_code_from_local: start_worker_command += " --load-code-from-local " + if use_pickle: + start_worker_command += " --use-pickle " command = [ RAYLET_EXECUTABLE, diff --git a/python/ray/tests/test_basic.py b/python/ray/tests/test_basic.py index 1e334e1b4..360334868 100644 --- a/python/ray/tests/test_basic.py +++ b/python/ray/tests/test_basic.py @@ -130,7 +130,7 @@ def test_fair_queueing(shutdown_only): assert len(ready) == 1000, len(ready) -def test_complex_serialization(ray_start_regular): +def complex_serialization(use_pickle): def assert_equal(obj1, obj2): module_numpy = (type(obj1).__module__ == np.__name__ or type(obj2).__module__ == np.__name__) @@ -340,6 +340,15 @@ def test_complex_serialization(ray_start_regular): assert ray.get(ray.put(s)).readline() == line +def test_complex_serialization(ray_start_regular): + complex_serialization(use_pickle=False) + + +def test_complex_serialization_with_pickle(shutdown_only): + ray.init(use_pickle=True) + complex_serialization(use_pickle=True) + + def test_nested_functions(ray_start_regular): # Make sure that remote functions can use other values that are defined # after the remote function but before the first function invocation. @@ -410,7 +419,7 @@ def test_ray_recursive_objects(ray_start_regular): # Create a list of recursive objects. recursive_objects = [lst, a1, a2, a3, d1] - if ray.worker.USE_NEW_SERIALIZER: + if ray.worker.global_worker.use_pickle: # Serialize the recursive objects. for obj in recursive_objects: ray.put(obj) diff --git a/python/ray/tests/test_multi_node.py b/python/ray/tests/test_multi_node.py index b62bc01de..199892201 100644 --- a/python/ray/tests/test_multi_node.py +++ b/python/ray/tests/test_multi_node.py @@ -551,3 +551,23 @@ print("success") # Make sure we can still talk with the raylet. ray.get(f.remote()) + + +@pytest.mark.parametrize( + "call_ray_start", ["ray start --head --num-cpus=1 --use-pickle"], + indirect=True) +def test_use_pickle(call_ray_start): + address = call_ray_start + + ray.init(address=address, use_pickle=True) + + assert ray.worker.global_worker.use_pickle + x = (2, "hello") + + @ray.remote + def f(x): + assert x == (2, "hello") + assert ray.worker.global_worker.use_pickle + return (3, "world") + + assert ray.get(f.remote(x)) == (3, "world") diff --git a/python/ray/worker.py b/python/ray/worker.py index e3e94ac05..dfc09726d 100644 --- a/python/ray/worker.py +++ b/python/ray/worker.py @@ -26,7 +26,6 @@ import random import pyarrow import pyarrow.plasma as plasma import ray.cloudpickle as pickle -from ray.cloudpickle import USE_NEW_SERIALIZER import ray.experimental.signal as ray_signal import ray.experimental.no_return import ray.gcs_utils @@ -176,6 +175,11 @@ class Worker(object): self.check_connected() return self.node.load_code_from_local + @property + def use_pickle(self): + self.check_connected() + return self.node.use_pickle + @property def task_context(self): """A thread-local that contains the following attributes. @@ -391,7 +395,7 @@ class Worker(object): for attempt in reversed( range(ray_constants.DEFAULT_PUT_OBJECT_RETRIES)): try: - if USE_NEW_SERIALIZER: + if self.use_pickle: self.store_with_plasma(object_id, value) else: self._try_store_and_register(object_id, value) @@ -433,8 +437,13 @@ class Worker(object): value, object_id, memcopy_threads=self.memcopy_threads) else: writer = Pickle5Writer() - inband = pickle.dumps( - value, protocol=5, buffer_callback=writer.buffer_callback) + if ray.cloudpickle.FAST_CLOUDPICKLE_USED: + inband = pickle.dumps( + value, + protocol=5, + buffer_callback=writer.buffer_callback) + else: + inband = pickle.dumps(value) self.core_worker.put_pickle5_buffers(object_id, inband, writer, self.memcopy_threads) except pyarrow.plasma.PlasmaObjectExists: @@ -512,10 +521,12 @@ class Worker(object): def _deserialize_object_from_arrow(self, data, metadata, object_id, serialization_context): if metadata: - if (USE_NEW_SERIALIZER - and metadata == ray_constants.PICKLE5_BUFFER_METADATA): + if metadata == ray_constants.PICKLE5_BUFFER_METADATA: in_band, buffers = unpack_pickle5_buffers(data) - return pickle.loads(in_band, buffers=buffers) + if len(buffers) > 0: + return pickle.loads(in_band, buffers=buffers) + else: + return pickle.loads(in_band) # Check if the object should be returned as raw bytes. if metadata == ray_constants.RAW_BUFFER_METADATA: if data is None: @@ -1085,7 +1096,7 @@ def _initialize_serialization(job_id, worker=global_worker): worker.serialization_context_map[job_id] = serialization_context - if not USE_NEW_SERIALIZER: + if not worker.use_pickle: for error_cls in RAY_EXCEPTION_TYPES: register_custom_serializer( error_cls, @@ -1158,6 +1169,7 @@ def init(address=None, raylet_socket_name=None, temp_dir=None, load_code_from_local=False, + use_pickle=False, _internal_config=None): """Connect to an existing Ray cluster or start one and connect to it. @@ -1242,6 +1254,7 @@ def init(address=None, directory for the Ray process. load_code_from_local: Whether code should be loaded from a local module or from the GCS. + use_pickle: Whether data objects should be serialized with cloudpickle. _internal_config (str): JSON configuration for overriding RayConfig defaults. For testing purposes ONLY. @@ -1316,6 +1329,7 @@ def init(address=None, raylet_socket_name=raylet_socket_name, temp_dir=temp_dir, load_code_from_local=load_code_from_local, + use_pickle=use_pickle, _internal_config=_internal_config, ) # Start the Ray processes. We set shutdown_at_exit=False because we @@ -1372,7 +1386,8 @@ def init(address=None, redis_password=redis_password, object_id_seed=object_id_seed, temp_dir=temp_dir, - load_code_from_local=load_code_from_local) + load_code_from_local=load_code_from_local, + use_pickle=use_pickle) _global_node = ray.node.Node( ray_params, head=False, shutdown_at_exit=False, connect_only=True) @@ -2045,7 +2060,7 @@ def register_custom_serializer(cls, assert isinstance(job_id, JobID) def register_class_for_serialization(worker_info): - if USE_NEW_SERIALIZER: + if worker_info["worker"].use_pickle: if pickle.FAST_CLOUDPICKLE_USED: # construct a reducer pickle.CloudPickler.dispatch[ diff --git a/python/ray/workers/default_worker.py b/python/ray/workers/default_worker.py index 508a38fae..594218480 100644 --- a/python/ray/workers/default_worker.py +++ b/python/ray/workers/default_worker.py @@ -62,6 +62,11 @@ parser.add_argument( default=False, action="store_true", help="True if code is loaded from local files, as opposed to the GCS.") +parser.add_argument( + "--use-pickle", + default=False, + action="store_true", + help="True if cloudpickle should be used for serialization.") if __name__ == "__main__": args = parser.parse_args() @@ -75,7 +80,8 @@ if __name__ == "__main__": plasma_store_socket_name=args.object_store_name, raylet_socket_name=args.raylet_name, temp_dir=args.temp_dir, - load_code_from_local=args.load_code_from_local) + load_code_from_local=args.load_code_from_local, + use_pickle=args.use_pickle) node = ray.node.Node( ray_params, head=False, shutdown_at_exit=False, connect_only=True)