From 59079a799cf1939935ba2c0a1509d504cacd70ea Mon Sep 17 00:00:00 2001 From: Ion Date: Thu, 21 Mar 2019 15:17:42 -0700 Subject: [PATCH] Signal actor failure (#4196) --- python/ray/experimental/signal.py | 31 +++++++++++++++++++++++++------ python/ray/tests/conftest.py | 19 +++++++++++++++++++ python/ray/tests/test_signal.py | 28 ++++++++++++++++++++++++++++ src/ray/raylet/node_manager.cc | 16 ++++++++++++++-- 4 files changed, 86 insertions(+), 8 deletions(-) diff --git a/python/ray/experimental/signal.py b/python/ray/experimental/signal.py index d1f0e2f48..f2a0d81ca 100644 --- a/python/ray/experimental/signal.py +++ b/python/ray/experimental/signal.py @@ -7,6 +7,12 @@ from collections import defaultdict import ray import ray.cloudpickle as cloudpickle +# This string should be identical to the name of the signal sent upon +# detecting that an actor died. +# This constant is also used in NodeManager::PublishActorStateTransition() +# in node_manager.cc +ACTOR_DIED_STR = "ACTOR_DIED_SIGNAL" + class Signal(object): """Base class for Ray signals.""" @@ -20,6 +26,13 @@ class ErrorSignal(Signal): self.error = error +class ActorDiedSignal(Signal): + """Signal raised if an exception happens in a task or actor method.""" + + def __init__(self): + pass + + def _get_task_id(source): """Return the task id associated to the generic source of the signal. @@ -137,12 +150,18 @@ def receive(sources, timeout=None): # The list of results for source s is stored in answer[1] for r in answer[1]: for s in task_source_list: - # Now it gets tricky: r[0] is the redis internal sequence id - signal_counters[ray.utils.hex_to_binary(task_id)] = r[0] - # r[1] contains a list with elements (key, value), in our case - # we only have one key "signal" and the value is the signal. - signal = cloudpickle.loads(ray.utils.hex_to_binary(r[1][1])) - results.append((s, signal)) + if r[1][1].decode("ascii") == ACTOR_DIED_STR: + results.append((s, ActorDiedSignal())) + else: + # Now it gets tricky: r[0] is the redis internal sequence + # id + signal_counters[ray.utils.hex_to_binary(task_id)] = r[0] + # r[1] contains a list with elements (key, value), in our + # case we only have one key "signal" and the value is the + # signal. + signal = cloudpickle.loads( + ray.utils.hex_to_binary(r[1][1])) + results.append((s, signal)) return results diff --git a/python/ray/tests/conftest.py b/python/ray/tests/conftest.py index f9c34d223..2e670fb0a 100644 --- a/python/ray/tests/conftest.py +++ b/python/ray/tests/conftest.py @@ -161,3 +161,22 @@ def call_ray_start(request): ray.shutdown() # Kill the Ray cluster. subprocess.Popen(["ray", "stop"]).wait() + + +@pytest.fixture() +def two_node_cluster(): + internal_config = json.dumps({ + "initial_reconstruction_timeout_milliseconds": 200, + "num_heartbeats_timeout": 10, + }) + cluster = ray.tests.cluster_utils.Cluster( + head_node_args={"_internal_config": internal_config}) + for _ in range(2): + remote_node = cluster.add_node( + num_cpus=1, _internal_config=internal_config) + ray.init(redis_address=cluster.redis_address) + yield cluster, remote_node + + # The code after the yield will run as teardown code. + ray.shutdown() + cluster.shutdown() diff --git a/python/ray/tests/test_signal.py b/python/ray/tests/test_signal.py index cb1daa838..fe2e74379 100644 --- a/python/ray/tests/test_signal.py +++ b/python/ray/tests/test_signal.py @@ -274,6 +274,34 @@ def test_forget(ray_start_regular): assert len(result_list) == count +def test_signal_on_node_failure(two_node_cluster): + """Test actor checkpointing on a remote node.""" + + class ActorSignal(object): + def __init__(self): + pass + + def local_plasma(self): + return ray.worker.global_worker.plasma_client.store_socket_name + + # Place the actor on the remote node. + cluster, remote_node = two_node_cluster + actor_cls = ray.remote(max_reconstructions=0)(ActorSignal) + actor = actor_cls.remote() + # Try until we put an actor on a different node. + while (ray.get(actor.local_plasma.remote()) != + remote_node.plasma_store_socket_name): + actor = actor_cls.remote() + + # Kill actor process. + cluster.remove_node(remote_node) + + # Wait on signal from the actor on the failed node. + result_list = signal.receive([actor], timeout=10) + assert len(result_list) == 1 + assert type(result_list[0][1]) == signal.ActorDiedSignal + + def test_send_signal_from_two_tasks_to_driver(ray_start_regular): # Define a remote function that sends a user-defined signal. @ray.remote diff --git a/src/ray/raylet/node_manager.cc b/src/ray/raylet/node_manager.cc index 72211ddd9..3566df39d 100644 --- a/src/ray/raylet/node_manager.cc +++ b/src/ray/raylet/node_manager.cc @@ -502,8 +502,20 @@ void NodeManager::PublishActorStateTransition( // RECONSTRUCTING or DEAD entries have an odd index. log_length += 1; } - RAY_CHECK_OK(gcs_client_->actor_table().AppendAt( - JobID::nil(), actor_id, actor_notification, nullptr, failure_callback, log_length)); + // If we successful appended a record to the GCS table of the actor that + // has died, signal this to anyone receiving signals from this actor. + auto success_callback = [](gcs::AsyncGcsClient *client, const ActorID &id, + const ActorTableDataT &data) { + auto redis_context = client->primary_context(); + if (data.state == ActorState::DEAD || data.state == ActorState::RECONSTRUCTING) { + std::vector args = {"XADD", id.hex(), "*", "signal", + "ACTOR_DIED_SIGNAL"}; + RAY_CHECK_OK(redis_context->RunArgvAsync(args)); + } + }; + RAY_CHECK_OK(gcs_client_->actor_table().AppendAt(JobID::nil(), actor_id, + actor_notification, success_callback, + failure_callback, log_length)); } void NodeManager::HandleActorStateTransition(const ActorID &actor_id,