[Serve] Reconfigure backend class at runtime (#11709)

This commit is contained in:
architkulkarni
2020-11-09 12:04:51 -08:00
committed by GitHub
parent 287aba6dc3
commit adcaabcd64
8 changed files with 127 additions and 4 deletions
+8
View File
@@ -176,6 +176,9 @@ class Client:
- "max_concurrent_queries": the maximum number of queries
that will be sent to a replica of this backend
without receiving a response.
- "user_config" (experimental): Arguments to pass to the
reconfigure method of the backend. The reconfigure method is
called if "user_config" is not None.
"""
if not isinstance(config_options, (BackendConfig, dict)):
@@ -228,6 +231,9 @@ class Client:
- "max_concurrent_queries": the maximum number of queries that
will be sent to a replica of this backend without receiving a
response.
- "user_config" (experimental): Arguments to pass to the
reconfigure method of the backend. The reconfigure method is
called if "user_config" is not None.
env (serve.CondaEnv, optional): conda environment to run this
backend in. Requires the caller to be running in an activated
conda environment (not necessarily ``env``), and requires
@@ -263,6 +269,7 @@ class Client:
metadata = BackendMetadata(
accepts_batches=replica_config.accepts_batches,
is_blocking=replica_config.is_blocking)
if isinstance(config, dict):
backend_config = BackendConfig.parse_obj({
**config, "internal_metadata": metadata
@@ -272,6 +279,7 @@ class Client:
update={"internal_metadata": metadata})
else:
raise TypeError("config must be a BackendConfig or a dictionary.")
backend_config._validate_complete()
ray.get(
self._controller.create_backend.remote(backend_tag, backend_config,
+18 -1
View File
@@ -15,7 +15,8 @@ from ray.serve.exceptions import RayServeException
from ray.util import metrics
from ray.serve.config import BackendConfig
from ray.serve.router import Query
from ray.serve.constants import DEFAULT_LATENCY_BUCKET_MS
from ray.serve.constants import (DEFAULT_LATENCY_BUCKET_MS,
BACKEND_RECONFIGURE_METHOD)
from ray.exceptions import RayTaskError
logger = _get_logger()
@@ -152,6 +153,7 @@ class RayServeWorker:
self.config = backend_config
self.batch_queue = BatchQueue(self.config.max_batch_size or 1,
self.config.batch_wait_timeout)
self.reconfigure(self.config.user_config)
self.num_ongoing_requests = 0
@@ -347,10 +349,25 @@ class RayServeWorker:
# it will not be raised.
await asyncio.wait(all_evaluated_futures)
def reconfigure(self, user_config) -> None:
if user_config:
if self.is_function:
raise ValueError(
"argument func_or_class must be a class to use user_config"
)
elif not hasattr(self.callable, BACKEND_RECONFIGURE_METHOD):
raise RayServeException("user_config specified but backend " +
self.backend_tag + " missing " +
BACKEND_RECONFIGURE_METHOD + " method")
reconfigure_method = getattr(self.callable,
BACKEND_RECONFIGURE_METHOD)
reconfigure_method(user_config)
def update_config(self, new_config: BackendConfig) -> None:
self.config = new_config
self.batch_queue.set_config(self.config.max_batch_size or 1,
self.config.batch_wait_timeout)
self.reconfigure(self.config.user_config)
async def handle_request(self,
request: Union[Query, bytes]) -> asyncio.Future:
+5
View File
@@ -45,6 +45,10 @@ class BackendConfig(BaseModel):
sent to a replica of this backend without receiving a response.
Defaults to None (no maximum).
:type max_concurrent_queries: int, optional
:param user_config: Arguments to pass to the reconfigure method of the
backend. The reconfigure method is called if user_config is not
None.
:type user_config: Any, optional
"""
internal_metadata: BackendMetadata = BackendMetadata()
@@ -52,6 +56,7 @@ class BackendConfig(BaseModel):
max_batch_size: Optional[PositiveInt] = None
batch_wait_timeout: float = 0
max_concurrent_queries: Optional[int] = None
user_config: Any = None
class Config:
validate_assignment = True
+3
View File
@@ -34,3 +34,6 @@ DEFAULT_LATENCY_BUCKET_MS = [
2000,
5000,
]
#: Name of backend reconfiguration method implemented by user.
BACKEND_RECONFIGURE_METHOD = "reconfigure"
+3 -1
View File
@@ -138,7 +138,9 @@ class ActorStateReconciler:
return_list = []
for replica_tag in self.replicas.get(backend_tag, []):
try:
return_list.append(ray.get_actor(replica_tag))
replica_name = format_actor_name(replica_tag,
self.controller_name)
return_list.append(ray.get_actor(replica_name))
except ValueError:
pass
return return_list
@@ -0,0 +1,33 @@
import requests
import random
import ray
from ray import serve
from ray.serve import BackendConfig
ray.init()
client = serve.start()
class Threshold:
def __init__(self):
# self.model won't be changed by reconfigure.
self.model = random.Random() # Imagine this is some heavyweight model.
def reconfigure(self, config):
# This will be called when the class is created and when
# the user_config field of BackendConfig is updated.
self.threshold = config["threshold"]
def __call__(self, request):
return self.model.random() > self.threshold
backend_config = BackendConfig(user_config={"threshold": 0.01})
client.create_backend("threshold", Threshold, config=backend_config)
client.create_endpoint("threshold", backend="threshold", route="/threshold")
print(requests.get("http://127.0.0.1:8000/threshold").text) # true, probably
backend_config = BackendConfig(user_config={"threshold": 0.99})
client.update_backend_config("threshold", backend_config)
print(requests.get("http://127.0.0.1:8000/threshold").text) # false, probably
+37 -1
View File
@@ -1,7 +1,7 @@
import asyncio
from collections import defaultdict
import time
import os
import pytest
import requests
@@ -48,6 +48,42 @@ def test_e2e(serve_instance):
assert resp == "POST"
def test_backend_user_config(serve_instance):
client = serve_instance
class Counter:
def __init__(self):
self.count = 10
def __call__(self, flask_request):
return self.count, os.getpid()
def reconfigure(self, config):
self.count = config["count"]
config = BackendConfig(num_replicas=2, user_config={"count": 123, "b": 2})
client.create_backend("counter", Counter, config=config)
client.create_endpoint("counter", backend="counter", route="/counter")
handle = client.get_handle("counter")
def check(val, num_replicas):
pids_seen = set()
for i in range(100):
result = ray.get(handle.remote())
assert (str(result[0]) == val), result[0]
pids_seen.add(result[1])
assert (len(pids_seen) == num_replicas)
check("123", 2)
client.update_backend_config("counter", BackendConfig(num_replicas=3))
check("123", 3)
config = BackendConfig(user_config={"count": 456})
client.update_backend_config("counter", config)
check("456", 3)
def test_call_method(serve_instance):
client = serve_instance