[Serve] Cleanup Router Implementation (#8718)

This commit is contained in:
Simon Mo
2020-06-01 19:21:28 -07:00
committed by GitHub
parent dcf58a43dc
commit 4cef1ee591
3 changed files with 54 additions and 77 deletions
+7 -6
View File
@@ -17,7 +17,7 @@ class EndpointPolicy:
__metaclass__ = ABCMeta
@abstractmethod
async def flush(self, endpoint_queue, backend_queues):
def flush(self, endpoint_queue, backend_queues):
"""Flush the endpoint queue into the given backend queues.
This method should assign each query in the endpoint_queue to a
@@ -27,8 +27,8 @@ class EndpointPolicy:
knows which backend_queues to flush.
Arguments:
endpoint_queue: asyncio.Queue containing queries to assign.
backend_queues: Dict(str, asyncio.Queue) mapping backend tags to
endpoint_queue: deque containing queries to assign.
backend_queues: Dict(str, deque) mapping backend tags to
their corresponding query queues.
Returns:
@@ -51,19 +51,20 @@ class RandomEndpointPolicy(EndpointPolicy):
self.backend_names, self.backend_weights = zip(
*sorted(traffic_dict.items()))
async def flush(self, endpoint_queue, backend_queues):
def flush(self, endpoint_queue, backend_queues):
if len(self.backend_names) == 0:
logger.info("No backends to assign traffic to.")
return set()
assigned_backends = set()
while endpoint_queue.qsize():
query = await endpoint_queue.get()
while len(endpoint_queue) > 0:
query = endpoint_queue.pop()
if query.shard_key is None:
rstate = np.random
else:
sha256_seed = sha256(query.shard_key.encode("utf-8"))
seed = np.frombuffer(sha256_seed.digest(), dtype=np.uint32)
# Note(simon): This constructor takes 100+us, maybe cache this?
rstate = np.random.RandomState(seed)
chosen_backend = rstate.choice(
+46 -70
View File
@@ -1,18 +1,11 @@
import asyncio
import copy
from collections import defaultdict
from collections import defaultdict, deque
import time
from typing import DefaultDict, List
# Note on choosing blist instead of stdlib heapq
# 1. pop operation should be O(1) (amortized)
# (helpful even for batched pop)
# 2. There should not be significant overhead in
# maintaining the sorted list.
# 3. The blist implementation is fast and uses C extensions.
import blist
import ray
import ray.cloudpickle as pickle
from ray.exceptions import RayTaskError
@@ -91,22 +84,7 @@ def _make_future_unwrapper(client_futures: List[asyncio.Future],
class Router:
"""A router that routes request to available workers.
The traffic policy is used to assign requests to workers.
Traffic policy splits the traffic among different replicas
probabilistically:
1. When all backends are ready to receive traffic, we will randomly
choose a backend based on the weights assigned by the traffic policy
dictionary.
2. When more than 1 but not all backends are ready, we will normalize the
weights of the ready backends to 1 and choose a backend via sampling.
3. When there is only 1 backend ready, we will only use that backend.
"""
"""A router that routes request to available workers."""
async def __init__(self, cluster_name=None):
# Note: Several queues are used in the router
@@ -122,11 +100,11 @@ class Router:
# -- Queues -- #
# endpoint_name -> request queue
self.endpoint_queues: DefaultDict[asyncio.Queue[Query]] = defaultdict(
asyncio.Queue)
# backend_name -> worker request queue
self.worker_queues: DefaultDict[asyncio.Queue[
ray.actor.ActorHandle]] = defaultdict(asyncio.Queue)
# We use FIFO (left to right) ordering. The new items should be added
# using appendleft. Old items should be removed via pop().
self.endpoint_queues: DefaultDict[deque[Query]] = defaultdict(deque)
# backend_name -> worker replica tag queue
self.worker_queues: DefaultDict[deque[str]] = defaultdict(deque)
# backend_name -> worker payload queue
self.backend_queues = defaultdict(blist.sortedlist)
@@ -150,6 +128,7 @@ class Router:
# batching polcies.
self.flush_lock = asyncio.Lock()
# -- State Restoration -- #
# Fetch the worker handles, traffic policies, and backend configs from
# the master actor. We use a "pull-based" approach instead of pushing
# them from the master so that the router can transparently recover
@@ -173,6 +152,7 @@ class Router:
for backend, backend_config in backend_configs.items():
await self.set_backend_config(backend, backend_config)
# -- Metric Registration -- #
[metric_exporter] = retry_actor_failures(
master_actor.get_metric_exporter)
self.metric_client = MetricClient(metric_exporter)
@@ -215,9 +195,9 @@ class Router:
call_method=request_meta.call_method,
shard_key=request_meta.shard_key,
async_future=asyncio.get_event_loop().create_future())
await self.endpoint_queues[endpoint].put(query)
async with self.flush_lock:
await self.flush_endpoint_queue(endpoint)
self.endpoint_queues[endpoint].appendleft(query)
self.flush_endpoint_queue(endpoint)
# Note: a future change can be to directly return the ObjectID from
# replica task submission
@@ -241,39 +221,38 @@ class Router:
if backend_replica_tag not in self.replicas:
return
await self.worker_queues[backend_tag].put(backend_replica_tag)
async with self.flush_lock:
await self.flush_backend_queues([backend_tag])
self.worker_queues[backend_tag].appendleft(backend_replica_tag)
self.flush_backend_queues([backend_tag])
async def remove_worker(self, backend_tag, replica_tag):
backend_replica_tag = backend_tag + ":" + replica_tag
if backend_replica_tag not in self.replicas:
return
del self.replicas[backend_replica_tag]
# We need this lock because we modify worker_queue here.
async with self.flush_lock:
old_queue = self.worker_queues[backend_tag]
new_queue = asyncio.Queue()
del self.replicas[backend_replica_tag]
while not old_queue.empty():
curr_tag = await old_queue.get()
if curr_tag != backend_replica_tag:
await new_queue.put(curr_tag)
self.worker_queues[backend_tag] = new_queue
try:
self.worker_queues[backend_tag].remove(backend_replica_tag)
except ValueError:
# Replica doesn't exist in the idle worker queues.
# It's ok because the worker might not have returned the
# result.
pass
async def set_traffic(self, endpoint, traffic_dict):
logger.debug("Setting traffic for endpoint %s to %s", endpoint,
traffic_dict)
async with self.flush_lock:
self.traffic[endpoint] = RandomEndpointPolicy(traffic_dict)
await self.flush_endpoint_queue(endpoint)
self.flush_endpoint_queue(endpoint)
async def remove_endpoint(self, endpoint):
logger.debug("Removing endpoint {}".format(endpoint))
async with self.flush_lock:
await self.flush_endpoint_queue(endpoint)
self.flush_endpoint_queue(endpoint)
if endpoint in self.endpoint_queues:
del self.endpoint_queues[endpoint]
if endpoint in self.traffic:
@@ -282,12 +261,13 @@ class Router:
async def set_backend_config(self, backend, config):
logger.debug("Setting backend config for "
"backend {} to {}.".format(backend, config))
self.backend_info[backend] = config
async with self.flush_lock:
self.backend_info[backend] = config
async def remove_backend(self, backend):
logger.debug("Removing backend {}".format(backend))
async with self.flush_lock:
await self.flush_backend_queues([backend])
self.flush_backend_queues([backend])
if backend in self.backend_info:
del self.backend_info[backend]
if backend in self.worker_queues:
@@ -295,30 +275,21 @@ class Router:
if backend in self.backend_queues:
del self.backend_queues[backend]
async def flush_endpoint_queue(self, endpoint):
def flush_endpoint_queue(self, endpoint):
"""Attempt to schedule any pending requests to available backends."""
assert self.flush_lock.locked()
if endpoint not in self.traffic:
return
backends_to_flush = await self.traffic[endpoint].flush(
backends_to_flush = self.traffic[endpoint].flush(
self.endpoint_queues[endpoint], self.backend_queues)
await self.flush_backend_queues(backends_to_flush)
def _get_available_backends(self, endpoint):
backends_in_policy = set(self.traffic[endpoint].keys())
available_workers = {
backend
for backend, queues in self.worker_queues.items()
if queues.qsize() > 0
}
return list(backends_in_policy.intersection(available_workers))
self.flush_backend_queues(backends_to_flush)
# Flushes the specified backend queues and assigns work to workers.
async def flush_backend_queues(self, backends_to_flush):
def flush_backend_queues(self, backends_to_flush):
assert self.flush_lock.locked()
for backend in backends_to_flush:
# No workers available.
if self.worker_queues[backend].qsize() == 0:
if len(self.worker_queues[backend]) == 0:
continue
# No work to do.
if len(self.backend_queues[backend]) == 0:
@@ -329,14 +300,14 @@ class Router:
logger.debug("Assigning queries for backend {} with buffer "
"queue size {} and worker queue size {}".format(
backend, len(buffer_queue), worker_queue.qsize()))
backend, len(buffer_queue), len(worker_queue)))
max_batch_size = None
if backend in self.backend_info:
max_batch_size = self.backend_info[backend].max_batch_size
await self._assign_query_to_worker(backend, buffer_queue,
worker_queue, max_batch_size)
self._assign_query_to_worker(backend, buffer_queue, worker_queue,
max_batch_size)
async def _do_query(self, backend, backend_replica_tag, req):
# If the worker died, this will be a RayActorError. Just return it and
@@ -353,14 +324,19 @@ class Router:
logger.debug("Got result in {:.2f}s".format(time.time() - start))
return result
async def _assign_query_to_worker(self,
backend,
buffer_queue,
worker_queue,
max_batch_size=None):
def _assign_query_to_worker(self,
backend,
buffer_queue,
worker_queue,
max_batch_size=None):
while len(buffer_queue) and len(worker_queue):
backend_replica_tag = worker_queue.pop()
# The replica might have been deleted already.
if backend_replica_tag not in self.replicas:
continue
while len(buffer_queue) and worker_queue.qsize():
backend_replica_tag = await worker_queue.get()
if max_batch_size is None: # No batching
request = buffer_queue.pop(0)
future = asyncio.get_event_loop().create_task(
+1 -1
View File
@@ -127,7 +127,7 @@ async def test_split_traffic_random(serve_instance, task_runner_mock_actor):
async def test_queue_remove_replicas(serve_instance):
class TestRouter(Router):
def worker_queue_size(self, backend):
return self.worker_queues["backend-remove"].qsize()
return len(self.worker_queues["backend-remove"])
temp_actor = mock_task_runner()
q = ray.remote(TestRouter).remote()