mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 02:01:24 +08:00
Rate limit asyncio actor (#6242)
This commit is contained in:
+27
-5
@@ -130,11 +130,13 @@ MEMCOPY_THREADS = 12
|
||||
PY3 = cpython.PY_MAJOR_VERSION >= 3
|
||||
|
||||
|
||||
if cpython.PY_MAJOR_VERSION >= 3:
|
||||
if PY3:
|
||||
import pickle
|
||||
else:
|
||||
import cPickle as pickle
|
||||
|
||||
if PY3:
|
||||
from ray.async_compat import sync_to_async
|
||||
|
||||
cdef int check_status(const CRayStatus& status) nogil except -1:
|
||||
if status.ok():
|
||||
@@ -557,10 +559,17 @@ cdef execute_task(
|
||||
|
||||
def function_executor(*arguments, **kwarguments):
|
||||
function = execution_info.function
|
||||
result_or_coroutine = function(actor, *arguments, **kwarguments)
|
||||
|
||||
if PY3 and inspect.iscoroutine(result_or_coroutine):
|
||||
coroutine = result_or_coroutine
|
||||
if PY3 and core_worker.current_actor_is_asyncio():
|
||||
if inspect.iscoroutinefunction(function.method):
|
||||
async_function = function
|
||||
else:
|
||||
# Just execute the method if it's ray internal method.
|
||||
if function.name.startswith("__ray"):
|
||||
return function(actor, *arguments, **kwarguments)
|
||||
async_function = sync_to_async(function)
|
||||
|
||||
coroutine = async_function(actor, *arguments, **kwarguments)
|
||||
loop = core_worker.create_or_get_event_loop()
|
||||
|
||||
future = asyncio.run_coroutine_threadsafe(coroutine, loop)
|
||||
@@ -573,7 +582,7 @@ cdef execute_task(
|
||||
|
||||
return future.result()
|
||||
|
||||
return result_or_coroutine
|
||||
return function(actor, *arguments, **kwarguments)
|
||||
|
||||
with core_worker.profile_event(b"task", extra_data=extra_data):
|
||||
try:
|
||||
@@ -749,6 +758,7 @@ cdef class CoreWorker:
|
||||
task_execution_handler, check_signals, exit_handler, True))
|
||||
|
||||
def disconnect(self):
|
||||
self.destory_event_loop_if_exists()
|
||||
with nogil:
|
||||
self.core_worker.get().Disconnect()
|
||||
|
||||
@@ -1099,6 +1109,18 @@ cdef class CoreWorker:
|
||||
self.async_thread = threading.Thread(
|
||||
target=lambda: self.async_event_loop.run_forever()
|
||||
)
|
||||
# Making the thread a daemon causes it to exit
|
||||
# when the main thread exits.
|
||||
self.async_thread.daemon = True
|
||||
self.async_thread.start()
|
||||
|
||||
return self.async_event_loop
|
||||
|
||||
def destory_event_loop_if_exists(self):
|
||||
if self.async_event_loop is not None:
|
||||
self.async_event_loop.stop()
|
||||
if self.async_thread is not None:
|
||||
self.async_thread.join()
|
||||
|
||||
def current_actor_is_asyncio(self):
|
||||
return self.core_worker.get().GetWorkerContext().CurrentActorIsAsync()
|
||||
|
||||
+8
-4
@@ -378,7 +378,10 @@ class ActorClass(object):
|
||||
task.
|
||||
is_direct_call: Use direct actor calls.
|
||||
max_concurrency: The max number of concurrent calls to allow for
|
||||
this actor. This only works with direct actor calls.
|
||||
this actor. This only works with direct actor calls. The max
|
||||
concurrency defaults to 1 for threaded execution, and 100 for
|
||||
asyncio execution. Note that the execution order is not
|
||||
guaranteed when max_concurrency > 1.
|
||||
name: The globally unique name for the actor.
|
||||
detached: Whether the actor should be kept alive after driver
|
||||
exits.
|
||||
@@ -395,7 +398,10 @@ class ActorClass(object):
|
||||
if is_direct_call is None:
|
||||
is_direct_call = bool(os.environ.get("RAY_FORCE_DIRECT"))
|
||||
if max_concurrency is None:
|
||||
max_concurrency = 1
|
||||
if is_asyncio:
|
||||
max_concurrency = 100
|
||||
else:
|
||||
max_concurrency = 1
|
||||
|
||||
if max_concurrency > 1 and not is_direct_call:
|
||||
raise ValueError(
|
||||
@@ -406,8 +412,6 @@ class ActorClass(object):
|
||||
if is_asyncio and not is_direct_call:
|
||||
raise ValueError(
|
||||
"Setting is_asyncio requires is_direct_call=True.")
|
||||
if is_asyncio and max_concurrency != 1:
|
||||
raise ValueError("Setting is_asyncio requires max_concurrency=1.")
|
||||
|
||||
worker = ray.worker.get_global_worker()
|
||||
if worker.mode is None:
|
||||
|
||||
@@ -0,0 +1,13 @@
|
||||
"""
|
||||
This file should only be imported from Python 3.
|
||||
It will raise SyntaxError when importing from Python 2.
|
||||
"""
|
||||
|
||||
|
||||
def sync_to_async(func):
|
||||
"""Convert a blocking function to async function"""
|
||||
|
||||
async def wrapper(*args, **kwargs):
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
@@ -761,6 +761,14 @@ class FunctionActorManager(object):
|
||||
self._save_and_log_checkpoint(actor)
|
||||
return method_returns
|
||||
|
||||
# Set method_name and method as attributes to the executor clusore
|
||||
# so we can make decision based on these attributes in task executor.
|
||||
# Precisely, asyncio support requires to know whether:
|
||||
# - the method is a ray internal method: starts with __ray
|
||||
# - the method is a coroutine function: defined by async def
|
||||
actor_method_executor.name = method_name
|
||||
actor_method_executor.method = method
|
||||
|
||||
return actor_method_executor
|
||||
|
||||
def _save_and_log_checkpoint(self, actor):
|
||||
|
||||
@@ -54,6 +54,10 @@ cdef extern from "ray/core_worker/transport/direct_actor_transport.h" nogil:
|
||||
void Wait()
|
||||
void Notify()
|
||||
|
||||
cdef extern from "ray/core_worker/context.h" nogil:
|
||||
cdef cppclass CWorkerContext "ray::WorkerContext":
|
||||
c_bool CurrentActorIsAsync()
|
||||
|
||||
cdef extern from "ray/core_worker/core_worker.h" nogil:
|
||||
cdef cppclass CCoreWorker "ray::CoreWorker":
|
||||
CCoreWorker(const CWorkerType worker_type, const CLanguage language,
|
||||
@@ -132,4 +136,5 @@ cdef extern from "ray/core_worker/core_worker.h" nogil:
|
||||
c_bool local_only, c_bool delete_creating_tasks)
|
||||
c_string MemoryUsageString()
|
||||
|
||||
CWorkerContext &GetWorkerContext()
|
||||
void YieldCurrentFiber(CFiberEvent &coroutine_done)
|
||||
|
||||
@@ -4,6 +4,7 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import asyncio
|
||||
import threading
|
||||
import pytest
|
||||
|
||||
import ray
|
||||
@@ -102,13 +103,9 @@ def test_asyncio_actor(ray_start_regular):
|
||||
class AsyncBatcher(object):
|
||||
def __init__(self):
|
||||
self.batch = []
|
||||
# The event currently need to be created from the same thread.
|
||||
# We currently run async coroutines from a different thread.
|
||||
self.event = None
|
||||
self.event = asyncio.Event()
|
||||
|
||||
async def add(self, x):
|
||||
if self.event is None:
|
||||
self.event = asyncio.Event()
|
||||
self.batch.append(x)
|
||||
if len(self.batch) >= 3:
|
||||
self.event.set()
|
||||
@@ -125,3 +122,51 @@ def test_asyncio_actor(ray_start_regular):
|
||||
r3 = ray.get(x3)
|
||||
assert r1 == [1, 2, 3]
|
||||
assert r1 == r2 == r3
|
||||
|
||||
|
||||
def test_asyncio_actor_same_thread(ray_start_regular):
|
||||
@ray.remote
|
||||
class Actor:
|
||||
def sync_thread_id(self):
|
||||
return threading.current_thread().ident
|
||||
|
||||
async def async_thread_id(self):
|
||||
return threading.current_thread().ident
|
||||
|
||||
a = Actor.options(is_direct_call=True, is_asyncio=True).remote()
|
||||
sync_id, async_id = ray.get(
|
||||
[a.sync_thread_id.remote(),
|
||||
a.async_thread_id.remote()])
|
||||
assert sync_id == async_id
|
||||
|
||||
|
||||
def test_asyncio_actor_concurrency(ray_start_regular):
|
||||
@ray.remote
|
||||
class RecordOrder:
|
||||
def __init__(self):
|
||||
self.history = []
|
||||
|
||||
async def do_work(self):
|
||||
self.history.append("STARTED")
|
||||
# Force a context switch
|
||||
await asyncio.sleep(0)
|
||||
self.history.append("ENDED")
|
||||
|
||||
def get_history(self):
|
||||
return self.history
|
||||
|
||||
num_calls = 10
|
||||
|
||||
a = RecordOrder.options(
|
||||
is_direct_call=True, max_concurrency=1, is_asyncio=True).remote()
|
||||
ray.get([a.do_work.remote() for _ in range(num_calls)])
|
||||
history = ray.get(a.get_history.remote())
|
||||
|
||||
# We only care about ordered start-end-start-end sequence because
|
||||
# coroutines may be executed out of enqueued order.
|
||||
answer = []
|
||||
for _ in range(num_calls):
|
||||
for status in ["STARTED", "ENDED"]:
|
||||
answer.append(status)
|
||||
|
||||
assert history == answer
|
||||
|
||||
Reference in New Issue
Block a user