[Serve] HTTPOptions for deployment modes (#13142)

This commit is contained in:
Simon Mo
2021-01-05 16:41:52 -08:00
committed by GitHub
parent bd19ed31e7
commit 39813ff6b0
17 changed files with 329 additions and 97 deletions
+18
View File
@@ -405,3 +405,21 @@ backend based on a class that is installed in the Python environment that
the workers will run in. Example:
.. literalinclude:: ../../../python/ray/serve/examples/doc/imported_backend.py
Configuring HTTP Server Locations
=================================
By default, Ray Serve starts only one HTTP on the head node of the Ray cluster.
You can configure this behavior using the ``http_options={"location": ...}`` flag
in :mod:`serve.start <ray.serve.start>`:
- "HeadOnly": start one HTTP server on the head node. Serve
assumes the head node is the node you executed serve.start
on. This is the default.
- "EveryNode": start one HTTP server per node.
- "NoServer" or ``None``: disable HTTP server.
.. note::
Using the "EveryNode" option, you can point a cloud load balancer to the
instance group of Ray cluster to achieve high availability of Serve's HTTP
proxies.
+2 -2
View File
@@ -69,7 +69,7 @@ a backend in serve for our model (and versioned it with a string).
What serve does when we run this code is store the model as a Ray actor
and route traffic to it as the endpoint is queried, in this case over HTTP.
Note that in order for this endpoint to be accessible from other machines, we
need to specify ``http_host="0.0.0.0"`` in :mod:`serve.start <ray.serve.start>` like we did here.
need to specify ``http_options={"host": "0.0.0.0"}`` in :mod:`serve.start <ray.serve.start>` like we did here.
Now let's query our endpoint to see the result.
@@ -225,7 +225,7 @@ With the cluster now running, we can run a simple script to start Ray Serve and
# Connect to the running Ray cluster.
ray.init(address="auto")
# Bind on 0.0.0.0 to expose the HTTP server on external IPs.
client = serve.start(http_host="0.0.0.0")
client = serve.start(http_options={"host": "0.0.0.0"})
def hello():
return "hello world"
+2 -2
View File
@@ -96,10 +96,10 @@ You can follow the same pattern for other Starlette middlewares.
from starlette.middleware.cors import CORSMiddleware
client = serve.start(
http_middlewares=[
http_options={"middlewares": [
Middleware(
CORSMiddleware, allow_origins=["*"], allow_methods=["*"])
])
]})
.. _serve-handle-explainer:
+1 -1
View File
@@ -105,7 +105,7 @@ py_test(
py_test(
name = "test_standalone",
size = "small",
size = "medium",
srcs = serve_tests_srcs,
tags = ["exclusive"],
deps = [":serve_lib"],
+10 -4
View File
@@ -1,7 +1,12 @@
from ray.serve.api import (accept_batch, Client, connect,
get_current_backend_tag, get_current_replica_tag,
start)
from ray.serve.config import BackendConfig
from ray.serve.api import (
accept_batch,
Client,
connect,
get_current_backend_tag,
get_current_replica_tag,
start,
)
from ray.serve.config import BackendConfig, HTTPOptions
from ray.serve.env import CondaEnv
# Mute the warning because Serve sometimes intentionally calls
@@ -18,4 +23,5 @@ __all__ = [
"get_current_backend_tag",
"get_current_replica_tag",
"start",
"HTTPOptions",
]
+68 -28
View File
@@ -7,6 +7,7 @@ from uuid import UUID
import threading
from typing import Any, Callable, Coroutine, Dict, List, Optional, Type, Union
from dataclasses import dataclass
from warnings import warn
import ray
from ray.serve.constants import (DEFAULT_HTTP_HOST, DEFAULT_HTTP_PORT,
@@ -14,10 +15,11 @@ from ray.serve.constants import (DEFAULT_HTTP_HOST, DEFAULT_HTTP_PORT,
from ray.serve.controller import ServeController, BackendTag, ReplicaTag
from ray.serve.handle import RayServeHandle, RayServeSyncHandle
from ray.serve.utils import (block_until_http_ready, format_actor_name,
get_random_letters, logger, get_conda_env_dir)
get_random_letters, logger, get_conda_env_dir,
get_current_node_resource_key)
from ray.serve.exceptions import RayServeException
from ray.serve.config import (BackendConfig, ReplicaConfig, BackendMetadata,
HTTPConfig)
HTTPOptions)
from ray.serve.env import CondaEnv
from ray.serve.router import RequestMetadata, Router
from ray.actor import ActorHandle
@@ -522,10 +524,13 @@ class Client:
return handle
def start(detached: bool = False,
http_host: Optional[str] = DEFAULT_HTTP_HOST,
http_port: int = DEFAULT_HTTP_PORT,
http_middlewares: List[Any] = []) -> Client:
def start(
detached: bool = False,
http_host: Optional[str] = DEFAULT_HTTP_HOST,
http_port: int = DEFAULT_HTTP_PORT,
http_middlewares: List[Any] = [],
http_options: Optional[Union[dict, HTTPOptions]] = None,
) -> Client:
"""Initialize a serve instance.
By default, the instance will be scoped to the lifetime of the returned
@@ -536,15 +541,44 @@ def start(detached: bool = False,
Args:
detached (bool): Whether not the instance should be detached from this
script.
http_host (str, optional): Host for HTTP servers to listen on. Defaults
to "127.0.0.1". To expose Serve publicly, you probably want to set
this to "0.0.0.0". One HTTP server will be started on each node in
the Ray cluster. To not start HTTP servers, set this to None.
http_port (int): Port for HTTP server. Defaults to 8000.
http_middlewares (list): A list of Starlette middlewares that will be
applied to the HTTP servers in the cluster.
script.
http_host (Optional[str]): Deprecated, use http_options instead.
http_port (int): Deprecated, use http_options instead.
http_middlewares (list): Deprecated, use http_options instead.
http_options (Optional[Dict, serve.HTTPOptions]): Configuration options
for HTTP proxy. You can pass in a dictionary or HTTPOptions object
with fields:
- host(str, None): Host for HTTP servers to listen on. Defaults to
"127.0.0.1". To expose Serve publicly, you probably want to set
this to "0.0.0.0".
- port(int): Port for HTTP server. Defaults to 8000.
- middlewares(list): A list of Starlette middlewares that will be
applied to the HTTP servers in the cluster.
- location(str, serve.config.DeploymentMode): The deployment
location of HTTP servers:
- "HeadOnly": start one HTTP server on the head node. Serve
assumes the head node is the node you executed serve.start
on. This is the default.
- "EveryNode": start one HTTP server per node.
- "NoServer" or None: disable HTTP server.
"""
if ((http_host != DEFAULT_HTTP_HOST) or (http_port != DEFAULT_HTTP_PORT)
or (len(http_middlewares) != 0)):
if http_options is not None:
raise ValueError(
"You cannot specify both `http_options` and any of the "
"`http_host`, `http_port`, and `http_middlewares` arguments. "
"`http_options` is preferred.")
else:
warn(
"`http_host`, `http_port`, `http_middlewares` are deprecated. "
"Please use serve.start(http_options={'host': ..., "
"'port': ..., middlewares': ...}) instead.",
DeprecationWarning,
)
# Initialize ray if needed.
if not ray.is_initialized():
ray.init()
@@ -564,29 +598,35 @@ def start(detached: bool = False,
controller_name = format_actor_name(SERVE_CONTROLLER_NAME,
get_random_letters())
if isinstance(http_options, dict):
http_options = HTTPOptions.parse_obj(http_options)
if http_options is None:
http_options = HTTPOptions(
host=http_host, port=http_port, middlewares=http_middlewares)
controller = ServeController.options(
name=controller_name,
lifetime="detached" if detached else None,
max_restarts=-1,
max_task_retries=-1,
# Pin Serve controller on the head node.
resources={
get_current_node_resource_key(): 0.01
},
).remote(
controller_name,
HTTPConfig(http_host, http_port, http_middlewares),
detached=detached)
http_options,
detached=detached,
)
if http_host is not None:
futures = []
for node_id in ray.state.node_ids():
future = block_until_http_ready.options(
num_cpus=0, resources={
node_id: 0.01
}).remote(
"http://{}:{}/-/routes".format(http_host, http_port),
timeout=HTTP_PROXY_TIMEOUT)
futures.append(future)
proxy_handles = ray.get(controller.get_http_proxies.remote())
if len(proxy_handles) > 0:
try:
ray.get(futures)
except ray.exceptions.RayTaskError:
ray.get(
[handle.ready.remote() for handle in proxy_handles.values()],
timeout=HTTP_PROXY_TIMEOUT,
)
except ray.exceptions.GetTimeoutError:
raise TimeoutError(
"HTTP proxies not available after {HTTP_PROXY_TIMEOUT}s.")
+30 -9
View File
@@ -1,9 +1,12 @@
import inspect
from dataclasses import dataclass
from enum import Enum
from typing import Any, Dict, List, Optional
from pydantic import BaseModel, PositiveInt, validator, PositiveFloat
from ray.serve.constants import ASYNC_CONCURRENCY
from typing import Optional, Dict, Any, List
from dataclasses import dataclass, field
import pydantic
from pydantic import BaseModel, PositiveFloat, PositiveInt, validator
from ray.serve.constants import (ASYNC_CONCURRENCY, DEFAULT_HTTP_HOST,
DEFAULT_HTTP_PORT)
def _callable_accepts_batch(func_or_class):
@@ -198,8 +201,26 @@ class ReplicaConfig:
self.resource_dict.update(custom_resources)
@dataclass
class HTTPConfig:
host: str = field(init=True)
port: int = field(init=True)
middlewares: List[Any] = field(init=True)
class DeploymentMode(str, Enum):
NoServer = "NoServer"
HeadOnly = "HeadOnly"
EveryNode = "EveryNode"
class HTTPOptions(pydantic.BaseModel):
# Documentation inside serve.start for user's convenience.
host: Optional[str] = DEFAULT_HTTP_HOST
port: int = DEFAULT_HTTP_PORT
middlewares: List[Any] = []
location: Optional[DeploymentMode] = DeploymentMode.HeadOnly
@validator("location", always=True)
def location_backfill_no_server(cls, v, values):
if values["host"] is None or v is None:
return DeploymentMode.NoServer
return v
class Config:
validate_assignment = True
extra = "forbid"
arbitrary_types_allowed = True
+19 -26
View File
@@ -1,35 +1,28 @@
import asyncio
from collections import defaultdict
import os
import random
import time
from collections import defaultdict
from dataclasses import dataclass
from typing import Dict, Any, Optional
from uuid import uuid4, UUID
from typing import Any, Dict, Optional
from uuid import UUID, uuid4
import ray.cloudpickle as pickle
from ray.actor import ActorHandle
from ray.serve.backend_state import BackendState
from ray.serve.backend_worker import create_backend_replica
from ray.serve.common import (BackendInfo, BackendTag, EndpointTag, GoalId,
NodeId, ReplicaTag, TrafficPolicy)
from ray.serve.config import (BackendConfig, HTTPOptions, ReplicaConfig)
from ray.serve.constants import LongPollKey
from ray.serve.endpoint_state import EndpointState
from ray.serve.exceptions import RayServeException
from ray.serve.http_state import HTTPState
from ray.serve.kv_store import RayInternalKVStore
from ray.serve.long_poll import LongPollHost
from ray.serve.utils import logger
import ray
import ray.cloudpickle as pickle
from ray.serve.backend_worker import create_backend_replica
from ray.serve.constants import (
LongPollKey, )
from ray.serve.kv_store import RayInternalKVStore
from ray.serve.exceptions import RayServeException
from ray.serve.utils import logger
from ray.serve.config import BackendConfig, ReplicaConfig, HTTPConfig
from ray.serve.long_poll import LongPollHost
from ray.serve.backend_state import BackendState
from ray.serve.endpoint_state import EndpointState
from ray.serve.http_state import HTTPState
from ray.serve.common import (
BackendInfo,
BackendTag,
EndpointTag,
GoalId,
ReplicaTag,
NodeId,
TrafficPolicy,
)
from ray.actor import ActorHandle
# Used for testing purposes only. If this is set, the controller will crash
# after writing each checkpoint with the specified probability.
@@ -84,7 +77,7 @@ class ServeController:
async def __init__(self,
controller_name: str,
http_config: HTTPConfig,
http_config: HTTPOptions,
detached: bool = False):
# Used to read/write checkpoints.
self.kv_store = RayInternalKVStore(namespace=controller_name)
@@ -2,7 +2,7 @@ import ray
from ray import serve
import requests
ray.init()
ray.init(num_cpus=4)
client = serve.start()
@@ -2,7 +2,7 @@ import ray
from ray import serve
import requests
ray.init()
ray.init(num_cpus=4)
client = serve.start()
+22 -3
View File
@@ -164,6 +164,8 @@ class HTTPProxyActor:
self.host = host
self.port = port
self.setup_complete = asyncio.Event()
self.app = HTTPProxy(controller_name)
await self.app.setup()
@@ -173,10 +175,25 @@ class HTTPProxyActor:
**middleware.options)
# Start running the HTTP server on the event loop.
asyncio.get_event_loop().create_task(self.run())
# This task should be running forever. We track it in case of failure.
self.running_task = asyncio.get_event_loop().create_task(self.run())
def ready(self):
return True
async def ready(self):
"""Returns when HTTP proxy is ready to serve traffic.
Or throw exception when it is not able to serve traffic.
"""
done_set, _ = await asyncio.wait(
[
# Either the HTTP setup has completed.
# The event is set inside self.run.
self.setup_complete.wait(),
# Or self.run errored.
self.running_task,
],
return_when=asyncio.FIRST_COMPLETED)
# Return None, or re-throw the exception from self.running_task.
return await done_set.pop()
async def run(self):
sock = socket.socket()
@@ -202,4 +219,6 @@ class HTTPProxyActor:
# because the existing implementation fails if it isn't running in
# the main thread and uvicorn doesn't expose a way to configure it.
server.install_signal_handlers = lambda: None
self.setup_complete.set()
await server.serve(sockets=[sock])
+22 -8
View File
@@ -1,17 +1,18 @@
from typing import Dict
from typing import Dict, List, Tuple
import ray
from ray.actor import ActorHandle
from ray.serve.config import HTTPConfig
from ray.serve.config import HTTPOptions, DeploymentMode
from ray.serve.constants import ASYNC_CONCURRENCY, SERVE_PROXY_NAME
from ray.serve.http_proxy import HTTPProxyActor
from ray.serve.utils import format_actor_name, logger, get_all_node_ids
from ray.serve.utils import (format_actor_name, logger, get_all_node_ids,
get_current_node_resource_key)
from ray.serve.common import NodeId
class HTTPState:
def __init__(self, controller_name: str, detached: bool,
config: HTTPConfig):
config: HTTPOptions):
self._controller_name = controller_name
self._detached = detached
self._config = config
@@ -30,12 +31,25 @@ class HTTPState:
self._start_proxies_if_needed()
self._stop_proxies_if_needed()
def _get_target_nodes(self) -> List[Tuple[str, str]]:
"""Return the list of (id, resource_key) to deploy HTTP servers on."""
location = self._config.location
target_nodes = get_all_node_ids()
if location == DeploymentMode.NoServer:
return []
if location == DeploymentMode.HeadOnly:
head_node_resource_key = get_current_node_resource_key()
target_nodes = [(node_id, node_resource)
for node_id, node_resource in target_nodes
if node_resource == head_node_resource_key][:1]
return target_nodes
def _start_proxies_if_needed(self) -> None:
"""Start a proxy on every node if it doesn't already exist."""
if self._config.host is None:
return
for node_id, node_resource in get_all_node_ids():
for node_id, node_resource in self._get_target_nodes():
if node_id in self._proxy_actors:
continue
+15 -2
View File
@@ -1,6 +1,7 @@
#!/usr/bin/env python
import click
from ray.serve.config import DeploymentMode
import ray
from ray import serve
@@ -36,8 +37,20 @@ def cli(address):
type=int,
help="Port for HTTP servers to listen on. "
f"Defaults to {DEFAULT_HTTP_PORT}.")
def start(http_host, http_port):
serve.start(detached=True, http_host=http_host, http_port=http_port)
@click.option(
"--http-location",
default=DeploymentMode.HeadOnly,
required=False,
type=click.Choice(list(DeploymentMode)),
help="Location of the HTTP servers. Defaults to HeadOnly.")
def start(http_host, http_port, http_location):
serve.start(
detached=True,
http_options=dict(
host=http_host,
port=http_port,
location=http_location,
))
@cli.command(help="Shutdown the running Serve instance on the Ray cluster.")
+7 -1
View File
@@ -15,7 +15,13 @@ if os.environ.get("RAY_SERVE_INTENTIONALLY_CRASH", False) == 1:
@pytest.fixture(scope="session")
def _shared_serve_instance():
os.environ["SERVE_LOG_DEBUG"] = "1" # Uncomment to turn on debug log
# Note(simon):
# This line should be not turned on on master because it leads to very
# spammy and not useful log in case of a failure in CI.
# To run locally, please use this instead.
# SERVE_LOG_DEBUG=1 pytest -v -s test_api.py
# os.environ["SERVE_LOG_DEBUG"] = "1" <- Do not uncomment this.
# Overriding task_retry_delay_ms to relaunch actors more quickly
ray.init(
num_cpus=36,
+11 -1
View File
@@ -1,7 +1,8 @@
import pytest
from ray import serve
from ray.serve.config import BackendConfig, ReplicaConfig, BackendMetadata
from ray.serve.config import (BackendConfig, DeploymentMode, HTTPOptions,
ReplicaConfig, BackendMetadata)
from ray.serve.constants import ASYNC_CONCURRENCY
from pydantic import ValidationError
@@ -142,6 +143,15 @@ def test_replica_config_validation():
ReplicaConfig(Class, ray_actor_options={"max_restarts": None})
def test_http_options():
HTTPOptions()
HTTPOptions(host="8.8.8.8", middlewares=[object()])
assert HTTPOptions(host=None).location == "NoServer"
assert HTTPOptions(location=None).location == "NoServer"
assert HTTPOptions(
location=DeploymentMode.EveryNode).location == "EveryNode"
if __name__ == "__main__":
import sys
sys.exit(pytest.main(["-v", "-s", __file__]))
+84 -8
View File
@@ -55,13 +55,13 @@ def test_detached_deployment():
"This test can only be ran when port sharing is supported."))
def test_multiple_routers():
cluster = Cluster()
head_node = cluster.add_node()
cluster.add_node()
head_node = cluster.add_node(num_cpus=4)
cluster.add_node(num_cpus=4)
ray.init(head_node.address)
node_ids = ray.state.node_ids()
assert len(node_ids) == 2
client = serve.start(http_port=8005) # noqa: F841
client = serve.start(http_options=dict(port=8005, location="EveryNode"))
def get_proxy_names():
proxy_names = []
@@ -135,11 +135,12 @@ def test_middleware():
port = new_port()
serve.start(
http_port=port,
http_middlewares=[
Middleware(
CORSMiddleware, allow_origins=["*"], allow_methods=["*"])
])
http_options=dict(
port=port,
middlewares=[
Middleware(
CORSMiddleware, allow_origins=["*"], allow_methods=["*"])
]))
ray.get(block_until_http_ready.remote(f"http://127.0.0.1:{port}/-/routes"))
# Snatched several test cases from Starlette
@@ -159,5 +160,80 @@ def test_middleware():
ray.shutdown()
def test_http_proxy_fail_loudly():
# Test that if the http server fail to start, serve.start should fail.
with pytest.raises(socket.gaierror):
serve.start(http_options={"host": "bad.ip.address"})
ray.shutdown()
def test_no_http():
# The following should have the same effect.
options = [
{
"http_host": None
},
{
"http_options": {
"host": None
}
},
{
"http_options": {
"location": None
}
},
{
"http_options": {
"location": "NoServer"
}
},
]
ray.init()
for option in options:
client = serve.start(**option)
# Only controller actor should exist
live_actors = [
actor for actor in ray.actors().values()
if actor["State"] == ray.gcs_utils.ActorTableData.ALIVE
]
assert len(live_actors) == 1
client.shutdown()
ray.shutdown()
def test_http_head_only():
cluster = Cluster()
head_node = cluster.add_node(num_cpus=4)
cluster.add_node(num_cpus=4)
ray.init(head_node.address)
node_ids = ray.state.node_ids()
assert len(node_ids) == 2
client = serve.start(http_options={
"port": new_port(),
"location": "HeadOnly"
})
# Only the controller and head node actor should be started
assert len(ray.actors()) == 2
# They should all be placed on the head node
cpu_per_nodes = {
r["CPU"]
for r in ray.state.state._available_resources_per_node().values()
}
assert cpu_per_nodes == {2, 4}
client.shutdown()
ray.shutdown()
cluster.shutdown()
if __name__ == "__main__":
sys.exit(pytest.main(["-v", "-s", __file__]))
+16
View File
@@ -425,3 +425,19 @@ def compute_dict_delta(old_dict, new_dict) -> Tuple[dict, dict, dict]:
{k: new_dict[k]
for k in updated_keys},
)
def get_current_node_resource_key() -> str:
"""Get the Ray resource key for current node.
It can be used for actor placement.
"""
current_node_id = ray.get_runtime_context().node_id.hex()
for node in ray.nodes():
if node["NodeID"] == current_node_id:
# Found the node.
for key in node["Resources"].keys():
if key.startswith("node:"):
return key
else:
raise ValueError("Cannot found the node dictionary for current node.")