From bfcf254e528b76b631bbfc24c9a8dcf47dd2ddab Mon Sep 17 00:00:00 2001 From: Hao Chen Date: Thu, 24 Jan 2019 15:28:44 +0800 Subject: [PATCH] Fix: do not treat actor task as failed if the actor will be reconstructed (#3736) --- src/ray/raylet/node_manager.cc | 65 +++++++++++++++++++--------------- test/actor_test.py | 28 ++++++++------- test/multi_node_test.py | 38 ++++++++++++++++++++ 3 files changed, 90 insertions(+), 41 deletions(-) diff --git a/src/ray/raylet/node_manager.cc b/src/ray/raylet/node_manager.cc index e20173cb8..952fadabe 100644 --- a/src/ray/raylet/node_manager.cc +++ b/src/ray/raylet/node_manager.cc @@ -809,33 +809,49 @@ void NodeManager::ProcessDisconnectClientMessage( // If the client has any blocked tasks, mark them as unblocked. In // particular, we are no longer waiting for their dependencies. if (worker) { - while (!worker->GetBlockedTaskIds().empty()) { - // NOTE(swang): HandleTaskUnblocked will modify the worker, so it is - // not safe to pass in the iterator directly. - const TaskID task_id = *worker->GetBlockedTaskIds().begin(); - HandleTaskUnblocked(client, task_id); + if (is_worker && worker->IsDead()) { + // Don't need to unblock the client if it's a worker and is already dead. + // Because in this case, its task is already cleaned up. + RAY_LOG(DEBUG) << "Skip unblocking worker because it's already dead."; + } else { + while (!worker->GetBlockedTaskIds().empty()) { + // NOTE(swang): HandleTaskUnblocked will modify the worker, so it is + // not safe to pass in the iterator directly. + const TaskID task_id = *worker->GetBlockedTaskIds().begin(); + HandleTaskUnblocked(client, task_id); + } } } - // Remove the dead client from the pool and stop listening for messages. if (is_worker) { - // The client is a worker. Handle the case where the worker is killed - // while executing a task. Clean up the assigned task's resources, push - // an error to the driver. - // (See design_docs/task_states.rst for the state transition diagram.) - const TaskID &task_id = worker->GetAssignedTaskId(); - if (!task_id.is_nil() && !worker->IsDead()) { - // If the worker was killed intentionally, e.g., when the driver that created - // the task that this worker is currently executing exits, the task for this - // worker has already been removed from queue, so the following are skipped. - const Task &task = local_queues_.RemoveTask(task_id); - // Handle the task failure in order to raise an exception in the - // application. - TreatTaskAsFailed(task); + // The client is a worker. + if (worker->IsDead()) { + // If the worker was killed by us because the driver exited, + // treat it as intentionally disconnected. + intentional_disconnect = true; + } - const JobID &job_id = worker->GetAssignedDriverId(); + const ActorID &actor_id = worker->GetActorId(); + if (!actor_id.is_nil()) { + // If the worker was an actor, update actor state, reconstruct the actor if needed, + // and clean up actor's tasks if the actor is permanently dead. + HandleDisconnectedActor(actor_id, true, intentional_disconnect); + } + + const TaskID &task_id = worker->GetAssignedTaskId(); + // If the worker was running a task, clean up the task and push an error to + // the driver, unless the worker is already dead. + if (!task_id.is_nil() && !worker->IsDead()) { + // If the worker was an actor, the task was already cleaned up in + // `HandleDisconnectedActor`. + if (actor_id.is_nil()) { + const Task &task = local_queues_.RemoveTask(task_id); + TreatTaskAsFailed(task); + } if (!intentional_disconnect) { + // Push the error to driver. + const JobID &job_id = worker->GetAssignedDriverId(); // TODO(rkn): Define this constant somewhere else. std::string type = "worker_died"; std::ostringstream error_message; @@ -846,16 +862,9 @@ void NodeManager::ProcessDisconnectClientMessage( } } + // Remove the dead client from the pool and stop listening for messages. worker_pool_.DisconnectWorker(worker); - // If the worker was an actor, add it to the list of dead actors. - const ActorID &actor_id = worker->GetActorId(); - if (!actor_id.is_nil()) { - RAY_LOG(DEBUG) << "The actor with ID " << actor_id << " died on " - << gcs_client_->client_table().GetLocalClientId(); - HandleDisconnectedActor(actor_id, /*was_local=*/true, intentional_disconnect); - } - const ClientID &client_id = gcs_client_->client_table().GetLocalClientId(); // Return the resources that were being used by this worker. diff --git a/test/actor_test.py b/test/actor_test.py index da5c54091..3617f78e8 100644 --- a/test/actor_test.py +++ b/test/actor_test.py @@ -2215,30 +2215,32 @@ def test_actor_reconstruction(ray_start_regular): def __init__(self): self.value = 0 - def increase(self): + def increase(self, delay=0): + time.sleep(delay) self.value += 1 return self.value def get_pid(self): return os.getpid() - def kill_actor(actor): - """Kill actor process.""" - pid = ray.get(actor.get_pid.remote()) - os.kill(pid, signal.SIGKILL) - time.sleep(1) - actor = ReconstructableActor.remote() + pid = ray.get(actor.get_pid.remote()) # Call increase 3 times for _ in range(3): ray.get(actor.increase.remote()) - # kill actor process - kill_actor(actor) - # Call increase again. - # Check that actor is reconstructed and value is 4. - assert ray.get(actor.increase.remote()) == 4 + # Call increase again with some delay. + result = actor.increase.remote(delay=0.5) + # Sleep some time to wait for the above task to start execution. + time.sleep(0.2) + # Kill actor process, while the above task is still being executed. + os.kill(pid, signal.SIGKILL) + # Check that the above task didn't fail and the actor is reconstructed. + assert ray.get(result) == 4 + # Check that we can still call the actor. + assert ray.get(actor.increase.remote()) == 5 # kill actor process one more time. - kill_actor(actor) + pid = ray.get(actor.get_pid.remote()) + os.kill(pid, signal.SIGKILL) # The actor has exceeded max reconstructions, and this task should fail. with pytest.raises(ray.worker.RayTaskError): ray.get(actor.increase.remote()) diff --git a/test/multi_node_test.py b/test/multi_node_test.py index 3986aa02b..239789d13 100644 --- a/test/multi_node_test.py +++ b/test/multi_node_test.py @@ -416,3 +416,41 @@ print("success") for i in range(2): out = run_string_as_driver(driver_script) assert "success" in out + + +def test_driver_exiting_when_worker_blocked(ray_start_head): + # This test will create some drivers that submit some tasks and then + # exit without waiting for the tasks to complete. + redis_address = ray_start_head + + ray.init(redis_address=redis_address) + + # Define a driver that creates an actor and exits. + driver_script = """ +import time +import ray +ray.init(redis_address="{}") +@ray.remote +def f(): + time.sleep(10**6) +@ray.remote +def g(): + ray.get(f.remote()) +g.remote() +time.sleep(1) +print("success") +""".format(redis_address) + + # Create some drivers and let them exit and make sure everything is + # still alive. + for _ in range(3): + out = run_string_as_driver(driver_script) + # Make sure the first driver ran to completion. + assert "success" in out + + @ray.remote + def f(): + return 1 + + # Make sure we can still talk with the raylet. + ray.get(f.remote())