From b52cce6632dc0a91c72e2fae6312eac4e2ad4d0d Mon Sep 17 00:00:00 2001 From: Edward Oakes Date: Mon, 21 Dec 2020 20:39:13 -0600 Subject: [PATCH] [serve] Refactor SystemState into EndpointState and BackendState (#13018) --- python/ray/serve/controller.py | 198 +++++++++++++++++---------------- 1 file changed, 100 insertions(+), 98 deletions(-) diff --git a/python/ray/serve/controller.py b/python/ray/serve/controller.py index 17a543048..4a4b754ff 100644 --- a/python/ray/serve/controller.py +++ b/python/ray/serve/controller.py @@ -93,33 +93,16 @@ class BackendInfo(BaseModel): arbitrary_types_allowed = True -@dataclass -class SystemState: - backends: Dict[BackendTag, BackendInfo] = field(default_factory=dict) - traffic_policies: Dict[EndpointTag, TrafficPolicy] = field( - default_factory=dict) - routes: Dict[BackendTag, Tuple[EndpointTag, Any]] = field( - default_factory=dict) +class EndpointState: + def __init__(self, checkpoint: bytes = None): + self.routes: Dict[BackendTag, Tuple[EndpointTag, Any]] = dict() + self.traffic_policies: Dict[EndpointTag, TrafficPolicy] = dict() - backend_goal_ids: Dict[BackendTag, GoalId] = field(default_factory=dict) - traffic_goal_ids: Dict[EndpointTag, GoalId] = field(default_factory=dict) - route_goal_ids: Dict[BackendTag, GoalId] = field(default_factory=dict) + if checkpoint is not None: + self.routes, self.traffic_policies = pickle.loads(checkpoint) - def get_backend_configs(self) -> Dict[BackendTag, BackendConfig]: - return { - tag: info.backend_config - for tag, info in self.backends.items() - } - - def get_backend(self, backend_tag: BackendTag) -> Optional[BackendInfo]: - return self.backends.get(backend_tag) - - def add_backend(self, - backend_tag: BackendTag, - backend_info: BackendInfo, - goal_id: GoalId = 0) -> None: - self.backends[backend_tag] = backend_info - self.backend_goal_ids = goal_id + def checkpoint(self): + return pickle.dumps((self.routes, self.traffic_policies)) def get_endpoints(self) -> Dict[EndpointTag, Dict[str, Any]]: endpoints = {} @@ -141,6 +124,32 @@ class SystemState: return endpoints +class BackendState: + def __init__(self, checkpoint: bytes = None): + self.backends: Dict[BackendTag, BackendInfo] = dict() + + if checkpoint is not None: + self.backends = pickle.loads(checkpoint) + + def checkpoint(self): + return pickle.dumps(self.backends) + + def get_backend_configs(self) -> Dict[BackendTag, BackendConfig]: + return { + tag: info.backend_config + for tag, info in self.backends.items() + } + + def get_backend(self, backend_tag: BackendTag) -> Optional[BackendInfo]: + return self.backends.get(backend_tag) + + def add_backend(self, + backend_tag: BackendTag, + backend_info: BackendInfo, + goal_id: GoalId = 0) -> None: + self.backends[backend_tag] = backend_info + + @dataclass class ActorStateReconciler: controller_name: str = field(init=True) @@ -192,7 +201,7 @@ class ActorStateReconciler: for replica_dict in self.backend_replicas.values() ])) - async def _start_backend_replica(self, current_state: SystemState, + async def _start_backend_replica(self, backend_state: BackendState, backend_tag: BackendTag, replica_tag: ReplicaTag) -> ActorHandle: """Start a replica and return its actor handle. @@ -210,7 +219,7 @@ class ActorStateReconciler: except ValueError: logger.debug("Starting replica '{}' for backend '{}'.".format( replica_tag, backend_tag)) - backend_info = current_state.get_backend(backend_tag) + backend_info = backend_state.get_backend(backend_tag) replica_handle = ray.remote(backend_info.worker_class).options( name=replica_name, @@ -284,12 +293,12 @@ class ActorStateReconciler: self.backend_replicas_to_stop[backend_tag].append(replica_tag) async def _enqueue_pending_scale_changes_loop(self, - current_state: SystemState): + backend_state: BackendState): for backend_tag, replicas_to_create in self.backend_replicas_to_start.\ items(): for replica_tag in replicas_to_create: replica_handle = await self._start_backend_replica( - current_state, backend_tag, replica_tag) + backend_state, backend_tag, replica_tag) ready_future = replica_handle.ready.remote().as_future() self.currently_starting_replicas[ready_future] = ( backend_tag, replica_tag, replica_handle) @@ -456,19 +465,19 @@ class ActorStateReconciler: replica_tag] = ray.get_actor(replica_name) async def _recover_from_checkpoint( - self, current_state: SystemState, controller: "ServeController" + self, backend_state: BackendState, controller: "ServeController" ) -> Dict[BackendTag, BasicAutoscalingPolicy]: self._recover_actor_handles() autoscaling_policies = dict() - for backend, info in current_state.backends.items(): + for backend, info in backend_state.backends.items(): metadata = info.backend_config.internal_metadata if metadata.autoscaling_config is not None: autoscaling_policies[backend] = BasicAutoscalingPolicy( backend, metadata.autoscaling_config) # Start/stop any pending backend replicas. - await self._enqueue_pending_scale_changes_loop(current_state) + await self._enqueue_pending_scale_changes_loop(backend_state) await self.backend_control_loop() return autoscaling_policies @@ -482,8 +491,8 @@ class FutureResult: @dataclass class Checkpoint: - goal_state: SystemState - current_state: SystemState + endpoint_state_checkpoint: bytes + backend_state_checkpoint: bytes reconciler: ActorStateReconciler # TODO(ilr) Rename reconciler to PendingState inflight_reqs: Dict[uuid4, FutureResult] @@ -523,13 +532,6 @@ class ServeController: detached: bool = False): # Used to read/write checkpoints. self.kv_store = RayInternalKVStore(namespace=controller_name) - # Current State - self.current_state = SystemState() - # Goal State - # TODO(ilr) This is currently *unused* until the refactor of the serve - # controller. - self.goal_state = SystemState() - # ActorStateReconciler self.actor_reconciler = ActorStateReconciler(controller_name, detached) # backend -> AutoscalingPolicy @@ -556,10 +558,17 @@ class ServeController: self.inflight_results: Dict[UUID, asyncio.Event] = dict() self._serializable_inflight_results: Dict[UUID, FutureResult] = dict() - checkpoint = self.kv_store.get(CHECKPOINT_KEY) - if checkpoint is None: + checkpoint_bytes = self.kv_store.get(CHECKPOINT_KEY) + if checkpoint_bytes is None: logger.debug("No checkpoint found") + self.backend_state = BackendState() + self.endpoint_state = EndpointState() else: + checkpoint: Checkpoint = pickle.loads(checkpoint_bytes) + self.backend_state = BackendState( + checkpoint=checkpoint.backend_state_checkpoint) + self.endpoint_state = EndpointState( + checkpoint=checkpoint.endpoint_state_checkpoint) await self._recover_from_checkpoint(checkpoint) # NOTE(simon): Currently we do all-to-all broadcast. This means @@ -618,17 +627,17 @@ class ServeController: def notify_traffic_policies_changed(self): self.long_poll_host.notify_changed( LongPollKey.TRAFFIC_POLICIES, - self.current_state.traffic_policies, + self.endpoint_state.traffic_policies, ) def notify_backend_configs_changed(self): self.long_poll_host.notify_changed( LongPollKey.BACKEND_CONFIGS, - self.current_state.get_backend_configs()) + self.backend_state.get_backend_configs()) def notify_route_table_changed(self): self.long_poll_host.notify_changed(LongPollKey.ROUTE_TABLE, - self.current_state.routes) + self.endpoint_state.routes) async def listen_for_change(self, keys_to_snapshot_ids: Dict[str, int]): """Proxy long pull client's listen request. @@ -652,19 +661,19 @@ class ServeController: start = time.time() checkpoint = pickle.dumps( - Checkpoint(self.goal_state, self.current_state, - self.actor_reconciler, + Checkpoint(self.endpoint_state.checkpoint(), + self.backend_state.checkpoint(), self.actor_reconciler, self._serializable_inflight_results)) self.kv_store.put(CHECKPOINT_KEY, checkpoint) - logger.debug("Wrote checkpoint in {:.2f}".format(time.time() - start)) + logger.debug("Wrote checkpoint in {:.3f}s".format(time.time() - start)) if random.random( ) < _CRASH_AFTER_CHECKPOINT_PROBABILITY and self.detached: logger.warning("Intentionally crashing after checkpoint") os._exit(0) - async def _recover_from_checkpoint(self, checkpoint_bytes: bytes) -> None: + async def _recover_from_checkpoint(self, checkpoint: Checkpoint) -> None: """Recover the instance state from the provided checkpoint. This should be called in the constructor to ensure that the internal @@ -679,12 +688,9 @@ class ServeController: start = time.time() logger.info("Recovering from checkpoint") - restored_checkpoint: Checkpoint = pickle.loads(checkpoint_bytes) - self.current_state = restored_checkpoint.current_state + self.actor_reconciler = checkpoint.reconciler - self.actor_reconciler = restored_checkpoint.reconciler - - self._serializable_inflight_results = restored_checkpoint.inflight_reqs + self._serializable_inflight_results = checkpoint.inflight_reqs for uuid, fut_result in self._serializable_inflight_results.items(): self._create_event_with_result(fut_result.requested_goal, uuid) @@ -704,7 +710,7 @@ class ServeController: async def finish_recover_from_checkpoint(): assert self.write_lock.locked() self.autoscaling_policies = await self.actor_reconciler.\ - _recover_from_checkpoint(self.current_state, self) + _recover_from_checkpoint(self.backend_state, self) self.write_lock.release() logger.info( "Recovered from checkpoint in {:.3f}s".format(time.time() - @@ -714,7 +720,7 @@ class ServeController: asyncio.get_event_loop().create_task(finish_recover_from_checkpoint()) async def do_autoscale(self) -> None: - for backend, info in self.current_state.backends.items(): + for backend, info in self.backend_state.backends.items(): if backend not in self.autoscaling_policies: continue @@ -726,9 +732,6 @@ class ServeController: async def reconcile_current_and_goal_backends(self): pass - # backends_to_delete = set( - # self.current_state.backends.keys()).difference( - # self.goal_state.backends.keys()) async def run_control_loop(self) -> None: while True: @@ -750,15 +753,15 @@ class ServeController: def get_all_backends(self) -> Dict[BackendTag, BackendConfig]: """Returns a dictionary of backend tag to backend config.""" - return self.current_state.get_backend_configs() + return self.backend_state.get_backend_configs() def get_all_endpoints(self) -> Dict[EndpointTag, Dict[BackendTag, Any]]: """Returns a dictionary of backend tag to backend config.""" - return self.current_state.get_endpoints() + return self.endpoint_state.get_endpoints() async def _set_traffic(self, endpoint_name: str, traffic_dict: Dict[str, float]) -> UUID: - if endpoint_name not in self.current_state.get_endpoints(): + if endpoint_name not in self.endpoint_state.get_endpoints(): raise ValueError("Attempted to assign traffic for an endpoint '{}'" " that is not registered.".format(endpoint_name)) @@ -766,13 +769,13 @@ class ServeController: dict), "Traffic policy must be a dictionary." for backend in traffic_dict: - if self.current_state.get_backend(backend) is None: + if self.backend_state.get_backend(backend) is None: raise ValueError( "Attempted to assign traffic to a backend '{}' that " "is not registered.".format(backend)) traffic_policy = TrafficPolicy(traffic_dict) - self.current_state.traffic_policies[endpoint_name] = traffic_policy + self.endpoint_state.traffic_policies[endpoint_name] = traffic_policy return_uuid = self._create_event_with_result({ endpoint_name: traffic_policy @@ -795,20 +798,21 @@ class ServeController: proportion: float) -> UUID: """Shadow traffic from the endpoint to the backend.""" async with self.write_lock: - if endpoint_name not in self.current_state.get_endpoints(): + if endpoint_name not in self.endpoint_state.get_endpoints(): raise ValueError("Attempted to shadow traffic from an " "endpoint '{}' that is not registered." .format(endpoint_name)) - if self.current_state.get_backend(backend_tag) is None: + if self.backend_state.get_backend(backend_tag) is None: raise ValueError( "Attempted to shadow traffic to a backend '{}' that " "is not registered.".format(backend_tag)) - self.current_state.traffic_policies[endpoint_name].set_shadow( + self.endpoint_state.traffic_policies[endpoint_name].set_shadow( backend_tag, proportion) - traffic_policy = self.current_state.traffic_policies[endpoint_name] + traffic_policy = self.endpoint_state.traffic_policies[ + endpoint_name] return_uuid = self._create_event_with_result({ endpoint_name: traffic_policy @@ -839,10 +843,10 @@ class ServeController: # TODO(edoakes): move this to client side. err_prefix = "Cannot create endpoint." - if route in self.current_state.routes: + if route in self.endpoint_state.routes: # Ensures this method is idempotent - if self.current_state.routes[route] == (endpoint, methods): + if self.endpoint_state.routes[route] == (endpoint, methods): return else: @@ -850,7 +854,7 @@ class ServeController: "{} Route '{}' is already registered.".format( err_prefix, route)) - if endpoint in self.current_state.get_endpoints(): + if endpoint in self.endpoint_state.get_endpoints(): raise ValueError( "{} Endpoint '{}' is already registered.".format( err_prefix, endpoint)) @@ -859,7 +863,7 @@ class ServeController: "Registering route '{}' to endpoint '{}' with methods '{}'.". format(route, endpoint, methods)) - self.current_state.routes[route] = (endpoint, methods) + self.endpoint_state.routes[route] = (endpoint, methods) # NOTE(edoakes): checkpoint is written in self._set_traffic. return_uuid = await self._set_traffic(endpoint, traffic_dict) @@ -876,7 +880,7 @@ class ServeController: # This method must be idempotent. We should validate that the # specified endpoint exists on the client. for route, (route_endpoint, - _) in self.current_state.routes.items(): + _) in self.endpoint_state.routes.items(): if route_endpoint == endpoint: route_to_delete = route break @@ -885,11 +889,11 @@ class ServeController: return # Remove the routing entry. - del self.current_state.routes[route_to_delete] + del self.endpoint_state.routes[route_to_delete] # Remove the traffic policy entry if it exists. - if endpoint in self.current_state.traffic_policies: - del self.current_state.traffic_policies[endpoint] + if endpoint in self.endpoint_state.traffic_policies: + del self.endpoint_state.traffic_policies[endpoint] return_uuid = self._create_event_with_result({ route_to_delete: None, @@ -908,7 +912,7 @@ class ServeController: """Register a new backend under the specified tag.""" async with self.write_lock: # Ensures this method is idempotent. - backend_info = self.current_state.get_backend(backend_tag) + backend_info = self.backend_state.get_backend(backend_tag) if backend_info is not None: if (backend_info.backend_config == backend_config and backend_info.replica_config == replica_config): @@ -923,7 +927,7 @@ class ServeController: worker_class=backend_replica, backend_config=backend_config, replica_config=replica_config) - self.current_state.add_backend(backend_tag, backend_info) + self.backend_state.add_backend(backend_tag, backend_info) metadata = backend_config.internal_metadata if metadata.autoscaling_config is not None: self.autoscaling_policies[ @@ -933,10 +937,10 @@ class ServeController: try: # This call should be to run control loop self.actor_reconciler._scale_backend_replicas( - self.current_state.backends, backend_tag, + self.backend_state.backends, backend_tag, backend_config.num_replicas) except RayServeException as e: - del self.current_state.backends[backend_tag] + del self.backend_state.backends[backend_tag] raise e return_uuid = self._create_event_with_result({ @@ -947,7 +951,7 @@ class ServeController: # crash while making the change. self._checkpoint() await self.actor_reconciler._enqueue_pending_scale_changes_loop( - self.current_state) + self.backend_state) await self.actor_reconciler.backend_control_loop() self.notify_replica_handles_changed() @@ -961,11 +965,11 @@ class ServeController: async with self.write_lock: # This method must be idempotent. We should validate that the # specified backend exists on the client. - if self.current_state.get_backend(backend_tag) is None: + if self.backend_state.get_backend(backend_tag) is None: return # Check that the specified backend isn't used by any endpoints. - for endpoint, traffic_policy in self.current_state.\ + for endpoint, traffic_policy in self.endpoint_state.\ traffic_policies.items(): if (backend_tag in traffic_policy.traffic_dict or backend_tag in traffic_policy.shadow_dict): @@ -975,17 +979,15 @@ class ServeController: "again.".format(backend_tag, endpoint)) # Scale its replicas down to 0. This will also remove the backend - # from self.current_state.backends and + # from self.backend_state.backends and # self.actor_reconciler.backend_replicas. - self.goal_state.backends[backend_tag] = None - # This should be a call to the control loop self.actor_reconciler._scale_backend_replicas( - self.current_state.backends, backend_tag, 0) + self.backend_state.backends, backend_tag, 0) # Remove the backend's metadata. - del self.current_state.backends[backend_tag] + del self.backend_state.backends[backend_tag] if backend_tag in self.autoscaling_policies: del self.autoscaling_policies[backend_tag] @@ -998,7 +1000,7 @@ class ServeController: # after pushing the update. self._checkpoint() await self.actor_reconciler._enqueue_pending_scale_changes_loop( - self.current_state) + self.backend_state) await self.actor_reconciler.backend_control_loop() self.notify_replica_handles_changed() @@ -1008,24 +1010,24 @@ class ServeController: config_options: BackendConfig) -> UUID: """Set the config for the specified backend.""" async with self.write_lock: - assert (self.current_state.get_backend(backend_tag) + assert (self.backend_state.get_backend(backend_tag) ), "Backend {} is not registered.".format(backend_tag) assert isinstance(config_options, BackendConfig) - stored_backend_config = self.current_state.get_backend( + stored_backend_config = self.backend_state.get_backend( backend_tag).backend_config backend_config = stored_backend_config.copy( update=config_options.dict(exclude_unset=True)) backend_config._validate_complete() - self.current_state.get_backend( + self.backend_state.get_backend( backend_tag).backend_config = backend_config - backend_info = self.current_state.get_backend(backend_tag) + backend_info = self.backend_state.get_backend(backend_tag) # Scale the replicas with the new configuration. # This should be to run the control loop self.actor_reconciler._scale_backend_replicas( - self.current_state.backends, backend_tag, + self.backend_state.backends, backend_tag, backend_config.num_replicas) return_uuid = self._create_event_with_result({ @@ -1040,7 +1042,7 @@ class ServeController: # (particularly for setting max_batch_size). await self.actor_reconciler._enqueue_pending_scale_changes_loop( - self.current_state) + self.backend_state) await self.actor_reconciler.backend_control_loop() self.notify_replica_handles_changed() @@ -1049,9 +1051,9 @@ class ServeController: def get_backend_config(self, backend_tag: BackendTag) -> BackendConfig: """Get the current config for the specified backend.""" - assert (self.current_state.get_backend(backend_tag) + assert (self.backend_state.get_backend(backend_tag) ), "Backend {} is not registered.".format(backend_tag) - return self.current_state.get_backend(backend_tag).backend_config + return self.backend_state.get_backend(backend_tag).backend_config def get_http_config(self): """Return the HTTP proxy configuration."""