diff --git a/doc/api.rst b/doc/api.rst index 2d4d74a29..a407b91bd 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -5,6 +5,7 @@ The Ray API .. autofunction:: ray.put .. autofunction:: ray.get .. autofunction:: ray.remote +.. autofunction:: ray.select .. autofunction:: ray.init .. autofunction:: ray.kill_workers .. autofunction:: ray.restart_workers_local diff --git a/examples/hyperopt/driver.py b/examples/hyperopt/driver.py index 34b9ee749..0efb67760 100644 --- a/examples/hyperopt/driver.py +++ b/examples/hyperopt/driver.py @@ -55,7 +55,10 @@ if __name__ == "__main__": # Fetch the results of the tasks and print the results. for i in range(trials): - params, result_id = results[i] + # Get the index of the first task that completes. + index = ray.select([result_id for _, result_id in results], num_objects=1)[0] + # Process the output of this task and remove it from the list. + params, result_id = results.pop(index) accuracy = ray.get(result_id) print """We achieve accuracy {:.3}% with learning_rate: {:.2} diff --git a/lib/python/ray/__init__.py b/lib/python/ray/__init__.py index e01643631..b2a843eea 100644 --- a/lib/python/ray/__init__.py +++ b/lib/python/ray/__init__.py @@ -11,7 +11,7 @@ if hasattr(ctypes, "windll"): import config import serialization -from worker import scheduler_info, visualize_computation_graph, task_info, init, connect, disconnect, get, put, remote, kill_workers, restart_workers_local +from worker import scheduler_info, visualize_computation_graph, task_info, init, connect, disconnect, get, put, select, remote, kill_workers, restart_workers_local from worker import Reusable, reusables from libraylib import SCRIPT_MODE, WORKER_MODE, PYTHON_MODE, SILENT_MODE from libraylib import ObjectID diff --git a/lib/python/ray/worker.py b/lib/python/ray/worker.py index 4a136dad2..deeca503a 100644 --- a/lib/python/ray/worker.py +++ b/lib/python/ray/worker.py @@ -819,6 +819,36 @@ def put(value, worker=global_worker): worker.put_object(objectid, value) return objectid +def select(objectids, num_objects=0, worker=global_worker): + """Return a list of the indices of the objects that are ready. + + If num_objects is 0, the function immediately returns the indices of all + objects that are ready. If it is set, the function waits until that number of + objects is ready and returns that exact number of objectids. + + Args: + objectids (List[ray.ObjectID]): List of objectids for objects that may or + may not be ready. + num_objects (int): The number of indices that should be returned. + + Returns: + List of indices in the original list of objects that are ready. + """ + check_connected(worker) + if num_objects > len(objectids): + raise Exception("num_objects cannot be greater than len(objectids), num_objects is {}, and len(objectids) is {}.".format(num_objects, len(objectids))) + ready_ids = raylib.ray_select(worker.handle, objectids) + # Polls scheduler until enough objects are ready. + while len(ready_ids) < num_objects: + ready_ids = raylib.ray_select(worker.handle, objectids) + time.sleep(0.1) + if num_objects != 0: + # Return indices for exactly the requested number of objects. + return ready_ids[:num_objects] + else: + # Return indices for all objects that are ready. + return ready_ids + def kill_workers(worker=global_worker): """Kill all of the workers in the cluster. This does not kill drivers. diff --git a/protos/ray.proto b/protos/ray.proto index e1d29b775..cd2eec032 100644 --- a/protos/ray.proto +++ b/protos/ray.proto @@ -58,6 +58,8 @@ service Scheduler { rpc ExportReusableVariable(ExportReusableVariableRequest) returns (AckReply); // Notify the scheduler that a failure occurred while running a task, importing a remote function, or importing a reusable variable. rpc NotifyFailure(NotifyFailureRequest) returns (AckReply); + // Polls the scheduler to see what objectids can be retrieved in the input list. + rpc Select(SelectRequest) returns (SelectReply); } message AckReply { @@ -168,6 +170,14 @@ message SchedulerInfoReply { repeated ObjstoreData objstore = 7; // Information about the object stores } +message SelectRequest { + repeated uint64 objectids = 1; // List of objectids to be checked. +} + +message SelectReply { + repeated uint64 indices = 1; // List of indices that correspond to objectids in the original list that are ready. +} + // Object stores service ObjStore { diff --git a/src/raylib.cc b/src/raylib.cc index f95b17fe2..95144ceed 100644 --- a/src/raylib.cc +++ b/src/raylib.cc @@ -897,6 +897,26 @@ static PyObject* request_object(PyObject* self, PyObject* args) { Py_RETURN_NONE; } +static PyObject* ray_select(PyObject* self, PyObject* args) { + Worker* worker; + PyObject* objectids; + if (!PyArg_ParseTuple(args, "O&O", &PyObjectToWorker, &worker, &objectids)) { + return NULL; + } + std::vector objectids_vec; + for (size_t i = 0; i < PyList_Size(objectids); ++i) { + ObjectID objectid; + PyObjectToObjectID(PyList_GetItem(objectids, i), &objectid); + objectids_vec.push_back(objectid); + } + std::vector indices = worker->select(objectids_vec); + PyObject* result = PyList_New(indices.size()); + for (size_t i = 0; i < indices.size(); ++i) { + PyList_SetItem(result, i, PyInt_FromLong(indices[i])); + } + return result; +} + static PyObject* alias_objectids(PyObject* self, PyObject* args) { Worker* worker; ObjectID alias_objectid; @@ -1061,6 +1081,7 @@ static PyMethodDef RayLibMethods[] = { { "get_object", get_object, METH_VARARGS, "get protocol buffer object from the local object store" }, { "get_objectid", get_objectid, METH_VARARGS, "register a new object reference with the scheduler" }, { "request_object" , request_object, METH_VARARGS, "request an object to be delivered to the local object store" }, + { "ray_select" , ray_select, METH_VARARGS, "checks the scheduler to see if a object can be gotten" }, { "alias_objectids", alias_objectids, METH_VARARGS, "make two objectids refer to the same object" }, { "wait_for_next_message", wait_for_next_message, METH_VARARGS, "get next message from scheduler (blocking)" }, { "submit_task", submit_task, METH_VARARGS, "call a remote function" }, diff --git a/src/scheduler.cc b/src/scheduler.cc index b179141da..888fb2581 100644 --- a/src/scheduler.cc +++ b/src/scheduler.cc @@ -549,6 +549,21 @@ Status SchedulerService::ExportReusableVariable(ServerContext* context, const Ex return Status::OK; } +Status SchedulerService::Select(ServerContext* context, const SelectRequest* request, SelectReply* reply) { + auto objtable = GET(objtable_); + for (int i = 0; i < request->objectids_size(); ++i) { + ObjectID objectid = request->objectids(i); + if (has_canonical_objectid(objectid)) { + ObjectID canonical_objectid = get_canonical_objectid(objectid); + RAY_CHECK_LT(canonical_objectid, objtable->size(), "Canonical_objectid is outside object table."); + if ((*objtable)[canonical_objectid].size() != 0) { + reply->add_indices(i); + } + } + } + return Status::OK; +} + void SchedulerService::deliver_object_async_if_necessary(ObjectID canonical_objectid, ObjStoreId from, ObjStoreId to) { bool object_present_or_in_transit; { diff --git a/src/scheduler.h b/src/scheduler.h index 1a9cc6075..aaac29044 100644 --- a/src/scheduler.h +++ b/src/scheduler.h @@ -78,6 +78,7 @@ public: Status ExportRemoteFunction(ServerContext* context, const ExportRemoteFunctionRequest* request, AckReply* reply) override; Status ExportReusableVariable(ServerContext* context, const ExportReusableVariableRequest* request, AckReply* reply) override; Status NotifyFailure(ServerContext*, const NotifyFailureRequest* request, AckReply* reply) override; + Status Select(ServerContext*, const SelectRequest* request, SelectReply* reply) override; #ifdef NDEBUG // If we've disabled assertions, then just use regular SynchronizedPtr to skip lock checking. diff --git a/src/worker.cc b/src/worker.cc index f0aa3cbbf..64d0b4062 100644 --- a/src/worker.cc +++ b/src/worker.cc @@ -424,6 +424,22 @@ void Worker::task_info(ClientContext &context, TaskInfoRequest &request, TaskInf RAY_CHECK_GRPC(scheduler_stub_->TaskInfo(&context, request, &reply)); } +std::vector Worker::select(std::vector& objectids) { + RAY_CHECK(connected_, "Attempted to test if object was ready but failed."); + ClientContext context; + SelectRequest request; + SelectReply reply; + for (int i = 0; i < objectids.size(); ++i) { + request.add_objectids(objectids[i]); + } + RAY_CHECK_GRPC(scheduler_stub_->Select(&context, request, &reply)); + std::vector result; + for (int i = 0; i < reply.indices_size(); ++i) { + result.push_back(reply.indices(i)); + } + return result; +} + bool Worker::export_remote_function(const std::string& function_name, const std::string& function) { RAY_CHECK(connected_, "Attempted to export function but failed."); ClientContext context; diff --git a/src/worker.h b/src/worker.h index bafb7f27b..814cc1041 100644 --- a/src/worker.h +++ b/src/worker.h @@ -102,6 +102,8 @@ class Worker { void scheduler_info(ClientContext &context, SchedulerInfoRequest &request, SchedulerInfoReply &reply); // get task statuses from scheduler void task_info(ClientContext &context, TaskInfoRequest &request, TaskInfoReply &reply); + // gets indices of available objects + std::vector select(std::vector& objectids); // export function to workers bool export_remote_function(const std::string& function_name, const std::string& function); // export reusable variable to workers diff --git a/test/runtest.py b/test/runtest.py index f539d99ec..13f7a5500 100644 --- a/test/runtest.py +++ b/test/runtest.py @@ -305,6 +305,24 @@ class APITest(unittest.TestCase): ray.worker.cleanup() + def testSelect(self): + ray.init(start_ray_local=True, num_workers=4) + + @ray.remote([float], [int]) + def f(delay): + time.sleep(delay) + return 1 + objectids = [f.remote(1.5), f.remote(1.5), f.remote(1.0), f.remote(0.5)] + self.assertEqual(ray.select(objectids), []) + time.sleep(0.75) + self.assertEqual(ray.select(objectids), [3]) + time.sleep(0.5) + self.assertEqual(ray.select(objectids), [2, 3]) + time.sleep(0.5) + self.assertEqual(ray.select(objectids), [0, 1, 2, 3]) + + ray.worker.cleanup() + def testCachingReusables(self): # Test that we can define reusable variables before the driver is connected. def foo_initializer():