mirror of
https://github.com/wassname/ray.git
synced 2026-07-02 01:20:20 +08:00
Rate limit asyncio actor (#6242)
This commit is contained in:
+27
-5
@@ -130,11 +130,13 @@ MEMCOPY_THREADS = 12
|
||||
PY3 = cpython.PY_MAJOR_VERSION >= 3
|
||||
|
||||
|
||||
if cpython.PY_MAJOR_VERSION >= 3:
|
||||
if PY3:
|
||||
import pickle
|
||||
else:
|
||||
import cPickle as pickle
|
||||
|
||||
if PY3:
|
||||
from ray.async_compat import sync_to_async
|
||||
|
||||
cdef int check_status(const CRayStatus& status) nogil except -1:
|
||||
if status.ok():
|
||||
@@ -557,10 +559,17 @@ cdef execute_task(
|
||||
|
||||
def function_executor(*arguments, **kwarguments):
|
||||
function = execution_info.function
|
||||
result_or_coroutine = function(actor, *arguments, **kwarguments)
|
||||
|
||||
if PY3 and inspect.iscoroutine(result_or_coroutine):
|
||||
coroutine = result_or_coroutine
|
||||
if PY3 and core_worker.current_actor_is_asyncio():
|
||||
if inspect.iscoroutinefunction(function.method):
|
||||
async_function = function
|
||||
else:
|
||||
# Just execute the method if it's ray internal method.
|
||||
if function.name.startswith("__ray"):
|
||||
return function(actor, *arguments, **kwarguments)
|
||||
async_function = sync_to_async(function)
|
||||
|
||||
coroutine = async_function(actor, *arguments, **kwarguments)
|
||||
loop = core_worker.create_or_get_event_loop()
|
||||
|
||||
future = asyncio.run_coroutine_threadsafe(coroutine, loop)
|
||||
@@ -573,7 +582,7 @@ cdef execute_task(
|
||||
|
||||
return future.result()
|
||||
|
||||
return result_or_coroutine
|
||||
return function(actor, *arguments, **kwarguments)
|
||||
|
||||
with core_worker.profile_event(b"task", extra_data=extra_data):
|
||||
try:
|
||||
@@ -749,6 +758,7 @@ cdef class CoreWorker:
|
||||
task_execution_handler, check_signals, exit_handler, True))
|
||||
|
||||
def disconnect(self):
|
||||
self.destory_event_loop_if_exists()
|
||||
with nogil:
|
||||
self.core_worker.get().Disconnect()
|
||||
|
||||
@@ -1099,6 +1109,18 @@ cdef class CoreWorker:
|
||||
self.async_thread = threading.Thread(
|
||||
target=lambda: self.async_event_loop.run_forever()
|
||||
)
|
||||
# Making the thread a daemon causes it to exit
|
||||
# when the main thread exits.
|
||||
self.async_thread.daemon = True
|
||||
self.async_thread.start()
|
||||
|
||||
return self.async_event_loop
|
||||
|
||||
def destory_event_loop_if_exists(self):
|
||||
if self.async_event_loop is not None:
|
||||
self.async_event_loop.stop()
|
||||
if self.async_thread is not None:
|
||||
self.async_thread.join()
|
||||
|
||||
def current_actor_is_asyncio(self):
|
||||
return self.core_worker.get().GetWorkerContext().CurrentActorIsAsync()
|
||||
|
||||
+8
-4
@@ -378,7 +378,10 @@ class ActorClass(object):
|
||||
task.
|
||||
is_direct_call: Use direct actor calls.
|
||||
max_concurrency: The max number of concurrent calls to allow for
|
||||
this actor. This only works with direct actor calls.
|
||||
this actor. This only works with direct actor calls. The max
|
||||
concurrency defaults to 1 for threaded execution, and 100 for
|
||||
asyncio execution. Note that the execution order is not
|
||||
guaranteed when max_concurrency > 1.
|
||||
name: The globally unique name for the actor.
|
||||
detached: Whether the actor should be kept alive after driver
|
||||
exits.
|
||||
@@ -395,7 +398,10 @@ class ActorClass(object):
|
||||
if is_direct_call is None:
|
||||
is_direct_call = bool(os.environ.get("RAY_FORCE_DIRECT"))
|
||||
if max_concurrency is None:
|
||||
max_concurrency = 1
|
||||
if is_asyncio:
|
||||
max_concurrency = 100
|
||||
else:
|
||||
max_concurrency = 1
|
||||
|
||||
if max_concurrency > 1 and not is_direct_call:
|
||||
raise ValueError(
|
||||
@@ -406,8 +412,6 @@ class ActorClass(object):
|
||||
if is_asyncio and not is_direct_call:
|
||||
raise ValueError(
|
||||
"Setting is_asyncio requires is_direct_call=True.")
|
||||
if is_asyncio and max_concurrency != 1:
|
||||
raise ValueError("Setting is_asyncio requires max_concurrency=1.")
|
||||
|
||||
worker = ray.worker.get_global_worker()
|
||||
if worker.mode is None:
|
||||
|
||||
@@ -0,0 +1,13 @@
|
||||
"""
|
||||
This file should only be imported from Python 3.
|
||||
It will raise SyntaxError when importing from Python 2.
|
||||
"""
|
||||
|
||||
|
||||
def sync_to_async(func):
|
||||
"""Convert a blocking function to async function"""
|
||||
|
||||
async def wrapper(*args, **kwargs):
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
@@ -761,6 +761,14 @@ class FunctionActorManager(object):
|
||||
self._save_and_log_checkpoint(actor)
|
||||
return method_returns
|
||||
|
||||
# Set method_name and method as attributes to the executor clusore
|
||||
# so we can make decision based on these attributes in task executor.
|
||||
# Precisely, asyncio support requires to know whether:
|
||||
# - the method is a ray internal method: starts with __ray
|
||||
# - the method is a coroutine function: defined by async def
|
||||
actor_method_executor.name = method_name
|
||||
actor_method_executor.method = method
|
||||
|
||||
return actor_method_executor
|
||||
|
||||
def _save_and_log_checkpoint(self, actor):
|
||||
|
||||
@@ -54,6 +54,10 @@ cdef extern from "ray/core_worker/transport/direct_actor_transport.h" nogil:
|
||||
void Wait()
|
||||
void Notify()
|
||||
|
||||
cdef extern from "ray/core_worker/context.h" nogil:
|
||||
cdef cppclass CWorkerContext "ray::WorkerContext":
|
||||
c_bool CurrentActorIsAsync()
|
||||
|
||||
cdef extern from "ray/core_worker/core_worker.h" nogil:
|
||||
cdef cppclass CCoreWorker "ray::CoreWorker":
|
||||
CCoreWorker(const CWorkerType worker_type, const CLanguage language,
|
||||
@@ -132,4 +136,5 @@ cdef extern from "ray/core_worker/core_worker.h" nogil:
|
||||
c_bool local_only, c_bool delete_creating_tasks)
|
||||
c_string MemoryUsageString()
|
||||
|
||||
CWorkerContext &GetWorkerContext()
|
||||
void YieldCurrentFiber(CFiberEvent &coroutine_done)
|
||||
|
||||
@@ -4,6 +4,7 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import asyncio
|
||||
import threading
|
||||
import pytest
|
||||
|
||||
import ray
|
||||
@@ -102,13 +103,9 @@ def test_asyncio_actor(ray_start_regular):
|
||||
class AsyncBatcher(object):
|
||||
def __init__(self):
|
||||
self.batch = []
|
||||
# The event currently need to be created from the same thread.
|
||||
# We currently run async coroutines from a different thread.
|
||||
self.event = None
|
||||
self.event = asyncio.Event()
|
||||
|
||||
async def add(self, x):
|
||||
if self.event is None:
|
||||
self.event = asyncio.Event()
|
||||
self.batch.append(x)
|
||||
if len(self.batch) >= 3:
|
||||
self.event.set()
|
||||
@@ -125,3 +122,51 @@ def test_asyncio_actor(ray_start_regular):
|
||||
r3 = ray.get(x3)
|
||||
assert r1 == [1, 2, 3]
|
||||
assert r1 == r2 == r3
|
||||
|
||||
|
||||
def test_asyncio_actor_same_thread(ray_start_regular):
|
||||
@ray.remote
|
||||
class Actor:
|
||||
def sync_thread_id(self):
|
||||
return threading.current_thread().ident
|
||||
|
||||
async def async_thread_id(self):
|
||||
return threading.current_thread().ident
|
||||
|
||||
a = Actor.options(is_direct_call=True, is_asyncio=True).remote()
|
||||
sync_id, async_id = ray.get(
|
||||
[a.sync_thread_id.remote(),
|
||||
a.async_thread_id.remote()])
|
||||
assert sync_id == async_id
|
||||
|
||||
|
||||
def test_asyncio_actor_concurrency(ray_start_regular):
|
||||
@ray.remote
|
||||
class RecordOrder:
|
||||
def __init__(self):
|
||||
self.history = []
|
||||
|
||||
async def do_work(self):
|
||||
self.history.append("STARTED")
|
||||
# Force a context switch
|
||||
await asyncio.sleep(0)
|
||||
self.history.append("ENDED")
|
||||
|
||||
def get_history(self):
|
||||
return self.history
|
||||
|
||||
num_calls = 10
|
||||
|
||||
a = RecordOrder.options(
|
||||
is_direct_call=True, max_concurrency=1, is_asyncio=True).remote()
|
||||
ray.get([a.do_work.remote() for _ in range(num_calls)])
|
||||
history = ray.get(a.get_history.remote())
|
||||
|
||||
# We only care about ordered start-end-start-end sequence because
|
||||
# coroutines may be executed out of enqueued order.
|
||||
answer = []
|
||||
for _ in range(num_calls):
|
||||
for status in ["STARTED", "ENDED"]:
|
||||
answer.append(status)
|
||||
|
||||
assert history == answer
|
||||
|
||||
@@ -257,6 +257,7 @@ void CoreWorkerDirectTaskReceiver::SetActorAsAsync() {
|
||||
// immediately start working on any ready fibers.
|
||||
fiber_shutdown_event_.Wait();
|
||||
});
|
||||
fiber_rate_limiter_.reset(new FiberRateLimiter(max_concurrency_));
|
||||
is_asyncio_ = true;
|
||||
}
|
||||
};
|
||||
@@ -291,8 +292,9 @@ void CoreWorkerDirectTaskReceiver::HandlePushTask(
|
||||
auto it = scheduling_queue_.find(task_spec.CallerId());
|
||||
if (it == scheduling_queue_.end()) {
|
||||
auto result = scheduling_queue_.emplace(
|
||||
task_spec.CallerId(), std::unique_ptr<SchedulingQueue>(new SchedulingQueue(
|
||||
task_main_io_service_, *waiter_, pool_, is_asyncio_)));
|
||||
task_spec.CallerId(),
|
||||
std::unique_ptr<SchedulingQueue>(new SchedulingQueue(
|
||||
task_main_io_service_, *waiter_, pool_, is_asyncio_, fiber_rate_limiter_)));
|
||||
it = result.first;
|
||||
}
|
||||
|
||||
|
||||
@@ -265,6 +265,38 @@ class FiberEvent {
|
||||
bool ready_ = false;
|
||||
};
|
||||
|
||||
/// Used by async actor mode. The FiberRateLimiter is a barrier that
|
||||
/// allows at most num fibers running at once. It implements the
|
||||
/// semaphore data structure.
|
||||
class FiberRateLimiter {
|
||||
public:
|
||||
FiberRateLimiter(int num) : num_(num) {}
|
||||
|
||||
// Enter the semaphore. Wait fo the value to be > 0 and decrement the value.
|
||||
void Acquire() {
|
||||
std::unique_lock<boost::fibers::mutex> lock(mutex_);
|
||||
cond_.wait(lock, [this]() { return num_ > 0; });
|
||||
num_ -= 1;
|
||||
}
|
||||
|
||||
// Exit the semaphore. Increment the value and notify other waiter.
|
||||
void Release() {
|
||||
{
|
||||
std::unique_lock<boost::fibers::mutex> lock(mutex_);
|
||||
num_ += 1;
|
||||
}
|
||||
// TODO(simon): This not does guarantee to wake up the first queued fiber.
|
||||
// This could be a problem for certain workloads because there is no guarantee
|
||||
// on task ordering .
|
||||
cond_.notify_one();
|
||||
}
|
||||
|
||||
private:
|
||||
boost::fibers::condition_variable cond_;
|
||||
boost::fibers::mutex mutex_;
|
||||
int num_;
|
||||
};
|
||||
|
||||
/// Used to ensure serial order of task execution per actor handle.
|
||||
/// See direct_actor.proto for a description of the ordering protocol.
|
||||
class SchedulingQueue {
|
||||
@@ -272,13 +304,15 @@ class SchedulingQueue {
|
||||
SchedulingQueue(boost::asio::io_service &main_io_service, DependencyWaiter &waiter,
|
||||
std::shared_ptr<BoundedExecutor> pool = nullptr,
|
||||
bool use_asyncio = false,
|
||||
std::shared_ptr<FiberRateLimiter> fiber_rate_limiter = nullptr,
|
||||
int64_t reorder_wait_seconds = kMaxReorderWaitSeconds)
|
||||
: wait_timer_(main_io_service),
|
||||
waiter_(waiter),
|
||||
reorder_wait_seconds_(reorder_wait_seconds),
|
||||
main_thread_id_(boost::this_thread::get_id()),
|
||||
pool_(pool),
|
||||
use_asyncio_(use_asyncio) {}
|
||||
use_asyncio_(use_asyncio),
|
||||
fiber_rate_limiter_(fiber_rate_limiter) {}
|
||||
|
||||
void Add(int64_t seq_no, int64_t client_processed_up_to,
|
||||
std::function<void()> accept_request, std::function<void()> reject_request,
|
||||
@@ -327,7 +361,12 @@ class SchedulingQueue {
|
||||
auto request = head->second;
|
||||
|
||||
if (use_asyncio_) {
|
||||
boost::fibers::fiber([request]() mutable { request.Accept(); }).detach();
|
||||
boost::fibers::fiber([request, this]() mutable {
|
||||
fiber_rate_limiter_->Acquire();
|
||||
request.Accept();
|
||||
fiber_rate_limiter_->Release();
|
||||
})
|
||||
.detach();
|
||||
} else if (pool_ != nullptr) {
|
||||
pool_->PostBlocking([request]() mutable { request.Accept(); });
|
||||
} else {
|
||||
@@ -385,6 +424,9 @@ class SchedulingQueue {
|
||||
/// Whether we should enqueue requests into asyncio pool. Setting this to true
|
||||
/// will instantiate all tasks as fibers that can be yielded.
|
||||
bool use_asyncio_;
|
||||
/// If use_asyncio_ is true, fiber_rate_limiter_ limits the max number of async
|
||||
/// tasks running at once.
|
||||
std::shared_ptr<FiberRateLimiter> fiber_rate_limiter_;
|
||||
|
||||
friend class SchedulingQueueTest;
|
||||
};
|
||||
@@ -452,12 +494,16 @@ class CoreWorkerDirectTaskReceiver {
|
||||
/// If concurrent calls are allowed, holds the pool for executing these tasks.
|
||||
std::shared_ptr<BoundedExecutor> pool_;
|
||||
/// Whether this actor use asyncio for concurrency.
|
||||
/// TODO(simon) group all asyncio related fields into a separate struct.
|
||||
bool is_asyncio_ = false;
|
||||
/// The thread that runs all asyncio fibers. is_asyncio_ must be true.
|
||||
std::thread fiber_runner_thread_;
|
||||
/// The fiber event used to block fiber_runner_thread_ from shutdown.
|
||||
/// is_asyncio_ must be true.
|
||||
FiberEvent fiber_shutdown_event_;
|
||||
/// The fiber semaphore used to limit the number of concurrent fibers
|
||||
/// running at once.
|
||||
std::shared_ptr<FiberRateLimiter> fiber_rate_limiter_;
|
||||
};
|
||||
|
||||
} // namespace ray
|
||||
|
||||
Reference in New Issue
Block a user