diff --git a/ci/travis/ci.sh b/ci/travis/ci.sh index 6267a2321..ee339ead2 100755 --- a/ci/travis/ci.sh +++ b/ci/travis/ci.sh @@ -140,6 +140,7 @@ test_python() { python/ray/serve/... python/ray/tests/... -python/ray/serve:test_api # segfault on windows? https://github.com/ray-project/ray/issues/12541 + -python/ray/serve:test_handle # "fatal error" (?) https://github.com/ray-project/ray/pull/13695 -python/ray/tests:test_actor_advanced # timeout -python/ray/tests:test_advanced_2 -python/ray/tests:test_advanced_3 # test_invalid_unicode_in_worker_log() fails on Windows diff --git a/python/ray/serve/api.py b/python/ray/serve/api.py index 2e0490631..4c0a0a91f 100644 --- a/python/ray/serve/api.py +++ b/python/ray/serve/api.py @@ -66,6 +66,8 @@ def _ensure_connected(f: Callable) -> Callable: class ThreadProxiedRouter: def __init__(self, controller_handle, sync: bool): + self.controller_handle = controller_handle + self.sync = sync self.router = Router(controller_handle) if sync: @@ -92,6 +94,11 @@ class ThreadProxiedRouter: **kwargs) return coro + def __reduce__(self): + deserializer = ThreadProxiedRouter + serialized_data = (self.controller_handle, self.sync) + return deserializer, serialized_data + class Client: def __init__(self, diff --git a/python/ray/serve/handle.py b/python/ray/serve/handle.py index 475f64556..3659e5978 100644 --- a/python/ray/serve/handle.py +++ b/python/ray/serve/handle.py @@ -6,7 +6,6 @@ from enum import Enum from ray.serve.utils import get_random_letters from ray.util import metrics -from ray.serve.router import Router @dataclass(frozen=True) @@ -42,10 +41,11 @@ class RayServeHandle: # raises RayTaskError Exception """ - def __init__(self, - router: Router, - endpoint_name, - handle_options: Optional[HandleOptions] = None): + def __init__( + self, + router, # ThreadProxiedRouter + endpoint_name, + handle_options: Optional[HandleOptions] = None): self.router = router self.endpoint_name = endpoint_name self.handle_options = handle_options or HandleOptions() @@ -91,7 +91,7 @@ class RayServeHandle: async def remote(self, request_data: Optional[Union[Dict, Any]] = None, **kwargs): - """Issue an asynchrounous request to the endpoint. + """Issue an asynchronous request to the endpoint. Returns a Ray ObjectRef whose results can be waited for or retrieved using ray.wait or ray.get (or ``await object_ref``), respectively. @@ -112,6 +112,12 @@ class RayServeHandle: def __repr__(self): return f"{self.__class__.__name__}(endpoint='{self.endpoint_name}')" + def __reduce__(self): + deserializer = RayServeHandle + serialized_data = (self.router, self.endpoint_name, + self.handle_options) + return deserializer, serialized_data + class RayServeSyncHandle(RayServeHandle): def remote(self, request_data: Optional[Union[Dict, Any]] = None, @@ -138,3 +144,9 @@ class RayServeSyncHandle(RayServeHandle): future: concurrent.futures.Future = asyncio.run_coroutine_threadsafe( coro, self.router.async_loop) return future.result() + + def __reduce__(self): + deserializer = RayServeSyncHandle + serialized_data = (self.router, self.endpoint_name, + self.handle_options) + return deserializer, serialized_data diff --git a/python/ray/serve/tests/test_handle.py b/python/ray/serve/tests/test_handle.py index c17db7686..88ab9d2c2 100644 --- a/python/ray/serve/tests/test_handle.py +++ b/python/ray/serve/tests/test_handle.py @@ -1,9 +1,51 @@ import requests - +import pytest import ray from ray import serve +@pytest.mark.asyncio +async def test_async_handle_serializable(serve_instance): + client = serve_instance + + def f(_): + return "hello" + + client.create_backend("f", f) + client.create_endpoint("f", backend="f") + + @ray.remote + class TaskActor: + async def task(self, handle): + ref = await handle.remote() + output = await ref + return output + + handle = client.get_handle("f", sync=False) + + task_actor = TaskActor.remote() + result = await task_actor.task.remote(handle) + assert result == "hello" + + +def test_sync_handle_serializable(serve_instance): + client = serve_instance + + def f(_): + return "hello" + + client.create_backend("f", f) + client.create_endpoint("f", backend="f") + + @ray.remote + def task(handle): + return ray.get(handle.remote()) + + handle = client.get_handle("f", sync=True) + result_ref = task.remote(handle) + assert ray.get(result_ref) == "hello" + + def test_handle_in_endpoint(serve_instance): client = serve_instance