mirror of
https://github.com/wassname/ray.git
synced 2026-07-03 23:16:23 +08:00
[Serve] Add benchmark for async handles (#12858)
This commit is contained in:
+13
-2
@@ -1,3 +1,4 @@
|
||||
import asyncio
|
||||
import atexit
|
||||
import time
|
||||
from functools import wraps
|
||||
@@ -411,13 +412,17 @@ class Client:
|
||||
@_ensure_connected
|
||||
def get_handle(self,
|
||||
endpoint_name: str,
|
||||
missing_ok: Optional[bool] = False) -> RayServeHandle:
|
||||
missing_ok: Optional[bool] = False,
|
||||
sync: bool = True) -> RayServeHandle:
|
||||
"""Retrieve RayServeHandle for service endpoint to invoke it from Python.
|
||||
|
||||
Args:
|
||||
endpoint_name (str): A registered service endpoint.
|
||||
missing_ok (bool): If true, then Serve won't check the endpoint is
|
||||
registered. False by default.
|
||||
sync (bool): If true, then Serve will return a ServeHandle that
|
||||
works everywhere. Otherwise, Serve will return a ServeHandle
|
||||
that's only usable in asyncio loop.
|
||||
|
||||
Returns:
|
||||
RayServeHandle
|
||||
@@ -426,8 +431,14 @@ class Client:
|
||||
self._controller.get_all_endpoints.remote()):
|
||||
raise KeyError(f"Endpoint '{endpoint_name}' does not exist.")
|
||||
|
||||
if asyncio.get_event_loop().is_running() and sync:
|
||||
logger.warning(
|
||||
"You are retrieving a ServeHandle inside an asyncio loop. "
|
||||
"Try getting client.get_handle(.., sync=False) to get better "
|
||||
"performance.")
|
||||
|
||||
if endpoint_name not in self._handle_cache:
|
||||
handle = RayServeHandle(self._controller, endpoint_name, sync=True)
|
||||
handle = RayServeHandle(self._controller, endpoint_name, sync=sync)
|
||||
self._handle_cache[endpoint_name] = handle
|
||||
return self._handle_cache[endpoint_name]
|
||||
|
||||
|
||||
@@ -23,64 +23,88 @@
|
||||
# 2 forwarders and 5 worker replicas: 620 requests/s
|
||||
# 2 forwarders and 10 worker replicas: 609 requests/s
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
|
||||
import ray
|
||||
from ray import serve
|
||||
from ray.serve import BackendConfig
|
||||
from ray.serve.utils import logger
|
||||
import time
|
||||
|
||||
num_queries = 2000
|
||||
num_queries = 10000
|
||||
max_concurrent_queries = 100000
|
||||
|
||||
ray.init(address="auto")
|
||||
|
||||
client = serve.start()
|
||||
|
||||
|
||||
def hello_world(_):
|
||||
def worker(_):
|
||||
return b"Hello World"
|
||||
|
||||
|
||||
class ForwardActor:
|
||||
def __init__(self):
|
||||
def __init__(self, sync: bool):
|
||||
client = serve.connect()
|
||||
self.handle = client.get_handle("hello_world")
|
||||
self.sync = sync
|
||||
self.handle = client.get_handle("worker", sync=sync)
|
||||
|
||||
async def __call__(self, _):
|
||||
await self.handle.remote()
|
||||
if self.sync:
|
||||
await self.handle.remote()
|
||||
else:
|
||||
await (await self.handle.remote_async())
|
||||
|
||||
|
||||
client.create_backend("hello_world", hello_world)
|
||||
client.create_endpoint("hello_world", backend="hello_world")
|
||||
async def run_test(num_replicas, num_forwarders, sync):
|
||||
client = serve.start()
|
||||
client.create_backend(
|
||||
"worker",
|
||||
worker,
|
||||
config=BackendConfig(
|
||||
num_replicas=num_replicas,
|
||||
max_concurrent_queries=max_concurrent_queries,
|
||||
))
|
||||
client.create_endpoint("worker", backend="worker")
|
||||
endpoint_name = "worker"
|
||||
|
||||
client.create_backend("ForwardActor", ForwardActor)
|
||||
client.create_endpoint("ForwardActor", backend="ForwardActor")
|
||||
if num_forwarders > 0:
|
||||
client.create_backend(
|
||||
"ForwardActor",
|
||||
ForwardActor,
|
||||
sync,
|
||||
config=BackendConfig(
|
||||
num_replicas=num_forwarders,
|
||||
max_concurrent_queries=max_concurrent_queries))
|
||||
client.create_endpoint("ForwardActor", backend="ForwardActor")
|
||||
endpoint_name = "ForwardActor"
|
||||
|
||||
|
||||
def run_test(num_replicas, num_forwarders):
|
||||
replicas_config = BackendConfig(num_replicas=num_replicas)
|
||||
client.update_backend_config("hello_world", replicas_config)
|
||||
|
||||
if (num_forwarders == 0):
|
||||
handle = client.get_handle("hello_world")
|
||||
else:
|
||||
forwarders_config = BackendConfig(num_replicas=num_forwarders)
|
||||
client.update_backend_config("ForwardActor", forwarders_config)
|
||||
handle = client.get_handle("ForwardActor")
|
||||
handle = client.get_handle(endpoint_name, sync=sync)
|
||||
|
||||
# warmup - helpful to wait for gc.collect() and actors to start
|
||||
start = time.time()
|
||||
while time.time() - start < 1:
|
||||
ray.get(handle.remote())
|
||||
if sync:
|
||||
ray.get(handle.remote())
|
||||
else:
|
||||
ray.get(await handle.remote_async())
|
||||
|
||||
# real test
|
||||
start = time.time()
|
||||
ray.get([handle.remote() for _ in range(num_queries)])
|
||||
if sync:
|
||||
ray.get([handle.remote() for _ in range(num_queries)])
|
||||
else:
|
||||
ray.get([(await handle.remote_async()) for _ in range(num_queries)])
|
||||
qps = num_queries / (time.time() - start)
|
||||
|
||||
logger.info("{} forwarders and {} worker replicas: {} requests/s".format(
|
||||
num_forwarders, num_replicas, int(qps)))
|
||||
print(
|
||||
f"Sync: {sync}, {num_forwarders} forwarders and {num_replicas} worker "
|
||||
f"replicas: {int(qps)} requests/s")
|
||||
client.shutdown()
|
||||
|
||||
|
||||
for num_forwarders in [0, 1, 2]:
|
||||
for num_replicas in [1, 5, 10]:
|
||||
run_test(num_replicas, num_forwarders)
|
||||
async def main():
|
||||
for sync in [True, False]:
|
||||
for num_forwarders in [0, 1, 2]:
|
||||
for num_replicas in [1, 5, 10]:
|
||||
await run_test(num_replicas, num_forwarders, sync)
|
||||
|
||||
|
||||
asyncio.get_event_loop().run_until_complete(main())
|
||||
|
||||
@@ -7,6 +7,7 @@ import ray
|
||||
from ray.serve.context import TaskContext
|
||||
from ray.serve.router import RequestMetadata, Router
|
||||
from ray.serve.utils import get_random_letters
|
||||
from ray.serve.exceptions import RayServeException
|
||||
|
||||
global_async_loop = None
|
||||
|
||||
@@ -109,16 +110,25 @@ class RayServeHandle:
|
||||
``**kwargs``: All keyword arguments will be available in
|
||||
``request.args``.
|
||||
"""
|
||||
assert self.sync, "handle.remote() should be called from sync handle."
|
||||
if not self.sync:
|
||||
raise RayServeException(
|
||||
"You are trying to call handle.remote() with async handle. "
|
||||
"Please use `await handle.remote_async()` instead.")
|
||||
|
||||
coro = self._remote(request_data, kwargs)
|
||||
future: concurrent.futures.Future = asyncio.run_coroutine_threadsafe(
|
||||
coro, self.async_loop)
|
||||
|
||||
# Block until the result is ready.
|
||||
return future.result()
|
||||
|
||||
async def _remote_async(self, request_data, **kwargs) -> ray.ObjectRef:
|
||||
async def remote_async(self,
|
||||
request_data: Optional[Union[Dict, Any]] = None,
|
||||
**kwargs) -> ray.ObjectRef:
|
||||
"""Experimental API for enqueue a request in async context."""
|
||||
assert not self.sync, "_remote_async must be called inside async loop."
|
||||
if not asyncio.get_event_loop().is_running():
|
||||
raise RayServeException(
|
||||
"remote_async must be called from a running event loop.")
|
||||
return await self._remote(request_data, kwargs)
|
||||
|
||||
def options(self,
|
||||
|
||||
Reference in New Issue
Block a user