mirror of
https://github.com/wassname/ray.git
synced 2026-06-27 21:38:18 +08:00
[Serve] Handle Bug Fixes (#12971)
This commit is contained in:
+87
-20
@@ -4,22 +4,40 @@ import time
|
||||
from functools import wraps
|
||||
import os
|
||||
from uuid import UUID
|
||||
import threading
|
||||
from typing import Any, Callable, Coroutine, Dict, List, Optional, Type, Union
|
||||
|
||||
from ray.serve.context import TaskContext
|
||||
|
||||
import ray
|
||||
from ray.serve.constants import (DEFAULT_HTTP_HOST, DEFAULT_HTTP_PORT,
|
||||
SERVE_CONTROLLER_NAME, HTTP_PROXY_TIMEOUT)
|
||||
from ray.serve.controller import ServeController
|
||||
from ray.serve.handle import RayServeHandle
|
||||
from ray.serve.handle import RayServeHandle, RayServeSyncHandle
|
||||
from ray.serve.utils import (block_until_http_ready, format_actor_name,
|
||||
get_random_letters, logger, get_conda_env_dir)
|
||||
from ray.serve.exceptions import RayServeException
|
||||
from ray.serve.config import BackendConfig, ReplicaConfig, BackendMetadata
|
||||
from ray.serve.env import CondaEnv
|
||||
from ray.serve.router import RequestMetadata, Router
|
||||
from ray.actor import ActorHandle
|
||||
from typing import Any, Callable, Dict, List, Optional, Type, Union
|
||||
|
||||
_INTERNAL_CONTROLLER_NAME = None
|
||||
|
||||
global_async_loop = None
|
||||
|
||||
|
||||
def create_or_get_async_loop_in_thread():
|
||||
global global_async_loop
|
||||
if global_async_loop is None:
|
||||
global_async_loop = asyncio.new_event_loop()
|
||||
thread = threading.Thread(
|
||||
daemon=True,
|
||||
target=global_async_loop.run_forever,
|
||||
)
|
||||
thread.start()
|
||||
return global_async_loop
|
||||
|
||||
|
||||
def _set_internal_controller_name(name):
|
||||
global _INTERNAL_CONTROLLER_NAME
|
||||
@@ -36,6 +54,36 @@ def _ensure_connected(f: Callable) -> Callable:
|
||||
return check
|
||||
|
||||
|
||||
class ThreadProxiedRouter:
|
||||
def __init__(self, controller_handle, sync: bool):
|
||||
self.router = Router(controller_handle)
|
||||
|
||||
if sync:
|
||||
self.async_loop = create_or_get_async_loop_in_thread()
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
self.router.setup_in_async_loop(),
|
||||
self.async_loop,
|
||||
)
|
||||
else:
|
||||
self.async_loop = asyncio.get_event_loop()
|
||||
self.async_loop.create_task(self.router.setup_in_async_loop())
|
||||
|
||||
def _remote(self, endpoint_name, handle_options, request_data,
|
||||
kwargs) -> Coroutine:
|
||||
request_metadata = RequestMetadata(
|
||||
get_random_letters(10), # Used for debugging.
|
||||
endpoint_name,
|
||||
TaskContext.Python,
|
||||
call_method=handle_options.method_name,
|
||||
shard_key=handle_options.shard_key,
|
||||
http_method=handle_options.http_method,
|
||||
http_headers=handle_options.http_headers,
|
||||
)
|
||||
coro = self.router.assign_request(request_metadata, request_data,
|
||||
**kwargs)
|
||||
return coro
|
||||
|
||||
|
||||
class Client:
|
||||
def __init__(self,
|
||||
controller: ActorHandle,
|
||||
@@ -48,12 +96,8 @@ class Client:
|
||||
self._http_host, self._http_port = ray.get(
|
||||
controller.get_http_config.remote())
|
||||
|
||||
# NOTE(simon): Used to cache client.get_handle(endpoint) call. It will
|
||||
# mostly grow in size, it will only shrink when user calls the
|
||||
# .remove_endpoint method. This is fine because we expect the number of
|
||||
# endpoints to be fairly small. However, in case this dictionary does
|
||||
# grow very big, we can replace it with a LRU cache instead.
|
||||
self._handle_cache: Dict[str, ActorHandle] = dict()
|
||||
self._sync_proxied_router = None
|
||||
self._async_proxied_router = None
|
||||
|
||||
# NOTE(edoakes): Need this because the shutdown order isn't guaranteed
|
||||
# when the interpreter is exiting so we can't rely on __del__ (it
|
||||
@@ -65,6 +109,18 @@ class Client:
|
||||
|
||||
atexit.register(shutdown_serve_client)
|
||||
|
||||
def _get_proxied_router(self, sync: bool):
|
||||
if sync:
|
||||
if self._sync_proxied_router is None:
|
||||
self._sync_proxied_router = ThreadProxiedRouter(
|
||||
self._controller, sync=True)
|
||||
return self._sync_proxied_router
|
||||
else:
|
||||
if self._async_proxied_router is None:
|
||||
self._async_proxied_router = ThreadProxiedRouter(
|
||||
self._controller, sync=False)
|
||||
return self._async_proxied_router
|
||||
|
||||
def __del__(self):
|
||||
if not self._detached:
|
||||
logger.debug("Shutting down Ray Serve because client went out of "
|
||||
@@ -198,8 +254,6 @@ class Client:
|
||||
|
||||
Does not delete any associated backends.
|
||||
"""
|
||||
if endpoint in self._handle_cache:
|
||||
del self._handle_cache[endpoint]
|
||||
self._get_result(self._controller.delete_endpoint.remote(endpoint))
|
||||
|
||||
@_ensure_connected
|
||||
@@ -410,10 +464,11 @@ class Client:
|
||||
proportion))
|
||||
|
||||
@_ensure_connected
|
||||
def get_handle(self,
|
||||
endpoint_name: str,
|
||||
missing_ok: Optional[bool] = False,
|
||||
sync: bool = True) -> RayServeHandle:
|
||||
def get_handle(
|
||||
self,
|
||||
endpoint_name: str,
|
||||
missing_ok: Optional[bool] = False,
|
||||
sync: bool = True) -> Union[RayServeHandle, RayServeSyncHandle]:
|
||||
"""Retrieve RayServeHandle for service endpoint to invoke it from Python.
|
||||
|
||||
Args:
|
||||
@@ -433,14 +488,26 @@ class Client:
|
||||
|
||||
if asyncio.get_event_loop().is_running() and sync:
|
||||
logger.warning(
|
||||
"You are retrieving a ServeHandle inside an asyncio loop. "
|
||||
"You are retrieving a sync handle inside an asyncio loop. "
|
||||
"Try getting client.get_handle(.., sync=False) to get better "
|
||||
"performance.")
|
||||
"performance. Learn more at https://docs.ray.io/en/master/"
|
||||
"serve/advanced.html#sync-and-async-handles")
|
||||
|
||||
if endpoint_name not in self._handle_cache:
|
||||
handle = RayServeHandle(self._controller, endpoint_name, sync=sync)
|
||||
self._handle_cache[endpoint_name] = handle
|
||||
return self._handle_cache[endpoint_name]
|
||||
if not asyncio.get_event_loop().is_running() and not sync:
|
||||
logger.warning(
|
||||
"You are retrieving an async handle outside an asyncio loop. "
|
||||
"You should make sure client.get_handle is called inside a "
|
||||
"running event loop. Or call client.get_handle(.., sync=True) "
|
||||
"to create sync handle. Learn more at https://docs.ray.io/en/"
|
||||
"master/serve/advanced.html#sync-and-async-handles")
|
||||
|
||||
if sync:
|
||||
handle = RayServeSyncHandle(
|
||||
self._get_proxied_router(sync=sync), endpoint_name)
|
||||
else:
|
||||
handle = RayServeHandle(
|
||||
self._get_proxied_router(sync=sync), endpoint_name)
|
||||
return handle
|
||||
|
||||
|
||||
def start(detached: bool = False,
|
||||
|
||||
+83
-118
@@ -1,27 +1,23 @@
|
||||
import asyncio
|
||||
import concurrent.futures
|
||||
import threading
|
||||
from typing import Any, Coroutine, Dict, Optional, Union
|
||||
|
||||
import ray
|
||||
from ray.serve.context import TaskContext
|
||||
from ray.serve.router import RequestMetadata, Router
|
||||
from ray.serve.utils import get_random_letters
|
||||
from ray.serve.exceptions import RayServeException
|
||||
|
||||
global_async_loop = None
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, Optional, Union
|
||||
from enum import Enum
|
||||
|
||||
|
||||
def create_or_get_async_loop_in_thread():
|
||||
global global_async_loop
|
||||
if global_async_loop is None:
|
||||
global_async_loop = asyncio.new_event_loop()
|
||||
thread = threading.Thread(
|
||||
daemon=True,
|
||||
target=global_async_loop.run_forever,
|
||||
)
|
||||
thread.start()
|
||||
return global_async_loop
|
||||
@dataclass(frozen=True)
|
||||
class HandleOptions:
|
||||
"""Options for each ServeHandle instances. These fields are immutable."""
|
||||
method_name: str = "__call__"
|
||||
shard_key: Optional[str] = None
|
||||
http_method: str = "GET"
|
||||
http_headers: Dict[str, str] = field(default_factory=dict)
|
||||
|
||||
|
||||
# Use a global singleton enum to emulate default options. We cannot use None
|
||||
# for those option because None is a valid new value.
|
||||
class DEFAULT(Enum):
|
||||
VALUE = 1
|
||||
|
||||
|
||||
class RayServeHandle:
|
||||
@@ -31,75 +27,59 @@ class RayServeHandle:
|
||||
an HTTP endpoint.
|
||||
|
||||
Example:
|
||||
>>> handle = serve.get_handle("my_endpoint")
|
||||
>>> handle = serve_client.get_handle("my_endpoint")
|
||||
>>> handle
|
||||
RayServeHandle(
|
||||
Endpoint="my_endpoint",
|
||||
Traffic=...
|
||||
)
|
||||
>>> handle.remote(my_request_content)
|
||||
RayServeHandle(endpoint="my_endpoint")
|
||||
>>> await handle.remote(my_request_content)
|
||||
ObjectRef(...)
|
||||
>>> ray.get(handle.remote(...))
|
||||
>>> ray.get(await handle.remote(...))
|
||||
# result
|
||||
>>> ray.get(handle.remote(let_it_crash_request))
|
||||
>>> ray.get(await handle.remote(let_it_crash_request))
|
||||
# raises RayTaskError Exception
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
controller_handle,
|
||||
endpoint_name,
|
||||
sync: bool,
|
||||
*,
|
||||
method_name=None,
|
||||
shard_key=None,
|
||||
http_method=None,
|
||||
http_headers=None,
|
||||
):
|
||||
self.controller_handle = controller_handle
|
||||
def __init__(self,
|
||||
router,
|
||||
endpoint_name,
|
||||
handle_options: Optional[HandleOptions] = None):
|
||||
self.router = router
|
||||
self.endpoint_name = endpoint_name
|
||||
self.handle_options = handle_options or HandleOptions()
|
||||
|
||||
self.method_name = method_name
|
||||
self.shard_key = shard_key
|
||||
self.http_method = http_method
|
||||
self.http_headers = http_headers
|
||||
def options(self,
|
||||
*,
|
||||
method_name: Union[str, DEFAULT] = DEFAULT.VALUE,
|
||||
shard_key: Union[str, DEFAULT] = DEFAULT.VALUE,
|
||||
http_method: Union[str, DEFAULT] = DEFAULT.VALUE,
|
||||
http_headers: Union[Dict[str, str], DEFAULT] = DEFAULT.VALUE):
|
||||
"""Set options for this handle.
|
||||
|
||||
self.router = Router(self.controller_handle)
|
||||
self.sync = sync
|
||||
# In the synchrounous mode, we create a new event loop in a separate
|
||||
# thread and run the Router.setup in that loop. In the async mode, we
|
||||
# can just use the current loop we are in right now.
|
||||
if self.sync:
|
||||
self.async_loop = create_or_get_async_loop_in_thread()
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
self.router.setup_in_async_loop(),
|
||||
self.async_loop,
|
||||
)
|
||||
else: # async
|
||||
self.async_loop = asyncio.get_event_loop()
|
||||
# create_task is not threadsafe.
|
||||
self.async_loop.create_task(self.router.setup_in_async_loop())
|
||||
Args:
|
||||
method_name(str): The method to invoke on the backend.
|
||||
http_method(str): The HTTP method to use for the request.
|
||||
shard_key(str): A string to use to deterministically map this
|
||||
request to a backend if there are multiple for this endpoint.
|
||||
"""
|
||||
new_options_dict = self.handle_options.__dict__.copy()
|
||||
user_modified_options_dict = {
|
||||
key: value
|
||||
for key, value in
|
||||
zip(["method_name", "shard_key", "http_method", "http_headers"],
|
||||
[method_name, shard_key, http_method, http_headers])
|
||||
if value != DEFAULT.VALUE
|
||||
}
|
||||
new_options_dict.update(user_modified_options_dict)
|
||||
new_options = HandleOptions(**new_options_dict)
|
||||
|
||||
def _remote(self, request_data, kwargs) -> Coroutine:
|
||||
request_metadata = RequestMetadata(
|
||||
get_random_letters(10), # Used for debugging.
|
||||
self.endpoint_name,
|
||||
TaskContext.Python,
|
||||
call_method=self.method_name or "__call__",
|
||||
shard_key=self.shard_key,
|
||||
http_method=self.http_method or "GET",
|
||||
http_headers=self.http_headers or dict(),
|
||||
)
|
||||
coro = self.router.assign_request(request_metadata, request_data,
|
||||
**kwargs)
|
||||
return coro
|
||||
return self.__class__(self.router, self.endpoint_name, new_options)
|
||||
|
||||
def remote(self, request_data: Optional[Union[Dict, Any]] = None,
|
||||
**kwargs):
|
||||
"""Issue an asynchronous request to the endpoint.
|
||||
async def remote(self,
|
||||
request_data: Optional[Union[Dict, Any]] = None,
|
||||
**kwargs):
|
||||
"""Issue an asynchrounous request to the endpoint.
|
||||
|
||||
Returns a Ray ObjectRef whose results can be waited for or retrieved
|
||||
using ray.wait or ray.get, respectively.
|
||||
using ray.wait or ray.get (or ``await object_ref``), respectively.
|
||||
|
||||
Returns:
|
||||
ray.ObjectRef
|
||||
@@ -110,47 +90,32 @@ class RayServeHandle:
|
||||
``**kwargs``: All keyword arguments will be available in
|
||||
``request.query_params``.
|
||||
"""
|
||||
if not self.sync:
|
||||
raise RayServeException(
|
||||
"You are trying to call handle.remote() with async handle. "
|
||||
"Please use `await handle.remote_async()` instead.")
|
||||
|
||||
coro = self._remote(request_data, kwargs)
|
||||
future: concurrent.futures.Future = asyncio.run_coroutine_threadsafe(
|
||||
coro, self.async_loop)
|
||||
|
||||
# Block until the result is ready.
|
||||
return future.result()
|
||||
|
||||
async def remote_async(self,
|
||||
request_data: Optional[Union[Dict, Any]] = None,
|
||||
**kwargs) -> ray.ObjectRef:
|
||||
"""Experimental API for enqueue a request in async context."""
|
||||
if not asyncio.get_event_loop().is_running():
|
||||
raise RayServeException(
|
||||
"remote_async must be called from a running event loop.")
|
||||
return await self._remote(request_data, kwargs)
|
||||
|
||||
def options(self,
|
||||
method_name: Optional[str] = None,
|
||||
*,
|
||||
shard_key: Optional[str] = None,
|
||||
http_method: Optional[str] = None,
|
||||
http_headers: Optional[Dict[str, str]] = None):
|
||||
"""Set options for this handle.
|
||||
|
||||
Args:
|
||||
method_name(str): The method to invoke on the backend.
|
||||
http_method(str): The HTTP method to use for the request.
|
||||
shard_key(str): A string to use to deterministically map this
|
||||
request to a backend if there are multiple for this endpoint.
|
||||
"""
|
||||
# Don't override default non-null values.
|
||||
self.method_name = self.method_name or method_name
|
||||
self.shard_key = self.shard_key or shard_key
|
||||
self.http_method = self.http_method or http_method
|
||||
self.http_headers = self.http_headers or http_headers
|
||||
return self
|
||||
return await self.router._remote(
|
||||
self.endpoint_name, self.handle_options, request_data, kwargs)
|
||||
|
||||
def __repr__(self):
|
||||
return f"RayServeHandle(endpoint='{self.endpoint_name}')"
|
||||
return f"{self.__class__.__name__}(endpoint='{self.endpoint_name}')"
|
||||
|
||||
|
||||
class RayServeSyncHandle(RayServeHandle):
|
||||
def remote(self, request_data: Optional[Union[Dict, Any]] = None,
|
||||
**kwargs):
|
||||
"""Issue an asynchrounous request to the endpoint.
|
||||
|
||||
Returns a Ray ObjectRef whose results can be waited for or retrieved
|
||||
using ray.wait or ray.get (or ``await object_ref``), respectively.
|
||||
|
||||
Returns:
|
||||
ray.ObjectRef
|
||||
Args:
|
||||
request_data(dict, Any): If it's a dictionary, the data will be
|
||||
available in ``request.json()`` or ``request.form()``.
|
||||
Otherwise, it will be available in ``request.data``.
|
||||
``**kwargs``: All keyword arguments will be available in
|
||||
``request.args``.
|
||||
"""
|
||||
coro = self.router._remote(self.endpoint_name, self.handle_options,
|
||||
request_data, kwargs)
|
||||
future: concurrent.futures.Future = asyncio.run_coroutine_threadsafe(
|
||||
coro, self.router.async_loop)
|
||||
return future.result()
|
||||
|
||||
@@ -146,7 +146,7 @@ def test_call_method(serve_instance):
|
||||
|
||||
# Test serve handle path.
|
||||
handle = client.get_handle("endpoint")
|
||||
assert ray.get(handle.options("method").remote()) == "hello"
|
||||
assert ray.get(handle.options(method_name="method").remote()) == "hello"
|
||||
|
||||
|
||||
def test_no_route(serve_instance):
|
||||
|
||||
@@ -106,6 +106,36 @@ def test_handle_inject_starlette_request(serve_instance):
|
||||
assert request_type == "<class 'starlette.requests.Request'>"
|
||||
|
||||
|
||||
def test_handle_option_chaining(serve_instance):
|
||||
# https://github.com/ray-project/ray/issues/12802
|
||||
# https://github.com/ray-project/ray/issues/12798
|
||||
|
||||
client = serve_instance
|
||||
|
||||
class MultiMethod:
|
||||
def method_a(self, _):
|
||||
return "method_a"
|
||||
|
||||
def method_b(self, _):
|
||||
return "method_b"
|
||||
|
||||
def __call__(self, _):
|
||||
return "__call__"
|
||||
|
||||
client.create_backend("m", MultiMethod)
|
||||
client.create_endpoint("m", backend="m")
|
||||
|
||||
# get_handle should give you a clean handle
|
||||
handle1 = client.get_handle("m").options(method_name="method_a")
|
||||
handle2 = client.get_handle("m")
|
||||
# options().options() override should work
|
||||
handle3 = handle1.options(method_name="method_b")
|
||||
|
||||
assert ray.get(handle1.remote()) == "method_a"
|
||||
assert ray.get(handle2.remote()) == "__call__"
|
||||
assert ray.get(handle3.remote()) == "method_b"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
import pytest
|
||||
|
||||
Reference in New Issue
Block a user