mirror of
https://github.com/wassname/ray.git
synced 2026-06-27 19:16:19 +08:00
[Serve] Revert "Revert "[Serve] Fix ServeHandle serialization"" and disable failing Windows test (#13771)
This commit is contained in:
@@ -140,6 +140,7 @@ test_python() {
|
||||
python/ray/serve/...
|
||||
python/ray/tests/...
|
||||
-python/ray/serve:test_api # segfault on windows? https://github.com/ray-project/ray/issues/12541
|
||||
-python/ray/serve:test_handle # "fatal error" (?) https://github.com/ray-project/ray/pull/13695
|
||||
-python/ray/tests:test_actor_advanced # timeout
|
||||
-python/ray/tests:test_advanced_2
|
||||
-python/ray/tests:test_advanced_3 # test_invalid_unicode_in_worker_log() fails on Windows
|
||||
|
||||
@@ -66,6 +66,8 @@ def _ensure_connected(f: Callable) -> Callable:
|
||||
|
||||
class ThreadProxiedRouter:
|
||||
def __init__(self, controller_handle, sync: bool):
|
||||
self.controller_handle = controller_handle
|
||||
self.sync = sync
|
||||
self.router = Router(controller_handle)
|
||||
|
||||
if sync:
|
||||
@@ -92,6 +94,11 @@ class ThreadProxiedRouter:
|
||||
**kwargs)
|
||||
return coro
|
||||
|
||||
def __reduce__(self):
|
||||
deserializer = ThreadProxiedRouter
|
||||
serialized_data = (self.controller_handle, self.sync)
|
||||
return deserializer, serialized_data
|
||||
|
||||
|
||||
class Client:
|
||||
def __init__(self,
|
||||
|
||||
@@ -6,7 +6,6 @@ from enum import Enum
|
||||
|
||||
from ray.serve.utils import get_random_letters
|
||||
from ray.util import metrics
|
||||
from ray.serve.router import Router
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
@@ -42,10 +41,11 @@ class RayServeHandle:
|
||||
# raises RayTaskError Exception
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
router: Router,
|
||||
endpoint_name,
|
||||
handle_options: Optional[HandleOptions] = None):
|
||||
def __init__(
|
||||
self,
|
||||
router, # ThreadProxiedRouter
|
||||
endpoint_name,
|
||||
handle_options: Optional[HandleOptions] = None):
|
||||
self.router = router
|
||||
self.endpoint_name = endpoint_name
|
||||
self.handle_options = handle_options or HandleOptions()
|
||||
@@ -91,7 +91,7 @@ class RayServeHandle:
|
||||
async def remote(self,
|
||||
request_data: Optional[Union[Dict, Any]] = None,
|
||||
**kwargs):
|
||||
"""Issue an asynchrounous request to the endpoint.
|
||||
"""Issue an asynchronous request to the endpoint.
|
||||
|
||||
Returns a Ray ObjectRef whose results can be waited for or retrieved
|
||||
using ray.wait or ray.get (or ``await object_ref``), respectively.
|
||||
@@ -112,6 +112,12 @@ class RayServeHandle:
|
||||
def __repr__(self):
|
||||
return f"{self.__class__.__name__}(endpoint='{self.endpoint_name}')"
|
||||
|
||||
def __reduce__(self):
|
||||
deserializer = RayServeHandle
|
||||
serialized_data = (self.router, self.endpoint_name,
|
||||
self.handle_options)
|
||||
return deserializer, serialized_data
|
||||
|
||||
|
||||
class RayServeSyncHandle(RayServeHandle):
|
||||
def remote(self, request_data: Optional[Union[Dict, Any]] = None,
|
||||
@@ -138,3 +144,9 @@ class RayServeSyncHandle(RayServeHandle):
|
||||
future: concurrent.futures.Future = asyncio.run_coroutine_threadsafe(
|
||||
coro, self.router.async_loop)
|
||||
return future.result()
|
||||
|
||||
def __reduce__(self):
|
||||
deserializer = RayServeSyncHandle
|
||||
serialized_data = (self.router, self.endpoint_name,
|
||||
self.handle_options)
|
||||
return deserializer, serialized_data
|
||||
|
||||
@@ -1,9 +1,51 @@
|
||||
import requests
|
||||
|
||||
import pytest
|
||||
import ray
|
||||
from ray import serve
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_handle_serializable(serve_instance):
|
||||
client = serve_instance
|
||||
|
||||
def f(_):
|
||||
return "hello"
|
||||
|
||||
client.create_backend("f", f)
|
||||
client.create_endpoint("f", backend="f")
|
||||
|
||||
@ray.remote
|
||||
class TaskActor:
|
||||
async def task(self, handle):
|
||||
ref = await handle.remote()
|
||||
output = await ref
|
||||
return output
|
||||
|
||||
handle = client.get_handle("f", sync=False)
|
||||
|
||||
task_actor = TaskActor.remote()
|
||||
result = await task_actor.task.remote(handle)
|
||||
assert result == "hello"
|
||||
|
||||
|
||||
def test_sync_handle_serializable(serve_instance):
|
||||
client = serve_instance
|
||||
|
||||
def f(_):
|
||||
return "hello"
|
||||
|
||||
client.create_backend("f", f)
|
||||
client.create_endpoint("f", backend="f")
|
||||
|
||||
@ray.remote
|
||||
def task(handle):
|
||||
return ray.get(handle.remote())
|
||||
|
||||
handle = client.get_handle("f", sync=True)
|
||||
result_ref = task.remote(handle)
|
||||
assert ray.get(result_ref) == "hello"
|
||||
|
||||
|
||||
def test_handle_in_endpoint(serve_instance):
|
||||
client = serve_instance
|
||||
|
||||
|
||||
Reference in New Issue
Block a user