Change get to take a timeout and multiple object IDs. (#212)

* Change plasma_get to take a timeout and an array of object IDs.

* Address comments.

* Bug fix related to computing object hashes.

* Add test.

* Fix file descriptor leak.

* Fix valgrind.

* Formatting.

* Remove call to plasma_contains from the plasma client. Use timeout internally in ray.get.

* small fixes
This commit is contained in:
Robert Nishihara
2017-01-19 12:21:12 -08:00
committed by Philipp Moritz
parent 4f6100b67f
commit b98a63fd3a
16 changed files with 715 additions and 1016 deletions
+26 -8
View File
@@ -142,29 +142,47 @@ class PlasmaClient(object):
buff = libplasma.create(self.conn, object_id, size, metadata)
return PlasmaBuffer(buff, object_id, self)
def get(self, object_id):
def get(self, object_ids, timeout_ms=-1):
"""Create a buffer from the PlasmaStore based on object ID.
If the object has not been sealed yet, this call will block. The retrieved
buffer is immutable.
Args:
object_id (str): A string used to identify an object.
object_ids (List[str]): A list of strings used to identify some objects.
timeout_ms (int): The number of milliseconds that the get call should
block before timing out and returning.
"""
buff = libplasma.get(self.conn, object_id)[0]
return PlasmaBuffer(buff, object_id, self)
results = libplasma.get(self.conn, object_ids, timeout_ms)
assert len(object_ids) == len(results)
returns = []
for i in range(len(object_ids)):
if results[i] is None:
returns.append(None)
else:
returns.append(PlasmaBuffer(results[i][0], object_ids[i], self))
return returns
def get_metadata(self, object_id):
def get_metadata(self, object_ids, timeout_ms=-1):
"""Create a buffer from the PlasmaStore based on object ID.
If the object has not been sealed yet, this call will block until the object
has been sealed. The retrieved buffer is immutable.
Args:
object_id (str): A string used to identify an object.
object_ids (List[str]): A list of strings used to identify some objects.
timeout_ms (int): The number of milliseconds that the get call should
block before timing out and returning.
"""
buff = libplasma.get(self.conn, object_id)[1]
return PlasmaBuffer(buff, object_id, self)
results = libplasma.get(self.conn, object_ids, timeout_ms)
assert len(object_ids) == len(results)
returns = []
for i in range(len(object_ids)):
if results[i] is None:
returns.append(None)
else:
returns.append(PlasmaBuffer(results[i][1], object_ids[i], self))
return returns
def contains(self, object_id):
"""Check if the object is present and has been sealed in the PlasmaStore.
+37 -8
View File
@@ -22,10 +22,10 @@ USE_VALGRIND = False
PLASMA_STORE_MEMORY = 1000000000
def assert_get_object_equal(unit_test, client1, client2, object_id, memory_buffer=None, metadata=None):
client1_buff = client1.get(object_id)
client2_buff = client2.get(object_id)
client1_metadata = client1.get_metadata(object_id)
client2_metadata = client2.get_metadata(object_id)
client1_buff = client1.get([object_id])[0]
client2_buff = client2.get([object_id])[0]
client1_metadata = client1.get_metadata([object_id])[0]
client2_metadata = client2.get_metadata([object_id])[0]
unit_test.assertEqual(len(client1_buff), len(client2_buff))
unit_test.assertEqual(len(client1_metadata), len(client2_metadata))
# Check that the buffers from the two clients are the same.
@@ -72,7 +72,7 @@ class TestPlasmaClient(unittest.TestCase):
# Seal the object.
self.plasma_client.seal(object_id)
# Get the object.
memory_buffer = self.plasma_client.get(object_id)
memory_buffer = self.plasma_client.get([object_id])[0]
for i in range(length):
self.assertEqual(memory_buffer[i], chr(i % 256))
@@ -89,11 +89,11 @@ class TestPlasmaClient(unittest.TestCase):
# Seal the object.
self.plasma_client.seal(object_id)
# Get the object.
memory_buffer = self.plasma_client.get(object_id)
memory_buffer = self.plasma_client.get([object_id])[0]
for i in range(length):
self.assertEqual(memory_buffer[i], chr(i % 256))
# Get the metadata.
metadata_buffer = self.plasma_client.get_metadata(object_id)
metadata_buffer = self.plasma_client.get_metadata([object_id])[0]
self.assertEqual(len(metadata), len(metadata_buffer))
for i in range(len(metadata)):
self.assertEqual(chr(metadata[i]), metadata_buffer[i])
@@ -112,6 +112,35 @@ class TestPlasmaClient(unittest.TestCase):
else:
self.assertTrue(False)
def test_get(self):
num_object_ids = 100
# Test timing out of get with various timeouts.
for timeout in [0, 10, 100, 1000]:
object_ids = [random_object_id() for _ in range(num_object_ids)]
results = self.plasma_client.get(object_ids, timeout_ms=timeout)
self.assertEqual(results, num_object_ids * [None])
data_buffers = []
metadata_buffers = []
for i in range(num_object_ids):
if i % 2 == 0:
data_buffer, metadata_buffer = create_object_with_id(self.plasma_client, object_ids[i], 2000, 2000)
data_buffers.append(data_buffer)
metadata_buffers.append(metadata_buffer)
# Test timing out from some but not all get calls with various timeouts.
for timeout in [0, 10, 100, 1000]:
data_results = self.plasma_client.get(object_ids, timeout_ms=timeout)
metadata_results = self.plasma_client.get(object_ids, timeout_ms=timeout)
for i in range(num_object_ids):
if i % 2 == 0:
self.assertTrue(plasma.buffers_equal(data_buffers[i // 2], data_results[i]))
# TODO(rkn): We should compare the metadata as well. But currently the
# types are different (e.g., memoryview versus bytearray).
# self.assertTrue(plasma.buffers_equal(metadata_buffers[i // 2], metadata_results[i]))
else:
self.assertIsNone(results[i])
def test_store_full(self):
# The store is started with 1GB, so make sure that create throws an
# exception when it is full.
@@ -336,7 +365,7 @@ class TestPlasmaClient(unittest.TestCase):
# memory_buffer[0] = chr(0)
# self.assertRaises(Exception, illegal_assignment)
# Get the object.
memory_buffer = self.plasma_client.get(object_id)
memory_buffer = self.plasma_client.get([object_id])[0]
# Make sure the object is read only.
def illegal_assignment():
memory_buffer[0] = chr(0)
+29 -20
View File
@@ -428,20 +428,29 @@ class Worker(object):
# Optionally do something with the contained_objectids here.
contained_objectids = []
def get_object(self, objectid):
"""Get the value in the local object store associated with objectid.
def get_object(self, object_ids):
"""Get the value or values in the local object store associated with object_ids.
Return the value from the local object store for objectid. This will block
until the value for objectid has been written to the local object store.
Return the values from the local object store for object_ids. This will block
until all the values for object_ids have been written to the local object store.
Args:
objectid (object_id.ObjectID): The object ID of the value to retrieve.
object_ids (List[object_id.ObjectID]): A list of the object IDs whose
values should be retrieved.
"""
self.plasma_client.fetch([objectid.id()])
deserialized = numbuf.retrieve_list(objectid.id(), self.plasma_client.conn)
self.plasma_client.fetch([object_id.id() for object_id in object_ids])
# We currently pass in a timeout of one second.
unready_ids = object_ids
while len(unready_ids) > 0:
results = numbuf.retrieve_list([object_id.id() for object_id in object_ids], self.plasma_client.conn, 1000)
unready_ids = [object_id for (object_id, val) in results if val is None]
# This would be a natural place to issue a command to reconstruct some of
# the objects.
# Unwrap the object from the list (it was wrapped put_object).
assert len(deserialized) == 1
return deserialized[0]
assert len(results) == len(object_ids)
for i in range(len(results)):
assert results[i][0] == object_ids[i].id()
return [result[1][0] for result in results]
def submit_task(self, function_id, func_name, args):
"""Submit a remote task to the scheduler.
@@ -1228,17 +1237,17 @@ def flush_log(worker=global_worker):
worker.photon_client.log_event(event_log_key, event_log_value)
worker.events = []
def get(objectid, worker=global_worker):
def get(object_ids, worker=global_worker):
"""Get a remote object or a list of remote objects from the object store.
This method blocks until the object corresponding to objectid is available in
This method blocks until the object corresponding to the object ID is available in
the local object store. If this object is not in the local object store, it
will be shipped from an object store that has it (once the object has been
created). If objectid is a list, then the objects corresponding to each object
created). If object_ids is a list, then the objects corresponding to each object
in the list will be returned.
Args:
objectid: Object ID of the object to get or a list of object IDs to get.
object_ids: Object ID of the object to get or a list of object IDs to get.
Returns:
A Python object or a list of Python objects.
@@ -1249,19 +1258,19 @@ def get(objectid, worker=global_worker):
if worker.mode == PYTHON_MODE:
# In PYTHON_MODE, ray.get is the identity operation (the input will actually be a value not an objectid)
return objectid
if isinstance(objectid, list):
values = [worker.get_object(x) for x in objectid]
return object_ids
if isinstance(object_ids, list):
values = worker.get_object(object_ids)
for i, value in enumerate(values):
if isinstance(value, RayTaskError):
raise RayGetError(objectid[i], value)
raise RayGetError(object_ids[i], value)
return values
else:
value = worker.get_object(objectid)
value = worker.get_object([object_ids])[0]
if isinstance(value, RayTaskError):
# If the result is a RayTaskError, then the task that created this object
# failed, and we should propagate the error message here.
raise RayGetError(objectid, value)
raise RayGetError(object_ids, value)
return value
def put(value, worker=global_worker):
@@ -1705,7 +1714,7 @@ def get_arguments_for_execution(function, serialized_args, worker=global_worker)
for (i, arg) in enumerate(serialized_args):
if isinstance(arg, photon.ObjectID):
# get the object from the local object store
argument = worker.get_object(arg)
argument = worker.get_object([arg])[0]
if isinstance(argument, RayTaskError):
# If the result is a RayTaskError, then the task that created this
# object failed, and we should propagate the error message here.