mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 11:53:32 +08:00
[Serve] add type hints for controller and backend_worker (#10288)
This commit is contained in:
@@ -5,7 +5,7 @@ from collections.abc import Iterable
|
||||
from collections import defaultdict
|
||||
from itertools import groupby
|
||||
from operator import attrgetter
|
||||
from typing import Union
|
||||
from typing import Union, List, Any, Callable, Type
|
||||
import time
|
||||
|
||||
import ray
|
||||
@@ -20,32 +20,33 @@ from ray.serve.exceptions import RayServeException
|
||||
from ray.experimental import metrics
|
||||
from ray.serve.config import BackendConfig
|
||||
from ray.serve.router import Query
|
||||
from ray.exceptions import RayTaskError
|
||||
|
||||
logger = _get_logger()
|
||||
|
||||
|
||||
class BatchQueue:
|
||||
def __init__(self, max_batch_size, timeout_s):
|
||||
def __init__(self, max_batch_size: int, timeout_s: float) -> None:
|
||||
self.queue = asyncio.Queue()
|
||||
self.full_batch_event = asyncio.Event()
|
||||
self.max_batch_size = max_batch_size
|
||||
self.timeout_s = timeout_s
|
||||
|
||||
def set_config(self, max_batch_size, timeout_s):
|
||||
def set_config(self, max_batch_size: int, timeout_s: float) -> None:
|
||||
self.max_batch_size = max_batch_size
|
||||
self.timeout_s = timeout_s
|
||||
|
||||
def put(self, request):
|
||||
def put(self, request: Query) -> None:
|
||||
self.queue.put_nowait(request)
|
||||
# Signal when the full batch is ready. The event will be reset
|
||||
# in wait_for_batch.
|
||||
if self.queue.qsize() == self.max_batch_size:
|
||||
self.full_batch_event.set()
|
||||
|
||||
def qsize(self):
|
||||
def qsize(self) -> int:
|
||||
return self.queue.qsize()
|
||||
|
||||
async def wait_for_batch(self):
|
||||
async def wait_for_batch(self) -> List[Query]:
|
||||
"""Wait for batch respecting self.max_batch_size and self.timeout_s.
|
||||
|
||||
Returns a batch of up to self.max_batch_size items, waiting for up
|
||||
@@ -89,7 +90,7 @@ class BatchQueue:
|
||||
return batch
|
||||
|
||||
|
||||
def create_backend_worker(func_or_class):
|
||||
def create_backend_worker(func_or_class: Union[Callable, Type[Callable]]):
|
||||
"""Creates a worker class wrapping the provided function or class."""
|
||||
|
||||
if inspect.isfunction(func_or_class):
|
||||
@@ -99,6 +100,7 @@ def create_backend_worker(func_or_class):
|
||||
else:
|
||||
assert False, "func_or_class must be function or class."
|
||||
|
||||
# TODO(architkulkarni): Add type hints after upgrading cloudpickle
|
||||
class RayServeWrappedWorker(object):
|
||||
def __init__(self,
|
||||
backend_tag,
|
||||
@@ -129,7 +131,7 @@ def create_backend_worker(func_or_class):
|
||||
return RayServeWrappedWorker
|
||||
|
||||
|
||||
def wrap_to_ray_error(exception):
|
||||
def wrap_to_ray_error(exception: Exception) -> RayTaskError:
|
||||
"""Utility method to wrap exceptions in user code."""
|
||||
|
||||
try:
|
||||
@@ -140,7 +142,7 @@ def wrap_to_ray_error(exception):
|
||||
return ray.exceptions.RayTaskError(str(e), traceback_str, e.__class__)
|
||||
|
||||
|
||||
def ensure_async(func):
|
||||
def ensure_async(func: Callable) -> Callable:
|
||||
if inspect.iscoroutinefunction(func):
|
||||
return func
|
||||
else:
|
||||
@@ -150,8 +152,8 @@ def ensure_async(func):
|
||||
class RayServeWorker:
|
||||
"""Handles requests with the provided callable."""
|
||||
|
||||
def __init__(self, backend_tag, replica_tag, _callable,
|
||||
backend_config: BackendConfig, is_function):
|
||||
def __init__(self, backend_tag: str, replica_tag: str, _callable: Callable,
|
||||
backend_config: BackendConfig, is_function: bool) -> None:
|
||||
self.backend_tag = backend_tag
|
||||
self.replica_tag = replica_tag
|
||||
self.callable = _callable
|
||||
@@ -182,7 +184,7 @@ class RayServeWorker:
|
||||
|
||||
asyncio.get_event_loop().create_task(self.main_loop())
|
||||
|
||||
def get_runner_method(self, request_item):
|
||||
def get_runner_method(self, request_item: Query) -> Callable:
|
||||
method_name = request_item.call_method
|
||||
if not hasattr(self.callable, method_name):
|
||||
raise RayServeException("Backend doesn't have method {} "
|
||||
@@ -193,7 +195,7 @@ class RayServeWorker:
|
||||
return self.callable
|
||||
return getattr(self.callable, method_name)
|
||||
|
||||
def has_positional_args(self, f):
|
||||
def has_positional_args(self, f: Callable) -> bool:
|
||||
# NOTE:
|
||||
# In the case of simple functions, not actors, the f will be
|
||||
# function.__call__, but we need to inspect the function itself.
|
||||
@@ -207,13 +209,13 @@ class RayServeWorker:
|
||||
return True
|
||||
return False
|
||||
|
||||
def _reset_context(self):
|
||||
def _reset_context(self) -> None:
|
||||
# NOTE(simon): context management won't work in async mode because
|
||||
# many concurrent queries might be running at the same time.
|
||||
serve_context.web = None
|
||||
serve_context.batch_size = None
|
||||
|
||||
async def invoke_single(self, request_item):
|
||||
async def invoke_single(self, request_item: Query) -> Any:
|
||||
args, kwargs, is_web_context = parse_request_item(request_item)
|
||||
serve_context.web = is_web_context
|
||||
|
||||
@@ -231,7 +233,7 @@ class RayServeWorker:
|
||||
|
||||
return result
|
||||
|
||||
async def invoke_batch(self, request_item_list):
|
||||
async def invoke_batch(self, request_item_list: List[Query]) -> List[Any]:
|
||||
arg_list = []
|
||||
kwargs_list = defaultdict(list)
|
||||
context_flags = set()
|
||||
@@ -308,7 +310,7 @@ class RayServeWorker:
|
||||
self._reset_context()
|
||||
return [wrapped_exception for _ in range(batch_size)]
|
||||
|
||||
async def main_loop(self):
|
||||
async def main_loop(self) -> None:
|
||||
while True:
|
||||
# NOTE(simon): There's an issue when user updated batch size and
|
||||
# batch wait timeout during the execution, these values will not be
|
||||
@@ -338,12 +340,13 @@ class RayServeWorker:
|
||||
# it will not be raised.
|
||||
await asyncio.wait(all_evaluated_futures)
|
||||
|
||||
def update_config(self, new_config: BackendConfig):
|
||||
def update_config(self, new_config: BackendConfig) -> None:
|
||||
self.config = new_config
|
||||
self.batch_queue.set_config(self.config.max_batch_size or 1,
|
||||
self.config.batch_wait_timeout)
|
||||
|
||||
async def handle_request(self, request: Union[Query, bytes]):
|
||||
async def handle_request(self,
|
||||
request: Union[Query, bytes]) -> asyncio.Future:
|
||||
if isinstance(request, bytes):
|
||||
request = Query.ray_deserialize(request)
|
||||
logger.debug("Worker {} got request {}".format(self.replica_tag,
|
||||
|
||||
@@ -14,6 +14,9 @@ from ray.serve.kv_store import RayInternalKVStore
|
||||
from ray.serve.exceptions import RayServeException
|
||||
from ray.serve.utils import (format_actor_name, get_random_letters, logger,
|
||||
try_schedule_resources_on_nodes, get_all_node_ids)
|
||||
from ray.serve.config import BackendConfig, ReplicaConfig
|
||||
from ray.actor import ActorHandle
|
||||
from typing import Dict, List, Any, Tuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
@@ -31,12 +34,12 @@ CONTROL_LOOP_PERIOD_S = 1.0
|
||||
|
||||
|
||||
class TrafficPolicy:
|
||||
def __init__(self, traffic_dict):
|
||||
def __init__(self, traffic_dict: Dict[str, float]) -> None:
|
||||
self.traffic_dict = dict()
|
||||
self.shadow_dict = dict()
|
||||
self.set_traffic_dict(traffic_dict)
|
||||
|
||||
def set_traffic_dict(self, traffic_dict):
|
||||
def set_traffic_dict(self, traffic_dict: Dict[str, float]) -> None:
|
||||
prob = 0
|
||||
for backend, weight in traffic_dict.items():
|
||||
if weight < 0:
|
||||
@@ -52,7 +55,7 @@ class TrafficPolicy:
|
||||
"currently they sum to {}".format(prob))
|
||||
self.traffic_dict = traffic_dict
|
||||
|
||||
def set_shadow(self, backend, proportion):
|
||||
def set_shadow(self, backend: str, proportion: float):
|
||||
if proportion == 0 and backend in self.shadow_dict:
|
||||
del self.shadow_dict[backend]
|
||||
else:
|
||||
@@ -89,8 +92,8 @@ class ServeController:
|
||||
requires all implementations here to be idempotent.
|
||||
"""
|
||||
|
||||
async def __init__(self, instance_name, http_host, http_port,
|
||||
_http_middlewares):
|
||||
async def __init__(self, instance_name: str, http_host: str,
|
||||
http_port: str, _http_middlewares: List[Any]) -> None:
|
||||
# Unique name of the serve instance managed by this actor. Used to
|
||||
# namespace child actors and checkpoints.
|
||||
self.instance_name = instance_name
|
||||
@@ -160,7 +163,7 @@ class ServeController:
|
||||
|
||||
asyncio.get_event_loop().create_task(self.run_control_loop())
|
||||
|
||||
def _start_routers_if_needed(self):
|
||||
def _start_routers_if_needed(self) -> None:
|
||||
"""Start a router on every node if it doesn't already exist."""
|
||||
for node_id, node_resource in get_all_node_ids():
|
||||
if node_id in self.routers:
|
||||
@@ -192,7 +195,7 @@ class ServeController:
|
||||
|
||||
self.routers[node_id] = router
|
||||
|
||||
def _stop_routers_if_needed(self):
|
||||
def _stop_routers_if_needed(self) -> bool:
|
||||
"""Removes router actors from any nodes that no longer exist.
|
||||
|
||||
Returns whether or not any actors were removed (a checkpoint should
|
||||
@@ -214,15 +217,15 @@ class ServeController:
|
||||
|
||||
return checkpoint_required
|
||||
|
||||
def get_routers(self):
|
||||
def get_routers(self) -> Dict[str, ActorHandle]:
|
||||
"""Returns a dictionary of node ID to router actor handles."""
|
||||
return self.routers
|
||||
|
||||
def get_router_config(self):
|
||||
def get_router_config(self) -> Dict[str, Dict[str, Tuple[str, List[str]]]]:
|
||||
"""Called by the router on startup to fetch required state."""
|
||||
return self.routes
|
||||
|
||||
def _checkpoint(self):
|
||||
def _checkpoint(self) -> None:
|
||||
"""Checkpoint internal state and write it to the KV store."""
|
||||
assert self.write_lock.locked()
|
||||
logger.debug("Writing checkpoint")
|
||||
@@ -240,7 +243,7 @@ class ServeController:
|
||||
logger.warning("Intentionally crashing after checkpoint")
|
||||
os._exit(0)
|
||||
|
||||
async def _recover_from_checkpoint(self, checkpoint_bytes):
|
||||
async def _recover_from_checkpoint(self, checkpoint_bytes: bytes) -> None:
|
||||
"""Recover the instance state from the provided checkpoint.
|
||||
|
||||
Performs the following operations:
|
||||
@@ -332,7 +335,7 @@ class ServeController:
|
||||
|
||||
self.write_lock.release()
|
||||
|
||||
async def do_autoscale(self):
|
||||
async def do_autoscale(self) -> None:
|
||||
for backend in self.backends:
|
||||
if backend not in self.autoscaling_policies:
|
||||
continue
|
||||
@@ -344,7 +347,7 @@ class ServeController:
|
||||
await self.update_backend_config(
|
||||
backend, {"num_replicas": new_num_replicas})
|
||||
|
||||
async def run_control_loop(self):
|
||||
async def run_control_loop(self) -> None:
|
||||
while True:
|
||||
await self.do_autoscale()
|
||||
async with self.write_lock:
|
||||
@@ -355,26 +358,27 @@ class ServeController:
|
||||
|
||||
await asyncio.sleep(CONTROL_LOOP_PERIOD_S)
|
||||
|
||||
def get_backend_configs(self):
|
||||
def get_backend_configs(self) -> Dict[str, BackendConfig]:
|
||||
"""Fetched by the router on startup."""
|
||||
backend_configs = {}
|
||||
for backend, info in self.backends.items():
|
||||
backend_configs[backend] = info.backend_config
|
||||
return backend_configs
|
||||
|
||||
def get_traffic_policies(self):
|
||||
def get_traffic_policies(self) -> Dict[str, TrafficPolicy]:
|
||||
"""Fetched by the router on startup."""
|
||||
return self.traffic_policies
|
||||
|
||||
def _list_replicas(self, backend_tag):
|
||||
def _list_replicas(self, backend_tag: str) -> List[str]:
|
||||
"""Used only for testing."""
|
||||
return self.replicas[backend_tag]
|
||||
|
||||
def get_traffic_policy(self, endpoint):
|
||||
def get_traffic_policy(self, endpoint: str) -> TrafficPolicy:
|
||||
"""Fetched by serve handles."""
|
||||
return self.traffic_policies[endpoint]
|
||||
|
||||
async def _start_backend_worker(self, backend_tag, replica_tag):
|
||||
async def _start_backend_worker(self, backend_tag: str,
|
||||
replica_tag: str) -> ActorHandle:
|
||||
"""Creates a backend worker and waits for it to start up.
|
||||
|
||||
Assumes that the backend configuration has already been registered
|
||||
@@ -399,7 +403,7 @@ class ServeController:
|
||||
await worker_handle.ready.remote()
|
||||
return worker_handle
|
||||
|
||||
async def _start_replica(self, backend_tag, replica_tag):
|
||||
async def _start_replica(self, backend_tag: str, replica_tag: str) -> None:
|
||||
# NOTE(edoakes): the replicas may already be created if we
|
||||
# failed after creating them but before writing a
|
||||
# checkpoint.
|
||||
@@ -419,7 +423,7 @@ class ServeController:
|
||||
for router in self.routers.values()
|
||||
])
|
||||
|
||||
async def _start_pending_replicas(self):
|
||||
async def _start_pending_replicas(self) -> None:
|
||||
"""Starts the pending backend replicas in self.replicas_to_start.
|
||||
|
||||
Starts the worker, then pushes an update to the router to add it to
|
||||
@@ -439,7 +443,7 @@ class ServeController:
|
||||
|
||||
self.replicas_to_start.clear()
|
||||
|
||||
async def _stop_pending_replicas(self):
|
||||
async def _stop_pending_replicas(self) -> None:
|
||||
"""Stops the pending backend replicas in self.replicas_to_stop.
|
||||
|
||||
Removes workers from the router, kills them, and clears
|
||||
@@ -469,7 +473,7 @@ class ServeController:
|
||||
|
||||
self.replicas_to_stop.clear()
|
||||
|
||||
async def _remove_pending_backends(self):
|
||||
async def _remove_pending_backends(self) -> None:
|
||||
"""Removes the pending backends in self.backends_to_remove.
|
||||
|
||||
Clears self.backends_to_remove.
|
||||
@@ -481,7 +485,7 @@ class ServeController:
|
||||
])
|
||||
self.backends_to_remove.clear()
|
||||
|
||||
async def _remove_pending_endpoints(self):
|
||||
async def _remove_pending_endpoints(self) -> None:
|
||||
"""Removes the pending endpoints in self.endpoints_to_remove.
|
||||
|
||||
Clears self.endpoints_to_remove.
|
||||
@@ -493,7 +497,7 @@ class ServeController:
|
||||
])
|
||||
self.endpoints_to_remove.clear()
|
||||
|
||||
def _scale_replicas(self, backend_tag, num_replicas):
|
||||
def _scale_replicas(self, backend_tag: str, num_replicas: int) -> None:
|
||||
"""Scale the given backend to the number of replicas.
|
||||
|
||||
NOTE: this does not actually start or stop the replicas, but instead
|
||||
@@ -552,18 +556,18 @@ class ServeController:
|
||||
|
||||
self.replicas_to_stop[backend_tag].append(replica_tag)
|
||||
|
||||
def get_all_worker_handles(self):
|
||||
def get_all_worker_handles(self) -> Dict[str, Dict[str, ActorHandle]]:
|
||||
"""Fetched by the router on startup."""
|
||||
return self.workers
|
||||
|
||||
def get_all_backends(self):
|
||||
def get_all_backends(self) -> Dict[str, Dict[str, Any]]:
|
||||
"""Returns a dictionary of backend tag to backend config dict."""
|
||||
backends = {}
|
||||
for backend_tag, backend_info in self.backends.items():
|
||||
backends[backend_tag] = backend_info.backend_config.__dict__
|
||||
return backends
|
||||
|
||||
def get_all_endpoints(self):
|
||||
def get_all_endpoints(self) -> Dict[str, Dict[str, Any]]:
|
||||
"""Returns a dictionary of endpoint to endpoint config."""
|
||||
endpoints = {}
|
||||
for route, (endpoint, methods) in self.routes.items():
|
||||
@@ -583,7 +587,8 @@ class ServeController:
|
||||
}
|
||||
return endpoints
|
||||
|
||||
async def _set_traffic(self, endpoint_name, traffic_dict):
|
||||
async def _set_traffic(self, endpoint_name: str,
|
||||
traffic_dict: Dict[str, float]) -> None:
|
||||
if endpoint_name not in self.get_all_endpoints():
|
||||
raise ValueError("Attempted to assign traffic for an endpoint '{}'"
|
||||
" that is not registered.".format(endpoint_name))
|
||||
@@ -609,12 +614,14 @@ class ServeController:
|
||||
for router in self.routers.values()
|
||||
])
|
||||
|
||||
async def set_traffic(self, endpoint_name, traffic_dict):
|
||||
async def set_traffic(self, endpoint_name: str,
|
||||
traffic_dict: Dict[str, float]) -> None:
|
||||
"""Sets the traffic policy for the specified endpoint."""
|
||||
async with self.write_lock:
|
||||
await self._set_traffic(endpoint_name, traffic_dict)
|
||||
|
||||
async def shadow_traffic(self, endpoint_name, backend_tag, proportion):
|
||||
async def shadow_traffic(self, endpoint_name: str, backend_tag: str,
|
||||
proportion: float) -> None:
|
||||
"""Shadow traffic from the endpoint to the backend."""
|
||||
async with self.write_lock:
|
||||
if endpoint_name not in self.get_all_endpoints():
|
||||
@@ -641,7 +648,10 @@ class ServeController:
|
||||
) for router in self.routers.values()
|
||||
])
|
||||
|
||||
async def create_endpoint(self, endpoint, traffic_dict, route, methods):
|
||||
# TODO(architkulkarni): add optional type hints after upgrading cloudpickle
|
||||
async def create_endpoint(self, endpoint: str,
|
||||
traffic_dict: Dict[str, float], route,
|
||||
methods) -> None:
|
||||
"""Create a new endpoint with the specified route and methods.
|
||||
|
||||
If the route is None, this is a "headless" endpoint that will not
|
||||
@@ -686,7 +696,7 @@ class ServeController:
|
||||
for router in self.routers.values()
|
||||
])
|
||||
|
||||
async def delete_endpoint(self, endpoint):
|
||||
async def delete_endpoint(self, endpoint: str) -> None:
|
||||
"""Delete the specified endpoint.
|
||||
|
||||
Does not modify any corresponding backends.
|
||||
@@ -723,8 +733,9 @@ class ServeController:
|
||||
])
|
||||
await self._remove_pending_endpoints()
|
||||
|
||||
async def create_backend(self, backend_tag, backend_config,
|
||||
replica_config):
|
||||
async def create_backend(self, backend_tag: str,
|
||||
backend_config: BackendConfig,
|
||||
replica_config: ReplicaConfig) -> None:
|
||||
"""Register a new backend under the specified tag."""
|
||||
async with self.write_lock:
|
||||
# Ensures this method is idempotent.
|
||||
@@ -766,7 +777,7 @@ class ServeController:
|
||||
])
|
||||
await self.broadcast_backend_config(backend_tag)
|
||||
|
||||
async def delete_backend(self, backend_tag):
|
||||
async def delete_backend(self, backend_tag: str) -> None:
|
||||
async with self.write_lock:
|
||||
# This method must be idempotent. We should validate that the
|
||||
# specified backend exists on the client.
|
||||
@@ -801,7 +812,8 @@ class ServeController:
|
||||
await self._stop_pending_replicas()
|
||||
await self._remove_pending_backends()
|
||||
|
||||
async def update_backend_config(self, backend_tag, config_options):
|
||||
async def update_backend_config(self, backend_tag: str,
|
||||
config_options: Dict[str, Any]) -> None:
|
||||
"""Set the config for the specified backend."""
|
||||
async with self.write_lock:
|
||||
assert (backend_tag in self.backends
|
||||
@@ -831,7 +843,7 @@ class ServeController:
|
||||
|
||||
await self.broadcast_backend_config(backend_tag)
|
||||
|
||||
async def broadcast_backend_config(self, backend_tag):
|
||||
async def broadcast_backend_config(self, backend_tag: str) -> None:
|
||||
backend_config = self.backends[backend_tag].backend_config
|
||||
broadcast_futures = []
|
||||
for replica_tag in self.replicas[backend_tag]:
|
||||
@@ -845,13 +857,13 @@ class ServeController:
|
||||
if len(broadcast_futures) > 0:
|
||||
await asyncio.gather(*broadcast_futures)
|
||||
|
||||
def get_backend_config(self, backend_tag):
|
||||
def get_backend_config(self, backend_tag: str) -> BackendConfig:
|
||||
"""Get the current config for the specified backend."""
|
||||
assert (backend_tag in self.backends
|
||||
), "Backend {} is not registered.".format(backend_tag)
|
||||
return self.backends[backend_tag].backend_config
|
||||
|
||||
async def shutdown(self):
|
||||
async def shutdown(self) -> None:
|
||||
"""Shuts down the serve instance completely."""
|
||||
async with self.write_lock:
|
||||
for router in self.routers.values():
|
||||
@@ -861,7 +873,8 @@ class ServeController:
|
||||
ray.kill(replica, no_restart=True)
|
||||
self.kv_store.delete(CHECKPOINT_KEY)
|
||||
|
||||
async def report_queue_lengths(self, router_name, queue_lengths):
|
||||
async def report_queue_lengths(self, router_name: str,
|
||||
queue_lengths: Dict[str, int]):
|
||||
# TODO: remove old router stats when removing them.
|
||||
for backend, queue_length in queue_lengths.items():
|
||||
self.backend_stats[backend][router_name] = queue_length
|
||||
|
||||
Reference in New Issue
Block a user