From 2c5cb95b42496dc418141082548bfcdb8d9f2a86 Mon Sep 17 00:00:00 2001 From: Simon Mo Date: Tue, 20 Oct 2020 10:44:23 -0700 Subject: [PATCH] [Serve] Get ServeHandle on the same node (#11477) --- python/ray/serve/api.py | 28 +++++++++++--- python/ray/serve/tests/test_standalone.py | 46 ++++++++++++++++++++++- python/ray/serve/utils.py | 6 +++ 3 files changed, 73 insertions(+), 7 deletions(-) diff --git a/python/ray/serve/api.py b/python/ray/serve/api.py index f469e6b69..993a9d8f4 100644 --- a/python/ray/serve/api.py +++ b/python/ray/serve/api.py @@ -1,5 +1,6 @@ import atexit from functools import wraps +import random import ray from ray.serve.constants import (DEFAULT_HTTP_HOST, DEFAULT_HTTP_PORT, @@ -7,7 +8,7 @@ from ray.serve.constants import (DEFAULT_HTTP_HOST, DEFAULT_HTTP_PORT, from ray.serve.controller import ServeController from ray.serve.handle import RayServeHandle from ray.serve.utils import (block_until_http_ready, format_actor_name, - get_random_letters, logger) + get_random_letters, logger, get_node_id_for_actor) from ray.serve.exceptions import RayServeException from ray.serve.config import BackendConfig, ReplicaConfig, BackendMetadata from ray.actor import ActorHandle @@ -317,23 +318,38 @@ class Client: proportion)) @_ensure_connected - def get_handle(self, endpoint_name: str) -> RayServeHandle: + def get_handle(self, + endpoint_name: str, + missing_ok: Optional[bool] = False) -> RayServeHandle: """Retrieve RayServeHandle for service endpoint to invoke it from Python. Args: endpoint_name (str): A registered service endpoint. + missing_ok (bool): If true, then Serve won't check the endpoint is + registered. False by default. Returns: RayServeHandle """ - if endpoint_name not in ray.get( + if not missing_ok and endpoint_name not in ray.get( self._controller.get_all_endpoints.remote()): raise KeyError(f"Endpoint '{endpoint_name}' does not exist.") - # TODO(edoakes): we should choose the router on the same node. - routers = ray.get(self._controller.get_routers.remote()) + routers = list(ray.get(self._controller.get_routers.remote()).values()) + current_node_id = ray.get_runtime_context().node_id.hex() + + try: + router_chosen = next( + filter(lambda r: get_node_id_for_actor(r) == current_node_id, + routers)) + except StopIteration: + logger.warning( + f"When getting a handle for {endpoint_name}, Serve can't find " + "a router on the same node. Serve will use a random router.") + router_chosen = random.choice(routers) + return RayServeHandle( - list(routers.values())[0], + router_chosen, endpoint_name, ) diff --git a/python/ray/serve/tests/test_standalone.py b/python/ray/serve/tests/test_standalone.py index 50f9b152f..6498169e6 100644 --- a/python/ray/serve/tests/test_standalone.py +++ b/python/ray/serve/tests/test_standalone.py @@ -2,6 +2,7 @@ The test file for all standalone tests that doesn't requires a shared Serve instance. """ +from random import randint import sys import socket @@ -13,7 +14,7 @@ from ray import serve from ray.cluster_utils import Cluster from ray.serve.constants import SERVE_PROXY_NAME from ray.serve.utils import (block_until_http_ready, get_all_node_ids, - format_actor_name) + format_actor_name, get_node_id_for_actor) from ray.test_utils import wait_for_condition from ray._private.services import new_port @@ -128,5 +129,48 @@ def test_middleware(): ray.shutdown() +@pytest.mark.skipif( + not hasattr(socket, "SO_REUSEPORT"), + reason=("Port sharing only works on newer verion of Linux. " + "This test can only be ran when port sharing is supported.")) +def test_cluster_handle_affinity(): + cluster = Cluster() + # HACK: using two different ip address so the placement constraint for + # resource check later will work. + head_node = cluster.add_node(node_ip_address="127.0.0.1", num_cpus=4) + cluster.add_node(node_ip_address="0.0.0.0", num_cpus=4) + + ray.init(head_node.address) + + # Make sure we have two nodes. + node_ids = [n["NodeID"] for n in ray.nodes()] + assert len(node_ids) == 2 + + # Start the backend. + client = serve.start(http_port=randint(10000, 30000), detached=True) + client.create_backend("hi:v0", lambda _: "hi") + client.create_endpoint("hi", backend="hi:v0") + + # Try to retrieve the handle from both head and worker node, check the + # router's node id. + @ray.remote + def check_handle_router_id(): + client = serve.connect() + handle = client.get_handle("hi") + return get_node_id_for_actor(handle.router_handle) + + router_node_ids = ray.get([ + check_handle_router_id.options(resources={ + node_id: 0.01 + }).remote() for node_id in ray.state.node_ids() + ]) + + assert set(router_node_ids) == set(node_ids) + + # Clean up the nodes (otherwise Ray will segfault). + ray.shutdown() + cluster.shutdown() + + if __name__ == "__main__": sys.exit(pytest.main(["-v", "-s", __file__])) diff --git a/python/ray/serve/utils.py b/python/ray/serve/utils.py index 0f2eb38f6..b227c049d 100644 --- a/python/ray/serve/utils.py +++ b/python/ray/serve/utils.py @@ -299,3 +299,9 @@ def get_all_node_ids(): node_ids.append(("{}-{}".format(node_id, index), node_id)) return node_ids + + +def get_node_id_for_actor(actor_handle): + """Given an actor handle, return the node id it's placed on.""" + + return ray.actors()[actor_handle._actor_id.hex()]["Address"]["NodeID"]