From 102e3682d98cf8227d88beb97ccafe69f3cc7cea Mon Sep 17 00:00:00 2001 From: Robert Nishihara Date: Thu, 16 Jun 2016 16:04:52 -0700 Subject: [PATCH] make sure the scheduler sends tasks to the worker only after the worker is ready (#116) --- protos/ray.proto | 13 ++++++++----- src/scheduler.cc | 13 ++++++------- src/scheduler.h | 2 +- src/worker.cc | 18 +++++++++++++----- 4 files changed, 28 insertions(+), 18 deletions(-) diff --git a/protos/ray.proto b/protos/ray.proto index 455703821..5409be385 100644 --- a/protos/ray.proto +++ b/protos/ray.proto @@ -43,8 +43,8 @@ service Scheduler { rpc DecrementRefCount(DecrementRefCountRequest) returns (AckReply); // Used by the worker to notify the scheduler about which objrefs a particular object contains rpc AddContainedObjRefs(AddContainedObjRefsRequest) returns (AckReply); - // Used by the worker to report back and ask for more work - rpc NotifyTaskCompleted(NotifyTaskCompletedRequest) returns (AckReply); + // Used by the worker to ask for work, this also returns the status of the previous task if there was one + rpc ReadyForNewTask(ReadyForNewTaskRequest) returns (AckReply); // Get information about the scheduler state rpc SchedulerInfo(SchedulerInfoRequest) returns (SchedulerInfoReply); } @@ -121,10 +121,13 @@ message DecrementRefCountRequest { repeated uint64 objref = 1; // Object references whose reference count should be decremented. Duplicates will be decremented multiple times. } -message NotifyTaskCompletedRequest { +message ReadyForNewTaskRequest { uint64 workerid = 1; // ID of the worker which executed the task - bool task_succeeded = 2; // True if the task succeeded, false if it threw an exception - string error_message = 3; // The contents of the exception, if the task threw an exception + message PreviousTaskInfo { + bool task_succeeded = 1; // True if the task succeeded, false if it threw an exception + string error_message = 2; // The contents of the exception, if the task threw an exception + } + PreviousTaskInfo previous_task_info = 2; // Information about the previous task, this is only present if there was a previous task } message ChangeCountRequest { diff --git a/src/scheduler.cc b/src/scheduler.cc index 5e0d3c122..05f6f2963 100644 --- a/src/scheduler.cc +++ b/src/scheduler.cc @@ -156,14 +156,16 @@ Status SchedulerService::ObjReady(ServerContext* context, const ObjReadyRequest* return Status::OK; } -Status SchedulerService::NotifyTaskCompleted(ServerContext* context, const NotifyTaskCompletedRequest* request, AckReply* reply) { - RAY_LOG(RAY_INFO, "worker " << request->workerid() << " reported back"); +Status SchedulerService::ReadyForNewTask(ServerContext* context, const ReadyForNewTaskRequest* request, AckReply* reply) { + RAY_LOG(RAY_INFO, "worker " << request->workerid() << " is ready for a new task"); { std::lock_guard lock(avail_workers_lock_); avail_workers_.push_back(request->workerid()); } - if (!request->task_succeeded()) { - RAY_LOG(RAY_FATAL, "The task on worker " << request->workerid() << " threw an exception with the following error message: " << request->error_message()); + if (request->has_previous_task_info()) { + if (!request->previous_task_info().task_succeeded()) { + RAY_LOG(RAY_FATAL, "The task on worker " << request->workerid() << " threw an exception with the following error message: " << request->previous_task_info().error_message()); + } } schedule(); return Status::OK; @@ -317,9 +319,6 @@ std::pair SchedulerService::register_worker(const std::str workers_[workerid].objstoreid = objstoreid; workers_[workerid].worker_stub = WorkerService::NewStub(channel); workers_lock_.unlock(); - avail_workers_lock_.lock(); - avail_workers_.push_back(workerid); - avail_workers_lock_.unlock(); return std::make_pair(workerid, objstoreid); } diff --git a/src/scheduler.h b/src/scheduler.h index 07f17e9f8..76e7b8192 100644 --- a/src/scheduler.h +++ b/src/scheduler.h @@ -60,7 +60,7 @@ public: Status RegisterWorker(ServerContext* context, const RegisterWorkerRequest* request, RegisterWorkerReply* reply) override; Status RegisterFunction(ServerContext* context, const RegisterFunctionRequest* request, AckReply* reply) override; Status ObjReady(ServerContext* context, const ObjReadyRequest* request, AckReply* reply) override; - Status NotifyTaskCompleted(ServerContext* context, const NotifyTaskCompletedRequest* request, AckReply* reply) override; + Status ReadyForNewTask(ServerContext* context, const ReadyForNewTaskRequest* request, AckReply* reply) override; Status IncrementRefCount(ServerContext* context, const IncrementRefCountRequest* request, AckReply* reply) override; Status DecrementRefCount(ServerContext* context, const DecrementRefCountRequest* request, AckReply* reply) override; Status AddContainedObjRefs(ServerContext* context, const AddContainedObjRefsRequest* request, AckReply* reply) override; diff --git a/src/worker.cc b/src/worker.cc index 7654e647a..71e1a6500 100644 --- a/src/worker.cc +++ b/src/worker.cc @@ -257,12 +257,13 @@ Task* Worker::receive_next_task() { void Worker::notify_task_completed(bool task_succeeded, std::string error_message) { RAY_CHECK(connected_, "Attempted to perform notify_task_completed but failed."); ClientContext context; - NotifyTaskCompletedRequest request; + ReadyForNewTaskRequest request; request.set_workerid(workerid_); - request.set_task_succeeded(task_succeeded); - request.set_error_message(error_message); + ReadyForNewTaskRequest::PreviousTaskInfo* previous_task_info = request.mutable_previous_task_info(); + previous_task_info->set_task_succeeded(task_succeeded); + previous_task_info->set_error_message(error_message); AckReply reply; - scheduler_stub_->NotifyTaskCompleted(&context, request, &reply); + scheduler_stub_->ReadyForNewTask(&context, request, &reply); } void Worker::disconnect() { @@ -285,7 +286,7 @@ void Worker::scheduler_info(ClientContext &context, SchedulerInfoRequest &reques // run in a separate thread and potentially utilize multiple threads. void Worker::start_worker_service() { const char* service_addr = worker_address_.c_str(); - worker_server_thread_ = std::thread([service_addr]() { + worker_server_thread_ = std::thread([this, service_addr]() { std::string service_address(service_addr); std::string::iterator split_point = split_ip_address(service_address); std::string port; @@ -296,6 +297,13 @@ void Worker::start_worker_service() { builder.RegisterService(&service); std::unique_ptr server(builder.BuildAndStart()); RAY_LOG(RAY_INFO, "worker server listening on " << service_address); + + ClientContext context; + ReadyForNewTaskRequest request; + request.set_workerid(workerid_); + AckReply reply; + scheduler_stub_->ReadyForNewTask(&context, request, &reply); + server->Wait(); }); }