[Serve] Add benchmark for async handles (#12858)

This commit is contained in:
Simon Mo
2020-12-15 11:21:51 -08:00
committed by GitHub
parent 0031723ace
commit fdd85e3af4
3 changed files with 81 additions and 36 deletions
+13 -2
View File
@@ -1,3 +1,4 @@
import asyncio
import atexit
import time
from functools import wraps
@@ -411,13 +412,17 @@ class Client:
@_ensure_connected
def get_handle(self,
endpoint_name: str,
missing_ok: Optional[bool] = False) -> RayServeHandle:
missing_ok: Optional[bool] = False,
sync: bool = True) -> RayServeHandle:
"""Retrieve RayServeHandle for service endpoint to invoke it from Python.
Args:
endpoint_name (str): A registered service endpoint.
missing_ok (bool): If true, then Serve won't check the endpoint is
registered. False by default.
sync (bool): If true, then Serve will return a ServeHandle that
works everywhere. Otherwise, Serve will return a ServeHandle
that's only usable in asyncio loop.
Returns:
RayServeHandle
@@ -426,8 +431,14 @@ class Client:
self._controller.get_all_endpoints.remote()):
raise KeyError(f"Endpoint '{endpoint_name}' does not exist.")
if asyncio.get_event_loop().is_running() and sync:
logger.warning(
"You are retrieving a ServeHandle inside an asyncio loop. "
"Try getting client.get_handle(.., sync=False) to get better "
"performance.")
if endpoint_name not in self._handle_cache:
handle = RayServeHandle(self._controller, endpoint_name, sync=True)
handle = RayServeHandle(self._controller, endpoint_name, sync=sync)
self._handle_cache[endpoint_name] = handle
return self._handle_cache[endpoint_name]
+55 -31
View File
@@ -23,64 +23,88 @@
# 2 forwarders and 5 worker replicas: 620 requests/s
# 2 forwarders and 10 worker replicas: 609 requests/s
import asyncio
import time
import ray
from ray import serve
from ray.serve import BackendConfig
from ray.serve.utils import logger
import time
num_queries = 2000
num_queries = 10000
max_concurrent_queries = 100000
ray.init(address="auto")
client = serve.start()
def hello_world(_):
def worker(_):
return b"Hello World"
class ForwardActor:
def __init__(self):
def __init__(self, sync: bool):
client = serve.connect()
self.handle = client.get_handle("hello_world")
self.sync = sync
self.handle = client.get_handle("worker", sync=sync)
async def __call__(self, _):
await self.handle.remote()
if self.sync:
await self.handle.remote()
else:
await (await self.handle.remote_async())
client.create_backend("hello_world", hello_world)
client.create_endpoint("hello_world", backend="hello_world")
async def run_test(num_replicas, num_forwarders, sync):
client = serve.start()
client.create_backend(
"worker",
worker,
config=BackendConfig(
num_replicas=num_replicas,
max_concurrent_queries=max_concurrent_queries,
))
client.create_endpoint("worker", backend="worker")
endpoint_name = "worker"
client.create_backend("ForwardActor", ForwardActor)
client.create_endpoint("ForwardActor", backend="ForwardActor")
if num_forwarders > 0:
client.create_backend(
"ForwardActor",
ForwardActor,
sync,
config=BackendConfig(
num_replicas=num_forwarders,
max_concurrent_queries=max_concurrent_queries))
client.create_endpoint("ForwardActor", backend="ForwardActor")
endpoint_name = "ForwardActor"
def run_test(num_replicas, num_forwarders):
replicas_config = BackendConfig(num_replicas=num_replicas)
client.update_backend_config("hello_world", replicas_config)
if (num_forwarders == 0):
handle = client.get_handle("hello_world")
else:
forwarders_config = BackendConfig(num_replicas=num_forwarders)
client.update_backend_config("ForwardActor", forwarders_config)
handle = client.get_handle("ForwardActor")
handle = client.get_handle(endpoint_name, sync=sync)
# warmup - helpful to wait for gc.collect() and actors to start
start = time.time()
while time.time() - start < 1:
ray.get(handle.remote())
if sync:
ray.get(handle.remote())
else:
ray.get(await handle.remote_async())
# real test
start = time.time()
ray.get([handle.remote() for _ in range(num_queries)])
if sync:
ray.get([handle.remote() for _ in range(num_queries)])
else:
ray.get([(await handle.remote_async()) for _ in range(num_queries)])
qps = num_queries / (time.time() - start)
logger.info("{} forwarders and {} worker replicas: {} requests/s".format(
num_forwarders, num_replicas, int(qps)))
print(
f"Sync: {sync}, {num_forwarders} forwarders and {num_replicas} worker "
f"replicas: {int(qps)} requests/s")
client.shutdown()
for num_forwarders in [0, 1, 2]:
for num_replicas in [1, 5, 10]:
run_test(num_replicas, num_forwarders)
async def main():
for sync in [True, False]:
for num_forwarders in [0, 1, 2]:
for num_replicas in [1, 5, 10]:
await run_test(num_replicas, num_forwarders, sync)
asyncio.get_event_loop().run_until_complete(main())
+13 -3
View File
@@ -7,6 +7,7 @@ import ray
from ray.serve.context import TaskContext
from ray.serve.router import RequestMetadata, Router
from ray.serve.utils import get_random_letters
from ray.serve.exceptions import RayServeException
global_async_loop = None
@@ -109,16 +110,25 @@ class RayServeHandle:
``**kwargs``: All keyword arguments will be available in
``request.args``.
"""
assert self.sync, "handle.remote() should be called from sync handle."
if not self.sync:
raise RayServeException(
"You are trying to call handle.remote() with async handle. "
"Please use `await handle.remote_async()` instead.")
coro = self._remote(request_data, kwargs)
future: concurrent.futures.Future = asyncio.run_coroutine_threadsafe(
coro, self.async_loop)
# Block until the result is ready.
return future.result()
async def _remote_async(self, request_data, **kwargs) -> ray.ObjectRef:
async def remote_async(self,
request_data: Optional[Union[Dict, Any]] = None,
**kwargs) -> ray.ObjectRef:
"""Experimental API for enqueue a request in async context."""
assert not self.sync, "_remote_async must be called inside async loop."
if not asyncio.get_event_loop().is_running():
raise RayServeException(
"remote_async must be called from a running event loop.")
return await self._remote(request_data, kwargs)
def options(self,