mirror of
https://github.com/wassname/ray.git
synced 2026-07-04 13:05:25 +08:00
Changed ray.select() to ray.wait() and its functionality (#426)
* Re-implemented select, changed name to wait * Changed tests for select to tests for wait * Updated the hyperopt example to match wait * Small fixes and improve example readme. * Make tests pass.
This commit is contained in:
+1
-1
@@ -5,7 +5,7 @@ The Ray API
|
||||
.. autofunction:: ray.put
|
||||
.. autofunction:: ray.get
|
||||
.. autofunction:: ray.remote
|
||||
.. autofunction:: ray.select
|
||||
.. autofunction:: ray.wait
|
||||
.. autofunction:: ray.init
|
||||
.. autofunction:: ray.kill_workers
|
||||
.. autofunction:: ray.restart_workers_local
|
||||
|
||||
@@ -63,8 +63,9 @@ def generate_random_params():
|
||||
|
||||
results = []
|
||||
for _ in range(100):
|
||||
randparams = generate_random_params()
|
||||
results.append((randparams, train_cnn_and_compute_accuracy(randparams, train_images, train_labels, validation_images, validation_labels)))
|
||||
params = generate_random_params()
|
||||
accuracy = train_cnn_and_compute_accuracy(randparams, train_images, train_labels, validation_images, validation_labels)
|
||||
results.append(accuracy)
|
||||
```
|
||||
|
||||
Then we can inspect the contents of `results` and see which set of
|
||||
@@ -101,16 +102,53 @@ computation. Instead, it simply submits a number of tasks to the scheduler.
|
||||
|
||||
```python
|
||||
result_ids = []
|
||||
# Launch 100 tasks.
|
||||
for _ in range(100):
|
||||
params = generate_random_params()
|
||||
results.append((params, train_cnn_and_compute_accuracy.remote(params, train_images, train_labels, validation_images, validation_labels)))
|
||||
accuracy_id = train_cnn_and_compute_accuracy.remote(randparams, train_images, train_labels, validation_images, validation_labels)
|
||||
result_ids.append(accuracy_id)
|
||||
```
|
||||
|
||||
If we wish to wait until the results have all been retrieved, we can retrieve
|
||||
their values with `ray.get`.
|
||||
|
||||
```python
|
||||
results = [(params, ray.get(result_id)) for (params, result_id) in result_ids]
|
||||
results = ray.get(result_ids)
|
||||
```
|
||||
|
||||
One drawback of the above approach is that nothing will be printed until all of
|
||||
the experiments have finished. What we'd really like is to start processing
|
||||
the results of certain experiments as soon as they finish (and possibly launch
|
||||
more experiments based on the outcomes of the first ones). To do this, we can
|
||||
use `ray.wait`, which takes a list of object IDs and returns two lists of object
|
||||
IDs.
|
||||
|
||||
```python
|
||||
ready_ids, remaining_ids = ray.wait(result_ids, num_returns=3, timeout=10)
|
||||
```
|
||||
|
||||
In the above, `result_ids` is a list of object IDs. The command `ray.wait` will
|
||||
return as soon as either three of the object IDs in `result_ids` are ready (that
|
||||
is, the task that created the corresponding object finished executing and stored
|
||||
the object in the object store) or ten seconds pass, whichever comes first. To
|
||||
wait indefinitely, omit the timeout argument. Now, we can rewrite the script as
|
||||
follows.
|
||||
|
||||
```python
|
||||
remaining_ids = []
|
||||
# Launch 100 tasks.
|
||||
for _ in range(100):
|
||||
params = generate_random_params()
|
||||
accuracy_id = train_cnn_and_compute_accuracy.remote(randparams, train_images, train_labels, validation_images, validation_labels)
|
||||
result_ids.append(accuracy_id)
|
||||
|
||||
# Process the tasks one at a time.
|
||||
while len(remaining_ids) > 0:
|
||||
# Process the next task that finishes.
|
||||
ready_ids, remaining_ids = ray.wait(remaining_ids, num_returns=1)
|
||||
# Get the accuracy corresponding to the ready object ID.
|
||||
accuracy = ray.get(ready_ids[0])
|
||||
print "Accuracy {}".format(accuracy)
|
||||
```
|
||||
|
||||
## Additional notes
|
||||
|
||||
+24
-11
@@ -39,26 +39,39 @@ if __name__ == "__main__":
|
||||
validation_images = ray.put(mnist.validation.images)
|
||||
validation_labels = ray.put(mnist.validation.labels)
|
||||
|
||||
# Store the best parameters, the best accuracy, and all of the results.
|
||||
# Keep track of the best parameters and the best accuracy.
|
||||
best_params = None
|
||||
best_accuracy = 0
|
||||
results = []
|
||||
# This list holds the object IDs for all of the experiments that we have
|
||||
# launched and that have not yet been processed.
|
||||
remaining_ids = []
|
||||
# This is a dictionary mapping the object ID of an experiment to the
|
||||
# parameters used for that experiment.
|
||||
params_mapping = {}
|
||||
|
||||
# Randomly generate some hyperparameters, and launch a task for each set.
|
||||
for i in range(trials):
|
||||
# A function for generating random hyperparameters.
|
||||
def generate_random_params():
|
||||
learning_rate = 10 ** np.random.uniform(-5, 5)
|
||||
batch_size = np.random.randint(1, 100)
|
||||
dropout = np.random.uniform(0, 1)
|
||||
stddev = 10 ** np.random.uniform(-5, 5)
|
||||
params = {"learning_rate": learning_rate, "batch_size": batch_size, "dropout": dropout, "stddev": stddev}
|
||||
results.append((params, hyperopt.train_cnn_and_compute_accuracy.remote(params, steps, train_images, train_labels, validation_images, validation_labels)))
|
||||
return {"learning_rate": learning_rate, "batch_size": batch_size, "dropout": dropout, "stddev": stddev}
|
||||
|
||||
# Fetch the results of the tasks and print the results.
|
||||
# Randomly generate some hyperparameters, and launch a task for each set.
|
||||
for i in range(trials):
|
||||
# Get the index of the first task that completes.
|
||||
index = ray.select([result_id for _, result_id in results], num_objects=1)[0]
|
||||
# Process the output of this task and remove it from the list.
|
||||
params, result_id = results.pop(index)
|
||||
params = generate_random_params()
|
||||
accuracy_id = hyperopt.train_cnn_and_compute_accuracy.remote(params, steps, train_images, train_labels, validation_images, validation_labels)
|
||||
remaining_ids.append(accuracy_id)
|
||||
# Keep track of which parameters correspond to this experiment.
|
||||
params_mapping[accuracy_id] = params
|
||||
|
||||
# Fetch and print the results of the tasks in the order that they complete.
|
||||
for i in range(trials):
|
||||
# Use ray.wait to get the object ID of the first task that completes.
|
||||
ready_ids, remaining_ids = ray.wait(remaining_ids)
|
||||
# Process the output of this task.
|
||||
result_id = ready_ids[0]
|
||||
params = params_mapping[result_id]
|
||||
accuracy = ray.get(result_id)
|
||||
print """We achieve accuracy {:.3}% with
|
||||
learning_rate: {:.2}
|
||||
|
||||
@@ -11,7 +11,7 @@ if hasattr(ctypes, "windll"):
|
||||
|
||||
import config
|
||||
import serialization
|
||||
from worker import scheduler_info, register_class, visualize_computation_graph, task_info, init, connect, disconnect, get, put, select, remote, kill_workers, restart_workers_local
|
||||
from worker import scheduler_info, register_class, visualize_computation_graph, task_info, init, connect, disconnect, get, put, wait, remote, kill_workers, restart_workers_local
|
||||
from worker import Reusable, reusables
|
||||
from libraylib import SCRIPT_MODE, WORKER_MODE, PYTHON_MODE, SILENT_MODE
|
||||
from libraylib import ObjectID
|
||||
|
||||
+28
-20
@@ -838,35 +838,43 @@ def put(value, worker=global_worker):
|
||||
worker.put_object(objectid, value)
|
||||
return objectid
|
||||
|
||||
def select(objectids, num_objects=0, worker=global_worker):
|
||||
"""Return a list of the indices of the objects that are ready.
|
||||
def wait(objectids, num_returns=1, timeout=None, worker=global_worker):
|
||||
"""Return a list of IDs that are ready and a list of IDs that are not ready.
|
||||
|
||||
If num_objects is 0, the function immediately returns the indices of all
|
||||
objects that are ready. If it is set, the function waits until that number of
|
||||
objects is ready and returns that exact number of objectids.
|
||||
If timeout is set, the function returns either when the requested number of
|
||||
IDs are ready or when the timeout is reached, whichever occurs first. If it is
|
||||
not set, the function simply waits until that number of objects is ready and
|
||||
returns that exact number of objectids.
|
||||
|
||||
This method returns two lists. The first list consists of object IDs that
|
||||
correspond to objects that are stored in the object store. The second list
|
||||
corresponds to the rest of the object IDs (which may or may not be ready).
|
||||
|
||||
Args:
|
||||
objectids (List[ray.ObjectID]): List of objectids for objects that may or
|
||||
may not be ready.
|
||||
num_objects (int): The number of indices that should be returned.
|
||||
objectids (List[raylib.ObjectID]): List of object IDs for objects that may
|
||||
or may not be ready.
|
||||
num_returns (int): The number of object IDs that should be returned.
|
||||
timeout (float): The maximum amount of time in seconds that should be spent
|
||||
polling the scheduler.
|
||||
|
||||
Returns:
|
||||
List of indices in the original list of objects that are ready.
|
||||
A list of object IDs that are ready and a list of the remaining object IDs.
|
||||
"""
|
||||
check_connected(worker)
|
||||
if num_objects > len(objectids):
|
||||
raise Exception("num_objects cannot be greater than len(objectids), num_objects is {}, and len(objectids) is {}.".format(num_objects, len(objectids)))
|
||||
ready_ids = raylib.ray_select(worker.handle, objectids)
|
||||
if num_returns < 0:
|
||||
raise Exception("num_returns cannot be less than 0.")
|
||||
if num_returns > len(objectids):
|
||||
raise Exception("num_returns cannot be greater than the length of the input list: num_objects is {}, and the length is {}.".format(num_returns, len(objectids)))
|
||||
start_time = time.time()
|
||||
ready_indices = raylib.wait(worker.handle, objectids)
|
||||
# Polls scheduler until enough objects are ready.
|
||||
while len(ready_ids) < num_objects:
|
||||
ready_ids = raylib.ray_select(worker.handle, objectids)
|
||||
while len(ready_indices) < num_returns and (time.time() - start_time < timeout or timeout is None):
|
||||
ready_indices = raylib.wait(worker.handle, objectids)
|
||||
time.sleep(0.1)
|
||||
if num_objects != 0:
|
||||
# Return indices for exactly the requested number of objects.
|
||||
return ready_ids[:num_objects]
|
||||
else:
|
||||
# Return indices for all objects that are ready.
|
||||
return ready_ids
|
||||
# Return indices for exactly the requested number of objects.
|
||||
ready_ids = [objectids[i] for i in ready_indices[:num_returns]]
|
||||
not_ready_ids = [objectids[i] for i in range(len(objectids)) if i not in ready_indices[:num_returns]]
|
||||
return ready_ids, not_ready_ids
|
||||
|
||||
def kill_workers(worker=global_worker):
|
||||
"""Kill all of the workers in the cluster. This does not kill drivers.
|
||||
|
||||
+3
-3
@@ -61,7 +61,7 @@ service Scheduler {
|
||||
// Notify the scheduler that a failure occurred while running a task, importing a remote function, or importing a reusable variable.
|
||||
rpc NotifyFailure(NotifyFailureRequest) returns (AckReply);
|
||||
// Polls the scheduler to see what objectids can be retrieved in the input list.
|
||||
rpc Select(SelectRequest) returns (SelectReply);
|
||||
rpc Wait(WaitRequest) returns (WaitReply);
|
||||
}
|
||||
|
||||
message AckReply {
|
||||
@@ -173,11 +173,11 @@ message SchedulerInfoReply {
|
||||
repeated ObjstoreData objstore = 7; // Information about the object stores
|
||||
}
|
||||
|
||||
message SelectRequest {
|
||||
message WaitRequest {
|
||||
repeated uint64 objectids = 1; // List of objectids to be checked.
|
||||
}
|
||||
|
||||
message SelectReply {
|
||||
message WaitReply {
|
||||
repeated uint64 indices = 1; // List of indices that correspond to objectids in the original list that are ready.
|
||||
}
|
||||
|
||||
|
||||
+3
-3
@@ -892,7 +892,7 @@ static PyObject* request_object(PyObject* self, PyObject* args) {
|
||||
Py_RETURN_NONE;
|
||||
}
|
||||
|
||||
static PyObject* ray_select(PyObject* self, PyObject* args) {
|
||||
static PyObject* wait(PyObject* self, PyObject* args) {
|
||||
Worker* worker;
|
||||
PyObject* objectids;
|
||||
if (!PyArg_ParseTuple(args, "O&O", &PyObjectToWorker, &worker, &objectids)) {
|
||||
@@ -904,7 +904,7 @@ static PyObject* ray_select(PyObject* self, PyObject* args) {
|
||||
PyObjectToObjectID(PyList_GetItem(objectids, i), &objectid);
|
||||
objectids_vec.push_back(objectid);
|
||||
}
|
||||
std::vector<int> indices = worker->select(objectids_vec);
|
||||
std::vector<int> indices = worker->wait(objectids_vec);
|
||||
PyObject* result = PyList_New(indices.size());
|
||||
for (size_t i = 0; i < indices.size(); ++i) {
|
||||
PyList_SetItem(result, i, PyInt_FromLong(indices[i]));
|
||||
@@ -1081,7 +1081,7 @@ static PyMethodDef RayLibMethods[] = {
|
||||
{ "add_contained_objectids", add_contained_objectids, METH_VARARGS, "notify the scheduler about the object IDs contained in a remote object" },
|
||||
{ "get_objectid", get_objectid, METH_VARARGS, "register a new object reference with the scheduler" },
|
||||
{ "request_object" , request_object, METH_VARARGS, "request an object to be delivered to the local object store" },
|
||||
{ "ray_select" , ray_select, METH_VARARGS, "checks the scheduler to see if a object can be gotten" },
|
||||
{ "wait" , wait, METH_VARARGS, "checks the scheduler to see if a object can be gotten" },
|
||||
{ "alias_objectids", alias_objectids, METH_VARARGS, "make two objectids refer to the same object" },
|
||||
{ "wait_for_next_message", wait_for_next_message, METH_VARARGS, "get next message from scheduler (blocking)" },
|
||||
{ "submit_task", submit_task, METH_VARARGS, "call a remote function" },
|
||||
|
||||
+1
-1
@@ -594,7 +594,7 @@ Status SchedulerService::ExportReusableVariable(ServerContext* context, const Ex
|
||||
return Status::OK;
|
||||
}
|
||||
|
||||
Status SchedulerService::Select(ServerContext* context, const SelectRequest* request, SelectReply* reply) {
|
||||
Status SchedulerService::Wait(ServerContext* context, const WaitRequest* request, WaitReply* reply) {
|
||||
auto objtable = GET(objtable_);
|
||||
for (int i = 0; i < request->objectids_size(); ++i) {
|
||||
ObjectID objectid = request->objectids(i);
|
||||
|
||||
+1
-1
@@ -79,7 +79,7 @@ public:
|
||||
Status ExportRemoteFunction(ServerContext* context, const ExportRemoteFunctionRequest* request, AckReply* reply) override;
|
||||
Status ExportReusableVariable(ServerContext* context, const ExportReusableVariableRequest* request, AckReply* reply) override;
|
||||
Status NotifyFailure(ServerContext*, const NotifyFailureRequest* request, AckReply* reply) override;
|
||||
Status Select(ServerContext*, const SelectRequest* request, SelectReply* reply) override;
|
||||
Status Wait(ServerContext*, const WaitRequest* request, WaitReply* reply) override;
|
||||
|
||||
#ifdef NDEBUG
|
||||
// If we've disabled assertions, then just use regular SynchronizedPtr to skip lock checking.
|
||||
|
||||
+4
-4
@@ -409,15 +409,15 @@ void Worker::task_info(ClientContext &context, TaskInfoRequest &request, TaskInf
|
||||
RAY_CHECK_GRPC(scheduler_stub_->TaskInfo(&context, request, &reply));
|
||||
}
|
||||
|
||||
std::vector<int> Worker::select(std::vector<ObjectID>& objectids) {
|
||||
std::vector<int> Worker::wait(std::vector<ObjectID>& objectids) {
|
||||
RAY_CHECK(connected_, "Attempted to test if object was ready but failed.");
|
||||
ClientContext context;
|
||||
SelectRequest request;
|
||||
SelectReply reply;
|
||||
WaitRequest request;
|
||||
WaitReply reply;
|
||||
for (int i = 0; i < objectids.size(); ++i) {
|
||||
request.add_objectids(objectids[i]);
|
||||
}
|
||||
RAY_CHECK_GRPC(scheduler_stub_->Select(&context, request, &reply));
|
||||
RAY_CHECK_GRPC(scheduler_stub_->Wait(&context, request, &reply));
|
||||
std::vector<int> result;
|
||||
for (int i = 0; i < reply.indices_size(); ++i) {
|
||||
result.push_back(reply.indices(i));
|
||||
|
||||
+1
-1
@@ -102,7 +102,7 @@ class Worker {
|
||||
// get task statuses from scheduler
|
||||
void task_info(ClientContext &context, TaskInfoRequest &request, TaskInfoReply &reply);
|
||||
// gets indices of available objects
|
||||
std::vector<int> select(std::vector<ObjectID>& objectids);
|
||||
std::vector<int> wait(std::vector<ObjectID>& objectids);
|
||||
// Export a function to be run on all workers.
|
||||
void run_function_on_all_workers(const std::string& function);
|
||||
// export function to workers
|
||||
|
||||
+23
-19
@@ -358,31 +358,35 @@ class APITest(unittest.TestCase):
|
||||
self.assertEqual(ray.get(object_ids), range(10))
|
||||
ray.worker.cleanup()
|
||||
|
||||
def testSelect(self):
|
||||
ray.init(start_ray_local=True, num_workers=4)
|
||||
def testWait(self):
|
||||
ray.init(start_ray_local=True, num_workers=1)
|
||||
|
||||
@ray.remote
|
||||
def f(delay):
|
||||
time.sleep(delay)
|
||||
return 1
|
||||
objectids = [f.remote(1.5), f.remote(1.5), f.remote(1.0), f.remote(0.5)]
|
||||
self.assertEqual(ray.select(objectids), [])
|
||||
time.sleep(0.75)
|
||||
self.assertEqual(ray.select(objectids), [3])
|
||||
time.sleep(0.5)
|
||||
self.assertEqual(ray.select(objectids), [2, 3])
|
||||
time.sleep(0.5)
|
||||
self.assertEqual(ray.select(objectids), [0, 1, 2, 3])
|
||||
|
||||
objectids = [f.remote(0.5), f.remote(0.75), f.remote(0.25), f.remote(1.0)]
|
||||
values = ["a", "b", "c", "d"]
|
||||
indices = []
|
||||
while len(objectids) > 0:
|
||||
index = ray.select(objectids, num_objects=1)[0]
|
||||
indices.append(values[index])
|
||||
objectids.pop(index)
|
||||
values.pop(index)
|
||||
self.assertEqual(indices, ["c", "a", "b", "d"])
|
||||
objectids = [f.remote(1.0), f.remote(0.5), f.remote(0.5), f.remote(0.5)]
|
||||
ready_ids, remaining_ids = ray.wait(objectids)
|
||||
self.assertTrue(len(ready_ids) == 1)
|
||||
self.assertTrue(len(remaining_ids) == 3)
|
||||
ready_ids, remaining_ids = ray.wait(objectids, num_returns=4)
|
||||
self.assertEqual(ready_ids, objectids)
|
||||
self.assertEqual(remaining_ids, [])
|
||||
|
||||
objectids = [f.remote(0.5), f.remote(0.5), f.remote(0.5), f.remote(0.5)]
|
||||
start_time = time.time()
|
||||
ready_ids, remaining_ids = ray.wait(objectids, timeout=1.75, num_returns=4)
|
||||
self.assertTrue(time.time() - start_time < 2)
|
||||
self.assertEqual(len(ready_ids), 3)
|
||||
self.assertEqual(len(remaining_ids), 1)
|
||||
ray.wait(objectids)
|
||||
objectids = [f.remote(1.0), f.remote(0.5), f.remote(0.5), f.remote(0.5)]
|
||||
start_time = time.time()
|
||||
ready_ids, remaining_ids = ray.wait(objectids, timeout=5)
|
||||
self.assertTrue(time.time() - start_time < 5)
|
||||
self.assertEqual(len(ready_ids), 1)
|
||||
self.assertEqual(len(remaining_ids), 3)
|
||||
|
||||
ray.worker.cleanup()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user