mirror of
https://github.com/wassname/ray.git
synced 2026-07-03 22:42:17 +08:00
[serve] Create all other actors in master actor (#7791)
This commit is contained in:
+29
-81
@@ -9,10 +9,10 @@ import numpy as np
|
||||
import ray
|
||||
from ray.serve.constants import (DEFAULT_HTTP_HOST, DEFAULT_HTTP_PORT,
|
||||
SERVE_MASTER_NAME)
|
||||
from ray.serve.global_state import GlobalState, start_initial_state
|
||||
from ray.serve.global_state import GlobalState, ServeMaster
|
||||
from ray.serve.kv_store_service import SQLiteKVStore
|
||||
from ray.serve.task_runner import RayServeMixin, TaskRunnerActor
|
||||
from ray.serve.utils import block_until_http_ready, get_random_letters, expand
|
||||
from ray.serve.utils import block_until_http_ready, expand
|
||||
from ray.serve.exceptions import RayServeException, batch_annotation_not_found
|
||||
from ray.serve.backend_config import BackendConfig
|
||||
from ray.serve.policy import RoutePolicy
|
||||
@@ -138,16 +138,15 @@ def init(
|
||||
def kv_store_connector(namespace):
|
||||
return SQLiteKVStore(namespace, db_path=kv_store_path)
|
||||
|
||||
master = start_initial_state(kv_store_connector)
|
||||
master = ServeMaster.options(
|
||||
detached=True, name=SERVE_MASTER_NAME).remote(kv_store_connector)
|
||||
|
||||
ray.get(master.start_router.remote(queueing_policy.value, policy_kwargs))
|
||||
|
||||
global_state = GlobalState(master)
|
||||
router = global_state.init_or_get_router(
|
||||
queueing_policy=queueing_policy, policy_kwargs=policy_kwargs)
|
||||
global_state.init_or_get_metric_monitor(
|
||||
gc_window_seconds=gc_window_seconds)
|
||||
ray.get(master.start_metric_monitor.remote(gc_window_seconds))
|
||||
if start_server:
|
||||
global_state.init_or_get_http_proxy(
|
||||
host=http_host, port=http_port).set_router_handle.remote(router)
|
||||
ray.get(master.start_http_proxy.remote(http_host, http_port))
|
||||
|
||||
if start_server and blocking:
|
||||
block_until_http_ready("http://{}:{}/-/routes".format(
|
||||
@@ -169,9 +168,11 @@ def create_endpoint(endpoint_name, route=None, methods=["GET"]):
|
||||
methods = [m.upper() for m in methods]
|
||||
global_state.route_table.register_service(
|
||||
route, endpoint_name, methods=methods)
|
||||
ray.get(global_state.init_or_get_http_proxy().set_route_table.remote(
|
||||
global_state.route_table.list_service(
|
||||
include_methods=True, include_headless=False)))
|
||||
http_proxy = global_state.get_http_proxy()
|
||||
ray.get(
|
||||
http_proxy.set_route_table.remote(
|
||||
global_state.route_table.list_service(
|
||||
include_methods=True, include_headless=False)))
|
||||
|
||||
|
||||
@_ensure_connected
|
||||
@@ -198,8 +199,8 @@ def set_backend_config(backend_tag, backend_config):
|
||||
|
||||
# inform the router about change in configuration
|
||||
# particularly for setting max_batch_size
|
||||
ray.get(global_state.init_or_get_router().set_backend_config.remote(
|
||||
backend_tag, backend_config_dict))
|
||||
router = global_state.get_router()
|
||||
ray.get(router.set_backend_config.remote(backend_tag, backend_config_dict))
|
||||
|
||||
# checking if replicas need to be restarted
|
||||
# Replicas are restarted if there is any change in the backend config
|
||||
@@ -281,7 +282,6 @@ def create_backend(func_or_class,
|
||||
class CustomActor(RayServeMixin, func_or_class):
|
||||
@wraps(func_or_class.__init__)
|
||||
def __init__(self, *args, **kwargs):
|
||||
init() # serve init
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
arg_list = actor_init_args
|
||||
@@ -305,68 +305,11 @@ def create_backend(func_or_class,
|
||||
|
||||
# set the backend config inside the router
|
||||
# particularly for max-batch-size
|
||||
ray.get(global_state.init_or_get_router().set_backend_config.remote(
|
||||
backend_tag, backend_config_dict))
|
||||
router = global_state.get_router()
|
||||
ray.get(router.set_backend_config.remote(backend_tag, backend_config_dict))
|
||||
_scale(backend_tag, backend_config_dict["num_replicas"])
|
||||
|
||||
|
||||
def _start_replica(backend_tag):
|
||||
assert (backend_tag in global_state.backend_table.list_backends()
|
||||
), "Backend {} is not registered.".format(backend_tag)
|
||||
|
||||
replica_tag = "{}#{}".format(backend_tag, get_random_letters(length=6))
|
||||
|
||||
# get the info which starts the replicas
|
||||
creator = global_state.backend_table.get_backend_creator(backend_tag)
|
||||
backend_config_dict = global_state.backend_table.get_info(backend_tag)
|
||||
backend_config = BackendConfig(**backend_config_dict)
|
||||
init_args = global_state.backend_table.get_init_args(backend_tag)
|
||||
|
||||
# get actor creation kwargs
|
||||
actor_kwargs = backend_config.get_actor_creation_args(init_args)
|
||||
|
||||
# Create the runner in the master actor
|
||||
[runner_handle] = ray.get(
|
||||
global_state.master_actor_handle.start_actor_with_creator.remote(
|
||||
creator, actor_kwargs, replica_tag))
|
||||
|
||||
# Setup the worker
|
||||
ray.get(
|
||||
runner_handle._ray_serve_setup.remote(
|
||||
backend_tag, global_state.init_or_get_router(), runner_handle))
|
||||
runner_handle._ray_serve_fetch.remote()
|
||||
|
||||
# Register the worker in config tables as well as metric monitor
|
||||
global_state.backend_table.add_replica(backend_tag, replica_tag)
|
||||
global_state.init_or_get_metric_monitor().add_target.remote(runner_handle)
|
||||
|
||||
|
||||
def _remove_replica(backend_tag):
|
||||
assert (backend_tag in global_state.backend_table.list_backends()
|
||||
), "Backend {} is not registered.".format(backend_tag)
|
||||
assert (
|
||||
len(global_state.backend_table.list_replicas(backend_tag)) >
|
||||
0), "Backend {} does not have enough replicas to be removed.".format(
|
||||
backend_tag)
|
||||
|
||||
replica_tag = global_state.backend_table.remove_replica(backend_tag)
|
||||
[replica_handle] = ray.get(
|
||||
global_state.master_actor_handle.get_handle.remote(replica_tag))
|
||||
|
||||
# Remove the replica from metric monitor.
|
||||
ray.get(global_state.init_or_get_metric_monitor().remove_target.remote(
|
||||
replica_handle))
|
||||
|
||||
# Remove the replica from master actor.
|
||||
ray.get(global_state.master_actor_handle.remove_handle.remote(replica_tag))
|
||||
|
||||
# Remove the replica from router.
|
||||
# This will also destory the actor handle.
|
||||
ray.get(
|
||||
global_state.init_or_get_router().remove_and_destory_replica.remote(
|
||||
backend_tag, replica_handle))
|
||||
|
||||
|
||||
@_ensure_connected
|
||||
def _scale(backend_tag, num_replicas):
|
||||
"""Set the number of replicas for backend_tag.
|
||||
@@ -386,10 +329,14 @@ def _scale(backend_tag, num_replicas):
|
||||
|
||||
if delta_num_replicas > 0:
|
||||
for _ in range(delta_num_replicas):
|
||||
_start_replica(backend_tag)
|
||||
ray.get(
|
||||
global_state.master_actor.start_backend_replica.remote(
|
||||
backend_tag))
|
||||
elif delta_num_replicas < 0:
|
||||
for _ in range(-delta_num_replicas):
|
||||
_remove_replica(backend_tag)
|
||||
ray.get(
|
||||
global_state.master_actor.remove_backend_replica.remote(
|
||||
backend_tag))
|
||||
|
||||
|
||||
@_ensure_connected
|
||||
@@ -441,8 +388,9 @@ def split(endpoint_name, traffic_policy_dictionary):
|
||||
|
||||
global_state.policy_table.register_traffic_policy(
|
||||
endpoint_name, traffic_policy_dictionary)
|
||||
ray.get(global_state.init_or_get_router().set_traffic.remote(
|
||||
endpoint_name, traffic_policy_dictionary))
|
||||
router = global_state.get_router()
|
||||
ray.get(
|
||||
router.set_traffic.remote(endpoint_name, traffic_policy_dictionary))
|
||||
|
||||
|
||||
@_ensure_connected
|
||||
@@ -473,7 +421,7 @@ def get_handle(endpoint_name,
|
||||
from ray.serve.handle import RayServeHandle
|
||||
|
||||
return RayServeHandle(
|
||||
global_state.init_or_get_router(),
|
||||
global_state.get_router(),
|
||||
endpoint_name,
|
||||
relative_slo_ms,
|
||||
absolute_slo_ms,
|
||||
@@ -492,8 +440,8 @@ def stat(percentiles=[50, 90, 95],
|
||||
The longest aggregation window must be shorter or equal to the
|
||||
gc_window_seconds.
|
||||
"""
|
||||
return ray.get(global_state.init_or_get_metric_monitor().collect.remote(
|
||||
percentiles, agg_windows_seconds))
|
||||
monitor = global_state.get_metric_monitor()
|
||||
return ray.get(monitor.collect.remote(percentiles, agg_windows_seconds))
|
||||
|
||||
|
||||
class route:
|
||||
|
||||
Reference in New Issue
Block a user