Fix max_task_retries for named actors (#12762)

This commit is contained in:
Edward Oakes
2020-12-10 18:24:55 -06:00
committed by GitHub
parent 0e90cbcd19
commit 62d6b0a558
7 changed files with 81 additions and 10 deletions
+61
View File
@@ -1,3 +1,4 @@
import asyncio
import collections
import numpy as np
import os
@@ -211,6 +212,66 @@ def test_actor_restart_with_retry(ray_init_with_task_retry_delay):
ray.get(actor.increase.remote())
def test_named_actor_max_task_retries(ray_init_with_task_retry_delay):
@ray.remote(num_cpus=0)
class Counter:
def __init__(self):
self.count = 0
self.event = asyncio.Event()
def increment(self):
self.count += 1
self.event.set()
async def wait_for_count(self, count):
while True:
if self.count >= count:
return
await self.event.wait()
self.event.clear()
@ray.remote
class ActorToKill:
def __init__(self, counter):
counter.increment.remote()
def run(self, counter, signal):
counter.increment.remote()
ray.get(signal.wait.remote())
@ray.remote
class CallingActor:
def __init__(self):
self.actor = ray.get_actor("a")
def call_other(self, counter, signal):
return ray.get(self.actor.run.remote(counter, signal))
init_counter = Counter.remote()
run_counter = Counter.remote()
signal = SignalActor.remote()
# Start the two actors, wait for ActorToKill's constructor to run.
a = ActorToKill.options(
name="a", max_restarts=-1, max_task_retries=-1).remote(init_counter)
c = CallingActor.remote()
ray.get(init_counter.wait_for_count.remote(1), timeout=30)
# Signal the CallingActor to call ActorToKill, wait for it to be running,
# then kill ActorToKill.
# Verify that this causes ActorToKill's constructor to run a second time
# and the run method to begin a second time.
ref = c.call_other.remote(run_counter, signal)
ray.get(run_counter.wait_for_count.remote(1), timeout=30)
ray.kill(a, no_restart=False)
ray.get(init_counter.wait_for_count.remote(2), timeout=30)
ray.get(run_counter.wait_for_count.remote(2), timeout=30)
# Signal the run method to finish, verify that the CallingActor returns.
signal.send.remote()
ray.get(ref, timeout=30)
def test_actor_restart_on_node_failure(ray_start_cluster):
config = {
"num_heartbeats_timeout": 10,