Allow manually writing to return ObjectIDs from tasks/actor methods (#3805)

This commit is contained in:
Peter Schafhalter
2019-03-19 03:24:57 +01:00
committed by Robert Nishihara
parent 7c3274e65b
commit c93eb126ec
3 changed files with 125 additions and 2 deletions
+16
View File
@@ -0,0 +1,16 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
class NoReturn(object):
"""Do not store the return value in the object store.
If a task returns this object, then Ray will not store this object in the
object store. Calling `ray.get` on the task's return ObjectIDs may block
indefinitely unless the task manually stores an object for the
corresponding ObjectID.
"""
def __init__(self):
raise TypeError("The `NoReturn` object should not be instantiated")
+95
View File
@@ -0,0 +1,95 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import pytest
import ray
import ray.exceptions
import ray.experimental.no_return
import ray.worker
@pytest.fixture
def ray_start():
# Start the Ray processes.
ray.init(num_cpus=1)
yield None
# The code after the yield will run as teardown code.
ray.shutdown()
def test_set_single_output(ray_start):
@ray.remote
def f():
return_object_ids = ray.worker.global_worker._current_task.returns()
ray.worker.global_worker.put_object(return_object_ids[0], 123)
return ray.experimental.no_return.NoReturn
assert ray.get(f.remote()) == 123
def test_set_multiple_outputs(ray_start):
@ray.remote(num_return_vals=3)
def f(set_out0, set_out1, set_out2):
returns = []
return_object_ids = ray.worker.global_worker._current_task.returns()
for i, set_out in enumerate([set_out0, set_out1, set_out2]):
if set_out:
ray.worker.global_worker.put_object(return_object_ids[i], True)
returns.append(ray.experimental.no_return.NoReturn)
else:
returns.append(False)
return tuple(returns)
for set_out0 in [True, False]:
for set_out1 in [True, False]:
for set_out2 in [True, False]:
result_object_ids = f.remote(set_out0, set_out1, set_out2)
assert ray.get(result_object_ids) == [
set_out0, set_out1, set_out2
]
def test_set_actor_method(ray_start):
@ray.remote
class Actor(object):
def __init__(self):
pass
def ping(self):
return_object_ids = ray.worker.global_worker._current_task.returns(
)
ray.worker.global_worker.put_object(return_object_ids[0], 123)
return ray.experimental.no_return.NoReturn
actor = Actor.remote()
assert ray.get(actor.ping.remote()) == 123
def test_exception(ray_start):
@ray.remote(num_return_vals=2)
def f():
return_object_ids = ray.worker.global_worker._current_task.returns()
# The first return value is successfully stored in the object store
ray.worker.global_worker.put_object(return_object_ids[0], 123)
raise Exception("Error")
# The exception is stored at the second return objcet ID.
return ray.experimental.no_return.NoReturn, 456
object_id, exception_id = f.remote()
assert ray.get(object_id) == 123
with pytest.raises(ray.exceptions.RayTaskError):
ray.get(exception_id)
def test_no_set_and_no_return(ray_start):
@ray.remote
def f():
return ray.experimental.no_return.NoReturn
object_id = f.remote()
with pytest.raises(ray.exceptions.RayTaskError) as e:
ray.get(object_id)
assert "Attempting to return 'ray.experimental.NoReturn'" in str(e.value)
+14 -2
View File
@@ -25,6 +25,7 @@ import pyarrow
import pyarrow.plasma as plasma
import ray.cloudpickle as pickle
import ray.experimental.signal as ray_signal
import ray.experimental.no_return
import ray.experimental.state as state
import ray.gcs_utils
import ray.memory_monitor as memory_monitor
@@ -164,6 +165,7 @@ class Worker(object):
# Index of the current session. This number will
# increment every time when `ray.shutdown` is called.
self._session_index = 0
self._current_task = None
@property
def task_context(self):
@@ -790,8 +792,15 @@ class Worker(object):
if isinstance(outputs[i], ray.actor.ActorHandle):
raise Exception("Returning an actor handle from a remote "
"function is not allowed).")
self.put_object(object_ids[i], outputs[i])
if outputs[i] is ray.experimental.no_return.NoReturn:
if not self.plasma_client.contains(
pyarrow.plasma.ObjectID(object_ids[i].binary())):
raise RuntimeError(
"Attempting to return 'ray.experimental.NoReturn' "
"from a remote function, but the corresponding "
"ObjectID does not exist in the local object store.")
else:
self.put_object(object_ids[i], outputs[i])
def _process_task(self, task, function_execution_info):
"""Execute a task assigned to this worker.
@@ -847,6 +856,7 @@ class Worker(object):
# Execute the task.
try:
self._current_task = task
with profiling.profile("task:execute"):
if (task.actor_id().is_nil()
and task.actor_creation_id().is_nil()):
@@ -867,6 +877,8 @@ class Worker(object):
self._handle_process_task_failure(
function_descriptor, return_object_ids, e, traceback_str)
return
finally:
self._current_task = None
# Store the outputs in the local object store.
try: