Add timeout param to ray.get (#6107)

This commit is contained in:
Ujval Misra
2019-11-14 00:50:04 -08:00
committed by Richard Liaw
parent e4c0843f60
commit e3e3ad4b25
12 changed files with 100 additions and 20 deletions
+4 -1
View File
@@ -84,7 +84,8 @@ from ray.exceptions import (
RayError,
RayletError,
RayTaskError,
ObjectStoreFullError
ObjectStoreFullError,
RayTimeoutError,
)
from ray.experimental.no_return import NoReturn
from ray.function_manager import FunctionDescriptor
@@ -138,6 +139,8 @@ cdef int check_status(const CRayStatus& status) nogil except -1:
raise ObjectStoreFullError(message)
elif status.IsInterrupted():
raise KeyboardInterrupt()
elif status.IsTimedOut():
raise RayTimeoutError(message)
else:
raise RayletError(message)
+6
View File
@@ -161,6 +161,11 @@ class UnreconstructableError(RayError):
"https://ray.readthedocs.io/en/latest/memory-management.html"))
class RayTimeoutError(RayError):
"""Indicates that a call to the worker timed out."""
pass
RAY_EXCEPTION_TYPES = [
RayError,
RayTaskError,
@@ -168,4 +173,5 @@ RAY_EXCEPTION_TYPES = [
RayActorError,
ObjectStoreFullError,
UnreconstructableError,
RayTimeoutError,
]
+5
View File
@@ -73,6 +73,9 @@ cdef extern from "ray/common/status.h" namespace "ray" nogil:
@staticmethod
CRayStatus RedisError(const c_string &msg)
@staticmethod
CRayStatus TimedOut(const c_string &msg)
@staticmethod
CRayStatus Interrupted(const c_string &msg)
@@ -89,7 +92,9 @@ cdef extern from "ray/common/status.h" namespace "ray" nogil:
c_bool IsNotImplemented()
c_bool IsObjectStoreFull()
c_bool IsRedisError()
c_bool IsTimedOut()
c_bool IsInterrupted()
c_bool IsSystemExit()
c_string ToString()
c_string CodeAsString()
+15
View File
@@ -29,6 +29,7 @@ import pytest
import ray
from ray import signature
from ray.exceptions import RayTimeoutError
import ray.ray_constants as ray_constants
import ray.tests.cluster_utils
import ray.tests.utils
@@ -1190,6 +1191,20 @@ def test_get_dict(ray_start_regular):
assert result == expected
def test_get_with_timeout(ray_start_regular):
@ray.remote
def f(a):
time.sleep(a)
return a
assert ray.get(f.remote(3), timeout=10) == 3
obj_id = f.remote(3)
with pytest.raises(RayTimeoutError):
ray.get(obj_id, timeout=2)
assert ray.get(obj_id, timeout=2) == 3
def test_direct_call_simple(ray_start_regular):
@ray.remote
def f(x):
+16 -4
View File
@@ -286,7 +286,7 @@ class Worker(object):
return context.deserialize_objects(data_metadata_pairs, object_ids,
error_timeout)
def get_objects(self, object_ids):
def get_objects(self, object_ids, timeout=None):
"""Get the values in the object store associated with the IDs.
Return the values from the local object store for object_ids. This will
@@ -296,6 +296,8 @@ class Worker(object):
Args:
object_ids (List[object_id.ObjectID]): A list of the object IDs
whose values should be retrieved.
timeout (float): timeout (float): The maximum amount of time in
seconds to wait before returning.
Raises:
Exception if running in LOCAL_MODE and any of the object IDs do not
@@ -309,10 +311,15 @@ class Worker(object):
"which is not an ray.ObjectID.".format(object_id))
if self.mode == LOCAL_MODE:
# TODO(ujvl): Remove check when local mode moved to core worker.
if timeout is not None:
raise ValueError(
"`get` must be called with timeout=None in local mode.")
return self.local_mode_manager.get_objects(object_ids)
timeout_ms = int(timeout * 1000) if timeout else -1
data_metadata_pairs = self.core_worker.get_objects(
object_ids, self.current_task_id)
object_ids, self.current_task_id, timeout_ms)
return self.deserialize_objects(data_metadata_pairs, object_ids)
def run_function_on_all_workers(self, function,
@@ -1388,7 +1395,7 @@ def register_custom_serializer(cls,
class_id=class_id)
def get(object_ids):
def get(object_ids, timeout=None):
"""Get a remote object or a list of remote objects from the object store.
This method blocks until the object corresponding to the object ID is
@@ -1400,11 +1407,15 @@ def get(object_ids):
Args:
object_ids: Object ID of the object to get or a list of object IDs to
get.
timeout (float): The maximum amount of time in seconds to wait before
returning.
Returns:
A Python object or a list of Python objects.
Raises:
RayTimeoutError: A RayTimeoutError is raised if a timeout is set and
the get takes longer than timeout to return.
Exception: An exception is raised if the task that created the object
or that created one of the objects raised an exception.
"""
@@ -1420,7 +1431,8 @@ def get(object_ids):
"or a list of object IDs.")
global last_task_error_raise_time
values = worker.get_objects(object_ids)
# TODO(ujvl): Consider how to allow user to retrieve the ready objects.
values = worker.get_objects(object_ids, timeout=timeout)
for i, value in enumerate(values):
if isinstance(value, RayError):
last_task_error_raise_time = time.time()