From 55ccfb60895ea607457928457f5a0208b9a9581c Mon Sep 17 00:00:00 2001 From: Edward Oakes Date: Thu, 27 Feb 2020 10:16:04 -0800 Subject: [PATCH] Fix asyncio actor race condition (#7335) --- python/ray/tests/test_dynres.py | 118 +++++++++++------- .../core_worker/test/scheduling_queue_test.cc | 18 ++- .../transport/direct_actor_transport.cc | 31 +---- .../transport/direct_actor_transport.h | 57 +++++---- 4 files changed, 114 insertions(+), 110 deletions(-) diff --git a/python/ray/tests/test_dynres.py b/python/ray/tests/test_dynres.py index 47984582b..a6cf2bda1 100644 --- a/python/ray/tests/test_dynres.py +++ b/python/ray/tests/test_dynres.py @@ -1,3 +1,4 @@ +import asyncio import logging import time @@ -300,12 +301,6 @@ def test_dynamic_res_concurrent_res_increment(ray_start_cluster): num_nodes = 5 TIMEOUT_DURATION = 1 - # Create a object ID to have the task wait on - WAIT_OBJECT_ID_STR = ("a" * 20).encode("ascii") - - # Create a object ID to signal that the task is running - TASK_RUNNING_OBJECT_ID_STR = ("b" * 20).encode("ascii") - for i in range(num_nodes): cluster.add_node() @@ -325,29 +320,42 @@ def test_dynamic_res_concurrent_res_increment(ray_start_cluster): # Task to hold the resource till the driver signals to finish @ray.remote - def wait_func(running_oid, wait_oid): - # Signal that the task is running - ray.worker.global_worker.put_object(1, ray.ObjectID(running_oid)) - # Make the task wait till signalled by driver - ray.get(ray.ObjectID(wait_oid)) + def wait_func(running_signal, finish_signal): + # Signal that the task is running. + ray.get(running_signal.send.remote()) + # Wait until signaled by driver. + ray.get(finish_signal.wait.remote()) @ray.remote def test_func(): return 1 + @ray.remote(num_cpus=0) + class Signal: + def __init__(self): + self.ready_event = asyncio.Event() + + def send(self): + self.ready_event.set() + + async def wait(self): + await self.ready_event.wait() + + running_signal = Signal.remote() + finish_signal = Signal.remote() + # Launch the task with resource requirement of 4, thus the new available # capacity becomes 1 task = wait_func._remote( - args=[TASK_RUNNING_OBJECT_ID_STR, WAIT_OBJECT_ID_STR], - resources={res_name: 4}) - # Wait till wait_func is launched before updating resource - ray.get(ray.ObjectID(TASK_RUNNING_OBJECT_ID_STR)) + args=[running_signal, finish_signal], resources={res_name: 4}) + # Wait until wait_func is launched before updating resource + ray.get(running_signal.wait.remote()) # Update the resource capacity ray.get(set_res.remote(res_name, updated_capacity, target_node_id)) # Signal task to complete - ray.worker.global_worker.put_object(1, ray.ObjectID(WAIT_OBJECT_ID_STR)) + ray.get(finish_signal.send.remote()) ray.get(task) # Check if scheduler state is consistent by launching a task requiring @@ -379,12 +387,6 @@ def test_dynamic_res_concurrent_res_decrement(ray_start_cluster): num_nodes = 5 TIMEOUT_DURATION = 1 - # Create a object ID to have the task wait on - WAIT_OBJECT_ID_STR = ("a" * 20).encode("ascii") - - # Create a object ID to signal that the task is running - TASK_RUNNING_OBJECT_ID_STR = ("b" * 20).encode("ascii") - for i in range(num_nodes): cluster.add_node() @@ -404,29 +406,42 @@ def test_dynamic_res_concurrent_res_decrement(ray_start_cluster): # Task to hold the resource till the driver signals to finish @ray.remote - def wait_func(running_oid, wait_oid): - # Signal that the task is running - ray.worker.global_worker.put_object(1, ray.ObjectID(running_oid)) - # Make the task wait till signalled by driver - ray.get(ray.ObjectID(wait_oid)) + def wait_func(running_signal, finish_signal): + # Signal that the task is running. + ray.get(running_signal.send.remote()) + # Wait until signaled by driver. + ray.get(finish_signal.wait.remote()) @ray.remote def test_func(): return 1 + @ray.remote(num_cpus=0) + class Signal: + def __init__(self): + self.ready_event = asyncio.Event() + + def send(self): + self.ready_event.set() + + async def wait(self): + await self.ready_event.wait() + + running_signal = Signal.remote() + finish_signal = Signal.remote() + # Launch the task with resource requirement of 4, thus the new available # capacity becomes 1 task = wait_func._remote( - args=[TASK_RUNNING_OBJECT_ID_STR, WAIT_OBJECT_ID_STR], - resources={res_name: 4}) - # Wait till wait_func is launched before updating resource - ray.get(ray.ObjectID(TASK_RUNNING_OBJECT_ID_STR)) + args=[running_signal, finish_signal], resources={res_name: 4}) + # Wait until wait_func is launched before updating resource + ray.get(running_signal.wait.remote()) # Decrease the resource capacity ray.get(set_res.remote(res_name, updated_capacity, target_node_id)) # Signal task to complete - ray.worker.global_worker.put_object(1, ray.ObjectID(WAIT_OBJECT_ID_STR)) + ray.get(finish_signal.send.remote()) ray.get(task) # Check if scheduler state is consistent by launching a task requiring @@ -456,12 +471,6 @@ def test_dynamic_res_concurrent_res_delete(ray_start_cluster): num_nodes = 5 TIMEOUT_DURATION = 1 - # Create a object ID to have the task wait on - WAIT_OBJECT_ID_STR = ("a" * 20).encode("ascii") - - # Create a object ID to signal that the task is running - TASK_RUNNING_OBJECT_ID_STR = ("b" * 20).encode("ascii") - for i in range(num_nodes): cluster.add_node() @@ -486,29 +495,42 @@ def test_dynamic_res_concurrent_res_delete(ray_start_cluster): # Task to hold the resource till the driver signals to finish @ray.remote - def wait_func(running_oid, wait_oid): - # Signal that the task is running - ray.worker.global_worker.put_object(1, ray.ObjectID(running_oid)) - # Make the task wait till signalled by driver - ray.get(ray.ObjectID(wait_oid)) + def wait_func(running_signal, finish_signal): + # Signal that the task is running. + ray.get(running_signal.send.remote()) + # Wait until signaled by driver. + ray.get(finish_signal.wait.remote()) @ray.remote def test_func(): return 1 + @ray.remote(num_cpus=0) + class Signal: + def __init__(self): + self.ready_event = asyncio.Event() + + def send(self): + self.ready_event.set() + + async def wait(self): + await self.ready_event.wait() + + running_signal = Signal.remote() + finish_signal = Signal.remote() + # Launch the task with resource requirement of 4, thus the new available # capacity becomes 1 task = wait_func._remote( - args=[TASK_RUNNING_OBJECT_ID_STR, WAIT_OBJECT_ID_STR], - resources={res_name: 4}) - # Wait till wait_func is launched before updating resource - ray.get(ray.ObjectID(TASK_RUNNING_OBJECT_ID_STR)) + args=[running_signal, finish_signal], resources={res_name: 4}) + # Wait until wait_func is launched before updating resource + ray.get(running_signal.wait.remote()) # Delete the resource ray.get(delete_res.remote(res_name, target_node_id)) # Signal task to complete - ray.worker.global_worker.put_object(1, ray.ObjectID(WAIT_OBJECT_ID_STR)) + ray.get(finish_signal.send.remote()) ray.get(task) # Check if scheduler state is consistent by launching a task requiring diff --git a/src/ray/core_worker/test/scheduling_queue_test.cc b/src/ray/core_worker/test/scheduling_queue_test.cc index bc402de88..cbc608289 100644 --- a/src/ray/core_worker/test/scheduling_queue_test.cc +++ b/src/ray/core_worker/test/scheduling_queue_test.cc @@ -23,7 +23,8 @@ class MockWaiter : public DependencyWaiter { TEST(SchedulingQueueTest, TestInOrder) { boost::asio::io_service io_service; MockWaiter waiter; - SchedulingQueue queue(io_service, waiter, nullptr, 0); + WorkerContext context(WorkerType::WORKER, JobID::Nil()); + SchedulingQueue queue(io_service, waiter, context, 0); int n_ok = 0; int n_rej = 0; auto fn_ok = [&n_ok]() { n_ok++; }; @@ -43,7 +44,8 @@ TEST(SchedulingQueueTest, TestWaitForObjects) { ObjectID obj3 = ObjectID::FromRandom(); boost::asio::io_service io_service; MockWaiter waiter; - SchedulingQueue queue(io_service, waiter, nullptr, 0); + WorkerContext context(WorkerType::WORKER, JobID::Nil()); + SchedulingQueue queue(io_service, waiter, context, 0); int n_ok = 0; int n_rej = 0; auto fn_ok = [&n_ok]() { n_ok++; }; @@ -68,7 +70,8 @@ TEST(SchedulingQueueTest, TestWaitForObjectsNotSubjectToSeqTimeout) { ObjectID obj1 = ObjectID::FromRandom(); boost::asio::io_service io_service; MockWaiter waiter; - SchedulingQueue queue(io_service, waiter, nullptr, 0); + WorkerContext context(WorkerType::WORKER, JobID::Nil()); + SchedulingQueue queue(io_service, waiter, context, 0); int n_ok = 0; int n_rej = 0; auto fn_ok = [&n_ok]() { n_ok++; }; @@ -85,7 +88,8 @@ TEST(SchedulingQueueTest, TestWaitForObjectsNotSubjectToSeqTimeout) { TEST(SchedulingQueueTest, TestOutOfOrder) { boost::asio::io_service io_service; MockWaiter waiter; - SchedulingQueue queue(io_service, waiter, nullptr, 0); + WorkerContext context(WorkerType::WORKER, JobID::Nil()); + SchedulingQueue queue(io_service, waiter, context, 0); int n_ok = 0; int n_rej = 0; auto fn_ok = [&n_ok]() { n_ok++; }; @@ -102,7 +106,8 @@ TEST(SchedulingQueueTest, TestOutOfOrder) { TEST(SchedulingQueueTest, TestSeqWaitTimeout) { boost::asio::io_service io_service; MockWaiter waiter; - SchedulingQueue queue(io_service, waiter, nullptr, 0); + WorkerContext context(WorkerType::WORKER, JobID::Nil()); + SchedulingQueue queue(io_service, waiter, context, 0); int n_ok = 0; int n_rej = 0; auto fn_ok = [&n_ok]() { n_ok++; }; @@ -124,7 +129,8 @@ TEST(SchedulingQueueTest, TestSeqWaitTimeout) { TEST(SchedulingQueueTest, TestSkipAlreadyProcessedByClient) { boost::asio::io_service io_service; MockWaiter waiter; - SchedulingQueue queue(io_service, waiter, nullptr, 0); + WorkerContext context(WorkerType::WORKER, JobID::Nil()); + SchedulingQueue queue(io_service, waiter, context, 0); int n_ok = 0; int n_rej = 0; auto fn_ok = [&n_ok]() { n_ok++; }; diff --git a/src/ray/core_worker/transport/direct_actor_transport.cc b/src/ray/core_worker/transport/direct_actor_transport.cc index c5fbaebac..317213bdd 100644 --- a/src/ray/core_worker/transport/direct_actor_transport.cc +++ b/src/ray/core_worker/transport/direct_actor_transport.cc @@ -178,24 +178,6 @@ void CoreWorkerDirectTaskReceiver::Init(rpc::ClientFactoryFn client_factory, client_factory_ = client_factory; } -void CoreWorkerDirectTaskReceiver::SetMaxActorConcurrency(int max_concurrency) { - if (max_concurrency != max_concurrency_) { - RAY_LOG(INFO) << "Creating new thread pool of size " << max_concurrency; - RAY_CHECK(pool_ == nullptr) << "Cannot change max concurrency at runtime."; - pool_.reset(new BoundedExecutor(max_concurrency)); - max_concurrency_ = max_concurrency; - } -} - -void CoreWorkerDirectTaskReceiver::SetActorAsAsync(int max_concurrency) { - if (!is_asyncio_) { - RAY_LOG(DEBUG) << "Setting direct actor as async, creating new fiber thread."; - fiber_state_.reset(new FiberState(max_concurrency)); - max_concurrency_ = max_concurrency; - is_asyncio_ = true; - } -}; - void CoreWorkerDirectTaskReceiver::HandlePushTask( const rpc::PushTaskRequest &request, rpc::PushTaskReply *reply, rpc::SendReplyCallback send_reply_callback) { @@ -208,14 +190,6 @@ void CoreWorkerDirectTaskReceiver::HandlePushTask( return; } - // Only call SetMaxActorConcurrency to configure threadpool size when the - // actor is not async actor. Async actor is single threaded. - if (worker_context_.CurrentActorIsAsync()) { - SetActorAsAsync(worker_context_.CurrentActorMaxConcurrency()); - } else { - SetMaxActorConcurrency(worker_context_.CurrentActorMaxConcurrency()); - } - std::vector dependencies; for (size_t i = 0; i < task_spec.NumArgs(); ++i) { int count = task_spec.ArgIdCount(i); @@ -325,9 +299,8 @@ void CoreWorkerDirectTaskReceiver::HandlePushTask( auto it = scheduling_queue_.find(task_spec.CallerId()); if (it == scheduling_queue_.end()) { auto result = scheduling_queue_.emplace( - task_spec.CallerId(), - std::unique_ptr(new SchedulingQueue( - task_main_io_service_, *waiter_, pool_, is_asyncio_, fiber_state_))); + task_spec.CallerId(), std::unique_ptr(new SchedulingQueue( + task_main_io_service_, *waiter_, worker_context_))); it = result.first; } it->second->Add(request.sequence_number(), request.client_processed_up_to(), diff --git a/src/ray/core_worker/transport/direct_actor_transport.h b/src/ray/core_worker/transport/direct_actor_transport.h index 8b889604a..cf50c023e 100644 --- a/src/ray/core_worker/transport/direct_actor_transport.h +++ b/src/ray/core_worker/transport/direct_actor_transport.h @@ -239,17 +239,13 @@ class BoundedExecutor { class SchedulingQueue { public: SchedulingQueue(boost::asio::io_service &main_io_service, DependencyWaiter &waiter, - std::shared_ptr pool = nullptr, - bool use_asyncio = false, - std::shared_ptr fiber_state = nullptr, + WorkerContext &worker_context, int64_t reorder_wait_seconds = kMaxReorderWaitSeconds) : wait_timer_(main_io_service), waiter_(waiter), reorder_wait_seconds_(reorder_wait_seconds), main_thread_id_(boost::this_thread::get_id()), - pool_(pool), - use_asyncio_(use_asyncio), - fiber_state_(fiber_state) {} + worker_context_(worker_context) {} void Add(int64_t seq_no, int64_t client_processed_up_to, std::function accept_request, std::function reject_request, @@ -283,6 +279,24 @@ class SchedulingQueue { private: /// Schedules as many requests as possible in sequence. void ScheduleRequests() { + // Only call SetMaxActorConcurrency to configure threadpool size when the + // actor is not async actor. Async actor is single threaded. + int max_concurrency = worker_context_.CurrentActorMaxConcurrency(); + if (worker_context_.CurrentActorIsAsync()) { + // If this is an async actor, initialize the fiber state once. + if (!is_asyncio_) { + RAY_LOG(DEBUG) << "Setting direct actor as async, creating new fiber thread."; + fiber_state_.reset(new FiberState(max_concurrency)); + is_asyncio_ = true; + } + } else { + // If this is a concurrency actor (not async), initialize the thread pool once. + if (max_concurrency != 1 && !pool_) { + RAY_LOG(INFO) << "Creating new thread pool of size " << max_concurrency; + pool_.reset(new BoundedExecutor(max_concurrency)); + } + } + // Cancel any stale requests that the client doesn't need any longer. while (!pending_tasks_.empty() && pending_tasks_.begin()->first < next_seq_no_) { auto head = pending_tasks_.begin(); @@ -298,11 +312,14 @@ class SchedulingQueue { auto head = pending_tasks_.begin(); auto request = head->second; - if (use_asyncio_) { + if (is_asyncio_) { + // Process async actor task. fiber_state_->EnqueueFiber([request]() mutable { request.Accept(); }); - } else if (pool_ != nullptr) { + } else if (pool_) { + // Process concurrent actor task. pool_->PostBlocking([request]() mutable { request.Accept(); }); } else { + // Process normal actor task. request.Accept(); } pending_tasks_.erase(head); @@ -339,6 +356,8 @@ class SchedulingQueue { } } + // Worker context. + WorkerContext &worker_context_; /// Max time in seconds to wait for dependencies to show up. const int64_t reorder_wait_seconds_ = 0; /// Sorted map of (accept, rej) task callbacks keyed by their sequence number. @@ -353,13 +372,13 @@ class SchedulingQueue { /// Reference to the waiter owned by the task receiver. DependencyWaiter &waiter_; /// If concurrent calls are allowed, holds the pool for executing these tasks. - std::shared_ptr pool_; + std::unique_ptr pool_; /// Whether we should enqueue requests into asyncio pool. Setting this to true /// will instantiate all tasks as fibers that can be yielded. - bool use_asyncio_; + bool is_asyncio_ = false; /// If use_asyncio_ is true, fiber_state_ contains the running state required /// to enable continuation and work together with python asyncio. - std::shared_ptr fiber_state_; + std::unique_ptr fiber_state_; friend class SchedulingQueueTest; }; @@ -403,12 +422,6 @@ class CoreWorkerDirectTaskReceiver { rpc::DirectActorCallArgWaitCompleteReply *reply, rpc::SendReplyCallback send_reply_callback); - /// Set the max concurrency at runtime. It cannot be changed once set. - void SetMaxActorConcurrency(int max_concurrency); - - /// Set the max concurrency and start async actor context. - void SetActorAsAsync(int max_concurrency); - private: // Worker context. WorkerContext &worker_context_; @@ -430,18 +443,8 @@ class CoreWorkerDirectTaskReceiver { /// Queue of pending requests per actor handle. /// TODO(ekl) GC these queues once the handle is no longer active. std::unordered_map> scheduling_queue_; - /// The max number of concurrent calls to allow. - int max_concurrency_ = 1; /// Whether we are shutting down and not running further tasks. bool exiting_ = false; - /// If concurrent calls are allowed, holds the pool for executing these tasks. - std::shared_ptr pool_; - /// Whether this actor use asyncio for concurrency. - /// TODO(simon) group all asyncio related fields into a separate struct. - bool is_asyncio_ = false; - /// If use_asyncio_ is true, fiber_state_ contains the running state required - /// to enable continuation and work together with python asyncio. - std::shared_ptr fiber_state_; }; } // namespace ray