mirror of
https://github.com/wassname/ray.git
synced 2026-06-30 01:40:21 +08:00
Use different serialization context for each driver. (#2406)
This commit is contained in:
@@ -157,13 +157,13 @@ class ImportThread(object):
|
||||
|
||||
def fetch_and_execute_function_to_run(self, key):
|
||||
"""Run on arbitrary function on the worker."""
|
||||
driver_id, serialized_function = self.redis_client.hmget(
|
||||
key, ["driver_id", "function"])
|
||||
(driver_id, serialized_function,
|
||||
run_on_other_drivers) = self.redis_client.hmget(
|
||||
key, ["driver_id", "function", "run_on_other_drivers"])
|
||||
|
||||
if (self.worker.mode in [ray.SCRIPT_MODE, ray.SILENT_MODE]
|
||||
if (run_on_other_drivers == "False"
|
||||
and self.worker.mode in [ray.SCRIPT_MODE, ray.SILENT_MODE]
|
||||
and driver_id != self.worker.task_driver_id.id()):
|
||||
# This export was from a different driver and there's no need for
|
||||
# this driver to import it.
|
||||
return
|
||||
|
||||
try:
|
||||
|
||||
+113
-49
@@ -245,6 +245,25 @@ class Worker(object):
|
||||
self.original_gpu_ids = ray.utils.get_cuda_visible_devices()
|
||||
self.profiler = profiling.Profiler(self)
|
||||
self.state_lock = threading.Lock()
|
||||
# A dictionary that maps from driver id to SerializationContext
|
||||
# TODO: clean up the SerializationContext once the job finished.
|
||||
self.serialization_context_map = {}
|
||||
# Identity of the driver that this worker is processing.
|
||||
self.task_driver_id = None
|
||||
|
||||
def get_serialization_context(self, driver_id):
|
||||
"""Get the SerializationContext of the driver that this worker is processing.
|
||||
|
||||
Args:
|
||||
driver_id: The ID of the driver that indicates which driver to get
|
||||
the serialization context for.
|
||||
|
||||
Returns:
|
||||
The serialization context of the given driver.
|
||||
"""
|
||||
if driver_id not in self.serialization_context_map:
|
||||
_initialize_serialization(driver_id)
|
||||
return self.serialization_context_map[driver_id]
|
||||
|
||||
def check_connected(self):
|
||||
"""Check if the worker is connected.
|
||||
@@ -308,7 +327,8 @@ class Worker(object):
|
||||
value,
|
||||
object_id=pyarrow.plasma.ObjectID(object_id.id()),
|
||||
memcopy_threads=self.memcopy_threads,
|
||||
serialization_context=self.serialization_context)
|
||||
serialization_context=self.get_serialization_context(
|
||||
self.task_driver_id))
|
||||
break
|
||||
except pyarrow.SerializationCallbackError as e:
|
||||
try:
|
||||
@@ -400,7 +420,8 @@ class Worker(object):
|
||||
results += self.plasma_client.get(
|
||||
object_ids[i:(
|
||||
i + ray._config.worker_get_request_size())],
|
||||
timeout, self.serialization_context)
|
||||
timeout,
|
||||
self.get_serialization_context(self.task_driver_id))
|
||||
return results
|
||||
except pyarrow.lib.ArrowInvalid:
|
||||
# TODO(ekl): the local scheduler could include relevant
|
||||
@@ -690,7 +711,8 @@ class Worker(object):
|
||||
})
|
||||
self.redis_client.rpush("Exports", key)
|
||||
|
||||
def run_function_on_all_workers(self, function):
|
||||
def run_function_on_all_workers(self, function,
|
||||
run_on_other_drivers=False):
|
||||
"""Run arbitrary code on all of the workers.
|
||||
|
||||
This function will first be run on the driver, and then it will be
|
||||
@@ -702,6 +724,9 @@ class Worker(object):
|
||||
function (Callable): The function to run on all of the workers. It
|
||||
should not take any arguments. If it returns anything, its
|
||||
return values will not be used.
|
||||
run_on_other_drivers: The boolean that indicates whether we want to
|
||||
run this funtion on other drivers. One case is we may need to
|
||||
share objects across drivers.
|
||||
"""
|
||||
# If ray.init has not been called yet, then cache the function and
|
||||
# export it when connect is called. Otherwise, run the function on all
|
||||
@@ -734,7 +759,8 @@ class Worker(object):
|
||||
key, {
|
||||
"driver_id": self.task_driver_id.id(),
|
||||
"function_id": function_to_run_id,
|
||||
"function": pickled_function
|
||||
"function": pickled_function,
|
||||
"run_on_other_drivers": run_on_other_drivers
|
||||
})
|
||||
self.redis_client.rpush("Exports", key)
|
||||
# TODO(rkn): If the worker fails after it calls setnx and before it
|
||||
@@ -1209,17 +1235,17 @@ def error_info(worker=global_worker):
|
||||
return errors
|
||||
|
||||
|
||||
def _initialize_serialization(worker=global_worker):
|
||||
def _initialize_serialization(driver_id, worker=global_worker):
|
||||
"""Initialize the serialization library.
|
||||
|
||||
This defines a custom serializer for object IDs and also tells ray to
|
||||
serialize several exception classes that we define for error handling.
|
||||
"""
|
||||
worker.serialization_context = pyarrow.default_serialization_context()
|
||||
serialization_context = pyarrow.default_serialization_context()
|
||||
# Tell the serialization context to use the cloudpickle version that we
|
||||
# ship with Ray.
|
||||
worker.serialization_context.set_pickle(pickle.dumps, pickle.loads)
|
||||
pyarrow.register_torch_serialization_handlers(worker.serialization_context)
|
||||
serialization_context.set_pickle(pickle.dumps, pickle.loads)
|
||||
pyarrow.register_torch_serialization_handlers(serialization_context)
|
||||
|
||||
# Define a custom serializer and deserializer for handling Object IDs.
|
||||
def object_id_custom_serializer(obj):
|
||||
@@ -1231,7 +1257,7 @@ def _initialize_serialization(worker=global_worker):
|
||||
# We register this serializer on each worker instead of calling
|
||||
# register_custom_serializer from the driver so that isinstance still
|
||||
# works.
|
||||
worker.serialization_context.register_type(
|
||||
serialization_context.register_type(
|
||||
ray.ObjectID,
|
||||
"ray.ObjectID",
|
||||
pickle=False,
|
||||
@@ -1249,28 +1275,55 @@ def _initialize_serialization(worker=global_worker):
|
||||
# We register this serializer on each worker instead of calling
|
||||
# register_custom_serializer from the driver so that isinstance still
|
||||
# works.
|
||||
worker.serialization_context.register_type(
|
||||
serialization_context.register_type(
|
||||
ray.actor.ActorHandle,
|
||||
"ray.ActorHandle",
|
||||
pickle=False,
|
||||
custom_serializer=actor_handle_serializer,
|
||||
custom_deserializer=actor_handle_deserializer)
|
||||
|
||||
if worker.mode in [SCRIPT_MODE, SILENT_MODE]:
|
||||
# These should only be called on the driver because
|
||||
# register_custom_serializer will export the class to all of the
|
||||
# workers.
|
||||
register_custom_serializer(RayTaskError, use_dict=True)
|
||||
register_custom_serializer(RayGetError, use_dict=True)
|
||||
register_custom_serializer(RayGetArgumentError, use_dict=True)
|
||||
# Tell Ray to serialize lambdas with pickle.
|
||||
register_custom_serializer(type(lambda: 0), use_pickle=True)
|
||||
# Tell Ray to serialize types with pickle.
|
||||
register_custom_serializer(type(int), use_pickle=True)
|
||||
# Tell Ray to serialize FunctionSignatures as dictionaries. This is
|
||||
# used when passing around actor handles.
|
||||
register_custom_serializer(
|
||||
ray.signature.FunctionSignature, use_dict=True)
|
||||
worker.serialization_context_map[driver_id] = serialization_context
|
||||
|
||||
register_custom_serializer(
|
||||
RayTaskError,
|
||||
use_dict=True,
|
||||
local=True,
|
||||
driver_id=driver_id,
|
||||
class_id="ray.RayTaskError")
|
||||
register_custom_serializer(
|
||||
RayGetError,
|
||||
use_dict=True,
|
||||
local=True,
|
||||
driver_id=driver_id,
|
||||
class_id="ray.RayGetError")
|
||||
register_custom_serializer(
|
||||
RayGetArgumentError,
|
||||
use_dict=True,
|
||||
local=True,
|
||||
driver_id=driver_id,
|
||||
class_id="ray.RayGetArgumentError")
|
||||
# Tell Ray to serialize lambdas with pickle.
|
||||
register_custom_serializer(
|
||||
type(lambda: 0),
|
||||
use_pickle=True,
|
||||
local=True,
|
||||
driver_id=driver_id,
|
||||
class_id="lambda")
|
||||
# Tell Ray to serialize types with pickle.
|
||||
register_custom_serializer(
|
||||
type(int),
|
||||
use_pickle=True,
|
||||
local=True,
|
||||
driver_id=driver_id,
|
||||
class_id="type")
|
||||
# Tell Ray to serialize FunctionSignatures as dictionaries. This is
|
||||
# used when passing around actor handles.
|
||||
register_custom_serializer(
|
||||
ray.signature.FunctionSignature,
|
||||
use_dict=True,
|
||||
local=True,
|
||||
driver_id=driver_id,
|
||||
class_id="ray.signature.FunctionSignature")
|
||||
|
||||
|
||||
def get_address_info_from_redis_helper(redis_address,
|
||||
@@ -2167,10 +2220,6 @@ def connect(info,
|
||||
# driver task.
|
||||
worker.current_task_id = driver_task.task_id()
|
||||
|
||||
# Initialize the serialization library. This registers some classes, and so
|
||||
# it must be run before we export all of the cached remote functions.
|
||||
_initialize_serialization()
|
||||
|
||||
# Start the import thread
|
||||
import_thread.ImportThread(worker, mode).start()
|
||||
|
||||
@@ -2242,7 +2291,7 @@ def disconnect(worker=global_worker):
|
||||
worker.connected = False
|
||||
worker.cached_functions_to_run = []
|
||||
worker.cached_remote_functions_and_actors = []
|
||||
worker.serialization_context = pyarrow.SerializationContext()
|
||||
worker.serialization_context_map.clear()
|
||||
|
||||
|
||||
def _try_to_compute_deterministic_class_id(cls, depth=5):
|
||||
@@ -2293,6 +2342,8 @@ def register_custom_serializer(cls,
|
||||
serializer=None,
|
||||
deserializer=None,
|
||||
local=False,
|
||||
driver_id=None,
|
||||
class_id=None,
|
||||
worker=global_worker):
|
||||
"""Enable serialization and deserialization for a particular class.
|
||||
|
||||
@@ -2313,6 +2364,9 @@ def register_custom_serializer(cls,
|
||||
if and only if use_pickle and use_dict are False.
|
||||
local: True if the serializers should only be registered on the current
|
||||
worker. This should usually be False.
|
||||
driver_id: ID of the driver that we want to register the class for.
|
||||
class_id: ID of the class that we are registering. If this is not
|
||||
specified, we will calculate a new one inside the function.
|
||||
|
||||
Raises:
|
||||
Exception: An exception is raised if pickle=False and the class cannot
|
||||
@@ -2332,25 +2386,32 @@ def register_custom_serializer(cls,
|
||||
# Raise an exception if cls cannot be serialized efficiently by Ray.
|
||||
serialization.check_serializable(cls)
|
||||
|
||||
if not local:
|
||||
# In this case, the class ID will be used to deduplicate the class
|
||||
# across workers. Note that cloudpickle unfortunately does not produce
|
||||
# deterministic strings, so these IDs could be different on different
|
||||
# workers. We could use something weaker like cls.__name__, however
|
||||
# that would run the risk of having collisions. TODO(rkn): We should
|
||||
# improve this.
|
||||
try:
|
||||
# Attempt to produce a class ID that will be the same on each
|
||||
# worker. However, determinism is not guaranteed, and the result
|
||||
# may be different on different workers.
|
||||
class_id = _try_to_compute_deterministic_class_id(cls)
|
||||
except Exception:
|
||||
raise serialization.CloudPickleError("Failed to pickle class "
|
||||
"'{}'".format(cls))
|
||||
if class_id is None:
|
||||
if not local:
|
||||
# In this case, the class ID will be used to deduplicate the class
|
||||
# across workers. Note that cloudpickle unfortunately does not
|
||||
# produce deterministic strings, so these IDs could be different
|
||||
# on different workers. We could use something weaker like
|
||||
# cls.__name__, however that would run the risk of having
|
||||
# collisions.
|
||||
# TODO(rkn): We should improve this.
|
||||
try:
|
||||
# Attempt to produce a class ID that will be the same on each
|
||||
# worker. However, determinism is not guaranteed, and the
|
||||
# result may be different on different workers.
|
||||
class_id = _try_to_compute_deterministic_class_id(cls)
|
||||
except Exception as e:
|
||||
raise serialization.CloudPickleError("Failed to pickle class "
|
||||
"'{}'".format(cls))
|
||||
else:
|
||||
# In this case, the class ID only needs to be meaningful on this
|
||||
# worker and not across workers.
|
||||
class_id = random_string()
|
||||
|
||||
if driver_id is None:
|
||||
driver_id_bytes = worker.task_driver_id.id()
|
||||
else:
|
||||
# In this case, the class ID only needs to be meaningful on this worker
|
||||
# and not across workers.
|
||||
class_id = random_string()
|
||||
driver_id_bytes = driver_id.id()
|
||||
|
||||
def register_class_for_serialization(worker_info):
|
||||
# TODO(rkn): We need to be more thoughtful about what to do if custom
|
||||
@@ -2358,7 +2419,10 @@ def register_custom_serializer(cls,
|
||||
# we may want to use the last user-defined serializers and ignore
|
||||
# subsequent calls to register_custom_serializer that were made by the
|
||||
# system.
|
||||
worker_info["worker"].serialization_context.register_type(
|
||||
|
||||
serialization_context = worker_info[
|
||||
"worker"].get_serialization_context(ray.ObjectID(driver_id_bytes))
|
||||
serialization_context.register_type(
|
||||
cls,
|
||||
class_id,
|
||||
pickle=use_pickle,
|
||||
|
||||
Reference in New Issue
Block a user