Deserialize Args in Event Loop Thread (#7806)

This commit is contained in:
Simon Mo
2020-03-30 18:28:13 -07:00
committed by GitHub
parent f86e623095
commit dc9b62e007
3 changed files with 53 additions and 32 deletions
+39 -21
View File
@@ -297,22 +297,6 @@ cdef void prepare_args(
CTaskArg.PassByReference((CObjectID.FromBinary(
core_worker.put_serialized_object(serialized_arg)))))
cdef deserialize_args(
const c_vector[shared_ptr[CRayObject]] &c_args,
const c_vector[CObjectID] &arg_reference_ids):
if c_args.empty():
return [], {}
args = ray.worker.global_worker.deserialize_objects(
RayObjectsToDataMetadataPairs(c_args),
VectorToObjectIDs(arg_reference_ids))
for arg in args:
if isinstance(arg, RayError):
raise arg
return ray.signature.recover_args(args)
cdef execute_task(
CTaskType task_type,
@@ -331,7 +315,7 @@ cdef execute_task(
CoreWorker core_worker = worker.core_worker
JobID job_id = core_worker.get_current_job_id()
CTaskID task_id = core_worker.core_worker.get().GetCurrentTaskId()
CFiberEvent fiber_event
CFiberEvent task_done_event
# Automatically restrict the GPUs available to this task.
ray.utils.set_cuda_visible_devices(ray.get_gpu_ids())
@@ -410,13 +394,13 @@ cdef execute_task(
future = asyncio.run_coroutine_threadsafe(coroutine, loop)
def callback(future):
fiber_event.Notify()
task_done_event.Notify()
monitor_state.unregister_coroutine(coroutine)
future.add_done_callback(callback)
with nogil:
(core_worker.core_worker.get()
.YieldCurrentFiber(fiber_event))
.YieldCurrentFiber(task_done_event))
return future.result()
@@ -431,7 +415,30 @@ cdef execute_task(
worker.memory_monitor.raise_if_low_memory()
with core_worker.profile_event(b"task:deserialize_arguments"):
args, kwargs = deserialize_args(c_args, c_arg_reference_ids)
if c_args.empty():
args, kwargs = [], {}
else:
metadata_pairs = RayObjectsToDataMetadataPairs(c_args)
object_ids = VectorToObjectIDs(c_arg_reference_ids)
if core_worker.current_actor_is_asyncio():
# We deserialize objects in event loop thread to
# prevent segfaults. See #7799
def deserialize_args():
return (ray.worker.global_worker
.deserialize_objects(
metadata_pairs, object_ids))
args = core_worker.run_function_in_event_loop(
deserialize_args)
else:
args = ray.worker.global_worker.deserialize_objects(
metadata_pairs, object_ids)
for arg in args:
if isinstance(arg, RayError):
raise arg
args, kwargs = ray.signature.recover_args(args)
if (<int>task_type == <int>TASK_TYPE_ACTOR_CREATION_TASK):
actor = worker.actors[core_worker.get_actor_id()]
class_name = actor.__class__.__name__
@@ -446,7 +453,6 @@ cdef execute_task(
task_exception = False
if c_return_ids.size() == 1:
outputs = (outputs,)
# Store the outputs in the object store.
with core_worker.profile_event(b"task:store_outputs"):
core_worker.store_task_outputs(
@@ -1107,6 +1113,18 @@ cdef class CoreWorker:
return self.async_event_loop
def run_function_in_event_loop(self, func):
cdef:
CFiberEvent event
loop = self.create_or_get_event_loop()
coroutine = sync_to_async(func)()
future = asyncio.run_coroutine_threadsafe(coroutine, loop)
future.add_done_callback(lambda _: event.Notify())
with nogil:
(self.core_worker.get()
.YieldCurrentFiber(event))
return future.result()
def destory_event_loop_if_exists(self):
if self.async_event_loop is not None:
self.async_event_loop.stop()
+10 -10
View File
@@ -37,7 +37,7 @@ async def test_runner_actor(serve_instance):
PRODUCER_NAME = "prod"
runner = TaskRunnerActor.remote(echo)
runner._ray_serve_setup.remote(CONSUMER_NAME, q, runner)
ray.get(runner._ray_serve_setup.remote(CONSUMER_NAME, q, runner))
runner._ray_serve_fetch.remote()
q.link.remote(PRODUCER_NAME, CONSUMER_NAME)
@@ -68,7 +68,7 @@ async def test_ray_serve_mixin(serve_instance):
runner = CustomActor.remote(3)
runner._ray_serve_setup.remote(CONSUMER_NAME, q, runner)
ray.get(runner._ray_serve_setup.remote(CONSUMER_NAME, q, runner))
runner._ray_serve_fetch.remote()
q.link.remote(PRODUCER_NAME, CONSUMER_NAME)
@@ -92,7 +92,7 @@ async def test_task_runner_check_context(serve_instance):
runner = TaskRunnerActor.remote(echo)
runner._ray_serve_setup.remote(CONSUMER_NAME, q, runner)
ray.get(runner._ray_serve_setup.remote(CONSUMER_NAME, q, runner))
runner._ray_serve_fetch.remote()
q.link.remote(PRODUCER_NAME, CONSUMER_NAME)
@@ -122,7 +122,7 @@ async def test_task_runner_custom_method_single(serve_instance):
runner = CustomActor.remote()
runner._ray_serve_setup.remote(CONSUMER_NAME, q, runner)
ray.get(runner._ray_serve_setup.remote(CONSUMER_NAME, q, runner))
runner._ray_serve_fetch.remote()
q.link.remote(PRODUCER_NAME, CONSUMER_NAME)
@@ -163,11 +163,11 @@ async def test_task_runner_custom_method_batch(serve_instance):
runner = CustomActor.remote()
runner._ray_serve_setup.remote(CONSUMER_NAME, q, runner)
ray.get(runner._ray_serve_setup.remote(CONSUMER_NAME, q, runner))
q.link.remote(PRODUCER_NAME, CONSUMER_NAME)
q.set_backend_config.remote(
CONSUMER_NAME, BackendConfig(max_batch_size=2).__dict__)
await q.link.remote(PRODUCER_NAME, CONSUMER_NAME)
await q.set_backend_config.remote(
CONSUMER_NAME, BackendConfig(max_batch_size=10).__dict__)
a_query_param = RequestMetadata(
PRODUCER_NAME, context.TaskContext.Python, call_method="a")
@@ -177,7 +177,7 @@ async def test_task_runner_custom_method_batch(serve_instance):
futures = [q.enqueue_request.remote(a_query_param) for _ in range(2)]
futures += [q.enqueue_request.remote(b_query_param) for _ in range(2)]
runner._ray_serve_fetch.remote()
await runner._ray_serve_fetch.remote()
gathered = await asyncio.gather(*futures)
assert gathered == ["a-0", "a-1", "b-0", "b-1"]
assert set(gathered) == {"a-0", "a-1", "b-0", "b-1"}