mirror of
https://github.com/wassname/ray.git
synced 2026-07-03 01:09:13 +08:00
[serve] Built-in support for imported backends (#13867)
This commit is contained in:
+11
-12
@@ -323,22 +323,23 @@ class Client:
|
||||
def create_backend(
|
||||
self,
|
||||
backend_tag: str,
|
||||
func_or_class: Union[Callable, Type[Callable]],
|
||||
*actor_init_args: Any,
|
||||
backend_def: Union[Callable, Type[Callable], str],
|
||||
*init_args: Any,
|
||||
ray_actor_options: Optional[Dict] = None,
|
||||
config: Optional[Union[BackendConfig, Dict[str, Any]]] = None,
|
||||
env: Optional[CondaEnv] = None) -> None:
|
||||
"""Create a backend with the provided tag.
|
||||
|
||||
The backend will serve requests with func_or_class.
|
||||
|
||||
Args:
|
||||
backend_tag (str): a unique tag assign to identify this backend.
|
||||
func_or_class (callable, class): a function or a class implementing
|
||||
__call__, returning a JSON-serializable object or a
|
||||
Starlette Response object.
|
||||
*actor_init_args (optional): the arguments to pass to the class
|
||||
initialization method.
|
||||
backend_def (callable, class, str): a function or class
|
||||
implementing __call__ and returning a JSON-serializable object
|
||||
or a Starlette Response object. A string import path can also
|
||||
be provided (e.g., "my_module.MyClass"), in which case the
|
||||
underlying function or class will be imported dynamically in
|
||||
the worker replicas.
|
||||
*init_args (optional): the arguments to pass to the class
|
||||
initialization method. Not valid if backend_def is a function.
|
||||
ray_actor_options (optional): options to be passed into the
|
||||
@ray.remote decorator for the backend actor.
|
||||
config (dict, serve.BackendConfig, optional): configuration options
|
||||
@@ -386,9 +387,7 @@ class Client:
|
||||
ray_actor_options.update(
|
||||
override_environment_variables={"PYTHONHOME": conda_env_dir})
|
||||
replica_config = ReplicaConfig(
|
||||
func_or_class,
|
||||
*actor_init_args,
|
||||
ray_actor_options=ray_actor_options)
|
||||
backend_def, *init_args, ray_actor_options=ray_actor_options)
|
||||
metadata = BackendMetadata(
|
||||
accepts_batches=replica_config.accepts_batches,
|
||||
is_blocking=replica_config.is_blocking)
|
||||
|
||||
@@ -97,7 +97,7 @@ class BackendReplica:
|
||||
max_task_retries=-1,
|
||||
**backend_info.replica_config.ray_actor_options).remote(
|
||||
self._backend_tag, self._replica_tag,
|
||||
backend_info.replica_config.actor_init_args,
|
||||
backend_info.replica_config.init_args,
|
||||
backend_info.backend_config, self._controller_name)
|
||||
self._startup_obj_ref = self._actor_handle.ready.remote()
|
||||
self._state = ReplicaState.STARTING
|
||||
@@ -277,7 +277,7 @@ class BackendState:
|
||||
return None
|
||||
|
||||
backend_replica_class = create_backend_replica(
|
||||
replica_config.func_or_class)
|
||||
replica_config.backend_def)
|
||||
|
||||
# Save creator that starts replicas, the arguments to be passed in,
|
||||
# and the configuration for the backends.
|
||||
|
||||
@@ -13,7 +13,7 @@ from ray.actor import ActorHandle
|
||||
from ray.async_compat import sync_to_async
|
||||
|
||||
from ray.serve.utils import (parse_request_item, _get_logger, chain_future,
|
||||
unpack_future)
|
||||
unpack_future, import_attr)
|
||||
from ray.serve.exceptions import RayServeException
|
||||
from ray.util import metrics
|
||||
from ray.serve.config import BackendConfig
|
||||
@@ -94,33 +94,40 @@ class BatchQueue:
|
||||
return batch
|
||||
|
||||
|
||||
def create_backend_replica(func_or_class: Union[Callable, Type[Callable]]):
|
||||
def create_backend_replica(backend_def: Union[Callable, Type[Callable], str]):
|
||||
"""Creates a replica class wrapping the provided function or class.
|
||||
|
||||
This approach is picked over inheritance to avoid conflict between user
|
||||
provided class and the RayServeReplica class.
|
||||
"""
|
||||
|
||||
if inspect.isfunction(func_or_class):
|
||||
is_function = True
|
||||
elif inspect.isclass(func_or_class):
|
||||
is_function = False
|
||||
else:
|
||||
assert False, "func_or_class must be function or class."
|
||||
backend_def = backend_def
|
||||
|
||||
# TODO(architkulkarni): Add type hints after upgrading cloudpickle
|
||||
class RayServeWrappedReplica(object):
|
||||
def __init__(self, backend_tag, replica_tag, init_args,
|
||||
backend_config: BackendConfig, controller_name: str):
|
||||
if isinstance(backend_def, str):
|
||||
backend = import_attr(backend_def)
|
||||
else:
|
||||
backend = backend_def
|
||||
|
||||
if inspect.isfunction(backend):
|
||||
is_function = True
|
||||
elif inspect.isclass(backend):
|
||||
is_function = False
|
||||
else:
|
||||
assert False, ("backend_def must be function, class, or "
|
||||
"corresponding import path.")
|
||||
|
||||
# Set the controller name so that serve.connect() in the user's
|
||||
# backend code will connect to the instance that this backend is
|
||||
# running in.
|
||||
ray.serve.api._set_internal_replica_context(
|
||||
backend_tag, replica_tag, controller_name)
|
||||
if is_function:
|
||||
_callable = func_or_class
|
||||
_callable = backend
|
||||
else:
|
||||
_callable = func_or_class(*init_args)
|
||||
_callable = backend(*init_args)
|
||||
|
||||
assert controller_name, "Must provide a valid controller_name"
|
||||
controller_handle = ray.get_actor(controller_name)
|
||||
@@ -144,8 +151,12 @@ def create_backend_replica(func_or_class: Union[Callable, Type[Callable]]):
|
||||
async def drain_pending_queries(self):
|
||||
return await self.backend.drain_pending_queries()
|
||||
|
||||
RayServeWrappedReplica.__name__ = "RayServeReplica_{}".format(
|
||||
func_or_class.__name__)
|
||||
if isinstance(backend_def, str):
|
||||
RayServeWrappedReplica.__name__ = "RayServeReplica_{}".format(
|
||||
backend_def)
|
||||
else:
|
||||
RayServeWrappedReplica.__name__ = "RayServeReplica_{}".format(
|
||||
backend_def.__name__)
|
||||
return RayServeWrappedReplica
|
||||
|
||||
|
||||
@@ -415,8 +426,7 @@ class RayServeReplica:
|
||||
if user_config:
|
||||
if self.is_function:
|
||||
raise ValueError(
|
||||
"argument func_or_class must be a class to use user_config"
|
||||
)
|
||||
"backend_def must be a class to use user_config")
|
||||
elif not hasattr(self.callable, BACKEND_RECONFIGURE_METHOD):
|
||||
raise RayServeException("user_config specified but backend " +
|
||||
self.backend_tag + " missing " +
|
||||
|
||||
+39
-29
@@ -5,22 +5,29 @@ from typing import Any, Dict, List, Optional
|
||||
|
||||
import pydantic
|
||||
from pydantic import BaseModel, confloat, PositiveFloat, PositiveInt, validator
|
||||
from ray.serve.constants import (ASYNC_CONCURRENCY, DEFAULT_HTTP_HOST,
|
||||
DEFAULT_HTTP_PORT)
|
||||
from ray.serve.constants import DEFAULT_HTTP_HOST, DEFAULT_HTTP_PORT
|
||||
|
||||
|
||||
def _callable_accepts_batch(func_or_class):
|
||||
if inspect.isfunction(func_or_class):
|
||||
return hasattr(func_or_class, "_serve_accept_batch")
|
||||
elif inspect.isclass(func_or_class):
|
||||
return hasattr(func_or_class.__call__, "_serve_accept_batch")
|
||||
def _callable_accepts_batch(backend_def):
|
||||
if inspect.isfunction(backend_def):
|
||||
return hasattr(backend_def, "_serve_accept_batch")
|
||||
elif inspect.isclass(backend_def):
|
||||
return hasattr(backend_def.__call__, "_serve_accept_batch")
|
||||
elif isinstance(backend_def, str):
|
||||
return True
|
||||
else:
|
||||
raise TypeError("backend_def must be function, class, or str.")
|
||||
|
||||
|
||||
def _callable_is_blocking(func_or_class):
|
||||
if inspect.isfunction(func_or_class):
|
||||
return not inspect.iscoroutinefunction(func_or_class)
|
||||
elif inspect.isclass(func_or_class):
|
||||
return not inspect.iscoroutinefunction(func_or_class.__call__)
|
||||
def _callable_is_blocking(backend_def):
|
||||
if inspect.isfunction(backend_def):
|
||||
return not inspect.iscoroutinefunction(backend_def)
|
||||
elif inspect.isclass(backend_def):
|
||||
return not inspect.iscoroutinefunction(backend_def.__call__)
|
||||
elif isinstance(backend_def, str):
|
||||
return False
|
||||
else:
|
||||
raise TypeError("backend_def must be function, class, or str.")
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -105,8 +112,11 @@ class BackendConfig(BaseModel):
|
||||
# Pipeline/async mode: if the servable is not blocking,
|
||||
# router should just keep pushing queries to the replicas
|
||||
# until a high limit.
|
||||
# TODO(edoakes): setting this to a relatively low constant because
|
||||
# we can't determine if imported backends are sync or async, but we
|
||||
# may consider tweaking it in the future.
|
||||
if not values["internal_metadata"].is_blocking:
|
||||
v = ASYNC_CONCURRENCY
|
||||
v = 100
|
||||
|
||||
# Batch inference mode: user specifies non zero timeout to wait for
|
||||
# full batch. We will use 2*max_batch_size to perform double
|
||||
@@ -119,12 +129,11 @@ class BackendConfig(BaseModel):
|
||||
|
||||
|
||||
class ReplicaConfig:
|
||||
def __init__(self, func_or_class, *actor_init_args,
|
||||
ray_actor_options=None):
|
||||
self.func_or_class = func_or_class
|
||||
self.accepts_batches = _callable_accepts_batch(func_or_class)
|
||||
self.is_blocking = _callable_is_blocking(func_or_class)
|
||||
self.actor_init_args = list(actor_init_args)
|
||||
def __init__(self, backend_def, *init_args, ray_actor_options=None):
|
||||
self.backend_def = backend_def
|
||||
self.accepts_batches = _callable_accepts_batch(backend_def)
|
||||
self.is_blocking = _callable_is_blocking(backend_def)
|
||||
self.init_args = list(init_args)
|
||||
if ray_actor_options is None:
|
||||
self.ray_actor_options = {}
|
||||
else:
|
||||
@@ -134,27 +143,28 @@ class ReplicaConfig:
|
||||
self._validate()
|
||||
|
||||
def _validate(self):
|
||||
# Validate that func_or_class is a function or class.
|
||||
if inspect.isfunction(self.func_or_class):
|
||||
if len(self.actor_init_args) != 0:
|
||||
# Validate that backend_def is an import path, function, or class.
|
||||
if isinstance(self.backend_def, str):
|
||||
pass
|
||||
elif inspect.isfunction(self.backend_def):
|
||||
if len(self.init_args) != 0:
|
||||
raise ValueError(
|
||||
"actor_init_args not supported for function backend.")
|
||||
elif not inspect.isclass(self.func_or_class):
|
||||
"init_args not supported for function backend.")
|
||||
elif not inspect.isclass(self.backend_def):
|
||||
raise TypeError(
|
||||
"Backend must be a function or class, it is {}.".format(
|
||||
type(self.func_or_class)))
|
||||
type(self.backend_def)))
|
||||
|
||||
if not isinstance(self.ray_actor_options, dict):
|
||||
raise TypeError("ray_actor_options must be a dictionary.")
|
||||
elif "lifetime" in self.ray_actor_options:
|
||||
raise ValueError(
|
||||
"Specifying lifetime in actor_init_args is not allowed.")
|
||||
"Specifying lifetime in init_args is not allowed.")
|
||||
elif "name" in self.ray_actor_options:
|
||||
raise ValueError(
|
||||
"Specifying name in actor_init_args is not allowed.")
|
||||
raise ValueError("Specifying name in init_args is not allowed.")
|
||||
elif "max_restarts" in self.ray_actor_options:
|
||||
raise ValueError("Specifying max_restarts in "
|
||||
"actor_init_args is not allowed.")
|
||||
"init_args is not allowed.")
|
||||
else:
|
||||
# Ray defaults to zero CPUs for placement, we default to one here.
|
||||
if "num_cpus" not in self.ray_actor_options:
|
||||
|
||||
@@ -1,13 +1,12 @@
|
||||
import requests
|
||||
|
||||
from ray import serve
|
||||
from ray.serve.backends import ImportedBackend
|
||||
|
||||
client = serve.start()
|
||||
|
||||
# Include your class as input to the ImportedBackend constructor.
|
||||
backend_class = ImportedBackend("ray.serve.utils.MockImportedBackend")
|
||||
client.create_backend("imported", backend_class, "input_arg")
|
||||
import_path = "ray.serve.utils.MockImportedBackend"
|
||||
client.create_backend("imported", import_path, "input_arg")
|
||||
client.create_endpoint("imported", backend="imported", route="/imported")
|
||||
|
||||
print(requests.get("http://127.0.0.1:8000/imported").text)
|
||||
|
||||
@@ -16,7 +16,7 @@ pytestmark = pytest.mark.asyncio
|
||||
|
||||
|
||||
def setup_worker(name,
|
||||
func_or_class,
|
||||
backend_def,
|
||||
init_args=None,
|
||||
backend_config=BackendConfig(),
|
||||
controller_name=""):
|
||||
@@ -26,7 +26,7 @@ def setup_worker(name,
|
||||
@ray.remote
|
||||
class WorkerActor:
|
||||
def __init__(self):
|
||||
self.worker = create_backend_replica(func_or_class)(
|
||||
self.worker = create_backend_replica(backend_def)(
|
||||
name, name + ":tag", init_args, backend_config,
|
||||
controller_name)
|
||||
|
||||
|
||||
@@ -3,7 +3,6 @@ import pytest
|
||||
from ray import serve
|
||||
from ray.serve.config import (BackendConfig, DeploymentMode, HTTPOptions,
|
||||
ReplicaConfig, BackendMetadata)
|
||||
from ray.serve.constants import ASYNC_CONCURRENCY
|
||||
from pydantic import ValidationError
|
||||
|
||||
|
||||
@@ -42,7 +41,7 @@ def test_backend_config_validation():
|
||||
assert BackendConfig(
|
||||
max_batch_size=10,
|
||||
internal_metadata=BackendMetadata(
|
||||
is_blocking=False)).max_concurrent_queries == ASYNC_CONCURRENCY
|
||||
is_blocking=False)).max_concurrent_queries == 100
|
||||
assert BackendConfig(
|
||||
max_batch_size=7, batch_wait_timeout=1.0).max_concurrent_queries == 14
|
||||
|
||||
|
||||
@@ -1,15 +1,16 @@
|
||||
import ray
|
||||
from ray.serve.backends import ImportedBackend
|
||||
from ray.serve.config import BackendConfig
|
||||
|
||||
|
||||
def test_imported_backend(serve_instance):
|
||||
client = serve_instance
|
||||
|
||||
backend_class = ImportedBackend("ray.serve.utils.MockImportedBackend")
|
||||
config = BackendConfig(user_config="config", max_batch_size=2)
|
||||
client.create_backend(
|
||||
"imported", backend_class, "input_arg", config=config)
|
||||
"imported",
|
||||
"ray.serve.utils.MockImportedBackend",
|
||||
"input_arg",
|
||||
config=config)
|
||||
client.create_endpoint("imported", backend="imported")
|
||||
|
||||
# Basic sanity check.
|
||||
@@ -27,3 +28,12 @@ def test_imported_backend(serve_instance):
|
||||
# Check that other call methods work.
|
||||
handle = handle.options(method_name="other_method")
|
||||
assert ray.get(handle.remote("hello")) == "hello"
|
||||
|
||||
# Check that functions work as well.
|
||||
client.create_backend(
|
||||
"imported_func",
|
||||
"ray.serve.utils.mock_imported_function",
|
||||
config=BackendConfig(max_batch_size=2))
|
||||
client.create_endpoint("imported_func", backend="imported_func")
|
||||
handle = client.get_handle("imported_func")
|
||||
assert ray.get(handle.remote("hello")) == "hello"
|
||||
|
||||
@@ -9,7 +9,7 @@ import pytest
|
||||
import ray
|
||||
from ray.serve.utils import (ServeEncoder, chain_future, unpack_future,
|
||||
try_schedule_resources_on_nodes,
|
||||
get_conda_env_dir, import_class)
|
||||
get_conda_env_dir, import_attr)
|
||||
|
||||
|
||||
def test_bytes_encoder():
|
||||
@@ -126,11 +126,11 @@ def test_get_conda_env_dir(tmp_path):
|
||||
os.environ["CONDA_PREFIX"] = ""
|
||||
|
||||
|
||||
def test_import_class():
|
||||
assert import_class("ray.serve.Client") == ray.serve.api.Client
|
||||
assert import_class("ray.serve.api.Client") == ray.serve.api.Client
|
||||
def test_import_attr():
|
||||
assert import_attr("ray.serve.Client") == ray.serve.api.Client
|
||||
assert import_attr("ray.serve.api.Client") == ray.serve.api.Client
|
||||
|
||||
policy_cls = import_class("ray.serve.controller.TrafficPolicy")
|
||||
policy_cls = import_attr("ray.serve.controller.TrafficPolicy")
|
||||
assert policy_cls == ray.serve.controller.TrafficPolicy
|
||||
|
||||
policy = policy_cls({"endpoint1": 0.5, "endpoint2": 0.5})
|
||||
@@ -140,6 +140,10 @@ def test_import_class():
|
||||
|
||||
print(repr(policy))
|
||||
|
||||
# Very meta...
|
||||
import_attr_2 = import_attr("ray.serve.utils.import_attr")
|
||||
assert import_attr_2 == import_attr
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
|
||||
@@ -359,22 +359,26 @@ def get_node_id_for_actor(actor_handle):
|
||||
return ray.actors()[actor_handle._actor_id.hex()]["Address"]["NodeID"]
|
||||
|
||||
|
||||
def import_class(full_path: str):
|
||||
"""Given a full import path to a class name, return the imported class.
|
||||
def import_attr(full_path: str):
|
||||
"""Given a full import path to a module attr, return the imported attr.
|
||||
|
||||
For example, the following are equivalent:
|
||||
MyClass = import_class("module.submodule.MyClass")
|
||||
MyClass = import_attr("module.submodule.MyClass")
|
||||
from module.submodule import MyClass
|
||||
|
||||
Returns:
|
||||
Imported class
|
||||
Imported attr
|
||||
"""
|
||||
|
||||
last_period_idx = full_path.rfind(".")
|
||||
class_name = full_path[last_period_idx + 1:]
|
||||
attr_name = full_path[last_period_idx + 1:]
|
||||
module_name = full_path[:last_period_idx]
|
||||
module = importlib.import_module(module_name)
|
||||
return getattr(module, class_name)
|
||||
return getattr(module, attr_name)
|
||||
|
||||
|
||||
async def mock_imported_function(batch):
|
||||
return [await request.body() for request in batch]
|
||||
|
||||
|
||||
class MockImportedBackend:
|
||||
|
||||
Reference in New Issue
Block a user