From 3b7788bf880d58b0fcd42daba1515f5fb05dfe09 Mon Sep 17 00:00:00 2001 From: Robert Nishihara Date: Sat, 11 Mar 2017 12:09:28 -0800 Subject: [PATCH] Disallow calling ray.put on an object ID. (#353) --- python/ray/worker.py | 16 +++++++++++++--- test/runtest.py | 13 +++++++++++++ 2 files changed, 26 insertions(+), 3 deletions(-) diff --git a/python/ray/worker.py b/python/ray/worker.py index 379eb3510..82c55f4b8 100644 --- a/python/ray/worker.py +++ b/python/ray/worker.py @@ -441,6 +441,14 @@ class Worker(object): objectid (object_id.ObjectID): The object ID of the value to be put. value: The value to put in the object store. """ + # Make sure that the value is not an object ID. + if isinstance(value, ray.local_scheduler.ObjectID): + raise Exception("Calling `put` on an ObjectID is not allowed (similarly, " + "returning an ObjectID from a remote function is not " + "allowed). If you really want to do this, you can wrap " + "the ObjectID in a list and call `put` on it (or return " + "it).") + # Serialize and put the object in the object store. try: ray.numbuf.store_list(objectid.id(), self.plasma_client.conn, [value]) @@ -465,6 +473,11 @@ class Worker(object): object_ids (List[object_id.ObjectID]): A list of the object IDs whose values should be retrieved. """ + # Make sure that the values are object IDs. + for object_id in object_ids: + if not isinstance(object_id, ray.local_scheduler.ObjectID): + raise Exception("Attempting to call `get` on the value {}, which is " + "not an ObjectID.".format(object_id)) # Do an initial fetch for remote objects. self.plasma_client.fetch([object_id.id() for object_id in object_ids]) @@ -1980,8 +1993,5 @@ def store_outputs_in_objstore(objectids, outputs, worker=global_worker): wrapped in a tuple with one element prior to being passed into this function. """ - for i in range(len(objectids)): - if isinstance(outputs[i], ray.local_scheduler.ObjectID): - raise Exception("This remote function returned an ObjectID as its {}th return value. This is not allowed.".format(i)) for i in range(len(objectids)): worker.put_object(objectids[i], outputs[i]) diff --git a/test/runtest.py b/test/runtest.py index aeb5a0180..413558791 100644 --- a/test/runtest.py +++ b/test/runtest.py @@ -655,6 +655,19 @@ class APITest(unittest.TestCase): ray.worker.cleanup() + def testIllegalAPICalls(self): + ray.init(num_workers=0) + + # Verify that we cannot call put on an ObjectID. + x = ray.put(1) + with self.assertRaises(Exception): + ray.put(x) + # Verify that we cannot call get on a regular value. + with self.assertRaises(Exception): + ray.get(3) + + ray.worker.cleanup() + class PythonModeTest(unittest.TestCase): def testPythonMode(self):