diff --git a/python/ray/experimental/signal.py b/python/ray/experimental/signal.py index a893c75ac..d1f0e2f48 100644 --- a/python/ray/experimental/signal.py +++ b/python/ray/experimental/signal.py @@ -170,4 +170,5 @@ def reset(): If the worker calls receive() on a source next, it will get all the signals generated by that source starting with index = 1. """ - ray.worker.global_worker.signal_counters = defaultdict(lambda: b"0") + if hasattr(ray.worker.global_worker, "signal_counters"): + ray.worker.global_worker.signal_counters = defaultdict(lambda: b"0") diff --git a/python/ray/tests/test_signal.py b/python/ray/tests/test_signal.py index d7bef391a..9885cdde7 100644 --- a/python/ray/tests/test_signal.py +++ b/python/ray/tests/test_signal.py @@ -316,3 +316,22 @@ def test_receiving_on_two_returns(ray_start): assert ((x == results[0][0] and y == results[1][0]) or (x == results[1][0] and y == results[0][0])) + + +def test_serial_tasks_reading_same_signal(ray_start): + @ray.remote + def send_signal(value): + signal.send(UserSignal(value)) + + a = send_signal.remote(0) + + @ray.remote + def f(sources): + return ray.experimental.signal.receive(sources, timeout=1) + + result_list = ray.get(f.remote([a])) + assert len(result_list) == 1 + result_list = ray.get(f.remote([a])) + assert len(result_list) == 1 + result_list = ray.get(f.remote([a])) + assert len(result_list) == 1 diff --git a/python/ray/worker.py b/python/ray/worker.py index 427356043..b72cd13dd 100644 --- a/python/ray/worker.py +++ b/python/ray/worker.py @@ -965,6 +965,9 @@ class Worker(object): # actor. Because the following tasks should all have the # same driver id. self.task_driver_id = DriverID.nil() + # Reset signal counters so that the next task can get + # all past signals. + ray_signal.reset() # Increase the task execution counter. self.function_actor_manager.increase_task_counter(