From 317d0da7d85fc926723b9a5e8fbe78973ec0dab9 Mon Sep 17 00:00:00 2001 From: Kunal Gosar Date: Fri, 1 Jun 2018 16:42:27 -0700 Subject: [PATCH] Add experimental API for ray.get and ray.wait with additional argument types (#2071) --- python/ray/experimental/__init__.py | 4 +- python/ray/experimental/api.py | 66 +++++++++++++++++++++++++++++ test/runtest.py | 47 ++++++++++++++++++++ 3 files changed, 116 insertions(+), 1 deletion(-) create mode 100644 python/ray/experimental/api.py diff --git a/python/ray/experimental/__init__.py b/python/ray/experimental/__init__.py index 58005f443..a10cc6ed3 100644 --- a/python/ray/experimental/__init__.py +++ b/python/ray/experimental/__init__.py @@ -8,10 +8,12 @@ from .features import ( flush_finished_tasks_unsafe, flush_evicted_objects_unsafe, _flush_finished_tasks_unsafe_shard, _flush_evicted_objects_unsafe_shard) from .named_actors import get_actor, register_actor +from .api import get, wait __all__ = [ "TensorFlowVariables", "flush_redis_unsafe", "flush_task_and_object_metadata_unsafe", "flush_finished_tasks_unsafe", "flush_evicted_objects_unsafe", "_flush_finished_tasks_unsafe_shard", - "_flush_evicted_objects_unsafe_shard", "get_actor", "register_actor" + "_flush_evicted_objects_unsafe_shard", "get_actor", "register_actor", + "get", "wait" ] diff --git a/python/ray/experimental/api.py b/python/ray/experimental/api.py new file mode 100644 index 000000000..9891ecff7 --- /dev/null +++ b/python/ray/experimental/api.py @@ -0,0 +1,66 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import ray +import numpy as np + + +def get(object_ids, worker=None): + """Get a single or a collection of remote objects from the object store. + + This method is identical to `ray.get` except it adds support for tuples, + ndarrays and dictionaries. + + Args: + object_ids: Object ID of the object to get, a list, tuple, ndarray of + object IDs to get or a dict of {key: object ID}. + + Returns: + A Python object, a list of Python objects or a dict of {key: object}. + """ + # There is a dependency on ray.worker which prevents importing + # global_worker at the top of this file + worker = ray.worker.global_worker if worker is None else worker + if isinstance(object_ids, (tuple, np.ndarray)): + return ray.get(list(object_ids), worker) + elif isinstance(object_ids, dict): + keys_to_get = [ + k for k, v in object_ids.items() if isinstance(v, ray.ObjectID) + ] + ids_to_get = [ + v for k, v in object_ids.items() if isinstance(v, ray.ObjectID) + ] + values = ray.get(ids_to_get) + + result = object_ids.copy() + for key, value in zip(keys_to_get, values): + result[key] = value + return result + else: + return ray.get(object_ids, worker) + + +def wait(object_ids, num_returns=1, timeout=None, worker=None): + """Return a list of IDs that are ready and a list of IDs that are not. + + This method is identical to `ray.wait` except it adds support for tuples + and ndarrays. + + Args: + object_ids (List[ObjectID], Tuple(ObjectID), np.array(ObjectID)): + List like of object IDs for objects that may or may not be ready. + Note that these IDs must be unique. + num_returns (int): The number of object IDs that should be returned. + timeout (int): The maximum amount of time in milliseconds to wait + before returning. + + Returns: + A list of object IDs that are ready and a list of the remaining object + IDs. + """ + worker = ray.worker.global_worker if worker is None else worker + if isinstance(object_ids, (tuple, np.ndarray)): + return ray.wait(list(object_ids), num_returns, timeout, worker) + + return ray.wait(object_ids, num_returns, timeout, worker) diff --git a/test/runtest.py b/test/runtest.py index 1286c7ad7..a0ff07d8f 100644 --- a/test/runtest.py +++ b/test/runtest.py @@ -758,6 +758,27 @@ class APITest(unittest.TestCase): results = ray.get([object_ids[i] for i in indices]) self.assertEqual(results, indices) + def testGetMultipleExperimental(self): + self.init_ray() + object_ids = [ray.put(i) for i in range(10)] + + object_ids_tuple = tuple(object_ids) + self.assertEqual( + ray.experimental.get(object_ids_tuple), list(range(10))) + + object_ids_nparray = np.array(object_ids) + self.assertEqual( + ray.experimental.get(object_ids_nparray), list(range(10))) + + def testGetDict(self): + self.init_ray() + d = {str(i): ray.put(i) for i in range(5)} + for i in range(5, 10): + d[str(i)] = i + result = ray.experimental.get(d) + expected = {str(i): i for i in range(10)} + self.assertEqual(result, expected) + @unittest.skipIf( os.environ.get("RAY_USE_XRAY") == "1", "This test does not work with xray yet.") @@ -826,6 +847,32 @@ class APITest(unittest.TestCase): with self.assertRaises(TypeError): ray.wait([1]) + @unittest.skipIf( + os.environ.get("RAY_USE_XRAY") == "1", + "This test does not work with xray yet.") + def testWaitIterables(self): + self.init_ray(num_cpus=1) + + @ray.remote + def f(delay): + time.sleep(delay) + return 1 + + objectids = (f.remote(1.0), f.remote(0.5), f.remote(0.5), + f.remote(0.5)) + ready_ids, remaining_ids = ray.experimental.wait(objectids) + self.assertEqual(len(ready_ids), 1) + self.assertEqual(len(remaining_ids), 3) + + objectids = np.array( + [f.remote(1.0), + f.remote(0.5), + f.remote(0.5), + f.remote(0.5)]) + ready_ids, remaining_ids = ray.experimental.wait(objectids) + self.assertEqual(len(ready_ids), 1) + self.assertEqual(len(remaining_ids), 3) + @unittest.skipIf( os.environ.get("RAY_USE_XRAY") == "1", "This test does not work with xray yet.")