mirror of
https://github.com/wassname/ray.git
synced 2026-06-27 20:22:39 +08:00
Change get to take a timeout and multiple object IDs. (#212)
* Change plasma_get to take a timeout and an array of object IDs. * Address comments. * Bug fix related to computing object hashes. * Add test. * Fix file descriptor leak. * Fix valgrind. * Formatting. * Remove call to plasma_contains from the plasma client. Use timeout internally in ray.get. * small fixes
This commit is contained in:
committed by
Philipp Moritz
parent
4f6100b67f
commit
b98a63fd3a
+26
-8
@@ -142,29 +142,47 @@ class PlasmaClient(object):
|
||||
buff = libplasma.create(self.conn, object_id, size, metadata)
|
||||
return PlasmaBuffer(buff, object_id, self)
|
||||
|
||||
def get(self, object_id):
|
||||
def get(self, object_ids, timeout_ms=-1):
|
||||
"""Create a buffer from the PlasmaStore based on object ID.
|
||||
|
||||
If the object has not been sealed yet, this call will block. The retrieved
|
||||
buffer is immutable.
|
||||
|
||||
Args:
|
||||
object_id (str): A string used to identify an object.
|
||||
object_ids (List[str]): A list of strings used to identify some objects.
|
||||
timeout_ms (int): The number of milliseconds that the get call should
|
||||
block before timing out and returning.
|
||||
"""
|
||||
buff = libplasma.get(self.conn, object_id)[0]
|
||||
return PlasmaBuffer(buff, object_id, self)
|
||||
results = libplasma.get(self.conn, object_ids, timeout_ms)
|
||||
assert len(object_ids) == len(results)
|
||||
returns = []
|
||||
for i in range(len(object_ids)):
|
||||
if results[i] is None:
|
||||
returns.append(None)
|
||||
else:
|
||||
returns.append(PlasmaBuffer(results[i][0], object_ids[i], self))
|
||||
return returns
|
||||
|
||||
def get_metadata(self, object_id):
|
||||
def get_metadata(self, object_ids, timeout_ms=-1):
|
||||
"""Create a buffer from the PlasmaStore based on object ID.
|
||||
|
||||
If the object has not been sealed yet, this call will block until the object
|
||||
has been sealed. The retrieved buffer is immutable.
|
||||
|
||||
Args:
|
||||
object_id (str): A string used to identify an object.
|
||||
object_ids (List[str]): A list of strings used to identify some objects.
|
||||
timeout_ms (int): The number of milliseconds that the get call should
|
||||
block before timing out and returning.
|
||||
"""
|
||||
buff = libplasma.get(self.conn, object_id)[1]
|
||||
return PlasmaBuffer(buff, object_id, self)
|
||||
results = libplasma.get(self.conn, object_ids, timeout_ms)
|
||||
assert len(object_ids) == len(results)
|
||||
returns = []
|
||||
for i in range(len(object_ids)):
|
||||
if results[i] is None:
|
||||
returns.append(None)
|
||||
else:
|
||||
returns.append(PlasmaBuffer(results[i][1], object_ids[i], self))
|
||||
return returns
|
||||
|
||||
def contains(self, object_id):
|
||||
"""Check if the object is present and has been sealed in the PlasmaStore.
|
||||
|
||||
@@ -22,10 +22,10 @@ USE_VALGRIND = False
|
||||
PLASMA_STORE_MEMORY = 1000000000
|
||||
|
||||
def assert_get_object_equal(unit_test, client1, client2, object_id, memory_buffer=None, metadata=None):
|
||||
client1_buff = client1.get(object_id)
|
||||
client2_buff = client2.get(object_id)
|
||||
client1_metadata = client1.get_metadata(object_id)
|
||||
client2_metadata = client2.get_metadata(object_id)
|
||||
client1_buff = client1.get([object_id])[0]
|
||||
client2_buff = client2.get([object_id])[0]
|
||||
client1_metadata = client1.get_metadata([object_id])[0]
|
||||
client2_metadata = client2.get_metadata([object_id])[0]
|
||||
unit_test.assertEqual(len(client1_buff), len(client2_buff))
|
||||
unit_test.assertEqual(len(client1_metadata), len(client2_metadata))
|
||||
# Check that the buffers from the two clients are the same.
|
||||
@@ -72,7 +72,7 @@ class TestPlasmaClient(unittest.TestCase):
|
||||
# Seal the object.
|
||||
self.plasma_client.seal(object_id)
|
||||
# Get the object.
|
||||
memory_buffer = self.plasma_client.get(object_id)
|
||||
memory_buffer = self.plasma_client.get([object_id])[0]
|
||||
for i in range(length):
|
||||
self.assertEqual(memory_buffer[i], chr(i % 256))
|
||||
|
||||
@@ -89,11 +89,11 @@ class TestPlasmaClient(unittest.TestCase):
|
||||
# Seal the object.
|
||||
self.plasma_client.seal(object_id)
|
||||
# Get the object.
|
||||
memory_buffer = self.plasma_client.get(object_id)
|
||||
memory_buffer = self.plasma_client.get([object_id])[0]
|
||||
for i in range(length):
|
||||
self.assertEqual(memory_buffer[i], chr(i % 256))
|
||||
# Get the metadata.
|
||||
metadata_buffer = self.plasma_client.get_metadata(object_id)
|
||||
metadata_buffer = self.plasma_client.get_metadata([object_id])[0]
|
||||
self.assertEqual(len(metadata), len(metadata_buffer))
|
||||
for i in range(len(metadata)):
|
||||
self.assertEqual(chr(metadata[i]), metadata_buffer[i])
|
||||
@@ -112,6 +112,35 @@ class TestPlasmaClient(unittest.TestCase):
|
||||
else:
|
||||
self.assertTrue(False)
|
||||
|
||||
def test_get(self):
|
||||
num_object_ids = 100
|
||||
# Test timing out of get with various timeouts.
|
||||
for timeout in [0, 10, 100, 1000]:
|
||||
object_ids = [random_object_id() for _ in range(num_object_ids)]
|
||||
results = self.plasma_client.get(object_ids, timeout_ms=timeout)
|
||||
self.assertEqual(results, num_object_ids * [None])
|
||||
|
||||
data_buffers = []
|
||||
metadata_buffers = []
|
||||
for i in range(num_object_ids):
|
||||
if i % 2 == 0:
|
||||
data_buffer, metadata_buffer = create_object_with_id(self.plasma_client, object_ids[i], 2000, 2000)
|
||||
data_buffers.append(data_buffer)
|
||||
metadata_buffers.append(metadata_buffer)
|
||||
|
||||
# Test timing out from some but not all get calls with various timeouts.
|
||||
for timeout in [0, 10, 100, 1000]:
|
||||
data_results = self.plasma_client.get(object_ids, timeout_ms=timeout)
|
||||
metadata_results = self.plasma_client.get(object_ids, timeout_ms=timeout)
|
||||
for i in range(num_object_ids):
|
||||
if i % 2 == 0:
|
||||
self.assertTrue(plasma.buffers_equal(data_buffers[i // 2], data_results[i]))
|
||||
# TODO(rkn): We should compare the metadata as well. But currently the
|
||||
# types are different (e.g., memoryview versus bytearray).
|
||||
# self.assertTrue(plasma.buffers_equal(metadata_buffers[i // 2], metadata_results[i]))
|
||||
else:
|
||||
self.assertIsNone(results[i])
|
||||
|
||||
def test_store_full(self):
|
||||
# The store is started with 1GB, so make sure that create throws an
|
||||
# exception when it is full.
|
||||
@@ -336,7 +365,7 @@ class TestPlasmaClient(unittest.TestCase):
|
||||
# memory_buffer[0] = chr(0)
|
||||
# self.assertRaises(Exception, illegal_assignment)
|
||||
# Get the object.
|
||||
memory_buffer = self.plasma_client.get(object_id)
|
||||
memory_buffer = self.plasma_client.get([object_id])[0]
|
||||
# Make sure the object is read only.
|
||||
def illegal_assignment():
|
||||
memory_buffer[0] = chr(0)
|
||||
|
||||
+29
-20
@@ -428,20 +428,29 @@ class Worker(object):
|
||||
# Optionally do something with the contained_objectids here.
|
||||
contained_objectids = []
|
||||
|
||||
def get_object(self, objectid):
|
||||
"""Get the value in the local object store associated with objectid.
|
||||
def get_object(self, object_ids):
|
||||
"""Get the value or values in the local object store associated with object_ids.
|
||||
|
||||
Return the value from the local object store for objectid. This will block
|
||||
until the value for objectid has been written to the local object store.
|
||||
Return the values from the local object store for object_ids. This will block
|
||||
until all the values for object_ids have been written to the local object store.
|
||||
|
||||
Args:
|
||||
objectid (object_id.ObjectID): The object ID of the value to retrieve.
|
||||
object_ids (List[object_id.ObjectID]): A list of the object IDs whose
|
||||
values should be retrieved.
|
||||
"""
|
||||
self.plasma_client.fetch([objectid.id()])
|
||||
deserialized = numbuf.retrieve_list(objectid.id(), self.plasma_client.conn)
|
||||
self.plasma_client.fetch([object_id.id() for object_id in object_ids])
|
||||
# We currently pass in a timeout of one second.
|
||||
unready_ids = object_ids
|
||||
while len(unready_ids) > 0:
|
||||
results = numbuf.retrieve_list([object_id.id() for object_id in object_ids], self.plasma_client.conn, 1000)
|
||||
unready_ids = [object_id for (object_id, val) in results if val is None]
|
||||
# This would be a natural place to issue a command to reconstruct some of
|
||||
# the objects.
|
||||
# Unwrap the object from the list (it was wrapped put_object).
|
||||
assert len(deserialized) == 1
|
||||
return deserialized[0]
|
||||
assert len(results) == len(object_ids)
|
||||
for i in range(len(results)):
|
||||
assert results[i][0] == object_ids[i].id()
|
||||
return [result[1][0] for result in results]
|
||||
|
||||
def submit_task(self, function_id, func_name, args):
|
||||
"""Submit a remote task to the scheduler.
|
||||
@@ -1228,17 +1237,17 @@ def flush_log(worker=global_worker):
|
||||
worker.photon_client.log_event(event_log_key, event_log_value)
|
||||
worker.events = []
|
||||
|
||||
def get(objectid, worker=global_worker):
|
||||
def get(object_ids, worker=global_worker):
|
||||
"""Get a remote object or a list of remote objects from the object store.
|
||||
|
||||
This method blocks until the object corresponding to objectid is available in
|
||||
This method blocks until the object corresponding to the object ID is available in
|
||||
the local object store. If this object is not in the local object store, it
|
||||
will be shipped from an object store that has it (once the object has been
|
||||
created). If objectid is a list, then the objects corresponding to each object
|
||||
created). If object_ids is a list, then the objects corresponding to each object
|
||||
in the list will be returned.
|
||||
|
||||
Args:
|
||||
objectid: Object ID of the object to get or a list of object IDs to get.
|
||||
object_ids: Object ID of the object to get or a list of object IDs to get.
|
||||
|
||||
Returns:
|
||||
A Python object or a list of Python objects.
|
||||
@@ -1249,19 +1258,19 @@ def get(objectid, worker=global_worker):
|
||||
|
||||
if worker.mode == PYTHON_MODE:
|
||||
# In PYTHON_MODE, ray.get is the identity operation (the input will actually be a value not an objectid)
|
||||
return objectid
|
||||
if isinstance(objectid, list):
|
||||
values = [worker.get_object(x) for x in objectid]
|
||||
return object_ids
|
||||
if isinstance(object_ids, list):
|
||||
values = worker.get_object(object_ids)
|
||||
for i, value in enumerate(values):
|
||||
if isinstance(value, RayTaskError):
|
||||
raise RayGetError(objectid[i], value)
|
||||
raise RayGetError(object_ids[i], value)
|
||||
return values
|
||||
else:
|
||||
value = worker.get_object(objectid)
|
||||
value = worker.get_object([object_ids])[0]
|
||||
if isinstance(value, RayTaskError):
|
||||
# If the result is a RayTaskError, then the task that created this object
|
||||
# failed, and we should propagate the error message here.
|
||||
raise RayGetError(objectid, value)
|
||||
raise RayGetError(object_ids, value)
|
||||
return value
|
||||
|
||||
def put(value, worker=global_worker):
|
||||
@@ -1705,7 +1714,7 @@ def get_arguments_for_execution(function, serialized_args, worker=global_worker)
|
||||
for (i, arg) in enumerate(serialized_args):
|
||||
if isinstance(arg, photon.ObjectID):
|
||||
# get the object from the local object store
|
||||
argument = worker.get_object(arg)
|
||||
argument = worker.get_object([arg])[0]
|
||||
if isinstance(argument, RayTaskError):
|
||||
# If the result is a RayTaskError, then the task that created this
|
||||
# object failed, and we should propagate the error message here.
|
||||
|
||||
Reference in New Issue
Block a user