From 35487972028202dc8817701e2bd47875052c16dd Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Fri, 2 Sep 2016 18:02:44 -0700 Subject: [PATCH] [API] Implement get for multiple objects (#398) * [API] Implement get for multiple objects * Small fixes. --- README.md | 2 +- doc/tutorial.md | 15 ++++++++++++--- examples/alexnet/alexnet.py | 5 ++--- examples/alexnet/driver.py | 2 +- examples/lbfgs/README.md | 4 ++-- examples/lbfgs/driver.py | 4 ++-- lib/python/ray/worker.py | 16 ++++++++++++---- test/runtest.py | 6 ++++++ 8 files changed, 38 insertions(+), 16 deletions(-) diff --git a/README.md b/README.md index 3606914c0..6c26c4a29 100644 --- a/README.md +++ b/README.md @@ -33,7 +33,7 @@ for _ in range(10): result_ids.append(estimate_pi.remote(100)) # Fetch the results of the tasks and print their average. -estimate = np.mean([ray.get(result_id) for result_id in result_ids]) +estimate = np.mean(ray.get(result_ids)) print "Pi is approximately {}.".format(estimate) ``` diff --git a/doc/tutorial.md b/doc/tutorial.md index 0d5f07e0a..cd4087df5 100644 --- a/doc/tutorial.md +++ b/doc/tutorial.md @@ -95,6 +95,15 @@ If the remote object corresponding to the object ID `x_id` has not been created yet, *the command `ray.get(x_id)` will wait until the remote object has been created.* +A very common use case of `ray.get` is to get a list of object IDs. In this +case, you can call `ray.get(object_ids)` where `object_ids` is a list of object +IDs. + +```python +result_ids = [ray.put(i) for i in range(10)] +ray.get(result_ids) # prints [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] +``` + ## Computation graphs in Ray Ray represents computation with a directed acyclic graph of tasks. Tasks are @@ -202,7 +211,7 @@ for i in range(10): result_ids.append(sleep.remote(2)) # Wait for the results. If we have at least ten workers, this takes 2 seconds. -[ray.get(result_id) for result_id in result_ids] # prints [0, 0, 0, 0, 0, 0, 0, 0, 0, 0] +ray.get(result_ids) # prints [0, 0, 0, 0, 0, 0, 0, 0, 0, 0] ``` The for loop simply adds ten tasks to the computation graph, with no @@ -287,10 +296,10 @@ def run_experiment(i): for j in range(10): sub_results.append(sub_experiment.remote(i, j)) # Return the sum of the results of the sub-experiments. - return sum([ray.get(sub_result) for sub_result in sub_results]) + return sum(ray.get(sub_results)) results = [run_experiment.remote(i) for i in range(5)] -[ray.get(result) for result in results] # prints [45, 55, 65, 75, 85] +ray.get(results) # prints [45, 55, 65, 75, 85] ``` When the remote function `run_experiment` is executed on a worker, it calls the diff --git a/examples/alexnet/alexnet.py b/examples/alexnet/alexnet.py index c2211930e..5b0a707a8 100644 --- a/examples/alexnet/alexnet.py +++ b/examples/alexnet/alexnet.py @@ -241,7 +241,7 @@ def num_images(batches): int: The number of images """ shape_ids = [ra.shape.remote(batch) for batch in batches] - return sum([ray.get(shape_id)[0] for shape_id in shape_ids]) + return sum([shape[0] for shape in ray.get(shape_ids)]) @ray.remote def compute_mean_image(batches): @@ -256,9 +256,8 @@ def compute_mean_image(batches): if len(batches) == 0: raise Exception("No images were passed into `compute_mean_image`.") sum_image_ids = [ra.sum.remote(batch, axis=0) for batch in batches] - sum_images = [ray.get(sum_image_id) for sum_image_id in sum_image_ids] n_images = num_images.remote(batches) - return np.sum(sum_images, axis=0).astype("float64") / ray.get(n_images) + return np.sum(ray.get(sum_image_ids), axis=0).astype("float64") / ray.get(n_images) @ray.remote(num_return_vals=4) def shuffle_arrays(first_images, first_labels, second_images, second_labels): diff --git a/examples/alexnet/driver.py b/examples/alexnet/driver.py index 9fc80f513..a9cdd6e86 100644 --- a/examples/alexnet/driver.py +++ b/examples/alexnet/driver.py @@ -93,7 +93,7 @@ if __name__ == "__main__": print "Iteration {}: accuracy = {:.3}%".format(iteration, 100 * ray.get(accuracy)) # Fetch the gradients. This blocks until the gradients have been computed. - gradient_sets = [ray.get(gradient_id) for gradient_id in gradient_ids] + gradient_sets = ray.get(gradient_ids) # Average the gradients over all of the tasks. mean_gradients = [np.mean([gradient_set[i] for gradient_set in gradient_sets], axis=0) for i in range(len(weights))] # Use the gradients to update the network. diff --git a/examples/lbfgs/README.md b/examples/lbfgs/README.md index 7703cf025..63d6eb2cf 100644 --- a/examples/lbfgs/README.md +++ b/examples/lbfgs/README.md @@ -111,12 +111,12 @@ gradient. def full_loss(theta): theta_id = ray.put(theta) loss_ids = [loss.remote(theta_id, xs_id, ys_id) for (xs_id, ys_id) in batch_ids] - return sum([ray.get(loss_id) for loss_id in loss_ids]) + return sum(ray.get(loss_ids)) def full_grad(theta): theta_id = ray.put(theta) grad_ids = [grad.remote(theta_id, xs_id, ys_id) for (xs_id, ys_id) in batch_ids] - return sum([ray.get(grad_id) for grad_id in grad_ids]).astype("float64") # This conversion is necessary for use with fmin_l_bfgs_b. + return sum(ray.get(grad_ids)).astype("float64") # This conversion is necessary for use with fmin_l_bfgs_b. ``` Note that we turn `theta` into a remote object with the line `theta_id = diff --git a/examples/lbfgs/driver.py b/examples/lbfgs/driver.py index 5d7fff6fd..cc040b6c6 100644 --- a/examples/lbfgs/driver.py +++ b/examples/lbfgs/driver.py @@ -92,13 +92,13 @@ if __name__ == "__main__": def full_loss(theta): theta_id = ray.put(theta) loss_ids = [loss.remote(theta_id, xs_id, ys_id) for (xs_id, ys_id) in batch_ids] - return sum([ray.get(loss_id) for loss_id in loss_ids]) + return sum(ray.get(loss_ids)) # Compute the gradient of the loss on the entire dataset. def full_grad(theta): theta_id = ray.put(theta) grad_ids = [grad.remote(theta_id, xs_id, ys_id) for (xs_id, ys_id) in batch_ids] - return sum([ray.get(grad_id) for grad_id in grad_ids]).astype("float64") # This conversion is necessary for use with fmin_l_bfgs_b. + return sum(ray.get(grad_ids)).astype("float64") # This conversion is necessary for use with fmin_l_bfgs_b. # From the perspective of scipy.optimize.fmin_l_bfgs_b, full_loss is simply a # function which takes some parameters theta, and computes a loss. Similarly, diff --git a/lib/python/ray/worker.py b/lib/python/ray/worker.py index 259755e18..934429044 100644 --- a/lib/python/ray/worker.py +++ b/lib/python/ray/worker.py @@ -768,22 +768,30 @@ def disconnect(worker=global_worker): reusables._cached_reusables = [] def get(objectid, worker=global_worker): - """Get a remote object from an object store. + """Get a remote object or a list of remote objects from the object store. This method blocks until the object corresponding to objectid is available in the local object store. If this object is not in the local object store, it will be shipped from an object store that has it (once the object has been - created). + created). If objectid is a list, then the objects corresponding to each object + in the list will be returned. Args: - objectid (raylib.ObjectID): Object ID to the object to get. + objectid: Object ID of the object to get or a list of object IDs to get. Returns: - A Python object + A Python object or a list of Python objects. """ check_connected(worker) if worker.mode == raylib.PYTHON_MODE: return objectid # In raylib.PYTHON_MODE, ray.get is the identity operation (the input will actually be a value not an objectid) + if isinstance(objectid, list): + [raylib.request_object(worker.handle, x) for x in objectid] + values = [worker.get_object(x) for x in objectid] + for i, value in enumerate(values): + if isinstance(value, RayTaskError): + raise RayGetError(objectid[i], value) + return values raylib.request_object(worker.handle, objectid) value = worker.get_object(objectid) if isinstance(value, RayTaskError): diff --git a/test/runtest.py b/test/runtest.py index 99f2c74d7..839b0d0c0 100644 --- a/test/runtest.py +++ b/test/runtest.py @@ -306,6 +306,12 @@ class APITest(unittest.TestCase): ray.worker.cleanup() + def testGetMultiple(self): + ray.init(start_ray_local=True, num_workers=0) + object_ids = [ray.put(i) for i in range(10)] + self.assertEqual(ray.get(object_ids), range(10)) + ray.worker.cleanup() + def testSelect(self): ray.init(start_ray_local=True, num_workers=4)