mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 11:05:26 +08:00
[serve] Create CurrentState & GoalState (#12369)
This commit is contained in:
@@ -43,6 +43,7 @@ BackendTag = str
|
||||
EndpointTag = str
|
||||
ReplicaTag = str
|
||||
NodeId = str
|
||||
GoalId = int
|
||||
|
||||
|
||||
class TrafficPolicy:
|
||||
@@ -91,13 +92,17 @@ class BackendInfo(BaseModel):
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConfigurationStore:
|
||||
class SystemState:
|
||||
backends: Dict[BackendTag, BackendInfo] = field(default_factory=dict)
|
||||
traffic_policies: Dict[EndpointTag, TrafficPolicy] = field(
|
||||
default_factory=dict)
|
||||
routes: Dict[BackendTag, Tuple[EndpointTag, Any]] = field(
|
||||
default_factory=dict)
|
||||
|
||||
backend_goal_ids: Dict[BackendTag, GoalId] = field(default_factory=dict)
|
||||
traffic_goal_ids: Dict[EndpointTag, GoalId] = field(default_factory=dict)
|
||||
route_goal_ids: Dict[BackendTag, GoalId] = field(default_factory=dict)
|
||||
|
||||
def get_backend_configs(self) -> Dict[BackendTag, BackendConfig]:
|
||||
return {
|
||||
tag: info.backend_config
|
||||
@@ -107,9 +112,31 @@ class ConfigurationStore:
|
||||
def get_backend(self, backend_tag: BackendTag) -> Optional[BackendInfo]:
|
||||
return self.backends.get(backend_tag)
|
||||
|
||||
def add_backend(self, backend_tag: BackendTag,
|
||||
backend_info: BackendInfo) -> None:
|
||||
def add_backend(self,
|
||||
backend_tag: BackendTag,
|
||||
backend_info: BackendInfo,
|
||||
goal_id: GoalId = 0) -> None:
|
||||
self.backends[backend_tag] = backend_info
|
||||
self.backend_goal_ids = goal_id
|
||||
|
||||
def get_endpoints(self) -> Dict[EndpointTag, Dict[str, Any]]:
|
||||
endpoints = {}
|
||||
for route, (endpoint, methods) in self.routes.items():
|
||||
if endpoint in self.traffic_policies:
|
||||
traffic_policy = self.traffic_policies[endpoint]
|
||||
traffic_dict = traffic_policy.traffic_dict
|
||||
shadow_dict = traffic_policy.shadow_dict
|
||||
else:
|
||||
traffic_dict = {}
|
||||
shadow_dict = {}
|
||||
|
||||
endpoints[endpoint] = {
|
||||
"route": route if route.startswith("/") else None,
|
||||
"methods": methods,
|
||||
"traffic": traffic_dict,
|
||||
"shadows": shadow_dict,
|
||||
}
|
||||
return endpoints
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -147,7 +174,7 @@ class ActorStateReconciler:
|
||||
]))
|
||||
|
||||
async def _start_pending_backend_replicas(
|
||||
self, config_store: ConfigurationStore) -> None:
|
||||
self, current_state: SystemState) -> None:
|
||||
"""Starts the pending backend replicas in self.backend_replicas_to_start.
|
||||
|
||||
Waits for replicas to start up, then removes them from
|
||||
@@ -158,7 +185,7 @@ class ActorStateReconciler:
|
||||
items():
|
||||
for replica_tag in replicas_to_create:
|
||||
replica_handle = await self._start_backend_replica(
|
||||
config_store, backend_tag, replica_tag)
|
||||
current_state, backend_tag, replica_tag)
|
||||
ready_future = replica_handle.ready.remote().as_future()
|
||||
fut_to_replica_info[ready_future] = (backend_tag, replica_tag,
|
||||
replica_handle)
|
||||
@@ -182,15 +209,14 @@ class ActorStateReconciler:
|
||||
|
||||
self.backend_replicas_to_start.clear()
|
||||
|
||||
async def _start_backend_replica(self, config_store: ConfigurationStore,
|
||||
async def _start_backend_replica(self, current_state: SystemState,
|
||||
backend_tag: BackendTag,
|
||||
replica_tag: ReplicaTag) -> ActorHandle:
|
||||
"""Start a replica and return its actor handle.
|
||||
|
||||
Checks if the named actor already exists before starting a new one.
|
||||
|
||||
Assumes that the backend configuration has already been registered
|
||||
in the ConfigurationStore.
|
||||
Assumes that the backend configuration is already in the Goal State.
|
||||
"""
|
||||
# NOTE(edoakes): the replicas may already be created if we
|
||||
# failed after creating them but before writing a
|
||||
@@ -201,7 +227,7 @@ class ActorStateReconciler:
|
||||
except ValueError:
|
||||
logger.debug("Starting replica '{}' for backend '{}'.".format(
|
||||
replica_tag, backend_tag))
|
||||
backend_info = config_store.get_backend(backend_tag)
|
||||
backend_info = current_state.get_backend(backend_tag)
|
||||
|
||||
replica_handle = ray.remote(backend_info.worker_class).options(
|
||||
name=replica_name,
|
||||
@@ -373,20 +399,19 @@ class ActorStateReconciler:
|
||||
replica_tag] = ray.get_actor(replica_name)
|
||||
|
||||
async def _recover_from_checkpoint(
|
||||
self, config_store: ConfigurationStore,
|
||||
controller: "ServeController"
|
||||
self, current_state: SystemState, controller: "ServeController"
|
||||
) -> Dict[BackendTag, BasicAutoscalingPolicy]:
|
||||
self._recover_actor_handles()
|
||||
autoscaling_policies = dict()
|
||||
|
||||
for backend, info in config_store.backends.items():
|
||||
for backend, info in current_state.backends.items():
|
||||
metadata = info.backend_config.internal_metadata
|
||||
if metadata.autoscaling_config is not None:
|
||||
autoscaling_policies[backend] = BasicAutoscalingPolicy(
|
||||
backend, metadata.autoscaling_config)
|
||||
|
||||
# Start/stop any pending backend replicas.
|
||||
await self._start_pending_backend_replicas(config_store)
|
||||
await self._start_pending_backend_replicas(current_state)
|
||||
await self._stop_pending_backend_replicas()
|
||||
|
||||
return autoscaling_policies
|
||||
@@ -394,8 +419,10 @@ class ActorStateReconciler:
|
||||
|
||||
@dataclass
|
||||
class Checkpoint:
|
||||
config: ConfigurationStore
|
||||
goal_state: SystemState
|
||||
current_state: SystemState
|
||||
reconciler: ActorStateReconciler
|
||||
# TODO(ilr) Rename reconciler to PendingState
|
||||
|
||||
|
||||
@ray.remote
|
||||
@@ -432,8 +459,12 @@ class ServeController:
|
||||
detached: bool = False):
|
||||
# Used to read/write checkpoints.
|
||||
self.kv_store = RayInternalKVStore(namespace=controller_name)
|
||||
# ConfigurationStore
|
||||
self.configuration_store = ConfigurationStore()
|
||||
# Current State
|
||||
self.current_state = SystemState()
|
||||
# Goal State
|
||||
# TODO(ilr) This is currently *unused* until the refactor of the serve
|
||||
# controller.
|
||||
self.goal_state = SystemState()
|
||||
# ActorStateReconciler
|
||||
self.actor_reconciler = ActorStateReconciler(controller_name, detached)
|
||||
|
||||
@@ -497,12 +528,12 @@ class ServeController:
|
||||
})
|
||||
|
||||
def notify_traffic_policies_changed(self):
|
||||
self.long_poll_host.notify_changed(
|
||||
"traffic_policies", self.configuration_store.traffic_policies)
|
||||
self.long_poll_host.notify_changed("traffic_policies",
|
||||
self.current_state.traffic_policies)
|
||||
|
||||
def notify_backend_configs_changed(self):
|
||||
self.long_poll_host.notify_changed(
|
||||
"backend_configs", self.configuration_store.get_backend_configs())
|
||||
"backend_configs", self.current_state.get_backend_configs())
|
||||
|
||||
async def listen_for_change(self, keys_to_snapshot_ids: Dict[str, int]):
|
||||
"""Proxy long pull client's listen request.
|
||||
@@ -519,9 +550,9 @@ class ServeController:
|
||||
"""Returns a dictionary of node ID to router actor handles."""
|
||||
return self.actor_reconciler.routers_cache
|
||||
|
||||
def get_router_config(self) -> Dict[str, Dict[str, Tuple[str, List[str]]]]:
|
||||
def get_router_config(self) -> Dict[str, Tuple[str, List[str]]]:
|
||||
"""Called by the router on startup to fetch required state."""
|
||||
return self.configuration_store.routes
|
||||
return self.current_state.routes
|
||||
|
||||
def _checkpoint(self) -> None:
|
||||
"""Checkpoint internal state and write it to the KV store."""
|
||||
@@ -530,7 +561,8 @@ class ServeController:
|
||||
start = time.time()
|
||||
|
||||
checkpoint = pickle.dumps(
|
||||
Checkpoint(self.configuration_store, self.actor_reconciler))
|
||||
Checkpoint(self.goal_state, self.current_state,
|
||||
self.actor_reconciler))
|
||||
|
||||
self.kv_store.put(CHECKPOINT_KEY, checkpoint)
|
||||
logger.debug("Wrote checkpoint in {:.2f}".format(time.time() - start))
|
||||
@@ -559,14 +591,14 @@ class ServeController:
|
||||
logger.info("Recovering from checkpoint")
|
||||
|
||||
restored_checkpoint: Checkpoint = pickle.loads(checkpoint_bytes)
|
||||
# Restore ConfigurationStore
|
||||
self.configuration_store = restored_checkpoint.config
|
||||
# Restore SystemState
|
||||
self.current_state = restored_checkpoint.current_state
|
||||
|
||||
# Restore ActorStateReconciler
|
||||
self.actor_reconciler = restored_checkpoint.reconciler
|
||||
|
||||
self.autoscaling_policies = await self.actor_reconciler.\
|
||||
_recover_from_checkpoint(self.configuration_store, self)
|
||||
_recover_from_checkpoint(self.current_state, self)
|
||||
|
||||
logger.info(
|
||||
"Recovered from checkpoint in {:.3f}s".format(time.time() - start))
|
||||
@@ -574,7 +606,7 @@ class ServeController:
|
||||
self.write_lock.release()
|
||||
|
||||
async def do_autoscale(self) -> None:
|
||||
for backend, info in self.configuration_store.backends.items():
|
||||
for backend, info in self.current_state.backends.items():
|
||||
if backend not in self.autoscaling_policies:
|
||||
continue
|
||||
|
||||
@@ -599,11 +631,11 @@ class ServeController:
|
||||
|
||||
def get_backend_configs(self) -> Dict[str, BackendConfig]:
|
||||
"""Fetched by the router on startup."""
|
||||
return self.configuration_store.get_backend_configs()
|
||||
return self.current_state.get_backend_configs()
|
||||
|
||||
def get_traffic_policies(self) -> Dict[str, TrafficPolicy]:
|
||||
"""Fetched by the router on startup."""
|
||||
return self.configuration_store.traffic_policies
|
||||
return self.current_state.traffic_policies
|
||||
|
||||
def _list_replicas(self, backend_tag: BackendTag) -> List[ReplicaTag]:
|
||||
"""Used only for testing."""
|
||||
@@ -611,7 +643,7 @@ class ServeController:
|
||||
|
||||
def get_traffic_policy(self, endpoint: str) -> TrafficPolicy:
|
||||
"""Fetched by serve handles."""
|
||||
return self.configuration_store.traffic_policies[endpoint]
|
||||
return self.current_state.traffic_policies[endpoint]
|
||||
|
||||
def get_all_replica_handles(self) -> Dict[str, Dict[str, ActorHandle]]:
|
||||
"""Fetched by the router on startup."""
|
||||
@@ -619,33 +651,14 @@ class ServeController:
|
||||
|
||||
def get_all_backends(self) -> Dict[str, BackendConfig]:
|
||||
"""Returns a dictionary of backend tag to backend config."""
|
||||
return self.configuration_store.get_backend_configs()
|
||||
return self.current_state.get_backend_configs()
|
||||
|
||||
def get_all_endpoints(self) -> Dict[str, Dict[str, Any]]:
|
||||
"""Returns a dictionary of endpoint to endpoint config."""
|
||||
endpoints = {}
|
||||
for route, (endpoint,
|
||||
methods) in self.configuration_store.routes.items():
|
||||
if endpoint in self.configuration_store.traffic_policies:
|
||||
traffic_policy = self.configuration_store.traffic_policies[
|
||||
endpoint]
|
||||
traffic_dict = traffic_policy.traffic_dict
|
||||
shadow_dict = traffic_policy.shadow_dict
|
||||
else:
|
||||
traffic_dict = {}
|
||||
shadow_dict = {}
|
||||
|
||||
endpoints[endpoint] = {
|
||||
"route": route if route.startswith("/") else None,
|
||||
"methods": methods,
|
||||
"traffic": traffic_dict,
|
||||
"shadows": shadow_dict,
|
||||
}
|
||||
return endpoints
|
||||
return self.current_state.get_endpoints()
|
||||
|
||||
async def _set_traffic(self, endpoint_name: str,
|
||||
traffic_dict: Dict[str, float]) -> None:
|
||||
if endpoint_name not in self.get_all_endpoints():
|
||||
if endpoint_name not in self.current_state.get_endpoints():
|
||||
raise ValueError("Attempted to assign traffic for an endpoint '{}'"
|
||||
" that is not registered.".format(endpoint_name))
|
||||
|
||||
@@ -653,14 +666,13 @@ class ServeController:
|
||||
dict), "Traffic policy must be a dictionary."
|
||||
|
||||
for backend in traffic_dict:
|
||||
if self.configuration_store.get_backend(backend) is None:
|
||||
if self.current_state.get_backend(backend) is None:
|
||||
raise ValueError(
|
||||
"Attempted to assign traffic to a backend '{}' that "
|
||||
"is not registered.".format(backend))
|
||||
|
||||
traffic_policy = TrafficPolicy(traffic_dict)
|
||||
self.configuration_store.traffic_policies[
|
||||
endpoint_name] = traffic_policy
|
||||
self.current_state.traffic_policies[endpoint_name] = traffic_policy
|
||||
|
||||
# NOTE(edoakes): we must write a checkpoint before pushing the
|
||||
# update to avoid inconsistent state if we crash after pushing the
|
||||
@@ -679,18 +691,18 @@ class ServeController:
|
||||
proportion: float) -> None:
|
||||
"""Shadow traffic from the endpoint to the backend."""
|
||||
async with self.write_lock:
|
||||
if endpoint_name not in self.get_all_endpoints():
|
||||
if endpoint_name not in self.current_state.get_endpoints():
|
||||
raise ValueError("Attempted to shadow traffic from an "
|
||||
"endpoint '{}' that is not registered."
|
||||
.format(endpoint_name))
|
||||
|
||||
if self.configuration_store.get_backend(backend_tag) is None:
|
||||
if self.current_state.get_backend(backend_tag) is None:
|
||||
raise ValueError(
|
||||
"Attempted to shadow traffic to a backend '{}' that "
|
||||
"is not registered.".format(backend_tag))
|
||||
|
||||
self.configuration_store.traffic_policies[
|
||||
endpoint_name].set_shadow(backend_tag, proportion)
|
||||
self.current_state.traffic_policies[endpoint_name].set_shadow(
|
||||
backend_tag, proportion)
|
||||
|
||||
# NOTE(edoakes): we must write a checkpoint before pushing the
|
||||
# update to avoid inconsistent state if we crash after pushing the
|
||||
@@ -717,11 +729,10 @@ class ServeController:
|
||||
|
||||
# TODO(edoakes): move this to client side.
|
||||
err_prefix = "Cannot create endpoint."
|
||||
if route in self.configuration_store.routes:
|
||||
if route in self.current_state.routes:
|
||||
|
||||
# Ensures this method is idempotent
|
||||
if self.configuration_store.routes[route] == (endpoint,
|
||||
methods):
|
||||
if self.current_state.routes[route] == (endpoint, methods):
|
||||
return
|
||||
|
||||
else:
|
||||
@@ -729,7 +740,7 @@ class ServeController:
|
||||
"{} Route '{}' is already registered.".format(
|
||||
err_prefix, route))
|
||||
|
||||
if endpoint in self.get_all_endpoints():
|
||||
if endpoint in self.current_state.get_endpoints():
|
||||
raise ValueError(
|
||||
"{} Endpoint '{}' is already registered.".format(
|
||||
err_prefix, endpoint))
|
||||
@@ -738,12 +749,12 @@ class ServeController:
|
||||
"Registering route '{}' to endpoint '{}' with methods '{}'.".
|
||||
format(route, endpoint, methods))
|
||||
|
||||
self.configuration_store.routes[route] = (endpoint, methods)
|
||||
self.current_state.routes[route] = (endpoint, methods)
|
||||
|
||||
# NOTE(edoakes): checkpoint is written in self._set_traffic.
|
||||
await self._set_traffic(endpoint, traffic_dict)
|
||||
await asyncio.gather(*[
|
||||
router.set_route_table.remote(self.configuration_store.routes)
|
||||
router.set_route_table.remote(self.current_state.routes)
|
||||
for router in self.actor_reconciler.router_handles()
|
||||
])
|
||||
|
||||
@@ -757,7 +768,7 @@ class ServeController:
|
||||
# This method must be idempotent. We should validate that the
|
||||
# specified endpoint exists on the client.
|
||||
for route, (route_endpoint,
|
||||
_) in self.configuration_store.routes.items():
|
||||
_) in self.current_state.routes.items():
|
||||
if route_endpoint == endpoint:
|
||||
route_to_delete = route
|
||||
break
|
||||
@@ -766,11 +777,11 @@ class ServeController:
|
||||
return
|
||||
|
||||
# Remove the routing entry.
|
||||
del self.configuration_store.routes[route_to_delete]
|
||||
del self.current_state.routes[route_to_delete]
|
||||
|
||||
# Remove the traffic policy entry if it exists.
|
||||
if endpoint in self.configuration_store.traffic_policies:
|
||||
del self.configuration_store.traffic_policies[endpoint]
|
||||
if endpoint in self.current_state.traffic_policies:
|
||||
del self.current_state.traffic_policies[endpoint]
|
||||
|
||||
self.actor_reconciler.endpoints_to_remove.append(endpoint)
|
||||
|
||||
@@ -780,7 +791,7 @@ class ServeController:
|
||||
self._checkpoint()
|
||||
|
||||
await asyncio.gather(*[
|
||||
router.set_route_table.remote(self.configuration_store.routes)
|
||||
router.set_route_table.remote(self.current_state.routes)
|
||||
for router in self.actor_reconciler.router_handles()
|
||||
])
|
||||
|
||||
@@ -790,7 +801,7 @@ class ServeController:
|
||||
"""Register a new backend under the specified tag."""
|
||||
async with self.write_lock:
|
||||
# Ensures this method is idempotent.
|
||||
backend_info = self.configuration_store.get_backend(backend_tag)
|
||||
backend_info = self.current_state.get_backend(backend_tag)
|
||||
if backend_info is not None:
|
||||
if (backend_info.backend_config == backend_config
|
||||
and backend_info.replica_config == replica_config):
|
||||
@@ -801,7 +812,7 @@ class ServeController:
|
||||
|
||||
# Save creator that starts replicas, the arguments to be passed in,
|
||||
# and the configuration for the backends.
|
||||
self.configuration_store.add_backend(
|
||||
self.current_state.add_backend(
|
||||
backend_tag,
|
||||
BackendInfo(
|
||||
worker_class=backend_replica,
|
||||
@@ -815,10 +826,10 @@ class ServeController:
|
||||
|
||||
try:
|
||||
self.actor_reconciler._scale_backend_replicas(
|
||||
self.configuration_store.backends, backend_tag,
|
||||
self.current_state.backends, backend_tag,
|
||||
backend_config.num_replicas)
|
||||
except RayServeException as e:
|
||||
del self.configuration_store.backends[backend_tag]
|
||||
del self.current_state.backends[backend_tag]
|
||||
raise e
|
||||
|
||||
# NOTE(edoakes): we must write a checkpoint before starting new
|
||||
@@ -826,7 +837,7 @@ class ServeController:
|
||||
# crash while making the change.
|
||||
self._checkpoint()
|
||||
await self.actor_reconciler._start_pending_backend_replicas(
|
||||
self.configuration_store)
|
||||
self.current_state)
|
||||
|
||||
self.notify_replica_handles_changed()
|
||||
|
||||
@@ -838,11 +849,11 @@ class ServeController:
|
||||
async with self.write_lock:
|
||||
# This method must be idempotent. We should validate that the
|
||||
# specified backend exists on the client.
|
||||
if self.configuration_store.get_backend(backend_tag) is None:
|
||||
if self.current_state.get_backend(backend_tag) is None:
|
||||
return
|
||||
|
||||
# Check that the specified backend isn't used by any endpoints.
|
||||
for endpoint, traffic_policy in self.configuration_store.\
|
||||
for endpoint, traffic_policy in self.current_state.\
|
||||
traffic_policies.items():
|
||||
if (backend_tag in traffic_policy.traffic_dict
|
||||
or backend_tag in traffic_policy.shadow_dict):
|
||||
@@ -852,13 +863,13 @@ class ServeController:
|
||||
"again.".format(backend_tag, endpoint))
|
||||
|
||||
# Scale its replicas down to 0. This will also remove the backend
|
||||
# from self.configuration_store.backends and
|
||||
# from self.current_state.backends and
|
||||
# self.actor_reconciler.backend_replicas.
|
||||
self.actor_reconciler._scale_backend_replicas(
|
||||
self.configuration_store.backends, backend_tag, 0)
|
||||
self.current_state.backends, backend_tag, 0)
|
||||
|
||||
# Remove the backend's metadata.
|
||||
del self.configuration_store.backends[backend_tag]
|
||||
del self.current_state.backends[backend_tag]
|
||||
if backend_tag in self.autoscaling_policies:
|
||||
del self.autoscaling_policies[backend_tag]
|
||||
|
||||
@@ -877,21 +888,21 @@ class ServeController:
|
||||
config_options: BackendConfig) -> None:
|
||||
"""Set the config for the specified backend."""
|
||||
async with self.write_lock:
|
||||
assert (self.configuration_store.get_backend(backend_tag)
|
||||
assert (self.current_state.get_backend(backend_tag)
|
||||
), "Backend {} is not registered.".format(backend_tag)
|
||||
assert isinstance(config_options, BackendConfig)
|
||||
|
||||
stored_backend_config = self.configuration_store.get_backend(
|
||||
stored_backend_config = self.current_state.get_backend(
|
||||
backend_tag).backend_config
|
||||
backend_config = stored_backend_config.copy(
|
||||
update=config_options.dict(exclude_unset=True))
|
||||
backend_config._validate_complete()
|
||||
self.configuration_store.get_backend(
|
||||
self.current_state.get_backend(
|
||||
backend_tag).backend_config = backend_config
|
||||
|
||||
# Scale the replicas with the new configuration.
|
||||
self.actor_reconciler._scale_backend_replicas(
|
||||
self.configuration_store.backends, backend_tag,
|
||||
self.current_state.backends, backend_tag,
|
||||
backend_config.num_replicas)
|
||||
|
||||
# NOTE(edoakes): we must write a checkpoint before pushing the
|
||||
@@ -903,7 +914,7 @@ class ServeController:
|
||||
# (particularly for setting max_batch_size).
|
||||
|
||||
await self.actor_reconciler._start_pending_backend_replicas(
|
||||
self.configuration_store)
|
||||
self.current_state)
|
||||
await self.actor_reconciler._stop_pending_backend_replicas()
|
||||
|
||||
self.notify_replica_handles_changed()
|
||||
@@ -911,9 +922,9 @@ class ServeController:
|
||||
|
||||
def get_backend_config(self, backend_tag: BackendTag) -> BackendConfig:
|
||||
"""Get the current config for the specified backend."""
|
||||
assert (self.configuration_store.get_backend(backend_tag)
|
||||
assert (self.current_state.get_backend(backend_tag)
|
||||
), "Backend {} is not registered.".format(backend_tag)
|
||||
return self.configuration_store.get_backend(backend_tag).backend_config
|
||||
return self.current_state.get_backend(backend_tag).backend_config
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
"""Shuts down the serve instance completely."""
|
||||
|
||||
Reference in New Issue
Block a user