diff --git a/python/ray/serve/api.py b/python/ray/serve/api.py index 947af0ca5..e4a8fa7ce 100644 --- a/python/ray/serve/api.py +++ b/python/ray/serve/api.py @@ -262,7 +262,7 @@ class Client: Does not delete any associated backends. """ - self._get_result(self._controller.delete_endpoint.remote(endpoint)) + ray.get(self._controller.delete_endpoint.remote(endpoint)) @_ensure_connected def list_endpoints(self) -> Dict[str, Dict[str, Any]]: @@ -447,7 +447,7 @@ class Client: traffic_policy_dictionary (dict): a dictionary maps backend names to their traffic weights. The weights must sum to 1. """ - self._get_result( + ray.get( self._controller.set_traffic.remote(endpoint_name, traffic_policy_dictionary)) @@ -473,7 +473,7 @@ class Client: (float, int)) or not 0 <= proportion <= 1: raise TypeError("proportion must be a float from 0 to 1.") - self._get_result( + ray.get( self._controller.shadow_traffic.remote(endpoint_name, backend_tag, proportion)) diff --git a/python/ray/serve/backend_state.py b/python/ray/serve/backend_state.py index 079c65fec..20b70a9e0 100644 --- a/python/ray/serve/backend_state.py +++ b/python/ray/serve/backend_state.py @@ -19,6 +19,12 @@ _RESOURCE_CHECK_ENABLED = True class BackendState: + """Manages all state for backends in the system. + + This class is *not* thread safe, so any state-modifying methods should be + called with a lock held. + """ + def __init__(self, controller_name: str, detached: bool, diff --git a/python/ray/serve/controller.py b/python/ray/serve/controller.py index c9745545e..93c88681c 100644 --- a/python/ray/serve/controller.py +++ b/python/ray/serve/controller.py @@ -4,17 +4,25 @@ import random import time from collections import defaultdict from dataclasses import dataclass -from typing import Any, Dict, Optional -from uuid import UUID, uuid4 +from typing import Dict, Any, List, Optional +from uuid import uuid4, UUID +import ray import ray.cloudpickle as pickle from ray.actor import ActorHandle from ray.serve.backend_state import BackendState from ray.serve.backend_worker import create_backend_replica -from ray.serve.common import (BackendInfo, BackendTag, EndpointTag, GoalId, - NodeId, ReplicaTag, TrafficPolicy) -from ray.serve.config import (BackendConfig, HTTPOptions, ReplicaConfig) from ray.serve.constants import LongPollKey +from ray.serve.common import ( + BackendInfo, + BackendTag, + EndpointTag, + GoalId, + NodeId, + ReplicaTag, + TrafficPolicy, +) +from ray.serve.config import BackendConfig, HTTPOptions, ReplicaConfig from ray.serve.endpoint_state import EndpointState from ray.serve.exceptions import RayServeException from ray.serve.http_state import HTTPState @@ -22,8 +30,6 @@ from ray.serve.kv_store import RayInternalKVStore from ray.serve.long_poll import LongPollHost from ray.serve.utils import logger -import ray - # Used for testing purposes only. If this is set, the controller will crash # after writing each checkpoint with the specified probability. _CRASH_AFTER_CHECKPOINT_PROBABILITY = 0 @@ -32,8 +38,6 @@ CHECKPOINT_KEY = "serve-controller-checkpoint" # How often to call the control loop on the controller. CONTROL_LOOP_PERIOD_S = 1.0 -REPLICA_STARTUP_TIME_WARNING_S = 5 - @dataclass class FutureResult: @@ -43,7 +47,6 @@ class FutureResult: @dataclass class Checkpoint: - endpoint_state_checkpoint: bytes backend_state_checkpoint: bytes # TODO(ilr) Rename reconciler to PendingState inflight_reqs: Dict[uuid4, FutureResult] @@ -94,28 +97,6 @@ class ServeController: self.inflight_results: Dict[UUID, asyncio.Event] = dict() self._serializable_inflight_results: Dict[UUID, FutureResult] = dict() - # HTTP state doesn't currently require a checkpoint. - self.http_state = HTTPState(controller_name, detached, http_config) - - checkpoint_bytes = self.kv_store.get(CHECKPOINT_KEY) - if checkpoint_bytes is None: - logger.debug("No checkpoint found") - self.backend_state = BackendState(controller_name, detached) - self.endpoint_state = EndpointState() - else: - checkpoint: Checkpoint = pickle.loads(checkpoint_bytes) - self.backend_state = BackendState( - controller_name, - detached, - checkpoint=checkpoint.backend_state_checkpoint) - self.endpoint_state = EndpointState( - checkpoint=checkpoint.endpoint_state_checkpoint) - - self._serializable_inflight_results = checkpoint.inflight_reqs - for uuid, fut_result in self._serializable_inflight_results.items( - ): - self._create_event_with_result(fut_result.requested_goal, uuid) - # NOTE(simon): Currently we do all-to-all broadcast. This means # any listeners will receive notification for all changes. This # can be problem at scale, e.g. updating a single backend config @@ -123,10 +104,27 @@ class ServeController: # optimize the logic to support subscription by key. self.long_poll_host = LongPollHost() + self.http_state = HTTPState(controller_name, detached, http_config) + self.endpoint_state = EndpointState(self.kv_store, self.long_poll_host) + + checkpoint_bytes = self.kv_store.get(CHECKPOINT_KEY) + if checkpoint_bytes is None: + logger.debug("No checkpoint found") + self.backend_state = BackendState(controller_name, detached) + else: + checkpoint: Checkpoint = pickle.loads(checkpoint_bytes) + self.backend_state = BackendState( + controller_name, + detached, + checkpoint=checkpoint.backend_state_checkpoint) + + self._serializable_inflight_results = checkpoint.inflight_reqs + for uuid, fut_result in self._serializable_inflight_results.items( + ): + self._create_event_with_result(fut_result.requested_goal, uuid) + self.notify_backend_configs_changed() self.notify_replica_handles_changed() - self.notify_traffic_policies_changed() - self.notify_route_table_changed() asyncio.get_event_loop().create_task(self.run_control_loop()) @@ -168,21 +166,11 @@ class ServeController: self.backend_state.backend_replicas.items() }) - def notify_traffic_policies_changed(self): - self.long_poll_host.notify_changed( - LongPollKey.TRAFFIC_POLICIES, - self.endpoint_state.traffic_policies, - ) - def notify_backend_configs_changed(self): self.long_poll_host.notify_changed( LongPollKey.BACKEND_CONFIGS, self.backend_state.get_backend_configs()) - def notify_route_table_changed(self): - self.long_poll_host.notify_changed(LongPollKey.ROUTE_TABLE, - self.endpoint_state.routes) - async def listen_for_change(self, keys_to_snapshot_ids: Dict[str, int]): """Proxy long pull client's listen request. @@ -205,8 +193,7 @@ class ServeController: start = time.time() checkpoint = pickle.dumps( - Checkpoint(self.endpoint_state.checkpoint(), - self.backend_state.checkpoint(), + Checkpoint(self.backend_state.checkpoint(), self._serializable_inflight_results)) self.kv_store.put(CHECKPOINT_KEY, checkpoint) @@ -254,155 +241,74 @@ class ServeController: """Returns a dictionary of backend tag to backend config.""" return self.endpoint_state.get_endpoints() - async def _set_traffic(self, endpoint_name: str, - traffic_dict: Dict[str, float]) -> UUID: - if endpoint_name not in self.endpoint_state.get_endpoints(): - raise ValueError("Attempted to assign traffic for an endpoint '{}'" - " that is not registered.".format(endpoint_name)) - - assert isinstance(traffic_dict, - dict), "Traffic policy must be a dictionary." - + def _set_traffic(self, endpoint_name: str, + traffic_dict: Dict[str, float]) -> UUID: for backend in traffic_dict: if self.backend_state.get_backend(backend) is None: raise ValueError( "Attempted to assign traffic to a backend '{}' that " "is not registered.".format(backend)) - traffic_policy = TrafficPolicy(traffic_dict) - self.endpoint_state.traffic_policies[endpoint_name] = traffic_policy + self.endpoint_state.set_traffic_policy(endpoint_name, + TrafficPolicy(traffic_dict)) - return_uuid = self._create_event_with_result({ - endpoint_name: traffic_policy - }) - # NOTE(edoakes): we must write a checkpoint before pushing the - # update to avoid inconsistent state if we crash after pushing the - # update. - self._checkpoint() - self.notify_traffic_policies_changed() - self.set_goal_id(return_uuid) - return return_uuid + def _validate_traffic_dict(self, traffic_dict: Dict[str, float]): + for backend in traffic_dict: + if self.backend_state.get_backend(backend) is None: + raise ValueError( + "Attempted to assign traffic to a backend '{}' that " + "is not registered.".format(backend)) async def set_traffic(self, endpoint_name: str, - traffic_dict: Dict[str, float]) -> UUID: + traffic_dict: Dict[str, float]) -> None: """Sets the traffic policy for the specified endpoint.""" async with self.write_lock: - return_uuid = await self._set_traffic(endpoint_name, traffic_dict) - return return_uuid + self._validate_traffic_dict(traffic_dict) + self._set_traffic(endpoint_name, traffic_dict) async def shadow_traffic(self, endpoint_name: str, backend_tag: BackendTag, proportion: float) -> UUID: """Shadow traffic from the endpoint to the backend.""" async with self.write_lock: - if endpoint_name not in self.endpoint_state.get_endpoints(): - raise ValueError("Attempted to shadow traffic from an " - "endpoint '{}' that is not registered." - .format(endpoint_name)) - if self.backend_state.get_backend(backend_tag) is None: raise ValueError( "Attempted to shadow traffic to a backend '{}' that " "is not registered.".format(backend_tag)) - self.endpoint_state.traffic_policies[endpoint_name].set_shadow( - backend_tag, proportion) + logger.info( + "Shadowing '{}' of traffic to endpoint '{}' to backend '{}'.". + format(proportion, endpoint_name, backend_tag)) - traffic_policy = self.endpoint_state.traffic_policies[ - endpoint_name] - - return_uuid = self._create_event_with_result({ - endpoint_name: traffic_policy - }) - # NOTE(edoakes): we must write a checkpoint before pushing the - # update to avoid inconsistent state if we crash after pushing the - # update. - self._checkpoint() - self.notify_traffic_policies_changed() - self.set_goal_id(return_uuid) - return return_uuid + self.endpoint_state.shadow_traffic(endpoint_name, backend_tag, + proportion) # TODO(architkulkarni): add Optional for route after cloudpickle upgrade async def create_endpoint(self, endpoint: str, traffic_dict: Dict[str, float], route, - methods) -> UUID: + methods: List[str]) -> UUID: """Create a new endpoint with the specified route and methods. If the route is None, this is a "headless" endpoint that will not be exposed over HTTP and can only be accessed via a handle. """ async with self.write_lock: - # If this is a headless endpoint with no route, key the endpoint - # based on its name. - # TODO(edoakes): we should probably just store routes and endpoints - # separately. - if route is None: - route = endpoint - - # TODO(edoakes): move this to client side. - err_prefix = "Cannot create endpoint." - if route in self.endpoint_state.routes: - - # Ensures this method is idempotent - if self.endpoint_state.routes[route] == (endpoint, methods): - return - - else: - raise ValueError( - "{} Route '{}' is already registered.".format( - err_prefix, route)) - - if endpoint in self.endpoint_state.get_endpoints(): - raise ValueError( - "{} Endpoint '{}' is already registered.".format( - err_prefix, endpoint)) + self._validate_traffic_dict(traffic_dict) logger.info( "Registering route '{}' to endpoint '{}' with methods '{}'.". format(route, endpoint, methods)) - self.endpoint_state.routes[route] = (endpoint, methods) + self.endpoint_state.create_endpoint(endpoint, route, methods, + TrafficPolicy(traffic_dict)) - # NOTE(edoakes): checkpoint is written in self._set_traffic. - return_uuid = await self._set_traffic(endpoint, traffic_dict) - self.notify_route_table_changed() - return return_uuid - - async def delete_endpoint(self, endpoint: str) -> UUID: + async def delete_endpoint(self, endpoint: str) -> None: """Delete the specified endpoint. Does not modify any corresponding backends. """ logger.info("Deleting endpoint '{}'".format(endpoint)) async with self.write_lock: - # This method must be idempotent. We should validate that the - # specified endpoint exists on the client. - for route, (route_endpoint, - _) in self.endpoint_state.routes.items(): - if route_endpoint == endpoint: - route_to_delete = route - break - else: - logger.info("Endpoint '{}' doesn't exist".format(endpoint)) - return - - # Remove the routing entry. - del self.endpoint_state.routes[route_to_delete] - - # Remove the traffic policy entry if it exists. - if endpoint in self.endpoint_state.traffic_policies: - del self.endpoint_state.traffic_policies[endpoint] - - return_uuid = self._create_event_with_result({ - route_to_delete: None, - endpoint: None - }) - # NOTE(edoakes): we must write a checkpoint before pushing the - # updates to the proxies to avoid inconsistent state if we crash - # after pushing the update. - self._checkpoint() - self.notify_route_table_changed() - self.set_goal_id(return_uuid) - return return_uuid + self.endpoint_state.delete_endpoint(endpoint) async def set_backend_goal(self, backend_tag: BackendTag, backend_info: BackendInfo, @@ -466,19 +372,15 @@ class ServeController: return # Check that the specified backend isn't used by any endpoints. - for endpoint, traffic_policy in self.endpoint_state.\ - traffic_policies.items(): - if (backend_tag in traffic_policy.traffic_dict - or backend_tag in traffic_policy.shadow_dict): + for endpoint, info in self.endpoint_state.get_endpoints().items(): + if (backend_tag in info["traffic"] + or backend_tag in info["shadows"]): raise ValueError("Backend '{}' is used by endpoint '{}' " "and cannot be deleted. Please remove " "the backend from all endpoints and try " "again.".format(backend_tag, endpoint)) - # Scale its replicas down to 0. This will also remove the backend - # from self.backend_state.backends and - - # This should be a call to the control loop + # Scale its replicas down to 0. self.backend_state.scale_backend_replicas(backend_tag, 0, force_kill) diff --git a/python/ray/serve/endpoint_state.py b/python/ray/serve/endpoint_state.py index 4fc880a95..bdbfe2c39 100644 --- a/python/ray/serve/endpoint_state.py +++ b/python/ray/serve/endpoint_state.py @@ -1,25 +1,107 @@ -from typing import Dict, Any, Tuple +from typing import Dict, Any, List, Optional, Tuple import ray.cloudpickle as pickle from ray.serve.common import BackendTag, EndpointTag, TrafficPolicy +from ray.serve.constants import LongPollKey +from ray.serve.kv_store import RayInternalKVStore +from ray.serve.long_poll import LongPollHost + +CHECKPOINT_KEY = "serve-endpoint-state-checkpoint" class EndpointState: - def __init__(self, checkpoint: bytes = None): - self.routes: Dict[BackendTag, Tuple[EndpointTag, Any]] = dict() - self.traffic_policies: Dict[EndpointTag, TrafficPolicy] = dict() + """Manages all state for endpoints in the system. + This class is *not* thread safe, so any state-modifying methods should be + called with a lock held. + """ + + def __init__(self, kv_store: RayInternalKVStore, + long_poll_host: LongPollHost): + self._kv_store = kv_store + self._long_poll_host = long_poll_host + self._routes: Dict[BackendTag, Tuple[EndpointTag, Any]] = dict() + self._traffic_policies: Dict[EndpointTag, TrafficPolicy] = dict() + + checkpoint = self._kv_store.get(CHECKPOINT_KEY) if checkpoint is not None: - self.routes, self.traffic_policies = pickle.loads(checkpoint) + self._routes, self._traffic_policies = pickle.loads(checkpoint) - def checkpoint(self): - return pickle.dumps((self.routes, self.traffic_policies)) + self._notify_route_table_changed() + self._notify_traffic_policies_changed() + + def _checkpoint(self): + self._kv_store.put( + CHECKPOINT_KEY, pickle.dumps((self._routes, + self._traffic_policies))) + + def _notify_route_table_changed(self): + self._long_poll_host.notify_changed(LongPollKey.ROUTE_TABLE, + self._routes) + + def _notify_traffic_policies_changed(self): + self._long_poll_host.notify_changed( + LongPollKey.TRAFFIC_POLICIES, + self._traffic_policies, + ) + + def create_endpoint(self, endpoint: EndpointTag, route: Optional[str], + methods: List[str], traffic_policy: TrafficPolicy): + # If this is a headless endpoint with no route, key the endpoint + # based on its name. + # TODO(edoakes): we should probably just store routes and endpoints + # separately. + if route is None: + route = endpoint + + err_prefix = "Cannot create endpoint." + if route in self._routes: + # Ensures this method is idempotent + if self._routes[route] == (endpoint, methods): + return + else: + raise ValueError("{} Route '{}' is already registered.".format( + err_prefix, route)) + + if endpoint in self._traffic_policies: + raise ValueError("{} Endpoint '{}' is already registered.".format( + err_prefix, endpoint)) + + self._routes[route] = (endpoint, methods) + self._traffic_policies[endpoint] = traffic_policy + + self._checkpoint() + self._notify_route_table_changed() + self._notify_traffic_policies_changed() + + def set_traffic_policy(self, endpoint: EndpointTag, + traffic_policy: TrafficPolicy): + if endpoint not in self._traffic_policies: + raise ValueError("Attempted to assign traffic for an endpoint '{}'" + " that is not registered.".format(endpoint)) + + self._traffic_policies[endpoint] = traffic_policy + + self._checkpoint() + self._notify_traffic_policies_changed() + + def shadow_traffic(self, endpoint: EndpointTag, backend: BackendTag, + proportion: float): + if endpoint not in self._traffic_policies: + raise ValueError("Attempted to shadow traffic from an " + "endpoint '{}' that is not registered." + .format(endpoint)) + + self._traffic_policies[endpoint].set_shadow(backend, proportion) + + self._checkpoint() + self._notify_traffic_policies_changed() def get_endpoints(self) -> Dict[EndpointTag, Dict[str, Any]]: endpoints = {} - for route, (endpoint, methods) in self.routes.items(): - if endpoint in self.traffic_policies: - traffic_policy = self.traffic_policies[endpoint] + for route, (endpoint, methods) in self._routes.items(): + if endpoint in self._traffic_policies: + traffic_policy = self._traffic_policies[endpoint] traffic_dict = traffic_policy.traffic_dict shadow_dict = traffic_policy.shadow_dict else: @@ -33,3 +115,19 @@ class EndpointState: "shadows": shadow_dict, } return endpoints + + def delete_endpoint(self, endpoint: EndpointTag) -> None: + # This method must be idempotent. We should validate that the + # specified endpoint exists on the client. + for route, (route_endpoint, _) in self._routes.items(): + if route_endpoint == endpoint: + route_to_delete = route + break + else: + return + + del self._routes[route_to_delete] + del self._traffic_policies[endpoint] + + self._checkpoint() + self._notify_route_table_changed() diff --git a/python/ray/serve/http_state.py b/python/ray/serve/http_state.py index 7e2b0cf9c..ecf4c6783 100644 --- a/python/ray/serve/http_state.py +++ b/python/ray/serve/http_state.py @@ -11,6 +11,12 @@ from ray.serve.common import NodeId class HTTPState: + """Manages all state for HTTP proxies in the system. + + This class is *not* thread safe, so any state-modifying methods should be + called with a lock held. + """ + def __init__(self, controller_name: str, detached: bool, config: HTTPOptions): self._controller_name = controller_name