diff --git a/.travis.yml b/.travis.yml index d29a78b3d..cacd12dd1 100644 --- a/.travis.yml +++ b/.travis.yml @@ -140,7 +140,7 @@ matrix: install: - eval `python $TRAVIS_BUILD_DIR/ci/travis/determine_tests_to_run.py` - - if [ $RAY_CI_TUNE_AFFECTED != "1" ] && [ $RAY_CI_RLLIB_AFFECTED != "1" ] && [ $RAY_CI_PYTHON_AFFECTED != "1" ]; then exit; fi + - if [ $RAY_CI_SERVE_AFFECTED != "1" ] && [ $RAY_CI_TUNE_AFFECTED != "1" ] && [ $RAY_CI_RLLIB_AFFECTED != "1" ] && [ $RAY_CI_PYTHON_AFFECTED != "1" ]; then exit; fi - ./ci/suppress_output ./ci/travis/install-bazel.sh - ./ci/suppress_output ./ci/travis/install-dependencies.sh diff --git a/python/ray/experimental/serve/__init__.py b/python/ray/experimental/serve/__init__.py index cd3a84c42..9b14959b5 100644 --- a/python/ray/experimental/serve/__init__.py +++ b/python/ray/experimental/serve/__init__.py @@ -4,9 +4,9 @@ if sys.version_info < (3, 0): from ray.experimental.serve.api import (init, create_backend, create_endpoint, link, split, rollback, get_handle, - global_state) # noqa: E402 + global_state, scale) # noqa: E402 __all__ = [ "init", "create_backend", "create_endpoint", "link", "split", "rollback", - "get_handle", "global_state" + "get_handle", "global_state", "scale" ] diff --git a/python/ray/experimental/serve/api.py b/python/ray/experimental/serve/api.py index dd89961b6..a1ae67b01 100644 --- a/python/ray/experimental/serve/api.py +++ b/python/ray/experimental/serve/api.py @@ -67,7 +67,8 @@ def create_backend(func_or_class, backend_tag, *actor_init_args): initialization method. """ if inspect.isfunction(func_or_class): - runner = TaskRunnerActor.remote(func_or_class) + # ignore lint on lambda expression + creator = lambda: TaskRunnerActor.remote(func_or_class) # noqa: E731 elif inspect.isclass(func_or_class): # Python inheritance order is right-to-left. We put RayServeMixin # on the left to make sure its methods are not overriden. @@ -75,20 +76,70 @@ def create_backend(func_or_class, backend_tag, *actor_init_args): class CustomActor(RayServeMixin, func_or_class): pass - runner = CustomActor.remote(*actor_init_args) + # ignore lint on lambda expression + creator = lambda: CustomActor.remote(*actor_init_args) # noqa: E731 else: raise TypeError( "Backend must be a function or class, it is {}.".format( type(func_or_class))) - global_state.backend_actor_handles.append(runner) - - runner._ray_serve_setup.remote(backend_tag, - global_state.router_actor_handle) - runner._ray_serve_main_loop.remote(runner) + global_state.backend_creators[backend_tag] = creator global_state.registered_backends.add(backend_tag) + scale(backend_tag, 1) + + +def _start_replica(backend_tag): + assert backend_tag in global_state.registered_backends, ( + "Backend {} is not registered.".format(backend_tag)) + + creator = global_state.backend_creators[backend_tag] + + runner = creator() + setup_done = runner._ray_serve_setup.remote( + backend_tag, global_state.router_actor_handle) + ray.get(setup_done) + runner._ray_serve_main_loop.remote(runner) + + global_state.backend_replicas[backend_tag].append(runner) + + +def _remove_replica(backend_tag): + assert backend_tag in global_state.registered_backends, ( + "Backend {} is not registered.".format(backend_tag)) + assert len(global_state.backend_replicas[backend_tag]) > 0, ( + "Backend {} does not have enough replicas to be removed.".format( + backend_tag)) + + replicas = global_state.backend_replicas[backend_tag] + oldest_replica_handle = replicas.popleft() + # explicitly terminate that actor + del oldest_replica_handle + + +def scale(backend_tag, num_replicas): + """Set the number of replicas for backend_tag. + + Args: + backend_tag (str): A registered backend. + num_replicas (int): Desired number of replicas + """ + assert backend_tag in global_state.registered_backends, ( + "Backend {} is not registered.".format(backend_tag)) + assert num_replicas > 0, "Number of replicas must be greater than 1." + + replicas = global_state.backend_replicas[backend_tag] + current_num_replicas = len(replicas) + delta_num_replicas = num_replicas - current_num_replicas + + if delta_num_replicas > 0: + for _ in range(delta_num_replicas): + _start_replica(backend_tag) + elif delta_num_replicas < 0: + for _ in range(-delta_num_replicas): + _remove_replica(backend_tag) + def link(endpoint_name, backend_tag): """Associate a service endpoint with backend tag. diff --git a/python/ray/experimental/serve/examples/echo_full.py b/python/ray/experimental/serve/examples/echo_full.py index 2edfff4c2..ffbfa9608 100644 --- a/python/ray/experimental/serve/examples/echo_full.py +++ b/python/ray/experimental/serve/examples/echo_full.py @@ -52,3 +52,7 @@ serve.split("my_endpoint", {"echo:v1": 0.5, "echo:v2": 0.5}) for _ in range(10): print(requests.get("http://127.0.0.1:8000/echo").json()) time.sleep(0.5) + +# You can also scale each backend independently. +serve.scale("echo:v1", 2) +serve.scale("echo:v2", 2) diff --git a/python/ray/experimental/serve/global_state.py b/python/ray/experimental/serve/global_state.py index e817ca185..780d93b51 100644 --- a/python/ray/experimental/serve/global_state.py +++ b/python/ray/experimental/serve/global_state.py @@ -27,9 +27,6 @@ class GlobalState: """ def __init__(self): - #: holds all actor handles. - self.backend_actor_handles = [] - #: actor handle to KV store actor self.kv_store_actor_handle = None #: actor handle to HTTP server @@ -45,6 +42,12 @@ class GlobalState: #: Mapping of endpoints -> a stack of traffic policy self.policy_action_history = defaultdict(deque) + #: Backend creaters. Mapping backend_tag -> callable creator + self.backend_creators = dict() + #: Number of replicas per backend. + # Mapping backend_tag -> deque(actor_handles) + self.backend_replicas = defaultdict(deque) + #: HTTP address. Currently it's hard coded to localhost with port 8000 # In future iteration, HTTP server will be started on every node and # use random/available port in a pre-defined port range. TODO(simon) diff --git a/python/ray/experimental/serve/tests/test_api.py b/python/ray/experimental/serve/tests/test_api.py index a45c37ba2..e11fa2bd3 100644 --- a/python/ray/experimental/serve/tests/test_api.py +++ b/python/ray/experimental/serve/tests/test_api.py @@ -32,3 +32,42 @@ def test_e2e(serve_instance): resp = requests.get("http://127.0.0.1:8000/api").json()["result"] assert resp == "OK" + + +def test_scaling_replicas(serve_instance): + class Counter: + def __init__(self): + self.count = 0 + + def __call__(self, _): + self.count += 1 + return self.count + + serve.create_endpoint("counter", "/increment") + + # Keep checking the routing table until /increment is populated + while "/increment" not in requests.get("http://127.0.0.1:8000/").json(): + time.sleep(0.2) + + serve.create_backend(Counter, "counter:v1") + serve.link("counter", "counter:v1") + + serve.scale("counter:v1", 2) + + counter_result = [] + for _ in range(10): + resp = requests.get("http://127.0.0.1:8000/increment").json()["result"] + counter_result.append(resp) + + # If the load is shared among two replicas. The max result cannot be 10. + assert max(counter_result) < 10 + + serve.scale("counter:v1", 1) + + counter_result = [] + for _ in range(10): + resp = requests.get("http://127.0.0.1:8000/increment").json()["result"] + counter_result.append(resp) + # Give some time for a replica to spin down. But majority of the request + # should be served by the only remaining replica. + assert max(counter_result) - min(counter_result) > 6