mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 14:31:15 +08:00
Deserialize Args in Event Loop Thread (#7806)
This commit is contained in:
+39
-21
@@ -297,22 +297,6 @@ cdef void prepare_args(
|
||||
CTaskArg.PassByReference((CObjectID.FromBinary(
|
||||
core_worker.put_serialized_object(serialized_arg)))))
|
||||
|
||||
cdef deserialize_args(
|
||||
const c_vector[shared_ptr[CRayObject]] &c_args,
|
||||
const c_vector[CObjectID] &arg_reference_ids):
|
||||
if c_args.empty():
|
||||
return [], {}
|
||||
|
||||
args = ray.worker.global_worker.deserialize_objects(
|
||||
RayObjectsToDataMetadataPairs(c_args),
|
||||
VectorToObjectIDs(arg_reference_ids))
|
||||
|
||||
for arg in args:
|
||||
if isinstance(arg, RayError):
|
||||
raise arg
|
||||
|
||||
return ray.signature.recover_args(args)
|
||||
|
||||
|
||||
cdef execute_task(
|
||||
CTaskType task_type,
|
||||
@@ -331,7 +315,7 @@ cdef execute_task(
|
||||
CoreWorker core_worker = worker.core_worker
|
||||
JobID job_id = core_worker.get_current_job_id()
|
||||
CTaskID task_id = core_worker.core_worker.get().GetCurrentTaskId()
|
||||
CFiberEvent fiber_event
|
||||
CFiberEvent task_done_event
|
||||
|
||||
# Automatically restrict the GPUs available to this task.
|
||||
ray.utils.set_cuda_visible_devices(ray.get_gpu_ids())
|
||||
@@ -410,13 +394,13 @@ cdef execute_task(
|
||||
future = asyncio.run_coroutine_threadsafe(coroutine, loop)
|
||||
|
||||
def callback(future):
|
||||
fiber_event.Notify()
|
||||
task_done_event.Notify()
|
||||
monitor_state.unregister_coroutine(coroutine)
|
||||
|
||||
future.add_done_callback(callback)
|
||||
with nogil:
|
||||
(core_worker.core_worker.get()
|
||||
.YieldCurrentFiber(fiber_event))
|
||||
.YieldCurrentFiber(task_done_event))
|
||||
|
||||
return future.result()
|
||||
|
||||
@@ -431,7 +415,30 @@ cdef execute_task(
|
||||
worker.memory_monitor.raise_if_low_memory()
|
||||
|
||||
with core_worker.profile_event(b"task:deserialize_arguments"):
|
||||
args, kwargs = deserialize_args(c_args, c_arg_reference_ids)
|
||||
if c_args.empty():
|
||||
args, kwargs = [], {}
|
||||
else:
|
||||
metadata_pairs = RayObjectsToDataMetadataPairs(c_args)
|
||||
object_ids = VectorToObjectIDs(c_arg_reference_ids)
|
||||
|
||||
if core_worker.current_actor_is_asyncio():
|
||||
# We deserialize objects in event loop thread to
|
||||
# prevent segfaults. See #7799
|
||||
def deserialize_args():
|
||||
return (ray.worker.global_worker
|
||||
.deserialize_objects(
|
||||
metadata_pairs, object_ids))
|
||||
args = core_worker.run_function_in_event_loop(
|
||||
deserialize_args)
|
||||
else:
|
||||
args = ray.worker.global_worker.deserialize_objects(
|
||||
metadata_pairs, object_ids)
|
||||
|
||||
for arg in args:
|
||||
if isinstance(arg, RayError):
|
||||
raise arg
|
||||
args, kwargs = ray.signature.recover_args(args)
|
||||
|
||||
if (<int>task_type == <int>TASK_TYPE_ACTOR_CREATION_TASK):
|
||||
actor = worker.actors[core_worker.get_actor_id()]
|
||||
class_name = actor.__class__.__name__
|
||||
@@ -446,7 +453,6 @@ cdef execute_task(
|
||||
task_exception = False
|
||||
if c_return_ids.size() == 1:
|
||||
outputs = (outputs,)
|
||||
|
||||
# Store the outputs in the object store.
|
||||
with core_worker.profile_event(b"task:store_outputs"):
|
||||
core_worker.store_task_outputs(
|
||||
@@ -1107,6 +1113,18 @@ cdef class CoreWorker:
|
||||
|
||||
return self.async_event_loop
|
||||
|
||||
def run_function_in_event_loop(self, func):
|
||||
cdef:
|
||||
CFiberEvent event
|
||||
loop = self.create_or_get_event_loop()
|
||||
coroutine = sync_to_async(func)()
|
||||
future = asyncio.run_coroutine_threadsafe(coroutine, loop)
|
||||
future.add_done_callback(lambda _: event.Notify())
|
||||
with nogil:
|
||||
(self.core_worker.get()
|
||||
.YieldCurrentFiber(event))
|
||||
return future.result()
|
||||
|
||||
def destory_event_loop_if_exists(self):
|
||||
if self.async_event_loop is not None:
|
||||
self.async_event_loop.stop()
|
||||
|
||||
@@ -37,7 +37,7 @@ async def test_runner_actor(serve_instance):
|
||||
PRODUCER_NAME = "prod"
|
||||
|
||||
runner = TaskRunnerActor.remote(echo)
|
||||
runner._ray_serve_setup.remote(CONSUMER_NAME, q, runner)
|
||||
ray.get(runner._ray_serve_setup.remote(CONSUMER_NAME, q, runner))
|
||||
runner._ray_serve_fetch.remote()
|
||||
|
||||
q.link.remote(PRODUCER_NAME, CONSUMER_NAME)
|
||||
@@ -68,7 +68,7 @@ async def test_ray_serve_mixin(serve_instance):
|
||||
|
||||
runner = CustomActor.remote(3)
|
||||
|
||||
runner._ray_serve_setup.remote(CONSUMER_NAME, q, runner)
|
||||
ray.get(runner._ray_serve_setup.remote(CONSUMER_NAME, q, runner))
|
||||
runner._ray_serve_fetch.remote()
|
||||
|
||||
q.link.remote(PRODUCER_NAME, CONSUMER_NAME)
|
||||
@@ -92,7 +92,7 @@ async def test_task_runner_check_context(serve_instance):
|
||||
|
||||
runner = TaskRunnerActor.remote(echo)
|
||||
|
||||
runner._ray_serve_setup.remote(CONSUMER_NAME, q, runner)
|
||||
ray.get(runner._ray_serve_setup.remote(CONSUMER_NAME, q, runner))
|
||||
runner._ray_serve_fetch.remote()
|
||||
|
||||
q.link.remote(PRODUCER_NAME, CONSUMER_NAME)
|
||||
@@ -122,7 +122,7 @@ async def test_task_runner_custom_method_single(serve_instance):
|
||||
|
||||
runner = CustomActor.remote()
|
||||
|
||||
runner._ray_serve_setup.remote(CONSUMER_NAME, q, runner)
|
||||
ray.get(runner._ray_serve_setup.remote(CONSUMER_NAME, q, runner))
|
||||
runner._ray_serve_fetch.remote()
|
||||
|
||||
q.link.remote(PRODUCER_NAME, CONSUMER_NAME)
|
||||
@@ -163,11 +163,11 @@ async def test_task_runner_custom_method_batch(serve_instance):
|
||||
|
||||
runner = CustomActor.remote()
|
||||
|
||||
runner._ray_serve_setup.remote(CONSUMER_NAME, q, runner)
|
||||
ray.get(runner._ray_serve_setup.remote(CONSUMER_NAME, q, runner))
|
||||
|
||||
q.link.remote(PRODUCER_NAME, CONSUMER_NAME)
|
||||
q.set_backend_config.remote(
|
||||
CONSUMER_NAME, BackendConfig(max_batch_size=2).__dict__)
|
||||
await q.link.remote(PRODUCER_NAME, CONSUMER_NAME)
|
||||
await q.set_backend_config.remote(
|
||||
CONSUMER_NAME, BackendConfig(max_batch_size=10).__dict__)
|
||||
|
||||
a_query_param = RequestMetadata(
|
||||
PRODUCER_NAME, context.TaskContext.Python, call_method="a")
|
||||
@@ -177,7 +177,7 @@ async def test_task_runner_custom_method_batch(serve_instance):
|
||||
futures = [q.enqueue_request.remote(a_query_param) for _ in range(2)]
|
||||
futures += [q.enqueue_request.remote(b_query_param) for _ in range(2)]
|
||||
|
||||
runner._ray_serve_fetch.remote()
|
||||
await runner._ray_serve_fetch.remote()
|
||||
|
||||
gathered = await asyncio.gather(*futures)
|
||||
assert gathered == ["a-0", "a-1", "b-0", "b-1"]
|
||||
assert set(gathered) == {"a-0", "a-1", "b-0", "b-1"}
|
||||
|
||||
Reference in New Issue
Block a user