diff --git a/cmake/Modules/ArrowExternalProject.cmake b/cmake/Modules/ArrowExternalProject.cmake index dfb25f244..827673c35 100644 --- a/cmake/Modules/ArrowExternalProject.cmake +++ b/cmake/Modules/ArrowExternalProject.cmake @@ -14,10 +14,10 @@ # - PLASMA_SHARED_LIB set(arrow_URL https://github.com/apache/arrow.git) -# The PR for this commit is https://github.com/apache/arrow/pull/2522. We +# The PR for this commit is https://github.com/apache/arrow/pull/2664. We # include the link here to make it easier to find the right commit because # Arrow often rewrites git history and invalidates certain commits. -set(arrow_TAG 7104d64ff2cd6c20e29d3cf4ec5c58bc10798f66) +set(arrow_TAG 3545186d6997b943ffc3d79634f2d08eefbd7322) set(ARROW_INSTALL_PREFIX ${CMAKE_CURRENT_BINARY_DIR}/external/arrow-install) set(ARROW_HOME ${ARROW_INSTALL_PREFIX}) diff --git a/python/ray/test/test_utils.py b/python/ray/test/test_utils.py index a29daa5a0..6e18cd439 100644 --- a/python/ray/test/test_utils.py +++ b/python/ray/test/test_utils.py @@ -6,6 +6,7 @@ import json import os import redis import subprocess +import sys import tempfile import time @@ -147,3 +148,40 @@ def run_and_get_output(command): with open(tmp.name, 'r') as f: result = f.readlines() return "\n".join(result) + + +def run_string_as_driver(driver_script): + """Run a driver as a separate process. + + Args: + driver_script: A string to run as a Python script. + + Returns: + The script's output. + """ + # Save the driver script as a file so we can call it using subprocess. + with tempfile.NamedTemporaryFile() as f: + f.write(driver_script.encode("ascii")) + f.flush() + out = ray.utils.decode( + subprocess.check_output([sys.executable, f.name])) + return out + + +def run_string_as_driver_nonblocking(driver_script): + """Start a driver as a separate process and return immediately. + + Args: + driver_script: A string to run as a Python script. + + Returns: + A handle to the driver process. + """ + # Save the driver script as a file so we can call it using subprocess. We + # do not delete this file because if we do then it may get removed before + # the Python process tries to run it. + with tempfile.NamedTemporaryFile(delete=False) as f: + f.write(driver_script.encode("ascii")) + f.flush() + return subprocess.Popen( + [sys.executable, f.name], stdout=subprocess.PIPE) diff --git a/src/ray/common/client_connection.h b/src/ray/common/client_connection.h index 39d3c084d..20b232c33 100644 --- a/src/ray/common/client_connection.h +++ b/src/ray/common/client_connection.h @@ -114,7 +114,7 @@ class ClientConnection : public ServerConnection, MessageHandler message_handler_; /// A label used for debug messages. const std::string debug_label_; - /// Buffers for the current message being read rom the client. + /// Buffers for the current message being read from the client. int64_t read_version_; int64_t read_type_; uint64_t read_length_; diff --git a/src/ray/raylet/node_manager.cc b/src/ray/raylet/node_manager.cc index bed03ad90..fcc60e030 100644 --- a/src/ray/raylet/node_manager.cc +++ b/src/ray/raylet/node_manager.cc @@ -222,7 +222,8 @@ void NodeManager::KillWorker(std::shared_ptr worker) { retry_timer->expires_from_now(retry_duration); retry_timer->async_wait([retry_timer, worker](const boost::system::error_code &error) { RAY_LOG(DEBUG) << "Send SIGKILL to worker, pid=" << worker->Pid(); - // Force kill worker. + // Force kill worker. TODO(rkn): Is there some small danger that the worker + // has already died and the PID has been reassigned to a different process? kill(worker->Pid(), SIGKILL); }); } @@ -638,8 +639,25 @@ void NodeManager::ProcessGetTaskMessage( void NodeManager::ProcessDisconnectClientMessage( const std::shared_ptr &client) { - // Remove the dead worker from the pool and stop listening for messages. const std::shared_ptr worker = worker_pool_.GetRegisteredWorker(client); + const std::shared_ptr driver = worker_pool_.GetRegisteredDriver(client); + // This client can't be a worker and a driver. + RAY_CHECK(worker == nullptr || driver == nullptr); + + // If both worker and driver are null, then this method has already been + // called, so just return. + if (worker == nullptr && driver == nullptr) { + RAY_LOG(INFO) << "Ignoring client disconnect because the client has already " + << "been disconnected."; + return; + } + + // If the client is blocked, we need to treat it as unblocked. In particular, + // we are no longer waiting for its dependencies. If the client is not + // blocked, this won't do anything. + HandleClientUnblocked(client); + + // Remove the dead client from the pool and stop listening for messages. if (worker) { // The client is a worker. Handle the case where the worker is killed @@ -651,17 +669,15 @@ void NodeManager::ProcessDisconnectClientMessage( // If the worker was killed intentionally, e.g., when the driver that created // the task that this worker is currently executing exits, the task for this // worker has already been removed from queue, so the following are skipped. - auto const &running_tasks = local_queues_.GetRunningTasks(); - // TODO(rkn): This is too heavyweight just to get the task's driver ID. - auto const it = std::find_if( - running_tasks.begin(), running_tasks.end(), [task_id](const Task &task) { - return task.GetTaskSpecification().TaskId() == task_id; - }); - RAY_CHECK(running_tasks.size() != 0); - RAY_CHECK(it != running_tasks.end()); - const TaskSpecification &spec = it->GetTaskSpecification(); - const JobID job_id = spec.DriverId(); + task_dependency_manager_.TaskCanceled(task_id); + // task_dependency_manager_.UnsubscribeDependencies(current_task_id); + const Task &task = local_queues_.RemoveTask(task_id); + const TaskSpecification &spec = task.GetTaskSpecification(); + // Handle the task failure in order to raise an exception in the + // application. + TreatTaskAsFailed(spec); + const JobID &job_id = worker->GetAssignedDriverId(); // TODO(rkn): Define this constant somewhere else. std::string type = "worker_died"; std::ostringstream error_message; @@ -669,18 +685,12 @@ void NodeManager::ProcessDisconnectClientMessage( << "."; RAY_CHECK_OK(gcs_client_->error_table().PushErrorToDriver( job_id, type, error_message.str(), current_time_ms())); - - // Handle the task failure in order to raise an exception in the - // application. - TreatTaskAsFailed(spec); - task_dependency_manager_.TaskCanceled(spec.TaskId()); - local_queues_.RemoveTask(spec.TaskId()); } worker_pool_.DisconnectWorker(worker); // If the worker was an actor, add it to the list of dead actors. - const ActorID actor_id = worker->GetActorId(); + const ActorID &actor_id = worker->GetActorId(); if (!actor_id.is_nil()) { // TODO(rkn): Consider broadcasting a message to all of the other // node managers so that they can mark the actor as dead. @@ -715,7 +725,6 @@ void NodeManager::ProcessDisconnectClientMessage( // The client is a driver. RAY_CHECK_OK(gcs_client_->driver_table().AppendDriverData(client->GetClientID(), /*is_dead=*/true)); - const std::shared_ptr driver = worker_pool_.GetRegisteredDriver(client); RAY_CHECK(driver); auto driver_id = driver->GetAssignedTaskId(); RAY_CHECK(!driver_id.is_nil()); @@ -725,6 +734,10 @@ void NodeManager::ProcessDisconnectClientMessage( RAY_LOG(DEBUG) << "Driver (pid=" << driver->Pid() << ") is disconnected. " << "driver_id: " << driver->GetAssignedDriverId(); } + + // TODO(rkn): Tell the object manager that this client has disconnected so + // that it can clean up the wait requests for this client. Currently I think + // these can be leaked. } void NodeManager::ProcessSubmitTaskMessage(const uint8_t *message_data) { @@ -798,12 +811,21 @@ void NodeManager::ProcessWaitRequestMessage( flatbuffers::Offset wait_reply = protocol::CreateWaitReply( fbb, to_flatbuf(fbb, found), to_flatbuf(fbb, remaining)); fbb.Finish(wait_reply); - RAY_CHECK_OK( + + auto status = client->WriteMessage(static_cast(protocol::MessageType::WaitReply), - fbb.GetSize(), fbb.GetBufferPointer())); - // The client is unblocked now because the wait call has returned. - if (client_blocked) { - HandleClientUnblocked(client); + fbb.GetSize(), fbb.GetBufferPointer()); + if (status.ok()) { + // The client is unblocked now because the wait call has returned. + if (client_blocked) { + HandleClientUnblocked(client); + } + } else { + // We failed to write to the client, so disconnect the client. + RAY_LOG(WARNING) + << "Failed to send WaitReply to client, so disconnecting client"; + // We failed to send the reply to the client, so disconnect the worker. + ProcessDisconnectClientMessage(client); } }); RAY_CHECK_OK(status); @@ -1308,9 +1330,7 @@ void NodeManager::AssignTask(Task &task) { } else { RAY_LOG(WARNING) << "Failed to send task to worker, disconnecting client"; // We failed to send the task to the worker, so disconnect the worker. - ProcessClientMessage(worker->Connection(), - static_cast(protocol::MessageType::DisconnectClient), - nullptr); + ProcessDisconnectClientMessage(worker->Connection()); // Queue this task for future assignment. The task will be assigned to a // worker once one becomes available. // (See design_docs/task_states.rst for the state transition diagram.) diff --git a/src/ray/raylet/node_manager.h b/src/ray/raylet/node_manager.h index d8025a6c8..b6b23223c 100644 --- a/src/ray/raylet/node_manager.h +++ b/src/ray/raylet/node_manager.h @@ -267,7 +267,7 @@ class NodeManager { bool CheckDependencyManagerInvariant() const; /// Process client message of RegisterClientRequest - // + /// /// \param client The client that sent the message. /// \param message_data A pointer to the message data. /// \return Void. @@ -275,26 +275,29 @@ class NodeManager { const std::shared_ptr &client, const uint8_t *message_data); /// Process client message of GetTask - // + /// /// \param client The client that sent the message. /// \return Void. void ProcessGetTaskMessage(const std::shared_ptr &client); - /// Process client message of DisconnectClient - // + /// Handle a client that has disconnected. This can be called multiple times + /// on the same client because this is triggered both when a client + /// disconnects and when the node manager fails to write a message to the + /// client. + /// /// \param client The client that sent the message. /// \return Void. void ProcessDisconnectClientMessage( const std::shared_ptr &client); /// Process client message of SubmitTask - // + /// /// \param message_data A pointer to the message data. /// \return Void. void ProcessSubmitTaskMessage(const uint8_t *message_data); /// Process client message of ReconstructObjects - // + /// /// \param client The client that sent the message. /// \param message_data A pointer to the message data. /// \return Void. @@ -302,7 +305,7 @@ class NodeManager { const std::shared_ptr &client, const uint8_t *message_data); /// Process client message of WaitRequest - // + /// /// \param client The client that sent the message. /// \param message_data A pointer to the message data. /// \return Void. @@ -310,7 +313,7 @@ class NodeManager { const uint8_t *message_data); /// Process client message of PushErrorRequest - // + /// /// \param message_data A pointer to the message data. /// \return Void. void ProcessPushErrorRequestMessage(const uint8_t *message_data); diff --git a/test/component_failures_test.py b/test/component_failures_test.py index 64dd3712b..3a57452e6 100644 --- a/test/component_failures_test.py +++ b/test/component_failures_test.py @@ -2,11 +2,14 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import pytest import os -import ray +import signal import time +import pytest + +import ray +from ray.test.test_utils import run_string_as_driver_nonblocking import pyarrow as pa @@ -23,6 +26,112 @@ def ray_start_workers_separate(): ray.shutdown() +@pytest.fixture +def shutdown_only(): + yield None + # The code after the yield will run as teardown code. + ray.shutdown() + + +# This test checks that when a worker dies in the middle of a get, the plasma +# store and raylet will not die. +@pytest.mark.skipif( + os.environ.get("RAY_USE_XRAY") != "1", + reason="This test only works with xray.") +@pytest.mark.skipif( + os.environ.get("RAY_USE_NEW_GCS") == "on", + reason="Not working with new GCS API.") +def test_dying_worker_get_raylet(shutdown_only): + # Start the Ray processes. + ray.init(num_cpus=2) + + @ray.remote + def sleep_forever(): + time.sleep(10**6) + + @ray.remote + def get_worker_pid(): + return os.getpid() + + x_id = sleep_forever.remote() + time.sleep(0.01) # Try to wait for the sleep task to get scheduled. + # Get the PID of the other worker. + worker_pid = ray.get(get_worker_pid.remote()) + + @ray.remote + def f(id_in_a_list): + ray.get(id_in_a_list[0]) + + # Have the worker wait in a get call. + result_id = f.remote([x_id]) + time.sleep(1) + + # Make sure the task hasn't finished. + ready_ids, _ = ray.wait([result_id], timeout=0) + assert len(ready_ids) == 0 + + # Kill the worker. + os.kill(worker_pid, signal.SIGKILL) + time.sleep(0.1) + + # Make sure the sleep task hasn't finished. + ready_ids, _ = ray.wait([x_id], timeout=0) + assert len(ready_ids) == 0 + # Seal the object so the store attempts to notify the worker that the + # get has been fulfilled. + ray.worker.global_worker.put_object(x_id, 1) + time.sleep(0.1) + + # Make sure that nothing has died. + assert ray.services.all_processes_alive() + + +# This test checks that when a driver dies in the middle of a get, the plasma +# store and raylet will not die. +@pytest.mark.skipif( + os.environ.get("RAY_USE_XRAY") != "1", + reason="This test only works with xray.") +@pytest.mark.skipif( + os.environ.get("RAY_USE_NEW_GCS") == "on", + reason="Not working with new GCS API.") +def test_dying_driver_get(shutdown_only): + # Start the Ray processes. + address_info = ray.init(num_cpus=1) + + @ray.remote + def sleep_forever(): + time.sleep(10**6) + + x_id = sleep_forever.remote() + + driver = """ +import ray +ray.init("{}") +ray.get(ray.ObjectID({})) +""".format(address_info["redis_address"], x_id.id()) + + p = run_string_as_driver_nonblocking(driver) + # Make sure the driver is running. + time.sleep(1) + assert p.poll() is None + + # Kill the driver process. + p.kill() + p.wait() + time.sleep(0.1) + + # Make sure the original task hasn't finished. + ready_ids, _ = ray.wait([x_id], timeout=0) + assert len(ready_ids) == 0 + # Seal the object so the store attempts to notify the worker that the + # get has been fulfilled. + ray.worker.global_worker.put_object(x_id, 1) + time.sleep(0.1) + + # Make sure that nothing has died. + assert ray.services.all_processes_alive() + + # This test checks that when a worker dies in the middle of a get, the # plasma store and manager will not die. @pytest.mark.skipif( @@ -59,6 +168,97 @@ def test_dying_worker_get(ray_start_workers_separate): exclude=[ray.services.PROCESS_TYPE_WORKER]) +# This test checks that when a worker dies in the middle of a wait, the plasma +# store and raylet will not die. +@pytest.mark.skipif( + os.environ.get("RAY_USE_XRAY") != "1", + reason="This test only works with xray.") +@pytest.mark.skipif( + os.environ.get("RAY_USE_NEW_GCS") == "on", + reason="Not working with new GCS API.") +def test_dying_worker_wait_raylet(shutdown_only): + ray.init(num_cpus=2) + + @ray.remote + def sleep_forever(): + time.sleep(10**6) + + @ray.remote + def get_pid(): + return os.getpid() + + x_id = sleep_forever.remote() + # Get the PID of the worker that block_in_wait will run on (sleep a little + # to make sure that sleep_forever has already started). + time.sleep(0.1) + worker_pid = ray.get(get_pid.remote()) + + @ray.remote + def block_in_wait(object_id_in_list): + ray.wait(object_id_in_list) + + # Have the worker wait in a wait call. + block_in_wait.remote([x_id]) + time.sleep(0.1) + + # Kill the worker. + os.kill(worker_pid, signal.SIGKILL) + time.sleep(0.1) + + # Create the object. + ray.worker.global_worker.put_object(x_id, 1) + time.sleep(0.1) + + # Make sure that nothing has died. + assert ray.services.all_processes_alive() + + +# This test checks that when a driver dies in the middle of a wait, the plasma +# store and raylet will not die. +@pytest.mark.skipif( + os.environ.get("RAY_USE_XRAY") != "1", + reason="This test only works with xray.") +@pytest.mark.skipif( + os.environ.get("RAY_USE_NEW_GCS") == "on", + reason="Not working with new GCS API.") +def test_dying_driver_wait(shutdown_only): + # Start the Ray processes. + address_info = ray.init(num_cpus=1) + + @ray.remote + def sleep_forever(): + time.sleep(10**6) + + x_id = sleep_forever.remote() + + driver = """ +import ray +ray.init("{}") +ray.wait([ray.ObjectID({})]) +""".format(address_info["redis_address"], x_id.id()) + + p = run_string_as_driver_nonblocking(driver) + # Make sure the driver is running. + time.sleep(1) + assert p.poll() is None + + # Kill the driver process. + p.kill() + p.wait() + time.sleep(0.1) + + # Make sure the original task hasn't finished. + ready_ids, _ = ray.wait([x_id], timeout=0) + assert len(ready_ids) == 0 + # Seal the object so the store attempts to notify the worker that the + # wait can return. + ray.worker.global_worker.put_object(x_id, 1) + time.sleep(0.1) + + # Make sure that nothing has died. + assert ray.services.all_processes_alive() + + # This test checks that when a worker dies in the middle of a wait, the # plasma store and manager will not die. @pytest.mark.skipif( diff --git a/test/multi_node_test.py b/test/multi_node_test.py index 657c03710..a1f0bd87b 100644 --- a/test/multi_node_test.py +++ b/test/multi_node_test.py @@ -5,49 +5,11 @@ from __future__ import print_function import os import pytest import subprocess -import sys -import tempfile import time import ray -from ray.test.test_utils import run_and_get_output - - -def run_string_as_driver(driver_script): - """Run a driver as a separate process. - - Args: - driver_script: A string to run as a Python script. - - Returns: - The script's output. - """ - # Save the driver script as a file so we can call it using subprocess. - with tempfile.NamedTemporaryFile() as f: - f.write(driver_script.encode("ascii")) - f.flush() - out = ray.utils.decode( - subprocess.check_output([sys.executable, f.name])) - return out - - -def run_string_as_driver_nonblocking(driver_script): - """Start a driver as a separate process and return immediately. - - Args: - driver_script: A string to run as a Python script. - - Returns: - A handle to the driver process. - """ - # Save the driver script as a file so we can call it using subprocess. We - # do not delete this file because if we do then it may get removed before - # the Python process tries to run it. - with tempfile.NamedTemporaryFile(delete=False) as f: - f.write(driver_script.encode("ascii")) - f.flush() - return subprocess.Popen( - [sys.executable, f.name], stdout=subprocess.PIPE) +from ray.test.test_utils import (run_and_get_output, run_string_as_driver, + run_string_as_driver_nonblocking) @pytest.fixture