Changed ray.select() to ray.wait() and its functionality (#426)

* Re-implemented select, changed name to wait

* Changed tests for select to tests for wait

* Updated the hyperopt example to match wait

* Small fixes and improve example readme.

* Make tests pass.
This commit is contained in:
Wapaul1
2016-09-14 17:14:11 -07:00
committed by Philipp Moritz
parent 8c6d3a88a9
commit d5815673a5
12 changed files with 132 additions and 69 deletions
+1 -1
View File
@@ -5,7 +5,7 @@ The Ray API
.. autofunction:: ray.put
.. autofunction:: ray.get
.. autofunction:: ray.remote
.. autofunction:: ray.select
.. autofunction:: ray.wait
.. autofunction:: ray.init
.. autofunction:: ray.kill_workers
.. autofunction:: ray.restart_workers_local
+42 -4
View File
@@ -63,8 +63,9 @@ def generate_random_params():
results = []
for _ in range(100):
randparams = generate_random_params()
results.append((randparams, train_cnn_and_compute_accuracy(randparams, train_images, train_labels, validation_images, validation_labels)))
params = generate_random_params()
accuracy = train_cnn_and_compute_accuracy(randparams, train_images, train_labels, validation_images, validation_labels)
results.append(accuracy)
```
Then we can inspect the contents of `results` and see which set of
@@ -101,16 +102,53 @@ computation. Instead, it simply submits a number of tasks to the scheduler.
```python
result_ids = []
# Launch 100 tasks.
for _ in range(100):
params = generate_random_params()
results.append((params, train_cnn_and_compute_accuracy.remote(params, train_images, train_labels, validation_images, validation_labels)))
accuracy_id = train_cnn_and_compute_accuracy.remote(randparams, train_images, train_labels, validation_images, validation_labels)
result_ids.append(accuracy_id)
```
If we wish to wait until the results have all been retrieved, we can retrieve
their values with `ray.get`.
```python
results = [(params, ray.get(result_id)) for (params, result_id) in result_ids]
results = ray.get(result_ids)
```
One drawback of the above approach is that nothing will be printed until all of
the experiments have finished. What we'd really like is to start processing
the results of certain experiments as soon as they finish (and possibly launch
more experiments based on the outcomes of the first ones). To do this, we can
use `ray.wait`, which takes a list of object IDs and returns two lists of object
IDs.
```python
ready_ids, remaining_ids = ray.wait(result_ids, num_returns=3, timeout=10)
```
In the above, `result_ids` is a list of object IDs. The command `ray.wait` will
return as soon as either three of the object IDs in `result_ids` are ready (that
is, the task that created the corresponding object finished executing and stored
the object in the object store) or ten seconds pass, whichever comes first. To
wait indefinitely, omit the timeout argument. Now, we can rewrite the script as
follows.
```python
remaining_ids = []
# Launch 100 tasks.
for _ in range(100):
params = generate_random_params()
accuracy_id = train_cnn_and_compute_accuracy.remote(randparams, train_images, train_labels, validation_images, validation_labels)
result_ids.append(accuracy_id)
# Process the tasks one at a time.
while len(remaining_ids) > 0:
# Process the next task that finishes.
ready_ids, remaining_ids = ray.wait(remaining_ids, num_returns=1)
# Get the accuracy corresponding to the ready object ID.
accuracy = ray.get(ready_ids[0])
print "Accuracy {}".format(accuracy)
```
## Additional notes
+24 -11
View File
@@ -39,26 +39,39 @@ if __name__ == "__main__":
validation_images = ray.put(mnist.validation.images)
validation_labels = ray.put(mnist.validation.labels)
# Store the best parameters, the best accuracy, and all of the results.
# Keep track of the best parameters and the best accuracy.
best_params = None
best_accuracy = 0
results = []
# This list holds the object IDs for all of the experiments that we have
# launched and that have not yet been processed.
remaining_ids = []
# This is a dictionary mapping the object ID of an experiment to the
# parameters used for that experiment.
params_mapping = {}
# Randomly generate some hyperparameters, and launch a task for each set.
for i in range(trials):
# A function for generating random hyperparameters.
def generate_random_params():
learning_rate = 10 ** np.random.uniform(-5, 5)
batch_size = np.random.randint(1, 100)
dropout = np.random.uniform(0, 1)
stddev = 10 ** np.random.uniform(-5, 5)
params = {"learning_rate": learning_rate, "batch_size": batch_size, "dropout": dropout, "stddev": stddev}
results.append((params, hyperopt.train_cnn_and_compute_accuracy.remote(params, steps, train_images, train_labels, validation_images, validation_labels)))
return {"learning_rate": learning_rate, "batch_size": batch_size, "dropout": dropout, "stddev": stddev}
# Fetch the results of the tasks and print the results.
# Randomly generate some hyperparameters, and launch a task for each set.
for i in range(trials):
# Get the index of the first task that completes.
index = ray.select([result_id for _, result_id in results], num_objects=1)[0]
# Process the output of this task and remove it from the list.
params, result_id = results.pop(index)
params = generate_random_params()
accuracy_id = hyperopt.train_cnn_and_compute_accuracy.remote(params, steps, train_images, train_labels, validation_images, validation_labels)
remaining_ids.append(accuracy_id)
# Keep track of which parameters correspond to this experiment.
params_mapping[accuracy_id] = params
# Fetch and print the results of the tasks in the order that they complete.
for i in range(trials):
# Use ray.wait to get the object ID of the first task that completes.
ready_ids, remaining_ids = ray.wait(remaining_ids)
# Process the output of this task.
result_id = ready_ids[0]
params = params_mapping[result_id]
accuracy = ray.get(result_id)
print """We achieve accuracy {:.3}% with
learning_rate: {:.2}
+1 -1
View File
@@ -11,7 +11,7 @@ if hasattr(ctypes, "windll"):
import config
import serialization
from worker import scheduler_info, register_class, visualize_computation_graph, task_info, init, connect, disconnect, get, put, select, remote, kill_workers, restart_workers_local
from worker import scheduler_info, register_class, visualize_computation_graph, task_info, init, connect, disconnect, get, put, wait, remote, kill_workers, restart_workers_local
from worker import Reusable, reusables
from libraylib import SCRIPT_MODE, WORKER_MODE, PYTHON_MODE, SILENT_MODE
from libraylib import ObjectID
+28 -20
View File
@@ -838,35 +838,43 @@ def put(value, worker=global_worker):
worker.put_object(objectid, value)
return objectid
def select(objectids, num_objects=0, worker=global_worker):
"""Return a list of the indices of the objects that are ready.
def wait(objectids, num_returns=1, timeout=None, worker=global_worker):
"""Return a list of IDs that are ready and a list of IDs that are not ready.
If num_objects is 0, the function immediately returns the indices of all
objects that are ready. If it is set, the function waits until that number of
objects is ready and returns that exact number of objectids.
If timeout is set, the function returns either when the requested number of
IDs are ready or when the timeout is reached, whichever occurs first. If it is
not set, the function simply waits until that number of objects is ready and
returns that exact number of objectids.
This method returns two lists. The first list consists of object IDs that
correspond to objects that are stored in the object store. The second list
corresponds to the rest of the object IDs (which may or may not be ready).
Args:
objectids (List[ray.ObjectID]): List of objectids for objects that may or
may not be ready.
num_objects (int): The number of indices that should be returned.
objectids (List[raylib.ObjectID]): List of object IDs for objects that may
or may not be ready.
num_returns (int): The number of object IDs that should be returned.
timeout (float): The maximum amount of time in seconds that should be spent
polling the scheduler.
Returns:
List of indices in the original list of objects that are ready.
A list of object IDs that are ready and a list of the remaining object IDs.
"""
check_connected(worker)
if num_objects > len(objectids):
raise Exception("num_objects cannot be greater than len(objectids), num_objects is {}, and len(objectids) is {}.".format(num_objects, len(objectids)))
ready_ids = raylib.ray_select(worker.handle, objectids)
if num_returns < 0:
raise Exception("num_returns cannot be less than 0.")
if num_returns > len(objectids):
raise Exception("num_returns cannot be greater than the length of the input list: num_objects is {}, and the length is {}.".format(num_returns, len(objectids)))
start_time = time.time()
ready_indices = raylib.wait(worker.handle, objectids)
# Polls scheduler until enough objects are ready.
while len(ready_ids) < num_objects:
ready_ids = raylib.ray_select(worker.handle, objectids)
while len(ready_indices) < num_returns and (time.time() - start_time < timeout or timeout is None):
ready_indices = raylib.wait(worker.handle, objectids)
time.sleep(0.1)
if num_objects != 0:
# Return indices for exactly the requested number of objects.
return ready_ids[:num_objects]
else:
# Return indices for all objects that are ready.
return ready_ids
# Return indices for exactly the requested number of objects.
ready_ids = [objectids[i] for i in ready_indices[:num_returns]]
not_ready_ids = [objectids[i] for i in range(len(objectids)) if i not in ready_indices[:num_returns]]
return ready_ids, not_ready_ids
def kill_workers(worker=global_worker):
"""Kill all of the workers in the cluster. This does not kill drivers.
+3 -3
View File
@@ -61,7 +61,7 @@ service Scheduler {
// Notify the scheduler that a failure occurred while running a task, importing a remote function, or importing a reusable variable.
rpc NotifyFailure(NotifyFailureRequest) returns (AckReply);
// Polls the scheduler to see what objectids can be retrieved in the input list.
rpc Select(SelectRequest) returns (SelectReply);
rpc Wait(WaitRequest) returns (WaitReply);
}
message AckReply {
@@ -173,11 +173,11 @@ message SchedulerInfoReply {
repeated ObjstoreData objstore = 7; // Information about the object stores
}
message SelectRequest {
message WaitRequest {
repeated uint64 objectids = 1; // List of objectids to be checked.
}
message SelectReply {
message WaitReply {
repeated uint64 indices = 1; // List of indices that correspond to objectids in the original list that are ready.
}
+3 -3
View File
@@ -892,7 +892,7 @@ static PyObject* request_object(PyObject* self, PyObject* args) {
Py_RETURN_NONE;
}
static PyObject* ray_select(PyObject* self, PyObject* args) {
static PyObject* wait(PyObject* self, PyObject* args) {
Worker* worker;
PyObject* objectids;
if (!PyArg_ParseTuple(args, "O&O", &PyObjectToWorker, &worker, &objectids)) {
@@ -904,7 +904,7 @@ static PyObject* ray_select(PyObject* self, PyObject* args) {
PyObjectToObjectID(PyList_GetItem(objectids, i), &objectid);
objectids_vec.push_back(objectid);
}
std::vector<int> indices = worker->select(objectids_vec);
std::vector<int> indices = worker->wait(objectids_vec);
PyObject* result = PyList_New(indices.size());
for (size_t i = 0; i < indices.size(); ++i) {
PyList_SetItem(result, i, PyInt_FromLong(indices[i]));
@@ -1081,7 +1081,7 @@ static PyMethodDef RayLibMethods[] = {
{ "add_contained_objectids", add_contained_objectids, METH_VARARGS, "notify the scheduler about the object IDs contained in a remote object" },
{ "get_objectid", get_objectid, METH_VARARGS, "register a new object reference with the scheduler" },
{ "request_object" , request_object, METH_VARARGS, "request an object to be delivered to the local object store" },
{ "ray_select" , ray_select, METH_VARARGS, "checks the scheduler to see if a object can be gotten" },
{ "wait" , wait, METH_VARARGS, "checks the scheduler to see if a object can be gotten" },
{ "alias_objectids", alias_objectids, METH_VARARGS, "make two objectids refer to the same object" },
{ "wait_for_next_message", wait_for_next_message, METH_VARARGS, "get next message from scheduler (blocking)" },
{ "submit_task", submit_task, METH_VARARGS, "call a remote function" },
+1 -1
View File
@@ -594,7 +594,7 @@ Status SchedulerService::ExportReusableVariable(ServerContext* context, const Ex
return Status::OK;
}
Status SchedulerService::Select(ServerContext* context, const SelectRequest* request, SelectReply* reply) {
Status SchedulerService::Wait(ServerContext* context, const WaitRequest* request, WaitReply* reply) {
auto objtable = GET(objtable_);
for (int i = 0; i < request->objectids_size(); ++i) {
ObjectID objectid = request->objectids(i);
+1 -1
View File
@@ -79,7 +79,7 @@ public:
Status ExportRemoteFunction(ServerContext* context, const ExportRemoteFunctionRequest* request, AckReply* reply) override;
Status ExportReusableVariable(ServerContext* context, const ExportReusableVariableRequest* request, AckReply* reply) override;
Status NotifyFailure(ServerContext*, const NotifyFailureRequest* request, AckReply* reply) override;
Status Select(ServerContext*, const SelectRequest* request, SelectReply* reply) override;
Status Wait(ServerContext*, const WaitRequest* request, WaitReply* reply) override;
#ifdef NDEBUG
// If we've disabled assertions, then just use regular SynchronizedPtr to skip lock checking.
+4 -4
View File
@@ -409,15 +409,15 @@ void Worker::task_info(ClientContext &context, TaskInfoRequest &request, TaskInf
RAY_CHECK_GRPC(scheduler_stub_->TaskInfo(&context, request, &reply));
}
std::vector<int> Worker::select(std::vector<ObjectID>& objectids) {
std::vector<int> Worker::wait(std::vector<ObjectID>& objectids) {
RAY_CHECK(connected_, "Attempted to test if object was ready but failed.");
ClientContext context;
SelectRequest request;
SelectReply reply;
WaitRequest request;
WaitReply reply;
for (int i = 0; i < objectids.size(); ++i) {
request.add_objectids(objectids[i]);
}
RAY_CHECK_GRPC(scheduler_stub_->Select(&context, request, &reply));
RAY_CHECK_GRPC(scheduler_stub_->Wait(&context, request, &reply));
std::vector<int> result;
for (int i = 0; i < reply.indices_size(); ++i) {
result.push_back(reply.indices(i));
+1 -1
View File
@@ -102,7 +102,7 @@ class Worker {
// get task statuses from scheduler
void task_info(ClientContext &context, TaskInfoRequest &request, TaskInfoReply &reply);
// gets indices of available objects
std::vector<int> select(std::vector<ObjectID>& objectids);
std::vector<int> wait(std::vector<ObjectID>& objectids);
// Export a function to be run on all workers.
void run_function_on_all_workers(const std::string& function);
// export function to workers
+23 -19
View File
@@ -358,31 +358,35 @@ class APITest(unittest.TestCase):
self.assertEqual(ray.get(object_ids), range(10))
ray.worker.cleanup()
def testSelect(self):
ray.init(start_ray_local=True, num_workers=4)
def testWait(self):
ray.init(start_ray_local=True, num_workers=1)
@ray.remote
def f(delay):
time.sleep(delay)
return 1
objectids = [f.remote(1.5), f.remote(1.5), f.remote(1.0), f.remote(0.5)]
self.assertEqual(ray.select(objectids), [])
time.sleep(0.75)
self.assertEqual(ray.select(objectids), [3])
time.sleep(0.5)
self.assertEqual(ray.select(objectids), [2, 3])
time.sleep(0.5)
self.assertEqual(ray.select(objectids), [0, 1, 2, 3])
objectids = [f.remote(0.5), f.remote(0.75), f.remote(0.25), f.remote(1.0)]
values = ["a", "b", "c", "d"]
indices = []
while len(objectids) > 0:
index = ray.select(objectids, num_objects=1)[0]
indices.append(values[index])
objectids.pop(index)
values.pop(index)
self.assertEqual(indices, ["c", "a", "b", "d"])
objectids = [f.remote(1.0), f.remote(0.5), f.remote(0.5), f.remote(0.5)]
ready_ids, remaining_ids = ray.wait(objectids)
self.assertTrue(len(ready_ids) == 1)
self.assertTrue(len(remaining_ids) == 3)
ready_ids, remaining_ids = ray.wait(objectids, num_returns=4)
self.assertEqual(ready_ids, objectids)
self.assertEqual(remaining_ids, [])
objectids = [f.remote(0.5), f.remote(0.5), f.remote(0.5), f.remote(0.5)]
start_time = time.time()
ready_ids, remaining_ids = ray.wait(objectids, timeout=1.75, num_returns=4)
self.assertTrue(time.time() - start_time < 2)
self.assertEqual(len(ready_ids), 3)
self.assertEqual(len(remaining_ids), 1)
ray.wait(objectids)
objectids = [f.remote(1.0), f.remote(0.5), f.remote(0.5), f.remote(0.5)]
start_time = time.time()
ready_ids, remaining_ids = ray.wait(objectids, timeout=5)
self.assertTrue(time.time() - start_time < 5)
self.assertEqual(len(ready_ids), 1)
self.assertEqual(len(remaining_ids), 3)
ray.worker.cleanup()