diff --git a/doc/source/serve/advanced.rst b/doc/source/serve/advanced.rst index 3c3ca3940..fdf2684e7 100644 --- a/doc/source/serve/advanced.rst +++ b/doc/source/serve/advanced.rst @@ -267,6 +267,26 @@ That's it. Let's take a look at an example: .. literalinclude:: ../../../python/ray/serve/examples/doc/snippet_model_composition.py + +.. _serve-sync-async-handles: + +Sync and Async Handles +====================== + +Ray Serve offers two types of ``ServeHandle``. You can use the ``client.get_handle(..., sync=True|False)`` +flag to toggle between them. + +- When you set ``sync=True`` (the default), a synchronous handle is returned. + Calling ``handle.remote()`` should return a Ray ObjectRef. +- When you set ``sync=False``, an asyncio based handle is returned. You need to + Call it with ``await handle.remote()`` to return a Ray ObjectRef. To use ``await``, + you have to run ``client.get_handle`` and ``handle.remote`` in Python asyncio event loop. + +The async handle has performance advantage because it uses asyncio directly; as compared +to the sync handle, which talks to an asyncio event loop in a thread. To learn more about +the reasoning behind these, checkout our `architecture documentation <./architecture.html>`_. + + Monitoring ========== diff --git a/python/ray/serve/api.py b/python/ray/serve/api.py index ec1593654..3e4b53b28 100644 --- a/python/ray/serve/api.py +++ b/python/ray/serve/api.py @@ -4,22 +4,40 @@ import time from functools import wraps import os from uuid import UUID +import threading +from typing import Any, Callable, Coroutine, Dict, List, Optional, Type, Union + +from ray.serve.context import TaskContext import ray from ray.serve.constants import (DEFAULT_HTTP_HOST, DEFAULT_HTTP_PORT, SERVE_CONTROLLER_NAME, HTTP_PROXY_TIMEOUT) from ray.serve.controller import ServeController -from ray.serve.handle import RayServeHandle +from ray.serve.handle import RayServeHandle, RayServeSyncHandle from ray.serve.utils import (block_until_http_ready, format_actor_name, get_random_letters, logger, get_conda_env_dir) from ray.serve.exceptions import RayServeException from ray.serve.config import BackendConfig, ReplicaConfig, BackendMetadata from ray.serve.env import CondaEnv +from ray.serve.router import RequestMetadata, Router from ray.actor import ActorHandle -from typing import Any, Callable, Dict, List, Optional, Type, Union _INTERNAL_CONTROLLER_NAME = None +global_async_loop = None + + +def create_or_get_async_loop_in_thread(): + global global_async_loop + if global_async_loop is None: + global_async_loop = asyncio.new_event_loop() + thread = threading.Thread( + daemon=True, + target=global_async_loop.run_forever, + ) + thread.start() + return global_async_loop + def _set_internal_controller_name(name): global _INTERNAL_CONTROLLER_NAME @@ -36,6 +54,36 @@ def _ensure_connected(f: Callable) -> Callable: return check +class ThreadProxiedRouter: + def __init__(self, controller_handle, sync: bool): + self.router = Router(controller_handle) + + if sync: + self.async_loop = create_or_get_async_loop_in_thread() + asyncio.run_coroutine_threadsafe( + self.router.setup_in_async_loop(), + self.async_loop, + ) + else: + self.async_loop = asyncio.get_event_loop() + self.async_loop.create_task(self.router.setup_in_async_loop()) + + def _remote(self, endpoint_name, handle_options, request_data, + kwargs) -> Coroutine: + request_metadata = RequestMetadata( + get_random_letters(10), # Used for debugging. + endpoint_name, + TaskContext.Python, + call_method=handle_options.method_name, + shard_key=handle_options.shard_key, + http_method=handle_options.http_method, + http_headers=handle_options.http_headers, + ) + coro = self.router.assign_request(request_metadata, request_data, + **kwargs) + return coro + + class Client: def __init__(self, controller: ActorHandle, @@ -48,12 +96,8 @@ class Client: self._http_host, self._http_port = ray.get( controller.get_http_config.remote()) - # NOTE(simon): Used to cache client.get_handle(endpoint) call. It will - # mostly grow in size, it will only shrink when user calls the - # .remove_endpoint method. This is fine because we expect the number of - # endpoints to be fairly small. However, in case this dictionary does - # grow very big, we can replace it with a LRU cache instead. - self._handle_cache: Dict[str, ActorHandle] = dict() + self._sync_proxied_router = None + self._async_proxied_router = None # NOTE(edoakes): Need this because the shutdown order isn't guaranteed # when the interpreter is exiting so we can't rely on __del__ (it @@ -65,6 +109,18 @@ class Client: atexit.register(shutdown_serve_client) + def _get_proxied_router(self, sync: bool): + if sync: + if self._sync_proxied_router is None: + self._sync_proxied_router = ThreadProxiedRouter( + self._controller, sync=True) + return self._sync_proxied_router + else: + if self._async_proxied_router is None: + self._async_proxied_router = ThreadProxiedRouter( + self._controller, sync=False) + return self._async_proxied_router + def __del__(self): if not self._detached: logger.debug("Shutting down Ray Serve because client went out of " @@ -198,8 +254,6 @@ class Client: Does not delete any associated backends. """ - if endpoint in self._handle_cache: - del self._handle_cache[endpoint] self._get_result(self._controller.delete_endpoint.remote(endpoint)) @_ensure_connected @@ -410,10 +464,11 @@ class Client: proportion)) @_ensure_connected - def get_handle(self, - endpoint_name: str, - missing_ok: Optional[bool] = False, - sync: bool = True) -> RayServeHandle: + def get_handle( + self, + endpoint_name: str, + missing_ok: Optional[bool] = False, + sync: bool = True) -> Union[RayServeHandle, RayServeSyncHandle]: """Retrieve RayServeHandle for service endpoint to invoke it from Python. Args: @@ -433,14 +488,26 @@ class Client: if asyncio.get_event_loop().is_running() and sync: logger.warning( - "You are retrieving a ServeHandle inside an asyncio loop. " + "You are retrieving a sync handle inside an asyncio loop. " "Try getting client.get_handle(.., sync=False) to get better " - "performance.") + "performance. Learn more at https://docs.ray.io/en/master/" + "serve/advanced.html#sync-and-async-handles") - if endpoint_name not in self._handle_cache: - handle = RayServeHandle(self._controller, endpoint_name, sync=sync) - self._handle_cache[endpoint_name] = handle - return self._handle_cache[endpoint_name] + if not asyncio.get_event_loop().is_running() and not sync: + logger.warning( + "You are retrieving an async handle outside an asyncio loop. " + "You should make sure client.get_handle is called inside a " + "running event loop. Or call client.get_handle(.., sync=True) " + "to create sync handle. Learn more at https://docs.ray.io/en/" + "master/serve/advanced.html#sync-and-async-handles") + + if sync: + handle = RayServeSyncHandle( + self._get_proxied_router(sync=sync), endpoint_name) + else: + handle = RayServeHandle( + self._get_proxied_router(sync=sync), endpoint_name) + return handle def start(detached: bool = False, diff --git a/python/ray/serve/handle.py b/python/ray/serve/handle.py index 4bfd663fd..381c8b833 100644 --- a/python/ray/serve/handle.py +++ b/python/ray/serve/handle.py @@ -1,27 +1,23 @@ import asyncio import concurrent.futures -import threading -from typing import Any, Coroutine, Dict, Optional, Union - -import ray -from ray.serve.context import TaskContext -from ray.serve.router import RequestMetadata, Router -from ray.serve.utils import get_random_letters -from ray.serve.exceptions import RayServeException - -global_async_loop = None +from dataclasses import dataclass, field +from typing import Any, Dict, Optional, Union +from enum import Enum -def create_or_get_async_loop_in_thread(): - global global_async_loop - if global_async_loop is None: - global_async_loop = asyncio.new_event_loop() - thread = threading.Thread( - daemon=True, - target=global_async_loop.run_forever, - ) - thread.start() - return global_async_loop +@dataclass(frozen=True) +class HandleOptions: + """Options for each ServeHandle instances. These fields are immutable.""" + method_name: str = "__call__" + shard_key: Optional[str] = None + http_method: str = "GET" + http_headers: Dict[str, str] = field(default_factory=dict) + + +# Use a global singleton enum to emulate default options. We cannot use None +# for those option because None is a valid new value. +class DEFAULT(Enum): + VALUE = 1 class RayServeHandle: @@ -31,75 +27,59 @@ class RayServeHandle: an HTTP endpoint. Example: - >>> handle = serve.get_handle("my_endpoint") + >>> handle = serve_client.get_handle("my_endpoint") >>> handle - RayServeHandle( - Endpoint="my_endpoint", - Traffic=... - ) - >>> handle.remote(my_request_content) + RayServeHandle(endpoint="my_endpoint") + >>> await handle.remote(my_request_content) ObjectRef(...) - >>> ray.get(handle.remote(...)) + >>> ray.get(await handle.remote(...)) # result - >>> ray.get(handle.remote(let_it_crash_request)) + >>> ray.get(await handle.remote(let_it_crash_request)) # raises RayTaskError Exception """ - def __init__( - self, - controller_handle, - endpoint_name, - sync: bool, - *, - method_name=None, - shard_key=None, - http_method=None, - http_headers=None, - ): - self.controller_handle = controller_handle + def __init__(self, + router, + endpoint_name, + handle_options: Optional[HandleOptions] = None): + self.router = router self.endpoint_name = endpoint_name + self.handle_options = handle_options or HandleOptions() - self.method_name = method_name - self.shard_key = shard_key - self.http_method = http_method - self.http_headers = http_headers + def options(self, + *, + method_name: Union[str, DEFAULT] = DEFAULT.VALUE, + shard_key: Union[str, DEFAULT] = DEFAULT.VALUE, + http_method: Union[str, DEFAULT] = DEFAULT.VALUE, + http_headers: Union[Dict[str, str], DEFAULT] = DEFAULT.VALUE): + """Set options for this handle. - self.router = Router(self.controller_handle) - self.sync = sync - # In the synchrounous mode, we create a new event loop in a separate - # thread and run the Router.setup in that loop. In the async mode, we - # can just use the current loop we are in right now. - if self.sync: - self.async_loop = create_or_get_async_loop_in_thread() - asyncio.run_coroutine_threadsafe( - self.router.setup_in_async_loop(), - self.async_loop, - ) - else: # async - self.async_loop = asyncio.get_event_loop() - # create_task is not threadsafe. - self.async_loop.create_task(self.router.setup_in_async_loop()) + Args: + method_name(str): The method to invoke on the backend. + http_method(str): The HTTP method to use for the request. + shard_key(str): A string to use to deterministically map this + request to a backend if there are multiple for this endpoint. + """ + new_options_dict = self.handle_options.__dict__.copy() + user_modified_options_dict = { + key: value + for key, value in + zip(["method_name", "shard_key", "http_method", "http_headers"], + [method_name, shard_key, http_method, http_headers]) + if value != DEFAULT.VALUE + } + new_options_dict.update(user_modified_options_dict) + new_options = HandleOptions(**new_options_dict) - def _remote(self, request_data, kwargs) -> Coroutine: - request_metadata = RequestMetadata( - get_random_letters(10), # Used for debugging. - self.endpoint_name, - TaskContext.Python, - call_method=self.method_name or "__call__", - shard_key=self.shard_key, - http_method=self.http_method or "GET", - http_headers=self.http_headers or dict(), - ) - coro = self.router.assign_request(request_metadata, request_data, - **kwargs) - return coro + return self.__class__(self.router, self.endpoint_name, new_options) - def remote(self, request_data: Optional[Union[Dict, Any]] = None, - **kwargs): - """Issue an asynchronous request to the endpoint. + async def remote(self, + request_data: Optional[Union[Dict, Any]] = None, + **kwargs): + """Issue an asynchrounous request to the endpoint. Returns a Ray ObjectRef whose results can be waited for or retrieved - using ray.wait or ray.get, respectively. + using ray.wait or ray.get (or ``await object_ref``), respectively. Returns: ray.ObjectRef @@ -110,47 +90,32 @@ class RayServeHandle: ``**kwargs``: All keyword arguments will be available in ``request.query_params``. """ - if not self.sync: - raise RayServeException( - "You are trying to call handle.remote() with async handle. " - "Please use `await handle.remote_async()` instead.") - - coro = self._remote(request_data, kwargs) - future: concurrent.futures.Future = asyncio.run_coroutine_threadsafe( - coro, self.async_loop) - - # Block until the result is ready. - return future.result() - - async def remote_async(self, - request_data: Optional[Union[Dict, Any]] = None, - **kwargs) -> ray.ObjectRef: - """Experimental API for enqueue a request in async context.""" - if not asyncio.get_event_loop().is_running(): - raise RayServeException( - "remote_async must be called from a running event loop.") - return await self._remote(request_data, kwargs) - - def options(self, - method_name: Optional[str] = None, - *, - shard_key: Optional[str] = None, - http_method: Optional[str] = None, - http_headers: Optional[Dict[str, str]] = None): - """Set options for this handle. - - Args: - method_name(str): The method to invoke on the backend. - http_method(str): The HTTP method to use for the request. - shard_key(str): A string to use to deterministically map this - request to a backend if there are multiple for this endpoint. - """ - # Don't override default non-null values. - self.method_name = self.method_name or method_name - self.shard_key = self.shard_key or shard_key - self.http_method = self.http_method or http_method - self.http_headers = self.http_headers or http_headers - return self + return await self.router._remote( + self.endpoint_name, self.handle_options, request_data, kwargs) def __repr__(self): - return f"RayServeHandle(endpoint='{self.endpoint_name}')" + return f"{self.__class__.__name__}(endpoint='{self.endpoint_name}')" + + +class RayServeSyncHandle(RayServeHandle): + def remote(self, request_data: Optional[Union[Dict, Any]] = None, + **kwargs): + """Issue an asynchrounous request to the endpoint. + + Returns a Ray ObjectRef whose results can be waited for or retrieved + using ray.wait or ray.get (or ``await object_ref``), respectively. + + Returns: + ray.ObjectRef + Args: + request_data(dict, Any): If it's a dictionary, the data will be + available in ``request.json()`` or ``request.form()``. + Otherwise, it will be available in ``request.data``. + ``**kwargs``: All keyword arguments will be available in + ``request.args``. + """ + coro = self.router._remote(self.endpoint_name, self.handle_options, + request_data, kwargs) + future: concurrent.futures.Future = asyncio.run_coroutine_threadsafe( + coro, self.router.async_loop) + return future.result() diff --git a/python/ray/serve/tests/test_api.py b/python/ray/serve/tests/test_api.py index ea0f6f35d..318c93732 100644 --- a/python/ray/serve/tests/test_api.py +++ b/python/ray/serve/tests/test_api.py @@ -146,7 +146,7 @@ def test_call_method(serve_instance): # Test serve handle path. handle = client.get_handle("endpoint") - assert ray.get(handle.options("method").remote()) == "hello" + assert ray.get(handle.options(method_name="method").remote()) == "hello" def test_no_route(serve_instance): diff --git a/python/ray/serve/tests/test_handle.py b/python/ray/serve/tests/test_handle.py index cc6b1e72b..c17db7686 100644 --- a/python/ray/serve/tests/test_handle.py +++ b/python/ray/serve/tests/test_handle.py @@ -106,6 +106,36 @@ def test_handle_inject_starlette_request(serve_instance): assert request_type == "" +def test_handle_option_chaining(serve_instance): + # https://github.com/ray-project/ray/issues/12802 + # https://github.com/ray-project/ray/issues/12798 + + client = serve_instance + + class MultiMethod: + def method_a(self, _): + return "method_a" + + def method_b(self, _): + return "method_b" + + def __call__(self, _): + return "__call__" + + client.create_backend("m", MultiMethod) + client.create_endpoint("m", backend="m") + + # get_handle should give you a clean handle + handle1 = client.get_handle("m").options(method_name="method_a") + handle2 = client.get_handle("m") + # options().options() override should work + handle3 = handle1.options(method_name="method_b") + + assert ray.get(handle1.remote()) == "method_a" + assert ray.get(handle2.remote()) == "__call__" + assert ray.get(handle3.remote()) == "method_b" + + if __name__ == "__main__": import sys import pytest