diff --git a/python/ray/_raylet.pyx b/python/ray/_raylet.pyx index db6def77e..359cf208a 100644 --- a/python/ray/_raylet.pyx +++ b/python/ray/_raylet.pyx @@ -421,10 +421,6 @@ cdef class RayletClient: def job_id(self): return JobID(self.client.GetJobID().Binary()) - @property - def is_worker(self): - return self.client.IsWorker() - cdef deserialize_args( const c_vector[shared_ptr[CRayObject]] &c_args, const c_vector[CObjectID] &arg_reference_ids): diff --git a/python/ray/tests/test_basic.py b/python/ray/tests/test_basic.py index d229d1767..401ebc0f5 100644 --- a/python/ray/tests/test_basic.py +++ b/python/ray/tests/test_basic.py @@ -1205,15 +1205,25 @@ def test_get_with_timeout(ray_start_regular): assert ray.get(obj_id, timeout=2) == 3 -def test_direct_call_simple(ray_start_regular): +@pytest.mark.parametrize( + "ray_start_cluster", [{ + "num_cpus": 1, + "num_nodes": 1, + }, { + "num_cpus": 1, + "num_nodes": 2, + }], + indirect=True) +def test_direct_call_simple(ray_start_cluster): @ray.remote def f(x): return x + 1 f_direct = f.options(is_direct_call=True) assert ray.get(f_direct.remote(2)) == 3 - assert ray.get([f_direct.remote(i) for i in range(100)]) == list( - range(1, 101)) + for _ in range(10): + assert ray.get([f_direct.remote(i) for i in range(100)]) == list( + range(1, 101)) def test_direct_call_refcount(ray_start_regular): @@ -1302,7 +1312,16 @@ def test_direct_call_matrix(shutdown_only): check(source_actor, dest_actor, is_large, out_of_band) -def test_direct_call_chain(ray_start_regular): +@pytest.mark.parametrize( + "ray_start_cluster", [{ + "num_cpus": 1, + "num_nodes": 1, + }, { + "num_cpus": 1, + "num_nodes": 2, + }], + indirect=True) +def test_direct_call_chain(ray_start_cluster): @ray.remote def g(x): return x + 1 diff --git a/src/ray/common/grpc_util.h b/src/ray/common/grpc_util.h index 5095548ee..d725abdd8 100644 --- a/src/ray/common/grpc_util.h +++ b/src/ray/common/grpc_util.h @@ -6,6 +6,7 @@ #include #include + #include "status.h" namespace ray { diff --git a/src/ray/common/task/task.h b/src/ray/common/task/task.h index 596cd49b5..9b6418f51 100644 --- a/src/ray/common/task/task.h +++ b/src/ray/common/task/task.h @@ -11,6 +11,10 @@ namespace ray { typedef std::function, const std::string &, int)> DispatchTaskCallback; +/// Arguments are the raylet ID to spill back to, the raylet's +/// address and the raylet's port. +typedef std::function + SpillbackTaskCallback; /// \class Task /// @@ -42,12 +46,15 @@ class Task { } /// Override dispatch behaviour. - void OnDispatchInstead( - std::function, const std::string &, int)> - callback) { + void OnDispatchInstead(const DispatchTaskCallback &callback) { on_dispatch_ = callback; } + /// Override spillback behaviour. + void OnSpillbackInstead(const SpillbackTaskCallback &callback) { + on_spillback_ = callback; + } + /// Get the mutable specification for the task. This specification may be /// updated at runtime. /// @@ -73,7 +80,10 @@ class Task { void CopyTaskExecutionSpec(const Task &task); /// Returns the override dispatch task callback, or nullptr. - DispatchTaskCallback &OnDispatch() const { return on_dispatch_; } + const DispatchTaskCallback &OnDispatch() const { return on_dispatch_; } + + /// Returns the override spillback task callback, or nullptr. + const SpillbackTaskCallback &OnSpillback() const { return on_spillback_; } std::string DebugString() const; @@ -95,6 +105,9 @@ class Task { /// For direct task calls, overrides the dispatch behaviour to send an RPC /// back to the submitting worker. mutable DispatchTaskCallback on_dispatch_ = nullptr; + /// For direct task calls, overrides the spillback behaviour to send an RPC + /// back to the submitting worker. + mutable SpillbackTaskCallback on_spillback_ = nullptr; }; } // namespace ray diff --git a/src/ray/core_worker/core_worker.cc b/src/ray/core_worker/core_worker.cc index dc16e1f77..5e424363a 100644 --- a/src/ray/core_worker/core_worker.cc +++ b/src/ray/core_worker/core_worker.cc @@ -158,7 +158,7 @@ CoreWorker::CoreWorker(const WorkerType worker_type, const Language language, auto grpc_client = rpc::NodeManagerWorkerClient::make( node_ip_address, node_manager_port, *client_call_manager_); ClientID raylet_id; - raylet_client_ = std::unique_ptr(new RayletClient( + raylet_client_ = std::shared_ptr(new RayletClient( std::move(grpc_client), raylet_socket, WorkerID::FromBinary(worker_context_.GetWorkerID().Binary()), (worker_type_ == ray::WorkerType::WORKER), worker_context_.GetCurrentJobID(), @@ -227,11 +227,17 @@ CoreWorker::CoreWorker(const WorkerType worker_type, const Language language, direct_task_submitter_ = std::unique_ptr(new CoreWorkerDirectTaskSubmitter( - *raylet_client_, + raylet_client_, [this](WorkerAddress addr) { return std::shared_ptr(new rpc::CoreWorkerClient( addr.first, addr.second, *client_call_manager_)); }, + [this](const rpc::Address &address) { + auto grpc_client = rpc::NodeManagerWorkerClient::make( + address.ip_address(), address.port(), *client_call_manager_); + return std::shared_ptr( + new RayletClient(std::move(grpc_client))); + }, memory_store_provider_)); } @@ -930,14 +936,4 @@ void CoreWorker::HandleDirectActorCallArgWaitComplete( }); } -void CoreWorker::HandleWorkerLeaseGranted(const rpc::WorkerLeaseGrantedRequest &request, - rpc::WorkerLeaseGrantedReply *reply, - rpc::SendReplyCallback send_reply_callback) { - // Run this directly since the main thread may be tied up processing a task and - // we need to still continue processing these scheduling operations in the backend. - direct_task_submitter_->HandleWorkerLeaseGranted( - std::make_pair(request.address(), request.port())); - send_reply_callback(Status::OK(), nullptr, nullptr); -} - } // namespace ray diff --git a/src/ray/core_worker/core_worker.h b/src/ray/core_worker/core_worker.h index bd9800b6b..67cae483f 100644 --- a/src/ray/core_worker/core_worker.h +++ b/src/ray/core_worker/core_worker.h @@ -25,11 +25,10 @@ /// 1) Add the rpc to the CoreWorkerService in core_worker.proto, e.g., "ExampleCall" /// 2) Add a new handler to the macro below: "RAY_CORE_WORKER_RPC_HANDLER(ExampleCall, 1)" /// 3) Add a method to the CoreWorker class below: "CoreWorker::HandleExampleCall" -#define RAY_CORE_WORKER_RPC_HANDLERS \ - RAY_CORE_WORKER_RPC_HANDLER(AssignTask, 5) \ - RAY_CORE_WORKER_RPC_HANDLER(PushTask, 9999) \ - RAY_CORE_WORKER_RPC_HANDLER(DirectActorCallArgWaitComplete, 100) \ - RAY_CORE_WORKER_RPC_HANDLER(WorkerLeaseGranted, 5) +#define RAY_CORE_WORKER_RPC_HANDLERS \ + RAY_CORE_WORKER_RPC_HANDLER(AssignTask, 5) \ + RAY_CORE_WORKER_RPC_HANDLER(PushTask, 9999) \ + RAY_CORE_WORKER_RPC_HANDLER(DirectActorCallArgWaitComplete, 100) namespace ray { @@ -355,11 +354,6 @@ class CoreWorker { rpc::DirectActorCallArgWaitCompleteReply *reply, rpc::SendReplyCallback send_reply_callback); - /// Implements gRPC server handler. - void HandleWorkerLeaseGranted(const rpc::WorkerLeaseGrantedRequest &request, - rpc::WorkerLeaseGrantedReply *reply, - rpc::SendReplyCallback send_reply_callback); - private: /// Run the io_service_ event loop. This should be called in a background thread. void RunIOService(); @@ -485,8 +479,11 @@ class CoreWorker { // Client to the GCS shared by core worker interfaces. std::shared_ptr gcs_client_; - // Client to the raylet shared by core worker interfaces. - std::unique_ptr raylet_client_; + // Client to the raylet shared by core worker interfaces. This needs to be a + // shared_ptr for direct calls because we can lease multiple workers through + // one client, and we need to keep the connection alive until we return all + // of the workers. + std::shared_ptr raylet_client_; // Thread that runs a boost::asio service to process IO events. std::thread io_thread_; diff --git a/src/ray/core_worker/store_provider/plasma_store_provider.cc b/src/ray/core_worker/store_provider/plasma_store_provider.cc index dcceb2322..9df34f040 100644 --- a/src/ray/core_worker/store_provider/plasma_store_provider.cc +++ b/src/ray/core_worker/store_provider/plasma_store_provider.cc @@ -7,7 +7,7 @@ namespace ray { CoreWorkerPlasmaStoreProvider::CoreWorkerPlasmaStoreProvider( - const std::string &store_socket, const std::unique_ptr &raylet_client, + const std::string &store_socket, const std::shared_ptr raylet_client, std::function check_signals) : raylet_client_(raylet_client) { check_signals_ = check_signals; @@ -27,7 +27,6 @@ Status CoreWorkerPlasmaStoreProvider::SetClientOptions(std::string name, Status CoreWorkerPlasmaStoreProvider::Put(const RayObject &object, const ObjectID &object_id) { - RAY_CHECK(!object_id.IsDirectCallType()); std::shared_ptr data; RAY_RETURN_NOT_OK(Create(object.GetMetadata(), object.HasData() ? object.GetData()->Size() : 0, object_id, @@ -47,6 +46,7 @@ Status CoreWorkerPlasmaStoreProvider::Create(const std::shared_ptr &meta const size_t data_size, const ObjectID &object_id, std::shared_ptr *data) { + RAY_CHECK(!object_id.IsDirectCallType()); auto plasma_id = object_id.ToPlasmaId(); std::shared_ptr arrow_buffer; { diff --git a/src/ray/core_worker/store_provider/plasma_store_provider.h b/src/ray/core_worker/store_provider/plasma_store_provider.h index 5b94cd536..c908da8cc 100644 --- a/src/ray/core_worker/store_provider/plasma_store_provider.h +++ b/src/ray/core_worker/store_provider/plasma_store_provider.h @@ -19,7 +19,7 @@ namespace ray { class CoreWorkerPlasmaStoreProvider { public: CoreWorkerPlasmaStoreProvider(const std::string &store_socket, - const std::unique_ptr &raylet_client, + const std::shared_ptr raylet_client, std::function check_signals); ~CoreWorkerPlasmaStoreProvider(); @@ -80,7 +80,7 @@ class CoreWorkerPlasmaStoreProvider { static void WarnIfAttemptedTooManyTimes(int num_attempts, const absl::flat_hash_set &remaining); - const std::unique_ptr &raylet_client_; + const std::shared_ptr raylet_client_; plasma::PlasmaClient store_client_; std::mutex store_client_mutex_; std::function check_signals_; diff --git a/src/ray/core_worker/test/core_worker_test.cc b/src/ray/core_worker/test/core_worker_test.cc index 34450bb4a..e3a8a7af6 100644 --- a/src/ray/core_worker/test/core_worker_test.cc +++ b/src/ray/core_worker/test/core_worker_test.cc @@ -227,6 +227,7 @@ void CoreWorkerTest::TestNormalTask(std::unordered_map &res RayFunction func(ray::Language::PYTHON, {}); TaskOptions options; + options.is_direct_call = true; std::vector return_ids; RAY_CHECK_OK(driver.SubmitTask(func, args, options, &return_ids)); diff --git a/src/ray/core_worker/test/direct_task_transport_test.cc b/src/ray/core_worker/test/direct_task_transport_test.cc index ff03163a8..4ffcb93bf 100644 --- a/src/ray/core_worker/test/direct_task_transport_test.cc +++ b/src/ray/core_worker/test/direct_task_transport_test.cc @@ -24,18 +24,47 @@ class MockWorkerClient : public rpc::CoreWorkerClientInterface { class MockRayletClient : public WorkerLeaseInterface { public: - ray::Status ReturnWorker(int worker_port) { + ray::Status ReturnWorker(int worker_port) override { num_workers_returned += 1; return Status::OK(); } - ray::Status RequestWorkerLease(const ray::TaskSpecification &resource_spec) { + ray::Status RequestWorkerLease( + const ray::TaskSpecification &resource_spec, + const rpc::ClientCallback &callback) override { num_workers_requested += 1; + callbacks.push_back(callback); return Status::OK(); } + // Trigger reply to RequestWorkerLease. + bool GrantWorkerLease(const std::string &address, int port, + const ClientID &retry_at_raylet_id) { + rpc::WorkerLeaseReply reply; + if (!retry_at_raylet_id.IsNil()) { + reply.mutable_retry_at_raylet_address()->set_ip_address(address); + reply.mutable_retry_at_raylet_address()->set_port(port); + reply.mutable_retry_at_raylet_address()->set_raylet_id(retry_at_raylet_id.Binary()); + } else { + reply.mutable_worker_address()->set_ip_address(address); + reply.mutable_worker_address()->set_port(port); + reply.mutable_worker_address()->set_raylet_id(retry_at_raylet_id.Binary()); + } + if (callbacks.size() == 0) { + return false; + } else { + auto callback = callbacks.front(); + callback(Status::OK(), reply); + callbacks.pop_front(); + return true; + } + } + + ~MockRayletClient() {} + int num_workers_requested = 0; int num_workers_returned = 0; + std::list> callbacks = {}; }; TEST(TestMemoryStore, TestPromoteToPlasma) { @@ -159,52 +188,52 @@ TEST(LocalDependencyResolverTest, TestInlinePendingDependencies) { } TEST(DirectTaskTransportTest, TestSubmitOneTask) { - MockRayletClient raylet_client; + auto raylet_client = std::make_shared(); auto worker_client = std::shared_ptr(new MockWorkerClient()); auto ptr = std::shared_ptr(new CoreWorkerMemoryStore()); auto store = std::make_shared(ptr); auto factory = [&](WorkerAddress addr) { return worker_client; }; - CoreWorkerDirectTaskSubmitter submitter(raylet_client, factory, store); + CoreWorkerDirectTaskSubmitter submitter(raylet_client, factory, nullptr, store); TaskSpecification task; task.GetMutableMessage().set_task_id(TaskID::Nil().Binary()); ASSERT_TRUE(submitter.SubmitTask(task).ok()); - ASSERT_EQ(raylet_client.num_workers_requested, 1); - ASSERT_EQ(raylet_client.num_workers_returned, 0); + ASSERT_EQ(raylet_client->num_workers_requested, 1); + ASSERT_EQ(raylet_client->num_workers_returned, 0); ASSERT_EQ(worker_client->callbacks.size(), 0); - submitter.HandleWorkerLeaseGranted(std::make_pair("localhost", 1234)); + ASSERT_TRUE(raylet_client->GrantWorkerLease("localhost", 1234, ClientID::Nil())); ASSERT_EQ(worker_client->callbacks.size(), 1); worker_client->callbacks[0](Status::OK(), rpc::PushTaskReply()); - ASSERT_EQ(raylet_client.num_workers_returned, 1); + ASSERT_EQ(raylet_client->num_workers_returned, 1); } TEST(DirectTaskTransportTest, TestHandleTaskFailure) { - MockRayletClient raylet_client; + auto raylet_client = std::make_shared(); auto worker_client = std::shared_ptr(new MockWorkerClient()); auto ptr = std::shared_ptr(new CoreWorkerMemoryStore()); auto store = std::make_shared(ptr); auto factory = [&](WorkerAddress addr) { return worker_client; }; - CoreWorkerDirectTaskSubmitter submitter(raylet_client, factory, store); + CoreWorkerDirectTaskSubmitter submitter(raylet_client, factory, nullptr, store); TaskSpecification task; task.GetMutableMessage().set_task_id(TaskID::Nil().Binary()); ASSERT_TRUE(submitter.SubmitTask(task).ok()); - submitter.HandleWorkerLeaseGranted(std::make_pair("localhost", 1234)); + ASSERT_TRUE(raylet_client->GrantWorkerLease("localhost", 1234, ClientID::Nil())); // Simulate a system failure, i.e., worker died unexpectedly. worker_client->callbacks[0](Status::IOError("oops"), rpc::PushTaskReply()); ASSERT_EQ(worker_client->callbacks.size(), 1); - ASSERT_EQ(raylet_client.num_workers_returned, 1); + ASSERT_EQ(raylet_client->num_workers_returned, 1); } TEST(DirectTaskTransportTest, TestConcurrentWorkerLeases) { - MockRayletClient raylet_client; + auto raylet_client = std::make_shared(); auto worker_client = std::shared_ptr(new MockWorkerClient()); auto ptr = std::shared_ptr(new CoreWorkerMemoryStore()); auto store = std::make_shared(ptr); auto factory = [&](WorkerAddress addr) { return worker_client; }; - CoreWorkerDirectTaskSubmitter submitter(raylet_client, factory, store); + CoreWorkerDirectTaskSubmitter submitter(raylet_client, factory, nullptr, store); TaskSpecification task1; TaskSpecification task2; TaskSpecification task3; @@ -215,37 +244,37 @@ TEST(DirectTaskTransportTest, TestConcurrentWorkerLeases) { ASSERT_TRUE(submitter.SubmitTask(task1).ok()); ASSERT_TRUE(submitter.SubmitTask(task2).ok()); ASSERT_TRUE(submitter.SubmitTask(task3).ok()); - ASSERT_EQ(raylet_client.num_workers_requested, 1); + ASSERT_EQ(raylet_client->num_workers_requested, 1); // Task 1 is pushed; worker 2 is requested. - submitter.HandleWorkerLeaseGranted(std::make_pair("localhost", 1000)); + ASSERT_TRUE(raylet_client->GrantWorkerLease("localhost", 1000, ClientID::Nil())); ASSERT_EQ(worker_client->callbacks.size(), 1); - ASSERT_EQ(raylet_client.num_workers_requested, 2); + ASSERT_EQ(raylet_client->num_workers_requested, 2); // Task 2 is pushed; worker 3 is requested. - submitter.HandleWorkerLeaseGranted(std::make_pair("localhost", 1001)); + ASSERT_TRUE(raylet_client->GrantWorkerLease("localhost", 1001, ClientID::Nil())); ASSERT_EQ(worker_client->callbacks.size(), 2); - ASSERT_EQ(raylet_client.num_workers_requested, 3); + ASSERT_EQ(raylet_client->num_workers_requested, 3); // Task 3 is pushed; no more workers requested. - submitter.HandleWorkerLeaseGranted(std::make_pair("localhost", 1002)); + ASSERT_TRUE(raylet_client->GrantWorkerLease("localhost", 1002, ClientID::Nil())); ASSERT_EQ(worker_client->callbacks.size(), 3); - ASSERT_EQ(raylet_client.num_workers_requested, 3); + ASSERT_EQ(raylet_client->num_workers_requested, 3); // All workers returned. for (const auto &cb : worker_client->callbacks) { cb(Status::OK(), rpc::PushTaskReply()); } - ASSERT_EQ(raylet_client.num_workers_returned, 3); + ASSERT_EQ(raylet_client->num_workers_returned, 3); } TEST(DirectTaskTransportTest, TestReuseWorkerLease) { - MockRayletClient raylet_client; + auto raylet_client = std::make_shared(); auto worker_client = std::shared_ptr(new MockWorkerClient()); auto ptr = std::shared_ptr(new CoreWorkerMemoryStore()); auto store = std::make_shared(ptr); auto factory = [&](WorkerAddress addr) { return worker_client; }; - CoreWorkerDirectTaskSubmitter submitter(raylet_client, factory, store); + CoreWorkerDirectTaskSubmitter submitter(raylet_client, factory, nullptr, store); TaskSpecification task1; TaskSpecification task2; TaskSpecification task3; @@ -256,39 +285,39 @@ TEST(DirectTaskTransportTest, TestReuseWorkerLease) { ASSERT_TRUE(submitter.SubmitTask(task1).ok()); ASSERT_TRUE(submitter.SubmitTask(task2).ok()); ASSERT_TRUE(submitter.SubmitTask(task3).ok()); - ASSERT_EQ(raylet_client.num_workers_requested, 1); + ASSERT_EQ(raylet_client->num_workers_requested, 1); // Task 1 is pushed. - submitter.HandleWorkerLeaseGranted(std::make_pair("localhost", 1000)); + ASSERT_TRUE(raylet_client->GrantWorkerLease("localhost", 1000, ClientID::Nil())); ASSERT_EQ(worker_client->callbacks.size(), 1); - ASSERT_EQ(raylet_client.num_workers_requested, 2); + ASSERT_EQ(raylet_client->num_workers_requested, 2); // Task 1 finishes, Task 2 is scheduled on the same worker. worker_client->callbacks[0](Status::OK(), rpc::PushTaskReply()); ASSERT_EQ(worker_client->callbacks.size(), 2); - ASSERT_EQ(raylet_client.num_workers_returned, 0); + ASSERT_EQ(raylet_client->num_workers_returned, 0); // Task 2 finishes, Task 3 is scheduled on the same worker. worker_client->callbacks[1](Status::OK(), rpc::PushTaskReply()); ASSERT_EQ(worker_client->callbacks.size(), 3); - ASSERT_EQ(raylet_client.num_workers_returned, 0); + ASSERT_EQ(raylet_client->num_workers_returned, 0); // Task 3 finishes, the worker is returned. worker_client->callbacks[2](Status::OK(), rpc::PushTaskReply()); - ASSERT_EQ(raylet_client.num_workers_returned, 1); + ASSERT_EQ(raylet_client->num_workers_returned, 1); // The second lease request is returned immediately. - submitter.HandleWorkerLeaseGranted(std::make_pair("localhost", 1001)); - ASSERT_EQ(raylet_client.num_workers_returned, 2); + ASSERT_TRUE(raylet_client->GrantWorkerLease("localhost", 1001, ClientID::Nil())); + ASSERT_EQ(raylet_client->num_workers_returned, 2); } TEST(DirectTaskTransportTest, TestWorkerNotReusedOnError) { - MockRayletClient raylet_client; + auto raylet_client = std::make_shared(); auto worker_client = std::shared_ptr(new MockWorkerClient()); auto ptr = std::shared_ptr(new CoreWorkerMemoryStore()); auto store = std::make_shared(ptr); auto factory = [&](WorkerAddress addr) { return worker_client; }; - CoreWorkerDirectTaskSubmitter submitter(raylet_client, factory, store); + CoreWorkerDirectTaskSubmitter submitter(raylet_client, factory, nullptr, store); TaskSpecification task1; TaskSpecification task2; task1.GetMutableMessage().set_task_id(TaskID::Nil().Binary()); @@ -296,23 +325,67 @@ TEST(DirectTaskTransportTest, TestWorkerNotReusedOnError) { ASSERT_TRUE(submitter.SubmitTask(task1).ok()); ASSERT_TRUE(submitter.SubmitTask(task2).ok()); - ASSERT_EQ(raylet_client.num_workers_requested, 1); + ASSERT_EQ(raylet_client->num_workers_requested, 1); // Task 1 is pushed. - submitter.HandleWorkerLeaseGranted(std::make_pair("localhost", 1000)); + ASSERT_TRUE(raylet_client->GrantWorkerLease("localhost", 1000, ClientID::Nil())); ASSERT_EQ(worker_client->callbacks.size(), 1); - ASSERT_EQ(raylet_client.num_workers_requested, 2); + ASSERT_EQ(raylet_client->num_workers_requested, 2); // Task 1 finishes with failure; the worker is returned. worker_client->callbacks[0](Status::IOError("worker dead"), rpc::PushTaskReply()); ASSERT_EQ(worker_client->callbacks.size(), 1); - ASSERT_EQ(raylet_client.num_workers_returned, 1); + ASSERT_EQ(raylet_client->num_workers_returned, 1); // Task 2 runs successfully on the second worker. - submitter.HandleWorkerLeaseGranted(std::make_pair("localhost", 1001)); + ASSERT_TRUE(raylet_client->GrantWorkerLease("localhost", 1001, ClientID::Nil())); ASSERT_EQ(worker_client->callbacks.size(), 2); worker_client->callbacks[1](Status::OK(), rpc::PushTaskReply()); - ASSERT_EQ(raylet_client.num_workers_returned, 2); + ASSERT_EQ(raylet_client->num_workers_returned, 2); +} + +TEST(DirectTaskTransportTest, TestSpillback) { + auto raylet_client = std::make_shared(); + auto worker_client = std::shared_ptr(new MockWorkerClient()); + auto ptr = std::shared_ptr(new CoreWorkerMemoryStore()); + auto store = std::make_shared(ptr); + auto factory = [&](WorkerAddress addr) { return worker_client; }; + + std::unordered_map> remote_lease_clients; + auto lease_client_factory = [&](const rpc::Address &addr) { + ClientID raylet_id = ClientID::FromBinary(addr.raylet_id()); + // We should not create a connection to the same raylet more than once. + RAY_CHECK(remote_lease_clients.count(raylet_id) == 0); + auto client = std::make_shared(); + remote_lease_clients[raylet_id] = client; + return client; + }; + CoreWorkerDirectTaskSubmitter submitter(raylet_client, factory, lease_client_factory, + store); + TaskSpecification task; + task.GetMutableMessage().set_task_id(TaskID::Nil().Binary()); + + ASSERT_TRUE(submitter.SubmitTask(task).ok()); + ASSERT_EQ(raylet_client->num_workers_requested, 1); + ASSERT_EQ(raylet_client->num_workers_returned, 0); + ASSERT_EQ(worker_client->callbacks.size(), 0); + ASSERT_EQ(remote_lease_clients.size(), 0); + + // Spillback to a remote node. + auto remote_raylet_id = ClientID::FromRandom(); + ASSERT_TRUE(raylet_client->GrantWorkerLease("localhost", 1234, remote_raylet_id)); + ASSERT_EQ(remote_lease_clients.count(remote_raylet_id), 1); + // There should be no more callbacks on the local client. + ASSERT_FALSE(raylet_client->GrantWorkerLease("remote", 1234, ClientID::Nil())); + // Trigger retry at the remote node. + ASSERT_TRUE(remote_lease_clients[remote_raylet_id]->GrantWorkerLease("remote", 1234, + ClientID::Nil())); + ASSERT_EQ(worker_client->callbacks.size(), 1); + + // The worker is returned to the remote node, not the local one. + worker_client->callbacks[0](Status::OK(), rpc::PushTaskReply()); + ASSERT_EQ(raylet_client->num_workers_returned, 0); + ASSERT_EQ(remote_lease_clients[remote_raylet_id]->num_workers_returned, 1); } } // namespace ray diff --git a/src/ray/core_worker/transport/direct_task_transport.cc b/src/ray/core_worker/transport/direct_task_transport.cc index e94d7dc8a..069c7154c 100644 --- a/src/ray/core_worker/transport/direct_task_transport.cc +++ b/src/ray/core_worker/transport/direct_task_transport.cc @@ -84,16 +84,14 @@ Status CoreWorkerDirectTaskSubmitter::SubmitTask(TaskSpecification task_spec) { resolver_.ResolveDependencies(task_spec, [this, task_spec]() { // TODO(ekl) should have a queue per distinct resource type required absl::MutexLock lock(&mu_); - RequestNewWorkerIfNeeded(task_spec); queued_tasks_.push_back(task_spec); - // The task is now queued and will be picked up by the next leased or newly - // idle worker. We are guaranteed a worker will show up since we called - // RequestNewWorkerIfNeeded() earlier while holding mu_. + RequestNewWorkerIfNeeded(task_spec); }); return Status::OK(); } -void CoreWorkerDirectTaskSubmitter::HandleWorkerLeaseGranted(const WorkerAddress addr) { +void CoreWorkerDirectTaskSubmitter::HandleWorkerLeaseGranted( + const WorkerAddress &addr, std::shared_ptr lease_client) { // Setup client state for this worker. { absl::MutexLock lock(&mu_); @@ -105,6 +103,7 @@ void CoreWorkerDirectTaskSubmitter::HandleWorkerLeaseGranted(const WorkerAddress std::shared_ptr(client_factory_(addr)); RAY_LOG(INFO) << "Connected to " << addr.first << ":" << addr.second; } + worker_to_lease_client_[addr] = std::move(lease_client); } // Try to assign it work. @@ -115,24 +114,88 @@ void CoreWorkerDirectTaskSubmitter::OnWorkerIdle(const WorkerAddress &addr, bool was_error) { absl::MutexLock lock(&mu_); if (queued_tasks_.empty() || was_error) { - RAY_CHECK_OK(lease_client_.ReturnWorker(addr.second)); + auto lease_client = std::move(worker_to_lease_client_[addr]); + worker_to_lease_client_.erase(addr); + RAY_CHECK_OK(lease_client->ReturnWorker(addr.second)); } else { auto &client = *client_cache_[addr]; PushNormalTask(addr, client, queued_tasks_.front()); queued_tasks_.pop_front(); } - // We have a queue of tasks, try to request more workers. - if (!queued_tasks_.empty()) { - RequestNewWorkerIfNeeded(queued_tasks_.front()); + RequestNewWorkerIfNeeded(queued_tasks_.front()); +} + +std::shared_ptr +CoreWorkerDirectTaskSubmitter::GetOrConnectLeaseClient( + const rpc::Address *raylet_address) { + std::shared_ptr lease_client; + if (raylet_address) { + // Connect to raylet. + ClientID raylet_id = ClientID::FromBinary(raylet_address->raylet_id()); + auto it = remote_lease_clients_.find(raylet_id); + if (it == remote_lease_clients_.end()) { + RAY_LOG(DEBUG) << "Connecting to raylet " << raylet_id; + it = + remote_lease_clients_.emplace(raylet_id, lease_client_factory_(*raylet_address)) + .first; + } + lease_client = it->second; + } else { + lease_client = local_lease_client_; } + + return lease_client; } void CoreWorkerDirectTaskSubmitter::RequestNewWorkerIfNeeded( - const TaskSpecification &resource_spec) { + const TaskSpecification &resource_spec, const rpc::Address *raylet_address) { if (worker_request_pending_) { return; } - RAY_CHECK_OK(lease_client_.RequestWorkerLease(resource_spec)); + if (queued_tasks_.empty()) { + // We don't have any tasks to run, so no need to request a worker. + return; + } + + // NOTE(swang): We must copy the resource spec here because the resource spec + // may get swapped out by the time the callback fires. If we change this so + // that we associate the granted worker with the requested resource spec, + // then we can just pass the ref instead of copying. + TaskSpecification resource_spec_copy(resource_spec.GetMessage()); + auto lease_client = GetOrConnectLeaseClient(raylet_address); + RAY_CHECK_OK(lease_client->RequestWorkerLease( + resource_spec_copy, + [this, resource_spec_copy, lease_client]( + const Status &status, const rpc::WorkerLeaseReply &reply) mutable { + if (status.ok()) { + if (!reply.worker_address().raylet_id().empty()) { + RAY_LOG(DEBUG) << "Lease granted " << resource_spec_copy.TaskId(); + HandleWorkerLeaseGranted( + {reply.worker_address().ip_address(), reply.worker_address().port()}, + std::move(lease_client)); + } else { + absl::MutexLock lock(&mu_); + worker_request_pending_ = false; + RequestNewWorkerIfNeeded(resource_spec_copy, + &reply.retry_at_raylet_address()); + } + } else { + RAY_LOG(DEBUG) << "Retrying lease request " << resource_spec_copy.TaskId(); + absl::MutexLock lock(&mu_); + worker_request_pending_ = false; + if (lease_client != local_lease_client_) { + // A remote request failed. Retry the worker lease request locally + // if it's still in the queue. + // TODO(swang): Fail after some number of retries? + RAY_LOG(ERROR) << "Retrying attempt to schedule task at remote node. Error: " + << status.ToString(); + RequestNewWorkerIfNeeded(resource_spec_copy); + } else { + RAY_LOG(FATAL) << "Lost connection with local raylet. Error: " + << status.ToString(); + } + } + })); worker_request_pending_ = true; } diff --git a/src/ray/core_worker/transport/direct_task_transport.h b/src/ray/core_worker/transport/direct_task_transport.h index 0bb708a23..fabedc903 100644 --- a/src/ray/core_worker/transport/direct_task_transport.h +++ b/src/ray/core_worker/transport/direct_task_transport.h @@ -54,15 +54,19 @@ class LocalDependencyResolver { typedef std::pair WorkerAddress; typedef std::function(WorkerAddress)> ClientFactoryFn; +typedef std::function(const rpc::Address &)> + LeaseClientFactoryFn; // This class is thread-safe. class CoreWorkerDirectTaskSubmitter { public: CoreWorkerDirectTaskSubmitter( - WorkerLeaseInterface &lease_client, ClientFactoryFn client_factory, + std::shared_ptr lease_client, ClientFactoryFn client_factory, + LeaseClientFactoryFn lease_client_factory, std::shared_ptr store_provider) - : lease_client_(lease_client), + : local_lease_client_(lease_client), client_factory_(client_factory), + lease_client_factory_(lease_client_factory), in_memory_store_(store_provider), resolver_(in_memory_store_) {} @@ -71,34 +75,49 @@ class CoreWorkerDirectTaskSubmitter { /// \param[in] task_spec The task to schedule. Status SubmitTask(TaskSpecification task_spec); - /// Callback for when the raylet grants us a worker lease. The worker is returned - /// to the raylet once it finishes its task and either the lease term has - /// expired, or there is no more work it can take on. - /// - /// \param[in] addr The (addr, port) pair identifying the worker. - void HandleWorkerLeaseGranted(const WorkerAddress addr); - private: /// Schedule more work onto an idle worker or return it back to the raylet if /// no more tasks are queued for submission. If an error was encountered /// processing the worker, we don't attempt to re-use the worker. void OnWorkerIdle(const WorkerAddress &addr, bool was_error); + /// Get an existing lease client or connect a new one. If a raylet_address is + /// provided, this connects to a remote raylet. Else, this connects to the + /// local raylet. + std::shared_ptr GetOrConnectLeaseClient( + const rpc::Address *raylet_address) EXCLUSIVE_LOCKS_REQUIRED(mu_); + /// Request a new worker from the raylet if no such requests are currently in - /// flight. - void RequestNewWorkerIfNeeded(const TaskSpecification &resource_spec) + /// flight and there are tasks queued. If a raylet address is provided, then + /// the worker should be requested from the raylet at that address. Else, the + /// worker should be requested from the local raylet. + void RequestNewWorkerIfNeeded(const TaskSpecification &resource_spec, + const rpc::Address *raylet_address = nullptr) EXCLUSIVE_LOCKS_REQUIRED(mu_); + /// Callback for when the raylet grants us a worker lease. The worker is returned + /// to the raylet via the given lease client once the task queue is empty. + /// TODO: Implement a lease term by which we need to return the worker. + void HandleWorkerLeaseGranted(const WorkerAddress &addr, + std::shared_ptr lease_client); + /// Push a task to a specific worker. void PushNormalTask(const WorkerAddress &addr, rpc::CoreWorkerClientInterface &client, TaskSpecification &task_spec); - // Client that can be used to lease and return workers. - WorkerLeaseInterface &lease_client_; + // Client that can be used to lease and return workers from the local raylet. + std::shared_ptr local_lease_client_; + + /// Cache of gRPC clients to remote raylets. + absl::flat_hash_map> + remote_lease_clients_ GUARDED_BY(mu_); /// Factory for producing new core worker clients. ClientFactoryFn client_factory_; + /// Factory for producing new clients to request leases from remote nodes. + LeaseClientFactoryFn lease_client_factory_; + /// The store provider. std::shared_ptr in_memory_store_; @@ -112,6 +131,11 @@ class CoreWorkerDirectTaskSubmitter { absl::flat_hash_map> client_cache_ GUARDED_BY(mu_); + /// Map from worker address to the lease client through which it should be + /// returned. + absl::flat_hash_map> + worker_to_lease_client_ GUARDED_BY(mu_); + // Whether we have a request to the Raylet to acquire a new worker in flight. bool worker_request_pending_ GUARDED_BY(mu_) = false; diff --git a/src/ray/core_worker/transport/raylet_transport.cc b/src/ray/core_worker/transport/raylet_transport.cc index 27b3f8bdb..d1709df12 100644 --- a/src/ray/core_worker/transport/raylet_transport.cc +++ b/src/ray/core_worker/transport/raylet_transport.cc @@ -6,7 +6,7 @@ namespace ray { CoreWorkerRayletTaskReceiver::CoreWorkerRayletTaskReceiver( - std::unique_ptr &raylet_client, const TaskHandler &task_handler, + std::shared_ptr &raylet_client, const TaskHandler &task_handler, const std::function &exit_handler) : raylet_client_(raylet_client), task_handler_(task_handler), diff --git a/src/ray/core_worker/transport/raylet_transport.h b/src/ray/core_worker/transport/raylet_transport.h index 3d05abe53..b4de32b24 100644 --- a/src/ray/core_worker/transport/raylet_transport.h +++ b/src/ray/core_worker/transport/raylet_transport.h @@ -15,7 +15,7 @@ class CoreWorkerRayletTaskReceiver { const TaskSpecification &task_spec, const ResourceMappingType &resource_ids, std::vector> *return_objects)>; - CoreWorkerRayletTaskReceiver(std::unique_ptr &raylet_client, + CoreWorkerRayletTaskReceiver(std::shared_ptr &raylet_client, const TaskHandler &task_handler, const std::function &exit_handler); @@ -32,7 +32,7 @@ class CoreWorkerRayletTaskReceiver { private: /// Raylet client. - std::unique_ptr &raylet_client_; + std::shared_ptr &raylet_client_; /// The callback function to process a task. TaskHandler task_handler_; /// The callback function to exit the worker. diff --git a/src/ray/gcs/redis_gcs_client_test.cc b/src/ray/gcs/redis_gcs_client_test.cc index dc26e32ec..f4150d44b 100644 --- a/src/ray/gcs/redis_gcs_client_test.cc +++ b/src/ray/gcs/redis_gcs_client_test.cc @@ -1159,7 +1159,7 @@ void ClientTableNotification(gcs::RedisGcsClient *client, const ClientID &client ASSERT_EQ(data.state() == GcsNodeInfo::ALIVE, is_alive); GcsNodeInfo cached_client; - client->client_table().GetClient(added_id, cached_client); + ASSERT_TRUE(client->client_table().GetClient(added_id, &cached_client)); ASSERT_EQ(ClientID::FromBinary(cached_client.node_id()), added_id); ASSERT_EQ(cached_client.state() == GcsNodeInfo::ALIVE, is_alive); } diff --git a/src/ray/gcs/tables.cc b/src/ray/gcs/tables.cc index cc7399d01..f6537840f 100644 --- a/src/ray/gcs/tables.cc +++ b/src/ray/gcs/tables.cc @@ -679,14 +679,14 @@ ray::Status ClientTable::MarkDisconnected(const ClientID &dead_node_id) { return Append(JobID::Nil(), client_log_key_, node_info, nullptr); } -void ClientTable::GetClient(const ClientID &node_id, GcsNodeInfo &node_info) const { +bool ClientTable::GetClient(const ClientID &node_id, GcsNodeInfo *node_info) const { RAY_CHECK(!node_id.IsNil()); auto entry = node_cache_.find(node_id); - if (entry != node_cache_.end()) { - node_info = entry->second; - } else { - node_info.set_node_id(ClientID::Nil().Binary()); + auto found = (entry != node_cache_.end()); + if (found) { + *node_info = entry->second; } + return found; } const std::unordered_map &ClientTable::GetAllClients() const { diff --git a/src/ray/gcs/tables.h b/src/ray/gcs/tables.h index 919ff24ea..511ba4fd4 100644 --- a/src/ray/gcs/tables.h +++ b/src/ray/gcs/tables.h @@ -881,11 +881,11 @@ class ClientTable : public Log { /// information for clients that we've heard a notification for. /// /// \param client The client to get information about. - /// \param node_info A reference to the client information. If we have information - /// about the client in the cache, then the reference will be modified to - /// contain that information. Else, the reference will be updated to contain + /// \param node_info The client information will be copied here if + /// we have the client in the cache. /// a nil client ID. - void GetClient(const ClientID &client, GcsNodeInfo &node_info) const; + /// \return Whether teh client is in the cache. + bool GetClient(const ClientID &client, GcsNodeInfo *node_info) const; /// Get the local client's ID. /// diff --git a/src/ray/object_manager/object_directory.cc b/src/ray/object_manager/object_directory.cc index 0a1376bc1..e08466876 100644 --- a/src/ray/object_manager/object_directory.cc +++ b/src/ray/object_manager/object_directory.cc @@ -107,9 +107,10 @@ ray::Status ObjectDirectory::ReportObjectRemoved( void ObjectDirectory::LookupRemoteConnectionInfo( RemoteConnectionInfo &connection_info) const { GcsNodeInfo node_info; - gcs_client_->client_table().GetClient(connection_info.client_id, node_info); + bool found = + gcs_client_->client_table().GetClient(connection_info.client_id, &node_info); ClientID result_client_id = ClientID::FromBinary(node_info.node_id()); - if (!result_client_id.IsNil()) { + if (found) { RAY_CHECK(result_client_id == connection_info.client_id); if (node_info.state() == GcsNodeInfo::ALIVE) { connection_info.ip = node_info.node_manager_address(); diff --git a/src/ray/object_manager/test/object_manager_stress_test.cc b/src/ray/object_manager/test/object_manager_stress_test.cc index 1e564ba41..013c5172d 100644 --- a/src/ray/object_manager/test/object_manager_stress_test.cc +++ b/src/ray/object_manager/test/object_manager_stress_test.cc @@ -412,12 +412,12 @@ class StressTestObjectManager : public TestObjectManagerBase { << "All connected clients:" << "\n"; GcsNodeInfo data; - gcs_client_1->client_table().GetClient(client_id_1, data); + ASSERT_TRUE(gcs_client_1->client_table().GetClient(client_id_1, &data)); RAY_LOG(DEBUG) << "ClientID=" << ClientID::FromBinary(data.node_id()) << "\n" << "ClientIp=" << data.node_manager_address() << "\n" << "ClientPort=" << data.node_manager_port(); GcsNodeInfo data2; - gcs_client_1->client_table().GetClient(client_id_2, data2); + ASSERT_TRUE(gcs_client_1->client_table().GetClient(client_id_2, &data2)); RAY_LOG(DEBUG) << "ClientID=" << ClientID::FromBinary(data2.node_id()) << "\n" << "ClientIp=" << data2.node_manager_address() << "\n" << "ClientPort=" << data2.node_manager_port(); diff --git a/src/ray/object_manager/test/object_manager_test.cc b/src/ray/object_manager/test/object_manager_test.cc index 2d0bea8fa..3a1626c93 100644 --- a/src/ray/object_manager/test/object_manager_test.cc +++ b/src/ray/object_manager/test/object_manager_test.cc @@ -434,14 +434,14 @@ class TestObjectManager : public TestObjectManagerBase { << "Server client ids:" << "\n"; GcsNodeInfo data; - gcs_client_1->client_table().GetClient(client_id_1, data); + ASSERT_TRUE(gcs_client_1->client_table().GetClient(client_id_1, &data)); RAY_LOG(DEBUG) << (ClientID::FromBinary(data.node_id()).IsNil()); RAY_LOG(DEBUG) << "Server 1 ClientID=" << ClientID::FromBinary(data.node_id()); RAY_LOG(DEBUG) << "Server 1 ClientIp=" << data.node_manager_address(); RAY_LOG(DEBUG) << "Server 1 ClientPort=" << data.node_manager_port(); ASSERT_EQ(client_id_1, ClientID::FromBinary(data.node_id())); GcsNodeInfo data2; - gcs_client_1->client_table().GetClient(client_id_2, data2); + ASSERT_TRUE(gcs_client_1->client_table().GetClient(client_id_2, &data2)); RAY_LOG(DEBUG) << "Server 2 ClientID=" << ClientID::FromBinary(data2.node_id()); RAY_LOG(DEBUG) << "Server 2 ClientIp=" << data2.node_manager_address(); RAY_LOG(DEBUG) << "Server 2 ClientPort=" << data2.node_manager_port(); diff --git a/src/ray/protobuf/common.proto b/src/ray/protobuf/common.proto index 68daabc9a..7b68c88f8 100644 --- a/src/ray/protobuf/common.proto +++ b/src/ray/protobuf/common.proto @@ -27,6 +27,7 @@ enum TaskType { ACTOR_TASK = 2; } +// Address of a worker or node manager. message Address { bytes raylet_id = 1; string ip_address = 2; diff --git a/src/ray/protobuf/core_worker.proto b/src/ray/protobuf/core_worker.proto index 2ad4b2623..323624feb 100644 --- a/src/ray/protobuf/core_worker.proto +++ b/src/ray/protobuf/core_worker.proto @@ -87,16 +87,6 @@ message DirectActorCallArgWaitCompleteRequest { message DirectActorCallArgWaitCompleteReply { } -message WorkerLeaseGrantedRequest { - // Address of the leased worker. - string address = 1; - // Port of the leased worker. - int32 port = 2; -} - -message WorkerLeaseGrantedReply { -} - service CoreWorkerService { // Push a task to a worker from the raylet. rpc AssignTask(AssignTaskRequest) returns (AssignTaskReply); @@ -105,6 +95,4 @@ service CoreWorkerService { // Reply from raylet that wait for direct actor call args has completed. rpc DirectActorCallArgWaitComplete(DirectActorCallArgWaitCompleteRequest) returns (DirectActorCallArgWaitCompleteReply); - // Reply from raylet to fulfill a worker lease request. - rpc WorkerLeaseGranted(WorkerLeaseGrantedRequest) returns (WorkerLeaseGrantedReply); } diff --git a/src/ray/protobuf/node_manager.proto b/src/ray/protobuf/node_manager.proto index d68a9afe1..6b33a1cc8 100644 --- a/src/ray/protobuf/node_manager.proto +++ b/src/ray/protobuf/node_manager.proto @@ -12,6 +12,29 @@ message SubmitTaskRequest { message SubmitTaskReply { } +// Request a worker from the raylet with the specified resources. +message WorkerLeaseRequest { + // Task containing the requested resources. + TaskSpec resource_spec = 1; +} + +message WorkerLeaseReply { + // Address of the leased worker. If this is empty, then the request should be + // retried at the provided raylet address. + Address worker_address = 1; + // Address of the raylet to spill back to, if any. + Address retry_at_raylet_address = 2; +} + +// Release a worker back to its raylet. +message ReturnWorkerRequest { + // Port of the leased worker that we are now returning. + int32 worker_port = 1; +} + +message ReturnWorkerReply { +} + message ForwardTaskRequest { // The ID of the task to be forwarded. bytes task_id = 1; @@ -66,6 +89,10 @@ message NodeStatsReply { service NodeManagerService { // Submit a task (from a local or remote worker) to the node manager. rpc SubmitTask(SubmitTaskRequest) returns (SubmitTaskReply); + // Request a worker from the raylet. + rpc RequestWorkerLease(WorkerLeaseRequest) returns (WorkerLeaseReply); + // Release a worker back to its raylet. + rpc ReturnWorker(ReturnWorkerRequest) returns (ReturnWorkerReply); // Forward a task and its uncommitted lineage to the remote node manager. rpc ForwardTask(ForwardTaskRequest) returns (ForwardTaskReply); // Get the current node stats. diff --git a/src/ray/raylet/format/node_manager.fbs b/src/ray/raylet/format/node_manager.fbs index 94bda2f89..91694b7e6 100644 --- a/src/ray/raylet/format/node_manager.fbs +++ b/src/ray/raylet/format/node_manager.fbs @@ -74,10 +74,6 @@ enum MessageType:int { SetResourceRequest, // Update the active set of object IDs in use on this worker. ReportActiveObjectIDs, - // Request a worker from the raylet with the specified resources. - RequestWorkerLease, - // Returns a worker to the raylet. - ReturnWorker, } table TaskExecutionSpecification { @@ -95,14 +91,6 @@ table Task { task_execution_spec: TaskExecutionSpecification; } -table WorkerLeaseRequest { - resource_spec: string; -} - -table ReturnWorkerRequest { - worker_port: int; -} - // This message describes a given resource that is reserved for a worker. table ResourceIdSetInfo { // The name of the resource. diff --git a/src/ray/raylet/node_manager.cc b/src/ray/raylet/node_manager.cc index dbe6a9e30..f1511327a 100644 --- a/src/ray/raylet/node_manager.cc +++ b/src/ray/raylet/node_manager.cc @@ -896,12 +896,6 @@ void NodeManager::ProcessClientMessage( // because it's already disconnected. return; } break; - case protocol::MessageType::RequestWorkerLease: { - ProcessRequestWorkerLeaseMessage(client, message_data); - } break; - case protocol::MessageType::ReturnWorker: { - ProcessReturnWorkerMessage(message_data); - } break; case protocol::MessageType::SetResourceRequest: { ProcessSetResourceRequest(client, message_data); } break; @@ -1197,43 +1191,6 @@ void NodeManager::ProcessDisconnectClientMessage( // these can be leaked. } -void NodeManager::ProcessRequestWorkerLeaseMessage( - const std::shared_ptr &client, const uint8_t *message_data) { - // Read the resource spec submitted by the client. - auto fbs_message = flatbuffers::GetRoot(message_data); - rpc::Task task_message; - RAY_CHECK(task_message.mutable_task_spec()->ParseFromArray( - fbs_message->resource_spec()->data(), fbs_message->resource_spec()->size())); - - // Override the task dispatch to call back to the client instead of executing the - // task directly on the worker. TODO(ekl) handle spilling case - Task task(task_message); - task.OnDispatchInstead([this, client](const std::shared_ptr granted, - const std::string &address, int port) { - std::shared_ptr client_worker = worker_pool_.GetRegisteredWorker(client); - if (client_worker == nullptr) { - client_worker = worker_pool_.GetRegisteredDriver(client); - } - if (client_worker == nullptr) { - RAY_LOG(FATAL) << "TODO: Lost worker for lease request " << client; - } else { - client_worker->WorkerLeaseGranted(address, port); - leased_workers_[port] = std::static_pointer_cast(granted); - } - }); - SubmitTask(task, Lineage()); -} - -void NodeManager::ProcessReturnWorkerMessage(const uint8_t *message_data) { - // Read the resource spec submitted by the client. - auto fbs_message = flatbuffers::GetRoot(message_data); - auto worker_port = fbs_message->worker_port(); - RAY_LOG(DEBUG) << "Return worker " << worker_port; - std::shared_ptr worker = leased_workers_[worker_port]; - leased_workers_.erase(worker_port); - HandleWorkerAvailable(worker); -} - void NodeManager::ProcessFetchOrReconstructMessage( const std::shared_ptr &client, const uint8_t *message_data) { auto message = flatbuffers::GetRoot(message_data); @@ -1442,6 +1399,11 @@ void NodeManager::HandleSubmitTask(const rpc::SubmitTaskRequest &request, rpc::SendReplyCallback send_reply_callback) { rpc::Task task; task.mutable_task_spec()->CopyFrom(request.task_spec()); + // Set the caller's node ID. + if (task.task_spec().caller_address().raylet_id() == "") { + task.mutable_task_spec()->mutable_caller_address()->set_raylet_id( + gcs_client_->client_table().GetLocalClientId().Binary()); + } // Submit the task to the raylet. Since the task was submitted // locally, there is no uncommitted lineage. @@ -1449,6 +1411,61 @@ void NodeManager::HandleSubmitTask(const rpc::SubmitTaskRequest &request, send_reply_callback(Status::OK(), nullptr, nullptr); } +void NodeManager::HandleWorkerLeaseRequest(const rpc::WorkerLeaseRequest &request, + rpc::WorkerLeaseReply *reply, + rpc::SendReplyCallback send_reply_callback) { + rpc::Task task_message; + task_message.mutable_task_spec()->CopyFrom(request.resource_spec()); + + // Override the task dispatch to call back to the client instead of executing the + // task directly on the worker. + Task task(task_message); + RAY_LOG(DEBUG) << "Worker lease request " << task.GetTaskSpecification().TaskId(); + TaskID task_id = task.GetTaskSpecification().TaskId(); + task.OnDispatchInstead( + [this, task_id, reply, send_reply_callback](const std::shared_ptr granted, + const std::string &address, int port) { + RAY_LOG(DEBUG) << "Worker lease request DISPATCH " << task_id; + reply->mutable_worker_address()->set_ip_address(address); + reply->mutable_worker_address()->set_port(port); + reply->mutable_worker_address()->set_raylet_id( + gcs_client_->client_table().GetLocalClientId().Binary()); + send_reply_callback(Status::OK(), nullptr, nullptr); + + // TODO(swang): Kill worker if other end hangs up. + // TODO(swang): Implement a lease term by which the owner needs to return the + // worker. + leased_workers_[port] = std::static_pointer_cast(granted); + }); + task.OnSpillbackInstead( + [reply, task_id, send_reply_callback](const ClientID &spillback_to, + const std::string &address, int port) { + RAY_LOG(DEBUG) << "Worker lease request SPILLBACK " << task_id; + reply->mutable_retry_at_raylet_address()->set_ip_address(address); + reply->mutable_retry_at_raylet_address()->set_port(port); + reply->mutable_retry_at_raylet_address()->set_raylet_id(spillback_to.Binary()); + send_reply_callback(Status::OK(), nullptr, nullptr); + }); + SubmitTask(task, Lineage()); +} + +void NodeManager::HandleReturnWorker(const rpc::ReturnWorkerRequest &request, + rpc::ReturnWorkerReply *reply, + rpc::SendReplyCallback send_reply_callback) { + // Read the resource spec submitted by the client. + auto worker_port = request.worker_port(); + RAY_LOG(DEBUG) << "Return worker " << worker_port; + std::shared_ptr worker = std::move(leased_workers_[worker_port]); + leased_workers_.erase(worker_port); + Status status; + if (worker) { + HandleWorkerAvailable(worker); + } else { + status = Status::Invalid("Returned worker does not exist"); + } + send_reply_callback(status, nullptr, nullptr); +} + void NodeManager::HandleForwardTask(const rpc::ForwardTaskRequest &request, rpc::ForwardTaskReply *reply, rpc::SendReplyCallback send_reply_callback) { @@ -2480,6 +2497,17 @@ void NodeManager::ForwardTaskOrResubmit(const Task &task, void NodeManager::ForwardTask( const Task &task, const ClientID &node_id, const std::function &on_error) { + // Override spillback for direct tasks. + if (task.OnSpillback() != nullptr) { + GcsNodeInfo node_info; + bool found = gcs_client_->client_table().GetClient(node_id, &node_info); + RAY_CHECK(found) << "Spilling back to a node manager, but no GCS info found for node " + << node_id; + task.OnSpillback()(node_id, node_info.node_manager_address(), + node_info.node_manager_port()); + return; + } + // Lookup node manager client for this node_id and use it to send the request. auto client_entry = remote_node_manager_clients_.find(node_id); if (client_entry == remote_node_manager_clients_.end()) { diff --git a/src/ray/raylet/node_manager.h b/src/ray/raylet/node_manager.h index 84f331552..efea4246d 100644 --- a/src/ray/raylet/node_manager.h +++ b/src/ray/raylet/node_manager.h @@ -412,20 +412,6 @@ class NodeManager : public rpc::NodeManagerServiceHandler { const std::shared_ptr &client, bool intentional_disconnect = false); - /// Process client message of RequestWorkerLease - /// - /// \param client The client that sent the message. - /// \param message_data A pointer to the message data. - /// \return Void. - void ProcessRequestWorkerLeaseMessage( - const std::shared_ptr &client, const uint8_t *message_data); - - /// Process client message of ReturnWorkerMessage - /// - /// \param message_data A pointer to the message data. - /// \return Void. - void ProcessReturnWorkerMessage(const uint8_t *message_data); - /// Process client message of FetchOrReconstruct /// /// \param client The client that sent the message. @@ -514,6 +500,16 @@ class NodeManager : public rpc::NodeManagerServiceHandler { rpc::SubmitTaskReply *reply, rpc::SendReplyCallback send_reply_callback) override; + /// Handle a `WorkerLease` request. + void HandleWorkerLeaseRequest(const rpc::WorkerLeaseRequest &request, + rpc::WorkerLeaseReply *reply, + rpc::SendReplyCallback send_reply_callback) override; + + /// Handle a `ReturnWorker` request. + void HandleReturnWorker(const rpc::ReturnWorkerRequest &request, + rpc::ReturnWorkerReply *reply, + rpc::SendReplyCallback send_reply_callback) override; + /// Handle a `ForwardTask` request. void HandleForwardTask(const rpc::ForwardTaskRequest &request, rpc::ForwardTaskReply *reply, diff --git a/src/ray/raylet/raylet_client.cc b/src/ray/raylet/raylet_client.cc index 80ef91517..9803a44ae 100644 --- a/src/ray/raylet/raylet_client.cc +++ b/src/ray/raylet/raylet_client.cc @@ -201,15 +201,14 @@ ray::Status RayletConnection::AtomicRequestReply( return ReadMessage(reply_type, reply_message); } +RayletClient::RayletClient(std::shared_ptr grpc_client) + : grpc_client_(std::move(grpc_client)) {} + RayletClient::RayletClient(std::shared_ptr grpc_client, const std::string &raylet_socket, const WorkerID &worker_id, bool is_worker, const JobID &job_id, const Language &language, ClientID *raylet_id, int port) - : grpc_client_(std::move(grpc_client)), - worker_id_(worker_id), - is_worker_(is_worker), - job_id_(job_id), - language_(language) { + : grpc_client_(std::move(grpc_client)), worker_id_(worker_id), job_id_(job_id) { // For C++14, we could use std::make_unique conn_ = std::unique_ptr(new RayletConnection(raylet_socket, -1, -1)); @@ -381,17 +380,18 @@ ray::Status RayletClient::ReportActiveObjectIDs( } ray::Status RayletClient::RequestWorkerLease( - const ray::TaskSpecification &resource_spec) { - flatbuffers::FlatBufferBuilder fbb; - auto message = ray::protocol::CreateWorkerLeaseRequest( - fbb, fbb.CreateString(resource_spec.Serialize())); - fbb.Finish(message); - return conn_->WriteMessage(MessageType::RequestWorkerLease, &fbb); + const ray::TaskSpecification &resource_spec, + const ray::rpc::ClientCallback &callback) { + ray::rpc::WorkerLeaseRequest request; + request.mutable_resource_spec()->CopyFrom(resource_spec.GetMessage()); + return grpc_client_->RequestWorkerLease(request, callback); } ray::Status RayletClient::ReturnWorker(int worker_port) { - flatbuffers::FlatBufferBuilder fbb; - auto message = ray::protocol::CreateReturnWorkerRequest(fbb, worker_port); - fbb.Finish(message); - return conn_->WriteMessage(MessageType::ReturnWorker, &fbb); + ray::rpc::ReturnWorkerRequest request; + request.set_worker_port(worker_port); + return grpc_client_->ReturnWorker( + request, [](const ray::Status &status, const ray::rpc::ReturnWorkerReply &reply) { + RAY_CHECK_OK(status); + }); } diff --git a/src/ray/raylet/raylet_client.h b/src/ray/raylet/raylet_client.h index aa36d4c02..f796307ce 100644 --- a/src/ray/raylet/raylet_client.h +++ b/src/ray/raylet/raylet_client.h @@ -69,12 +69,16 @@ class WorkerLeaseInterface { /// Requests a worker from the raylet. The callback will be sent via gRPC. /// \param resource_spec Resources that should be allocated for the worker. /// \return ray::Status - virtual ray::Status RequestWorkerLease(const ray::TaskSpecification &resource_spec) = 0; + virtual ray::Status RequestWorkerLease( + const ray::TaskSpecification &resource_spec, + const ray::rpc::ClientCallback &callback) = 0; /// Returns a worker to the raylet. /// \param worker_port The local port of the worker on the raylet node. /// \return ray::Status virtual ray::Status ReturnWorker(int worker_port) = 0; + + virtual ~WorkerLeaseInterface(){}; }; class RayletClient : public WorkerLeaseInterface { @@ -96,6 +100,11 @@ class RayletClient : public WorkerLeaseInterface { bool is_worker, const JobID &job_id, const Language &language, ClientID *raylet_id, int port = -1); + /// Connect to the raylet via grpc only. + /// + /// \param grpc_client gRPC client to the raylet. + RayletClient(std::shared_ptr grpc_client); + ray::Status Disconnect() { return conn_->Disconnect(); }; /// Submit a task using the raylet code path. @@ -203,19 +212,17 @@ class RayletClient : public WorkerLeaseInterface { ray::Status ReportActiveObjectIDs(const std::unordered_set &object_ids); /// Implements WorkerLeaseInterface. - ray::Status RequestWorkerLease(const ray::TaskSpecification &resource_spec) override; + ray::Status RequestWorkerLease( + const ray::TaskSpecification &resource_spec, + const ray::rpc::ClientCallback &callback) override; /// Implements WorkerLeaseInterface. ray::Status ReturnWorker(int worker_port) override; - Language GetLanguage() const { return language_; } - WorkerID GetWorkerID() const { return worker_id_; } JobID GetJobID() const { return job_id_; } - bool IsWorker() const { return is_worker_; } - const ResourceMappingType &GetResourceIDs() const { return resource_ids_; } private: @@ -223,9 +230,7 @@ class RayletClient : public WorkerLeaseInterface { /// request types. std::shared_ptr grpc_client_; const WorkerID worker_id_; - const bool is_worker_; const JobID job_id_; - const Language language_; /// A map from resource name to the resource IDs that are currently reserved /// for this worker. Each pair consists of the resource ID and the fraction /// of that resource allocated for this worker. diff --git a/src/ray/raylet/worker.cc b/src/ray/raylet/worker.cc index 78e6c4a04..9c46180a7 100644 --- a/src/ray/raylet/worker.cc +++ b/src/ray/raylet/worker.cc @@ -172,24 +172,6 @@ void Worker::DirectActorCallArgWaitComplete(int64_t tag) { } } -void Worker::WorkerLeaseGranted(const std::string &address, int port) { - RAY_CHECK(!address.empty()); - RAY_CHECK(port_ > 0); - rpc::WorkerLeaseGrantedRequest request; - request.set_address(address); - request.set_port(port); - auto status = rpc_client_->WorkerLeaseGranted( - request, [address, port](Status status, const rpc::WorkerLeaseGrantedReply &reply) { - if (!status.ok()) { - RAY_LOG(ERROR) << "Failed to reply to lease request: " << status.ToString() - << " for " << address << ":" << port; - } - }); - if (!status.ok()) { - RAY_LOG(ERROR) << "Failed to reply to lease request: " << status.ToString(); - } -} - } // namespace raylet } // end namespace ray diff --git a/src/ray/rpc/node_manager/node_manager_client.h b/src/ray/rpc/node_manager/node_manager_client.h index 4d02c52bc..686548ace 100644 --- a/src/ray/rpc/node_manager/node_manager_client.h +++ b/src/ray/rpc/node_manager/node_manager_client.h @@ -83,6 +83,24 @@ class NodeManagerWorkerClient return call->GetStatus(); } + /// Request a worker lease. + ray::Status RequestWorkerLease(const WorkerLeaseRequest &request, + const ClientCallback &callback) { + auto call = client_call_manager_ + .CreateCall( + *stub_, &NodeManagerService::Stub::PrepareAsyncRequestWorkerLease, + request, callback); + return call->GetStatus(); + } + + ray::Status ReturnWorker(const ReturnWorkerRequest &request, + const ClientCallback &callback) { + auto call = client_call_manager_.CreateCall( + *stub_, &NodeManagerService::Stub::PrepareAsyncReturnWorker, request, callback); + return call->GetStatus(); + } + private: /// Constructor. /// diff --git a/src/ray/rpc/node_manager/node_manager_server.h b/src/ray/rpc/node_manager/node_manager_server.h index 40a4d8022..7d89b320f 100644 --- a/src/ray/rpc/node_manager/node_manager_server.h +++ b/src/ray/rpc/node_manager/node_manager_server.h @@ -27,6 +27,14 @@ class NodeManagerServiceHandler { virtual void HandleSubmitTask(const SubmitTaskRequest &request, SubmitTaskReply *reply, SendReplyCallback send_reply_callback) = 0; + virtual void HandleWorkerLeaseRequest(const WorkerLeaseRequest &request, + WorkerLeaseReply *reply, + SendReplyCallback send_reply_callback) = 0; + + virtual void HandleReturnWorker(const ReturnWorkerRequest &request, + ReturnWorkerReply *reply, + SendReplyCallback send_reply_callback) = 0; + virtual void HandleForwardTask(const ForwardTaskRequest &request, ForwardTaskReply *reply, SendReplyCallback send_reply_callback) = 0; @@ -62,6 +70,20 @@ class NodeManagerGrpcService : public GrpcService { service_handler_, &NodeManagerServiceHandler::HandleSubmitTask, cq, main_service_)); + std::unique_ptr request_worker_lease_call_factory( + new ServerCallFactoryImpl( + service_, &NodeManagerService::AsyncService::RequestRequestWorkerLease, + service_handler_, &NodeManagerServiceHandler::HandleWorkerLeaseRequest, cq, + main_service_)); + + std::unique_ptr release_worker_call_factory( + new ServerCallFactoryImpl( + service_, &NodeManagerService::AsyncService::RequestReturnWorker, + service_handler_, &NodeManagerServiceHandler::HandleReturnWorker, cq, + main_service_)); + std::unique_ptr forward_task_call_factory( new ServerCallFactoryImpl( @@ -79,6 +101,10 @@ class NodeManagerGrpcService : public GrpcService { // Set accept concurrency. server_call_factories_and_concurrencies->emplace_back( std::move(submit_task_call_factory), 100); + server_call_factories_and_concurrencies->emplace_back( + std::move(request_worker_lease_call_factory), 100); + server_call_factories_and_concurrencies->emplace_back( + std::move(release_worker_call_factory), 100); server_call_factories_and_concurrencies->emplace_back( std::move(forward_task_call_factory), 100); server_call_factories_and_concurrencies->emplace_back( diff --git a/src/ray/rpc/worker/core_worker_client.h b/src/ray/rpc/worker/core_worker_client.h index 763f76f9d..c58573371 100644 --- a/src/ray/rpc/worker/core_worker_client.h +++ b/src/ray/rpc/worker/core_worker_client.h @@ -74,17 +74,6 @@ class CoreWorkerClientInterface { return Status::NotImplemented(""); } - /// Grants a worker to the client. - /// - /// \param[in] request The request message. - /// \param[in] callback The callback function that handles reply. - /// \return if the rpc call succeeds - virtual ray::Status WorkerLeaseGranted( - const WorkerLeaseGrantedRequest &request, - const ClientCallback &callback) { - return Status::NotImplemented(""); - } - virtual ~CoreWorkerClientInterface(){}; }; @@ -152,17 +141,6 @@ class CoreWorkerClient : public std::enable_shared_from_this, return call->GetStatus(); } - ray::Status WorkerLeaseGranted( - const WorkerLeaseGrantedRequest &request, - const ClientCallback &callback) override { - auto call = - client_call_manager_.CreateCall( - *stub_, &CoreWorkerService::Stub::PrepareAsyncWorkerLeaseGranted, request, - callback); - return call->GetStatus(); - } - /// Send as many pending tasks as possible. This method is thread-safe. /// /// The client will guarantee no more than kMaxBytesInFlight bytes of RPCs are being