diff --git a/BUILD.bazel b/BUILD.bazel index 0d4007a6f..03bcc36aa 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -774,6 +774,16 @@ cc_test( ], ) +cc_test( + name = "lease_policy_test", + srcs = ["src/ray/core_worker/test/lease_policy_test.cc"], + copts = COPTS, + deps = [ + ":core_worker_lib", + "@com_google_googletest//:gtest_main", + ], +) + cc_test( name = "cluster_resource_scheduler_test", srcs = [ diff --git a/python/ray/tests/test_advanced_3.py b/python/ray/tests/test_advanced_3.py index b1bc25fbb..2e60f40e9 100644 --- a/python/ray/tests/test_advanced_3.py +++ b/python/ray/tests/test_advanced_3.py @@ -1,4 +1,5 @@ # coding: utf-8 +import collections import glob import logging import os @@ -37,11 +38,10 @@ def attempt_to_load_balance(remote_function, while attempts < num_attempts: locations = ray.get( [remote_function.remote(*args) for _ in range(total_tasks)]) - names = set(locations) - counts = [locations.count(name) for name in names] - logger.info(f"Counts are {counts}.") - if (len(names) == num_nodes - and all(count >= minimum_count for count in counts)): + counts = collections.Counter(locations) + logger.info(f"Counts are {counts}") + if (len(counts) == num_nodes + and counts.most_common()[-1][1] >= minimum_count): break attempts += 1 assert attempts < num_attempts @@ -124,6 +124,38 @@ def test_load_balancing_with_dependencies(ray_start_cluster, fast): attempt_to_load_balance(f, [x], 100, num_nodes, 25) +def test_locality_aware_leasing(ray_start_cluster): + # This test ensures that a task will run where its task dependencies are + # located. We run an initial non_local() task that is pinned to a + # non-local node via a custom resource constraint, and then we run an + # unpinned task f() that depends on the output of non_local(), ensuring + # that f() runs on the same node as non_local(). + cluster = ray_start_cluster + + # Disable worker caching so worker leases are not reused, and disable + # inlining of return objects so return objects are always put into Plasma. + cluster.add_node( + num_cpus=1, + _system_config={ + "worker_lease_timeout_milliseconds": 0, + "max_direct_call_object_size": 0, + }) + # Use a custom resource for pinning tasks to a node. + non_local_node = cluster.add_node(num_cpus=1, resources={"pin": 1}) + ray.init(address=cluster.address) + + @ray.remote(resources={"pin": 1}) + def non_local(): + return ray.worker.global_worker.node.unique_id + + @ray.remote + def f(x): + return ray.worker.global_worker.node.unique_id + + # Test that task f() runs on the same node as non_local(). + assert ray.get(f.remote(non_local.remote())) == non_local_node.unique_id + + def wait_for_num_objects(num_objects, timeout=10): start_time = time.time() while time.time() - start_time < timeout: @@ -805,7 +837,7 @@ def test_override_environment_variables_task(ray_start_regular): assert (ray.get( get_env.options(override_environment_variables={ - "a": "b" + "a": "b", }).remote("a")) == "b") @@ -817,7 +849,7 @@ def test_override_environment_variables_actor(ray_start_regular): a = EnvGetter.options(override_environment_variables={ "a": "b", - "c": "d" + "c": "d", }).remote() assert (ray.get(a.get.remote("a")) == "b") assert (ray.get(a.get.remote("c")) == "d") @@ -834,7 +866,7 @@ def test_override_environment_variables_nested_task(ray_start_regular): assert (ray.get( get_env_wrapper.options(override_environment_variables={ - "a": "b" + "a": "b", }).remote("a")) == "b") @@ -842,7 +874,7 @@ def test_override_environment_variables_multitenancy(shutdown_only): ray.init( job_config=ray.job_config.JobConfig(worker_env={ "foo1": "bar1", - "foo2": "bar2" + "foo2": "bar2", })) @ray.remote @@ -853,11 +885,11 @@ def test_override_environment_variables_multitenancy(shutdown_only): assert ray.get(get_env.remote("foo2")) == "bar2" assert ray.get( get_env.options(override_environment_variables={ - "foo1": "baz1" + "foo1": "baz1", }).remote("foo1")) == "baz1" assert ray.get( get_env.options(override_environment_variables={ - "foo1": "baz1" + "foo1": "baz1", }).remote("foo2")) == "bar2" @@ -866,7 +898,7 @@ def test_override_environment_variables_complex(shutdown_only): job_config=ray.job_config.JobConfig(worker_env={ "a": "job_a", "b": "job_b", - "z": "job_z" + "z": "job_z", })) @ray.remote @@ -892,13 +924,13 @@ def test_override_environment_variables_complex(shutdown_only): def nested_get(self, key): aa = NestedEnvGetter.options(override_environment_variables={ "c": "e", - "d": "dd" + "d": "dd", }).remote() return ray.get(aa.get.remote(key)) a = EnvGetter.options(override_environment_variables={ "a": "b", - "c": "d" + "c": "d", }).remote() assert (ray.get(a.get.remote("a")) == "b") assert (ray.get(a.get_task.remote("a")) == "b") @@ -907,7 +939,7 @@ def test_override_environment_variables_complex(shutdown_only): assert (ray.get(a.nested_get.remote("d")) == "dd") assert (ray.get( get_env.options(override_environment_variables={ - "a": "b" + "a": "b", }).remote("a")) == "b") assert (ray.get(a.get.remote("z")) == "job_z") @@ -915,7 +947,7 @@ def test_override_environment_variables_complex(shutdown_only): assert (ray.get(a.nested_get.remote("z")) == "job_z") assert (ray.get( get_env.options(override_environment_variables={ - "a": "b" + "a": "b", }).remote("z")) == "job_z") diff --git a/src/ray/common/ray_config_def.h b/src/ray/common/ray_config_def.h index fe41477f7..aa75f0e2d 100644 --- a/src/ray/common/ray_config_def.h +++ b/src/ray/common/ray_config_def.h @@ -363,3 +363,8 @@ RAY_CONFIG(int64_t, min_spilling_size, 100 * 1024 * 1024) /// When it is true, manual (force) spilling is not available. /// TODO(sang): Fix it. RAY_CONFIG(bool, automatic_object_deletion_enabled, true) + +/* Configuration parameters for locality-aware scheduling. */ +/// Whether to enable locality-aware leasing. If enabled, then Ray will consider task +/// dependency locality when choosing a worker for leasing. +RAY_CONFIG(bool, locality_aware_leasing_enabled, true) diff --git a/src/ray/core_worker/core_worker.cc b/src/ray/core_worker/core_worker.cc index 219b5e062..da44b1453 100644 --- a/src/ray/core_worker/core_worker.cc +++ b/src/ray/core_worker/core_worker.cc @@ -489,11 +489,28 @@ CoreWorker::CoreWorker(const CoreWorkerOptions &options, const WorkerID &worker_ new CoreWorkerDirectActorTaskSubmitter(core_worker_client_pool_, memory_store_, task_manager_)); + auto node_addr_factory = [this](const NodeID &node_id) { + absl::optional addr; + if (auto node_info = gcs_client_->Nodes().Get(node_id)) { + rpc::Address address; + address.set_raylet_id(node_info->node_id()); + address.set_ip_address(node_info->node_manager_address()); + address.set_port(node_info->node_manager_port()); + addr = address; + } + return addr; + }; + auto lease_policy = RayConfig::instance().locality_aware_leasing_enabled() + ? std::shared_ptr( + std::make_shared( + reference_counter_, node_addr_factory, rpc_address_)) + : std::shared_ptr( + std::make_shared(rpc_address_)); direct_task_submitter_ = std::unique_ptr(new CoreWorkerDirectTaskSubmitter( rpc_address_, local_raylet_client_, core_worker_client_pool_, - raylet_client_factory, memory_store_, task_manager_, local_raylet_id, - RayConfig::instance().worker_lease_timeout_milliseconds(), + raylet_client_factory, std::move(lease_policy), memory_store_, task_manager_, + local_raylet_id, RayConfig::instance().worker_lease_timeout_milliseconds(), std::move(actor_creator), RayConfig::instance().max_tasks_in_flight_per_worker(), boost::asio::steady_timer(io_service_))); diff --git a/src/ray/core_worker/core_worker.h b/src/ray/core_worker/core_worker.h index 256f0c42d..1358af3de 100644 --- a/src/ray/core_worker/core_worker.h +++ b/src/ray/core_worker/core_worker.h @@ -23,6 +23,7 @@ #include "ray/core_worker/common.h" #include "ray/core_worker/context.h" #include "ray/core_worker/future_resolver.h" +#include "ray/core_worker/lease_policy.h" #include "ray/core_worker/object_recovery_manager.h" #include "ray/core_worker/profiling.h" #include "ray/core_worker/reference_count.h" diff --git a/src/ray/core_worker/lease_policy.cc b/src/ray/core_worker/lease_policy.cc new file mode 100644 index 000000000..de6acd11b --- /dev/null +++ b/src/ray/core_worker/lease_policy.cc @@ -0,0 +1,61 @@ +// Copyright 2017 The Ray Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ray/core_worker/lease_policy.h" + +namespace ray { + +rpc::Address LocalityAwareLeasePolicy::GetBestNodeForTask(const TaskSpecification &spec) { + if (auto node_id = GetBestNodeIdForTask(spec)) { + if (auto addr = node_addr_factory_(node_id.value())) { + return addr.value(); + } + } + return fallback_rpc_address_; +} + +/// Criteria for "best" node: The node with the most object bytes (from object_ids) local. +absl::optional LocalityAwareLeasePolicy::GetBestNodeIdForTask( + const TaskSpecification &spec) { + const auto object_ids = spec.GetDependencyIds(); + // Number of object bytes (from object_ids) that a given node has local. + absl::flat_hash_map bytes_local_table; + uint64_t max_bytes = 0; + absl::optional max_bytes_node; + // Finds the node with the maximum number of object bytes local. + for (const ObjectID &object_id : object_ids) { + if (auto locality_data = locality_data_provider_->GetLocalityData(object_id)) { + for (const NodeID &node_id : locality_data->nodes_containing_object) { + auto &bytes = bytes_local_table[node_id]; + bytes += locality_data->object_size; + // Update max, if needed. + if (bytes > max_bytes) { + max_bytes = bytes; + max_bytes_node = node_id; + } + } + } else { + RAY_LOG(WARNING) << "No locality data available for object " << object_id + << ", won't be included in locality cost"; + } + } + return max_bytes_node; +} + +rpc::Address LocalLeasePolicy::GetBestNodeForTask(const TaskSpecification &spec) { + // Always return the local node. + return local_node_rpc_address_; +} + +} // namespace ray diff --git a/src/ray/core_worker/lease_policy.h b/src/ray/core_worker/lease_policy.h new file mode 100644 index 000000000..e8762fe2a --- /dev/null +++ b/src/ray/core_worker/lease_policy.h @@ -0,0 +1,98 @@ +// Copyright 2017 The Ray Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "absl/base/thread_annotations.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "ray/common/id.h" +#include "ray/common/task/task_spec.h" +#include "src/ray/protobuf/common.pb.h" + +namespace ray { + +struct LocalityData { + uint64_t object_size; + absl::flat_hash_set nodes_containing_object; +}; + +/// Interface for providers of locality data to the lease policy. +class LocalityDataProviderInterface { + public: + virtual absl::optional GetLocalityData(const ObjectID &object_id) = 0; + + virtual ~LocalityDataProviderInterface() {} +}; + +/// Interface for mocking the lease policy. +class LeasePolicyInterface { + public: + /// Get the address of the best worker node for a lease request for the provided task. + virtual rpc::Address GetBestNodeForTask(const TaskSpecification &spec) = 0; + + virtual ~LeasePolicyInterface() {} +}; + +typedef std::function(const NodeID &node_id)> + NodeAddrFactory; + +/// Class used by the core worker to implement a locality-aware lease policy for +/// picking a worker node for a lease request. This class is not thread-safe. +class LocalityAwareLeasePolicy : public LeasePolicyInterface { + public: + LocalityAwareLeasePolicy( + std::shared_ptr locality_data_provider, + NodeAddrFactory node_addr_factory, const rpc::Address fallback_rpc_address) + : locality_data_provider_(locality_data_provider), + node_addr_factory_(node_addr_factory), + fallback_rpc_address_(fallback_rpc_address) {} + + ~LocalityAwareLeasePolicy() {} + + /// Get the address of the best worker node for a lease request for the provided task. + rpc::Address GetBestNodeForTask(const TaskSpecification &spec); + + private: + /// Get the best worker node for a lease request for the provided task. + absl::optional GetBestNodeIdForTask(const TaskSpecification &spec); + + /// Provider of locality data that will be used in choosing the best lessor. + std::shared_ptr locality_data_provider_; + + /// Factory for building node RPC addresses given a NodeID. + NodeAddrFactory node_addr_factory_; + + /// RPC address of fallback node (usually the local node). + const rpc::Address fallback_rpc_address_; +}; + +/// Class used by the core worker to implement a local-only lease policy for picking +/// a worker node for a lease request. This class is not thread-safe. +class LocalLeasePolicy : public LeasePolicyInterface { + public: + LocalLeasePolicy(const rpc::Address local_node_rpc_address) + : local_node_rpc_address_(local_node_rpc_address) {} + + ~LocalLeasePolicy() {} + + /// Get the address of the local node for a lease request for the provided task. + rpc::Address GetBestNodeForTask(const TaskSpecification &spec); + + private: + /// RPC address of the local node. + const rpc::Address local_node_rpc_address_; +}; + +} // namespace ray diff --git a/src/ray/core_worker/reference_count.cc b/src/ray/core_worker/reference_count.cc index ef0168af7..be89794ed 100644 --- a/src/ray/core_worker/reference_count.cc +++ b/src/ray/core_worker/reference_count.cc @@ -939,6 +939,39 @@ void ReferenceCounter::HandleObjectSpilled(const ObjectID &object_id) { ReleasePlasmaObject(it); } +absl::optional ReferenceCounter::GetLocalityData( + const ObjectID &object_id) { + absl::MutexLock lock(&mutex_); + // Uses the reference table to return locality data for an object. + auto it = object_id_refs_.find(object_id); + if (it == object_id_refs_.end()) { + RAY_LOG(DEBUG) << "Object " << object_id + << " not in reference table, locality data not available"; + return absl::nullopt; + } + + const auto &node_id = it->second.pinned_at_raylet_id; + if (!node_id.has_value()) { + RAY_LOG(DEBUG) + << "Reference " << it->second.call_site << " for object " << object_id + << " doesn't have a defined pinned raylet ID, locality data not available"; + return absl::nullopt; + } + // The raylet ID to which this reference is pinned is defined. + + const auto object_size = it->second.object_size; + if (object_size < 0) { + RAY_LOG(DEBUG) << "Reference " << it->second.call_site << " for object " << object_id + << " has an unknown object size, locality data not available"; + return absl::nullopt; + } + // The object size of this reference is known. + + absl::optional locality_data( + {static_cast(object_size), {node_id.value()}}); + return locality_data; +} + ReferenceCounter::Reference ReferenceCounter::Reference::FromProto( const rpc::ObjectReferenceCount &ref_count) { Reference ref; diff --git a/src/ray/core_worker/reference_count.h b/src/ray/core_worker/reference_count.h index 03bffa5fe..d18684e32 100644 --- a/src/ray/core_worker/reference_count.h +++ b/src/ray/core_worker/reference_count.h @@ -21,6 +21,7 @@ #include "absl/container/flat_hash_set.h" #include "absl/synchronization/mutex.h" #include "ray/common/id.h" +#include "ray/core_worker/lease_policy.h" #include "ray/rpc/grpc_server.h" #include "ray/rpc/worker/core_worker_client.h" #include "ray/rpc/worker/core_worker_client_pool.h" @@ -50,7 +51,8 @@ class ReferenceCounterInterface { /// Class used by the core worker to keep track of ObjectID reference counts for garbage /// collection. This class is thread safe. -class ReferenceCounter : public ReferenceCounterInterface { +class ReferenceCounter : public ReferenceCounterInterface, + public LocalityDataProviderInterface { public: using ReferenceTableProto = ::google::protobuf::RepeatedPtrField; @@ -386,6 +388,9 @@ class ReferenceCounter : public ReferenceCounterInterface { /// records that the object has been spilled to suppress reconstruction. void HandleObjectSpilled(const ObjectID &object_id); + /// Get locality data for object. + absl::optional GetLocalityData(const ObjectID &object_id); + private: struct Reference { /// Constructor for a reference whose origin is unknown. diff --git a/src/ray/core_worker/reference_count_test.cc b/src/ray/core_worker/reference_count_test.cc index 8362b2d21..4d36851f6 100644 --- a/src/ray/core_worker/reference_count_test.cc +++ b/src/ray/core_worker/reference_count_test.cc @@ -321,6 +321,47 @@ TEST_F(ReferenceCountTest, TestReferenceStats) { ASSERT_EQ(stats2.object_refs(0).call_site(), "file2.py:43"); } +// Tests fetching of locality data from reference table. +TEST_F(ReferenceCountTest, TestGetLocalityData) { + ObjectID obj1 = ObjectID::FromRandom(); + ObjectID obj2 = ObjectID::FromRandom(); + NodeID node1 = NodeID::FromRandom(); + NodeID node2 = NodeID::FromRandom(); + rpc::Address address; + address.set_ip_address("1234"); + + // Owned object with defined object size and pinned node location should return valid + // locality data. + int64_t object_size = 100; + rc->AddOwnedObject(obj1, {}, address, "file2.py:42", object_size, false, + absl::optional(node1)); + auto locality_data_obj1 = rc->GetLocalityData(obj1); + ASSERT_TRUE(locality_data_obj1.has_value()); + ASSERT_EQ(locality_data_obj1->object_size, object_size); + ASSERT_EQ(locality_data_obj1->nodes_containing_object, + absl::flat_hash_set{node1}); + + // Fetching locality data for an object that doesn't have a reference in the table + // should return a null optional. + auto locality_data_obj2_not_exist = rc->GetLocalityData(obj2); + ASSERT_FALSE(locality_data_obj2_not_exist.has_value()); + + // Fetching locality data for an object that doesn't have a pinned node location + // defined should return a null optional. + rc->AddLocalReference(obj2, "file.py:43"); + rc->UpdateObjectSize(obj2, 200); + auto locality_data_obj2_no_pinned_raylet = rc->GetLocalityData(obj2); + ASSERT_FALSE(locality_data_obj2_no_pinned_raylet.has_value()); + rc->RemoveLocalReference(obj2, nullptr); + + // Fetching locality data for an object that doesn't have an object size defined + // should return a null optional. + rc->AddOwnedObject(obj2, {}, address, "file2.py:43", -1, false, + absl::optional(node2)); + auto locality_data_obj2_no_object_size = rc->GetLocalityData(obj2); + ASSERT_FALSE(locality_data_obj2_no_object_size.has_value()); +} + // Tests that we can get the owner address correctly for objects that we own, // objects that we borrowed via a serialized object ID, and objects whose // origin we do not know. diff --git a/src/ray/core_worker/test/direct_task_transport_test.cc b/src/ray/core_worker/test/direct_task_transport_test.cc index 27af163b2..dcc85163a 100644 --- a/src/ray/core_worker/test/direct_task_transport_test.cc +++ b/src/ray/core_worker/test/direct_task_transport_test.cc @@ -183,6 +183,25 @@ class MockActorCreator : public ActorCreatorInterface { ~MockActorCreator() {} }; +class MockLeasePolicy : public LeasePolicyInterface { + public: + MockLeasePolicy(const NodeID &node_id = NodeID::Nil()) { + fallback_rpc_address_ = rpc::Address(); + fallback_rpc_address_.set_raylet_id(node_id.Binary()); + } + + rpc::Address GetBestNodeForTask(const TaskSpecification &spec) { + num_lease_policy_consults++; + return fallback_rpc_address_; + }; + + ~MockLeasePolicy() {} + + rpc::Address fallback_rpc_address_; + + int num_lease_policy_consults = 0; +}; + TEST(TestMemoryStore, TestPromoteToPlasma) { size_t num_plasma_puts = 0; auto mem = std::make_shared( @@ -341,9 +360,10 @@ TEST(DirectTaskTransportTest, TestSubmitOneTask) { [&](const rpc::Address &addr) { return worker_client; }); auto task_finisher = std::make_shared(); auto actor_creator = std::make_shared(); + auto lease_policy = std::make_shared(); CoreWorkerDirectTaskSubmitter submitter(address, raylet_client, client_pool, nullptr, - store, task_finisher, NodeID::Nil(), - kLongTimeout, actor_creator); + lease_policy, store, task_finisher, + NodeID::Nil(), kLongTimeout, actor_creator); std::unordered_map empty_resources; ray::FunctionDescriptor empty_descriptor = @@ -351,6 +371,7 @@ TEST(DirectTaskTransportTest, TestSubmitOneTask) { TaskSpecification task = BuildTaskSpec(empty_resources, empty_descriptor); ASSERT_TRUE(submitter.SubmitTask(task).ok()); + ASSERT_EQ(lease_policy->num_lease_policy_consults, 1); ASSERT_EQ(raylet_client->num_workers_requested, 1); ASSERT_EQ(raylet_client->num_workers_returned, 0); ASSERT_EQ(worker_client->callbacks.size(), 0); @@ -382,9 +403,10 @@ TEST(DirectTaskTransportTest, TestHandleTaskFailure) { [&](const rpc::Address &addr) { return worker_client; }); auto task_finisher = std::make_shared(); auto actor_creator = std::make_shared(); + auto lease_policy = std::make_shared(); CoreWorkerDirectTaskSubmitter submitter(address, raylet_client, client_pool, nullptr, - store, task_finisher, NodeID::Nil(), - kLongTimeout, actor_creator); + lease_policy, store, task_finisher, + NodeID::Nil(), kLongTimeout, actor_creator); std::unordered_map empty_resources; ray::FunctionDescriptor empty_descriptor = ray::FunctionDescriptorBuilder::BuildPython("", "", "", ""); @@ -416,9 +438,10 @@ TEST(DirectTaskTransportTest, TestConcurrentWorkerLeases) { [&](const rpc::Address &addr) { return worker_client; }); auto task_finisher = std::make_shared(); auto actor_creator = std::make_shared(); + auto lease_policy = std::make_shared(); CoreWorkerDirectTaskSubmitter submitter(address, raylet_client, client_pool, nullptr, - store, task_finisher, NodeID::Nil(), - kLongTimeout, actor_creator); + lease_policy, store, task_finisher, + NodeID::Nil(), kLongTimeout, actor_creator); std::unordered_map empty_resources; ray::FunctionDescriptor empty_descriptor = ray::FunctionDescriptorBuilder::BuildPython("", "", "", ""); @@ -429,21 +452,25 @@ TEST(DirectTaskTransportTest, TestConcurrentWorkerLeases) { ASSERT_TRUE(submitter.SubmitTask(task1).ok()); ASSERT_TRUE(submitter.SubmitTask(task2).ok()); ASSERT_TRUE(submitter.SubmitTask(task3).ok()); + ASSERT_EQ(lease_policy->num_lease_policy_consults, 1); ASSERT_EQ(raylet_client->num_workers_requested, 1); // Task 1 is pushed; worker 2 is requested. ASSERT_TRUE(raylet_client->GrantWorkerLease("localhost", 1000, NodeID::Nil())); ASSERT_EQ(worker_client->callbacks.size(), 1); + ASSERT_EQ(lease_policy->num_lease_policy_consults, 2); ASSERT_EQ(raylet_client->num_workers_requested, 2); // Task 2 is pushed; worker 3 is requested. ASSERT_TRUE(raylet_client->GrantWorkerLease("localhost", 1001, NodeID::Nil())); ASSERT_EQ(worker_client->callbacks.size(), 2); + ASSERT_EQ(lease_policy->num_lease_policy_consults, 3); ASSERT_EQ(raylet_client->num_workers_requested, 3); // Task 3 is pushed; no more workers requested. ASSERT_TRUE(raylet_client->GrantWorkerLease("localhost", 1002, NodeID::Nil())); ASSERT_EQ(worker_client->callbacks.size(), 3); + ASSERT_EQ(lease_policy->num_lease_policy_consults, 3); ASSERT_EQ(raylet_client->num_workers_requested, 3); // All workers returned. @@ -471,9 +498,10 @@ TEST(DirectTaskTransportTest, TestReuseWorkerLease) { [&](const rpc::Address &addr) { return worker_client; }); auto task_finisher = std::make_shared(); auto actor_creator = std::make_shared(); + auto lease_policy = std::make_shared(); CoreWorkerDirectTaskSubmitter submitter(address, raylet_client, client_pool, nullptr, - store, task_finisher, NodeID::Nil(), - kLongTimeout, actor_creator); + lease_policy, store, task_finisher, + NodeID::Nil(), kLongTimeout, actor_creator); std::unordered_map empty_resources; ray::FunctionDescriptor empty_descriptor = ray::FunctionDescriptorBuilder::BuildPython("", "", "", ""); @@ -484,11 +512,13 @@ TEST(DirectTaskTransportTest, TestReuseWorkerLease) { ASSERT_TRUE(submitter.SubmitTask(task1).ok()); ASSERT_TRUE(submitter.SubmitTask(task2).ok()); ASSERT_TRUE(submitter.SubmitTask(task3).ok()); + ASSERT_EQ(lease_policy->num_lease_policy_consults, 1); ASSERT_EQ(raylet_client->num_workers_requested, 1); // Task 1 is pushed. ASSERT_TRUE(raylet_client->GrantWorkerLease("localhost", 1000, NodeID::Nil())); ASSERT_EQ(worker_client->callbacks.size(), 1); + ASSERT_EQ(lease_policy->num_lease_policy_consults, 2); ASSERT_EQ(raylet_client->num_workers_requested, 2); ASSERT_EQ(raylet_client->num_leases_canceled, 0); @@ -511,6 +541,7 @@ TEST(DirectTaskTransportTest, TestReuseWorkerLease) { // The second lease request is returned immediately. ASSERT_TRUE(raylet_client->GrantWorkerLease("localhost", 1001, NodeID::Nil())); ASSERT_EQ(worker_client->callbacks.size(), 0); + ASSERT_EQ(lease_policy->num_lease_policy_consults, 2); ASSERT_EQ(raylet_client->num_workers_returned, 2); ASSERT_EQ(raylet_client->num_workers_disconnected, 0); ASSERT_EQ(task_finisher->num_tasks_complete, 3); @@ -532,9 +563,10 @@ TEST(DirectTaskTransportTest, TestRetryLeaseCancellation) { [&](const rpc::Address &addr) { return worker_client; }); auto task_finisher = std::make_shared(); auto actor_creator = std::make_shared(); + auto lease_policy = std::make_shared(); CoreWorkerDirectTaskSubmitter submitter(address, raylet_client, client_pool, nullptr, - store, task_finisher, NodeID::Nil(), - kLongTimeout, actor_creator); + lease_policy, store, task_finisher, + NodeID::Nil(), kLongTimeout, actor_creator); std::unordered_map empty_resources; ray::FunctionDescriptor empty_descriptor = ray::FunctionDescriptorBuilder::BuildPython("", "", "", ""); @@ -592,9 +624,10 @@ TEST(DirectTaskTransportTest, TestConcurrentCancellationAndSubmission) { [&](const rpc::Address &addr) { return worker_client; }); auto task_finisher = std::make_shared(); auto actor_creator = std::make_shared(); + auto lease_policy = std::make_shared(); CoreWorkerDirectTaskSubmitter submitter(address, raylet_client, client_pool, nullptr, - store, task_finisher, NodeID::Nil(), - kLongTimeout, actor_creator); + lease_policy, store, task_finisher, + NodeID::Nil(), kLongTimeout, actor_creator); std::unordered_map empty_resources; ray::FunctionDescriptor empty_descriptor = ray::FunctionDescriptorBuilder::BuildPython("", "", "", ""); @@ -649,9 +682,10 @@ TEST(DirectTaskTransportTest, TestWorkerNotReusedOnError) { [&](const rpc::Address &addr) { return worker_client; }); auto task_finisher = std::make_shared(); auto actor_creator = std::make_shared(); + auto lease_policy = std::make_shared(); CoreWorkerDirectTaskSubmitter submitter(address, raylet_client, client_pool, nullptr, - store, task_finisher, NodeID::Nil(), - kLongTimeout, actor_creator); + lease_policy, store, task_finisher, + NodeID::Nil(), kLongTimeout, actor_creator); std::unordered_map empty_resources; ray::FunctionDescriptor empty_descriptor = ray::FunctionDescriptorBuilder::BuildPython("", "", "", ""); @@ -697,9 +731,10 @@ TEST(DirectTaskTransportTest, TestWorkerNotReturnedOnExit) { [&](const rpc::Address &addr) { return worker_client; }); auto task_finisher = std::make_shared(); auto actor_creator = std::make_shared(); + auto lease_policy = std::make_shared(); CoreWorkerDirectTaskSubmitter submitter(address, raylet_client, client_pool, nullptr, - store, task_finisher, NodeID::Nil(), - kLongTimeout, actor_creator); + lease_policy, store, task_finisher, + NodeID::Nil(), kLongTimeout, actor_creator); std::unordered_map empty_resources; ray::FunctionDescriptor empty_descriptor = ray::FunctionDescriptorBuilder::BuildPython("", "", "", ""); @@ -744,15 +779,17 @@ TEST(DirectTaskTransportTest, TestSpillback) { }; auto task_finisher = std::make_shared(); auto actor_creator = std::make_shared(); - CoreWorkerDirectTaskSubmitter submitter(address, raylet_client, client_pool, - lease_client_factory, store, task_finisher, - NodeID::Nil(), kLongTimeout, actor_creator); + auto lease_policy = std::make_shared(); + CoreWorkerDirectTaskSubmitter submitter( + address, raylet_client, client_pool, lease_client_factory, lease_policy, store, + task_finisher, NodeID::Nil(), kLongTimeout, actor_creator); std::unordered_map empty_resources; ray::FunctionDescriptor empty_descriptor = ray::FunctionDescriptorBuilder::BuildPython("", "", "", ""); TaskSpecification task = BuildTaskSpec(empty_resources, empty_descriptor); ASSERT_TRUE(submitter.SubmitTask(task).ok()); + ASSERT_EQ(lease_policy->num_lease_policy_consults, 1); ASSERT_EQ(raylet_client->num_workers_requested, 1); ASSERT_EQ(raylet_client->num_workers_returned, 0); ASSERT_EQ(worker_client->callbacks.size(), 0); @@ -762,6 +799,8 @@ TEST(DirectTaskTransportTest, TestSpillback) { auto remote_raylet_id = NodeID::FromRandom(); ASSERT_TRUE(raylet_client->GrantWorkerLease("localhost", 7777, remote_raylet_id)); ASSERT_EQ(remote_lease_clients.count(7777), 1); + // Confirm that lease policy is not consulted on spillback. + ASSERT_EQ(lease_policy->num_lease_policy_consults, 1); // There should be no more callbacks on the local client. ASSERT_FALSE(raylet_client->GrantWorkerLease("remote", 1234, NodeID::Nil())); // Trigger retry at the remote node. @@ -807,9 +846,10 @@ TEST(DirectTaskTransportTest, TestSpillbackRoundTrip) { auto task_finisher = std::make_shared(); auto local_raylet_id = NodeID::FromRandom(); auto actor_creator = std::make_shared(); - CoreWorkerDirectTaskSubmitter submitter(address, raylet_client, client_pool, - lease_client_factory, store, task_finisher, - local_raylet_id, kLongTimeout, actor_creator); + auto lease_policy = std::make_shared(local_raylet_id); + CoreWorkerDirectTaskSubmitter submitter( + address, raylet_client, client_pool, lease_client_factory, lease_policy, store, + task_finisher, local_raylet_id, kLongTimeout, actor_creator); std::unordered_map empty_resources; ray::FunctionDescriptor empty_descriptor = ray::FunctionDescriptorBuilder::BuildPython("", "", "", ""); @@ -825,6 +865,8 @@ TEST(DirectTaskTransportTest, TestSpillbackRoundTrip) { auto remote_raylet_id = NodeID::FromRandom(); ASSERT_TRUE(raylet_client->GrantWorkerLease("localhost", 7777, remote_raylet_id)); ASSERT_EQ(remote_lease_clients.count(7777), 1); + // Confirm that lease policy is not consulted on spillback. + ASSERT_EQ(lease_policy->num_lease_policy_consults, 1); ASSERT_FALSE(raylet_client->GrantWorkerLease("remote", 1234, NodeID::Nil())); // Trigger a spillback back to the local node. ASSERT_TRUE( @@ -868,9 +910,10 @@ void TestSchedulingKey(const std::shared_ptr store, [&](const rpc::Address &addr) { return worker_client; }); auto task_finisher = std::make_shared(); auto actor_creator = std::make_shared(); + auto lease_policy = std::make_shared(); CoreWorkerDirectTaskSubmitter submitter(address, raylet_client, client_pool, nullptr, - store, task_finisher, NodeID::Nil(), - kLongTimeout, actor_creator); + lease_policy, store, task_finisher, + NodeID::Nil(), kLongTimeout, actor_creator); ASSERT_TRUE(submitter.SubmitTask(same1).ok()); ASSERT_TRUE(submitter.SubmitTask(same2).ok()); @@ -994,8 +1037,10 @@ TEST(DirectTaskTransportTest, TestWorkerLeaseTimeout) { [&](const rpc::Address &addr) { return worker_client; }); auto task_finisher = std::make_shared(); auto actor_creator = std::make_shared(); + auto lease_policy = std::make_shared(); CoreWorkerDirectTaskSubmitter submitter(address, raylet_client, client_pool, nullptr, - store, task_finisher, NodeID::Nil(), + lease_policy, store, task_finisher, + NodeID::Nil(), /*lease_timeout_ms=*/5, actor_creator); std::unordered_map empty_resources; ray::FunctionDescriptor empty_descriptor = @@ -1053,9 +1098,10 @@ TEST(DirectTaskTransportTest, TestKillExecutingTask) { auto task_finisher = std::make_shared(); auto actor_creator = std::make_shared(); + auto lease_policy = std::make_shared(); CoreWorkerDirectTaskSubmitter submitter(address, raylet_client, client_pool, nullptr, - store, task_finisher, NodeID::Nil(), - kLongTimeout, actor_creator); + lease_policy, store, task_finisher, + NodeID::Nil(), kLongTimeout, actor_creator); std::unordered_map empty_resources; ray::FunctionDescriptor empty_descriptor = ray::FunctionDescriptorBuilder::BuildPython("", "", "", ""); @@ -1105,9 +1151,10 @@ TEST(DirectTaskTransportTest, TestKillPendingTask) { [&](const rpc::Address &addr) { return worker_client; }); auto task_finisher = std::make_shared(); auto actor_creator = std::make_shared(); + auto lease_policy = std::make_shared(); CoreWorkerDirectTaskSubmitter submitter(address, raylet_client, client_pool, nullptr, - store, task_finisher, NodeID::Nil(), - kLongTimeout, actor_creator); + lease_policy, store, task_finisher, + NodeID::Nil(), kLongTimeout, actor_creator); std::unordered_map empty_resources; ray::FunctionDescriptor empty_descriptor = ray::FunctionDescriptorBuilder::BuildPython("", "", "", ""); @@ -1141,9 +1188,10 @@ TEST(DirectTaskTransportTest, TestKillResolvingTask) { [&](const rpc::Address &addr) { return worker_client; }); auto task_finisher = std::make_shared(); auto actor_creator = std::make_shared(); + auto lease_policy = std::make_shared(); CoreWorkerDirectTaskSubmitter submitter(address, raylet_client, client_pool, nullptr, - store, task_finisher, NodeID::Nil(), - kLongTimeout, actor_creator); + lease_policy, store, task_finisher, + NodeID::Nil(), kLongTimeout, actor_creator); std::unordered_map empty_resources; ray::FunctionDescriptor empty_descriptor = ray::FunctionDescriptorBuilder::BuildPython("", "", "", ""); @@ -1176,14 +1224,15 @@ TEST(DirectTaskTransportTest, TestPipeliningConcurrentWorkerLeases) { [&](const rpc::Address &addr) { return worker_client; }); auto task_finisher = std::make_shared(); auto actor_creator = std::make_shared(); + auto lease_policy = std::make_shared(); - // Set max_tasks_in_flight_per_worker to a value larger than 1 to enable the pipelining - // of task submissions. This is done by passing a max_tasks_in_flight_per_worker - // parameter to the CoreWorkerDirectTaskSubmitter. + // Set max_tasks_in_flight_per_worker to a value larger than 1 to enable the + // pipelining of task submissions. This is done by passing a + // max_tasks_in_flight_per_worker parameter to the CoreWorkerDirectTaskSubmitter. uint32_t max_tasks_in_flight_per_worker = 10; CoreWorkerDirectTaskSubmitter submitter( - address, raylet_client, client_pool, nullptr, store, task_finisher, NodeID::Nil(), - kLongTimeout, actor_creator, max_tasks_in_flight_per_worker); + address, raylet_client, client_pool, nullptr, lease_policy, store, task_finisher, + NodeID::Nil(), kLongTimeout, actor_creator, max_tasks_in_flight_per_worker); // Prepare 20 tasks and save them in a vector. std::unordered_map empty_resources; @@ -1216,8 +1265,8 @@ TEST(DirectTaskTransportTest, TestPipeliningConcurrentWorkerLeases) { ASSERT_TRUE(worker_client->ReplyPushTask()); // No worker should be returned until all the tasks that were submitted to it have // been completed. In our case, the first worker should only be returned after the - // 10th task has been executed. The second worker should only be returned at the end, - // or after the 20th task has been executed. + // 10th task has been executed. The second worker should only be returned at the + // end, or after the 20th task has been executed. if (i < 10) { ASSERT_EQ(raylet_client->num_workers_returned, 0); } else if (i >= 10 && i < 20) { @@ -1250,14 +1299,15 @@ TEST(DirectTaskTransportTest, TestPipeliningReuseWorkerLease) { [&](const rpc::Address &addr) { return worker_client; }); auto task_finisher = std::make_shared(); auto actor_creator = std::make_shared(); + auto lease_policy = std::make_shared(); - // Set max_tasks_in_flight_per_worker to a value larger than 1 to enable the pipelining - // of task submissions. This is done by passing a max_tasks_in_flight_per_worker - // parameter to the CoreWorkerDirectTaskSubmitter. + // Set max_tasks_in_flight_per_worker to a value larger than 1 to enable the + // pipelining of task submissions. This is done by passing a + // max_tasks_in_flight_per_worker parameter to the CoreWorkerDirectTaskSubmitter. uint32_t max_tasks_in_flight_per_worker = 10; CoreWorkerDirectTaskSubmitter submitter( - address, raylet_client, client_pool, nullptr, store, task_finisher, NodeID::Nil(), - kLongTimeout, actor_creator, max_tasks_in_flight_per_worker); + address, raylet_client, client_pool, nullptr, lease_policy, store, task_finisher, + NodeID::Nil(), kLongTimeout, actor_creator, max_tasks_in_flight_per_worker); // prepare 30 tasks and save them in a vector std::unordered_map empty_resources; @@ -1329,14 +1379,15 @@ TEST(DirectTaskTransportTest, TestPipeliningNumberOfWorkersRequested) { [&](const rpc::Address &addr) { return worker_client; }); auto task_finisher = std::make_shared(); auto actor_creator = std::make_shared(); + auto lease_policy = std::make_shared(); - // Set max_tasks_in_flight_per_worker to a value larger than 1 to enable the pipelining - // of task submissions. This is done by passing a max_tasks_in_flight_per_worker - // parameter to the CoreWorkerDirectTaskSubmitter. + // Set max_tasks_in_flight_per_worker to a value larger than 1 to enable the + // pipelining of task submissions. This is done by passing a + // max_tasks_in_flight_per_worker parameter to the CoreWorkerDirectTaskSubmitter. uint32_t max_tasks_in_flight_per_worker = 10; CoreWorkerDirectTaskSubmitter submitter( - address, raylet_client, client_pool, nullptr, store, task_finisher, NodeID::Nil(), - kLongTimeout, actor_creator, max_tasks_in_flight_per_worker); + address, raylet_client, client_pool, nullptr, lease_policy, store, task_finisher, + NodeID::Nil(), kLongTimeout, actor_creator, max_tasks_in_flight_per_worker); // prepare 30 tasks and save them in a vector std::unordered_map empty_resources; @@ -1409,7 +1460,8 @@ TEST(DirectTaskTransportTest, TestPipeliningNumberOfWorkersRequested) { ASSERT_EQ(raylet_client->num_leases_canceled, 0); ASSERT_EQ(worker_client->callbacks.size(), 11); - // Submit 9 more tasks, and check that the total number of workers requested is still 2. + // Submit 9 more tasks, and check that the total number of workers requested is + // still 2. for (int i = 1; i <= 9; i++) { auto task = tasks.front(); ASSERT_TRUE(submitter.SubmitTask(task).ok()); @@ -1424,8 +1476,8 @@ TEST(DirectTaskTransportTest, TestPipeliningNumberOfWorkersRequested) { ASSERT_EQ(raylet_client->num_leases_canceled, 0); ASSERT_EQ(worker_client->callbacks.size(), 20); - // Call ReplyPushTask on a quarter of the submitted tasks (5), and check that the total - // number of workers requested remains equal to 2. + // Call ReplyPushTask on a quarter of the submitted tasks (5), and check that the + // total number of workers requested remains equal to 2. for (int i = 1; i <= 5; i++) { ASSERT_TRUE(worker_client->ReplyPushTask()); } @@ -1452,8 +1504,8 @@ TEST(DirectTaskTransportTest, TestPipeliningNumberOfWorkersRequested) { ASSERT_EQ(raylet_client->num_leases_canceled, 0); ASSERT_EQ(worker_client->callbacks.size(), 20); - // Call ReplyPushTask on a quarter of the submitted tasks (5), and check that the total - // number of workers requested remains equal to 2. + // Call ReplyPushTask on a quarter of the submitted tasks (5), and check that the + // total number of workers requested remains equal to 2. for (int i = 1; i <= 5; i++) { ASSERT_TRUE(worker_client->ReplyPushTask()); } @@ -1465,7 +1517,8 @@ TEST(DirectTaskTransportTest, TestPipeliningNumberOfWorkersRequested) { ASSERT_EQ(raylet_client->num_leases_canceled, 0); ASSERT_EQ(worker_client->callbacks.size(), 15); - // Submit last 5 tasks, and check that the total number of workers requested is still 2 + // Submit last 5 tasks, and check that the total number of workers requested is still + // 2 for (int i = 1; i <= 5; i++) { auto task = tasks.front(); ASSERT_TRUE(submitter.SubmitTask(task).ok()); diff --git a/src/ray/core_worker/test/lease_policy_test.cc b/src/ray/core_worker/test/lease_policy_test.cc new file mode 100644 index 000000000..3e7cbf33c --- /dev/null +++ b/src/ray/core_worker/test/lease_policy_test.cc @@ -0,0 +1,211 @@ +// Copyright 2017 The Ray Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ray/core_worker/lease_policy.h" + +#include "gtest/gtest.h" +#include "ray/common/task/task_spec.h" + +namespace ray { + +TaskSpecification CreateFakeTask(std::vector deps) { + TaskSpecification spec; + spec.GetMutableMessage().set_task_id(TaskID::ForFakeTask().Binary()); + for (auto &dep : deps) { + spec.GetMutableMessage().add_args()->mutable_object_ref()->set_object_id( + dep.Binary()); + } + return spec; +} + +class MockLocalityDataProvider : public LocalityDataProviderInterface { + public: + MockLocalityDataProvider() {} + + MockLocalityDataProvider(absl::flat_hash_map locality_data) + : locality_data_(locality_data) {} + + absl::optional GetLocalityData(const ObjectID &object_id) { + num_locality_data_fetches++; + return locality_data_[object_id]; + }; + + ~MockLocalityDataProvider() {} + + int num_locality_data_fetches = 0; + absl::flat_hash_map locality_data_; +}; + +absl::optional MockNodeAddrFactory(const NodeID &node_id) { + rpc::Address mock_rpc_address; + mock_rpc_address.set_raylet_id(node_id.Binary()); + absl::optional opt_mock_rpc_address = mock_rpc_address; + return opt_mock_rpc_address; +} + +absl::optional MockNodeAddrFactoryAlwaysNull(const NodeID &node_id) { + return absl::nullopt; +} + +TEST(LocalLeasePolicyTest, TestReturnFallback) { + NodeID fallback_node = NodeID::FromRandom(); + rpc::Address fallback_rpc_address = MockNodeAddrFactory(fallback_node).value(); + LocalLeasePolicy local_lease_policy(fallback_rpc_address); + ObjectID obj1 = ObjectID::FromRandom(); + ObjectID obj2 = ObjectID::FromRandom(); + std::vector deps{obj1, obj2}; + auto task_spec = CreateFakeTask(deps); + rpc::Address best_node_address = local_lease_policy.GetBestNodeForTask(task_spec); + // Test that fallback node was chosen. + ASSERT_EQ(NodeID::FromBinary(best_node_address.raylet_id()), fallback_node); +} + +TEST(LocalityAwareLeasePolicyTest, TestBestLocalityDominatingNode) { + absl::flat_hash_map locality_data; + NodeID fallback_node = NodeID::FromRandom(); + rpc::Address fallback_rpc_address = MockNodeAddrFactory(fallback_node).value(); + NodeID best_node = NodeID::FromRandom(); + ObjectID obj1 = ObjectID::FromRandom(); + ObjectID obj2 = ObjectID::FromRandom(); + // Both objects are local on best_node. + locality_data.emplace(obj1, LocalityData{8, {best_node}}); + locality_data.emplace(obj2, LocalityData{16, {best_node}}); + auto mock_locality_data_provider = + std::make_shared(locality_data); + LocalityAwareLeasePolicy locality_lease_policy( + mock_locality_data_provider, MockNodeAddrFactory, fallback_rpc_address); + std::vector deps{obj1, obj2}; + auto task_spec = CreateFakeTask(deps); + rpc::Address best_node_address = locality_lease_policy.GetBestNodeForTask(task_spec); + // Locality data provider should be called once for each dependency. + ASSERT_EQ(mock_locality_data_provider->num_locality_data_fetches, deps.size()); + // Test that best node was chosen. + ASSERT_EQ(NodeID::FromBinary(best_node_address.raylet_id()), best_node); +} + +TEST(LocalityAwareLeasePolicyTest, TestBestLocalityBiggerObject) { + absl::flat_hash_map locality_data; + NodeID fallback_node = NodeID::FromRandom(); + rpc::Address fallback_rpc_address = MockNodeAddrFactory(fallback_node).value(); + NodeID best_node = NodeID::FromRandom(); + NodeID bad_node = NodeID::FromRandom(); + ObjectID obj1 = ObjectID::FromRandom(); + ObjectID obj2 = ObjectID::FromRandom(); + // Larger object is local on best_node. + locality_data.emplace(obj1, LocalityData{8, {bad_node}}); + locality_data.emplace(obj2, LocalityData{16, {best_node}}); + auto mock_locality_data_provider = + std::make_shared(locality_data); + LocalityAwareLeasePolicy locality_lease_policy( + mock_locality_data_provider, MockNodeAddrFactory, fallback_rpc_address); + std::vector deps{obj1, obj2}; + auto task_spec = CreateFakeTask(deps); + rpc::Address best_node_address = locality_lease_policy.GetBestNodeForTask(task_spec); + // Locality data provider should be called once for each dependency. + ASSERT_EQ(mock_locality_data_provider->num_locality_data_fetches, deps.size()); + // Test that best node was chosen. + ASSERT_EQ(NodeID::FromBinary(best_node_address.raylet_id()), best_node); +} + +TEST(LocalityAwareLeasePolicyTest, TestBestLocalityBetterNode) { + absl::flat_hash_map locality_data; + NodeID fallback_node = NodeID::FromRandom(); + rpc::Address fallback_rpc_address = MockNodeAddrFactory(fallback_node).value(); + NodeID best_node = NodeID::FromRandom(); + NodeID bad_node = NodeID::FromRandom(); + ObjectID obj1 = ObjectID::FromRandom(); + ObjectID obj2 = ObjectID::FromRandom(); + ObjectID obj3 = ObjectID::FromRandom(); + // fallback_node: 8 bytes local + // bad_node: 24 bytes local + // best_node: 28 bytes local + locality_data.emplace(obj1, LocalityData{8, {fallback_node, bad_node}}); + locality_data.emplace(obj2, LocalityData{16, {best_node, bad_node}}); + locality_data.emplace(obj3, LocalityData{12, {best_node}}); + auto mock_locality_data_provider = + std::make_shared(locality_data); + LocalityAwareLeasePolicy locality_lease_policy( + mock_locality_data_provider, MockNodeAddrFactory, fallback_rpc_address); + std::vector deps{obj1, obj2, obj3}; + auto task_spec = CreateFakeTask(deps); + rpc::Address best_node_address = locality_lease_policy.GetBestNodeForTask(task_spec); + // Locality data provider should be called once for each dependency. + ASSERT_EQ(mock_locality_data_provider->num_locality_data_fetches, deps.size()); + // Test that best node was chosen. + ASSERT_EQ(NodeID::FromBinary(best_node_address.raylet_id()), best_node); +} + +TEST(LocalityAwareLeasePolicyTest, TestBestLocalityFallbackNoLocations) { + absl::flat_hash_map locality_data; + NodeID fallback_node = NodeID::FromRandom(); + rpc::Address fallback_rpc_address = MockNodeAddrFactory(fallback_node).value(); + ObjectID obj1 = ObjectID::FromRandom(); + ObjectID obj2 = ObjectID::FromRandom(); + // No known object locations. + locality_data.emplace(obj1, LocalityData{8, {}}); + locality_data.emplace(obj2, LocalityData{16, {}}); + auto mock_locality_data_provider = + std::make_shared(locality_data); + LocalityAwareLeasePolicy locality_lease_policy( + mock_locality_data_provider, MockNodeAddrFactory, fallback_rpc_address); + std::vector deps{obj1, obj2}; + auto task_spec = CreateFakeTask(deps); + rpc::Address best_node_address = locality_lease_policy.GetBestNodeForTask(task_spec); + // Locality data provider should be called once for each dependency. + ASSERT_EQ(mock_locality_data_provider->num_locality_data_fetches, deps.size()); + // Test that fallback node was chosen. + ASSERT_EQ(NodeID::FromBinary(best_node_address.raylet_id()), fallback_node); +} + +TEST(LocalityAwareLeasePolicyTest, TestBestLocalityFallbackNoDeps) { + absl::flat_hash_map locality_data; + NodeID fallback_node = NodeID::FromRandom(); + rpc::Address fallback_rpc_address = MockNodeAddrFactory(fallback_node).value(); + auto mock_locality_data_provider = std::make_shared(); + LocalityAwareLeasePolicy locality_lease_policy( + mock_locality_data_provider, MockNodeAddrFactory, fallback_rpc_address); + // No task dependencies. + std::vector deps; + auto task_spec = CreateFakeTask(deps); + rpc::Address best_node_address = locality_lease_policy.GetBestNodeForTask(task_spec); + // Locality data provider should be called once for each dependency. + ASSERT_EQ(mock_locality_data_provider->num_locality_data_fetches, deps.size()); + // Test that fallback node was chosen. + ASSERT_EQ(NodeID::FromBinary(best_node_address.raylet_id()), fallback_node); +} + +TEST(LocalityAwareLeasePolicyTest, TestBestLocalityFallbackAddrFetchFail) { + absl::flat_hash_map locality_data; + NodeID fallback_node = NodeID::FromRandom(); + rpc::Address fallback_rpc_address = MockNodeAddrFactory(fallback_node).value(); + NodeID best_node = NodeID::FromRandom(); + ObjectID obj1 = ObjectID::FromRandom(); + ObjectID obj2 = ObjectID::FromRandom(); + locality_data.emplace(obj1, LocalityData{8, {best_node}}); + locality_data.emplace(obj2, LocalityData{16, {best_node}}); + auto mock_locality_data_provider = + std::make_shared(locality_data); + // Provided node address factory always returns absl::nullopt. + LocalityAwareLeasePolicy locality_lease_policy( + mock_locality_data_provider, MockNodeAddrFactoryAlwaysNull, fallback_rpc_address); + std::vector deps{obj1, obj2}; + auto task_spec = CreateFakeTask(deps); + rpc::Address best_node_address = locality_lease_policy.GetBestNodeForTask(task_spec); + // Locality data provider should be called once for each dependency. + ASSERT_EQ(mock_locality_data_provider->num_locality_data_fetches, deps.size()); + // Test that fallback node was chosen. + ASSERT_EQ(NodeID::FromBinary(best_node_address.raylet_id()), fallback_node); +} + +} // namespace ray diff --git a/src/ray/core_worker/transport/direct_task_transport.cc b/src/ray/core_worker/transport/direct_task_transport.cc index c343c8d0f..4f7f2e2b7 100644 --- a/src/ray/core_worker/transport/direct_task_transport.cc +++ b/src/ray/core_worker/transport/direct_task_transport.cc @@ -232,8 +232,8 @@ std::shared_ptr CoreWorkerDirectTaskSubmitter::GetOrConnectLeaseClient( const rpc::Address *raylet_address) { std::shared_ptr lease_client; - if (raylet_address && - NodeID::FromBinary(raylet_address->raylet_id()) != local_raylet_id_) { + RAY_CHECK(raylet_address != nullptr); + if (NodeID::FromBinary(raylet_address->raylet_id()) != local_raylet_id_) { // A remote raylet was specified. Connect to the raylet if needed. NodeID raylet_id = NodeID::FromBinary(raylet_address->raylet_id()); auto it = remote_lease_clients_.find(raylet_id); @@ -281,8 +281,14 @@ void CoreWorkerDirectTaskSubmitter::RequestNewWorkerIfNeeded( return; } - auto lease_client = GetOrConnectLeaseClient(raylet_address); TaskSpecification &resource_spec = task_queue.front(); + rpc::Address best_node_address; + if (raylet_address == nullptr) { + // If no raylet address is given, find the best worker for our next lease request. + best_node_address = lease_policy_->GetBestNodeForTask(resource_spec); + raylet_address = &best_node_address; + } + auto lease_client = GetOrConnectLeaseClient(raylet_address); TaskID task_id = resource_spec.TaskId(); // Subtract 1 so we don't double count the task we are requesting for. int64_t queue_size = task_queue.size() - 1; diff --git a/src/ray/core_worker/transport/direct_task_transport.h b/src/ray/core_worker/transport/direct_task_transport.h index 19a2a7080..eb36a8a57 100644 --- a/src/ray/core_worker/transport/direct_task_transport.h +++ b/src/ray/core_worker/transport/direct_task_transport.h @@ -22,6 +22,7 @@ #include "ray/common/ray_object.h" #include "ray/core_worker/actor_manager.h" #include "ray/core_worker/context.h" +#include "ray/core_worker/lease_policy.h" #include "ray/core_worker/store_provider/memory_store/memory_store.h" #include "ray/core_worker/task_manager.h" #include "ray/core_worker/transport/dependency_resolver.h" @@ -54,6 +55,7 @@ class CoreWorkerDirectTaskSubmitter { rpc::Address rpc_address, std::shared_ptr lease_client, std::shared_ptr core_worker_client_pool, LeaseClientFactoryFn lease_client_factory, + std::shared_ptr lease_policy, std::shared_ptr store, std::shared_ptr task_finisher, NodeID local_raylet_id, int64_t lease_timeout_ms, std::shared_ptr actor_creator, @@ -63,6 +65,7 @@ class CoreWorkerDirectTaskSubmitter { : rpc_address_(rpc_address), local_lease_client_(lease_client), lease_client_factory_(lease_client_factory), + lease_policy_(std::move(lease_policy)), resolver_(store, task_finisher), task_finisher_(task_finisher), lease_timeout_ms_(lease_timeout_ms), @@ -159,6 +162,10 @@ class CoreWorkerDirectTaskSubmitter { /// Factory for producing new clients to request leases from remote nodes. LeaseClientFactoryFn lease_client_factory_; + /// Provider of worker leasing decisions for the first lease request (not on + /// spillback). + std::shared_ptr lease_policy_; + /// Resolve local and remote dependencies; LocalDependencyResolver resolver_;