mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 11:21:15 +08:00
Signal actor failure (#4196)
This commit is contained in:
@@ -7,6 +7,12 @@ from collections import defaultdict
|
||||
import ray
|
||||
import ray.cloudpickle as cloudpickle
|
||||
|
||||
# This string should be identical to the name of the signal sent upon
|
||||
# detecting that an actor died.
|
||||
# This constant is also used in NodeManager::PublishActorStateTransition()
|
||||
# in node_manager.cc
|
||||
ACTOR_DIED_STR = "ACTOR_DIED_SIGNAL"
|
||||
|
||||
|
||||
class Signal(object):
|
||||
"""Base class for Ray signals."""
|
||||
@@ -20,6 +26,13 @@ class ErrorSignal(Signal):
|
||||
self.error = error
|
||||
|
||||
|
||||
class ActorDiedSignal(Signal):
|
||||
"""Signal raised if an exception happens in a task or actor method."""
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
|
||||
def _get_task_id(source):
|
||||
"""Return the task id associated to the generic source of the signal.
|
||||
|
||||
@@ -137,12 +150,18 @@ def receive(sources, timeout=None):
|
||||
# The list of results for source s is stored in answer[1]
|
||||
for r in answer[1]:
|
||||
for s in task_source_list:
|
||||
# Now it gets tricky: r[0] is the redis internal sequence id
|
||||
signal_counters[ray.utils.hex_to_binary(task_id)] = r[0]
|
||||
# r[1] contains a list with elements (key, value), in our case
|
||||
# we only have one key "signal" and the value is the signal.
|
||||
signal = cloudpickle.loads(ray.utils.hex_to_binary(r[1][1]))
|
||||
results.append((s, signal))
|
||||
if r[1][1].decode("ascii") == ACTOR_DIED_STR:
|
||||
results.append((s, ActorDiedSignal()))
|
||||
else:
|
||||
# Now it gets tricky: r[0] is the redis internal sequence
|
||||
# id
|
||||
signal_counters[ray.utils.hex_to_binary(task_id)] = r[0]
|
||||
# r[1] contains a list with elements (key, value), in our
|
||||
# case we only have one key "signal" and the value is the
|
||||
# signal.
|
||||
signal = cloudpickle.loads(
|
||||
ray.utils.hex_to_binary(r[1][1]))
|
||||
results.append((s, signal))
|
||||
|
||||
return results
|
||||
|
||||
|
||||
@@ -161,3 +161,22 @@ def call_ray_start(request):
|
||||
ray.shutdown()
|
||||
# Kill the Ray cluster.
|
||||
subprocess.Popen(["ray", "stop"]).wait()
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def two_node_cluster():
|
||||
internal_config = json.dumps({
|
||||
"initial_reconstruction_timeout_milliseconds": 200,
|
||||
"num_heartbeats_timeout": 10,
|
||||
})
|
||||
cluster = ray.tests.cluster_utils.Cluster(
|
||||
head_node_args={"_internal_config": internal_config})
|
||||
for _ in range(2):
|
||||
remote_node = cluster.add_node(
|
||||
num_cpus=1, _internal_config=internal_config)
|
||||
ray.init(redis_address=cluster.redis_address)
|
||||
yield cluster, remote_node
|
||||
|
||||
# The code after the yield will run as teardown code.
|
||||
ray.shutdown()
|
||||
cluster.shutdown()
|
||||
|
||||
@@ -274,6 +274,34 @@ def test_forget(ray_start_regular):
|
||||
assert len(result_list) == count
|
||||
|
||||
|
||||
def test_signal_on_node_failure(two_node_cluster):
|
||||
"""Test actor checkpointing on a remote node."""
|
||||
|
||||
class ActorSignal(object):
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def local_plasma(self):
|
||||
return ray.worker.global_worker.plasma_client.store_socket_name
|
||||
|
||||
# Place the actor on the remote node.
|
||||
cluster, remote_node = two_node_cluster
|
||||
actor_cls = ray.remote(max_reconstructions=0)(ActorSignal)
|
||||
actor = actor_cls.remote()
|
||||
# Try until we put an actor on a different node.
|
||||
while (ray.get(actor.local_plasma.remote()) !=
|
||||
remote_node.plasma_store_socket_name):
|
||||
actor = actor_cls.remote()
|
||||
|
||||
# Kill actor process.
|
||||
cluster.remove_node(remote_node)
|
||||
|
||||
# Wait on signal from the actor on the failed node.
|
||||
result_list = signal.receive([actor], timeout=10)
|
||||
assert len(result_list) == 1
|
||||
assert type(result_list[0][1]) == signal.ActorDiedSignal
|
||||
|
||||
|
||||
def test_send_signal_from_two_tasks_to_driver(ray_start_regular):
|
||||
# Define a remote function that sends a user-defined signal.
|
||||
@ray.remote
|
||||
|
||||
Reference in New Issue
Block a user