[Serve] add type hints for controller and backend_worker (#10288)

This commit is contained in:
architkulkarni
2020-08-27 10:20:36 -07:00
committed by GitHub
parent f75dfd60a3
commit eea7a86163
2 changed files with 75 additions and 59 deletions
+22 -19
View File
@@ -5,7 +5,7 @@ from collections.abc import Iterable
from collections import defaultdict
from itertools import groupby
from operator import attrgetter
from typing import Union
from typing import Union, List, Any, Callable, Type
import time
import ray
@@ -20,32 +20,33 @@ from ray.serve.exceptions import RayServeException
from ray.experimental import metrics
from ray.serve.config import BackendConfig
from ray.serve.router import Query
from ray.exceptions import RayTaskError
logger = _get_logger()
class BatchQueue:
def __init__(self, max_batch_size, timeout_s):
def __init__(self, max_batch_size: int, timeout_s: float) -> None:
self.queue = asyncio.Queue()
self.full_batch_event = asyncio.Event()
self.max_batch_size = max_batch_size
self.timeout_s = timeout_s
def set_config(self, max_batch_size, timeout_s):
def set_config(self, max_batch_size: int, timeout_s: float) -> None:
self.max_batch_size = max_batch_size
self.timeout_s = timeout_s
def put(self, request):
def put(self, request: Query) -> None:
self.queue.put_nowait(request)
# Signal when the full batch is ready. The event will be reset
# in wait_for_batch.
if self.queue.qsize() == self.max_batch_size:
self.full_batch_event.set()
def qsize(self):
def qsize(self) -> int:
return self.queue.qsize()
async def wait_for_batch(self):
async def wait_for_batch(self) -> List[Query]:
"""Wait for batch respecting self.max_batch_size and self.timeout_s.
Returns a batch of up to self.max_batch_size items, waiting for up
@@ -89,7 +90,7 @@ class BatchQueue:
return batch
def create_backend_worker(func_or_class):
def create_backend_worker(func_or_class: Union[Callable, Type[Callable]]):
"""Creates a worker class wrapping the provided function or class."""
if inspect.isfunction(func_or_class):
@@ -99,6 +100,7 @@ def create_backend_worker(func_or_class):
else:
assert False, "func_or_class must be function or class."
# TODO(architkulkarni): Add type hints after upgrading cloudpickle
class RayServeWrappedWorker(object):
def __init__(self,
backend_tag,
@@ -129,7 +131,7 @@ def create_backend_worker(func_or_class):
return RayServeWrappedWorker
def wrap_to_ray_error(exception):
def wrap_to_ray_error(exception: Exception) -> RayTaskError:
"""Utility method to wrap exceptions in user code."""
try:
@@ -140,7 +142,7 @@ def wrap_to_ray_error(exception):
return ray.exceptions.RayTaskError(str(e), traceback_str, e.__class__)
def ensure_async(func):
def ensure_async(func: Callable) -> Callable:
if inspect.iscoroutinefunction(func):
return func
else:
@@ -150,8 +152,8 @@ def ensure_async(func):
class RayServeWorker:
"""Handles requests with the provided callable."""
def __init__(self, backend_tag, replica_tag, _callable,
backend_config: BackendConfig, is_function):
def __init__(self, backend_tag: str, replica_tag: str, _callable: Callable,
backend_config: BackendConfig, is_function: bool) -> None:
self.backend_tag = backend_tag
self.replica_tag = replica_tag
self.callable = _callable
@@ -182,7 +184,7 @@ class RayServeWorker:
asyncio.get_event_loop().create_task(self.main_loop())
def get_runner_method(self, request_item):
def get_runner_method(self, request_item: Query) -> Callable:
method_name = request_item.call_method
if not hasattr(self.callable, method_name):
raise RayServeException("Backend doesn't have method {} "
@@ -193,7 +195,7 @@ class RayServeWorker:
return self.callable
return getattr(self.callable, method_name)
def has_positional_args(self, f):
def has_positional_args(self, f: Callable) -> bool:
# NOTE:
# In the case of simple functions, not actors, the f will be
# function.__call__, but we need to inspect the function itself.
@@ -207,13 +209,13 @@ class RayServeWorker:
return True
return False
def _reset_context(self):
def _reset_context(self) -> None:
# NOTE(simon): context management won't work in async mode because
# many concurrent queries might be running at the same time.
serve_context.web = None
serve_context.batch_size = None
async def invoke_single(self, request_item):
async def invoke_single(self, request_item: Query) -> Any:
args, kwargs, is_web_context = parse_request_item(request_item)
serve_context.web = is_web_context
@@ -231,7 +233,7 @@ class RayServeWorker:
return result
async def invoke_batch(self, request_item_list):
async def invoke_batch(self, request_item_list: List[Query]) -> List[Any]:
arg_list = []
kwargs_list = defaultdict(list)
context_flags = set()
@@ -308,7 +310,7 @@ class RayServeWorker:
self._reset_context()
return [wrapped_exception for _ in range(batch_size)]
async def main_loop(self):
async def main_loop(self) -> None:
while True:
# NOTE(simon): There's an issue when user updated batch size and
# batch wait timeout during the execution, these values will not be
@@ -338,12 +340,13 @@ class RayServeWorker:
# it will not be raised.
await asyncio.wait(all_evaluated_futures)
def update_config(self, new_config: BackendConfig):
def update_config(self, new_config: BackendConfig) -> None:
self.config = new_config
self.batch_queue.set_config(self.config.max_batch_size or 1,
self.config.batch_wait_timeout)
async def handle_request(self, request: Union[Query, bytes]):
async def handle_request(self,
request: Union[Query, bytes]) -> asyncio.Future:
if isinstance(request, bytes):
request = Query.ray_deserialize(request)
logger.debug("Worker {} got request {}".format(self.replica_tag,
+53 -40
View File
@@ -14,6 +14,9 @@ from ray.serve.kv_store import RayInternalKVStore
from ray.serve.exceptions import RayServeException
from ray.serve.utils import (format_actor_name, get_random_letters, logger,
try_schedule_resources_on_nodes, get_all_node_ids)
from ray.serve.config import BackendConfig, ReplicaConfig
from ray.actor import ActorHandle
from typing import Dict, List, Any, Tuple
import numpy as np
@@ -31,12 +34,12 @@ CONTROL_LOOP_PERIOD_S = 1.0
class TrafficPolicy:
def __init__(self, traffic_dict):
def __init__(self, traffic_dict: Dict[str, float]) -> None:
self.traffic_dict = dict()
self.shadow_dict = dict()
self.set_traffic_dict(traffic_dict)
def set_traffic_dict(self, traffic_dict):
def set_traffic_dict(self, traffic_dict: Dict[str, float]) -> None:
prob = 0
for backend, weight in traffic_dict.items():
if weight < 0:
@@ -52,7 +55,7 @@ class TrafficPolicy:
"currently they sum to {}".format(prob))
self.traffic_dict = traffic_dict
def set_shadow(self, backend, proportion):
def set_shadow(self, backend: str, proportion: float):
if proportion == 0 and backend in self.shadow_dict:
del self.shadow_dict[backend]
else:
@@ -89,8 +92,8 @@ class ServeController:
requires all implementations here to be idempotent.
"""
async def __init__(self, instance_name, http_host, http_port,
_http_middlewares):
async def __init__(self, instance_name: str, http_host: str,
http_port: str, _http_middlewares: List[Any]) -> None:
# Unique name of the serve instance managed by this actor. Used to
# namespace child actors and checkpoints.
self.instance_name = instance_name
@@ -160,7 +163,7 @@ class ServeController:
asyncio.get_event_loop().create_task(self.run_control_loop())
def _start_routers_if_needed(self):
def _start_routers_if_needed(self) -> None:
"""Start a router on every node if it doesn't already exist."""
for node_id, node_resource in get_all_node_ids():
if node_id in self.routers:
@@ -192,7 +195,7 @@ class ServeController:
self.routers[node_id] = router
def _stop_routers_if_needed(self):
def _stop_routers_if_needed(self) -> bool:
"""Removes router actors from any nodes that no longer exist.
Returns whether or not any actors were removed (a checkpoint should
@@ -214,15 +217,15 @@ class ServeController:
return checkpoint_required
def get_routers(self):
def get_routers(self) -> Dict[str, ActorHandle]:
"""Returns a dictionary of node ID to router actor handles."""
return self.routers
def get_router_config(self):
def get_router_config(self) -> Dict[str, Dict[str, Tuple[str, List[str]]]]:
"""Called by the router on startup to fetch required state."""
return self.routes
def _checkpoint(self):
def _checkpoint(self) -> None:
"""Checkpoint internal state and write it to the KV store."""
assert self.write_lock.locked()
logger.debug("Writing checkpoint")
@@ -240,7 +243,7 @@ class ServeController:
logger.warning("Intentionally crashing after checkpoint")
os._exit(0)
async def _recover_from_checkpoint(self, checkpoint_bytes):
async def _recover_from_checkpoint(self, checkpoint_bytes: bytes) -> None:
"""Recover the instance state from the provided checkpoint.
Performs the following operations:
@@ -332,7 +335,7 @@ class ServeController:
self.write_lock.release()
async def do_autoscale(self):
async def do_autoscale(self) -> None:
for backend in self.backends:
if backend not in self.autoscaling_policies:
continue
@@ -344,7 +347,7 @@ class ServeController:
await self.update_backend_config(
backend, {"num_replicas": new_num_replicas})
async def run_control_loop(self):
async def run_control_loop(self) -> None:
while True:
await self.do_autoscale()
async with self.write_lock:
@@ -355,26 +358,27 @@ class ServeController:
await asyncio.sleep(CONTROL_LOOP_PERIOD_S)
def get_backend_configs(self):
def get_backend_configs(self) -> Dict[str, BackendConfig]:
"""Fetched by the router on startup."""
backend_configs = {}
for backend, info in self.backends.items():
backend_configs[backend] = info.backend_config
return backend_configs
def get_traffic_policies(self):
def get_traffic_policies(self) -> Dict[str, TrafficPolicy]:
"""Fetched by the router on startup."""
return self.traffic_policies
def _list_replicas(self, backend_tag):
def _list_replicas(self, backend_tag: str) -> List[str]:
"""Used only for testing."""
return self.replicas[backend_tag]
def get_traffic_policy(self, endpoint):
def get_traffic_policy(self, endpoint: str) -> TrafficPolicy:
"""Fetched by serve handles."""
return self.traffic_policies[endpoint]
async def _start_backend_worker(self, backend_tag, replica_tag):
async def _start_backend_worker(self, backend_tag: str,
replica_tag: str) -> ActorHandle:
"""Creates a backend worker and waits for it to start up.
Assumes that the backend configuration has already been registered
@@ -399,7 +403,7 @@ class ServeController:
await worker_handle.ready.remote()
return worker_handle
async def _start_replica(self, backend_tag, replica_tag):
async def _start_replica(self, backend_tag: str, replica_tag: str) -> None:
# NOTE(edoakes): the replicas may already be created if we
# failed after creating them but before writing a
# checkpoint.
@@ -419,7 +423,7 @@ class ServeController:
for router in self.routers.values()
])
async def _start_pending_replicas(self):
async def _start_pending_replicas(self) -> None:
"""Starts the pending backend replicas in self.replicas_to_start.
Starts the worker, then pushes an update to the router to add it to
@@ -439,7 +443,7 @@ class ServeController:
self.replicas_to_start.clear()
async def _stop_pending_replicas(self):
async def _stop_pending_replicas(self) -> None:
"""Stops the pending backend replicas in self.replicas_to_stop.
Removes workers from the router, kills them, and clears
@@ -469,7 +473,7 @@ class ServeController:
self.replicas_to_stop.clear()
async def _remove_pending_backends(self):
async def _remove_pending_backends(self) -> None:
"""Removes the pending backends in self.backends_to_remove.
Clears self.backends_to_remove.
@@ -481,7 +485,7 @@ class ServeController:
])
self.backends_to_remove.clear()
async def _remove_pending_endpoints(self):
async def _remove_pending_endpoints(self) -> None:
"""Removes the pending endpoints in self.endpoints_to_remove.
Clears self.endpoints_to_remove.
@@ -493,7 +497,7 @@ class ServeController:
])
self.endpoints_to_remove.clear()
def _scale_replicas(self, backend_tag, num_replicas):
def _scale_replicas(self, backend_tag: str, num_replicas: int) -> None:
"""Scale the given backend to the number of replicas.
NOTE: this does not actually start or stop the replicas, but instead
@@ -552,18 +556,18 @@ class ServeController:
self.replicas_to_stop[backend_tag].append(replica_tag)
def get_all_worker_handles(self):
def get_all_worker_handles(self) -> Dict[str, Dict[str, ActorHandle]]:
"""Fetched by the router on startup."""
return self.workers
def get_all_backends(self):
def get_all_backends(self) -> Dict[str, Dict[str, Any]]:
"""Returns a dictionary of backend tag to backend config dict."""
backends = {}
for backend_tag, backend_info in self.backends.items():
backends[backend_tag] = backend_info.backend_config.__dict__
return backends
def get_all_endpoints(self):
def get_all_endpoints(self) -> Dict[str, Dict[str, Any]]:
"""Returns a dictionary of endpoint to endpoint config."""
endpoints = {}
for route, (endpoint, methods) in self.routes.items():
@@ -583,7 +587,8 @@ class ServeController:
}
return endpoints
async def _set_traffic(self, endpoint_name, traffic_dict):
async def _set_traffic(self, endpoint_name: str,
traffic_dict: Dict[str, float]) -> None:
if endpoint_name not in self.get_all_endpoints():
raise ValueError("Attempted to assign traffic for an endpoint '{}'"
" that is not registered.".format(endpoint_name))
@@ -609,12 +614,14 @@ class ServeController:
for router in self.routers.values()
])
async def set_traffic(self, endpoint_name, traffic_dict):
async def set_traffic(self, endpoint_name: str,
traffic_dict: Dict[str, float]) -> None:
"""Sets the traffic policy for the specified endpoint."""
async with self.write_lock:
await self._set_traffic(endpoint_name, traffic_dict)
async def shadow_traffic(self, endpoint_name, backend_tag, proportion):
async def shadow_traffic(self, endpoint_name: str, backend_tag: str,
proportion: float) -> None:
"""Shadow traffic from the endpoint to the backend."""
async with self.write_lock:
if endpoint_name not in self.get_all_endpoints():
@@ -641,7 +648,10 @@ class ServeController:
) for router in self.routers.values()
])
async def create_endpoint(self, endpoint, traffic_dict, route, methods):
# TODO(architkulkarni): add optional type hints after upgrading cloudpickle
async def create_endpoint(self, endpoint: str,
traffic_dict: Dict[str, float], route,
methods) -> None:
"""Create a new endpoint with the specified route and methods.
If the route is None, this is a "headless" endpoint that will not
@@ -686,7 +696,7 @@ class ServeController:
for router in self.routers.values()
])
async def delete_endpoint(self, endpoint):
async def delete_endpoint(self, endpoint: str) -> None:
"""Delete the specified endpoint.
Does not modify any corresponding backends.
@@ -723,8 +733,9 @@ class ServeController:
])
await self._remove_pending_endpoints()
async def create_backend(self, backend_tag, backend_config,
replica_config):
async def create_backend(self, backend_tag: str,
backend_config: BackendConfig,
replica_config: ReplicaConfig) -> None:
"""Register a new backend under the specified tag."""
async with self.write_lock:
# Ensures this method is idempotent.
@@ -766,7 +777,7 @@ class ServeController:
])
await self.broadcast_backend_config(backend_tag)
async def delete_backend(self, backend_tag):
async def delete_backend(self, backend_tag: str) -> None:
async with self.write_lock:
# This method must be idempotent. We should validate that the
# specified backend exists on the client.
@@ -801,7 +812,8 @@ class ServeController:
await self._stop_pending_replicas()
await self._remove_pending_backends()
async def update_backend_config(self, backend_tag, config_options):
async def update_backend_config(self, backend_tag: str,
config_options: Dict[str, Any]) -> None:
"""Set the config for the specified backend."""
async with self.write_lock:
assert (backend_tag in self.backends
@@ -831,7 +843,7 @@ class ServeController:
await self.broadcast_backend_config(backend_tag)
async def broadcast_backend_config(self, backend_tag):
async def broadcast_backend_config(self, backend_tag: str) -> None:
backend_config = self.backends[backend_tag].backend_config
broadcast_futures = []
for replica_tag in self.replicas[backend_tag]:
@@ -845,13 +857,13 @@ class ServeController:
if len(broadcast_futures) > 0:
await asyncio.gather(*broadcast_futures)
def get_backend_config(self, backend_tag):
def get_backend_config(self, backend_tag: str) -> BackendConfig:
"""Get the current config for the specified backend."""
assert (backend_tag in self.backends
), "Backend {} is not registered.".format(backend_tag)
return self.backends[backend_tag].backend_config
async def shutdown(self):
async def shutdown(self) -> None:
"""Shuts down the serve instance completely."""
async with self.write_lock:
for router in self.routers.values():
@@ -861,7 +873,8 @@ class ServeController:
ray.kill(replica, no_restart=True)
self.kv_store.delete(CHECKPOINT_KEY)
async def report_queue_lengths(self, router_name, queue_lengths):
async def report_queue_lengths(self, router_name: str,
queue_lengths: Dict[str, int]):
# TODO: remove old router stats when removing them.
for backend, queue_length in queue_lengths.items():
self.backend_stats[backend][router_name] = queue_length