From eea7a861631c2d9d8dc20a318aedc82bd99055bf Mon Sep 17 00:00:00 2001 From: architkulkarni Date: Thu, 27 Aug 2020 10:20:36 -0700 Subject: [PATCH] [Serve] add type hints for controller and backend_worker (#10288) --- python/ray/serve/backend_worker.py | 41 +++++++------ python/ray/serve/controller.py | 93 +++++++++++++++++------------- 2 files changed, 75 insertions(+), 59 deletions(-) diff --git a/python/ray/serve/backend_worker.py b/python/ray/serve/backend_worker.py index 736458699..11cd5bdc7 100644 --- a/python/ray/serve/backend_worker.py +++ b/python/ray/serve/backend_worker.py @@ -5,7 +5,7 @@ from collections.abc import Iterable from collections import defaultdict from itertools import groupby from operator import attrgetter -from typing import Union +from typing import Union, List, Any, Callable, Type import time import ray @@ -20,32 +20,33 @@ from ray.serve.exceptions import RayServeException from ray.experimental import metrics from ray.serve.config import BackendConfig from ray.serve.router import Query +from ray.exceptions import RayTaskError logger = _get_logger() class BatchQueue: - def __init__(self, max_batch_size, timeout_s): + def __init__(self, max_batch_size: int, timeout_s: float) -> None: self.queue = asyncio.Queue() self.full_batch_event = asyncio.Event() self.max_batch_size = max_batch_size self.timeout_s = timeout_s - def set_config(self, max_batch_size, timeout_s): + def set_config(self, max_batch_size: int, timeout_s: float) -> None: self.max_batch_size = max_batch_size self.timeout_s = timeout_s - def put(self, request): + def put(self, request: Query) -> None: self.queue.put_nowait(request) # Signal when the full batch is ready. The event will be reset # in wait_for_batch. if self.queue.qsize() == self.max_batch_size: self.full_batch_event.set() - def qsize(self): + def qsize(self) -> int: return self.queue.qsize() - async def wait_for_batch(self): + async def wait_for_batch(self) -> List[Query]: """Wait for batch respecting self.max_batch_size and self.timeout_s. Returns a batch of up to self.max_batch_size items, waiting for up @@ -89,7 +90,7 @@ class BatchQueue: return batch -def create_backend_worker(func_or_class): +def create_backend_worker(func_or_class: Union[Callable, Type[Callable]]): """Creates a worker class wrapping the provided function or class.""" if inspect.isfunction(func_or_class): @@ -99,6 +100,7 @@ def create_backend_worker(func_or_class): else: assert False, "func_or_class must be function or class." + # TODO(architkulkarni): Add type hints after upgrading cloudpickle class RayServeWrappedWorker(object): def __init__(self, backend_tag, @@ -129,7 +131,7 @@ def create_backend_worker(func_or_class): return RayServeWrappedWorker -def wrap_to_ray_error(exception): +def wrap_to_ray_error(exception: Exception) -> RayTaskError: """Utility method to wrap exceptions in user code.""" try: @@ -140,7 +142,7 @@ def wrap_to_ray_error(exception): return ray.exceptions.RayTaskError(str(e), traceback_str, e.__class__) -def ensure_async(func): +def ensure_async(func: Callable) -> Callable: if inspect.iscoroutinefunction(func): return func else: @@ -150,8 +152,8 @@ def ensure_async(func): class RayServeWorker: """Handles requests with the provided callable.""" - def __init__(self, backend_tag, replica_tag, _callable, - backend_config: BackendConfig, is_function): + def __init__(self, backend_tag: str, replica_tag: str, _callable: Callable, + backend_config: BackendConfig, is_function: bool) -> None: self.backend_tag = backend_tag self.replica_tag = replica_tag self.callable = _callable @@ -182,7 +184,7 @@ class RayServeWorker: asyncio.get_event_loop().create_task(self.main_loop()) - def get_runner_method(self, request_item): + def get_runner_method(self, request_item: Query) -> Callable: method_name = request_item.call_method if not hasattr(self.callable, method_name): raise RayServeException("Backend doesn't have method {} " @@ -193,7 +195,7 @@ class RayServeWorker: return self.callable return getattr(self.callable, method_name) - def has_positional_args(self, f): + def has_positional_args(self, f: Callable) -> bool: # NOTE: # In the case of simple functions, not actors, the f will be # function.__call__, but we need to inspect the function itself. @@ -207,13 +209,13 @@ class RayServeWorker: return True return False - def _reset_context(self): + def _reset_context(self) -> None: # NOTE(simon): context management won't work in async mode because # many concurrent queries might be running at the same time. serve_context.web = None serve_context.batch_size = None - async def invoke_single(self, request_item): + async def invoke_single(self, request_item: Query) -> Any: args, kwargs, is_web_context = parse_request_item(request_item) serve_context.web = is_web_context @@ -231,7 +233,7 @@ class RayServeWorker: return result - async def invoke_batch(self, request_item_list): + async def invoke_batch(self, request_item_list: List[Query]) -> List[Any]: arg_list = [] kwargs_list = defaultdict(list) context_flags = set() @@ -308,7 +310,7 @@ class RayServeWorker: self._reset_context() return [wrapped_exception for _ in range(batch_size)] - async def main_loop(self): + async def main_loop(self) -> None: while True: # NOTE(simon): There's an issue when user updated batch size and # batch wait timeout during the execution, these values will not be @@ -338,12 +340,13 @@ class RayServeWorker: # it will not be raised. await asyncio.wait(all_evaluated_futures) - def update_config(self, new_config: BackendConfig): + def update_config(self, new_config: BackendConfig) -> None: self.config = new_config self.batch_queue.set_config(self.config.max_batch_size or 1, self.config.batch_wait_timeout) - async def handle_request(self, request: Union[Query, bytes]): + async def handle_request(self, + request: Union[Query, bytes]) -> asyncio.Future: if isinstance(request, bytes): request = Query.ray_deserialize(request) logger.debug("Worker {} got request {}".format(self.replica_tag, diff --git a/python/ray/serve/controller.py b/python/ray/serve/controller.py index c66f2c53b..5ff50acc3 100644 --- a/python/ray/serve/controller.py +++ b/python/ray/serve/controller.py @@ -14,6 +14,9 @@ from ray.serve.kv_store import RayInternalKVStore from ray.serve.exceptions import RayServeException from ray.serve.utils import (format_actor_name, get_random_letters, logger, try_schedule_resources_on_nodes, get_all_node_ids) +from ray.serve.config import BackendConfig, ReplicaConfig +from ray.actor import ActorHandle +from typing import Dict, List, Any, Tuple import numpy as np @@ -31,12 +34,12 @@ CONTROL_LOOP_PERIOD_S = 1.0 class TrafficPolicy: - def __init__(self, traffic_dict): + def __init__(self, traffic_dict: Dict[str, float]) -> None: self.traffic_dict = dict() self.shadow_dict = dict() self.set_traffic_dict(traffic_dict) - def set_traffic_dict(self, traffic_dict): + def set_traffic_dict(self, traffic_dict: Dict[str, float]) -> None: prob = 0 for backend, weight in traffic_dict.items(): if weight < 0: @@ -52,7 +55,7 @@ class TrafficPolicy: "currently they sum to {}".format(prob)) self.traffic_dict = traffic_dict - def set_shadow(self, backend, proportion): + def set_shadow(self, backend: str, proportion: float): if proportion == 0 and backend in self.shadow_dict: del self.shadow_dict[backend] else: @@ -89,8 +92,8 @@ class ServeController: requires all implementations here to be idempotent. """ - async def __init__(self, instance_name, http_host, http_port, - _http_middlewares): + async def __init__(self, instance_name: str, http_host: str, + http_port: str, _http_middlewares: List[Any]) -> None: # Unique name of the serve instance managed by this actor. Used to # namespace child actors and checkpoints. self.instance_name = instance_name @@ -160,7 +163,7 @@ class ServeController: asyncio.get_event_loop().create_task(self.run_control_loop()) - def _start_routers_if_needed(self): + def _start_routers_if_needed(self) -> None: """Start a router on every node if it doesn't already exist.""" for node_id, node_resource in get_all_node_ids(): if node_id in self.routers: @@ -192,7 +195,7 @@ class ServeController: self.routers[node_id] = router - def _stop_routers_if_needed(self): + def _stop_routers_if_needed(self) -> bool: """Removes router actors from any nodes that no longer exist. Returns whether or not any actors were removed (a checkpoint should @@ -214,15 +217,15 @@ class ServeController: return checkpoint_required - def get_routers(self): + def get_routers(self) -> Dict[str, ActorHandle]: """Returns a dictionary of node ID to router actor handles.""" return self.routers - def get_router_config(self): + def get_router_config(self) -> Dict[str, Dict[str, Tuple[str, List[str]]]]: """Called by the router on startup to fetch required state.""" return self.routes - def _checkpoint(self): + def _checkpoint(self) -> None: """Checkpoint internal state and write it to the KV store.""" assert self.write_lock.locked() logger.debug("Writing checkpoint") @@ -240,7 +243,7 @@ class ServeController: logger.warning("Intentionally crashing after checkpoint") os._exit(0) - async def _recover_from_checkpoint(self, checkpoint_bytes): + async def _recover_from_checkpoint(self, checkpoint_bytes: bytes) -> None: """Recover the instance state from the provided checkpoint. Performs the following operations: @@ -332,7 +335,7 @@ class ServeController: self.write_lock.release() - async def do_autoscale(self): + async def do_autoscale(self) -> None: for backend in self.backends: if backend not in self.autoscaling_policies: continue @@ -344,7 +347,7 @@ class ServeController: await self.update_backend_config( backend, {"num_replicas": new_num_replicas}) - async def run_control_loop(self): + async def run_control_loop(self) -> None: while True: await self.do_autoscale() async with self.write_lock: @@ -355,26 +358,27 @@ class ServeController: await asyncio.sleep(CONTROL_LOOP_PERIOD_S) - def get_backend_configs(self): + def get_backend_configs(self) -> Dict[str, BackendConfig]: """Fetched by the router on startup.""" backend_configs = {} for backend, info in self.backends.items(): backend_configs[backend] = info.backend_config return backend_configs - def get_traffic_policies(self): + def get_traffic_policies(self) -> Dict[str, TrafficPolicy]: """Fetched by the router on startup.""" return self.traffic_policies - def _list_replicas(self, backend_tag): + def _list_replicas(self, backend_tag: str) -> List[str]: """Used only for testing.""" return self.replicas[backend_tag] - def get_traffic_policy(self, endpoint): + def get_traffic_policy(self, endpoint: str) -> TrafficPolicy: """Fetched by serve handles.""" return self.traffic_policies[endpoint] - async def _start_backend_worker(self, backend_tag, replica_tag): + async def _start_backend_worker(self, backend_tag: str, + replica_tag: str) -> ActorHandle: """Creates a backend worker and waits for it to start up. Assumes that the backend configuration has already been registered @@ -399,7 +403,7 @@ class ServeController: await worker_handle.ready.remote() return worker_handle - async def _start_replica(self, backend_tag, replica_tag): + async def _start_replica(self, backend_tag: str, replica_tag: str) -> None: # NOTE(edoakes): the replicas may already be created if we # failed after creating them but before writing a # checkpoint. @@ -419,7 +423,7 @@ class ServeController: for router in self.routers.values() ]) - async def _start_pending_replicas(self): + async def _start_pending_replicas(self) -> None: """Starts the pending backend replicas in self.replicas_to_start. Starts the worker, then pushes an update to the router to add it to @@ -439,7 +443,7 @@ class ServeController: self.replicas_to_start.clear() - async def _stop_pending_replicas(self): + async def _stop_pending_replicas(self) -> None: """Stops the pending backend replicas in self.replicas_to_stop. Removes workers from the router, kills them, and clears @@ -469,7 +473,7 @@ class ServeController: self.replicas_to_stop.clear() - async def _remove_pending_backends(self): + async def _remove_pending_backends(self) -> None: """Removes the pending backends in self.backends_to_remove. Clears self.backends_to_remove. @@ -481,7 +485,7 @@ class ServeController: ]) self.backends_to_remove.clear() - async def _remove_pending_endpoints(self): + async def _remove_pending_endpoints(self) -> None: """Removes the pending endpoints in self.endpoints_to_remove. Clears self.endpoints_to_remove. @@ -493,7 +497,7 @@ class ServeController: ]) self.endpoints_to_remove.clear() - def _scale_replicas(self, backend_tag, num_replicas): + def _scale_replicas(self, backend_tag: str, num_replicas: int) -> None: """Scale the given backend to the number of replicas. NOTE: this does not actually start or stop the replicas, but instead @@ -552,18 +556,18 @@ class ServeController: self.replicas_to_stop[backend_tag].append(replica_tag) - def get_all_worker_handles(self): + def get_all_worker_handles(self) -> Dict[str, Dict[str, ActorHandle]]: """Fetched by the router on startup.""" return self.workers - def get_all_backends(self): + def get_all_backends(self) -> Dict[str, Dict[str, Any]]: """Returns a dictionary of backend tag to backend config dict.""" backends = {} for backend_tag, backend_info in self.backends.items(): backends[backend_tag] = backend_info.backend_config.__dict__ return backends - def get_all_endpoints(self): + def get_all_endpoints(self) -> Dict[str, Dict[str, Any]]: """Returns a dictionary of endpoint to endpoint config.""" endpoints = {} for route, (endpoint, methods) in self.routes.items(): @@ -583,7 +587,8 @@ class ServeController: } return endpoints - async def _set_traffic(self, endpoint_name, traffic_dict): + async def _set_traffic(self, endpoint_name: str, + traffic_dict: Dict[str, float]) -> None: if endpoint_name not in self.get_all_endpoints(): raise ValueError("Attempted to assign traffic for an endpoint '{}'" " that is not registered.".format(endpoint_name)) @@ -609,12 +614,14 @@ class ServeController: for router in self.routers.values() ]) - async def set_traffic(self, endpoint_name, traffic_dict): + async def set_traffic(self, endpoint_name: str, + traffic_dict: Dict[str, float]) -> None: """Sets the traffic policy for the specified endpoint.""" async with self.write_lock: await self._set_traffic(endpoint_name, traffic_dict) - async def shadow_traffic(self, endpoint_name, backend_tag, proportion): + async def shadow_traffic(self, endpoint_name: str, backend_tag: str, + proportion: float) -> None: """Shadow traffic from the endpoint to the backend.""" async with self.write_lock: if endpoint_name not in self.get_all_endpoints(): @@ -641,7 +648,10 @@ class ServeController: ) for router in self.routers.values() ]) - async def create_endpoint(self, endpoint, traffic_dict, route, methods): + # TODO(architkulkarni): add optional type hints after upgrading cloudpickle + async def create_endpoint(self, endpoint: str, + traffic_dict: Dict[str, float], route, + methods) -> None: """Create a new endpoint with the specified route and methods. If the route is None, this is a "headless" endpoint that will not @@ -686,7 +696,7 @@ class ServeController: for router in self.routers.values() ]) - async def delete_endpoint(self, endpoint): + async def delete_endpoint(self, endpoint: str) -> None: """Delete the specified endpoint. Does not modify any corresponding backends. @@ -723,8 +733,9 @@ class ServeController: ]) await self._remove_pending_endpoints() - async def create_backend(self, backend_tag, backend_config, - replica_config): + async def create_backend(self, backend_tag: str, + backend_config: BackendConfig, + replica_config: ReplicaConfig) -> None: """Register a new backend under the specified tag.""" async with self.write_lock: # Ensures this method is idempotent. @@ -766,7 +777,7 @@ class ServeController: ]) await self.broadcast_backend_config(backend_tag) - async def delete_backend(self, backend_tag): + async def delete_backend(self, backend_tag: str) -> None: async with self.write_lock: # This method must be idempotent. We should validate that the # specified backend exists on the client. @@ -801,7 +812,8 @@ class ServeController: await self._stop_pending_replicas() await self._remove_pending_backends() - async def update_backend_config(self, backend_tag, config_options): + async def update_backend_config(self, backend_tag: str, + config_options: Dict[str, Any]) -> None: """Set the config for the specified backend.""" async with self.write_lock: assert (backend_tag in self.backends @@ -831,7 +843,7 @@ class ServeController: await self.broadcast_backend_config(backend_tag) - async def broadcast_backend_config(self, backend_tag): + async def broadcast_backend_config(self, backend_tag: str) -> None: backend_config = self.backends[backend_tag].backend_config broadcast_futures = [] for replica_tag in self.replicas[backend_tag]: @@ -845,13 +857,13 @@ class ServeController: if len(broadcast_futures) > 0: await asyncio.gather(*broadcast_futures) - def get_backend_config(self, backend_tag): + def get_backend_config(self, backend_tag: str) -> BackendConfig: """Get the current config for the specified backend.""" assert (backend_tag in self.backends ), "Backend {} is not registered.".format(backend_tag) return self.backends[backend_tag].backend_config - async def shutdown(self): + async def shutdown(self) -> None: """Shuts down the serve instance completely.""" async with self.write_lock: for router in self.routers.values(): @@ -861,7 +873,8 @@ class ServeController: ray.kill(replica, no_restart=True) self.kv_store.delete(CHECKPOINT_KEY) - async def report_queue_lengths(self, router_name, queue_lengths): + async def report_queue_lengths(self, router_name: str, + queue_lengths: Dict[str, int]): # TODO: remove old router stats when removing them. for backend, queue_length in queue_lengths.items(): self.backend_stats[backend][router_name] = queue_length