From d5815673a504b6d71aba7a80a8a7626552bc7209 Mon Sep 17 00:00:00 2001 From: Wapaul1 Date: Wed, 14 Sep 2016 17:14:11 -0700 Subject: [PATCH] Changed ray.select() to ray.wait() and its functionality (#426) * Re-implemented select, changed name to wait * Changed tests for select to tests for wait * Updated the hyperopt example to match wait * Small fixes and improve example readme. * Make tests pass. --- doc/api.rst | 2 +- examples/hyperopt/README.md | 46 +++++++++++++++++++++++++++++++---- examples/hyperopt/driver.py | 35 ++++++++++++++++++--------- lib/python/ray/__init__.py | 2 +- lib/python/ray/worker.py | 48 +++++++++++++++++++++---------------- protos/ray.proto | 6 ++--- src/raylib.cc | 6 ++--- src/scheduler.cc | 2 +- src/scheduler.h | 2 +- src/worker.cc | 8 +++---- src/worker.h | 2 +- test/runtest.py | 42 +++++++++++++++++--------------- 12 files changed, 132 insertions(+), 69 deletions(-) diff --git a/doc/api.rst b/doc/api.rst index a407b91bd..b6c092dcc 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -5,7 +5,7 @@ The Ray API .. autofunction:: ray.put .. autofunction:: ray.get .. autofunction:: ray.remote -.. autofunction:: ray.select +.. autofunction:: ray.wait .. autofunction:: ray.init .. autofunction:: ray.kill_workers .. autofunction:: ray.restart_workers_local diff --git a/examples/hyperopt/README.md b/examples/hyperopt/README.md index 67dd2847d..a94512f00 100644 --- a/examples/hyperopt/README.md +++ b/examples/hyperopt/README.md @@ -63,8 +63,9 @@ def generate_random_params(): results = [] for _ in range(100): - randparams = generate_random_params() - results.append((randparams, train_cnn_and_compute_accuracy(randparams, train_images, train_labels, validation_images, validation_labels))) + params = generate_random_params() + accuracy = train_cnn_and_compute_accuracy(randparams, train_images, train_labels, validation_images, validation_labels) + results.append(accuracy) ``` Then we can inspect the contents of `results` and see which set of @@ -101,16 +102,53 @@ computation. Instead, it simply submits a number of tasks to the scheduler. ```python result_ids = [] +# Launch 100 tasks. for _ in range(100): params = generate_random_params() - results.append((params, train_cnn_and_compute_accuracy.remote(params, train_images, train_labels, validation_images, validation_labels))) + accuracy_id = train_cnn_and_compute_accuracy.remote(randparams, train_images, train_labels, validation_images, validation_labels) + result_ids.append(accuracy_id) ``` If we wish to wait until the results have all been retrieved, we can retrieve their values with `ray.get`. ```python -results = [(params, ray.get(result_id)) for (params, result_id) in result_ids] +results = ray.get(result_ids) +``` + +One drawback of the above approach is that nothing will be printed until all of +the experiments have finished. What we'd really like is to start processing +the results of certain experiments as soon as they finish (and possibly launch +more experiments based on the outcomes of the first ones). To do this, we can +use `ray.wait`, which takes a list of object IDs and returns two lists of object +IDs. + +```python +ready_ids, remaining_ids = ray.wait(result_ids, num_returns=3, timeout=10) +``` + +In the above, `result_ids` is a list of object IDs. The command `ray.wait` will +return as soon as either three of the object IDs in `result_ids` are ready (that +is, the task that created the corresponding object finished executing and stored +the object in the object store) or ten seconds pass, whichever comes first. To +wait indefinitely, omit the timeout argument. Now, we can rewrite the script as +follows. + +```python +remaining_ids = [] +# Launch 100 tasks. +for _ in range(100): + params = generate_random_params() + accuracy_id = train_cnn_and_compute_accuracy.remote(randparams, train_images, train_labels, validation_images, validation_labels) + result_ids.append(accuracy_id) + +# Process the tasks one at a time. +while len(remaining_ids) > 0: + # Process the next task that finishes. + ready_ids, remaining_ids = ray.wait(remaining_ids, num_returns=1) + # Get the accuracy corresponding to the ready object ID. + accuracy = ray.get(ready_ids[0]) + print "Accuracy {}".format(accuracy) ``` ## Additional notes diff --git a/examples/hyperopt/driver.py b/examples/hyperopt/driver.py index 0efb67760..76794f532 100644 --- a/examples/hyperopt/driver.py +++ b/examples/hyperopt/driver.py @@ -39,26 +39,39 @@ if __name__ == "__main__": validation_images = ray.put(mnist.validation.images) validation_labels = ray.put(mnist.validation.labels) - # Store the best parameters, the best accuracy, and all of the results. + # Keep track of the best parameters and the best accuracy. best_params = None best_accuracy = 0 - results = [] + # This list holds the object IDs for all of the experiments that we have + # launched and that have not yet been processed. + remaining_ids = [] + # This is a dictionary mapping the object ID of an experiment to the + # parameters used for that experiment. + params_mapping = {} - # Randomly generate some hyperparameters, and launch a task for each set. - for i in range(trials): + # A function for generating random hyperparameters. + def generate_random_params(): learning_rate = 10 ** np.random.uniform(-5, 5) batch_size = np.random.randint(1, 100) dropout = np.random.uniform(0, 1) stddev = 10 ** np.random.uniform(-5, 5) - params = {"learning_rate": learning_rate, "batch_size": batch_size, "dropout": dropout, "stddev": stddev} - results.append((params, hyperopt.train_cnn_and_compute_accuracy.remote(params, steps, train_images, train_labels, validation_images, validation_labels))) + return {"learning_rate": learning_rate, "batch_size": batch_size, "dropout": dropout, "stddev": stddev} - # Fetch the results of the tasks and print the results. + # Randomly generate some hyperparameters, and launch a task for each set. for i in range(trials): - # Get the index of the first task that completes. - index = ray.select([result_id for _, result_id in results], num_objects=1)[0] - # Process the output of this task and remove it from the list. - params, result_id = results.pop(index) + params = generate_random_params() + accuracy_id = hyperopt.train_cnn_and_compute_accuracy.remote(params, steps, train_images, train_labels, validation_images, validation_labels) + remaining_ids.append(accuracy_id) + # Keep track of which parameters correspond to this experiment. + params_mapping[accuracy_id] = params + + # Fetch and print the results of the tasks in the order that they complete. + for i in range(trials): + # Use ray.wait to get the object ID of the first task that completes. + ready_ids, remaining_ids = ray.wait(remaining_ids) + # Process the output of this task. + result_id = ready_ids[0] + params = params_mapping[result_id] accuracy = ray.get(result_id) print """We achieve accuracy {:.3}% with learning_rate: {:.2} diff --git a/lib/python/ray/__init__.py b/lib/python/ray/__init__.py index d00b54ef1..9847505c3 100644 --- a/lib/python/ray/__init__.py +++ b/lib/python/ray/__init__.py @@ -11,7 +11,7 @@ if hasattr(ctypes, "windll"): import config import serialization -from worker import scheduler_info, register_class, visualize_computation_graph, task_info, init, connect, disconnect, get, put, select, remote, kill_workers, restart_workers_local +from worker import scheduler_info, register_class, visualize_computation_graph, task_info, init, connect, disconnect, get, put, wait, remote, kill_workers, restart_workers_local from worker import Reusable, reusables from libraylib import SCRIPT_MODE, WORKER_MODE, PYTHON_MODE, SILENT_MODE from libraylib import ObjectID diff --git a/lib/python/ray/worker.py b/lib/python/ray/worker.py index f48e7882e..71183b395 100644 --- a/lib/python/ray/worker.py +++ b/lib/python/ray/worker.py @@ -838,35 +838,43 @@ def put(value, worker=global_worker): worker.put_object(objectid, value) return objectid -def select(objectids, num_objects=0, worker=global_worker): - """Return a list of the indices of the objects that are ready. +def wait(objectids, num_returns=1, timeout=None, worker=global_worker): + """Return a list of IDs that are ready and a list of IDs that are not ready. - If num_objects is 0, the function immediately returns the indices of all - objects that are ready. If it is set, the function waits until that number of - objects is ready and returns that exact number of objectids. + If timeout is set, the function returns either when the requested number of + IDs are ready or when the timeout is reached, whichever occurs first. If it is + not set, the function simply waits until that number of objects is ready and + returns that exact number of objectids. + + This method returns two lists. The first list consists of object IDs that + correspond to objects that are stored in the object store. The second list + corresponds to the rest of the object IDs (which may or may not be ready). Args: - objectids (List[ray.ObjectID]): List of objectids for objects that may or - may not be ready. - num_objects (int): The number of indices that should be returned. + objectids (List[raylib.ObjectID]): List of object IDs for objects that may + or may not be ready. + num_returns (int): The number of object IDs that should be returned. + timeout (float): The maximum amount of time in seconds that should be spent + polling the scheduler. Returns: - List of indices in the original list of objects that are ready. + A list of object IDs that are ready and a list of the remaining object IDs. """ check_connected(worker) - if num_objects > len(objectids): - raise Exception("num_objects cannot be greater than len(objectids), num_objects is {}, and len(objectids) is {}.".format(num_objects, len(objectids))) - ready_ids = raylib.ray_select(worker.handle, objectids) + if num_returns < 0: + raise Exception("num_returns cannot be less than 0.") + if num_returns > len(objectids): + raise Exception("num_returns cannot be greater than the length of the input list: num_objects is {}, and the length is {}.".format(num_returns, len(objectids))) + start_time = time.time() + ready_indices = raylib.wait(worker.handle, objectids) # Polls scheduler until enough objects are ready. - while len(ready_ids) < num_objects: - ready_ids = raylib.ray_select(worker.handle, objectids) + while len(ready_indices) < num_returns and (time.time() - start_time < timeout or timeout is None): + ready_indices = raylib.wait(worker.handle, objectids) time.sleep(0.1) - if num_objects != 0: - # Return indices for exactly the requested number of objects. - return ready_ids[:num_objects] - else: - # Return indices for all objects that are ready. - return ready_ids + # Return indices for exactly the requested number of objects. + ready_ids = [objectids[i] for i in ready_indices[:num_returns]] + not_ready_ids = [objectids[i] for i in range(len(objectids)) if i not in ready_indices[:num_returns]] + return ready_ids, not_ready_ids def kill_workers(worker=global_worker): """Kill all of the workers in the cluster. This does not kill drivers. diff --git a/protos/ray.proto b/protos/ray.proto index 2d44b09a5..9a2c22b41 100644 --- a/protos/ray.proto +++ b/protos/ray.proto @@ -61,7 +61,7 @@ service Scheduler { // Notify the scheduler that a failure occurred while running a task, importing a remote function, or importing a reusable variable. rpc NotifyFailure(NotifyFailureRequest) returns (AckReply); // Polls the scheduler to see what objectids can be retrieved in the input list. - rpc Select(SelectRequest) returns (SelectReply); + rpc Wait(WaitRequest) returns (WaitReply); } message AckReply { @@ -173,11 +173,11 @@ message SchedulerInfoReply { repeated ObjstoreData objstore = 7; // Information about the object stores } -message SelectRequest { +message WaitRequest { repeated uint64 objectids = 1; // List of objectids to be checked. } -message SelectReply { +message WaitReply { repeated uint64 indices = 1; // List of indices that correspond to objectids in the original list that are ready. } diff --git a/src/raylib.cc b/src/raylib.cc index c38fc2ba1..ec9ff082c 100644 --- a/src/raylib.cc +++ b/src/raylib.cc @@ -892,7 +892,7 @@ static PyObject* request_object(PyObject* self, PyObject* args) { Py_RETURN_NONE; } -static PyObject* ray_select(PyObject* self, PyObject* args) { +static PyObject* wait(PyObject* self, PyObject* args) { Worker* worker; PyObject* objectids; if (!PyArg_ParseTuple(args, "O&O", &PyObjectToWorker, &worker, &objectids)) { @@ -904,7 +904,7 @@ static PyObject* ray_select(PyObject* self, PyObject* args) { PyObjectToObjectID(PyList_GetItem(objectids, i), &objectid); objectids_vec.push_back(objectid); } - std::vector indices = worker->select(objectids_vec); + std::vector indices = worker->wait(objectids_vec); PyObject* result = PyList_New(indices.size()); for (size_t i = 0; i < indices.size(); ++i) { PyList_SetItem(result, i, PyInt_FromLong(indices[i])); @@ -1081,7 +1081,7 @@ static PyMethodDef RayLibMethods[] = { { "add_contained_objectids", add_contained_objectids, METH_VARARGS, "notify the scheduler about the object IDs contained in a remote object" }, { "get_objectid", get_objectid, METH_VARARGS, "register a new object reference with the scheduler" }, { "request_object" , request_object, METH_VARARGS, "request an object to be delivered to the local object store" }, - { "ray_select" , ray_select, METH_VARARGS, "checks the scheduler to see if a object can be gotten" }, + { "wait" , wait, METH_VARARGS, "checks the scheduler to see if a object can be gotten" }, { "alias_objectids", alias_objectids, METH_VARARGS, "make two objectids refer to the same object" }, { "wait_for_next_message", wait_for_next_message, METH_VARARGS, "get next message from scheduler (blocking)" }, { "submit_task", submit_task, METH_VARARGS, "call a remote function" }, diff --git a/src/scheduler.cc b/src/scheduler.cc index d6f090cfd..e5e0a0ead 100644 --- a/src/scheduler.cc +++ b/src/scheduler.cc @@ -594,7 +594,7 @@ Status SchedulerService::ExportReusableVariable(ServerContext* context, const Ex return Status::OK; } -Status SchedulerService::Select(ServerContext* context, const SelectRequest* request, SelectReply* reply) { +Status SchedulerService::Wait(ServerContext* context, const WaitRequest* request, WaitReply* reply) { auto objtable = GET(objtable_); for (int i = 0; i < request->objectids_size(); ++i) { ObjectID objectid = request->objectids(i); diff --git a/src/scheduler.h b/src/scheduler.h index a9a13e93d..e4c1d7713 100644 --- a/src/scheduler.h +++ b/src/scheduler.h @@ -79,7 +79,7 @@ public: Status ExportRemoteFunction(ServerContext* context, const ExportRemoteFunctionRequest* request, AckReply* reply) override; Status ExportReusableVariable(ServerContext* context, const ExportReusableVariableRequest* request, AckReply* reply) override; Status NotifyFailure(ServerContext*, const NotifyFailureRequest* request, AckReply* reply) override; - Status Select(ServerContext*, const SelectRequest* request, SelectReply* reply) override; + Status Wait(ServerContext*, const WaitRequest* request, WaitReply* reply) override; #ifdef NDEBUG // If we've disabled assertions, then just use regular SynchronizedPtr to skip lock checking. diff --git a/src/worker.cc b/src/worker.cc index 9b05c0d31..8d024392f 100644 --- a/src/worker.cc +++ b/src/worker.cc @@ -409,15 +409,15 @@ void Worker::task_info(ClientContext &context, TaskInfoRequest &request, TaskInf RAY_CHECK_GRPC(scheduler_stub_->TaskInfo(&context, request, &reply)); } -std::vector Worker::select(std::vector& objectids) { +std::vector Worker::wait(std::vector& objectids) { RAY_CHECK(connected_, "Attempted to test if object was ready but failed."); ClientContext context; - SelectRequest request; - SelectReply reply; + WaitRequest request; + WaitReply reply; for (int i = 0; i < objectids.size(); ++i) { request.add_objectids(objectids[i]); } - RAY_CHECK_GRPC(scheduler_stub_->Select(&context, request, &reply)); + RAY_CHECK_GRPC(scheduler_stub_->Wait(&context, request, &reply)); std::vector result; for (int i = 0; i < reply.indices_size(); ++i) { result.push_back(reply.indices(i)); diff --git a/src/worker.h b/src/worker.h index ec264d18d..18149972c 100644 --- a/src/worker.h +++ b/src/worker.h @@ -102,7 +102,7 @@ class Worker { // get task statuses from scheduler void task_info(ClientContext &context, TaskInfoRequest &request, TaskInfoReply &reply); // gets indices of available objects - std::vector select(std::vector& objectids); + std::vector wait(std::vector& objectids); // Export a function to be run on all workers. void run_function_on_all_workers(const std::string& function); // export function to workers diff --git a/test/runtest.py b/test/runtest.py index d12415bc5..aad6db8d7 100644 --- a/test/runtest.py +++ b/test/runtest.py @@ -358,31 +358,35 @@ class APITest(unittest.TestCase): self.assertEqual(ray.get(object_ids), range(10)) ray.worker.cleanup() - def testSelect(self): - ray.init(start_ray_local=True, num_workers=4) + def testWait(self): + ray.init(start_ray_local=True, num_workers=1) @ray.remote def f(delay): time.sleep(delay) return 1 - objectids = [f.remote(1.5), f.remote(1.5), f.remote(1.0), f.remote(0.5)] - self.assertEqual(ray.select(objectids), []) - time.sleep(0.75) - self.assertEqual(ray.select(objectids), [3]) - time.sleep(0.5) - self.assertEqual(ray.select(objectids), [2, 3]) - time.sleep(0.5) - self.assertEqual(ray.select(objectids), [0, 1, 2, 3]) - objectids = [f.remote(0.5), f.remote(0.75), f.remote(0.25), f.remote(1.0)] - values = ["a", "b", "c", "d"] - indices = [] - while len(objectids) > 0: - index = ray.select(objectids, num_objects=1)[0] - indices.append(values[index]) - objectids.pop(index) - values.pop(index) - self.assertEqual(indices, ["c", "a", "b", "d"]) + objectids = [f.remote(1.0), f.remote(0.5), f.remote(0.5), f.remote(0.5)] + ready_ids, remaining_ids = ray.wait(objectids) + self.assertTrue(len(ready_ids) == 1) + self.assertTrue(len(remaining_ids) == 3) + ready_ids, remaining_ids = ray.wait(objectids, num_returns=4) + self.assertEqual(ready_ids, objectids) + self.assertEqual(remaining_ids, []) + + objectids = [f.remote(0.5), f.remote(0.5), f.remote(0.5), f.remote(0.5)] + start_time = time.time() + ready_ids, remaining_ids = ray.wait(objectids, timeout=1.75, num_returns=4) + self.assertTrue(time.time() - start_time < 2) + self.assertEqual(len(ready_ids), 3) + self.assertEqual(len(remaining_ids), 1) + ray.wait(objectids) + objectids = [f.remote(1.0), f.remote(0.5), f.remote(0.5), f.remote(0.5)] + start_time = time.time() + ready_ids, remaining_ids = ray.wait(objectids, timeout=5) + self.assertTrue(time.time() - start_time < 5) + self.assertEqual(len(ready_ids), 1) + self.assertEqual(len(remaining_ids), 3) ray.worker.cleanup()