mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 01:16:06 +08:00
[Serve] Reconfigure backend class at runtime (#11709)
This commit is contained in:
@@ -176,6 +176,9 @@ class Client:
|
||||
- "max_concurrent_queries": the maximum number of queries
|
||||
that will be sent to a replica of this backend
|
||||
without receiving a response.
|
||||
- "user_config" (experimental): Arguments to pass to the
|
||||
reconfigure method of the backend. The reconfigure method is
|
||||
called if "user_config" is not None.
|
||||
"""
|
||||
|
||||
if not isinstance(config_options, (BackendConfig, dict)):
|
||||
@@ -228,6 +231,9 @@ class Client:
|
||||
- "max_concurrent_queries": the maximum number of queries that
|
||||
will be sent to a replica of this backend without receiving a
|
||||
response.
|
||||
- "user_config" (experimental): Arguments to pass to the
|
||||
reconfigure method of the backend. The reconfigure method is
|
||||
called if "user_config" is not None.
|
||||
env (serve.CondaEnv, optional): conda environment to run this
|
||||
backend in. Requires the caller to be running in an activated
|
||||
conda environment (not necessarily ``env``), and requires
|
||||
@@ -263,6 +269,7 @@ class Client:
|
||||
metadata = BackendMetadata(
|
||||
accepts_batches=replica_config.accepts_batches,
|
||||
is_blocking=replica_config.is_blocking)
|
||||
|
||||
if isinstance(config, dict):
|
||||
backend_config = BackendConfig.parse_obj({
|
||||
**config, "internal_metadata": metadata
|
||||
@@ -272,6 +279,7 @@ class Client:
|
||||
update={"internal_metadata": metadata})
|
||||
else:
|
||||
raise TypeError("config must be a BackendConfig or a dictionary.")
|
||||
|
||||
backend_config._validate_complete()
|
||||
ray.get(
|
||||
self._controller.create_backend.remote(backend_tag, backend_config,
|
||||
|
||||
@@ -15,7 +15,8 @@ from ray.serve.exceptions import RayServeException
|
||||
from ray.util import metrics
|
||||
from ray.serve.config import BackendConfig
|
||||
from ray.serve.router import Query
|
||||
from ray.serve.constants import DEFAULT_LATENCY_BUCKET_MS
|
||||
from ray.serve.constants import (DEFAULT_LATENCY_BUCKET_MS,
|
||||
BACKEND_RECONFIGURE_METHOD)
|
||||
from ray.exceptions import RayTaskError
|
||||
|
||||
logger = _get_logger()
|
||||
@@ -152,6 +153,7 @@ class RayServeWorker:
|
||||
self.config = backend_config
|
||||
self.batch_queue = BatchQueue(self.config.max_batch_size or 1,
|
||||
self.config.batch_wait_timeout)
|
||||
self.reconfigure(self.config.user_config)
|
||||
|
||||
self.num_ongoing_requests = 0
|
||||
|
||||
@@ -347,10 +349,25 @@ class RayServeWorker:
|
||||
# it will not be raised.
|
||||
await asyncio.wait(all_evaluated_futures)
|
||||
|
||||
def reconfigure(self, user_config) -> None:
|
||||
if user_config:
|
||||
if self.is_function:
|
||||
raise ValueError(
|
||||
"argument func_or_class must be a class to use user_config"
|
||||
)
|
||||
elif not hasattr(self.callable, BACKEND_RECONFIGURE_METHOD):
|
||||
raise RayServeException("user_config specified but backend " +
|
||||
self.backend_tag + " missing " +
|
||||
BACKEND_RECONFIGURE_METHOD + " method")
|
||||
reconfigure_method = getattr(self.callable,
|
||||
BACKEND_RECONFIGURE_METHOD)
|
||||
reconfigure_method(user_config)
|
||||
|
||||
def update_config(self, new_config: BackendConfig) -> None:
|
||||
self.config = new_config
|
||||
self.batch_queue.set_config(self.config.max_batch_size or 1,
|
||||
self.config.batch_wait_timeout)
|
||||
self.reconfigure(self.config.user_config)
|
||||
|
||||
async def handle_request(self,
|
||||
request: Union[Query, bytes]) -> asyncio.Future:
|
||||
|
||||
@@ -45,6 +45,10 @@ class BackendConfig(BaseModel):
|
||||
sent to a replica of this backend without receiving a response.
|
||||
Defaults to None (no maximum).
|
||||
:type max_concurrent_queries: int, optional
|
||||
:param user_config: Arguments to pass to the reconfigure method of the
|
||||
backend. The reconfigure method is called if user_config is not
|
||||
None.
|
||||
:type user_config: Any, optional
|
||||
"""
|
||||
|
||||
internal_metadata: BackendMetadata = BackendMetadata()
|
||||
@@ -52,6 +56,7 @@ class BackendConfig(BaseModel):
|
||||
max_batch_size: Optional[PositiveInt] = None
|
||||
batch_wait_timeout: float = 0
|
||||
max_concurrent_queries: Optional[int] = None
|
||||
user_config: Any = None
|
||||
|
||||
class Config:
|
||||
validate_assignment = True
|
||||
|
||||
@@ -34,3 +34,6 @@ DEFAULT_LATENCY_BUCKET_MS = [
|
||||
2000,
|
||||
5000,
|
||||
]
|
||||
|
||||
#: Name of backend reconfiguration method implemented by user.
|
||||
BACKEND_RECONFIGURE_METHOD = "reconfigure"
|
||||
|
||||
@@ -138,7 +138,9 @@ class ActorStateReconciler:
|
||||
return_list = []
|
||||
for replica_tag in self.replicas.get(backend_tag, []):
|
||||
try:
|
||||
return_list.append(ray.get_actor(replica_tag))
|
||||
replica_name = format_actor_name(replica_tag,
|
||||
self.controller_name)
|
||||
return_list.append(ray.get_actor(replica_name))
|
||||
except ValueError:
|
||||
pass
|
||||
return return_list
|
||||
|
||||
@@ -0,0 +1,33 @@
|
||||
import requests
|
||||
import random
|
||||
|
||||
import ray
|
||||
from ray import serve
|
||||
from ray.serve import BackendConfig
|
||||
|
||||
ray.init()
|
||||
client = serve.start()
|
||||
|
||||
|
||||
class Threshold:
|
||||
def __init__(self):
|
||||
# self.model won't be changed by reconfigure.
|
||||
self.model = random.Random() # Imagine this is some heavyweight model.
|
||||
|
||||
def reconfigure(self, config):
|
||||
# This will be called when the class is created and when
|
||||
# the user_config field of BackendConfig is updated.
|
||||
self.threshold = config["threshold"]
|
||||
|
||||
def __call__(self, request):
|
||||
return self.model.random() > self.threshold
|
||||
|
||||
|
||||
backend_config = BackendConfig(user_config={"threshold": 0.01})
|
||||
client.create_backend("threshold", Threshold, config=backend_config)
|
||||
client.create_endpoint("threshold", backend="threshold", route="/threshold")
|
||||
print(requests.get("http://127.0.0.1:8000/threshold").text) # true, probably
|
||||
|
||||
backend_config = BackendConfig(user_config={"threshold": 0.99})
|
||||
client.update_backend_config("threshold", backend_config)
|
||||
print(requests.get("http://127.0.0.1:8000/threshold").text) # false, probably
|
||||
@@ -1,7 +1,7 @@
|
||||
import asyncio
|
||||
from collections import defaultdict
|
||||
import time
|
||||
|
||||
import os
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
@@ -48,6 +48,42 @@ def test_e2e(serve_instance):
|
||||
assert resp == "POST"
|
||||
|
||||
|
||||
def test_backend_user_config(serve_instance):
|
||||
client = serve_instance
|
||||
|
||||
class Counter:
|
||||
def __init__(self):
|
||||
self.count = 10
|
||||
|
||||
def __call__(self, flask_request):
|
||||
return self.count, os.getpid()
|
||||
|
||||
def reconfigure(self, config):
|
||||
self.count = config["count"]
|
||||
|
||||
config = BackendConfig(num_replicas=2, user_config={"count": 123, "b": 2})
|
||||
client.create_backend("counter", Counter, config=config)
|
||||
client.create_endpoint("counter", backend="counter", route="/counter")
|
||||
handle = client.get_handle("counter")
|
||||
|
||||
def check(val, num_replicas):
|
||||
pids_seen = set()
|
||||
for i in range(100):
|
||||
result = ray.get(handle.remote())
|
||||
assert (str(result[0]) == val), result[0]
|
||||
pids_seen.add(result[1])
|
||||
assert (len(pids_seen) == num_replicas)
|
||||
|
||||
check("123", 2)
|
||||
|
||||
client.update_backend_config("counter", BackendConfig(num_replicas=3))
|
||||
check("123", 3)
|
||||
|
||||
config = BackendConfig(user_config={"count": 456})
|
||||
client.update_backend_config("counter", config)
|
||||
check("456", 3)
|
||||
|
||||
|
||||
def test_call_method(serve_instance):
|
||||
client = serve_instance
|
||||
|
||||
|
||||
Reference in New Issue
Block a user