mirror of
https://github.com/wassname/ray.git
synced 2026-07-01 11:27:06 +08:00
Add timeout param to ray.get (#6107)
This commit is contained in:
committed by
Richard Liaw
parent
e4c0843f60
commit
e3e3ad4b25
@@ -84,7 +84,8 @@ from ray.exceptions import (
|
||||
RayError,
|
||||
RayletError,
|
||||
RayTaskError,
|
||||
ObjectStoreFullError
|
||||
ObjectStoreFullError,
|
||||
RayTimeoutError,
|
||||
)
|
||||
from ray.experimental.no_return import NoReturn
|
||||
from ray.function_manager import FunctionDescriptor
|
||||
@@ -138,6 +139,8 @@ cdef int check_status(const CRayStatus& status) nogil except -1:
|
||||
raise ObjectStoreFullError(message)
|
||||
elif status.IsInterrupted():
|
||||
raise KeyboardInterrupt()
|
||||
elif status.IsTimedOut():
|
||||
raise RayTimeoutError(message)
|
||||
else:
|
||||
raise RayletError(message)
|
||||
|
||||
|
||||
@@ -161,6 +161,11 @@ class UnreconstructableError(RayError):
|
||||
"https://ray.readthedocs.io/en/latest/memory-management.html"))
|
||||
|
||||
|
||||
class RayTimeoutError(RayError):
|
||||
"""Indicates that a call to the worker timed out."""
|
||||
pass
|
||||
|
||||
|
||||
RAY_EXCEPTION_TYPES = [
|
||||
RayError,
|
||||
RayTaskError,
|
||||
@@ -168,4 +173,5 @@ RAY_EXCEPTION_TYPES = [
|
||||
RayActorError,
|
||||
ObjectStoreFullError,
|
||||
UnreconstructableError,
|
||||
RayTimeoutError,
|
||||
]
|
||||
|
||||
@@ -73,6 +73,9 @@ cdef extern from "ray/common/status.h" namespace "ray" nogil:
|
||||
@staticmethod
|
||||
CRayStatus RedisError(const c_string &msg)
|
||||
|
||||
@staticmethod
|
||||
CRayStatus TimedOut(const c_string &msg)
|
||||
|
||||
@staticmethod
|
||||
CRayStatus Interrupted(const c_string &msg)
|
||||
|
||||
@@ -89,7 +92,9 @@ cdef extern from "ray/common/status.h" namespace "ray" nogil:
|
||||
c_bool IsNotImplemented()
|
||||
c_bool IsObjectStoreFull()
|
||||
c_bool IsRedisError()
|
||||
c_bool IsTimedOut()
|
||||
c_bool IsInterrupted()
|
||||
c_bool IsSystemExit()
|
||||
|
||||
c_string ToString()
|
||||
c_string CodeAsString()
|
||||
|
||||
@@ -29,6 +29,7 @@ import pytest
|
||||
|
||||
import ray
|
||||
from ray import signature
|
||||
from ray.exceptions import RayTimeoutError
|
||||
import ray.ray_constants as ray_constants
|
||||
import ray.tests.cluster_utils
|
||||
import ray.tests.utils
|
||||
@@ -1190,6 +1191,20 @@ def test_get_dict(ray_start_regular):
|
||||
assert result == expected
|
||||
|
||||
|
||||
def test_get_with_timeout(ray_start_regular):
|
||||
@ray.remote
|
||||
def f(a):
|
||||
time.sleep(a)
|
||||
return a
|
||||
|
||||
assert ray.get(f.remote(3), timeout=10) == 3
|
||||
|
||||
obj_id = f.remote(3)
|
||||
with pytest.raises(RayTimeoutError):
|
||||
ray.get(obj_id, timeout=2)
|
||||
assert ray.get(obj_id, timeout=2) == 3
|
||||
|
||||
|
||||
def test_direct_call_simple(ray_start_regular):
|
||||
@ray.remote
|
||||
def f(x):
|
||||
|
||||
+16
-4
@@ -286,7 +286,7 @@ class Worker(object):
|
||||
return context.deserialize_objects(data_metadata_pairs, object_ids,
|
||||
error_timeout)
|
||||
|
||||
def get_objects(self, object_ids):
|
||||
def get_objects(self, object_ids, timeout=None):
|
||||
"""Get the values in the object store associated with the IDs.
|
||||
|
||||
Return the values from the local object store for object_ids. This will
|
||||
@@ -296,6 +296,8 @@ class Worker(object):
|
||||
Args:
|
||||
object_ids (List[object_id.ObjectID]): A list of the object IDs
|
||||
whose values should be retrieved.
|
||||
timeout (float): timeout (float): The maximum amount of time in
|
||||
seconds to wait before returning.
|
||||
|
||||
Raises:
|
||||
Exception if running in LOCAL_MODE and any of the object IDs do not
|
||||
@@ -309,10 +311,15 @@ class Worker(object):
|
||||
"which is not an ray.ObjectID.".format(object_id))
|
||||
|
||||
if self.mode == LOCAL_MODE:
|
||||
# TODO(ujvl): Remove check when local mode moved to core worker.
|
||||
if timeout is not None:
|
||||
raise ValueError(
|
||||
"`get` must be called with timeout=None in local mode.")
|
||||
return self.local_mode_manager.get_objects(object_ids)
|
||||
|
||||
timeout_ms = int(timeout * 1000) if timeout else -1
|
||||
data_metadata_pairs = self.core_worker.get_objects(
|
||||
object_ids, self.current_task_id)
|
||||
object_ids, self.current_task_id, timeout_ms)
|
||||
return self.deserialize_objects(data_metadata_pairs, object_ids)
|
||||
|
||||
def run_function_on_all_workers(self, function,
|
||||
@@ -1388,7 +1395,7 @@ def register_custom_serializer(cls,
|
||||
class_id=class_id)
|
||||
|
||||
|
||||
def get(object_ids):
|
||||
def get(object_ids, timeout=None):
|
||||
"""Get a remote object or a list of remote objects from the object store.
|
||||
|
||||
This method blocks until the object corresponding to the object ID is
|
||||
@@ -1400,11 +1407,15 @@ def get(object_ids):
|
||||
Args:
|
||||
object_ids: Object ID of the object to get or a list of object IDs to
|
||||
get.
|
||||
timeout (float): The maximum amount of time in seconds to wait before
|
||||
returning.
|
||||
|
||||
Returns:
|
||||
A Python object or a list of Python objects.
|
||||
|
||||
Raises:
|
||||
RayTimeoutError: A RayTimeoutError is raised if a timeout is set and
|
||||
the get takes longer than timeout to return.
|
||||
Exception: An exception is raised if the task that created the object
|
||||
or that created one of the objects raised an exception.
|
||||
"""
|
||||
@@ -1420,7 +1431,8 @@ def get(object_ids):
|
||||
"or a list of object IDs.")
|
||||
|
||||
global last_task_error_raise_time
|
||||
values = worker.get_objects(object_ids)
|
||||
# TODO(ujvl): Consider how to allow user to retrieve the ready objects.
|
||||
values = worker.get_objects(object_ids, timeout=timeout)
|
||||
for i, value in enumerate(values):
|
||||
if isinstance(value, RayError):
|
||||
last_task_error_raise_time = time.time()
|
||||
|
||||
Reference in New Issue
Block a user