From 87bb7a8f67069d022104d51d5150b9f48d324203 Mon Sep 17 00:00:00 2001 From: Robert Nishihara Date: Mon, 15 Aug 2016 11:02:54 -0700 Subject: [PATCH] [WIP] Large changes to make the tests pass. (#376) * Revert "Make tests more informative (#372)" This reverts commit fd353250c8ff6214c2dfbca6872e2265e94f7e02. * fix bugs, in particular deactivate worker service on driver and remove condition variables * changes to minimize the changes in this PR * switch from faulty mutex synchronization to using atomics * Increase the default size of the message queues, to accommodate exporting large numbers of remote functions. This is a temporary fix, but not a long term solution. * Reorganize the scheduler export code to queue up exports. This does not solve the underlying problem yet, but sets up a solution. * Start a separate thread on driver to print error messages by constantly querying the scheduler. This is a temporary solution because the solution based on starting a worker service for the driver which the scheduler can push error messages to is buggy. * Fix segfault in taskcapsule destructor. * Move tests for catching errors into a separate test file. * Revert "roll back grpc (#368)" This reverts commit c01ef95d04c9088cac226e9715710ae38b3a610a. --- .travis.yml | 1 + lib/python/ray/services.py | 17 +- lib/python/ray/worker.py | 52 ++++- src/ipc.h | 2 +- src/objstore.cc | 74 +++---- src/objstore.h | 9 +- src/raylib.cc | 28 +-- src/scheduler.cc | 96 ++++++--- src/scheduler.h | 31 ++- src/worker.cc | 94 +++++---- src/worker.h | 41 ++-- test/array_test.py | 50 ++--- test/failure_test.py | 145 ++++++++++++++ test/microbenchmarks.py | 14 +- test/runtest.py | 386 ++++++++++++++++--------------------- thirdparty/grpc | 2 +- 16 files changed, 602 insertions(+), 440 deletions(-) create mode 100644 test/failure_test.py diff --git a/.travis.yml b/.travis.yml index 38888e977..18838f9a3 100644 --- a/.travis.yml +++ b/.travis.yml @@ -18,6 +18,7 @@ install: script: - ./test/travis-ci/run_test.sh --docker-image=amplab/ray:test-base 'source setup-env.sh && cd test && python runtest.py' - ./test/travis-ci/run_test.sh --docker-image=amplab/ray:test-base 'source setup-env.sh && cd test && python array_test.py' + - ./test/travis-ci/run_test.sh --docker-image=amplab/ray:test-base 'source setup-env.sh && cd test && python failure_test.py' - ./test/travis-ci/run_test.sh --docker-image=amplab/ray:test-base 'source setup-env.sh && cd test && python microbenchmarks.py' - ./test/travis-ci/run_test.sh --docker-only --shm-size=500m --docker-image=amplab/ray:test-examples 'source setup-env.sh && cd examples/hyperopt && python driver.py' - ./test/travis-ci/run_test.sh --docker-only --shm-size=500m --docker-image=amplab/ray:test-examples 'source setup-env.sh && cd examples/lbfgs && python driver.py' diff --git a/lib/python/ray/services.py b/lib/python/ray/services.py index 841c398d9..7b21cfcc6 100644 --- a/lib/python/ray/services.py +++ b/lib/python/ray/services.py @@ -2,7 +2,8 @@ import os import sys import time import subprocess32 as subprocess -import numpy as np +import string +import random # Ray modules import config @@ -21,7 +22,7 @@ def address(host, port): return host + ":" + str(port) def new_scheduler_port(): - return np.random.randint(10000, 65536) + return random.randint(10000, 65535) def cleanup(): """When running in local mode, shutdown the Ray processes. @@ -72,16 +73,16 @@ def start_objstore(scheduler_address, node_ip_address, cleanup): scheduler_address (str): The ip address and port of the scheduler to connect to. node_ip_address (str): The ip address of the node running the object store. - The object store's port number will be chosen by the object store process. cleanup (bool): True if using Ray in local mode. If cleanup is true, then this process will be killed by serices.cleanup() when the Python process that imported services exits. """ - p = subprocess.Popen(["objstore", scheduler_address, node_ip_address, "--log-file-prefix", config.get_log_file_path("")], env=_services_env) + random_string = "".join(random.choice(string.ascii_uppercase + string.digits) for _ in range(10)) + p = subprocess.Popen(["objstore", scheduler_address, node_ip_address, "--log-file-name", config.get_log_file_path("-".join(["objstore", random_string]) + ".log")], env=_services_env) if cleanup: all_processes.append(p) -def start_worker(node_ip_address, worker_path, scheduler_address, cleanup=True, user_source_directory=None): +def start_worker(node_ip_address, worker_path, scheduler_address, objstore_address=None, cleanup=True, user_source_directory=None): """This method starts a worker process. Args: @@ -90,6 +91,8 @@ def start_worker(node_ip_address, worker_path, scheduler_address, cleanup=True, run. scheduler_address (str): The ip address and port of the scheduler to connect to. + objstore_address (Optional[str]): The ip address and port of the object + store to connect to. cleanup (Optional[bool]): True if using Ray in local mode. If cleanup is true, then this process will be killed by serices.cleanup() when the Python process that imported services exits. This is True by default. @@ -106,6 +109,8 @@ def start_worker(node_ip_address, worker_path, scheduler_address, cleanup=True, "--node-ip-address=" + node_ip_address, "--user-source-directory=" + user_source_directory, "--scheduler-address=" + scheduler_address] + if objstore_address is not None: + command.append("--objstore-address=" + objstore_address) p = subprocess.Popen(command) if cleanup: all_processes.append(p) @@ -155,7 +160,7 @@ def start_workers(scheduler_address, objstore_address, num_workers, worker_path) """ node_ip_address = objstore_address.split(":")[0] for _ in range(num_workers): - start_worker(node_ip_address, worker_path, scheduler_address, objstore_address=objstore_address, cleanup=False) + start_worker(node_ip_address, worker_path, scheduler_address, cleanup=False) def start_ray_local(node_ip_address="127.0.0.1", num_objstores=1, num_workers=0, worker_path=None): """Start Ray in local mode. diff --git a/lib/python/ray/worker.py b/lib/python/ray/worker.py index 3dbb00ac8..39dab8f88 100644 --- a/lib/python/ray/worker.py +++ b/lib/python/ray/worker.py @@ -9,6 +9,7 @@ import funcsigs import numpy as np import colorama import atexit +import threading # Ray modules import config @@ -368,9 +369,6 @@ class Worker(object): eventually does call connect, if it is a driver, it will export these functions to the scheduler. If cached_remote_functions is None, that means that connect has been called already. - num_failed_tasks (int): The number of tasks that have failed and whose error - messages have been displayed to the user. We use this value to know when - a failed task hasn't been seen by the user and should be displayed. """ def __init__(self): @@ -379,7 +377,6 @@ class Worker(object): self.handle = None self.mode = None self.cached_remote_functions = [] - self.num_failed_tasks = 0 def set_mode(self, mode): """Set the mode of the worker. @@ -538,6 +535,9 @@ made by one task do not affect other tasks. logger = logging.getLogger("ray") """Logger: The logging object for the Python worker code.""" +class RayConnectionError(Exception): + pass + def check_connected(worker=global_worker): """Check if the worker is connected. @@ -545,7 +545,7 @@ def check_connected(worker=global_worker): Exception: An exception is raised if the worker is not connected. """ if worker.handle is None and worker.mode != raylib.PYTHON_MODE: - raise Exception("This command cannot be called before a Ray cluster has been started. You can start one with 'ray.init(start_ray_local=True, num_workers=1)'.") + raise RayConnectionError("This command cannot be called before a Ray cluster has been started. You can start one with 'ray.init(start_ray_local=True, num_workers=1)'.") def print_failed_task(task_status): """Print information about failed tasks. @@ -678,6 +678,37 @@ def cleanup(worker=global_worker): atexit.register(cleanup) +def print_error_messages(worker=global_worker): + num_failed_tasks = 0 + num_failed_remote_function_imports = 0 + num_failed_reusable_variable_imports = 0 + num_failed_reusable_variable_reinitializations = 0 + while True: + try: + info = task_info(worker=worker) + # Print failed task errors. + for error in info["failed_tasks"][num_failed_tasks:]: + print error["error_message"] + num_failed_tasks = len(info["failed_tasks"]) + # Print remote function import errors. + for error in info["failed_remote_function_imports"][num_failed_remote_function_imports:]: + print error["error_message"] + num_failed_remote_function_imports = len(info["failed_remote_function_imports"]) + # Print reusable variable import errors. + for error in info["failed_reusable_variable_imports"][num_failed_reusable_variable_imports:]: + print error["error_message"] + num_failed_reusable_variable_imports = len(info["failed_reusable_variable_imports"]) + # Print reusable variable reinitialization errors. + for error in info["failed_reinitialize_reusable_variables"][num_failed_reusable_variable_reinitializations:]: + print error["error_message"] + num_failed_reusable_variable_reinitializations = len(info["failed_reinitialize_reusable_variables"]) + except RayConnectionError: + # When the driver is exiting, we set worker.handle to None, which will cause + # the check_connected call inside of task_info to raise an exception. We use + # the try block here to suppress that exception. + pass + time.sleep(0.2) + def connect(node_ip_address, scheduler_address, objstore_address=None, worker=global_worker, mode=raylib.WORKER_MODE): """Connect this worker to the scheduler and an object store. @@ -702,6 +733,17 @@ def connect(node_ip_address, scheduler_address, objstore_address=None, worker=gl # receive commands from the scheduler. This call also sets up a queue between # the worker and the worker service. worker.handle, worker.worker_address = raylib.create_worker(node_ip_address, scheduler_address, objstore_address if objstore_address is not None else "", mode) + # If this is a driver running in SCRIPT_MODE, start a thread to print error + # messages asynchronously in the background. Ideally the scheduler would push + # messages to the driver's worker service, but we ran into bugs when trying to + # properly shutdown the driver's worker service, so we are temporarily using + # this implementation which constantly queries the scheduler for new error + # messages. + if mode == raylib.SCRIPT_MODE: + t = threading.Thread(target=print_error_messages, args=(worker,)) + # Making the thread a daemon causes it to exit when the main thread exits. + t.daemon = True + t.start() worker.set_mode(mode) FORMAT = "%(asctime)-15s %(message)s" # Configure the Python logging module. Note that if we do not provide our own diff --git a/src/ipc.h b/src/ipc.h index 6ce42f4f9..03300f4ba 100644 --- a/src/ipc.h +++ b/src/ipc.h @@ -55,7 +55,7 @@ private: template class MessageQueue : public MessageQueue<> { public: - bool connect(const std::string& name, bool create, size_t capacity = 100) { return MessageQueue<>::connect(name, create, sizeof(T), capacity); } + bool connect(const std::string& name, bool create, size_t capacity = 1000) { return MessageQueue<>::connect(name, create, sizeof(T), capacity); } bool send(const T* object) { return MessageQueue<>::send(object, sizeof(*object)); }; bool receive(T* object) { return MessageQueue<>::receive(object, sizeof(*object)); } }; diff --git a/src/objstore.cc b/src/objstore.cc index 76721555d..c292d8258 100644 --- a/src/objstore.cc +++ b/src/objstore.cc @@ -39,24 +39,20 @@ void ObjStoreService::get_data_from(ObjectID objectid, ObjStore::Stub& stub) { RAY_LOG(RAY_DEBUG, "finished streaming data, objectid was " << objectid << " and size was " << num_bytes); } -ObjStoreService::ObjStoreService(const std::string& scheduler_address) - : scheduler_address_(scheduler_address) { +ObjStoreService::ObjStoreService(std::shared_ptr scheduler_channel) + : scheduler_stub_(Scheduler::NewStub(scheduler_channel)) { } -void ObjStoreService::register_objstore() { - RAY_CHECK(!objstore_address_.empty(), "The object store address must be set before register_objstore is called."); - // Create the scheduler stub. - auto scheduler_channel = grpc::CreateChannel(scheduler_address_, grpc::InsecureChannelCredentials()); - scheduler_stub_ = Scheduler::NewStub(scheduler_channel); - - // Create message queue to receive requests from workers. - std::string recv_queue_name = std::string("queue:") + objstore_address_ + std::string(":obj"); - RAY_LOG(RAY_INFO, "Object store creating queue with name " << recv_queue_name << " to receive requests from workers."); +void ObjStoreService::register_objstore(const std::string& objstore_address, const std::string& recv_queue_name) { + // Create the queue that will be used by workers to send requests to the + // object store. + RAY_LOG(RAY_INFO, "Object store is creating queue with name " << recv_queue_name); RAY_CHECK(recv_queue_.connect(recv_queue_name, true), "error connecting recv_queue_"); - // Register the objecet store with the scheduler. + objstore_address_ = objstore_address; + // Register the object store with the scheduler. ClientContext context; RegisterObjStoreRequest request; - request.set_objstore_address(objstore_address_); + request.set_objstore_address(objstore_address); RegisterObjStoreReply reply; RAY_CHECK_GRPC(scheduler_stub_->RegisterObjStore(&context, request, &reply)); objstoreid_ = reply.objstoreid(); @@ -331,40 +327,26 @@ void ObjStoreService::start_objstore_service() { }); } -void set_logfile(const char* log_file_prefix, const std::string& node_ip_address, int port) { - if (log_file_prefix) { - std::string log_file_name = std::string(log_file_prefix) + "objstore-" + node_ip_address + "-" + std::to_string(port) + ".log"; - create_log_dir_or_die(log_file_name.c_str()); - global_ray_config.log_to_file = true; - global_ray_config.logfile.open(log_file_name); - } else { - std::cout << "object store: writing logs to stdout; you can change this by passing --log-file-prefix to ./objstore" << std::endl; - global_ray_config.log_to_file = false; - } -} - -void start_objstore(const std::string& scheduler_address, const std::string& node_ip_address, const char* log_file_prefix) { - // Initialize the object store. - ObjStoreService service(scheduler_address); - int port; +void start_objstore(const char* scheduler_addr, const char* node_ip_address) { + RAY_LOG(RAY_INFO, "Starting an object store on node " << std::string(node_ip_address)); + auto scheduler_channel = grpc::CreateChannel(scheduler_addr, grpc::InsecureChannelCredentials()); + RAY_LOG(RAY_INFO, "Object store connected to scheduler " << scheduler_addr); + ObjStoreService service(scheduler_channel); ServerBuilder builder; // Get GRPC to assign an unused port. + int port; builder.AddListeningPort(std::string("0.0.0.0:0"), grpc::InsecureServerCredentials(), &port); builder.RegisterService(&service); std::unique_ptr server(builder.BuildAndStart()); if (server == nullptr) { - RAY_CHECK(false, "Failed to create the object store server.") + RAY_CHECK(false, "Failed to create the object store service."); } - // Set the object store address. - service.set_objstore_address(node_ip_address + ":" + std::to_string(port)); - // Set the logfile. - set_logfile(log_file_prefix, node_ip_address, port); - // Register the object store with the scheduler. - service.register_objstore(); - // Launch a thread to process incoming messages in the message queue from - // the workers. + std::string objstore_address = std::string(node_ip_address) + ":" + std::to_string(port); + RAY_LOG(RAY_INFO, "This object store has address " << objstore_address); + std::string recv_queue_name = std::string("queue:") + objstore_address + std::string(":obj"); + service.register_objstore(objstore_address, recv_queue_name); service.start_objstore_service(); - // Process incoming GRPC calls. These may come from the schedeler or from + // Process incoming GRPC calls. These may come from the scheduler or from // other object stores. This method does not return. server->Wait(); } @@ -374,12 +356,20 @@ RayConfig global_ray_config; int main(int argc, char** argv) { RAY_CHECK_GE(argc, 3, "object store: expected at least two arguments (scheduler ip address and object store ip address)"); - const char* log_file_prefix = nullptr; if (argc > 3) { - log_file_prefix = get_cmd_option(argv, argv + argc, "--log-file-prefix"); + const char* log_file_name = get_cmd_option(argv, argv + argc, "--log-file-name"); + if (log_file_name) { + std::cout << "object store: writing to log file " << log_file_name << std::endl; + create_log_dir_or_die(log_file_name); + global_ray_config.log_to_file = true; + global_ray_config.logfile.open(log_file_name); + } else { + std::cout << "object store: writing logs to stdout; you can change this by passing --log-file-name to ./scheduler" << std::endl; + global_ray_config.log_to_file = false; + } } - start_objstore(argv[1], argv[2], log_file_prefix); + start_objstore(argv[1], argv[2]); return 0; } diff --git a/src/objstore.h b/src/objstore.h index e35566c9c..351b09068 100644 --- a/src/objstore.h +++ b/src/objstore.h @@ -37,12 +37,7 @@ enum MemoryStatusType {READY = 0, NOT_READY = 1, DEALLOCATED = 2, NOT_PRESENT = class ObjStoreService final : public ObjStore::Service { public: - ObjStoreService(const std::string& scheduler_address); - // Create the scheduler stub, register the object store with the scheduler, - // and create a message queue for workers to connect to. - void register_objstore(); - // Set the object store address. - void set_objstore_address(const std::string& objstore_address) { objstore_address_ = objstore_address; } + ObjStoreService(std::shared_ptr scheduler_channel); Status StartDelivery(ServerContext* context, const StartDeliveryRequest* request, AckReply* reply) override; Status StreamObjTo(ServerContext* context, const StreamObjToRequest* request, ServerWriter* writer) override; @@ -50,6 +45,7 @@ public: Status DeallocateObject(ServerContext* context, const DeallocateObjectRequest* request, AckReply* reply) override; Status ObjStoreInfo(ServerContext* context, const ObjStoreInfoRequest* request, ObjStoreInfoReply* reply) override; void start_objstore_service(); + void register_objstore(const std::string& objstore_address, const std::string& recv_queue_name); private: void get_data_from(ObjectID objectid, ObjStore::Stub& stub); // check if we already connected to the other objstore, if yes, return reference to connection, otherwise connect @@ -62,7 +58,6 @@ private: void object_ready(ObjectID objectid, size_t metadata_offset); static const size_t CHUNK_SIZE; - std::string scheduler_address_; std::string objstore_address_; ObjStoreId objstoreid_; // id of this objectstore in the scheduler object store table std::shared_ptr segmentpool_; diff --git a/src/raylib.cc b/src/raylib.cc index a16786741..de7f1ab7c 100644 --- a/src/raylib.cc +++ b/src/raylib.cc @@ -288,7 +288,7 @@ int serialize(PyObject* worker_capsule, PyObject* val, Obj* obj, std::vectormutable_string_data()->set_data(buffer, length); + obj->mutable_string_data()->set_data(std::string(buffer, length)); } else if (PyUnicode_Check(val)) { Py_ssize_t length; #if PY_MAJOR_VERSION >= 3 @@ -298,7 +298,7 @@ int serialize(PyObject* worker_capsule, PyObject* val, Obj* obj, std::vectormutable_unicode_data()->set_data(data, length); + obj->mutable_unicode_data()->set_data(std::string(data, length)); Py_XDECREF(str); } else if (val == Py_None) { obj->mutable_empty_data(); // allocate an Empty object, this is a None @@ -584,7 +584,7 @@ static PyObject* serialize_task(PyObject* self, PyObject* args) { if (!PyArg_ParseTuple(args, "Os#O", &worker_capsule, &name, &len, &arguments)) { return NULL; } - task->set_name(name, len); + task->set_name(std::string(name, len)); std::vector objectids; // This is a vector of all the objectids that are serialized in this task, including objectids that are contained in Python objects that are passed by value. if (PyList_Check(arguments)) { for (size_t i = 0, size = PyList_Size(arguments); i < size; ++i) { @@ -665,12 +665,12 @@ static PyObject* create_worker(PyObject* self, PyObject* args) { // The object store address can be the empty string, in which case the // scheduler will choose the object store address. const char* objstore_address; - Mode mode; + int mode; if (!PyArg_ParseTuple(args, "sssi", &node_ip_address, &scheduler_address, &objstore_address, &mode)) { return NULL; } bool is_driver = (mode != Mode::WORKER_MODE); - Worker* worker = new Worker(std::string(node_ip_address), std::string(scheduler_address), mode); + Worker* worker = new Worker(std::string(node_ip_address), std::string(scheduler_address), static_cast(mode)); worker->register_worker(std::string(node_ip_address), std::string(objstore_address), is_driver); PyObject* t = PyTuple_New(2); @@ -785,6 +785,7 @@ static PyObject* submit_task(PyObject* self, PyObject* args) { request.set_allocated_task(task); SubmitTaskReply reply = worker->submit_task(&request); if (!reply.function_registered()) { + request.release_task(); PyErr_SetString(RayError, "task: function not registered"); return NULL; } @@ -824,11 +825,11 @@ static PyObject* notify_failure(PyObject* self, PyObject* args) { Worker* worker; const char* name; const char* error_message; - FailedType type; + int type; if (!PyArg_ParseTuple(args, "O&ssi", &PyObjectToWorker, &worker, &name, &error_message, &type)) { return NULL; } - worker->notify_failure(type, std::string(name), std::string(error_message)); + worker->notify_failure(static_cast(type), std::string(name), std::string(error_message)); Py_RETURN_NONE; } @@ -900,16 +901,6 @@ static PyObject* alias_objectids(PyObject* self, PyObject* args) { Py_RETURN_NONE; } -static PyObject* start_worker_service(PyObject* self, PyObject* args) { - Worker* worker; - Mode mode; - if (!PyArg_ParseTuple(args, "O&i", &PyObjectToWorker, &worker, &mode)) { - return NULL; - } - worker->start_worker_service(mode); - Py_RETURN_NONE; -} - static PyObject* scheduler_info(PyObject* self, PyObject* args) { Worker* worker; if (!PyArg_ParseTuple(args, "O&", &PyObjectToWorker, &worker)) { @@ -1077,8 +1068,7 @@ static PyMethodDef RayLibMethods[] = { { "alias_objectids", alias_objectids, METH_VARARGS, "make two objectids refer to the same object" }, { "wait_for_next_message", wait_for_next_message, METH_VARARGS, "get next message from scheduler (blocking)" }, { "submit_task", submit_task, METH_VARARGS, "call a remote function" }, - { "ready_for_new_task", ready_for_new_task, METH_VARARGS, "notify the scheduler that a task has been completed" }, - { "start_worker_service", start_worker_service, METH_VARARGS, "start the worker service" }, + { "ready_for_new_task", ready_for_new_task, METH_VARARGS, "notify the scheduler that the worker is ready for a new task" }, { "scheduler_info", scheduler_info, METH_VARARGS, "get info about scheduler state" }, { "task_info", task_info, METH_VARARGS, "get information about task statuses and failures" }, { "export_remote_function", export_remote_function, METH_VARARGS, "export a remote function to workers" }, diff --git a/src/scheduler.cc b/src/scheduler.cc index dd9093b88..d4bc238f1 100644 --- a/src/scheduler.cc +++ b/src/scheduler.cc @@ -328,7 +328,7 @@ Status SchedulerService::NotifyFailure(ServerContext* context, const NotifyFailu PrintErrorMessageRequest print_request; print_request.mutable_failure()->CopyFrom(request->failure()); AckReply print_reply; - RAY_CHECK_GRPC(worker->worker_stub->PrintErrorMessage(&client_context, print_request, &print_reply)); + // RAY_CHECK_GRPC(worker->worker_stub->PrintErrorMessage(&client_context, print_request, &print_reply)); } } } @@ -365,10 +365,10 @@ Status SchedulerService::ReadyForNewTask(ServerContext* context, const ReadyForN // all of the exported functions and all of the exported reusable variables. if (!(*workers)[workerid].initialized) { // This should only happen once. - // Import all remote functions on the worker. - export_all_functions_to_worker(workerid, workers, GET(exported_functions_)); - // Import all reusable variables on the worker. - export_all_reusable_variables_to_worker(workerid, workers, GET(exported_reusable_variables_)); + // Queue up all remote functions to be imported on the worker. + add_all_remote_functions_to_worker_export_queue(workerid); + // Queue up all reusable variables to be imported on the worker. + add_all_reusable_variables_to_worker_export_queue(workerid); // Mark the worker as initialized. (*workers)[workerid].initialized = true; } @@ -514,28 +514,38 @@ Status SchedulerService::KillWorkers(ServerContext* context, const KillWorkersRe } Status SchedulerService::ExportRemoteFunction(ServerContext* context, const ExportRemoteFunctionRequest* request, AckReply* reply) { - auto workers = GET(workers_); - auto exported_functions = GET(exported_functions_); - // TODO(rkn): Does this do a deep copy? - exported_functions->push_back(std::unique_ptr(new Function(request->function()))); - for (size_t i = 0; i < workers->size(); ++i) { - if ((*workers)[i].current_task != ROOT_OPERATION) { - export_function_to_worker(i, exported_functions->size() - 1, workers, exported_functions); + { + auto workers = GET(workers_); + auto remote_function_export_queue = GET(remote_function_export_queue_); + auto exported_functions = GET(exported_functions_); + // TODO(rkn): Does this do a deep copy? + exported_functions->push_back(std::unique_ptr(new Function(request->function()))); + for (WorkerId workerid = 0; workerid < workers->size(); ++workerid) { + if ((*workers)[workerid].current_task != ROOT_OPERATION) { + // Add this workerid and remote function pair to the export queue. + remote_function_export_queue->push(std::make_pair(workerid, exported_functions->size() - 1)); + } } } + schedule(); return Status::OK; } Status SchedulerService::ExportReusableVariable(ServerContext* context, const ExportReusableVariableRequest* request, AckReply* reply) { - auto workers = GET(workers_); - auto exported_reusable_variables = GET(exported_reusable_variables_); - // TODO(rkn): Does this do a deep copy? - exported_reusable_variables->push_back(std::unique_ptr(new ReusableVar(request->reusable_variable()))); - for (size_t i = 0; i < workers->size(); ++i) { - if ((*workers)[i].current_task != ROOT_OPERATION) { - export_reusable_variable_to_worker(i, exported_reusable_variables->size() - 1, workers, exported_reusable_variables); + { + auto workers = GET(workers_); + auto reusable_variable_export_queue = GET(reusable_variable_export_queue_); + auto exported_reusable_variables = GET(exported_reusable_variables_); + // TODO(rkn): Does this do a deep copy? + exported_reusable_variables->push_back(std::unique_ptr(new ReusableVar(request->reusable_variable()))); + for (WorkerId workerid = 0; workerid < workers->size(); ++workerid) { + if ((*workers)[workerid].current_task != ROOT_OPERATION) { + // Add this workerid and reusable variable pair to the export queue. + reusable_variable_export_queue->push(std::make_pair(workerid, exported_reusable_variables->size() - 1)); + } } } + schedule(); return Status::OK; } @@ -584,8 +594,16 @@ void SchedulerService::deliver_object_async(ObjectID canonical_objectid, ObjStor } void SchedulerService::schedule() { - // TODO(rkn): Do this more intelligently. - perform_gets(); // See what we can do in get_queue_ + // Export remote functions to the workers. This must happen before we schedule + // tasks in order to guarantee that remote function calls use the most up to + // date definitions. + perform_remote_function_exports(); + // Export reusable variables to the workers. This must happen before we + // schedule tasks in order to guarantee that the workers have the definitions + // they need. + perform_reusable_variable_exports(); + // See what we can do in get_queue_ + perform_gets(); if (scheduling_algorithm_ == SCHEDULING_ALGORITHM_NAIVE) { schedule_tasks_naively(); // See what we can do in task_queue_ } else if (scheduling_algorithm_ == SCHEDULING_ALGORITHM_LOCALITY_AWARE) { @@ -768,6 +786,28 @@ bool SchedulerService::is_canonical(ObjectID objectid) { return objectid == (*target_objectids)[objectid]; } +void SchedulerService::perform_remote_function_exports() { + auto workers = GET(workers_); + auto remote_function_export_queue = GET(remote_function_export_queue_); + auto exported_functions = GET(exported_functions_); + while (!remote_function_export_queue->empty()) { + std::pair workerid_functionid_pair = remote_function_export_queue->front(); + export_function_to_worker(workerid_functionid_pair.first, workerid_functionid_pair.second, workers, exported_functions); + remote_function_export_queue->pop(); + } +} + +void SchedulerService::perform_reusable_variable_exports() { + auto workers = GET(workers_); + auto reusable_variable_export_queue = GET(reusable_variable_export_queue_); + auto exported_reusable_variables = GET(exported_reusable_variables_); + while (!reusable_variable_export_queue->empty()) { + std::pair workerid_variableid_pair = reusable_variable_export_queue->front(); + export_reusable_variable_to_worker(workerid_variableid_pair.first, workerid_variableid_pair.second, workers, exported_reusable_variables); + reusable_variable_export_queue->pop(); + } +} + void SchedulerService::perform_gets() { auto get_queue = GET(get_queue_); // Complete all get tasks that can be completed. @@ -1047,15 +1087,19 @@ void SchedulerService::export_reusable_variable_to_worker(WorkerId workerid, int RAY_CHECK_GRPC((*workers)[workerid].worker_stub->ImportReusableVariable(&import_context, import_request, &import_reply)); } -void SchedulerService::export_all_functions_to_worker(WorkerId workerid, MySynchronizedPtr > &workers, const MySynchronizedPtr > > &exported_functions) { +void SchedulerService::add_all_remote_functions_to_worker_export_queue(WorkerId workerid) { + auto remote_function_export_queue = GET(remote_function_export_queue_); + auto exported_functions = GET(exported_functions_); for (int i = 0; i < exported_functions->size(); ++i) { - export_function_to_worker(workerid, i, workers, exported_functions); + remote_function_export_queue->push(std::make_pair(workerid, i)); } } -void SchedulerService::export_all_reusable_variables_to_worker(WorkerId workerid, MySynchronizedPtr > &workers, const MySynchronizedPtr > > &exported_reusable_variables) { +void SchedulerService::add_all_reusable_variables_to_worker_export_queue(WorkerId workerid) { + auto reusable_variable_export_queue = GET(reusable_variable_export_queue_); + auto exported_reusable_variables = GET(exported_reusable_variables_); for (int i = 0; i < exported_reusable_variables->size(); ++i) { - export_reusable_variable_to_worker(workerid, i, workers, exported_reusable_variables); + reusable_variable_export_queue->push(std::make_pair(workerid, i)); } } @@ -1070,7 +1114,7 @@ void start_scheduler_service(const char* service_addr, SchedulingAlgorithmType s builder.RegisterService(&service); std::unique_ptr server(builder.BuildAndStart()); if (server == nullptr) { - RAY_CHECK(false, "Failed to create the scheduler server.") + RAY_CHECK(false, "Failed to create the scheduler service."); } server->Wait(); } diff --git a/src/scheduler.h b/src/scheduler.h index 45372463a..1a9cc6075 100644 --- a/src/scheduler.h +++ b/src/scheduler.h @@ -2,6 +2,7 @@ #define RAY_SCHEDULER_H +#include #include #include #include @@ -117,6 +118,10 @@ private: ObjStoreId pick_objstore(ObjectID objectid); // checks if objectid is a canonical objectid bool is_canonical(ObjectID objectid); + // Export all queued up remote functions. + void perform_remote_function_exports(); + // Export all queued up reusable variables. + void perform_reusable_variable_exports(); void perform_gets(); // schedule tasks using the naive algorithm void schedule_tasks_naively(); @@ -146,16 +151,12 @@ private: void export_function_to_worker(WorkerId workerid, int function_index, MySynchronizedPtr > &workers, const MySynchronizedPtr > > &exported_functions); // Export a reusable variable to a worker void export_reusable_variable_to_worker(WorkerId workerid, int reusable_variable_index, MySynchronizedPtr > &workers, const MySynchronizedPtr > > &exported_reusable_variables); - // Export all reusable variables to a worker. This is used when a new worker - // registers and is protected by the workers lock (which is passed in) to - // ensure that no other reusable variables are exported to the worker while - // this method is being called. - void export_all_functions_to_worker(WorkerId workerid, MySynchronizedPtr > &workers, const MySynchronizedPtr > > &exported_functions); - // Export all remote functions to a worker. This is used when a new worker - // registers and is protected by the workers lock (which is passed in) to - // ensure that no other remote functions are exported to the worker while this - // method is being called. - void export_all_reusable_variables_to_worker(WorkerId workerid, MySynchronizedPtr > &workers, const MySynchronizedPtr > > &exported_reusable_variables); + // Add to the remote function export queue the job of exporting all remote + // functions to the given worker. This is used when a new worker registers. + void add_all_remote_functions_to_worker_export_queue(WorkerId workerid); + // Add to the reusable variable export queue the job of exporting all reusable + // variables to the given worker. This is used when a new worker registers. + void add_all_reusable_variables_to_worker_export_queue(WorkerId workerid); template MySynchronizedPtr get(Synchronized& my_field, const char* name,unsigned int line_number); @@ -224,6 +225,16 @@ private: // lock (objects_lock_). // TODO(rkn): Consider making this part of the // objtable data structure. std::vector > objects_in_transit_; + // List of pending remote function exports. These should be processed in a + // first in first out manner. The first element of each pair is the ID of the + // worker to export the remote function to, and the second element of each + // pair is the index of the function to export. + Synchronized > > remote_function_export_queue_; + // List of pending reusable variable exports. These should be processed in a + // first in first out manner. The first element of each pair is the ID of the + // worker to export the reusable variable to, and the second element of each + // pair is the index of the reusable variable to export. + Synchronized > > reusable_variable_export_queue_; // All of the remote functions that have been exported to the workers. Synchronized > > exported_functions_; // All of the reusable variables that have been exported to the workers. diff --git a/src/worker.cc b/src/worker.cc index c501db2c2..1049465e5 100644 --- a/src/worker.cc +++ b/src/worker.cc @@ -1,5 +1,7 @@ #include "worker.h" +#include +#include #include #include @@ -9,8 +11,11 @@ extern "C" { static PyObject *RayError; } -inline WorkerServiceImpl::WorkerServiceImpl(Mode mode) - : mode_(mode) {} +inline WorkerServiceImpl::WorkerServiceImpl(const std::string& send_queue_name, Mode mode) + : mode_(mode) { + RAY_LOG(RAY_DEBUG, "Worker service connecting to queue " << send_queue_name); + RAY_CHECK(send_queue_.connect(send_queue_name, false), "error connecting send_queue_"); +} Status WorkerServiceImpl::ExecuteTask(ServerContext* context, const ExecuteTaskRequest* request, AckReply* reply) { RAY_CHECK(mode_ == Mode::WORKER_MODE, "ExecuteTask can only be called on workers."); @@ -84,23 +89,20 @@ Status WorkerServiceImpl::PrintErrorMessage(ServerContext* context, const PrintE return Status::OK; } -void WorkerServiceImpl::connect_to_queue() { - RAY_LOG(RAY_DEBUG, "Worker service creating queue with name " << worker_address_ << " to commmunicate with worker."); - RAY_CHECK(send_queue_.connect(worker_address_, true), "error connecting send_queue_"); -} - Worker::Worker(const std::string& node_ip_address, const std::string& scheduler_address, Mode mode) - : node_ip_address_(node_ip_address), - scheduler_address_(scheduler_address), + : scheduler_address_(scheduler_address), + node_ip_address_(node_ip_address), mode_(mode) { - // Connect to the scheduler service. - RAY_LOG(RAY_DEBUG, "Worker creating a scheduler stub.") - auto scheduler_channel = grpc::CreateChannel(scheduler_address_, grpc::InsecureChannelCredentials()); + auto scheduler_channel = grpc::CreateChannel(scheduler_address, grpc::InsecureChannelCredentials()); scheduler_stub_ = Scheduler::NewStub(scheduler_channel); - // Start the worker service. This will find an unused port which is stored in - // worker_port_. This also sets up a message queue between the worker and the - // worker service. - start_worker_service(mode_); + // Generate a random string to use for naming the message queue to avoid + // collisions with message queues created by other workers. + std::random_device rd; + std::mt19937 rng(rd()); + std::uniform_int_distribution queue_name_generator(0, 10000000); + receive_queue_name_ = "worker_receive_queue:" + std::to_string(queue_name_generator(rng)); + RAY_LOG(RAY_DEBUG, "Worker creating queue " << receive_queue_name_ << std::endl); + RAY_CHECK(receive_queue_.connect(receive_queue_name_, true), "error connecting receive_queue_"); } @@ -128,6 +130,10 @@ bool Worker::kill_workers(ClientContext &context) { } void Worker::register_worker(const std::string& node_ip_address, const std::string& objstore_address, bool is_driver) { + if (mode_ == Mode::WORKER_MODE) { + start_worker_service(mode_); + RAY_CHECK(!worker_address_.empty(), "The worker address is empty. This should be initialized by start_worker_service, so it is possible that the thread synchronization failed.") + } unsigned int retry_wait_milliseconds = 20; RegisterWorkerRequest request; request.set_node_ip_address(node_ip_address); @@ -390,7 +396,7 @@ std::unique_ptr Worker::receive_next_message() { } void Worker::ready_for_new_task() { - RAY_CHECK(connected_, "Attempted to perform notify_task_completed but failed."); + RAY_CHECK(connected_, "Attempted to perform ready_for_new_task but failed."); ClientContext context; ReadyForNewTaskRequest request; request.set_workerid(workerid_); @@ -402,9 +408,9 @@ void Worker::disconnect() { connected_ = false; // Shut down the worker service. This will cause the call to server->Wait() to // return. - server_ptr_->Shutdown(); + // server_ptr_->Shutdown(); // Wait for the thread that launched the worker service to return. - worker_server_thread_.join(); + // worker_server_thread_.join(); } // TODO(rkn): Should we be using pointers or references? And should they be const? @@ -445,49 +451,37 @@ void Worker::export_reusable_variable(const std::string& name, const std::string // (in our case running in the main thread), whereas the WorkerService will // run in a separate thread and potentially utilize multiple threads. void Worker::start_worker_service(Mode mode) { - RAY_LOG(RAY_DEBUG, "Worker is starting the worker service."); - // Signal when the worker service has started. - std::condition_variable worker_service_started; - // Lock for the above condition. - std::mutex worker_service_started_mutex; + // Use atomics so the worker service thread can signal the outside thread that + // the worker service has been started. + std::atomic_bool worker_service_started; + worker_service_started.store(false); // Launch a new thread for running the worker service. We store this as a // field so that we can clean it up when we disconnect the worker. worker_server_thread_ = std::thread([this, mode, &worker_service_started]() { + // Create the worker service. + WorkerServiceImpl service(receive_queue_name_, mode); ServerBuilder builder; - // Get GRPC to assign an unused port number. - builder.AddListeningPort(std::string("0.0.0.0:0"), grpc::InsecureServerCredentials(), &worker_port_); - // Create and start the worker service. - WorkerServiceImpl service(mode); + // Let GRPC choose an unused port. + int port; + builder.AddListeningPort(std::string("0.0.0.0:0"), grpc::InsecureServerCredentials(), &port); builder.RegisterService(&service); std::unique_ptr server(builder.BuildAndStart()); - server_ptr_ = server.get(); if (server == nullptr) { - RAY_CHECK(false, "Failed to create the worker server.") + RAY_CHECK(false, "Failed to create the worker service."); } - RAY_LOG(RAY_DEBUG, "Worker service listening on " << worker_address_); - worker_address_ = node_ip_address_ + ":" + std::to_string(worker_port_); - service.set_worker_address(worker_address_); - // Connect the worker service by a queue to the worker object. - service.connect_to_queue(); - // Use the condition variable to notify the outside thread that the worker - // service has been started. - // TODO(rkn): Once this has been called, the outside thread will notify the - // scheduler that the worker is ready to receive tasks. This can happen - // before server->Wait() is called below. What happens to messages sent from - // the scheduler before the call to server->Wait()? - worker_service_started.notify_all(); + worker_address_ = node_ip_address_ + ":" + std::to_string(port); + server_ptr_ = server.get(); + RAY_LOG(RAY_INFO, "worker server listening at " << worker_address_); + worker_service_started.store(true); // Wait for work and process work. This method does not return until // Shutdown is called from a different thread. server->Wait(); RAY_LOG(RAY_INFO, "Worker service thread returning.") }); - { - // Wait until we know the worker service has been started. - std::unique_lock lock(worker_service_started_mutex); - worker_service_started.wait(lock); + // Wait for the worker service to start. This essentially implements a + // condition variable using atomics, but that failed on Mac OS X on Travis. + while (!worker_service_started.load()) { + RAY_LOG(RAY_DEBUG, "Looping while waiting for the worker service to start."); + std::this_thread::sleep_for(std::chrono::milliseconds(100)); } - // Connect to the queue for receiving messages from the worker service. - std::string receive_queue_name = worker_address_; - RAY_LOG(RAY_DEBUG, "Worker connecting to queue with name " << receive_queue_name << " to commmunicate with worker service."); - RAY_CHECK(receive_queue_.connect(receive_queue_name, false), "error connecting receive_queue_"); } diff --git a/src/worker.h b/src/worker.h index af480bbd9..bafb7f27b 100644 --- a/src/worker.h +++ b/src/worker.h @@ -1,8 +1,6 @@ #ifndef RAY_WORKER_H #define RAY_WORKER_H -#include -#include #include #include #include @@ -32,18 +30,15 @@ enum Mode {SCRIPT_MODE, WORKER_MODE, PYTHON_MODE, SILENT_MODE}; class WorkerServiceImpl final : public WorkerService::Service { public: - WorkerServiceImpl(Mode mode); + WorkerServiceImpl(const std::string& worker_address, Mode mode); Status ExecuteTask(ServerContext* context, const ExecuteTaskRequest* request, AckReply* reply) override; Status ImportRemoteFunction(ServerContext* context, const ImportRemoteFunctionRequest* request, AckReply* reply) override; Status Die(ServerContext* context, const DieRequest* request, AckReply* reply) override; Status ImportReusableVariable(ServerContext* context, const ImportReusableVariableRequest* request, AckReply* reply) override; Status PrintErrorMessage(ServerContext* context, const PrintErrorMessageRequest* request, AckReply* reply) override; - // Set worker address. - void set_worker_address(const std::string& worker_address) { worker_address_ = worker_address; } - // Connect the worker service to the worker object via a queue. - void connect_to_queue(); private: - std::string worker_address_; + // The queue used to send commands from the worker service to the worker. This + // corresponds to the receive_queue_ in the worker. MessageQueue send_queue_; // This is true if the worker service is part of a driver process and false // if it is part of a worker process. @@ -52,10 +47,8 @@ private: class Worker { public: - // This constructor constructs a stub for the scheduler service. It also - // starts the worker service, which also sets up a message queue between the - // worker and the worker service. Worker(const std::string& node_ip_address, const std::string& scheduler_address, Mode mode); + // Submit a remote task to the scheduler. If the function in the task is not // registered with the scheduler, we will sleep for retry_wait_milliseconds // and try to resubmit the task to the scheduler up to max_retries more times. @@ -92,15 +85,14 @@ class Worker { void register_remote_function(const std::string& name, size_t num_return_vals); // Notify the scheduler that a failure has occurred. void notify_failure(FailedType type, const std::string& name, const std::string& error_message); - // Start the worker server which accepts commands from the scheduler. This - // also creates a message queue that worker service uses to send messages to - // the worker. The queue is read by the Python interpreter. For drivers, these - // commands are only for printing error messages. + // Start the worker server which accepts commands from the scheduler. For + // workers, these commands are stored in the message queue, which is read by + // the Python interpreter. For drivers, these commands are only for printing + // error messages. void start_worker_service(Mode mode); // wait for next task from the RPC system. If null, it means there are no more tasks and the worker should shut down. std::unique_ptr receive_next_message(); - // tell the scheduler that we are done with the current task and request the - // next one. + // Tell the scheduler that the worker is ready for a new task. void ready_for_new_task(); // disconnect the worker void disconnect(); @@ -118,12 +110,12 @@ class Worker { const char* get_worker_address() { return worker_address_.c_str(); } private: + Mode mode_; bool connected_; const size_t CHUNK_SIZE = 8 * 1024; std::unique_ptr scheduler_stub_; Server* server_ptr_; std::thread worker_server_thread_; - MessageQueue receive_queue_; bip::managed_shared_memory segment_; WorkerId workerid_; ObjStoreId objstoreid_; @@ -131,9 +123,18 @@ class Worker { std::string objstore_address_; std::string worker_address_; std::string node_ip_address_; - int worker_port_; - Mode mode_; + // The queue used to send commands from the worker service to the worker. + // This queue is created by the worker. This corresponds to the send_queue_ in + // the worker service. + MessageQueue receive_queue_; + // The name of the receive queue. + std::string receive_queue_name_; + // The queue used to send requests to the object store. There is a single + // queue shared by all workers sending requests to the object store, and this + // queue is created by the object store. MessageQueue request_obj_queue_; + // The queue used to receive object addresses from the object store. This + // queue is created by this worker. MessageQueue receive_obj_queue_; std::shared_ptr segmentpool_; }; diff --git a/test/array_test.py b/test/array_test.py index f8d11e345..b4889e513 100644 --- a/test/array_test.py +++ b/test/array_test.py @@ -9,19 +9,13 @@ from numpy.testing import assert_equal, assert_almost_equal import ray.array.remote as ra import ray.array.distributed as da -class TestRemoteArrays(unittest.TestCase): +class RemoteArrayTest(unittest.TestCase): - @classmethod - def setUpClass(cls): + def testMethods(self): for module in [ra.core, ra.random, ra.linalg, da.core, da.random, da.linalg]: reload(module) - ray.init(start_ray_local=True, num_workers=1) + ray.init(start_ray_local=True) - @classmethod - def tearDownClass(cls): - ray.worker.cleanup() - - def test_methods(self): # test eye object_id = ra.eye.remote(3) val = ray.get(object_id) @@ -47,32 +41,40 @@ class TestRemoteArrays(unittest.TestCase): r_val = ray.get(r_id) assert_almost_equal(np.dot(q_val, r_val), a_val) -class TestDistributedArrays(unittest.TestCase): - - @classmethod - def setUpClass(cls): - for module in [ra.core, ra.random, ra.linalg, da.core, da.random, da.linalg]: - reload(module) - ray.init(start_ray_local=True, num_workers=10, num_objstores=2) - - @classmethod - def tearDownClass(cls): ray.worker.cleanup() - def test_serialization(self): +class DistributedArrayTest(unittest.TestCase): + + def testSerialization(self): + for module in [ra.core, ra.random, ra.linalg, da.core, da.random, da.linalg]: + reload(module) + ray.init(start_ray_local=True, num_workers=0) + x = da.DistArray([2, 3, 4], np.array([[[ray.put(0)]]])) capsule, _ = ray.serialization.serialize(ray.worker.global_worker.handle, x) y = ray.serialization.deserialize(ray.worker.global_worker.handle, capsule) self.assertEqual(x.shape, y.shape) self.assertEqual(x.objectids[0, 0, 0].id, y.objectids[0, 0, 0].id) - def test_assemble(self): + ray.worker.cleanup() + + def testAssemble(self): + for module in [ra.core, ra.random, ra.linalg, da.core, da.random, da.linalg]: + reload(module) + ray.init(start_ray_local=True, num_workers=1) + a = ra.ones.remote([da.BLOCK_SIZE, da.BLOCK_SIZE]) b = ra.zeros.remote([da.BLOCK_SIZE, da.BLOCK_SIZE]) x = da.DistArray([2 * da.BLOCK_SIZE, da.BLOCK_SIZE], np.array([[a], [b]])) assert_equal(x.assemble(), np.vstack([np.ones([da.BLOCK_SIZE, da.BLOCK_SIZE]), np.zeros([da.BLOCK_SIZE, da.BLOCK_SIZE])])) - def test_methods(self): + ray.worker.cleanup() + + def testMethods(self): + for module in [ra.core, ra.random, ra.linalg, da.core, da.random, da.linalg]: + reload(module) + ray.init(start_ray_local=True, num_objstores=2, num_workers=10) + x = da.zeros.remote([9, 25, 51], "float") assert_equal(ray.get(da.assemble.remote(x)), np.zeros([9, 25, 51])) @@ -205,5 +207,7 @@ class TestDistributedArrays(unittest.TestCase): d2 = np.random.randint(1, 35) test_dist_qr(d1, d2) + ray.worker.cleanup() + if __name__ == "__main__": - unittest.main(verbosity=2) + unittest.main(verbosity=2) diff --git a/test/failure_test.py b/test/failure_test.py new file mode 100644 index 000000000..809c33799 --- /dev/null +++ b/test/failure_test.py @@ -0,0 +1,145 @@ +import unittest +import ray +import time + +import test_functions + +class FailureTest(unittest.TestCase): + + def testNoArgs(self): + reload(test_functions) + ray.init(start_ray_local=True, num_workers=1, driver_mode=ray.SILENT_MODE) + + test_functions.no_op_fail.remote() + time.sleep(0.2) + task_info = ray.task_info() + self.assertEqual(len(task_info["failed_tasks"]), 1) + self.assertEqual(len(task_info["running_tasks"]), 0) + self.assertTrue("The @remote decorator for function test_functions.no_op_fail has 0 return values, but test_functions.no_op_fail returned more than 0 values." in task_info["failed_tasks"][0].get("error_message")) + + ray.worker.cleanup() + + def testTypeChecking(self): + reload(test_functions) + ray.init(start_ray_local=True, num_workers=1, driver_mode=ray.SILENT_MODE) + + # Make sure that these functions throw exceptions because there return + # values do not type check. + test_functions.test_return1.remote() + test_functions.test_return2.remote() + time.sleep(0.2) + task_info = ray.task_info() + self.assertEqual(len(task_info["failed_tasks"]), 2) + self.assertEqual(len(task_info["running_tasks"]), 0) + + ray.worker.cleanup() + +class TaskStatusTest(unittest.TestCase): + def testFailedTask(self): + reload(test_functions) + ray.init(start_ray_local=True, num_workers=3, driver_mode=ray.SILENT_MODE) + + test_functions.test_alias_f.remote() + test_functions.throw_exception_fct1.remote() + test_functions.throw_exception_fct1.remote() + for _ in range(100): # Retry if we need to wait longer. + if len(ray.task_info()["failed_tasks"]) >= 2: + break + time.sleep(0.1) + result = ray.task_info() + self.assertEqual(len(result["failed_tasks"]), 2) + task_ids = set() + for task in result["failed_tasks"]: + self.assertTrue(task.has_key("worker_address")) + self.assertTrue(task.has_key("operationid")) + self.assertTrue("Test function 1 intentionally failed." in task.get("error_message")) + self.assertTrue(task["operationid"] not in task_ids) + task_ids.add(task["operationid"]) + + x = test_functions.throw_exception_fct2.remote() + try: + ray.get(x) + except Exception as e: + self.assertTrue("Test function 2 intentionally failed."in str(e)) + else: + self.assertTrue(False) # ray.get should throw an exception + + x, y, z = test_functions.throw_exception_fct3.remote(1.0) + for ref in [x, y, z]: + try: + ray.get(ref) + except Exception as e: + self.assertTrue("Test function 3 intentionally failed."in str(e)) + else: + self.assertTrue(False) # ray.get should throw an exception + + ray.worker.cleanup() + + def testFailImportingRemoteFunction(self): + ray.init(start_ray_local=True, num_workers=2, driver_mode=ray.SILENT_MODE) + + # This example is somewhat contrived. It should be successfully pickled, and + # then it should throw an exception when it is unpickled. This may depend a + # bit on the specifics of our pickler. + def reducer(*args): + raise Exception("There is a problem here.") + class Foo(object): + def __init__(self): + self.__name__ = "Foo_object" + self.func_doc = "" + self.__globals__ = {} + def __reduce__(self): + return reducer, () + def __call__(self): + return + ray.remote([], [])(Foo()) + for _ in range(100): # Retry if we need to wait longer. + if len(ray.task_info()["failed_remote_function_imports"]) >= 1: + break + time.sleep(0.1) + self.assertTrue("There is a problem here." in ray.task_info()["failed_remote_function_imports"][0]["error_message"]) + + ray.worker.cleanup() + + def testFailImportingReusableVariable(self): + ray.init(start_ray_local=True, num_workers=2, driver_mode=ray.SILENT_MODE) + + # This will throw an exception when the reusable variable is imported on the + # workers. + def initializer(): + if ray.worker.global_worker.mode == ray.WORKER_MODE: + raise Exception("The initializer failed.") + return 0 + ray.reusables.foo = ray.Reusable(initializer) + for _ in range(100): # Retry if we need to wait longer. + if len(ray.task_info()["failed_reusable_variable_imports"]) >= 1: + break + time.sleep(0.1) + # Check that the error message is in the task info. + self.assertTrue("The initializer failed." in ray.task_info()["failed_reusable_variable_imports"][0]["error_message"]) + + ray.worker.cleanup() + + def testFailReinitializingVariable(self): + ray.init(start_ray_local=True, num_workers=2, driver_mode=ray.SILENT_MODE) + + def initializer(): + return 0 + def reinitializer(foo): + raise Exception("The reinitializer failed.") + ray.reusables.foo = ray.Reusable(initializer, reinitializer) + @ray.remote([], []) + def use_foo(): + ray.reusables.foo + use_foo.remote() + for _ in range(100): # Retry if we need to wait longer. + if len(ray.task_info()["failed_reinitialize_reusable_variables"]) >= 1: + break + time.sleep(0.1) + # Check that the error message is in the task info. + self.assertTrue("The reinitializer failed." in ray.task_info()["failed_reinitialize_reusable_variables"][0]["error_message"]) + + ray.worker.cleanup() + +if __name__ == "__main__": + unittest.main(verbosity=2) diff --git a/test/microbenchmarks.py b/test/microbenchmarks.py index 452cd2885..89cba5f6c 100644 --- a/test/microbenchmarks.py +++ b/test/microbenchmarks.py @@ -6,18 +6,12 @@ import numpy as np import test_functions -class TestMicroBenchmarks(unittest.TestCase): +class MicroBenchmarkTest(unittest.TestCase): - @classmethod - def setUpClass(cls): + def testTiming(self): reload(test_functions) ray.init(start_ray_local=True, num_workers=3) - @classmethod - def tearDownClass(cls): - ray.worker.cleanup() - - def test_timing(self): # measure the time required to submit a remote task to the scheduler elapsed_times = [] for _ in range(1000): @@ -83,5 +77,7 @@ class TestMicroBenchmarks(unittest.TestCase): print " worst: {}".format(elapsed_times[999]) # average_elapsed_time should be about 0.00087 + ray.worker.cleanup() + if __name__ == "__main__": - unittest.main(verbosity=2) + unittest.main(verbosity=2) diff --git a/test/runtest.py b/test/runtest.py index 7890223ab..f539d99ec 100644 --- a/test/runtest.py +++ b/test/runtest.py @@ -33,38 +33,32 @@ class UserDefinedType(object): class SerializationTest(unittest.TestCase): - @classmethod - def setUpClass(cls): - ray.init(start_ray_local=True, num_workers=0) - - @classmethod - def tearDownClass(cls): - ray.worker.cleanup() - - def round_trip_test(self, data): + def roundTripTest(self, data): serialized, _ = ray.serialization.serialize(ray.worker.global_worker.handle, data) result = ray.serialization.deserialize(ray.worker.global_worker.handle, serialized) assert_equal(data, result) - def numpy_type_test(self, typ): - self.round_trip_test(np.random.randint(0, 10, size=(100, 100)).astype(typ)) - self.round_trip_test(np.array(0).astype(typ)) - self.round_trip_test(np.empty((0,)).astype(typ)) + def numpyTypeTest(self, typ): + self.roundTripTest(np.random.randint(0, 10, size=(100, 100)).astype(typ)) + self.roundTripTest(np.array(0).astype(typ)) + self.roundTripTest(np.empty((0,)).astype(typ)) + + def testSerialize(self): + ray.init(start_ray_local=True, num_workers=0) - def test_serialize(self): for val in RAY_TEST_OBJECTS: - self.round_trip_test(val) + self.roundTripTest(val) - self.round_trip_test(np.zeros((100, 100))) + self.roundTripTest(np.zeros((100, 100))) - self.numpy_type_test("int8") - self.numpy_type_test("uint8") - self.numpy_type_test("int16") - self.numpy_type_test("uint16") - self.numpy_type_test("int32") - self.numpy_type_test("uint32") - self.numpy_type_test("float32") - self.numpy_type_test("float64") + self.numpyTypeTest("int8") + self.numpyTypeTest("uint8") + self.numpyTypeTest("int16") + self.numpyTypeTest("uint16") + self.numpyTypeTest("int32") + self.numpyTypeTest("uint32") + self.numpyTypeTest("float32") + self.numpyTypeTest("float64") ref0 = ray.put(0) ref1 = ray.put(0) @@ -76,15 +70,17 @@ class SerializationTest(unittest.TestCase): result = ray.serialization.deserialize(ray.worker.global_worker.handle, capsule) self.assertTrue((a == result).all()) - self.round_trip_test(ref0) - self.round_trip_test([ref0, ref1, ref2, ref3]) - self.round_trip_test({"0": ref0, "1": ref1, "2": ref2, "3": ref3}) - self.round_trip_test((ref0, 1)) + self.roundTripTest(ref0) + self.roundTripTest([ref0, ref1, ref2, ref3]) + self.roundTripTest({"0": ref0, "1": ref1, "2": ref2, "3": ref3}) + self.roundTripTest((ref0, 1)) + + ray.worker.cleanup() class ObjStoreTest(unittest.TestCase): # Test setting up object stores, transfering data between them and retrieving data to a client - def test_object_store(self): + def testObjStore(self): node_ip_address = "127.0.0.1" scheduler_address = ray.services.start_ray_local(num_objstores=2, num_workers=0, worker_path=None) ray.connect(node_ip_address, scheduler_address, mode=ray.SCRIPT_MODE) @@ -130,6 +126,15 @@ class ObjStoreTest(unittest.TestCase): result = ray.get(objectid, w1) assert_equal(result, data) + # This test fails. See https://github.com/amplab/ray/issues/159. + # getting multiple times shouldn't matter + # for data in [np.zeros([10, 20]), np.random.normal(size=[45, 25]), np.zeros([10, 20], dtype=np.dtype("float64")), np.zeros([10, 20], dtype=np.dtype("float32")), np.zeros([10, 20], dtype=np.dtype("int64")), np.zeros([10, 20], dtype=np.dtype("int32"))]: + # objectid = worker.put(data, w1) + # result = worker.get(objectid, w2) + # result = worker.get(objectid, w2) + # result = worker.get(objectid, w2) + # assert_equal(result, data) + # Getting a buffer after modifying it before it finishes should return updated buffer objectid = ray.libraylib.get_objectid(w1.handle) buf = ray.libraylib.allocate_buffer(w1.handle, objectid, 100) @@ -143,18 +148,11 @@ class ObjStoreTest(unittest.TestCase): ray.disconnect(worker=w2) ray.worker.cleanup() -class APITest(unittest.TestCase): +class WorkerTest(unittest.TestCase): - @classmethod - def setUpClass(cls): - reload(test_functions) - ray.init(start_ray_local=True, num_workers=3, driver_mode=ray.SILENT_MODE) + def testPutGet(self): + ray.init(start_ray_local=True, num_workers=0) - @classmethod - def tearDownClass(cls): - ray.worker.cleanup() - - def test_put_get(self): for i in range(100): value_before = i * 10 ** 6 objectid = ray.put(value_before) @@ -179,18 +177,14 @@ class APITest(unittest.TestCase): value_after = ray.get(objectid) self.assertEqual(value_before, value_after) - @unittest.skip("This test is currently disabled.") - def test_multiple_get(self): - # This test fails. See https://github.com/amplab/ray/issues/159. getting - # multiple times shouldn't matter - for data in [np.zeros([10, 20]), np.random.normal(size=[45, 25]), np.zeros([10, 20], dtype=np.dtype("float64")), np.zeros([10, 20], dtype=np.dtype("float32")), np.zeros([10, 20], dtype=np.dtype("int64")), np.zeros([10, 20], dtype=np.dtype("int32"))]: - objectid = ray.put(data) - result = ray.get(objectid) - result = ray.get(objectid) - result = ray.get(objectid) - assert_equal(result, data) + ray.worker.cleanup() + +class APITest(unittest.TestCase): + + def testObjectIDAliasing(self): + reload(test_functions) + ray.init(start_ray_local=True, num_workers=3) - def test_objectid_aliasing(self): ref = test_functions.test_alias_f.remote() assert_equal(ray.get(ref), np.ones([3, 4, 5])) ref = test_functions.test_alias_g.remote() @@ -198,7 +192,12 @@ class APITest(unittest.TestCase): ref = test_functions.test_alias_h.remote() assert_equal(ray.get(ref), np.ones([3, 4, 5])) - def test_keyword_args(self): + ray.worker.cleanup() + + def testKeywordArgs(self): + reload(test_functions) + ray.init(start_ray_local=True, num_workers=1) + x = test_functions.keyword_fct1.remote(1) self.assertEqual(ray.get(x), "1 hello") x = test_functions.keyword_fct1.remote(1, "hi") @@ -230,7 +229,12 @@ class APITest(unittest.TestCase): x = test_functions.keyword_fct3.remote(0, 1) self.assertEqual(ray.get(x), "0 1 hello world") - def test_variable_number_of_args(self): + ray.worker.cleanup() + + def testVariableNumberOfArgs(self): + reload(test_functions) + ray.init(start_ray_local=True, num_workers=1) + x = test_functions.varargs_fct1.remote(0, 1, 2) self.assertEqual(ray.get(x), "0 1 2") x = test_functions.varargs_fct2.remote(0, 1, 2) @@ -239,32 +243,23 @@ class APITest(unittest.TestCase): self.assertTrue(test_functions.kwargs_exception_thrown) self.assertTrue(test_functions.varargs_and_kwargs_exception_thrown) - def test_no_args(self): + ray.worker.cleanup() + + def testNoArgs(self): + reload(test_functions) + ray.init(start_ray_local=True, num_workers=1) + test_functions.no_op.remote() time.sleep(0.2) task_info = ray.task_info() self.assertEqual(len(task_info["failed_tasks"]), 0) self.assertEqual(len(task_info["running_tasks"]), 0) - test_functions.no_op_fail.remote() - time.sleep(0.2) - task_info = ray.task_info() - self.assertEqual(len(task_info["failed_tasks"]), 1) - self.assertEqual(len(task_info["running_tasks"]), 0) - self.assertTrue("The @remote decorator for function test_functions.no_op_fail has 0 return values, but test_functions.no_op_fail returned more than 0 values." in task_info["failed_tasks"][0].get("error_message")) + ray.worker.cleanup() - def test_type_checking(self): - # Make sure that these functions throw exceptions because there return - # values do not type check. - num_failed_tasks = len(ray.task_info()["failed_tasks"]) - test_functions.test_return1.remote() - test_functions.test_return2.remote() - time.sleep(0.2) - task_info = ray.task_info() - self.assertEqual(len(task_info["failed_tasks"]), num_failed_tasks + 2) - self.assertEqual(len(task_info["running_tasks"]), 0) + def testDefiningRemoteFunctions(self): + ray.init(start_ray_local=True, num_workers=2) - def test_defining_remote_functions(self): # Test that we can define a remote function in the shell. @ray.remote([int], [int]) def f(x): @@ -308,13 +303,9 @@ class APITest(unittest.TestCase): self.assertEqual(ray.get(l.remote(1)), 2) self.assertEqual(ray.get(m.remote(1)), 2) -class TestCachingBeforeInit(unittest.TestCase): - - @classmethod - def tearDownClass(cls): ray.worker.cleanup() - def test_caching_reusables(self): + def testCachingReusables(self): # Test that we can define reusable variables before the driver is connected. def foo_initializer(): return 1 @@ -324,6 +315,7 @@ class TestCachingBeforeInit(unittest.TestCase): return [] ray.reusables.foo = ray.Reusable(foo_initializer) ray.reusables.bar = ray.Reusable(bar_initializer, bar_reinitializer) + @ray.remote([], [int]) def use_foo(): return ray.reusables.foo @@ -339,94 +331,8 @@ class TestCachingBeforeInit(unittest.TestCase): self.assertEqual(ray.get(use_bar.remote()), [1]) self.assertEqual(ray.get(use_bar.remote()), [1]) -class TestFailures(unittest.TestCase): - - @classmethod - def setUpClass(cls): - reload(test_functions) - ray.init(start_ray_local=True, num_workers=3, driver_mode=ray.SILENT_MODE) - - @classmethod - def tearDownClass(cls): ray.worker.cleanup() - def test_failed_task(self): - test_functions.test_alias_f.remote() - test_functions.throw_exception_fct1.remote() - test_functions.throw_exception_fct1.remote() - time.sleep(1) - result = ray.task_info() - self.assertEqual(len(result["failed_tasks"]), 2) - task_ids = set() - for task in result["failed_tasks"]: - self.assertTrue(task.has_key("worker_address")) - self.assertTrue(task.has_key("operationid")) - self.assertTrue("Test function 1 intentionally failed." in task.get("error_message")) - self.assertTrue(task["operationid"] not in task_ids) - task_ids.add(task["operationid"]) - - x = test_functions.throw_exception_fct2.remote() - try: - ray.get(x) - except Exception as e: - self.assertTrue("Test function 2 intentionally failed."in str(e)) - else: - self.assertTrue(False) # ray.get should throw an exception - - x, y, z = test_functions.throw_exception_fct3.remote(1.0) - for ref in [x, y, z]: - try: - ray.get(ref) - except Exception as e: - self.assertTrue("Test function 3 intentionally failed."in str(e)) - else: - self.assertTrue(False) # ray.get should throw an exception - - def test_fail_importing_remote_function(self): - # This example is somewhat contrived. It should be successfully pickled, and - # then it should throw an exception when it is unpickled. This may depend a - # bit on the specifics of our pickler. - def reducer(*args): - raise Exception("There is a problem here.") - class Foo(object): - def __init__(self): - self.__name__ = "Foo_object" - self.func_doc = "" - self.__globals__ = {} - def __reduce__(self): - return reducer, () - def __call__(self): - return - ray.remote([], [])(Foo()) - time.sleep(0.1) - self.assertTrue("There is a problem here." in ray.task_info()["failed_remote_function_imports"][0]["error_message"]) - - def test_fail_importing_reusable_variable(self): - # This will throw an exception when the reusable variable is imported on the - # workers. - def initializer(): - if ray.worker.global_worker.mode == ray.WORKER_MODE: - raise Exception("The initializer failed.") - return 0 - ray.reusables.foo = ray.Reusable(initializer) - time.sleep(0.1) - # Check that the error message is in the task info. - self.assertTrue("The initializer failed." in ray.task_info()["failed_reusable_variable_imports"][0]["error_message"]) - - def test_fail_reinitializing_variable(self): - def initializer(): - return 0 - def reinitializer(foo): - raise Exception("The reinitializer failed.") - ray.reusables.foo = ray.Reusable(initializer, reinitializer) - @ray.remote([], []) - def use_foo(): - ray.reusables.foo - use_foo.remote() - time.sleep(0.1) - # Check that the error message is in the task info. - self.assertTrue("The reinitializer failed." in ray.task_info()["failed_reinitialize_reusable_variables"][0]["error_message"]) - def check_get_deallocated(data): x = ray.put(data) ray.get(x) @@ -437,25 +343,20 @@ def check_get_not_deallocated(data): y = ray.get(x) return y, x.id -class TestReferenceCounting(unittest.TestCase): +class ReferenceCountingTest(unittest.TestCase): - @classmethod - def setUpClass(cls): + def testDeallocation(self): reload(test_functions) for module in [ra.core, ra.random, ra.linalg, da.core, da.random, da.linalg]: reload(module) - ray.init(start_ray_local=True, num_workers=3) + ray.init(start_ray_local=True, num_workers=1) - @classmethod - def tearDownClass(cls): - ray.worker.cleanup() - - def test_deallocation(self): x = test_functions.test_alias_f.remote() ray.get(x) time.sleep(0.1) objectid_val = x.id self.assertEqual(ray.scheduler_info()["reference_counts"][objectid_val], 1) + del x self.assertEqual(ray.scheduler_info()["reference_counts"][objectid_val], -1) # -1 indicates deallocated @@ -464,6 +365,7 @@ class TestReferenceCounting(unittest.TestCase): time.sleep(0.1) objectid_val = y.id self.assertEqual(ray.scheduler_info()["reference_counts"][objectid_val:(objectid_val + 3)], [1, 0, 0]) + del y self.assertEqual(ray.scheduler_info()["reference_counts"][objectid_val:(objectid_val + 3)], [-1, -1, -1]) @@ -471,6 +373,7 @@ class TestReferenceCounting(unittest.TestCase): time.sleep(0.1) objectid_val = z.id self.assertEqual(ray.scheduler_info()["reference_counts"][objectid_val:(objectid_val + 3)], [1, 1, 1]) + del z time.sleep(0.1) self.assertEqual(ray.scheduler_info()["reference_counts"][objectid_val:(objectid_val + 3)], [-1, -1, -1]) @@ -481,6 +384,7 @@ class TestReferenceCounting(unittest.TestCase): objectid_val = x.id time.sleep(0.1) self.assertEqual(ray.scheduler_info()["reference_counts"][objectid_val:(objectid_val + 3)], [1, 1, 1]) + del x time.sleep(0.1) self.assertEqual(ray.scheduler_info()["reference_counts"][objectid_val:(objectid_val + 3)], [-1, 1, 1]) @@ -491,7 +395,11 @@ class TestReferenceCounting(unittest.TestCase): time.sleep(0.1) self.assertEqual(ray.scheduler_info()["reference_counts"][objectid_val:(objectid_val + 3)], [-1, -1, -1]) - def test_get(self): + ray.worker.cleanup() + + def testGet(self): + ray.init(start_ray_local=True, num_workers=3) + for val in RAY_TEST_OBJECTS + [np.zeros((2, 2)), UserDefinedType()]: objectid_val = check_get_deallocated(val) self.assertEqual(ray.scheduler_info()["reference_counts"][objectid_val], -1) @@ -500,35 +408,36 @@ class TestReferenceCounting(unittest.TestCase): x, objectid_val = check_get_not_deallocated(val) self.assertEqual(ray.scheduler_info()["reference_counts"][objectid_val], 1) - @unittest.skip("This test is currently disabled.") - def test_get_twice(self): # The following currently segfaults: The second "result = " closes the # memory segment as soon as the assignment is done (and the first result # goes out of scope). - data = np.zeros([10, 20]) - objectid = ray.put(data) - result = ray.get(objectid) - result = ray.get(objectid) - assert_equal(result, data) + # data = np.zeros([10, 20]) + # objectid = ray.put(data) + # result = worker.get(objectid) + # result = worker.get(objectid) + # assert_equal(result, data) - @unittest.skip("This test is currently disabled.") - def test_get_bool_and_none(self): - # This fails, because for bool and None, we cannot track python reference - # counts and therefore cannot keep the refcount up (see - # 5281bd414f6b404f61e1fe25ec5f6651defee206). The resulting behavior is still - # correct however because True, False and None are returned by get "by - # value" and therefore can be reclaimed from the object store safely. - for val in [True, False, None]: - x, objectid_val = check_get_not_deallocated(val) - self.assertEqual(ray.scheduler_info()["reference_counts"][objectid_val], 1) - -class TestPythonMode(unittest.TestCase): - - @classmethod - def tearDownClass(cls): ray.worker.cleanup() - def test_python_mode(self): + # @unittest.expectedFailure + # def testGetFailing(self): + # ray.init(start_ray_local=True, num_workers=3) + + # # This is failing, because for bool and None, we cannot track python + # # refcounts and therefore cannot keep the refcount up + # # (see 5281bd414f6b404f61e1fe25ec5f6651defee206). + # # The resulting behavior is still correct however because True, False and + # # None are returned by get "by value" and therefore can be reclaimed from + # # the object store safely. + # for val in [True, False, None]: + # x, objectid_val = check_get_not_deallocated(val) + # self.assertEqual(ray.scheduler_info()["reference_counts"][objectid_val], 1) + + # ray.worker.cleanup() + +class PythonModeTest(unittest.TestCase): + + def testPythonMode(self): reload(test_functions) ray.init(start_ray_local=True, driver_mode=ray.PYTHON_MODE) @@ -546,40 +455,64 @@ class TestPythonMode(unittest.TestCase): assert_equal(aref, np.array([0, 0])) # python_mode_g should not mutate aref assert_equal(bref, np.array([1, 0])) -class TestPythonCExtensions(unittest.TestCase): - - @classmethod - def tearDownClass(cls): ray.worker.cleanup() - def test_reference_counting_bools_and_none(self): +class PythonCExtensionTest(unittest.TestCase): + + def testReferenceCountNone(self): ray.init(start_ray_local=True, num_workers=1) # Make sure that we aren't accidentally messing up Python's reference counts. - for obj in [None, True, False]: - @ray.remote([], [int]) - def f(): - return sys.getrefcount(obj) - first_count = ray.get(f.remote()) - second_count = ray.get(f.remote()) - self.assertEqual(first_count, second_count) + @ray.remote([], [int]) + def f(): + return sys.getrefcount(None) + first_count = ray.get(f.remote()) + second_count = ray.get(f.remote()) + self.assertEqual(first_count, second_count) -class TestReusableVariables(unittest.TestCase): - - @classmethod - def tearDownClass(cls): ray.worker.cleanup() - def test_reusable_variables(self): + def testReferenceCountNone(self): + ray.init(start_ray_local=True, num_workers=1) + + # Make sure that we aren't accidentally messing up Python's reference counts. + @ray.remote([], [int]) + def f(): + return sys.getrefcount(True) + first_count = ray.get(f.remote()) + second_count = ray.get(f.remote()) + self.assertEqual(first_count, second_count) + + ray.worker.cleanup() + + def testReferenceCountNone(self): + ray.init(start_ray_local=True, num_workers=1) + + # Make sure that we aren't accidentally messing up Python's reference counts. + @ray.remote([], [int]) + def f(): + return sys.getrefcount(False) + first_count = ray.get(f.remote()) + second_count = ray.get(f.remote()) + self.assertEqual(first_count, second_count) + + ray.worker.cleanup() + +class ReusablesTest(unittest.TestCase): + + def testReusables(self): ray.init(start_ray_local=True, num_workers=1) # Test that we can add a variable to the key-value store. + def foo_initializer(): return 1 def foo_reinitializer(foo): return foo + ray.reusables.foo = ray.Reusable(foo_initializer, foo_reinitializer) self.assertEqual(ray.reusables.foo, 1) + @ray.remote([], [int]) def use_foo(): return ray.reusables.foo @@ -588,9 +521,12 @@ class TestReusableVariables(unittest.TestCase): self.assertEqual(ray.get(use_foo.remote()), 1) # Test that we can add a variable to the key-value store, mutate it, and reset it. + def bar_initializer(): return [1, 2, 3] + ray.reusables.bar = ray.Reusable(bar_initializer) + @ray.remote([], [list]) def use_bar(): ray.reusables.bar.append(4) @@ -600,13 +536,16 @@ class TestReusableVariables(unittest.TestCase): self.assertEqual(ray.get(use_bar.remote()), [1, 2, 3, 4]) # Test that we can use the reinitializer. + def baz_initializer(): return np.zeros([4]) def baz_reinitializer(baz): for i in range(len(baz)): baz[i] = 0 return baz + ray.reusables.baz = ray.Reusable(baz_initializer, baz_reinitializer) + @ray.remote([int], [np.ndarray]) def use_baz(i): baz = ray.reusables.baz @@ -620,11 +559,14 @@ class TestReusableVariables(unittest.TestCase): # Make sure the reinitializer is actually getting called. Note that this is # not the correct usage of a reinitializer because it does not reset qux to # its original state. This is just for testing. + def qux_initializer(): return 0 def qux_reinitializer(x): return x + 1 + ray.reusables.qux = ray.Reusable(qux_initializer, qux_reinitializer) + @ray.remote([], [int]) def use_qux(): return ray.reusables.qux @@ -632,43 +574,45 @@ class TestReusableVariables(unittest.TestCase): self.assertEqual(ray.get(use_qux.remote()), 1) self.assertEqual(ray.get(use_qux.remote()), 2) -class TestAttachingToCluster(unittest.TestCase): - - @classmethod - def tearDownClass(cls): ray.worker.cleanup() - def test_attaching_to_cluster(self): +class ClusterAttachingTest(unittest.TestCase): + + def testAttachingToCluster(self): node_ip_address = "127.0.0.1" scheduler_port = np.random.randint(40000, 50000) scheduler_address = "{}:{}".format(node_ip_address, scheduler_port) ray.services.start_scheduler(scheduler_address, cleanup=True) + time.sleep(0.1) ray.services.start_node(scheduler_address, node_ip_address, num_workers=1, cleanup=True) + ray.init(node_ip_address=node_ip_address, scheduler_address=scheduler_address) + @ray.remote([int], [int]) def f(x): return x + 1 self.assertEqual(ray.get(f.remote(0)), 1) -class TestAttachingToClusterWithMultipleObjectStores(unittest.TestCase): - - @classmethod - def tearDownClass(cls): ray.worker.cleanup() - def test_attaching_to_cluster_with_multiple_object_stores(self): + def testAttachingToClusterWithMultipleObjectStores(self): node_ip_address = "127.0.0.1" scheduler_port = np.random.randint(40000, 50000) scheduler_address = "{}:{}".format(node_ip_address, scheduler_port) ray.services.start_scheduler(scheduler_address, cleanup=True) + time.sleep(0.1) ray.services.start_node(scheduler_address, node_ip_address, num_workers=5, cleanup=True) ray.services.start_node(scheduler_address, node_ip_address, num_workers=5, cleanup=True) ray.services.start_node(scheduler_address, node_ip_address, num_workers=5, cleanup=True) + ray.init(node_ip_address=node_ip_address, scheduler_address=scheduler_address) + @ray.remote([int], [int]) def f(x): return x + 1 self.assertEqual(ray.get(f.remote(0)), 1) + ray.worker.cleanup() + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/thirdparty/grpc b/thirdparty/grpc index 15a10b7ea..77e8b714e 160000 --- a/thirdparty/grpc +++ b/thirdparty/grpc @@ -1 +1 @@ -Subproject commit 15a10b7ea262a95b5c10288f67932d5ba24ac47d +Subproject commit 77e8b714e510c6a3061d17aab8af769a7b45eed4