mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 13:54:27 +08:00
[Serve] Implement replica scaling (#5850)
* Implement replica scaling * Lint * Fix .travis.yml so it won't skip if only serve affected
This commit is contained in:
@@ -4,9 +4,9 @@ if sys.version_info < (3, 0):
|
||||
|
||||
from ray.experimental.serve.api import (init, create_backend, create_endpoint,
|
||||
link, split, rollback, get_handle,
|
||||
global_state) # noqa: E402
|
||||
global_state, scale) # noqa: E402
|
||||
|
||||
__all__ = [
|
||||
"init", "create_backend", "create_endpoint", "link", "split", "rollback",
|
||||
"get_handle", "global_state"
|
||||
"get_handle", "global_state", "scale"
|
||||
]
|
||||
|
||||
@@ -67,7 +67,8 @@ def create_backend(func_or_class, backend_tag, *actor_init_args):
|
||||
initialization method.
|
||||
"""
|
||||
if inspect.isfunction(func_or_class):
|
||||
runner = TaskRunnerActor.remote(func_or_class)
|
||||
# ignore lint on lambda expression
|
||||
creator = lambda: TaskRunnerActor.remote(func_or_class) # noqa: E731
|
||||
elif inspect.isclass(func_or_class):
|
||||
# Python inheritance order is right-to-left. We put RayServeMixin
|
||||
# on the left to make sure its methods are not overriden.
|
||||
@@ -75,20 +76,70 @@ def create_backend(func_or_class, backend_tag, *actor_init_args):
|
||||
class CustomActor(RayServeMixin, func_or_class):
|
||||
pass
|
||||
|
||||
runner = CustomActor.remote(*actor_init_args)
|
||||
# ignore lint on lambda expression
|
||||
creator = lambda: CustomActor.remote(*actor_init_args) # noqa: E731
|
||||
else:
|
||||
raise TypeError(
|
||||
"Backend must be a function or class, it is {}.".format(
|
||||
type(func_or_class)))
|
||||
|
||||
global_state.backend_actor_handles.append(runner)
|
||||
|
||||
runner._ray_serve_setup.remote(backend_tag,
|
||||
global_state.router_actor_handle)
|
||||
runner._ray_serve_main_loop.remote(runner)
|
||||
global_state.backend_creators[backend_tag] = creator
|
||||
|
||||
global_state.registered_backends.add(backend_tag)
|
||||
|
||||
scale(backend_tag, 1)
|
||||
|
||||
|
||||
def _start_replica(backend_tag):
|
||||
assert backend_tag in global_state.registered_backends, (
|
||||
"Backend {} is not registered.".format(backend_tag))
|
||||
|
||||
creator = global_state.backend_creators[backend_tag]
|
||||
|
||||
runner = creator()
|
||||
setup_done = runner._ray_serve_setup.remote(
|
||||
backend_tag, global_state.router_actor_handle)
|
||||
ray.get(setup_done)
|
||||
runner._ray_serve_main_loop.remote(runner)
|
||||
|
||||
global_state.backend_replicas[backend_tag].append(runner)
|
||||
|
||||
|
||||
def _remove_replica(backend_tag):
|
||||
assert backend_tag in global_state.registered_backends, (
|
||||
"Backend {} is not registered.".format(backend_tag))
|
||||
assert len(global_state.backend_replicas[backend_tag]) > 0, (
|
||||
"Backend {} does not have enough replicas to be removed.".format(
|
||||
backend_tag))
|
||||
|
||||
replicas = global_state.backend_replicas[backend_tag]
|
||||
oldest_replica_handle = replicas.popleft()
|
||||
# explicitly terminate that actor
|
||||
del oldest_replica_handle
|
||||
|
||||
|
||||
def scale(backend_tag, num_replicas):
|
||||
"""Set the number of replicas for backend_tag.
|
||||
|
||||
Args:
|
||||
backend_tag (str): A registered backend.
|
||||
num_replicas (int): Desired number of replicas
|
||||
"""
|
||||
assert backend_tag in global_state.registered_backends, (
|
||||
"Backend {} is not registered.".format(backend_tag))
|
||||
assert num_replicas > 0, "Number of replicas must be greater than 1."
|
||||
|
||||
replicas = global_state.backend_replicas[backend_tag]
|
||||
current_num_replicas = len(replicas)
|
||||
delta_num_replicas = num_replicas - current_num_replicas
|
||||
|
||||
if delta_num_replicas > 0:
|
||||
for _ in range(delta_num_replicas):
|
||||
_start_replica(backend_tag)
|
||||
elif delta_num_replicas < 0:
|
||||
for _ in range(-delta_num_replicas):
|
||||
_remove_replica(backend_tag)
|
||||
|
||||
|
||||
def link(endpoint_name, backend_tag):
|
||||
"""Associate a service endpoint with backend tag.
|
||||
|
||||
@@ -52,3 +52,7 @@ serve.split("my_endpoint", {"echo:v1": 0.5, "echo:v2": 0.5})
|
||||
for _ in range(10):
|
||||
print(requests.get("http://127.0.0.1:8000/echo").json())
|
||||
time.sleep(0.5)
|
||||
|
||||
# You can also scale each backend independently.
|
||||
serve.scale("echo:v1", 2)
|
||||
serve.scale("echo:v2", 2)
|
||||
|
||||
@@ -27,9 +27,6 @@ class GlobalState:
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
#: holds all actor handles.
|
||||
self.backend_actor_handles = []
|
||||
|
||||
#: actor handle to KV store actor
|
||||
self.kv_store_actor_handle = None
|
||||
#: actor handle to HTTP server
|
||||
@@ -45,6 +42,12 @@ class GlobalState:
|
||||
#: Mapping of endpoints -> a stack of traffic policy
|
||||
self.policy_action_history = defaultdict(deque)
|
||||
|
||||
#: Backend creaters. Mapping backend_tag -> callable creator
|
||||
self.backend_creators = dict()
|
||||
#: Number of replicas per backend.
|
||||
# Mapping backend_tag -> deque(actor_handles)
|
||||
self.backend_replicas = defaultdict(deque)
|
||||
|
||||
#: HTTP address. Currently it's hard coded to localhost with port 8000
|
||||
# In future iteration, HTTP server will be started on every node and
|
||||
# use random/available port in a pre-defined port range. TODO(simon)
|
||||
|
||||
@@ -32,3 +32,42 @@ def test_e2e(serve_instance):
|
||||
|
||||
resp = requests.get("http://127.0.0.1:8000/api").json()["result"]
|
||||
assert resp == "OK"
|
||||
|
||||
|
||||
def test_scaling_replicas(serve_instance):
|
||||
class Counter:
|
||||
def __init__(self):
|
||||
self.count = 0
|
||||
|
||||
def __call__(self, _):
|
||||
self.count += 1
|
||||
return self.count
|
||||
|
||||
serve.create_endpoint("counter", "/increment")
|
||||
|
||||
# Keep checking the routing table until /increment is populated
|
||||
while "/increment" not in requests.get("http://127.0.0.1:8000/").json():
|
||||
time.sleep(0.2)
|
||||
|
||||
serve.create_backend(Counter, "counter:v1")
|
||||
serve.link("counter", "counter:v1")
|
||||
|
||||
serve.scale("counter:v1", 2)
|
||||
|
||||
counter_result = []
|
||||
for _ in range(10):
|
||||
resp = requests.get("http://127.0.0.1:8000/increment").json()["result"]
|
||||
counter_result.append(resp)
|
||||
|
||||
# If the load is shared among two replicas. The max result cannot be 10.
|
||||
assert max(counter_result) < 10
|
||||
|
||||
serve.scale("counter:v1", 1)
|
||||
|
||||
counter_result = []
|
||||
for _ in range(10):
|
||||
resp = requests.get("http://127.0.0.1:8000/increment").json()["result"]
|
||||
counter_result.append(resp)
|
||||
# Give some time for a replica to spin down. But majority of the request
|
||||
# should be served by the only remaining replica.
|
||||
assert max(counter_result) - min(counter_result) > 6
|
||||
|
||||
Reference in New Issue
Block a user