[Serve] Handle Bug Fixes (#12971)

This commit is contained in:
Simon Mo
2020-12-22 19:13:16 -08:00
committed by GitHub
parent 81d3cbaa77
commit bc68260144
5 changed files with 221 additions and 139 deletions
+87 -20
View File
@@ -4,22 +4,40 @@ import time
from functools import wraps
import os
from uuid import UUID
import threading
from typing import Any, Callable, Coroutine, Dict, List, Optional, Type, Union
from ray.serve.context import TaskContext
import ray
from ray.serve.constants import (DEFAULT_HTTP_HOST, DEFAULT_HTTP_PORT,
SERVE_CONTROLLER_NAME, HTTP_PROXY_TIMEOUT)
from ray.serve.controller import ServeController
from ray.serve.handle import RayServeHandle
from ray.serve.handle import RayServeHandle, RayServeSyncHandle
from ray.serve.utils import (block_until_http_ready, format_actor_name,
get_random_letters, logger, get_conda_env_dir)
from ray.serve.exceptions import RayServeException
from ray.serve.config import BackendConfig, ReplicaConfig, BackendMetadata
from ray.serve.env import CondaEnv
from ray.serve.router import RequestMetadata, Router
from ray.actor import ActorHandle
from typing import Any, Callable, Dict, List, Optional, Type, Union
_INTERNAL_CONTROLLER_NAME = None
global_async_loop = None
def create_or_get_async_loop_in_thread():
global global_async_loop
if global_async_loop is None:
global_async_loop = asyncio.new_event_loop()
thread = threading.Thread(
daemon=True,
target=global_async_loop.run_forever,
)
thread.start()
return global_async_loop
def _set_internal_controller_name(name):
global _INTERNAL_CONTROLLER_NAME
@@ -36,6 +54,36 @@ def _ensure_connected(f: Callable) -> Callable:
return check
class ThreadProxiedRouter:
def __init__(self, controller_handle, sync: bool):
self.router = Router(controller_handle)
if sync:
self.async_loop = create_or_get_async_loop_in_thread()
asyncio.run_coroutine_threadsafe(
self.router.setup_in_async_loop(),
self.async_loop,
)
else:
self.async_loop = asyncio.get_event_loop()
self.async_loop.create_task(self.router.setup_in_async_loop())
def _remote(self, endpoint_name, handle_options, request_data,
kwargs) -> Coroutine:
request_metadata = RequestMetadata(
get_random_letters(10), # Used for debugging.
endpoint_name,
TaskContext.Python,
call_method=handle_options.method_name,
shard_key=handle_options.shard_key,
http_method=handle_options.http_method,
http_headers=handle_options.http_headers,
)
coro = self.router.assign_request(request_metadata, request_data,
**kwargs)
return coro
class Client:
def __init__(self,
controller: ActorHandle,
@@ -48,12 +96,8 @@ class Client:
self._http_host, self._http_port = ray.get(
controller.get_http_config.remote())
# NOTE(simon): Used to cache client.get_handle(endpoint) call. It will
# mostly grow in size, it will only shrink when user calls the
# .remove_endpoint method. This is fine because we expect the number of
# endpoints to be fairly small. However, in case this dictionary does
# grow very big, we can replace it with a LRU cache instead.
self._handle_cache: Dict[str, ActorHandle] = dict()
self._sync_proxied_router = None
self._async_proxied_router = None
# NOTE(edoakes): Need this because the shutdown order isn't guaranteed
# when the interpreter is exiting so we can't rely on __del__ (it
@@ -65,6 +109,18 @@ class Client:
atexit.register(shutdown_serve_client)
def _get_proxied_router(self, sync: bool):
if sync:
if self._sync_proxied_router is None:
self._sync_proxied_router = ThreadProxiedRouter(
self._controller, sync=True)
return self._sync_proxied_router
else:
if self._async_proxied_router is None:
self._async_proxied_router = ThreadProxiedRouter(
self._controller, sync=False)
return self._async_proxied_router
def __del__(self):
if not self._detached:
logger.debug("Shutting down Ray Serve because client went out of "
@@ -198,8 +254,6 @@ class Client:
Does not delete any associated backends.
"""
if endpoint in self._handle_cache:
del self._handle_cache[endpoint]
self._get_result(self._controller.delete_endpoint.remote(endpoint))
@_ensure_connected
@@ -410,10 +464,11 @@ class Client:
proportion))
@_ensure_connected
def get_handle(self,
endpoint_name: str,
missing_ok: Optional[bool] = False,
sync: bool = True) -> RayServeHandle:
def get_handle(
self,
endpoint_name: str,
missing_ok: Optional[bool] = False,
sync: bool = True) -> Union[RayServeHandle, RayServeSyncHandle]:
"""Retrieve RayServeHandle for service endpoint to invoke it from Python.
Args:
@@ -433,14 +488,26 @@ class Client:
if asyncio.get_event_loop().is_running() and sync:
logger.warning(
"You are retrieving a ServeHandle inside an asyncio loop. "
"You are retrieving a sync handle inside an asyncio loop. "
"Try getting client.get_handle(.., sync=False) to get better "
"performance.")
"performance. Learn more at https://docs.ray.io/en/master/"
"serve/advanced.html#sync-and-async-handles")
if endpoint_name not in self._handle_cache:
handle = RayServeHandle(self._controller, endpoint_name, sync=sync)
self._handle_cache[endpoint_name] = handle
return self._handle_cache[endpoint_name]
if not asyncio.get_event_loop().is_running() and not sync:
logger.warning(
"You are retrieving an async handle outside an asyncio loop. "
"You should make sure client.get_handle is called inside a "
"running event loop. Or call client.get_handle(.., sync=True) "
"to create sync handle. Learn more at https://docs.ray.io/en/"
"master/serve/advanced.html#sync-and-async-handles")
if sync:
handle = RayServeSyncHandle(
self._get_proxied_router(sync=sync), endpoint_name)
else:
handle = RayServeHandle(
self._get_proxied_router(sync=sync), endpoint_name)
return handle
def start(detached: bool = False,
+83 -118
View File
@@ -1,27 +1,23 @@
import asyncio
import concurrent.futures
import threading
from typing import Any, Coroutine, Dict, Optional, Union
import ray
from ray.serve.context import TaskContext
from ray.serve.router import RequestMetadata, Router
from ray.serve.utils import get_random_letters
from ray.serve.exceptions import RayServeException
global_async_loop = None
from dataclasses import dataclass, field
from typing import Any, Dict, Optional, Union
from enum import Enum
def create_or_get_async_loop_in_thread():
global global_async_loop
if global_async_loop is None:
global_async_loop = asyncio.new_event_loop()
thread = threading.Thread(
daemon=True,
target=global_async_loop.run_forever,
)
thread.start()
return global_async_loop
@dataclass(frozen=True)
class HandleOptions:
"""Options for each ServeHandle instances. These fields are immutable."""
method_name: str = "__call__"
shard_key: Optional[str] = None
http_method: str = "GET"
http_headers: Dict[str, str] = field(default_factory=dict)
# Use a global singleton enum to emulate default options. We cannot use None
# for those option because None is a valid new value.
class DEFAULT(Enum):
VALUE = 1
class RayServeHandle:
@@ -31,75 +27,59 @@ class RayServeHandle:
an HTTP endpoint.
Example:
>>> handle = serve.get_handle("my_endpoint")
>>> handle = serve_client.get_handle("my_endpoint")
>>> handle
RayServeHandle(
Endpoint="my_endpoint",
Traffic=...
)
>>> handle.remote(my_request_content)
RayServeHandle(endpoint="my_endpoint")
>>> await handle.remote(my_request_content)
ObjectRef(...)
>>> ray.get(handle.remote(...))
>>> ray.get(await handle.remote(...))
# result
>>> ray.get(handle.remote(let_it_crash_request))
>>> ray.get(await handle.remote(let_it_crash_request))
# raises RayTaskError Exception
"""
def __init__(
self,
controller_handle,
endpoint_name,
sync: bool,
*,
method_name=None,
shard_key=None,
http_method=None,
http_headers=None,
):
self.controller_handle = controller_handle
def __init__(self,
router,
endpoint_name,
handle_options: Optional[HandleOptions] = None):
self.router = router
self.endpoint_name = endpoint_name
self.handle_options = handle_options or HandleOptions()
self.method_name = method_name
self.shard_key = shard_key
self.http_method = http_method
self.http_headers = http_headers
def options(self,
*,
method_name: Union[str, DEFAULT] = DEFAULT.VALUE,
shard_key: Union[str, DEFAULT] = DEFAULT.VALUE,
http_method: Union[str, DEFAULT] = DEFAULT.VALUE,
http_headers: Union[Dict[str, str], DEFAULT] = DEFAULT.VALUE):
"""Set options for this handle.
self.router = Router(self.controller_handle)
self.sync = sync
# In the synchrounous mode, we create a new event loop in a separate
# thread and run the Router.setup in that loop. In the async mode, we
# can just use the current loop we are in right now.
if self.sync:
self.async_loop = create_or_get_async_loop_in_thread()
asyncio.run_coroutine_threadsafe(
self.router.setup_in_async_loop(),
self.async_loop,
)
else: # async
self.async_loop = asyncio.get_event_loop()
# create_task is not threadsafe.
self.async_loop.create_task(self.router.setup_in_async_loop())
Args:
method_name(str): The method to invoke on the backend.
http_method(str): The HTTP method to use for the request.
shard_key(str): A string to use to deterministically map this
request to a backend if there are multiple for this endpoint.
"""
new_options_dict = self.handle_options.__dict__.copy()
user_modified_options_dict = {
key: value
for key, value in
zip(["method_name", "shard_key", "http_method", "http_headers"],
[method_name, shard_key, http_method, http_headers])
if value != DEFAULT.VALUE
}
new_options_dict.update(user_modified_options_dict)
new_options = HandleOptions(**new_options_dict)
def _remote(self, request_data, kwargs) -> Coroutine:
request_metadata = RequestMetadata(
get_random_letters(10), # Used for debugging.
self.endpoint_name,
TaskContext.Python,
call_method=self.method_name or "__call__",
shard_key=self.shard_key,
http_method=self.http_method or "GET",
http_headers=self.http_headers or dict(),
)
coro = self.router.assign_request(request_metadata, request_data,
**kwargs)
return coro
return self.__class__(self.router, self.endpoint_name, new_options)
def remote(self, request_data: Optional[Union[Dict, Any]] = None,
**kwargs):
"""Issue an asynchronous request to the endpoint.
async def remote(self,
request_data: Optional[Union[Dict, Any]] = None,
**kwargs):
"""Issue an asynchrounous request to the endpoint.
Returns a Ray ObjectRef whose results can be waited for or retrieved
using ray.wait or ray.get, respectively.
using ray.wait or ray.get (or ``await object_ref``), respectively.
Returns:
ray.ObjectRef
@@ -110,47 +90,32 @@ class RayServeHandle:
``**kwargs``: All keyword arguments will be available in
``request.query_params``.
"""
if not self.sync:
raise RayServeException(
"You are trying to call handle.remote() with async handle. "
"Please use `await handle.remote_async()` instead.")
coro = self._remote(request_data, kwargs)
future: concurrent.futures.Future = asyncio.run_coroutine_threadsafe(
coro, self.async_loop)
# Block until the result is ready.
return future.result()
async def remote_async(self,
request_data: Optional[Union[Dict, Any]] = None,
**kwargs) -> ray.ObjectRef:
"""Experimental API for enqueue a request in async context."""
if not asyncio.get_event_loop().is_running():
raise RayServeException(
"remote_async must be called from a running event loop.")
return await self._remote(request_data, kwargs)
def options(self,
method_name: Optional[str] = None,
*,
shard_key: Optional[str] = None,
http_method: Optional[str] = None,
http_headers: Optional[Dict[str, str]] = None):
"""Set options for this handle.
Args:
method_name(str): The method to invoke on the backend.
http_method(str): The HTTP method to use for the request.
shard_key(str): A string to use to deterministically map this
request to a backend if there are multiple for this endpoint.
"""
# Don't override default non-null values.
self.method_name = self.method_name or method_name
self.shard_key = self.shard_key or shard_key
self.http_method = self.http_method or http_method
self.http_headers = self.http_headers or http_headers
return self
return await self.router._remote(
self.endpoint_name, self.handle_options, request_data, kwargs)
def __repr__(self):
return f"RayServeHandle(endpoint='{self.endpoint_name}')"
return f"{self.__class__.__name__}(endpoint='{self.endpoint_name}')"
class RayServeSyncHandle(RayServeHandle):
def remote(self, request_data: Optional[Union[Dict, Any]] = None,
**kwargs):
"""Issue an asynchrounous request to the endpoint.
Returns a Ray ObjectRef whose results can be waited for or retrieved
using ray.wait or ray.get (or ``await object_ref``), respectively.
Returns:
ray.ObjectRef
Args:
request_data(dict, Any): If it's a dictionary, the data will be
available in ``request.json()`` or ``request.form()``.
Otherwise, it will be available in ``request.data``.
``**kwargs``: All keyword arguments will be available in
``request.args``.
"""
coro = self.router._remote(self.endpoint_name, self.handle_options,
request_data, kwargs)
future: concurrent.futures.Future = asyncio.run_coroutine_threadsafe(
coro, self.router.async_loop)
return future.result()
+1 -1
View File
@@ -146,7 +146,7 @@ def test_call_method(serve_instance):
# Test serve handle path.
handle = client.get_handle("endpoint")
assert ray.get(handle.options("method").remote()) == "hello"
assert ray.get(handle.options(method_name="method").remote()) == "hello"
def test_no_route(serve_instance):
+30
View File
@@ -106,6 +106,36 @@ def test_handle_inject_starlette_request(serve_instance):
assert request_type == "<class 'starlette.requests.Request'>"
def test_handle_option_chaining(serve_instance):
# https://github.com/ray-project/ray/issues/12802
# https://github.com/ray-project/ray/issues/12798
client = serve_instance
class MultiMethod:
def method_a(self, _):
return "method_a"
def method_b(self, _):
return "method_b"
def __call__(self, _):
return "__call__"
client.create_backend("m", MultiMethod)
client.create_endpoint("m", backend="m")
# get_handle should give you a clean handle
handle1 = client.get_handle("m").options(method_name="method_a")
handle2 = client.get_handle("m")
# options().options() override should work
handle3 = handle1.options(method_name="method_b")
assert ray.get(handle1.remote()) == "method_a"
assert ray.get(handle2.remote()) == "__call__"
assert ray.get(handle3.remote()) == "method_b"
if __name__ == "__main__":
import sys
import pytest