diff --git a/python/ray/ray_constants.py b/python/ray/ray_constants.py index a835801bd..9986cb41d 100644 --- a/python/ray/ray_constants.py +++ b/python/ray/ray_constants.py @@ -15,7 +15,7 @@ def env_integer(key, default): ID_SIZE = 20 -NIL_JOB_ID = ray.ObjectID(ID_SIZE * b"\x00") +NIL_JOB_ID = ray.ObjectID(ID_SIZE * b"\xff") # If a remote function or actor (or some other export) has serialized size # greater than this quantity, print an warning. @@ -43,6 +43,7 @@ WORKER_DIED_PUSH_ERROR = "worker_died" PUT_RECONSTRUCTION_PUSH_ERROR = "put_reconstruction" HASH_MISMATCH_PUSH_ERROR = "object_hash_mismatch" INFEASIBLE_TASK_ERROR = "infeasible_task" +REMOVED_NODE_ERROR = "node_removed" # Abort autoscaling if more than this number of errors are encountered. This # is a safety feature to prevent e.g. runaway node launches. diff --git a/python/ray/worker.py b/python/ray/worker.py index bcea4e258..83e127126 100644 --- a/python/ray/worker.py +++ b/python/ray/worker.py @@ -1206,11 +1206,10 @@ def error_applies_to_driver(error_key, worker=global_worker): + ray_constants.ID_SIZE), error_key # If the driver ID in the error message is a sequence of all zeros, then # the message is intended for all drivers. - generic_driver_id = ray_constants.ID_SIZE * b"\x00" driver_id = error_key[len(ERROR_KEY_PREFIX):( len(ERROR_KEY_PREFIX) + ray_constants.ID_SIZE)] return (driver_id == worker.task_driver_id.id() - or driver_id == generic_driver_id) + or driver_id == ray.ray_constants.NIL_JOB_ID.id()) def error_info(worker=global_worker): @@ -1967,9 +1966,11 @@ def print_error_messages_raylet(worker): assert gcs_entry.EntriesLength() == 1 error_data = ray.gcs_utils.ErrorTableData.GetRootAsErrorTableData( gcs_entry.Entries(0), 0) - NIL_JOB_ID = ray_constants.ID_SIZE * b"\x00" job_id = error_data.JobId() - if job_id not in [worker.task_driver_id.id(), NIL_JOB_ID]: + if job_id not in [ + worker.task_driver_id.id(), + ray_constants.NIL_JOB_ID.id() + ]: continue error_message = ray.utils.decode(error_data.ErrorMessage()) diff --git a/src/ray/raylet/monitor.cc b/src/ray/raylet/monitor.cc index aa845035c..05cf79309 100644 --- a/src/ray/raylet/monitor.cc +++ b/src/ray/raylet/monitor.cc @@ -1,6 +1,7 @@ #include "ray/raylet/monitor.h" #include "ray/status.h" +#include "ray/util/util.h" namespace ray { @@ -43,6 +44,19 @@ void Monitor::Tick() { if (dead_clients_.count(it->first) == 0) { RAY_LOG(WARNING) << "Client timed out: " << it->first; RAY_CHECK_OK(gcs_client_.client_table().MarkDisconnected(it->first)); + + // Broadcast a warning to all of the drivers indicating that the node + // has been marked as dead. + // TODO(rkn): Define this constant somewhere else. + std::string type = "node_removed"; + std::ostringstream error_message; + error_message << "The node with client ID " << it->first << " has been marked " + << "dead because the monitor has missed too many heartbeats " + << "from it."; + // We use the nil JobID to broadcast the message to all drivers. + RAY_CHECK_OK(gcs_client_.error_table().PushErrorToDriver( + JobID::nil(), type, error_message.str(), current_time_ms())); + dead_clients_.insert(it->first); } it = heartbeats_.erase(it); diff --git a/test/failure_test.py b/test/failure_test.py index 89149ea4e..a8dde76b8 100644 --- a/test/failure_test.py +++ b/test/failure_test.py @@ -503,3 +503,45 @@ def test_warning_for_infeasible_tasks(ray_start_regular): # This actor placement task is infeasible. Foo.remote() wait_for_errors(ray_constants.INFEASIBLE_TASK_ERROR, 2) + + +@pytest.fixture +def ray_start_two_nodes(): + # Start the Ray processes. + ray.worker._init(start_ray_local=True, num_local_schedulers=2, num_cpus=0) + yield None + # The code after the yield will run as teardown code. + ray.shutdown() + + +# Note that this test will take at least 10 seconds because it must wait for +# the monitor to detect enough missed heartbeats. +@pytest.mark.skipif( + os.environ.get("RAY_USE_XRAY") != "1", + reason="This test only works with xray.") +def test_warning_for_dead_node(ray_start_two_nodes): + # Wait for the raylet to appear in the client table. + while len(ray.global_state.client_table()) < 2: + time.sleep(0.1) + + client_ids = {item["ClientID"] for item in ray.global_state.client_table()} + + # Try to make sure that the monitor has received at least one heartbeat + # from the node. + time.sleep(0.5) + + # Kill both raylets. + ray.services.all_processes[ray.services.PROCESS_TYPE_RAYLET][1].kill() + ray.services.all_processes[ray.services.PROCESS_TYPE_RAYLET][0].kill() + + # Check that we get warning messages for both raylets. + wait_for_errors(ray_constants.REMOVED_NODE_ERROR, 2, timeout=20) + + # Extract the client IDs from the error messages. This will need to be + # changed if the error message changes. + warning_client_ids = { + item['message'].split(' ')[5] + for item in ray.error_info() + } + + assert client_ids == warning_client_ids