diff --git a/python/ray/experimental/signal.py b/python/ray/experimental/signal.py index f2a0d81ca..25ec072d3 100644 --- a/python/ray/experimental/signal.py +++ b/python/ray/experimental/signal.py @@ -2,6 +2,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import logging + from collections import defaultdict import ray @@ -13,6 +15,8 @@ import ray.cloudpickle as cloudpickle # in node_manager.cc ACTOR_DIED_STR = "ACTOR_DIED_SIGNAL" +logger = logging.getLogger(__name__) + class Signal(object): """Base class for Ray signals.""" @@ -125,10 +129,16 @@ def receive(sources, timeout=None): for s in sources: task_id_to_sources[_get_task_id(s).hex()].append(s) + if timeout < 1e-3: + logger.warning("Timeout too small. Using 1ms minimum") + timeout = 1e-3 + + timeout_ms = int(1000 * timeout) + # Construct the redis query. query = "XREAD BLOCK " - # Multiply by 1000x since timeout is in sec and redis expects ms. - query += str(1000 * timeout) + # redis expects ms. + query += str(timeout_ms) query += " STREAMS " query += " ".join([task_id for task_id in task_id_to_sources]) query += " " diff --git a/python/ray/tests/test_signal.py b/python/ray/tests/test_signal.py index fe2e74379..176fbd45b 100644 --- a/python/ray/tests/test_signal.py +++ b/python/ray/tests/test_signal.py @@ -353,3 +353,36 @@ def test_serial_tasks_reading_same_signal(ray_start_regular): assert len(result_list) == 1 result_list = ray.get(f.remote([a])) assert len(result_list) == 1 + + +def test_non_integral_receive_timeout(ray_start_regular): + @ray.remote + def send_signal(value): + signal.send(UserSignal(value)) + + a = send_signal.remote(0) + # make sure send_signal had a chance to execute + ray.get(a) + + result_list = ray.experimental.signal.receive([a], timeout=0.1) + + assert len(result_list) == 1 + + +def test_small_receive_timeout(ray_start_regular): + """ Test that receive handles timeout smaller than the 1ms min + """ + # 0.1 ms + small_timeout = 1e-4 + + @ray.remote + def send_signal(value): + signal.send(UserSignal(value)) + + a = send_signal.remote(0) + # make sure send_signal had a chance to execute + ray.get(a) + + result_list = ray.experimental.signal.receive([a], timeout=small_timeout) + + assert len(result_list) == 1