[serve] Create CurrentState & GoalState (#12369)

This commit is contained in:
Ian Rodney
2020-11-30 17:34:30 -08:00
committed by GitHub
parent 234df9091e
commit e422ace053
+97 -86
View File
@@ -43,6 +43,7 @@ BackendTag = str
EndpointTag = str
ReplicaTag = str
NodeId = str
GoalId = int
class TrafficPolicy:
@@ -91,13 +92,17 @@ class BackendInfo(BaseModel):
@dataclass
class ConfigurationStore:
class SystemState:
backends: Dict[BackendTag, BackendInfo] = field(default_factory=dict)
traffic_policies: Dict[EndpointTag, TrafficPolicy] = field(
default_factory=dict)
routes: Dict[BackendTag, Tuple[EndpointTag, Any]] = field(
default_factory=dict)
backend_goal_ids: Dict[BackendTag, GoalId] = field(default_factory=dict)
traffic_goal_ids: Dict[EndpointTag, GoalId] = field(default_factory=dict)
route_goal_ids: Dict[BackendTag, GoalId] = field(default_factory=dict)
def get_backend_configs(self) -> Dict[BackendTag, BackendConfig]:
return {
tag: info.backend_config
@@ -107,9 +112,31 @@ class ConfigurationStore:
def get_backend(self, backend_tag: BackendTag) -> Optional[BackendInfo]:
return self.backends.get(backend_tag)
def add_backend(self, backend_tag: BackendTag,
backend_info: BackendInfo) -> None:
def add_backend(self,
backend_tag: BackendTag,
backend_info: BackendInfo,
goal_id: GoalId = 0) -> None:
self.backends[backend_tag] = backend_info
self.backend_goal_ids = goal_id
def get_endpoints(self) -> Dict[EndpointTag, Dict[str, Any]]:
endpoints = {}
for route, (endpoint, methods) in self.routes.items():
if endpoint in self.traffic_policies:
traffic_policy = self.traffic_policies[endpoint]
traffic_dict = traffic_policy.traffic_dict
shadow_dict = traffic_policy.shadow_dict
else:
traffic_dict = {}
shadow_dict = {}
endpoints[endpoint] = {
"route": route if route.startswith("/") else None,
"methods": methods,
"traffic": traffic_dict,
"shadows": shadow_dict,
}
return endpoints
@dataclass
@@ -147,7 +174,7 @@ class ActorStateReconciler:
]))
async def _start_pending_backend_replicas(
self, config_store: ConfigurationStore) -> None:
self, current_state: SystemState) -> None:
"""Starts the pending backend replicas in self.backend_replicas_to_start.
Waits for replicas to start up, then removes them from
@@ -158,7 +185,7 @@ class ActorStateReconciler:
items():
for replica_tag in replicas_to_create:
replica_handle = await self._start_backend_replica(
config_store, backend_tag, replica_tag)
current_state, backend_tag, replica_tag)
ready_future = replica_handle.ready.remote().as_future()
fut_to_replica_info[ready_future] = (backend_tag, replica_tag,
replica_handle)
@@ -182,15 +209,14 @@ class ActorStateReconciler:
self.backend_replicas_to_start.clear()
async def _start_backend_replica(self, config_store: ConfigurationStore,
async def _start_backend_replica(self, current_state: SystemState,
backend_tag: BackendTag,
replica_tag: ReplicaTag) -> ActorHandle:
"""Start a replica and return its actor handle.
Checks if the named actor already exists before starting a new one.
Assumes that the backend configuration has already been registered
in the ConfigurationStore.
Assumes that the backend configuration is already in the Goal State.
"""
# NOTE(edoakes): the replicas may already be created if we
# failed after creating them but before writing a
@@ -201,7 +227,7 @@ class ActorStateReconciler:
except ValueError:
logger.debug("Starting replica '{}' for backend '{}'.".format(
replica_tag, backend_tag))
backend_info = config_store.get_backend(backend_tag)
backend_info = current_state.get_backend(backend_tag)
replica_handle = ray.remote(backend_info.worker_class).options(
name=replica_name,
@@ -373,20 +399,19 @@ class ActorStateReconciler:
replica_tag] = ray.get_actor(replica_name)
async def _recover_from_checkpoint(
self, config_store: ConfigurationStore,
controller: "ServeController"
self, current_state: SystemState, controller: "ServeController"
) -> Dict[BackendTag, BasicAutoscalingPolicy]:
self._recover_actor_handles()
autoscaling_policies = dict()
for backend, info in config_store.backends.items():
for backend, info in current_state.backends.items():
metadata = info.backend_config.internal_metadata
if metadata.autoscaling_config is not None:
autoscaling_policies[backend] = BasicAutoscalingPolicy(
backend, metadata.autoscaling_config)
# Start/stop any pending backend replicas.
await self._start_pending_backend_replicas(config_store)
await self._start_pending_backend_replicas(current_state)
await self._stop_pending_backend_replicas()
return autoscaling_policies
@@ -394,8 +419,10 @@ class ActorStateReconciler:
@dataclass
class Checkpoint:
config: ConfigurationStore
goal_state: SystemState
current_state: SystemState
reconciler: ActorStateReconciler
# TODO(ilr) Rename reconciler to PendingState
@ray.remote
@@ -432,8 +459,12 @@ class ServeController:
detached: bool = False):
# Used to read/write checkpoints.
self.kv_store = RayInternalKVStore(namespace=controller_name)
# ConfigurationStore
self.configuration_store = ConfigurationStore()
# Current State
self.current_state = SystemState()
# Goal State
# TODO(ilr) This is currently *unused* until the refactor of the serve
# controller.
self.goal_state = SystemState()
# ActorStateReconciler
self.actor_reconciler = ActorStateReconciler(controller_name, detached)
@@ -497,12 +528,12 @@ class ServeController:
})
def notify_traffic_policies_changed(self):
self.long_poll_host.notify_changed(
"traffic_policies", self.configuration_store.traffic_policies)
self.long_poll_host.notify_changed("traffic_policies",
self.current_state.traffic_policies)
def notify_backend_configs_changed(self):
self.long_poll_host.notify_changed(
"backend_configs", self.configuration_store.get_backend_configs())
"backend_configs", self.current_state.get_backend_configs())
async def listen_for_change(self, keys_to_snapshot_ids: Dict[str, int]):
"""Proxy long pull client's listen request.
@@ -519,9 +550,9 @@ class ServeController:
"""Returns a dictionary of node ID to router actor handles."""
return self.actor_reconciler.routers_cache
def get_router_config(self) -> Dict[str, Dict[str, Tuple[str, List[str]]]]:
def get_router_config(self) -> Dict[str, Tuple[str, List[str]]]:
"""Called by the router on startup to fetch required state."""
return self.configuration_store.routes
return self.current_state.routes
def _checkpoint(self) -> None:
"""Checkpoint internal state and write it to the KV store."""
@@ -530,7 +561,8 @@ class ServeController:
start = time.time()
checkpoint = pickle.dumps(
Checkpoint(self.configuration_store, self.actor_reconciler))
Checkpoint(self.goal_state, self.current_state,
self.actor_reconciler))
self.kv_store.put(CHECKPOINT_KEY, checkpoint)
logger.debug("Wrote checkpoint in {:.2f}".format(time.time() - start))
@@ -559,14 +591,14 @@ class ServeController:
logger.info("Recovering from checkpoint")
restored_checkpoint: Checkpoint = pickle.loads(checkpoint_bytes)
# Restore ConfigurationStore
self.configuration_store = restored_checkpoint.config
# Restore SystemState
self.current_state = restored_checkpoint.current_state
# Restore ActorStateReconciler
self.actor_reconciler = restored_checkpoint.reconciler
self.autoscaling_policies = await self.actor_reconciler.\
_recover_from_checkpoint(self.configuration_store, self)
_recover_from_checkpoint(self.current_state, self)
logger.info(
"Recovered from checkpoint in {:.3f}s".format(time.time() - start))
@@ -574,7 +606,7 @@ class ServeController:
self.write_lock.release()
async def do_autoscale(self) -> None:
for backend, info in self.configuration_store.backends.items():
for backend, info in self.current_state.backends.items():
if backend not in self.autoscaling_policies:
continue
@@ -599,11 +631,11 @@ class ServeController:
def get_backend_configs(self) -> Dict[str, BackendConfig]:
"""Fetched by the router on startup."""
return self.configuration_store.get_backend_configs()
return self.current_state.get_backend_configs()
def get_traffic_policies(self) -> Dict[str, TrafficPolicy]:
"""Fetched by the router on startup."""
return self.configuration_store.traffic_policies
return self.current_state.traffic_policies
def _list_replicas(self, backend_tag: BackendTag) -> List[ReplicaTag]:
"""Used only for testing."""
@@ -611,7 +643,7 @@ class ServeController:
def get_traffic_policy(self, endpoint: str) -> TrafficPolicy:
"""Fetched by serve handles."""
return self.configuration_store.traffic_policies[endpoint]
return self.current_state.traffic_policies[endpoint]
def get_all_replica_handles(self) -> Dict[str, Dict[str, ActorHandle]]:
"""Fetched by the router on startup."""
@@ -619,33 +651,14 @@ class ServeController:
def get_all_backends(self) -> Dict[str, BackendConfig]:
"""Returns a dictionary of backend tag to backend config."""
return self.configuration_store.get_backend_configs()
return self.current_state.get_backend_configs()
def get_all_endpoints(self) -> Dict[str, Dict[str, Any]]:
"""Returns a dictionary of endpoint to endpoint config."""
endpoints = {}
for route, (endpoint,
methods) in self.configuration_store.routes.items():
if endpoint in self.configuration_store.traffic_policies:
traffic_policy = self.configuration_store.traffic_policies[
endpoint]
traffic_dict = traffic_policy.traffic_dict
shadow_dict = traffic_policy.shadow_dict
else:
traffic_dict = {}
shadow_dict = {}
endpoints[endpoint] = {
"route": route if route.startswith("/") else None,
"methods": methods,
"traffic": traffic_dict,
"shadows": shadow_dict,
}
return endpoints
return self.current_state.get_endpoints()
async def _set_traffic(self, endpoint_name: str,
traffic_dict: Dict[str, float]) -> None:
if endpoint_name not in self.get_all_endpoints():
if endpoint_name not in self.current_state.get_endpoints():
raise ValueError("Attempted to assign traffic for an endpoint '{}'"
" that is not registered.".format(endpoint_name))
@@ -653,14 +666,13 @@ class ServeController:
dict), "Traffic policy must be a dictionary."
for backend in traffic_dict:
if self.configuration_store.get_backend(backend) is None:
if self.current_state.get_backend(backend) is None:
raise ValueError(
"Attempted to assign traffic to a backend '{}' that "
"is not registered.".format(backend))
traffic_policy = TrafficPolicy(traffic_dict)
self.configuration_store.traffic_policies[
endpoint_name] = traffic_policy
self.current_state.traffic_policies[endpoint_name] = traffic_policy
# NOTE(edoakes): we must write a checkpoint before pushing the
# update to avoid inconsistent state if we crash after pushing the
@@ -679,18 +691,18 @@ class ServeController:
proportion: float) -> None:
"""Shadow traffic from the endpoint to the backend."""
async with self.write_lock:
if endpoint_name not in self.get_all_endpoints():
if endpoint_name not in self.current_state.get_endpoints():
raise ValueError("Attempted to shadow traffic from an "
"endpoint '{}' that is not registered."
.format(endpoint_name))
if self.configuration_store.get_backend(backend_tag) is None:
if self.current_state.get_backend(backend_tag) is None:
raise ValueError(
"Attempted to shadow traffic to a backend '{}' that "
"is not registered.".format(backend_tag))
self.configuration_store.traffic_policies[
endpoint_name].set_shadow(backend_tag, proportion)
self.current_state.traffic_policies[endpoint_name].set_shadow(
backend_tag, proportion)
# NOTE(edoakes): we must write a checkpoint before pushing the
# update to avoid inconsistent state if we crash after pushing the
@@ -717,11 +729,10 @@ class ServeController:
# TODO(edoakes): move this to client side.
err_prefix = "Cannot create endpoint."
if route in self.configuration_store.routes:
if route in self.current_state.routes:
# Ensures this method is idempotent
if self.configuration_store.routes[route] == (endpoint,
methods):
if self.current_state.routes[route] == (endpoint, methods):
return
else:
@@ -729,7 +740,7 @@ class ServeController:
"{} Route '{}' is already registered.".format(
err_prefix, route))
if endpoint in self.get_all_endpoints():
if endpoint in self.current_state.get_endpoints():
raise ValueError(
"{} Endpoint '{}' is already registered.".format(
err_prefix, endpoint))
@@ -738,12 +749,12 @@ class ServeController:
"Registering route '{}' to endpoint '{}' with methods '{}'.".
format(route, endpoint, methods))
self.configuration_store.routes[route] = (endpoint, methods)
self.current_state.routes[route] = (endpoint, methods)
# NOTE(edoakes): checkpoint is written in self._set_traffic.
await self._set_traffic(endpoint, traffic_dict)
await asyncio.gather(*[
router.set_route_table.remote(self.configuration_store.routes)
router.set_route_table.remote(self.current_state.routes)
for router in self.actor_reconciler.router_handles()
])
@@ -757,7 +768,7 @@ class ServeController:
# This method must be idempotent. We should validate that the
# specified endpoint exists on the client.
for route, (route_endpoint,
_) in self.configuration_store.routes.items():
_) in self.current_state.routes.items():
if route_endpoint == endpoint:
route_to_delete = route
break
@@ -766,11 +777,11 @@ class ServeController:
return
# Remove the routing entry.
del self.configuration_store.routes[route_to_delete]
del self.current_state.routes[route_to_delete]
# Remove the traffic policy entry if it exists.
if endpoint in self.configuration_store.traffic_policies:
del self.configuration_store.traffic_policies[endpoint]
if endpoint in self.current_state.traffic_policies:
del self.current_state.traffic_policies[endpoint]
self.actor_reconciler.endpoints_to_remove.append(endpoint)
@@ -780,7 +791,7 @@ class ServeController:
self._checkpoint()
await asyncio.gather(*[
router.set_route_table.remote(self.configuration_store.routes)
router.set_route_table.remote(self.current_state.routes)
for router in self.actor_reconciler.router_handles()
])
@@ -790,7 +801,7 @@ class ServeController:
"""Register a new backend under the specified tag."""
async with self.write_lock:
# Ensures this method is idempotent.
backend_info = self.configuration_store.get_backend(backend_tag)
backend_info = self.current_state.get_backend(backend_tag)
if backend_info is not None:
if (backend_info.backend_config == backend_config
and backend_info.replica_config == replica_config):
@@ -801,7 +812,7 @@ class ServeController:
# Save creator that starts replicas, the arguments to be passed in,
# and the configuration for the backends.
self.configuration_store.add_backend(
self.current_state.add_backend(
backend_tag,
BackendInfo(
worker_class=backend_replica,
@@ -815,10 +826,10 @@ class ServeController:
try:
self.actor_reconciler._scale_backend_replicas(
self.configuration_store.backends, backend_tag,
self.current_state.backends, backend_tag,
backend_config.num_replicas)
except RayServeException as e:
del self.configuration_store.backends[backend_tag]
del self.current_state.backends[backend_tag]
raise e
# NOTE(edoakes): we must write a checkpoint before starting new
@@ -826,7 +837,7 @@ class ServeController:
# crash while making the change.
self._checkpoint()
await self.actor_reconciler._start_pending_backend_replicas(
self.configuration_store)
self.current_state)
self.notify_replica_handles_changed()
@@ -838,11 +849,11 @@ class ServeController:
async with self.write_lock:
# This method must be idempotent. We should validate that the
# specified backend exists on the client.
if self.configuration_store.get_backend(backend_tag) is None:
if self.current_state.get_backend(backend_tag) is None:
return
# Check that the specified backend isn't used by any endpoints.
for endpoint, traffic_policy in self.configuration_store.\
for endpoint, traffic_policy in self.current_state.\
traffic_policies.items():
if (backend_tag in traffic_policy.traffic_dict
or backend_tag in traffic_policy.shadow_dict):
@@ -852,13 +863,13 @@ class ServeController:
"again.".format(backend_tag, endpoint))
# Scale its replicas down to 0. This will also remove the backend
# from self.configuration_store.backends and
# from self.current_state.backends and
# self.actor_reconciler.backend_replicas.
self.actor_reconciler._scale_backend_replicas(
self.configuration_store.backends, backend_tag, 0)
self.current_state.backends, backend_tag, 0)
# Remove the backend's metadata.
del self.configuration_store.backends[backend_tag]
del self.current_state.backends[backend_tag]
if backend_tag in self.autoscaling_policies:
del self.autoscaling_policies[backend_tag]
@@ -877,21 +888,21 @@ class ServeController:
config_options: BackendConfig) -> None:
"""Set the config for the specified backend."""
async with self.write_lock:
assert (self.configuration_store.get_backend(backend_tag)
assert (self.current_state.get_backend(backend_tag)
), "Backend {} is not registered.".format(backend_tag)
assert isinstance(config_options, BackendConfig)
stored_backend_config = self.configuration_store.get_backend(
stored_backend_config = self.current_state.get_backend(
backend_tag).backend_config
backend_config = stored_backend_config.copy(
update=config_options.dict(exclude_unset=True))
backend_config._validate_complete()
self.configuration_store.get_backend(
self.current_state.get_backend(
backend_tag).backend_config = backend_config
# Scale the replicas with the new configuration.
self.actor_reconciler._scale_backend_replicas(
self.configuration_store.backends, backend_tag,
self.current_state.backends, backend_tag,
backend_config.num_replicas)
# NOTE(edoakes): we must write a checkpoint before pushing the
@@ -903,7 +914,7 @@ class ServeController:
# (particularly for setting max_batch_size).
await self.actor_reconciler._start_pending_backend_replicas(
self.configuration_store)
self.current_state)
await self.actor_reconciler._stop_pending_backend_replicas()
self.notify_replica_handles_changed()
@@ -911,9 +922,9 @@ class ServeController:
def get_backend_config(self, backend_tag: BackendTag) -> BackendConfig:
"""Get the current config for the specified backend."""
assert (self.configuration_store.get_backend(backend_tag)
assert (self.current_state.get_backend(backend_tag)
), "Backend {} is not registered.".format(backend_tag)
return self.configuration_store.get_backend(backend_tag).backend_config
return self.current_state.get_backend(backend_tag).backend_config
async def shutdown(self) -> None:
"""Shuts down the serve instance completely."""