diff --git a/lib/python/ray/services.py b/lib/python/ray/services.py index cfb71ed62..1fbe39c23 100644 --- a/lib/python/ray/services.py +++ b/lib/python/ray/services.py @@ -136,11 +136,11 @@ def start_services_local(num_objstores=1, num_workers_per_objstore=0, worker_pat driver_workers = [] for i in range(num_objstores): driver_worker = worker.Worker() - ray.connect(scheduler_address, objstore_address, address(IP_ADDRESS, new_worker_port()), driver_worker) + ray.connect(scheduler_address, objstore_address, address(IP_ADDRESS, new_worker_port()), is_driver=True, worker=driver_worker) driver_workers.append(driver_worker) drivers.append(driver_worker) time.sleep(0.5) return driver_workers else: - ray.connect(scheduler_address, objstore_addresses[0], address(IP_ADDRESS, new_worker_port()), mode=driver_mode) + ray.connect(scheduler_address, objstore_addresses[0], address(IP_ADDRESS, new_worker_port()), is_driver=True, mode=driver_mode) time.sleep(0.5) diff --git a/lib/python/ray/worker.py b/lib/python/ray/worker.py index b9649c23f..2294f8597 100644 --- a/lib/python/ray/worker.py +++ b/lib/python/ray/worker.py @@ -160,10 +160,10 @@ def register_module(module, recursive=False, worker=global_worker): # elif recursive and isinstance(val, ModuleType): # register_module(val, recursive, worker) -def connect(scheduler_addr, objstore_addr, worker_addr, worker=global_worker, mode=ray.WORKER_MODE): +def connect(scheduler_addr, objstore_addr, worker_addr, is_driver=False, worker=global_worker, mode=ray.WORKER_MODE): if hasattr(worker, "handle"): del worker.handle - worker.handle = ray.lib.create_worker(scheduler_addr, objstore_addr, worker_addr) + worker.handle = ray.lib.create_worker(scheduler_addr, objstore_addr, worker_addr, is_driver) FORMAT = "%(asctime)-15s %(message)s" log_basename = os.path.join(LOG_DIRECTORY, (LOG_TIMESTAMP + "-worker-{}").format(datetime.datetime.now(), worker_addr)) logging.basicConfig(level=logging.DEBUG, format=FORMAT, filename=log_basename + ".log") diff --git a/protos/ray.proto b/protos/ray.proto index 7b5bf862b..1c6be5a5b 100644 --- a/protos/ray.proto +++ b/protos/ray.proto @@ -58,6 +58,7 @@ message AckReply { message RegisterWorkerRequest { string worker_address = 1; // IP address of the worker being registered string objstore_address = 2; // IP address of the object store the worker is connected to + bool is_driver = 3; // True if the worker is a driver, and false otherwise } message RegisterWorkerReply { diff --git a/src/raylib.cc b/src/raylib.cc index 2c2ff279b..c6a9cddae 100644 --- a/src/raylib.cc +++ b/src/raylib.cc @@ -594,13 +594,15 @@ PyObject* create_worker(PyObject* self, PyObject* args) { const char* scheduler_addr; const char* objstore_addr; const char* worker_addr; - if (!PyArg_ParseTuple(args, "sss", &scheduler_addr, &objstore_addr, &worker_addr)) { + PyObject* is_driver_obj; + if (!PyArg_ParseTuple(args, "sssO", &scheduler_addr, &objstore_addr, &worker_addr, &is_driver_obj)) { return NULL; } + bool is_driver = PyObject_IsTrue(is_driver_obj); auto scheduler_channel = grpc::CreateChannel(scheduler_addr, grpc::InsecureChannelCredentials()); auto objstore_channel = grpc::CreateChannel(objstore_addr, grpc::InsecureChannelCredentials()); Worker* worker = new Worker(std::string(worker_addr), scheduler_channel, objstore_channel); - worker->register_worker(std::string(worker_addr), std::string(objstore_addr)); + worker->register_worker(std::string(worker_addr), std::string(objstore_addr), is_driver); return PyCapsule_New(static_cast(worker), "worker", &WorkerCapsule_Destructor); } diff --git a/src/scheduler.cc b/src/scheduler.cc index 13898a26e..ba4548859 100644 --- a/src/scheduler.cc +++ b/src/scheduler.cc @@ -103,7 +103,7 @@ Status SchedulerService::RegisterObjStore(ServerContext* context, const Register } Status SchedulerService::RegisterWorker(ServerContext* context, const RegisterWorkerRequest* request, RegisterWorkerReply* reply) { - std::pair info = register_worker(request->worker_address(), request->objstore_address()); + std::pair info = register_worker(request->worker_address(), request->objstore_address(), request->is_driver()); WorkerId workerid = info.first; ObjStoreId objstoreid = info.second; RAY_LOG(RAY_INFO, "registered worker with workerid " << workerid); @@ -139,31 +139,32 @@ Status SchedulerService::ObjReady(ServerContext* context, const ObjReadyRequest* } Status SchedulerService::ReadyForNewTask(ServerContext* context, const ReadyForNewTaskRequest* request, AckReply* reply) { - RAY_LOG(RAY_INFO, "worker " << request->workerid() << " is ready for a new task"); + WorkerId workerid = request->workerid(); + OperationId operationid = (*workers_.get())[workerid].current_task; + RAY_LOG(RAY_INFO, "worker " << workerid << " is ready for a new task"); + RAY_CHECK(operationid != ROOT_OPERATION, "A driver appears to have called ReadyForNewTask."); if (request->has_previous_task_info()) { - OperationId operationid; - operationid = (*workers_.get())[request->workerid()].current_task; + RAY_CHECK(operationid != NO_OPERATION, "request->has_previous_task_info() should not be true if operationid == NO_OPERATION."); std::string task_name; task_name = computation_graph_.get()->get_task(operationid).name(); TaskStatus info; { auto workers = workers_.get(); - operationid = (*workers)[request->workerid()].current_task; info.set_operationid(operationid); info.set_function_name(task_name); - info.set_worker_address((*workers)[request->workerid()].worker_address); + info.set_worker_address((*workers)[workerid].worker_address); info.set_error_message(request->previous_task_info().error_message()); - (*workers)[request->workerid()].current_task = NO_OPERATION; // clear operation ID + (*workers)[workerid].current_task = NO_OPERATION; // clear operation ID } if (!request->previous_task_info().task_succeeded()) { - RAY_LOG(RAY_INFO, "Error: Task " << info.operationid() << " executing function " << info.function_name() << " on worker " << request->workerid() << " failed with error message: " << info.error_message()); + RAY_LOG(RAY_INFO, "Error: Task " << info.operationid() << " executing function " << info.function_name() << " on worker " << workerid << " failed with error message: " << info.error_message()); failed_tasks_.get()->push_back(info); } else { successful_tasks_.get()->push_back(info.operationid()); } // TODO(rkn): Handle task failure } - avail_workers_.get()->push_back(request->workerid()); + avail_workers_.get()->push_back(workerid); schedule(); return Status::OK; } @@ -223,7 +224,7 @@ Status SchedulerService::TaskInfo(ServerContext* context, const TaskInfoRequest* } for (int i = 0; i < workers->size(); ++i) { OperationId operationid = (*workers)[i].current_task; - if (operationid != NO_OPERATION) { + if (operationid != NO_OPERATION && operationid != ROOT_OPERATION) { const Task& task = computation_graph->get_task(operationid); TaskStatus* info = reply->add_running_task(); info->set_operationid(operationid); @@ -337,7 +338,7 @@ bool SchedulerService::can_run(const Task& task) { return true; } -std::pair SchedulerService::register_worker(const std::string& worker_address, const std::string& objstore_address) { +std::pair SchedulerService::register_worker(const std::string& worker_address, const std::string& objstore_address, bool is_driver) { RAY_LOG(RAY_INFO, "registering worker " << worker_address << " connected to object store " << objstore_address); ObjStoreId objstoreid = std::numeric_limits::max(); // TODO: HACK: num_attempts is a hack @@ -363,7 +364,11 @@ std::pair SchedulerService::register_worker(const std::str (*workers)[workerid].objstoreid = objstoreid; (*workers)[workerid].worker_stub = WorkerService::NewStub(channel); (*workers)[workerid].worker_address = worker_address; - (*workers)[workerid].current_task = NO_OPERATION; + if (is_driver) { + (*workers)[workerid].current_task = ROOT_OPERATION; // We use this field to identify which workers are drivers. + } else { + (*workers)[workerid].current_task = NO_OPERATION; + } } return std::make_pair(workerid, objstoreid); } diff --git a/src/scheduler.h b/src/scheduler.h index 5b907f805..aad0ff8d0 100644 --- a/src/scheduler.h +++ b/src/scheduler.h @@ -83,7 +83,7 @@ public: // checks if the dependencies of the task are met bool can_run(const Task& task); // register a worker and its object store (if it has not been registered yet) - std::pair register_worker(const std::string& worker_address, const std::string& objstore_address); + std::pair register_worker(const std::string& worker_address, const std::string& objstore_address, bool is_driver); // register a new object with the scheduler and return its object reference ObjRef register_new_object(); // register the location of the object reference in the object table diff --git a/src/worker.cc b/src/worker.cc index d79060430..6423da92f 100644 --- a/src/worker.cc +++ b/src/worker.cc @@ -43,10 +43,11 @@ SubmitTaskReply Worker::submit_task(SubmitTaskRequest* request, int max_retries, return reply; } -void Worker::register_worker(const std::string& worker_address, const std::string& objstore_address) { +void Worker::register_worker(const std::string& worker_address, const std::string& objstore_address, bool is_driver) { RegisterWorkerRequest request; request.set_worker_address(worker_address); request.set_objstore_address(objstore_address); + request.set_is_driver(is_driver); RegisterWorkerReply reply; ClientContext context; Status status = scheduler_stub_->RegisterWorker(&context, request, &reply); diff --git a/src/worker.h b/src/worker.h index d52332e6a..95697e78f 100644 --- a/src/worker.h +++ b/src/worker.h @@ -45,7 +45,7 @@ class Worker { // and try to resubmit the task to the scheduler up to max_retries more times. SubmitTaskReply submit_task(SubmitTaskRequest* request, int max_retries = 120, int retry_wait_milliseconds = 500); // send request to the scheduler to register this worker - void register_worker(const std::string& worker_address, const std::string& objstore_address); + void register_worker(const std::string& worker_address, const std::string& objstore_address, bool is_driver); // get a new object reference that is registered with the scheduler ObjRef get_objref(); // request an object to be delivered to the local object store