From e3e3ad4b2525e464062818ee572f7e928dcb8ff3 Mon Sep 17 00:00:00 2001 From: Ujval Misra Date: Thu, 14 Nov 2019 00:50:04 -0800 Subject: [PATCH] Add timeout param to ray.get (#6107) --- doc/source/walkthrough.rst | 22 ++++++++++++++++--- python/ray/_raylet.pyx | 5 ++++- python/ray/exceptions.py | 6 +++++ python/ray/includes/common.pxd | 5 +++++ python/ray/tests/test_basic.py | 15 +++++++++++++ python/ray/worker.py | 20 +++++++++++++---- src/ray/common/status.cc | 3 +++ src/ray/common/status.h | 10 +++++++-- .../memory_store/memory_store.cc | 8 +++++-- .../store_provider/memory_store_provider.cc | 7 ++++-- .../store_provider/plasma_store_provider.cc | 10 +++++++-- src/ray/core_worker/test/core_worker_test.cc | 9 ++++---- 12 files changed, 100 insertions(+), 20 deletions(-) diff --git a/doc/source/walkthrough.rst b/doc/source/walkthrough.rst index 8881b200c..df0e4c2c0 100644 --- a/doc/source/walkthrough.rst +++ b/doc/source/walkthrough.rst @@ -40,7 +40,7 @@ Ray enables arbitrary Python functions to be executed asynchronously. These asyn def remote_function(): return 1 -This causes a few things changes in behavior: +This causes a few changes in behavior: 1. **Invocation:** The regular version is called with ``regular_function()``, whereas the remote version is called with ``remote_function.remote()``. 2. **Return values:** ``regular_function`` immediately executes and returns ``1``, whereas ``remote_function`` immediately returns an object ID (a future) and then creates a task that will be executed on a worker process. The result can be retrieved with ``ray.get``. @@ -145,7 +145,7 @@ Below are more examples of resource specifications: def f(): return 1 -Further, remote function can return multiple object IDs. +Further, remote functions can return multiple object IDs. .. code-block:: python @@ -188,7 +188,7 @@ Object IDs can be created in multiple ways. Fetching Results ---------------- -The command ``ray.get(x_id)`` takes an object ID and creates a Python object +The command ``ray.get(x_id, timeout=None)`` takes an object ID and creates a Python object from the corresponding remote object. First, if the current node's object store does not contain the object, the object is downloaded. Then, if the object is a `numpy array `__ or a collection of numpy arrays, the ``get`` call is zero-copy and returns arrays backed by shared object store memory. @@ -200,6 +200,22 @@ Otherwise, we deserialize the object data into a Python object. obj_id = ray.put(y) assert ray.get(obj_id) == 1 +You can also set a timeout to return early from a ``get`` that's blocking for too long. + +.. code-block:: python + + from ray.exceptions import RayTimeoutException + + @ray.remote + def long_running_function() + time.sleep(8) + + obj_id = long_running_function.remote() + try: + ray.get(obj_id, timeout=4) + except RayTimeoutError: + print("`get` timed out.") + .. autofunction:: ray.get :noindex: diff --git a/python/ray/_raylet.pyx b/python/ray/_raylet.pyx index b1df36d56..9ea19f0d1 100644 --- a/python/ray/_raylet.pyx +++ b/python/ray/_raylet.pyx @@ -84,7 +84,8 @@ from ray.exceptions import ( RayError, RayletError, RayTaskError, - ObjectStoreFullError + ObjectStoreFullError, + RayTimeoutError, ) from ray.experimental.no_return import NoReturn from ray.function_manager import FunctionDescriptor @@ -138,6 +139,8 @@ cdef int check_status(const CRayStatus& status) nogil except -1: raise ObjectStoreFullError(message) elif status.IsInterrupted(): raise KeyboardInterrupt() + elif status.IsTimedOut(): + raise RayTimeoutError(message) else: raise RayletError(message) diff --git a/python/ray/exceptions.py b/python/ray/exceptions.py index 3e6ea37f1..9045c9ccc 100644 --- a/python/ray/exceptions.py +++ b/python/ray/exceptions.py @@ -161,6 +161,11 @@ class UnreconstructableError(RayError): "https://ray.readthedocs.io/en/latest/memory-management.html")) +class RayTimeoutError(RayError): + """Indicates that a call to the worker timed out.""" + pass + + RAY_EXCEPTION_TYPES = [ RayError, RayTaskError, @@ -168,4 +173,5 @@ RAY_EXCEPTION_TYPES = [ RayActorError, ObjectStoreFullError, UnreconstructableError, + RayTimeoutError, ] diff --git a/python/ray/includes/common.pxd b/python/ray/includes/common.pxd index e16f944c1..afdb8c43a 100644 --- a/python/ray/includes/common.pxd +++ b/python/ray/includes/common.pxd @@ -73,6 +73,9 @@ cdef extern from "ray/common/status.h" namespace "ray" nogil: @staticmethod CRayStatus RedisError(const c_string &msg) + @staticmethod + CRayStatus TimedOut(const c_string &msg) + @staticmethod CRayStatus Interrupted(const c_string &msg) @@ -89,7 +92,9 @@ cdef extern from "ray/common/status.h" namespace "ray" nogil: c_bool IsNotImplemented() c_bool IsObjectStoreFull() c_bool IsRedisError() + c_bool IsTimedOut() c_bool IsInterrupted() + c_bool IsSystemExit() c_string ToString() c_string CodeAsString() diff --git a/python/ray/tests/test_basic.py b/python/ray/tests/test_basic.py index 77f424a8b..029dcc2a5 100644 --- a/python/ray/tests/test_basic.py +++ b/python/ray/tests/test_basic.py @@ -29,6 +29,7 @@ import pytest import ray from ray import signature +from ray.exceptions import RayTimeoutError import ray.ray_constants as ray_constants import ray.tests.cluster_utils import ray.tests.utils @@ -1190,6 +1191,20 @@ def test_get_dict(ray_start_regular): assert result == expected +def test_get_with_timeout(ray_start_regular): + @ray.remote + def f(a): + time.sleep(a) + return a + + assert ray.get(f.remote(3), timeout=10) == 3 + + obj_id = f.remote(3) + with pytest.raises(RayTimeoutError): + ray.get(obj_id, timeout=2) + assert ray.get(obj_id, timeout=2) == 3 + + def test_direct_call_simple(ray_start_regular): @ray.remote def f(x): diff --git a/python/ray/worker.py b/python/ray/worker.py index a4a2216db..3809fb2d2 100644 --- a/python/ray/worker.py +++ b/python/ray/worker.py @@ -286,7 +286,7 @@ class Worker(object): return context.deserialize_objects(data_metadata_pairs, object_ids, error_timeout) - def get_objects(self, object_ids): + def get_objects(self, object_ids, timeout=None): """Get the values in the object store associated with the IDs. Return the values from the local object store for object_ids. This will @@ -296,6 +296,8 @@ class Worker(object): Args: object_ids (List[object_id.ObjectID]): A list of the object IDs whose values should be retrieved. + timeout (float): timeout (float): The maximum amount of time in + seconds to wait before returning. Raises: Exception if running in LOCAL_MODE and any of the object IDs do not @@ -309,10 +311,15 @@ class Worker(object): "which is not an ray.ObjectID.".format(object_id)) if self.mode == LOCAL_MODE: + # TODO(ujvl): Remove check when local mode moved to core worker. + if timeout is not None: + raise ValueError( + "`get` must be called with timeout=None in local mode.") return self.local_mode_manager.get_objects(object_ids) + timeout_ms = int(timeout * 1000) if timeout else -1 data_metadata_pairs = self.core_worker.get_objects( - object_ids, self.current_task_id) + object_ids, self.current_task_id, timeout_ms) return self.deserialize_objects(data_metadata_pairs, object_ids) def run_function_on_all_workers(self, function, @@ -1388,7 +1395,7 @@ def register_custom_serializer(cls, class_id=class_id) -def get(object_ids): +def get(object_ids, timeout=None): """Get a remote object or a list of remote objects from the object store. This method blocks until the object corresponding to the object ID is @@ -1400,11 +1407,15 @@ def get(object_ids): Args: object_ids: Object ID of the object to get or a list of object IDs to get. + timeout (float): The maximum amount of time in seconds to wait before + returning. Returns: A Python object or a list of Python objects. Raises: + RayTimeoutError: A RayTimeoutError is raised if a timeout is set and + the get takes longer than timeout to return. Exception: An exception is raised if the task that created the object or that created one of the objects raised an exception. """ @@ -1420,7 +1431,8 @@ def get(object_ids): "or a list of object IDs.") global last_task_error_raise_time - values = worker.get_objects(object_ids) + # TODO(ujvl): Consider how to allow user to retrieve the ready objects. + values = worker.get_objects(object_ids, timeout=timeout) for i, value in enumerate(values): if isinstance(value, RayError): last_task_error_raise_time = time.time() diff --git a/src/ray/common/status.cc b/src/ray/common/status.cc index f7345c35c..fc32c5196 100644 --- a/src/ray/common/status.cc +++ b/src/ray/common/status.cc @@ -74,6 +74,9 @@ std::string Status::CodeAsString() const { case StatusCode::RedisError: type = "RedisError"; break; + case StatusCode::TimedOut: + type = "TimedOut"; + break; case StatusCode::Interrupted: type = "Interrupted"; break; diff --git a/src/ray/common/status.h b/src/ray/common/status.h index e4e970727..1fb329f55 100644 --- a/src/ray/common/status.h +++ b/src/ray/common/status.h @@ -79,8 +79,9 @@ enum class StatusCode : char { UnknownError = 9, NotImplemented = 10, RedisError = 11, - Interrupted = 12, - SystemExit = 13, + TimedOut = 12, + Interrupted = 13, + SystemExit = 14, }; #if defined(__clang__) @@ -144,6 +145,10 @@ class RAY_EXPORT Status { return Status(StatusCode::RedisError, msg); } + static Status TimedOut(const std::string &msg) { + return Status(StatusCode::TimedOut, msg); + } + static Status Interrupted(const std::string &msg) { return Status(StatusCode::Interrupted, msg); } @@ -165,6 +170,7 @@ class RAY_EXPORT Status { bool IsUnknownError() const { return code() == StatusCode::UnknownError; } bool IsNotImplemented() const { return code() == StatusCode::NotImplemented; } bool IsRedisError() const { return code() == StatusCode::RedisError; } + bool IsTimedOut() const { return code() == StatusCode::TimedOut; } bool IsInterrupted() const { return code() == StatusCode::Interrupted; } bool IsSystemExit() const { return code() == StatusCode::SystemExit; } diff --git a/src/ray/core_worker/store_provider/memory_store/memory_store.cc b/src/ray/core_worker/store_provider/memory_store/memory_store.cc index 0084778da..10afc3e58 100644 --- a/src/ray/core_worker/store_provider/memory_store/memory_store.cc +++ b/src/ray/core_worker/store_provider/memory_store/memory_store.cc @@ -223,7 +223,7 @@ Status CoreWorkerMemoryStore::Get(const std::vector &object_ids, } // Wait for remaining objects (or timeout). - get_request->Wait(timeout_ms); + bool done = get_request->Wait(timeout_ms); { absl::MutexLock lock(&mu_); @@ -253,7 +253,11 @@ Status CoreWorkerMemoryStore::Get(const std::vector &object_ids, } } - return Status::OK(); + if (done) { + return Status::OK(); + } else { + return Status::TimedOut("Get timed out: some object(s) not ready."); + } } void CoreWorkerMemoryStore::Delete(const std::vector &object_ids) { diff --git a/src/ray/core_worker/store_provider/memory_store_provider.cc b/src/ray/core_worker/store_provider/memory_store_provider.cc index 12b6d14bc..883c2949b 100644 --- a/src/ray/core_worker/store_provider/memory_store_provider.cc +++ b/src/ray/core_worker/store_provider/memory_store_provider.cc @@ -55,8 +55,11 @@ Status CoreWorkerMemoryStoreProvider::Wait( std::vector id_vector(object_ids.begin(), object_ids.end()); std::vector> result_objects; RAY_CHECK(object_ids.size() == id_vector.size()); - RAY_RETURN_NOT_OK( - store_->Get(id_vector, num_objects, timeout_ms, false, &result_objects)); + auto status = store_->Get(id_vector, num_objects, timeout_ms, false, &result_objects); + // Ignore TimedOut statuses since we return ready objects explicitly. + if (!status.IsTimedOut()) { + RAY_RETURN_NOT_OK(status); + } for (size_t i = 0; i < id_vector.size(); i++) { if (result_objects[i] != nullptr) { diff --git a/src/ray/core_worker/store_provider/plasma_store_provider.cc b/src/ray/core_worker/store_provider/plasma_store_provider.cc index 2c6c9f96f..011c37282 100644 --- a/src/ray/core_worker/store_provider/plasma_store_provider.cc +++ b/src/ray/core_worker/store_provider/plasma_store_provider.cc @@ -156,6 +156,7 @@ Status CoreWorkerPlasmaStoreProvider::Get( // objects are all fetched if timeout is -1. int unsuccessful_attempts = 0; bool should_break = false; + bool timed_out = false; int64_t remaining_timeout = timeout_ms; while (!remaining.empty() && !should_break) { batch_ids.clear(); @@ -171,14 +172,14 @@ Status CoreWorkerPlasmaStoreProvider::Get( if (remaining_timeout >= 0) { batch_timeout = std::min(remaining_timeout, batch_timeout); remaining_timeout -= batch_timeout; - should_break = remaining_timeout <= 0; + timed_out = remaining_timeout <= 0; } size_t previous_size = remaining.size(); RAY_RETURN_NOT_OK(FetchAndGetFromPlasmaStore(remaining, batch_ids, batch_timeout, /*fetch_only=*/false, task_id, results, got_exception)); - should_break = should_break || *got_exception; + should_break = timed_out || *got_exception; if ((previous_size - remaining.size()) < batch_ids.size()) { unsuccessful_attempts++; @@ -194,6 +195,11 @@ Status CoreWorkerPlasmaStoreProvider::Get( } } + if (!remaining.empty() && timed_out) { + RAY_RETURN_NOT_OK(raylet_client_->NotifyUnblocked(task_id)); + return Status::TimedOut("Get timed out: some object(s) not ready."); + } + // Notify unblocked because we blocked when calling FetchOrReconstruct with // fetch_only=false. return raylet_client_->NotifyUnblocked(task_id); diff --git a/src/ray/core_worker/test/core_worker_test.cc b/src/ray/core_worker/test/core_worker_test.cc index 22350999a..ec0c95653 100644 --- a/src/ray/core_worker/test/core_worker_test.cc +++ b/src/ray/core_worker/test/core_worker_test.cc @@ -683,7 +683,8 @@ TEST_F(SingleNodeTest, TestMemoryStoreProvider) { RAY_CHECK_OK(provider.Delete(ids_set)); usleep(200 * 1000); - RAY_CHECK_OK(provider.Get(ids_set, 0, RandomTaskId(), &results, &got_exception)); + ASSERT_TRUE( + provider.Get(ids_set, 0, RandomTaskId(), &results, &got_exception).IsTimedOut()); ASSERT_TRUE(!got_exception); ASSERT_EQ(results.size(), 0); @@ -811,7 +812,7 @@ TEST_F(SingleNodeTest, TestObjectInterface) { // wait for objects being deleted, so wait a while for plasma store // to process the command. usleep(200 * 1000); - RAY_CHECK_OK(core_worker.Get(ids, 0, &results)); + ASSERT_TRUE(core_worker.Get(ids, 0, &results).IsTimedOut()); ASSERT_EQ(results.size(), 2); ASSERT_TRUE(!results[0]); ASSERT_TRUE(!results[1]); @@ -872,12 +873,12 @@ TEST_F(TwoNodeTest, TestObjectInterfaceCrossNodes) { // to process the command. usleep(1000 * 1000); // Verify objects are deleted from both machines. - RAY_CHECK_OK(worker2.Get(ids, 0, &results)); + ASSERT_TRUE(worker2.Get(ids, 0, &results).IsTimedOut()); ASSERT_EQ(results.size(), 2); ASSERT_TRUE(!results[0]); ASSERT_TRUE(!results[1]); - RAY_CHECK_OK(worker1.Get(ids, 0, &results)); + ASSERT_TRUE(worker1.Get(ids, 0, &results).IsTimedOut()); ASSERT_EQ(results.size(), 2); ASSERT_TRUE(!results[0]); ASSERT_TRUE(!results[1]);