[serve] Built-in support for imported backends (#13867)

This commit is contained in:
Edward Oakes
2021-02-04 15:09:12 -06:00
committed by GitHub
parent db59736b1a
commit 7af0c999f3
12 changed files with 118 additions and 88 deletions
+11 -12
View File
@@ -323,22 +323,23 @@ class Client:
def create_backend(
self,
backend_tag: str,
func_or_class: Union[Callable, Type[Callable]],
*actor_init_args: Any,
backend_def: Union[Callable, Type[Callable], str],
*init_args: Any,
ray_actor_options: Optional[Dict] = None,
config: Optional[Union[BackendConfig, Dict[str, Any]]] = None,
env: Optional[CondaEnv] = None) -> None:
"""Create a backend with the provided tag.
The backend will serve requests with func_or_class.
Args:
backend_tag (str): a unique tag assign to identify this backend.
func_or_class (callable, class): a function or a class implementing
__call__, returning a JSON-serializable object or a
Starlette Response object.
*actor_init_args (optional): the arguments to pass to the class
initialization method.
backend_def (callable, class, str): a function or class
implementing __call__ and returning a JSON-serializable object
or a Starlette Response object. A string import path can also
be provided (e.g., "my_module.MyClass"), in which case the
underlying function or class will be imported dynamically in
the worker replicas.
*init_args (optional): the arguments to pass to the class
initialization method. Not valid if backend_def is a function.
ray_actor_options (optional): options to be passed into the
@ray.remote decorator for the backend actor.
config (dict, serve.BackendConfig, optional): configuration options
@@ -386,9 +387,7 @@ class Client:
ray_actor_options.update(
override_environment_variables={"PYTHONHOME": conda_env_dir})
replica_config = ReplicaConfig(
func_or_class,
*actor_init_args,
ray_actor_options=ray_actor_options)
backend_def, *init_args, ray_actor_options=ray_actor_options)
metadata = BackendMetadata(
accepts_batches=replica_config.accepts_batches,
is_blocking=replica_config.is_blocking)
+2 -2
View File
@@ -97,7 +97,7 @@ class BackendReplica:
max_task_retries=-1,
**backend_info.replica_config.ray_actor_options).remote(
self._backend_tag, self._replica_tag,
backend_info.replica_config.actor_init_args,
backend_info.replica_config.init_args,
backend_info.backend_config, self._controller_name)
self._startup_obj_ref = self._actor_handle.ready.remote()
self._state = ReplicaState.STARTING
@@ -277,7 +277,7 @@ class BackendState:
return None
backend_replica_class = create_backend_replica(
replica_config.func_or_class)
replica_config.backend_def)
# Save creator that starts replicas, the arguments to be passed in,
# and the configuration for the backends.
+25 -15
View File
@@ -13,7 +13,7 @@ from ray.actor import ActorHandle
from ray.async_compat import sync_to_async
from ray.serve.utils import (parse_request_item, _get_logger, chain_future,
unpack_future)
unpack_future, import_attr)
from ray.serve.exceptions import RayServeException
from ray.util import metrics
from ray.serve.config import BackendConfig
@@ -94,33 +94,40 @@ class BatchQueue:
return batch
def create_backend_replica(func_or_class: Union[Callable, Type[Callable]]):
def create_backend_replica(backend_def: Union[Callable, Type[Callable], str]):
"""Creates a replica class wrapping the provided function or class.
This approach is picked over inheritance to avoid conflict between user
provided class and the RayServeReplica class.
"""
if inspect.isfunction(func_or_class):
is_function = True
elif inspect.isclass(func_or_class):
is_function = False
else:
assert False, "func_or_class must be function or class."
backend_def = backend_def
# TODO(architkulkarni): Add type hints after upgrading cloudpickle
class RayServeWrappedReplica(object):
def __init__(self, backend_tag, replica_tag, init_args,
backend_config: BackendConfig, controller_name: str):
if isinstance(backend_def, str):
backend = import_attr(backend_def)
else:
backend = backend_def
if inspect.isfunction(backend):
is_function = True
elif inspect.isclass(backend):
is_function = False
else:
assert False, ("backend_def must be function, class, or "
"corresponding import path.")
# Set the controller name so that serve.connect() in the user's
# backend code will connect to the instance that this backend is
# running in.
ray.serve.api._set_internal_replica_context(
backend_tag, replica_tag, controller_name)
if is_function:
_callable = func_or_class
_callable = backend
else:
_callable = func_or_class(*init_args)
_callable = backend(*init_args)
assert controller_name, "Must provide a valid controller_name"
controller_handle = ray.get_actor(controller_name)
@@ -144,8 +151,12 @@ def create_backend_replica(func_or_class: Union[Callable, Type[Callable]]):
async def drain_pending_queries(self):
return await self.backend.drain_pending_queries()
RayServeWrappedReplica.__name__ = "RayServeReplica_{}".format(
func_or_class.__name__)
if isinstance(backend_def, str):
RayServeWrappedReplica.__name__ = "RayServeReplica_{}".format(
backend_def)
else:
RayServeWrappedReplica.__name__ = "RayServeReplica_{}".format(
backend_def.__name__)
return RayServeWrappedReplica
@@ -415,8 +426,7 @@ class RayServeReplica:
if user_config:
if self.is_function:
raise ValueError(
"argument func_or_class must be a class to use user_config"
)
"backend_def must be a class to use user_config")
elif not hasattr(self.callable, BACKEND_RECONFIGURE_METHOD):
raise RayServeException("user_config specified but backend " +
self.backend_tag + " missing " +
+39 -29
View File
@@ -5,22 +5,29 @@ from typing import Any, Dict, List, Optional
import pydantic
from pydantic import BaseModel, confloat, PositiveFloat, PositiveInt, validator
from ray.serve.constants import (ASYNC_CONCURRENCY, DEFAULT_HTTP_HOST,
DEFAULT_HTTP_PORT)
from ray.serve.constants import DEFAULT_HTTP_HOST, DEFAULT_HTTP_PORT
def _callable_accepts_batch(func_or_class):
if inspect.isfunction(func_or_class):
return hasattr(func_or_class, "_serve_accept_batch")
elif inspect.isclass(func_or_class):
return hasattr(func_or_class.__call__, "_serve_accept_batch")
def _callable_accepts_batch(backend_def):
if inspect.isfunction(backend_def):
return hasattr(backend_def, "_serve_accept_batch")
elif inspect.isclass(backend_def):
return hasattr(backend_def.__call__, "_serve_accept_batch")
elif isinstance(backend_def, str):
return True
else:
raise TypeError("backend_def must be function, class, or str.")
def _callable_is_blocking(func_or_class):
if inspect.isfunction(func_or_class):
return not inspect.iscoroutinefunction(func_or_class)
elif inspect.isclass(func_or_class):
return not inspect.iscoroutinefunction(func_or_class.__call__)
def _callable_is_blocking(backend_def):
if inspect.isfunction(backend_def):
return not inspect.iscoroutinefunction(backend_def)
elif inspect.isclass(backend_def):
return not inspect.iscoroutinefunction(backend_def.__call__)
elif isinstance(backend_def, str):
return False
else:
raise TypeError("backend_def must be function, class, or str.")
@dataclass
@@ -105,8 +112,11 @@ class BackendConfig(BaseModel):
# Pipeline/async mode: if the servable is not blocking,
# router should just keep pushing queries to the replicas
# until a high limit.
# TODO(edoakes): setting this to a relatively low constant because
# we can't determine if imported backends are sync or async, but we
# may consider tweaking it in the future.
if not values["internal_metadata"].is_blocking:
v = ASYNC_CONCURRENCY
v = 100
# Batch inference mode: user specifies non zero timeout to wait for
# full batch. We will use 2*max_batch_size to perform double
@@ -119,12 +129,11 @@ class BackendConfig(BaseModel):
class ReplicaConfig:
def __init__(self, func_or_class, *actor_init_args,
ray_actor_options=None):
self.func_or_class = func_or_class
self.accepts_batches = _callable_accepts_batch(func_or_class)
self.is_blocking = _callable_is_blocking(func_or_class)
self.actor_init_args = list(actor_init_args)
def __init__(self, backend_def, *init_args, ray_actor_options=None):
self.backend_def = backend_def
self.accepts_batches = _callable_accepts_batch(backend_def)
self.is_blocking = _callable_is_blocking(backend_def)
self.init_args = list(init_args)
if ray_actor_options is None:
self.ray_actor_options = {}
else:
@@ -134,27 +143,28 @@ class ReplicaConfig:
self._validate()
def _validate(self):
# Validate that func_or_class is a function or class.
if inspect.isfunction(self.func_or_class):
if len(self.actor_init_args) != 0:
# Validate that backend_def is an import path, function, or class.
if isinstance(self.backend_def, str):
pass
elif inspect.isfunction(self.backend_def):
if len(self.init_args) != 0:
raise ValueError(
"actor_init_args not supported for function backend.")
elif not inspect.isclass(self.func_or_class):
"init_args not supported for function backend.")
elif not inspect.isclass(self.backend_def):
raise TypeError(
"Backend must be a function or class, it is {}.".format(
type(self.func_or_class)))
type(self.backend_def)))
if not isinstance(self.ray_actor_options, dict):
raise TypeError("ray_actor_options must be a dictionary.")
elif "lifetime" in self.ray_actor_options:
raise ValueError(
"Specifying lifetime in actor_init_args is not allowed.")
"Specifying lifetime in init_args is not allowed.")
elif "name" in self.ray_actor_options:
raise ValueError(
"Specifying name in actor_init_args is not allowed.")
raise ValueError("Specifying name in init_args is not allowed.")
elif "max_restarts" in self.ray_actor_options:
raise ValueError("Specifying max_restarts in "
"actor_init_args is not allowed.")
"init_args is not allowed.")
else:
# Ray defaults to zero CPUs for placement, we default to one here.
if "num_cpus" not in self.ray_actor_options:
@@ -1,13 +1,12 @@
import requests
from ray import serve
from ray.serve.backends import ImportedBackend
client = serve.start()
# Include your class as input to the ImportedBackend constructor.
backend_class = ImportedBackend("ray.serve.utils.MockImportedBackend")
client.create_backend("imported", backend_class, "input_arg")
import_path = "ray.serve.utils.MockImportedBackend"
client.create_backend("imported", import_path, "input_arg")
client.create_endpoint("imported", backend="imported", route="/imported")
print(requests.get("http://127.0.0.1:8000/imported").text)
@@ -16,7 +16,7 @@ pytestmark = pytest.mark.asyncio
def setup_worker(name,
func_or_class,
backend_def,
init_args=None,
backend_config=BackendConfig(),
controller_name=""):
@@ -26,7 +26,7 @@ def setup_worker(name,
@ray.remote
class WorkerActor:
def __init__(self):
self.worker = create_backend_replica(func_or_class)(
self.worker = create_backend_replica(backend_def)(
name, name + ":tag", init_args, backend_config,
controller_name)
+1 -2
View File
@@ -3,7 +3,6 @@ import pytest
from ray import serve
from ray.serve.config import (BackendConfig, DeploymentMode, HTTPOptions,
ReplicaConfig, BackendMetadata)
from ray.serve.constants import ASYNC_CONCURRENCY
from pydantic import ValidationError
@@ -42,7 +41,7 @@ def test_backend_config_validation():
assert BackendConfig(
max_batch_size=10,
internal_metadata=BackendMetadata(
is_blocking=False)).max_concurrent_queries == ASYNC_CONCURRENCY
is_blocking=False)).max_concurrent_queries == 100
assert BackendConfig(
max_batch_size=7, batch_wait_timeout=1.0).max_concurrent_queries == 14
@@ -1,15 +1,16 @@
import ray
from ray.serve.backends import ImportedBackend
from ray.serve.config import BackendConfig
def test_imported_backend(serve_instance):
client = serve_instance
backend_class = ImportedBackend("ray.serve.utils.MockImportedBackend")
config = BackendConfig(user_config="config", max_batch_size=2)
client.create_backend(
"imported", backend_class, "input_arg", config=config)
"imported",
"ray.serve.utils.MockImportedBackend",
"input_arg",
config=config)
client.create_endpoint("imported", backend="imported")
# Basic sanity check.
@@ -27,3 +28,12 @@ def test_imported_backend(serve_instance):
# Check that other call methods work.
handle = handle.options(method_name="other_method")
assert ray.get(handle.remote("hello")) == "hello"
# Check that functions work as well.
client.create_backend(
"imported_func",
"ray.serve.utils.mock_imported_function",
config=BackendConfig(max_batch_size=2))
client.create_endpoint("imported_func", backend="imported_func")
handle = client.get_handle("imported_func")
assert ray.get(handle.remote("hello")) == "hello"
+9 -5
View File
@@ -9,7 +9,7 @@ import pytest
import ray
from ray.serve.utils import (ServeEncoder, chain_future, unpack_future,
try_schedule_resources_on_nodes,
get_conda_env_dir, import_class)
get_conda_env_dir, import_attr)
def test_bytes_encoder():
@@ -126,11 +126,11 @@ def test_get_conda_env_dir(tmp_path):
os.environ["CONDA_PREFIX"] = ""
def test_import_class():
assert import_class("ray.serve.Client") == ray.serve.api.Client
assert import_class("ray.serve.api.Client") == ray.serve.api.Client
def test_import_attr():
assert import_attr("ray.serve.Client") == ray.serve.api.Client
assert import_attr("ray.serve.api.Client") == ray.serve.api.Client
policy_cls = import_class("ray.serve.controller.TrafficPolicy")
policy_cls = import_attr("ray.serve.controller.TrafficPolicy")
assert policy_cls == ray.serve.controller.TrafficPolicy
policy = policy_cls({"endpoint1": 0.5, "endpoint2": 0.5})
@@ -140,6 +140,10 @@ def test_import_class():
print(repr(policy))
# Very meta...
import_attr_2 = import_attr("ray.serve.utils.import_attr")
assert import_attr_2 == import_attr
if __name__ == "__main__":
import sys
+10 -6
View File
@@ -359,22 +359,26 @@ def get_node_id_for_actor(actor_handle):
return ray.actors()[actor_handle._actor_id.hex()]["Address"]["NodeID"]
def import_class(full_path: str):
"""Given a full import path to a class name, return the imported class.
def import_attr(full_path: str):
"""Given a full import path to a module attr, return the imported attr.
For example, the following are equivalent:
MyClass = import_class("module.submodule.MyClass")
MyClass = import_attr("module.submodule.MyClass")
from module.submodule import MyClass
Returns:
Imported class
Imported attr
"""
last_period_idx = full_path.rfind(".")
class_name = full_path[last_period_idx + 1:]
attr_name = full_path[last_period_idx + 1:]
module_name = full_path[:last_period_idx]
module = importlib.import_module(module_name)
return getattr(module, class_name)
return getattr(module, attr_name)
async def mock_imported_function(batch):
return [await request.body() for request in batch]
class MockImportedBackend: