[Serve] Implement replica scaling (#5850)

* Implement replica scaling

* Lint

* Fix .travis.yml so it won't skip if only serve affected
This commit is contained in:
Simon Mo
2019-10-07 01:57:31 -07:00
committed by GitHub
parent 5834c56c64
commit 25dde48607
6 changed files with 110 additions and 13 deletions
+2 -2
View File
@@ -4,9 +4,9 @@ if sys.version_info < (3, 0):
from ray.experimental.serve.api import (init, create_backend, create_endpoint,
link, split, rollback, get_handle,
global_state) # noqa: E402
global_state, scale) # noqa: E402
__all__ = [
"init", "create_backend", "create_endpoint", "link", "split", "rollback",
"get_handle", "global_state"
"get_handle", "global_state", "scale"
]
+58 -7
View File
@@ -67,7 +67,8 @@ def create_backend(func_or_class, backend_tag, *actor_init_args):
initialization method.
"""
if inspect.isfunction(func_or_class):
runner = TaskRunnerActor.remote(func_or_class)
# ignore lint on lambda expression
creator = lambda: TaskRunnerActor.remote(func_or_class) # noqa: E731
elif inspect.isclass(func_or_class):
# Python inheritance order is right-to-left. We put RayServeMixin
# on the left to make sure its methods are not overriden.
@@ -75,20 +76,70 @@ def create_backend(func_or_class, backend_tag, *actor_init_args):
class CustomActor(RayServeMixin, func_or_class):
pass
runner = CustomActor.remote(*actor_init_args)
# ignore lint on lambda expression
creator = lambda: CustomActor.remote(*actor_init_args) # noqa: E731
else:
raise TypeError(
"Backend must be a function or class, it is {}.".format(
type(func_or_class)))
global_state.backend_actor_handles.append(runner)
runner._ray_serve_setup.remote(backend_tag,
global_state.router_actor_handle)
runner._ray_serve_main_loop.remote(runner)
global_state.backend_creators[backend_tag] = creator
global_state.registered_backends.add(backend_tag)
scale(backend_tag, 1)
def _start_replica(backend_tag):
assert backend_tag in global_state.registered_backends, (
"Backend {} is not registered.".format(backend_tag))
creator = global_state.backend_creators[backend_tag]
runner = creator()
setup_done = runner._ray_serve_setup.remote(
backend_tag, global_state.router_actor_handle)
ray.get(setup_done)
runner._ray_serve_main_loop.remote(runner)
global_state.backend_replicas[backend_tag].append(runner)
def _remove_replica(backend_tag):
assert backend_tag in global_state.registered_backends, (
"Backend {} is not registered.".format(backend_tag))
assert len(global_state.backend_replicas[backend_tag]) > 0, (
"Backend {} does not have enough replicas to be removed.".format(
backend_tag))
replicas = global_state.backend_replicas[backend_tag]
oldest_replica_handle = replicas.popleft()
# explicitly terminate that actor
del oldest_replica_handle
def scale(backend_tag, num_replicas):
"""Set the number of replicas for backend_tag.
Args:
backend_tag (str): A registered backend.
num_replicas (int): Desired number of replicas
"""
assert backend_tag in global_state.registered_backends, (
"Backend {} is not registered.".format(backend_tag))
assert num_replicas > 0, "Number of replicas must be greater than 1."
replicas = global_state.backend_replicas[backend_tag]
current_num_replicas = len(replicas)
delta_num_replicas = num_replicas - current_num_replicas
if delta_num_replicas > 0:
for _ in range(delta_num_replicas):
_start_replica(backend_tag)
elif delta_num_replicas < 0:
for _ in range(-delta_num_replicas):
_remove_replica(backend_tag)
def link(endpoint_name, backend_tag):
"""Associate a service endpoint with backend tag.
@@ -52,3 +52,7 @@ serve.split("my_endpoint", {"echo:v1": 0.5, "echo:v2": 0.5})
for _ in range(10):
print(requests.get("http://127.0.0.1:8000/echo").json())
time.sleep(0.5)
# You can also scale each backend independently.
serve.scale("echo:v1", 2)
serve.scale("echo:v2", 2)
@@ -27,9 +27,6 @@ class GlobalState:
"""
def __init__(self):
#: holds all actor handles.
self.backend_actor_handles = []
#: actor handle to KV store actor
self.kv_store_actor_handle = None
#: actor handle to HTTP server
@@ -45,6 +42,12 @@ class GlobalState:
#: Mapping of endpoints -> a stack of traffic policy
self.policy_action_history = defaultdict(deque)
#: Backend creaters. Mapping backend_tag -> callable creator
self.backend_creators = dict()
#: Number of replicas per backend.
# Mapping backend_tag -> deque(actor_handles)
self.backend_replicas = defaultdict(deque)
#: HTTP address. Currently it's hard coded to localhost with port 8000
# In future iteration, HTTP server will be started on every node and
# use random/available port in a pre-defined port range. TODO(simon)
@@ -32,3 +32,42 @@ def test_e2e(serve_instance):
resp = requests.get("http://127.0.0.1:8000/api").json()["result"]
assert resp == "OK"
def test_scaling_replicas(serve_instance):
class Counter:
def __init__(self):
self.count = 0
def __call__(self, _):
self.count += 1
return self.count
serve.create_endpoint("counter", "/increment")
# Keep checking the routing table until /increment is populated
while "/increment" not in requests.get("http://127.0.0.1:8000/").json():
time.sleep(0.2)
serve.create_backend(Counter, "counter:v1")
serve.link("counter", "counter:v1")
serve.scale("counter:v1", 2)
counter_result = []
for _ in range(10):
resp = requests.get("http://127.0.0.1:8000/increment").json()["result"]
counter_result.append(resp)
# If the load is shared among two replicas. The max result cannot be 10.
assert max(counter_result) < 10
serve.scale("counter:v1", 1)
counter_result = []
for _ in range(10):
resp = requests.get("http://127.0.0.1:8000/increment").json()["result"]
counter_result.append(resp)
# Give some time for a replica to spin down. But majority of the request
# should be served by the only remaining replica.
assert max(counter_result) - min(counter_result) > 6