From 202fbdf38c48f7db54994e7143232a75490c9fdb Mon Sep 17 00:00:00 2001 From: architkulkarni Date: Wed, 27 Jan 2021 12:11:31 -0800 Subject: [PATCH] [Serve] Fix ServeHandle serialization (#13695) --- python/ray/serve/api.py | 7 +++++ python/ray/serve/handle.py | 25 ++++++++++----- python/ray/serve/tests/test_handle.py | 44 ++++++++++++++++++++++++++- 3 files changed, 68 insertions(+), 8 deletions(-) diff --git a/python/ray/serve/api.py b/python/ray/serve/api.py index b42cd7846..19783dc37 100644 --- a/python/ray/serve/api.py +++ b/python/ray/serve/api.py @@ -66,6 +66,8 @@ def _ensure_connected(f: Callable) -> Callable: class ThreadProxiedRouter: def __init__(self, controller_handle, sync: bool): + self.controller_handle = controller_handle + self.sync = sync self.router = Router(controller_handle) if sync: @@ -92,6 +94,11 @@ class ThreadProxiedRouter: **kwargs) return coro + def __reduce__(self): + deserializer = ThreadProxiedRouter + serialized_data = (self.controller_handle, self.sync) + return deserializer, serialized_data + class Client: def __init__(self, diff --git a/python/ray/serve/handle.py b/python/ray/serve/handle.py index c6951c638..4ee2624a8 100644 --- a/python/ray/serve/handle.py +++ b/python/ray/serve/handle.py @@ -4,8 +4,6 @@ from dataclasses import dataclass, field from typing import Any, Dict, Optional, Union from enum import Enum -from ray.serve.router import Router - @dataclass(frozen=True) class HandleOptions: @@ -40,10 +38,11 @@ class RayServeHandle: # raises RayTaskError Exception """ - def __init__(self, - router: Router, - endpoint_name, - handle_options: Optional[HandleOptions] = None): + def __init__( + self, + router, # ThreadProxiedRouter + endpoint_name, + handle_options: Optional[HandleOptions] = None): self.router = router self.endpoint_name = endpoint_name self.handle_options = handle_options or HandleOptions() @@ -78,7 +77,7 @@ class RayServeHandle: async def remote(self, request_data: Optional[Union[Dict, Any]] = None, **kwargs): - """Issue an asynchrounous request to the endpoint. + """Issue an asynchronous request to the endpoint. Returns a Ray ObjectRef whose results can be waited for or retrieved using ray.wait or ray.get (or ``await object_ref``), respectively. @@ -98,6 +97,12 @@ class RayServeHandle: def __repr__(self): return f"{self.__class__.__name__}(endpoint='{self.endpoint_name}')" + def __reduce__(self): + deserializer = RayServeHandle + serialized_data = (self.router, self.endpoint_name, + self.handle_options) + return deserializer, serialized_data + class RayServeSyncHandle(RayServeHandle): def remote(self, request_data: Optional[Union[Dict, Any]] = None, @@ -123,3 +128,9 @@ class RayServeSyncHandle(RayServeHandle): future: concurrent.futures.Future = asyncio.run_coroutine_threadsafe( coro, self.router.async_loop) return future.result() + + def __reduce__(self): + deserializer = RayServeSyncHandle + serialized_data = (self.router, self.endpoint_name, + self.handle_options) + return deserializer, serialized_data diff --git a/python/ray/serve/tests/test_handle.py b/python/ray/serve/tests/test_handle.py index c17db7686..88ab9d2c2 100644 --- a/python/ray/serve/tests/test_handle.py +++ b/python/ray/serve/tests/test_handle.py @@ -1,9 +1,51 @@ import requests - +import pytest import ray from ray import serve +@pytest.mark.asyncio +async def test_async_handle_serializable(serve_instance): + client = serve_instance + + def f(_): + return "hello" + + client.create_backend("f", f) + client.create_endpoint("f", backend="f") + + @ray.remote + class TaskActor: + async def task(self, handle): + ref = await handle.remote() + output = await ref + return output + + handle = client.get_handle("f", sync=False) + + task_actor = TaskActor.remote() + result = await task_actor.task.remote(handle) + assert result == "hello" + + +def test_sync_handle_serializable(serve_instance): + client = serve_instance + + def f(_): + return "hello" + + client.create_backend("f", f) + client.create_endpoint("f", backend="f") + + @ray.remote + def task(handle): + return ray.get(handle.remote()) + + handle = client.get_handle("f", sync=True) + result_ref = task.remote(handle) + assert ray.get(result_ref) == "hello" + + def test_handle_in_endpoint(serve_instance): client = serve_instance