[serve] Create all other actors in master actor (#7791)

This commit is contained in:
Edward Oakes
2020-04-01 10:15:04 -05:00
committed by GitHub
parent b011c604d7
commit f4239d27fa
5 changed files with 144 additions and 218 deletions
+29 -81
View File
@@ -9,10 +9,10 @@ import numpy as np
import ray
from ray.serve.constants import (DEFAULT_HTTP_HOST, DEFAULT_HTTP_PORT,
SERVE_MASTER_NAME)
from ray.serve.global_state import GlobalState, start_initial_state
from ray.serve.global_state import GlobalState, ServeMaster
from ray.serve.kv_store_service import SQLiteKVStore
from ray.serve.task_runner import RayServeMixin, TaskRunnerActor
from ray.serve.utils import block_until_http_ready, get_random_letters, expand
from ray.serve.utils import block_until_http_ready, expand
from ray.serve.exceptions import RayServeException, batch_annotation_not_found
from ray.serve.backend_config import BackendConfig
from ray.serve.policy import RoutePolicy
@@ -138,16 +138,15 @@ def init(
def kv_store_connector(namespace):
return SQLiteKVStore(namespace, db_path=kv_store_path)
master = start_initial_state(kv_store_connector)
master = ServeMaster.options(
detached=True, name=SERVE_MASTER_NAME).remote(kv_store_connector)
ray.get(master.start_router.remote(queueing_policy.value, policy_kwargs))
global_state = GlobalState(master)
router = global_state.init_or_get_router(
queueing_policy=queueing_policy, policy_kwargs=policy_kwargs)
global_state.init_or_get_metric_monitor(
gc_window_seconds=gc_window_seconds)
ray.get(master.start_metric_monitor.remote(gc_window_seconds))
if start_server:
global_state.init_or_get_http_proxy(
host=http_host, port=http_port).set_router_handle.remote(router)
ray.get(master.start_http_proxy.remote(http_host, http_port))
if start_server and blocking:
block_until_http_ready("http://{}:{}/-/routes".format(
@@ -169,9 +168,11 @@ def create_endpoint(endpoint_name, route=None, methods=["GET"]):
methods = [m.upper() for m in methods]
global_state.route_table.register_service(
route, endpoint_name, methods=methods)
ray.get(global_state.init_or_get_http_proxy().set_route_table.remote(
global_state.route_table.list_service(
include_methods=True, include_headless=False)))
http_proxy = global_state.get_http_proxy()
ray.get(
http_proxy.set_route_table.remote(
global_state.route_table.list_service(
include_methods=True, include_headless=False)))
@_ensure_connected
@@ -198,8 +199,8 @@ def set_backend_config(backend_tag, backend_config):
# inform the router about change in configuration
# particularly for setting max_batch_size
ray.get(global_state.init_or_get_router().set_backend_config.remote(
backend_tag, backend_config_dict))
router = global_state.get_router()
ray.get(router.set_backend_config.remote(backend_tag, backend_config_dict))
# checking if replicas need to be restarted
# Replicas are restarted if there is any change in the backend config
@@ -281,7 +282,6 @@ def create_backend(func_or_class,
class CustomActor(RayServeMixin, func_or_class):
@wraps(func_or_class.__init__)
def __init__(self, *args, **kwargs):
init() # serve init
super().__init__(*args, **kwargs)
arg_list = actor_init_args
@@ -305,68 +305,11 @@ def create_backend(func_or_class,
# set the backend config inside the router
# particularly for max-batch-size
ray.get(global_state.init_or_get_router().set_backend_config.remote(
backend_tag, backend_config_dict))
router = global_state.get_router()
ray.get(router.set_backend_config.remote(backend_tag, backend_config_dict))
_scale(backend_tag, backend_config_dict["num_replicas"])
def _start_replica(backend_tag):
assert (backend_tag in global_state.backend_table.list_backends()
), "Backend {} is not registered.".format(backend_tag)
replica_tag = "{}#{}".format(backend_tag, get_random_letters(length=6))
# get the info which starts the replicas
creator = global_state.backend_table.get_backend_creator(backend_tag)
backend_config_dict = global_state.backend_table.get_info(backend_tag)
backend_config = BackendConfig(**backend_config_dict)
init_args = global_state.backend_table.get_init_args(backend_tag)
# get actor creation kwargs
actor_kwargs = backend_config.get_actor_creation_args(init_args)
# Create the runner in the master actor
[runner_handle] = ray.get(
global_state.master_actor_handle.start_actor_with_creator.remote(
creator, actor_kwargs, replica_tag))
# Setup the worker
ray.get(
runner_handle._ray_serve_setup.remote(
backend_tag, global_state.init_or_get_router(), runner_handle))
runner_handle._ray_serve_fetch.remote()
# Register the worker in config tables as well as metric monitor
global_state.backend_table.add_replica(backend_tag, replica_tag)
global_state.init_or_get_metric_monitor().add_target.remote(runner_handle)
def _remove_replica(backend_tag):
assert (backend_tag in global_state.backend_table.list_backends()
), "Backend {} is not registered.".format(backend_tag)
assert (
len(global_state.backend_table.list_replicas(backend_tag)) >
0), "Backend {} does not have enough replicas to be removed.".format(
backend_tag)
replica_tag = global_state.backend_table.remove_replica(backend_tag)
[replica_handle] = ray.get(
global_state.master_actor_handle.get_handle.remote(replica_tag))
# Remove the replica from metric monitor.
ray.get(global_state.init_or_get_metric_monitor().remove_target.remote(
replica_handle))
# Remove the replica from master actor.
ray.get(global_state.master_actor_handle.remove_handle.remote(replica_tag))
# Remove the replica from router.
# This will also destory the actor handle.
ray.get(
global_state.init_or_get_router().remove_and_destory_replica.remote(
backend_tag, replica_handle))
@_ensure_connected
def _scale(backend_tag, num_replicas):
"""Set the number of replicas for backend_tag.
@@ -386,10 +329,14 @@ def _scale(backend_tag, num_replicas):
if delta_num_replicas > 0:
for _ in range(delta_num_replicas):
_start_replica(backend_tag)
ray.get(
global_state.master_actor.start_backend_replica.remote(
backend_tag))
elif delta_num_replicas < 0:
for _ in range(-delta_num_replicas):
_remove_replica(backend_tag)
ray.get(
global_state.master_actor.remove_backend_replica.remote(
backend_tag))
@_ensure_connected
@@ -441,8 +388,9 @@ def split(endpoint_name, traffic_policy_dictionary):
global_state.policy_table.register_traffic_policy(
endpoint_name, traffic_policy_dictionary)
ray.get(global_state.init_or_get_router().set_traffic.remote(
endpoint_name, traffic_policy_dictionary))
router = global_state.get_router()
ray.get(
router.set_traffic.remote(endpoint_name, traffic_policy_dictionary))
@_ensure_connected
@@ -473,7 +421,7 @@ def get_handle(endpoint_name,
from ray.serve.handle import RayServeHandle
return RayServeHandle(
global_state.init_or_get_router(),
global_state.get_router(),
endpoint_name,
relative_slo_ms,
absolute_slo_ms,
@@ -492,8 +440,8 @@ def stat(percentiles=[50, 90, 95],
The longest aggregation window must be shorter or equal to the
gc_window_seconds.
"""
return ray.get(global_state.init_or_get_metric_monitor().collect.remote(
percentiles, agg_windows_seconds))
monitor = global_state.get_metric_monitor()
return ray.get(monitor.collect.remote(percentiles, agg_windows_seconds))
class route: