diff --git a/doc/source/serve/advanced.rst b/doc/source/serve/advanced.rst index e0a2de7a0..dc386ee11 100644 --- a/doc/source/serve/advanced.rst +++ b/doc/source/serve/advanced.rst @@ -405,3 +405,21 @@ backend based on a class that is installed in the Python environment that the workers will run in. Example: .. literalinclude:: ../../../python/ray/serve/examples/doc/imported_backend.py + +Configuring HTTP Server Locations +================================= + +By default, Ray Serve starts only one HTTP on the head node of the Ray cluster. +You can configure this behavior using the ``http_options={"location": ...}`` flag +in :mod:`serve.start `: + +- "HeadOnly": start one HTTP server on the head node. Serve + assumes the head node is the node you executed serve.start + on. This is the default. +- "EveryNode": start one HTTP server per node. +- "NoServer" or ``None``: disable HTTP server. + +.. note:: + Using the "EveryNode" option, you can point a cloud load balancer to the + instance group of Ray cluster to achieve high availability of Serve's HTTP + proxies. \ No newline at end of file diff --git a/doc/source/serve/deployment.rst b/doc/source/serve/deployment.rst index da8068213..5ab65a7a3 100644 --- a/doc/source/serve/deployment.rst +++ b/doc/source/serve/deployment.rst @@ -69,7 +69,7 @@ a backend in serve for our model (and versioned it with a string). What serve does when we run this code is store the model as a Ray actor and route traffic to it as the endpoint is queried, in this case over HTTP. Note that in order for this endpoint to be accessible from other machines, we -need to specify ``http_host="0.0.0.0"`` in :mod:`serve.start ` like we did here. +need to specify ``http_options={"host": "0.0.0.0"}`` in :mod:`serve.start ` like we did here. Now let's query our endpoint to see the result. @@ -225,7 +225,7 @@ With the cluster now running, we can run a simple script to start Ray Serve and # Connect to the running Ray cluster. ray.init(address="auto") # Bind on 0.0.0.0 to expose the HTTP server on external IPs. - client = serve.start(http_host="0.0.0.0") + client = serve.start(http_options={"host": "0.0.0.0"}) def hello(): return "hello world" diff --git a/doc/source/serve/faq.rst b/doc/source/serve/faq.rst index 451307ce3..a9d66b610 100644 --- a/doc/source/serve/faq.rst +++ b/doc/source/serve/faq.rst @@ -96,10 +96,10 @@ You can follow the same pattern for other Starlette middlewares. from starlette.middleware.cors import CORSMiddleware client = serve.start( - http_middlewares=[ + http_options={"middlewares": [ Middleware( CORSMiddleware, allow_origins=["*"], allow_methods=["*"]) - ]) + ]}) .. _serve-handle-explainer: diff --git a/python/ray/serve/BUILD b/python/ray/serve/BUILD index 8252c2a4b..443607607 100644 --- a/python/ray/serve/BUILD +++ b/python/ray/serve/BUILD @@ -105,7 +105,7 @@ py_test( py_test( name = "test_standalone", - size = "small", + size = "medium", srcs = serve_tests_srcs, tags = ["exclusive"], deps = [":serve_lib"], diff --git a/python/ray/serve/__init__.py b/python/ray/serve/__init__.py index 61aa92acc..3b76286a4 100644 --- a/python/ray/serve/__init__.py +++ b/python/ray/serve/__init__.py @@ -1,7 +1,12 @@ -from ray.serve.api import (accept_batch, Client, connect, - get_current_backend_tag, get_current_replica_tag, - start) -from ray.serve.config import BackendConfig +from ray.serve.api import ( + accept_batch, + Client, + connect, + get_current_backend_tag, + get_current_replica_tag, + start, +) +from ray.serve.config import BackendConfig, HTTPOptions from ray.serve.env import CondaEnv # Mute the warning because Serve sometimes intentionally calls @@ -18,4 +23,5 @@ __all__ = [ "get_current_backend_tag", "get_current_replica_tag", "start", + "HTTPOptions", ] diff --git a/python/ray/serve/api.py b/python/ray/serve/api.py index 564a29fc5..947af0ca5 100644 --- a/python/ray/serve/api.py +++ b/python/ray/serve/api.py @@ -7,6 +7,7 @@ from uuid import UUID import threading from typing import Any, Callable, Coroutine, Dict, List, Optional, Type, Union from dataclasses import dataclass +from warnings import warn import ray from ray.serve.constants import (DEFAULT_HTTP_HOST, DEFAULT_HTTP_PORT, @@ -14,10 +15,11 @@ from ray.serve.constants import (DEFAULT_HTTP_HOST, DEFAULT_HTTP_PORT, from ray.serve.controller import ServeController, BackendTag, ReplicaTag from ray.serve.handle import RayServeHandle, RayServeSyncHandle from ray.serve.utils import (block_until_http_ready, format_actor_name, - get_random_letters, logger, get_conda_env_dir) + get_random_letters, logger, get_conda_env_dir, + get_current_node_resource_key) from ray.serve.exceptions import RayServeException from ray.serve.config import (BackendConfig, ReplicaConfig, BackendMetadata, - HTTPConfig) + HTTPOptions) from ray.serve.env import CondaEnv from ray.serve.router import RequestMetadata, Router from ray.actor import ActorHandle @@ -522,10 +524,13 @@ class Client: return handle -def start(detached: bool = False, - http_host: Optional[str] = DEFAULT_HTTP_HOST, - http_port: int = DEFAULT_HTTP_PORT, - http_middlewares: List[Any] = []) -> Client: +def start( + detached: bool = False, + http_host: Optional[str] = DEFAULT_HTTP_HOST, + http_port: int = DEFAULT_HTTP_PORT, + http_middlewares: List[Any] = [], + http_options: Optional[Union[dict, HTTPOptions]] = None, +) -> Client: """Initialize a serve instance. By default, the instance will be scoped to the lifetime of the returned @@ -536,15 +541,44 @@ def start(detached: bool = False, Args: detached (bool): Whether not the instance should be detached from this - script. - http_host (str, optional): Host for HTTP servers to listen on. Defaults - to "127.0.0.1". To expose Serve publicly, you probably want to set - this to "0.0.0.0". One HTTP server will be started on each node in - the Ray cluster. To not start HTTP servers, set this to None. - http_port (int): Port for HTTP server. Defaults to 8000. - http_middlewares (list): A list of Starlette middlewares that will be - applied to the HTTP servers in the cluster. + script. + http_host (Optional[str]): Deprecated, use http_options instead. + http_port (int): Deprecated, use http_options instead. + http_middlewares (list): Deprecated, use http_options instead. + http_options (Optional[Dict, serve.HTTPOptions]): Configuration options + for HTTP proxy. You can pass in a dictionary or HTTPOptions object + with fields: + + - host(str, None): Host for HTTP servers to listen on. Defaults to + "127.0.0.1". To expose Serve publicly, you probably want to set + this to "0.0.0.0". + - port(int): Port for HTTP server. Defaults to 8000. + - middlewares(list): A list of Starlette middlewares that will be + applied to the HTTP servers in the cluster. + - location(str, serve.config.DeploymentMode): The deployment + location of HTTP servers: + + - "HeadOnly": start one HTTP server on the head node. Serve + assumes the head node is the node you executed serve.start + on. This is the default. + - "EveryNode": start one HTTP server per node. + - "NoServer" or None: disable HTTP server. """ + if ((http_host != DEFAULT_HTTP_HOST) or (http_port != DEFAULT_HTTP_PORT) + or (len(http_middlewares) != 0)): + if http_options is not None: + raise ValueError( + "You cannot specify both `http_options` and any of the " + "`http_host`, `http_port`, and `http_middlewares` arguments. " + "`http_options` is preferred.") + else: + warn( + "`http_host`, `http_port`, `http_middlewares` are deprecated. " + "Please use serve.start(http_options={'host': ..., " + "'port': ..., middlewares': ...}) instead.", + DeprecationWarning, + ) + # Initialize ray if needed. if not ray.is_initialized(): ray.init() @@ -564,29 +598,35 @@ def start(detached: bool = False, controller_name = format_actor_name(SERVE_CONTROLLER_NAME, get_random_letters()) + if isinstance(http_options, dict): + http_options = HTTPOptions.parse_obj(http_options) + if http_options is None: + http_options = HTTPOptions( + host=http_host, port=http_port, middlewares=http_middlewares) + controller = ServeController.options( name=controller_name, lifetime="detached" if detached else None, max_restarts=-1, max_task_retries=-1, + # Pin Serve controller on the head node. + resources={ + get_current_node_resource_key(): 0.01 + }, ).remote( controller_name, - HTTPConfig(http_host, http_port, http_middlewares), - detached=detached) + http_options, + detached=detached, + ) - if http_host is not None: - futures = [] - for node_id in ray.state.node_ids(): - future = block_until_http_ready.options( - num_cpus=0, resources={ - node_id: 0.01 - }).remote( - "http://{}:{}/-/routes".format(http_host, http_port), - timeout=HTTP_PROXY_TIMEOUT) - futures.append(future) + proxy_handles = ray.get(controller.get_http_proxies.remote()) + if len(proxy_handles) > 0: try: - ray.get(futures) - except ray.exceptions.RayTaskError: + ray.get( + [handle.ready.remote() for handle in proxy_handles.values()], + timeout=HTTP_PROXY_TIMEOUT, + ) + except ray.exceptions.GetTimeoutError: raise TimeoutError( "HTTP proxies not available after {HTTP_PROXY_TIMEOUT}s.") diff --git a/python/ray/serve/config.py b/python/ray/serve/config.py index 7d40b8430..205af81b0 100644 --- a/python/ray/serve/config.py +++ b/python/ray/serve/config.py @@ -1,9 +1,12 @@ import inspect +from dataclasses import dataclass +from enum import Enum +from typing import Any, Dict, List, Optional -from pydantic import BaseModel, PositiveInt, validator, PositiveFloat -from ray.serve.constants import ASYNC_CONCURRENCY -from typing import Optional, Dict, Any, List -from dataclasses import dataclass, field +import pydantic +from pydantic import BaseModel, PositiveFloat, PositiveInt, validator +from ray.serve.constants import (ASYNC_CONCURRENCY, DEFAULT_HTTP_HOST, + DEFAULT_HTTP_PORT) def _callable_accepts_batch(func_or_class): @@ -198,8 +201,26 @@ class ReplicaConfig: self.resource_dict.update(custom_resources) -@dataclass -class HTTPConfig: - host: str = field(init=True) - port: int = field(init=True) - middlewares: List[Any] = field(init=True) +class DeploymentMode(str, Enum): + NoServer = "NoServer" + HeadOnly = "HeadOnly" + EveryNode = "EveryNode" + + +class HTTPOptions(pydantic.BaseModel): + # Documentation inside serve.start for user's convenience. + host: Optional[str] = DEFAULT_HTTP_HOST + port: int = DEFAULT_HTTP_PORT + middlewares: List[Any] = [] + location: Optional[DeploymentMode] = DeploymentMode.HeadOnly + + @validator("location", always=True) + def location_backfill_no_server(cls, v, values): + if values["host"] is None or v is None: + return DeploymentMode.NoServer + return v + + class Config: + validate_assignment = True + extra = "forbid" + arbitrary_types_allowed = True diff --git a/python/ray/serve/controller.py b/python/ray/serve/controller.py index e523221a2..c9745545e 100644 --- a/python/ray/serve/controller.py +++ b/python/ray/serve/controller.py @@ -1,35 +1,28 @@ import asyncio -from collections import defaultdict import os import random import time +from collections import defaultdict from dataclasses import dataclass -from typing import Dict, Any, Optional -from uuid import uuid4, UUID +from typing import Any, Dict, Optional +from uuid import UUID, uuid4 + +import ray.cloudpickle as pickle +from ray.actor import ActorHandle +from ray.serve.backend_state import BackendState +from ray.serve.backend_worker import create_backend_replica +from ray.serve.common import (BackendInfo, BackendTag, EndpointTag, GoalId, + NodeId, ReplicaTag, TrafficPolicy) +from ray.serve.config import (BackendConfig, HTTPOptions, ReplicaConfig) +from ray.serve.constants import LongPollKey +from ray.serve.endpoint_state import EndpointState +from ray.serve.exceptions import RayServeException +from ray.serve.http_state import HTTPState +from ray.serve.kv_store import RayInternalKVStore +from ray.serve.long_poll import LongPollHost +from ray.serve.utils import logger import ray -import ray.cloudpickle as pickle -from ray.serve.backend_worker import create_backend_replica -from ray.serve.constants import ( - LongPollKey, ) -from ray.serve.kv_store import RayInternalKVStore -from ray.serve.exceptions import RayServeException -from ray.serve.utils import logger -from ray.serve.config import BackendConfig, ReplicaConfig, HTTPConfig -from ray.serve.long_poll import LongPollHost -from ray.serve.backend_state import BackendState -from ray.serve.endpoint_state import EndpointState -from ray.serve.http_state import HTTPState -from ray.serve.common import ( - BackendInfo, - BackendTag, - EndpointTag, - GoalId, - ReplicaTag, - NodeId, - TrafficPolicy, -) -from ray.actor import ActorHandle # Used for testing purposes only. If this is set, the controller will crash # after writing each checkpoint with the specified probability. @@ -84,7 +77,7 @@ class ServeController: async def __init__(self, controller_name: str, - http_config: HTTPConfig, + http_config: HTTPOptions, detached: bool = False): # Used to read/write checkpoints. self.kv_store = RayInternalKVStore(namespace=controller_name) diff --git a/python/ray/serve/examples/doc/quickstart_class.py b/python/ray/serve/examples/doc/quickstart_class.py index fbdec6c7a..f2843c2cb 100644 --- a/python/ray/serve/examples/doc/quickstart_class.py +++ b/python/ray/serve/examples/doc/quickstart_class.py @@ -2,7 +2,7 @@ import ray from ray import serve import requests -ray.init() +ray.init(num_cpus=4) client = serve.start() diff --git a/python/ray/serve/examples/doc/quickstart_function.py b/python/ray/serve/examples/doc/quickstart_function.py index 4d14dd8b0..8ccf82c28 100644 --- a/python/ray/serve/examples/doc/quickstart_function.py +++ b/python/ray/serve/examples/doc/quickstart_function.py @@ -2,7 +2,7 @@ import ray from ray import serve import requests -ray.init() +ray.init(num_cpus=4) client = serve.start() diff --git a/python/ray/serve/http_proxy.py b/python/ray/serve/http_proxy.py index b6a551023..dad0c0034 100644 --- a/python/ray/serve/http_proxy.py +++ b/python/ray/serve/http_proxy.py @@ -164,6 +164,8 @@ class HTTPProxyActor: self.host = host self.port = port + self.setup_complete = asyncio.Event() + self.app = HTTPProxy(controller_name) await self.app.setup() @@ -173,10 +175,25 @@ class HTTPProxyActor: **middleware.options) # Start running the HTTP server on the event loop. - asyncio.get_event_loop().create_task(self.run()) + # This task should be running forever. We track it in case of failure. + self.running_task = asyncio.get_event_loop().create_task(self.run()) - def ready(self): - return True + async def ready(self): + """Returns when HTTP proxy is ready to serve traffic. + Or throw exception when it is not able to serve traffic. + """ + done_set, _ = await asyncio.wait( + [ + # Either the HTTP setup has completed. + # The event is set inside self.run. + self.setup_complete.wait(), + # Or self.run errored. + self.running_task, + ], + return_when=asyncio.FIRST_COMPLETED) + + # Return None, or re-throw the exception from self.running_task. + return await done_set.pop() async def run(self): sock = socket.socket() @@ -202,4 +219,6 @@ class HTTPProxyActor: # because the existing implementation fails if it isn't running in # the main thread and uvicorn doesn't expose a way to configure it. server.install_signal_handlers = lambda: None + + self.setup_complete.set() await server.serve(sockets=[sock]) diff --git a/python/ray/serve/http_state.py b/python/ray/serve/http_state.py index 76027aea6..7e2b0cf9c 100644 --- a/python/ray/serve/http_state.py +++ b/python/ray/serve/http_state.py @@ -1,17 +1,18 @@ -from typing import Dict +from typing import Dict, List, Tuple import ray from ray.actor import ActorHandle -from ray.serve.config import HTTPConfig +from ray.serve.config import HTTPOptions, DeploymentMode from ray.serve.constants import ASYNC_CONCURRENCY, SERVE_PROXY_NAME from ray.serve.http_proxy import HTTPProxyActor -from ray.serve.utils import format_actor_name, logger, get_all_node_ids +from ray.serve.utils import (format_actor_name, logger, get_all_node_ids, + get_current_node_resource_key) from ray.serve.common import NodeId class HTTPState: def __init__(self, controller_name: str, detached: bool, - config: HTTPConfig): + config: HTTPOptions): self._controller_name = controller_name self._detached = detached self._config = config @@ -30,12 +31,25 @@ class HTTPState: self._start_proxies_if_needed() self._stop_proxies_if_needed() + def _get_target_nodes(self) -> List[Tuple[str, str]]: + """Return the list of (id, resource_key) to deploy HTTP servers on.""" + location = self._config.location + target_nodes = get_all_node_ids() + + if location == DeploymentMode.NoServer: + return [] + + if location == DeploymentMode.HeadOnly: + head_node_resource_key = get_current_node_resource_key() + target_nodes = [(node_id, node_resource) + for node_id, node_resource in target_nodes + if node_resource == head_node_resource_key][:1] + + return target_nodes + def _start_proxies_if_needed(self) -> None: """Start a proxy on every node if it doesn't already exist.""" - if self._config.host is None: - return - - for node_id, node_resource in get_all_node_ids(): + for node_id, node_resource in self._get_target_nodes(): if node_id in self._proxy_actors: continue diff --git a/python/ray/serve/scripts.py b/python/ray/serve/scripts.py index 8a0baa7c0..849eeff86 100644 --- a/python/ray/serve/scripts.py +++ b/python/ray/serve/scripts.py @@ -1,6 +1,7 @@ #!/usr/bin/env python import click +from ray.serve.config import DeploymentMode import ray from ray import serve @@ -36,8 +37,20 @@ def cli(address): type=int, help="Port for HTTP servers to listen on. " f"Defaults to {DEFAULT_HTTP_PORT}.") -def start(http_host, http_port): - serve.start(detached=True, http_host=http_host, http_port=http_port) +@click.option( + "--http-location", + default=DeploymentMode.HeadOnly, + required=False, + type=click.Choice(list(DeploymentMode)), + help="Location of the HTTP servers. Defaults to HeadOnly.") +def start(http_host, http_port, http_location): + serve.start( + detached=True, + http_options=dict( + host=http_host, + port=http_port, + location=http_location, + )) @cli.command(help="Shutdown the running Serve instance on the Ray cluster.") diff --git a/python/ray/serve/tests/conftest.py b/python/ray/serve/tests/conftest.py index 6c78cfabf..f5144705b 100644 --- a/python/ray/serve/tests/conftest.py +++ b/python/ray/serve/tests/conftest.py @@ -15,7 +15,13 @@ if os.environ.get("RAY_SERVE_INTENTIONALLY_CRASH", False) == 1: @pytest.fixture(scope="session") def _shared_serve_instance(): - os.environ["SERVE_LOG_DEBUG"] = "1" # Uncomment to turn on debug log + # Note(simon): + # This line should be not turned on on master because it leads to very + # spammy and not useful log in case of a failure in CI. + # To run locally, please use this instead. + # SERVE_LOG_DEBUG=1 pytest -v -s test_api.py + # os.environ["SERVE_LOG_DEBUG"] = "1" <- Do not uncomment this. + # Overriding task_retry_delay_ms to relaunch actors more quickly ray.init( num_cpus=36, diff --git a/python/ray/serve/tests/test_config.py b/python/ray/serve/tests/test_config.py index 3b07148c9..40942ad76 100644 --- a/python/ray/serve/tests/test_config.py +++ b/python/ray/serve/tests/test_config.py @@ -1,7 +1,8 @@ import pytest from ray import serve -from ray.serve.config import BackendConfig, ReplicaConfig, BackendMetadata +from ray.serve.config import (BackendConfig, DeploymentMode, HTTPOptions, + ReplicaConfig, BackendMetadata) from ray.serve.constants import ASYNC_CONCURRENCY from pydantic import ValidationError @@ -142,6 +143,15 @@ def test_replica_config_validation(): ReplicaConfig(Class, ray_actor_options={"max_restarts": None}) +def test_http_options(): + HTTPOptions() + HTTPOptions(host="8.8.8.8", middlewares=[object()]) + assert HTTPOptions(host=None).location == "NoServer" + assert HTTPOptions(location=None).location == "NoServer" + assert HTTPOptions( + location=DeploymentMode.EveryNode).location == "EveryNode" + + if __name__ == "__main__": import sys sys.exit(pytest.main(["-v", "-s", __file__])) diff --git a/python/ray/serve/tests/test_standalone.py b/python/ray/serve/tests/test_standalone.py index 686571068..ac632fbd3 100644 --- a/python/ray/serve/tests/test_standalone.py +++ b/python/ray/serve/tests/test_standalone.py @@ -55,13 +55,13 @@ def test_detached_deployment(): "This test can only be ran when port sharing is supported.")) def test_multiple_routers(): cluster = Cluster() - head_node = cluster.add_node() - cluster.add_node() + head_node = cluster.add_node(num_cpus=4) + cluster.add_node(num_cpus=4) ray.init(head_node.address) node_ids = ray.state.node_ids() assert len(node_ids) == 2 - client = serve.start(http_port=8005) # noqa: F841 + client = serve.start(http_options=dict(port=8005, location="EveryNode")) def get_proxy_names(): proxy_names = [] @@ -135,11 +135,12 @@ def test_middleware(): port = new_port() serve.start( - http_port=port, - http_middlewares=[ - Middleware( - CORSMiddleware, allow_origins=["*"], allow_methods=["*"]) - ]) + http_options=dict( + port=port, + middlewares=[ + Middleware( + CORSMiddleware, allow_origins=["*"], allow_methods=["*"]) + ])) ray.get(block_until_http_ready.remote(f"http://127.0.0.1:{port}/-/routes")) # Snatched several test cases from Starlette @@ -159,5 +160,80 @@ def test_middleware(): ray.shutdown() +def test_http_proxy_fail_loudly(): + # Test that if the http server fail to start, serve.start should fail. + with pytest.raises(socket.gaierror): + serve.start(http_options={"host": "bad.ip.address"}) + + ray.shutdown() + + +def test_no_http(): + # The following should have the same effect. + options = [ + { + "http_host": None + }, + { + "http_options": { + "host": None + } + }, + { + "http_options": { + "location": None + } + }, + { + "http_options": { + "location": "NoServer" + } + }, + ] + + ray.init() + for option in options: + client = serve.start(**option) + + # Only controller actor should exist + live_actors = [ + actor for actor in ray.actors().values() + if actor["State"] == ray.gcs_utils.ActorTableData.ALIVE + ] + assert len(live_actors) == 1 + + client.shutdown() + ray.shutdown() + + +def test_http_head_only(): + cluster = Cluster() + head_node = cluster.add_node(num_cpus=4) + cluster.add_node(num_cpus=4) + + ray.init(head_node.address) + node_ids = ray.state.node_ids() + assert len(node_ids) == 2 + + client = serve.start(http_options={ + "port": new_port(), + "location": "HeadOnly" + }) + + # Only the controller and head node actor should be started + assert len(ray.actors()) == 2 + + # They should all be placed on the head node + cpu_per_nodes = { + r["CPU"] + for r in ray.state.state._available_resources_per_node().values() + } + assert cpu_per_nodes == {2, 4} + + client.shutdown() + ray.shutdown() + cluster.shutdown() + + if __name__ == "__main__": sys.exit(pytest.main(["-v", "-s", __file__])) diff --git a/python/ray/serve/utils.py b/python/ray/serve/utils.py index 4d41880fa..7244ca4db 100644 --- a/python/ray/serve/utils.py +++ b/python/ray/serve/utils.py @@ -425,3 +425,19 @@ def compute_dict_delta(old_dict, new_dict) -> Tuple[dict, dict, dict]: {k: new_dict[k] for k in updated_keys}, ) + + +def get_current_node_resource_key() -> str: + """Get the Ray resource key for current node. + + It can be used for actor placement. + """ + current_node_id = ray.get_runtime_context().node_id.hex() + for node in ray.nodes(): + if node["NodeID"] == current_node_id: + # Found the node. + for key in node["Resources"].keys(): + if key.startswith("node:"): + return key + else: + raise ValueError("Cannot found the node dictionary for current node.")