[serve] Clean up EndpointState interface, move checkpointing inside of EndpointState (#13215)

This commit is contained in:
Edward Oakes
2021-01-08 22:36:19 -06:00
committed by GitHub
parent c5ae30d1d4
commit d434ba6518
5 changed files with 184 additions and 172 deletions
+3 -3
View File
@@ -262,7 +262,7 @@ class Client:
Does not delete any associated backends.
"""
self._get_result(self._controller.delete_endpoint.remote(endpoint))
ray.get(self._controller.delete_endpoint.remote(endpoint))
@_ensure_connected
def list_endpoints(self) -> Dict[str, Dict[str, Any]]:
@@ -447,7 +447,7 @@ class Client:
traffic_policy_dictionary (dict): a dictionary maps backend names
to their traffic weights. The weights must sum to 1.
"""
self._get_result(
ray.get(
self._controller.set_traffic.remote(endpoint_name,
traffic_policy_dictionary))
@@ -473,7 +473,7 @@ class Client:
(float, int)) or not 0 <= proportion <= 1:
raise TypeError("proportion must be a float from 0 to 1.")
self._get_result(
ray.get(
self._controller.shadow_traffic.remote(endpoint_name, backend_tag,
proportion))
+6
View File
@@ -19,6 +19,12 @@ _RESOURCE_CHECK_ENABLED = True
class BackendState:
"""Manages all state for backends in the system.
This class is *not* thread safe, so any state-modifying methods should be
called with a lock held.
"""
def __init__(self,
controller_name: str,
detached: bool,
+61 -159
View File
@@ -4,17 +4,25 @@ import random
import time
from collections import defaultdict
from dataclasses import dataclass
from typing import Any, Dict, Optional
from uuid import UUID, uuid4
from typing import Dict, Any, List, Optional
from uuid import uuid4, UUID
import ray
import ray.cloudpickle as pickle
from ray.actor import ActorHandle
from ray.serve.backend_state import BackendState
from ray.serve.backend_worker import create_backend_replica
from ray.serve.common import (BackendInfo, BackendTag, EndpointTag, GoalId,
NodeId, ReplicaTag, TrafficPolicy)
from ray.serve.config import (BackendConfig, HTTPOptions, ReplicaConfig)
from ray.serve.constants import LongPollKey
from ray.serve.common import (
BackendInfo,
BackendTag,
EndpointTag,
GoalId,
NodeId,
ReplicaTag,
TrafficPolicy,
)
from ray.serve.config import BackendConfig, HTTPOptions, ReplicaConfig
from ray.serve.endpoint_state import EndpointState
from ray.serve.exceptions import RayServeException
from ray.serve.http_state import HTTPState
@@ -22,8 +30,6 @@ from ray.serve.kv_store import RayInternalKVStore
from ray.serve.long_poll import LongPollHost
from ray.serve.utils import logger
import ray
# Used for testing purposes only. If this is set, the controller will crash
# after writing each checkpoint with the specified probability.
_CRASH_AFTER_CHECKPOINT_PROBABILITY = 0
@@ -32,8 +38,6 @@ CHECKPOINT_KEY = "serve-controller-checkpoint"
# How often to call the control loop on the controller.
CONTROL_LOOP_PERIOD_S = 1.0
REPLICA_STARTUP_TIME_WARNING_S = 5
@dataclass
class FutureResult:
@@ -43,7 +47,6 @@ class FutureResult:
@dataclass
class Checkpoint:
endpoint_state_checkpoint: bytes
backend_state_checkpoint: bytes
# TODO(ilr) Rename reconciler to PendingState
inflight_reqs: Dict[uuid4, FutureResult]
@@ -94,28 +97,6 @@ class ServeController:
self.inflight_results: Dict[UUID, asyncio.Event] = dict()
self._serializable_inflight_results: Dict[UUID, FutureResult] = dict()
# HTTP state doesn't currently require a checkpoint.
self.http_state = HTTPState(controller_name, detached, http_config)
checkpoint_bytes = self.kv_store.get(CHECKPOINT_KEY)
if checkpoint_bytes is None:
logger.debug("No checkpoint found")
self.backend_state = BackendState(controller_name, detached)
self.endpoint_state = EndpointState()
else:
checkpoint: Checkpoint = pickle.loads(checkpoint_bytes)
self.backend_state = BackendState(
controller_name,
detached,
checkpoint=checkpoint.backend_state_checkpoint)
self.endpoint_state = EndpointState(
checkpoint=checkpoint.endpoint_state_checkpoint)
self._serializable_inflight_results = checkpoint.inflight_reqs
for uuid, fut_result in self._serializable_inflight_results.items(
):
self._create_event_with_result(fut_result.requested_goal, uuid)
# NOTE(simon): Currently we do all-to-all broadcast. This means
# any listeners will receive notification for all changes. This
# can be problem at scale, e.g. updating a single backend config
@@ -123,10 +104,27 @@ class ServeController:
# optimize the logic to support subscription by key.
self.long_poll_host = LongPollHost()
self.http_state = HTTPState(controller_name, detached, http_config)
self.endpoint_state = EndpointState(self.kv_store, self.long_poll_host)
checkpoint_bytes = self.kv_store.get(CHECKPOINT_KEY)
if checkpoint_bytes is None:
logger.debug("No checkpoint found")
self.backend_state = BackendState(controller_name, detached)
else:
checkpoint: Checkpoint = pickle.loads(checkpoint_bytes)
self.backend_state = BackendState(
controller_name,
detached,
checkpoint=checkpoint.backend_state_checkpoint)
self._serializable_inflight_results = checkpoint.inflight_reqs
for uuid, fut_result in self._serializable_inflight_results.items(
):
self._create_event_with_result(fut_result.requested_goal, uuid)
self.notify_backend_configs_changed()
self.notify_replica_handles_changed()
self.notify_traffic_policies_changed()
self.notify_route_table_changed()
asyncio.get_event_loop().create_task(self.run_control_loop())
@@ -168,21 +166,11 @@ class ServeController:
self.backend_state.backend_replicas.items()
})
def notify_traffic_policies_changed(self):
self.long_poll_host.notify_changed(
LongPollKey.TRAFFIC_POLICIES,
self.endpoint_state.traffic_policies,
)
def notify_backend_configs_changed(self):
self.long_poll_host.notify_changed(
LongPollKey.BACKEND_CONFIGS,
self.backend_state.get_backend_configs())
def notify_route_table_changed(self):
self.long_poll_host.notify_changed(LongPollKey.ROUTE_TABLE,
self.endpoint_state.routes)
async def listen_for_change(self, keys_to_snapshot_ids: Dict[str, int]):
"""Proxy long pull client's listen request.
@@ -205,8 +193,7 @@ class ServeController:
start = time.time()
checkpoint = pickle.dumps(
Checkpoint(self.endpoint_state.checkpoint(),
self.backend_state.checkpoint(),
Checkpoint(self.backend_state.checkpoint(),
self._serializable_inflight_results))
self.kv_store.put(CHECKPOINT_KEY, checkpoint)
@@ -254,155 +241,74 @@ class ServeController:
"""Returns a dictionary of backend tag to backend config."""
return self.endpoint_state.get_endpoints()
async def _set_traffic(self, endpoint_name: str,
traffic_dict: Dict[str, float]) -> UUID:
if endpoint_name not in self.endpoint_state.get_endpoints():
raise ValueError("Attempted to assign traffic for an endpoint '{}'"
" that is not registered.".format(endpoint_name))
assert isinstance(traffic_dict,
dict), "Traffic policy must be a dictionary."
def _set_traffic(self, endpoint_name: str,
traffic_dict: Dict[str, float]) -> UUID:
for backend in traffic_dict:
if self.backend_state.get_backend(backend) is None:
raise ValueError(
"Attempted to assign traffic to a backend '{}' that "
"is not registered.".format(backend))
traffic_policy = TrafficPolicy(traffic_dict)
self.endpoint_state.traffic_policies[endpoint_name] = traffic_policy
self.endpoint_state.set_traffic_policy(endpoint_name,
TrafficPolicy(traffic_dict))
return_uuid = self._create_event_with_result({
endpoint_name: traffic_policy
})
# NOTE(edoakes): we must write a checkpoint before pushing the
# update to avoid inconsistent state if we crash after pushing the
# update.
self._checkpoint()
self.notify_traffic_policies_changed()
self.set_goal_id(return_uuid)
return return_uuid
def _validate_traffic_dict(self, traffic_dict: Dict[str, float]):
for backend in traffic_dict:
if self.backend_state.get_backend(backend) is None:
raise ValueError(
"Attempted to assign traffic to a backend '{}' that "
"is not registered.".format(backend))
async def set_traffic(self, endpoint_name: str,
traffic_dict: Dict[str, float]) -> UUID:
traffic_dict: Dict[str, float]) -> None:
"""Sets the traffic policy for the specified endpoint."""
async with self.write_lock:
return_uuid = await self._set_traffic(endpoint_name, traffic_dict)
return return_uuid
self._validate_traffic_dict(traffic_dict)
self._set_traffic(endpoint_name, traffic_dict)
async def shadow_traffic(self, endpoint_name: str, backend_tag: BackendTag,
proportion: float) -> UUID:
"""Shadow traffic from the endpoint to the backend."""
async with self.write_lock:
if endpoint_name not in self.endpoint_state.get_endpoints():
raise ValueError("Attempted to shadow traffic from an "
"endpoint '{}' that is not registered."
.format(endpoint_name))
if self.backend_state.get_backend(backend_tag) is None:
raise ValueError(
"Attempted to shadow traffic to a backend '{}' that "
"is not registered.".format(backend_tag))
self.endpoint_state.traffic_policies[endpoint_name].set_shadow(
backend_tag, proportion)
logger.info(
"Shadowing '{}' of traffic to endpoint '{}' to backend '{}'.".
format(proportion, endpoint_name, backend_tag))
traffic_policy = self.endpoint_state.traffic_policies[
endpoint_name]
return_uuid = self._create_event_with_result({
endpoint_name: traffic_policy
})
# NOTE(edoakes): we must write a checkpoint before pushing the
# update to avoid inconsistent state if we crash after pushing the
# update.
self._checkpoint()
self.notify_traffic_policies_changed()
self.set_goal_id(return_uuid)
return return_uuid
self.endpoint_state.shadow_traffic(endpoint_name, backend_tag,
proportion)
# TODO(architkulkarni): add Optional for route after cloudpickle upgrade
async def create_endpoint(self, endpoint: str,
traffic_dict: Dict[str, float], route,
methods) -> UUID:
methods: List[str]) -> UUID:
"""Create a new endpoint with the specified route and methods.
If the route is None, this is a "headless" endpoint that will not
be exposed over HTTP and can only be accessed via a handle.
"""
async with self.write_lock:
# If this is a headless endpoint with no route, key the endpoint
# based on its name.
# TODO(edoakes): we should probably just store routes and endpoints
# separately.
if route is None:
route = endpoint
# TODO(edoakes): move this to client side.
err_prefix = "Cannot create endpoint."
if route in self.endpoint_state.routes:
# Ensures this method is idempotent
if self.endpoint_state.routes[route] == (endpoint, methods):
return
else:
raise ValueError(
"{} Route '{}' is already registered.".format(
err_prefix, route))
if endpoint in self.endpoint_state.get_endpoints():
raise ValueError(
"{} Endpoint '{}' is already registered.".format(
err_prefix, endpoint))
self._validate_traffic_dict(traffic_dict)
logger.info(
"Registering route '{}' to endpoint '{}' with methods '{}'.".
format(route, endpoint, methods))
self.endpoint_state.routes[route] = (endpoint, methods)
self.endpoint_state.create_endpoint(endpoint, route, methods,
TrafficPolicy(traffic_dict))
# NOTE(edoakes): checkpoint is written in self._set_traffic.
return_uuid = await self._set_traffic(endpoint, traffic_dict)
self.notify_route_table_changed()
return return_uuid
async def delete_endpoint(self, endpoint: str) -> UUID:
async def delete_endpoint(self, endpoint: str) -> None:
"""Delete the specified endpoint.
Does not modify any corresponding backends.
"""
logger.info("Deleting endpoint '{}'".format(endpoint))
async with self.write_lock:
# This method must be idempotent. We should validate that the
# specified endpoint exists on the client.
for route, (route_endpoint,
_) in self.endpoint_state.routes.items():
if route_endpoint == endpoint:
route_to_delete = route
break
else:
logger.info("Endpoint '{}' doesn't exist".format(endpoint))
return
# Remove the routing entry.
del self.endpoint_state.routes[route_to_delete]
# Remove the traffic policy entry if it exists.
if endpoint in self.endpoint_state.traffic_policies:
del self.endpoint_state.traffic_policies[endpoint]
return_uuid = self._create_event_with_result({
route_to_delete: None,
endpoint: None
})
# NOTE(edoakes): we must write a checkpoint before pushing the
# updates to the proxies to avoid inconsistent state if we crash
# after pushing the update.
self._checkpoint()
self.notify_route_table_changed()
self.set_goal_id(return_uuid)
return return_uuid
self.endpoint_state.delete_endpoint(endpoint)
async def set_backend_goal(self, backend_tag: BackendTag,
backend_info: BackendInfo,
@@ -466,19 +372,15 @@ class ServeController:
return
# Check that the specified backend isn't used by any endpoints.
for endpoint, traffic_policy in self.endpoint_state.\
traffic_policies.items():
if (backend_tag in traffic_policy.traffic_dict
or backend_tag in traffic_policy.shadow_dict):
for endpoint, info in self.endpoint_state.get_endpoints().items():
if (backend_tag in info["traffic"]
or backend_tag in info["shadows"]):
raise ValueError("Backend '{}' is used by endpoint '{}' "
"and cannot be deleted. Please remove "
"the backend from all endpoints and try "
"again.".format(backend_tag, endpoint))
# Scale its replicas down to 0. This will also remove the backend
# from self.backend_state.backends and
# This should be a call to the control loop
# Scale its replicas down to 0.
self.backend_state.scale_backend_replicas(backend_tag, 0,
force_kill)
+108 -10
View File
@@ -1,25 +1,107 @@
from typing import Dict, Any, Tuple
from typing import Dict, Any, List, Optional, Tuple
import ray.cloudpickle as pickle
from ray.serve.common import BackendTag, EndpointTag, TrafficPolicy
from ray.serve.constants import LongPollKey
from ray.serve.kv_store import RayInternalKVStore
from ray.serve.long_poll import LongPollHost
CHECKPOINT_KEY = "serve-endpoint-state-checkpoint"
class EndpointState:
def __init__(self, checkpoint: bytes = None):
self.routes: Dict[BackendTag, Tuple[EndpointTag, Any]] = dict()
self.traffic_policies: Dict[EndpointTag, TrafficPolicy] = dict()
"""Manages all state for endpoints in the system.
This class is *not* thread safe, so any state-modifying methods should be
called with a lock held.
"""
def __init__(self, kv_store: RayInternalKVStore,
long_poll_host: LongPollHost):
self._kv_store = kv_store
self._long_poll_host = long_poll_host
self._routes: Dict[BackendTag, Tuple[EndpointTag, Any]] = dict()
self._traffic_policies: Dict[EndpointTag, TrafficPolicy] = dict()
checkpoint = self._kv_store.get(CHECKPOINT_KEY)
if checkpoint is not None:
self.routes, self.traffic_policies = pickle.loads(checkpoint)
self._routes, self._traffic_policies = pickle.loads(checkpoint)
def checkpoint(self):
return pickle.dumps((self.routes, self.traffic_policies))
self._notify_route_table_changed()
self._notify_traffic_policies_changed()
def _checkpoint(self):
self._kv_store.put(
CHECKPOINT_KEY, pickle.dumps((self._routes,
self._traffic_policies)))
def _notify_route_table_changed(self):
self._long_poll_host.notify_changed(LongPollKey.ROUTE_TABLE,
self._routes)
def _notify_traffic_policies_changed(self):
self._long_poll_host.notify_changed(
LongPollKey.TRAFFIC_POLICIES,
self._traffic_policies,
)
def create_endpoint(self, endpoint: EndpointTag, route: Optional[str],
methods: List[str], traffic_policy: TrafficPolicy):
# If this is a headless endpoint with no route, key the endpoint
# based on its name.
# TODO(edoakes): we should probably just store routes and endpoints
# separately.
if route is None:
route = endpoint
err_prefix = "Cannot create endpoint."
if route in self._routes:
# Ensures this method is idempotent
if self._routes[route] == (endpoint, methods):
return
else:
raise ValueError("{} Route '{}' is already registered.".format(
err_prefix, route))
if endpoint in self._traffic_policies:
raise ValueError("{} Endpoint '{}' is already registered.".format(
err_prefix, endpoint))
self._routes[route] = (endpoint, methods)
self._traffic_policies[endpoint] = traffic_policy
self._checkpoint()
self._notify_route_table_changed()
self._notify_traffic_policies_changed()
def set_traffic_policy(self, endpoint: EndpointTag,
traffic_policy: TrafficPolicy):
if endpoint not in self._traffic_policies:
raise ValueError("Attempted to assign traffic for an endpoint '{}'"
" that is not registered.".format(endpoint))
self._traffic_policies[endpoint] = traffic_policy
self._checkpoint()
self._notify_traffic_policies_changed()
def shadow_traffic(self, endpoint: EndpointTag, backend: BackendTag,
proportion: float):
if endpoint not in self._traffic_policies:
raise ValueError("Attempted to shadow traffic from an "
"endpoint '{}' that is not registered."
.format(endpoint))
self._traffic_policies[endpoint].set_shadow(backend, proportion)
self._checkpoint()
self._notify_traffic_policies_changed()
def get_endpoints(self) -> Dict[EndpointTag, Dict[str, Any]]:
endpoints = {}
for route, (endpoint, methods) in self.routes.items():
if endpoint in self.traffic_policies:
traffic_policy = self.traffic_policies[endpoint]
for route, (endpoint, methods) in self._routes.items():
if endpoint in self._traffic_policies:
traffic_policy = self._traffic_policies[endpoint]
traffic_dict = traffic_policy.traffic_dict
shadow_dict = traffic_policy.shadow_dict
else:
@@ -33,3 +115,19 @@ class EndpointState:
"shadows": shadow_dict,
}
return endpoints
def delete_endpoint(self, endpoint: EndpointTag) -> None:
# This method must be idempotent. We should validate that the
# specified endpoint exists on the client.
for route, (route_endpoint, _) in self._routes.items():
if route_endpoint == endpoint:
route_to_delete = route
break
else:
return
del self._routes[route_to_delete]
del self._traffic_policies[endpoint]
self._checkpoint()
self._notify_route_table_changed()
+6
View File
@@ -11,6 +11,12 @@ from ray.serve.common import NodeId
class HTTPState:
"""Manages all state for HTTP proxies in the system.
This class is *not* thread safe, so any state-modifying methods should be
called with a lock held.
"""
def __init__(self, controller_name: str, detached: bool,
config: HTTPOptions):
self._controller_name = controller_name