diff --git a/python/ray/tests/test_basic.py b/python/ray/tests/test_basic.py index f800fc502..5009f4bf3 100644 --- a/python/ray/tests/test_basic.py +++ b/python/ray/tests/test_basic.py @@ -4,21 +4,24 @@ import io import json import logging import os +import pickle import re import string import sys +import tempfile import threading import time import pickle +import uuid import weakref import numpy as np import pytest import ray -from ray.exceptions import RayTimeoutError import ray.cluster_utils import ray.test_utils +from ray.exceptions import RayTimeoutError logger = logging.getLogger(__name__) @@ -1265,17 +1268,38 @@ def test_get_dict(ray_start_regular): def test_get_with_timeout(ray_start_regular): + def random_path(): + return os.path.join(tempfile.gettempdir(), uuid.uuid4().hex) + + def touch(path): + with open(path, "w"): + pass + @ray.remote - def f(a): - time.sleep(a) - return a + def wait_for_file(path): + if path: + while True: + if os.path.exists(path): + break + time.sleep(0.1) - assert ray.get(f.remote(3), timeout=10) == 3 + # Check that get() returns early if object is ready. + start = time.time() + ray.get(wait_for_file.remote(None), timeout=30) + assert time.time() - start < 30 - obj_id = f.remote(3) + # Check that get() raises a TimeoutError after the timeout if the object + # is not ready yet. + path = random_path() + result_id = wait_for_file.remote(path) with pytest.raises(RayTimeoutError): - ray.get(obj_id, timeout=2) - assert ray.get(obj_id, timeout=2) == 3 + ray.get(result_id, timeout=0.1) + + # Check that a subsequent get() returns early. + touch(path) + start = time.time() + ray.get(result_id, timeout=30) + assert time.time() - start < 30 @pytest.mark.parametrize(