Rate limit asyncio actor (#6242)

This commit is contained in:
Simon Mo
2019-11-24 11:39:28 -08:00
committed by GitHub
parent 9f0d005ce6
commit aa8d5d2f6c
8 changed files with 163 additions and 18 deletions
+27 -5
View File
@@ -130,11 +130,13 @@ MEMCOPY_THREADS = 12
PY3 = cpython.PY_MAJOR_VERSION >= 3
if cpython.PY_MAJOR_VERSION >= 3:
if PY3:
import pickle
else:
import cPickle as pickle
if PY3:
from ray.async_compat import sync_to_async
cdef int check_status(const CRayStatus& status) nogil except -1:
if status.ok():
@@ -557,10 +559,17 @@ cdef execute_task(
def function_executor(*arguments, **kwarguments):
function = execution_info.function
result_or_coroutine = function(actor, *arguments, **kwarguments)
if PY3 and inspect.iscoroutine(result_or_coroutine):
coroutine = result_or_coroutine
if PY3 and core_worker.current_actor_is_asyncio():
if inspect.iscoroutinefunction(function.method):
async_function = function
else:
# Just execute the method if it's ray internal method.
if function.name.startswith("__ray"):
return function(actor, *arguments, **kwarguments)
async_function = sync_to_async(function)
coroutine = async_function(actor, *arguments, **kwarguments)
loop = core_worker.create_or_get_event_loop()
future = asyncio.run_coroutine_threadsafe(coroutine, loop)
@@ -573,7 +582,7 @@ cdef execute_task(
return future.result()
return result_or_coroutine
return function(actor, *arguments, **kwarguments)
with core_worker.profile_event(b"task", extra_data=extra_data):
try:
@@ -749,6 +758,7 @@ cdef class CoreWorker:
task_execution_handler, check_signals, exit_handler, True))
def disconnect(self):
self.destory_event_loop_if_exists()
with nogil:
self.core_worker.get().Disconnect()
@@ -1099,6 +1109,18 @@ cdef class CoreWorker:
self.async_thread = threading.Thread(
target=lambda: self.async_event_loop.run_forever()
)
# Making the thread a daemon causes it to exit
# when the main thread exits.
self.async_thread.daemon = True
self.async_thread.start()
return self.async_event_loop
def destory_event_loop_if_exists(self):
if self.async_event_loop is not None:
self.async_event_loop.stop()
if self.async_thread is not None:
self.async_thread.join()
def current_actor_is_asyncio(self):
return self.core_worker.get().GetWorkerContext().CurrentActorIsAsync()
+8 -4
View File
@@ -378,7 +378,10 @@ class ActorClass(object):
task.
is_direct_call: Use direct actor calls.
max_concurrency: The max number of concurrent calls to allow for
this actor. This only works with direct actor calls.
this actor. This only works with direct actor calls. The max
concurrency defaults to 1 for threaded execution, and 100 for
asyncio execution. Note that the execution order is not
guaranteed when max_concurrency > 1.
name: The globally unique name for the actor.
detached: Whether the actor should be kept alive after driver
exits.
@@ -395,7 +398,10 @@ class ActorClass(object):
if is_direct_call is None:
is_direct_call = bool(os.environ.get("RAY_FORCE_DIRECT"))
if max_concurrency is None:
max_concurrency = 1
if is_asyncio:
max_concurrency = 100
else:
max_concurrency = 1
if max_concurrency > 1 and not is_direct_call:
raise ValueError(
@@ -406,8 +412,6 @@ class ActorClass(object):
if is_asyncio and not is_direct_call:
raise ValueError(
"Setting is_asyncio requires is_direct_call=True.")
if is_asyncio and max_concurrency != 1:
raise ValueError("Setting is_asyncio requires max_concurrency=1.")
worker = ray.worker.get_global_worker()
if worker.mode is None:
+13
View File
@@ -0,0 +1,13 @@
"""
This file should only be imported from Python 3.
It will raise SyntaxError when importing from Python 2.
"""
def sync_to_async(func):
"""Convert a blocking function to async function"""
async def wrapper(*args, **kwargs):
return func(*args, **kwargs)
return wrapper
+8
View File
@@ -761,6 +761,14 @@ class FunctionActorManager(object):
self._save_and_log_checkpoint(actor)
return method_returns
# Set method_name and method as attributes to the executor clusore
# so we can make decision based on these attributes in task executor.
# Precisely, asyncio support requires to know whether:
# - the method is a ray internal method: starts with __ray
# - the method is a coroutine function: defined by async def
actor_method_executor.name = method_name
actor_method_executor.method = method
return actor_method_executor
def _save_and_log_checkpoint(self, actor):
+5
View File
@@ -54,6 +54,10 @@ cdef extern from "ray/core_worker/transport/direct_actor_transport.h" nogil:
void Wait()
void Notify()
cdef extern from "ray/core_worker/context.h" nogil:
cdef cppclass CWorkerContext "ray::WorkerContext":
c_bool CurrentActorIsAsync()
cdef extern from "ray/core_worker/core_worker.h" nogil:
cdef cppclass CCoreWorker "ray::CoreWorker":
CCoreWorker(const CWorkerType worker_type, const CLanguage language,
@@ -132,4 +136,5 @@ cdef extern from "ray/core_worker/core_worker.h" nogil:
c_bool local_only, c_bool delete_creating_tasks)
c_string MemoryUsageString()
CWorkerContext &GetWorkerContext()
void YieldCurrentFiber(CFiberEvent &coroutine_done)
+50 -5
View File
@@ -4,6 +4,7 @@ from __future__ import division
from __future__ import print_function
import asyncio
import threading
import pytest
import ray
@@ -102,13 +103,9 @@ def test_asyncio_actor(ray_start_regular):
class AsyncBatcher(object):
def __init__(self):
self.batch = []
# The event currently need to be created from the same thread.
# We currently run async coroutines from a different thread.
self.event = None
self.event = asyncio.Event()
async def add(self, x):
if self.event is None:
self.event = asyncio.Event()
self.batch.append(x)
if len(self.batch) >= 3:
self.event.set()
@@ -125,3 +122,51 @@ def test_asyncio_actor(ray_start_regular):
r3 = ray.get(x3)
assert r1 == [1, 2, 3]
assert r1 == r2 == r3
def test_asyncio_actor_same_thread(ray_start_regular):
@ray.remote
class Actor:
def sync_thread_id(self):
return threading.current_thread().ident
async def async_thread_id(self):
return threading.current_thread().ident
a = Actor.options(is_direct_call=True, is_asyncio=True).remote()
sync_id, async_id = ray.get(
[a.sync_thread_id.remote(),
a.async_thread_id.remote()])
assert sync_id == async_id
def test_asyncio_actor_concurrency(ray_start_regular):
@ray.remote
class RecordOrder:
def __init__(self):
self.history = []
async def do_work(self):
self.history.append("STARTED")
# Force a context switch
await asyncio.sleep(0)
self.history.append("ENDED")
def get_history(self):
return self.history
num_calls = 10
a = RecordOrder.options(
is_direct_call=True, max_concurrency=1, is_asyncio=True).remote()
ray.get([a.do_work.remote() for _ in range(num_calls)])
history = ray.get(a.get_history.remote())
# We only care about ordered start-end-start-end sequence because
# coroutines may be executed out of enqueued order.
answer = []
for _ in range(num_calls):
for status in ["STARTED", "ENDED"]:
answer.append(status)
assert history == answer