From 7af0c999f3f97230fa3140b07f9cd4ca1d234596 Mon Sep 17 00:00:00 2001 From: Edward Oakes Date: Thu, 4 Feb 2021 15:09:12 -0600 Subject: [PATCH] [serve] Built-in support for imported backends (#13867) --- doc/source/serve/advanced.rst | 9 ++- doc/source/serve/package-ref.rst | 4 -- python/ray/serve/api.py | 23 +++---- python/ray/serve/backend_state.py | 4 +- python/ray/serve/backend_worker.py | 40 +++++++---- python/ray/serve/config.py | 68 +++++++++++-------- .../serve/examples/doc/imported_backend.py | 5 +- python/ray/serve/tests/test_backend_worker.py | 4 +- python/ray/serve/tests/test_config.py | 3 +- .../ray/serve/tests/test_imported_backend.py | 16 ++++- python/ray/serve/tests/test_util.py | 14 ++-- python/ray/serve/utils.py | 16 +++-- 12 files changed, 118 insertions(+), 88 deletions(-) diff --git a/doc/source/serve/advanced.rst b/doc/source/serve/advanced.rst index 542a3ce18..7a6027ad5 100644 --- a/doc/source/serve/advanced.rst +++ b/doc/source/serve/advanced.rst @@ -398,10 +398,9 @@ as shown below. The dependencies required in the backend may be different than the dependencies installed in the driver program (the one running Serve API -calls). In this case, you can use an -:mod:`ImportedBackend ` to specify a -backend based on a class that is installed in the Python environment that -the workers will run in. Example: +calls). In this case, you can pass the backend in as an import path that will +be imported in the Python environment in the workers, but not the driver. +Example: .. literalinclude:: ../../../python/ray/serve/examples/doc/imported_backend.py @@ -421,4 +420,4 @@ in :mod:`serve.start `: .. note:: Using the "EveryNode" option, you can point a cloud load balancer to the instance group of Ray cluster to achieve high availability of Serve's HTTP - proxies. \ No newline at end of file + proxies. diff --git a/doc/source/serve/package-ref.rst b/doc/source/serve/package-ref.rst index 3df9c2915..20ed340be 100644 --- a/doc/source/serve/package-ref.rst +++ b/doc/source/serve/package-ref.rst @@ -37,7 +37,3 @@ objects instead of Starlette requests. Batching Requests ----------------- .. autofunction:: ray.serve.accept_batch - -Built-in Backends ------------------ -.. autoclass:: ray.serve.backends.ImportedBackend diff --git a/python/ray/serve/api.py b/python/ray/serve/api.py index b42cd7846..2e0490631 100644 --- a/python/ray/serve/api.py +++ b/python/ray/serve/api.py @@ -323,22 +323,23 @@ class Client: def create_backend( self, backend_tag: str, - func_or_class: Union[Callable, Type[Callable]], - *actor_init_args: Any, + backend_def: Union[Callable, Type[Callable], str], + *init_args: Any, ray_actor_options: Optional[Dict] = None, config: Optional[Union[BackendConfig, Dict[str, Any]]] = None, env: Optional[CondaEnv] = None) -> None: """Create a backend with the provided tag. - The backend will serve requests with func_or_class. - Args: backend_tag (str): a unique tag assign to identify this backend. - func_or_class (callable, class): a function or a class implementing - __call__, returning a JSON-serializable object or a - Starlette Response object. - *actor_init_args (optional): the arguments to pass to the class - initialization method. + backend_def (callable, class, str): a function or class + implementing __call__ and returning a JSON-serializable object + or a Starlette Response object. A string import path can also + be provided (e.g., "my_module.MyClass"), in which case the + underlying function or class will be imported dynamically in + the worker replicas. + *init_args (optional): the arguments to pass to the class + initialization method. Not valid if backend_def is a function. ray_actor_options (optional): options to be passed into the @ray.remote decorator for the backend actor. config (dict, serve.BackendConfig, optional): configuration options @@ -386,9 +387,7 @@ class Client: ray_actor_options.update( override_environment_variables={"PYTHONHOME": conda_env_dir}) replica_config = ReplicaConfig( - func_or_class, - *actor_init_args, - ray_actor_options=ray_actor_options) + backend_def, *init_args, ray_actor_options=ray_actor_options) metadata = BackendMetadata( accepts_batches=replica_config.accepts_batches, is_blocking=replica_config.is_blocking) diff --git a/python/ray/serve/backend_state.py b/python/ray/serve/backend_state.py index 418ab3b2a..ba6e2260f 100644 --- a/python/ray/serve/backend_state.py +++ b/python/ray/serve/backend_state.py @@ -97,7 +97,7 @@ class BackendReplica: max_task_retries=-1, **backend_info.replica_config.ray_actor_options).remote( self._backend_tag, self._replica_tag, - backend_info.replica_config.actor_init_args, + backend_info.replica_config.init_args, backend_info.backend_config, self._controller_name) self._startup_obj_ref = self._actor_handle.ready.remote() self._state = ReplicaState.STARTING @@ -277,7 +277,7 @@ class BackendState: return None backend_replica_class = create_backend_replica( - replica_config.func_or_class) + replica_config.backend_def) # Save creator that starts replicas, the arguments to be passed in, # and the configuration for the backends. diff --git a/python/ray/serve/backend_worker.py b/python/ray/serve/backend_worker.py index da087efa5..5740cf4f5 100644 --- a/python/ray/serve/backend_worker.py +++ b/python/ray/serve/backend_worker.py @@ -13,7 +13,7 @@ from ray.actor import ActorHandle from ray.async_compat import sync_to_async from ray.serve.utils import (parse_request_item, _get_logger, chain_future, - unpack_future) + unpack_future, import_attr) from ray.serve.exceptions import RayServeException from ray.util import metrics from ray.serve.config import BackendConfig @@ -94,33 +94,40 @@ class BatchQueue: return batch -def create_backend_replica(func_or_class: Union[Callable, Type[Callable]]): +def create_backend_replica(backend_def: Union[Callable, Type[Callable], str]): """Creates a replica class wrapping the provided function or class. This approach is picked over inheritance to avoid conflict between user provided class and the RayServeReplica class. """ - - if inspect.isfunction(func_or_class): - is_function = True - elif inspect.isclass(func_or_class): - is_function = False - else: - assert False, "func_or_class must be function or class." + backend_def = backend_def # TODO(architkulkarni): Add type hints after upgrading cloudpickle class RayServeWrappedReplica(object): def __init__(self, backend_tag, replica_tag, init_args, backend_config: BackendConfig, controller_name: str): + if isinstance(backend_def, str): + backend = import_attr(backend_def) + else: + backend = backend_def + + if inspect.isfunction(backend): + is_function = True + elif inspect.isclass(backend): + is_function = False + else: + assert False, ("backend_def must be function, class, or " + "corresponding import path.") + # Set the controller name so that serve.connect() in the user's # backend code will connect to the instance that this backend is # running in. ray.serve.api._set_internal_replica_context( backend_tag, replica_tag, controller_name) if is_function: - _callable = func_or_class + _callable = backend else: - _callable = func_or_class(*init_args) + _callable = backend(*init_args) assert controller_name, "Must provide a valid controller_name" controller_handle = ray.get_actor(controller_name) @@ -144,8 +151,12 @@ def create_backend_replica(func_or_class: Union[Callable, Type[Callable]]): async def drain_pending_queries(self): return await self.backend.drain_pending_queries() - RayServeWrappedReplica.__name__ = "RayServeReplica_{}".format( - func_or_class.__name__) + if isinstance(backend_def, str): + RayServeWrappedReplica.__name__ = "RayServeReplica_{}".format( + backend_def) + else: + RayServeWrappedReplica.__name__ = "RayServeReplica_{}".format( + backend_def.__name__) return RayServeWrappedReplica @@ -415,8 +426,7 @@ class RayServeReplica: if user_config: if self.is_function: raise ValueError( - "argument func_or_class must be a class to use user_config" - ) + "backend_def must be a class to use user_config") elif not hasattr(self.callable, BACKEND_RECONFIGURE_METHOD): raise RayServeException("user_config specified but backend " + self.backend_tag + " missing " + diff --git a/python/ray/serve/config.py b/python/ray/serve/config.py index 41a1eca08..8060b406f 100644 --- a/python/ray/serve/config.py +++ b/python/ray/serve/config.py @@ -5,22 +5,29 @@ from typing import Any, Dict, List, Optional import pydantic from pydantic import BaseModel, confloat, PositiveFloat, PositiveInt, validator -from ray.serve.constants import (ASYNC_CONCURRENCY, DEFAULT_HTTP_HOST, - DEFAULT_HTTP_PORT) +from ray.serve.constants import DEFAULT_HTTP_HOST, DEFAULT_HTTP_PORT -def _callable_accepts_batch(func_or_class): - if inspect.isfunction(func_or_class): - return hasattr(func_or_class, "_serve_accept_batch") - elif inspect.isclass(func_or_class): - return hasattr(func_or_class.__call__, "_serve_accept_batch") +def _callable_accepts_batch(backend_def): + if inspect.isfunction(backend_def): + return hasattr(backend_def, "_serve_accept_batch") + elif inspect.isclass(backend_def): + return hasattr(backend_def.__call__, "_serve_accept_batch") + elif isinstance(backend_def, str): + return True + else: + raise TypeError("backend_def must be function, class, or str.") -def _callable_is_blocking(func_or_class): - if inspect.isfunction(func_or_class): - return not inspect.iscoroutinefunction(func_or_class) - elif inspect.isclass(func_or_class): - return not inspect.iscoroutinefunction(func_or_class.__call__) +def _callable_is_blocking(backend_def): + if inspect.isfunction(backend_def): + return not inspect.iscoroutinefunction(backend_def) + elif inspect.isclass(backend_def): + return not inspect.iscoroutinefunction(backend_def.__call__) + elif isinstance(backend_def, str): + return False + else: + raise TypeError("backend_def must be function, class, or str.") @dataclass @@ -105,8 +112,11 @@ class BackendConfig(BaseModel): # Pipeline/async mode: if the servable is not blocking, # router should just keep pushing queries to the replicas # until a high limit. + # TODO(edoakes): setting this to a relatively low constant because + # we can't determine if imported backends are sync or async, but we + # may consider tweaking it in the future. if not values["internal_metadata"].is_blocking: - v = ASYNC_CONCURRENCY + v = 100 # Batch inference mode: user specifies non zero timeout to wait for # full batch. We will use 2*max_batch_size to perform double @@ -119,12 +129,11 @@ class BackendConfig(BaseModel): class ReplicaConfig: - def __init__(self, func_or_class, *actor_init_args, - ray_actor_options=None): - self.func_or_class = func_or_class - self.accepts_batches = _callable_accepts_batch(func_or_class) - self.is_blocking = _callable_is_blocking(func_or_class) - self.actor_init_args = list(actor_init_args) + def __init__(self, backend_def, *init_args, ray_actor_options=None): + self.backend_def = backend_def + self.accepts_batches = _callable_accepts_batch(backend_def) + self.is_blocking = _callable_is_blocking(backend_def) + self.init_args = list(init_args) if ray_actor_options is None: self.ray_actor_options = {} else: @@ -134,27 +143,28 @@ class ReplicaConfig: self._validate() def _validate(self): - # Validate that func_or_class is a function or class. - if inspect.isfunction(self.func_or_class): - if len(self.actor_init_args) != 0: + # Validate that backend_def is an import path, function, or class. + if isinstance(self.backend_def, str): + pass + elif inspect.isfunction(self.backend_def): + if len(self.init_args) != 0: raise ValueError( - "actor_init_args not supported for function backend.") - elif not inspect.isclass(self.func_or_class): + "init_args not supported for function backend.") + elif not inspect.isclass(self.backend_def): raise TypeError( "Backend must be a function or class, it is {}.".format( - type(self.func_or_class))) + type(self.backend_def))) if not isinstance(self.ray_actor_options, dict): raise TypeError("ray_actor_options must be a dictionary.") elif "lifetime" in self.ray_actor_options: raise ValueError( - "Specifying lifetime in actor_init_args is not allowed.") + "Specifying lifetime in init_args is not allowed.") elif "name" in self.ray_actor_options: - raise ValueError( - "Specifying name in actor_init_args is not allowed.") + raise ValueError("Specifying name in init_args is not allowed.") elif "max_restarts" in self.ray_actor_options: raise ValueError("Specifying max_restarts in " - "actor_init_args is not allowed.") + "init_args is not allowed.") else: # Ray defaults to zero CPUs for placement, we default to one here. if "num_cpus" not in self.ray_actor_options: diff --git a/python/ray/serve/examples/doc/imported_backend.py b/python/ray/serve/examples/doc/imported_backend.py index d80d73b4a..596604aaa 100644 --- a/python/ray/serve/examples/doc/imported_backend.py +++ b/python/ray/serve/examples/doc/imported_backend.py @@ -1,13 +1,12 @@ import requests from ray import serve -from ray.serve.backends import ImportedBackend client = serve.start() # Include your class as input to the ImportedBackend constructor. -backend_class = ImportedBackend("ray.serve.utils.MockImportedBackend") -client.create_backend("imported", backend_class, "input_arg") +import_path = "ray.serve.utils.MockImportedBackend" +client.create_backend("imported", import_path, "input_arg") client.create_endpoint("imported", backend="imported", route="/imported") print(requests.get("http://127.0.0.1:8000/imported").text) diff --git a/python/ray/serve/tests/test_backend_worker.py b/python/ray/serve/tests/test_backend_worker.py index 74c5418df..11c22e02e 100644 --- a/python/ray/serve/tests/test_backend_worker.py +++ b/python/ray/serve/tests/test_backend_worker.py @@ -16,7 +16,7 @@ pytestmark = pytest.mark.asyncio def setup_worker(name, - func_or_class, + backend_def, init_args=None, backend_config=BackendConfig(), controller_name=""): @@ -26,7 +26,7 @@ def setup_worker(name, @ray.remote class WorkerActor: def __init__(self): - self.worker = create_backend_replica(func_or_class)( + self.worker = create_backend_replica(backend_def)( name, name + ":tag", init_args, backend_config, controller_name) diff --git a/python/ray/serve/tests/test_config.py b/python/ray/serve/tests/test_config.py index 40942ad76..5227b3ff5 100644 --- a/python/ray/serve/tests/test_config.py +++ b/python/ray/serve/tests/test_config.py @@ -3,7 +3,6 @@ import pytest from ray import serve from ray.serve.config import (BackendConfig, DeploymentMode, HTTPOptions, ReplicaConfig, BackendMetadata) -from ray.serve.constants import ASYNC_CONCURRENCY from pydantic import ValidationError @@ -42,7 +41,7 @@ def test_backend_config_validation(): assert BackendConfig( max_batch_size=10, internal_metadata=BackendMetadata( - is_blocking=False)).max_concurrent_queries == ASYNC_CONCURRENCY + is_blocking=False)).max_concurrent_queries == 100 assert BackendConfig( max_batch_size=7, batch_wait_timeout=1.0).max_concurrent_queries == 14 diff --git a/python/ray/serve/tests/test_imported_backend.py b/python/ray/serve/tests/test_imported_backend.py index 99f08a04b..4b1398072 100644 --- a/python/ray/serve/tests/test_imported_backend.py +++ b/python/ray/serve/tests/test_imported_backend.py @@ -1,15 +1,16 @@ import ray -from ray.serve.backends import ImportedBackend from ray.serve.config import BackendConfig def test_imported_backend(serve_instance): client = serve_instance - backend_class = ImportedBackend("ray.serve.utils.MockImportedBackend") config = BackendConfig(user_config="config", max_batch_size=2) client.create_backend( - "imported", backend_class, "input_arg", config=config) + "imported", + "ray.serve.utils.MockImportedBackend", + "input_arg", + config=config) client.create_endpoint("imported", backend="imported") # Basic sanity check. @@ -27,3 +28,12 @@ def test_imported_backend(serve_instance): # Check that other call methods work. handle = handle.options(method_name="other_method") assert ray.get(handle.remote("hello")) == "hello" + + # Check that functions work as well. + client.create_backend( + "imported_func", + "ray.serve.utils.mock_imported_function", + config=BackendConfig(max_batch_size=2)) + client.create_endpoint("imported_func", backend="imported_func") + handle = client.get_handle("imported_func") + assert ray.get(handle.remote("hello")) == "hello" diff --git a/python/ray/serve/tests/test_util.py b/python/ray/serve/tests/test_util.py index 9893bc4ce..95f526c31 100644 --- a/python/ray/serve/tests/test_util.py +++ b/python/ray/serve/tests/test_util.py @@ -9,7 +9,7 @@ import pytest import ray from ray.serve.utils import (ServeEncoder, chain_future, unpack_future, try_schedule_resources_on_nodes, - get_conda_env_dir, import_class) + get_conda_env_dir, import_attr) def test_bytes_encoder(): @@ -126,11 +126,11 @@ def test_get_conda_env_dir(tmp_path): os.environ["CONDA_PREFIX"] = "" -def test_import_class(): - assert import_class("ray.serve.Client") == ray.serve.api.Client - assert import_class("ray.serve.api.Client") == ray.serve.api.Client +def test_import_attr(): + assert import_attr("ray.serve.Client") == ray.serve.api.Client + assert import_attr("ray.serve.api.Client") == ray.serve.api.Client - policy_cls = import_class("ray.serve.controller.TrafficPolicy") + policy_cls = import_attr("ray.serve.controller.TrafficPolicy") assert policy_cls == ray.serve.controller.TrafficPolicy policy = policy_cls({"endpoint1": 0.5, "endpoint2": 0.5}) @@ -140,6 +140,10 @@ def test_import_class(): print(repr(policy)) + # Very meta... + import_attr_2 = import_attr("ray.serve.utils.import_attr") + assert import_attr_2 == import_attr + if __name__ == "__main__": import sys diff --git a/python/ray/serve/utils.py b/python/ray/serve/utils.py index 10753fcb5..1d19593e6 100644 --- a/python/ray/serve/utils.py +++ b/python/ray/serve/utils.py @@ -359,22 +359,26 @@ def get_node_id_for_actor(actor_handle): return ray.actors()[actor_handle._actor_id.hex()]["Address"]["NodeID"] -def import_class(full_path: str): - """Given a full import path to a class name, return the imported class. +def import_attr(full_path: str): + """Given a full import path to a module attr, return the imported attr. For example, the following are equivalent: - MyClass = import_class("module.submodule.MyClass") + MyClass = import_attr("module.submodule.MyClass") from module.submodule import MyClass Returns: - Imported class + Imported attr """ last_period_idx = full_path.rfind(".") - class_name = full_path[last_period_idx + 1:] + attr_name = full_path[last_period_idx + 1:] module_name = full_path[:last_period_idx] module = importlib.import_module(module_name) - return getattr(module, class_name) + return getattr(module, attr_name) + + +async def mock_imported_function(batch): + return [await request.body() for request in batch] class MockImportedBackend: