From 3ce8eb2d4cb7377aca74865a4064b0327b363099 Mon Sep 17 00:00:00 2001 From: Robert Nishihara Date: Tue, 2 Oct 2018 00:08:47 -0700 Subject: [PATCH] Test dying_worker_get and dying_worker_wait for xray. (#2997) This tests the case in which a worker is blocked in a call to ray.get or ray.wait, and then the worker dies. Then later, the object that the worker was waiting for becomes available. We need to make sure not to try to send a message to the dead worker and then die. Related to #2790. --- cmake/Modules/ArrowExternalProject.cmake | 4 +- python/ray/test/test_utils.py | 38 +++++ src/ray/common/client_connection.h | 2 +- src/ray/raylet/node_manager.cc | 76 +++++---- src/ray/raylet/node_manager.h | 19 ++- test/component_failures_test.py | 204 ++++++++++++++++++++++- test/multi_node_test.py | 42 +---- 7 files changed, 304 insertions(+), 81 deletions(-) diff --git a/cmake/Modules/ArrowExternalProject.cmake b/cmake/Modules/ArrowExternalProject.cmake index dfb25f244..827673c35 100644 --- a/cmake/Modules/ArrowExternalProject.cmake +++ b/cmake/Modules/ArrowExternalProject.cmake @@ -14,10 +14,10 @@ # - PLASMA_SHARED_LIB set(arrow_URL https://github.com/apache/arrow.git) -# The PR for this commit is https://github.com/apache/arrow/pull/2522. We +# The PR for this commit is https://github.com/apache/arrow/pull/2664. We # include the link here to make it easier to find the right commit because # Arrow often rewrites git history and invalidates certain commits. -set(arrow_TAG 7104d64ff2cd6c20e29d3cf4ec5c58bc10798f66) +set(arrow_TAG 3545186d6997b943ffc3d79634f2d08eefbd7322) set(ARROW_INSTALL_PREFIX ${CMAKE_CURRENT_BINARY_DIR}/external/arrow-install) set(ARROW_HOME ${ARROW_INSTALL_PREFIX}) diff --git a/python/ray/test/test_utils.py b/python/ray/test/test_utils.py index a29daa5a0..6e18cd439 100644 --- a/python/ray/test/test_utils.py +++ b/python/ray/test/test_utils.py @@ -6,6 +6,7 @@ import json import os import redis import subprocess +import sys import tempfile import time @@ -147,3 +148,40 @@ def run_and_get_output(command): with open(tmp.name, 'r') as f: result = f.readlines() return "\n".join(result) + + +def run_string_as_driver(driver_script): + """Run a driver as a separate process. + + Args: + driver_script: A string to run as a Python script. + + Returns: + The script's output. + """ + # Save the driver script as a file so we can call it using subprocess. + with tempfile.NamedTemporaryFile() as f: + f.write(driver_script.encode("ascii")) + f.flush() + out = ray.utils.decode( + subprocess.check_output([sys.executable, f.name])) + return out + + +def run_string_as_driver_nonblocking(driver_script): + """Start a driver as a separate process and return immediately. + + Args: + driver_script: A string to run as a Python script. + + Returns: + A handle to the driver process. + """ + # Save the driver script as a file so we can call it using subprocess. We + # do not delete this file because if we do then it may get removed before + # the Python process tries to run it. + with tempfile.NamedTemporaryFile(delete=False) as f: + f.write(driver_script.encode("ascii")) + f.flush() + return subprocess.Popen( + [sys.executable, f.name], stdout=subprocess.PIPE) diff --git a/src/ray/common/client_connection.h b/src/ray/common/client_connection.h index 39d3c084d..20b232c33 100644 --- a/src/ray/common/client_connection.h +++ b/src/ray/common/client_connection.h @@ -114,7 +114,7 @@ class ClientConnection : public ServerConnection, MessageHandler message_handler_; /// A label used for debug messages. const std::string debug_label_; - /// Buffers for the current message being read rom the client. + /// Buffers for the current message being read from the client. int64_t read_version_; int64_t read_type_; uint64_t read_length_; diff --git a/src/ray/raylet/node_manager.cc b/src/ray/raylet/node_manager.cc index bed03ad90..fcc60e030 100644 --- a/src/ray/raylet/node_manager.cc +++ b/src/ray/raylet/node_manager.cc @@ -222,7 +222,8 @@ void NodeManager::KillWorker(std::shared_ptr worker) { retry_timer->expires_from_now(retry_duration); retry_timer->async_wait([retry_timer, worker](const boost::system::error_code &error) { RAY_LOG(DEBUG) << "Send SIGKILL to worker, pid=" << worker->Pid(); - // Force kill worker. + // Force kill worker. TODO(rkn): Is there some small danger that the worker + // has already died and the PID has been reassigned to a different process? kill(worker->Pid(), SIGKILL); }); } @@ -638,8 +639,25 @@ void NodeManager::ProcessGetTaskMessage( void NodeManager::ProcessDisconnectClientMessage( const std::shared_ptr &client) { - // Remove the dead worker from the pool and stop listening for messages. const std::shared_ptr worker = worker_pool_.GetRegisteredWorker(client); + const std::shared_ptr driver = worker_pool_.GetRegisteredDriver(client); + // This client can't be a worker and a driver. + RAY_CHECK(worker == nullptr || driver == nullptr); + + // If both worker and driver are null, then this method has already been + // called, so just return. + if (worker == nullptr && driver == nullptr) { + RAY_LOG(INFO) << "Ignoring client disconnect because the client has already " + << "been disconnected."; + return; + } + + // If the client is blocked, we need to treat it as unblocked. In particular, + // we are no longer waiting for its dependencies. If the client is not + // blocked, this won't do anything. + HandleClientUnblocked(client); + + // Remove the dead client from the pool and stop listening for messages. if (worker) { // The client is a worker. Handle the case where the worker is killed @@ -651,17 +669,15 @@ void NodeManager::ProcessDisconnectClientMessage( // If the worker was killed intentionally, e.g., when the driver that created // the task that this worker is currently executing exits, the task for this // worker has already been removed from queue, so the following are skipped. - auto const &running_tasks = local_queues_.GetRunningTasks(); - // TODO(rkn): This is too heavyweight just to get the task's driver ID. - auto const it = std::find_if( - running_tasks.begin(), running_tasks.end(), [task_id](const Task &task) { - return task.GetTaskSpecification().TaskId() == task_id; - }); - RAY_CHECK(running_tasks.size() != 0); - RAY_CHECK(it != running_tasks.end()); - const TaskSpecification &spec = it->GetTaskSpecification(); - const JobID job_id = spec.DriverId(); + task_dependency_manager_.TaskCanceled(task_id); + // task_dependency_manager_.UnsubscribeDependencies(current_task_id); + const Task &task = local_queues_.RemoveTask(task_id); + const TaskSpecification &spec = task.GetTaskSpecification(); + // Handle the task failure in order to raise an exception in the + // application. + TreatTaskAsFailed(spec); + const JobID &job_id = worker->GetAssignedDriverId(); // TODO(rkn): Define this constant somewhere else. std::string type = "worker_died"; std::ostringstream error_message; @@ -669,18 +685,12 @@ void NodeManager::ProcessDisconnectClientMessage( << "."; RAY_CHECK_OK(gcs_client_->error_table().PushErrorToDriver( job_id, type, error_message.str(), current_time_ms())); - - // Handle the task failure in order to raise an exception in the - // application. - TreatTaskAsFailed(spec); - task_dependency_manager_.TaskCanceled(spec.TaskId()); - local_queues_.RemoveTask(spec.TaskId()); } worker_pool_.DisconnectWorker(worker); // If the worker was an actor, add it to the list of dead actors. - const ActorID actor_id = worker->GetActorId(); + const ActorID &actor_id = worker->GetActorId(); if (!actor_id.is_nil()) { // TODO(rkn): Consider broadcasting a message to all of the other // node managers so that they can mark the actor as dead. @@ -715,7 +725,6 @@ void NodeManager::ProcessDisconnectClientMessage( // The client is a driver. RAY_CHECK_OK(gcs_client_->driver_table().AppendDriverData(client->GetClientID(), /*is_dead=*/true)); - const std::shared_ptr driver = worker_pool_.GetRegisteredDriver(client); RAY_CHECK(driver); auto driver_id = driver->GetAssignedTaskId(); RAY_CHECK(!driver_id.is_nil()); @@ -725,6 +734,10 @@ void NodeManager::ProcessDisconnectClientMessage( RAY_LOG(DEBUG) << "Driver (pid=" << driver->Pid() << ") is disconnected. " << "driver_id: " << driver->GetAssignedDriverId(); } + + // TODO(rkn): Tell the object manager that this client has disconnected so + // that it can clean up the wait requests for this client. Currently I think + // these can be leaked. } void NodeManager::ProcessSubmitTaskMessage(const uint8_t *message_data) { @@ -798,12 +811,21 @@ void NodeManager::ProcessWaitRequestMessage( flatbuffers::Offset wait_reply = protocol::CreateWaitReply( fbb, to_flatbuf(fbb, found), to_flatbuf(fbb, remaining)); fbb.Finish(wait_reply); - RAY_CHECK_OK( + + auto status = client->WriteMessage(static_cast(protocol::MessageType::WaitReply), - fbb.GetSize(), fbb.GetBufferPointer())); - // The client is unblocked now because the wait call has returned. - if (client_blocked) { - HandleClientUnblocked(client); + fbb.GetSize(), fbb.GetBufferPointer()); + if (status.ok()) { + // The client is unblocked now because the wait call has returned. + if (client_blocked) { + HandleClientUnblocked(client); + } + } else { + // We failed to write to the client, so disconnect the client. + RAY_LOG(WARNING) + << "Failed to send WaitReply to client, so disconnecting client"; + // We failed to send the reply to the client, so disconnect the worker. + ProcessDisconnectClientMessage(client); } }); RAY_CHECK_OK(status); @@ -1308,9 +1330,7 @@ void NodeManager::AssignTask(Task &task) { } else { RAY_LOG(WARNING) << "Failed to send task to worker, disconnecting client"; // We failed to send the task to the worker, so disconnect the worker. - ProcessClientMessage(worker->Connection(), - static_cast(protocol::MessageType::DisconnectClient), - nullptr); + ProcessDisconnectClientMessage(worker->Connection()); // Queue this task for future assignment. The task will be assigned to a // worker once one becomes available. // (See design_docs/task_states.rst for the state transition diagram.) diff --git a/src/ray/raylet/node_manager.h b/src/ray/raylet/node_manager.h index d8025a6c8..b6b23223c 100644 --- a/src/ray/raylet/node_manager.h +++ b/src/ray/raylet/node_manager.h @@ -267,7 +267,7 @@ class NodeManager { bool CheckDependencyManagerInvariant() const; /// Process client message of RegisterClientRequest - // + /// /// \param client The client that sent the message. /// \param message_data A pointer to the message data. /// \return Void. @@ -275,26 +275,29 @@ class NodeManager { const std::shared_ptr &client, const uint8_t *message_data); /// Process client message of GetTask - // + /// /// \param client The client that sent the message. /// \return Void. void ProcessGetTaskMessage(const std::shared_ptr &client); - /// Process client message of DisconnectClient - // + /// Handle a client that has disconnected. This can be called multiple times + /// on the same client because this is triggered both when a client + /// disconnects and when the node manager fails to write a message to the + /// client. + /// /// \param client The client that sent the message. /// \return Void. void ProcessDisconnectClientMessage( const std::shared_ptr &client); /// Process client message of SubmitTask - // + /// /// \param message_data A pointer to the message data. /// \return Void. void ProcessSubmitTaskMessage(const uint8_t *message_data); /// Process client message of ReconstructObjects - // + /// /// \param client The client that sent the message. /// \param message_data A pointer to the message data. /// \return Void. @@ -302,7 +305,7 @@ class NodeManager { const std::shared_ptr &client, const uint8_t *message_data); /// Process client message of WaitRequest - // + /// /// \param client The client that sent the message. /// \param message_data A pointer to the message data. /// \return Void. @@ -310,7 +313,7 @@ class NodeManager { const uint8_t *message_data); /// Process client message of PushErrorRequest - // + /// /// \param message_data A pointer to the message data. /// \return Void. void ProcessPushErrorRequestMessage(const uint8_t *message_data); diff --git a/test/component_failures_test.py b/test/component_failures_test.py index 64dd3712b..3a57452e6 100644 --- a/test/component_failures_test.py +++ b/test/component_failures_test.py @@ -2,11 +2,14 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import pytest import os -import ray +import signal import time +import pytest + +import ray +from ray.test.test_utils import run_string_as_driver_nonblocking import pyarrow as pa @@ -23,6 +26,112 @@ def ray_start_workers_separate(): ray.shutdown() +@pytest.fixture +def shutdown_only(): + yield None + # The code after the yield will run as teardown code. + ray.shutdown() + + +# This test checks that when a worker dies in the middle of a get, the plasma +# store and raylet will not die. +@pytest.mark.skipif( + os.environ.get("RAY_USE_XRAY") != "1", + reason="This test only works with xray.") +@pytest.mark.skipif( + os.environ.get("RAY_USE_NEW_GCS") == "on", + reason="Not working with new GCS API.") +def test_dying_worker_get_raylet(shutdown_only): + # Start the Ray processes. + ray.init(num_cpus=2) + + @ray.remote + def sleep_forever(): + time.sleep(10**6) + + @ray.remote + def get_worker_pid(): + return os.getpid() + + x_id = sleep_forever.remote() + time.sleep(0.01) # Try to wait for the sleep task to get scheduled. + # Get the PID of the other worker. + worker_pid = ray.get(get_worker_pid.remote()) + + @ray.remote + def f(id_in_a_list): + ray.get(id_in_a_list[0]) + + # Have the worker wait in a get call. + result_id = f.remote([x_id]) + time.sleep(1) + + # Make sure the task hasn't finished. + ready_ids, _ = ray.wait([result_id], timeout=0) + assert len(ready_ids) == 0 + + # Kill the worker. + os.kill(worker_pid, signal.SIGKILL) + time.sleep(0.1) + + # Make sure the sleep task hasn't finished. + ready_ids, _ = ray.wait([x_id], timeout=0) + assert len(ready_ids) == 0 + # Seal the object so the store attempts to notify the worker that the + # get has been fulfilled. + ray.worker.global_worker.put_object(x_id, 1) + time.sleep(0.1) + + # Make sure that nothing has died. + assert ray.services.all_processes_alive() + + +# This test checks that when a driver dies in the middle of a get, the plasma +# store and raylet will not die. +@pytest.mark.skipif( + os.environ.get("RAY_USE_XRAY") != "1", + reason="This test only works with xray.") +@pytest.mark.skipif( + os.environ.get("RAY_USE_NEW_GCS") == "on", + reason="Not working with new GCS API.") +def test_dying_driver_get(shutdown_only): + # Start the Ray processes. + address_info = ray.init(num_cpus=1) + + @ray.remote + def sleep_forever(): + time.sleep(10**6) + + x_id = sleep_forever.remote() + + driver = """ +import ray +ray.init("{}") +ray.get(ray.ObjectID({})) +""".format(address_info["redis_address"], x_id.id()) + + p = run_string_as_driver_nonblocking(driver) + # Make sure the driver is running. + time.sleep(1) + assert p.poll() is None + + # Kill the driver process. + p.kill() + p.wait() + time.sleep(0.1) + + # Make sure the original task hasn't finished. + ready_ids, _ = ray.wait([x_id], timeout=0) + assert len(ready_ids) == 0 + # Seal the object so the store attempts to notify the worker that the + # get has been fulfilled. + ray.worker.global_worker.put_object(x_id, 1) + time.sleep(0.1) + + # Make sure that nothing has died. + assert ray.services.all_processes_alive() + + # This test checks that when a worker dies in the middle of a get, the # plasma store and manager will not die. @pytest.mark.skipif( @@ -59,6 +168,97 @@ def test_dying_worker_get(ray_start_workers_separate): exclude=[ray.services.PROCESS_TYPE_WORKER]) +# This test checks that when a worker dies in the middle of a wait, the plasma +# store and raylet will not die. +@pytest.mark.skipif( + os.environ.get("RAY_USE_XRAY") != "1", + reason="This test only works with xray.") +@pytest.mark.skipif( + os.environ.get("RAY_USE_NEW_GCS") == "on", + reason="Not working with new GCS API.") +def test_dying_worker_wait_raylet(shutdown_only): + ray.init(num_cpus=2) + + @ray.remote + def sleep_forever(): + time.sleep(10**6) + + @ray.remote + def get_pid(): + return os.getpid() + + x_id = sleep_forever.remote() + # Get the PID of the worker that block_in_wait will run on (sleep a little + # to make sure that sleep_forever has already started). + time.sleep(0.1) + worker_pid = ray.get(get_pid.remote()) + + @ray.remote + def block_in_wait(object_id_in_list): + ray.wait(object_id_in_list) + + # Have the worker wait in a wait call. + block_in_wait.remote([x_id]) + time.sleep(0.1) + + # Kill the worker. + os.kill(worker_pid, signal.SIGKILL) + time.sleep(0.1) + + # Create the object. + ray.worker.global_worker.put_object(x_id, 1) + time.sleep(0.1) + + # Make sure that nothing has died. + assert ray.services.all_processes_alive() + + +# This test checks that when a driver dies in the middle of a wait, the plasma +# store and raylet will not die. +@pytest.mark.skipif( + os.environ.get("RAY_USE_XRAY") != "1", + reason="This test only works with xray.") +@pytest.mark.skipif( + os.environ.get("RAY_USE_NEW_GCS") == "on", + reason="Not working with new GCS API.") +def test_dying_driver_wait(shutdown_only): + # Start the Ray processes. + address_info = ray.init(num_cpus=1) + + @ray.remote + def sleep_forever(): + time.sleep(10**6) + + x_id = sleep_forever.remote() + + driver = """ +import ray +ray.init("{}") +ray.wait([ray.ObjectID({})]) +""".format(address_info["redis_address"], x_id.id()) + + p = run_string_as_driver_nonblocking(driver) + # Make sure the driver is running. + time.sleep(1) + assert p.poll() is None + + # Kill the driver process. + p.kill() + p.wait() + time.sleep(0.1) + + # Make sure the original task hasn't finished. + ready_ids, _ = ray.wait([x_id], timeout=0) + assert len(ready_ids) == 0 + # Seal the object so the store attempts to notify the worker that the + # wait can return. + ray.worker.global_worker.put_object(x_id, 1) + time.sleep(0.1) + + # Make sure that nothing has died. + assert ray.services.all_processes_alive() + + # This test checks that when a worker dies in the middle of a wait, the # plasma store and manager will not die. @pytest.mark.skipif( diff --git a/test/multi_node_test.py b/test/multi_node_test.py index 657c03710..a1f0bd87b 100644 --- a/test/multi_node_test.py +++ b/test/multi_node_test.py @@ -5,49 +5,11 @@ from __future__ import print_function import os import pytest import subprocess -import sys -import tempfile import time import ray -from ray.test.test_utils import run_and_get_output - - -def run_string_as_driver(driver_script): - """Run a driver as a separate process. - - Args: - driver_script: A string to run as a Python script. - - Returns: - The script's output. - """ - # Save the driver script as a file so we can call it using subprocess. - with tempfile.NamedTemporaryFile() as f: - f.write(driver_script.encode("ascii")) - f.flush() - out = ray.utils.decode( - subprocess.check_output([sys.executable, f.name])) - return out - - -def run_string_as_driver_nonblocking(driver_script): - """Start a driver as a separate process and return immediately. - - Args: - driver_script: A string to run as a Python script. - - Returns: - A handle to the driver process. - """ - # Save the driver script as a file so we can call it using subprocess. We - # do not delete this file because if we do then it may get removed before - # the Python process tries to run it. - with tempfile.NamedTemporaryFile(delete=False) as f: - f.write(driver_script.encode("ascii")) - f.flush() - return subprocess.Popen( - [sys.executable, f.name], stdout=subprocess.PIPE) +from ray.test.test_utils import (run_and_get_output, run_string_as_driver, + run_string_as_driver_nonblocking) @pytest.fixture