mirror of
https://github.com/wassname/ray.git
synced 2026-06-29 09:55:49 +08:00
[serve] Create FutureResults from ControllerAPI (#12577)
This commit is contained in:
@@ -17,6 +17,14 @@ py_test(
|
||||
deps = [":serve_lib"],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "test_controller",
|
||||
size = "small",
|
||||
srcs = serve_tests_srcs,
|
||||
tags = ["exclusive"],
|
||||
deps = [":serve_lib"],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "test_backend_worker",
|
||||
size = "small",
|
||||
|
||||
+15
-7
@@ -1,6 +1,7 @@
|
||||
import atexit
|
||||
from functools import wraps
|
||||
import os
|
||||
from uuid import UUID
|
||||
|
||||
import ray
|
||||
from ray.serve.constants import (DEFAULT_HTTP_HOST, DEFAULT_HTTP_PORT,
|
||||
@@ -83,6 +84,13 @@ class Client:
|
||||
ray.kill(self._controller, no_restart=True)
|
||||
self._shutdown = True
|
||||
|
||||
@_ensure_connected
|
||||
def _get_result(self, result_object_id: ray.ObjectRef) -> bool:
|
||||
result_id: UUID = ray.get(result_object_id)
|
||||
result = ray.get(self._controller.wait_for_event.remote(result_id))
|
||||
logger.debug(f"Getting result_id ({result_id}) with result: {result}")
|
||||
return result
|
||||
|
||||
@_ensure_connected
|
||||
def create_endpoint(self,
|
||||
endpoint_name: str,
|
||||
@@ -137,7 +145,7 @@ class Client:
|
||||
"an element of type {}".format(type(method)))
|
||||
upper_methods.append(method.upper())
|
||||
|
||||
ray.get(
|
||||
self._get_result(
|
||||
self._controller.create_endpoint.remote(
|
||||
endpoint_name, {backend: 1.0}, route, upper_methods))
|
||||
|
||||
@@ -149,7 +157,7 @@ class Client:
|
||||
"""
|
||||
if endpoint in self._handle_cache:
|
||||
del self._handle_cache[endpoint]
|
||||
ray.get(self._controller.delete_endpoint.remote(endpoint))
|
||||
self._get_result(self._controller.delete_endpoint.remote(endpoint))
|
||||
|
||||
@_ensure_connected
|
||||
def list_endpoints(self) -> Dict[str, Dict[str, Any]]:
|
||||
@@ -193,7 +201,7 @@ class Client:
|
||||
"config_options must be a BackendConfig or dictionary.")
|
||||
if isinstance(config_options, dict):
|
||||
config_options = BackendConfig.parse_obj(config_options)
|
||||
ray.get(
|
||||
self._get_result(
|
||||
self._controller.update_backend_config.remote(
|
||||
backend_tag, config_options))
|
||||
|
||||
@@ -290,7 +298,7 @@ class Client:
|
||||
raise TypeError("config must be a BackendConfig or a dictionary.")
|
||||
|
||||
backend_config._validate_complete()
|
||||
ray.get(
|
||||
self._get_result(
|
||||
self._controller.create_backend.remote(backend_tag, backend_config,
|
||||
replica_config))
|
||||
|
||||
@@ -308,7 +316,7 @@ class Client:
|
||||
|
||||
The backend must not currently be used by any endpoints.
|
||||
"""
|
||||
ray.get(self._controller.delete_backend.remote(backend_tag))
|
||||
self._get_result(self._controller.delete_backend.remote(backend_tag))
|
||||
|
||||
@_ensure_connected
|
||||
def set_traffic(self, endpoint_name: str,
|
||||
@@ -327,7 +335,7 @@ class Client:
|
||||
traffic_policy_dictionary (dict): a dictionary maps backend names
|
||||
to their traffic weights. The weights must sum to 1.
|
||||
"""
|
||||
ray.get(
|
||||
self._get_result(
|
||||
self._controller.set_traffic.remote(endpoint_name,
|
||||
traffic_policy_dictionary))
|
||||
|
||||
@@ -353,7 +361,7 @@ class Client:
|
||||
(float, int)) or not 0 <= proportion <= 1:
|
||||
raise TypeError("proportion must be a float from 0 to 1.")
|
||||
|
||||
ray.get(
|
||||
self._get_result(
|
||||
self._controller.shadow_traffic.remote(endpoint_name, backend_tag,
|
||||
proportion))
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ import random
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, Any, List, Optional, Tuple
|
||||
from uuid import uuid4, UUID
|
||||
from pydantic import BaseModel
|
||||
|
||||
import ray
|
||||
@@ -420,12 +421,19 @@ class ActorStateReconciler:
|
||||
return autoscaling_policies
|
||||
|
||||
|
||||
@dataclass
|
||||
class FutureResult:
|
||||
# Goal requested when this future was created
|
||||
requested_goal: Dict[str, Any]
|
||||
|
||||
|
||||
@dataclass
|
||||
class Checkpoint:
|
||||
goal_state: SystemState
|
||||
current_state: SystemState
|
||||
reconciler: ActorStateReconciler
|
||||
# TODO(ilr) Rename reconciler to PendingState
|
||||
inflight_reqs: Dict[uuid4, FutureResult]
|
||||
|
||||
|
||||
@ray.remote
|
||||
@@ -490,6 +498,11 @@ class ServeController:
|
||||
self.actor_reconciler._start_routers_if_needed(
|
||||
self.http_host, self.http_port, self.http_middlewares)
|
||||
|
||||
# Map of awaiting results
|
||||
# TODO(ilr): Checkpoint this once this becomes asynchronous
|
||||
self.inflight_results: Dict[UUID, asyncio.Event] = dict()
|
||||
self._serializable_inflight_results: Dict[UUID, FutureResult] = dict()
|
||||
|
||||
# NOTE(edoakes): unfortunately, we can't completely recover from a
|
||||
# checkpoint in the constructor because we block while waiting for
|
||||
# other actors to start up, and those actors fetch soft state from
|
||||
@@ -522,6 +535,34 @@ class ServeController:
|
||||
|
||||
asyncio.get_event_loop().create_task(self.run_control_loop())
|
||||
|
||||
async def wait_for_event(self, uuid: UUID) -> bool:
|
||||
if uuid not in self.inflight_results:
|
||||
return True
|
||||
event = self.inflight_results[uuid]
|
||||
await event.wait()
|
||||
self.inflight_results.pop(uuid)
|
||||
self._serializable_inflight_results.pop(uuid)
|
||||
async with self.write_lock:
|
||||
self._checkpoint()
|
||||
|
||||
return True
|
||||
|
||||
def _create_event_with_result(
|
||||
self,
|
||||
goal_state: Dict[str, any],
|
||||
recreation_uuid: Optional[UUID] = None) -> UUID:
|
||||
# NOTE(ilr) Must be called before checkpointing!
|
||||
event = asyncio.Event()
|
||||
event.result = FutureResult(goal_state)
|
||||
event.set()
|
||||
uuid_val = recreation_uuid or uuid4()
|
||||
self.inflight_results[uuid_val] = event
|
||||
self._serializable_inflight_results[uuid_val] = event.result
|
||||
return uuid_val
|
||||
|
||||
async def _num_inflight_results(self) -> int:
|
||||
return len(self.inflight_results)
|
||||
|
||||
def notify_replica_handles_changed(self):
|
||||
self.long_poll_host.notify_changed(
|
||||
"worker_handles", {
|
||||
@@ -565,7 +606,8 @@ class ServeController:
|
||||
|
||||
checkpoint = pickle.dumps(
|
||||
Checkpoint(self.goal_state, self.current_state,
|
||||
self.actor_reconciler))
|
||||
self.actor_reconciler,
|
||||
self._serializable_inflight_results))
|
||||
|
||||
self.kv_store.put(CHECKPOINT_KEY, checkpoint)
|
||||
logger.debug("Wrote checkpoint in {:.2f}".format(time.time() - start))
|
||||
@@ -600,6 +642,11 @@ class ServeController:
|
||||
# Restore ActorStateReconciler
|
||||
self.actor_reconciler = restored_checkpoint.reconciler
|
||||
|
||||
# Recreate self.inflight_requests!
|
||||
self._serializable_inflight_results = restored_checkpoint.inflight_reqs
|
||||
for uuid, fut_result in self._serializable_inflight_results.items():
|
||||
self._create_event_with_result(fut_result.requested_goal, uuid)
|
||||
|
||||
self.autoscaling_policies = await self.actor_reconciler.\
|
||||
_recover_from_checkpoint(self.current_state, self)
|
||||
|
||||
@@ -660,7 +707,7 @@ class ServeController:
|
||||
return self.current_state.get_endpoints()
|
||||
|
||||
async def _set_traffic(self, endpoint_name: str,
|
||||
traffic_dict: Dict[str, float]) -> None:
|
||||
traffic_dict: Dict[str, float]) -> UUID:
|
||||
if endpoint_name not in self.current_state.get_endpoints():
|
||||
raise ValueError("Attempted to assign traffic for an endpoint '{}'"
|
||||
" that is not registered.".format(endpoint_name))
|
||||
@@ -677,21 +724,26 @@ class ServeController:
|
||||
traffic_policy = TrafficPolicy(traffic_dict)
|
||||
self.current_state.traffic_policies[endpoint_name] = traffic_policy
|
||||
|
||||
return_uuid = self._create_event_with_result({
|
||||
endpoint_name: traffic_policy
|
||||
})
|
||||
# NOTE(edoakes): we must write a checkpoint before pushing the
|
||||
# update to avoid inconsistent state if we crash after pushing the
|
||||
# update.
|
||||
self._checkpoint()
|
||||
|
||||
self.notify_traffic_policies_changed()
|
||||
return return_uuid
|
||||
|
||||
async def set_traffic(self, endpoint_name: str,
|
||||
traffic_dict: Dict[str, float]) -> None:
|
||||
traffic_dict: Dict[str, float]) -> UUID:
|
||||
"""Sets the traffic policy for the specified endpoint."""
|
||||
async with self.write_lock:
|
||||
await self._set_traffic(endpoint_name, traffic_dict)
|
||||
return_uuid = await self._set_traffic(endpoint_name, traffic_dict)
|
||||
return return_uuid
|
||||
|
||||
async def shadow_traffic(self, endpoint_name: str, backend_tag: BackendTag,
|
||||
proportion: float) -> None:
|
||||
proportion: float) -> UUID:
|
||||
"""Shadow traffic from the endpoint to the backend."""
|
||||
async with self.write_lock:
|
||||
if endpoint_name not in self.current_state.get_endpoints():
|
||||
@@ -707,16 +759,22 @@ class ServeController:
|
||||
self.current_state.traffic_policies[endpoint_name].set_shadow(
|
||||
backend_tag, proportion)
|
||||
|
||||
traffic_policy = self.current_state.traffic_policies[endpoint_name]
|
||||
|
||||
return_uuid = self._create_event_with_result({
|
||||
endpoint_name: traffic_policy
|
||||
})
|
||||
# NOTE(edoakes): we must write a checkpoint before pushing the
|
||||
# update to avoid inconsistent state if we crash after pushing the
|
||||
# update.
|
||||
self._checkpoint()
|
||||
self.notify_traffic_policies_changed()
|
||||
return return_uuid
|
||||
|
||||
# TODO(architkulkarni): add Optional for route after cloudpickle upgrade
|
||||
async def create_endpoint(self, endpoint: str,
|
||||
traffic_dict: Dict[str, float], route,
|
||||
methods) -> None:
|
||||
methods) -> UUID:
|
||||
"""Create a new endpoint with the specified route and methods.
|
||||
|
||||
If the route is None, this is a "headless" endpoint that will not
|
||||
@@ -755,13 +813,14 @@ class ServeController:
|
||||
self.current_state.routes[route] = (endpoint, methods)
|
||||
|
||||
# NOTE(edoakes): checkpoint is written in self._set_traffic.
|
||||
await self._set_traffic(endpoint, traffic_dict)
|
||||
return_uuid = await self._set_traffic(endpoint, traffic_dict)
|
||||
await asyncio.gather(*[
|
||||
router.set_route_table.remote(self.current_state.routes)
|
||||
for router in self.actor_reconciler.router_handles()
|
||||
])
|
||||
return return_uuid
|
||||
|
||||
async def delete_endpoint(self, endpoint: str) -> None:
|
||||
async def delete_endpoint(self, endpoint: str) -> UUID:
|
||||
"""Delete the specified endpoint.
|
||||
|
||||
Does not modify any corresponding backends.
|
||||
@@ -788,6 +847,10 @@ class ServeController:
|
||||
|
||||
self.actor_reconciler.endpoints_to_remove.append(endpoint)
|
||||
|
||||
return_uuid = self._create_event_with_result({
|
||||
route_to_delete: None,
|
||||
endpoint: None
|
||||
})
|
||||
# NOTE(edoakes): we must write a checkpoint before pushing the
|
||||
# updates to the routers to avoid inconsistent state if we crash
|
||||
# after pushing the update.
|
||||
@@ -797,10 +860,11 @@ class ServeController:
|
||||
router.set_route_table.remote(self.current_state.routes)
|
||||
for router in self.actor_reconciler.router_handles()
|
||||
])
|
||||
return return_uuid
|
||||
|
||||
async def create_backend(self, backend_tag: BackendTag,
|
||||
backend_config: BackendConfig,
|
||||
replica_config: ReplicaConfig) -> None:
|
||||
replica_config: ReplicaConfig) -> UUID:
|
||||
"""Register a new backend under the specified tag."""
|
||||
async with self.write_lock:
|
||||
# Ensures this method is idempotent.
|
||||
@@ -815,12 +879,11 @@ class ServeController:
|
||||
|
||||
# Save creator that starts replicas, the arguments to be passed in,
|
||||
# and the configuration for the backends.
|
||||
self.current_state.add_backend(
|
||||
backend_tag,
|
||||
BackendInfo(
|
||||
worker_class=backend_replica,
|
||||
backend_config=backend_config,
|
||||
replica_config=replica_config))
|
||||
backend_info = BackendInfo(
|
||||
worker_class=backend_replica,
|
||||
backend_config=backend_config,
|
||||
replica_config=replica_config)
|
||||
self.current_state.add_backend(backend_tag, backend_info)
|
||||
metadata = backend_config.internal_metadata
|
||||
if metadata.autoscaling_config is not None:
|
||||
self.autoscaling_policies[
|
||||
@@ -835,6 +898,9 @@ class ServeController:
|
||||
del self.current_state.backends[backend_tag]
|
||||
raise e
|
||||
|
||||
return_uuid = self._create_event_with_result({
|
||||
backend_tag: backend_info
|
||||
})
|
||||
# NOTE(edoakes): we must write a checkpoint before starting new
|
||||
# or pushing the updated config to avoid inconsistent state if we
|
||||
# crash while making the change.
|
||||
@@ -847,8 +913,9 @@ class ServeController:
|
||||
# Set the backend config inside the router
|
||||
# (particularly for max_concurrent_queries).
|
||||
self.notify_backend_configs_changed()
|
||||
return return_uuid
|
||||
|
||||
async def delete_backend(self, backend_tag: BackendTag) -> None:
|
||||
async def delete_backend(self, backend_tag: BackendTag) -> UUID:
|
||||
async with self.write_lock:
|
||||
# This method must be idempotent. We should validate that the
|
||||
# specified backend exists on the client.
|
||||
@@ -879,6 +946,7 @@ class ServeController:
|
||||
# Add the intention to remove the backend from the router.
|
||||
self.actor_reconciler.backends_to_remove.append(backend_tag)
|
||||
|
||||
return_uuid = self._create_event_with_result({backend_tag: None})
|
||||
# NOTE(edoakes): we must write a checkpoint before removing the
|
||||
# backend from the router to avoid inconsistent state if we crash
|
||||
# after pushing the update.
|
||||
@@ -886,9 +954,10 @@ class ServeController:
|
||||
await self.actor_reconciler._stop_pending_backend_replicas()
|
||||
|
||||
self.notify_replica_handles_changed()
|
||||
return return_uuid
|
||||
|
||||
async def update_backend_config(self, backend_tag: BackendTag,
|
||||
config_options: BackendConfig) -> None:
|
||||
config_options: BackendConfig) -> UUID:
|
||||
"""Set the config for the specified backend."""
|
||||
async with self.write_lock:
|
||||
assert (self.current_state.get_backend(backend_tag)
|
||||
@@ -902,12 +971,16 @@ class ServeController:
|
||||
backend_config._validate_complete()
|
||||
self.current_state.get_backend(
|
||||
backend_tag).backend_config = backend_config
|
||||
backend_info = self.current_state.get_backend(backend_tag)
|
||||
|
||||
# Scale the replicas with the new configuration.
|
||||
self.actor_reconciler._scale_backend_replicas(
|
||||
self.current_state.backends, backend_tag,
|
||||
backend_config.num_replicas)
|
||||
|
||||
return_uuid = self._create_event_with_result({
|
||||
backend_tag: backend_info
|
||||
})
|
||||
# NOTE(edoakes): we must write a checkpoint before pushing the
|
||||
# update to avoid inconsistent state if we crash after pushing the
|
||||
# update.
|
||||
@@ -922,6 +995,7 @@ class ServeController:
|
||||
|
||||
self.notify_replica_handles_changed()
|
||||
self.notify_backend_configs_changed()
|
||||
return return_uuid
|
||||
|
||||
def get_backend_config(self, backend_tag: BackendTag) -> BackendConfig:
|
||||
"""Get the current config for the specified backend."""
|
||||
|
||||
@@ -0,0 +1,23 @@
|
||||
import pytest
|
||||
|
||||
import ray
|
||||
|
||||
|
||||
def test_controller_inflight_requests_clear(serve_instance):
|
||||
client = serve_instance
|
||||
initial_number_reqs = ray.get(
|
||||
client._controller._num_inflight_results.remote())
|
||||
|
||||
def function(_):
|
||||
return "hello"
|
||||
|
||||
client.create_backend("tst", function)
|
||||
client.create_endpoint("end_pt", backend="tst")
|
||||
|
||||
assert ray.get(client._controller._num_inflight_results.remote()
|
||||
) - initial_number_reqs == 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
sys.exit(pytest.main(["-v", "-s", __file__]))
|
||||
Reference in New Issue
Block a user