From ac363bf4515b57e760c232611b85a175aa5b56ef Mon Sep 17 00:00:00 2001 From: Robert Nishihara Date: Thu, 4 Aug 2016 17:47:08 -0700 Subject: [PATCH] Let worker get worker address and object store address from scheduler (#350) --- doc/using-ray-on-a-cluster.md | 2 +- lib/python/ray/services.py | 85 +++++++++++++----------------- lib/python/ray/worker.py | 42 +++++++-------- protos/ray.proto | 8 +-- scripts/cluster.py | 4 +- scripts/default_worker.py | 8 +-- src/ipc.cc | 2 +- src/raylib.cc | 23 ++++---- src/scheduler.cc | 98 +++++++++++++++++++++-------------- src/scheduler.h | 2 - src/worker.cc | 23 ++++---- src/worker.h | 8 ++- test/runtest.py | 7 +-- 13 files changed, 165 insertions(+), 147 deletions(-) diff --git a/doc/using-ray-on-a-cluster.md b/doc/using-ray-on-a-cluster.md index 2fb35e386..cd851f6a8 100644 --- a/doc/using-ray-on-a-cluster.md +++ b/doc/using-ray-on-a-cluster.md @@ -152,7 +152,7 @@ to the cluster's head node (as described by the output of the Then within a Python interpreter, run the following commands. import ray - ray.init(scheduler_address="98.76.54.321:10001", objstore_address="98.76.54.321:20001", driver_address="98.76.54.321:30001") + ray.init(node_ip_address="98.76.54.321", scheduler_address="98.76.54.321:10001") ``` 7. Note that there are several more commands that can be run from within diff --git a/lib/python/ray/services.py b/lib/python/ray/services.py index 03548e12e..95275cea1 100644 --- a/lib/python/ray/services.py +++ b/lib/python/ray/services.py @@ -14,7 +14,6 @@ _services_env["PATH"] = os.pathsep.join([os.path.dirname(os.path.abspath(__file_ # mode. all_processes = [] -IP_ADDRESS = "127.0.0.1" TIMEOUT_SECONDS = 5 def address(host, port): @@ -26,18 +25,6 @@ def new_scheduler_port(): scheduler_port_counter += 1 return 10000 + scheduler_port_counter -worker_port_counter = 0 -def new_worker_port(): - global worker_port_counter - worker_port_counter += 1 - return 40000 + worker_port_counter - -driver_port_counter = 0 -def new_driver_port(): - global driver_port_counter - driver_port_counter += 1 - return 30000 + driver_port_counter - objstore_port_counter = 0 def new_objstore_port(): global objstore_port_counter @@ -53,23 +40,23 @@ def cleanup(): started and disconnected by worker.py. """ global all_processes - for p, address in all_processes: + successfully_shut_down = True + for p in all_processes: if p.poll() is not None: # process has already terminated - print "Process at address " + address + " has already terminated." continue - print "Attempting to kill process at address " + address + "." p.kill() time.sleep(0.05) # is this necessary? if p.poll() is not None: - print "Successfully killed process at address " + address + "." continue - print "Kill attempt failed, attempting to terminate process at address " + address + "." p.terminate() time.sleep(0.05) # is this necessary? if p.poll is not None: - print "Successfully terminated process at address " + address + "." continue - print "Termination attempt failed, giving up." + successfully_shut_down = False + if successfully_shut_down: + print "Successfully shut down Ray." + else: + print "Ray did not shut down properly." all_processes = [] def start_scheduler(scheduler_address, local): @@ -83,7 +70,7 @@ def start_scheduler(scheduler_address, local): """ p = subprocess.Popen(["scheduler", scheduler_address, "--log-file-name", config.get_log_file_path("scheduler.log")], env=_services_env) if local: - all_processes.append((p, scheduler_address)) + all_processes.append(p) def start_objstore(scheduler_address, objstore_address, local): """This method starts an object store process. @@ -98,38 +85,40 @@ def start_objstore(scheduler_address, objstore_address, local): """ p = subprocess.Popen(["objstore", scheduler_address, objstore_address, "--log-file-name", config.get_log_file_path("-".join(["objstore", objstore_address]) + ".log")], env=_services_env) if local: - all_processes.append((p, objstore_address)) + all_processes.append(p) -def start_worker(worker_path, scheduler_address, objstore_address, worker_address, local, user_source_directory=None): +def start_worker(node_ip_address, worker_path, scheduler_address, objstore_address=None, local=True, user_source_directory=None): """This method starts a worker process. Args: + node_ip_address (str): The IP address of the node that the worker runs on. worker_path (str): The path of the source code which the worker process will run. scheduler_address (str): The ip address and port of the scheduler to connect to. - objstore_address (str): The ip address and port of the object store to - connect to. - worker_address (str): The ip address and port to use for the worker. - local (bool): True if using Ray in local mode. If local is true, then this - process will be killed by serices.cleanup() when the Python process that - imported services exits. - user_source_directory (str): The directory containing the application code. - This directory will be added to the path of each worker. If not provided, - the directory of the script currently being run is used. + objstore_address (Optional[str]): The ip address and port of the object + store to connect to. + local (Optional[bool]): True if using Ray in local mode. If local is true, + then this process will be killed by serices.cleanup() when the Python + process that imported services exits. This is True by default. + user_source_directory (Optional[str]): The directory containing the + application code. This directory will be added to the path of each worker. + If not provided, the directory of the script currently being run is used. """ if user_source_directory is None: # This extracts the directory of the script that is currently being run. # This will allow users to import modules contained in this directory. user_source_directory = os.path.dirname(os.path.abspath(os.path.join(os.path.curdir, sys.argv[0]))) - p = subprocess.Popen(["python", - worker_path, - "--user-source-directory=" + user_source_directory, - "--scheduler-address=" + scheduler_address, - "--objstore-address=" + objstore_address, - "--worker-address=" + worker_address]) + command = ["python", + worker_path, + "--node-ip-address=" + node_ip_address, + "--user-source-directory=" + user_source_directory, + "--scheduler-address=" + scheduler_address] + if objstore_address is not None: + command.append("--objstore-address=" + objstore_address) + p = subprocess.Popen(command) if local: - all_processes.append((p, worker_address)) + all_processes.append(p) def start_node(scheduler_address, node_ip_address, num_workers, worker_path=None, user_source_directory=None): """Start an object store and associated workers in the cluster setting. @@ -153,7 +142,7 @@ def start_node(scheduler_address, node_ip_address, num_workers, worker_path=None if worker_path is None: worker_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../../scripts/default_worker.py") for _ in range(num_workers): - start_worker(worker_path, scheduler_address, objstore_address, address(node_ip_address, new_worker_port()), user_source_directory=user_source_directory, local=False) + start_worker(node_ip_address, worker_path, scheduler_address, objstore_address=objstore_address, user_source_directory=user_source_directory, local=False) time.sleep(0.5) def start_workers(scheduler_address, objstore_address, num_workers, worker_path): @@ -174,9 +163,9 @@ def start_workers(scheduler_address, objstore_address, num_workers, worker_path) """ node_ip_address = objstore_address.split(":")[0] for _ in range(num_workers): - start_worker(worker_path, scheduler_address, objstore_address, address(node_ip_address, new_worker_port()), local=False) + start_worker(node_ip_address, worker_path, scheduler_address, objstore_address=objstore_address, local=False) -def start_ray_local(num_objstores=1, num_workers=0, worker_path=None): +def start_ray_local(node_ip_address="127.0.0.1", num_objstores=1, num_workers=0, worker_path=None): """Start Ray in local mode. This method starts Ray in local mode (as opposed to cluster mode, which is @@ -190,20 +179,19 @@ def start_ray_local(num_objstores=1, num_workers=0, worker_path=None): worker. Returns: - The address of the scheduler, the addresses of all of the object stores, and - the one new driver address for each object store. + The address of the scheduler and the addresses of all of the object stores. """ if worker_path is None: worker_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../../scripts/default_worker.py") if num_workers > 0 and num_objstores < 1: raise Exception("Attempting to start a cluster with {} workers per object store, but `num_objstores` is {}.".format(num_objstores)) - scheduler_address = address(IP_ADDRESS, new_scheduler_port()) + scheduler_address = address(node_ip_address, new_scheduler_port()) start_scheduler(scheduler_address, local=True) time.sleep(0.1) objstore_addresses = [] # create objstores for i in range(num_objstores): - objstore_address = address(IP_ADDRESS, new_objstore_port()) + objstore_address = address(node_ip_address, new_objstore_port()) objstore_addresses.append(objstore_address) start_objstore(scheduler_address, objstore_address, local=True) time.sleep(0.2) @@ -214,8 +202,7 @@ def start_ray_local(num_objstores=1, num_workers=0, worker_path=None): # remaining number of workers. num_workers_to_start = num_workers - (num_objstores - 1) * (num_workers / num_objstores) for _ in range(num_workers_to_start): - start_worker(worker_path, scheduler_address, objstore_address, address(IP_ADDRESS, new_worker_port()), local=True) + start_worker(node_ip_address, worker_path, scheduler_address, objstore_address=objstore_address, local=True) time.sleep(0.3) - driver_addresses = [address(IP_ADDRESS, new_driver_port()) for _ in range(num_objstores)] - return scheduler_address, objstore_addresses, driver_addresses + return scheduler_address, objstore_addresses diff --git a/lib/python/ray/worker.py b/lib/python/ray/worker.py index 87ae3a76a..a089f4241 100644 --- a/lib/python/ray/worker.py +++ b/lib/python/ray/worker.py @@ -658,7 +658,7 @@ def register_module(module, worker=global_worker): _logger().info("registering {}.".format(val.func_name)) worker.register_function(val) -def init(start_ray_local=False, num_workers=None, num_objstores=None, scheduler_address=None, objstore_address=None, driver_address=None, driver_mode=SCRIPT_MODE): +def init(start_ray_local=False, num_workers=None, num_objstores=None, scheduler_address=None, node_ip_address=None, driver_mode=SCRIPT_MODE): """Either connect to an existing Ray cluster or start one and connect to it. This method handles two cases. Either a Ray cluster already exists and we @@ -675,10 +675,9 @@ def init(start_ray_local=False, num_workers=None, num_objstores=None, scheduler_ start_ray_local is True. scheduler_address (Optional[str]): The address of the scheduler to connect to if start_ray_local is False. - objstore_address (Optional[str]): The address of the object store to connect - to if start_ray_local is False. - driver_address (Optional[str]): The address of this driver if - start_ray_local is False. + node_ip_address (Optional[str]): The address of the node the worker is + running on. It is required if start_ray_local is False and it cannot be + provided otherwise. driver_mode (Optional[bool]): The mode in which to start the driver. This should be one of SCRIPT_MODE, PYTHON_MODE, and SILENT_MODE. @@ -689,28 +688,28 @@ def init(start_ray_local=False, num_workers=None, num_objstores=None, scheduler_ if start_ray_local: # In this case, we launch a scheduler, a new object store, and some workers, # and we connect to them. - if (scheduler_address is not None) or (objstore_address is not None) or (driver_address is not None): - raise Exception("If start_ray_local=True, then you cannot pass in a scheduler_address, objstore_address, or worker_address.") + if (scheduler_address is not None) or (node_ip_address is not None): + raise Exception("If start_ray_local=True, then you cannot pass in a scheduler_address or a node_ip_address.") if driver_mode not in [SCRIPT_MODE, PYTHON_MODE, SILENT_MODE]: raise Exception("If start_ray_local=True, then driver_mode must be in [SCRIPT_MODE, PYTHON_MODE, SILENT_MODE].") + # Use the address 127.0.0.1 in local mode. + node_ip_address = "127.0.0.1" num_workers = 1 if num_workers is None else num_workers num_objstores = 1 if num_objstores is None else num_objstores # Start the scheduler, object store, and some workers. These will be killed # by the call to cleanup(), which happens when the Python script exits. - scheduler_address, objstore_addresses, driver_addresses = services.start_ray_local(num_objstores=num_objstores, num_workers=num_workers, worker_path=None) - # It is possible for start_ray_local to return multiple object stores, but - # we will only connect the driver to one of them. - objstore_address = objstore_addresses[0] - driver_address = driver_addresses[0] + scheduler_address, _ = services.start_ray_local(num_objstores=num_objstores, num_workers=num_workers, worker_path=None) else: # In this case, there is an existing scheduler and object store, and we do # not need to start any processes. if (num_workers is not None) or (num_objstores is not None): raise Exception("The arguments num_workers and num_objstores must not be provided unless start_ray_local=True.") + if node_ip_address is None: + raise Exception("When start_ray_local=False, the node_ip_address of the current node must be provided.") # Connect this driver to the scheduler and object store. The corresponing call # to disconnect will happen in the call to cleanup() when the Python script # exits. - connect(scheduler_address, objstore_address, driver_address, is_driver=True, worker=global_worker, mode=driver_mode) + connect(node_ip_address, scheduler_address, is_driver=True, worker=global_worker, mode=driver_mode) def cleanup(worker=global_worker): """Disconnect the driver, and terminate any processes started in init. @@ -726,14 +725,15 @@ def cleanup(worker=global_worker): atexit.register(cleanup) -def connect(scheduler_address, objstore_address, worker_address, is_driver=False, worker=global_worker, mode=WORKER_MODE): +def connect(node_ip_address, scheduler_address, objstore_address=None, is_driver=False, worker=global_worker, mode=WORKER_MODE): """Connect this worker to the scheduler and an object store. Args: + node_ip_address (str): The ip address of the node the worker runs on. scheduler_address (str): The ip address and port of the scheduler. - objstore_address (str): The ip address and port of the local object store. - worker_address (str): The ip address and port of this worker. The port can - be chosen arbitrarily. + objstore_address (Optional[str]): The ip address and port of the local + object store. Normally, this argument should be omitted and the scheduler + will tell the worker what object store to connect to. is_driver (bool): True if this worker is a driver and false otherwise. mode: The mode of the worker. One of SCRIPT_MODE, WORKER_MODE, PYTHON_MODE, and SILENT_MODE. @@ -741,22 +741,20 @@ def connect(scheduler_address, objstore_address, worker_address, is_driver=False if hasattr(worker, "handle"): del worker.handle worker.scheduler_address = scheduler_address - worker.objstore_address = objstore_address - worker.worker_address = worker_address - worker.handle = raylib.create_worker(worker.scheduler_address, worker.objstore_address, worker.worker_address, is_driver) + worker.handle, worker.worker_address = raylib.create_worker(node_ip_address, scheduler_address, objstore_address if objstore_address is not None else "", is_driver) worker.set_mode(mode) FORMAT = "%(asctime)-15s %(message)s" # Configure the Python logging module. Note that if we do not provide our own # logger, then our logging will interfere with other Python modules that also # use the logging module. - log_handler = logging.FileHandler(config.get_log_file_path("-".join(["worker", worker_address]) + ".log")) + log_handler = logging.FileHandler(config.get_log_file_path("-".join(["worker", worker.worker_address]) + ".log")) log_handler.setLevel(logging.DEBUG) log_handler.setFormatter(logging.Formatter(FORMAT)) _logger().addHandler(log_handler) _logger().setLevel(logging.DEBUG) _logger().propagate = False # Configure the logging from the worker C++ code. - raylib.set_log_config(config.get_log_file_path("-".join(["worker", worker_address, "c++"]) + ".log")) + raylib.set_log_config(config.get_log_file_path("-".join(["worker", worker.worker_address, "c++"]) + ".log")) if mode in [SCRIPT_MODE, SILENT_MODE]: for function_to_export in worker.cached_remote_functions: raylib.export_function(worker.handle, function_to_export) diff --git a/protos/ray.proto b/protos/ray.proto index a7379aec1..6906c5cf2 100644 --- a/protos/ray.proto +++ b/protos/ray.proto @@ -62,14 +62,16 @@ message AckReply { } message RegisterWorkerRequest { - string worker_address = 1; // IP address of the worker being registered - string objstore_address = 2; // IP address of the object store the worker is connected to - bool is_driver = 3; // True if the worker is a driver, and false otherwise + string node_ip_address = 1; // The IP address of the node the worker is running on. + string objstore_address = 2; // The address of the object store the worker should connect to. If omitted, this will be assigned by the scheduler. + bool is_driver = 3; // True if the worker is a driver, and false otherwise. } message RegisterWorkerReply { uint64 workerid = 1; // Worker ID assigned by the scheduler uint64 objstoreid = 2; // The Object store ID of the worker's local object store + string worker_address = 3; // IP address of the worker being registered + string objstore_address = 4; // IP address of the object store the worker should connect to } message RegisterObjStoreRequest { diff --git a/scripts/cluster.py b/scripts/cluster.py index 8d656d68a..ccd379b5b 100644 --- a/scripts/cluster.py +++ b/scripts/cluster.py @@ -185,8 +185,8 @@ class RayCluster(object): Then within a Python interpreter or script, run the following commands. import ray - ray.init(scheduler_address="{}:10001", objstore_address="{}:20001", driver_address="{}:30001") - """.format(self.key_file, self.username, self.node_ip_addresses[0], cd_location, setup_env_path, self.node_private_ip_addresses[0], self.node_private_ip_addresses[0], self.node_private_ip_addresses[0]) + ray.init(node_ip_address="{}", scheduler_address="{}:10001") + """.format(self.key_file, self.username, self.node_ip_addresses[0], cd_location, setup_env_path, self.node_private_ip_addresses[0], self.node_private_ip_addresses[0]) def stop_ray(self): """Kill all of the processes in the Ray cluster. diff --git a/scripts/default_worker.py b/scripts/default_worker.py index a24ce223b..3c2706f2e 100644 --- a/scripts/default_worker.py +++ b/scripts/default_worker.py @@ -6,9 +6,9 @@ import ray parser = argparse.ArgumentParser(description="Parse addresses for the worker to connect to.") parser.add_argument("--user-source-directory", type=str, help="the directory containing the user's application code") -parser.add_argument("--scheduler-address", default="127.0.0.1:10001", type=str, help="the scheduler's address") -parser.add_argument("--objstore-address", default="127.0.0.1:20001", type=str, help="the objstore's address") -parser.add_argument("--worker-address", default="127.0.0.1:40001", type=str, help="the worker's address") +parser.add_argument("--node-ip-address", required=True, type=str, help="the ip address of the worker's node") +parser.add_argument("--scheduler-address", required=True, type=str, help="the scheduler's address") +parser.add_argument("--objstore-address", type=str, help="the objstore's address") if __name__ == "__main__": args = parser.parse_args() @@ -18,6 +18,6 @@ if __name__ == "__main__": # insert into the first position (as opposed to the zeroth) because the # zeroth position is reserved for the empty string. sys.path.insert(1, args.user_source_directory) - ray.worker.connect(args.scheduler_address, args.objstore_address, args.worker_address) + ray.worker.connect(args.node_ip_address, args.scheduler_address) ray.worker.main_loop() diff --git a/src/ipc.cc b/src/ipc.cc index 12ec411d2..dc5bcbe1d 100644 --- a/src/ipc.cc +++ b/src/ipc.cc @@ -52,7 +52,7 @@ bool MessageQueue<>::connect(const std::string& name, bool create, size_t messag } } catch (bip::interprocess_exception &ex) { - RAY_CHECK(false, "boost::interprocess exception: " << ex.what()); + RAY_CHECK(false, "name = " << name_ << ", create = " << create << ", boost::interprocess exception: " << ex.what()); } return true; } diff --git a/src/raylib.cc b/src/raylib.cc index bdc7a6cf4..0b4849b6c 100644 --- a/src/raylib.cc +++ b/src/raylib.cc @@ -642,19 +642,24 @@ static PyObject* deserialize_task(PyObject* worker_capsule, const Task& task) { // Ray Python API static PyObject* create_worker(PyObject* self, PyObject* args) { - const char* scheduler_addr; - const char* objstore_addr; - const char* worker_addr; + const char* node_ip_address; + const char* scheduler_address; + // The object store address can be the empty string, in which case the + // scheduler will choose the object store address. + const char* objstore_address; PyObject* is_driver_obj; - if (!PyArg_ParseTuple(args, "sssO", &scheduler_addr, &objstore_addr, &worker_addr, &is_driver_obj)) { + if (!PyArg_ParseTuple(args, "sssO", &node_ip_address, &scheduler_address, &objstore_address, &is_driver_obj)) { return NULL; } bool is_driver = PyObject_IsTrue(is_driver_obj); - auto scheduler_channel = grpc::CreateChannel(scheduler_addr, grpc::InsecureChannelCredentials()); - auto objstore_channel = grpc::CreateChannel(objstore_addr, grpc::InsecureChannelCredentials()); - Worker* worker = new Worker(std::string(worker_addr), scheduler_channel, objstore_channel); - worker->register_worker(std::string(worker_addr), std::string(objstore_addr), is_driver); - return PyCapsule_New(static_cast(worker), "worker", &WorkerCapsule_Destructor); + Worker* worker = new Worker(std::string(scheduler_address)); + worker->register_worker(std::string(node_ip_address), std::string(objstore_address), is_driver); + + PyObject* t = PyTuple_New(2); + PyObject* worker_capsule = PyCapsule_New(static_cast(worker), "worker", &WorkerCapsule_Destructor); + PyTuple_SetItem(t, 0, worker_capsule); + PyTuple_SetItem(t, 1, PyString_FromString(worker->get_worker_address())); + return t; } static PyObject* disconnect(PyObject* self, PyObject* args) { diff --git a/src/scheduler.cc b/src/scheduler.cc index ac875c47a..085e94af2 100644 --- a/src/scheduler.cc +++ b/src/scheduler.cc @@ -215,12 +215,66 @@ Status SchedulerService::RegisterObjStore(ServerContext* context, const Register } Status SchedulerService::RegisterWorker(ServerContext* context, const RegisterWorkerRequest* request, RegisterWorkerReply* reply) { - std::pair info = register_worker(request->worker_address(), request->objstore_address(), request->is_driver()); - WorkerId workerid = info.first; - ObjStoreId objstoreid = info.second; - RAY_LOG(RAY_INFO, "registered worker with workerid " << workerid); + std::string objstore_address = request->objstore_address(); + std::string node_ip_address = request->node_ip_address(); + bool is_driver = request->is_driver(); + RAY_LOG(RAY_INFO, "Registering a worker from node with IP address " << node_ip_address); + // Find the object store to connect to. We use the max size to indicate that + // the object store for this worker has not been found. + ObjStoreId objstoreid = std::numeric_limits::max(); + // TODO: HACK: num_attempts is a hack + for (int num_attempts = 0; num_attempts < 30; ++num_attempts) { + auto objstores = GET(objstores_); + for (size_t i = 0; i < objstores->size(); ++i) { + if (objstore_address != "" && (*objstores)[i].address == objstore_address) { + // This object store address is the same as the provided object store + // address. + objstoreid = i; + } + if ((*objstores)[i].address.compare(0, node_ip_address.size(), node_ip_address) == 0) { + // The object store address was not provided and this object store + // address has node_ip_address as a prefix, so it is on the same machine + // as the worker that is registering. + objstoreid = i; + objstore_address = (*objstores)[i].address; + } + } + if (objstoreid == std::numeric_limits::max()) { + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + } else { + break; + } + } + if (objstore_address.empty()) { + RAY_CHECK_NEQ(objstoreid, std::numeric_limits::max(), "No object store with IP address " << node_ip_address << " has registered."); + } else { + RAY_CHECK_NEQ(objstoreid, std::numeric_limits::max(), "Object store with address " << objstore_address << " not yet registered."); + } + // Populate the worker information and generate a worker address. + WorkerId workerid; + std::string worker_address; + { + auto workers = GET(workers_); + workerid = workers->size(); + worker_address = node_ip_address + ":" + std::to_string(40000 + workerid); + workers->push_back(WorkerHandle()); + auto channel = grpc::CreateChannel(worker_address, grpc::InsecureChannelCredentials()); + (*workers)[workerid].channel = channel; + (*workers)[workerid].objstoreid = objstoreid; + (*workers)[workerid].worker_stub = WorkerService::NewStub(channel); + (*workers)[workerid].worker_address = worker_address; + (*workers)[workerid].initialized = false; + if (is_driver) { + (*workers)[workerid].current_task = ROOT_OPERATION; // We use this field to identify which workers are drivers. + } else { + (*workers)[workerid].current_task = NO_OPERATION; + } + } + RAY_LOG(RAY_INFO, "Finished registering worker with workerid " << workerid << ", worker address " << worker_address << " on node with IP address " << node_ip_address << ", is_driver = " << is_driver << ", assigned to object store with id " << objstoreid << " and address " << objstore_address); reply->set_workerid(workerid); reply->set_objstoreid(objstoreid); + reply->set_worker_address(worker_address); + reply->set_objstore_address(objstore_address); schedule(); return Status::OK; } @@ -540,42 +594,6 @@ bool SchedulerService::can_run(const Task& task) { return true; } -std::pair SchedulerService::register_worker(const std::string& worker_address, const std::string& objstore_address, bool is_driver) { - RAY_LOG(RAY_INFO, "registering worker " << worker_address << " connected to object store " << objstore_address); - ObjStoreId objstoreid = std::numeric_limits::max(); - // TODO: HACK: num_attempts is a hack - for (int num_attempts = 0; num_attempts < 30; ++num_attempts) { - auto objstores = GET(objstores_); - for (size_t i = 0; i < objstores->size(); ++i) { - if ((*objstores)[i].address == objstore_address) { - objstoreid = i; - } - } - if (objstoreid == std::numeric_limits::max()) { - std::this_thread::sleep_for(std::chrono::milliseconds(100)); - } - } - RAY_CHECK_NEQ(objstoreid, std::numeric_limits::max(), "object store with address " << objstore_address << " not yet registered"); - WorkerId workerid; - { - auto workers = GET(workers_); - workerid = workers->size(); - workers->push_back(WorkerHandle()); - auto channel = grpc::CreateChannel(worker_address, grpc::InsecureChannelCredentials()); - (*workers)[workerid].channel = channel; - (*workers)[workerid].objstoreid = objstoreid; - (*workers)[workerid].worker_stub = WorkerService::NewStub(channel); - (*workers)[workerid].worker_address = worker_address; - (*workers)[workerid].initialized = false; - if (is_driver) { - (*workers)[workerid].current_task = ROOT_OPERATION; // We use this field to identify which workers are drivers. - } else { - (*workers)[workerid].current_task = NO_OPERATION; - } - } - return std::make_pair(workerid, objstoreid); -} - ObjectID SchedulerService::register_new_object() { // If we don't simultaneously lock objtable_ and target_objectids_, we will probably get errors. // TODO(rkn): increment/decrement_reference_count also acquire reference_counts_lock_ and target_objectids_lock_ (through has_canonical_objectid()), which caused deadlock in the past diff --git a/src/scheduler.h b/src/scheduler.h index f8136a5dd..7bec6410a 100644 --- a/src/scheduler.h +++ b/src/scheduler.h @@ -99,8 +99,6 @@ public: void assign_task(OperationId operationid, WorkerId workerid, const MySynchronizedPtr &computation_graph); // checks if the dependencies of the task are met bool can_run(const Task& task); - // register a worker and its object store (if it has not been registered yet) - std::pair register_worker(const std::string& worker_address, const std::string& objstore_address, bool is_driver); // register a new object with the scheduler and return its object ID ObjectID register_new_object(); // register the location of the object ID in the object table diff --git a/src/worker.cc b/src/worker.cc index b0886c29b..4a4589274 100644 --- a/src/worker.cc +++ b/src/worker.cc @@ -56,11 +56,10 @@ Status WorkerServiceImpl::Die(ServerContext* context, const DieRequest* request, return Status::OK; } -Worker::Worker(const std::string& worker_address, std::shared_ptr scheduler_channel, std::shared_ptr objstore_channel) - : worker_address_(worker_address), - scheduler_stub_(Scheduler::NewStub(scheduler_channel)) { - RAY_CHECK(receive_queue_.connect(worker_address_, true), "error connecting receive_queue_"); - connected_ = true; +Worker::Worker(const std::string& scheduler_address) + : scheduler_address_(scheduler_address) { + auto scheduler_channel = grpc::CreateChannel(scheduler_address, grpc::InsecureChannelCredentials()); + scheduler_stub_ = Scheduler::NewStub(scheduler_channel); } SubmitTaskReply Worker::submit_task(SubmitTaskRequest* request, int max_retries, int retry_wait_milliseconds) { @@ -87,10 +86,12 @@ bool Worker::kill_workers(ClientContext &context) { return reply.success(); } -void Worker::register_worker(const std::string& worker_address, const std::string& objstore_address, bool is_driver) { +void Worker::register_worker(const std::string& node_ip_address, const std::string& objstore_address, bool is_driver) { unsigned int retry_wait_milliseconds = 20; RegisterWorkerRequest request; - request.set_worker_address(worker_address); + request.set_node_ip_address(node_ip_address); + // The object store address can be the empty string, in which case the + // scheduler will assign an object store address. request.set_objstore_address(objstore_address); request.set_is_driver(is_driver); RegisterWorkerReply reply; @@ -108,9 +109,13 @@ void Worker::register_worker(const std::string& worker_address, const std::strin } workerid_ = reply.workerid(); objstoreid_ = reply.objstoreid(); + objstore_address_ = reply.objstore_address(); + worker_address_ = reply.worker_address(); segmentpool_ = std::make_shared(objstoreid_, false); - RAY_CHECK(request_obj_queue_.connect(std::string("queue:") + objstore_address + std::string(":obj"), false), "error connecting request_obj_queue_"); - RAY_CHECK(receive_obj_queue_.connect(std::string("queue:") + objstore_address + std::string(":worker:") + std::to_string(workerid_) + std::string(":obj"), true), "error connecting receive_obj_queue_"); + RAY_CHECK(receive_queue_.connect(worker_address_, true), "error connecting receive_queue_"); + RAY_CHECK(request_obj_queue_.connect(std::string("queue:") + objstore_address_ + std::string(":obj"), false), "error connecting request_obj_queue_"); + RAY_CHECK(receive_obj_queue_.connect(std::string("queue:") + objstore_address_ + std::string(":worker:") + std::to_string(workerid_) + std::string(":obj"), true), "error connecting receive_obj_queue_"); + connected_ = true; return; } diff --git a/src/worker.h b/src/worker.h index d49f99045..ff68abee6 100644 --- a/src/worker.h +++ b/src/worker.h @@ -37,7 +37,7 @@ private: class Worker { public: - Worker(const std::string& worker_address, std::shared_ptr scheduler_channel, std::shared_ptr objstore_channel); + Worker(const std::string& scheduler_address); // Submit a remote task to the scheduler. If the function in the task is not // registered with the scheduler, we will sleep for retry_wait_milliseconds @@ -46,7 +46,7 @@ class Worker { // Requests the scheduler to kill workers bool kill_workers(ClientContext &context); // send request to the scheduler to register this worker - void register_worker(const std::string& worker_address, const std::string& objstore_address, bool is_driver); + void register_worker(const std::string& ip_address, const std::string& objstore_address, bool is_driver); // get a new object ID that is registered with the scheduler ObjectID get_objectid(); // request an object to be delivered to the local object store @@ -94,6 +94,8 @@ class Worker { bool export_function(const std::string& function); // export reusable variable to workers void export_reusable_variable(const std::string& name, const std::string& initializer, const std::string& reinitializer); + // return the worker address + const char* get_worker_address() { return worker_address_.c_str(); } private: bool connected_; @@ -104,6 +106,8 @@ class Worker { bip::managed_shared_memory segment_; WorkerId workerid_; ObjStoreId objstoreid_; + std::string scheduler_address_; + std::string objstore_address_; std::string worker_address_; MessageQueue request_obj_queue_; MessageQueue receive_obj_queue_; diff --git a/test/runtest.py b/test/runtest.py index 7dd707cfc..d29127b4a 100644 --- a/test/runtest.py +++ b/test/runtest.py @@ -81,12 +81,13 @@ class ObjStoreTest(unittest.TestCase): # Test setting up object stores, transfering data between them and retrieving data to a client def testObjStore(self): - scheduler_address, objstore_addresses, driver_addresses = ray.services.start_ray_local(num_objstores=2, num_workers=0, worker_path=None) + scheduler_address, objstore_addresses = ray.services.start_ray_local(num_objstores=2, num_workers=0, worker_path=None) w1 = ray.worker.Worker() w2 = ray.worker.Worker() - ray.connect(scheduler_address, objstore_addresses[0], driver_addresses[0], is_driver=True, mode=ray.SCRIPT_MODE, worker=w1) + node_ip_address = "127.0.0.1" + ray.connect(node_ip_address, scheduler_address, objstore_addresses[0], is_driver=True, mode=ray.SCRIPT_MODE, worker=w1) ray.reusables._cached_reusables = [] # This is a hack to make the test run. - ray.connect(scheduler_address, objstore_addresses[1], driver_addresses[1], is_driver=True, mode=ray.SCRIPT_MODE, worker=w2) + ray.connect(node_ip_address, scheduler_address, objstore_addresses[1], is_driver=True, mode=ray.SCRIPT_MODE, worker=w2) # putting and getting an object shouldn't change it for data in ["h", "h" * 10000, 0, 0.0]: