diff --git a/python/plasma/plasma.py b/python/plasma/plasma.py index e5e62d602..37b53fa0d 100644 --- a/python/plasma/plasma.py +++ b/python/plasma/plasma.py @@ -260,6 +260,8 @@ class PlasmaClient(object): def wait(self, object_ids, timeout=PLASMA_WAIT_TIMEOUT, num_returns=1): """Wait until num_returns objects in object_ids are ready. + Currently, the object ID arguments to wait must be unique. + Args: object_ids (List[str]): List of object IDs to wait for. timeout (int): Return to the caller after timeout milliseconds. @@ -269,6 +271,10 @@ class PlasmaClient(object): ready_ids, waiting_ids (List[str], List[str]): List of object IDs that are ready and list of object IDs we might still wait on respectively. """ + # Check that the object ID arguments are unique. The plasma manager + # currently crashes if given duplicate object IDs. + if len(object_ids) != len(set(object_ids)): + raise Exception("Wait requires a list of unique object IDs.") ready_ids, waiting_ids = libplasma.wait(self.conn, object_ids, timeout, num_returns) return ready_ids, list(waiting_ids) diff --git a/python/ray/worker.py b/python/ray/worker.py index 8754f523d..8930e98ce 100644 --- a/python/ray/worker.py +++ b/python/ray/worker.py @@ -1421,7 +1421,7 @@ def wait(object_ids, num_returns=1, timeout=None, worker=global_worker): Args: object_ids (List[ObjectID]): List of object IDs for objects that may - or may not be ready. + or may not be ready. Note that these IDs must be unique. num_returns (int): The number of object IDs that should be returned. timeout (int): The maximum amount of time in milliseconds to wait before returning. diff --git a/test/runtest.py b/test/runtest.py index e357356bc..24aeff7a1 100644 --- a/test/runtest.py +++ b/test/runtest.py @@ -380,6 +380,10 @@ class APITest(unittest.TestCase): self.assertEqual(len(ready_ids), 1) self.assertEqual(len(remaining_ids), 3) + # Verify that calling wait with duplicate object IDs throws an exception. + x = ray.put(1) + self.assertRaises(Exception, lambda : ray.wait([x, x])) + ray.worker.cleanup() def testMultipleWaitsAndGets(self):