[serve] Async controller (#13111)

This commit is contained in:
Ian Rodney
2020-12-31 08:51:33 -08:00
committed by GitHub
parent 7120f3a6ab
commit acb082fc47
2 changed files with 123 additions and 60 deletions
+117 -56
View File
@@ -49,7 +49,7 @@ BackendTag = str
EndpointTag = str
ReplicaTag = str
NodeId = str
GoalId = int
GoalId = UUID
Duration = float
@@ -172,12 +172,13 @@ class BackendInfo(BaseModel):
class BackendState:
def __init__(self, checkpoint: bytes = None):
self.backends: Dict[BackendTag, BackendInfo] = dict()
self.goals: Dict[BackendTag, GoalId] = dict()
if checkpoint is not None:
self.backends = pickle.loads(checkpoint)
self.backends, self.goals = pickle.loads(checkpoint)
def checkpoint(self):
return pickle.dumps(self.backends)
return pickle.dumps([self.backends, self.goals])
def get_backend_configs(self) -> Dict[BackendTag, BackendConfig]:
return {
@@ -188,11 +189,38 @@ class BackendState:
def get_backend(self, backend_tag: BackendTag) -> Optional[BackendInfo]:
return self.backends.get(backend_tag)
def add_backend(self,
backend_tag: BackendTag,
backend_info: BackendInfo,
goal_id: GoalId = 0) -> None:
def _set_backend_goal(self, backend_tag: BackendTag,
backend_info: Optional[BackendInfo],
goal_id: GoalId) -> Optional[GoalId]:
existing_goal = self.goals.get(backend_tag)
self.backends[backend_tag] = backend_info
if not backend_info:
del self.backends[backend_tag]
self.goals[backend_tag] = goal_id
return existing_goal
def completed_goals(
self,
current_replicas: Dict[BackendTag, Dict[ReplicaTag, ActorHandle]]
) -> List[GoalId]:
completed_goals = []
all_tags = set(current_replicas.keys()).union(
set(self.backends.keys()))
for backend_tag in all_tags:
desired_info = self.backends.get(backend_tag)
existing_info = current_replicas.get(backend_tag)
# Check for deleting
if (not desired_info or
desired_info.backend_config.num_replicas == 0) and \
(not existing_info or len(existing_info) == 0):
completed_goals.append(self.goals[backend_tag])
# Check for a non-zero number of backends
if desired_info and existing_info and desired_info.backend_config.\
num_replicas == len(existing_info):
completed_goals.append(self.goals[backend_tag])
return completed_goals
class EndpointState:
@@ -362,10 +390,9 @@ class ActorStateReconciler:
-delta_num_replicas, backend_tag))
assert len(
self.backend_replicas[backend_tag]) >= delta_num_replicas
replicas_copy = self.backend_replicas.copy()
for _ in range(-delta_num_replicas):
replica_tag, _ = self.backend_replicas[backend_tag].popitem()
if len(self.backend_replicas[backend_tag]) == 0:
del self.backend_replicas[backend_tag]
replica_tag, _ = replicas_copy[backend_tag].popitem()
graceful_timeout_s = (backend_info.backend_config.
experimental_graceful_shutdown_timeout_s)
@@ -450,39 +477,48 @@ class ActorStateReconciler:
(backend_tag,
replica_tag) = self.currently_stopping_replicas.pop(fut)
backend = self.backend_replicas_to_stop.get(backend_tag)
backend_to_stop = self.backend_replicas_to_stop.get(
backend_tag)
if backend:
if backend_to_stop:
try:
backend.remove(replica_tag)
backend_to_stop.remove(replica_tag)
except ValueError:
pass
if len(backend) == 0:
if len(backend_to_stop) == 0:
del self.backend_replicas_to_stop[backend_tag]
backend = self.backend_replicas.get(backend_tag)
if backend:
try:
del backend[replica_tag]
except KeyError:
pass
if len(self.backend_replicas[backend_tag]) == 0:
del self.backend_replicas[backend_tag]
return len(in_flight)
async def backend_control_loop(self):
start = time.time()
prev_warning = start
need_to_continue = True
num_pending_starts, num_pending_stops = 0, 0
while need_to_continue:
if time.time() - prev_warning > REPLICA_STARTUP_TIME_WARNING_S:
prev_warning = time.time()
delta = time.time() - start
logger.warning(
f"Waited {delta:.2f}s for {num_pending_starts} replicas "
f"to start up or {num_pending_stops} replicas to shutdown."
" Make sure there are enough resources to create the "
"replicas.")
async def update_actor_state(self, start_time: float) -> bool:
"""Returns whether the number of backends has changed."""
num_starting = len(self.currently_starting_replicas)
num_stopping = len(self.currently_stopping_replicas)
num_pending_starts = await self._check_currently_starting_replicas(
)
num_pending_stops = await self._check_currently_stopping_replicas()
need_to_continue = num_pending_starts or num_pending_stops
num_pending_starts = await self._check_currently_starting_replicas()
num_pending_stops = await self._check_currently_stopping_replicas()
time_running = int(time.time() - start_time)
if (time_running > 0
and time_running % REPLICA_STARTUP_TIME_WARNING_S == 0):
delta = time.time() - start_time
logger.warning(
f"Waited {delta:.2f}s for {num_pending_starts} replicas "
f"to start up or {num_pending_stops} replicas to shutdown."
" Make sure there are enough resources to create the "
"replicas.")
asyncio.sleep(1)
return (len(self.currently_starting_replicas) != num_starting) or \
(len(self.currently_stopping_replicas) != num_stopping)
def _recover_actor_handles(self) -> None:
# Fetch actor handles for all of the backend replicas in the system.
@@ -510,7 +546,6 @@ class ActorStateReconciler:
# Start/stop any pending backend replicas.
await self._enqueue_pending_scale_changes_loop(backend_state)
await self.backend_control_loop()
return autoscaling_policies
@@ -613,7 +648,9 @@ class ServeController:
asyncio.get_event_loop().create_task(self.run_control_loop())
async def wait_for_event(self, uuid: UUID) -> bool:
start = time.time()
if uuid not in self.inflight_results:
logger.debug(f"UUID ({uuid}) not found!!!")
return True
event = self.inflight_results[uuid]
await event.wait()
@@ -621,6 +658,7 @@ class ServeController:
self._serializable_inflight_results.pop(uuid)
async with self.write_lock:
self._checkpoint()
logger.debug(f"Waiting for {uuid} took {time.time() - start} seconds")
return True
@@ -631,8 +669,8 @@ class ServeController:
# NOTE(ilr) Must be called before checkpointing!
event = asyncio.Event()
event.result = FutureResult(goal_state)
event.set()
uuid_val = recreation_uuid or uuid4()
logger.debug(f"Creating uuid {uuid_val} for result of {goal_state}")
self.inflight_results[uuid_val] = event
self._serializable_inflight_results[uuid_val] = event.result
return uuid_val
@@ -757,12 +795,30 @@ class ServeController:
async def reconcile_current_and_goal_backends(self):
pass
def set_goal_id(self, goal_id: UUID) -> None:
event = self.inflight_results.get(goal_id)
logger.debug(f"Setting Goal Id: {goal_id}")
if event:
event.set()
async def run_control_loop(self) -> None:
start_time = time.time()
while True:
await self.do_autoscale()
async with self.write_lock:
self.http_state.update()
delta_workers = await self.actor_reconciler.update_actor_state(
start_time)
if delta_workers:
self.notify_replica_handles_changed()
self.notify_backend_configs_changed()
self._checkpoint()
else:
start_time = time.time()
completed_ids = self.backend_state.completed_goals(
self.actor_reconciler.backend_replicas)
for done_id in completed_ids:
self.set_goal_id(done_id)
await asyncio.sleep(CONTROL_LOOP_PERIOD_S)
def _all_replica_handles(
@@ -804,6 +860,7 @@ class ServeController:
# update.
self._checkpoint()
self.notify_traffic_policies_changed()
self.set_goal_id(return_uuid)
return return_uuid
async def set_traffic(self, endpoint_name: str,
@@ -841,6 +898,7 @@ class ServeController:
# update.
self._checkpoint()
self.notify_traffic_policies_changed()
self.set_goal_id(return_uuid)
return return_uuid
# TODO(architkulkarni): add Optional for route after cloudpickle upgrade
@@ -923,8 +981,18 @@ class ServeController:
# after pushing the update.
self._checkpoint()
self.notify_route_table_changed()
self.set_goal_id(return_uuid)
return return_uuid
async def set_backend_goal(self, backend_tag: BackendTag,
backend_info: BackendInfo,
new_id: GoalId) -> None:
# NOTE(ilr) Must checkpoint after doing this!
existing_id_to_set = self.backend_state._set_backend_goal(
backend_tag, backend_info, new_id)
if existing_id_to_set:
self.set_goal_id(existing_id_to_set)
async def create_backend(self, backend_tag: BackendTag,
backend_config: BackendConfig,
replica_config: ReplicaConfig) -> UUID:
@@ -946,13 +1014,18 @@ class ServeController:
worker_class=backend_replica,
backend_config=backend_config,
replica_config=replica_config)
self.backend_state.add_backend(backend_tag, backend_info)
metadata = backend_config.internal_metadata
if metadata.autoscaling_config is not None:
self.autoscaling_policies[
backend_tag] = BasicAutoscalingPolicy(
backend_tag, metadata.autoscaling_config)
return_uuid = self._create_event_with_result({
backend_tag: backend_info
})
await self.set_backend_goal(backend_tag, backend_info, return_uuid)
try:
# This call should be to run control loop
self.actor_reconciler._scale_backend_replicas(
@@ -962,22 +1035,13 @@ class ServeController:
del self.backend_state.backends[backend_tag]
raise e
return_uuid = self._create_event_with_result({
backend_tag: backend_info
})
# NOTE(edoakes): we must write a checkpoint before starting new
# or pushing the updated config to avoid inconsistent state if we
# crash while making the change.
self._checkpoint()
await self.actor_reconciler._enqueue_pending_scale_changes_loop(
self.backend_state)
await self.actor_reconciler.backend_control_loop()
self.notify_replica_handles_changed()
# Set the backend config inside routers
# (particularly for max_concurrent_queries).
self.notify_backend_configs_changed()
return return_uuid
async def delete_backend(self,
@@ -1001,7 +1065,6 @@ class ServeController:
# Scale its replicas down to 0. This will also remove the backend
# from self.backend_state.backends and
# self.actor_reconciler.backend_replicas.
# This should be a call to the control loop
self.actor_reconciler._scale_backend_replicas(
@@ -1016,15 +1079,14 @@ class ServeController:
self.actor_reconciler.backends_to_remove.append(backend_tag)
return_uuid = self._create_event_with_result({backend_tag: None})
# Remove the backend's metadata.
await self.set_backend_goal(backend_tag, None, return_uuid)
# NOTE(edoakes): we must write a checkpoint before removing the
# backend from the routers to avoid inconsistent state if we crash
# after pushing the update.
self._checkpoint()
await self.actor_reconciler._enqueue_pending_scale_changes_loop(
self.backend_state)
await self.actor_reconciler.backend_control_loop()
self.notify_replica_handles_changed()
return return_uuid
async def update_backend_config(self, backend_tag: BackendTag,
@@ -1044,6 +1106,11 @@ class ServeController:
backend_tag).backend_config = backend_config
backend_info = self.backend_state.get_backend(backend_tag)
return_uuid = self._create_event_with_result({
backend_tag: backend_info
})
await self.set_backend_goal(backend_tag, backend_info, return_uuid)
# Scale the replicas with the new configuration.
# This should be to run the control loop
@@ -1051,9 +1118,6 @@ class ServeController:
self.backend_state.backends, backend_tag,
backend_config.num_replicas)
return_uuid = self._create_event_with_result({
backend_tag: backend_info
})
# NOTE(edoakes): we must write a checkpoint before pushing the
# update to avoid inconsistent state if we crash after pushing the
# update.
@@ -1064,10 +1128,7 @@ class ServeController:
await self.actor_reconciler._enqueue_pending_scale_changes_loop(
self.backend_state)
await self.actor_reconciler.backend_control_loop()
self.notify_replica_handles_changed()
self.notify_backend_configs_changed()
return return_uuid
def get_backend_config(self, backend_tag: BackendTag) -> BackendConfig:
+6 -4
View File
@@ -226,8 +226,9 @@ def test_create_backend_idempotent(serve_instance):
for i in range(10):
ray.get(
controller.create_backend.remote("my_backend", backend_config,
replica_config))
controller.wait_for_event.remote(
controller.create_backend.remote("my_backend", backend_config,
replica_config)))
assert len(ray.get(controller.get_all_backends.remote())) == 1
client.create_endpoint(
@@ -248,8 +249,9 @@ def test_create_endpoint_idempotent(serve_instance):
for i in range(10):
ray.get(
controller.create_endpoint.remote(
"my_endpoint", {"my_backend": 1.0}, "/my_route", ["GET"]))
controller.wait_for_event.remote(
controller.create_endpoint.remote(
"my_endpoint", {"my_backend": 1.0}, "/my_route", ["GET"])))
assert len(ray.get(controller.get_all_endpoints.remote())) == 1
assert requests.get("http://127.0.0.1:8000/my_route").text == "hello"