mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 21:12:15 +08:00
Reset signal counters when a task finishes (#4173)
This commit is contained in:
@@ -170,4 +170,5 @@ def reset():
|
||||
If the worker calls receive() on a source next, it will get all the
|
||||
signals generated by that source starting with index = 1.
|
||||
"""
|
||||
ray.worker.global_worker.signal_counters = defaultdict(lambda: b"0")
|
||||
if hasattr(ray.worker.global_worker, "signal_counters"):
|
||||
ray.worker.global_worker.signal_counters = defaultdict(lambda: b"0")
|
||||
|
||||
@@ -316,3 +316,22 @@ def test_receiving_on_two_returns(ray_start):
|
||||
|
||||
assert ((x == results[0][0] and y == results[1][0])
|
||||
or (x == results[1][0] and y == results[0][0]))
|
||||
|
||||
|
||||
def test_serial_tasks_reading_same_signal(ray_start):
|
||||
@ray.remote
|
||||
def send_signal(value):
|
||||
signal.send(UserSignal(value))
|
||||
|
||||
a = send_signal.remote(0)
|
||||
|
||||
@ray.remote
|
||||
def f(sources):
|
||||
return ray.experimental.signal.receive(sources, timeout=1)
|
||||
|
||||
result_list = ray.get(f.remote([a]))
|
||||
assert len(result_list) == 1
|
||||
result_list = ray.get(f.remote([a]))
|
||||
assert len(result_list) == 1
|
||||
result_list = ray.get(f.remote([a]))
|
||||
assert len(result_list) == 1
|
||||
|
||||
@@ -965,6 +965,9 @@ class Worker(object):
|
||||
# actor. Because the following tasks should all have the
|
||||
# same driver id.
|
||||
self.task_driver_id = DriverID.nil()
|
||||
# Reset signal counters so that the next task can get
|
||||
# all past signals.
|
||||
ray_signal.reset()
|
||||
|
||||
# Increase the task execution counter.
|
||||
self.function_actor_manager.increase_task_counter(
|
||||
|
||||
Reference in New Issue
Block a user