From da41180dc07f36c5602172d5a9f462f81ecccb2f Mon Sep 17 00:00:00 2001 From: Stephanie Wang Date: Mon, 2 Dec 2019 10:20:57 -0800 Subject: [PATCH] [direct task] Retry tasks on failure and turn on RAY_FORCE_DIRECT for test_multinode_failures.py (#6306) * multinode failures direct * Add number of retries allowed for tasks * Retry tasks * Add failing test for object reconstruction * Handle return status and debug * update * Retry task unit test * update * update * todo * Fix max_retries decorator, fix test * Fix test that flaked * lint * comments --- python/ray/_raylet.pyx | 6 +- python/ray/includes/libcoreworker.pxd | 3 +- python/ray/remote_function.py | 15 ++- python/ray/tests/BUILD | 8 ++ python/ray/tests/test_failure.py | 2 +- python/ray/tests/test_multinode_failures.py | 91 ++++++++++++++++++- .../tests/test_multinode_failures_direct.py | 18 ++++ python/ray/worker.py | 7 +- rllib/tests/test_optimizers.py | 2 - src/ray/core_worker/core_worker.cc | 9 +- src/ray/core_worker/core_worker.h | 3 +- src/ray/core_worker/task_manager.cc | 43 +++++++-- src/ray/core_worker/task_manager.h | 40 ++++++-- src/ray/core_worker/test/core_worker_test.cc | 3 +- .../test/direct_actor_transport_test.cc | 6 +- .../test/direct_task_transport_test.cc | 2 +- src/ray/core_worker/test/task_manager_test.cc | 39 +++++++- .../transport/direct_actor_transport.cc | 21 ++--- .../transport/direct_task_transport.cc | 24 +++-- .../transport/direct_task_transport.h | 3 +- src/ray/protobuf/common.proto | 2 + 21 files changed, 284 insertions(+), 63 deletions(-) create mode 100644 python/ray/tests/test_multinode_failures_direct.py diff --git a/python/ray/_raylet.pyx b/python/ray/_raylet.pyx index 399ae824d..abc7d1bbd 100644 --- a/python/ray/_raylet.pyx +++ b/python/ray/_raylet.pyx @@ -911,7 +911,8 @@ cdef class CoreWorker: args, int num_return_vals, c_bool is_direct_call, - resources): + resources, + int max_retries): cdef: unordered_map[c_string, double] c_resources CTaskOptions task_options @@ -929,7 +930,8 @@ cdef class CoreWorker: with nogil: check_status(self.core_worker.get().SubmitTask( - ray_function, args_vector, task_options, &return_ids)) + ray_function, args_vector, task_options, &return_ids, + max_retries)) return VectorToObjectIDs(return_ids) diff --git a/python/ray/includes/libcoreworker.pxd b/python/ray/includes/libcoreworker.pxd index 9bd3844dc..af3f73e76 100644 --- a/python/ray/includes/libcoreworker.pxd +++ b/python/ray/includes/libcoreworker.pxd @@ -86,7 +86,8 @@ cdef extern from "ray/core_worker/core_worker.h" nogil: CRayStatus SubmitTask( const CRayFunction &function, const c_vector[CTaskArg] &args, - const CTaskOptions &options, c_vector[CObjectID] *return_ids) + const CTaskOptions &options, c_vector[CObjectID] *return_ids, + int max_retries) CRayStatus CreateActor( const CRayFunction &function, const c_vector[CTaskArg] &args, const CActorCreationOptions &options, CActorID *actor_id) diff --git a/python/ray/remote_function.py b/python/ray/remote_function.py index 15ce2d532..ff6d9080c 100644 --- a/python/ray/remote_function.py +++ b/python/ray/remote_function.py @@ -14,6 +14,9 @@ import ray.signature DEFAULT_REMOTE_FUNCTION_CPUS = 1 DEFAULT_REMOTE_FUNCTION_NUM_RETURN_VALS = 1 DEFAULT_REMOTE_FUNCTION_MAX_CALLS = 0 +# Normal tasks may be retried on failure this many times. +# TODO(swang): Allow this to be set globally for an application. +DEFAULT_REMOTE_FUNCTION_NUM_TASK_RETRIES = 3 logger = logging.getLogger(__name__) @@ -59,7 +62,8 @@ class RemoteFunction(object): """ def __init__(self, function, num_cpus, num_gpus, memory, - object_store_memory, resources, num_return_vals, max_calls): + object_store_memory, resources, num_return_vals, max_calls, + max_retries): self._function = function self._function_name = ( self._function.__module__ + "." + self._function.__name__) @@ -76,6 +80,8 @@ class RemoteFunction(object): num_return_vals is None else num_return_vals) self._max_calls = (DEFAULT_REMOTE_FUNCTION_MAX_CALLS if max_calls is None else max_calls) + self._max_retries = (DEFAULT_REMOTE_FUNCTION_NUM_TASK_RETRIES + if max_retries is None else max_retries) self._decorator = getattr(function, "__ray_invocation_decorator__", None) @@ -142,7 +148,8 @@ class RemoteFunction(object): num_gpus=None, memory=None, object_store_memory=None, - resources=None): + resources=None, + max_retries=None): """Submit the remote function for execution.""" worker = ray.worker.get_global_worker() worker.check_connected() @@ -176,6 +183,8 @@ class RemoteFunction(object): num_return_vals = self._num_return_vals if is_direct_call is None: is_direct_call = self.direct_call_enabled + if max_retries is None: + max_retries = self._max_retries resources = ray.utils.resources_from_resource_arguments( self._num_cpus, self._num_gpus, self._memory, @@ -196,7 +205,7 @@ class RemoteFunction(object): else: object_ids = worker.core_worker.submit_task( self._function_descriptor_list, list_args, num_return_vals, - is_direct_call, resources) + is_direct_call, resources, max_retries) if len(object_ids) == 1: return object_ids[0] diff --git a/python/ray/tests/BUILD b/python/ray/tests/BUILD index 627300847..40059cc77 100644 --- a/python/ray/tests/BUILD +++ b/python/ray/tests/BUILD @@ -70,6 +70,14 @@ py_test( deps = ["//:ray_lib"], ) +py_test( + name = "test_multinode_failures_direct", + size = "medium", + srcs = ["test_multinode_failures_direct.py", "test_multinode_failures.py"], + tags = ["exclusive", "manual"], + deps = ["//:ray_lib"], +) + py_test( name = "test_stress", size = "large", diff --git a/python/ray/tests/test_failure.py b/python/ray/tests/test_failure.py index c6c1f0a50..38d4ab1eb 100644 --- a/python/ray/tests/test_failure.py +++ b/python/ray/tests/test_failure.py @@ -308,7 +308,7 @@ def test_worker_raising_exception(ray_start_regular): def test_worker_dying(ray_start_regular): # Define a remote function that will kill the worker that runs it. - @ray.remote + @ray.remote(max_retries=0) def f(): eval("exit()") diff --git a/python/ray/tests/test_multinode_failures.py b/python/ray/tests/test_multinode_failures.py index 32d1c1a7b..81f4e6237 100644 --- a/python/ray/tests/test_multinode_failures.py +++ b/python/ray/tests/test_multinode_failures.py @@ -16,6 +16,8 @@ import ray.ray_constants as ray_constants from ray.cluster_utils import Cluster from ray.test_utils import RayTestTimeoutException +RAY_FORCE_DIRECT = bool(os.environ.get("RAY_FORCE_DIRECT")) + @pytest.fixture(params=[(1, 4), (4, 4)]) def ray_start_workers_separate_multinode(request): @@ -83,10 +85,20 @@ def _test_component_failed(cluster, component_type): # Submit many tasks with many dependencies. @ray.remote def f(x): + if RAY_FORCE_DIRECT: + # Sleep to make sure that tasks actually fail mid-execution. We + # only use it for direct calls because the test already takes a + # long time to run with the raylet codepath. + time.sleep(0.01) return x @ray.remote def g(*xs): + if RAY_FORCE_DIRECT: + # Sleep to make sure that tasks actually fail mid-execution. We + # only use it for direct calls because the test already takes a + # long time to run with the raylet codepath. + time.sleep(0.01) return 1 # Kill the component on all nodes except the head node as the tasks @@ -138,11 +150,13 @@ def check_components_alive(cluster, component_type, check_component_alive): @pytest.mark.parametrize( - "ray_start_cluster", [{ + "ray_start_cluster", + [{ "num_cpus": 8, "num_nodes": 4, "_internal_config": json.dumps({ - "num_heartbeats_timeout": 100 + # Raylet codepath is not stable with a shorter timeout. + "num_heartbeats_timeout": 10 if RAY_FORCE_DIRECT else 100 }), }], indirect=True) @@ -156,15 +170,83 @@ def test_raylet_failed(ray_start_cluster): True) +@pytest.mark.skipif( + RAY_FORCE_DIRECT, + reason="No reconstruction for objects placed in plasma yet") +@pytest.mark.parametrize( + "ray_start_cluster", + [{ + # Force at least one task per node. + "num_cpus": 1, + "num_nodes": 4, + "object_store_memory": 1000 * 1024 * 1024, + "_internal_config": json.dumps({ + # Raylet codepath is not stable with a shorter timeout. + "num_heartbeats_timeout": 10 if RAY_FORCE_DIRECT else 100, + "object_manager_pull_timeout_ms": 1000, + "object_manager_push_timeout_ms": 1000, + "object_manager_repeated_push_delay_ms": 1000, + }), + }], + indirect=True) +def test_object_reconstruction(ray_start_cluster): + cluster = ray_start_cluster + + # Submit tasks with dependencies in plasma. + @ray.remote + def large_value(): + # Sleep for a bit to force tasks onto different nodes. + time.sleep(0.1) + return np.zeros(10 * 1024 * 1024) + + @ray.remote + def g(x): + return + + # Kill the component on all nodes except the head node as the tasks + # execute. Do this in a loop while submitting tasks between each + # component failure. + time.sleep(0.1) + worker_nodes = cluster.list_all_nodes()[1:] + assert len(worker_nodes) > 0 + component_type = ray_constants.PROCESS_TYPE_RAYLET + for node in worker_nodes: + process = node.all_processes[component_type][0].process + # Submit a round of tasks with many dependencies. + num_tasks = len(worker_nodes) + xs = [large_value.remote() for _ in range(num_tasks)] + # Wait for the tasks to complete, then evict the objects from the local + # node. + for x in xs: + ray.get(x) + ray.internal.free([x], local_only=True) + + # Kill a component on one of the nodes. + process.terminate() + time.sleep(1) + process.kill() + process.wait() + assert not process.poll() is None + + # Make sure that we can still get the objects after the + # executing tasks died. + print("F", xs) + xs = [g.remote(x) for x in xs] + print("G", xs) + ray.get(xs) + + @pytest.mark.skipif( os.environ.get("RAY_USE_NEW_GCS") == "on", reason="Hanging with new GCS API.") @pytest.mark.parametrize( - "ray_start_cluster", [{ + "ray_start_cluster", + [{ "num_cpus": 8, "num_nodes": 2, "_internal_config": json.dumps({ - "num_heartbeats_timeout": 100 + # Raylet codepath is not stable with a shorter timeout. + "num_heartbeats_timeout": 10 if RAY_FORCE_DIRECT else 100 }), }], indirect=True) @@ -179,6 +261,7 @@ def test_plasma_store_failed(ray_start_cluster): check_components_alive(cluster, ray_constants.PROCESS_TYPE_RAYLET, False) +@pytest.mark.skipif(RAY_FORCE_DIRECT, reason="no actor restart yet") @pytest.mark.parametrize( "ray_start_cluster", [{ "num_cpus": 4, diff --git a/python/ray/tests/test_multinode_failures_direct.py b/python/ray/tests/test_multinode_failures_direct.py new file mode 100644 index 000000000..ac15f5d82 --- /dev/null +++ b/python/ray/tests/test_multinode_failures_direct.py @@ -0,0 +1,18 @@ +"""Wrapper script that sets RAY_FORCE_DIRECT.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import pytest +import sys +import os + +if __name__ == "__main__": + os.environ["RAY_FORCE_DIRECT"] = "1" + sys.exit( + pytest.main([ + "-v", + os.path.join( + os.path.dirname(__file__), "test_multinode_failures.py") + ])) diff --git a/python/ray/worker.py b/python/ray/worker.py index bb11cd327..693944cf8 100644 --- a/python/ray/worker.py +++ b/python/ray/worker.py @@ -1621,6 +1621,7 @@ def make_decorator(num_return_vals=None, object_store_memory=None, resources=None, max_calls=None, + max_retries=None, max_reconstructions=None, worker=None): def decorator(function_or_class): @@ -1633,7 +1634,8 @@ def make_decorator(num_return_vals=None, return ray.remote_function.RemoteFunction( function_or_class, num_cpus, num_gpus, memory, - object_store_memory, resources, num_return_vals, max_calls) + object_store_memory, resources, num_return_vals, max_calls, + max_retries) if inspect.isclass(function_or_class): if num_return_vals is not None: @@ -1732,6 +1734,7 @@ def remote(*args, **kwargs): "resources", "max_calls", "max_reconstructions", + "max_retries", ], error_string num_cpus = kwargs["num_cpus"] if "num_cpus" in kwargs else None @@ -1751,6 +1754,7 @@ def remote(*args, **kwargs): max_reconstructions = kwargs.get("max_reconstructions") memory = kwargs.get("memory") object_store_memory = kwargs.get("object_store_memory") + max_retries = kwargs.get("max_retries") return make_decorator( num_return_vals=num_return_vals, @@ -1761,4 +1765,5 @@ def remote(*args, **kwargs): resources=resources, max_calls=max_calls, max_reconstructions=max_reconstructions, + max_retries=max_retries, worker=worker) diff --git a/rllib/tests/test_optimizers.py b/rllib/tests/test_optimizers.py index 395aaeda6..a5df9b3e6 100644 --- a/rllib/tests/test_optimizers.py +++ b/rllib/tests/test_optimizers.py @@ -183,9 +183,7 @@ class AsyncSamplesOptimizerTest(unittest.TestCase): print(stats) self.assertLess(stats["num_steps_sampled"], 5000) replay_ratio = stats["num_steps_replayed"] / stats["num_steps_sampled"] - train_ratio = stats["num_steps_sampled"] / stats["num_steps_trained"] self.assertGreater(replay_ratio, 0.7) - self.assertLess(train_ratio, 0.4) def testMultiTierAggregationBadConf(self): local, remotes = self._make_envs() diff --git a/src/ray/core_worker/core_worker.cc b/src/ray/core_worker/core_worker.cc index ee6c4a850..f76539f38 100644 --- a/src/ray/core_worker/core_worker.cc +++ b/src/ray/core_worker/core_worker.cc @@ -165,7 +165,10 @@ CoreWorker::CoreWorker(const WorkerType worker_type, const Language language, }, ref_counting_enabled ? reference_counter_ : nullptr, raylet_client_)); - task_manager_.reset(new TaskManager(memory_store_)); + task_manager_.reset( + new TaskManager(memory_store_, [this](const TaskSpecification &spec) { + RAY_CHECK_OK(direct_task_submitter_->SubmitTask(spec)); + })); resolver_.reset(new LocalDependencyResolver(memory_store_)); // Create an entry for the driver task in the task table. This task is @@ -589,7 +592,7 @@ void CoreWorker::PinObjectReferences(const TaskSpecification &task_spec, Status CoreWorker::SubmitTask(const RayFunction &function, const std::vector &args, const TaskOptions &task_options, - std::vector *return_ids) { + std::vector *return_ids, int max_retries) { TaskSpecBuilder builder; const int next_task_index = worker_context_.GetNextTaskIndex(); const auto task_id = @@ -605,7 +608,7 @@ Status CoreWorker::SubmitTask(const RayFunction &function, return_ids); TaskSpecification task_spec = builder.Build(); if (task_options.is_direct_call) { - task_manager_->AddPendingTask(task_spec); + task_manager_->AddPendingTask(task_spec, max_retries); PinObjectReferences(task_spec, TaskTransportType::DIRECT); return direct_task_submitter_->SubmitTask(task_spec); } else { diff --git a/src/ray/core_worker/core_worker.h b/src/ray/core_worker/core_worker.h index b2590ab72..31553405d 100644 --- a/src/ray/core_worker/core_worker.h +++ b/src/ray/core_worker/core_worker.h @@ -280,7 +280,8 @@ class CoreWorker { /// \param[out] return_ids Ids of the return objects. /// \return Status error if task submission fails, likely due to raylet failure. Status SubmitTask(const RayFunction &function, const std::vector &args, - const TaskOptions &task_options, std::vector *return_ids); + const TaskOptions &task_options, std::vector *return_ids, + int max_retries); /// Create an actor. /// diff --git a/src/ray/core_worker/task_manager.cc b/src/ray/core_worker/task_manager.cc index 7e3663f49..67ac09a52 100644 --- a/src/ray/core_worker/task_manager.cc +++ b/src/ray/core_worker/task_manager.cc @@ -2,10 +2,11 @@ namespace ray { -void TaskManager::AddPendingTask(const TaskSpecification &spec) { +void TaskManager::AddPendingTask(const TaskSpecification &spec, int max_retries) { RAY_LOG(DEBUG) << "Adding pending task " << spec.TaskId(); absl::MutexLock lock(&mu_); - RAY_CHECK(pending_tasks_.emplace(spec.TaskId(), spec.NumReturns()).second); + std::pair entry = {spec, max_retries}; + RAY_CHECK(pending_tasks_.emplace(spec.TaskId(), std::move(entry)).second); } void TaskManager::CompletePendingTask(const TaskID &task_id, @@ -48,18 +49,48 @@ void TaskManager::CompletePendingTask(const TaskID &task_id, } } -void TaskManager::FailPendingTask(const TaskID &task_id, rpc::ErrorType error_type) { +void TaskManager::PendingTaskFailed(const TaskID &task_id, rpc::ErrorType error_type) { + if (error_type == rpc::ErrorType::ACTOR_DIED) { + // Note that this might be the __ray_terminate__ task, so we don't log + // loudly with ERROR here. + RAY_LOG(INFO) << "Task " << task_id << " failed with error " + << rpc::ErrorType_Name(error_type); + } else { + RAY_LOG(ERROR) << "Task " << task_id << " failed with error " + << rpc::ErrorType_Name(error_type); + } + RAY_LOG(DEBUG) << "Failing task " << task_id; - int64_t num_returns; + int num_retries_left = 0; + TaskSpecification spec; { absl::MutexLock lock(&mu_); auto it = pending_tasks_.find(task_id); RAY_CHECK(it != pending_tasks_.end()) << "Tried to complete task that was not pending " << task_id; - num_returns = it->second; - pending_tasks_.erase(it); + spec = it->second.first; + num_retries_left = it->second.second; + if (num_retries_left == 0) { + pending_tasks_.erase(it); + } else { + RAY_CHECK(num_retries_left > 0); + it->second.second--; + } } + // We should not hold the lock during these calls because they may trigger + // callbacks in this or other classes. + if (num_retries_left > 0) { + RAY_LOG(ERROR) << num_retries_left << " retries left for task " << spec.TaskId() + << ", attempting to resubmit."; + retry_task_callback_(spec); + } else { + MarkPendingTaskFailed(task_id, spec.NumReturns(), error_type); + } +} + +void TaskManager::MarkPendingTaskFailed(const TaskID &task_id, int64_t num_returns, + rpc::ErrorType error_type) { RAY_LOG(DEBUG) << "Treat task as failed. task_id: " << task_id << ", error_type: " << ErrorType_Name(error_type); for (int i = 0; i < num_returns; i++) { diff --git a/src/ray/core_worker/task_manager.h b/src/ray/core_worker/task_manager.h index 6309a3a48..52b7d3baf 100644 --- a/src/ray/core_worker/task_manager.h +++ b/src/ray/core_worker/task_manager.h @@ -18,21 +18,26 @@ class TaskFinisherInterface { virtual void CompletePendingTask(const TaskID &task_id, const rpc::PushTaskReply &reply) = 0; - virtual void FailPendingTask(const TaskID &task_id, rpc::ErrorType error_type) = 0; + virtual void PendingTaskFailed(const TaskID &task_id, rpc::ErrorType error_type) = 0; virtual ~TaskFinisherInterface() {} }; +using RetryTaskCallback = std::function; + class TaskManager : public TaskFinisherInterface { public: - TaskManager(std::shared_ptr in_memory_store) - : in_memory_store_(in_memory_store) {} + TaskManager(std::shared_ptr in_memory_store, + RetryTaskCallback retry_task_callback) + : in_memory_store_(in_memory_store), retry_task_callback_(retry_task_callback) {} /// Add a task that is pending execution. /// /// \param[in] spec The spec of the pending task. + /// \param[in] max_retries Number of times this task may be retried + /// on failure. /// \return Void. - void AddPendingTask(const TaskSpecification &spec); + void AddPendingTask(const TaskSpecification &spec, int max_retries = 0); /// Return whether the task is pending. /// @@ -50,23 +55,38 @@ class TaskManager : public TaskFinisherInterface { void CompletePendingTask(const TaskID &task_id, const rpc::PushTaskReply &reply) override; - /// Treat a pending task as failed. + /// A pending task failed. This will either retry the task or mark the task + /// as failed if there are no retries left. /// /// \param[in] task_id ID of the pending task. /// \param[in] error_type The type of the specific error. - /// \return Void. - void FailPendingTask(const TaskID &task_id, rpc::ErrorType error_type) override; + void PendingTaskFailed(const TaskID &task_id, rpc::ErrorType error_type) override; private: + /// Treat a pending task as failed. The lock should not be held when calling + /// this method because it may trigger callbacks in this or other classes. + void MarkPendingTaskFailed(const TaskID &task_id, int64_t num_returns, + rpc::ErrorType error_type) LOCKS_EXCLUDED(mu_); + /// Used to store task results. std::shared_ptr in_memory_store_; + /// Called when a task should be retried. + const RetryTaskCallback retry_task_callback_; + /// Protects below fields. absl::Mutex mu_; - /// Map from task ID to the task's number of return values. This map contains - /// one entry per pending task that we submitted. - absl::flat_hash_map pending_tasks_ GUARDED_BY(mu_); + /// Map from task ID to a pair of: + /// {task spec, number of allowed retries left} + /// This map contains one entry per pending task that we submitted. + /// TODO(swang): The TaskSpec protobuf must be copied into the + /// PushTaskRequest protobuf when sent to a worker so that we can retry it if + /// the worker fails. We could avoid this by either not caching the full + /// TaskSpec for tasks that cannot be retried (e.g., actor tasks), or by + /// storing a shared_ptr to a PushTaskRequest protobuf for all tasks. + absl::flat_hash_map> pending_tasks_ + GUARDED_BY(mu_); }; } // namespace ray diff --git a/src/ray/core_worker/test/core_worker_test.cc b/src/ray/core_worker/test/core_worker_test.cc index f9c1c1793..91b81b1ec 100644 --- a/src/ray/core_worker/test/core_worker_test.cc +++ b/src/ray/core_worker/test/core_worker_test.cc @@ -259,7 +259,8 @@ void CoreWorkerTest::TestNormalTask(std::unordered_map &res options.is_direct_call = true; std::vector return_ids; - RAY_CHECK_OK(driver.SubmitTask(func, args, options, &return_ids)); + RAY_CHECK_OK( + driver.SubmitTask(func, args, options, &return_ids, /*max_retries=*/0)); ASSERT_EQ(return_ids.size(), 1); diff --git a/src/ray/core_worker/test/direct_actor_transport_test.cc b/src/ray/core_worker/test/direct_actor_transport_test.cc index 769b69523..0f37fa4ab 100644 --- a/src/ray/core_worker/test/direct_actor_transport_test.cc +++ b/src/ray/core_worker/test/direct_actor_transport_test.cc @@ -42,7 +42,7 @@ class MockTaskFinisher : public TaskFinisherInterface { MockTaskFinisher() {} MOCK_METHOD2(CompletePendingTask, void(const TaskID &, const rpc::PushTaskReply &)); - MOCK_METHOD2(FailPendingTask, void(const TaskID &task_id, rpc::ErrorType error_type)); + MOCK_METHOD2(PendingTaskFailed, void(const TaskID &task_id, rpc::ErrorType error_type)); }; TaskSpecification CreateActorTaskHelper(ActorID actor_id, int64_t counter) { @@ -86,7 +86,7 @@ TEST_F(DirectActorTransportTest, TestSubmitTask) { EXPECT_CALL(*task_finisher_, CompletePendingTask(TaskID::Nil(), _)) .Times(worker_client_->callbacks.size()); - EXPECT_CALL(*task_finisher_, FailPendingTask(_, _)).Times(0); + EXPECT_CALL(*task_finisher_, PendingTaskFailed(_, _)).Times(0); while (!worker_client_->callbacks.empty()) { ASSERT_TRUE(worker_client_->ReplyPushTask()); } @@ -163,7 +163,7 @@ TEST_F(DirectActorTransportTest, TestActorFailure) { ASSERT_EQ(worker_client_->callbacks.size(), 2); // Simulate the actor dying. All submitted tasks should get failed. - EXPECT_CALL(*task_finisher_, FailPendingTask(_, _)).Times(2); + EXPECT_CALL(*task_finisher_, PendingTaskFailed(_, _)).Times(2); EXPECT_CALL(*task_finisher_, CompletePendingTask(_, _)).Times(0); while (!worker_client_->callbacks.empty()) { ASSERT_TRUE(worker_client_->ReplyPushTask(Status::IOError(""))); diff --git a/src/ray/core_worker/test/direct_task_transport_test.cc b/src/ray/core_worker/test/direct_task_transport_test.cc index 941ea2012..04da447ad 100644 --- a/src/ray/core_worker/test/direct_task_transport_test.cc +++ b/src/ray/core_worker/test/direct_task_transport_test.cc @@ -44,7 +44,7 @@ class MockTaskFinisher : public TaskFinisherInterface { void CompletePendingTask(const TaskID &, const rpc::PushTaskReply &) override { num_tasks_complete++; } - void FailPendingTask(const TaskID &task_id, rpc::ErrorType error_type) override { + void PendingTaskFailed(const TaskID &task_id, rpc::ErrorType error_type) override { num_tasks_failed++; } diff --git a/src/ray/core_worker/test/task_manager_test.cc b/src/ray/core_worker/test/task_manager_test.cc index a276980cd..2f8bf6090 100644 --- a/src/ray/core_worker/test/task_manager_test.cc +++ b/src/ray/core_worker/test/task_manager_test.cc @@ -18,10 +18,14 @@ class TaskManagerTest : public ::testing::Test { public: TaskManagerTest() : store_(std::shared_ptr(new CoreWorkerMemoryStore())), - manager_(store_) {} + manager_(store_, [this](const TaskSpecification &spec) { + num_retries_++; + return Status::OK(); + }) {} std::shared_ptr store_; TaskManager manager_; + int num_retries_ = 0; }; TEST_F(TaskManagerTest, TestTaskSuccess) { @@ -47,6 +51,7 @@ TEST_F(TaskManagerTest, TestTaskSuccess) { ASSERT_EQ(std::memcmp(results[0]->GetData()->Data(), return_object->data().data(), return_object->data().size()), 0); + ASSERT_EQ(num_retries_, 0); } TEST_F(TaskManagerTest, TestTaskFailure) { @@ -58,7 +63,7 @@ TEST_F(TaskManagerTest, TestTaskFailure) { WorkerContext ctx(WorkerType::WORKER, JobID::FromInt(0)); auto error = rpc::ErrorType::WORKER_DIED; - manager_.FailPendingTask(spec.TaskId(), error); + manager_.PendingTaskFailed(spec.TaskId(), error); ASSERT_FALSE(manager_.IsTaskPending(spec.TaskId())); std::vector> results; @@ -67,6 +72,36 @@ TEST_F(TaskManagerTest, TestTaskFailure) { rpc::ErrorType stored_error; ASSERT_TRUE(results[0]->IsException(&stored_error)); ASSERT_EQ(stored_error, error); + ASSERT_EQ(num_retries_, 0); +} + +TEST_F(TaskManagerTest, TestTaskRetry) { + auto spec = CreateTaskHelper(1); + ASSERT_FALSE(manager_.IsTaskPending(spec.TaskId())); + int num_retries = 3; + manager_.AddPendingTask(spec, num_retries); + ASSERT_TRUE(manager_.IsTaskPending(spec.TaskId())); + auto return_id = spec.ReturnId(0, TaskTransportType::DIRECT); + WorkerContext ctx(WorkerType::WORKER, JobID::FromInt(0)); + + auto error = rpc::ErrorType::WORKER_DIED; + for (int i = 0; i < num_retries; i++) { + manager_.PendingTaskFailed(spec.TaskId(), error); + ASSERT_TRUE(manager_.IsTaskPending(spec.TaskId())); + std::vector> results; + ASSERT_FALSE(store_->Get({return_id}, 1, 0, ctx, false, &results).ok()); + ASSERT_EQ(num_retries_, i + 1); + } + + manager_.PendingTaskFailed(spec.TaskId(), error); + ASSERT_FALSE(manager_.IsTaskPending(spec.TaskId())); + + std::vector> results; + RAY_CHECK_OK(store_->Get({return_id}, 1, -0, ctx, false, &results)); + ASSERT_EQ(results.size(), 1); + rpc::ErrorType stored_error; + ASSERT_TRUE(results[0]->IsException(&stored_error)); + ASSERT_EQ(stored_error, error); } } // namespace ray diff --git a/src/ray/core_worker/transport/direct_actor_transport.cc b/src/ray/core_worker/transport/direct_actor_transport.cc index ac9b7b8db..1826f512f 100644 --- a/src/ray/core_worker/transport/direct_actor_transport.cc +++ b/src/ray/core_worker/transport/direct_actor_transport.cc @@ -20,7 +20,10 @@ Status CoreWorkerDirectActorTaskSubmitter::SubmitTask(TaskSpecification task_spe const auto task_id = task_spec.TaskId(); auto request = std::unique_ptr(new rpc::PushTaskRequest); - request->mutable_task_spec()->Swap(&task_spec.GetMutableMessage()); + // NOTE(swang): CopyFrom is needed because if we use Swap here and the task + // fails, then the task data will be gone when the TaskManager attempts to + // access the task. + request->mutable_task_spec()->CopyFrom(task_spec.GetMessage()); std::unique_lock guard(mutex_); @@ -45,7 +48,7 @@ Status CoreWorkerDirectActorTaskSubmitter::SubmitTask(TaskSpecification task_spe } else { // Actor is dead, treat the task as failure. RAY_CHECK(iter->second.state_ == ActorTableData::DEAD); - task_finisher_->FailPendingTask(task_id, rpc::ErrorType::ACTOR_DIED); + task_finisher_->PendingTaskFailed(task_id, rpc::ErrorType::ACTOR_DIED); } }); @@ -85,7 +88,7 @@ void CoreWorkerDirectActorTaskSubmitter::HandleActorUpdate( auto request = std::move(head->second); head = pending_it->second.erase(head); auto task_id = TaskID::FromBinary(request->task_spec().task_id()); - task_finisher_->FailPendingTask(task_id, rpc::ErrorType::ACTOR_DIED); + task_finisher_->PendingTaskFailed(task_id, rpc::ErrorType::ACTOR_DIED); } pending_requests_.erase(pending_it); } @@ -123,21 +126,15 @@ void CoreWorkerDirectActorTaskSubmitter::PushActorTask( << "Counter was " << task_number << " expected " << next_sequence_number_[actor_id]; next_sequence_number_[actor_id]++; - auto status = client.PushActorTask( + RAY_CHECK_OK(client.PushActorTask( std::move(request), [this, task_id](Status status, const rpc::PushTaskReply &reply) { if (!status.ok()) { - // Note that this might be the __ray_terminate__ task, so we don't log - // loudly with ERROR here. - RAY_LOG(INFO) << "Task failed with error: " << status; - task_finisher_->FailPendingTask(task_id, rpc::ErrorType::ACTOR_DIED); + task_finisher_->PendingTaskFailed(task_id, rpc::ErrorType::ACTOR_DIED); } else { task_finisher_->CompletePendingTask(task_id, reply); } - }); - if (!status.ok()) { - task_finisher_->FailPendingTask(task_id, rpc::ErrorType::ACTOR_DIED); - } + })); } bool CoreWorkerDirectActorTaskSubmitter::IsActorAlive(const ActorID &actor_id) const { diff --git a/src/ray/core_worker/transport/direct_task_transport.cc b/src/ray/core_worker/transport/direct_task_transport.cc index f275f5403..171c9bbb6 100644 --- a/src/ray/core_worker/transport/direct_task_transport.cc +++ b/src/ray/core_worker/transport/direct_task_transport.cc @@ -5,7 +5,9 @@ namespace ray { Status CoreWorkerDirectTaskSubmitter::SubmitTask(TaskSpecification task_spec) { + RAY_LOG(DEBUG) << "Submit task " << task_spec.TaskId(); resolver_.ResolveDependencies(task_spec, [this, task_spec]() { + RAY_LOG(DEBUG) << "Task dependencies resolved " << task_spec.TaskId(); absl::MutexLock lock(&mu_); // Note that the dependencies in the task spec are mutated to only contain // plasma dependencies after ResolveDependencies finishes. @@ -138,11 +140,15 @@ void CoreWorkerDirectTaskSubmitter::RequestNewWorkerIfNeeded( void CoreWorkerDirectTaskSubmitter::PushNormalTask(const rpc::WorkerAddress &addr, rpc::CoreWorkerClientInterface &client, const SchedulingKey &scheduling_key, - TaskSpecification &task_spec) { + const TaskSpecification &task_spec) { auto task_id = task_spec.TaskId(); auto request = std::unique_ptr(new rpc::PushTaskRequest); - request->mutable_task_spec()->Swap(&task_spec.GetMutableMessage()); - auto status = client.PushNormalTask( + RAY_LOG(DEBUG) << "Pushing normal task " << task_spec.TaskId(); + // NOTE(swang): CopyFrom is needed because if we use Swap here and the task + // fails, then the task data will be gone when the TaskManager attempts to + // access the task. + request->mutable_task_spec()->CopyFrom(task_spec.GetMessage()); + RAY_CHECK_OK(client.PushNormalTask( std::move(request), [this, task_id, scheduling_key, addr]( Status status, const rpc::PushTaskReply &reply) { { @@ -150,14 +156,14 @@ void CoreWorkerDirectTaskSubmitter::PushNormalTask(const rpc::WorkerAddress &add OnWorkerIdle(addr, scheduling_key, /*error=*/!status.ok()); } if (!status.ok()) { - task_finisher_->FailPendingTask(task_id, rpc::ErrorType::WORKER_DIED); + // TODO: It'd be nice to differentiate here between process vs node + // failure (e.g., by contacting the raylet). If it was a process + // failure, it may have been an application-level error and it may + // not make sense to retry the task. + task_finisher_->PendingTaskFailed(task_id, rpc::ErrorType::WORKER_DIED); } else { task_finisher_->CompletePendingTask(task_id, reply); } - }); - if (!status.ok()) { - // TODO(swang): add unit test for this. - task_finisher_->FailPendingTask(task_id, rpc::ErrorType::WORKER_DIED); - } + })); } }; // namespace ray diff --git a/src/ray/core_worker/transport/direct_task_transport.h b/src/ray/core_worker/transport/direct_task_transport.h index e269c26a4..6650b188f 100644 --- a/src/ray/core_worker/transport/direct_task_transport.h +++ b/src/ray/core_worker/transport/direct_task_transport.h @@ -76,7 +76,8 @@ class CoreWorkerDirectTaskSubmitter { /// Push a task to a specific worker. void PushNormalTask(const rpc::WorkerAddress &addr, rpc::CoreWorkerClientInterface &client, - const SchedulingKey &task_queue_key, TaskSpecification &task_spec); + const SchedulingKey &task_queue_key, + const TaskSpecification &task_spec); // Client that can be used to lease and return workers from the local raylet. std::shared_ptr local_lease_client_; diff --git a/src/ray/protobuf/common.proto b/src/ray/protobuf/common.proto index ee52bb811..2c3e41cf2 100644 --- a/src/ray/protobuf/common.proto +++ b/src/ray/protobuf/common.proto @@ -77,6 +77,8 @@ message TaskSpec { ActorTaskSpec actor_task_spec = 15; // Whether this task is a direct call task. bool is_direct_call = 16; + // Number of times this task may be retried on worker failure. + int32 max_retries = 17; } // Argument in the task.