mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 03:18:59 +08:00
371 lines
15 KiB
Python
371 lines
15 KiB
Python
import asyncio
|
|
import traceback
|
|
import inspect
|
|
from collections.abc import Iterable
|
|
from itertools import groupby
|
|
from typing import Union, List, Any, Callable, Type
|
|
import time
|
|
|
|
import ray
|
|
from ray.async_compat import sync_to_async
|
|
|
|
from ray.serve.utils import (parse_request_item, _get_logger, chain_future,
|
|
unpack_future)
|
|
from ray.serve.exceptions import RayServeException
|
|
from ray.util import metrics
|
|
from ray.serve.config import BackendConfig
|
|
from ray.serve.router import Query
|
|
from ray.serve.constants import DEFAULT_LATENCY_BUCKET_MS
|
|
from ray.exceptions import RayTaskError
|
|
|
|
logger = _get_logger()
|
|
|
|
|
|
class BatchQueue:
|
|
def __init__(self, max_batch_size: int, timeout_s: float) -> None:
|
|
self.queue = asyncio.Queue()
|
|
self.full_batch_event = asyncio.Event()
|
|
self.max_batch_size = max_batch_size
|
|
self.timeout_s = timeout_s
|
|
|
|
def set_config(self, max_batch_size: int, timeout_s: float) -> None:
|
|
self.max_batch_size = max_batch_size
|
|
self.timeout_s = timeout_s
|
|
|
|
def put(self, request: Query) -> None:
|
|
self.queue.put_nowait(request)
|
|
# Signal when the full batch is ready. The event will be reset
|
|
# in wait_for_batch.
|
|
if self.queue.qsize() == self.max_batch_size:
|
|
self.full_batch_event.set()
|
|
|
|
def qsize(self) -> int:
|
|
return self.queue.qsize()
|
|
|
|
async def wait_for_batch(self) -> List[Query]:
|
|
"""Wait for batch respecting self.max_batch_size and self.timeout_s.
|
|
|
|
Returns a batch of up to self.max_batch_size items, waiting for up
|
|
to self.timeout_s for a full batch. After the timeout, returns as many
|
|
items as are ready.
|
|
|
|
Always returns a batch with at least one item - will block
|
|
indefinitely until an item comes in.
|
|
"""
|
|
curr_timeout = self.timeout_s
|
|
batch = []
|
|
while len(batch) == 0:
|
|
loop_start = time.time()
|
|
|
|
# If the timeout is 0, wait for any item to be available on the
|
|
# queue.
|
|
if curr_timeout == 0:
|
|
batch.append(await self.queue.get())
|
|
# If the timeout is nonzero, wait for either the timeout to occur
|
|
# or the max batch size to be ready.
|
|
else:
|
|
try:
|
|
await asyncio.wait_for(self.full_batch_event.wait(),
|
|
curr_timeout)
|
|
except asyncio.TimeoutError:
|
|
pass
|
|
|
|
# Pull up to the max_batch_size requests off the queue.
|
|
while len(batch) < self.max_batch_size and not self.queue.empty():
|
|
batch.append(self.queue.get_nowait())
|
|
|
|
# Reset the event if there are fewer than max_batch_size requests
|
|
# in the queue.
|
|
if (self.queue.qsize() < self.max_batch_size
|
|
and self.full_batch_event.is_set()):
|
|
self.full_batch_event.clear()
|
|
|
|
# Adjust the timeout based on the time spent in this iteration.
|
|
curr_timeout = max(0, curr_timeout - (time.time() - loop_start))
|
|
|
|
return batch
|
|
|
|
|
|
def create_backend_worker(func_or_class: Union[Callable, Type[Callable]]):
|
|
"""Creates a worker class wrapping the provided function or class."""
|
|
|
|
if inspect.isfunction(func_or_class):
|
|
is_function = True
|
|
elif inspect.isclass(func_or_class):
|
|
is_function = False
|
|
else:
|
|
assert False, "func_or_class must be function or class."
|
|
|
|
# TODO(architkulkarni): Add type hints after upgrading cloudpickle
|
|
class RayServeWrappedWorker(object):
|
|
def __init__(self, backend_tag, replica_tag, init_args,
|
|
backend_config: BackendConfig, controller_name: str):
|
|
# Set the controller name so that serve.connect() will connect to
|
|
# the instance that this backend is running in.
|
|
ray.serve.api._set_internal_controller_name(controller_name)
|
|
if is_function:
|
|
_callable = func_or_class
|
|
else:
|
|
_callable = func_or_class(*init_args)
|
|
|
|
self.backend = RayServeWorker(backend_tag, replica_tag, _callable,
|
|
backend_config, is_function)
|
|
|
|
async def handle_request(self, request):
|
|
return await self.backend.handle_request(request)
|
|
|
|
def update_config(self, new_config: BackendConfig):
|
|
return self.backend.update_config(new_config)
|
|
|
|
def ready(self):
|
|
pass
|
|
|
|
RayServeWrappedWorker.__name__ = "RayServeWorker_" + func_or_class.__name__
|
|
return RayServeWrappedWorker
|
|
|
|
|
|
def wrap_to_ray_error(exception: Exception) -> RayTaskError:
|
|
"""Utility method to wrap exceptions in user code."""
|
|
|
|
try:
|
|
# Raise and catch so we can access traceback.format_exc()
|
|
raise exception
|
|
except Exception as e:
|
|
traceback_str = ray.utils.format_error_message(traceback.format_exc())
|
|
return ray.exceptions.RayTaskError(str(e), traceback_str, e.__class__)
|
|
|
|
|
|
def ensure_async(func: Callable) -> Callable:
|
|
return sync_to_async(func)
|
|
|
|
|
|
class RayServeWorker:
|
|
"""Handles requests with the provided callable."""
|
|
|
|
def __init__(self, backend_tag: str, replica_tag: str, _callable: Callable,
|
|
backend_config: BackendConfig, is_function: bool) -> None:
|
|
self.backend_tag = backend_tag
|
|
self.replica_tag = replica_tag
|
|
self.callable = _callable
|
|
self.is_function = is_function
|
|
|
|
self.config = backend_config
|
|
self.batch_queue = BatchQueue(self.config.max_batch_size or 1,
|
|
self.config.batch_wait_timeout)
|
|
|
|
self.num_ongoing_requests = 0
|
|
|
|
self.request_counter = metrics.Count(
|
|
"backend_request_counter",
|
|
description=("Number of queries that have been "
|
|
"processed in this replica"),
|
|
tag_keys=("backend", ))
|
|
self.request_counter.set_default_tags({"backend": self.backend_tag})
|
|
|
|
self.error_counter = metrics.Count(
|
|
"backend_error_counter",
|
|
description=("Number of exceptions that have "
|
|
"occurred in the backend"),
|
|
tag_keys=("backend", ))
|
|
self.error_counter.set_default_tags({"backend": self.backend_tag})
|
|
|
|
self.restart_counter = metrics.Count(
|
|
"backend_worker_starts",
|
|
description=("The number of time this replica workers "
|
|
"has been restarted due to failure."),
|
|
tag_keys=("backend", "replica_tag"))
|
|
self.restart_counter.set_default_tags({
|
|
"backend": self.backend_tag,
|
|
"replica_tag": self.replica_tag
|
|
})
|
|
|
|
self.queuing_latency_tracker = metrics.Histogram(
|
|
"backend_queuing_latency_ms",
|
|
description=(
|
|
"The latency for queries waiting in the replica's queue "
|
|
"waiting to be processed or batched."),
|
|
boundaries=DEFAULT_LATENCY_BUCKET_MS,
|
|
tag_keys=("backend", "replica_tag"))
|
|
self.queuing_latency_tracker.set_default_tags({
|
|
"backend": self.backend_tag,
|
|
"replica_tag": self.replica_tag
|
|
})
|
|
|
|
self.processing_latency_tracker = metrics.Histogram(
|
|
"backend_processing_latency_ms",
|
|
description="The latency for queries to be processed",
|
|
boundaries=DEFAULT_LATENCY_BUCKET_MS,
|
|
tag_keys=("backend", "replica_tag", "batch_size"))
|
|
self.processing_latency_tracker.set_default_tags({
|
|
"backend": self.backend_tag,
|
|
"replica_tag": self.replica_tag
|
|
})
|
|
|
|
self.num_queued_items = metrics.Gauge(
|
|
"replica_queued_queries",
|
|
description=("Current number of queries queued in the "
|
|
"the backend replicas"),
|
|
tag_keys=("backend", "replica_tag"))
|
|
self.num_queued_items.set_default_tags({
|
|
"backend": self.backend_tag,
|
|
"replica_tag": self.replica_tag
|
|
})
|
|
|
|
self.num_processing_items = metrics.Gauge(
|
|
"replica_processing_queries",
|
|
description="Current number of queries being processed",
|
|
tag_keys=("backend", "replica_tag"))
|
|
self.num_processing_items.set_default_tags({
|
|
"backend": self.backend_tag,
|
|
"replica_tag": self.replica_tag
|
|
})
|
|
|
|
self.restart_counter.record(1)
|
|
|
|
asyncio.get_event_loop().create_task(self.main_loop())
|
|
|
|
def get_runner_method(self, request_item: Query) -> Callable:
|
|
method_name = request_item.metadata.call_method
|
|
if not hasattr(self.callable, method_name):
|
|
raise RayServeException("Backend doesn't have method {} "
|
|
"which is specified in the request. "
|
|
"The available methods are {}".format(
|
|
method_name, dir(self.callable)))
|
|
if self.is_function:
|
|
return self.callable
|
|
return getattr(self.callable, method_name)
|
|
|
|
async def invoke_single(self, request_item: Query) -> Any:
|
|
method_to_call = ensure_async(self.get_runner_method(request_item))
|
|
arg = parse_request_item(request_item)
|
|
|
|
start = time.time()
|
|
try:
|
|
result = await method_to_call(arg)
|
|
self.request_counter.record(1)
|
|
except Exception as e:
|
|
import os
|
|
if "RAY_PDB" in os.environ:
|
|
ray.util.pdb.post_mortem()
|
|
result = wrap_to_ray_error(e)
|
|
self.error_counter.record(1)
|
|
|
|
self.processing_latency_tracker.record(
|
|
(time.time() - start) * 1000, tags={"batch_size": "1"})
|
|
|
|
return result
|
|
|
|
async def invoke_batch(self, request_item_list: List[Query]) -> List[Any]:
|
|
args = []
|
|
call_methods = set()
|
|
batch_size = len(request_item_list)
|
|
|
|
# Construct the batch of requests
|
|
for item in request_item_list:
|
|
args.append(parse_request_item(item))
|
|
call_methods.add(self.get_runner_method(item))
|
|
|
|
timing_start = time.time()
|
|
try:
|
|
if len(call_methods) != 1:
|
|
raise RayServeException(
|
|
f"Queries contain mixed calling methods: {call_methods}. "
|
|
"Please only send the same type of requests in batching "
|
|
"mode.")
|
|
|
|
self.request_counter.record(batch_size)
|
|
|
|
call_method = ensure_async(call_methods.pop())
|
|
result_list = await call_method(args)
|
|
|
|
if not isinstance(result_list, Iterable) or isinstance(
|
|
result_list, (dict, set)):
|
|
error_message = ("RayServe expects an ordered iterable object "
|
|
"but the worker returned a {}".format(
|
|
type(result_list)))
|
|
raise RayServeException(error_message)
|
|
|
|
# Normalize the result into a list type. This operation is fast
|
|
# in Python because it doesn't copy anything.
|
|
result_list = list(result_list)
|
|
|
|
if (len(result_list) != batch_size):
|
|
error_message = ("Worker doesn't preserve batch size. The "
|
|
"input has length {} but the returned list "
|
|
"has length {}. Please return a list of "
|
|
"results with length equal to the batch size"
|
|
".".format(batch_size, len(result_list)))
|
|
raise RayServeException(error_message)
|
|
except Exception as e:
|
|
wrapped_exception = wrap_to_ray_error(e)
|
|
self.error_counter.record(1)
|
|
result_list = [wrapped_exception for _ in range(batch_size)]
|
|
|
|
self.processing_latency_tracker.record(
|
|
(time.time() - timing_start) * 1000,
|
|
tags={"batch_size": str(batch_size)})
|
|
|
|
return result_list
|
|
|
|
async def main_loop(self) -> None:
|
|
while True:
|
|
# NOTE(simon): There's an issue when user updated batch size and
|
|
# batch wait timeout during the execution, these values will not be
|
|
# updated until after the current iteration.
|
|
batch = await self.batch_queue.wait_for_batch()
|
|
|
|
# Record metrics
|
|
self.num_queued_items.record(self.batch_queue.qsize())
|
|
self.num_processing_items.record(self.num_ongoing_requests -
|
|
self.batch_queue.qsize())
|
|
for query in batch:
|
|
queuing_time = (time.time() - query.tick_enter_replica) * 1000
|
|
self.queuing_latency_tracker.record(queuing_time)
|
|
|
|
all_evaluated_futures = []
|
|
|
|
if not self.config.internal_metadata.accepts_batches:
|
|
query = batch[0]
|
|
evaluated = asyncio.ensure_future(self.invoke_single(query))
|
|
all_evaluated_futures = [evaluated]
|
|
chain_future(evaluated, query.async_future)
|
|
else:
|
|
get_call_method = (
|
|
lambda query: query.metadata.call_method # noqa: E731
|
|
)
|
|
sorted_batch = sorted(batch, key=get_call_method)
|
|
for _, group in groupby(sorted_batch, key=get_call_method):
|
|
group = list(group)
|
|
evaluated = asyncio.ensure_future(self.invoke_batch(group))
|
|
all_evaluated_futures.append(evaluated)
|
|
result_futures = [q.async_future for q in group]
|
|
chain_future(
|
|
unpack_future(evaluated, len(group)), result_futures)
|
|
|
|
if self.config.internal_metadata.is_blocking:
|
|
# We use asyncio.wait here so if the result is exception,
|
|
# it will not be raised.
|
|
await asyncio.wait(all_evaluated_futures)
|
|
|
|
def update_config(self, new_config: BackendConfig) -> None:
|
|
self.config = new_config
|
|
self.batch_queue.set_config(self.config.max_batch_size or 1,
|
|
self.config.batch_wait_timeout)
|
|
|
|
async def handle_request(self,
|
|
request: Union[Query, bytes]) -> asyncio.Future:
|
|
if isinstance(request, bytes):
|
|
request = Query.ray_deserialize(request)
|
|
|
|
request.tick_enter_replica = time.time()
|
|
logger.debug("Worker {} got request {}".format(self.replica_tag,
|
|
request))
|
|
request.async_future = asyncio.get_event_loop().create_future()
|
|
self.num_ongoing_requests += 1
|
|
|
|
self.batch_queue.put(request)
|
|
result = await request.async_future
|
|
|
|
self.num_ongoing_requests -= 1
|
|
return result
|