diff --git a/python/ray/import_thread.py b/python/ray/import_thread.py index 484410f1c..850cc2c60 100644 --- a/python/ray/import_thread.py +++ b/python/ray/import_thread.py @@ -21,7 +21,7 @@ class ImportThread(object): """A thread used to import exports from the driver or other workers. Note: The driver also has an import thread, which is used only to import - custom class definitions from calls to register_custom_serializer that + custom class definitions from calls to _register_custom_serializer that happen under the hood on workers. Attributes: diff --git a/python/ray/worker.py b/python/ray/worker.py index dfc09726d..0923d2029 100644 --- a/python/ray/worker.py +++ b/python/ray/worker.py @@ -327,39 +327,35 @@ class Worker(object): memcopy_threads=self.memcopy_threads) break except pyarrow.SerializationCallbackError as e: + cls_type = type(e.example_object) try: - register_custom_serializer( - type(e.example_object), use_dict=True) - warning_message = ("WARNING: Serializing objects of type " - "{} by expanding them as dictionaries " - "of their fields. This behavior may " - "be incorrect in some cases.".format( - type(e.example_object))) + _register_custom_serializer(cls_type, use_dict=True) + warning_message = ( + "WARNING: Serializing objects of type " + "{} by expanding them as dictionaries " + "of their fields. This behavior may " + "be incorrect in some cases.".format(cls_type)) logger.debug(warning_message) except (serialization.RayNotDictionarySerializable, serialization.CloudPickleError, pickle.pickle.PicklingError, Exception): # We also handle generic exceptions here because # cloudpickle can fail with many different types of errors. + warning_message = ( + "Falling back to serializing {} objects by using " + "pickle. Use `ray.register_custom_serializer({},...)` " + "to provide faster serialization.".format( + cls_type, cls_type)) try: - register_custom_serializer( - type(e.example_object), use_pickle=True) - warning_message = ("WARNING: Falling back to " - "serializing objects of type {} by " - "using pickle. This may be " - "inefficient.".format( - type(e.example_object))) + _register_custom_serializer(cls_type, use_pickle=True) logger.warning(warning_message) - except serialization.CloudPickleError: - register_custom_serializer( - type(e.example_object), - use_pickle=True, - local=True) + except (serialization.CloudPickleError, ValueError): + _register_custom_serializer( + cls_type, use_pickle=True, local=True) warning_message = ("WARNING: Pickling the class {} " "failed, so we are using pickle " "and only registering the class " - "locally.".format( - type(e.example_object))) + "locally.".format(cls_type)) logger.warning(warning_message) def put_object(self, object_id, value): @@ -470,7 +466,7 @@ class Worker(object): # may not be serializable for cloudpickle. So we need # these extra fallbacks here to start from the beginning. # Hopefully the object could have a `__reduce__` method. - register_custom_serializer(type(value), use_pickle=True) + _register_custom_serializer(type(value), use_pickle=True) warning_message = ("WARNING: Serializing the class {} failed, " "falling back to cloudpickle.".format( type(value))) @@ -1085,7 +1081,7 @@ def _initialize_serialization(job_id, worker=global_worker): return new_handle # We register this serializer on each worker instead of calling - # register_custom_serializer from the driver so that isinstance still + # _register_custom_serializer from the driver so that isinstance still # works. serialization_context.register_type( ray.actor.ActorHandle, @@ -1098,7 +1094,7 @@ def _initialize_serialization(job_id, worker=global_worker): if not worker.use_pickle: for error_cls in RAY_EXCEPTION_TYPES: - register_custom_serializer( + _register_custom_serializer( error_cls, use_dict=True, local=True, @@ -1106,14 +1102,14 @@ def _initialize_serialization(job_id, worker=global_worker): class_id=error_cls.__module__ + ". " + error_cls.__name__, ) # Tell Ray to serialize lambdas with pickle. - register_custom_serializer( + _register_custom_serializer( type(lambda: 0), use_pickle=True, local=True, job_id=job_id, class_id="lambda") # Tell Ray to serialize types with pickle. - register_custom_serializer( + _register_custom_serializer( type(int), use_pickle=True, local=True, @@ -1121,7 +1117,7 @@ def _initialize_serialization(job_id, worker=global_worker): class_id="type") # Tell Ray to serialize FunctionSignatures as dictionaries. This is # used when passing around actor handles. - register_custom_serializer( + _register_custom_serializer( ray.signature.FunctionSignature, use_dict=True, local=True, @@ -1131,7 +1127,7 @@ def _initialize_serialization(job_id, worker=global_worker): # Ray's default __dict__ serialization is incorrect for this type # (the object's __dict__ is empty and therefore doesn't # contain the full state of the object). - register_custom_serializer( + _register_custom_serializer( io.StringIO, use_pickle=True, local=True, @@ -1981,13 +1977,73 @@ def _try_to_compute_deterministic_class_id(cls, depth=5): def register_custom_serializer(cls, - use_pickle=False, - use_dict=False, serializer=None, deserializer=None, - local=False, + use_pickle=False, + use_dict=False, + local=None, job_id=None, class_id=None): + """Registers custom functions for efficient object serialization. + + The serializer and deserializer are used when transferring objects of + `cls` across processes and nodes. This can be significantly faster than + the Ray default fallbacks. Wraps `_register_custom_serializer` underneath. + + `use_pickle` tells Ray to automatically use cloudpickle for serialization, + and `use_dict` automatically uses `cls.__dict__`. + + When calling this function, you can only provide one of the following: + + 1. serializer and deserializer + 2. `use_pickle` + 3. `use_dict` + + Args: + cls (type): The class that ray should use this custom serializer for. + serializer: The custom serializer that takes in a cls instance and + outputs a serialized representation. use_pickle and use_dict + must be False if provided. + deserializer: The custom deserializer that takes in a serialized + representation of the cls and outputs a cls instance. use_pickle + and use_dict must be False if provided. + use_pickle (bool): If true, objects of this class will be + serialized using pickle. Must be False if + use_dict is true. + use_dict (bool): If true, objects of this class be serialized turning + their __dict__ fields into a dictionary. Must be False if + use_pickle is true. + local: Deprecated. + job_id: Deprecated. + class_id (str): Unique ID of the class. Autogenerated if None. + """ + if job_id: + raise DeprecationWarning( + "`job_id` is no longer a valid parameter and will be removed in " + "future versions of Ray. If this breaks your application, " + "see `ray.worker._register_custom_serializer`.") + if local: + raise DeprecationWarning( + "`local` is no longer a valid parameter and will be removed in " + "future versions of Ray. If this breaks your application, " + "see `ray.worker._register_custom_serializer`.") + _register_custom_serializer( + cls, + use_pickle=use_pickle, + use_dict=use_dict, + serializer=serializer, + deserializer=deserializer, + class_id=class_id) + + +def _register_custom_serializer(cls, + use_pickle=False, + use_dict=False, + serializer=None, + deserializer=None, + local=False, + job_id=None, + class_id=None): """Enable serialization and deserialization for a particular class. This method runs the register_class function defined below on every worker, @@ -2008,13 +2064,12 @@ def register_custom_serializer(cls, local: True if the serializers should only be registered on the current worker. This should usually be False. job_id: ID of the job that we want to register the class for. - class_id: ID of the class that we are registering. If this is not - specified, we will calculate a new one inside the function. + class_id (str): Unique ID of the class. Autogenerated if None. Raises: - Exception: An exception is raised if pickle=False and the class cannot - be efficiently serialized by Ray. This can also raise an exception - if use_dict is true and cls is not pickleable. + RayNotDictionarySerializable: Raised if use_dict is true and cls cannot + be efficiently serialized by Ray. + ValueError: Raised if ray could not autogenerate a class_id. """ worker = global_worker assert (serializer is None) == (deserializer is None), ( @@ -2045,8 +2100,9 @@ def register_custom_serializer(cls, # result may be different on different workers. class_id = _try_to_compute_deterministic_class_id(cls) except Exception: - raise serialization.CloudPickleError("Failed to pickle class " - "'{}'".format(cls)) + raise ValueError( + "Failed to use pickle in generating a unique id for '{}'. " + "Provide a unique class_id.".format(cls)) else: # In this case, the class ID only needs to be meaningful on this # worker and not across workers.