From fdb4c6eb1cfc40c0680f3b276394e8ba1d4c5c6d Mon Sep 17 00:00:00 2001 From: Edward Oakes Date: Wed, 16 Dec 2020 10:30:20 -0600 Subject: [PATCH 01/88] Better message for too little /dev/shm memory (#12896) --- python/ray/_private/services.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/python/ray/_private/services.py b/python/ray/_private/services.py index 62adfd072..c3512ab92 100644 --- a/python/ray/_private/services.py +++ b/python/ray/_private/services.py @@ -1618,12 +1618,13 @@ def determine_plasma_store_config(object_store_memory, logger.warning( "WARNING: The object store is using {} instead of " "/dev/shm because /dev/shm has only {} bytes available. " - "This may slow down performance! You may be able to free " - "up space by deleting files in /dev/shm or terminating " - "any running plasma_store_server processes. If you are " - "inside a Docker container, you may need to pass an " - "argument with the flag '--shm-size' to 'docker run'.". - format(ray.utils.get_user_temp_dir(), shm_avail)) + "This will harm performance! You may be able to free up " + "space by deleting files in /dev/shm. If you are inside a " + "Docker container, you can increase /dev/shm size by " + "passing '--shm-size=Xgb' to 'docker run' (or add it to " + "the run_options list in a Ray cluster config). Make sure " + "to set this to more than 2gb.".format( + ray.utils.get_user_temp_dir(), shm_avail)) else: plasma_directory = ray.utils.get_user_temp_dir() From aedcf0c9d947669417586b2947e8f6ae81e3652b Mon Sep 17 00:00:00 2001 From: Edward Oakes Date: Wed, 16 Dec 2020 16:17:49 -0600 Subject: [PATCH 02/88] Disable test_distributions (#12919) --- rllib/BUILD | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/rllib/BUILD b/rllib/BUILD index bd612d0ff..bde7d5855 100644 --- a/rllib/BUILD +++ b/rllib/BUILD @@ -1089,12 +1089,13 @@ py_test( srcs = ["models/tests/test_convtranspose2d_stack.py"] ) -py_test( - name = "test_distributions", - tags = ["models"], - size = "medium", - srcs = ["models/tests/test_distributions.py"] -) +# Failing after the following PR: https://github.com/ray-project/ray/pull/12760. +#py_test( +# name = "test_distributions", +# tags = ["models"], +# size = "medium", +# srcs = ["models/tests/test_distributions.py"] +#) # -------------------------------------------------------------------- # Evaluation components From c677b9e201d9923ff6255a5aa5e74b0b65107471 Mon Sep 17 00:00:00 2001 From: Ameer Haj Ali Date: Thu, 17 Dec 2020 00:18:27 +0200 Subject: [PATCH 03/88] [autoscaler] Fix flaky autoscaler test (#12918) --- python/ray/tests/test_autoscaler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/ray/tests/test_autoscaler.py b/python/ray/tests/test_autoscaler.py index fa9b9241e..7ef1e9c5b 100644 --- a/python/ray/tests/test_autoscaler.py +++ b/python/ray/tests/test_autoscaler.py @@ -1076,7 +1076,7 @@ class AutoscalingTest(unittest.TestCase): 2, tag_filters={TAG_RAY_NODE_STATUS: STATUS_UPDATE_FAILED}) except AssertionError: # The failed nodes might have been already terminated by autoscaler - assert len(self.provider.non_terminated_nodes({})) == 0 + assert len(self.provider.non_terminated_nodes({})) < 2 def testConfiguresOutdatedNodes(self): config_path = self.write_config(SMALL_CLUSTER) From 8b783ecafaee2b70849e5aa60000108b876b8dd7 Mon Sep 17 00:00:00 2001 From: Alex Wu Date: Wed, 16 Dec 2020 14:18:43 -0800 Subject: [PATCH 04/88] Fix pull manager retry (#12907) --- python/ray/tests/test_object_manager.py | 42 ++++++++++++++++++++++++ src/ray/object_manager/object_manager.cc | 21 +++++++----- src/ray/object_manager/object_manager.h | 3 +- 3 files changed, 55 insertions(+), 11 deletions(-) diff --git a/python/ray/tests/test_object_manager.py b/python/ray/tests/test_object_manager.py index c474fced4..b29b9caa2 100644 --- a/python/ray/tests/test_object_manager.py +++ b/python/ray/tests/test_object_manager.py @@ -254,6 +254,48 @@ def test_many_small_transfers(ray_start_cluster_with_resource): do_transfers() +# This is a basic test to ensure that the pull request retry timer is +# integrated properly. To test it, we create a 2 node cluster then do the +# following: +# (1) Fill up the driver's object store. +# (2) Fill up the remote node's object store. +# (3) Try to get the remote object. This should fail due to an OOM error caused +# by step 1. +# (4) Allow the local object to be evicted. +# (5) Try to get the object again. Now the retry timer should kick in and +# successfuly pull the remote object. +@pytest.mark.timeout(30) +def test_pull_request_retry(shutdown_only): + cluster = Cluster() + cluster.add_node(num_cpus=0, num_gpus=1, object_store_memory=100 * 2**20) + cluster.add_node(num_cpus=1, num_gpus=0, object_store_memory=100 * 2**20) + cluster.wait_for_nodes() + ray.init(address=cluster.address) + + @ray.remote + def put(): + return np.zeros(64 * 2**20, dtype=np.int8) + + @ray.remote(num_cpus=0, num_gpus=1) + def driver(): + local_ref = ray.put(np.zeros(64 * 2**20, dtype=np.int8)) + + remote_ref = put.remote() + + ready, _ = ray.wait([remote_ref], timeout=1) + assert len(ready) == 0 + + del local_ref + + # This should always complete within 10 seconds. + ready, _ = ray.wait([remote_ref], timeout=20) + assert len(ready) > 0 + + # Pretend the GPU node is the driver. We do this to force the placement of + # the driver and `put` task on different nodes. + ray.get(driver.remote()) + + if __name__ == "__main__": import pytest import sys diff --git a/src/ray/object_manager/object_manager.cc b/src/ray/object_manager/object_manager.cc index 2eb641d04..3a31da864 100644 --- a/src/ray/object_manager/object_manager.cc +++ b/src/ray/object_manager/object_manager.cc @@ -89,14 +89,7 @@ ObjectManager::ObjectManager(asio::io_service &main_service, const NodeID &self_ static_cast(1L), static_cast(config_.max_bytes_in_flight / config_.object_chunk_size)))); - pull_retry_timer_.async_wait([this](const boost::system::error_code &e) { - RAY_CHECK(!e) << "The raylet's object manager has failed unexpectedly with error: " - << e - << ". Please file a bug report on here: " - "https://github.com/ray-project/ray/issues"; - - Tick(); - }); + pull_retry_timer_.async_wait([this](const boost::system::error_code &e) { Tick(e); }); if (plasma::plasma_store_runner) { store_notification_ = std::make_shared(main_service); @@ -814,6 +807,16 @@ void ObjectManager::RecordMetrics() const { stats::ObjectManagerPullRequests().Record(pull_manager_->NumActiveRequests()); } -void ObjectManager::Tick() { pull_manager_->Tick(); } +void ObjectManager::Tick(const boost::system::error_code &e) { + RAY_CHECK(!e) << "The raylet's object manager has failed unexpectedly with error: " << e + << ". Please file a bug report on here: " + "https://github.com/ray-project/ray/issues"; + + pull_manager_->Tick(); + + auto interval = boost::posix_time::milliseconds(config_.timer_freq_ms); + pull_retry_timer_.expires_from_now(interval); + pull_retry_timer_.async_wait([this](const boost::system::error_code &e) { Tick(e); }); +} } // namespace ray diff --git a/src/ray/object_manager/object_manager.h b/src/ray/object_manager/object_manager.h index 75dbb5bd4..ff409eb18 100644 --- a/src/ray/object_manager/object_manager.h +++ b/src/ray/object_manager/object_manager.h @@ -284,7 +284,7 @@ class ObjectManager : public ObjectManagerInterface, /// Record metrics. void RecordMetrics() const; - void Tick(); + void Tick(const boost::system::error_code &e); private: friend class TestObjectManager; @@ -458,7 +458,6 @@ class ObjectManager : public ObjectManagerInterface, const RestoreSpilledObjectCallback restore_spilled_object_; /// Pull manager retry timer . - /* std::unique_ptr pull_retry_timer_; */ boost::asio::deadline_timer pull_retry_timer_; /// Object push manager. From dd522a71a13788b3e65bcdb2848bd65b2ce93daa Mon Sep 17 00:00:00 2001 From: Amog Kamsetty Date: Wed, 16 Dec 2020 15:37:44 -0800 Subject: [PATCH 05/88] [SGD] Disable Elastic Training by default when using with Tune (#12927) --- python/ray/util/sgd/torch/torch_trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/ray/util/sgd/torch/torch_trainer.py b/python/ray/util/sgd/torch/torch_trainer.py index d17c23191..50f20e39f 100644 --- a/python/ray/util/sgd/torch/torch_trainer.py +++ b/python/ray/util/sgd/torch/torch_trainer.py @@ -713,7 +713,7 @@ class BaseTorchTrainable(Trainable): "removed in " "a future version of Ray. Override Trainable.step instead.") - train_stats = self.trainer.train(max_retries=10, profile=True) + train_stats = self.trainer.train(max_retries=0, profile=True) validation_stats = self.trainer.validate(profile=True) stats = merge_dicts(train_stats, validation_stats) return stats From ad036fd5645a023b8888798c1506557f462e0a1b Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Wed, 16 Dec 2020 16:09:13 -0800 Subject: [PATCH 06/88] Fix continue for debugger (#12862) --- python/ray/tests/test_ray_debugger.py | 27 +++++++++++++++++++++++++++ python/ray/util/rpdb.py | 7 +++++++ 2 files changed, 34 insertions(+) diff --git a/python/ray/tests/test_ray_debugger.py b/python/ray/tests/test_ray_debugger.py index 8007f8a7a..adea19684 100644 --- a/python/ray/tests/test_ray_debugger.py +++ b/python/ray/tests/test_ray_debugger.py @@ -37,6 +37,33 @@ def test_ray_debugger_breakpoint(shutdown_only): ray.get(result) +@pytest.mark.skipif( + platform.system() == "Windows", reason="Failing on Windows.") +def test_ray_debugger_commands(shutdown_only): + ray.init(num_cpus=2) + + @ray.remote + def f(): + ray.util.pdb.set_trace() + + result1 = f.remote() + result2 = f.remote() + + # Make sure that calling "continue" in the debugger + # gives back control to the debugger loop: + p = pexpect.spawn("ray debug") + p.expect("Enter breakpoint index or press enter to refresh: ") + p.sendline("0") + p.expect("-> ray.util.pdb.set_trace()") + p.sendline("c") + p.expect("Enter breakpoint index or press enter to refresh: ") + p.sendline("0") + p.expect("-> ray.util.pdb.set_trace()") + p.sendline("c") + + ray.get([result1, result2]) + + @pytest.mark.skipif( platform.system() == "Windows", reason="Failing on Windows.") def test_ray_debugger_stepping(shutdown_only): diff --git a/python/ray/util/rpdb.py b/python/ray/util/rpdb.py index 33d430ea0..134e0b4ec 100644 --- a/python/ray/util/rpdb.py +++ b/python/ray/util/rpdb.py @@ -125,6 +125,13 @@ class RemotePdb(Pdb): do_q = do_exit = do_quit + def do_continue(self, arg): + self.__restore() + self.handle.connection.close() + return Pdb.do_continue(self, arg) + + do_c = do_continue + def set_trace(self, frame=None): if frame is None: frame = sys._getframe().f_back From 057687e53423bb5045eb98debfd655b61b987f73 Mon Sep 17 00:00:00 2001 From: SangBin Cho Date: Wed, 16 Dec 2020 21:27:50 -0800 Subject: [PATCH 07/88] [New Scheduler] Fix test_failure.py by supporting infeasible tasks (#12738) * Fix the first issue. * ip * In Progress. * In progress. * done. * Remove unnecessary logs. * Addressed code review + fix some test failures. * Try fixing issues. * Fix issues. * Fix test issues. * Fix issues. * done. --- python/ray/tests/test_failure.py | 13 +- src/ray/object_manager/plasma/store.cc | 2 +- src/ray/raylet/node_manager.cc | 78 ++++---- src/ray/raylet/node_manager.h | 6 + .../scheduling/cluster_resource_scheduler.cc | 16 +- .../scheduling/cluster_resource_scheduler.h | 6 +- .../cluster_resource_scheduler_test.cc | 98 ++++++---- .../raylet/scheduling/cluster_task_manager.cc | 179 +++++++++++++++--- .../raylet/scheduling/cluster_task_manager.h | 19 +- .../scheduling/cluster_task_manager_test.cc | 108 ++++++++++- 10 files changed, 397 insertions(+), 128 deletions(-) diff --git a/python/ray/tests/test_failure.py b/python/ray/tests/test_failure.py index 89167f475..84904a5e1 100644 --- a/python/ray/tests/test_failure.py +++ b/python/ray/tests/test_failure.py @@ -16,14 +16,8 @@ import ray.utils import ray.ray_constants as ray_constants from ray.exceptions import RayTaskError from ray.cluster_utils import Cluster -from ray.test_utils import ( - wait_for_condition, - SignalActor, - init_error_pubsub, - get_error_message, - Semaphore, - new_scheduler_enabled, -) +from ray.test_utils import (wait_for_condition, SignalActor, init_error_pubsub, + get_error_message, Semaphore) def test_failed_task(ray_start_regular, error_pubsub): @@ -663,7 +657,6 @@ def test_warning_for_resource_deadlock(error_pubsub, shutdown_only): assert errors[0].type == ray_constants.RESOURCE_DEADLOCK_ERROR -@pytest.mark.skipif(new_scheduler_enabled(), reason="broken") def test_warning_for_infeasible_tasks(ray_start_regular, error_pubsub): p = error_pubsub # Check that we get warning messages for infeasible tasks. @@ -689,7 +682,6 @@ def test_warning_for_infeasible_tasks(ray_start_regular, error_pubsub): assert errors[0].type == ray_constants.INFEASIBLE_TASK_ERROR -@pytest.mark.skipif(new_scheduler_enabled(), reason="broken") def test_warning_for_infeasible_zero_cpu_actor(shutdown_only): # Check that we cannot place an actor on a 0 CPU machine and that we get an # infeasibility warning (even though the actor creation task itself @@ -956,7 +948,6 @@ def test_raylet_crash_when_get(ray_start_regular): thread.join() -@pytest.mark.skipif(new_scheduler_enabled(), reason="broken") def test_connect_with_disconnected_node(shutdown_only): config = { "num_heartbeats_timeout": 50, diff --git a/src/ray/object_manager/plasma/store.cc b/src/ray/object_manager/plasma/store.cc index 1ff2dd29d..0e1fa3af2 100644 --- a/src/ray/object_manager/plasma/store.cc +++ b/src/ray/object_manager/plasma/store.cc @@ -224,7 +224,7 @@ uint8_t *PlasmaStore::AllocateMemory(size_t size, bool evict_if_full, MEMFD_TYPE // make room. if (space_needed > 0) { if (spill_objects_callback_) { - // If the space needed is too small, we'd like to bump up to the minimum spilling + // If the space needed is too small, we'd like to bump up to the minimum // size. Cap the max size to be lower than the plasma store limit. int64_t byte_to_spill = std::min(PlasmaAllocator::GetFootprintLimit(), diff --git a/src/ray/raylet/node_manager.cc b/src/ray/raylet/node_manager.cc index 31ef907e0..e86975ba0 100644 --- a/src/ray/raylet/node_manager.cc +++ b/src/ray/raylet/node_manager.cc @@ -221,9 +221,12 @@ NodeManager::NodeManager(boost::asio::io_service &io_service, const NodeID &self return !(failed_workers_cache_.count(owner_worker_id) > 0 || failed_nodes_cache_.count(owner_node_id) > 0); }; + auto announce_infeasible_task = [this](const Task &task) { + PublishInfeasibleTaskError(task); + }; cluster_task_manager_ = std::shared_ptr(new ClusterTaskManager( self_node_id_, new_resource_scheduler_, fulfills_dependencies_func, - is_owner_alive, get_node_info_func)); + is_owner_alive, get_node_info_func, announce_infeasible_task)); } RAY_CHECK_OK(store_client_.Connect(config.store_socket_name.c_str())); @@ -930,6 +933,7 @@ void NodeManager::ResourceDeleted(const NodeID &node_id, void NodeManager::TryLocalInfeasibleTaskScheduling() { RAY_LOG(DEBUG) << "[LocalResourceUpdateRescheduler] The resource update is on the " "local node, check if we can reschedule tasks"; + SchedulingResources &new_local_resources = cluster_resource_map_[self_node_id_]; // SpillOver locally to figure out which infeasible tasks can be placed now @@ -2006,41 +2010,7 @@ void NodeManager::ScheduleTasks( for (const auto &task : local_queues_.GetTasks(TaskState::PLACEABLE)) { task_dependency_manager_.TaskPending(task); move_task_set.insert(task.GetTaskSpecification().TaskId()); - - // This block is used to suppress infeasible task warning. - bool suppress_warning = false; - const auto &required_resources = task.GetTaskSpecification().GetRequiredResources(); - const auto &resources_map = required_resources.GetResourceMap(); - const auto &it = resources_map.begin(); - // It is a hack to suppress infeasible task warning. - // If the first resource of a task requires this magic number, infeasible warning is - // suppressed. It is currently only used by placement group ready API. We don't want - // to have this in ray_config_def.h because the use case is very narrow, and we don't - // want to expose this anywhere. - double INFEASIBLE_TASK_SUPPRESS_MAGIC_NUMBER = 0.0101; - if (it != resources_map.end() && - it->second == INFEASIBLE_TASK_SUPPRESS_MAGIC_NUMBER) { - suppress_warning = true; - } - - // Push a warning to the task's driver that this task is currently infeasible. - if (!suppress_warning) { - // TODO(rkn): Define this constant somewhere else. - std::string type = "infeasible_task"; - std::ostringstream error_message; - error_message - << "The actor or task with ID " << task.GetTaskSpecification().TaskId() - << " cannot be scheduled right now. It requires " - << task.GetTaskSpecification().GetRequiredPlacementResources().ToString() - << " for placement, however the cluster currently cannot provide the requested " - "resources. The required resources may be added as autoscaling takes place " - "or placement groups are scheduled. Otherwise, consider reducing the " - "resource requirements of the task."; - auto error_data_ptr = - gcs::CreateErrorTableData(type, error_message.str(), current_time_ms(), - task.GetTaskSpecification().JobId()); - RAY_CHECK_OK(gcs_client_->Errors().AsyncReportJobError(error_data_ptr, nullptr)); - } + PublishInfeasibleTaskError(task); // Assert that this placeable task is not feasible locally (necessary but not // sufficient). RAY_CHECK(!task.GetTaskSpecification().GetRequiredPlacementResources().IsSubset( @@ -3198,6 +3168,42 @@ void NodeManager::RecordMetrics() { local_queues_.RecordMetrics(); } +void NodeManager::PublishInfeasibleTaskError(const Task &task) const { + // This block is used to suppress infeasible task warning. + bool suppress_warning = false; + const auto &required_resources = task.GetTaskSpecification().GetRequiredResources(); + const auto &resources_map = required_resources.GetResourceMap(); + const auto &it = resources_map.begin(); + // It is a hack to suppress infeasible task warning. + // If the first resource of a task requires this magic number, infeasible warning is + // suppressed. It is currently only used by placement group ready API. We don't want + // to have this in ray_config_def.h because the use case is very narrow, and we don't + // want to expose this anywhere. + double INFEASIBLE_TASK_SUPPRESS_MAGIC_NUMBER = 0.0101; + if (it != resources_map.end() && it->second == INFEASIBLE_TASK_SUPPRESS_MAGIC_NUMBER) { + suppress_warning = true; + } + + // Push a warning to the task's driver that this task is currently infeasible. + if (!suppress_warning) { + // TODO(rkn): Define this constant somewhere else. + std::string type = "infeasible_task"; + std::ostringstream error_message; + error_message + << "The actor or task with ID " << task.GetTaskSpecification().TaskId() + << " cannot be scheduled right now. It requires " + << task.GetTaskSpecification().GetRequiredPlacementResources().ToString() + << " for placement, however the cluster currently cannot provide the requested " + "resources. The required resources may be added as autoscaling takes place " + "or placement groups are scheduled. Otherwise, consider reducing the " + "resource requirements of the task."; + auto error_data_ptr = + gcs::CreateErrorTableData(type, error_message.str(), current_time_ms(), + task.GetTaskSpecification().JobId()); + RAY_CHECK_OK(gcs_client_->Errors().AsyncReportJobError(error_data_ptr, nullptr)); + } +} + } // namespace raylet } // namespace ray diff --git a/src/ray/raylet/node_manager.h b/src/ray/raylet/node_manager.h index f1e9ffad0..bdc8f5c47 100644 --- a/src/ray/raylet/node_manager.h +++ b/src/ray/raylet/node_manager.h @@ -649,6 +649,12 @@ class NodeManager : public rpc::NodeManagerServiceHandler { /// \return Whether the resource is returned successfully. bool ReturnBundleResources(const BundleSpecification &bundle_spec); + /// Publish the infeasible task error to GCS so that drivers can subscribe to it and + /// print. + /// + /// \param task Task that is infeasible + void PublishInfeasibleTaskError(const Task &task) const; + /// ID of this node. NodeID self_node_id_; boost::asio::io_service &io_service_; diff --git a/src/ray/raylet/scheduling/cluster_resource_scheduler.cc b/src/ray/raylet/scheduling/cluster_resource_scheduler.cc index 2245760c2..2590ed98f 100644 --- a/src/ray/raylet/scheduling/cluster_resource_scheduler.cc +++ b/src/ray/raylet/scheduling/cluster_resource_scheduler.cc @@ -172,7 +172,8 @@ int64_t ClusterResourceScheduler::IsSchedulable(const TaskRequest &task_req, int64_t ClusterResourceScheduler::GetBestSchedulableNode(const TaskRequest &task_req, bool actor_creation, - int64_t *total_violations) { + int64_t *total_violations, + bool *is_infeasible) { // Minimum number of soft violations across all nodes that can schedule the request. // We will pick the node with the smallest number of soft violations. int64_t min_violations = INT_MAX; @@ -248,20 +249,23 @@ int64_t ClusterResourceScheduler::GetBestSchedulableNode(const TaskRequest &task best_node = node.first; } if (violations == 0) { - *total_violations = 0; - return best_node; + // If violation is 0, we can schedule the task. So just break the loop + break; } } *total_violations = min_violations; + // If there's no best node, and the task is not feasible locally, + // it means the task is infeasible. + *is_infeasible = best_node == -1 && !local_node_feasible; return best_node; } std::string ClusterResourceScheduler::GetBestSchedulableNode( const std::unordered_map &task_resources, bool actor_creation, - int64_t *total_violations) { + int64_t *total_violations, bool *is_infeasible) { TaskRequest task_request = ResourceMapToTaskRequest(string_to_int_map_, task_resources); - int64_t node_id = - GetBestSchedulableNode(task_request, actor_creation, total_violations); + int64_t node_id = GetBestSchedulableNode(task_request, actor_creation, total_violations, + is_infeasible); std::string id_string; if (node_id == -1) { diff --git a/src/ray/raylet/scheduling/cluster_resource_scheduler.h b/src/ray/raylet/scheduling/cluster_resource_scheduler.h index dfaa61fbd..c4058e586 100644 --- a/src/ray/raylet/scheduling/cluster_resource_scheduler.h +++ b/src/ray/raylet/scheduling/cluster_resource_scheduler.h @@ -132,11 +132,13 @@ class ClusterResourceScheduler { /// \param violations: The number of soft constraint violations associated /// with the node returned by this function (assuming /// a node that can schedule task_req is found). + /// \param is_infeasible[in]: It is set true if the task is not schedulable because it + /// is infeasible. /// /// \return -1, if no node can schedule the current request; otherwise, /// return the ID of a node that can schedule the task request. int64_t GetBestSchedulableNode(const TaskRequest &task_request, bool actor_creation, - int64_t *violations); + int64_t *violations, bool *is_infeasible); /// Similar to /// int64_t GetBestSchedulableNode(const TaskRequest &task_request, int64_t @@ -147,7 +149,7 @@ class ClusterResourceScheduler { // task request. std::string GetBestSchedulableNode( const std::unordered_map &task_request, bool actor_creation, - int64_t *violations); + int64_t *violations, bool *is_infeasible); /// Return resources associated to the given node_id in ret_resources. /// If node_id not found, return false; otherwise return true. diff --git a/src/ray/raylet/scheduling/cluster_resource_scheduler_test.cc b/src/ray/raylet/scheduling/cluster_resource_scheduler_test.cc index a43ed2e4c..37bf4b3da 100644 --- a/src/ray/raylet/scheduling/cluster_resource_scheduler_test.cc +++ b/src/ray/raylet/scheduling/cluster_resource_scheduler_test.cc @@ -346,8 +346,9 @@ TEST_F(ClusterResourceSchedulerTest, SchedulingUpdateAvailableResourcesTest) { initTaskRequest(task_req, pred_demands, pred_soft, cust_ids, cust_demands, cust_soft, EmptyIntVector); int64_t violations; - int64_t node_id = - cluster_resources.GetBestSchedulableNode(task_req, false, &violations); + bool is_infeasible; + int64_t node_id = cluster_resources.GetBestSchedulableNode( + task_req, false, &violations, &is_infeasible); ASSERT_TRUE(node_id != -1); ASSERT_EQ(node_id, 1); ASSERT_TRUE(violations > 0); @@ -446,8 +447,9 @@ TEST_F(ClusterResourceSchedulerTest, SchedulingTaskRequestTest) { initTaskRequest(task_req, pred_demands, pred_soft, EmptyIntVector, EmptyFixedPointVector, EmptyBoolVector, EmptyIntVector); int64_t violations; - int64_t node_id = - cluster_resources.GetBestSchedulableNode(task_req, false, &violations); + bool is_infeasible; + int64_t node_id = cluster_resources.GetBestSchedulableNode( + task_req, false, &violations, &is_infeasible); ASSERT_EQ(node_id, -1); } // Predefined resources, soft constraint violation @@ -458,8 +460,9 @@ TEST_F(ClusterResourceSchedulerTest, SchedulingTaskRequestTest) { initTaskRequest(task_req, pred_demands, pred_soft, EmptyIntVector, EmptyFixedPointVector, EmptyBoolVector, EmptyIntVector); int64_t violations; - int64_t node_id = - cluster_resources.GetBestSchedulableNode(task_req, false, &violations); + bool is_infeasible; + int64_t node_id = cluster_resources.GetBestSchedulableNode( + task_req, false, &violations, &is_infeasible); ASSERT_TRUE(node_id != -1); ASSERT_TRUE(violations > 0); } @@ -472,8 +475,9 @@ TEST_F(ClusterResourceSchedulerTest, SchedulingTaskRequestTest) { initTaskRequest(task_req, pred_demands, pred_soft, EmptyIntVector, EmptyFixedPointVector, EmptyBoolVector, EmptyIntVector); int64_t violations; - int64_t node_id = - cluster_resources.GetBestSchedulableNode(task_req, false, &violations); + bool is_infeasible; + int64_t node_id = cluster_resources.GetBestSchedulableNode( + task_req, false, &violations, &is_infeasible); ASSERT_TRUE(node_id != -1); ASSERT_TRUE(violations == 0); } @@ -488,8 +492,9 @@ TEST_F(ClusterResourceSchedulerTest, SchedulingTaskRequestTest) { initTaskRequest(task_req, pred_demands, pred_soft, cust_ids, cust_demands, cust_soft, EmptyIntVector); int64_t violations; - int64_t node_id = - cluster_resources.GetBestSchedulableNode(task_req, false, &violations); + bool is_infeasible; + int64_t node_id = cluster_resources.GetBestSchedulableNode( + task_req, false, &violations, &is_infeasible); ASSERT_TRUE(node_id == -1); } // Custom resources, soft constraint violation. @@ -503,8 +508,9 @@ TEST_F(ClusterResourceSchedulerTest, SchedulingTaskRequestTest) { initTaskRequest(task_req, pred_demands, pred_soft, cust_ids, cust_demands, cust_soft, EmptyIntVector); int64_t violations; - int64_t node_id = - cluster_resources.GetBestSchedulableNode(task_req, false, &violations); + bool is_infeasible; + int64_t node_id = cluster_resources.GetBestSchedulableNode( + task_req, false, &violations, &is_infeasible); ASSERT_TRUE(node_id != -1); ASSERT_TRUE(violations > 0); } @@ -519,8 +525,9 @@ TEST_F(ClusterResourceSchedulerTest, SchedulingTaskRequestTest) { initTaskRequest(task_req, pred_demands, pred_soft, cust_ids, cust_demands, cust_soft, EmptyIntVector); int64_t violations; - int64_t node_id = - cluster_resources.GetBestSchedulableNode(task_req, false, &violations); + bool is_infeasible; + int64_t node_id = cluster_resources.GetBestSchedulableNode( + task_req, false, &violations, &is_infeasible); ASSERT_TRUE(node_id != -1); ASSERT_TRUE(violations == 0); } @@ -535,8 +542,9 @@ TEST_F(ClusterResourceSchedulerTest, SchedulingTaskRequestTest) { initTaskRequest(task_req, pred_demands, pred_soft, cust_ids, cust_demands, cust_soft, EmptyIntVector); int64_t violations; - int64_t node_id = - cluster_resources.GetBestSchedulableNode(task_req, false, &violations); + bool is_infeasible; + int64_t node_id = cluster_resources.GetBestSchedulableNode( + task_req, false, &violations, &is_infeasible); ASSERT_TRUE(node_id == -1); } // Custom resource missing, soft constraint violation. @@ -550,8 +558,9 @@ TEST_F(ClusterResourceSchedulerTest, SchedulingTaskRequestTest) { initTaskRequest(task_req, pred_demands, pred_soft, cust_ids, cust_demands, cust_soft, EmptyIntVector); int64_t violations; - int64_t node_id = - cluster_resources.GetBestSchedulableNode(task_req, false, &violations); + bool is_infeasible; + int64_t node_id = cluster_resources.GetBestSchedulableNode( + task_req, false, &violations, &is_infeasible); ASSERT_TRUE(node_id != -1); ASSERT_TRUE(violations > 0); } @@ -567,8 +576,9 @@ TEST_F(ClusterResourceSchedulerTest, SchedulingTaskRequestTest) { initTaskRequest(task_req, pred_demands, pred_soft, cust_ids, cust_demands, cust_soft, placement_hints); int64_t violations; - int64_t node_id = - cluster_resources.GetBestSchedulableNode(task_req, false, &violations); + bool is_infeasible; + int64_t node_id = cluster_resources.GetBestSchedulableNode( + task_req, false, &violations, &is_infeasible); ASSERT_TRUE(node_id != -1); ASSERT_TRUE(violations > 0); } @@ -584,8 +594,9 @@ TEST_F(ClusterResourceSchedulerTest, SchedulingTaskRequestTest) { initTaskRequest(task_req, pred_demands, pred_soft, cust_ids, cust_demands, cust_soft, placement_hints); int64_t violations; - int64_t node_id = - cluster_resources.GetBestSchedulableNode(task_req, false, &violations); + bool is_infeasible; + int64_t node_id = cluster_resources.GetBestSchedulableNode( + task_req, false, &violations, &is_infeasible); ASSERT_TRUE(node_id != -1); ASSERT_TRUE(violations == 0); } @@ -1007,23 +1018,24 @@ TEST_F(ClusterResourceSchedulerTest, TestAlwaysSpillInfeasibleTask) { // No feasible nodes. int64_t total_violations; - ASSERT_EQ( - cluster_resources.GetBestSchedulableNode(resource_spec, false, &total_violations), - ""); + bool is_infeasible; + ASSERT_EQ(cluster_resources.GetBestSchedulableNode(resource_spec, false, + &total_violations, &is_infeasible), + ""); // Feasible remote node, but doesn't currently have resources available. We // should spill there. cluster_resources.AddOrUpdateNode("remote_feasible", resource_spec, {{"CPU", 0.}}); - ASSERT_EQ( - cluster_resources.GetBestSchedulableNode(resource_spec, false, &total_violations), - "remote_feasible"); + ASSERT_EQ(cluster_resources.GetBestSchedulableNode(resource_spec, false, + &total_violations, &is_infeasible), + "remote_feasible"); // Feasible remote node, and it currently has resources available. We should // prefer to spill there. cluster_resources.AddOrUpdateNode("remote_available", resource_spec, resource_spec); - ASSERT_EQ( - cluster_resources.GetBestSchedulableNode(resource_spec, false, &total_violations), - "remote_available"); + ASSERT_EQ(cluster_resources.GetBestSchedulableNode(resource_spec, false, + &total_violations, &is_infeasible), + "remote_available"); } TEST_F(ClusterResourceSchedulerTest, HeartbeatTest) { @@ -1156,18 +1168,22 @@ TEST_F(ClusterResourceSchedulerTest, TestDirtyLocalView) { {{"CPU", num_slots_available}}); auto data = std::make_shared(); int64_t t; + bool is_infeasible; for (int i = 0; i < 3; i++) { // Resource usage report tick should reset the remote node's resources. cluster_resources.FillResourceUsage(true, data); for (int j = 0; j < num_slots_available; j++) { - ASSERT_EQ(cluster_resources.GetBestSchedulableNode(task_spec, false, &t), + ASSERT_EQ(cluster_resources.GetBestSchedulableNode(task_spec, false, &t, + &is_infeasible), "remote"); // Allocate remote resources. ASSERT_TRUE(cluster_resources.AllocateRemoteTaskResources("remote", task_spec)); } // Our local view says there are not enough resources on the remote node to // schedule another task. - ASSERT_EQ(cluster_resources.GetBestSchedulableNode(task_spec, false, &t), ""); + ASSERT_EQ( + cluster_resources.GetBestSchedulableNode(task_spec, false, &t, &is_infeasible), + ""); ASSERT_FALSE( cluster_resources.AllocateLocalTaskResources(task_spec, task_allocation)); ASSERT_FALSE(cluster_resources.AllocateRemoteTaskResources("remote", task_spec)); @@ -1180,25 +1196,31 @@ TEST_F(ClusterResourceSchedulerTest, DynamicResourceTest) { std::unordered_map task_request = {{"CPU", 1}, {"custom123", 2}}; int64_t t; + bool is_infeasible; - std::string result = cluster_resources.GetBestSchedulableNode(task_request, false, &t); + std::string result = + cluster_resources.GetBestSchedulableNode(task_request, false, &t, &is_infeasible); ASSERT_TRUE(result.empty()); cluster_resources.AddLocalResource("custom123", 5); - result = cluster_resources.GetBestSchedulableNode(task_request, false, &t); + result = + cluster_resources.GetBestSchedulableNode(task_request, false, &t, &is_infeasible); ASSERT_FALSE(result.empty()); task_request["custom123"] = 6; - result = cluster_resources.GetBestSchedulableNode(task_request, false, &t); + result = + cluster_resources.GetBestSchedulableNode(task_request, false, &t, &is_infeasible); ASSERT_TRUE(result.empty()); cluster_resources.AddLocalResource("custom123", 5); - result = cluster_resources.GetBestSchedulableNode(task_request, false, &t); + result = + cluster_resources.GetBestSchedulableNode(task_request, false, &t, &is_infeasible); ASSERT_FALSE(result.empty()); cluster_resources.DeleteLocalResource("custom123"); - result = cluster_resources.GetBestSchedulableNode(task_request, false, &t); + result = + cluster_resources.GetBestSchedulableNode(task_request, false, &t, &is_infeasible); ASSERT_TRUE(result.empty()); } diff --git a/src/ray/raylet/scheduling/cluster_task_manager.cc b/src/ray/raylet/scheduling/cluster_task_manager.cc index 4b6e5e0d7..bc86e280f 100644 --- a/src/ray/raylet/scheduling/cluster_task_manager.cc +++ b/src/ray/raylet/scheduling/cluster_task_manager.cc @@ -12,21 +12,26 @@ ClusterTaskManager::ClusterTaskManager( std::shared_ptr cluster_resource_scheduler, std::function fulfills_dependencies_func, std::function is_owner_alive, - NodeInfoGetter get_node_info) + NodeInfoGetter get_node_info, + std::function announce_infeasible_task) : self_node_id_(self_node_id), cluster_resource_scheduler_(cluster_resource_scheduler), fulfills_dependencies_func_(fulfills_dependencies_func), is_owner_alive_(is_owner_alive), get_node_info_(get_node_info), + announce_infeasible_task_(announce_infeasible_task), max_resource_shapes_per_load_report_( RayConfig::instance().max_resource_shapes_per_load_report()), report_worker_backlog_(RayConfig::instance().report_worker_backlog()) {} bool ClusterTaskManager::SchedulePendingTasks() { + // Always try to schedule infeasible tasks in case they are now feasible. + TryLocalInfeasibleTaskScheduling(); bool did_schedule = false; for (auto shapes_it = tasks_to_schedule_.begin(); shapes_it != tasks_to_schedule_.end();) { auto &work_queue = shapes_it->second; + bool is_infeasible = false; for (auto work_it = work_queue.begin(); work_it != work_queue.end();) { // Check every task in task_to_schedule queue to see // whether it can be scheduled. This avoids head-of-line @@ -39,33 +44,46 @@ bool ClusterTaskManager::SchedulePendingTasks() { << task.GetTaskSpecification().TaskId(); auto placement_resources = task.GetTaskSpecification().GetRequiredPlacementResources().GetResourceMap(); + // This argument is used to set violation, which is an unsupported feature now. int64_t _unused; - // TODO (Alex): We should distinguish between infeasible tasks and a fully - // utilized cluster. std::string node_id_string = cluster_resource_scheduler_->GetBestSchedulableNode( placement_resources, task.GetTaskSpecification().IsActorCreationTask(), - &_unused); + &_unused, &is_infeasible); + + // There is no node that has available resources to run the request. + // Move on to the next shape. if (node_id_string.empty()) { - // There is no node that has available resources to run the request. - // Move on to the next shape. - RAY_LOG(DEBUG) << "No feasible node found for task " - << task.GetTaskSpecification().TaskId(); + RAY_LOG(DEBUG) << "No node found to schedule a task " + << task.GetTaskSpecification().TaskId() << " is infeasible?" + << is_infeasible; break; - } else { - if (node_id_string == self_node_id_.Binary()) { - // Warning: WaitForTaskArgsRequests must execute (do not let it short - // circuit if did_schedule is true). - bool task_scheduled = WaitForTaskArgsRequests(work); - did_schedule = task_scheduled || did_schedule; - } else { - // Should spill over to a different node. - NodeID node_id = NodeID::FromBinary(node_id_string); - Spillback(node_id, work); - } - work_it = work_queue.erase(work_it); } + + if (node_id_string == self_node_id_.Binary()) { + // Warning: WaitForTaskArgsRequests must execute (do not let it short + // circuit if did_schedule is true). + bool task_scheduled = WaitForTaskArgsRequests(work); + did_schedule = task_scheduled || did_schedule; + } else { + // Should spill over to a different node. + NodeID node_id = NodeID::FromBinary(node_id_string); + Spillback(node_id, work); + } + work_it = work_queue.erase(work_it); } - if (work_queue.empty()) { + + if (is_infeasible) { + RAY_CHECK(!work_queue.empty()); + // Only announce the first item as infeasible. + auto &work_queue = shapes_it->second; + const auto &work = work_queue[0]; + const Task task = std::get<0>(work); + announce_infeasible_task_(task); + + // TODO(sang): Use a shared pointer deque to reduce copy overhead. + infeasible_tasks_[shapes_it->first] = shapes_it->second; + shapes_it = tasks_to_schedule_.erase(shapes_it); + } else if (work_queue.empty()) { shapes_it = tasks_to_schedule_.erase(shapes_it); } else { shapes_it++; @@ -172,9 +190,12 @@ bool ClusterTaskManager::AttemptDispatchWork(const Work &work, // Spill at most one task from this queue, then move on to the next // queue. int64_t _unused; + bool is_infeasible; auto placement_resources = spec.GetRequiredPlacementResources().GetResourceMap(); std::string node_id_string = cluster_resource_scheduler_->GetBestSchedulableNode( - placement_resources, spec.IsActorCreationTask(), &_unused); + placement_resources, spec.IsActorCreationTask(), &_unused, &is_infeasible); + RAY_CHECK(!is_infeasible) + << "Task cannot be infeasible when it is about to be dispatched"; if (node_id_string != self_node_id_.Binary() && !node_id_string.empty()) { NodeID node_id = NodeID::FromBinary(node_id_string); Spillback(node_id, work); @@ -201,7 +222,13 @@ void ClusterTaskManager::QueueTask(const Task &task, rpc::RequestWorkerLeaseRepl RAY_LOG(DEBUG) << "Queuing task " << task.GetTaskSpecification().TaskId(); Work work = std::make_tuple(task, reply, callback); const auto &scheduling_class = task.GetTaskSpecification().GetSchedulingClass(); - tasks_to_schedule_[scheduling_class].push_back(work); + // If the scheduling class is infeasible, just add the work to the infeasible queue + // directly. + if (infeasible_tasks_.count(scheduling_class) > 0) { + infeasible_tasks_[scheduling_class].push_back(work); + } else { + tasks_to_schedule_[scheduling_class].push_back(work); + } AddToBacklogTracker(task); } @@ -236,6 +263,8 @@ void ReplyCancelled(Work &work) { } bool ClusterTaskManager::CancelTask(const TaskID &task_id) { + // TODO(sang): There are lots of repetitive code around task backlogs. We should + // refactor them. for (auto shapes_it = tasks_to_schedule_.begin(); shapes_it != tasks_to_schedule_.end(); shapes_it++) { auto &work_queue = shapes_it->second; @@ -270,6 +299,23 @@ bool ClusterTaskManager::CancelTask(const TaskID &task_id) { } } + for (auto shapes_it = infeasible_tasks_.begin(); shapes_it != infeasible_tasks_.end(); + shapes_it++) { + auto &work_queue = shapes_it->second; + for (auto work_it = work_queue.begin(); work_it != work_queue.end(); work_it++) { + const auto &task = std::get<0>(*work_it); + if (task.GetTaskSpecification().TaskId() == task_id) { + RemoveFromBacklogTracker(task); + ReplyCancelled(*work_it); + work_queue.erase(work_it); + if (work_queue.empty()) { + infeasible_tasks_.erase(shapes_it); + } + return true; + } + } + } + auto iter = waiting_tasks_.find(task_id); if (iter != waiting_tasks_.end()) { const auto &task = std::get<0>(iter->second); @@ -369,13 +415,8 @@ void ClusterTaskManager::FillResourceUsage( // If a task is not feasible on the local node it will not be feasible on any other // node in the cluster. See the scheduling policy defined by // ClusterResourceScheduler::GetBestSchedulableNode for more details. - if (cluster_resource_scheduler_->IsLocallyFeasible(resources)) { - int num_ready = by_shape_entry->num_ready_requests_queued(); - by_shape_entry->set_num_ready_requests_queued(num_ready + count); - } else { - int num_infeasible = by_shape_entry->num_infeasible_requests_queued(); - by_shape_entry->set_num_infeasible_requests_queued(num_infeasible + count); - } + int num_ready = by_shape_entry->num_ready_requests_queued(); + by_shape_entry->set_num_ready_requests_queued(num_ready + count); auto backlog_it = backlog_tracker_.find(scheduling_class); if (backlog_it != backlog_tracker_.end()) { by_shape_entry->set_backlog_size(backlog_it->second); @@ -417,6 +458,45 @@ void ClusterTaskManager::FillResourceUsage( by_shape_entry->set_backlog_size(backlog_it->second); } } + + for (const auto &pair : infeasible_tasks_) { + const auto &scheduling_class = pair.first; + if (scheduling_class == one_cpu_scheduling_cls) { + continue; + } + if (num_reported++ >= max_resource_shapes_per_load_report_ && + max_resource_shapes_per_load_report_ >= 0) { + // TODO (Alex): It's possible that we skip a different scheduling key which contains + // the same resources. + break; + } + const auto &resources = + TaskSpecification::GetSchedulingClassDescriptor(scheduling_class) + .GetResourceMap(); + const auto &queue = pair.second; + const auto &count = queue.size(); + + auto by_shape_entry = resource_load_by_shape->Add(); + for (const auto &resource : resources) { + // Add to `resource_loads`. + const auto &label = resource.first; + const auto &quantity = resource.second; + (*resource_loads)[label] += quantity * count; + + // Add to `resource_load_by_shape`. + (*by_shape_entry->mutable_shape())[label] = quantity; + } + + // If a task is not feasible on the local node it will not be feasible on any other + // node in the cluster. See the scheduling policy defined by + // ClusterResourceScheduler::GetBestSchedulableNode for more details. + int num_infeasible = by_shape_entry->num_infeasible_requests_queued(); + by_shape_entry->set_num_infeasible_requests_queued(num_infeasible + count); + auto backlog_it = backlog_tracker_.find(scheduling_class); + if (backlog_it != backlog_tracker_.end()) { + by_shape_entry->set_backlog_size(backlog_it->second); + } + } } std::string ClusterTaskManager::DebugString() const { @@ -425,12 +505,50 @@ std::string ClusterTaskManager::DebugString() const { buffer << "Schedule queue length: " << tasks_to_schedule_.size() << "\n"; buffer << "Dispatch queue length: " << tasks_to_dispatch_.size() << "\n"; buffer << "Waiting tasks size: " << waiting_tasks_.size() << "\n"; + buffer << "infeasible queue length size: " << infeasible_tasks_.size() << "\n"; buffer << "cluster_resource_scheduler state: " << cluster_resource_scheduler_->DebugString() << "\n"; buffer << "=================================================="; return buffer.str(); } +void ClusterTaskManager::TryLocalInfeasibleTaskScheduling() { + for (auto shapes_it = infeasible_tasks_.begin(); + shapes_it != infeasible_tasks_.end();) { + auto &work_queue = shapes_it->second; + RAY_CHECK(!work_queue.empty()) + << "Empty work queue shouldn't have been added as a infeasible shape."; + // We only need to check the first item because every task has the same shape. + // If the first entry is infeasible, that means everything else is the same. + const auto work = work_queue[0]; + Task task = std::get<0>(work); + RAY_LOG(DEBUG) << "Check if the infeasible task is schedulable in any node. task_id:" + << task.GetTaskSpecification().TaskId(); + auto placement_resources = + task.GetTaskSpecification().GetRequiredPlacementResources().GetResourceMap(); + // This argument is used to set violation, which is an unsupported feature now. + int64_t _unused; + bool is_infeasible; + std::string node_id_string = cluster_resource_scheduler_->GetBestSchedulableNode( + placement_resources, task.GetTaskSpecification().IsActorCreationTask(), &_unused, + &is_infeasible); + + // There is no node that has available resources to run the request. + // Move on to the next shape. + if (is_infeasible) { + RAY_LOG(DEBUG) << "No feasible node found for task " + << task.GetTaskSpecification().TaskId(); + shapes_it++; + } else { + RAY_LOG(DEBUG) << "Infeasible task of task id " + << task.GetTaskSpecification().TaskId() + << " is now feasible. Move the entry back to tasks_to_schedule_"; + tasks_to_schedule_[shapes_it->first] = shapes_it->second; + shapes_it = infeasible_tasks_.erase(shapes_it); + } + } +} + void ClusterTaskManager::Dispatch( std::shared_ptr worker, std::unordered_map> &leased_workers, @@ -492,7 +610,6 @@ void ClusterTaskManager::Dispatch( } } } - // Send the result back. send_reply_callback(); } diff --git a/src/ray/raylet/scheduling/cluster_task_manager.h b/src/ray/raylet/scheduling/cluster_task_manager.h index 6ec2db994..995273ed5 100644 --- a/src/ray/raylet/scheduling/cluster_task_manager.h +++ b/src/ray/raylet/scheduling/cluster_task_manager.h @@ -49,12 +49,14 @@ class ClusterTaskManager { /// \param fulfills_dependencies_func: Returns true if all of a task's /// dependencies are fulfilled. /// \param is_owner_alive: A callback which returns if the owner process is alive - /// (according to our ownership model). \param gcs_client: A gcs client. + /// (according to our ownership model). + /// \param gcs_client: A gcs client. ClusterTaskManager(const NodeID &self_node_id, std::shared_ptr cluster_resource_scheduler, std::function fulfills_dependencies_func, std::function is_owner_alive, - NodeInfoGetter get_node_info); + NodeInfoGetter get_node_info, + std::function announce_infeasible_task); /// (Step 2) For each task in tasks_to_schedule_, pick a node in the system /// (local or remote) that has enough resources available to run the task, if @@ -122,11 +124,20 @@ class ClusterTaskManager { bool AttemptDispatchWork(const Work &work, std::shared_ptr &worker, bool *worker_leased); + /// Reiterate all local infeasible tasks and register them to task_to_schedule_ if it + /// becomes feasible to schedule. + void TryLocalInfeasibleTaskScheduling(); + const NodeID &self_node_id_; std::shared_ptr cluster_resource_scheduler_; + /// Function to make task dependencies to be local. std::function fulfills_dependencies_func_; + /// Function to check if the owner is alive on a given node. std::function is_owner_alive_; + /// Function to get the node information of a given node id. NodeInfoGetter get_node_info_; + /// Function to announce infeasible task to GCS. + std::function announce_infeasible_task_; const int max_resource_shapes_per_load_report_; const bool report_worker_backlog_; @@ -143,6 +154,10 @@ class ClusterTaskManager { /// Tasks move from waiting -> dispatch. absl::flat_hash_map waiting_tasks_; + /// Queue of lease requests that are infeasible. + /// Tasks go between scheduling <-> infeasible. + std::unordered_map> infeasible_tasks_; + /// Track the cumulative backlog of all workers requesting a lease to this raylet. std::unordered_map backlog_tracker_; diff --git a/src/ray/raylet/scheduling/cluster_task_manager_test.cc b/src/ray/raylet/scheduling/cluster_task_manager_test.cc index ddda8fed5..023390632 100644 --- a/src/ray/raylet/scheduling/cluster_task_manager_test.cc +++ b/src/ray/raylet/scheduling/cluster_task_manager_test.cc @@ -101,6 +101,7 @@ class ClusterTaskManagerTest : public ::testing::Test { dependencies_fulfilled_(true), is_owner_alive_(true), node_info_calls_(0), + announce_infeasible_task_calls_(0), task_manager_(id_, scheduler_, [this](const Task &_task) { fulfills_dependencies_calls_++; @@ -112,7 +113,8 @@ class ClusterTaskManagerTest : public ::testing::Test { [this](const NodeID &node_id) { node_info_calls_++; return node_info_[node_id]; - }) {} + }, + [this](const Task &task) { announce_infeasible_task_calls_++; }) {} void SetUp() {} @@ -141,6 +143,7 @@ class ClusterTaskManagerTest : public ::testing::Test { bool is_owner_alive_; int node_info_calls_; + int announce_infeasible_task_calls_; std::unordered_map> node_info_; ClusterTaskManager task_manager_; @@ -371,6 +374,43 @@ TEST_F(ClusterTaskManagerTest, TaskCancellationTest) { ASSERT_EQ(leased_workers_.size(), 1); } +TEST_F(ClusterTaskManagerTest, TaskCancelInfeasibleTask) { + /* Make sure cancelTask works for infeasible tasks */ + std::shared_ptr worker = + std::make_shared(WorkerID::FromRandom(), 1234); + pool_.PushWorker(std::dynamic_pointer_cast(worker)); + + Task task = CreateTask({{ray::kCPU_ResourceLabel, 12}}); + rpc::RequestWorkerLeaseReply reply; + + bool callback_called = false; + bool *callback_called_ptr = &callback_called; + auto callback = [callback_called_ptr]() { *callback_called_ptr = true; }; + + task_manager_.QueueTask(task, &reply, callback); + + // Task is now queued so cancellation works. + ASSERT_TRUE(task_manager_.CancelTask(task.GetTaskSpecification().TaskId())); + task_manager_.SchedulePendingTasks(); + task_manager_.DispatchScheduledTasksToWorkers(pool_, leased_workers_); + // Task will not execute. + ASSERT_TRUE(callback_called); + ASSERT_TRUE(reply.canceled()); + ASSERT_EQ(leased_workers_.size(), 0); + ASSERT_EQ(pool_.workers.size(), 1); + + // Althoug the feasible node is added, task shouldn't be executed because it is + // cancelled. + auto remote_node_id = NodeID::FromRandom(); + AddNode(remote_node_id, 12); + task_manager_.SchedulePendingTasks(); + task_manager_.DispatchScheduledTasksToWorkers(pool_, leased_workers_); + ASSERT_TRUE(callback_called); + ASSERT_TRUE(reply.canceled()); + ASSERT_EQ(leased_workers_.size(), 0); + ASSERT_EQ(pool_.workers.size(), 1); +} + TEST_F(ClusterTaskManagerTest, HeartbeatTest) { std::shared_ptr worker = std::make_shared(WorkerID::FromRandom(), 1234); @@ -570,6 +610,72 @@ TEST_F(ClusterTaskManagerTest, OwnerDeadTest) { ASSERT_EQ(pool_.workers.size(), 1); } +TEST_F(ClusterTaskManagerTest, TestInfeasibleTaskWarning) { + /* + Test if infeasible tasks warnings are printed. + */ + // Create an infeasible task. + Task task = CreateTask({{ray::kCPU_ResourceLabel, 12}}); + rpc::RequestWorkerLeaseReply reply; + std::shared_ptr callback_occurred = std::make_shared(false); + auto callback = [callback_occurred]() { *callback_occurred = true; }; + task_manager_.QueueTask(task, &reply, callback); + task_manager_.SchedulePendingTasks(); + ASSERT_EQ(announce_infeasible_task_calls_, 1); + + // Infeasible warning shouldn't be reprinted when the previous task is still infeasible + // after adding a new node. + AddNode(NodeID::FromRandom(), 8); + task_manager_.SchedulePendingTasks(); + std::shared_ptr worker = + std::make_shared(WorkerID::FromRandom(), 1234); + pool_.PushWorker(std::dynamic_pointer_cast(worker)); + // Task shouldn't be scheduled yet. + task_manager_.DispatchScheduledTasksToWorkers(pool_, leased_workers_); + ASSERT_EQ(announce_infeasible_task_calls_, 1); + ASSERT_FALSE(*callback_occurred); + ASSERT_EQ(leased_workers_.size(), 0); + ASSERT_EQ(pool_.workers.size(), 1); + + // Now we have a node that is feasible to schedule the task. Make sure the infeasible + // task is spillbacked properly. + auto remote_node_id = NodeID::FromRandom(); + AddNode(remote_node_id, 12); + task_manager_.SchedulePendingTasks(); + task_manager_.DispatchScheduledTasksToWorkers(pool_, leased_workers_); + // Make sure nothing happens locally. + ASSERT_EQ(announce_infeasible_task_calls_, 1); + ASSERT_TRUE(*callback_occurred); + ASSERT_EQ(leased_workers_.size(), 0); + ASSERT_EQ(pool_.workers.size(), 1); + // Make sure the spillback callback is called. + ASSERT_EQ(reply.retry_at_raylet_address().raylet_id(), remote_node_id.Binary()); +} + +TEST_F(ClusterTaskManagerTest, TestMultipleInfeasibleTasksWarnOnce) { + /* + Test infeasible warning is printed only once when the same shape is queued again. + */ + + // Make sure the first infeasible task announces warning. + Task task = CreateTask({{ray::kCPU_ResourceLabel, 12}}); + rpc::RequestWorkerLeaseReply reply; + std::shared_ptr callback_occurred = std::make_shared(false); + auto callback = [callback_occurred]() { *callback_occurred = true; }; + task_manager_.QueueTask(task, &reply, callback); + task_manager_.SchedulePendingTasks(); + ASSERT_EQ(announce_infeasible_task_calls_, 1); + + // Make sure the same shape infeasible task won't be announced. + Task task2 = CreateTask({{ray::kCPU_ResourceLabel, 12}}); + rpc::RequestWorkerLeaseReply reply2; + std::shared_ptr callback_occurred2 = std::make_shared(false); + auto callback2 = [callback_occurred2]() { *callback_occurred2 = true; }; + task_manager_.QueueTask(task2, &reply2, callback2); + task_manager_.SchedulePendingTasks(); + ASSERT_EQ(announce_infeasible_task_calls_, 1); +} + int main(int argc, char **argv) { ::testing::InitGoogleTest(&argc, argv); return RUN_ALL_TESTS(); From 12231ec2a6b623feadf9b695fdb08f987bc4ea45 Mon Sep 17 00:00:00 2001 From: Tao Wang Date: Thu, 17 Dec 2020 14:24:23 +0800 Subject: [PATCH 08/88] Optimize heartbeat manager initialization (#12911) --- src/ray/gcs/gcs_server/gcs_heartbeat_manager.cc | 8 ++++++++ src/ray/gcs/gcs_server/gcs_heartbeat_manager.h | 6 ++++++ src/ray/gcs/gcs_server/gcs_server.cc | 5 ++--- 3 files changed, 16 insertions(+), 3 deletions(-) diff --git a/src/ray/gcs/gcs_server/gcs_heartbeat_manager.cc b/src/ray/gcs/gcs_server/gcs_heartbeat_manager.cc index b16383097..64806982b 100644 --- a/src/ray/gcs/gcs_server/gcs_heartbeat_manager.cc +++ b/src/ray/gcs/gcs_server/gcs_heartbeat_manager.cc @@ -34,6 +34,14 @@ GcsHeartbeatManager::GcsHeartbeatManager( })); } +void GcsHeartbeatManager::Initialize(const GcsInitData &gcs_init_data) { + for (const auto &item : gcs_init_data.Nodes()) { + if (item.second.state() == rpc::GcsNodeInfo::ALIVE) { + heartbeats_.emplace(item.first, num_heartbeats_timeout_); + } + } +} + void GcsHeartbeatManager::Start() { io_service_.post([this] { if (!is_started_) { diff --git a/src/ray/gcs/gcs_server/gcs_heartbeat_manager.h b/src/ray/gcs/gcs_server/gcs_heartbeat_manager.h index 580daa6f3..1febd3ee9 100644 --- a/src/ray/gcs/gcs_server/gcs_heartbeat_manager.h +++ b/src/ray/gcs/gcs_server/gcs_heartbeat_manager.h @@ -43,6 +43,12 @@ class GcsHeartbeatManager : public rpc::HeartbeatInfoHandler { rpc::ReportHeartbeatReply *reply, rpc::SendReplyCallback send_reply_callback) override; + /// Initialize with the gcs tables data synchronously. + /// This should be called when GCS server restarts after a failure. + /// + /// \param gcs_init_data. + void Initialize(const GcsInitData &gcs_init_data); + /// Start node failure detect loop. void Start(); diff --git a/src/ray/gcs/gcs_server/gcs_server.cc b/src/ray/gcs/gcs_server/gcs_server.cc index bf8ca289d..23a12f6ec 100644 --- a/src/ray/gcs/gcs_server/gcs_server.cc +++ b/src/ray/gcs/gcs_server/gcs_server.cc @@ -151,9 +151,8 @@ void GcsServer::InitGcsHeartbeatManager(const GcsInitData &gcs_init_data) { main_service_.post( [this, node_id] { return gcs_node_manager_->OnNodeFailure(node_id); }); }); - for (const auto &node : gcs_init_data.Nodes()) { - gcs_heartbeat_manager_->AddNode(node.first); - } + // Initialize by gcs tables data. + gcs_heartbeat_manager_->Initialize(gcs_init_data); // Register service. heartbeat_info_service_.reset(new rpc::HeartbeatInfoGrpcService( heartbeat_manager_io_service_, *gcs_heartbeat_manager_)); From 40032541dcbb920aea86ad797497a6aa4cda1827 Mon Sep 17 00:00:00 2001 From: Yi Cheng <74173148+ahbone@users.noreply.github.com> Date: Wed, 16 Dec 2020 23:44:28 -0800 Subject: [PATCH 09/88] [core] Introduce fetch_local to `ray.wait` (#12526) --- python/ray/_raylet.pyx | 4 +- python/ray/includes/libcoreworker.pxd | 3 +- python/ray/tests/test_basic.py | 36 ++++++++++++++++ python/ray/worker.py | 8 +++- src/ray/core_worker/core_worker.cc | 21 +++++---- src/ray/core_worker/core_worker.h | 2 +- ...io_ray_runtime_object_NativeObjectStore.cc | 5 ++- .../io_ray_runtime_object_NativeObjectStore.h | 2 +- .../store_provider/plasma_store_provider.cc | 8 ++-- src/ray/core_worker/test/core_worker_test.cc | 4 +- src/ray/object_manager/object_manager.cc | 15 +++---- src/ray/object_manager/object_manager.h | 11 ++--- .../test/object_manager_test.cc | 43 ++++++++++++++----- src/ray/raylet/format/node_manager.fbs | 2 - src/ray/raylet/node_manager.cc | 11 +++-- src/ray/raylet/node_manager.h | 2 +- src/ray/raylet_client/raylet_client.cc | 6 +-- src/ray/raylet_client/raylet_client.h | 6 +-- streaming/src/test/queue_tests_base.h | 2 +- 19 files changed, 122 insertions(+), 69 deletions(-) diff --git a/python/ray/_raylet.pyx b/python/ray/_raylet.pyx index 47215149a..8a216c7cf 100644 --- a/python/ray/_raylet.pyx +++ b/python/ray/_raylet.pyx @@ -1004,7 +1004,7 @@ cdef class CoreWorker: return c_object_id.Binary() def wait(self, object_refs, int num_returns, int64_t timeout_ms, - TaskID current_task_id): + TaskID current_task_id, c_bool fetch_local): cdef: c_vector[CObjectID] wait_ids c_vector[c_bool] results @@ -1013,7 +1013,7 @@ cdef class CoreWorker: wait_ids = ObjectRefsToVector(object_refs) with nogil: check_status(CCoreWorkerProcess.GetCoreWorker().Wait( - wait_ids, num_returns, timeout_ms, &results)) + wait_ids, num_returns, timeout_ms, &results, fetch_local)) assert len(results) == len(object_refs) diff --git a/python/ray/includes/libcoreworker.pxd b/python/ray/includes/libcoreworker.pxd index abf1290b9..7394f68b5 100644 --- a/python/ray/includes/libcoreworker.pxd +++ b/python/ray/includes/libcoreworker.pxd @@ -179,7 +179,8 @@ cdef extern from "ray/core_worker/core_worker.h" nogil: c_bool plasma_objects_only) CRayStatus Contains(const CObjectID &object_id, c_bool *has_object) CRayStatus Wait(const c_vector[CObjectID] &object_ids, int num_objects, - int64_t timeout_ms, c_vector[c_bool] *results) + int64_t timeout_ms, c_vector[c_bool] *results, + c_bool fetch_local) CRayStatus Delete(const c_vector[CObjectID] &object_ids, c_bool local_only, c_bool delete_creating_tasks) CRayStatus TriggerGlobalGC() diff --git a/python/ray/tests/test_basic.py b/python/ray/tests/test_basic.py index 551d51f7f..d0e98972a 100644 --- a/python/ray/tests/test_basic.py +++ b/python/ray/tests/test_basic.py @@ -371,6 +371,42 @@ def test_ray_options(shutdown_only): assert without_options != with_options +@pytest.mark.parametrize( + "ray_start_cluster_head", [{ + "num_cpus": 0, + "object_store_memory": 75 * 1024 * 1024, + }], + indirect=True) +def test_fetch_local(ray_start_cluster_head): + cluster = ray_start_cluster_head + cluster.add_node(num_cpus=2, object_store_memory=75 * 1024 * 1024) + + signal_actor = ray.test_utils.SignalActor.remote() + + @ray.remote + def put(): + ray.wait([signal_actor.wait.remote()]) + return np.random.rand(5 * 1024 * 1024) # 40 MB data + + local_ref = ray.put(np.random.rand(5 * 1024 * 1024)) + remote_ref = put.remote() + # Data is not ready in any node + (ready_ref, remaining_ref) = ray.wait( + [remote_ref], timeout=2, fetch_local=False) + assert (0, 1) == (len(ready_ref), len(remaining_ref)) + ray.wait([signal_actor.send.remote()]) + + # Data is ready in some node, but not local node. + (ready_ref, remaining_ref) = ray.wait([remote_ref], fetch_local=False) + assert (1, 0) == (len(ready_ref), len(remaining_ref)) + (ready_ref, remaining_ref) = ray.wait( + [remote_ref], timeout=2, fetch_local=True) + assert (0, 1) == (len(ready_ref), len(remaining_ref)) + del local_ref + (ready_ref, remaining_ref) = ray.wait([remote_ref], fetch_local=True) + assert (1, 0) == (len(ready_ref), len(remaining_ref)) + + def test_nested_functions(ray_start_shared_local_modes): # Make sure that remote functions can use other values that are defined # after the remote function but before the first function invocation. diff --git a/python/ray/worker.py b/python/ray/worker.py index cc231f7fa..495478ad7 100644 --- a/python/ray/worker.py +++ b/python/ray/worker.py @@ -1417,7 +1417,7 @@ def put(value): blocking_wait_inside_async_warned = False -def wait(object_refs, *, num_returns=1, timeout=None): +def wait(object_refs, *, num_returns=1, timeout=None, fetch_local=True): """Return a list of IDs that are ready and a list of IDs that are not. If timeout is set, the function returns either when the requested number of @@ -1445,6 +1445,11 @@ def wait(object_refs, *, num_returns=1, timeout=None): num_returns (int): The number of object refs that should be returned. timeout (float): The maximum amount of time in seconds to wait before returning. + fetch_local (bool): If True, wait for the object to be downloaded onto + the local node before returning it as ready. If False, ray.wait() + will not trigger fetching of objects to the local node and will + return immediately once the object is available anywhere in the + cluster. Returns: A list of object refs that are ready and a list of the remaining object @@ -1507,6 +1512,7 @@ def wait(object_refs, *, num_returns=1, timeout=None): num_returns, timeout_milliseconds, worker.current_task_id, + fetch_local, ) return ready_ids, remaining_ids diff --git a/src/ray/core_worker/core_worker.cc b/src/ray/core_worker/core_worker.cc index 2aba250a5..9bd4bf1f4 100644 --- a/src/ray/core_worker/core_worker.cc +++ b/src/ray/core_worker/core_worker.cc @@ -1061,7 +1061,8 @@ void RetryObjectInPlasmaErrors(std::shared_ptr &memory_st } Status CoreWorker::Wait(const std::vector &ids, int num_objects, - int64_t timeout_ms, std::vector *results) { + int64_t timeout_ms, std::vector *results, + bool fetch_local) { results->resize(ids.size(), false); if (num_objects <= 0 || num_objects > static_cast(ids.size())) { @@ -1082,19 +1083,21 @@ Status CoreWorker::Wait(const std::vector &ids, int num_objects, memory_object_ids, std::min(static_cast(memory_object_ids.size()), num_objects), timeout_ms, worker_context_, &ready)); - RetryObjectInPlasmaErrors(memory_store_, worker_context_, memory_object_ids, - plasma_object_ids, ready); RAY_CHECK(static_cast(ready.size()) <= num_objects); if (timeout_ms > 0) { timeout_ms = std::max(0, static_cast(timeout_ms - (current_time_ms() - start_time))); } - if (static_cast(ready.size()) < num_objects && plasma_object_ids.size() > 0) { - RAY_RETURN_NOT_OK(plasma_store_provider_->Wait( - plasma_object_ids, - std::min(static_cast(plasma_object_ids.size()), - num_objects - static_cast(ready.size())), - timeout_ms, worker_context_, &ready)); + if (fetch_local) { + RetryObjectInPlasmaErrors(memory_store_, worker_context_, memory_object_ids, + plasma_object_ids, ready); + if (static_cast(ready.size()) < num_objects && plasma_object_ids.size() > 0) { + RAY_RETURN_NOT_OK(plasma_store_provider_->Wait( + plasma_object_ids, + std::min(static_cast(plasma_object_ids.size()), + num_objects - static_cast(ready.size())), + timeout_ms, worker_context_, &ready)); + } } RAY_CHECK(static_cast(ready.size()) <= num_objects); diff --git a/src/ray/core_worker/core_worker.h b/src/ray/core_worker/core_worker.h index 5e2770b71..4ecbe04d9 100644 --- a/src/ray/core_worker/core_worker.h +++ b/src/ray/core_worker/core_worker.h @@ -564,7 +564,7 @@ class CoreWorker : public rpc::CoreWorkerServiceHandler { /// \param[out] results A bitset that indicates each object has appeared or not. /// \return Status. Status Wait(const std::vector &object_ids, const int num_objects, - const int64_t timeout_ms, std::vector *results); + const int64_t timeout_ms, std::vector *results, bool fetch_local); /// Delete a list of objects from the plasma object store. /// diff --git a/src/ray/core_worker/lib/java/io_ray_runtime_object_NativeObjectStore.cc b/src/ray/core_worker/lib/java/io_ray_runtime_object_NativeObjectStore.cc index f14853002..b62b19818 100644 --- a/src/ray/core_worker/lib/java/io_ray_runtime_object_NativeObjectStore.cc +++ b/src/ray/core_worker/lib/java/io_ray_runtime_object_NativeObjectStore.cc @@ -100,7 +100,8 @@ JNIEXPORT jobject JNICALL Java_io_ray_runtime_object_NativeObjectStore_nativeGet } JNIEXPORT jobject JNICALL Java_io_ray_runtime_object_NativeObjectStore_nativeWait( - JNIEnv *env, jclass, jobject objectIds, jint numObjects, jlong timeoutMs) { + JNIEnv *env, jclass, jobject objectIds, jint numObjects, jlong timeoutMs, + jboolean fetch_local) { std::vector object_ids; JavaListToNativeVector( env, objectIds, &object_ids, [](JNIEnv *env, jobject id) { @@ -108,7 +109,7 @@ JNIEXPORT jobject JNICALL Java_io_ray_runtime_object_NativeObjectStore_nativeWai }); std::vector results; auto status = ray::CoreWorkerProcess::GetCoreWorker().Wait( - object_ids, (int)numObjects, (int64_t)timeoutMs, &results); + object_ids, (int)numObjects, (int64_t)timeoutMs, &results, (bool)fetch_local); THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, nullptr); return NativeVectorToJavaList(env, results, [](JNIEnv *env, const bool &item) { jobject java_item = diff --git a/src/ray/core_worker/lib/java/io_ray_runtime_object_NativeObjectStore.h b/src/ray/core_worker/lib/java/io_ray_runtime_object_NativeObjectStore.h index 0da1aba92..4e11c0456 100644 --- a/src/ray/core_worker/lib/java/io_ray_runtime_object_NativeObjectStore.h +++ b/src/ray/core_worker/lib/java/io_ray_runtime_object_NativeObjectStore.h @@ -55,7 +55,7 @@ JNIEXPORT jobject JNICALL Java_io_ray_runtime_object_NativeObjectStore_nativeGet * Signature: (Ljava/util/List;IJ)Ljava/util/List; */ JNIEXPORT jobject JNICALL Java_io_ray_runtime_object_NativeObjectStore_nativeWait( - JNIEnv *, jclass, jobject, jint, jlong); + JNIEnv *, jclass, jobject, jint, jlong, jboolean); /* * Class: io_ray_runtime_object_NativeObjectStore diff --git a/src/ray/core_worker/store_provider/plasma_store_provider.cc b/src/ray/core_worker/store_provider/plasma_store_provider.cc index 2faa8be51..5dca72612 100644 --- a/src/ray/core_worker/store_provider/plasma_store_provider.cc +++ b/src/ray/core_worker/store_provider/plasma_store_provider.cc @@ -335,10 +335,10 @@ Status CoreWorkerPlasmaStoreProvider::Wait( RAY_RETURN_NOT_OK(raylet_client_->NotifyDirectCallTaskBlocked()); } const auto owner_addresses = reference_counter_->GetOwnerAddresses(id_vector); - RAY_RETURN_NOT_OK(raylet_client_->Wait( - id_vector, owner_addresses, num_objects, call_timeout, /*wait_local*/ true, - /*mark_worker_blocked*/ !ctx.CurrentTaskIsDirectCall(), ctx.GetCurrentTaskID(), - &result_pair)); + RAY_RETURN_NOT_OK( + raylet_client_->Wait(id_vector, owner_addresses, num_objects, call_timeout, + /*mark_worker_blocked*/ !ctx.CurrentTaskIsDirectCall(), + ctx.GetCurrentTaskID(), &result_pair)); if (result_pair.first.size() >= static_cast(num_objects)) { should_break = true; diff --git a/src/ray/core_worker/test/core_worker_test.cc b/src/ray/core_worker/test/core_worker_test.cc index 8591fa5df..f06e1a7f4 100644 --- a/src/ray/core_worker/test/core_worker_test.cc +++ b/src/ray/core_worker/test/core_worker_test.cc @@ -811,11 +811,11 @@ TEST_F(SingleNodeTest, TestObjectInterface) { all_ids.push_back(non_existent_id); std::vector wait_results; - RAY_CHECK_OK(core_worker.Wait(all_ids, 2, -1, &wait_results)); + RAY_CHECK_OK(core_worker.Wait(all_ids, 2, -1, &wait_results, true)); ASSERT_EQ(wait_results.size(), 3); ASSERT_EQ(wait_results, std::vector({true, true, false})); - RAY_CHECK_OK(core_worker.Wait(all_ids, 3, 100, &wait_results)); + RAY_CHECK_OK(core_worker.Wait(all_ids, 3, 100, &wait_results, true)); ASSERT_EQ(wait_results.size(), 3); ASSERT_EQ(wait_results, std::vector({true, true, false})); diff --git a/src/ray/object_manager/object_manager.cc b/src/ray/object_manager/object_manager.cc index 3a31da864..3d777be12 100644 --- a/src/ray/object_manager/object_manager.cc +++ b/src/ray/object_manager/object_manager.cc @@ -432,11 +432,11 @@ void ObjectManager::CancelPull(const ObjectID &object_id) { ray::Status ObjectManager::Wait( const std::vector &object_ids, const std::unordered_map &owner_addresses, int64_t timeout_ms, - uint64_t num_required_objects, bool wait_local, const WaitCallback &callback) { + uint64_t num_required_objects, const WaitCallback &callback) { UniqueID wait_id = UniqueID::FromRandom(); RAY_LOG(DEBUG) << "Wait request " << wait_id << " on " << self_node_id_; RAY_RETURN_NOT_OK(AddWaitRequest(wait_id, object_ids, owner_addresses, timeout_ms, - num_required_objects, wait_local, callback)); + num_required_objects, callback)); RAY_RETURN_NOT_OK(LookupRemainingWaitObjects(wait_id)); // LookupRemainingWaitObjects invokes SubscribeRemainingWaitObjects once lookup has // been performed on all remaining objects. @@ -446,7 +446,7 @@ ray::Status ObjectManager::Wait( ray::Status ObjectManager::AddWaitRequest( const UniqueID &wait_id, const std::vector &object_ids, const std::unordered_map &owner_addresses, int64_t timeout_ms, - uint64_t num_required_objects, bool wait_local, const WaitCallback &callback) { + uint64_t num_required_objects, const WaitCallback &callback) { RAY_CHECK(timeout_ms >= 0 || timeout_ms == -1); RAY_CHECK(num_required_objects != 0); RAY_CHECK(num_required_objects <= object_ids.size()) @@ -462,7 +462,6 @@ ray::Status ObjectManager::AddWaitRequest( wait_state.owner_addresses = owner_addresses; wait_state.timeout_ms = timeout_ms; wait_state.num_required_objects = num_required_objects; - wait_state.wait_local = wait_local; for (const auto &object_id : object_ids) { if (local_objects_.count(object_id) > 0) { wait_state.found.insert(object_id); @@ -496,9 +495,7 @@ ray::Status ObjectManager::LookupRemainingWaitObjects(const UniqueID &wait_id) { auto &wait_state = active_wait_requests_.find(wait_id)->second; // Note that the object is guaranteed to be added to local_objects_ before // the notification is triggered. - bool remote_object_ready = !node_ids.empty() || !spilled_url.empty(); - if (local_objects_.count(lookup_object_id) > 0 || - (!wait_state.wait_local && remote_object_ready)) { + if (local_objects_.count(lookup_object_id) > 0) { wait_state.remaining.erase(lookup_object_id); wait_state.found.insert(lookup_object_id); } @@ -547,9 +544,7 @@ void ObjectManager::SubscribeRemainingWaitObjects(const UniqueID &wait_id) { auto &wait_state = object_id_wait_state->second; // Note that the object is guaranteed to be added to local_objects_ before // the notification is triggered. - bool remote_object_ready = !node_ids.empty() || !spilled_url.empty(); - if (local_objects_.count(subscribe_object_id) > 0 || - (!wait_state.wait_local && remote_object_ready)) { + if (local_objects_.count(subscribe_object_id) > 0) { RAY_LOG(DEBUG) << "Wait request " << wait_id << ": subscription notification received for object " << subscribe_object_id; diff --git a/src/ray/object_manager/object_manager.h b/src/ray/object_manager/object_manager.h index ff409eb18..9579df30e 100644 --- a/src/ray/object_manager/object_manager.h +++ b/src/ray/object_manager/object_manager.h @@ -254,13 +254,12 @@ class ObjectManager : public ObjectManagerInterface, /// \param timeout_ms The time in milliseconds to wait before invoking the callback. /// \param num_required_objects The minimum number of objects required before /// invoking the callback. - /// \param wait_local Whether to wait until objects arrive to this node's store. /// \param callback Invoked when either timeout_ms is satisfied OR num_ready_objects /// is satisfied. /// \return Status of whether the wait successfully initiated. ray::Status Wait(const std::vector &object_ids, const std::unordered_map &owner_addresses, - int64_t timeout_ms, uint64_t num_required_objects, bool wait_local, + int64_t timeout_ms, uint64_t num_required_objects, const WaitCallback &callback); /// Free a list of objects from object store. @@ -299,8 +298,6 @@ class ObjectManager : public ObjectManagerInterface, callback(callback) {} /// The period of time to wait before invoking the callback. int64_t timeout_ms; - /// Whether to wait for objects to become local before returning. - bool wait_local; /// The timer used whenever wait_ms > 0. std::unique_ptr timeout_timer; /// The callback invoked when WaitCallback is complete. @@ -311,8 +308,7 @@ class ObjectManager : public ObjectManagerInterface, std::unordered_map owner_addresses; /// The objects that have not yet been found. std::unordered_set remaining; - /// The objects that have been found. Note that if wait_local is true, then - /// this will only contain objects that are in local_objects_ too. + /// The objects that have been found. std::unordered_set found; /// Objects that have been requested either by Lookup or Subscribe. std::unordered_set requested_objects; @@ -324,8 +320,7 @@ class ObjectManager : public ObjectManagerInterface, ray::Status AddWaitRequest( const UniqueID &wait_id, const std::vector &object_ids, const std::unordered_map &owner_addresses, - int64_t timeout_ms, uint64_t num_required_objects, bool wait_local, - const WaitCallback &callback); + int64_t timeout_ms, uint64_t num_required_objects, const WaitCallback &callback); /// Lookup any remaining objects that are not local. This is invoked after /// the wait request is created and local objects are identified. diff --git a/src/ray/object_manager/test/object_manager_test.cc b/src/ray/object_manager/test/object_manager_test.cc index 75f3b5c70..493127000 100644 --- a/src/ray/object_manager/test/object_manager_test.cc +++ b/src/ray/object_manager/test/object_manager_test.cc @@ -88,7 +88,7 @@ class TestObjectManagerBase : public ::testing::Test { socket_name_2 = TestSetupUtil::StartObjectStore(); unsigned int pull_timeout_ms = 1; - push_timeout_ms = 1000; + push_timeout_ms = 1500; // start first server gcs::GcsClientOptions client_options("127.0.0.1", 6379, /*password*/ "", @@ -182,7 +182,9 @@ class TestObjectManagerBase : public ::testing::Test { class TestObjectManager : public TestObjectManagerBase { public: int current_wait_test = -1; - int num_connected_clients = 0; + int num_connected_clients_1 = 0; + int num_connected_clients_2 = 0; + std::atomic ready_cnt; NodeID node_id_1; NodeID node_id_2; @@ -197,10 +199,26 @@ class TestObjectManager : public TestObjectManagerBase { RAY_CHECK_OK(gcs_client_1->Nodes().AsyncSubscribeToNodeChange( [this](const NodeID &node_id, const GcsNodeInfo &data) { if (node_id == node_id_1 || node_id == node_id_2) { - num_connected_clients += 1; + num_connected_clients_1 += 1; } - if (num_connected_clients == 2) { - StartTests(); + if (num_connected_clients_1 == 2) { + ready_cnt += 1; + if (ready_cnt == 2) { + StartTests(); + } + } + }, + nullptr)); + RAY_CHECK_OK(gcs_client_2->Nodes().AsyncSubscribeToNodeChange( + [this](const NodeID &node_id, const GcsNodeInfo &data) { + if (node_id == node_id_1 || node_id == node_id_2) { + num_connected_clients_2 += 1; + } + if (num_connected_clients_2 == 2) { + ready_cnt += 1; + if (ready_cnt == 2) { + StartTests(); + } } }, nullptr)); @@ -261,8 +279,10 @@ class TestObjectManager : public TestObjectManagerBase { // object. ObjectID object_1 = WriteDataToClient(client2, data_size); ObjectID object_2 = WriteDataToClient(client2, data_size); - UniqueID sub_id = ray::UniqueID::FromRandom(); + server2->object_manager_.Push(object_1, gcs_client_1->Nodes().GetSelfId()); + server2->object_manager_.Push(object_2, gcs_client_1->Nodes().GetSelfId()); + UniqueID sub_id = ray::UniqueID::FromRandom(); RAY_CHECK_OK(server1->object_manager_.object_directory_->SubscribeObjectLocations( sub_id, object_1, rpc::Address(), [this, sub_id, object_1, object_2](const ray::ObjectID &object_id, @@ -276,7 +296,7 @@ class TestObjectManager : public TestObjectManagerBase { void TestWaitWhileSubscribed(UniqueID sub_id, ObjectID object_1, ObjectID object_2) { int required_objects = 1; - int timeout_ms = 1000; + int timeout_ms = 1500; std::vector object_ids = {object_1, object_2}; boost::posix_time::ptime start_time = boost::posix_time::second_clock::local_time(); @@ -285,7 +305,7 @@ class TestObjectManager : public TestObjectManagerBase { RAY_CHECK_OK(server1->object_manager_.AddWaitRequest( wait_id, object_ids, std::unordered_map(), timeout_ms, - required_objects, false, + required_objects, [this, sub_id, object_1, object_ids, start_time]( const std::vector &found, const std::vector &remaining) { @@ -317,7 +337,7 @@ class TestObjectManager : public TestObjectManagerBase { TestWait(data_size, 5, 3, /*timeout_ms=*/0, false, false); } break; case 1: { - // Ensure timeout_ms = 1000 is handled correctly. + // Ensure timeout_ms = 1500 is handled correctly. // Out of 5 objects, we expect 3 ready objects and 2 remaining objects. TestWait(data_size, 5, 3, wait_timeout_ms, false, false); } break; @@ -348,6 +368,7 @@ class TestObjectManager : public TestObjectManagerBase { oid = WriteDataToClient(client1, data_size); } else { oid = WriteDataToClient(client2, data_size); + server2->object_manager_.Push(oid, gcs_client_1->Nodes().GetSelfId()); } object_ids.push_back(oid); } @@ -359,7 +380,7 @@ class TestObjectManager : public TestObjectManagerBase { boost::posix_time::ptime start_time = boost::posix_time::second_clock::local_time(); RAY_CHECK_OK(server1->object_manager_.Wait( object_ids, std::unordered_map(), timeout_ms, - required_objects, false, + required_objects, [this, object_ids, num_objects, timeout_ms, required_objects, start_time]( const std::vector &found, const std::vector &remaining) { @@ -398,7 +419,7 @@ class TestObjectManager : public TestObjectManagerBase { NextWaitTest(); } break; case 1: { - // Ensure lookup succeeds as expected when timeout_ms = 1000. + // Ensure lookup succeeds as expected when timeout_ms = 1500. ASSERT_TRUE(found.size() >= required_objects); ASSERT_TRUE(static_cast(found.size() + remaining.size()) == num_objects); NextWaitTest(); diff --git a/src/ray/raylet/format/node_manager.fbs b/src/ray/raylet/format/node_manager.fbs index 162504eb5..c62754b75 100644 --- a/src/ray/raylet/format/node_manager.fbs +++ b/src/ray/raylet/format/node_manager.fbs @@ -208,8 +208,6 @@ table WaitRequest { num_ready_objects: int; // timeout timeout: long; - // Whether to wait until objects appear locally. - wait_local: bool; // False for direct call tasks. Blocking for those tasks is handled via the // NotifyDirectCallTaskBlocked/Unblocked IPCs. mark_worker_blocked: bool; diff --git a/src/ray/raylet/node_manager.cc b/src/ray/raylet/node_manager.cc index e86975ba0..a289900d4 100644 --- a/src/ray/raylet/node_manager.cc +++ b/src/ray/raylet/node_manager.cc @@ -1523,9 +1523,6 @@ void NodeManager::ProcessWaitRequestMessage( // Read the data. auto message = flatbuffers::GetRoot(message_data); std::vector object_ids = from_flatbuf(*message->object_ids()); - int64_t wait_ms = message->timeout(); - uint64_t num_required_objects = static_cast(message->num_ready_objects()); - bool wait_local = message->wait_local(); const auto refs = FlatbufferToObjectReference(*message->object_ids(), *message->owner_addresses()); std::unordered_map owner_addresses; @@ -1551,9 +1548,11 @@ void NodeManager::ProcessWaitRequestMessage( AsyncResolveObjects(client, refs, current_task_id, /*ray_get=*/false, /*mark_worker_blocked*/ was_blocked); } - + int64_t wait_ms = message->timeout(); + uint64_t num_required_objects = static_cast(message->num_ready_objects()); + // TODO Remove in the future since it should have already be done in other place ray::Status status = object_manager_.Wait( - object_ids, owner_addresses, wait_ms, num_required_objects, wait_local, + object_ids, owner_addresses, wait_ms, num_required_objects, [this, resolve_objects, was_blocked, client, current_task_id]( std::vector found, std::vector remaining) { // Write the data. @@ -1600,7 +1599,7 @@ void NodeManager::ProcessWaitForDirectActorCallArgsRequestMessage( // has been found, so the object may still be on a remote node when the // client receives the reply. ray::Status status = object_manager_.Wait( - object_ids, owner_addresses, -1, object_ids.size(), false, + object_ids, owner_addresses, -1, object_ids.size(), [this, client, tag](std::vector found, std::vector remaining) { RAY_CHECK(remaining.empty()); std::shared_ptr worker = diff --git a/src/ray/raylet/node_manager.h b/src/ray/raylet/node_manager.h index bdc8f5c47..f2a43935a 100644 --- a/src/ray/raylet/node_manager.h +++ b/src/ray/raylet/node_manager.h @@ -115,7 +115,7 @@ class NodeManager : public rpc::NodeManagerServiceHandler { NodeManager(boost::asio::io_service &io_service, const NodeID &self_node_id, const NodeManagerConfig &config, ObjectManager &object_manager, std::shared_ptr gcs_client, - std::shared_ptr object_directory_); + std::shared_ptr object_directory); /// Process a new client connection. /// diff --git a/src/ray/raylet_client/raylet_client.cc b/src/ray/raylet_client/raylet_client.cc index 3589fc840..9251c1020 100644 --- a/src/ray/raylet_client/raylet_client.cc +++ b/src/ray/raylet_client/raylet_client.cc @@ -206,13 +206,13 @@ Status raylet::RayletClient::NotifyDirectCallTaskUnblocked() { Status raylet::RayletClient::Wait(const std::vector &object_ids, const std::vector &owner_addresses, int num_returns, int64_t timeout_milliseconds, - bool wait_local, bool mark_worker_blocked, - const TaskID ¤t_task_id, WaitResultPair *result) { + bool mark_worker_blocked, const TaskID ¤t_task_id, + WaitResultPair *result) { // Write request. flatbuffers::FlatBufferBuilder fbb; auto message = protocol::CreateWaitRequest( fbb, to_flatbuf(fbb, object_ids), AddressesToFlatbuffer(fbb, owner_addresses), - num_returns, timeout_milliseconds, wait_local, mark_worker_blocked, + num_returns, timeout_milliseconds, mark_worker_blocked, to_flatbuf(fbb, current_task_id)); fbb.Finish(message); std::vector reply; diff --git a/src/ray/raylet_client/raylet_client.h b/src/ray/raylet_client/raylet_client.h index a50b7c0e7..6f2821038 100644 --- a/src/ray/raylet_client/raylet_client.h +++ b/src/ray/raylet_client/raylet_client.h @@ -272,7 +272,6 @@ class RayletClient : public RayletClientInterface { /// \param owner_addresses The addresses of the workers that own the objects. /// \param num_returns The number of objects to wait for. /// \param timeout_milliseconds Duration, in milliseconds, to wait before returning. - /// \param wait_local Whether to wait for objects to appear on this node. /// \param mark_worker_blocked Set to false if current task is a direct call task. /// \param current_task_id The task that called wait. /// \param result A pair with the first element containing the object ids that were @@ -280,9 +279,8 @@ class RayletClient : public RayletClientInterface { /// \return ray::Status. ray::Status Wait(const std::vector &object_ids, const std::vector &owner_addresses, int num_returns, - int64_t timeout_milliseconds, bool wait_local, - bool mark_worker_blocked, const TaskID ¤t_task_id, - WaitResultPair *result); + int64_t timeout_milliseconds, bool mark_worker_blocked, + const TaskID ¤t_task_id, WaitResultPair *result); /// Wait for the given objects, asynchronously. The core worker is notified when /// the wait completes. diff --git a/streaming/src/test/queue_tests_base.h b/streaming/src/test/queue_tests_base.h index cb168e078..a842f51ef 100644 --- a/streaming/src/test/queue_tests_base.h +++ b/streaming/src/test/queue_tests_base.h @@ -128,7 +128,7 @@ class StreamingQueueTestBase : public ::testing::TestWithParam { std::vector wait_results; std::vector> results; - Status wait_st = driver.Wait(return_ids, 1, 5 * 1000, &wait_results); + Status wait_st = driver.Wait(return_ids, 1, 5 * 1000, &wait_results, true); if (!wait_st.ok()) { STREAMING_LOG(ERROR) << "Wait fail."; return false; From e6cb4f4bd76e59363f2032e2c9c789e174723074 Mon Sep 17 00:00:00 2001 From: Allen Date: Thu, 17 Dec 2020 00:25:29 -0800 Subject: [PATCH 10/88] [Core] Add log of address and port (#12908) Co-authored-by: Allen Yin --- src/ray/rpc/metrics_agent_client.h | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/ray/rpc/metrics_agent_client.h b/src/ray/rpc/metrics_agent_client.h index 9a78b6d9d..36a0cbb51 100644 --- a/src/ray/rpc/metrics_agent_client.h +++ b/src/ray/rpc/metrics_agent_client.h @@ -37,6 +37,8 @@ class MetricsAgentClient { /// \param[in] client_call_manager The `ClientCallManager` used for managing requests. MetricsAgentClient(const std::string &address, const int port, ClientCallManager &client_call_manager) { + RAY_LOG(DEBUG) << "Initiate the metrics client of address:" << address + << " port:" << port; grpc_client_ = std::unique_ptr>( new GrpcClient(address, port, client_call_manager)); }; From 82f9c7014e2d0acd3e3869066f5dc3142ec9e7a7 Mon Sep 17 00:00:00 2001 From: Gekho457 <62982571+Gekho457@users.noreply.github.com> Date: Thu, 17 Dec 2020 09:41:48 -0800 Subject: [PATCH 11/88] [K8s] Retry getting home directory in command runner. (#12925) --- .../ray/autoscaler/_private/command_runner.py | 33 ++++++++++++++----- 1 file changed, 25 insertions(+), 8 deletions(-) diff --git a/python/ray/autoscaler/_private/command_runner.py b/python/ray/autoscaler/_private/command_runner.py index 075efa377..f350ff1f3 100644 --- a/python/ray/autoscaler/_private/command_runner.py +++ b/python/ray/autoscaler/_private/command_runner.py @@ -35,6 +35,8 @@ logger = logging.getLogger(__name__) HASH_MAX_LENGTH = 10 KUBECTL_RSYNC = os.path.join( os.path.dirname(os.path.abspath(__file__)), "kubernetes/kubectl-rsync.sh") +MAX_HOME_RETRIES = 3 +HOME_RETRY_DELAY_S = 5 _config = {"use_login_shells": True, "silent_rsync": True} @@ -248,16 +250,31 @@ class KubernetesCommandRunner(CommandRunnerInterface): @property def _home(self): + if self._home_cached is not None: + return self._home_cached + for _ in range(MAX_HOME_RETRIES - 1): + try: + self._home_cached = self._try_to_get_home() + return self._home_cached + except Exception: + # TODO (Dmitri): Identify the exception we're trying to avoid. + logger.info("Error reading container's home directory. " + f"Retrying in {HOME_RETRY_DELAY_S} seconds.") + time.sleep(HOME_RETRY_DELAY_S) + # Last try + self._home_cached = self._try_to_get_home() + return self._home_cached + + def _try_to_get_home(self): # TODO (Dmitri): Think about how to use the node's HOME variable # without making an extra kubectl exec call. - if self._home_cached is None: - cmd = self.kubectl + [ - "exec", "-it", self.node_id, "--", "printenv", "HOME" - ] - joined_cmd = " ".join(cmd) - raw_out = self.process_runner.check_output(joined_cmd, shell=True) - self._home_cached = raw_out.decode().strip("\n\r") - return self._home_cached + cmd = self.kubectl + [ + "exec", "-it", self.node_id, "--", "printenv", "HOME" + ] + joined_cmd = " ".join(cmd) + raw_out = self.process_runner.check_output(joined_cmd, shell=True) + home = raw_out.decode().strip("\n\r") + return home class SSHOptions: From c7a59b239f67b990c7235f0fd3a1eca0046d954d Mon Sep 17 00:00:00 2001 From: Edward Oakes Date: Thu, 17 Dec 2020 15:04:11 -0600 Subject: [PATCH 12/88] Remove unused endpoints_to_remove (#12946) --- python/ray/serve/controller.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/python/ray/serve/controller.py b/python/ray/serve/controller.py index 5237d453c..cba1ec64b 100644 --- a/python/ray/serve/controller.py +++ b/python/ray/serve/controller.py @@ -154,7 +154,6 @@ class ActorStateReconciler: backend_replicas_to_stop: Dict[BackendTag, List[ReplicaTag]] = field( default_factory=lambda: defaultdict(list)) backends_to_remove: List[BackendTag] = field(default_factory=list) - endpoints_to_remove: List[EndpointTag] = field(default_factory=list) # TODO(edoakes): consider removing this and just using the names. @@ -833,8 +832,6 @@ class ServeController: if endpoint in self.current_state.traffic_policies: del self.current_state.traffic_policies[endpoint] - self.actor_reconciler.endpoints_to_remove.append(endpoint) - return_uuid = self._create_event_with_result({ route_to_delete: None, endpoint: None From d747071dd9e883bd4ceefe80d0344f287630c4e5 Mon Sep 17 00:00:00 2001 From: dHannasch Date: Thu, 17 Dec 2020 15:26:30 -0700 Subject: [PATCH 13/88] Test shard_context on already-created boost::asio::io_service. (#12917) --- src/ray/gcs/test/asio_test.cc | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/src/ray/gcs/test/asio_test.cc b/src/ray/gcs/test/asio_test.cc index 5883ca661..052a95ec6 100644 --- a/src/ray/gcs/test/asio_test.cc +++ b/src/ray/gcs/test/asio_test.cc @@ -18,6 +18,7 @@ #include "gtest/gtest.h" #include "ray/common/test_util.h" +#include "ray/gcs/redis_context.h" #include "ray/util/logging.h" extern "C" { @@ -66,6 +67,14 @@ TEST_F(RedisAsioTest, TestRedisCommands) { redisAsyncCommand(ac, NULL, NULL, "SET key test"); redisAsyncCommand(ac, GetCallback, nullptr, "GET key"); + std::shared_ptr shard_context = + std::make_shared(io_service); + ASSERT_TRUE(shard_context + ->Connect(std::string("127.0.0.1"), TEST_REDIS_SERVER_PORTS.front(), + /*sharding=*/true, + /*password=*/std::string()) + .ok()); + io_service.run(); } From 124c8318a876194f25d0ab2382502a3662925b04 Mon Sep 17 00:00:00 2001 From: Sven Mika Date: Fri, 18 Dec 2020 00:44:26 +0100 Subject: [PATCH 14/88] [RLlib] Fix broken test_distributions.py (test_categorical) (#12915) --- rllib/BUILD | 13 ++++++------- rllib/models/tests/test_distributions.py | 5 +++-- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/rllib/BUILD b/rllib/BUILD index bde7d5855..bd612d0ff 100644 --- a/rllib/BUILD +++ b/rllib/BUILD @@ -1089,13 +1089,12 @@ py_test( srcs = ["models/tests/test_convtranspose2d_stack.py"] ) -# Failing after the following PR: https://github.com/ray-project/ray/pull/12760. -#py_test( -# name = "test_distributions", -# tags = ["models"], -# size = "medium", -# srcs = ["models/tests/test_distributions.py"] -#) +py_test( + name = "test_distributions", + tags = ["models"], + size = "medium", + srcs = ["models/tests/test_distributions.py"] +) # -------------------------------------------------------------------- # Evaluation components diff --git a/rllib/models/tests/test_distributions.py b/rllib/models/tests/test_distributions.py index e0317cf11..987f76a56 100644 --- a/rllib/models/tests/test_distributions.py +++ b/rllib/models/tests/test_distributions.py @@ -87,14 +87,15 @@ class TestDistributions(unittest.TestCase): batch_size = 10000 num_categories = 4 # Create categorical distribution with n categories. - inputs_space = Box(-1.0, 2.0, shape=(batch_size, num_categories)) + inputs_space = Box( + -1.0, 2.0, shape=(batch_size, num_categories), dtype=np.float32) values_space = Box( 0, num_categories - 1, shape=(batch_size, ), dtype=np.int32) inputs = inputs_space.sample() for fw, sess in framework_iterator( - session=True, frameworks=("jax", "tf", "tf2", "torch")): + session=True, frameworks=("tf", "tf2", "torch")): # Create the correct distribution object. cls = JAXCategorical if fw == "jax" else Categorical if \ fw != "torch" else TorchCategorical From 3d72000826f31e06ed136e50e038e21f3f7d975f Mon Sep 17 00:00:00 2001 From: Kai Fricke Date: Fri, 18 Dec 2020 04:16:03 +0100 Subject: [PATCH 15/88] [tune] Add `points_to_evaluate` to BasicVariantGenerator (#12916) Co-authored-by: Richard Liaw --- doc/source/tune/api_docs/search_space.rst | 9 +- doc/source/tune/api_docs/suggestion.rst | 19 +++ doc/source/tune/examples/index.rst | 2 +- .../tune/examples/tune_basic_example.rst | 6 + python/ray/tune/BUILD | 11 +- .../ray/tune/examples/tune_basic_example.py | 51 ++++++++ python/ray/tune/sample.py | 36 ++++++ python/ray/tune/suggest/basic_variant.py | 115 +++++++++++++++--- python/ray/tune/suggest/variant_generator.py | 81 ++++++++++-- python/ray/tune/tests/test_sample.py | 97 +++++++++++++++ python/ray/tune/tune.py | 3 +- 11 files changed, 396 insertions(+), 34 deletions(-) create mode 100644 doc/source/tune/examples/tune_basic_example.rst create mode 100644 python/ray/tune/examples/tune_basic_example.py diff --git a/doc/source/tune/api_docs/search_space.rst b/doc/source/tune/api_docs/search_space.rst index 005942fe9..3c069760f 100644 --- a/doc/source/tune/api_docs/search_space.rst +++ b/doc/source/tune/api_docs/search_space.rst @@ -263,10 +263,7 @@ Grid Search API .. autofunction:: ray.tune.grid_search -Internals ---------- +References +---------- -BasicVariantGenerator -~~~~~~~~~~~~~~~~~~~~~ - -.. autoclass:: ray.tune.suggest.BasicVariantGenerator +See also :ref:`tune-basicvariant`. \ No newline at end of file diff --git a/doc/source/tune/api_docs/suggestion.rst b/doc/source/tune/api_docs/suggestion.rst index 9675f8537..05c3d466a 100644 --- a/doc/source/tune/api_docs/suggestion.rst +++ b/doc/source/tune/api_docs/suggestion.rst @@ -22,6 +22,10 @@ Summary - Summary - Website - Code Example + * - :ref:`Random search/grid search ` + - Random search/grid search + - + - :doc:`/tune/examples/tune_basic_example` * - :ref:`AxSearch ` - Bayesian/Bandit Optimization - [`Ax `__] @@ -123,6 +127,21 @@ identifier. .. note:: This is currently not implemented for: AxSearch, TuneBOHB, SigOptSearch, and DragonflySearch. +.. _tune-basicvariant: + +Random search and grid search (tune.suggest.basic_variant.BasicVariantGenerator) +-------------------------------------------------------------------------------- + +The default and most basic way to do hyperparameter search is via random and grid search. +Ray Tune does this through the :class:`BasicVariantGenerator ` +class that generates trial variants given a search space definition. + +The :class:`BasicVariantGenerator ` is used per +default if no search algorithm is passed to +:func:`tune.run() `. + +.. autoclass:: ray.tune.suggest.basic_variant.BasicVariantGenerator + .. _tune-ax: Ax (tune.suggest.ax.AxSearch) diff --git a/doc/source/tune/examples/index.rst b/doc/source/tune/examples/index.rst index 89af9deb9..54852d550 100644 --- a/doc/source/tune/examples/index.rst +++ b/doc/source/tune/examples/index.rst @@ -13,7 +13,7 @@ If any example is broken, or if you'd like to add an example to this page, feel General Examples ---------------- - +- :doc:`/tune/examples/tune_basic_example`: Simple example for doing a basic random and grid search. - :doc:`/tune/examples/async_hyperband_example`: Example of using a simple tuning function with AsyncHyperBandScheduler. - :doc:`/tune/examples/hyperband_function_example`: Example of using a Trainable function with HyperBandScheduler. Also uses the AsyncHyperBandScheduler. - :doc:`/tune/examples/pbt_function`: Example of using the function API with a PopulationBasedTraining scheduler. diff --git a/doc/source/tune/examples/tune_basic_example.rst b/doc/source/tune/examples/tune_basic_example.rst new file mode 100644 index 000000000..1be5ab3f1 --- /dev/null +++ b/doc/source/tune/examples/tune_basic_example.rst @@ -0,0 +1,6 @@ +:orphan: + +tune_basic_example +~~~~~~~~~~~~~~~~~~ + +.. literalinclude:: /../../python/ray/tune/examples/tune_basic_example.py diff --git a/python/ray/tune/BUILD b/python/ray/tune/BUILD index f10df3ec9..8b3439853 100644 --- a/python/ray/tune/BUILD +++ b/python/ray/tune/BUILD @@ -157,7 +157,7 @@ py_test( py_test( name = "test_sample", - size = "small", + size = "medium", srcs = ["tests/test_sample.py"], deps = [":tune_lib"], tags = ["exclusive"], @@ -696,6 +696,15 @@ py_test( args = ["--smoke-test"] ) +py_test( + name = "tune_basic_example", + size = "small", + srcs = ["examples/tune_basic_example.py"], + deps = [":tune_lib"], + tags = ["exclusive", "example"], + args = ["--smoke-test"] +) + # Downloads too much data. # py_test( # name = "tune_cifar10_gluon", diff --git a/python/ray/tune/examples/tune_basic_example.py b/python/ray/tune/examples/tune_basic_example.py new file mode 100644 index 000000000..30677bc0c --- /dev/null +++ b/python/ray/tune/examples/tune_basic_example.py @@ -0,0 +1,51 @@ +"""This example demonstrates basic Ray Tune random search and grid search.""" +import time + +import ray +from ray import tune + + +def evaluation_fn(step, width, height): + time.sleep(0.1) + return (0.1 + width * step / 100)**(-1) + height * 0.1 + + +def easy_objective(config): + # Hyperparameters + width, height = config["width"], config["height"] + + for step in range(config["steps"]): + # Iterative training function - can be any arbitrary training procedure + intermediate_score = evaluation_fn(step, width, height) + # Feed the score back back to Tune. + tune.report(iterations=step, mean_loss=intermediate_score) + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument( + "--smoke-test", action="store_true", help="Finish quickly for testing") + args, _ = parser.parse_known_args() + ray.init(configure_logging=False) + + # This will do a grid search over the `activation` parameter. This means + # that each of the two values (`relu` and `tanh`) will be sampled once + # for each sample (`num_samples`). We end up with 2 * 50 = 100 samples. + # The `width` and `height` parameters are sampled randomly. + # `steps` is a constant parameter. + + analysis = tune.run( + easy_objective, + metric="mean_loss", + mode="min", + num_samples=5 if args.smoke_test else 50, + config={ + "steps": 5 if args.smoke_test else 100, + "width": tune.uniform(0, 20), + "height": tune.uniform(-100, 100), + "activation": tune.grid_search(["relu", "tanh"]) + }) + + print("Best hyperparameters found were: ", analysis.best_config) diff --git a/python/ray/tune/sample.py b/python/ray/tune/sample.py index a9d82331a..7190c69d2 100644 --- a/python/ray/tune/sample.py +++ b/python/ray/tune/sample.py @@ -53,6 +53,14 @@ class Domain: def is_function(self): return False + def is_valid(self, value: Any): + """Returns True if `value` is a valid value in this domain.""" + raise NotImplementedError + + @property + def domain_str(self): + return "(unknown)" + class Sampler: def sample(self, @@ -203,6 +211,13 @@ class Float(Domain): new.set_sampler(Quantized(new.get_sampler(), q), allow_override=True) return new + def is_valid(self, value: float): + return self.lower <= value <= self.upper + + @property + def domain_str(self): + return f"({self.lower}, {self.upper})" + class Integer(Domain): class _Uniform(Uniform): @@ -232,6 +247,13 @@ class Integer(Domain): new.set_sampler(self._Uniform()) return new + def is_valid(self, value: int): + return self.lower <= value <= self.upper + + @property + def domain_str(self): + return f"({self.lower}, {self.upper})" + class Categorical(Domain): class _Uniform(Uniform): @@ -264,6 +286,13 @@ class Categorical(Domain): def __getitem__(self, item): return self.categories[item] + def is_valid(self, value: Any): + return value in self.categories + + @property + def domain_str(self): + return f"{self.categories}" + class Function(Domain): class _CallSampler(BaseSampler): @@ -295,6 +324,13 @@ class Function(Domain): def is_function(self): return True + def is_valid(self, value: Any): + return True # This is user-defined, so lets not assume anything + + @property + def domain_str(self): + return f"{self.func}()" + class Quantized(Sampler): def __init__(self, sampler: Sampler, q: Union[float, int]): diff --git a/python/ray/tune/suggest/basic_variant.py b/python/ray/tune/suggest/basic_variant.py index 435e6dd01..46f54888b 100644 --- a/python/ray/tune/suggest/basic_variant.py +++ b/python/ray/tune/suggest/basic_variant.py @@ -1,44 +1,95 @@ +import copy import itertools import os import uuid -from typing import Dict, List, Union +from typing import Dict, List, Optional, Union from ray.tune.error import TuneError from ray.tune.experiment import Experiment, convert_to_experiment_list from ray.tune.config_parser import make_parser, create_trial_from_spec from ray.tune.suggest.variant_generator import ( - count_variants, generate_variants, format_vars, flatten_resolved_vars) + count_variants, generate_variants, format_vars, flatten_resolved_vars, + get_preset_variants) from ray.tune.suggest.search import SearchAlgorithm class BasicVariantGenerator(SearchAlgorithm): """Uses Tune's variant generation for resolving variables. - See also: `ray.tune.suggest.variant_generator`. + This is the default search algorithm used if no other search algorithm + is specified. - User API: + + Args: + points_to_evaluate (list): Initial parameter suggestions to be run + first. This is for when you already have some good parameters + you want to run first to help the algorithm make better suggestions + for future parameters. Needs to be a list of dicts containing the + configurations. + + + Example: .. code-block:: python from ray import tune - from ray.tune.suggest import BasicVariantGenerator - searcher = BasicVariantGenerator() - tune.run(my_trainable_func, algo=searcher) + # This will automatically use the `BasicVariantGenerator` + tune.run( + lambda config: config["a"] + config["b"], + config={ + "a": tune.grid_search([1, 2]), + "b": tune.randint(0, 3) + }, + num_samples=4) - Internal API: + In the example above, 8 trials will be generated: For each sample + (``4``), each of the grid search variants for ``a`` will be sampled + once. The ``b`` parameter will be sampled randomly. + + The generator accepts a pre-set list of points that should be evaluated. + The points will replace the first samples of each experiment passed to + the ``BasicVariantGenerator``. + + Each point will replace one sample of the specified ``num_samples``. If + grid search variables are overwritten with the values specified in the + presets, the number of samples will thus be reduced. + + Example: .. code-block:: python - from ray.tune.suggest import BasicVariantGenerator + from ray import tune + from ray.tune.suggest.basic_variant import BasicVariantGenerator + + + tune.run( + lambda config: config["a"] + config["b"], + config={ + "a": tune.grid_search([1, 2]), + "b": tune.randint(0, 3) + }, + search_alg=BasicVariantGenerator(points_to_evaluate=[ + {"a": 2, "b": 2}, + {"a": 1}, + {"b": 2} + ]), + num_samples=4) + + The example above will produce six trials via four samples: + + - The first sample will produce one trial with ``a=2`` and ``b=2``. + - The second sample will produce one trial with ``a=1`` and ``b`` sampled + randomly + - The third sample will produce two trials, one for each grid search + value of ``a``. It will be ``b=2`` for both of these trials. + - The fourth sample will produce two trials, one for each grid search + value of ``a``. ``b`` will be sampled randomly and independently for + both of these trials. - searcher = BasicVariantGenerator() - searcher.add_configurations({"experiment": { ... }}) - trial = searcher.next_trial() - searcher.is_finished == True """ - def __init__(self): + def __init__(self, points_to_evaluate: Optional[List[Dict]] = None): """Initializes the Variant Generator. """ @@ -48,6 +99,8 @@ class BasicVariantGenerator(SearchAlgorithm): self._counter = 0 self._finished = False + self._points_to_evaluate = points_to_evaluate or [] + # Unique prefix for all trials generated, e.g., trial ids start as # 2f1e_00001, 2f1ef_00002, 2f1ef_0003, etc. Overridable for testing. force_test_uuid = os.environ.get("_TEST_TUNE_TRIAL_UUID") @@ -72,12 +125,14 @@ class BasicVariantGenerator(SearchAlgorithm): """ experiment_list = convert_to_experiment_list(experiments) for experiment in experiment_list: - self._total_samples += count_variants(experiment.spec) + points_to_evaluate = copy.deepcopy(self._points_to_evaluate) + self._total_samples += count_variants(experiment.spec, + points_to_evaluate) self._trial_generator = itertools.chain( self._trial_generator, self._generate_trials( experiment.spec.get("num_samples", 1), experiment.spec, - experiment.dir_name)) + experiment.dir_name, points_to_evaluate)) def next_trial(self): """Provides one Trial object to be queued into the TrialRunner. @@ -95,7 +150,11 @@ class BasicVariantGenerator(SearchAlgorithm): self.set_finished() return None - def _generate_trials(self, num_samples, unresolved_spec, output_path=""): + def _generate_trials(self, + num_samples, + unresolved_spec, + output_path="", + points_to_evaluate=None): """Generates Trial objects with the variant generation process. Uses a fixed point iteration to resolve variants. All trials @@ -109,6 +168,28 @@ class BasicVariantGenerator(SearchAlgorithm): if "run" not in unresolved_spec: raise TuneError("Must specify `run` in {}".format(unresolved_spec)) + + points_to_evaluate = points_to_evaluate or [] + + while points_to_evaluate: + config = points_to_evaluate.pop(0) + for resolved_vars, spec in get_preset_variants( + unresolved_spec, config): + trial_id = self._uuid_prefix + ("%05d" % self._counter) + experiment_tag = str(self._counter) + self._counter += 1 + yield create_trial_from_spec( + spec, + output_path, + self._parser, + evaluated_params=flatten_resolved_vars(resolved_vars), + trial_id=trial_id, + experiment_tag=experiment_tag) + num_samples -= 1 + + if num_samples <= 0: + return + for _ in range(num_samples): for resolved_vars, spec in generate_variants(unresolved_spec): trial_id = self._uuid_prefix + ("%05d" % self._counter) diff --git a/python/ray/tune/suggest/variant_generator.py b/python/ray/tune/suggest/variant_generator.py index 7048cd804..849b3b012 100644 --- a/python/ray/tune/suggest/variant_generator.py +++ b/python/ray/tune/suggest/variant_generator.py @@ -1,6 +1,7 @@ import copy import logging -from typing import Any, Dict, Generator, List, Tuple +from collections.abc import Mapping +from typing import Any, Dict, Generator, List, Optional, Tuple import numpy import random @@ -138,13 +139,38 @@ def parse_spec_vars(spec: Dict) -> Tuple[List[Tuple[Tuple, Any]], List[Tuple[ return resolved_vars, domain_vars, grid_vars -def count_variants(spec: Dict) -> int: - spec = copy.deepcopy(spec) - _, domain_vars, grid_vars = parse_spec_vars(spec) - grid_count = 1 - for path, domain in grid_vars: - grid_count *= len(domain.categories) - return spec.get("num_samples", 1) * grid_count +def count_variants(spec: Dict, presets: Optional[List[Dict]] = None) -> int: + # Helper function: Deep update dictionary + def deep_update(d, u): + for k, v in u.items(): + if isinstance(v, Mapping): + d[k] = deep_update(d.get(k, {}), v) + else: + d[k] = v + return d + + # Count samples for a specific spec + def spec_samples(spec, num_samples=1): + _, domain_vars, grid_vars = parse_spec_vars(spec) + grid_count = 1 + for path, domain in grid_vars: + grid_count *= len(domain.categories) + return num_samples * grid_count + + total_samples = 0 + total_num_samples = spec.get("num_samples", 1) + # For each preset, overwrite the spec and count the samples generated + # for this preset + for preset in presets: + preset_spec = copy.deepcopy(spec) + deep_update(preset_spec["config"], preset) + total_samples += spec_samples(preset_spec, 1) + total_num_samples -= 1 + + # Add the remaining samples + if total_num_samples > 0: + total_samples += spec_samples(spec, total_num_samples) + return total_samples def _generate_variants(spec: Dict) -> Tuple[Dict, Dict]: @@ -172,6 +198,45 @@ def _generate_variants(spec: Dict) -> Tuple[Dict, Dict]: yield resolved_vars, spec +def get_preset_variants(spec: Dict, config: Dict): + """Get variants according to a spec, initialized with a config. + + Variables from the spec are overwritten by the variables in the config. + Thus, we may end up with less sampled parameters. + + This function also checks if values used to overwrite search space + parameters are valid, and logs a warning if not. + """ + spec = copy.deepcopy(spec) + + resolved, _, _ = parse_spec_vars(config) + + for path, val in resolved: + try: + domain = _get_value(spec["config"], path) + if isinstance(domain, dict): + if "grid_search" in domain: + domain = Categorical(domain["grid_search"]) + else: + # If users want to overwrite an entire subdict, + # let them do it. + domain = None + except IndexError as exc: + raise ValueError( + f"Pre-set config key `{'/'.join(path)}` does not correspond " + f"to a valid key in the search space definition. Please add " + f"this path to the `config` variable passed to `tune.run()`." + ) from exc + + if domain and not domain.is_valid(val): + logger.warning( + f"Pre-set value `{val}` is not within valid values of " + f"parameter `{'/'.join(path)}`: {domain.domain_str}") + assign_value(spec["config"], path, val) + + return _generate_variants(spec) + + def assign_value(spec: Dict, path: Tuple, value: Any): for k in path[:-1]: spec = spec[k] diff --git a/python/ray/tune/tests/test_sample.py b/python/ray/tune/tests/test_sample.py index d40900a6c..921e0c9ca 100644 --- a/python/ray/tune/tests/test_sample.py +++ b/python/ray/tune/tests/test_sample.py @@ -2,6 +2,7 @@ import numpy as np import unittest from ray import tune +from ray.tune import Experiment from ray.tune.suggest.variant_generator import generate_variants @@ -871,6 +872,102 @@ class SearchSpaceTest(unittest.TestCase): return self._testPointsToEvaluate( ZOOptSearch, config, budget=10, parallel_num=8) + def testPointsToEvaluateBasicVariant(self): + config = { + "metric": tune.sample.Categorical([1, 2, 3, 4]).uniform(), + "a": tune.sample.Categorical(["t1", "t2", "t3", "t4"]).uniform(), + "b": tune.sample.Integer(0, 5), + "c": tune.sample.Float(1e-4, 1e-1).loguniform() + } + + from ray.tune.suggest.basic_variant import BasicVariantGenerator + return self._testPointsToEvaluate(BasicVariantGenerator, config) + + def testPointsToEvaluateBasicVariantAdvanced(self): + config = { + "grid_1": tune.grid_search(["a", "b", "c", "d"]), + "grid_2": tune.grid_search(["x", "y", "z"]), + "nested": { + "random": tune.uniform(2., 10.), + "dependent": tune.sample_from( + lambda spec: -1. * spec.config.nested.random) + } + } + + points = [ + { + "grid_1": "b" + }, + { + "grid_2": "z" + }, + { + "grid_1": "a", + "grid_2": "y" + }, + { + "nested": { + "random": 8.0 + } + }, + ] + + from ray.tune.suggest.basic_variant import BasicVariantGenerator + + # grid_1 * grid_2 are 3 * 4 = 12 variants per complete grid search + # However if one grid var is set by preset variables, that run + # is excluded from grid search. + + # Point 1 overwrites grid_1, so the first trial only grid searches + # over grid_2 (3 trials). + # The remaining 5 trials search over the whole space (5 * 12 trials) + searcher = BasicVariantGenerator(points_to_evaluate=[points[0]]) + exp = Experiment( + run=_mock_objective, name="test", config=config, num_samples=6) + searcher.add_configurations(exp) + self.assertEqual(searcher.total_samples, 1 * 3 + 5 * 12) + + # Point 2 overwrites grid_2, so the first trial only grid searches + # over grid_1 (4 trials). + # The remaining 5 trials search over the whole space (5 * 12 trials) + searcher = BasicVariantGenerator(points_to_evaluate=[points[1]]) + exp = Experiment( + run=_mock_objective, name="test", config=config, num_samples=6) + searcher.add_configurations(exp) + self.assertEqual(searcher.total_samples, 1 * 4 + 5 * 12) + + # Point 3 overwrites grid_1 and grid_2, so the first trial does not + # grid search. + # The remaining 5 trials search over the whole space (5 * 12 trials) + searcher = BasicVariantGenerator(points_to_evaluate=[points[2]]) + exp = Experiment( + run=_mock_objective, name="test", config=config, num_samples=6) + searcher.add_configurations(exp) + self.assertEqual(searcher.total_samples, 1 + 5 * 12) + + # When initialized with all points, the first three trials are + # defined by the logic above. Only 3 trials are grid searched + # compeletely. + searcher = BasicVariantGenerator(points_to_evaluate=points) + exp = Experiment( + run=_mock_objective, name="test", config=config, num_samples=6) + searcher.add_configurations(exp) + self.assertEqual(searcher.total_samples, 1 * 3 + 1 * 4 + 1 + 3 * 12) + + # Run this and confirm results + analysis = tune.run(exp, search_alg=searcher) + configs = [trial.config for trial in analysis.trials] + + self.assertEqual(len(configs), searcher.total_samples) + self.assertTrue( + all(config["grid_1"] == "b" for config in configs[0:3])) + self.assertTrue( + all(config["grid_2"] == "z" for config in configs[3:7])) + self.assertTrue(configs[7]["grid_1"] == "a" + and configs[7]["grid_2"] == "y") + self.assertTrue(configs[8]["nested"]["random"] == 8.0) + self.assertTrue(configs[8]["nested"]["dependent"] == -8.0) + if __name__ == "__main__": import pytest diff --git a/python/ray/tune/tune.py b/python/ray/tune/tune.py index fe26e12e5..ab3df8ba8 100644 --- a/python/ray/tune/tune.py +++ b/python/ray/tune/tune.py @@ -171,7 +171,8 @@ def run( samples are generated until a stopping condition is met. local_dir (str): Local dir to save training results to. Defaults to ``~/ray_results``. - search_alg (Searcher): Search algorithm for optimization. + search_alg (Searcher|SearchAlgorithm): Search algorithm for + optimization. scheduler (TrialScheduler): Scheduler for executing the experiment. Choose among FIFO (default), MedianStopping, AsyncHyperBand, HyperBand and PopulationBasedTraining. Refer to From 53378170e07441e0bb6c9dbb3581bf6be289936f Mon Sep 17 00:00:00 2001 From: Farzan Taj Date: Thu, 17 Dec 2020 22:17:08 -0500 Subject: [PATCH 16/88] [tune] Change pickle to ray.cloudpickle -- support large models (#12958) * Change pickle to ray.cloudpickle * Change pickle import to ray.cloudpickle --- python/ray/tune/trainable.py | 2 +- python/ray/tune/utils/trainable.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/python/ray/tune/trainable.py b/python/ray/tune/trainable.py index aa08fd03b..c2915b19f 100644 --- a/python/ray/tune/trainable.py +++ b/python/ray/tune/trainable.py @@ -5,7 +5,7 @@ from datetime import datetime import copy import logging import os -import pickle +import ray.cloudpickle as pickle import platform from ray.tune.utils.trainable import TrainableUtil diff --git a/python/ray/tune/utils/trainable.py b/python/ray/tune/utils/trainable.py index bb9299793..cc61e3a83 100644 --- a/python/ray/tune/utils/trainable.py +++ b/python/ray/tune/utils/trainable.py @@ -5,7 +5,7 @@ import shutil from typing import Dict, Any import pandas as pd -import pickle +import ray.cloudpickle as pickle import os import ray From 17152c84a7e02581cccdfc337c9f540b9774a419 Mon Sep 17 00:00:00 2001 From: Tao Wang Date: Fri, 18 Dec 2020 11:22:13 +0800 Subject: [PATCH 17/88] [Tiny]Print raylet info after register (#12566) --- src/ray/raylet/raylet.cc | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/src/ray/raylet/raylet.cc b/src/ray/raylet/raylet.cc index d2ddead62..9add8e425 100644 --- a/src/ray/raylet/raylet.cc +++ b/src/ray/raylet/raylet.cc @@ -120,12 +120,13 @@ void Raylet::Stop() { ray::Status Raylet::RegisterGcs() { auto register_callback = [this](const Status &status) { RAY_CHECK_OK(status); - RAY_LOG(DEBUG) << "Node manager " << self_node_id_ << " started on " - << self_node_info_.node_manager_address() << ":" - << self_node_info_.node_manager_port() << " object manager at " - << self_node_info_.node_manager_address() << ":" - << self_node_info_.object_manager_port() << ", hostname " - << self_node_info_.node_manager_hostname(); + RAY_LOG(INFO) << "Raylet of id, " << self_node_id_ + << " started. Raylet consists of node_manager and object_manager." + << " node_manager address: " << self_node_info_.node_manager_address() + << ":" << self_node_info_.node_manager_port() + << " object_manager address: " << self_node_info_.node_manager_address() + << ":" << self_node_info_.object_manager_port() + << " hostname: " << self_node_info_.node_manager_address(); // Add resource information. const NodeManagerConfig &node_manager_config = node_manager_.GetInitialConfig(); From 6404f1e6095165e1192cb2f6fac806e96f973c30 Mon Sep 17 00:00:00 2001 From: "DK.Pino" Date: Fri, 18 Dec 2020 11:56:45 +0800 Subject: [PATCH 18/88] [Placement Group][New scheduler] New scheduler pg implementation (#12910) --- BUILD.bazel | 2 +- python/ray/tests/BUILD | 3 +- src/ray/raylet/node_manager.cc | 52 ++-- src/ray/raylet/node_manager.h | 22 -- .../placement_group_resource_manager.cc | 118 ++++++++- .../raylet/placement_group_resource_manager.h | 53 +++- .../placement_group_resource_manager_test.cc | 226 ++++++++++++++++++ .../scheduling/cluster_resource_data.cc | 3 +- .../scheduling/cluster_resource_scheduler.cc | 114 +++++++-- .../scheduling/cluster_resource_scheduler.h | 12 + .../cluster_resource_scheduler_test.cc | 15 +- .../raylet/scheduling/cluster_task_manager.cc | 1 + src/ray/raylet/test/util.h | 2 +- 13 files changed, 542 insertions(+), 81 deletions(-) diff --git a/BUILD.bazel b/BUILD.bazel index 6ffa8df54..16b9a315f 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -880,7 +880,7 @@ cc_test( ) cc_test( - name = "local_placement_group_manager_test", + name = "placement_group_resource_manager_test", srcs = ["src/ray/raylet/placement_group_resource_manager_test.cc"], copts = COPTS, deps = [ diff --git a/python/ray/tests/BUILD b/python/ray/tests/BUILD index 573131fec..55fac64e5 100644 --- a/python/ray/tests/BUILD +++ b/python/ray/tests/BUILD @@ -129,11 +129,10 @@ py_test_module_list( py_test_module_list( files = [ - "test_placement_group.py", # placement groups not implemented + "test_placement_group.py", ], size = "large", extra_srcs = SRCS, - tags = ["exclusive", "new_scheduler_broken"], deps = ["//:ray_lib"], ) diff --git a/src/ray/raylet/node_manager.cc b/src/ray/raylet/node_manager.cc index a289900d4..49bf6a6af 100644 --- a/src/ray/raylet/node_manager.cc +++ b/src/ray/raylet/node_manager.cc @@ -180,8 +180,6 @@ NodeManager::NodeManager(boost::asio::io_service &io_service, const NodeID &self last_local_gc_ns_(absl::GetCurrentTimeNanos()), local_gc_interval_ns_(RayConfig::instance().local_gc_interval_s() * 1e9), record_metrics_period_(config.record_metrics_period_ms) { - placement_group_resource_manager_ = std::make_shared( - local_available_resources_, cluster_resource_map_, self_node_id_); RAY_LOG(INFO) << "Initializing NodeManager with ID " << self_node_id_; RAY_CHECK(heartbeat_period_.count() > 0); // Initialize the resource map with own cluster resource configuration. @@ -227,6 +225,12 @@ NodeManager::NodeManager(boost::asio::io_service &io_service, const NodeID &self cluster_task_manager_ = std::shared_ptr(new ClusterTaskManager( self_node_id_, new_resource_scheduler_, fulfills_dependencies_func, is_owner_alive, get_node_info_func, announce_infeasible_task)); + placement_group_resource_manager_ = + std::make_shared(new_resource_scheduler_); + } else { + placement_group_resource_manager_ = + std::make_shared( + local_available_resources_, cluster_resource_map_, self_node_id_); } RAY_CHECK_OK(store_client_.Connect(config.store_socket_name.c_str())); @@ -458,6 +462,10 @@ void NodeManager::ReportResourceUsage() { resources_data->set_node_id(self_node_id_.Binary()); if (new_scheduler_enabled_) { + // Update local chche from gcs remote cache, this is needed when gcs restart. + // We should always keep the cache view consistent. + new_resource_scheduler_->UpdateLastReportResourcesFromGcs( + gcs_client_->Nodes().GetLastResourceUsage()); new_resource_scheduler_->FillResourceUsage(light_report_resource_usage_enabled_, resources_data); cluster_task_manager_->FillResourceUsage(light_report_resource_usage_enabled_, @@ -600,7 +608,6 @@ void NodeManager::HandleRequestObjectSpillage( void NodeManager::HandleReleaseUnusedBundles( const rpc::ReleaseUnusedBundlesRequest &request, rpc::ReleaseUnusedBundlesReply *reply, rpc::SendReplyCallback send_reply_callback) { - RAY_CHECK(!new_scheduler_enabled_) << "Not implemented"; RAY_LOG(DEBUG) << "Releasing unused bundles."; std::unordered_set in_use_bundles; for (int index = 0; index < request.bundles_in_use_size(); ++index) { @@ -1745,39 +1752,44 @@ void NodeManager::HandleRequestWorkerLease(const rpc::RequestWorkerLeaseRequest void NodeManager::HandlePrepareBundleResources( const rpc::PrepareBundleResourcesRequest &request, rpc::PrepareBundleResourcesReply *reply, rpc::SendReplyCallback send_reply_callback) { - // TODO(sang): Port this onto the new scheduler. - RAY_CHECK(!new_scheduler_enabled_) << "Not implemented yet."; auto bundle_spec = BundleSpecification(request.bundle_spec()); RAY_LOG(DEBUG) << "Request to prepare bundle resources is received, " << bundle_spec.DebugString(); + auto prepared = placement_group_resource_manager_->PrepareBundle(bundle_spec); reply->set_success(prepared); send_reply_callback(Status::OK(), nullptr, nullptr); - // Call task dispatch to assign work to the new group. - TryLocalInfeasibleTaskScheduling(); - DispatchTasks(local_queues_.GetReadyTasksByClass()); + + if (!new_scheduler_enabled_) { + // Call task dispatch to assign work to the new group. + TryLocalInfeasibleTaskScheduling(); + DispatchTasks(local_queues_.GetReadyTasksByClass()); + } } void NodeManager::HandleCommitBundleResources( const rpc::CommitBundleResourcesRequest &request, rpc::CommitBundleResourcesReply *reply, rpc::SendReplyCallback send_reply_callback) { - RAY_CHECK(!new_scheduler_enabled_) << "Not implemented yet."; - auto bundle_spec = BundleSpecification(request.bundle_spec()); RAY_LOG(DEBUG) << "Request to commit bundle resources is received, " << bundle_spec.DebugString(); placement_group_resource_manager_->CommitBundle(bundle_spec); send_reply_callback(Status::OK(), nullptr, nullptr); - // Call task dispatch to assign work to the new group. - TryLocalInfeasibleTaskScheduling(); - DispatchTasks(local_queues_.GetReadyTasksByClass()); + if (new_scheduler_enabled_) { + // Schedule in case a lease request for this placement group arrived before the commit + // message. + ScheduleAndDispatch(); + } else { + // Call task dispatch to assign work to the new group. + TryLocalInfeasibleTaskScheduling(); + DispatchTasks(local_queues_.GetReadyTasksByClass()); + } } void NodeManager::HandleCancelResourceReserve( const rpc::CancelResourceReserveRequest &request, rpc::CancelResourceReserveReply *reply, rpc::SendReplyCallback send_reply_callback) { - RAY_CHECK(!new_scheduler_enabled_) << "Not implemented"; auto bundle_spec = BundleSpecification(request.bundle_spec()); RAY_LOG(INFO) << "Request to cancel reserved resource is received, " << bundle_spec.DebugString(); @@ -1806,8 +1818,16 @@ void NodeManager::HandleCancelResourceReserve( // Return bundle resources. placement_group_resource_manager_->ReturnBundle(bundle_spec); - TryLocalInfeasibleTaskScheduling(); - DispatchTasks(local_queues_.GetReadyTasksByClass()); + + if (new_scheduler_enabled_) { + // Schedule in case a lease request for this placement group arrived before the commit + // message. + ScheduleAndDispatch(); + } else { + // Call task dispatch to assign work to the new group. + TryLocalInfeasibleTaskScheduling(); + DispatchTasks(local_queues_.GetReadyTasksByClass()); + } send_reply_callback(Status::OK(), nullptr, nullptr); } diff --git a/src/ray/raylet/node_manager.h b/src/ray/raylet/node_manager.h index f2a43935a..1bc554b11 100644 --- a/src/ray/raylet/node_manager.h +++ b/src/ray/raylet/node_manager.h @@ -297,28 +297,6 @@ class NodeManager : public rpc::NodeManagerServiceHandler { /// \return Void. void ScheduleTasks(std::unordered_map &resource_map); - /// Make a placement decision for the resource_map and subtract original resources so - /// that the node is ready to commit (create) placement group resources. - /// - /// \param resource_map A mapping from node manager ID to an estimate of the - /// resources available to that node manager. Scheduling decisions will only - /// consider the local node manager and the node managers in the keys of the - /// resource_map argument. - /// \param bundle_spec Specification of bundle that will be prepared. - /// \return True is resources were successfully prepared. False otherwise. - bool PrepareBundle(std::unordered_map &resource_map, - const BundleSpecification &bundle_spec); - - /// Make a placement decision for the resource_map. - /// - /// \param resource_map A mapping from node manager ID to an estimate of the - /// resources available to that node manager. Scheduling decisions will only - /// consider the local node manager and the node managers in the keys of the - /// resource_map argument. - /// \param bundle_spec Specification of bundle that will be prepared. - void CommitBundle(std::unordered_map &resource_map, - const BundleSpecification &bundle_spec); - /// Handle a task whose return value(s) must be reconstructed. /// /// \param task_id The relevant task ID. diff --git a/src/ray/raylet/placement_group_resource_manager.cc b/src/ray/raylet/placement_group_resource_manager.cc index 06a82663e..ff5b32719 100644 --- a/src/ray/raylet/placement_group_resource_manager.cc +++ b/src/ray/raylet/placement_group_resource_manager.cc @@ -22,6 +22,18 @@ namespace ray { namespace raylet { +void PlacementGroupResourceManager::ReturnUnusedBundle( + const std::unordered_set &in_use_bundles) { + for (auto iter = bundle_spec_map_.begin(); iter != bundle_spec_map_.end();) { + if (0 == in_use_bundles.count(iter->first)) { + ReturnBundle(*iter->second); + bundle_spec_map_.erase(iter++); + } else { + iter++; + } + } +} + OldPlacementGroupResourceManager::OldPlacementGroupResourceManager( ResourceIdSet &local_available_resources_, std::unordered_map &cluster_resource_map_, @@ -111,7 +123,7 @@ void OldPlacementGroupResourceManager::CommitBundle( void OldPlacementGroupResourceManager::ReturnBundle( const BundleSpecification &bundle_spec) { // We should commit resources if it weren't because - // ReturnBundleResources requires resources to be committed when it is called. + // ReturnBundle requires resources to be committed when it is called. auto it = bundle_state_map_.find(bundle_spec.BundleId()); if (it == bundle_state_map_.end()) { RAY_LOG(INFO) << "Duplicate cancel request, skip it directly."; @@ -136,16 +148,106 @@ void OldPlacementGroupResourceManager::ReturnBundle( ResourceSet(placement_group_resource_labels)); } -void OldPlacementGroupResourceManager::ReturnUnusedBundle( - const std::unordered_set &in_use_bundles) { - for (auto iter = bundle_spec_map_.begin(); iter != bundle_spec_map_.end();) { - if (0 == in_use_bundles.count(iter->first)) { - ReturnBundle(*iter->second); - bundle_spec_map_.erase(iter++); +NewPlacementGroupResourceManager::NewPlacementGroupResourceManager( + std::shared_ptr cluster_resource_scheduler_) + : cluster_resource_scheduler_(cluster_resource_scheduler_) {} + +bool NewPlacementGroupResourceManager::PrepareBundle( + const BundleSpecification &bundle_spec) { + auto iter = pg_bundles_.find(bundle_spec.BundleId()); + if (iter != pg_bundles_.end()) { + if (iter->second->state_ == CommitState::COMMITTED) { + // If the bundle state is already committed, it means that prepare request is just + // stale. + RAY_LOG(INFO) << "Duplicate prepare bundle request, skip it directly. This should " + "only happen when GCS restarts."; + return true; } else { - iter++; + // If there was a bundle in prepare state, it already locked resources, we will + // return bundle resources so that we can start from the prepare phase again. + ReturnBundle(bundle_spec); } } + + std::shared_ptr resource_instances = + std::make_shared(); + bool allocated = cluster_resource_scheduler_->AllocateLocalTaskResources( + bundle_spec.GetRequiredResources().GetResourceMap(), resource_instances); + + if (!allocated) { + return false; + } + + auto bundle_state = + std::make_shared(CommitState::PREPARED, resource_instances); + pg_bundles_[bundle_spec.BundleId()] = bundle_state; + bundle_spec_map_.emplace(bundle_spec.BundleId(), std::make_shared( + bundle_spec.GetMessage())); + + return true; +} + +void NewPlacementGroupResourceManager::CommitBundle( + const BundleSpecification &bundle_spec) { + auto it = pg_bundles_.find(bundle_spec.BundleId()); + if (it == pg_bundles_.end()) { + // We should only ever receive a commit for a non-existent placement group when a + // placement group is created and removed in quick succession. + RAY_LOG(DEBUG) + << "Received a commit message for an unknown bundle. The bundle info is " + << bundle_spec.DebugString(); + return; + } else { + // Ignore request If the bundle state is already committed. + if (it->second->state_ == CommitState::COMMITTED) { + RAY_LOG(INFO) << "Duplicate committ bundle request, skip it directly."; + return; + } + } + + const auto &bundle_state = it->second; + bundle_state->state_ = CommitState::COMMITTED; + + for (const auto &resource : bundle_spec.GetFormattedResources()) { + cluster_resource_scheduler_->AddLocalResource(resource.first, resource.second); + } +} + +void NewPlacementGroupResourceManager::ReturnBundle( + const BundleSpecification &bundle_spec) { + auto it = pg_bundles_.find(bundle_spec.BundleId()); + if (it == pg_bundles_.end()) { + RAY_LOG(INFO) << "Duplicate cancel request, skip it directly."; + return; + } + const auto &bundle_state = it->second; + if (bundle_state->state_ == CommitState::PREPARED) { + // Commit bundle first so that we can remove the bundle with consistent + // implementation. + CommitBundle(bundle_spec); + } + + // Return original resources to resource allocator `ClusterResourceScheduler`. + auto original_resources = it->second->resources_; + cluster_resource_scheduler_->FreeLocalTaskResources(original_resources); + + // Substract placement group resources from resource allocator + // `ClusterResourceScheduler`. + const auto &placement_group_resources = bundle_spec.GetFormattedResources(); + std::shared_ptr resource_instances = + std::make_shared(); + cluster_resource_scheduler_->AllocateLocalTaskResources(placement_group_resources, + resource_instances); + for (const auto &resource : placement_group_resources) { + if (cluster_resource_scheduler_->IsAvailableResourceEmpty(resource.first)) { + RAY_LOG(DEBUG) << "Available bundle resource:[" << resource.first + << "] is empty, Will delete it from local resource"; + // Delete local resource if available resource is empty when return bundle, or there + // will be resource leak. + cluster_resource_scheduler_->DeleteLocalResource(resource.first); + } + } + pg_bundles_.erase(it); } } // namespace raylet diff --git a/src/ray/raylet/placement_group_resource_manager.h b/src/ray/raylet/placement_group_resource_manager.h index 3b6e3a928..20ef5325a 100644 --- a/src/ray/raylet/placement_group_resource_manager.h +++ b/src/ray/raylet/placement_group_resource_manager.h @@ -18,6 +18,7 @@ #include "ray/common/bundle_spec.h" #include "ray/common/id.h" #include "ray/common/task/scheduling_resources.h" +#include "ray/raylet/scheduling/cluster_resource_scheduler.h" namespace ray { @@ -44,6 +45,14 @@ struct pair_hash { } }; +struct BundleTransactionState { + BundleTransactionState(CommitState state, + std::shared_ptr &resources) + : state_(state), resources_(resources) {} + CommitState state_; + std::shared_ptr resources_; +}; + /// `PlacementGroupResourceManager` responsible for managing the resources that /// about allocated for placement group bundles. class PlacementGroupResourceManager { @@ -68,16 +77,20 @@ class PlacementGroupResourceManager { /// Return back all the bundle(which is unused) resource. /// /// \param bundle_spec: A set of bundles which in use. - virtual void ReturnUnusedBundle( - const std::unordered_set &in_use_bundles) = 0; + void ReturnUnusedBundle(const std::unordered_set &in_use_bundles); virtual ~PlacementGroupResourceManager() {} + + protected: + /// Save `BundleSpecification` for cleaning leaked bundles after GCS restart. + absl::flat_hash_map, pair_hash> + bundle_spec_map_; }; /// Associated with old scheduler. class OldPlacementGroupResourceManager : public PlacementGroupResourceManager { public: - /// Create a local placement group manager. + /// Create a old placement group resource manager. /// /// \param local_available_resources_: The resources (IDs specificed) that are currently /// available. @@ -98,8 +111,6 @@ class OldPlacementGroupResourceManager : public PlacementGroupResourceManager { void ReturnBundle(const BundleSpecification &bundle_spec); - void ReturnUnusedBundle(const std::unordered_set &in_use_bundles); - /// Get all local available resource(IDs specificed). const ResourceIdSet &GetAllResourceIdSet() const { return local_available_resources_; }; @@ -121,10 +132,36 @@ class OldPlacementGroupResourceManager : public PlacementGroupResourceManager { /// creation. absl::flat_hash_map, pair_hash> bundle_state_map_; +}; - /// Save `BundleSpecification` for cleaning leaked bundles after GCS restart. - absl::flat_hash_map, pair_hash> - bundle_spec_map_; +/// Associated with new scheduler. +class NewPlacementGroupResourceManager : public PlacementGroupResourceManager { + public: + /// Create a new placement group resource manager. + /// + /// \param cluster_resource_scheduler_: The resource allocator of new scheduler. + NewPlacementGroupResourceManager( + std::shared_ptr cluster_resource_scheduler_); + + virtual ~NewPlacementGroupResourceManager() = default; + + bool PrepareBundle(const BundleSpecification &bundle_spec); + + void CommitBundle(const BundleSpecification &bundle_spec); + + void ReturnBundle(const BundleSpecification &bundle_spec); + + const std::shared_ptr GetResourceScheduler() const { + return cluster_resource_scheduler_; + } + + private: + std::shared_ptr cluster_resource_scheduler_; + + /// Tracking placement group bundles and their states. This mapping is the source of + /// truth for the new scheduler. + std::unordered_map, pair_hash> + pg_bundles_; }; } // namespace raylet diff --git a/src/ray/raylet/placement_group_resource_manager_test.cc b/src/ray/raylet/placement_group_resource_manager_test.cc index 10011aece..e78a21b40 100644 --- a/src/ray/raylet/placement_group_resource_manager_test.cc +++ b/src/ray/raylet/placement_group_resource_manager_test.cc @@ -262,6 +262,232 @@ TEST_F(OldPlacementGroupResourceManagerTest, TestIdempotencyWithRandomOrder) { CheckRemainingResourceCorrect(result_resource); } +class NewPlacementGroupResourceManagerTest : public ::testing::Test { + public: + std::unique_ptr + new_placement_group_resource_manager_; + + void InitLocalAvailableResource( + std::unordered_map &unit_resource) { + auto cluster_resource_scheduler_ = + std::make_shared("local", unit_resource); + new_placement_group_resource_manager_.reset( + new raylet::NewPlacementGroupResourceManager(cluster_resource_scheduler_)); + } + + void CheckAvailableResoueceEmpty(const std::string &resource) { + const auto cluster_resource_scheduler_ = + new_placement_group_resource_manager_->GetResourceScheduler(); + ASSERT_TRUE(cluster_resource_scheduler_->IsAvailableResourceEmpty(resource)); + } + + void CheckRemainingResourceCorrect(NodeResourceInstances &node_resource_instances) { + const auto cluster_resource_scheduler_ = + new_placement_group_resource_manager_->GetResourceScheduler(); + ASSERT_TRUE(cluster_resource_scheduler_->GetLocalResources() == + node_resource_instances); + } +}; + +TEST_F(NewPlacementGroupResourceManagerTest, TestNewPrepareBundleResource) { + // 1. create bundle spec. + auto group_id = PlacementGroupID::FromRandom(); + std::unordered_map unit_resource; + unit_resource.insert({"CPU", 1.0}); + auto bundle_spec = Mocker::GenBundleCreation(group_id, 1, unit_resource); + /// 2. init local available resource. + InitLocalAvailableResource(unit_resource); + /// 3. prepare bundle resource. + ASSERT_TRUE(new_placement_group_resource_manager_->PrepareBundle(bundle_spec)); + /// 4. check remaining resources is correct. + CheckAvailableResoueceEmpty("CPU"); +} + +TEST_F(NewPlacementGroupResourceManagerTest, + TestNewPrepareBundleWithInsufficientResource) { + // 1. create bundle spec. + auto group_id = PlacementGroupID::FromRandom(); + std::unordered_map unit_resource; + unit_resource.insert({"CPU", 2.0}); + auto bundle_spec = Mocker::GenBundleCreation(group_id, 1, unit_resource); + /// 2. init local available resource. + std::unordered_map init_unit_resource; + init_unit_resource.insert({"CPU", 1.0}); + InitLocalAvailableResource(init_unit_resource); + /// 3. prepare bundle resource. + ASSERT_FALSE(new_placement_group_resource_manager_->PrepareBundle(bundle_spec)); +} + +TEST_F(NewPlacementGroupResourceManagerTest, TestNewCommitBundleResource) { + // 1. create bundle spec. + auto group_id = PlacementGroupID::FromRandom(); + std::unordered_map unit_resource; + unit_resource.insert({"CPU", 1.0}); + auto bundle_spec = Mocker::GenBundleCreation(group_id, 1, unit_resource); + /// 2. init local available resource. + InitLocalAvailableResource(unit_resource); + /// 3. prepare and commit bundle resource. + ASSERT_TRUE(new_placement_group_resource_manager_->PrepareBundle(bundle_spec)); + new_placement_group_resource_manager_->CommitBundle(bundle_spec); + /// 4. check remaining resources is correct. + std::unordered_map remaining_resources = { + {"CPU_group_" + group_id.Hex(), 1.0}, + {"CPU_group_1_" + group_id.Hex(), 1.0}, + {"CPU", 1.0}}; + auto remaining_resource_scheduler = + std::make_shared("remaining", remaining_resources); + std::shared_ptr resource_instances = + std::make_shared(); + ASSERT_TRUE(remaining_resource_scheduler->AllocateLocalTaskResources( + unit_resource, resource_instances)); + auto remaining_resouece_instance = remaining_resource_scheduler->GetLocalResources(); + CheckRemainingResourceCorrect(remaining_resouece_instance); +} + +TEST_F(NewPlacementGroupResourceManagerTest, TestNewReturnBundleResource) { + // 1. create bundle spec. + auto group_id = PlacementGroupID::FromRandom(); + std::unordered_map unit_resource; + unit_resource.insert({"CPU", 1.0}); + auto bundle_spec = Mocker::GenBundleCreation(group_id, 1, unit_resource); + /// 2. init local available resource. + InitLocalAvailableResource(unit_resource); + /// 3. prepare and commit bundle resource. + ASSERT_TRUE(new_placement_group_resource_manager_->PrepareBundle(bundle_spec)); + new_placement_group_resource_manager_->CommitBundle(bundle_spec); + /// 4. return bundle resource. + new_placement_group_resource_manager_->ReturnBundle(bundle_spec); + /// 5. check remaining resources is correct. + auto remaining_resource_scheduler = + std::make_shared("remaining", unit_resource); + auto remaining_resouece_instance = remaining_resource_scheduler->GetLocalResources(); + CheckRemainingResourceCorrect(remaining_resouece_instance); +} + +TEST_F(NewPlacementGroupResourceManagerTest, TestNewMultipleBundlesCommitAndReturn) { + // 1. create two bundles spec. + auto group_id = PlacementGroupID::FromRandom(); + std::unordered_map unit_resource; + unit_resource.insert({"CPU", 1.0}); + auto first_bundle_spec = Mocker::GenBundleCreation(group_id, 1, unit_resource); + auto second_bundle_spec = Mocker::GenBundleCreation(group_id, 2, unit_resource); + /// 2. init local available resource. + std::unordered_map init_unit_resource; + init_unit_resource.insert({"CPU", 2.0}); + InitLocalAvailableResource(init_unit_resource); + /// 3. prepare and commit two bundle resource. + ASSERT_TRUE(new_placement_group_resource_manager_->PrepareBundle(first_bundle_spec)); + ASSERT_TRUE(new_placement_group_resource_manager_->PrepareBundle(second_bundle_spec)); + new_placement_group_resource_manager_->CommitBundle(first_bundle_spec); + new_placement_group_resource_manager_->CommitBundle(second_bundle_spec); + /// 4. check remaining resources is correct after commit phase. + std::unordered_map remaining_resources = { + {"CPU_group_" + group_id.Hex(), 2.0}, + {"CPU_group_1_" + group_id.Hex(), 1.0}, + {"CPU_group_2_" + group_id.Hex(), 1.0}, + {"CPU", 2.0}}; + auto remaining_resource_scheduler = + std::make_shared("remaining", remaining_resources); + std::shared_ptr resource_instances = + std::make_shared(); + ASSERT_TRUE(remaining_resource_scheduler->AllocateLocalTaskResources( + init_unit_resource, resource_instances)); + auto remaining_resouece_instance = remaining_resource_scheduler->GetLocalResources(); + CheckRemainingResourceCorrect(remaining_resouece_instance); + /// 5. return second bundle. + new_placement_group_resource_manager_->ReturnBundle(second_bundle_spec); + /// 6. check remaining resources is correct after return second bundle. + remaining_resources = {{"CPU_group_" + group_id.Hex(), 2.0}, + {"CPU_group_1_" + group_id.Hex(), 1.0}, + {"CPU", 2.0}}; + remaining_resource_scheduler = + std::make_shared("remaining", remaining_resources); + ASSERT_TRUE(remaining_resource_scheduler->AllocateLocalTaskResources( + {{"CPU_group_" + group_id.Hex(), 1.0}, {"CPU", 1.0}}, resource_instances)); + remaining_resouece_instance = remaining_resource_scheduler->GetLocalResources(); + CheckRemainingResourceCorrect(remaining_resouece_instance); + /// 7. return first bundel. + new_placement_group_resource_manager_->ReturnBundle(first_bundle_spec); + /// 8. check remaining resources is correct after all bundle returned. + remaining_resources = {{"CPU", 2.0}}; + remaining_resource_scheduler = + std::make_shared("remaining", remaining_resources); + remaining_resouece_instance = remaining_resource_scheduler->GetLocalResources(); + CheckRemainingResourceCorrect(remaining_resouece_instance); +} + +TEST_F(NewPlacementGroupResourceManagerTest, TestNewIdempotencyWithMultiPrepare) { + // 1. create one bundle spec. + auto group_id = PlacementGroupID::FromRandom(); + std::unordered_map unit_resource; + unit_resource.insert({"CPU", 1.0}); + auto bundle_spec = Mocker::GenBundleCreation(group_id, 1, unit_resource); + /// 2. init local available resource. + std::unordered_map available_resource = { + std::make_pair("CPU", 3.0)}; + InitLocalAvailableResource(available_resource); + /// 3. prepare bundle resource 10 times. + for (int i = 0; i < 10; i++) { + new_placement_group_resource_manager_->PrepareBundle(bundle_spec); + } + /// 4. check remaining resources is correct. + std::unordered_map remaining_resources = {{"CPU", 3.0}}; + auto remaining_resource_scheduler = + std::make_shared("remaining", remaining_resources); + std::shared_ptr resource_instances = + std::make_shared(); + ASSERT_TRUE(remaining_resource_scheduler->AllocateLocalTaskResources( + unit_resource, resource_instances)); + auto remaining_resouece_instance = remaining_resource_scheduler->GetLocalResources(); + CheckRemainingResourceCorrect(remaining_resouece_instance); +} + +TEST_F(NewPlacementGroupResourceManagerTest, TestNewIdempotencyWithRandomOrder) { + // 1. create one bundle spec. + auto group_id = PlacementGroupID::FromRandom(); + std::unordered_map unit_resource; + unit_resource.insert({"CPU", 1.0}); + auto bundle_spec = Mocker::GenBundleCreation(group_id, 1, unit_resource); + /// 2. init local available resource. + std::unordered_map available_resource = { + std::make_pair("CPU", 3.0)}; + InitLocalAvailableResource(available_resource); + /// 3. prepare bundle -> commit bundle -> prepare bundle. + ASSERT_TRUE(new_placement_group_resource_manager_->PrepareBundle(bundle_spec)); + new_placement_group_resource_manager_->CommitBundle(bundle_spec); + ASSERT_TRUE(new_placement_group_resource_manager_->PrepareBundle(bundle_spec)); + /// 4. check remaining resources is correct. + std::unordered_map remaining_resources = { + {"CPU_group_" + group_id.Hex(), 1.0}, + {"CPU_group_1_" + group_id.Hex(), 1.0}, + {"CPU", 3.0}}; + auto remaining_resource_scheduler = + std::make_shared("remaining", remaining_resources); + std::shared_ptr resource_instances = + std::make_shared(); + ASSERT_TRUE(remaining_resource_scheduler->AllocateLocalTaskResources( + unit_resource, resource_instances)); + auto remaining_resouece_instance = remaining_resource_scheduler->GetLocalResources(); + CheckRemainingResourceCorrect(remaining_resouece_instance); + new_placement_group_resource_manager_->ReturnBundle(bundle_spec); + // 5. prepare bundle -> commit bundle -> commit bundle. + ASSERT_TRUE(new_placement_group_resource_manager_->PrepareBundle(bundle_spec)); + new_placement_group_resource_manager_->CommitBundle(bundle_spec); + new_placement_group_resource_manager_->CommitBundle(bundle_spec); + // 6. check remaining resources is correct. + CheckRemainingResourceCorrect(remaining_resouece_instance); + new_placement_group_resource_manager_->ReturnBundle(bundle_spec); + // 7. prepare bundle -> return bundle -> commit bundle. + ASSERT_TRUE(new_placement_group_resource_manager_->PrepareBundle(bundle_spec)); + new_placement_group_resource_manager_->ReturnBundle(bundle_spec); + new_placement_group_resource_manager_->CommitBundle(bundle_spec); + // 8. check remaining resources is correct. + remaining_resource_scheduler = + std::make_shared("remaining", available_resource); + remaining_resouece_instance = remaining_resource_scheduler->GetLocalResources(); + CheckRemainingResourceCorrect(remaining_resouece_instance); +} + } // namespace ray int main(int argc, char **argv) { diff --git a/src/ray/raylet/scheduling/cluster_resource_data.cc b/src/ray/raylet/scheduling/cluster_resource_data.cc index cb0214dab..551b5a980 100644 --- a/src/ray/raylet/scheduling/cluster_resource_data.cc +++ b/src/ray/raylet/scheduling/cluster_resource_data.cc @@ -291,8 +291,7 @@ std::string NodeResourceInstances::DebugString(StringIdMap string_to_int_map) co } for (auto it = this->custom_resources.begin(); it != this->custom_resources.end(); ++it) { - buffer << "\t" << string_to_int_map.Get(it->first) << ":(" - << VectorToString(it->second.total) << ":" + buffer << "\t" << it->first << ":(" << VectorToString(it->second.total) << ":" << VectorToString(it->second.available) << ")\n"; } buffer << "}" << std::endl; diff --git a/src/ray/raylet/scheduling/cluster_resource_scheduler.cc b/src/ray/raylet/scheduling/cluster_resource_scheduler.cc index 2590ed98f..66047d258 100644 --- a/src/ray/raylet/scheduling/cluster_resource_scheduler.cc +++ b/src/ray/raylet/scheduling/cluster_resource_scheduler.cc @@ -352,6 +352,40 @@ void ClusterResourceScheduler::AddLocalResource(const std::string &resource_name } } +bool ClusterResourceScheduler::IsAvailableResourceEmpty( + const std::string &resource_name) { + auto it = nodes_.find(local_node_id_); + if (it == nodes_.end()) { + RAY_LOG(WARNING) << "Can't find local node:[" << local_node_id_ + << "] when check local available resource."; + return true; + } + + int idx = -1; + if (resource_name == ray::kCPU_ResourceLabel) { + idx = (int)CPU; + } else if (resource_name == ray::kGPU_ResourceLabel) { + idx = (int)GPU; + } else if (resource_name == ray::kTPU_ResourceLabel) { + idx = (int)TPU; + } else if (resource_name == ray::kMemory_ResourceLabel) { + idx = (int)MEM; + }; + + auto local_view = it->second.GetMutableLocalView(); + if (idx != -1) { + return local_view->predefined_resources[idx].available <= 0; + } + string_to_int_map_.Insert(resource_name); + int64_t resource_id = string_to_int_map_.Get(resource_name); + auto itr = local_view->custom_resources.find(resource_id); + if (itr != local_view->custom_resources.end()) { + return itr->second.available <= 0; + } else { + return true; + } +} + void ClusterResourceScheduler::UpdateResourceCapacity(const std::string &node_id_string, const std::string &resource_name, double resource_total) { @@ -448,9 +482,11 @@ void ClusterResourceScheduler::DeleteResource(const std::string &node_id_string, local_view->custom_resources.erase(itr); } - if (node_id == local_node_id_) { + auto c_itr = local_resources_.custom_resources.find(resource_id); + if (node_id == local_node_id_ && c_itr != local_resources_.custom_resources.end()) { local_resources_.custom_resources[resource_id].total.clear(); local_resources_.custom_resources[resource_id].available.clear(); + local_resources_.custom_resources.erase(c_itr); } } } @@ -835,6 +871,14 @@ void ClusterResourceScheduler::FreeLocalTaskResources( UpdateLocalAvailableResourcesFromResourceInstances(); } +void ClusterResourceScheduler::UpdateLastReportResourcesFromGcs( + std::shared_ptr gcs_resources) { + NodeResources node_resources = ResourceMapToNodeResources( + string_to_int_map_, gcs_resources->GetTotalResources().GetResourceMap(), + gcs_resources->GetAvailableResources().GetResourceMap()); + last_report_resources_.reset(new NodeResources(node_resources)); +} + void ClusterResourceScheduler::FillResourceUsage( bool light_report_resource_usage_enabled, std::shared_ptr resources_data) { @@ -844,8 +888,56 @@ void ClusterResourceScheduler::FillResourceUsage( << "Error: Populating heartbeat failed. Please file a bug report: " "https://github.com/ray-project/ray/issues/new."; - if (!light_report_resource_usage_enabled || !last_report_resources_ || - resources != *last_report_resources_.get()) { + // Initialize if last report resources is empty. + if (!last_report_resources_) { + NodeResources node_resources = + ResourceMapToNodeResources(string_to_int_map_, {{}}, {{}}); + last_report_resources_.reset(new NodeResources(node_resources)); + } + + if (light_report_resource_usage_enabled) { + // Reset all local views for remote nodes. This is needed in case tasks that + // we spilled back to a remote node were not actually scheduled on the + // node. Then, the remote node's resource availability may not change and + // so it may not send us another update. + for (auto &node : nodes_) { + if (node.first != local_node_id_) { + node.second.ResetLocalView(); + } + } + + for (int i = 0; i < PredefinedResources_MAX; i++) { + const auto &label = ResourceEnumToString((PredefinedResources)i); + const auto &capacity = resources.predefined_resources[i]; + const auto &last_capacity = last_report_resources_->predefined_resources[i]; + if (capacity.available != last_capacity.available) { + resources_data->set_resources_available_changed(true); + (*resources_data->mutable_resources_available())[label] = + capacity.available.Double(); + } + if (capacity.total != last_capacity.total) { + (*resources_data->mutable_resources_total())[label] = capacity.total.Double(); + } + } + for (auto it = resources.custom_resources.begin(); + it != resources.custom_resources.end(); it++) { + uint64_t custom_id = it->first; + const auto &capacity = it->second; + const auto &last_capacity = last_report_resources_->custom_resources[custom_id]; + const auto &label = string_to_int_map_.Get(custom_id); + if (capacity.available != last_capacity.available) { + resources_data->set_resources_available_changed(true); + (*resources_data->mutable_resources_available())[label] = + capacity.available.Double(); + } + if (capacity.total != last_capacity.total) { + (*resources_data->mutable_resources_total())[label] = capacity.total.Double(); + } + } + if (resources != *last_report_resources_.get()) { + last_report_resources_.reset(new NodeResources(resources)); + } + } else { for (int i = 0; i < PredefinedResources_MAX; i++) { const auto &label = ResourceEnumToString((PredefinedResources)i); const auto &capacity = resources.predefined_resources[i]; @@ -870,22 +962,6 @@ void ClusterResourceScheduler::FillResourceUsage( (*resources_data->mutable_resources_total())[label] = capacity.total.Double(); } } - resources_data->set_resources_available_changed(true); - if (light_report_resource_usage_enabled) { - last_report_resources_.reset(new NodeResources(resources)); - } - } - - if (light_report_resource_usage_enabled) { - // Reset all local views for remote nodes. This is needed in case tasks that - // we spilled back to a remote node were not actually scheduled on the - // node. Then, the remote node's resource availability may not change and - // so it may not send us another update. - for (auto &node : nodes_) { - if (node.first != local_node_id_) { - node.second.ResetLocalView(); - } - } } } diff --git a/src/ray/raylet/scheduling/cluster_resource_scheduler.h b/src/ray/raylet/scheduling/cluster_resource_scheduler.h index c4058e586..470c97c38 100644 --- a/src/ray/raylet/scheduling/cluster_resource_scheduler.h +++ b/src/ray/raylet/scheduling/cluster_resource_scheduler.h @@ -164,6 +164,11 @@ class ClusterResourceScheduler { /// \param resource_total: New capacity of the resource. void AddLocalResource(const std::string &resource_name, double resource_total); + /// Check whether the available resources are empty. + /// + /// \param resource_name: Resource which we want to check. + bool IsAvailableResourceEmpty(const std::string &resource_name); + /// Update total capacity of a given resource of a given node. /// /// \param node_name: Node whose resource we want to update. @@ -360,6 +365,13 @@ class ClusterResourceScheduler { void FillResourceUsage(bool light_report_resource_usage_enabled, std::shared_ptr resources_data); + /// Update last report resources local cache from gcs cache, + /// this is needed when gcs fo. + /// + /// \param gcs_resources: The remote cache from gcs. + void UpdateLastReportResourcesFromGcs( + std::shared_ptr gcs_resources); + /// Return human-readable string for this scheduler state. std::string DebugString() const; diff --git a/src/ray/raylet/scheduling/cluster_resource_scheduler_test.cc b/src/ray/raylet/scheduling/cluster_resource_scheduler_test.cc index 37bf4b3da..db8fa44ed 100644 --- a/src/ray/raylet/scheduling/cluster_resource_scheduler_test.cc +++ b/src/ray/raylet/scheduling/cluster_resource_scheduler_test.cc @@ -1131,10 +1131,10 @@ TEST_F(ClusterResourceSchedulerTest, TestLightResourceUsageReport) { } // Report resource usage if resource availability has changed. - cluster_resources.AddOrUpdateNode("local", {{"CPU", 1.}}, {{"CPU", 0.}}); + cluster_resources.AddOrUpdateNode("local", {{"CPU", 2.}}, {{"CPU", 0.}}); data->Clear(); cluster_resources.FillResourceUsage(true, data); - ASSERT_RESOURCES_EQ(data, 0, 1); + ASSERT_RESOURCES_EQ(data, 0, 2); // Don't report resource usage if resource availability hasn't changed. for (int i = 0; i < 3; i++) { @@ -1224,6 +1224,17 @@ TEST_F(ClusterResourceSchedulerTest, DynamicResourceTest) { ASSERT_TRUE(result.empty()); } +TEST_F(ClusterResourceSchedulerTest, AvailableResourceEmptyTest) { + ClusterResourceScheduler cluster_resources("local", {{"custom123", 5}}); + std::shared_ptr resource_instances = + std::make_shared(); + std::unordered_map task_request = {{"custom123", 5}}; + bool allocated = + cluster_resources.AllocateLocalTaskResources(task_request, resource_instances); + ASSERT_TRUE(allocated); + ASSERT_TRUE(cluster_resources.IsAvailableResourceEmpty("custom123")); +} + } // namespace ray int main(int argc, char **argv) { diff --git a/src/ray/raylet/scheduling/cluster_task_manager.cc b/src/ray/raylet/scheduling/cluster_task_manager.cc index bc86e280f..74437a4a1 100644 --- a/src/ray/raylet/scheduling/cluster_task_manager.cc +++ b/src/ray/raylet/scheduling/cluster_task_manager.cc @@ -202,6 +202,7 @@ bool ClusterTaskManager::AttemptDispatchWork(const Work &work, dispatched = true; } } else { + worker->SetBundleId(spec.PlacementGroupBundleId()); worker->SetOwnerAddress(spec.CallerAddress()); if (spec.IsActorCreationTask()) { // The actor belongs to this worker now. diff --git a/src/ray/raylet/test/util.h b/src/ray/raylet/test/util.h index 4d64507b8..90fc9b158 100644 --- a/src/ray/raylet/test/util.h +++ b/src/ray/raylet/test/util.h @@ -170,7 +170,7 @@ class MockWorker : public WorkerInterface { return bundle_id_; } - void SetBundleId(const BundleID &bundle_id) { RAY_CHECK(false) << "Method unused"; } + void SetBundleId(const BundleID &bundle_id) { bundle_id_ = bundle_id; } std::vector &GetBorrowedCPUInstances() { return borrowed_cpu_instances_; } From cfefd7c70ece76f4b94386f9f85792c6fee751cf Mon Sep 17 00:00:00 2001 From: dHannasch Date: Thu, 17 Dec 2020 22:15:42 -0700 Subject: [PATCH 19/88] Test PingPort (#12954) Co-authored-by: Richard Liaw --- src/ray/gcs/test/asio_test.cc | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/ray/gcs/test/asio_test.cc b/src/ray/gcs/test/asio_test.cc index 052a95ec6..976562a59 100644 --- a/src/ray/gcs/test/asio_test.cc +++ b/src/ray/gcs/test/asio_test.cc @@ -69,6 +69,13 @@ TEST_F(RedisAsioTest, TestRedisCommands) { std::shared_ptr shard_context = std::make_shared(io_service); + ASSERT_TRUE( + shard_context->PingPort(std::string("127.0.0.1"), TEST_REDIS_SERVER_PORTS.front()) + .ok()); + ASSERT_FALSE( + shard_context + ->PingPort(std::string("127.0.0.1"), TEST_REDIS_SERVER_PORTS.front() + 987) + .ok()); ASSERT_TRUE(shard_context ->Connect(std::string("127.0.0.1"), TEST_REDIS_SERVER_PORTS.front(), /*sharding=*/true, From a442cd17e0b3610fb4a1a580605767cc600d4c22 Mon Sep 17 00:00:00 2001 From: fangfengbin <869218239a@zju.edu.cn> Date: Fri, 18 Dec 2020 13:57:37 +0800 Subject: [PATCH 20/88] [GCS]Optimize gcs client reconnection (#12878) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * [GCS]Optimize gcs client reconnection * fix review comment * fix review comment * add part code Co-authored-by: 灵洵 --- src/ray/common/ray_config_def.h | 3 ++ .../gcs/gcs_client/service_based_accessor.cc | 37 ++++++++++++++----- .../gcs_client/service_based_gcs_client.cc | 32 +++++++++++++--- .../gcs/gcs_client/service_based_gcs_client.h | 2 + 4 files changed, 59 insertions(+), 15 deletions(-) diff --git a/src/ray/common/ray_config_def.h b/src/ray/common/ray_config_def.h index a1af82469..c702102ad 100644 --- a/src/ray/common/ray_config_def.h +++ b/src/ray/common/ray_config_def.h @@ -266,6 +266,9 @@ RAY_CONFIG(int64_t, ping_gcs_rpc_server_interval_milliseconds, 1000) /// Maximum number of times to retry ping gcs rpc server when gcs server restarts. RAY_CONFIG(int32_t, ping_gcs_rpc_server_max_retries, 1) +/// Minimum interval between reconnecting gcs rpc server when gcs server restarts. +RAY_CONFIG(int32_t, minimum_gcs_reconnect_interval_milliseconds, 5000) + /// Whether start the Plasma Store as a Raylet thread. RAY_CONFIG(bool, plasma_store_as_thread, false) diff --git a/src/ray/gcs/gcs_client/service_based_accessor.cc b/src/ray/gcs/gcs_client/service_based_accessor.cc index cc2907c63..2cf1d2caf 100644 --- a/src/ray/gcs/gcs_client/service_based_accessor.cc +++ b/src/ray/gcs/gcs_client/service_based_accessor.cc @@ -90,17 +90,23 @@ Status ServiceBasedJobInfoAccessor::AsyncSubscribeAll( void ServiceBasedJobInfoAccessor::AsyncResubscribe(bool is_pubsub_server_restarted) { RAY_LOG(DEBUG) << "Reestablishing subscription for job info."; + auto fetch_all_done = [](const Status &status) { + RAY_LOG(INFO) << "Finished fetching all job information from gcs server after gcs " + "server or pub-sub server is restarted."; + }; + // If only the GCS sever has restarted, we only need to fetch data from the GCS server. // If the pub-sub server has also restarted, we need to resubscribe to the pub-sub // server first, then fetch data from the GCS server. if (is_pubsub_server_restarted) { if (subscribe_operation_ != nullptr) { - RAY_CHECK_OK(subscribe_operation_( - [this](const Status &status) { fetch_all_data_operation_(nullptr); })); + RAY_CHECK_OK(subscribe_operation_([this, fetch_all_done](const Status &status) { + fetch_all_data_operation_(fetch_all_done); + })); } } else { if (fetch_all_data_operation_ != nullptr) { - fetch_all_data_operation_(nullptr); + fetch_all_data_operation_(fetch_all_done); } } } @@ -301,14 +307,20 @@ Status ServiceBasedActorInfoAccessor::AsyncUnsubscribe(const ActorID &actor_id) void ServiceBasedActorInfoAccessor::AsyncResubscribe(bool is_pubsub_server_restarted) { RAY_LOG(DEBUG) << "Reestablishing subscription for actor info."; + auto fetch_all_done = [](const Status &status) { + RAY_LOG(INFO) << "Finished fetching all actor information from gcs server after gcs " + "server or pub-sub server is restarted."; + }; + // If only the GCS sever has restarted, we only need to fetch data from the GCS server. // If the pub-sub server has also restarted, we need to resubscribe to the pub-sub // server first, then fetch data from the GCS server. absl::MutexLock lock(&mutex_); if (is_pubsub_server_restarted) { if (subscribe_all_operation_ != nullptr) { - RAY_CHECK_OK(subscribe_all_operation_( - [this](const Status &status) { fetch_all_data_operation_(nullptr); })); + RAY_CHECK_OK(subscribe_all_operation_([this, fetch_all_done](const Status &status) { + fetch_all_data_operation_(fetch_all_done); + })); } for (auto &item : subscribe_operations_) { auto &actor_id = item.first; @@ -325,7 +337,7 @@ void ServiceBasedActorInfoAccessor::AsyncResubscribe(bool is_pubsub_server_resta } } else { if (fetch_all_data_operation_ != nullptr) { - fetch_all_data_operation_(nullptr); + fetch_all_data_operation_(fetch_all_done); } for (auto &item : fetch_data_operations_) { item.second(nullptr); @@ -651,20 +663,27 @@ void ServiceBasedNodeInfoAccessor::HandleNotification(const GcsNodeInfo &node_in void ServiceBasedNodeInfoAccessor::AsyncResubscribe(bool is_pubsub_server_restarted) { RAY_LOG(DEBUG) << "Reestablishing subscription for node info."; + auto fetch_all_done = [](const Status &status) { + RAY_LOG(INFO) << "Finished fetching all node information from gcs server after gcs " + "server or pub-sub server is restarted."; + }; + // If only the GCS sever has restarted, we only need to fetch data from the GCS server. // If the pub-sub server has also restarted, we need to resubscribe to the pub-sub // server first, then fetch data from the GCS server. if (is_pubsub_server_restarted) { if (subscribe_node_operation_ != nullptr) { - RAY_CHECK_OK(subscribe_node_operation_( - [this](const Status &status) { fetch_node_data_operation_(nullptr); })); + RAY_CHECK_OK( + subscribe_node_operation_([this, fetch_all_done](const Status &status) { + fetch_node_data_operation_(fetch_all_done); + })); } if (subscribe_batch_resource_usage_operation_ != nullptr) { RAY_CHECK_OK(subscribe_batch_resource_usage_operation_(nullptr)); } } else { if (fetch_node_data_operation_ != nullptr) { - fetch_node_data_operation_(nullptr); + fetch_node_data_operation_(fetch_all_done); } } } diff --git a/src/ray/gcs/gcs_client/service_based_gcs_client.cc b/src/ray/gcs/gcs_client/service_based_gcs_client.cc index 359f6cd81..884612106 100644 --- a/src/ray/gcs/gcs_client/service_based_gcs_client.cc +++ b/src/ray/gcs/gcs_client/service_based_gcs_client.cc @@ -25,7 +25,9 @@ namespace ray { namespace gcs { ServiceBasedGcsClient::ServiceBasedGcsClient(const GcsClientOptions &options) - : GcsClient(options) {} + : GcsClient(options), + last_reconnect_timestamp_ms_(0), + last_reconnect_address_(std::make_pair("", -1)) {} Status ServiceBasedGcsClient::Connect(boost::asio::io_service &io_service) { RAY_CHECK(!is_connected_); @@ -175,7 +177,7 @@ void ServiceBasedGcsClient::GcsServiceFailureDetected(rpc::GcsServiceFailureType ReconnectGcsServer(); // NOTE(ffbin): Currently we don't support the case where the pub-sub server restarts, // because we use the same Redis server for both GCS storage and pub-sub. So the - // following flag is alway false. + // following flag is always false. resubscribe_func_(false); // Resend resource usage after reconnected, needed by resource view in GCS. node_accessor_->AsyncReReportResourceUsage(); @@ -191,11 +193,27 @@ void ServiceBasedGcsClient::ReconnectGcsServer() { int index = 0; for (; index < RayConfig::instance().ping_gcs_rpc_server_max_retries(); ++index) { if (get_server_address_func_(&address)) { - RAY_LOG(DEBUG) << "Attemptting to reconnect to GCS server: " << address.first << ":" - << address.second; + // After GCS is restarted, the gcs client will reestablish the connection. At + // present, every failed RPC request will trigger `ReconnectGcsServer`. In order to + // avoid repeated connections in a short period of time, we add a protection + // mechanism: if the address does not change (meaning gcs server doesn't restart), + // the connection can be made at most once in + // `minimum_gcs_reconnect_interval_milliseconds` milliseconds. + if (last_reconnect_address_ == address && + (current_sys_time_ms() - last_reconnect_timestamp_ms_) < + RayConfig::instance().minimum_gcs_reconnect_interval_milliseconds()) { + RAY_LOG(INFO) + << "Repeated reconnection in " + << RayConfig::instance().minimum_gcs_reconnect_interval_milliseconds() + << "milliseconds, return directly."; + return; + } + + RAY_LOG(INFO) << "Attemptting to reconnect to GCS server: " << address.first << ":" + << address.second; if (Ping(address.first, address.second, 100)) { - RAY_LOG(DEBUG) << "Reconnected to GCS server: " << address.first << ":" - << address.second; + RAY_LOG(INFO) << "Reconnected to GCS server: " << address.first << ":" + << address.second; break; } } @@ -205,6 +223,8 @@ void ServiceBasedGcsClient::ReconnectGcsServer() { if (index < RayConfig::instance().ping_gcs_rpc_server_max_retries()) { gcs_rpc_client_->Reset(address.first, address.second, *client_call_manager_); + last_reconnect_address_ = address; + last_reconnect_timestamp_ms_ = current_sys_time_ms(); } else { RAY_LOG(FATAL) << "Couldn't reconnect to GCS server. The last attempted GCS " "server address was " diff --git a/src/ray/gcs/gcs_client/service_based_gcs_client.h b/src/ray/gcs/gcs_client/service_based_gcs_client.h index 05fb4c467..906165099 100644 --- a/src/ray/gcs/gcs_client/service_based_gcs_client.h +++ b/src/ray/gcs/gcs_client/service_based_gcs_client.h @@ -72,6 +72,8 @@ class RAY_EXPORT ServiceBasedGcsClient : public GcsClient { std::function *)> get_server_address_func_; std::function resubscribe_func_; std::pair current_gcs_server_address_; + int64_t last_reconnect_timestamp_ms_; + std::pair last_reconnect_address_; }; } // namespace gcs From 426f8a8d15449c5e78efc4604e2b0fd4a8577fb8 Mon Sep 17 00:00:00 2001 From: Kai Fricke Date: Fri, 18 Dec 2020 10:31:40 +0100 Subject: [PATCH 21/88] [tune] Fix tutorial training on GPU (#12914) --- python/ray/tune/tests/tutorial.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/python/ray/tune/tests/tutorial.py b/python/ray/tune/tests/tutorial.py index 2aa442279..2a11f12a0 100644 --- a/python/ray/tune/tests/tutorial.py +++ b/python/ray/tune/tests/tutorial.py @@ -93,7 +93,11 @@ def train_mnist(config): batch_size=64, shuffle=True) + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model = ConvNet() + model.to(device) + optimizer = optim.SGD( model.parameters(), lr=config["lr"], momentum=config["momentum"]) for i in range(10): @@ -161,6 +165,11 @@ space = { hyperopt_search = HyperOptSearch(space, metric="mean_accuracy", mode="max") analysis = tune.run(train_mnist, num_samples=10, search_alg=hyperopt_search) + +# To enable GPUs, use this instead: +# analysis = tune.run( +# train_mnist, config=search_space, resources_per_trial={'gpu': 1}) + # __run_searchalg_end__ # __run_analysis_begin__ From bff50cfc37380140d53debfb140e262a38668db0 Mon Sep 17 00:00:00 2001 From: Gekho457 <62982571+Gekho457@users.noreply.github.com> Date: Fri, 18 Dec 2020 01:32:12 -0800 Subject: [PATCH 22/88] [k8s] Read gpu resources properly (#12942) * Read gpu resources properly * Comments and docstrings * Comment formatting --- python/ray/autoscaler/_private/commands.py | 7 ++-- .../autoscaler/_private/kubernetes/config.py | 42 +++++++++++++++++-- 2 files changed, 42 insertions(+), 7 deletions(-) diff --git a/python/ray/autoscaler/_private/commands.py b/python/ray/autoscaler/_private/commands.py index 6f70c4c7e..0c7e3abbd 100644 --- a/python/ray/autoscaler/_private/commands.py +++ b/python/ray/autoscaler/_private/commands.py @@ -280,9 +280,10 @@ def _bootstrap_config(config: Dict[str, Any], f"Failed to autodetect node resources: {str(exc)}. " "You can see full stack trace with higher verbosity.") - # NOTE: if `resources` field is missing, validate_config for non-AWS will - # fail (the schema error will ask the user to manually fill the resources) - # as we currently support autofilling resources for AWS instances only. + # NOTE: if `resources` field is missing, validate_config for providers + # other than AWS and Kubernetes will fail (the schema error will ask the + # user to manually fill the resources) as we currently support autofilling + # resources for AWS and Kubernetes only. validate_config(config) resolved_config = provider_cls.bootstrap_config(config) diff --git a/python/ray/autoscaler/_private/kubernetes/config.py b/python/ray/autoscaler/_private/kubernetes/config.py index 20173119b..b285e7701 100644 --- a/python/ray/autoscaler/_private/kubernetes/config.py +++ b/python/ray/autoscaler/_private/kubernetes/config.py @@ -60,6 +60,13 @@ def bootstrap_kubernetes(config): def fillout_resources_kubernetes(config): + """Fills CPU and GPU resources by reading pod spec of each available node + type. + + For each node type and each of CPU/GPU, looks at container's resources + and limits, takes min of the two. The result is rounded up, as Ray does + not currently support fractional CPU. + """ if "available_node_types" not in config: return config["available_node_types"] node_types = copy.deepcopy(config["available_node_types"]) @@ -96,20 +103,47 @@ def get_resource(container_resources, resource_name): limit = _get_resource( container_resources, resource_name, field_name="limits") resource = min(request, limit) + # float("inf") value means the resource wasn't detected in either + # requests or limits return 0 if resource == float("inf") else int(resource) def _get_resource(container_resources, resource_name, field_name): - if (field_name in container_resources - and resource_name in container_resources[field_name]): - return _parse_resource(container_resources[field_name][resource_name]) - else: + """Returns the resource quantity. + + The amount of resource is rounded up to nearest integer. + Returns float("inf") if the resource is not present. + + Args: + container_resources (dict): Container's resource field. + resource_name (str): One of 'cpu' or 'gpu'. + field_name (str): One of 'requests' or 'limits'. + + Returns: + Union[int, float]: Detected resource quantity. + """ + if field_name not in container_resources: + # No limit/resource field. return float("inf") + resources = container_resources[field_name] + # Look for keys containing the resource_name. For example, + # the key 'nvidia.com/gpu' contains the key 'gpu'. + matching_keys = [key for key in resources if resource_name in key.lower()] + if len(matching_keys) == 0: + return float("inf") + if len(matching_keys) > 1: + # Should have only one match -- mostly relevant for gpu. + raise ValueError(f"Multiple {resource_name} types not supported.") + # E.g. 'nvidia.com/gpu' or 'cpu'. + resource_key = matching_keys.pop() + resource_quantity = resources[resource_key] + return _parse_resource(resource_quantity) def _parse_resource(resource): resource_str = str(resource) if resource_str[-1] == "m": + # For example, '500m' rounds up to 1. return math.ceil(int(resource_str[:-1]) / 1000) else: return int(resource_str) From 55ae567f7a4c6213543d002b42270569454b5378 Mon Sep 17 00:00:00 2001 From: Kai Fricke Date: Fri, 18 Dec 2020 10:33:12 +0100 Subject: [PATCH 23/88] [tune] Fix and enable SigOpt tests (#12877) Co-authored-by: Richard Liaw --- .travis.yml | 1 - python/ray/tune/BUILD | 50 +++++++------- python/ray/tune/examples/sigopt_example.py | 18 +++-- .../sigopt_multi_objective_example.py | 21 ++++-- .../examples/sigopt_prior_beliefs_example.py | 20 ++++-- python/ray/tune/suggest/bayesopt.py | 8 ++- python/ray/tune/suggest/nevergrad.py | 2 +- python/ray/tune/suggest/sigopt.py | 26 ++++++-- python/ray/tune/suggest/suggestion.py | 6 ++ python/ray/tune/suggest/zoopt.py | 8 ++- python/ray/tune/tests/test_tune_restore.py | 65 +++++++++++++++---- 11 files changed, 158 insertions(+), 67 deletions(-) diff --git a/.travis.yml b/.travis.yml index d946fed22..a8ddb0d46 100644 --- a/.travis.yml +++ b/.travis.yml @@ -577,4 +577,3 @@ deploy: repo: ray-project/ray branch: master condition: $MULTIPLATFORM_JARS = 1 || $MAC_JARS = 1 || $LINUX_JARS = 1 - diff --git a/python/ray/tune/BUILD b/python/ray/tune/BUILD index 8b3439853..70fd218d5 100644 --- a/python/ray/tune/BUILD +++ b/python/ray/tune/BUILD @@ -254,7 +254,7 @@ py_test( size = "large", srcs = ["tests/test_tune_restore.py"], deps = [":tune_lib"], - tags = ["jenkins_only", "exclusive"], + tags = ["exclusive"], ) py_test( @@ -648,35 +648,35 @@ py_test( # ) # Needs SigOpt API key. -# py_test( -# name = "sigopt_example", -# size = "medium", -# srcs = ["examples/sigopt_example.py"], -# deps = [":tune_lib"], -# tags = ["exclusive", "example"], -# args = ["--smoke-test"] -# ) +py_test( + name = "sigopt_example", + size = "medium", + srcs = ["examples/sigopt_example.py"], + deps = [":tune_lib"], + tags = ["exclusive", "example"], + args = ["--smoke-test"] +) # Needs SigOpt API key. -# py_test( -# name = "sigopt_multi_objective_example", -# size = "medium", -# srcs = ["examples/sigopt_multi_objective_example.py"], -# deps = [":tune_lib"], s -# tags = ["exclusive", "example"], -# args = ["--smoke-test"] -# ) +py_test( + name = "sigopt_multi_objective_example", + size = "medium", + srcs = ["examples/sigopt_multi_objective_example.py"], + deps = [":tune_lib"], + tags = ["exclusive", "example"], + args = ["--smoke-test"] +) # Needs SigOpt API key. -# py_test( -# name = "sigopt_prior_beliefs_example", -# size = "medium", -# srcs = ["examples/sigopt_prior_beliefs_example.py"], -# deps = [":tune_lib"], -# tags = ["exclusive", "example"], -# args = ["--smoke-test"] -# ) +py_test( + name = "sigopt_prior_beliefs_example", + size = "medium", + srcs = ["examples/sigopt_prior_beliefs_example.py"], + deps = [":tune_lib"], + tags = ["exclusive", "example"], + args = ["--smoke-test"] +) py_test( name = "skopt_example", diff --git a/python/ray/tune/examples/sigopt_example.py b/python/ray/tune/examples/sigopt_example.py index dabe38aa1..d765404e9 100644 --- a/python/ray/tune/examples/sigopt_example.py +++ b/python/ray/tune/examples/sigopt_example.py @@ -2,6 +2,7 @@ It also checks that it is usable with a separate scheduler. """ +import sys import time from ray import tune @@ -29,14 +30,20 @@ if __name__ == "__main__": import argparse import os - assert "SIGOPT_KEY" in os.environ, \ - "SigOpt API key must be stored as environment variable at SIGOPT_KEY" - parser = argparse.ArgumentParser() parser.add_argument( "--smoke-test", action="store_true", help="Finish quickly for testing") args, _ = parser.parse_known_args() + if "SIGOPT_KEY" not in os.environ: + if args.smoke_test: + print("SigOpt API Key not found. Skipping smoke test.") + sys.exit(0) + else: + raise ValueError( + "SigOpt API Key not found. Please set the SIGOPT_KEY " + "environment variable.") + space = [ { "name": "width", @@ -67,7 +74,8 @@ if __name__ == "__main__": name="my_exp", search_alg=algo, scheduler=scheduler, - num_samples=10 if args.smoke_test else 1000, + num_samples=4 if args.smoke_test else 100, config={"steps": 10}) - print("Best hyperparameters found were: ", analysis.best_config) + print("Best hyperparameters found were: ", + analysis.get_best_config("mean_loss", "min")) diff --git a/python/ray/tune/examples/sigopt_multi_objective_example.py b/python/ray/tune/examples/sigopt_multi_objective_example.py index 1d34da8cc..b1ec64332 100644 --- a/python/ray/tune/examples/sigopt_multi_objective_example.py +++ b/python/ray/tune/examples/sigopt_multi_objective_example.py @@ -1,5 +1,5 @@ """Example using Sigopt's multi-objective functionality.""" - +import sys import time import numpy as np @@ -30,14 +30,20 @@ if __name__ == "__main__": import argparse import os - assert "SIGOPT_KEY" in os.environ, \ - "SigOpt API key must be stored as environment variable at SIGOPT_KEY" - parser = argparse.ArgumentParser() parser.add_argument( "--smoke-test", action="store_true", help="Finish quickly for testing") args, _ = parser.parse_known_args() + if "SIGOPT_KEY" not in os.environ: + if args.smoke_test: + print("SigOpt API Key not found. Skipping smoke test.") + sys.exit(0) + else: + raise ValueError( + "SigOpt API Key not found. Please set the SIGOPT_KEY " + "environment variable.") + space = [ { "name": "w1", @@ -52,7 +58,7 @@ if __name__ == "__main__": algo = SigOptSearch( space, name="SigOpt Example Multi Objective Experiment", - observation_budget=10 if args.smoke_test else 1000, + observation_budget=4 if args.smoke_test else 100, max_concurrent=1, metric=["average", "std", "sharpe"], mode=["max", "min", "obs"]) @@ -61,6 +67,7 @@ if __name__ == "__main__": easy_objective, name="my_exp", search_alg=algo, - num_samples=10 if args.smoke_test else 1000, + num_samples=4 if args.smoke_test else 100, config={"total_weight": 1}) - print("Best hyperparameters found were: ", analysis.best_config) + print("Best hyperparameters found were: ", + analysis.get_best_config("average", "min")) diff --git a/python/ray/tune/examples/sigopt_prior_beliefs_example.py b/python/ray/tune/examples/sigopt_prior_beliefs_example.py index 420a342a8..a624a6a8f 100644 --- a/python/ray/tune/examples/sigopt_prior_beliefs_example.py +++ b/python/ray/tune/examples/sigopt_prior_beliefs_example.py @@ -1,4 +1,5 @@ """"Example using Sigopt's support for prior beliefs.""" +import sys import numpy as np from ray import tune @@ -37,14 +38,21 @@ if __name__ == "__main__": import os from sigopt import Connection - assert "SIGOPT_KEY" in os.environ, \ - "SigOpt API key must be stored as environment variable at SIGOPT_KEY" - parser = argparse.ArgumentParser() parser.add_argument( "--smoke-test", action="store_true", help="Finish quickly for testing") args, _ = parser.parse_known_args() - samples = 10 if args.smoke_test else 1000 + + if "SIGOPT_KEY" not in os.environ: + if args.smoke_test: + print("SigOpt API Key not found. Skipping smoke test.") + sys.exit(0) + else: + raise ValueError( + "SigOpt API Key not found. Please set the SIGOPT_KEY " + "environment variable.") + + samples = 4 if args.smoke_test else 100 conn = Connection(client_token=os.environ["SIGOPT_KEY"]) experiment = conn.experiments().create( @@ -95,4 +103,6 @@ if __name__ == "__main__": search_alg=algo, num_samples=samples, config={}) - print("Best hyperparameters found were: ", analysis.best_config) + + print("Best hyperparameters found were: ", + analysis.get_best_config("average", "min")) diff --git a/python/ray/tune/suggest/bayesopt.py b/python/ray/tune/suggest/bayesopt.py index df8f94546..489e92709 100644 --- a/python/ray/tune/suggest/bayesopt.py +++ b/python/ray/tune/suggest/bayesopt.py @@ -169,9 +169,7 @@ class BayesOptSearch(Searcher): self.utility = byo.UtilityFunction(**utility_kwargs) - # Registering the provided analysis, if given - if analysis is not None: - self.register_analysis(analysis) + self._analysis = analysis if isinstance(space, dict) and space: resolved_vars, domain_vars, grid_vars = parse_spec_vars(space) @@ -200,6 +198,10 @@ class BayesOptSearch(Searcher): verbose=self._verbose, random_state=self._random_state) + # Registering the provided analysis, if given + if self._analysis is not None: + self.register_analysis(self._analysis) + def set_search_properties(self, metric: Optional[str], mode: Optional[str], config: Dict) -> bool: if self.optimizer: diff --git a/python/ray/tune/suggest/nevergrad.py b/python/ray/tune/suggest/nevergrad.py index 669114d9b..f5da80b00 100644 --- a/python/ray/tune/suggest/nevergrad.py +++ b/python/ray/tune/suggest/nevergrad.py @@ -148,7 +148,7 @@ class NevergradSearch(Searcher): space = self.convert_search_space(space) if isinstance(optimizer, Optimizer): - if space is not None or isinstance(space, list): + if space is not None and not isinstance(space, list): raise ValueError( "If you pass a configured optimizer to Nevergrad, either " "pass a list of parameter names or None as the `space` " diff --git a/python/ray/tune/suggest/sigopt.py b/python/ray/tune/suggest/sigopt.py index 415e8aa3a..8bdcaaf1d 100644 --- a/python/ray/tune/suggest/sigopt.py +++ b/python/ray/tune/suggest/sigopt.py @@ -136,6 +136,7 @@ class SigOptSearch(Searcher): project: Optional[str] = None, metric: Union[None, str, List[str]] = "episode_reward_mean", mode: Union[None, str, List[str]] = "max", + points_to_evaluate: Optional[List[Dict]] = None, **kwargs): assert (experiment_id is None) ^ (space is None), "space xor experiment_id must be set" @@ -182,17 +183,25 @@ class SigOptSearch(Searcher): else: self.experiment = self.conn.experiments(experiment_id).fetch() + self._points_to_evaluate = points_to_evaluate + super(SigOptSearch, self).__init__(metric=metric, mode=mode, **kwargs) def suggest(self, trial_id: str): if self._max_concurrent: if len(self._live_trial_mapping) >= self._max_concurrent: return None + + suggestion_kwargs = {} + if self._points_to_evaluate: + config = self._points_to_evaluate.pop(0) + suggestion_kwargs = {"assignments": config} + # Get new suggestion from SigOpt suggestion = self.conn.experiments( - self.experiment.id).suggestions().create() + self.experiment.id).suggestions().create(**suggestion_kwargs) - self._live_trial_mapping[trial_id] = suggestion + self._live_trial_mapping[trial_id] = suggestion.id return copy.deepcopy(suggestion.assignments) @@ -210,7 +219,7 @@ class SigOptSearch(Searcher): """ if result: payload = dict( - suggestion=self._live_trial_mapping[trial_id].id, + suggestion=self._live_trial_mapping[trial_id], values=self.serialize_result(result)) self.conn.experiments( self.experiment.id).observations().create(**payload) @@ -219,7 +228,7 @@ class SigOptSearch(Searcher): elif error: # Reports a failed Observation self.conn.experiments(self.experiment.id).observations().create( - failed=True, suggestion=self._live_trial_mapping[trial_id].id) + failed=True, suggestion=self._live_trial_mapping[trial_id]) del self._live_trial_mapping[trial_id] @staticmethod @@ -254,12 +263,15 @@ class SigOptSearch(Searcher): return values def save(self, checkpoint_path: str): - trials_object = (self.conn, self.experiment) + trials_object = (self.experiment.id, self._live_trial_mapping, + self._points_to_evaluate) with open(checkpoint_path, "wb") as outputFile: pickle.dump(trials_object, outputFile) def restore(self, checkpoint_path: str): with open(checkpoint_path, "rb") as inputFile: trials_object = pickle.load(inputFile) - self.conn = trials_object[0] - self.experiment = trials_object[1] + experiment_id, self._live_trial_mapping, self._points_to_evaluate = \ + trials_object + + self.experiment = self.conn.experiments(experiment_id).fetch() diff --git a/python/ray/tune/suggest/suggestion.py b/python/ray/tune/suggest/suggestion.py index 99a6001d1..3612576d4 100644 --- a/python/ray/tune/suggest/suggestion.py +++ b/python/ray/tune/suggest/suggestion.py @@ -391,6 +391,12 @@ class ConcurrencyLimiter(Searcher): def set_state(self, state: Dict): self.__dict__.update(state) + def save(self, checkpoint_path: str): + self.searcher.save(checkpoint_path) + + def restore(self, checkpoint_path: str): + self.searcher.restore(checkpoint_path) + def on_pause(self, trial_id: str): self.searcher.on_pause(trial_id) diff --git a/python/ray/tune/suggest/zoopt.py b/python/ray/tune/suggest/zoopt.py index f9e3d04bb..c0c0ddb18 100644 --- a/python/ray/tune/suggest/zoopt.py +++ b/python/ray/tune/suggest/zoopt.py @@ -138,6 +138,7 @@ class ZOOptSearch(Searcher): metric: Optional[str] = None, mode: Optional[str] = None, points_to_evaluate: Optional[List[Dict]] = None, + parallel_num: int = 1, **kwargs): assert zoopt is not None, "ZOOpt not found - please install zoopt " \ "by `pip install -U zoopt`." @@ -178,6 +179,8 @@ class ZOOptSearch(Searcher): self.kwargs = kwargs + self.parallel_num = parallel_num + super(ZOOptSearch, self).__init__(metric=self._metric, mode=mode) if self._dim_dict: @@ -206,7 +209,10 @@ class ZOOptSearch(Searcher): if self._algo == "sracos" or self._algo == "asracos": from zoopt.algos.opt_algorithms.racos.sracos import SRacosTune self.optimizer = SRacosTune( - dimension=dim, parameter=par, **self.kwargs) + dimension=dim, + parameter=par, + parallel_num=self.parallel_num, + **self.kwargs) if init_samples: self.optimizer.init_attribute() diff --git a/python/ray/tune/tests/test_tune_restore.py b/python/ray/tune/tests/test_tune_restore.py index 507dca1d7..fc61c81ff 100644 --- a/python/ray/tune/tests/test_tune_restore.py +++ b/python/ray/tune/tests/test_tune_restore.py @@ -105,11 +105,6 @@ class TuneExampleTest(unittest.TestCase): validate_save_restore(TrainMNIST) validate_save_restore(TrainMNIST, use_object_store=True) - def testLogging(self): - from ray.tune.examples.logging_example import MyTrainableClass - validate_save_restore(MyTrainableClass) - validate_save_restore(MyTrainableClass, use_object_store=True) - def testHyperbandExample(self): from ray.tune.examples.hyperband_example import MyTrainableClass validate_save_restore(MyTrainableClass) @@ -373,22 +368,72 @@ class SigOptWarmStartTest(AbstractWarmStartTest, unittest.TestCase): def cost(space, reporter): reporter(loss=(space["height"] - 14)**2 - abs(space["width"] - 3)) + # Unfortunately, SigOpt doesn't allow setting of random state. Thus, + # we always end up with different suggestions, which is unsuitable + # for the warm start test. Here we make do with points_to_evaluate, + # and ensure that state is preserved over checkpoints and restarts. + points = [ + { + "width": 5, + "height": 20 + }, + { + "width": 10, + "height": -20 + }, + { + "width": 15, + "height": 30 + }, + { + "width": 5, + "height": -30 + }, + { + "width": 10, + "height": 40 + }, + { + "width": 15, + "height": -40 + }, + { + "width": 5, + "height": 50 + }, + { + "width": 10, + "height": -50 + }, + { + "width": 15, + "height": 60 + }, + { + "width": 12, + "height": -60 + }, + ] + search_alg = SigOptSearch( space, name="SigOpt Example Experiment", max_concurrent=1, metric="loss", - mode="min") + mode="min", + points_to_evaluate=points) return search_alg, cost def testWarmStart(self): - if ("SIGOPT_KEY" not in os.environ): + if "SIGOPT_KEY" not in os.environ: + self.skipTest("No SigOpt API key found in environment.") return super().testWarmStart() def testRestore(self): - if ("SIGOPT_KEY" not in os.environ): + if "SIGOPT_KEY" not in os.environ: + self.skipTest("No SigOpt API key found in environment.") return super().testRestore() @@ -412,10 +457,6 @@ class ZOOptWarmStartTest(AbstractWarmStartTest, unittest.TestCase): return search_alg, cost - @unittest.skip("Skip because this seems to have leaking state.") - def testRestore(self): - pass - class SearcherTest(unittest.TestCase): class MockSearcher(Searcher): From 5cfa1934e4a2157d07ef534e1033a3f0e5bc1082 Mon Sep 17 00:00:00 2001 From: Barak Michener Date: Fri, 18 Dec 2020 11:47:38 -0800 Subject: [PATCH 24/88] [ray_client]: Implement object retain/release and Data Streaming API (#12818) --- python/ray/experimental/client/__init__.py | 17 +- python/ray/experimental/client/api.py | 31 ++ .../ray/experimental/client/client_pickler.py | 123 ++++++++ python/ray/experimental/client/common.py | 146 +++------ python/ray/experimental/client/dataclient.py | 103 +++++++ .../client/server/core_ray_api.py | 41 +-- .../client/server/dataservicer.py | 54 ++++ .../ray/experimental/client/server/server.py | 286 +++++++++++------- .../client/server/server_pickler.py | 119 ++++++++ .../client/server/server_stubs.py | 29 ++ python/ray/experimental/client/worker.py | 111 ++++--- python/ray/tests/BUILD | 1 + python/ray/tests/test_experimental_client.py | 2 +- .../test_experimental_client_references.py | 152 ++++++++++ src/ray/protobuf/ray_client.proto | 95 +++++- 15 files changed, 1000 insertions(+), 310 deletions(-) create mode 100644 python/ray/experimental/client/client_pickler.py create mode 100644 python/ray/experimental/client/dataclient.py create mode 100644 python/ray/experimental/client/server/dataservicer.py create mode 100644 python/ray/experimental/client/server/server_pickler.py create mode 100644 python/ray/experimental/client/server/server_stubs.py create mode 100644 python/ray/tests/test_experimental_client_references.py diff --git a/python/ray/experimental/client/__init__.py b/python/ray/experimental/client/__init__.py index a6bba39ed..2af86d023 100644 --- a/python/ray/experimental/client/__init__.py +++ b/python/ray/experimental/client/__init__.py @@ -91,15 +91,22 @@ def _get_client_api() -> APIImpl: return api +def _get_server_instance(): + """Used inside tests to inspect the running server. + """ + global _server_api + if _server_api is not None: + return _server_api.server + + class RayAPIStub: def connect(self, conn_str: str, secure: bool = False, metadata: List[Tuple[str, str]] = None, - stub=None): + stub=None) -> None: from ray.experimental.client.worker import Worker - _client_worker = Worker( - conn_str, secure=secure, metadata=metadata, stub=stub) + _client_worker = Worker(conn_str, secure=secure, metadata=metadata) _set_client_api(ClientAPI(_client_worker)) def disconnect(self): @@ -113,6 +120,10 @@ class RayAPIStub: api = _get_client_api() return getattr(api, key) + def is_connected(self) -> bool: + global _client_api + return _client_api is not None + ray = RayAPIStub() diff --git a/python/ray/experimental/client/api.py b/python/ray/experimental/client/api.py index 304cc4467..5167e5988 100644 --- a/python/ray/experimental/client/api.py +++ b/python/ray/experimental/client/api.py @@ -138,6 +138,31 @@ class APIImpl(ABC): """ pass + @abstractmethod + def call_release(self, id: bytes) -> None: + """ + Attempts to release an object reference. + + When client references are destructed, they release their reference, + which can opportunistically send a notification through the datachannel + to release the reference being held for that object on the server. + + Args: + id: The id of the reference to release on the server side. + """ + + @abstractmethod + def call_retain(self, id: bytes) -> None: + """ + Attempts to retain a client object reference. + + Increments the reference count on the client side, to prevent + the client worker from attempting to release the server reference. + + Args: + id: The id of the reference to retain on the client side. + """ + class ClientAPI(APIImpl): """ @@ -163,6 +188,12 @@ class ClientAPI(APIImpl): def call_remote(self, instance: "ClientStub", *args, **kwargs): return self.worker.call_remote(instance, *args, **kwargs) + def call_release(self, id: bytes) -> None: + return self.worker.call_release(id) + + def call_retain(self, id: bytes) -> None: + return self.worker.call_retain(id) + def close(self) -> None: return self.worker.close() diff --git a/python/ray/experimental/client/client_pickler.py b/python/ray/experimental/client/client_pickler.py new file mode 100644 index 000000000..73df31c0e --- /dev/null +++ b/python/ray/experimental/client/client_pickler.py @@ -0,0 +1,123 @@ +""" +Implements the client side of the client/server pickling protocol. + +All ray client client/server data transfer happens through this pickling +protocol. The model is as follows: + + * All Client objects (eg ClientObjectRef) always live on the client and + are never represented in the server + * All Ray objects (eg, ray.ObjectRef) always live on the server and are + never returned to the client + * In order to translate between these two references, PickleStub tuples + are generated as persistent ids in the data blobs during the pickling + and unpickling of these objects. + +The PickleStubs have just enough information to find or generate their +associated partner object on either side. + +This also has the advantage of avoiding predefined pickle behavior for ray +objects, which may include ray internal reference counting. + +ClientPickler dumps things from the client into the appropriate stubs +ServerUnpickler loads stubs from the server into their client counterparts. +""" + +import cloudpickle +import io +import sys + +from typing import NamedTuple +from typing import Any + +from ray.experimental.client.common import ClientObjectRef +from ray.experimental.client.common import ClientActorHandle +from ray.experimental.client.common import ClientActorRef +from ray.experimental.client.common import ClientRemoteFunc +from ray.experimental.client.common import SelfReferenceSentinel +import ray.core.generated.ray_client_pb2 as ray_client_pb2 + +if sys.version_info < (3, 8): + try: + import pickle5 as pickle # noqa: F401 + except ImportError: + import pickle # noqa: F401 +else: + import pickle # noqa: F401 + +PickleStub = NamedTuple("PickleStub", [("type", str), ("client_id", str), + ("ref_id", bytes)]) + + +class ClientPickler(cloudpickle.CloudPickler): + def __init__(self, client_id, *args, **kwargs): + super().__init__(*args, **kwargs) + self.client_id = client_id + + def persistent_id(self, obj): + if isinstance(obj, ClientObjectRef): + return PickleStub( + type="Object", + client_id=self.client_id, + ref_id=obj.id, + ) + elif isinstance(obj, ClientActorHandle): + return PickleStub( + type="Actor", + client_id=self.client_id, + ref_id=obj._actor_id, + ) + elif isinstance(obj, ClientRemoteFunc): + # TODO(barakmich): This is going to have trouble with mutually + # recursive functions that haven't, as yet, been executed. It's + # relatively doable (keep track of intermediate refs in progress + # with ensure_ref and return appropriately) But punting for now. + if obj._ref is None: + obj._ensure_ref() + if type(obj._ref) == SelfReferenceSentinel: + return PickleStub( + type="RemoteFuncSelfReference", + client_id=self.client_id, + ref_id=b"") + return PickleStub( + type="RemoteFunc", + client_id=self.client_id, + ref_id=obj._ref.id) + return None + + +class ServerUnpickler(pickle.Unpickler): + def persistent_load(self, pid): + assert isinstance(pid, PickleStub) + if pid.type == "Object": + return ClientObjectRef(id=pid.ref_id) + elif pid.type == "Actor": + return ClientActorHandle(ClientActorRef(id=pid.ref_id)) + else: + raise NotImplementedError("Being passed back an unknown stub") + + +def dumps_from_client(obj: Any, client_id: str, protocol=None) -> bytes: + with io.BytesIO() as file: + cp = ClientPickler(client_id, file, protocol=protocol) + cp.dump(obj) + return file.getvalue() + + +def loads_from_server(data: bytes, + *, + fix_imports=True, + encoding="ASCII", + errors="strict") -> Any: + if isinstance(data, str): + raise TypeError("Can't load pickle from unicode string") + file = io.BytesIO(data) + return ServerUnpickler( + file, fix_imports=fix_imports, encoding=encoding, + errors=errors).load() + + +def convert_to_arg(val: Any, client_id: str) -> ray_client_pb2.Arg: + out = ray_client_pb2.Arg() + out.local = ray_client_pb2.Arg.Locality.INTERNED + out.data = dumps_from_client(val, client_id) + return out diff --git a/python/ray/experimental/client/common.py b/python/ray/experimental/client/common.py index 24b012790..74f11c2c2 100644 --- a/python/ray/experimental/client/common.py +++ b/python/ray/experimental/client/common.py @@ -1,16 +1,12 @@ import ray.core.generated.ray_client_pb2 as ray_client_pb2 from ray.experimental.client import ray -from typing import Any from typing import Dict -from ray import cloudpickle - -import base64 class ClientBaseRef: - def __init__(self, id, handle=None): - self.id = id - self.handle = handle + def __init__(self, id: bytes): + self.id: bytes = id + ray.call_retain(id) def __repr__(self): return "%s(%s)" % ( @@ -24,14 +20,13 @@ class ClientBaseRef: def binary(self): return self.id - @classmethod - def from_remote_ref(cls, ref: ray_client_pb2.RemoteRef): - return cls(id=ref.id, handle=ref.handle) + def __del__(self): + if ray.is_connected(): + ray.call_release(self.id) class ClientObjectRef(ClientBaseRef): - def _unpack_ref(self): - return cloudpickle.loads(self.handle) + pass class ClientActorRef(ClientBaseRef): @@ -53,50 +48,42 @@ class ClientRemoteFunc(ClientStub): _func: The actual function to execute remotely _name: The original name of the function _ref: The ClientObjectRef of the pickled code of the function, _func - _raylet_remote: The Raylet-side ray.remote_function.RemoteFunction - for this object """ def __init__(self, f): self._func = f self._name = f.__name__ - self.id = None - - # self._ref can be lazily instantiated. Rather than eagerly creating - # function data objects in the server we can put them just before we - # execute the function, especially in cases where many @ray.remote - # functions exist in a library and only a handful are ever executed by - # a user of the library. - # - # TODO(barakmich): This ref might actually be better as a serialized - # ObjectRef. This requires being able to serialize the ref without - # pinning it (as the lifetime of the ref is tied with the server, not - # the client) self._ref = None - self._raylet_remote = None def __call__(self, *args, **kwargs): raise TypeError(f"Remote function cannot be called directly. " "Use {self._name}.remote method instead") def remote(self, *args, **kwargs): - return ray.call_remote(self, *args, **kwargs) - - def _get_ray_remote_impl(self): - if self._raylet_remote is None: - self._raylet_remote = ray.remote(self._func) - return self._raylet_remote + return ClientObjectRef(ray.call_remote(self, *args, **kwargs)) def __repr__(self): return "ClientRemoteFunc(%s, %s)" % (self._name, self._ref) - def _prepare_client_task(self) -> ray_client_pb2.ClientTask: + def _ensure_ref(self): if self._ref is None: + # While calling ray.put() on our function, if + # our function is recursive, it will attempt to + # encode the ClientRemoteFunc -- itself -- and + # infinitely recurse on _ensure_ref. + # + # So we set the state of the reference to be an + # in-progress self reference value, which + # the encoding can detect and handle correctly. + self._ref = SelfReferenceSentinel() self._ref = ray.put(self._func) + + def _prepare_client_task(self) -> ray_client_pb2.ClientTask: + self._ensure_ref() task = ray_client_pb2.ClientTask() task.type = ray_client_pb2.ClientTask.FUNCTION task.name = self._name - task.payload_id = self._ref.handle + task.payload_id = self._ref.id return task @@ -109,14 +96,12 @@ class ClientActorClass(ClientStub): actor_cls: The actual class to execute remotely _name: The original name of the class _ref: The ClientObjectRef of the pickled `actor_cls` - _raylet_remote: The Raylet-side ray.ActorClass for this object """ def __init__(self, actor_cls): self.actor_cls = actor_cls self._name = actor_cls.__name__ self._ref = None - self._raylet_remote = None def __call__(self, *args, **kwargs): raise TypeError(f"Remote actor cannot be instantiated directly. " @@ -135,10 +120,10 @@ class ClientActorClass(ClientStub): self._name = state["_name"] self._ref = state["_ref"] - def remote(self, *args, **kwargs): + def remote(self, *args, **kwargs) -> "ClientActorHandle": # Actually instantiate the actor - ref = ray.call_remote(self, *args, **kwargs) - return ClientActorHandle(ClientActorRef(ref.id, ref.handle), self) + ref_id = ray.call_remote(self, *args, **kwargs) + return ClientActorHandle(ClientActorRef(ref_id), self) def __repr__(self): return "ClientRemoteActor(%s, %s)" % (self._name, self._ref) @@ -154,7 +139,7 @@ class ClientActorClass(ClientStub): task = ray_client_pb2.ClientTask() task.type = ray_client_pb2.ClientTask.ACTOR task.name = self._name - task.payload_id = self._ref.handle + task.payload_id = self._ref.id return task @@ -177,26 +162,9 @@ class ClientActorHandle(ClientStub): def __init__(self, actor_ref: ClientActorRef, actor_class: ClientActorClass): self.actor_ref = actor_ref - self.actor_class = actor_class - self._real_actor_handle = None - def _get_ray_remote_impl(self): - if self._real_actor_handle is None: - self._real_actor_handle = cloudpickle.loads(self.actor_ref.handle) - return self._real_actor_handle - - def __getstate__(self) -> Dict: - state = { - "actor_ref": self.actor_ref, - "actor_class": self.actor_class, - "_real_actor_handle": self._real_actor_handle, - } - return state - - def __setstate__(self, state: Dict) -> None: - self.actor_ref = state["actor_ref"] - self.actor_class = state["actor_class"] - self._real_actor_handle = state["_real_actor_handle"] + def __del__(self) -> None: + ray.call_release(self.actor_ref.id) @property def _actor_id(self): @@ -226,65 +194,27 @@ class ClientRemoteMethod(ClientStub): def __call__(self, *args, **kwargs): raise TypeError(f"Remote method cannot be called directly. " - "Use {self._name}.remote() instead") - - def _get_ray_remote_impl(self): - return getattr(self.actor_handle._get_ray_remote_impl(), - self.method_name) - - def __getstate__(self) -> Dict: - state = { - "actor_handle": self.actor_handle, - "method_name": self.method_name, - } - return state - - def __setstate__(self, state: Dict) -> None: - self.actor_handle = state["actor_handle"] - self.method_name = state["method_name"] + f"Use {self._name}.remote() instead") def remote(self, *args, **kwargs): - return ray.call_remote(self, *args, **kwargs) + return ClientObjectRef(ray.call_remote(self, *args, **kwargs)) def __repr__(self): - name = "%s.%s" % (self.actor_handle.actor_class._name, - self.method_name) - return "ClientRemoteMethod(%s, %s)" % (name, - self.actor_handle.actor_id) + return "ClientRemoteMethod(%s, %s)" % (self.method_name, + self.actor_handle) def _prepare_client_task(self) -> ray_client_pb2.ClientTask: task = ray_client_pb2.ClientTask() task.type = ray_client_pb2.ClientTask.METHOD task.name = self.method_name - task.payload_id = self.actor_handle.actor_ref.handle + task.payload_id = self.actor_handle.actor_ref.id return task -def convert_from_arg(pb) -> Any: - if pb.local == ray_client_pb2.Arg.Locality.REFERENCE: - return ClientObjectRef(pb.reference_id) - elif pb.local == ray_client_pb2.Arg.Locality.INTERNED: - return cloudpickle.loads(pb.data) - - raise Exception("convert_from_arg: Uncovered locality enum") +class DataEncodingSentinel: + def __repr__(self) -> str: + return self.__class__.__name__ -def convert_to_arg(val): - out = ray_client_pb2.Arg() - if isinstance(val, ClientObjectRef): - out.local = ray_client_pb2.Arg.Locality.REFERENCE - out.reference_id = val.id - else: - out.local = ray_client_pb2.Arg.Locality.INTERNED - out.data = cloudpickle.dumps(val) - return out - - -def encode_exception(exception) -> str: - data = cloudpickle.dumps(exception) - return base64.standard_b64encode(data).decode() - - -def decode_exception(data) -> Exception: - data = base64.standard_b64decode(data) - return cloudpickle.loads(data) +class SelfReferenceSentinel(DataEncodingSentinel): + pass diff --git a/python/ray/experimental/client/dataclient.py b/python/ray/experimental/client/dataclient.py new file mode 100644 index 000000000..7e16c015b --- /dev/null +++ b/python/ray/experimental/client/dataclient.py @@ -0,0 +1,103 @@ +""" +This file implements a threaded stream controller to abstract a data stream +back to the ray clientserver. +""" +import logging +import queue +import threading +import grpc + +from typing import Any +from typing import Dict + +import ray.core.generated.ray_client_pb2 as ray_client_pb2 +import ray.core.generated.ray_client_pb2_grpc as ray_client_pb2_grpc + +logger = logging.getLogger(__name__) + +# The maximum field value for request_id -- which is also the maximum +# number of simultaneous in-flight requests. +INT32_MAX = (2**31) - 1 + + +class DataClient: + def __init__(self, channel: "grpc._channel.Channel", client_id: str): + """Initializes a thread-safe datapath over a Ray Client gRPC channel. + + Args: + channel: connected gRPC channel + """ + self.channel = channel + self.request_queue = queue.Queue() + self.data_thread = self._start_datathread() + self.ready_data: Dict[int, Any] = {} + self.cv = threading.Condition() + self._req_id = 0 + self._client_id = client_id + self.data_thread.start() + + def _next_id(self) -> int: + self._req_id += 1 + if self._req_id > INT32_MAX: + self._req_id = 1 + # Responses that aren't tracked (like opportunistic releases) + # have req_id=0, so make sure we never mint such an id. + assert self._req_id != 0 + return self._req_id + + def _start_datathread(self) -> threading.Thread: + return threading.Thread(target=self._data_main, args=(), daemon=True) + + def _data_main(self) -> None: + stub = ray_client_pb2_grpc.RayletDataStreamerStub(self.channel) + resp_stream = stub.Datapath( + iter(self.request_queue.get, None), + metadata=(("client_id", self._client_id), )) + for response in resp_stream: + if response.req_id == 0: + # This is not being waited for. + logger.debug(f"Got unawaited response {response}") + continue + with self.cv: + self.ready_data[response.req_id] = response + self.cv.notify_all() + + def close(self, close_channel: bool = False) -> None: + if self.request_queue is not None: + self.request_queue.put(None) + self.request_queue = None + if self.data_thread is not None: + self.data_thread.join() + self.data_thread = None + if close_channel: + self.channel.close() + + def _blocking_send(self, req: ray_client_pb2.DataRequest + ) -> ray_client_pb2.DataResponse: + req_id = self._next_id() + req.req_id = req_id + self.request_queue.put(req) + data = None + with self.cv: + self.cv.wait_for(lambda: req_id in self.ready_data) + data = self.ready_data[req_id] + del self.ready_data[req_id] + return data + + def GetObject(self, request: ray_client_pb2.GetRequest, + context=None) -> ray_client_pb2.GetResponse: + datareq = ray_client_pb2.DataRequest(get=request, ) + resp = self._blocking_send(datareq) + return resp.get + + def PutObject(self, request: ray_client_pb2.PutRequest, + context=None) -> ray_client_pb2.PutResponse: + datareq = ray_client_pb2.DataRequest(put=request, ) + resp = self._blocking_send(datareq) + return resp.put + + def ReleaseObject(self, + request: ray_client_pb2.ReleaseRequest, + context=None) -> None: + datareq = ray_client_pb2.DataRequest(release=request, ) + self.request_queue.put(datareq) diff --git a/python/ray/experimental/client/server/core_ray_api.py b/python/ray/experimental/client/server/core_ray_api.py index 6513021a8..2d930f352 100644 --- a/python/ray/experimental/client/server/core_ray_api.py +++ b/python/ray/experimental/client/server/core_ray_api.py @@ -11,12 +11,15 @@ from typing import Any from typing import Optional from typing import Union +import logging import ray from ray.experimental.client.api import APIImpl from ray.experimental.client.common import ClientObjectRef from ray.experimental.client.common import ClientStub +logger = logging.getLogger(__name__) + class CoreRayAPI(APIImpl): """ @@ -26,12 +29,6 @@ class CoreRayAPI(APIImpl): """ def get(self, vals, *, timeout: Optional[float] = None) -> Any: - if isinstance(vals, list): - if isinstance(vals[0], ClientObjectRef): - return ray.get( - [val._unpack_ref() for val in vals], timeout=timeout) - elif isinstance(vals, ClientObjectRef): - return ray.get(vals._unpack_ref(), timeout=timeout) return ray.get(vals, timeout=timeout) def put(self, vals: Any, *args, @@ -45,7 +42,8 @@ class CoreRayAPI(APIImpl): return ray.remote(*args, **kwargs) def call_remote(self, instance: ClientStub, *args, **kwargs): - return instance._get_ray_remote_impl().remote(*args, **kwargs) + raise NotImplementedError( + "Should not attempt execution of a client stub inside the raylet") def close(self) -> None: return None @@ -59,6 +57,12 @@ class CoreRayAPI(APIImpl): def is_initialized(self) -> bool: return ray.is_initialized() + def call_release(self, id: bytes) -> None: + return None + + def call_retain(self, id: bytes) -> None: + return None + # Allow for generic fallback to ray.* in remote methods. This allows calls # like ray.nodes() to be run in remote functions even though the client # doesn't currently support them. @@ -76,26 +80,7 @@ class RayServerAPI(CoreRayAPI): def __init__(self, server_instance): self.server = server_instance - # Wrap single item into list if needed before calling server put. - def put(self, vals: Any, *args, **kwargs) -> ClientObjectRef: - to_put = [] - single = False - if isinstance(vals, list): - to_put = vals - else: - single = True - to_put.append(vals) - - out = [self._put(x) for x in to_put] - if single: - out = out[0] - return out - - def _put(self, val: Any): - resp = self.server._put_and_retain_obj(val) - return ClientObjectRef(resp.id) - - def call_remote(self, instance: ClientStub, *args, **kwargs): + def call_remote(self, instance: ClientStub, *args, **kwargs) -> bytes: task = instance._prepare_client_task() ticket = self.server.Schedule(task, prepared_args=args) - return ClientObjectRef(ticket.return_id) + return ticket.return_id diff --git a/python/ray/experimental/client/server/dataservicer.py b/python/ray/experimental/client/server/dataservicer.py new file mode 100644 index 000000000..874e741d9 --- /dev/null +++ b/python/ray/experimental/client/server/dataservicer.py @@ -0,0 +1,54 @@ +import logging +import grpc + +from typing import TYPE_CHECKING + +import ray.core.generated.ray_client_pb2 as ray_client_pb2 +import ray.core.generated.ray_client_pb2_grpc as ray_client_pb2_grpc + +if TYPE_CHECKING: + from ray.experimental.client.server.server import RayletServicer + +logger = logging.getLogger(__name__) + + +class DataServicer(ray_client_pb2_grpc.RayletDataStreamerServicer): + def __init__(self, basic_service: "RayletServicer"): + self.basic_service = basic_service + + def Datapath(self, request_iterator, context): + metadata = {k: v for k, v in context.invocation_metadata()} + client_id = metadata["client_id"] + if client_id == "": + logger.error("Client connecting with no client_id") + return + logger.info(f"New data connection from client {client_id}") + try: + for req in request_iterator: + resp = None + req_type = req.WhichOneof("type") + if req_type == "get": + get_resp = self.basic_service._get_object( + req.get, client_id) + resp = ray_client_pb2.DataResponse(get=get_resp) + elif req_type == "put": + put_resp = self.basic_service._put_object( + req.put, client_id) + resp = ray_client_pb2.DataResponse(put=put_resp) + elif req_type == "release": + released = [] + for rel_id in req.release.ids: + rel = self.basic_service.release(client_id, rel_id) + released.append(rel) + resp = ray_client_pb2.DataResponse( + release=ray_client_pb2.ReleaseResponse(ok=released)) + else: + raise Exception(f"Unreachable code: Request type " + f"{req_type} not handled in Datapath") + resp.req_id = req.req_id + yield resp + except grpc.RpcError as e: + logger.debug(f"Closing channel: {e}") + finally: + logger.info(f"Lost data connection from client {client_id}") + self.basic_service.release_all(client_id) diff --git a/python/ray/experimental/client/server/server.py b/python/ray/experimental/client/server/server.py index 616e6e60d..0fd34eda4 100644 --- a/python/ray/experimental/client/server/server.py +++ b/python/ray/experimental/client/server/server.py @@ -1,6 +1,12 @@ import logging from concurrent import futures import grpc +import base64 +from collections import defaultdict + +from typing import Dict +from typing import Set + from ray import cloudpickle import ray import ray.state @@ -10,21 +16,26 @@ import time import inspect import json from ray.experimental.client import stash_api_for_tests, _set_server_api -from ray.experimental.client.common import convert_from_arg -from ray.experimental.client.common import encode_exception -from ray.experimental.client.common import ClientObjectRef +from ray.experimental.client.server.server_pickler import convert_from_arg +from ray.experimental.client.server.server_pickler import dumps_from_server +from ray.experimental.client.server.server_pickler import loads_from_client from ray.experimental.client.server.core_ray_api import RayServerAPI +from ray.experimental.client.server.dataservicer import DataServicer +from ray.experimental.client.server.server_stubs import current_func logger = logging.getLogger(__name__) class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer): def __init__(self, test_mode=False): - self.object_refs = {} + self.object_refs: Dict[str, Dict[bytes, ray.ObjectRef]] = defaultdict( + dict) self.function_refs = {} - self.actor_refs = {} + self.actor_refs: Dict[bytes, ray.ActorHandle] = {} + self.actor_owners: Dict[str, Set[bytes]] = defaultdict(set) self.registered_actor_classes = {} self._test_mode = test_mode + self._current_function_stub = None def ClusterInfo(self, request, context=None) -> ray_client_pb2.ClusterInfoResponse: @@ -61,20 +72,59 @@ class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer): raise TypeError("Unsupported cluster info type") return json.dumps(data) - def Terminate(self, request, context=None): - if request.WhichOneof("terminate_type") == "task_object": + def release(self, client_id: str, id: bytes) -> bool: + if client_id in self.object_refs: + if id in self.object_refs[client_id]: + logger.debug(f"Releasing object {id.hex()} for {client_id}") + del self.object_refs[client_id][id] + return True + + if client_id in self.actor_owners: + if id in self.actor_owners[client_id]: + logger.debug(f"Releasing actor {id.hex()} for {client_id}") + del self.actor_refs[id] + self.actor_owners[client_id].remove(id) + return True + + return False + + def release_all(self, client_id): + self._release_objects(client_id) + self._release_actors(client_id) + + def _release_objects(self, client_id): + if client_id not in self.object_refs: + logger.debug(f"Releasing client with no references: {client_id}") + return + count = len(self.object_refs[client_id]) + del self.object_refs[client_id] + logger.debug(f"Released all {count} objects for client {client_id}") + + def _release_actors(self, client_id): + if client_id not in self.actor_owners: + logger.debug(f"Releasing client with no actors: {client_id}") + count = 0 + for id_bytes in self.actor_owners[client_id]: + count += 1 + del self.actor_refs[id_bytes] + del self.actor_owners[client_id] + logger.debug(f"Released all {count} actors for client: {client_id}") + + def Terminate(self, req, context=None): + if req.WhichOneof("terminate_type") == "task_object": try: - object_ref = cloudpickle.loads(request.task_object.handle) + object_ref = \ + self.object_refs[req.client_id][req.task_object.id] ray.cancel( object_ref, - force=request.task_object.force, - recursive=request.task_object.recursive) + force=req.task_object.force, + recursive=req.task_object.recursive) except Exception as e: return_exception_in_context(e, context) - elif request.WhichOneof("terminate_type") == "actor": + elif req.WhichOneof("terminate_type") == "actor": try: - actor_ref = cloudpickle.loads(request.actor.handle) - ray.kill(actor_ref, no_restart=request.actor.no_restart) + actor_ref = self.actor_refs[req.actor.id] + ray.kill(actor_ref, no_restart=req.actor.no_restart) except Exception as e: return_exception_in_context(e, context) else: @@ -84,61 +134,71 @@ class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer): return ray_client_pb2.TerminateResponse(ok=True) def GetObject(self, request, context=None): - request_ref = cloudpickle.loads(request.handle) - if request_ref.binary() not in self.object_refs: + return self._get_object(request, "", context) + + def _get_object(self, request, client_id: str, context=None): + if request.id not in self.object_refs[client_id]: return ray_client_pb2.GetResponse(valid=False) - objectref = self.object_refs[request_ref.binary()] - logger.info("get: %s" % objectref) + objectref = self.object_refs[client_id][request.id] + logger.debug("get: %s" % objectref) try: item = ray.get(objectref, timeout=request.timeout) except Exception as e: - return_exception_in_context(e, context) - item_ser = cloudpickle.dumps(item) + return ray_client_pb2.GetResponse( + valid=False, error=cloudpickle.dumps(e)) + item_ser = dumps_from_server(item, client_id, self) return ray_client_pb2.GetResponse(valid=True, data=item_ser) - def PutObject(self, request, context=None) -> ray_client_pb2.PutResponse: - obj = cloudpickle.loads(request.data) - objectref = self._put_and_retain_obj(obj) - pickled_ref = cloudpickle.dumps(objectref) - return ray_client_pb2.PutResponse( - ref=make_remote_ref(objectref.binary(), pickled_ref)) + def PutObject(self, request: ray_client_pb2.PutRequest, + context=None) -> ray_client_pb2.PutResponse: + """gRPC entrypoint for unary PutObject + """ + return self._put_object(request, "", context) - def _put_and_retain_obj(self, obj) -> ray.ObjectRef: + def _put_object(self, + request: ray_client_pb2.PutRequest, + client_id: str, + context=None): + """Put an object in the cluster with ray.put() via gRPC. + + Args: + request: PutRequest with pickled data. + client_id: The client who owns this data, for tracking when to + delete this reference. + context: gRPC context. + """ + obj = loads_from_client(request.data, self) objectref = ray.put(obj) - self.object_refs[objectref.binary()] = objectref - logger.info("put: %s" % objectref) - return objectref + self.object_refs[client_id][objectref.binary()] = objectref + logger.debug("put: %s" % objectref) + return ray_client_pb2.PutResponse(id=objectref.binary()) def WaitObject(self, request, context=None) -> ray_client_pb2.WaitResponse: - object_refs = [cloudpickle.loads(o) for o in request.object_handles] + object_refs = [] + for id in request.object_ids: + if id not in self.object_refs[request.client_id]: + raise Exception( + "Asking for a ref not associated with this client: %s" % + str(id)) + object_refs.append(self.object_refs[request.client_id][id]) num_returns = request.num_returns timeout = request.timeout - object_refs_ids = [] - for object_ref in object_refs: - if object_ref.binary() not in self.object_refs: - return ray_client_pb2.WaitResponse(valid=False) - object_refs_ids.append(self.object_refs[object_ref.binary()]) try: ready_object_refs, remaining_object_refs = ray.wait( - object_refs_ids, + object_refs, num_returns=num_returns, timeout=timeout if timeout != -1 else None) except Exception: # TODO(ameer): improve exception messages. return ray_client_pb2.WaitResponse(valid=False) - logger.info("wait: %s %s" % (str(ready_object_refs), - str(remaining_object_refs))) + logger.debug("wait: %s %s" % (str(ready_object_refs), + str(remaining_object_refs))) ready_object_ids = [ - make_remote_ref( - id=ready_object_ref.binary(), - handle=cloudpickle.dumps(ready_object_ref), - ) for ready_object_ref in ready_object_refs + ready_object_ref.binary() for ready_object_ref in ready_object_refs ] remaining_object_ids = [ - make_remote_ref( - id=remaining_object_ref.binary(), - handle=cloudpickle.dumps(remaining_object_ref), - ) for remaining_object_ref in remaining_object_refs + remaining_object_ref.binary() + for remaining_object_ref in remaining_object_refs ] return ray_client_pb2.WaitResponse( valid=True, @@ -150,16 +210,17 @@ class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer): logger.info("schedule: %s %s" % (task.name, ray_client_pb2.ClientTask.RemoteExecType.Name(task.type))) - if task.type == ray_client_pb2.ClientTask.FUNCTION: - return self._schedule_function(task, context, prepared_args) - elif task.type == ray_client_pb2.ClientTask.ACTOR: - return self._schedule_actor(task, context, prepared_args) - elif task.type == ray_client_pb2.ClientTask.METHOD: - return self._schedule_method(task, context, prepared_args) - else: - raise NotImplementedError( - "Unimplemented Schedule task type: %s" % - ray_client_pb2.ClientTask.RemoteExecType.Name(task.type)) + with stash_api_for_tests(self._test_mode): + if task.type == ray_client_pb2.ClientTask.FUNCTION: + return self._schedule_function(task, context, prepared_args) + elif task.type == ray_client_pb2.ClientTask.ACTOR: + return self._schedule_actor(task, context, prepared_args) + elif task.type == ray_client_pb2.ClientTask.METHOD: + return self._schedule_method(task, context, prepared_args) + else: + raise NotImplementedError( + "Unimplemented Schedule task type: %s" % + ray_client_pb2.ClientTask.RemoteExecType.Name(task.type)) def _schedule_method( self, @@ -170,80 +231,67 @@ class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer): if actor_handle is None: raise Exception( "Can't run an actor the server doesn't have a handle for") - arglist = _convert_args(task.args, prepared_args) - with stash_api_for_tests(self._test_mode): - output = getattr(actor_handle, task.name).remote(*arglist) - self.object_refs[output.binary()] = output - pickled_ref = cloudpickle.dumps(output) - return ray_client_pb2.ClientTaskTicket( - return_ref=make_remote_ref(output.binary(), pickled_ref)) + arglist = self._convert_args(task.args, prepared_args) + output = getattr(actor_handle, task.name).remote(*arglist) + self.object_refs[task.client_id][output.binary()] = output + return ray_client_pb2.ClientTaskTicket(return_id=output.binary()) def _schedule_actor(self, task: ray_client_pb2.ClientTask, context=None, prepared_args=None) -> ray_client_pb2.ClientTaskTicket: - with stash_api_for_tests(self._test_mode): - payload_ref = cloudpickle.loads(task.payload_id) - if payload_ref.binary() not in self.registered_actor_classes: - actor_class_ref = self.object_refs[payload_ref.binary()] - actor_class = ray.get(actor_class_ref) - if not inspect.isclass(actor_class): - raise Exception("Attempting to schedule actor that " - "isn't a class.") - reg_class = ray.remote(actor_class) - self.registered_actor_classes[payload_ref.binary()] = reg_class - remote_class = self.registered_actor_classes[payload_ref.binary()] - arglist = _convert_args(task.args, prepared_args) - actor = remote_class.remote(*arglist) - actorhandle = cloudpickle.dumps(actor) - self.actor_refs[actorhandle] = actor + if task.payload_id not in self.registered_actor_classes: + actor_class_ref = \ + self.object_refs[task.client_id][task.payload_id] + actor_class = ray.get(actor_class_ref) + if not inspect.isclass(actor_class): + raise Exception("Attempting to schedule actor that " + "isn't a class.") + reg_class = ray.remote(actor_class) + self.registered_actor_classes[task.payload_id] = reg_class + remote_class = self.registered_actor_classes[task.payload_id] + arglist = self._convert_args(task.args, prepared_args) + actor = remote_class.remote(*arglist) + self.actor_refs[actor._actor_id.binary()] = actor + self.actor_owners[task.client_id].add(actor._actor_id.binary()) return ray_client_pb2.ClientTaskTicket( - return_ref=make_remote_ref(actor._actor_id.binary(), actorhandle)) + return_id=actor._actor_id.binary()) def _schedule_function( self, task: ray_client_pb2.ClientTask, context=None, prepared_args=None) -> ray_client_pb2.ClientTaskTicket: - payload_ref = cloudpickle.loads(task.payload_id) - if payload_ref.binary() not in self.function_refs: - funcref = self.object_refs[payload_ref.binary()] + remote_func = self.lookup_or_register_func(task.payload_id, + task.client_id) + arglist = self._convert_args(task.args, prepared_args) + # Prepare call if we're in a test + with current_func(remote_func): + output = remote_func.remote(*arglist) + if output.binary() in self.object_refs[task.client_id]: + raise Exception("already found it") + self.object_refs[task.client_id][output.binary()] = output + return ray_client_pb2.ClientTaskTicket(return_id=output.binary()) + + def _convert_args(self, arg_list, prepared_args=None): + if prepared_args is not None: + return prepared_args + out = [] + for arg in arg_list: + t = convert_from_arg(arg, self) + out.append(t) + return out + + def lookup_or_register_func(self, id: bytes, client_id: str + ) -> ray.remote_function.RemoteFunction: + if id not in self.function_refs: + funcref = self.object_refs[client_id][id] func = ray.get(funcref) if not inspect.isfunction(func): - raise Exception("Attempting to schedule function that " + raise Exception("Attempting to register function that " "isn't a function.") - self.function_refs[payload_ref.binary()] = ray.remote(func) - remote_func = self.function_refs[payload_ref.binary()] - arglist = _convert_args(task.args, prepared_args) - # Prepare call if we're in a test - with stash_api_for_tests(self._test_mode): - output = remote_func.remote(*arglist) - if output.binary() in self.object_refs: - raise Exception("already found it") - self.object_refs[output.binary()] = output - pickled_output = cloudpickle.dumps(output) - return ray_client_pb2.ClientTaskTicket( - return_ref=make_remote_ref(output.binary(), pickled_output)) - - -def _convert_args(arg_list, prepared_args=None): - if prepared_args is not None: - return prepared_args - out = [] - for arg in arg_list: - t = convert_from_arg(arg) - if isinstance(t, ClientObjectRef): - out.append(t._unpack_ref()) - else: - out.append(t) - return out - - -def make_remote_ref(id: bytes, handle: bytes) -> ray_client_pb2.RemoteRef: - return ray_client_pb2.RemoteRef( - id=id, - handle=handle, - ) + self.function_refs[id] = ray.remote(func) + return self.function_refs[id] def return_exception_in_context(err, context): @@ -252,12 +300,20 @@ def return_exception_in_context(err, context): context.set_code(grpc.StatusCode.INTERNAL) +def encode_exception(exception) -> str: + data = cloudpickle.dumps(exception) + return base64.standard_b64encode(data).decode() + + def serve(connection_str, test_mode=False): server = grpc.server(futures.ThreadPoolExecutor(max_workers=10)) task_servicer = RayletServicer(test_mode=test_mode) + data_servicer = DataServicer(task_servicer) _set_server_api(RayServerAPI(task_servicer)) ray_client_pb2_grpc.add_RayletDriverServicer_to_server( task_servicer, server) + ray_client_pb2_grpc.add_RayletDataStreamerServicer_to_server( + data_servicer, server) server.add_insecure_port(connection_str) server.start() return server diff --git a/python/ray/experimental/client/server/server_pickler.py b/python/ray/experimental/client/server/server_pickler.py new file mode 100644 index 000000000..ea6bd74d0 --- /dev/null +++ b/python/ray/experimental/client/server/server_pickler.py @@ -0,0 +1,119 @@ +""" +Implements the client side of the client/server pickling protocol. + +These picklers are aware of the server internals and can find the +references held for the client within the server. + +More discussion about the client/server pickling protocol can be found in: + + ray/experimental/client/client_pickler.py + +ServerPickler dumps ray objects from the server into the appropriate stubs. +ClientUnpickler loads stubs from the client and finds their associated handle +in the server instance. +""" +import cloudpickle +import io +import sys +import ray + +from typing import Any +from typing import TYPE_CHECKING + +from ray.experimental.client.client_pickler import PickleStub +from ray.experimental.client.server.server_stubs import ServerFunctionSentinel + +if TYPE_CHECKING: + from ray.experimental.client.server.server import RayletServicer + import ray.core.generated.ray_client_pb2 as ray_client_pb2 + +if sys.version_info < (3, 8): + try: + import pickle5 as pickle # noqa: F401 + except ImportError: + import pickle # noqa: F401 +else: + import pickle # noqa: F401 + + +class ServerPickler(cloudpickle.CloudPickler): + def __init__(self, client_id: str, server: "RayletServicer", *args, + **kwargs): + super().__init__(*args, **kwargs) + self.client_id = client_id + self.server = server + + def persistent_id(self, obj): + if isinstance(obj, ray.ObjectRef): + obj_id = obj.binary() + if obj_id not in self.server.object_refs[self.client_id]: + # We're passing back a reference, probably inside a reference. + # Let's hold onto it. + self.server.object_refs[self.client_id][obj_id] = obj + return PickleStub( + type="Object", + client_id=self.client_id, + ref_id=obj_id, + ) + elif isinstance(obj, ray.actor.ActorHandle): + actor_id = obj._actor_id.binary() + if actor_id not in self.server.actor_refs: + # We're passing back a handle, probably inside a reference. + self.actor_refs[actor_id] = obj + if actor_id not in self.actor_owners[self.client_id]: + self.actor_owners[self.client_id].add(actor_id) + return PickleStub( + type="Actor", + client_id=self.client_id, + ref_id=obj._actor_id.binary(), + ) + return None + + +class ClientUnpickler(pickle.Unpickler): + def __init__(self, server, *args, **kwargs): + super().__init__(*args, **kwargs) + self.server = server + + def persistent_load(self, pid): + assert isinstance(pid, PickleStub) + if pid.type == "Object": + return self.server.object_refs[pid.client_id][pid.ref_id] + elif pid.type == "Actor": + return self.server.actor_refs[pid.ref_id] + elif pid.type == "RemoteFuncSelfReference": + return ServerFunctionSentinel() + elif pid.type == "RemoteFunc": + return self.server.lookup_or_register_func(pid.ref_id, + pid.client_id) + else: + raise NotImplementedError("Uncovered client data type") + + +def dumps_from_server(obj: Any, + client_id: str, + server_instance: "RayletServicer", + protocol=None) -> bytes: + with io.BytesIO() as file: + sp = ServerPickler(client_id, server_instance, file, protocol=protocol) + sp.dump(obj) + return file.getvalue() + + +def loads_from_client(data: bytes, + server_instance: "RayletServicer", + *, + fix_imports=True, + encoding="ASCII", + errors="strict") -> Any: + if isinstance(data, str): + raise TypeError("Can't load pickle from unicode string") + file = io.BytesIO(data) + return ClientUnpickler( + server_instance, file, fix_imports=fix_imports, + encoding=encoding).load() + + +def convert_from_arg(pb: "ray_client_pb2.Arg", + server: "RayletServicer") -> Any: + return loads_from_client(pb.data, server) diff --git a/python/ray/experimental/client/server/server_stubs.py b/python/ray/experimental/client/server/server_stubs.py new file mode 100644 index 000000000..f55f64f25 --- /dev/null +++ b/python/ray/experimental/client/server/server_stubs.py @@ -0,0 +1,29 @@ +from contextlib import contextmanager + +_current_remote_func = None + + +@contextmanager +def current_func(f): + global _current_remote_func + remote_func = _current_remote_func + _current_remote_func = f + try: + yield + finally: + _current_remote_func = remote_func + + +class ServerFunctionSentinel: + def __init__(self): + pass + + def __reduce__(self): + global _current_remote_func + if _current_remote_func is None: + return (ServerFunctionSentinel, tuple()) + return (identity, (_current_remote_func, )) + + +def identity(x): + return x diff --git a/python/ray/experimental/client/worker.py b/python/ray/experimental/client/worker.py index 0a108e4f2..54ac71711 100644 --- a/python/ray/experimental/client/worker.py +++ b/python/ray/experimental/client/worker.py @@ -2,27 +2,32 @@ It implements the Ray API functions that are forwarded through grpc calls to the server. """ +import base64 import inspect import json import logging +import uuid +from collections import defaultdict from typing import Any +from typing import Dict from typing import List from typing import Tuple from typing import Optional -import ray.cloudpickle as cloudpickle from ray.util.inspect import is_cython import grpc -from ray.exceptions import TaskCancelledError +import ray.cloudpickle as cloudpickle import ray.core.generated.ray_client_pb2 as ray_client_pb2 import ray.core.generated.ray_client_pb2_grpc as ray_client_pb2_grpc -from ray.experimental.client.common import convert_to_arg -from ray.experimental.client.common import decode_exception +from ray.experimental.client.client_pickler import convert_to_arg +from ray.experimental.client.client_pickler import loads_from_server +from ray.experimental.client.client_pickler import dumps_from_client from ray.experimental.client.common import ClientObjectRef from ray.experimental.client.common import ClientActorClass from ray.experimental.client.common import ClientActorHandle from ray.experimental.client.common import ClientRemoteFunc +from ray.experimental.client.dataclient import DataClient logger = logging.getLogger(__name__) @@ -31,34 +36,32 @@ class Worker: def __init__(self, conn_str: str = "", secure: bool = False, - metadata: List[Tuple[str, str]] = None, - stub=None): + metadata: List[Tuple[str, str]] = None): """Initializes the worker side grpc client. Args: - stub: custom grpc stub. secure: whether to use SSL secure channel or not. metadata: additional metadata passed in the grpc request headers. """ self.metadata = metadata self.channel = None - if stub is None: - if secure: - credentials = grpc.ssl_channel_credentials() - self.channel = grpc.secure_channel(conn_str, credentials) - else: - self.channel = grpc.insecure_channel(conn_str) - self.server = ray_client_pb2_grpc.RayletDriverStub(self.channel) + self._client_id = make_client_id() + if secure: + credentials = grpc.ssl_channel_credentials() + self.channel = grpc.secure_channel(conn_str, credentials) else: - self.server = stub + self.channel = grpc.insecure_channel(conn_str) + self.server = ray_client_pb2_grpc.RayletDriverStub(self.channel) + self.data_client = DataClient(self.channel, self._client_id) + self.reference_count: Dict[bytes, int] = defaultdict(int) def get(self, vals, *, timeout: Optional[float] = None) -> Any: to_get = [] single = False if isinstance(vals, list): - to_get = [x.handle for x in vals] + to_get = vals elif isinstance(vals, ClientObjectRef): - to_get = [vals.handle] + to_get = [vals] single = True else: raise Exception("Can't get something that's not a " @@ -70,15 +73,15 @@ class Worker: out = out[0] return out - def _get(self, handle: bytes, timeout: float): - req = ray_client_pb2.GetRequest(handle=handle, timeout=timeout) + def _get(self, ref: ClientObjectRef, timeout: float): + req = ray_client_pb2.GetRequest(id=ref.id, timeout=timeout) try: - data = self.server.GetObject(req, metadata=self.metadata) + data = self.data_client.GetObject(req) except grpc.RpcError as e: - raise decode_exception(e.details()) + raise e.details() if not data.valid: - raise TaskCancelledError(handle) - return cloudpickle.loads(data.data) + raise cloudpickle.loads(data.error) + return loads_from_server(data.data) def put(self, vals): to_put = [] @@ -95,10 +98,10 @@ class Worker: return out def _put(self, val): - data = cloudpickle.dumps(val) + data = dumps_from_client(val, self._client_id) req = ray_client_pb2.PutRequest(data=data) - resp = self.server.PutObject(req, metadata=self.metadata) - return ClientObjectRef.from_remote_ref(resp.ref) + resp = self.data_client.PutObject(req) + return ClientObjectRef(resp.id) def wait(self, object_refs: List[ClientObjectRef], @@ -110,11 +113,10 @@ class Worker: for ref in object_refs: assert isinstance(ref, ClientObjectRef) data = { - "object_handles": [ - object_ref.handle for object_ref in object_refs - ], + "object_ids": [object_ref.id for object_ref in object_refs], "num_returns": num_returns, - "timeout": timeout if timeout else -1 + "timeout": timeout if timeout else -1, + "client_id": self._client_id, } req = ray_client_pb2.WaitRequest(**data) resp = self.server.WaitObject(req, metadata=self.metadata) @@ -122,12 +124,10 @@ class Worker: # TODO(ameer): improve error/exceptions messages. raise Exception("Client Wait request failed. Reference invalid?") client_ready_object_ids = [ - ClientObjectRef.from_remote_ref(ref) - for ref in resp.ready_object_ids + ClientObjectRef(ref) for ref in resp.ready_object_ids ] client_remaining_object_ids = [ - ClientObjectRef.from_remote_ref(ref) - for ref in resp.remaining_object_ids + ClientObjectRef(ref) for ref in resp.remaining_object_ids ] return (client_ready_object_ids, client_remaining_object_ids) @@ -144,19 +144,38 @@ class Worker: raise TypeError("The @ray.remote decorator must be applied to " "either a function or to a class.") - def call_remote(self, instance, *args, **kwargs): + def call_remote(self, instance, *args, **kwargs) -> bytes: task = instance._prepare_client_task() for arg in args: - pb_arg = convert_to_arg(arg) + pb_arg = convert_to_arg(arg, self._client_id) task.args.append(pb_arg) - logging.debug("Scheduling %s" % task) + task.client_id = self._client_id + logger.debug("Scheduling %s" % task) ticket = self.server.Schedule(task, metadata=self.metadata) - return ClientObjectRef.from_remote_ref(ticket.return_ref) + return ticket.return_id + + def call_release(self, id: bytes) -> None: + self.reference_count[id] -= 1 + if self.reference_count[id] == 0: + self._release_server(id) + del self.reference_count[id] + + def _release_server(self, id: bytes) -> None: + if self.data_client is not None: + logger.debug(f"Releasing {id}") + self.data_client.ReleaseObject( + ray_client_pb2.ReleaseRequest(ids=[id])) + + def call_retain(self, id: bytes) -> None: + logger.debug(f"Retaining {id}") + self.reference_count[id] += 1 def close(self): + self.data_client.close() self.server = None if self.channel: self.channel.close() + self.channel = None def terminate_actor(self, actor: ClientActorHandle, no_restart: bool) -> None: @@ -164,10 +183,11 @@ class Worker: raise ValueError("ray.kill() only supported for actors. " "Got: {}.".format(type(actor))) term_actor = ray_client_pb2.TerminateRequest.ActorTerminate() - term_actor.handle = actor.actor_ref.handle + term_actor.id = actor.actor_ref.id term_actor.no_restart = no_restart try: term = ray_client_pb2.TerminateRequest(actor=term_actor) + term.client_id = self._client_id self.server.Terminate(term) except grpc.RpcError as e: raise decode_exception(e.details()) @@ -179,11 +199,12 @@ class Worker: "ray.cancel() only supported for non-actor object refs. " f"Got: {type(obj)}.") term_object = ray_client_pb2.TerminateRequest.TaskObjectTerminate() - term_object.handle = obj.handle + term_object.id = obj.id term_object.force = force term_object.recursive = recursive try: term = ray_client_pb2.TerminateRequest(task_object=term_object) + term.client_id = self._client_id self.server.Terminate(term) except grpc.RpcError as e: raise decode_exception(e.details()) @@ -201,3 +222,13 @@ class Worker: return self.get_cluster_info( ray_client_pb2.ClusterInfoType.IS_INITIALIZED) return False + + +def make_client_id() -> str: + id = uuid.uuid4() + return id.hex + + +def decode_exception(data) -> Exception: + data = base64.standard_b64decode(data) + return loads_from_server(data) diff --git a/python/ray/tests/BUILD b/python/ray/tests/BUILD index 55fac64e5..588710e3a 100644 --- a/python/ray/tests/BUILD +++ b/python/ray/tests/BUILD @@ -96,6 +96,7 @@ py_test_module_list( "test_debug_tools.py", "test_experimental_client.py", "test_experimental_client_metadata.py", + "test_experimental_client_references.py", "test_experimental_client_terminate.py", "test_job.py", "test_memstat.py", diff --git a/python/ray/tests/test_experimental_client.py b/python/ray/tests/test_experimental_client.py index cbe52675f..1231b6730 100644 --- a/python/ray/tests/test_experimental_client.py +++ b/python/ray/tests/test_experimental_client.py @@ -142,7 +142,7 @@ def test_function_calling_function(ray_start_regular_shared): @ray.remote def f(): - print(f, f._name, g._name, g) + print(f, g) return ray.get(g.remote()) print(f, type(f)) diff --git a/python/ray/tests/test_experimental_client_references.py b/python/ray/tests/test_experimental_client_references.py new file mode 100644 index 000000000..9675b9c97 --- /dev/null +++ b/python/ray/tests/test_experimental_client_references.py @@ -0,0 +1,152 @@ +from ray.tests.test_experimental_client import ray_start_client_server +from ray.test_utils import wait_for_condition +import ray as real_ray +from ray.core.generated.gcs_pb2 import ActorTableData +from ray.experimental.client import _get_server_instance + + +def server_object_ref_count(n): + server = _get_server_instance() + assert server is not None + + def test_cond(): + if len(server.object_refs) == 0: + # No open clients + return n == 0 + client_id = list(server.object_refs.keys())[0] + return len(server.object_refs[client_id]) == n + + return test_cond + + +def server_actor_ref_count(n): + server = _get_server_instance() + assert server is not None + + def test_cond(): + if len(server.actor_refs) == 0: + # No running actors + return n == 0 + return len(server.actor_refs) == n + + return test_cond + + +def test_delete_refs_on_disconnect(ray_start_regular): + with ray_start_client_server() as ray: + + @ray.remote + def f(x): + return x + 2 + + thing1 = f.remote(6) # noqa + thing2 = ray.put("Hello World") # noqa + + # One put, one function -- the function result thing1 is + # in a different category, according to the raylet. + assert len(real_ray.objects()) == 2 + # But we're maintaining the reference + assert server_object_ref_count(3)() + # And can get the data + assert ray.get(thing1) == 8 + + # Close the client + ray.close() + + wait_for_condition(server_object_ref_count(0), timeout=5) + + def test_cond(): + return len(real_ray.objects()) == 0 + + wait_for_condition(test_cond, timeout=5) + + +def test_delete_ref_on_object_deletion(ray_start_regular): + with ray_start_client_server() as ray: + vals = { + "ref": ray.put("Hello World"), + "ref2": ray.put("This value stays"), + } + + del vals["ref"] + + wait_for_condition(server_object_ref_count(1), timeout=5) + + +def test_delete_actor_on_disconnect(ray_start_regular): + with ray_start_client_server() as ray: + + @ray.remote + class Accumulator: + def __init__(self): + self.acc = 0 + + def inc(self): + self.acc += 1 + + def get(self): + return self.acc + + actor = Accumulator.remote() + actor.inc.remote() + + assert server_actor_ref_count(1)() + + assert ray.get(actor.get.remote()) == 1 + + ray.close() + + wait_for_condition(server_actor_ref_count(0), timeout=5) + + def test_cond(): + alive_actors = [ + v for v in real_ray.actors().values() + if v["State"] != ActorTableData.DEAD + ] + return len(alive_actors) == 0 + + wait_for_condition(test_cond, timeout=10) + + +def test_delete_actor(ray_start_regular): + with ray_start_client_server() as ray: + + @ray.remote + class Accumulator: + def __init__(self): + self.acc = 0 + + def inc(self): + self.acc += 1 + + actor = Accumulator.remote() + actor.inc.remote() + actor2 = Accumulator.remote() + actor2.inc.remote() + + assert server_actor_ref_count(2)() + + del actor + + wait_for_condition(server_actor_ref_count(1), timeout=5) + + +def test_simple_multiple_references(ray_start_regular): + with ray_start_client_server() as ray: + + @ray.remote + class A: + def __init__(self): + self.x = ray.put("hi") + + def get(self): + return [self.x] + + a = A.remote() + ref1 = ray.get(a.get.remote())[0] + ref2 = ray.get(a.get.remote())[0] + del a + assert ray.get(ref1) == "hi" + del ref1 + assert ray.get(ref2) == "hi" + del ref2 diff --git a/src/ray/protobuf/ray_client.proto b/src/ray/protobuf/ray_client.proto index d4c392321..cdc3ee8aa 100644 --- a/src/ray/protobuf/ray_client.proto +++ b/src/ray/protobuf/ray_client.proto @@ -18,17 +18,24 @@ package ray.rpc; enum Type { DEFAULT = 0; } +// An argument to a ClientTask. message Arg { enum Locality { INTERNED = 0; REFERENCE = 1; } + + // The type of argument this is -- whether a data blob or a reference. Locality local = 1; + // The reference id, if a reference. bytes reference_id = 2; + // A data blob, if passed in-band. bytes data = 3; + // How to decode this data blob. Type type = 4; } +// Represents one unit of work to be executed by the server. message ClientTask { enum RemoteExecType { FUNCTION = 0; @@ -36,49 +43,69 @@ message ClientTask { METHOD = 2; STATIC_METHOD = 3; } + // Which type of work this request represents. RemoteExecType type = 1; + // A name parameter, if the payload can be called in more than one way (like a method on + // a payload object). string name = 2; + // A reference to the payload. bytes payload_id = 3; + // The parameters to pass to this call. repeated Arg args = 4; -} - -message RemoteRef { - bytes id = 1; - bytes handle = 2; + // The ID of the client namespace associated with the Datapath stream making this + // request. + string client_id = 5; } message ClientTaskTicket { - RemoteRef return_ref = 1; + // A reference to the returned value from the execution. + bytes return_id = 1; } +// Delivers data to the server message PutRequest { + // The data blob for the server to store. bytes data = 1; } message PutResponse { - RemoteRef ref = 1; + // The reference ID for the data that the server has stored. + bytes id = 1; } +// Requests data from the server. message GetRequest { - bytes handle = 1; + // The reference ID for the requested object data + bytes id = 1; + // Length of time to wait for data to be available, in seconds. Zero is no timeout. float timeout = 2; } message GetResponse { + // Whether or not the data was successfully retrieved bool valid = 1; + // The data blob, on success bytes data = 2; + // An error blob (for example, an exception) on failure. + bytes error = 3; } +// Waits for data to be ready on the server, with a timeout. message WaitRequest { - repeated bytes object_handles = 1; + // The IDs of the data to wait for ready status. + repeated bytes object_ids = 1; + // How many of the above ids to wait for before returning. int64 num_returns = 2; + // How long to wait for these IDs to become ready. double timeout = 3; + // The Client namespace associated with the Datapath stream that holds these IDs. + string client_id = 4; } message WaitResponse { bool valid = 1; - repeated RemoteRef ready_object_ids = 2; - repeated RemoteRef remaining_object_ids = 3; + repeated bytes ready_object_ids = 2; + repeated bytes remaining_object_ids = 3; } message ClusterInfoType { @@ -108,18 +135,19 @@ message ClusterInfoResponse { message TerminateRequest { message ActorTerminate { - bytes handle = 1; + bytes id = 1; bool no_restart = 2; } message TaskObjectTerminate { - bytes handle = 1; + bytes id = 1; bool force = 2; bool recursive = 3; } + string client_id = 1; oneof terminate_type { - ActorTerminate actor = 1; - TaskObjectTerminate task_object = 2; + ActorTerminate actor = 2; + TaskObjectTerminate task_object = 3; } } @@ -141,3 +169,40 @@ service RayletDriver { rpc ClusterInfo(ClusterInfoRequest) returns (ClusterInfoResponse) { } } + +message ReleaseRequest { + // The IDs to release from the server; the client connected on this stream no + // longer holds a reference to them. + repeated bytes ids = 1; +} + +message ReleaseResponse { + // For each requested ID, whether or not it was released. + repeated bool ok = 2; +} + +message DataRequest { + // An incrementing counter of request IDs on the Datapath, + // to match requests with responses asynchronously. + int32 req_id = 1; + oneof type { + GetRequest get = 2; + PutRequest put = 3; + ReleaseRequest release = 4; + } +} + +message DataResponse { + // The request id that this response matches with. + int32 req_id = 1; + oneof type { + GetResponse get = 2; + PutResponse put = 3; + ReleaseResponse release = 4; + } +} + +service RayletDataStreamer { + rpc Datapath(stream DataRequest) returns (stream DataResponse) { + } +} From 92812f2e8a11c67c4c730715051b7275007bff3c Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Fri, 18 Dec 2020 12:17:54 -0800 Subject: [PATCH 25/88] Implement resource deadlock detection for new scheduler (#12961) --- python/ray/test_utils.py | 2 +- python/ray/tests/test_failure.py | 56 +++++++++++++-- python/ray/tests/test_global_gc.py | 9 ++- src/ray/raylet/node_manager.cc | 68 +++++++++++-------- .../scheduling/cluster_resource_data.cc | 55 +++++++++++++++ .../raylet/scheduling/cluster_resource_data.h | 2 + .../scheduling/cluster_resource_scheduler.cc | 6 ++ .../scheduling/cluster_resource_scheduler.h | 3 + .../raylet/scheduling/cluster_task_manager.cc | 31 ++++++++- .../raylet/scheduling/cluster_task_manager.h | 14 +++- .../scheduling/cluster_task_manager_test.cc | 43 ++++++++++++ 11 files changed, 244 insertions(+), 45 deletions(-) diff --git a/python/ray/test_utils.py b/python/ray/test_utils.py index 648b4f4f9..7f6aaa360 100644 --- a/python/ray/test_utils.py +++ b/python/ray/test_utils.py @@ -414,7 +414,7 @@ def init_error_pubsub(): return p -def get_error_message(pub_sub, num, error_type=None, timeout=5): +def get_error_message(pub_sub, num, error_type=None, timeout=20): """Get errors through pub/sub.""" start_time = time.time() msgs = [] diff --git a/python/ray/tests/test_failure.py b/python/ray/tests/test_failure.py index 84904a5e1..b3505bbf5 100644 --- a/python/ray/tests/test_failure.py +++ b/python/ray/tests/test_failure.py @@ -17,7 +17,8 @@ import ray.ray_constants as ray_constants from ray.exceptions import RayTaskError from ray.cluster_utils import Cluster from ray.test_utils import (wait_for_condition, SignalActor, init_error_pubsub, - get_error_message, Semaphore) + get_error_message, Semaphore, + new_scheduler_enabled) def test_failed_task(ray_start_regular, error_pubsub): @@ -632,11 +633,12 @@ def test_export_large_objects(ray_start_regular, error_pubsub): assert errors[0].type == ray_constants.PICKLING_LARGE_OBJECT_PUSH_ERROR -@pytest.mark.skip(reason="TODO detect resource deadlock") -def test_warning_for_resource_deadlock(error_pubsub, shutdown_only): - p = error_pubsub - # Check that we get warning messages for infeasible tasks. - ray.init(num_cpus=1) +@pytest.mark.skipif( + new_scheduler_enabled(), reason="Supposed to deadlock, but it doesn't") +def test_warning_all_tasks_blocked(shutdown_only): + ray.init( + num_cpus=1, _system_config={"debug_dump_period_milliseconds": 500}) + p = init_error_pubsub() @ray.remote(num_cpus=1) class Foo: @@ -646,7 +648,7 @@ def test_warning_for_resource_deadlock(error_pubsub, shutdown_only): @ray.remote def f(): # Creating both actors is not possible. - actors = [Foo.remote() for _ in range(2)] + actors = [Foo.remote() for _ in range(3)] for a in actors: ray.get(a.f.remote()) @@ -657,6 +659,46 @@ def test_warning_for_resource_deadlock(error_pubsub, shutdown_only): assert errors[0].type == ray_constants.RESOURCE_DEADLOCK_ERROR +def test_warning_actor_waiting_on_actor(shutdown_only): + ray.init( + num_cpus=1, _system_config={"debug_dump_period_milliseconds": 500}) + p = init_error_pubsub() + + @ray.remote(num_cpus=1) + class Actor: + pass + + a = Actor.remote() # noqa + b = Actor.remote() # noqa + + errors = get_error_message(p, 1, ray_constants.RESOURCE_DEADLOCK_ERROR) + assert len(errors) == 1 + assert errors[0].type == ray_constants.RESOURCE_DEADLOCK_ERROR + + +def test_warning_task_waiting_on_actor(shutdown_only): + ray.init( + num_cpus=1, _system_config={"debug_dump_period_milliseconds": 500}) + p = init_error_pubsub() + + @ray.remote(num_cpus=1) + class Actor: + pass + + a = Actor.remote() # noqa + + @ray.remote(num_cpus=1) + def f(): + print("f running") + time.sleep(999) + + ids = [f.remote()] # noqa + + errors = get_error_message(p, 1, ray_constants.RESOURCE_DEADLOCK_ERROR) + assert len(errors) == 1 + assert errors[0].type == ray_constants.RESOURCE_DEADLOCK_ERROR + + def test_warning_for_infeasible_tasks(ray_start_regular, error_pubsub): p = error_pubsub # Check that we get warning messages for infeasible tasks. diff --git a/python/ray/tests/test_global_gc.py b/python/ray/tests/test_global_gc.py index 247516450..a3039f14d 100644 --- a/python/ray/tests/test_global_gc.py +++ b/python/ray/tests/test_global_gc.py @@ -9,7 +9,7 @@ import pytest import ray import ray.cluster_utils -from ray.test_utils import wait_for_condition, new_scheduler_enabled +from ray.test_utils import wait_for_condition from ray.internal.internal_api import global_gc logger = logging.getLogger(__name__) @@ -166,9 +166,9 @@ def test_global_gc_when_full(shutdown_only): gc.enable() -@pytest.mark.skipif(new_scheduler_enabled(), reason="hangs") def test_global_gc_actors(shutdown_only): - ray.init(num_cpus=1) + ray.init( + num_cpus=1, _system_config={"debug_dump_period_milliseconds": 500}) try: gc.disable() @@ -179,8 +179,7 @@ def test_global_gc_actors(shutdown_only): return "Ok" # Try creating 3 actors. Unless python GC is triggered to break - # reference cycles, this won't be possible. Note this test takes 20s - # to run due to the 10s delay before checking of infeasible tasks. + # reference cycles, this won't be possible. for i in range(3): a = A.remote() cycle = [a] diff --git a/src/ray/raylet/node_manager.cc b/src/ray/raylet/node_manager.cc index 49bf6a6af..2e13bbb91 100644 --- a/src/ray/raylet/node_manager.cc +++ b/src/ray/raylet/node_manager.cc @@ -655,37 +655,47 @@ void NodeManager::HandleReleaseUnusedBundles( // debug_dump_period_ milliseconds. // See https://github.com/ray-project/ray/issues/5790 for details. void NodeManager::WarnResourceDeadlock() { - // Check if any progress is being made on this raylet. - for (const auto &task : local_queues_.GetTasks(TaskState::RUNNING)) { - // Ignore blocked tasks. - if (local_queues_.GetBlockedTaskIds().count(task.GetTaskSpecification().TaskId())) { - continue; - } - // Progress is being made, don't warn. - resource_deadlock_warned_ = 0; - return; - } - - // The node is full of actors and no progress has been made for some time. - // If there are any pending tasks, build a warning. - std::ostringstream error_message; ray::Task exemplar; bool any_pending = false; int pending_actor_creations = 0; int pending_tasks = 0; + std::string available_resources; - // See if any tasks are blocked trying to acquire resources. - for (const auto &task : local_queues_.GetTasks(TaskState::READY)) { - const TaskSpecification &spec = task.GetTaskSpecification(); - if (spec.IsActorCreationTask()) { - pending_actor_creations += 1; - } else { - pending_tasks += 1; + // Check if any progress is being made on this raylet. + for (const auto &worker : worker_pool_.GetAllRegisteredWorkers()) { + if (!worker->IsDead() && !worker->GetAssignedTaskId().IsNil() && + !worker->IsBlocked() && worker->GetActorId().IsNil()) { + // Progress is being made in a task, don't warn. + resource_deadlock_warned_ = 0; + return; } - if (!any_pending) { - exemplar = task; - any_pending = true; + } + + if (new_scheduler_enabled_) { + // Check if any tasks are blocked on resource acquisition. + if (!cluster_task_manager_->AnyPendingTasks( + &exemplar, &any_pending, &pending_actor_creations, &pending_tasks)) { + // No pending tasks, no need to warn. + resource_deadlock_warned_ = 0; + return; } + available_resources = new_resource_scheduler_->GetLocalResourceViewString(); + } else { + // See if any tasks are blocked trying to acquire resources. + for (const auto &task : local_queues_.GetTasks(TaskState::READY)) { + const TaskSpecification &spec = task.GetTaskSpecification(); + if (spec.IsActorCreationTask()) { + pending_actor_creations += 1; + } else { + pending_tasks += 1; + } + if (!any_pending) { + exemplar = task; + any_pending = true; + } + } + SchedulingResources &local_resources = cluster_resource_map_[self_node_id_]; + available_resources = local_resources.GetAvailableResources().ToString(); } // Push an warning to the driver that a task is blocked trying to acquire resources. @@ -693,6 +703,7 @@ void NodeManager::WarnResourceDeadlock() { // case resource_deadlock_warned_: 0 => first time, don't do anything yet // case resource_deadlock_warned_: 1 => second time, print a warning // case resource_deadlock_warned_: >1 => global gc but don't print any warnings + std::ostringstream error_message; if (any_pending && resource_deadlock_warned_++ > 0) { // Actor references may be caught in cycles, preventing them from being deleted. // Trigger global GC to hopefully free up resource slots. @@ -703,15 +714,13 @@ void NodeManager::WarnResourceDeadlock() { return; } - SchedulingResources &local_resources = cluster_resource_map_[self_node_id_]; error_message << "The actor or task with ID " << exemplar.GetTaskSpecification().TaskId() << " cannot be scheduled right now. It requires " << exemplar.GetTaskSpecification().GetRequiredPlacementResources().ToString() - << " for placement, but this node only has remaining " - << local_resources.GetAvailableResources().ToString() << ". In total there are " - << pending_tasks << " pending tasks and " << pending_actor_creations - << " pending actors on this node. " + << " for placement, but this node only has remaining " << available_resources + << ". In total there are " << pending_tasks << " pending tasks and " + << pending_actor_creations << " pending actors on this node. " << "This is likely due to all cluster resources being claimed by actors. " << "To resolve the issue, consider creating fewer actors or increase the " << "resources available to this Ray cluster. You can ignore this message " @@ -2164,6 +2173,7 @@ void NodeManager::HandleDirectCallTaskBlocked( auto const cpu_resource_ids = worker->ReleaseTaskCpuResources(); local_available_resources_.Release(cpu_resource_ids); cluster_resource_map_[self_node_id_].Release(cpu_resource_ids.ToResourceSet()); + worker->MarkBlocked(); DispatchTasks(local_queues_.GetReadyTasksByClass()); } diff --git a/src/ray/raylet/scheduling/cluster_resource_data.cc b/src/ray/raylet/scheduling/cluster_resource_data.cc index 551b5a980..a9fa3af39 100644 --- a/src/ray/raylet/scheduling/cluster_resource_data.cc +++ b/src/ray/raylet/scheduling/cluster_resource_data.cc @@ -232,6 +232,61 @@ std::string NodeResources::DebugString(StringIdMap string_to_in_map) const { return buffer.str(); } +const std::string format_resource(std::string resource_name, double quantity) { + if (resource_name == "object_store_memory" || resource_name == "memory") { + // Convert to 50MiB chunks and then to GiB + return std::to_string(quantity * (50 * 1024 * 1024) / (1024 * 1024 * 1024)) + " GiB"; + } + return std::to_string(quantity); +} + +std::string NodeResources::DictString(StringIdMap string_to_in_map) const { + std::stringstream buffer; + bool first = true; + buffer << "{"; + for (size_t i = 0; i < this->predefined_resources.size(); i++) { + if (this->predefined_resources[i].total <= 0) { + continue; + } + if (first) { + first = false; + } else { + buffer << ", "; + } + std::string name = ""; + switch (i) { + case CPU: + name = "CPU"; + break; + case MEM: + name = "memory"; + break; + case GPU: + name = "GPU"; + break; + case TPU: + name = "TPU"; + break; + default: + RAY_CHECK(false) << "This should never happen."; + break; + } + buffer << format_resource(name, this->predefined_resources[i].available.Double()) + << "/"; + buffer << format_resource(name, this->predefined_resources[i].total.Double()); + buffer << " " << name; + } + for (auto it = this->custom_resources.begin(); it != this->custom_resources.end(); + ++it) { + auto name = string_to_in_map.Get(it->first); + buffer << ", " << format_resource(name, it->second.available.Double()) << "/" + << format_resource(name, it->second.total.Double()); + buffer << " " << name; + } + buffer << "}" << std::endl; + return buffer.str(); +} + bool NodeResourceInstances::operator==(const NodeResourceInstances &other) { for (size_t i = 0; i < PredefinedResources_MAX; i++) { if (!EqualVectors(this->predefined_resources[i].total, diff --git a/src/ray/raylet/scheduling/cluster_resource_data.h b/src/ray/raylet/scheduling/cluster_resource_data.h index 96b4c4359..46a5cbe1b 100644 --- a/src/ray/raylet/scheduling/cluster_resource_data.h +++ b/src/ray/raylet/scheduling/cluster_resource_data.h @@ -164,6 +164,8 @@ class NodeResources { bool operator!=(const NodeResources &other); /// Returns human-readable string for these resources. std::string DebugString(StringIdMap string_to_int_map) const; + /// Returns compact dict-like string. + std::string DictString(StringIdMap string_to_int_map) const; }; /// Total and available capacities of each resource instance. diff --git a/src/ray/raylet/scheduling/cluster_resource_scheduler.cc b/src/ray/raylet/scheduling/cluster_resource_scheduler.cc index 66047d258..bcca8862a 100644 --- a/src/ray/raylet/scheduling/cluster_resource_scheduler.cc +++ b/src/ray/raylet/scheduling/cluster_resource_scheduler.cc @@ -518,6 +518,12 @@ void ClusterResourceScheduler::InitResourceInstances( } } +std::string ClusterResourceScheduler::GetLocalResourceViewString() const { + const auto &node_it = nodes_.find(local_node_id_); + RAY_CHECK(node_it != nodes_.end()); + return node_it->second.GetLocalView().DictString(string_to_int_map_); +} + void ClusterResourceScheduler::InitLocalResources(const NodeResources &node_resources) { local_resources_.predefined_resources.resize(PredefinedResources_MAX); diff --git a/src/ray/raylet/scheduling/cluster_resource_scheduler.h b/src/ray/raylet/scheduling/cluster_resource_scheduler.h index 470c97c38..9e480b4c8 100644 --- a/src/ray/raylet/scheduling/cluster_resource_scheduler.h +++ b/src/ray/raylet/scheduling/cluster_resource_scheduler.h @@ -191,6 +191,9 @@ class ClusterResourceScheduler { /// Return local resources. NodeResourceInstances GetLocalResources() { return local_resources_; }; + /// Return local resources in human-readable string form. + std::string GetLocalResourceViewString() const; + /// Create instances for each resource associated with the local node, given /// the node's resources. /// diff --git a/src/ray/raylet/scheduling/cluster_task_manager.cc b/src/ray/raylet/scheduling/cluster_task_manager.cc index 74437a4a1..c11d818ef 100644 --- a/src/ray/raylet/scheduling/cluster_task_manager.cc +++ b/src/ray/raylet/scheduling/cluster_task_manager.cc @@ -242,7 +242,10 @@ void ClusterTaskManager::TasksUnblocked(const std::vector ready_ids) { const auto &scheduling_key = task.GetTaskSpecification().GetSchedulingClass(); RAY_LOG(DEBUG) << "Args ready, task can be dispatched " << task.GetTaskSpecification().TaskId(); - tasks_to_dispatch_[scheduling_key].push_back(work); + // Note: we transition tasks back to the scheduling queue instead of directly + // to dispatch. This allows AnyPendingTasks() to simply check the scheduling + // queue to see if any tasks are blocked on resource availability: see #12438 + tasks_to_schedule_[scheduling_key].push_back(work); waiting_tasks_.erase(it); } } @@ -500,6 +503,32 @@ void ClusterTaskManager::FillResourceUsage( } } +bool ClusterTaskManager::AnyPendingTasks(Task *exemplar, bool *any_pending, + int *num_pending_actor_creation, + int *num_pending_tasks) const { + // We are guaranteed that these tasks are blocked waiting for resources after a + // call to ScheduleAndDispatch(). Note that tasks that transition to waiting + // move back to the tasks_to_schedule_ queue after their deps are satisfied. + for (const auto &shapes_it : tasks_to_schedule_) { + auto &work_queue = shapes_it.second; + for (const auto &work_it : work_queue) { + const auto &task = std::get<0>(work_it); + if (task.GetTaskSpecification().IsActorCreationTask()) { + *num_pending_actor_creation += 1; + } else { + *num_pending_tasks += 1; + } + + if (!*any_pending) { + *exemplar = task; + *any_pending = true; + } + } + } + // If there's any pending task, at this point, there's no progress being made. + return *any_pending; +} + std::string ClusterTaskManager::DebugString() const { std::stringstream buffer; buffer << "========== Node: " << self_node_id_ << " =================\n"; diff --git a/src/ray/raylet/scheduling/cluster_task_manager.h b/src/ray/raylet/scheduling/cluster_task_manager.h index 995273ed5..b71593f8a 100644 --- a/src/ray/raylet/scheduling/cluster_task_manager.h +++ b/src/ray/raylet/scheduling/cluster_task_manager.h @@ -115,6 +115,16 @@ class ClusterTaskManager { void FillResourceUsage(bool light_report_resource_usage_enabled, std::shared_ptr data) const; + /// Return if any tasks are pending resource acquisition. + /// + /// \param[in] exemplar An example task that is deadlocking. + /// \param[in] num_pending_actor_creation Number of pending actor creation tasks. + /// \param[in] num_pending_tasks Number of pending tasks. + /// \param[in] any_pending True if there's any pending exemplar. + /// \return True if any progress is any tasks are pending. + bool AnyPendingTasks(Task *exemplar, bool *any_pending, int *num_pending_actor_creation, + int *num_pending_tasks) const; + std::string DebugString() const; private: @@ -147,11 +157,11 @@ class ClusterTaskManager { std::unordered_map> tasks_to_schedule_; /// Queue of lease requests that should be scheduled onto workers. - /// Tasks move from scheduled | waiting -> dispatch. + /// Tasks move from scheduled -> dispatch. std::unordered_map> tasks_to_dispatch_; /// Tasks waiting for arguments to be transferred locally. - /// Tasks move from waiting -> dispatch. + /// Tasks move (back) from waiting -> scheduled. absl::flat_hash_map waiting_tasks_; /// Queue of lease requests that are infeasible. diff --git a/src/ray/raylet/scheduling/cluster_task_manager_test.cc b/src/ray/raylet/scheduling/cluster_task_manager_test.cc index 023390632..24018dbc8 100644 --- a/src/ray/raylet/scheduling/cluster_task_manager_test.cc +++ b/src/ray/raylet/scheduling/cluster_task_manager_test.cc @@ -250,7 +250,9 @@ TEST_F(ClusterTaskManagerTest, ResourceTakenWhileResolving) { /* First task is unblocked now, but resources are no longer available */ auto id = task.GetTaskSpecification().TaskId(); std::vector unblocked = {id}; + dependencies_fulfilled_ = true; task_manager_.TasksUnblocked(unblocked); + task_manager_.SchedulePendingTasks(); task_manager_.DispatchScheduledTasksToWorkers(pool_, leased_workers_); ASSERT_EQ(num_callbacks, 1); @@ -261,6 +263,7 @@ TEST_F(ClusterTaskManagerTest, ResourceTakenWhileResolving) { leased_workers_.clear(); task_manager_.HandleTaskFinished(worker); + task_manager_.SchedulePendingTasks(); task_manager_.DispatchScheduledTasksToWorkers(pool_, leased_workers_); // Task2 is now done so task can run. @@ -676,6 +679,46 @@ TEST_F(ClusterTaskManagerTest, TestMultipleInfeasibleTasksWarnOnce) { ASSERT_EQ(announce_infeasible_task_calls_, 1); } +TEST_F(ClusterTaskManagerTest, TestAnyPendingTasks) { + /* + Check if the manager can correctly identify pending tasks. + */ + + // task1: running + Task task = CreateTask({{ray::kCPU_ResourceLabel, 6}}); + rpc::RequestWorkerLeaseReply reply; + std::shared_ptr callback_occurred = std::make_shared(false); + auto callback = [callback_occurred]() { *callback_occurred = true; }; + task_manager_.QueueTask(task, &reply, callback); + task_manager_.SchedulePendingTasks(); + std::shared_ptr worker = + std::make_shared(WorkerID::FromRandom(), 1234); + pool_.PushWorker(std::dynamic_pointer_cast(worker)); + task_manager_.DispatchScheduledTasksToWorkers(pool_, leased_workers_); + ASSERT_TRUE(*callback_occurred); + ASSERT_EQ(leased_workers_.size(), 1); + ASSERT_EQ(pool_.workers.size(), 0); + + // task1: running. Progress is made, and there's no deadlock. + ray::Task exemplar; + bool any_pending = false; + int pending_actor_creations = 0; + int pending_tasks = 0; + ASSERT_FALSE(task_manager_.AnyPendingTasks(&exemplar, &any_pending, + &pending_actor_creations, &pending_tasks)); + + // task1: running, task2: queued. + Task task2 = CreateTask({{ray::kCPU_ResourceLabel, 6}}); + rpc::RequestWorkerLeaseReply reply2; + std::shared_ptr callback_occurred2 = std::make_shared(false); + auto callback2 = [callback_occurred2]() { *callback_occurred2 = true; }; + task_manager_.QueueTask(task2, &reply2, callback2); + task_manager_.SchedulePendingTasks(); + ASSERT_FALSE(*callback_occurred2); + ASSERT_TRUE(task_manager_.AnyPendingTasks(&exemplar, &any_pending, + &pending_actor_creations, &pending_tasks)); +} + int main(int argc, char **argv) { ::testing::InitGoogleTest(&argc, argv); return RUN_ALL_TESTS(); From 3521e74f3acfe1f6b72b66c7d5281ba61b368e13 Mon Sep 17 00:00:00 2001 From: Edward Oakes Date: Fri, 18 Dec 2020 15:49:24 -0600 Subject: [PATCH 26/88] [serve] Support for imported backends (#12923) --- doc/source/serve/advanced.rst | 8 ++++ doc/source/serve/package-ref.rst | 3 ++ python/ray/serve/BUILD | 8 ++++ python/ray/serve/backends.py | 33 +++++++++++++++ python/ray/serve/examples/doc/conda_env.py | 2 - .../serve/examples/doc/imported_backend.py | 12 ++++++ .../ray/serve/tests/test_imported_backend.py | 29 +++++++++++++ python/ray/serve/tests/test_util.py | 18 +++++++- python/ray/serve/utils.py | 41 +++++++++++++++++++ 9 files changed, 151 insertions(+), 3 deletions(-) create mode 100644 python/ray/serve/backends.py create mode 100644 python/ray/serve/examples/doc/imported_backend.py create mode 100644 python/ray/serve/tests/test_imported_backend.py diff --git a/doc/source/serve/advanced.rst b/doc/source/serve/advanced.rst index 01ef54fc9..3c3ca3940 100644 --- a/doc/source/serve/advanced.rst +++ b/doc/source/serve/advanced.rst @@ -327,3 +327,11 @@ as shown below. :mod:`client.create_backend ` by default. +The dependencies required in the backend may be different than +the dependencies installed in the driver program (the one running Serve API +calls). In this case, you can use an +:mod:`ImportedBackend ` to specify a +backend based on a class that is installed in the Python environment that +the workers will run in. Example: + +.. literalinclude:: ../../../python/ray/serve/examples/doc/imported_backend.py diff --git a/doc/source/serve/package-ref.rst b/doc/source/serve/package-ref.rst index 4c1ad2f7b..5a9e947ff 100644 --- a/doc/source/serve/package-ref.rst +++ b/doc/source/serve/package-ref.rst @@ -31,3 +31,6 @@ objects instead of Flask requests. Batching Requests ----------------- .. autofunction:: ray.serve.accept_batch + +Built-in Backends +.. autoclass:: ray.serve.backends.ImportedBackend diff --git a/python/ray/serve/BUILD b/python/ray/serve/BUILD index be8707d86..dd1b91359 100644 --- a/python/ray/serve/BUILD +++ b/python/ray/serve/BUILD @@ -119,6 +119,14 @@ py_test( deps = [":serve_lib"], ) +py_test( + name = "test_imported_backend", + size = "small", + srcs = serve_tests_srcs, + tags = ["exclusive"], + deps = [":serve_lib"], +) + # Runs test_api and test_failure with injected failures in the controller. # TODO(simon): Tests are disabled until #11683 is fixed. diff --git a/python/ray/serve/backends.py b/python/ray/serve/backends.py new file mode 100644 index 000000000..086755500 --- /dev/null +++ b/python/ray/serve/backends.py @@ -0,0 +1,33 @@ +from ray.serve.utils import import_class + + +class ImportedBackend: + """Factory for a class that will dynamically import a backend class. + + This is intended to be used when the source code for a backend is + installed in the worker environment but not the driver. + + Intended usage: + >>> client = serve.connect() + >>> client.create_backend("b", ImportedBackend("module.Class"), *args) + + This will import module.Class on the worker and proxy all relevant methods + to it. + """ + + def __new__(cls, class_path): + class ImportedBackend: + def __init__(self, *args, **kwargs): + self.wrapped = import_class(class_path)(*args, **kwargs) + + def reconfigure(self, *args, **kwargs): + # NOTE(edoakes): we check that the reconfigure method is + # present if the user specifies a user_config, so we need to + # proxy it manually. + return self.wrapped.reconfigure(*args, **kwargs) + + def __getattr__(self, attr): + """Proxy all other methods to the wrapper class.""" + return getattr(self.wrapped, attr) + + return ImportedBackend diff --git a/python/ray/serve/examples/doc/conda_env.py b/python/ray/serve/examples/doc/conda_env.py index 4b8239d67..7eee2df33 100644 --- a/python/ray/serve/examples/doc/conda_env.py +++ b/python/ray/serve/examples/doc/conda_env.py @@ -1,10 +1,8 @@ import requests -import ray from ray import serve from ray.serve import CondaEnv import tensorflow as tf -ray.init() client = serve.start() diff --git a/python/ray/serve/examples/doc/imported_backend.py b/python/ray/serve/examples/doc/imported_backend.py new file mode 100644 index 000000000..ec864c211 --- /dev/null +++ b/python/ray/serve/examples/doc/imported_backend.py @@ -0,0 +1,12 @@ +import requests + +from ray import serve +from ray.serve.backends import ImportedBackend + +client = serve.start() + +backend_class = ImportedBackend("ray.serve.utils.MockImportedBackend") +client.create_backend("imported", backend_class, "input_arg") +client.create_endpoint("imported", backend="imported", route="/imported") + +print(requests.get("http://127.0.0.1:8000/imported").text) diff --git a/python/ray/serve/tests/test_imported_backend.py b/python/ray/serve/tests/test_imported_backend.py new file mode 100644 index 000000000..cc575dd94 --- /dev/null +++ b/python/ray/serve/tests/test_imported_backend.py @@ -0,0 +1,29 @@ +import ray +from ray.serve.backends import ImportedBackend +from ray.serve.config import BackendConfig + + +def test_imported_backend(serve_instance): + client = serve_instance + + backend_class = ImportedBackend("ray.serve.utils.MockImportedBackend") + config = BackendConfig(user_config="config") + client.create_backend( + "imported", backend_class, "input_arg", config=config) + client.create_endpoint("imported", backend="imported") + + # Basic sanity check. + handle = client.get_handle("imported") + assert ray.get(handle.remote()) == {"arg": "input_arg", "config": "config"} + + # Check that updating backend config works. + client.update_backend_config( + "imported", BackendConfig(user_config="new_config")) + assert ray.get(handle.remote()) == { + "arg": "input_arg", + "config": "new_config" + } + + # Check that other call methods work. + handle = handle.options(method_name="other_method") + assert ray.get(handle.remote("hello")) == "hello" diff --git a/python/ray/serve/tests/test_util.py b/python/ray/serve/tests/test_util.py index d492a3b1e..9893bc4ce 100644 --- a/python/ray/serve/tests/test_util.py +++ b/python/ray/serve/tests/test_util.py @@ -6,9 +6,10 @@ from copy import deepcopy import numpy as np import pytest +import ray from ray.serve.utils import (ServeEncoder, chain_future, unpack_future, try_schedule_resources_on_nodes, - get_conda_env_dir) + get_conda_env_dir, import_class) def test_bytes_encoder(): @@ -125,6 +126,21 @@ def test_get_conda_env_dir(tmp_path): os.environ["CONDA_PREFIX"] = "" +def test_import_class(): + assert import_class("ray.serve.Client") == ray.serve.api.Client + assert import_class("ray.serve.api.Client") == ray.serve.api.Client + + policy_cls = import_class("ray.serve.controller.TrafficPolicy") + assert policy_cls == ray.serve.controller.TrafficPolicy + + policy = policy_cls({"endpoint1": 0.5, "endpoint2": 0.5}) + with pytest.raises(ValueError): + policy.set_traffic_dict({"endpoint1": 0.5, "endpoint2": 0.6}) + policy.set_traffic_dict({"endpoint1": 0.4, "endpoint2": 0.6}) + + print(repr(policy)) + + if __name__ == "__main__": import sys sys.exit(pytest.main(["-v", "-s", __file__])) diff --git a/python/ray/serve/utils.py b/python/ray/serve/utils.py index efa6f1b6b..e8c5a6d13 100644 --- a/python/ray/serve/utils.py +++ b/python/ray/serve/utils.py @@ -1,5 +1,6 @@ import asyncio from functools import singledispatch +import importlib from itertools import groupby import json import logging @@ -342,3 +343,43 @@ def get_node_id_for_actor(actor_handle): """Given an actor handle, return the node id it's placed on.""" return ray.actors()[actor_handle._actor_id.hex()]["Address"]["NodeID"] + + +def import_class(full_path: str): + """Given a full import path to a class name, return the imported class. + + For example, the following are equivalent: + MyClass = import_class("module.submodule.MyClass") + from module.submodule import MyClass + + Returns: + Imported class + """ + + last_period_idx = full_path.rfind(".") + class_name = full_path[last_period_idx + 1:] + module_name = full_path[:last_period_idx] + module = importlib.import_module(module_name) + return getattr(module, class_name) + + +class MockImportedBackend: + """Used for testing backends.ImportedBackend. + + This is necessary because we need the class to be installed in the worker + processes. We could instead mock out importlib but doing so is messier and + reduces confidence in the test (it isn't truly end-to-end). + """ + + def __init__(self, arg): + self.arg = arg + self.config = None + + def reconfigure(self, config): + self.config = config + + def __call__(self, *args): + return {"arg": self.arg, "config": self.config} + + def other_method(self, request): + return request.data From 3e492a79ec7b67dd1137535cead690339effc2ac Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Fri, 18 Dec 2020 15:59:03 -0800 Subject: [PATCH 27/88] Increase the number of unique bits for actors to avoid handle collisions (#12894) --- .../tests/test_stats_collector.py | 18 ++++++--------- dashboard/tests/test_memory_utils.py | 5 +++-- .../src/main/java/io/ray/api/id/ActorId.java | 2 +- .../src/main/java/io/ray/api/id/ObjectId.java | 2 +- .../src/main/java/io/ray/api/id/UniqueId.java | 2 +- .../java/io/ray/runtime/UniqueIdTest.java | 22 ++++++------------- python/ray/exceptions.py | 6 +++-- python/ray/includes/function_descriptor.pxi | 13 +++++++---- python/ray/includes/unique_ids.pxi | 2 +- python/ray/log_monitor.py | 2 +- python/ray/ray_constants.py | 2 +- python/ray/serialization.py | 5 +++-- python/ray/tests/test_advanced_3.py | 6 ++--- python/ray/tests/test_multi_node.py | 8 +++---- python/ray/utils.py | 4 ++-- python/ray/worker.py | 3 ++- src/ray/common/constants.h | 2 +- src/ray/common/id.h | 2 +- src/ray/core_worker/actor_manager.cc | 2 ++ 19 files changed, 54 insertions(+), 54 deletions(-) diff --git a/dashboard/modules/stats_collector/tests/test_stats_collector.py b/dashboard/modules/stats_collector/tests/test_stats_collector.py index f4246770a..bed6d650f 100644 --- a/dashboard/modules/stats_collector/tests/test_stats_collector.py +++ b/dashboard/modules/stats_collector/tests/test_stats_collector.py @@ -112,20 +112,16 @@ def test_memory_table(disable_aiohttp_cache, ray_start_with_dashboard): def check_mem_table(): resp = requests.get(f"{webui_url}/memory/memory_table") resp_data = resp.json() - if not resp_data["result"]: - return False + assert resp_data["result"] latest_memory_table = resp_data["data"]["memoryTable"] summary = latest_memory_table["summary"] - try: - # 1 ref per handle and per object the actor has a ref to - assert summary["totalActorHandles"] == len(actors) * 2 - # 1 ref for my_obj - assert summary["totalLocalRefCount"] == 1 - return True - except AssertionError: - return False + # 1 ref per handle and per object the actor has a ref to + assert summary["totalActorHandles"] == len(actors) * 2 + # 1 ref for my_obj + assert summary["totalLocalRefCount"] == 1 - wait_for_condition(check_mem_table, 10) + wait_until_succeeded_without_exception( + check_mem_table, (AssertionError, ), timeout_ms=1000) def test_get_all_node_details(disable_aiohttp_cache, ray_start_with_dashboard): diff --git a/dashboard/tests/test_memory_utils.py b/dashboard/tests/test_memory_utils.py index f58ecd8ae..212eeefad 100644 --- a/dashboard/tests/test_memory_utils.py +++ b/dashboard/tests/test_memory_utils.py @@ -7,8 +7,9 @@ from ray.new_dashboard.memory_utils import ( NODE_ADDRESS = "127.0.0.1" IS_DRIVER = True PID = 1 -OBJECT_ID = "7wpsIhgZiBz/////AQAAyAEAAAA=" -ACTOR_ID = "fffffffffffffffff66d17ba010000c801000000" + +OBJECT_ID = "ZmZmZmZmZmZmZmZmZmZmZmZmZmZmZmZmZmZmZg==" +ACTOR_ID = "fffffffffffffffffffffffffffffffff66d17ba010000c801000000" DECODED_ID = decode_object_ref_if_needed(OBJECT_ID) OBJECT_SIZE = 100 diff --git a/java/api/src/main/java/io/ray/api/id/ActorId.java b/java/api/src/main/java/io/ray/api/id/ActorId.java index 65a0cf19a..a21d4e79f 100644 --- a/java/api/src/main/java/io/ray/api/id/ActorId.java +++ b/java/api/src/main/java/io/ray/api/id/ActorId.java @@ -7,7 +7,7 @@ import java.util.Random; public class ActorId extends BaseId implements Serializable { - private static final int UNIQUE_BYTES_LENGTH = 4; + private static final int UNIQUE_BYTES_LENGTH = 12; public static final int LENGTH = JobId.LENGTH + UNIQUE_BYTES_LENGTH; diff --git a/java/api/src/main/java/io/ray/api/id/ObjectId.java b/java/api/src/main/java/io/ray/api/id/ObjectId.java index 9b1fa246f..78b677ac8 100644 --- a/java/api/src/main/java/io/ray/api/id/ObjectId.java +++ b/java/api/src/main/java/io/ray/api/id/ObjectId.java @@ -10,7 +10,7 @@ import java.util.Random; */ public class ObjectId extends BaseId implements Serializable { - public static final int LENGTH = 20; + public static final int LENGTH = 28; /** * Create an ObjectId from a ByteBuffer. diff --git a/java/api/src/main/java/io/ray/api/id/UniqueId.java b/java/api/src/main/java/io/ray/api/id/UniqueId.java index 03de53943..44b19f6a7 100644 --- a/java/api/src/main/java/io/ray/api/id/UniqueId.java +++ b/java/api/src/main/java/io/ray/api/id/UniqueId.java @@ -11,7 +11,7 @@ import java.util.Random; */ public class UniqueId extends BaseId implements Serializable { - public static final int LENGTH = 20; + public static final int LENGTH = 28; public static final UniqueId NIL = genNil(); /** diff --git a/java/runtime/src/test/java/io/ray/runtime/UniqueIdTest.java b/java/runtime/src/test/java/io/ray/runtime/UniqueIdTest.java index 25704f321..7496f1baf 100644 --- a/java/runtime/src/test/java/io/ray/runtime/UniqueIdTest.java +++ b/java/runtime/src/test/java/io/ray/runtime/UniqueIdTest.java @@ -1,7 +1,6 @@ package io.ray.runtime; import io.ray.api.id.UniqueId; -import io.ray.runtime.util.IdUtil; import java.nio.ByteBuffer; import java.util.Arrays; import javax.xml.bind.DatatypeConverter; @@ -13,12 +12,12 @@ public class UniqueIdTest { @Test public void testConstructUniqueId() { // Test `fromHexString()` - UniqueId id1 = UniqueId.fromHexString("00000000123456789ABCDEF123456789ABCDEF00"); - Assert.assertEquals("00000000123456789abcdef123456789abcdef00", id1.toString()); + UniqueId id1 = UniqueId.fromHexString("00000000123456789ABCDEF123456789ABCDEF0123456789ABCDEF00"); + Assert.assertEquals("00000000123456789abcdef123456789abcdef0123456789abcdef00", id1.toString()); Assert.assertFalse(id1.isNil()); try { - UniqueId id2 = UniqueId.fromHexString("000000123456789ABCDEF123456789ABCDEF00"); + UniqueId id2 = UniqueId.fromHexString("000000123456789ABCDEF123456789ABCDEF0123456789ABCDEF00"); // This shouldn't be happened. Assert.assertTrue(false); } catch (IllegalArgumentException e) { @@ -34,23 +33,16 @@ public class UniqueIdTest { } // Test `fromByteBuffer()` - byte[] bytes = DatatypeConverter.parseHexBinary("0123456789ABCDEF0123456789ABCDEF01234567"); - ByteBuffer byteBuffer = ByteBuffer.wrap(bytes, 0, 20); + byte[] bytes = DatatypeConverter.parseHexBinary("0123456789ABCDEF0123456789ABCDEF012345670123456789ABCDEF"); + ByteBuffer byteBuffer = ByteBuffer.wrap(bytes, 0, 28); UniqueId id4 = UniqueId.fromByteBuffer(byteBuffer); Assert.assertTrue(Arrays.equals(bytes, id4.getBytes())); - Assert.assertEquals("0123456789abcdef0123456789abcdef01234567", id4.toString()); + Assert.assertEquals("0123456789abcdef0123456789abcdef012345670123456789abcdef", id4.toString()); // Test `genNil()` UniqueId id6 = UniqueId.NIL; - Assert.assertEquals("FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF".toLowerCase(), id6.toString()); + Assert.assertEquals("FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF".toLowerCase(), id6.toString()); Assert.assertTrue(id6.isNil()); } - - @Test - void testMurmurHash() { - UniqueId id = UniqueId.fromHexString("3131313131313131313132323232323232323232"); - long remainder = Long.remainderUnsigned(IdUtil.murmurHashCode(id), 1000000000); - Assert.assertEquals(remainder, 787616861); - } } diff --git a/python/ray/exceptions.py b/python/ray/exceptions.py index b5a0b477c..56e943db6 100644 --- a/python/ray/exceptions.py +++ b/python/ray/exceptions.py @@ -142,7 +142,8 @@ class WorkerCrashedError(RayError): """Indicates that the worker died unexpectedly while executing a task.""" def __str__(self): - return "The worker died unexpectedly while executing this task." + return ("The worker died unexpectedly while executing this task. " + "Check python-core-worker-*.log files for more information.") class RayActorError(RayError): @@ -153,7 +154,8 @@ class RayActorError(RayError): """ def __str__(self): - return "The actor died unexpectedly before finishing this task." + return ("The actor died unexpectedly before finishing this task. " + "Check python-core-worker-*.log files for more information.") class RaySystemError(RayError): diff --git a/python/ray/includes/function_descriptor.pxi b/python/ray/includes/function_descriptor.pxi index a9ac11fdb..d2c4cbbf4 100644 --- a/python/ray/includes/function_descriptor.pxi +++ b/python/ray/includes/function_descriptor.pxi @@ -12,6 +12,7 @@ import hashlib import cython import inspect import uuid +import ray.ray_constants as ray_constants ctypedef object (*FunctionDescriptor_from_cpp)(const CFunctionDescriptor &) @@ -188,7 +189,8 @@ cdef class PythonFunctionDescriptor(FunctionDescriptor): function_name = function.__name__ class_name = "" - pickled_function_hash = hashlib.sha1(pickled_function).hexdigest() + pickled_function_hash = hashlib.shake_128(pickled_function).hexdigest( + ray_constants.ID_SIZE) return cls(module_name, function_name, class_name, pickled_function_hash) @@ -208,7 +210,10 @@ cdef class PythonFunctionDescriptor(FunctionDescriptor): module_name = target_class.__module__ class_name = target_class.__name__ # Use a random uuid as function hash to solve actor name conflict. - return cls(module_name, "__init__", class_name, str(uuid.uuid4())) + return cls( + module_name, "__init__", class_name, + hashlib.shake_128( + uuid.uuid4().bytes).hexdigest(ray_constants.ID_SIZE)) @property def module_name(self): @@ -268,14 +273,14 @@ cdef class PythonFunctionDescriptor(FunctionDescriptor): Returns: ray.ObjectRef to represent the function descriptor. """ - function_id_hash = hashlib.sha1() + function_id_hash = hashlib.shake_128() # Include the function module and name in the hash. function_id_hash.update(self.typed_descriptor.ModuleName()) function_id_hash.update(self.typed_descriptor.FunctionName()) function_id_hash.update(self.typed_descriptor.ClassName()) function_id_hash.update(self.typed_descriptor.FunctionHash()) # Compute the function ID. - function_id = function_id_hash.digest() + function_id = function_id_hash.digest(ray_constants.ID_SIZE) return ray.FunctionID(function_id) def is_actor_method(self): diff --git a/python/ray/includes/unique_ids.pxi b/python/ray/includes/unique_ids.pxi index bcf766829..52a6730e6 100644 --- a/python/ray/includes/unique_ids.pxi +++ b/python/ray/includes/unique_ids.pxi @@ -31,7 +31,7 @@ def check_id(b, size=kUniqueIDSize): raise TypeError("Unsupported type: " + str(type(b))) if len(b) != size: raise ValueError("ID string needs to have length " + - str(size)) + str(size) + ", got " + str(len(b))) cdef extern from "ray/common/constants.h" nogil: diff --git a/python/ray/log_monitor.py b/python/ray/log_monitor.py index ac5fa5296..d6b3a314e 100644 --- a/python/ray/log_monitor.py +++ b/python/ray/log_monitor.py @@ -22,7 +22,7 @@ from ray.ray_logging import setup_component_logger logger = logging.getLogger(__name__) # The groups are worker id, job id, and pid. -JOB_LOG_PATTERN = re.compile(".*worker-([0-9a-f]{40})-(\d+)-(\d+)") +JOB_LOG_PATTERN = re.compile(".*worker-([0-9a-f]+)-(\d+)-(\d+)") class LogFileInfo: diff --git a/python/ray/ray_constants.py b/python/ray/ray_constants.py index be717ca3c..30b3b5c7b 100644 --- a/python/ray/ray_constants.py +++ b/python/ray/ray_constants.py @@ -19,7 +19,7 @@ def env_bool(key, default): return default -ID_SIZE = 20 +ID_SIZE = 28 # The default maximum number of bytes to allocate to the object store unless # overridden by the user. diff --git a/python/ray/serialization.py b/python/ray/serialization.py index dc9a2c40e..9a24f3ccc 100644 --- a/python/ray/serialization.py +++ b/python/ray/serialization.py @@ -74,7 +74,8 @@ def _try_to_compute_deterministic_class_id(cls, depth=5): new_class_id = pickle.dumps(pickle.loads(class_id)) if new_class_id == class_id: # We appear to have reached a fix point, so use this as the ID. - return hashlib.sha1(new_class_id).digest() + return hashlib.shake_128(new_class_id).digest( + ray_constants.ID_SIZE) class_id = new_class_id # We have not reached a fixed point, so we may end up with a different @@ -82,7 +83,7 @@ def _try_to_compute_deterministic_class_id(cls, depth=5): # same class definition being exported many many times. logger.warning( f"WARNING: Could not produce a deterministic class ID for class {cls}") - return hashlib.sha1(new_class_id).digest() + return hashlib.shake_128(new_class_id).digest(ray_constants.ID_SIZE) def object_ref_deserializer(reduced_obj_ref, owner_address): diff --git a/python/ray/tests/test_advanced_3.py b/python/ray/tests/test_advanced_3.py index 7f1e8e639..b1bc25fbb 100644 --- a/python/ray/tests/test_advanced_3.py +++ b/python/ray/tests/test_advanced_3.py @@ -284,14 +284,14 @@ def test_workers(shutdown_only): def test_object_ref_properties(): - id_bytes = b"00112233445566778899" + id_bytes = b"0011223344556677889900001111" object_ref = ray.ObjectRef(id_bytes) assert object_ref.binary() == id_bytes object_ref = ray.ObjectRef.nil() assert object_ref.is_nil() - with pytest.raises(ValueError, match=r".*needs to have length 20.*"): + with pytest.raises(ValueError, match=r".*needs to have length.*"): ray.ObjectRef(id_bytes + b"1234") - with pytest.raises(ValueError, match=r".*needs to have length 20.*"): + with pytest.raises(ValueError, match=r".*needs to have length.*"): ray.ObjectRef(b"0123456789") object_ref = ray.ObjectRef.from_random() assert not object_ref.is_nil() diff --git a/python/ray/tests/test_multi_node.py b/python/ray/tests/test_multi_node.py index cb206112d..fbce475c1 100644 --- a/python/ray/tests/test_multi_node.py +++ b/python/ray/tests/test_multi_node.py @@ -741,10 +741,10 @@ ray.get(main_wait.release.remote()) driver1_out_split = driver1_out.split("\n") driver2_out_split = driver2_out.split("\n") - assert driver1_out_split[0][-1] == "1" - assert driver1_out_split[1][-1] == "2" - assert driver2_out_split[0][-1] == "3" - assert driver2_out_split[1][-1] == "4" + assert driver1_out_split[0][-1] == "1", driver1_out_split + assert driver1_out_split[1][-1] == "2", driver1_out_split + assert driver2_out_split[0][-1] == "3", driver2_out_split + assert driver2_out_split[1][-1] == "4", driver2_out_split if __name__ == "__main__": diff --git a/python/ray/utils.py b/python/ray/utils.py index a3940d6e8..2704e07cc 100644 --- a/python/ray/utils.py +++ b/python/ray/utils.py @@ -50,9 +50,9 @@ def get_ray_temp_dir(): def _random_string(): - id_hash = hashlib.sha1() + id_hash = hashlib.shake_128() id_hash.update(uuid.uuid4().bytes) - id_bytes = id_hash.digest() + id_bytes = id_hash.digest(ray_constants.ID_SIZE) assert len(id_bytes) == ray_constants.ID_SIZE return id_bytes diff --git a/python/ray/worker.py b/python/ray/worker.py index 495478ad7..627037098 100644 --- a/python/ray/worker.py +++ b/python/ray/worker.py @@ -345,7 +345,8 @@ class Worker: # actually run the function locally. pickled_function = pickle.dumps(function) - function_to_run_id = hashlib.sha1(pickled_function).digest() + function_to_run_id = hashlib.shake_128(pickled_function).digest( + ray_constants.ID_SIZE) key = b"FunctionsToRun:" + function_to_run_id # First run the function on the driver. # We always run the task locally. diff --git a/src/ray/common/constants.h b/src/ray/common/constants.h index 1636846f0..3a3461f2c 100644 --- a/src/ray/common/constants.h +++ b/src/ray/common/constants.h @@ -18,7 +18,7 @@ #include /// Length of Ray full-length IDs in bytes. -constexpr size_t kUniqueIDSize = 20; +constexpr size_t kUniqueIDSize = 28; /// An ObjectID's bytes are split into the task ID itself and the index of the /// object's creation. This is the maximum width of the object index in bits. diff --git a/src/ray/common/id.h b/src/ray/common/id.h index d12ba550d..bd55b27e5 100644 --- a/src/ray/common/id.h +++ b/src/ray/common/id.h @@ -124,7 +124,7 @@ class JobID : public BaseID { class ActorID : public BaseID { private: - static constexpr size_t kUniqueBytesLength = 4; + static constexpr size_t kUniqueBytesLength = 12; public: /// Length of `ActorID` in bytes. diff --git a/src/ray/core_worker/actor_manager.cc b/src/ray/core_worker/actor_manager.cc index e6ef4fc87..6b931082a 100644 --- a/src/ray/core_worker/actor_manager.cc +++ b/src/ray/core_worker/actor_manager.cc @@ -91,6 +91,8 @@ bool ActorManager::AddActorHandle(std::unique_ptr actor_handle, std::placeholders::_1, std::placeholders::_2); RAY_CHECK_OK(gcs_client_->Actors().AsyncSubscribe( actor_id, actor_notification_callback, nullptr)); + } else { + RAY_LOG(ERROR) << "Actor handle already exists " << actor_id.Hex(); } return inserted; From 6ece291f352e1531f29fa8573e21aa47f290b485 Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Fri, 18 Dec 2020 16:00:54 -0800 Subject: [PATCH 28/88] Clean up block/unblock handling of resources in new scheduler (#12963) --- python/ray/tests/test_failure.py | 5 +-- python/ray/tests/test_gcs_fault_tolerance.py | 2 -- python/ray/tests/test_global_state.py | 3 -- python/ray/tests/test_reference_counting.py | 2 +- src/ray/raylet/node_manager.cc | 14 +++++--- .../scheduling/cluster_resource_scheduler.cc | 35 +++++++++++++------ .../scheduling/cluster_resource_scheduler.h | 8 +++-- .../raylet/scheduling/cluster_task_manager.cc | 12 +++---- .../raylet/scheduling/cluster_task_manager.h | 4 +-- src/ray/raylet/test/util.h | 8 ----- src/ray/raylet/worker.h | 22 ------------ src/ray/raylet/worker_pool.cc | 4 +-- 12 files changed, 51 insertions(+), 68 deletions(-) diff --git a/python/ray/tests/test_failure.py b/python/ray/tests/test_failure.py index b3505bbf5..f01868989 100644 --- a/python/ray/tests/test_failure.py +++ b/python/ray/tests/test_failure.py @@ -17,8 +17,7 @@ import ray.ray_constants as ray_constants from ray.exceptions import RayTaskError from ray.cluster_utils import Cluster from ray.test_utils import (wait_for_condition, SignalActor, init_error_pubsub, - get_error_message, Semaphore, - new_scheduler_enabled) + get_error_message, Semaphore) def test_failed_task(ray_start_regular, error_pubsub): @@ -633,8 +632,6 @@ def test_export_large_objects(ray_start_regular, error_pubsub): assert errors[0].type == ray_constants.PICKLING_LARGE_OBJECT_PUSH_ERROR -@pytest.mark.skipif( - new_scheduler_enabled(), reason="Supposed to deadlock, but it doesn't") def test_warning_all_tasks_blocked(shutdown_only): ray.init( num_cpus=1, _system_config={"debug_dump_period_milliseconds": 500}) diff --git a/python/ray/tests/test_gcs_fault_tolerance.py b/python/ray/tests/test_gcs_fault_tolerance.py index 32f20d42a..642cf4dd3 100644 --- a/python/ray/tests/test_gcs_fault_tolerance.py +++ b/python/ray/tests/test_gcs_fault_tolerance.py @@ -6,7 +6,6 @@ from ray.test_utils import ( generate_system_config_map, wait_for_condition, wait_for_pid_to_exit, - new_scheduler_enabled, ) @@ -21,7 +20,6 @@ def increase(x): return x + 1 -@pytest.mark.skipif(new_scheduler_enabled(), reason="notimpl") @pytest.mark.parametrize( "ray_start_regular", [ generate_system_config_map( diff --git a/python/ray/tests/test_global_state.py b/python/ray/tests/test_global_state.py index c201b6bc3..3dcd64c1e 100644 --- a/python/ray/tests/test_global_state.py +++ b/python/ray/tests/test_global_state.py @@ -8,7 +8,6 @@ import time import ray import ray.ray_constants import ray.test_utils -from ray.test_utils import new_scheduler_enabled from ray._raylet import GlobalStateAccessor @@ -217,8 +216,6 @@ def test_load_report(shutdown_only, max_shapes): global_state_accessor.disconnect() -@pytest.mark.skipif( - new_scheduler_enabled(), reason="requires placement groups") def test_placement_group_load_report(ray_start_cluster): cluster = ray_start_cluster # Add a head node that doesn't have gpu resource. diff --git a/python/ray/tests/test_reference_counting.py b/python/ray/tests/test_reference_counting.py index b9f3b0906..b93ee4221 100644 --- a/python/ray/tests/test_reference_counting.py +++ b/python/ray/tests/test_reference_counting.py @@ -167,7 +167,7 @@ def test_dependency_refcounts(ray_start_regular): check_refcounts({}) -@pytest.mark.skipif(new_scheduler_enabled(), reason="hangs") +@pytest.mark.skipif(new_scheduler_enabled(), reason="dynamic res todo") def test_actor_creation_task(ray_start_regular): @ray.remote def large_object(): diff --git a/src/ray/raylet/node_manager.cc b/src/ray/raylet/node_manager.cc index 2e13bbb91..fe975d79b 100644 --- a/src/ray/raylet/node_manager.cc +++ b/src/ray/raylet/node_manager.cc @@ -2158,9 +2158,11 @@ void NodeManager::HandleDirectCallTaskBlocked( cpu_instances = worker->GetAllocatedInstances()->GetCPUInstancesDouble(); } if (cpu_instances.size() > 0) { - std::vector borrowed_cpu_instances = + std::vector overflow_cpu_instances = new_resource_scheduler_->AddCPUResourceInstances(cpu_instances); - worker->SetBorrowedCPUInstances(borrowed_cpu_instances); + for (unsigned int i = 0; i < overflow_cpu_instances.size(); i++) { + RAY_CHECK(overflow_cpu_instances[i] == 0) << "Should not be overflow"; + } worker->MarkBlocked(); } ScheduleAndDispatch(); @@ -2199,9 +2201,11 @@ void NodeManager::HandleDirectCallTaskUnblocked( cpu_instances = worker->GetAllocatedInstances()->GetCPUInstancesDouble(); } if (cpu_instances.size() > 0) { - new_resource_scheduler_->SubtractCPUResourceInstances(cpu_instances); - new_resource_scheduler_->AddCPUResourceInstances(worker->GetBorrowedCPUInstances()); - worker->ClearBorrowedCPUInstances(); + // Important: we allow going negative here, since otherwise you can use infinite + // CPU resources by repeatedly blocking / unblocking a task. By allowing it to go + // negative, at most one task can "borrow" this worker's resources. + new_resource_scheduler_->SubtractCPUResourceInstances( + cpu_instances, /*allow_going_negative=*/true); worker->MarkUnblocked(); } ScheduleAndDispatch(); diff --git a/src/ray/raylet/scheduling/cluster_resource_scheduler.cc b/src/ray/raylet/scheduling/cluster_resource_scheduler.cc index bcca8862a..10eae694c 100644 --- a/src/ray/raylet/scheduling/cluster_resource_scheduler.cc +++ b/src/ray/raylet/scheduling/cluster_resource_scheduler.cc @@ -565,15 +565,25 @@ std::vector ClusterResourceScheduler::AddAvailableResourceInstances( } std::vector ClusterResourceScheduler::SubtractAvailableResourceInstances( - std::vector available, ResourceInstanceCapacities *resource_instances) { + std::vector available, ResourceInstanceCapacities *resource_instances, + bool allow_going_negative) { RAY_CHECK(available.size() == resource_instances->available.size()); std::vector underflow(available.size(), 0.); for (size_t i = 0; i < available.size(); i++) { - resource_instances->available[i] = resource_instances->available[i] - available[i]; if (resource_instances->available[i] < 0) { - underflow[i] = -resource_instances->available[i]; - resource_instances->available[i] = 0; + if (allow_going_negative) { + resource_instances->available[i] = + resource_instances->available[i] - available[i]; + } else { + underflow[i] = available[i]; // No change in the value in this case. + } + } else { + resource_instances->available[i] = resource_instances->available[i] - available[i]; + if (resource_instances->available[i] < 0 && !allow_going_negative) { + underflow[i] = -resource_instances->available[i]; + resource_instances->available[i] = 0; + } } } return underflow; @@ -777,7 +787,7 @@ std::vector ClusterResourceScheduler::AddCPUResourceInstances( } std::vector ClusterResourceScheduler::SubtractCPUResourceInstances( - std::vector &cpu_instances) { + std::vector &cpu_instances, bool allow_going_negative) { std::vector cpu_instances_fp = VectorDoubleToVectorFixedPoint(cpu_instances); @@ -787,7 +797,8 @@ std::vector ClusterResourceScheduler::SubtractCPUResourceInstances( RAY_CHECK(nodes_.find(local_node_id_) != nodes_.end()); auto underflow = SubtractAvailableResourceInstances( - cpu_instances_fp, &local_resources_.predefined_resources[CPU]); + cpu_instances_fp, &local_resources_.predefined_resources[CPU], + allow_going_negative); UpdateLocalAvailableResourcesFromResourceInstances(); return VectorFixedPointToVectorDouble(underflow); @@ -916,7 +927,8 @@ void ClusterResourceScheduler::FillResourceUsage( const auto &label = ResourceEnumToString((PredefinedResources)i); const auto &capacity = resources.predefined_resources[i]; const auto &last_capacity = last_report_resources_->predefined_resources[i]; - if (capacity.available != last_capacity.available) { + // Note: available may be negative, but only report positive to GCS. + if (capacity.available != last_capacity.available && capacity.available > 0) { resources_data->set_resources_available_changed(true); (*resources_data->mutable_resources_available())[label] = capacity.available.Double(); @@ -931,7 +943,8 @@ void ClusterResourceScheduler::FillResourceUsage( const auto &capacity = it->second; const auto &last_capacity = last_report_resources_->custom_resources[custom_id]; const auto &label = string_to_int_map_.Get(custom_id); - if (capacity.available != last_capacity.available) { + // Note: available may be negative, but only report positive to GCS. + if (capacity.available != last_capacity.available && capacity.available > 0) { resources_data->set_resources_available_changed(true); (*resources_data->mutable_resources_available())[label] = capacity.available.Double(); @@ -947,7 +960,8 @@ void ClusterResourceScheduler::FillResourceUsage( for (int i = 0; i < PredefinedResources_MAX; i++) { const auto &label = ResourceEnumToString((PredefinedResources)i); const auto &capacity = resources.predefined_resources[i]; - if (capacity.available != 0) { + // Note: available may be negative, but only report positive to GCS. + if (capacity.available > 0) { (*resources_data->mutable_resources_available())[label] = capacity.available.Double(); } @@ -960,7 +974,8 @@ void ClusterResourceScheduler::FillResourceUsage( uint64_t custom_id = it->first; const auto &capacity = it->second; const auto &label = string_to_int_map_.Get(custom_id); - if (capacity.available != 0) { + // Note: available may be negative, but only report positive to GCS. + if (capacity.available > 0) { (*resources_data->mutable_resources_available())[label] = capacity.available.Double(); } diff --git a/src/ray/raylet/scheduling/cluster_resource_scheduler.h b/src/ray/raylet/scheduling/cluster_resource_scheduler.h index 9e480b4c8..7d1f5253c 100644 --- a/src/ray/raylet/scheduling/cluster_resource_scheduler.h +++ b/src/ray/raylet/scheduling/cluster_resource_scheduler.h @@ -279,11 +279,13 @@ class ClusterResourceScheduler { /// /// \param free A list of capacities for resource's instances to be freed. /// \param resource_instances List of the resource instances being updated. + /// \param allow_going_negative Allow the values to go negative (disable underflow). /// \return Underflow of "resource_instances" after subtracting instance /// capacities in "available", i.e.,. /// max(available - reasource_instances.available, 0) std::vector SubtractAvailableResourceInstances( - std::vector available, ResourceInstanceCapacities *resource_instances); + std::vector available, ResourceInstanceCapacities *resource_instances, + bool allow_going_negative = false); /// Increase the available CPU instances of this node. /// @@ -296,10 +298,12 @@ class ClusterResourceScheduler { /// Decrease the available CPU instances of this node. /// /// \param cpu_instances CPU instances to be removed from available cpus. + /// \param allow_going_negative Allow the values to go negative (disable underflow). /// /// \return Underflow capacities of CPU instances after subtracting CPU /// capacities in cpu_instances. - std::vector SubtractCPUResourceInstances(std::vector &cpu_instances); + std::vector SubtractCPUResourceInstances(std::vector &cpu_instances, + bool allow_going_negative = false); /// Increase the available GPU instances of this node. /// diff --git a/src/ray/raylet/scheduling/cluster_task_manager.cc b/src/ray/raylet/scheduling/cluster_task_manager.cc index c11d818ef..09db70f1b 100644 --- a/src/ray/raylet/scheduling/cluster_task_manager.cc +++ b/src/ray/raylet/scheduling/cluster_task_manager.cc @@ -1,6 +1,7 @@ #include "ray/raylet/scheduling/cluster_task_manager.h" #include +#include #include "ray/util/logging.h" @@ -242,10 +243,7 @@ void ClusterTaskManager::TasksUnblocked(const std::vector ready_ids) { const auto &scheduling_key = task.GetTaskSpecification().GetSchedulingClass(); RAY_LOG(DEBUG) << "Args ready, task can be dispatched " << task.GetTaskSpecification().TaskId(); - // Note: we transition tasks back to the scheduling queue instead of directly - // to dispatch. This allows AnyPendingTasks() to simply check the scheduling - // queue to see if any tasks are blocked on resource availability: see #12438 - tasks_to_schedule_[scheduling_key].push_back(work); + tasks_to_dispatch_[scheduling_key].push_back(work); waiting_tasks_.erase(it); } } @@ -507,9 +505,9 @@ bool ClusterTaskManager::AnyPendingTasks(Task *exemplar, bool *any_pending, int *num_pending_actor_creation, int *num_pending_tasks) const { // We are guaranteed that these tasks are blocked waiting for resources after a - // call to ScheduleAndDispatch(). Note that tasks that transition to waiting - // move back to the tasks_to_schedule_ queue after their deps are satisfied. - for (const auto &shapes_it : tasks_to_schedule_) { + // call to ScheduleAndDispatch(). They may be waiting for workers as well, but + // this should be a transient condition only. + for (const auto &shapes_it : boost::join(tasks_to_dispatch_, tasks_to_schedule_)) { auto &work_queue = shapes_it.second; for (const auto &work_it : work_queue) { const auto &task = std::get<0>(work_it); diff --git a/src/ray/raylet/scheduling/cluster_task_manager.h b/src/ray/raylet/scheduling/cluster_task_manager.h index b71593f8a..aabc0a6fb 100644 --- a/src/ray/raylet/scheduling/cluster_task_manager.h +++ b/src/ray/raylet/scheduling/cluster_task_manager.h @@ -157,11 +157,11 @@ class ClusterTaskManager { std::unordered_map> tasks_to_schedule_; /// Queue of lease requests that should be scheduled onto workers. - /// Tasks move from scheduled -> dispatch. + /// Tasks move from scheduled | waiting -> dispatch. std::unordered_map> tasks_to_dispatch_; /// Tasks waiting for arguments to be transferred locally. - /// Tasks move (back) from waiting -> scheduled. + /// Tasks move from waiting -> dispatch. absl::flat_hash_map waiting_tasks_; /// Queue of lease requests that are infeasible. diff --git a/src/ray/raylet/test/util.h b/src/ray/raylet/test/util.h index 90fc9b158..a39bdd1c0 100644 --- a/src/ray/raylet/test/util.h +++ b/src/ray/raylet/test/util.h @@ -161,10 +161,6 @@ class MockWorker : public WorkerInterface { void ClearLifetimeAllocatedInstances() { lifetime_allocated_instances_ = nullptr; } - void SetBorrowedCPUInstances(std::vector &cpu_instances) { - borrowed_cpu_instances_ = cpu_instances; - } - const BundleID &GetBundleId() const { RAY_CHECK(false) << "Method unused"; return bundle_id_; @@ -172,10 +168,6 @@ class MockWorker : public WorkerInterface { void SetBundleId(const BundleID &bundle_id) { bundle_id_ = bundle_id; } - std::vector &GetBorrowedCPUInstances() { return borrowed_cpu_instances_; } - - void ClearBorrowedCPUInstances() { RAY_CHECK(false) << "Method unused"; } - Task &GetAssignedTask() { RAY_CHECK(false) << "Method unused"; auto *t = new Task(); diff --git a/src/ray/raylet/worker.h b/src/ray/raylet/worker.h index 7cd19868e..b4c0b34ce 100644 --- a/src/ray/raylet/worker.h +++ b/src/ray/raylet/worker.h @@ -98,12 +98,6 @@ class WorkerInterface { virtual void ClearLifetimeAllocatedInstances() = 0; - virtual void SetBorrowedCPUInstances(std::vector &cpu_instances) = 0; - - virtual std::vector &GetBorrowedCPUInstances() = 0; - - virtual void ClearBorrowedCPUInstances() = 0; - virtual Task &GetAssignedTask() = 0; virtual void SetAssignedTask(const Task &assigned_task) = 0; @@ -196,14 +190,6 @@ class Worker : public WorkerInterface { void ClearLifetimeAllocatedInstances() { lifetime_allocated_instances_ = nullptr; }; - void SetBorrowedCPUInstances(std::vector &cpu_instances) { - borrowed_cpu_instances_ = cpu_instances; - }; - - std::vector &GetBorrowedCPUInstances() { return borrowed_cpu_instances_; }; - - void ClearBorrowedCPUInstances() { return borrowed_cpu_instances_.clear(); }; - Task &GetAssignedTask() { return assigned_task_; }; void SetAssignedTask(const Task &assigned_task) { assigned_task_ = assigned_task; }; @@ -273,14 +259,6 @@ class Worker : public WorkerInterface { /// The capacity of each resource instance allocated to this worker /// when running as an actor. std::shared_ptr lifetime_allocated_instances_; - /// CPUs borrowed by the worker. This happens in the following scenario: - /// 1) Worker A is blocked, so it donates its CPUs back to the node. - /// 2) Other workers are scheduled and are allocated some of the CPUs donated by A. - /// 3) Task A is unblocked, but it cannot get all CPUs back. At this point, - /// the node is oversubscribed. borrowed_cpu_instances_ represents the number - /// of CPUs this node is oversubscribed by. - /// TODO (Ion): Investigate a more intuitive alternative to track these Cpus. - std::vector borrowed_cpu_instances_; /// Task being assigned to this worker. Task assigned_task_; }; diff --git a/src/ray/raylet/worker_pool.cc b/src/ray/raylet/worker_pool.cc index a7ad5c245..93a568748 100644 --- a/src/ray/raylet/worker_pool.cc +++ b/src/ray/raylet/worker_pool.cc @@ -956,8 +956,8 @@ const std::vector> WorkerPool::GetAllRegistered } void WorkerPool::WarnAboutSize() { - for (const auto &entry : states_by_lang_) { - auto state = entry.second; + for (auto &entry : states_by_lang_) { + auto &state = entry.second; int64_t num_workers_started_or_registered = 0; num_workers_started_or_registered += static_cast(state.registered_workers.size()); From ac5ea2c13d239ab952db74c33dcb69636e1b9ab8 Mon Sep 17 00:00:00 2001 From: Kai Yang Date: Sat, 19 Dec 2020 10:22:12 +0800 Subject: [PATCH 29/88] [Java] Fix output parsing in RunManager (#12968) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Fix output parsing in RunManager * change log level Co-authored-by: 灵洵 --- .../main/java/io/ray/runtime/runner/RunManager.java | 8 ++++++-- python/ray/_private/services.py | 12 +++++------- src/ray/gcs/gcs_client/service_based_gcs_client.cc | 12 ++++++++---- 3 files changed, 19 insertions(+), 13 deletions(-) diff --git a/java/runtime/src/main/java/io/ray/runtime/runner/RunManager.java b/java/runtime/src/main/java/io/ray/runtime/runner/RunManager.java index b6ed41f7f..4bd49deb4 100644 --- a/java/runtime/src/main/java/io/ray/runtime/runner/RunManager.java +++ b/java/runtime/src/main/java/io/ray/runtime/runner/RunManager.java @@ -75,14 +75,18 @@ public class RunManager { // address info of the local node. String script = String.format("import ray;" + " print(ray._private.services.get_address_info_from_redis(" - + "'%s', '%s', redis_password='%s', no_warning=True))", + + "'%s', '%s', redis_password='%s'))", rayConfig.getRedisAddress(), rayConfig.nodeIp, rayConfig.redisPassword); List command = Arrays.asList("python", "-c", script); String output = null; try { output = runCommand(command); - JsonObject addressInfo = new JsonParser().parse(output).getAsJsonObject(); + // NOTE(kfstorm): We only parse the last line here in case there are some warning + // messages appear at the beginning. + String[] lines = output.split(System.lineSeparator()); + String lastLine = lines[lines.length - 1]; + JsonObject addressInfo = new JsonParser().parse(lastLine).getAsJsonObject(); rayConfig.rayletSocketName = addressInfo.get("raylet_socket_name").getAsString(); rayConfig.objectStoreSocketName = addressInfo.get("object_store_address").getAsString(); rayConfig.nodeManagerPort = addressInfo.get("node_manager_port").getAsInt(); diff --git a/python/ray/_private/services.py b/python/ray/_private/services.py index c3512ab92..6d8bbb97c 100644 --- a/python/ray/_private/services.py +++ b/python/ray/_private/services.py @@ -279,8 +279,7 @@ def get_address_info_from_redis_helper(redis_address, def get_address_info_from_redis(redis_address, node_ip_address, num_retries=5, - redis_password=None, - no_warning=False): + redis_password=None): counter = 0 while True: try: @@ -291,11 +290,10 @@ def get_address_info_from_redis(redis_address, raise # Some of the information may not be in Redis yet, so wait a little # bit. - if not no_warning: - logger.warning( - "Some processes that the driver needs to connect to have " - "not registered with Redis, so retrying. Have you run " - "'ray start' on this node?") + logger.warning( + "Some processes that the driver needs to connect to have " + "not registered with Redis, so retrying. Have you run " + "'ray start' on this node?") time.sleep(1) counter += 1 diff --git a/src/ray/gcs/gcs_client/service_based_gcs_client.cc b/src/ray/gcs/gcs_client/service_based_gcs_client.cc index 884612106..f643496b8 100644 --- a/src/ray/gcs/gcs_client/service_based_gcs_client.cc +++ b/src/ray/gcs/gcs_client/service_based_gcs_client.cc @@ -209,11 +209,15 @@ void ServiceBasedGcsClient::ReconnectGcsServer() { return; } - RAY_LOG(INFO) << "Attemptting to reconnect to GCS server: " << address.first << ":" - << address.second; + RAY_LOG(DEBUG) << "Attemptting to reconnect to GCS server: " << address.first << ":" + << address.second; if (Ping(address.first, address.second, 100)) { - RAY_LOG(INFO) << "Reconnected to GCS server: " << address.first << ":" - << address.second; + // If `last_reconnect_address_` port is -1, it means that this is the first + // connection and no log will be printed. + if (last_reconnect_address_.second != -1) { + RAY_LOG(INFO) << "Reconnected to GCS server: " << address.first << ":" + << address.second; + } break; } } From 404161a3ffd2d5937be98ade718d8ea48ab4c5a3 Mon Sep 17 00:00:00 2001 From: Alex Wu Date: Fri, 18 Dec 2020 18:22:45 -0800 Subject: [PATCH 30/88] [Autoscaler/Core] Remove autoscaler spam (#12952) --- .../ray/autoscaler/_private/load_metrics.py | 24 +++++++------- src/ray/gcs/gcs_server/gcs_node_manager.cc | 1 + .../gcs_server/test/gcs_node_manager_test.cc | 32 +++++++++++++++++++ 3 files changed, 45 insertions(+), 12 deletions(-) diff --git a/python/ray/autoscaler/_private/load_metrics.py b/python/ray/autoscaler/_private/load_metrics.py index 1fadeae3b..b688fe617 100644 --- a/python/ray/autoscaler/_private/load_metrics.py +++ b/python/ray/autoscaler/_private/load_metrics.py @@ -79,27 +79,27 @@ class LoadMetrics: active_ips = set(active_ips) active_ips.add(self.local_ip) - def prune(mapping): + def prune(mapping, should_log): unwanted = set(mapping) - active_ips for unwanted_key in unwanted: - # TODO (Alex): Change this back to info after #12138. - logger.debug("LoadMetrics: " - "Removed mapping: {} - {}".format( - unwanted_key, mapping[unwanted_key])) + if should_log: + logger.info("LoadMetrics: " + "Removed mapping: {} - {}".format( + unwanted_key, mapping[unwanted_key])) del mapping[unwanted_key] - if unwanted: + if unwanted and should_log: # TODO (Alex): Change this back to info after #12138. - logger.debug( + logger.info( "LoadMetrics: " "Removed {} stale ip mappings: {} not in {}".format( len(unwanted), unwanted, active_ips)) assert not (unwanted & set(mapping)) - prune(self.last_used_time_by_ip) - prune(self.static_resources_by_ip) - prune(self.dynamic_resources_by_ip) - prune(self.resource_load_by_ip) - prune(self.last_heartbeat_time_by_ip) + prune(self.last_used_time_by_ip, should_log=True) + prune(self.static_resources_by_ip, should_log=False) + prune(self.dynamic_resources_by_ip, should_log=False) + prune(self.resource_load_by_ip, should_log=False) + prune(self.last_heartbeat_time_by_ip, should_log=False) def get_node_resources(self): """Return a list of node resources (static resource sizes). diff --git a/src/ray/gcs/gcs_server/gcs_node_manager.cc b/src/ray/gcs/gcs_server/gcs_node_manager.cc index 820d3a723..499abc90f 100644 --- a/src/ray/gcs/gcs_server/gcs_node_manager.cc +++ b/src/ray/gcs/gcs_server/gcs_node_manager.cc @@ -258,6 +258,7 @@ std::shared_ptr GcsNodeManager::RemoveNode( // Remove from cluster resources. gcs_resource_manager_->OnNodeDead(node_id); resources_buffer_.erase(node_id); + node_resource_usages_.erase(node_id); if (!is_intended) { // Broadcast a warning to all of the drivers indicating that the node // has been marked as dead. diff --git a/src/ray/gcs/gcs_server/test/gcs_node_manager_test.cc b/src/ray/gcs/gcs_server/test/gcs_node_manager_test.cc index a904512ac..74c4b8fd1 100644 --- a/src/ray/gcs/gcs_server/test/gcs_node_manager_test.cc +++ b/src/ray/gcs/gcs_server/test/gcs_node_manager_test.cc @@ -41,11 +41,43 @@ TEST_F(GcsNodeManagerTest, TestManagement) { auto node = Mocker::GenNodeInfo(); auto node_id = NodeID::FromBinary(node->node_id()); + { + rpc::GetAllResourceUsageRequest request; + rpc::GetAllResourceUsageReply reply; + auto send_reply_callback = [](ray::Status status, std::function f1, + std::function f2) {}; + node_manager.HandleGetAllResourceUsage(request, &reply, send_reply_callback); + ASSERT_EQ(reply.resource_usage_data().batch().size(), 0); + } + node_manager.AddNode(node); ASSERT_EQ(node, node_manager.GetAliveNode(node_id).value()); + rpc::ReportResourceUsageRequest report_request; + (*report_request.mutable_resources()->mutable_resources_available())["CPU"] = 2; + (*report_request.mutable_resources()->mutable_resources_total())["CPU"] = 2; + node_manager.UpdateNodeResourceUsage(node_id, report_request); + + { + rpc::GetAllResourceUsageRequest request; + rpc::GetAllResourceUsageReply reply; + auto send_reply_callback = [](ray::Status status, std::function f1, + std::function f2) {}; + node_manager.HandleGetAllResourceUsage(request, &reply, send_reply_callback); + ASSERT_EQ(reply.resource_usage_data().batch().size(), 1); + } + node_manager.RemoveNode(node_id); ASSERT_TRUE(!node_manager.GetAliveNode(node_id).has_value()); + + { + rpc::GetAllResourceUsageRequest request; + rpc::GetAllResourceUsageReply reply; + auto send_reply_callback = [](ray::Status status, std::function f1, + std::function f2) {}; + node_manager.HandleGetAllResourceUsage(request, &reply, send_reply_callback); + ASSERT_EQ(reply.resource_usage_data().batch().size(), 0); + } } TEST_F(GcsNodeManagerTest, TestListener) { From 9d939e66742556004a40f8898e7b80424be4965f Mon Sep 17 00:00:00 2001 From: SangBin Cho Date: Fri, 18 Dec 2020 19:31:14 -0800 Subject: [PATCH 31/88] [Object Spilling] Implement level triggered logic to make streaming shuffle work + additional cleanup (#12773) --- python/ray/_raylet.pyx | 22 +- python/ray/experimental/__init__.py | 2 - python/ray/experimental/object_spilling.py | 18 -- python/ray/external_storage.py | 15 +- python/ray/includes/libcoreworker.pxd | 2 +- python/ray/tests/conftest.py | 2 +- python/ray/tests/test_object_spilling.py | 211 ++++-------------- python/ray/tests/test_reference_counting.py | 1 - src/ray/common/ray_config_def.h | 8 +- src/ray/core_worker/core_worker.cc | 7 +- src/ray/core_worker/core_worker.h | 2 +- .../memory_store/memory_store.cc | 20 +- .../memory_store/memory_store.h | 10 +- .../store_provider/plasma_store_provider.cc | 12 +- .../store_provider/plasma_store_provider.h | 3 +- src/ray/gcs/gcs_server/gcs_object_manager.cc | 7 +- src/ray/object_manager/common.h | 10 +- src/ray/object_manager/object_manager.cc | 7 + src/ray/object_manager/object_manager.h | 7 + .../plasma/create_request_queue.cc | 70 +++--- .../plasma/create_request_queue.h | 13 +- .../object_manager/plasma/eviction_policy.cc | 8 +- src/ray/object_manager/plasma/plasma.fbs | 7 +- src/ray/object_manager/plasma/protocol.cc | 3 - src/ray/object_manager/plasma/store.cc | 55 ++--- src/ray/object_manager/plasma/store.h | 25 ++- src/ray/object_manager/plasma/store_runner.cc | 12 +- src/ray/object_manager/plasma/store_runner.h | 1 + src/ray/object_manager/pull_manager.cc | 2 + .../test/create_request_queue_test.cc | 75 ++++--- src/ray/protobuf/core_worker.proto | 1 + src/ray/raylet/format/node_manager.fbs | 1 + src/ray/raylet/local_object_manager.cc | 132 +++++++---- src/ray/raylet/local_object_manager.h | 127 ++++++++--- src/ray/raylet/main.cc | 2 + src/ray/raylet/node_manager.cc | 35 ++- src/ray/raylet/node_manager.h | 25 ++- src/ray/raylet/raylet.cc | 43 ++-- .../raylet/test/local_object_manager_test.cc | 185 ++++++++++----- src/ray/raylet_client/raylet_client.cc | 4 +- src/ray/raylet_client/raylet_client.h | 5 +- 41 files changed, 654 insertions(+), 543 deletions(-) delete mode 100644 python/ray/experimental/object_spilling.py diff --git a/python/ray/_raylet.pyx b/python/ray/_raylet.pyx index 8a216c7cf..4d5bb8ff9 100644 --- a/python/ray/_raylet.pyx +++ b/python/ray/_raylet.pyx @@ -638,9 +638,11 @@ cdef c_vector[c_string] spill_objects_handler( return return_urls -cdef void restore_spilled_objects_handler( +cdef int64_t restore_spilled_objects_handler( const c_vector[CObjectID]& object_ids_to_restore, const c_vector[c_string]& object_urls) nogil: + cdef: + int64_t bytes_restored = 0 with gil: urls = [] size = object_urls.size() @@ -651,7 +653,8 @@ cdef void restore_spilled_objects_handler( with ray.worker._changeproctitle( ray_constants.WORKER_PROCESS_TYPE_RESTORE_WORKER, ray_constants.WORKER_PROCESS_TYPE_RESTORE_WORKER_IDLE): - external_storage.restore_spilled_objects(object_refs, urls) + bytes_restored = external_storage.restore_spilled_objects( + object_refs, urls) except Exception: exception_str = ( "An unexpected internal error occurred while the IO worker " @@ -662,6 +665,7 @@ cdef void restore_spilled_objects_handler( "restore_spilled_objects_error", traceback.format_exc() + exception_str, job_id=None) + return bytes_restored cdef void delete_spilled_objects_handler( @@ -873,7 +877,8 @@ cdef class CoreWorker: return self.plasma_event_handler def get_objects(self, object_refs, TaskID current_task_id, - int64_t timeout_ms=-1, plasma_objects_only=False): + int64_t timeout_ms=-1, + plasma_objects_only=False): cdef: c_vector[shared_ptr[CRayObject]] results CTaskID c_task_id = current_task_id.native() @@ -1573,17 +1578,6 @@ cdef class CoreWorker: resource_name.encode("ascii"), capacity, CNodeID.FromBinary(client_id.binary())) - def force_spill_objects(self, object_refs): - cdef c_vector[CObjectID] object_ids - object_ids = ObjectRefsToVector(object_refs) - assert not RayConfig.instance().automatic_object_deletion_enabled(), ( - "Automatic object deletion is not supported for" - "force_spill_objects yet. Please set" - "automatic_object_deletion_enabled: False in Ray's system config.") - with nogil: - check_status(CCoreWorkerProcess.GetCoreWorker() - .SpillObjects(object_ids)) - cdef void async_set_result(shared_ptr[CRayObject] obj, CObjectID object_ref, void *future) with gil: diff --git a/python/ray/experimental/__init__.py b/python/ray/experimental/__init__.py index c59ef2702..d8cd30f3d 100644 --- a/python/ray/experimental/__init__.py +++ b/python/ray/experimental/__init__.py @@ -1,6 +1,4 @@ from .dynamic_resources import set_resource -from .object_spilling import force_spill_objects __all__ = [ "set_resource", - "force_spill_objects", ] diff --git a/python/ray/experimental/object_spilling.py b/python/ray/experimental/object_spilling.py deleted file mode 100644 index def5d2353..000000000 --- a/python/ray/experimental/object_spilling.py +++ /dev/null @@ -1,18 +0,0 @@ -import ray - - -def force_spill_objects(object_refs): - """Force spilling objects to external storage. - - Args: - object_refs: Object refs of the objects to be - spilled. - """ - core_worker = ray.worker.global_worker.core_worker - # Make sure that the values are object refs. - for object_ref in object_refs: - if not isinstance(object_ref, ray.ObjectRef): - raise TypeError( - f"Attempting to call `force_spill_objects` on the " - f"value {object_ref}, which is not an ray.ObjectRef.") - return core_worker.force_spill_objects(object_refs) diff --git a/python/ray/external_storage.py b/python/ray/external_storage.py index 726065277..1b4f6fec8 100644 --- a/python/ray/external_storage.py +++ b/python/ray/external_storage.py @@ -157,12 +157,15 @@ class ExternalStorage(metaclass=abc.ABCMeta): @abc.abstractmethod def restore_spilled_objects(self, object_refs: List[ObjectRef], - url_with_offset_list: List[str]): + url_with_offset_list: List[str]) -> int: """Restore objects from the external storage. Args: object_refs: List of object IDs (note that it is not ref). url_with_offset_list: List of url_with_offset. + + Returns: + The total number of bytes restored. """ @abc.abstractmethod @@ -215,6 +218,7 @@ class FileSystemStorage(ExternalStorage): def restore_spilled_objects(self, object_refs: List[ObjectRef], url_with_offset_list: List[str]): + total = 0 for i in range(len(object_refs)): object_ref = object_refs[i] url_with_offset = url_with_offset_list[i].decode() @@ -228,9 +232,11 @@ class FileSystemStorage(ExternalStorage): metadata_len = int.from_bytes(f.read(8), byteorder="little") buf_len = int.from_bytes(f.read(8), byteorder="little") self._size_check(metadata_len, buf_len, parsed_result.size) + total += buf_len metadata = f.read(metadata_len) # read remaining data to our buffer self._put_object_to_store(metadata, buf_len, f, object_ref) + return total def delete_spilled_objects(self, urls: List[str]): for url in urls: @@ -297,6 +303,7 @@ class ExternalStorageSmartOpenImpl(ExternalStorage): def restore_spilled_objects(self, object_refs: List[ObjectRef], url_with_offset_list: List[str]): from smart_open import open + total = 0 for i in range(len(object_refs)): object_ref = object_refs[i] url_with_offset = url_with_offset_list[i].decode() @@ -315,9 +322,11 @@ class ExternalStorageSmartOpenImpl(ExternalStorage): metadata_len = int.from_bytes(f.read(8), byteorder="little") buf_len = int.from_bytes(f.read(8), byteorder="little") self._size_check(metadata_len, buf_len, parsed_result.size) + total += buf_len metadata = f.read(metadata_len) # read remaining data to our buffer self._put_object_to_store(metadata, buf_len, f, object_ref) + return total def delete_spilled_objects(self, urls: List[str]): pass @@ -367,8 +376,8 @@ def restore_spilled_objects(object_refs: List[ObjectRef], object_refs: List of object IDs (note that it is not ref). url_with_offset_list: List of url_with_offset. """ - _external_storage.restore_spilled_objects(object_refs, - url_with_offset_list) + return _external_storage.restore_spilled_objects(object_refs, + url_with_offset_list) def delete_spilled_objects(urls: List[str]): diff --git a/python/ray/includes/libcoreworker.pxd b/python/ray/includes/libcoreworker.pxd index 7394f68b5..68c1a95b3 100644 --- a/python/ray/includes/libcoreworker.pxd +++ b/python/ray/includes/libcoreworker.pxd @@ -233,7 +233,7 @@ cdef extern from "ray/core_worker/core_worker.h" nogil: (CRayStatus() nogil) check_signals (void() nogil) gc_collect (c_vector[c_string](const c_vector[CObjectID] &) nogil) spill_objects - (void( + (int64_t( const c_vector[CObjectID] &, const c_vector[c_string] &) nogil) restore_spilled_objects (void( diff --git a/python/ray/tests/conftest.py b/python/ray/tests/conftest.py index 1c48a28d3..26d2bbdc0 100644 --- a/python/ray/tests/conftest.py +++ b/python/ray/tests/conftest.py @@ -23,7 +23,7 @@ def get_default_fixure_system_config(): "object_timeout_milliseconds": 200, "num_heartbeats_timeout": 10, "object_store_full_max_retries": 3, - "object_store_full_initial_delay_ms": 100, + "object_store_full_delay_ms": 100, } return system_config diff --git a/python/ray/tests/test_object_spilling.py b/python/ray/tests/test_object_spilling.py index 9004fd030..624fcb85d 100644 --- a/python/ray/tests/test_object_spilling.py +++ b/python/ray/tests/test_object_spilling.py @@ -4,11 +4,9 @@ import os import random import platform import sys -import time import numpy as np import pytest -import psutil import ray from ray.external_storage import (create_url_with_offset, parse_url_with_offset) @@ -43,57 +41,6 @@ def object_spilling_config(request, tmpdir): yield json.dumps(request.param) -@pytest.mark.skip("This test is for local benchmark.") -def test_sample_benchmark(object_spilling_config, shutdown_only): - # --Config values-- - max_io_workers = 10 - object_store_limit = 500 * 1024 * 1024 - eight_mb = 1024 * 1024 - object_size = 12 * eight_mb - spill_cnt = 50 - - # Limit our object store to 200 MiB of memory. - ray.init( - object_store_memory=object_store_limit, - _system_config={ - "object_store_full_max_retries": 0, - "max_io_workers": max_io_workers, - "object_spilling_config": object_spilling_config, - "automatic_object_deletion_enabled": False, - }) - arr = np.random.rand(object_size) - replay_buffer = [] - pinned_objects = set() - - # Create objects of more than 200 MiB. - spill_start = time.perf_counter() - for _ in range(spill_cnt): - ref = None - while ref is None: - try: - ref = ray.put(arr) - replay_buffer.append(ref) - pinned_objects.add(ref) - except ray.exceptions.ObjectStoreFullError: - ref_to_spill = pinned_objects.pop() - ray.experimental.force_spill_objects([ref_to_spill]) - spill_end = time.perf_counter() - - # Make sure to remove unpinned objects. - del pinned_objects - restore_start = time.perf_counter() - while replay_buffer: - ref = replay_buffer.pop() - sample = ray.get(ref) # noqa - restore_end = time.perf_counter() - - print(f"Object spilling benchmark for the config {object_spilling_config}") - print(f"Spilling {spill_cnt} number of objects of size {object_size}B " - f"takes {spill_end - spill_start} seconds with {max_io_workers} " - "number of io workers.") - print(f"Getting all objects takes {restore_end - restore_start} seconds.") - - def test_invalid_config_raises_exception(shutdown_only): # Make sure ray.init raises an exception before # it starts processes when invalid object spilling @@ -127,123 +74,38 @@ def test_url_generation_and_parse(): @pytest.mark.skipif( platform.system() == "Windows", reason="Failing on Windows.") -def test_spill_objects_manually(object_spilling_config, shutdown_only): +def test_spilling_not_done_for_pinned_object(tmp_path, shutdown_only): # Limit our object store to 75 MiB of memory. + temp_folder = tmp_path / "spill" + temp_folder.mkdir() ray.init( object_store_memory=75 * 1024 * 1024, _system_config={ - "object_store_full_max_retries": 0, - "automatic_object_spilling_enabled": False, "max_io_workers": 4, - "object_spilling_config": object_spilling_config, + "automatic_object_spilling_enabled": True, + "object_store_full_max_retries": 4, + "object_store_full_delay_ms": 100, + "object_spilling_config": json.dumps({ + "type": "filesystem", + "params": { + "directory_path": str(temp_folder) + } + }), "min_spilling_size": 0, - "automatic_object_deletion_enabled": False, }) - arr = np.random.rand(1024 * 1024) # 8 MB data - replay_buffer = [] - pinned_objects = set() + arr = np.random.rand(5 * 1024 * 1024) # 40 MB + ref = ray.get(ray.put(arr)) # noqa + # Since the ref exists, it should raise OOM. + with pytest.raises(ray.exceptions.ObjectStoreFullError): + ref2 = ray.put(arr) # noqa - # Create objects of more than 200 MiB. - for _ in range(25): - ref = None - while ref is None: - try: - ref = ray.put(arr) - replay_buffer.append(ref) - pinned_objects.add(ref) - except ray.exceptions.ObjectStoreFullError: - ref_to_spill = pinned_objects.pop() - ray.experimental.force_spill_objects([ref_to_spill]) + def is_dir_empty(): + num_files = 0 + for path in temp_folder.iterdir(): + num_files += 1 + return num_files == 0 - def is_worker(cmdline): - return cmdline and cmdline[0].startswith("ray::") - - # Make sure io workers are spawned with proper name. - processes = [ - x.cmdline()[0] for x in psutil.process_iter(attrs=["cmdline"]) - if is_worker(x.info["cmdline"]) - ] - assert ( - ray.ray_constants.WORKER_PROCESS_TYPE_SPILL_WORKER_IDLE in processes) - - # Spill 2 more objects so we will always have enough space for - # restoring objects back. - refs_to_spill = (pinned_objects.pop(), pinned_objects.pop()) - ray.experimental.force_spill_objects(refs_to_spill) - - # randomly sample objects - for _ in range(100): - ref = random.choice(replay_buffer) - sample = ray.get(ref) - assert np.array_equal(sample, arr) - - # Make sure io workers are spawned with proper name. - processes = [ - x.cmdline()[0] for x in psutil.process_iter(attrs=["cmdline"]) - if is_worker(x.info["cmdline"]) - ] - assert ( - ray.ray_constants.WORKER_PROCESS_TYPE_RESTORE_WORKER_IDLE in processes) - - -@pytest.mark.skipif( - platform.system() == "Windows", reason="Failing on Windows.") -def test_spill_objects_manually_from_workers(object_spilling_config, - shutdown_only): - # Limit our object store to 100 MiB of memory. - ray.init( - object_store_memory=100 * 1024 * 1024, - _system_config={ - "object_store_full_max_retries": 0, - "automatic_object_spilling_enabled": False, - "max_io_workers": 4, - "object_spilling_config": object_spilling_config, - "min_spilling_size": 0, - "automatic_object_deletion_enabled": False, - }) - - @ray.remote - def _worker(): - arr = np.random.rand(1024 * 1024) # 8 MB data - ref = ray.put(arr) - ray.experimental.force_spill_objects([ref]) - return ref - - # Create objects of more than 200 MiB. - replay_buffer = [ray.get(_worker.remote()) for _ in range(25)] - values = {ref: np.copy(ray.get(ref)) for ref in replay_buffer} - # Randomly sample objects. - for _ in range(100): - ref = random.choice(replay_buffer) - sample = ray.get(ref) - assert np.array_equal(sample, values[ref]) - - -@pytest.mark.skip(reason="Not implemented yet.") -def test_spill_objects_manually_with_workers(object_spilling_config, - shutdown_only): - # Limit our object store to 75 MiB of memory. - ray.init( - object_store_memory=100 * 1024 * 1024, - _system_config={ - "object_store_full_max_retries": 0, - "automatic_object_spilling_enabled": False, - "max_io_workers": 4, - "object_spilling_config": object_spilling_config, - "min_spilling_size": 0, - "automatic_object_deletion_enabled": False, - }) - arrays = [np.random.rand(100 * 1024) for _ in range(50)] - objects = [ray.put(arr) for arr in arrays] - - @ray.remote - def _worker(object_refs): - ray.experimental.force_spill_objects(object_refs) - - ray.get([_worker.remote([o]) for o in objects]) - - for restored, arr in zip(ray.get(objects), arrays): - assert np.array_equal(restored, arr) + wait_for_condition(is_dir_empty) @pytest.mark.skipif( @@ -255,7 +117,7 @@ def test_spill_objects_manually_with_workers(object_spilling_config, "_system_config": { "automatic_object_spilling_enabled": True, "object_store_full_max_retries": 4, - "object_store_full_initial_delay_ms": 100, + "object_store_full_delay_ms": 100, "max_io_workers": 4, "object_spilling_config": json.dumps({ "type": "filesystem", @@ -308,7 +170,7 @@ def test_spill_objects_automatically(object_spilling_config, shutdown_only): "max_io_workers": 4, "automatic_object_spilling_enabled": True, "object_store_full_max_retries": 4, - "object_store_full_initial_delay_ms": 100, + "object_store_full_delay_ms": 100, "object_spilling_config": object_spilling_config, "min_spilling_size": 0 }) @@ -344,7 +206,7 @@ def test_spill_during_get(object_spilling_config, shutdown_only): object_store_memory=100 * 1024 * 1024, _system_config={ "automatic_object_spilling_enabled": True, - "object_store_full_initial_delay_ms": 100, + "object_store_full_delay_ms": 100, # NOTE(swang): Use infinite retries because the OOM timer can still # get accidentally triggered when objects are released too slowly # (see github.com/ray-project/ray/issues/12040). @@ -381,7 +243,7 @@ def test_spill_deadlock(object_spilling_config, shutdown_only): "max_io_workers": 1, "automatic_object_spilling_enabled": True, "object_store_full_max_retries": 4, - "object_store_full_initial_delay_ms": 100, + "object_store_full_delay_ms": 100, "object_spilling_config": object_spilling_config, "min_spilling_size": 0, }) @@ -411,10 +273,11 @@ def test_delete_objects(tmp_path, shutdown_only): ray.init( object_store_memory=75 * 1024 * 1024, _system_config={ - "max_io_workers": 4, + "max_io_workers": 1, + "min_spilling_size": 0, "automatic_object_spilling_enabled": True, "object_store_full_max_retries": 4, - "object_store_full_initial_delay_ms": 100, + "object_store_full_delay_ms": 100, "object_spilling_config": json.dumps({ "type": "filesystem", "params": { @@ -454,9 +317,10 @@ def test_delete_objects_delete_while_creating(tmp_path, shutdown_only): object_store_memory=75 * 1024 * 1024, _system_config={ "max_io_workers": 4, + "min_spilling_size": 0, "automatic_object_spilling_enabled": True, "object_store_full_max_retries": 4, - "object_store_full_initial_delay_ms": 100, + "object_store_full_delay_ms": 100, "object_spilling_config": json.dumps({ "type": "filesystem", "params": { @@ -506,7 +370,7 @@ def test_delete_objects_on_worker_failure(tmp_path, shutdown_only): "max_io_workers": 4, "automatic_object_spilling_enabled": True, "object_store_full_max_retries": 4, - "object_store_full_initial_delay_ms": 100, + "object_store_full_delay_ms": 100, "object_spilling_config": json.dumps({ "type": "filesystem", "params": { @@ -579,9 +443,10 @@ def test_delete_objects_multi_node(tmp_path, ray_start_cluster): object_store_memory=75 * 1024 * 1024, _system_config={ "max_io_workers": 2, + "min_spilling_size": 20 * 1024 * 1024, "automatic_object_spilling_enabled": True, "object_store_full_max_retries": 4, - "object_store_full_initial_delay_ms": 100, + "object_store_full_delay_ms": 100, "object_spilling_config": json.dumps({ "type": "filesystem", "params": { @@ -648,14 +513,14 @@ def test_fusion_objects(tmp_path, shutdown_only): # Limit our object store to 75 MiB of memory. temp_folder = tmp_path / "spill" temp_folder.mkdir() - min_spilling_size = 30 * 1024 * 1024 + min_spilling_size = 10 * 1024 * 1024 ray.init( object_store_memory=75 * 1024 * 1024, _system_config={ - "max_io_workers": 4, + "max_io_workers": 3, "automatic_object_spilling_enabled": True, "object_store_full_max_retries": 4, - "object_store_full_initial_delay_ms": 100, + "object_store_full_delay_ms": 100, "object_spilling_config": json.dumps({ "type": "filesystem", "params": { diff --git a/python/ray/tests/test_reference_counting.py b/python/ray/tests/test_reference_counting.py index b93ee4221..a47a9a828 100644 --- a/python/ray/tests/test_reference_counting.py +++ b/python/ray/tests/test_reference_counting.py @@ -19,7 +19,6 @@ logger = logging.getLogger(__name__) @pytest.fixture def one_worker_100MiB(request): config = { - "object_store_full_max_retries": 2, "task_retry_delay_ms": 0, } yield ray.init( diff --git a/src/ray/common/ray_config_def.h b/src/ray/common/ray_config_def.h index c702102ad..f5f420463 100644 --- a/src/ray/common/ray_config_def.h +++ b/src/ray/common/ray_config_def.h @@ -243,10 +243,9 @@ RAY_CONFIG(int64_t, gcs_dump_debug_log_interval_minutes, 1) /// Maximum number of times to retry putting an object when the plasma store is full. /// Can be set to -1 to enable unlimited retries. -RAY_CONFIG(int32_t, object_store_full_max_retries, 5) +RAY_CONFIG(int32_t, object_store_full_max_retries, 1000) /// Duration to sleep after failing to put an object in plasma because it is full. -/// This will be exponentially increased for each retry. -RAY_CONFIG(uint32_t, object_store_full_initial_delay_ms, 1000) +RAY_CONFIG(uint32_t, object_store_full_delay_ms, 10) /// The amount of time to wait between logging plasma space usage debug messages. RAY_CONFIG(uint64_t, object_store_usage_log_interval_s, 10 * 60) @@ -254,6 +253,9 @@ RAY_CONFIG(uint64_t, object_store_usage_log_interval_s, 10 * 60) /// The amount of time between automatic local Python GC triggers. RAY_CONFIG(uint64_t, local_gc_interval_s, 10 * 60) +/// The min amount of time between local GCs (whether auto or mem pressure triggered). +RAY_CONFIG(uint64_t, local_gc_min_interval_s, 10) + /// Duration to wait between retries for failed tasks. RAY_CONFIG(uint32_t, task_retry_delay_ms, 5000) diff --git a/src/ray/core_worker/core_worker.cc b/src/ray/core_worker/core_worker.cc index 9bd4bf1f4..9d7099303 100644 --- a/src/ray/core_worker/core_worker.cc +++ b/src/ray/core_worker/core_worker.cc @@ -575,7 +575,8 @@ void CoreWorker::Exit(bool intentional) { << " received, this process will exit after all outstanding tasks have finished"; exiting_ = true; // Release the resources early in case draining takes a long time. - RAY_CHECK_OK(local_raylet_client_->NotifyDirectCallTaskBlocked()); + RAY_CHECK_OK( + local_raylet_client_->NotifyDirectCallTaskBlocked(/*release_resources*/ true)); // Callback to shutdown. auto shutdown = [this, intentional]() { @@ -2369,7 +2370,9 @@ void CoreWorker::HandleRestoreSpilledObjects( for (const auto &url : request.spilled_objects_url()) { spilled_objects_url.push_back(url); } - options_.restore_spilled_objects(object_ids_to_restore, spilled_objects_url); + auto total = + options_.restore_spilled_objects(object_ids_to_restore, spilled_objects_url); + reply->set_bytes_restored_total(total); send_reply_callback(Status::OK(), nullptr, nullptr); } else { send_reply_callback( diff --git a/src/ray/core_worker/core_worker.h b/src/ray/core_worker/core_worker.h index 4ecbe04d9..171a42d76 100644 --- a/src/ray/core_worker/core_worker.h +++ b/src/ray/core_worker/core_worker.h @@ -139,7 +139,7 @@ struct CoreWorkerOptions { /// Application-language callback to spill objects to external storage. std::function(const std::vector &)> spill_objects; /// Application-language callback to restore objects from external storage. - std::function &, const std::vector &)> + std::function &, const std::vector &)> restore_spilled_objects; /// Application-language callback to delete objects from external storage. std::function &, rpc::WorkerType)> diff --git a/src/ray/core_worker/store_provider/memory_store/memory_store.cc b/src/ray/core_worker/store_provider/memory_store/memory_store.cc index 45eb13952..0391b7a1d 100644 --- a/src/ray/core_worker/store_provider/memory_store/memory_store.cc +++ b/src/ray/core_worker/store_provider/memory_store/memory_store.cc @@ -232,16 +232,18 @@ bool CoreWorkerMemoryStore::Put(const RayObject &object, const ObjectID &object_ Status CoreWorkerMemoryStore::Get(const std::vector &object_ids, int num_objects, int64_t timeout_ms, const WorkerContext &ctx, bool remove_after_get, - std::vector> *results) { + std::vector> *results, + bool release_resources) { return GetImpl(object_ids, num_objects, timeout_ms, ctx, remove_after_get, results, - /*abort_if_any_object_is_exception=*/true); + /*abort_if_any_object_is_exception=*/true, release_resources); } Status CoreWorkerMemoryStore::GetImpl(const std::vector &object_ids, int num_objects, int64_t timeout_ms, const WorkerContext &ctx, bool remove_after_get, std::vector> *results, - bool abort_if_any_object_is_exception) { + bool abort_if_any_object_is_exception, + bool release_resources) { (*results).resize(object_ids.size(), nullptr); std::shared_ptr get_request; @@ -299,7 +301,8 @@ Status CoreWorkerMemoryStore::GetImpl(const std::vector &object_ids, // Wait for remaining objects (or timeout). if (should_notify_raylet) { - RAY_CHECK_OK(raylet_client_->NotifyDirectCallTaskBlocked()); + // SANG-TODO Implement memory store get + RAY_CHECK_OK(raylet_client_->NotifyDirectCallTaskBlocked(release_resources)); } bool done = false; @@ -374,11 +377,11 @@ Status CoreWorkerMemoryStore::Get( const absl::flat_hash_set &object_ids, int64_t timeout_ms, const WorkerContext &ctx, absl::flat_hash_map> *results, - bool *got_exception) { + bool *got_exception, bool release_resources) { const std::vector id_vector(object_ids.begin(), object_ids.end()); std::vector> result_objects; RAY_RETURN_NOT_OK(Get(id_vector, id_vector.size(), timeout_ms, ctx, - /*remove_after_get=*/false, &result_objects)); + /*remove_after_get=*/false, &result_objects, release_resources)); for (size_t i = 0; i < id_vector.size(); i++) { if (result_objects[i] != nullptr) { @@ -401,8 +404,9 @@ Status CoreWorkerMemoryStore::Wait(const absl::flat_hash_set &object_i std::vector id_vector(object_ids.begin(), object_ids.end()); std::vector> result_objects; RAY_CHECK(object_ids.size() == id_vector.size()); - auto status = GetImpl(id_vector, num_objects, timeout_ms, ctx, false, &result_objects, - /*abort_if_any_object_is_exception=*/false); + auto status = + GetImpl(id_vector, num_objects, timeout_ms, ctx, false, &result_objects, + /*abort_if_any_object_is_exception=*/false, /*release_resources=*/true); // Ignore TimedOut statuses since we return ready objects explicitly. if (!status.IsTimedOut()) { RAY_RETURN_NOT_OK(status); diff --git a/src/ray/core_worker/store_provider/memory_store/memory_store.h b/src/ray/core_worker/store_provider/memory_store/memory_store.h index 8d7bdfe65..faadafaff 100644 --- a/src/ray/core_worker/store_provider/memory_store/memory_store.h +++ b/src/ray/core_worker/store_provider/memory_store/memory_store.h @@ -58,13 +58,14 @@ class CoreWorkerMemoryStore { /// \return Status. Status Get(const std::vector &object_ids, int num_objects, int64_t timeout_ms, const WorkerContext &ctx, bool remove_after_get, - std::vector> *results); + std::vector> *results, + bool release_resources = true); /// Convenience wrapper around Get() that stores results in a given result map. Status Get(const absl::flat_hash_set &object_ids, int64_t timeout_ms, const WorkerContext &ctx, absl::flat_hash_map> *results, - bool *got_exception); + bool *got_exception, bool release_resources = true); /// Convenience wrapper around Get() that stores ready objects in a given result set. Status Wait(const absl::flat_hash_set &object_ids, int num_objects, @@ -137,11 +138,12 @@ class CoreWorkerMemoryStore { private: /// See the public version of `Get` for meaning of the other arguments. /// \param[in] abort_if_any_object_is_exception Whether we should abort if any object - /// is an exception. + /// \param[in] release_resources true if memory store blocking get needs to release + /// resources. is an exception. Status GetImpl(const std::vector &object_ids, int num_objects, int64_t timeout_ms, const WorkerContext &ctx, bool remove_after_get, std::vector> *results, - bool abort_if_any_object_is_exception); + bool abort_if_any_object_is_exception, bool release_resources); /// Optional callback for putting objects into the plasma store. std::function store_in_plasma_; diff --git a/src/ray/core_worker/store_provider/plasma_store_provider.cc b/src/ray/core_worker/store_provider/plasma_store_provider.cc index 5dca72612..25007a863 100644 --- a/src/ray/core_worker/store_provider/plasma_store_provider.cc +++ b/src/ray/core_worker/store_provider/plasma_store_provider.cc @@ -35,6 +35,7 @@ CoreWorkerPlasmaStoreProvider::CoreWorkerPlasmaStoreProvider( } else { get_current_call_site_ = []() { return ""; }; } + object_store_full_delay_ms_ = RayConfig::instance().object_store_full_delay_ms(); buffer_tracker_ = std::make_shared(); RAY_CHECK_OK(store_client_.Connect(store_socket)); if (warmup) { @@ -95,7 +96,8 @@ Status CoreWorkerPlasmaStoreProvider::Create(const std::shared_ptr &meta } while (retry_with_request_id > 0) { - std::this_thread::sleep_for(std::chrono::milliseconds(10)); + // TODO(sang): Use exponential backoff instead. + std::this_thread::sleep_for(std::chrono::milliseconds(object_store_full_delay_ms_)); { std::lock_guard guard(store_client_mutex_); RAY_LOG(DEBUG) << "Retrying request for object " << object_id << " with request ID " @@ -224,7 +226,7 @@ Status CoreWorkerPlasmaStoreProvider::Get( const absl::flat_hash_set &object_ids, int64_t timeout_ms, const WorkerContext &ctx, absl::flat_hash_map> *results, - bool *got_exception) { + bool *got_exception, bool release_resources) { int64_t batch_size = RayConfig::instance().worker_fetch_request_size(); std::vector batch_ids; absl::flat_hash_set remaining(object_ids.begin(), object_ids.end()); @@ -275,7 +277,7 @@ Status CoreWorkerPlasmaStoreProvider::Get( size_t previous_size = remaining.size(); // This is a separate IPC from the FetchAndGet in direct call mode. if (ctx.CurrentTaskIsDirectCall() && ctx.ShouldReleaseResourcesOnBlockingCalls()) { - RAY_RETURN_NOT_OK(raylet_client_->NotifyDirectCallTaskBlocked()); + RAY_RETURN_NOT_OK(raylet_client_->NotifyDirectCallTaskBlocked(release_resources)); } RAY_RETURN_NOT_OK( FetchAndGetFromPlasmaStore(remaining, batch_ids, batch_timeout, @@ -332,7 +334,9 @@ Status CoreWorkerPlasmaStoreProvider::Wait( // This is a separate IPC from the Wait in direct call mode. if (ctx.CurrentTaskIsDirectCall() && ctx.ShouldReleaseResourcesOnBlockingCalls()) { - RAY_RETURN_NOT_OK(raylet_client_->NotifyDirectCallTaskBlocked()); + // SANG-TODO Implement wait + RAY_RETURN_NOT_OK( + raylet_client_->NotifyDirectCallTaskBlocked(/*release_resources*/ true)); } const auto owner_addresses = reference_counter_->GetOwnerAddresses(id_vector); RAY_RETURN_NOT_OK( diff --git a/src/ray/core_worker/store_provider/plasma_store_provider.h b/src/ray/core_worker/store_provider/plasma_store_provider.h index ef9f4f850..88bed0428 100644 --- a/src/ray/core_worker/store_provider/plasma_store_provider.h +++ b/src/ray/core_worker/store_provider/plasma_store_provider.h @@ -90,7 +90,7 @@ class CoreWorkerPlasmaStoreProvider { Status Get(const absl::flat_hash_set &object_ids, int64_t timeout_ms, const WorkerContext &ctx, absl::flat_hash_map> *results, - bool *got_exception); + bool *got_exception, bool release_resources = true); Status Contains(const ObjectID &object_id, bool *has_object); @@ -154,6 +154,7 @@ class CoreWorkerPlasmaStoreProvider { std::mutex store_client_mutex_; std::function check_signals_; std::function get_current_call_site_; + uint32_t object_store_full_delay_ms_; // Active buffers tracker. This must be allocated as a separate structure since its // lifetime can exceed that of the store provider due to callback references. diff --git a/src/ray/gcs/gcs_server/gcs_object_manager.cc b/src/ray/gcs/gcs_server/gcs_object_manager.cc index 471bc896b..b5cc8f765 100644 --- a/src/ray/gcs/gcs_server/gcs_object_manager.cc +++ b/src/ray/gcs/gcs_server/gcs_object_manager.cc @@ -72,7 +72,9 @@ void GcsObjectManager::HandleAddObjectLocation( AddObjectLocationInCache(object_id, node_id); } else { absl::MutexLock lock(&mutex_); - object_to_locations_[object_id].spilled_url = request.spilled_url(); + RAY_CHECK(!request.spilled_url().empty()); + spilled_url = request.spilled_url(); + object_to_locations_[object_id].spilled_url = spilled_url; RAY_LOG(DEBUG) << "Adding object spilled location, object id = " << object_id; } @@ -91,7 +93,8 @@ void GcsObjectManager::HandleAddObjectLocation( notification.SerializeAsString(), nullptr)); RAY_LOG(DEBUG) << "Finished adding object location, job id = " << object_id.TaskId().JobId() << ", object id = " << object_id - << ", node id = " << node_id << ", task id = " << object_id.TaskId(); + << ", node id = " << node_id << ", task id = " << object_id.TaskId() + << ", spilled_url = " << spilled_url; } else { RAY_LOG(ERROR) << "Failed to add object location: " << status.ToString() << ", job id = " << object_id.TaskId().JobId() diff --git a/src/ray/object_manager/common.h b/src/ray/object_manager/common.h index c7c531ffc..9c71e2c2b 100644 --- a/src/ray/object_manager/common.h +++ b/src/ray/object_manager/common.h @@ -10,14 +10,8 @@ namespace ray { /// A callback to asynchronously spill objects when space is needed. -/// The callback tries to spill objects as much as num_bytes_to_spill and returns -/// the amount of space needed after the spilling is complete. -/// The returned value is calculated based off of min_bytes_to_spill. That says, -/// although it fails to spill num_bytes_to_spill, as long as it spills more than -/// min_bytes_to_spill, it will return the value that is less than 0 (meaning we -/// don't need any more additional space). -using SpillObjectsCallback = - std::function; +/// It spills enough objects to saturate all spill IO workers. +using SpillObjectsCallback = std::function; /// A callback to call when space has been released. using SpaceReleasedCallback = std::function; diff --git a/src/ray/object_manager/object_manager.cc b/src/ray/object_manager/object_manager.cc index 3d777be12..760909dc0 100644 --- a/src/ray/object_manager/object_manager.cc +++ b/src/ray/object_manager/object_manager.cc @@ -118,6 +118,13 @@ void ObjectManager::Stop() { } } +bool ObjectManager::IsPlasmaObjectSpillable(const ObjectID &object_id) { + if (plasma::plasma_store_runner != nullptr) { + return plasma::plasma_store_runner->IsPlasmaObjectSpillable(object_id); + } + return false; +} + void ObjectManager::RunRpcService() { rpc_service_.run(); } void ObjectManager::StartRpcService() { diff --git a/src/ray/object_manager/object_manager.h b/src/ray/object_manager/object_manager.h index 9579df30e..fdca6a190 100644 --- a/src/ray/object_manager/object_manager.h +++ b/src/ray/object_manager/object_manager.h @@ -206,6 +206,13 @@ class ObjectManager : public ObjectManagerInterface, /// signals from Raylet. void Stop(); + /// This methods call the plasma store which runs in a separate thread. + /// Check if the given object id is evictable by directly calling plasma store. + /// Plasma store will return true if the object is spillable, meaning it is only + /// pinned by the raylet, so we can comfotable evict after spilling the object from + /// local object manager. False otherwise. + bool IsPlasmaObjectSpillable(const ObjectID &object_id); + /// Subscribe to notifications of objects added to local store. /// Upon subscribing, the callback will be invoked for all objects that /// diff --git a/src/ray/object_manager/plasma/create_request_queue.cc b/src/ray/object_manager/plasma/create_request_queue.cc index 500fb1659..917a72102 100644 --- a/src/ray/object_manager/plasma/create_request_queue.cc +++ b/src/ray/object_manager/plasma/create_request_queue.cc @@ -69,17 +69,21 @@ std::pair CreateRequestQueue::TryRequestImmediately( auto req_id = AddRequest(object_id, client, create_callback); if (!ProcessRequests().ok()) { // If the request was not immediately fulfillable, finish it. - RAY_CHECK(!queue_.empty()); - FinishRequest(queue_.begin()); + if (!queue_.empty()) { + // Some errors such as a transient OOM error doesn't finish the request, so we + // should finish it here. + FinishRequest(queue_.begin()); + } } PlasmaError error; RAY_CHECK(GetRequestResult(req_id, &result, &error)); return {result, error}; } -Status CreateRequestQueue::ProcessRequest(std::unique_ptr &request) { +bool CreateRequestQueue::ProcessRequest(std::unique_ptr &request) { // Return an OOM error to the client if we have hit the maximum number of // retries. + // TODO(sang): Delete this logic? bool evict_if_full = evict_if_full_; if (max_retries_ == 0) { // If we cannot retry, then always evict on the first attempt. @@ -88,50 +92,36 @@ Status CreateRequestQueue::ProcessRequest(std::unique_ptr &reques // Always try to evict after the first attempt. evict_if_full = true; } - request->error = request->create_callback(evict_if_full, &request->result); - Status status; - auto should_retry_on_oom = max_retries_ == -1 || num_retries_ < max_retries_; - if (request->error == PlasmaError::TransientOutOfMemory) { - // The object store is full, but we should wait for space to be made - // through spilling, so do nothing. The caller must guarantee that - // ProcessRequests is called again so that we can try this request again. - // NOTE(swang): There could be other requests behind this one that are - // actually serviceable. This may be inefficient, but eventually this - // request will get served and unblock the following requests, once - // enough objects have been spilled. - // TODO(swang): Ask the raylet to spill enough space for multiple requests - // at once, instead of just the head of the queue. - num_retries_ = 0; - status = - Status::TransientObjectStoreFull("Object store full, queueing creation request"); - } else if (request->error == PlasmaError::OutOfMemory && should_retry_on_oom) { - num_retries_++; - RAY_LOG(DEBUG) << "Not enough memory to create the object, after " << num_retries_ - << " tries"; - - if (trigger_global_gc_) { - trigger_global_gc_(); - } - - status = Status::ObjectStoreFull("Object store full, should retry on timeout"); - } else if (request->error == PlasmaError::OutOfMemory) { - RAY_LOG(ERROR) << "Not enough memory to create object " << request->object_id - << " after " << num_retries_ - << " tries, will return OutOfMemory to the client"; - } - - return status; + return request->error != PlasmaError::OutOfMemory; } Status CreateRequestQueue::ProcessRequests() { while (!queue_.empty()) { auto request_it = queue_.begin(); - auto status = ProcessRequest(*request_it); - if (status.IsTransientObjectStoreFull() || status.IsObjectStoreFull()) { - return status; + auto create_ok = ProcessRequest(*request_it); + if (create_ok) { + FinishRequest(request_it); + } else { + if (trigger_global_gc_) { + trigger_global_gc_(); + } + + if (spill_objects_callback_()) { + return Status::TransientObjectStoreFull("Waiting for spilling."); + } else if (num_retries_ < max_retries_ || max_retries_ == -1) { + // We need a grace period since (1) global GC takes a bit of time to + // kick in, and (2) there is a race between spilling finishing and space + // actually freeing up in the object store. + // If max_retries == -1, we retry infinitely. + num_retries_ += 1; + return Status::ObjectStoreFull("Waiting for grace period."); + } else { + // Raise OOM. In this case, the request will be marked as OOM. + // We don't return so that we can process the next entry right away. + FinishRequest(request_it); + } } - FinishRequest(request_it); } return Status::OK(); } diff --git a/src/ray/object_manager/plasma/create_request_queue.h b/src/ray/object_manager/plasma/create_request_queue.h index f7d21fb97..212ce69e6 100644 --- a/src/ray/object_manager/plasma/create_request_queue.h +++ b/src/ray/object_manager/plasma/create_request_queue.h @@ -21,6 +21,7 @@ #include "absl/container/flat_hash_map.h" #include "ray/common/status.h" +#include "ray/object_manager/common.h" #include "ray/object_manager/plasma/common.h" #include "ray/object_manager/plasma/connection.h" #include "ray/object_manager/plasma/plasma.h" @@ -34,9 +35,11 @@ class CreateRequestQueue { std::function; CreateRequestQueue(int32_t max_retries, bool evict_if_full, + ray::SpillObjectsCallback spill_objects_callback, std::function trigger_global_gc) : max_retries_(max_retries), evict_if_full_(evict_if_full), + spill_objects_callback_(spill_objects_callback), trigger_global_gc_(trigger_global_gc) { RAY_LOG(DEBUG) << "Starting plasma::CreateRequestQueue with " << max_retries_ << " retries on OOM, evict if full? " << (evict_if_full_ ? 1 : 0); @@ -136,7 +139,7 @@ class CreateRequestQueue { /// Process a single request. Sets the request's error result to the error /// returned by the request handler inside. Returns OK if the request can be /// finished. - Status ProcessRequest(std::unique_ptr &request); + bool ProcessRequest(std::unique_ptr &request); /// Finish a queued request and remove it from the queue. void FinishRequest(std::list>::iterator request_it); @@ -156,6 +159,11 @@ class CreateRequestQueue { /// always try to evict. const bool evict_if_full_; + /// A callback to trigger object spilling. It tries to spill objects upto max + /// throughput. It returns true if space is made by object spilling, and false if + /// there's no more space to be made. + ray::SpillObjectsCallback spill_objects_callback_; + /// A callback to trigger global GC in the cluster if the object store is /// full. const std::function trigger_global_gc_; @@ -178,6 +186,9 @@ class CreateRequestQueue { /// finished. absl::flat_hash_map> fulfilled_requests_; + /// Last time global gc was invoked in ms. + uint64_t last_global_gc_ms_; + friend class CreateRequestQueueTest; }; diff --git a/src/ray/object_manager/plasma/eviction_policy.cc b/src/ray/object_manager/plasma/eviction_policy.cc index 2889a0760..0920a386f 100644 --- a/src/ray/object_manager/plasma/eviction_policy.cc +++ b/src/ray/object_manager/plasma/eviction_policy.cc @@ -132,10 +132,10 @@ int64_t EvictionPolicy::RequireSpace(int64_t size, RAY_LOG(DEBUG) << "not enough space to create this object, so evicting objects"; // Choose some objects to evict, and update the return pointers. int64_t num_bytes_evicted = ChooseObjectsToEvict(space_to_free, objects_to_evict); - RAY_LOG(INFO) << "There is not enough space to create this object, so evicting " - << objects_to_evict->size() << " objects to free up " << num_bytes_evicted - << " bytes. The number of bytes in use (before " - << "this eviction) is " << PlasmaAllocator::Allocated() << "."; + RAY_LOG(DEBUG) << "There is not enough space to create this object, so evicting " + << objects_to_evict->size() << " objects to free up " + << num_bytes_evicted << " bytes. The number of bytes in use (before " + << "this eviction) is " << PlasmaAllocator::Allocated() << "."; return required_space - num_bytes_evicted; } diff --git a/src/ray/object_manager/plasma/plasma.fbs b/src/ray/object_manager/plasma/plasma.fbs index ff8099ea7..3816de79e 100644 --- a/src/ray/object_manager/plasma/plasma.fbs +++ b/src/ray/object_manager/plasma/plasma.fbs @@ -82,11 +82,6 @@ enum PlasmaError:int { ObjectNonexistent, // Trying to create an object but there isn't enough space in the store. OutOfMemory, - // Trying to create an object but there isn't enough space in the store. - // However, objects are currently being spilled to make enough space. The - // client should try again soon, and there will be enough space (assuming the - // space is not taken by another client). - TransientOutOfMemory, // Trying to delete an object but it's not sealed. ObjectNotSealed, // Trying to delete an object but it's in use. @@ -162,7 +157,7 @@ table PlasmaCreateRetryRequest { object_id: string; // The ID of the request to retry. request_id: uint64; - } +} table CudaHandle { handle: [ubyte]; diff --git a/src/ray/object_manager/plasma/protocol.cc b/src/ray/object_manager/plasma/protocol.cc index e80d0fc05..497bb6907 100644 --- a/src/ray/object_manager/plasma/protocol.cc +++ b/src/ray/object_manager/plasma/protocol.cc @@ -131,9 +131,6 @@ Status PlasmaErrorStatus(fb::PlasmaError plasma_error) { return Status::ObjectNotFound("object does not exist in the plasma store"); case fb::PlasmaError::OutOfMemory: return Status::ObjectStoreFull("object does not fit in the plasma store"); - case fb::PlasmaError::TransientOutOfMemory: - return Status::ObjectStoreFull( - "object does not fit in the plasma store, spilling objects to make room"); case fb::PlasmaError::UnexpectedError: return Status::UnknownError( "an unexpected error occurred, likely due to a bug in the system or caller"); diff --git a/src/ray/object_manager/plasma/store.cc b/src/ray/object_manager/plasma/store.cc index 0e1fa3af2..a3e5fc019 100644 --- a/src/ray/object_manager/plasma/store.cc +++ b/src/ray/object_manager/plasma/store.cc @@ -138,7 +138,7 @@ PlasmaStore::PlasmaStore(boost::asio::io_service &main_service, std::string dire create_request_queue_( RayConfig::instance().object_store_full_max_retries(), /*evict_if_full=*/RayConfig::instance().object_pinning_enabled(), - object_store_full_callback) { + spill_objects_callback, object_store_full_callback) { store_info_.directory = directory; store_info_.hugepages_enabled = hugepages_enabled; #ifdef PLASMA_CUDA @@ -223,34 +223,7 @@ uint8_t *PlasmaStore::AllocateMemory(size_t size, bool evict_if_full, MEMFD_TYPE // More space is still needed. Try to spill objects to external storage to // make room. if (space_needed > 0) { - if (spill_objects_callback_) { - // If the space needed is too small, we'd like to bump up to the minimum - // size. Cap the max size to be lower than the plasma store limit. - int64_t byte_to_spill = - std::min(PlasmaAllocator::GetFootprintLimit(), - std::max(space_needed, RayConfig::instance().min_spilling_size())); - // Object spilling is asynchronous so that we do not block the plasma - // store thread. Therefore the client must try again, even if enough - // space will be made after the spill is complete. - // TODO(swang): Only respond to the client with OutOfMemory if we could not - // make enough space through spilling. If we could make enough space, - // respond to the plasma client once spilling is complete. - space_needed = spill_objects_callback_(byte_to_spill, space_needed); - } - if (space_needed > 0) { - // There is still not enough space, even once all evictable objects - // were evicted and all pending object spills have finished. The - // client may choose to try again, or throw an OutOfMemory error to - // the application immediately. - *error = PlasmaError::OutOfMemory; - } else { - // Once all pending object spills have finished, there should be - // enough space for this allocation. Return a transient error to the - // client so that they try again soon. - *error = PlasmaError::TransientOutOfMemory; - } - // Return an error to the client if not enough space could be freed to - // create the object. + *error = PlasmaError::OutOfMemory; break; } } @@ -311,9 +284,8 @@ PlasmaError PlasmaStore::HandleCreateObjectRequest(const std::shared_ptr owner_worker_id, evict_if_full, data_size, metadata_size, device_num, client, object); if (error == PlasmaError::OutOfMemory) { - RAY_LOG(WARNING) << "Not enough memory to create the object " << object_id - << ", data_size=" << data_size - << ", metadata_size=" << metadata_size; + RAY_LOG(DEBUG) << "Not enough memory to create the object " << object_id + << ", data_size=" << data_size << ", metadata_size=" << metadata_size; } return error; } @@ -551,8 +523,8 @@ void PlasmaStore::ProcessGetRequest(const std::shared_ptr &client, std::vector evicted_ids; std::vector evicted_entries; for (auto object_id : object_ids) { - // Check if this object is already present locally. If so, record that the - // object is being used and mark it as accounted for. + // Check if this object is already present + // locally. If so, record that the object is being used and mark it as accounted for. auto entry = GetObjectTableEntry(&store_info_, object_id); if (entry && entry->state == ObjectState::PLASMA_SEALED) { // Update the get request to take into account the present object. @@ -972,6 +944,9 @@ void PlasmaStore::SubscribeToUpdates(const std::shared_ptr &client) { Status PlasmaStore::ProcessMessage(const std::shared_ptr &client, fb::MessageType type, const std::vector &message) { + // Global lock is used here so that we allow raylet to access some of methods + // that are required for object spilling directly without releasing a lock. + std::lock_guard guard(mutex_); // TODO(suquark): We should convert these interfaces to const later. uint8_t *input = (uint8_t *)message.data(); size_t input_size = message.size(); @@ -1116,9 +1091,7 @@ void PlasmaStore::ProcessCreateRequests() { auto status = create_request_queue_.ProcessRequests(); uint32_t retry_after_ms = 0; - if (status.IsTransientObjectStoreFull()) { - retry_after_ms = delay_on_transient_oom_ms_; - } else if (status.IsObjectStoreFull()) { + if (!status.ok()) { retry_after_ms = delay_on_oom_ms_; } @@ -1151,4 +1124,12 @@ void PlasmaStore::ReplyToCreateClient(const std::shared_ptr &client, } } +bool PlasmaStore::IsObjectSpillable(const ObjectID &object_id) { + // The lock is acquired when a request is received to the plasma store. + // recursive mutex is used here to allow + std::lock_guard guard(mutex_); + auto entry = GetObjectTableEntry(&store_info_, object_id); + return entry->ref_count == 1; +} + } // namespace plasma diff --git a/src/ray/object_manager/plasma/store.h b/src/ray/object_manager/plasma/store.h index 5ef2cd654..b6494e6bc 100644 --- a/src/ray/object_manager/plasma/store.h +++ b/src/ray/object_manager/plasma/store.h @@ -99,9 +99,6 @@ class PlasmaStore { /// - PlasmaError::OutOfMemory, if the store is out of memory and /// cannot create the object. In this case, the client should not call /// plasma_release. - /// - PlasmaError::TransientOutOfMemory, if the store is temporarily out of - /// memory but there may be space soon to create the object. In this - /// case, the client should not call plasma_release. PlasmaError CreateObject(const ObjectID &object_id, const NodeID &owner_raylet_id, const std::string &owner_ip_address, int owner_port, const WorkerID &owner_worker_id, bool evict_if_full, @@ -186,6 +183,14 @@ class PlasmaStore { plasma::flatbuf::MessageType type, const std::vector &message); + /// Return true if the given object id has only one reference. + /// Only one reference means there's only a raylet that pins the object + /// so it is safe to spill the object. + /// NOTE: Avoid using this method outside object spilling context (e.g., unless you + /// absolutely know what's going on). This method won't work correctly if it is used + /// before the object is pinned by raylet for the first time. + bool IsObjectSpillable(const ObjectID &object_id); + void SetNotificationListener( const std::shared_ptr ¬ification_listener) { notification_listener_ = notification_listener; @@ -286,16 +291,14 @@ class PlasmaStore { /// A callback to asynchronously spill objects when space is needed. The /// callback returns the amount of space still needed after the spilling is /// complete. + /// NOTE: This function should guarantee the thread-safety because the callback is + /// shared with the main raylet thread. ray::SpillObjectsCallback spill_objects_callback_; /// The amount of time to wait before retrying a creation request after an /// OOM error. const uint32_t delay_on_oom_ms_; - /// The amount of time to wait before retrying a creation request after a - /// transient OOM error. - const uint32_t delay_on_transient_oom_ms_ = 10; - /// The amount of time to wait between logging space usage debug messages. const uint64_t usage_log_interval_ns_; @@ -309,6 +312,14 @@ class PlasmaStore { /// Queue of object creation requests. CreateRequestQueue create_request_queue_; + + /// This mutex is used in order to make plasma store threas-safe with raylet. + /// Raylet's local_object_manager needs to ping access plasma store's method in order to + /// figure out the correct view of the object store. recursive_mutex is used to avoid + /// deadlock while we keep the simplest possible change. NOTE(sang): Avoid adding more + /// interface that node manager or object manager can access the plasma store with this + /// mutex if it is not absolutely necessary. + std::recursive_mutex mutex_; }; } // namespace plasma diff --git a/src/ray/object_manager/plasma/store_runner.cc b/src/ray/object_manager/plasma/store_runner.cc index 152e386aa..1fc6a0662 100644 --- a/src/ray/object_manager/plasma/store_runner.cc +++ b/src/ray/object_manager/plasma/store_runner.cc @@ -94,10 +94,10 @@ void PlasmaStoreRunner::Start(ray::SpillObjectsCallback spill_objects_callback, { absl::MutexLock lock(&store_runner_mutex_); - store_.reset(new PlasmaStore( - main_service_, plasma_directory_, hugepages_enabled_, socket_name_, - external_store, RayConfig::instance().object_store_full_initial_delay_ms(), - spill_objects_callback, object_store_full_callback)); + store_.reset(new PlasmaStore(main_service_, plasma_directory_, hugepages_enabled_, + socket_name_, external_store, + RayConfig::instance().object_store_full_delay_ms(), + spill_objects_callback, object_store_full_callback)); plasma_config = store_->GetPlasmaStoreInfo(); // We are using a single memory-mapped file by mallocing and freeing a single @@ -134,6 +134,10 @@ void PlasmaStoreRunner::Shutdown() { } } +bool PlasmaStoreRunner::IsPlasmaObjectSpillable(const ObjectID &object_id) { + return store_->IsObjectSpillable(object_id); +} + std::unique_ptr plasma_store_runner; } // namespace plasma diff --git a/src/ray/object_manager/plasma/store_runner.h b/src/ray/object_manager/plasma/store_runner.h index 2f6a61cd5..07317c25d 100644 --- a/src/ray/object_manager/plasma/store_runner.h +++ b/src/ray/object_manager/plasma/store_runner.h @@ -22,6 +22,7 @@ class PlasmaStoreRunner { const std::shared_ptr ¬ification_listener) { store_->SetNotificationListener(notification_listener); } + bool IsPlasmaObjectSpillable(const ObjectID &object_id); private: void Shutdown(); diff --git a/src/ray/object_manager/pull_manager.cc b/src/ray/object_manager/pull_manager.cc index 082426cc1..7632c5c7b 100644 --- a/src/ray/object_manager/pull_manager.cc +++ b/src/ray/object_manager/pull_manager.cc @@ -46,6 +46,8 @@ void PullManager::OnLocationChange(const ObjectID &object_id, // before. it->second.client_locations = std::vector(client_ids.begin(), client_ids.end()); if (!spilled_url.empty()) { + RAY_LOG(DEBUG) << "OnLocationChange " << spilled_url << " num clients " + << client_ids.size(); // Try to restore the spilled object. restore_spilled_object_(object_id, spilled_url, [this, object_id](const ray::Status &status) { diff --git a/src/ray/object_manager/test/create_request_queue_test.cc b/src/ray/object_manager/test/create_request_queue_test.cc index 7d16d0b80..de60807a6 100644 --- a/src/ray/object_manager/test/create_request_queue_test.cc +++ b/src/ray/object_manager/test/create_request_queue_test.cc @@ -49,6 +49,7 @@ class CreateRequestQueueTest : public ::testing::Test { : queue_( /*max_retries=*/2, /*evict_if_full=*/true, + /*spill_object_callback=*/[&]() { return false; }, /*on_global_gc=*/[&]() { num_global_gc_++; }) {} void AssertNoLeaks() { @@ -117,7 +118,7 @@ TEST_F(CreateRequestQueueTest, TestOom) { // Retries used up. The first request should reply with OOM and the second // request should also be served. ASSERT_TRUE(queue_.ProcessRequests().ok()); - ASSERT_EQ(num_global_gc_, 2); + ASSERT_EQ(num_global_gc_, 3); // Both requests fulfilled. ASSERT_REQUEST_FINISHED(queue_, req_id1, PlasmaError::OutOfMemory); @@ -131,6 +132,8 @@ TEST(CreateRequestQueueParameterTest, TestOomInfiniteRetry) { CreateRequestQueue queue( /*max_retries=*/-1, /*evict_if_full=*/true, + // Spilling is failing. + /*spill_object_callback=*/[&]() { return false; }, /*on_global_gc=*/[&]() { num_global_gc_++; }); auto oom_request = [&](bool evict_if_full, PlasmaObject *result) { @@ -156,7 +159,13 @@ TEST(CreateRequestQueueParameterTest, TestOomInfiniteRetry) { } TEST_F(CreateRequestQueueTest, TestTransientOom) { - auto return_status = PlasmaError::TransientOutOfMemory; + CreateRequestQueue queue( + /*max_retries=*/2, + /*evict_if_full=*/false, + /*spill_object_callback=*/[&]() { return true; }, + /*on_global_gc=*/[&]() { num_global_gc_++; }); + + auto return_status = PlasmaError::OutOfMemory; auto oom_request = [&](bool evict_if_full, PlasmaObject *result) { if (return_status == PlasmaError::OK) { result->data_size = 1234; @@ -169,28 +178,35 @@ TEST_F(CreateRequestQueueTest, TestTransientOom) { }; auto client = std::make_shared(); - auto req_id1 = queue_.AddRequest(ObjectID::Nil(), client, oom_request); - auto req_id2 = queue_.AddRequest(ObjectID::Nil(), client, blocked_request); + auto req_id1 = queue.AddRequest(ObjectID::Nil(), client, oom_request); + auto req_id2 = queue.AddRequest(ObjectID::Nil(), client, blocked_request); // Transient OOM should not use up any retries. for (int i = 0; i < 3; i++) { - ASSERT_TRUE(queue_.ProcessRequests().IsTransientObjectStoreFull()); - ASSERT_REQUEST_UNFINISHED(queue_, req_id1); - ASSERT_REQUEST_UNFINISHED(queue_, req_id2); - ASSERT_EQ(num_global_gc_, 0); + ASSERT_TRUE(queue.ProcessRequests().IsTransientObjectStoreFull()); + ASSERT_REQUEST_UNFINISHED(queue, req_id1); + ASSERT_REQUEST_UNFINISHED(queue, req_id2); + ASSERT_EQ(num_global_gc_, i + 1); } // Return OK for the first request. The second request should also be served. return_status = PlasmaError::OK; - ASSERT_TRUE(queue_.ProcessRequests().ok()); - ASSERT_REQUEST_FINISHED(queue_, req_id1, PlasmaError::OK); - ASSERT_REQUEST_FINISHED(queue_, req_id2, PlasmaError::OK); + ASSERT_TRUE(queue.ProcessRequests().ok()); + ASSERT_REQUEST_FINISHED(queue, req_id1, PlasmaError::OK); + ASSERT_REQUEST_FINISHED(queue, req_id2, PlasmaError::OK); AssertNoLeaks(); } TEST_F(CreateRequestQueueTest, TestTransientOomThenOom) { - auto return_status = PlasmaError::TransientOutOfMemory; + bool is_spilling_possible = true; + CreateRequestQueue queue( + /*max_retries=*/2, + /*evict_if_full=*/false, + /*spill_object_callback=*/[&]() { return is_spilling_possible; }, + /*on_global_gc=*/[&]() { num_global_gc_++; }); + + auto return_status = PlasmaError::OutOfMemory; auto oom_request = [&](bool evict_if_full, PlasmaObject *result) { if (return_status == PlasmaError::OK) { result->data_size = 1234; @@ -203,31 +219,31 @@ TEST_F(CreateRequestQueueTest, TestTransientOomThenOom) { }; auto client = std::make_shared(); - auto req_id1 = queue_.AddRequest(ObjectID::Nil(), client, oom_request); - auto req_id2 = queue_.AddRequest(ObjectID::Nil(), client, blocked_request); + auto req_id1 = queue.AddRequest(ObjectID::Nil(), client, oom_request); + auto req_id2 = queue.AddRequest(ObjectID::Nil(), client, blocked_request); // Transient OOM should not use up any retries. for (int i = 0; i < 3; i++) { - ASSERT_TRUE(queue_.ProcessRequests().IsTransientObjectStoreFull()); - ASSERT_REQUEST_UNFINISHED(queue_, req_id1); - ASSERT_REQUEST_UNFINISHED(queue_, req_id2); - ASSERT_EQ(num_global_gc_, 0); + ASSERT_TRUE(queue.ProcessRequests().IsTransientObjectStoreFull()); + ASSERT_REQUEST_UNFINISHED(queue, req_id1); + ASSERT_REQUEST_UNFINISHED(queue, req_id2); + ASSERT_EQ(num_global_gc_, i + 1); } - // Now we are actually OOM. - return_status = PlasmaError::OutOfMemory; - ASSERT_TRUE(queue_.ProcessRequests().IsObjectStoreFull()); - ASSERT_TRUE(queue_.ProcessRequests().IsObjectStoreFull()); - ASSERT_REQUEST_UNFINISHED(queue_, req_id1); - ASSERT_REQUEST_UNFINISHED(queue_, req_id2); - ASSERT_EQ(num_global_gc_, 2); + // Now spilling is not possible. We should start raising OOM with retry. + is_spilling_possible = false; + ASSERT_TRUE(queue.ProcessRequests().IsObjectStoreFull()); + ASSERT_TRUE(queue.ProcessRequests().IsObjectStoreFull()); + ASSERT_REQUEST_UNFINISHED(queue, req_id1); + ASSERT_REQUEST_UNFINISHED(queue, req_id2); + ASSERT_EQ(num_global_gc_, 5); // Retries used up. The first request should reply with OOM and the second // request should also be served. - ASSERT_TRUE(queue_.ProcessRequests().ok()); - ASSERT_REQUEST_FINISHED(queue_, req_id1, PlasmaError::OutOfMemory); - ASSERT_REQUEST_FINISHED(queue_, req_id2, PlasmaError::OK); - ASSERT_EQ(num_global_gc_, 2); + ASSERT_TRUE(queue.ProcessRequests().ok()); + ASSERT_REQUEST_FINISHED(queue, req_id1, PlasmaError::OutOfMemory); + ASSERT_REQUEST_FINISHED(queue, req_id2, PlasmaError::OK); + ASSERT_EQ(num_global_gc_, 6); AssertNoLeaks(); } @@ -248,6 +264,7 @@ TEST(CreateRequestQueueParameterTest, TestNoEvictIfFull) { CreateRequestQueue queue( /*max_retries=*/2, /*evict_if_full=*/false, + /*spill_object_callback=*/[&]() { return false; }, /*on_global_gc=*/[&]() {}); bool first_try = true; diff --git a/src/ray/protobuf/core_worker.proto b/src/ray/protobuf/core_worker.proto index c95a58106..799530d27 100644 --- a/src/ray/protobuf/core_worker.proto +++ b/src/ray/protobuf/core_worker.proto @@ -315,6 +315,7 @@ message RestoreSpilledObjectsRequest { } message RestoreSpilledObjectsReply { + int64 bytes_restored_total = 1; } message DeleteSpilledObjectsRequest { diff --git a/src/ray/raylet/format/node_manager.fbs b/src/ray/raylet/format/node_manager.fbs index c62754b75..fb95bbc61 100644 --- a/src/ray/raylet/format/node_manager.fbs +++ b/src/ray/raylet/format/node_manager.fbs @@ -194,6 +194,7 @@ table NotifyUnblocked { } table NotifyDirectCallTaskBlocked { + release_resources: bool; } table NotifyDirectCallTaskUnblocked { diff --git a/src/ray/raylet/local_object_manager.cc b/src/ray/raylet/local_object_manager.cc index b42641a1e..87446f84f 100644 --- a/src/ray/raylet/local_object_manager.cc +++ b/src/ray/raylet/local_object_manager.cc @@ -22,7 +22,6 @@ namespace raylet { void LocalObjectManager::PinObjects(const std::vector &object_ids, std::vector> &&objects) { - absl::MutexLock lock(&mutex_); RAY_CHECK(object_pinning_enabled_); for (size_t i = 0; i < object_ids.size(); i++) { const auto &object_id = object_ids[i]; @@ -62,7 +61,6 @@ void LocalObjectManager::WaitForObjectFree(const rpc::Address &owner_address, void LocalObjectManager::ReleaseFreedObject(const ObjectID &object_id) { // object_pinning_enabled_ flag is off when the --lru-evict flag is on. if (object_pinning_enabled_) { - absl::MutexLock lock(&mutex_); RAY_LOG(DEBUG) << "Unpinning object " << object_id; // The object should be in one of these stats. pinned, spilling, or spilled. RAY_CHECK((pinned_objects_.count(object_id) > 0) || @@ -104,50 +102,85 @@ void LocalObjectManager::FlushFreeObjectsIfNeeded(int64_t now_ms) { } } -int64_t LocalObjectManager::SpillObjectsOfSize(int64_t num_bytes_to_spill, - int64_t min_bytes_to_spill) { - RAY_CHECK(num_bytes_to_spill >= min_bytes_to_spill); - +void LocalObjectManager::SpillObjectUptoMaxThroughput() { if (RayConfig::instance().object_spilling_config().empty() || !RayConfig::instance().automatic_object_spilling_enabled()) { - return min_bytes_to_spill; + return; } - absl::MutexLock lock(&mutex_); + // Spill as fast as we can using all our spill workers. + bool can_spill_more = true; + while (can_spill_more) { + if (!SpillObjectsOfSize(min_spilling_size_)) { + break; + } + { + absl::MutexLock lock(&mutex_); + num_active_workers_ += 1; + can_spill_more = num_active_workers_ < max_active_workers_; + } + } +} - RAY_LOG(INFO) << "Choosing objects to spill of total size " << num_bytes_to_spill; +bool LocalObjectManager::IsSpillingInProgress() { + absl::MutexLock lock(&mutex_); + return num_active_workers_ > 0; +} + +bool LocalObjectManager::SpillObjectsOfSize(int64_t num_bytes_to_spill) { + if (RayConfig::instance().object_spilling_config().empty() || + !RayConfig::instance().automatic_object_spilling_enabled()) { + return false; + } + + RAY_LOG(DEBUG) << "Choosing objects to spill of total size " << num_bytes_to_spill; int64_t bytes_to_spill = 0; auto it = pinned_objects_.begin(); std::vector objects_to_spill; - while (bytes_to_spill < num_bytes_to_spill && it != pinned_objects_.end()) { - bytes_to_spill += it->second->GetSize(); - objects_to_spill.push_back(it->first); + while (bytes_to_spill <= num_bytes_to_spill && it != pinned_objects_.end()) { + if (is_plasma_object_spillable_(it->first)) { + bytes_to_spill += it->second->GetSize(); + objects_to_spill.push_back(it->first); + } it++; } if (!objects_to_spill.empty()) { - RAY_LOG(INFO) << "Spilling objects of total size " << bytes_to_spill; - auto start_time = current_time_ms(); - SpillObjectsInternal( - objects_to_spill, [bytes_to_spill, start_time](const Status &status) { - if (!status.ok()) { - RAY_LOG(ERROR) << "Error spilling objects " << status.ToString(); - } else { - RAY_LOG(INFO) << "Spilled " << bytes_to_spill << " in " - << (current_time_ms() - start_time) << "ms"; - } - }); + RAY_LOG(DEBUG) << "Spilling objects of total size " << bytes_to_spill + << " num objects " << objects_to_spill.size(); + auto start_time = absl::GetCurrentTimeNanos(); + SpillObjectsInternal(objects_to_spill, [this, bytes_to_spill, objects_to_spill, + start_time](const Status &status) { + if (!status.ok()) { + RAY_LOG(ERROR) << "Error spilling objects " << status.ToString(); + } else { + auto now = absl::GetCurrentTimeNanos(); + RAY_LOG(DEBUG) << "Spilled " << bytes_to_spill << " bytes in " + << (now - start_time) / 1e6 << "ms"; + spilled_bytes_total_ += bytes_to_spill; + spilled_objects_total_ += objects_to_spill.size(); + // Adjust throughput timing to account for concurrent spill operations. + spill_time_total_s_ += (now - std::max(start_time, last_spill_finish_ns_)) / 1e9; + if (now - last_spill_log_ns_ > 1e9) { + last_spill_log_ns_ = now; + // TODO(ekl) logging at error level until we add a better UX indicator. + RAY_LOG(ERROR) << "Spilled " + << static_cast(spilled_bytes_total_ / (1024 * 1024)) + << " MiB, " << spilled_objects_total_ + << " objects, write throughput " + << static_cast(spilled_bytes_total_ / (1024 * 1024) / + spill_time_total_s_) + << " MiB/s"; + } + last_spill_finish_ns_ = now; + } + }); + return true; } - // We do not track a mapping between objects that need to be created to - // objects that are being spilled, so we just subtract the total number of - // bytes that are currently being spilled from the amount of space - // requested. If the space is claimed by another client, this client may - // need to request space again. - return min_bytes_to_spill - num_bytes_pending_spill_; + return false; } void LocalObjectManager::SpillObjects(const std::vector &object_ids, std::function callback) { - absl::MutexLock lock(&mutex_); SpillObjectsInternal(object_ids, callback); } @@ -196,7 +229,10 @@ void LocalObjectManager::SpillObjectsInternal( io_worker->rpc_client()->SpillObjects( request, [this, objects_to_spill, callback, io_worker]( const ray::Status &status, const rpc::SpillObjectsReply &r) { - absl::MutexLock lock(&mutex_); + { + absl::MutexLock lock(&mutex_); + num_active_workers_ -= 1; + } io_worker_pool_.PushSpillWorker(io_worker); if (!status.ok()) { for (const auto &object_id : objects_to_spill) { @@ -222,7 +258,6 @@ void LocalObjectManager::AddSpilledUrls( const std::vector &object_ids, const rpc::SpillObjectsReply &worker_reply, std::function callback) { auto num_remaining = std::make_shared(object_ids.size()); - auto num_bytes_spilled = std::make_shared(0); for (size_t i = 0; i < object_ids.size(); ++i) { const ObjectID &object_id = object_ids[i]; const std::string &object_url = worker_reply.spilled_objects_url(i); @@ -232,15 +267,12 @@ void LocalObjectManager::AddSpilledUrls( // be retrieved by other raylets. RAY_CHECK_OK(object_info_accessor_.AsyncAddSpilledUrl( object_id, object_url, - [this, object_id, object_url, callback, num_remaining, - num_bytes_spilled](Status status) { + [this, object_id, object_url, callback, num_remaining](Status status) { RAY_CHECK_OK(status); - absl::MutexLock lock(&mutex_); // Unpin the object. auto it = objects_pending_spill_.find(object_id); RAY_CHECK(it != objects_pending_spill_.end()); num_bytes_pending_spill_ -= it->second->GetSize(); - *num_bytes_spilled += it->second->GetSize(); objects_pending_spill_.erase(it); // Update the object_id -> url_ref_count to use it for deletion later. @@ -273,20 +305,41 @@ void LocalObjectManager::AsyncRestoreSpilledObject( << object_url; io_worker_pool_.PopRestoreWorker([this, object_id, object_url, callback]( std::shared_ptr io_worker) { + auto start_time = absl::GetCurrentTimeNanos(); RAY_LOG(DEBUG) << "Sending restore spilled object request"; rpc::RestoreSpilledObjectsRequest request; request.add_spilled_objects_url(std::move(object_url)); request.add_object_ids_to_restore(object_id.Binary()); io_worker->rpc_client()->RestoreSpilledObjects( request, - [this, object_id, callback, io_worker](const ray::Status &status, - const rpc::RestoreSpilledObjectsReply &r) { + [this, start_time, object_id, callback, io_worker]( + const ray::Status &status, const rpc::RestoreSpilledObjectsReply &r) { io_worker_pool_.PushRestoreWorker(io_worker); if (!status.ok()) { RAY_LOG(ERROR) << "Failed to send restore spilled object request: " << status.ToString(); } else { - RAY_LOG(DEBUG) << "Restored object " << object_id; + auto now = absl::GetCurrentTimeNanos(); + auto restored_bytes = r.bytes_restored_total(); + RAY_LOG(DEBUG) << "Restored " << restored_bytes << " in " + << (now - start_time) / 1e6 << "ms. Object id:" << object_id; + restored_bytes_total_ += restored_bytes; + restored_objects_total_ += 1; + // Adjust throughput timing to account for concurrent restore operations. + restore_time_total_s_ += + (now - std::max(start_time, last_restore_finish_ns_)) / 1e9; + if (now - last_restore_log_ns_ > 1e9) { + last_restore_log_ns_ = now; + // TODO(ekl) logging at error level until we add a better UX indicator. + RAY_LOG(ERROR) << "Restored " + << static_cast(restored_bytes_total_ / (1024 * 1024)) + << " MiB, " << restored_objects_total_ + << " objects, read throughput " + << static_cast(restored_bytes_total_ / (1024 * 1024) / + restore_time_total_s_) + << " MiB/s"; + } + last_restore_finish_ns_ = now; } if (callback) { callback(status); @@ -296,7 +349,6 @@ void LocalObjectManager::AsyncRestoreSpilledObject( } void LocalObjectManager::ProcessSpilledObjectsDeleteQueue(uint32_t max_batch_size) { - absl::MutexLock lock(&mutex_); std::vector object_urls_to_delete; // Process upto batch size of objects to delete. diff --git a/src/ray/raylet/local_object_manager.h b/src/ray/raylet/local_object_manager.h index 31adada2c..3cf2a2ac5 100644 --- a/src/ray/raylet/local_object_manager.h +++ b/src/ray/raylet/local_object_manager.h @@ -33,13 +33,15 @@ namespace raylet { /// have been freed, and objects that have been spilled. class LocalObjectManager { public: - LocalObjectManager(boost::asio::io_service &io_context, size_t free_objects_batch_size, - int64_t free_objects_period_ms, - IOWorkerPoolInterface &io_worker_pool, - gcs::ObjectInfoAccessor &object_info_accessor, - rpc::CoreWorkerClientPool &owner_client_pool, - bool object_pinning_enabled, bool automatic_object_deletion_enabled, - std::function &)> on_objects_freed) + LocalObjectManager( + boost::asio::io_service &io_context, size_t free_objects_batch_size, + int64_t free_objects_period_ms, IOWorkerPoolInterface &io_worker_pool, + gcs::ObjectInfoAccessor &object_info_accessor, + rpc::CoreWorkerClientPool &owner_client_pool, bool object_pinning_enabled, + bool automatic_object_deletion_enabled, int max_io_workers, + int64_t min_spilling_size, + std::function &)> on_objects_freed, + std::function is_plasma_object_spillable) : free_objects_period_ms_(free_objects_period_ms), free_objects_batch_size_(free_objects_batch_size), io_worker_pool_(io_worker_pool), @@ -48,7 +50,11 @@ class LocalObjectManager { object_pinning_enabled_(object_pinning_enabled), automatic_object_deletion_enabled_(automatic_object_deletion_enabled), on_objects_freed_(on_objects_freed), - last_free_objects_at_ms_(current_time_ms()) {} + last_free_objects_at_ms_(current_time_ms()), + min_spilling_size_(min_spilling_size), + num_active_workers_(0), + max_active_workers_(max_io_workers), + is_plasma_object_spillable_(is_plasma_object_spillable) {} /// Pin objects. /// @@ -67,22 +73,10 @@ class LocalObjectManager { void WaitForObjectFree(const rpc::Address &owner_address, const std::vector &object_ids); - /// Asynchronously spill objects when space is needed. - /// The callback tries to spill objects as much as num_bytes_to_spill and returns - /// the amount of space needed after the spilling is complete. - /// The returned value is calculated based off of min_bytes_to_spill. That says, - /// although it fails to spill num_bytes_to_spill, as long as it spills more than - /// min_bytes_to_spill, it will return the value that is less than 0 (meaning we - /// don't need any more additional space). + /// Spill objects as much as possible as fast as possible up to the max throughput. /// - /// \param num_bytes_to_spill The total number of bytes to spill. The method tries to - /// spill bytes as much as this value. - /// \param min_bytes_to_spill The minimum bytes that - /// need to be spilled. - /// \return The number of bytes of space still required after the - /// spill is complete. This return the value is less than 0 if it satifies the - /// min_bytes_to_spill. - int64_t SpillObjectsOfSize(int64_t num_bytes_to_spill, int64_t min_bytes_to_spill); + /// \return True if spilling is in progress. + void SpillObjectUptoMaxThroughput(); /// Spill objects to external storage. /// @@ -114,11 +108,33 @@ class LocalObjectManager { /// invocation. void ProcessSpilledObjectsDeleteQueue(uint32_t max_batch_size); + /// Return True if spilling is in progress. + /// This is a narrow interface that is accessed by plasma store. + /// We are using the narrow interface here because plasma store is running in a + /// different thread, and we'd like to avoid making this component thread-safe, + /// which is against the general raylet design. + /// + /// \return True if spilling is still in progress. False otherwise. + bool IsSpillingInProgress(); + private: + FRIEND_TEST(LocalObjectManagerTest, TestSpillObjectsOfSize); + FRIEND_TEST(LocalObjectManagerTest, + TestSpillObjectsOfSizeNumBytesToSpillHigherThanMinBytesToSpill); + FRIEND_TEST(LocalObjectManagerTest, TestSpillObjectNotEvictable); + + /// Asynchronously spill objects when space is needed. + /// The callback tries to spill objects as much as num_bytes_to_spill and returns + /// true if we could spill the corresponding bytes. + /// NOTE(sang): If 0 is given, this method spills a single object. + /// + /// \param num_bytes_to_spill The total number of bytes to spill. + /// \return True if it can spill num_bytes_to_spill. False otherwise. + bool SpillObjectsOfSize(int64_t num_bytes_to_spill); + /// Internal helper method for spilling objects. void SpillObjectsInternal(const std::vector &objects_ids, - std::function callback) - EXCLUSIVE_LOCKS_REQUIRED(mutex_); + std::function callback); /// Release an object that has been freed by its owner. void ReleaseFreedObject(const ObjectID &object_id); @@ -164,14 +180,12 @@ class LocalObjectManager { std::function &)> on_objects_freed_; // Objects that are pinned on this node. - absl::flat_hash_map> pinned_objects_ - GUARDED_BY(mutex_); + absl::flat_hash_map> pinned_objects_; // Objects that were pinned on this node but that are being spilled. // These objects will be released once spilling is complete and the URL is // written to the object directory. - absl::flat_hash_map> objects_pending_spill_ - GUARDED_BY(mutex_); + absl::flat_hash_map> objects_pending_spill_; /// The time that we last sent a FreeObjects request to other nodes for /// objects that have gone out of scope in the application. @@ -185,7 +199,7 @@ class LocalObjectManager { /// The total size of the objects that are currently being /// spilled from this node, in bytes. - size_t num_bytes_pending_spill_ GUARDED_BY(mutex_) = 0; + size_t num_bytes_pending_spill_; /// This class is accessed by both the raylet and plasma store threads. The /// mutex protects private members that relate to object spilling. @@ -198,16 +212,63 @@ class LocalObjectManager { /// A list of object id and url pairs that need to be deleted. /// We don't instantly delete objects when it goes out of scope from external storages /// because those objects could be still in progress of spilling. - std::queue spilled_object_pending_delete_ GUARDED_BY(mutex_); + std::queue spilled_object_pending_delete_; /// Mapping from object id to url_with_offsets. We cannot reuse pinned_objects_ because /// pinned_objects_ entries are deleted when spilling happens. - absl::flat_hash_map spilled_objects_url_ GUARDED_BY(mutex_); + absl::flat_hash_map spilled_objects_url_; /// Base URL -> ref_count. It is used because there could be multiple objects /// within a single spilled file. We need to ref count to avoid deleting the file /// before all objects within that file are out of scope. - absl::flat_hash_map url_ref_count_ GUARDED_BY(mutex_); + absl::flat_hash_map url_ref_count_; + + /// Minimum bytes to spill to a single IO spill worker. + int64_t min_spilling_size_; + + /// The current number of active spill workers. + int64_t num_active_workers_ GUARDED_BY(mutex_); + + /// The max number of active spill workers. + const int64_t max_active_workers_; + + /// Callback to check if a plasma object is pinned in workers. + /// Return true if unpinned, meaning we can safely spill the object. False otherwise. + std::function is_plasma_object_spillable_; + + /// + /// Stats + /// + + /// The last time a spill operation finished. + int64_t last_spill_finish_ns_ = 0; + + /// The total wall time in seconds spent in spilling. + double spill_time_total_s_ = 0; + + /// The total number of bytes spilled. + int64_t spilled_bytes_total_ = 0; + + /// The total number of objects spilled. + int64_t spilled_objects_total_ = 0; + + /// The last time a restore operation finished. + int64_t last_restore_finish_ns_ = 0; + + /// The total wall time in seconds spent in restoring. + double restore_time_total_s_ = 0; + + /// The total number of bytes restored. + int64_t restored_bytes_total_ = 0; + + /// The total number of objects restored. + int64_t restored_objects_total_ = 0; + + /// The last time a spill log finished. + int64_t last_spill_log_ns_ = 0; + + /// The last time a restore log finished. + int64_t last_restore_log_ns_ = 0; }; }; // namespace raylet diff --git a/src/ray/raylet/main.cc b/src/ray/raylet/main.cc index 4b92d7163..6352fea3e 100644 --- a/src/ray/raylet/main.cc +++ b/src/ray/raylet/main.cc @@ -229,6 +229,8 @@ int main(int argc, char *argv[]) { node_manager_config.store_socket_name = store_socket_name; node_manager_config.temp_dir = temp_dir; node_manager_config.session_dir = session_dir; + node_manager_config.max_io_workers = RayConfig::instance().max_io_workers(); + node_manager_config.min_spilling_size = RayConfig::instance().min_spilling_size(); // Configuration for the object manager. ray::ObjectManagerConfig object_manager_config; diff --git a/src/ray/raylet/node_manager.cc b/src/ray/raylet/node_manager.cc index fe975d79b..9dd1f25b8 100644 --- a/src/ray/raylet/node_manager.cc +++ b/src/ray/raylet/node_manager.cc @@ -116,7 +116,8 @@ std::string WorkerOwnerString(std::shared_ptr &worker) { NodeManager::NodeManager(boost::asio::io_service &io_service, const NodeID &self_node_id, const NodeManagerConfig &config, ObjectManager &object_manager, std::shared_ptr gcs_client, - std::shared_ptr object_directory) + std::shared_ptr object_directory, + std::function is_plasma_object_spillable) : self_node_id_(self_node_id), io_service_(io_service), object_manager_(object_manager), @@ -171,14 +172,18 @@ NodeManager::NodeManager(boost::asio::io_service &io_service, const NodeID &self /* object_pinning_enabled */ config.object_pinning_enabled, /* automatic_object_deletion_enabled */ config.automatic_object_deletion_enabled, + /*max_io_workers*/ config.max_io_workers, + /*min_spilling_size*/ config.min_spilling_size, [this](const std::vector &object_ids) { object_manager_.FreeObjects(object_ids, /*local_only=*/false); - }), + }, + is_plasma_object_spillable), new_scheduler_enabled_(RayConfig::instance().new_scheduler_enabled()), report_worker_backlog_(RayConfig::instance().report_worker_backlog()), last_local_gc_ns_(absl::GetCurrentTimeNanos()), local_gc_interval_ns_(RayConfig::instance().local_gc_interval_s() * 1e9), + local_gc_min_interval_ns_(RayConfig::instance().local_gc_min_interval_s() * 1e9), record_metrics_period_(config.record_metrics_period_ms) { RAY_LOG(INFO) << "Initializing NodeManager with ID " << self_node_id_; RAY_CHECK(heartbeat_period_.count() > 0); @@ -553,7 +558,8 @@ void NodeManager::ReportResourceUsage() { // Trigger local GC if needed. This throttles the frequency of local GC calls // to at most once per heartbeat interval. auto now = absl::GetCurrentTimeNanos(); - if (should_local_gc_ || now - last_local_gc_ns_ > local_gc_interval_ns_) { + if ((should_local_gc_ || now - last_local_gc_ns_ > local_gc_interval_ns_) && + now - last_local_gc_ns_ > local_gc_min_interval_ns_) { DoLocalGC(); should_local_gc_ = false; last_local_gc_ns_ = now; @@ -1186,8 +1192,7 @@ void NodeManager::ProcessClientMessage(const std::shared_ptr & ProcessFetchOrReconstructMessage(client, message_data); } break; case protocol::MessageType::NotifyDirectCallTaskBlocked: { - std::shared_ptr worker = worker_pool_.GetRegisteredWorker(client); - HandleDirectCallTaskBlocked(worker); + ProcessDirectCallTaskBlocked(client, message_data); } break; case protocol::MessageType::NotifyDirectCallTaskUnblocked: { std::shared_ptr worker = worker_pool_.GetRegisteredWorker(client); @@ -1534,6 +1539,15 @@ void NodeManager::ProcessFetchOrReconstructMessage( } } +void NodeManager::ProcessDirectCallTaskBlocked( + const std::shared_ptr &client, const uint8_t *message_data) { + auto message = + flatbuffers::GetRoot(message_data); + bool release_resources = message->release_resources(); + std::shared_ptr worker = worker_pool_.GetRegisteredWorker(client); + HandleDirectCallTaskBlocked(worker, release_resources); +} + void NodeManager::ProcessWaitRequestMessage( const std::shared_ptr &client, const uint8_t *message_data) { // Read the data. @@ -2148,9 +2162,9 @@ void NodeManager::SubmitTask(const Task &task) { } void NodeManager::HandleDirectCallTaskBlocked( - const std::shared_ptr &worker) { + const std::shared_ptr &worker, bool release_resources) { if (new_scheduler_enabled_) { - if (!worker || worker->IsBlocked()) { + if (!worker || worker->IsBlocked() || !release_resources) { return; } std::vector cpu_instances; @@ -2169,7 +2183,8 @@ void NodeManager::HandleDirectCallTaskBlocked( return; } - if (!worker || worker->GetAssignedTaskId().IsNil() || worker->IsBlocked()) { + if (!worker || worker->GetAssignedTaskId().IsNil() || worker->IsBlocked() || + !release_resources) { return; // The worker may have died or is no longer processing the task. } auto const cpu_resource_ids = worker->ReleaseTaskCpuResources(); @@ -2297,7 +2312,6 @@ void NodeManager::AsyncResolveObjectsFinish( const std::shared_ptr &client, const TaskID ¤t_task_id, bool was_blocked) { std::shared_ptr worker = worker_pool_.GetRegisteredWorker(client); - // TODO(swang): Because the object dependencies are tracked in the task // dependency manager, we could actually remove this message entirely and // instead unblock the worker once all the objects become available. @@ -3154,9 +3168,6 @@ void NodeManager::HandleGlobalGC(const rpc::GlobalGCRequest &request, } void NodeManager::TriggerGlobalGC() { - RAY_LOG(INFO) << "Broadcasting Python GC request to all raylets since the cluster " - << "is low on resources. This removes Ray actor and object refs " - << "that are stuck in Python reference cycles."; should_global_gc_ = true; // We won't see our own request, so trigger local GC in the next heartbeat. should_local_gc_ = true; diff --git a/src/ray/raylet/node_manager.h b/src/ray/raylet/node_manager.h index 1bc554b11..a734ebaec 100644 --- a/src/ray/raylet/node_manager.h +++ b/src/ray/raylet/node_manager.h @@ -104,6 +104,10 @@ struct NodeManagerConfig { std::unordered_map raylet_config; // The time between record metrics in milliseconds, or -1 to disable. uint64_t record_metrics_period_ms; + // The number if max io workers. + int max_io_workers; + // The minimum object size that can be spilled by each spill operation. + int64_t min_spilling_size; }; class NodeManager : public rpc::NodeManagerServiceHandler { @@ -115,7 +119,8 @@ class NodeManager : public rpc::NodeManagerServiceHandler { NodeManager(boost::asio::io_service &io_service, const NodeID &self_node_id, const NodeManagerConfig &config, ObjectManager &object_manager, std::shared_ptr gcs_client, - std::shared_ptr object_directory); + std::shared_ptr object_directory_, + std::function is_plasma_object_spillable); /// Process a new client connection. /// @@ -375,7 +380,8 @@ class NodeManager : public rpc::NodeManagerServiceHandler { /// arrive after the worker lease has been returned to the node manager. /// /// \param worker Shared ptr to the worker, or nullptr if lost. - void HandleDirectCallTaskBlocked(const std::shared_ptr &worker); + void HandleDirectCallTaskBlocked(const std::shared_ptr &worker, + bool release_resources); /// Handle a direct call task that is unblocked. Note that this callback may /// arrive after the worker lease has been returned to the node manager. @@ -437,6 +443,13 @@ class NodeManager : public rpc::NodeManagerServiceHandler { /// \return Void. void ProcessSubmitTaskMessage(const uint8_t *message_data); + /// Process client message of NotifyDirectCallTaskBlocked + /// + /// \param message_data A pointer to the message data. + /// \return Void. + void ProcessDirectCallTaskBlocked(const std::shared_ptr &client, + const uint8_t *message_data); + /// Process client message of RegisterClientRequest /// /// \param client The client that sent the message. @@ -745,11 +758,15 @@ class NodeManager : public rpc::NodeManagerServiceHandler { /// on all local workers of this raylet. bool should_local_gc_ = false; - /// The last time local GC was triggered. + /// The last time local gc was run. int64_t last_local_gc_ns_ = 0; /// The interval in nanoseconds between local GC automatic triggers. - const int64_t local_gc_interval_ns_ = 10 * 60 * 1e9; + const int64_t local_gc_interval_ns_; + + /// The min interval in nanoseconds between local GC runs (auto + memory pressure + /// triggered). + const int64_t local_gc_min_interval_ns_; /// These two classes make up the new scheduler. ClusterResourceScheduler is /// responsible for maintaining a view of the cluster state w.r.t resource diff --git a/src/ray/raylet/raylet.cc b/src/ray/raylet/raylet.cc index 9add8e425..6336f3160 100644 --- a/src/ray/raylet/raylet.cc +++ b/src/ray/raylet/raylet.cc @@ -70,24 +70,33 @@ Raylet::Raylet(boost::asio::io_service &main_service, const std::string &socket_ gcs_client_)) : std::dynamic_pointer_cast( std::make_shared(main_service, gcs_client_))), - object_manager_(main_service, self_node_id_, object_manager_config, - object_directory_, - [this](const ObjectID &object_id, const std::string &spilled_url, - std::function callback) { - node_manager_.GetLocalObjectManager().AsyncRestoreSpilledObject( - object_id, spilled_url, callback); - }, - [this](int64_t num_bytes_to_spill, int64_t min_bytes_to_spill) { - return node_manager_.GetLocalObjectManager().SpillObjectsOfSize( - num_bytes_to_spill, min_bytes_to_spill); - }, - [this]() { - // Post on the node manager's event loop since this - // will be called from the plasma store thread. - main_service_.post([this]() { node_manager_.TriggerGlobalGC(); }); - }), + object_manager_( + main_service, self_node_id_, object_manager_config, object_directory_, + [this](const ObjectID &object_id, const std::string &spilled_url, + std::function callback) { + node_manager_.GetLocalObjectManager().AsyncRestoreSpilledObject( + object_id, spilled_url, callback); + }, + [this]() { + // This callback is called from the plasma store thread. + // NOTE: It means the local object manager should be thread-safe. + main_service_.post([this]() { + node_manager_.GetLocalObjectManager().SpillObjectUptoMaxThroughput(); + }); + return node_manager_.GetLocalObjectManager().IsSpillingInProgress(); + }, + [this]() { + // Post on the node manager's event loop since this + // callback is called from the plasma store thread. + // This will help keep node manager lock-less. + main_service_.post([this]() { node_manager_.TriggerGlobalGC(); }); + }), node_manager_(main_service, self_node_id_, node_manager_config, object_manager_, - gcs_client_, object_directory_), + gcs_client_, object_directory_, + [this](const ObjectID &object_id) { + // It is used by local_object_store. + return object_manager_.IsPlasmaObjectSpillable(object_id); + }), socket_name_(socket_name), acceptor_(main_service, ParseUrlEndpoint(socket_name)), socket_(main_service) { diff --git a/src/ray/raylet/test/local_object_manager_test.cc b/src/ray/raylet/test/local_object_manager_test.cc index 96ee66638..e9f34baa7 100644 --- a/src/ray/raylet/test/local_object_manager_test.cc +++ b/src/ray/raylet/test/local_object_manager_test.cc @@ -236,15 +236,23 @@ class LocalObjectManagerTest : public ::testing::Test { /*free_objects_period_ms=*/1000, worker_pool, object_table, client_pool, /*object_pinning_enabled=*/true, /*automatic_object_delete_enabled=*/true, + /*max_io_workers=*/2, + /*min_spilling_size=*/0, [&](const std::vector &object_ids) { for (const auto &object_id : object_ids) { freed.insert(object_id); } + }, + /*is_plasma_object_spillable=*/ + [&](const ray::ObjectID &object_id) { + return unevictable_objects_.count(object_id) == 0; }), unpins(std::make_shared>()) { RayConfig::instance().initialize({{"object_spilling_config", "mock_config"}}); } + void TearDown() { unevictable_objects_.clear(); } + std::string BuildURL(const std::string url, int offset = 0, int num_objects = 1) { return url + "?" + "num_objects=" + std::to_string(num_objects) + "&offset=" + std::to_string(offset); @@ -262,6 +270,8 @@ class LocalObjectManagerTest : public ::testing::Test { // This hashmap is incremented when objects are unpinned by destroying their // unique_ptr. std::shared_ptr> unpins; + // Object ids in this field won't be evictable. + std::unordered_set unevictable_objects_; }; TEST_F(LocalObjectManagerTest, TestPin) { @@ -416,17 +426,11 @@ TEST_F(LocalObjectManagerTest, TestSpillObjectsOfSize) { objects.push_back(std::move(object)); } manager.PinObjects(object_ids, std::move(objects)); - - int64_t num_bytes_required = manager.SpillObjectsOfSize(total_size / 2, total_size / 2); - ASSERT_EQ(num_bytes_required, -object_size / 2); + ASSERT_TRUE(manager.SpillObjectsOfSize(total_size / 2)); for (const auto &id : object_ids) { ASSERT_EQ((*unpins)[id], 0); } - // Check that this returns the total number of bytes currently being spilled. - num_bytes_required = manager.SpillObjectsOfSize(0, 0); - ASSERT_EQ(num_bytes_required, -2 * object_size); - // Check that half the objects get spilled and the URLs get added to the // global object directory. std::vector urls; @@ -447,9 +451,124 @@ TEST_F(LocalObjectManagerTest, TestSpillObjectsOfSize) { ASSERT_EQ((*unpins)[object_url.first], 1); } - // Check that this returns the total number of bytes currently being spilled. - num_bytes_required = manager.SpillObjectsOfSize(0, 0); - ASSERT_EQ(num_bytes_required, 0); + // Make sure providing 0 bytes to SpillObjectsOfSize will spill one object. + // This is important to cover min_spilling_size_== 0. + ASSERT_TRUE(manager.SpillObjectsOfSize(0)); + EXPECT_CALL(worker_pool, PushSpillWorker(_)); + const std::string url = BuildURL("url" + std::to_string(object_ids.size())); + ASSERT_TRUE(worker_pool.io_worker_client->ReplySpillObjects({url})); + ASSERT_TRUE(object_table.ReplyAsyncAddSpilledUrl()); + ASSERT_EQ(object_table.object_urls.size(), 3); + urls.push_back(url); + for (auto &object_url : object_table.object_urls) { + auto it = std::find(urls.begin(), urls.end(), object_url.second); + ASSERT_TRUE(it != urls.end()); + ASSERT_EQ((*unpins)[object_url.first], 1); + } + + // Since there's no more object to spill, this should fail. + ASSERT_FALSE(manager.SpillObjectsOfSize(0)); +} + +TEST_F(LocalObjectManagerTest, TestSpillObjectNotEvictable) { + rpc::Address owner_address; + owner_address.set_worker_id(WorkerID::FromRandom().Binary()); + + std::vector object_ids; + std::vector> objects; + int64_t total_size = 0; + int64_t object_size = 1000; + + const ObjectID object_id = ObjectID::FromRandom(); + object_ids.push_back(object_id); + unevictable_objects_.emplace(object_id); + auto data_buffer = std::make_shared(object_size, object_id, unpins); + total_size += object_size; + std::unique_ptr object( + new RayObject(data_buffer, nullptr, std::vector())); + objects.push_back(std::move(object)); + + manager.PinObjects(object_ids, std::move(objects)); + ASSERT_FALSE(manager.SpillObjectsOfSize(1000)); + for (const auto &id : object_ids) { + ASSERT_EQ((*unpins)[id], 0); + } + + // Now object is evictable. Spill should succeed. + unevictable_objects_.erase(object_id); + ASSERT_TRUE(manager.SpillObjectsOfSize(1000)); +} + +TEST_F(LocalObjectManagerTest, TestSpillUptoMaxThroughput) { + rpc::Address owner_address; + owner_address.set_worker_id(WorkerID::FromRandom().Binary()); + + std::vector object_ids; + std::vector> objects; + int64_t object_size = 1000; + size_t total_objects = 3; + + // Pin 3 objects. + for (size_t i = 0; i < total_objects; i++) { + ObjectID object_id = ObjectID::FromRandom(); + object_ids.push_back(object_id); + auto data_buffer = std::make_shared(object_size, object_id, unpins); + std::unique_ptr object( + new RayObject(data_buffer, nullptr, std::vector())); + objects.push_back(std::move(object)); + } + manager.PinObjects(object_ids, std::move(objects)); + + // This will spill until 2 workers are occupied. + manager.SpillObjectUptoMaxThroughput(); + ASSERT_TRUE(manager.IsSpillingInProgress()); + // Spilling is still going on, meaning we can make the pace. So it should return true. + manager.SpillObjectUptoMaxThroughput(); + ASSERT_TRUE(manager.IsSpillingInProgress()); + // No object ids are spilled yet. + for (const auto &id : object_ids) { + ASSERT_EQ((*unpins)[id], 0); + } + + // Spill one object. + std::vector urls; + urls.push_back(BuildURL("url" + std::to_string(0))); + ASSERT_TRUE(worker_pool.io_worker_client->ReplySpillObjects({urls[0]})); + ASSERT_TRUE(object_table.ReplyAsyncAddSpilledUrl()); + // Make sure object is spilled. + ASSERT_EQ(object_table.object_urls.size(), 1); + for (auto &object_url : object_table.object_urls) { + if (urls[0] == object_url.second) { + ASSERT_EQ((*unpins)[object_url.first], 1); + } + } + + // Now, there's only one object that is current spilling. + // SpillObjectUptoMaxThroughput will spill one more object (since one worker is + // availlable). + manager.SpillObjectUptoMaxThroughput(); + ASSERT_TRUE(manager.IsSpillingInProgress()); + manager.SpillObjectUptoMaxThroughput(); + ASSERT_TRUE(manager.IsSpillingInProgress()); + + // Spilling is done for all objects. + for (size_t i = 1; i < object_ids.size(); i++) { + urls.push_back(BuildURL("url" + std::to_string(i))); + } + for (size_t i = 1; i < urls.size(); i++) { + ASSERT_TRUE(worker_pool.io_worker_client->ReplySpillObjects({urls[i]})); + ASSERT_TRUE(object_table.ReplyAsyncAddSpilledUrl()); + } + ASSERT_EQ(object_table.object_urls.size(), 3); + for (auto &object_url : object_table.object_urls) { + auto it = std::find(urls.begin(), urls.end(), object_url.second); + ASSERT_TRUE(it != urls.end()); + ASSERT_EQ((*unpins)[object_url.first], 1); + } + + // We cannot spill anymore as there is no more pinned object. + manager.SpillObjectUptoMaxThroughput(); + ASSERT_FALSE(manager.IsSpillingInProgress()); } TEST_F(LocalObjectManagerTest, TestSpillError) { @@ -739,52 +858,6 @@ TEST_F(LocalObjectManagerTest, TestDeleteMaxObjects) { ASSERT_EQ(deleted_urls_size, free_objects_batch_size); } -TEST_F(LocalObjectManagerTest, - TestSpillObjectsOfSizeNumBytesToSpillHigherThanMinBytesToSpill) { - /// Test the case SpillObjectsOfSize(num_bytes_to_spill, min_bytes_to_spill - /// where num_bytes_to_spill > min_bytes_to_spill. - rpc::Address owner_address; - owner_address.set_worker_id(WorkerID::FromRandom().Binary()); - - std::vector object_ids; - std::vector> objects; - int64_t total_size = 0; - int64_t object_size = 1000; - size_t object_len = 3; - - for (size_t i = 0; i < object_len; i++) { - ObjectID object_id = ObjectID::FromRandom(); - object_ids.push_back(object_id); - auto data_buffer = std::make_shared(object_size, object_id, unpins); - total_size += object_size; - std::unique_ptr object( - new RayObject(data_buffer, nullptr, std::vector())); - objects.push_back(std::move(object)); - } - manager.PinObjects(object_ids, std::move(objects)); - - // First test when num_bytes_to_spill > min_bytes to spill. - // It means that we cannot spill the num_bytes_required, but we at least spilled the - // required amount, which is the min_bytes_to_spill. - int64_t num_bytes_required = manager.SpillObjectsOfSize(8000, object_size); - // only min bytes to spill is considered. - ASSERT_TRUE(num_bytes_required <= 0); - - // Make sure the spilling is done properly. - std::vector urls; - for (size_t i = 0; i < object_ids.size(); i++) { - urls.push_back(BuildURL("url" + std::to_string(i))); - } - EXPECT_CALL(worker_pool, PushSpillWorker(_)); - ASSERT_TRUE(worker_pool.io_worker_client->ReplySpillObjects(urls)); - for (size_t i = 0; i < object_ids.size(); i++) { - ASSERT_TRUE(object_table.ReplyAsyncAddSpilledUrl()); - } - for (size_t i = 0; i < object_ids.size(); i++) { - ASSERT_EQ((*unpins).size(), object_len); - } -} - } // namespace raylet } // namespace ray diff --git a/src/ray/raylet_client/raylet_client.cc b/src/ray/raylet_client/raylet_client.cc index 9251c1020..1c6365796 100644 --- a/src/ray/raylet_client/raylet_client.cc +++ b/src/ray/raylet_client/raylet_client.cc @@ -189,9 +189,9 @@ Status raylet::RayletClient::NotifyUnblocked(const TaskID ¤t_task_id) { return conn_->WriteMessage(MessageType::NotifyUnblocked, &fbb); } -Status raylet::RayletClient::NotifyDirectCallTaskBlocked() { +Status raylet::RayletClient::NotifyDirectCallTaskBlocked(bool release_resources) { flatbuffers::FlatBufferBuilder fbb; - auto message = protocol::CreateNotifyDirectCallTaskBlocked(fbb); + auto message = protocol::CreateNotifyDirectCallTaskBlocked(fbb, release_resources); fbb.Finish(message); return conn_->WriteMessage(MessageType::NotifyDirectCallTaskBlocked, &fbb); } diff --git a/src/ray/raylet_client/raylet_client.h b/src/ray/raylet_client/raylet_client.h index 6f2821038..9fa1b7982 100644 --- a/src/ray/raylet_client/raylet_client.h +++ b/src/ray/raylet_client/raylet_client.h @@ -256,8 +256,9 @@ class RayletClient : public RayletClientInterface { /// Notify the raylet that this client is blocked. This is only used for direct task /// calls. Note that ordering of this with respect to Unblock calls is important. /// - /// \return ray::Status. - ray::Status NotifyDirectCallTaskBlocked(); + /// \param release_resources: true if the dirct call blocking needs to release + /// resources. \return ray::Status. + ray::Status NotifyDirectCallTaskBlocked(bool release_resources); /// Notify the raylet that this client is unblocked. This is only used for direct task /// calls. Note that ordering of this with respect to Block calls is important. From a092433bc87b1e9decb335119ba96ab48aee7c1c Mon Sep 17 00:00:00 2001 From: dHannasch Date: Fri, 18 Dec 2020 23:34:34 -0700 Subject: [PATCH 32/88] [core] Use the ConnectWithoutRetries error message (#12732) --- src/ray/gcs/redis_context.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/ray/gcs/redis_context.cc b/src/ray/gcs/redis_context.cc index c9f32c711..c4edbb688 100644 --- a/src/ray/gcs/redis_context.cc +++ b/src/ray/gcs/redis_context.cc @@ -333,7 +333,7 @@ Status ConnectWithRetries(const std::string &address, int port, break; } if (*context == nullptr) { - RAY_LOG(WARNING) << "Could not allocate Redis context, will retry in " + RAY_LOG(WARNING) << errorMessage << " Will retry in " << RayConfig::instance().redis_db_connect_wait_milliseconds() << " milliseconds."; } From 5d987f5988b5ff2333156b05e26a2f965b5da167 Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Fri, 18 Dec 2020 23:51:44 -0800 Subject: [PATCH 33/88] Revert "Increase the number of unique bits for actors to avoid handle collisions (#12894)" (#12988) This reverts commit 3e492a79ec7b67dd1137535cead690339effc2ac. --- .../tests/test_stats_collector.py | 18 +++++++++------ dashboard/tests/test_memory_utils.py | 5 ++--- .../src/main/java/io/ray/api/id/ActorId.java | 2 +- .../src/main/java/io/ray/api/id/ObjectId.java | 2 +- .../src/main/java/io/ray/api/id/UniqueId.java | 2 +- .../java/io/ray/runtime/UniqueIdTest.java | 22 +++++++++++++------ python/ray/exceptions.py | 6 ++--- python/ray/includes/function_descriptor.pxi | 13 ++++------- python/ray/includes/unique_ids.pxi | 2 +- python/ray/log_monitor.py | 2 +- python/ray/ray_constants.py | 2 +- python/ray/serialization.py | 5 ++--- python/ray/tests/test_advanced_3.py | 6 ++--- python/ray/tests/test_multi_node.py | 8 +++---- python/ray/utils.py | 4 ++-- python/ray/worker.py | 3 +-- src/ray/common/constants.h | 2 +- src/ray/common/id.h | 2 +- src/ray/core_worker/actor_manager.cc | 2 -- 19 files changed, 54 insertions(+), 54 deletions(-) diff --git a/dashboard/modules/stats_collector/tests/test_stats_collector.py b/dashboard/modules/stats_collector/tests/test_stats_collector.py index bed6d650f..f4246770a 100644 --- a/dashboard/modules/stats_collector/tests/test_stats_collector.py +++ b/dashboard/modules/stats_collector/tests/test_stats_collector.py @@ -112,16 +112,20 @@ def test_memory_table(disable_aiohttp_cache, ray_start_with_dashboard): def check_mem_table(): resp = requests.get(f"{webui_url}/memory/memory_table") resp_data = resp.json() - assert resp_data["result"] + if not resp_data["result"]: + return False latest_memory_table = resp_data["data"]["memoryTable"] summary = latest_memory_table["summary"] - # 1 ref per handle and per object the actor has a ref to - assert summary["totalActorHandles"] == len(actors) * 2 - # 1 ref for my_obj - assert summary["totalLocalRefCount"] == 1 + try: + # 1 ref per handle and per object the actor has a ref to + assert summary["totalActorHandles"] == len(actors) * 2 + # 1 ref for my_obj + assert summary["totalLocalRefCount"] == 1 + return True + except AssertionError: + return False - wait_until_succeeded_without_exception( - check_mem_table, (AssertionError, ), timeout_ms=1000) + wait_for_condition(check_mem_table, 10) def test_get_all_node_details(disable_aiohttp_cache, ray_start_with_dashboard): diff --git a/dashboard/tests/test_memory_utils.py b/dashboard/tests/test_memory_utils.py index 212eeefad..f58ecd8ae 100644 --- a/dashboard/tests/test_memory_utils.py +++ b/dashboard/tests/test_memory_utils.py @@ -7,9 +7,8 @@ from ray.new_dashboard.memory_utils import ( NODE_ADDRESS = "127.0.0.1" IS_DRIVER = True PID = 1 - -OBJECT_ID = "ZmZmZmZmZmZmZmZmZmZmZmZmZmZmZmZmZmZmZg==" -ACTOR_ID = "fffffffffffffffffffffffffffffffff66d17ba010000c801000000" +OBJECT_ID = "7wpsIhgZiBz/////AQAAyAEAAAA=" +ACTOR_ID = "fffffffffffffffff66d17ba010000c801000000" DECODED_ID = decode_object_ref_if_needed(OBJECT_ID) OBJECT_SIZE = 100 diff --git a/java/api/src/main/java/io/ray/api/id/ActorId.java b/java/api/src/main/java/io/ray/api/id/ActorId.java index a21d4e79f..65a0cf19a 100644 --- a/java/api/src/main/java/io/ray/api/id/ActorId.java +++ b/java/api/src/main/java/io/ray/api/id/ActorId.java @@ -7,7 +7,7 @@ import java.util.Random; public class ActorId extends BaseId implements Serializable { - private static final int UNIQUE_BYTES_LENGTH = 12; + private static final int UNIQUE_BYTES_LENGTH = 4; public static final int LENGTH = JobId.LENGTH + UNIQUE_BYTES_LENGTH; diff --git a/java/api/src/main/java/io/ray/api/id/ObjectId.java b/java/api/src/main/java/io/ray/api/id/ObjectId.java index 78b677ac8..9b1fa246f 100644 --- a/java/api/src/main/java/io/ray/api/id/ObjectId.java +++ b/java/api/src/main/java/io/ray/api/id/ObjectId.java @@ -10,7 +10,7 @@ import java.util.Random; */ public class ObjectId extends BaseId implements Serializable { - public static final int LENGTH = 28; + public static final int LENGTH = 20; /** * Create an ObjectId from a ByteBuffer. diff --git a/java/api/src/main/java/io/ray/api/id/UniqueId.java b/java/api/src/main/java/io/ray/api/id/UniqueId.java index 44b19f6a7..03de53943 100644 --- a/java/api/src/main/java/io/ray/api/id/UniqueId.java +++ b/java/api/src/main/java/io/ray/api/id/UniqueId.java @@ -11,7 +11,7 @@ import java.util.Random; */ public class UniqueId extends BaseId implements Serializable { - public static final int LENGTH = 28; + public static final int LENGTH = 20; public static final UniqueId NIL = genNil(); /** diff --git a/java/runtime/src/test/java/io/ray/runtime/UniqueIdTest.java b/java/runtime/src/test/java/io/ray/runtime/UniqueIdTest.java index 7496f1baf..25704f321 100644 --- a/java/runtime/src/test/java/io/ray/runtime/UniqueIdTest.java +++ b/java/runtime/src/test/java/io/ray/runtime/UniqueIdTest.java @@ -1,6 +1,7 @@ package io.ray.runtime; import io.ray.api.id.UniqueId; +import io.ray.runtime.util.IdUtil; import java.nio.ByteBuffer; import java.util.Arrays; import javax.xml.bind.DatatypeConverter; @@ -12,12 +13,12 @@ public class UniqueIdTest { @Test public void testConstructUniqueId() { // Test `fromHexString()` - UniqueId id1 = UniqueId.fromHexString("00000000123456789ABCDEF123456789ABCDEF0123456789ABCDEF00"); - Assert.assertEquals("00000000123456789abcdef123456789abcdef0123456789abcdef00", id1.toString()); + UniqueId id1 = UniqueId.fromHexString("00000000123456789ABCDEF123456789ABCDEF00"); + Assert.assertEquals("00000000123456789abcdef123456789abcdef00", id1.toString()); Assert.assertFalse(id1.isNil()); try { - UniqueId id2 = UniqueId.fromHexString("000000123456789ABCDEF123456789ABCDEF0123456789ABCDEF00"); + UniqueId id2 = UniqueId.fromHexString("000000123456789ABCDEF123456789ABCDEF00"); // This shouldn't be happened. Assert.assertTrue(false); } catch (IllegalArgumentException e) { @@ -33,16 +34,23 @@ public class UniqueIdTest { } // Test `fromByteBuffer()` - byte[] bytes = DatatypeConverter.parseHexBinary("0123456789ABCDEF0123456789ABCDEF012345670123456789ABCDEF"); - ByteBuffer byteBuffer = ByteBuffer.wrap(bytes, 0, 28); + byte[] bytes = DatatypeConverter.parseHexBinary("0123456789ABCDEF0123456789ABCDEF01234567"); + ByteBuffer byteBuffer = ByteBuffer.wrap(bytes, 0, 20); UniqueId id4 = UniqueId.fromByteBuffer(byteBuffer); Assert.assertTrue(Arrays.equals(bytes, id4.getBytes())); - Assert.assertEquals("0123456789abcdef0123456789abcdef012345670123456789abcdef", id4.toString()); + Assert.assertEquals("0123456789abcdef0123456789abcdef01234567", id4.toString()); // Test `genNil()` UniqueId id6 = UniqueId.NIL; - Assert.assertEquals("FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF".toLowerCase(), id6.toString()); + Assert.assertEquals("FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF".toLowerCase(), id6.toString()); Assert.assertTrue(id6.isNil()); } + + @Test + void testMurmurHash() { + UniqueId id = UniqueId.fromHexString("3131313131313131313132323232323232323232"); + long remainder = Long.remainderUnsigned(IdUtil.murmurHashCode(id), 1000000000); + Assert.assertEquals(remainder, 787616861); + } } diff --git a/python/ray/exceptions.py b/python/ray/exceptions.py index 56e943db6..b5a0b477c 100644 --- a/python/ray/exceptions.py +++ b/python/ray/exceptions.py @@ -142,8 +142,7 @@ class WorkerCrashedError(RayError): """Indicates that the worker died unexpectedly while executing a task.""" def __str__(self): - return ("The worker died unexpectedly while executing this task. " - "Check python-core-worker-*.log files for more information.") + return "The worker died unexpectedly while executing this task." class RayActorError(RayError): @@ -154,8 +153,7 @@ class RayActorError(RayError): """ def __str__(self): - return ("The actor died unexpectedly before finishing this task. " - "Check python-core-worker-*.log files for more information.") + return "The actor died unexpectedly before finishing this task." class RaySystemError(RayError): diff --git a/python/ray/includes/function_descriptor.pxi b/python/ray/includes/function_descriptor.pxi index d2c4cbbf4..a9ac11fdb 100644 --- a/python/ray/includes/function_descriptor.pxi +++ b/python/ray/includes/function_descriptor.pxi @@ -12,7 +12,6 @@ import hashlib import cython import inspect import uuid -import ray.ray_constants as ray_constants ctypedef object (*FunctionDescriptor_from_cpp)(const CFunctionDescriptor &) @@ -189,8 +188,7 @@ cdef class PythonFunctionDescriptor(FunctionDescriptor): function_name = function.__name__ class_name = "" - pickled_function_hash = hashlib.shake_128(pickled_function).hexdigest( - ray_constants.ID_SIZE) + pickled_function_hash = hashlib.sha1(pickled_function).hexdigest() return cls(module_name, function_name, class_name, pickled_function_hash) @@ -210,10 +208,7 @@ cdef class PythonFunctionDescriptor(FunctionDescriptor): module_name = target_class.__module__ class_name = target_class.__name__ # Use a random uuid as function hash to solve actor name conflict. - return cls( - module_name, "__init__", class_name, - hashlib.shake_128( - uuid.uuid4().bytes).hexdigest(ray_constants.ID_SIZE)) + return cls(module_name, "__init__", class_name, str(uuid.uuid4())) @property def module_name(self): @@ -273,14 +268,14 @@ cdef class PythonFunctionDescriptor(FunctionDescriptor): Returns: ray.ObjectRef to represent the function descriptor. """ - function_id_hash = hashlib.shake_128() + function_id_hash = hashlib.sha1() # Include the function module and name in the hash. function_id_hash.update(self.typed_descriptor.ModuleName()) function_id_hash.update(self.typed_descriptor.FunctionName()) function_id_hash.update(self.typed_descriptor.ClassName()) function_id_hash.update(self.typed_descriptor.FunctionHash()) # Compute the function ID. - function_id = function_id_hash.digest(ray_constants.ID_SIZE) + function_id = function_id_hash.digest() return ray.FunctionID(function_id) def is_actor_method(self): diff --git a/python/ray/includes/unique_ids.pxi b/python/ray/includes/unique_ids.pxi index 52a6730e6..bcf766829 100644 --- a/python/ray/includes/unique_ids.pxi +++ b/python/ray/includes/unique_ids.pxi @@ -31,7 +31,7 @@ def check_id(b, size=kUniqueIDSize): raise TypeError("Unsupported type: " + str(type(b))) if len(b) != size: raise ValueError("ID string needs to have length " + - str(size) + ", got " + str(len(b))) + str(size)) cdef extern from "ray/common/constants.h" nogil: diff --git a/python/ray/log_monitor.py b/python/ray/log_monitor.py index d6b3a314e..ac5fa5296 100644 --- a/python/ray/log_monitor.py +++ b/python/ray/log_monitor.py @@ -22,7 +22,7 @@ from ray.ray_logging import setup_component_logger logger = logging.getLogger(__name__) # The groups are worker id, job id, and pid. -JOB_LOG_PATTERN = re.compile(".*worker-([0-9a-f]+)-(\d+)-(\d+)") +JOB_LOG_PATTERN = re.compile(".*worker-([0-9a-f]{40})-(\d+)-(\d+)") class LogFileInfo: diff --git a/python/ray/ray_constants.py b/python/ray/ray_constants.py index 30b3b5c7b..be717ca3c 100644 --- a/python/ray/ray_constants.py +++ b/python/ray/ray_constants.py @@ -19,7 +19,7 @@ def env_bool(key, default): return default -ID_SIZE = 28 +ID_SIZE = 20 # The default maximum number of bytes to allocate to the object store unless # overridden by the user. diff --git a/python/ray/serialization.py b/python/ray/serialization.py index 9a24f3ccc..dc9a2c40e 100644 --- a/python/ray/serialization.py +++ b/python/ray/serialization.py @@ -74,8 +74,7 @@ def _try_to_compute_deterministic_class_id(cls, depth=5): new_class_id = pickle.dumps(pickle.loads(class_id)) if new_class_id == class_id: # We appear to have reached a fix point, so use this as the ID. - return hashlib.shake_128(new_class_id).digest( - ray_constants.ID_SIZE) + return hashlib.sha1(new_class_id).digest() class_id = new_class_id # We have not reached a fixed point, so we may end up with a different @@ -83,7 +82,7 @@ def _try_to_compute_deterministic_class_id(cls, depth=5): # same class definition being exported many many times. logger.warning( f"WARNING: Could not produce a deterministic class ID for class {cls}") - return hashlib.shake_128(new_class_id).digest(ray_constants.ID_SIZE) + return hashlib.sha1(new_class_id).digest() def object_ref_deserializer(reduced_obj_ref, owner_address): diff --git a/python/ray/tests/test_advanced_3.py b/python/ray/tests/test_advanced_3.py index b1bc25fbb..7f1e8e639 100644 --- a/python/ray/tests/test_advanced_3.py +++ b/python/ray/tests/test_advanced_3.py @@ -284,14 +284,14 @@ def test_workers(shutdown_only): def test_object_ref_properties(): - id_bytes = b"0011223344556677889900001111" + id_bytes = b"00112233445566778899" object_ref = ray.ObjectRef(id_bytes) assert object_ref.binary() == id_bytes object_ref = ray.ObjectRef.nil() assert object_ref.is_nil() - with pytest.raises(ValueError, match=r".*needs to have length.*"): + with pytest.raises(ValueError, match=r".*needs to have length 20.*"): ray.ObjectRef(id_bytes + b"1234") - with pytest.raises(ValueError, match=r".*needs to have length.*"): + with pytest.raises(ValueError, match=r".*needs to have length 20.*"): ray.ObjectRef(b"0123456789") object_ref = ray.ObjectRef.from_random() assert not object_ref.is_nil() diff --git a/python/ray/tests/test_multi_node.py b/python/ray/tests/test_multi_node.py index fbce475c1..cb206112d 100644 --- a/python/ray/tests/test_multi_node.py +++ b/python/ray/tests/test_multi_node.py @@ -741,10 +741,10 @@ ray.get(main_wait.release.remote()) driver1_out_split = driver1_out.split("\n") driver2_out_split = driver2_out.split("\n") - assert driver1_out_split[0][-1] == "1", driver1_out_split - assert driver1_out_split[1][-1] == "2", driver1_out_split - assert driver2_out_split[0][-1] == "3", driver2_out_split - assert driver2_out_split[1][-1] == "4", driver2_out_split + assert driver1_out_split[0][-1] == "1" + assert driver1_out_split[1][-1] == "2" + assert driver2_out_split[0][-1] == "3" + assert driver2_out_split[1][-1] == "4" if __name__ == "__main__": diff --git a/python/ray/utils.py b/python/ray/utils.py index 2704e07cc..a3940d6e8 100644 --- a/python/ray/utils.py +++ b/python/ray/utils.py @@ -50,9 +50,9 @@ def get_ray_temp_dir(): def _random_string(): - id_hash = hashlib.shake_128() + id_hash = hashlib.sha1() id_hash.update(uuid.uuid4().bytes) - id_bytes = id_hash.digest(ray_constants.ID_SIZE) + id_bytes = id_hash.digest() assert len(id_bytes) == ray_constants.ID_SIZE return id_bytes diff --git a/python/ray/worker.py b/python/ray/worker.py index 627037098..495478ad7 100644 --- a/python/ray/worker.py +++ b/python/ray/worker.py @@ -345,8 +345,7 @@ class Worker: # actually run the function locally. pickled_function = pickle.dumps(function) - function_to_run_id = hashlib.shake_128(pickled_function).digest( - ray_constants.ID_SIZE) + function_to_run_id = hashlib.sha1(pickled_function).digest() key = b"FunctionsToRun:" + function_to_run_id # First run the function on the driver. # We always run the task locally. diff --git a/src/ray/common/constants.h b/src/ray/common/constants.h index 3a3461f2c..1636846f0 100644 --- a/src/ray/common/constants.h +++ b/src/ray/common/constants.h @@ -18,7 +18,7 @@ #include /// Length of Ray full-length IDs in bytes. -constexpr size_t kUniqueIDSize = 28; +constexpr size_t kUniqueIDSize = 20; /// An ObjectID's bytes are split into the task ID itself and the index of the /// object's creation. This is the maximum width of the object index in bits. diff --git a/src/ray/common/id.h b/src/ray/common/id.h index bd55b27e5..d12ba550d 100644 --- a/src/ray/common/id.h +++ b/src/ray/common/id.h @@ -124,7 +124,7 @@ class JobID : public BaseID { class ActorID : public BaseID { private: - static constexpr size_t kUniqueBytesLength = 12; + static constexpr size_t kUniqueBytesLength = 4; public: /// Length of `ActorID` in bytes. diff --git a/src/ray/core_worker/actor_manager.cc b/src/ray/core_worker/actor_manager.cc index 6b931082a..e6ef4fc87 100644 --- a/src/ray/core_worker/actor_manager.cc +++ b/src/ray/core_worker/actor_manager.cc @@ -91,8 +91,6 @@ bool ActorManager::AddActorHandle(std::unique_ptr actor_handle, std::placeholders::_1, std::placeholders::_2); RAY_CHECK_OK(gcs_client_->Actors().AsyncSubscribe( actor_id, actor_notification_callback, nullptr)); - } else { - RAY_LOG(ERROR) << "Actor handle already exists " << actor_id.Hex(); } return inserted; From 5d3c9c8861ee658b7931bb374603f86987c17433 Mon Sep 17 00:00:00 2001 From: Amog Kamsetty Date: Sat, 19 Dec 2020 00:40:02 -0800 Subject: [PATCH 34/88] [Tune] Mlflow Integration (#12840) Co-authored-by: Kai Fricke Co-authored-by: Richard Liaw --- doc/source/conf.py | 1 + doc/source/images/mlflow.png | Bin 0 -> 3851 bytes doc/source/tune/_tutorials/overview.rst | 12 +- doc/source/tune/_tutorials/tune-mlflow.rst | 47 +++ doc/source/tune/api_docs/logging.rst | 17 +- doc/source/tune/examples/index.rst | 1 + .../tune/examples/mlflow_ptl_example.rst | 6 + python/ray/tune/BUILD | 32 +- python/ray/tune/examples/mlflow_example.py | 96 ++++- python/ray/tune/examples/mlflow_ptl.py | 93 +++++ python/ray/tune/function_runner.py | 9 +- python/ray/tune/integration/mlflow.py | 366 ++++++++++++++++++ python/ray/tune/integration/wandb.py | 5 +- python/ray/tune/logger.py | 38 +- python/ray/tune/tests/test_dependency.py | 1 + .../ray/tune/tests/test_integration_mlflow.py | 306 +++++++++++++++ python/ray/tune/tests/test_sample.py | 2 +- python/requirements_tune.txt | 1 + 18 files changed, 958 insertions(+), 75 deletions(-) create mode 100644 doc/source/images/mlflow.png create mode 100644 doc/source/tune/_tutorials/tune-mlflow.rst create mode 100644 doc/source/tune/examples/mlflow_ptl_example.rst create mode 100644 python/ray/tune/examples/mlflow_ptl.py create mode 100644 python/ray/tune/integration/mlflow.py create mode 100644 python/ray/tune/tests/test_integration_mlflow.py diff --git a/doc/source/conf.py b/doc/source/conf.py index c69f73760..13f075b29 100644 --- a/doc/source/conf.py +++ b/doc/source/conf.py @@ -41,6 +41,7 @@ MOCK_MODULES = [ "horovod", "horovod.ray", "kubernetes", + "mlflow", "mxnet", "mxnet.model", "psutil", diff --git a/doc/source/images/mlflow.png b/doc/source/images/mlflow.png new file mode 100644 index 0000000000000000000000000000000000000000..03b96d5daf39ad485a8d4cd614abbc9af21b8518 GIT binary patch literal 3851 zcmV+m5A^VfP)gW07J;^0FK}Qkl+A}-}d(Q0E^$~=;%SH=Y)iW+}zxKeSH9w;p6oDsn_$Vsj2w= z{&l_V?fCuj`u;J+>$KYR)9U%m=lG@5^7Q=v@$vCDH#cIl=}xZbIi%)|#O{%ikwryC zz~1#)&h8_g0;Wv z5(R3zE=?h230s>$LjaSuX%gCWyY2hG)nyBwGa5(7HD42 z&dW)qSga(Y^7)`u-S>nJ6xG=%j>53GMi_=sT$${9RtJdoc@%7D-k?(3_q+}c)$>T@ zc`n%(`<~drp<9uNE*69CzGrq|sEs3$=($w4_dU16LbI%6x>(uw+zt!PzN|bJLVdpP z$sHETTBg0K-Y(7qEz`xwd(~USX{=}3tLkl{A$rqM92b*H9FO-sy+gveLUkA=gJEO& z;n2H!+e1fj^P-E+z9)5AbOHinj#=d8C zuozAt6P zz-W%Vr}vLA0sq4mzKr4&zmEK|yev^<_`)>}V8XI_aFco~mEKln{A}S&HCn~Gd2lDp z{#Xtk&Q!zlZDd)n?UGmE=q;!#)eXos1(pR{x>^&owlm{r3s@AITYbrsE7fQfE$WwgE5U3#6MnvMq#CVaXt9d)x%jx! zk?J<2^(Vpnu+9!;N1 zbh@n>KU)~79+x?afWGp1nO5J>=VG&9eE-Iggqx9QV2w;xhDw?S_6;aEe zDw!5h`t5aQsweJL)1YSH(|Q0C)kEzTF1=I}JwelYkuOy>Qr$J|chN9SJZ)HntQ>n< zuVAKnpx?x$qnhXmx(FHe|LKp)CrnhU+c0d1Nur}gt>V$A^$KRHN6u7Nh*wS1t+s`+ zqkxHOcU4VYI-1t=o6ZYHsvEx|1QYT|K}>Sekk<3n?gb;&wJO6YhUp{>OIA`Y5<(s#!#1SXAZ!SNrM%jijA@JQL0uRzcH+js0sC$C{Z>Ej3ac`V>UHFzZ9Ap2lL08@Lon^XK^gZUp< zHzu~*-BSY1g&W=09K&I!(`mQcn*C>;)A^X|#i03^MuGBhyBzT4eM!MBy~4?NbwKG2 ze;ki+p&Cm-%<*3jswj17z240?LUT&@-DmBETQM~beR514^I1$oF%gx(O)}OIEZ>l}(}<^zcF;dd2qLZ_Rf)V4O#yexGF7 ztz>% zm&U5lcXZKK1F8O^<*`)yvH1GSd&7t~(hVJ`MyqJairVWEZVWsx`=$+ZL3AS<@A4!XA>lCJElb~mf5LpfK{{I)aEI5O2Gd`70vt{#^9;_vFOpJZe^i4^w%x(TTBn5y|mRW4xT$l4VZ z*!sRUZtwnR6~>2kM04_ie{#uhnw@Hxy=ASiNhOc#njRziF*009HH{OyQCJMP`e_@} z15IJ+IaEr!sn`VT_eNpyo~ASYh(v!UT0#kFcz~XnYP5>6IVm(W)g^$NN1VF5g)>dA z+q_ubl3Ttiu*h}8SX$2D-d(P&>yuJSC#Gv;r5ci=a_S)k-3G=D7h(F& zc#+g=V~aST$|C9ul4v`a5qp99t}dr%fCw=S`UZvwSq^p)_E}b{u{V}30WJI(R#;7Z z>NB$!Q>$oh2fN&=XsMu~0SrqB zrTPP674epld<|F}pvr}01ELuWqX{n>d# z8L^4BGQQsCmKc6lAqkRNjV4Jwj+#*3O`cUlBhEkzFK?{K7ATxi(_)wo(0oA)q7T>?hV zsjAlJ{x?XiD|#_Bah?i1S)K`(i4)?Eq^}`Oq@x(*g=(~lBki~)sIp;h;X_Kbeo_1p zX)XY@P+5#c>%{$7iaIgQskb*HZyX49Zm1rBRkZZh1G9y0VVXWYvv&theUKQ^dk}&= zzBS~%TdzfsZ0klT*T@^ebX#7icCm_8Y89DVNCFsp*;sAk$;4nnXVPjs$!n}r0n`O~ z%3|p@v5sZ@p zfxr?nr7I0Z*iSt}>Vm9%LW5A}g=*LpW3h@gcr5Hg;(^#JPKjB>|HsN5NJyS0r^npe zYZ-((FI2l&MViP|M2ApqTozvlbSN|pNn_69#52|ZeQDnA? zU3~yUv(?rV2)PnpiGU)GD$XxT-G*Ywp&3KgsHt!OD zVsMK##3j1ze2e(8ZxZU9Pz~#fwm!=A4NzBmuhMjrEnIAD?#jDi3@{vRQkRG33{XC( zW)t4bd{8}gu!@i^QS;iQ+Qz&xK(n1Od$+r0J3@5x+}kS~g*qoxqg6ETL9{e{ zBTI{C`VzU>b`(y|&GzDXk< zV{<0ct&oOyMfHIxy3Ndl;uL9|Zo^a_sCJ1JJLUjp4e((ei+hX7^t_+2 zR(^q8@~R%rEB72}tK+n?k_W2623h*mK$V4M!~gmlDY-(+M9opK4RvpTC<8ThoOTyr z4ybmriWimuhNDq7wbBW-JtqaJ77c^qg6I`s4yblH-Uvf@`#Q@|kz;tQeQjxvLSZ1c z98tVN{npkTP~CE{iV(oq53Q>k+fzmcv%vQ@>(T|hf1P1nGzV0>#EPc61O?T^s60Yb z8)m+{NM{#LnWi!9*UkaeXcf(EsGSWc$D!Vx>|c$PJ$p6b^?AR41+3v|gW1yPaybW_z`uSlOW1*AJYA*~S$LuPE$85X2x+jfRk19cnYd6#~XP(1(1e3`ZHH|3a6uRJ)Qz|mMq5` zc{~A44CcI5+#N7olN|e*tGfhv)!r)Z5ZIr`A6c|RU|QJHoYUh8;0wD)Qz99#29ij^YI(sk$rjk>VCHHPp|AmUc}ZPLTKVtf}@VzHbo|1JgeHu~it( zcQ%x?c>D(_b+jAh5$2})<1hdlhJfiLL-oIh1FkijgoAUR`q(nWqbI^>bI||* literal 0 HcmV?d00001 diff --git a/doc/source/tune/_tutorials/overview.rst b/doc/source/tune/_tutorials/overview.rst index dfbd986b0..0517c2f0a 100644 --- a/doc/source/tune/_tutorials/overview.rst +++ b/doc/source/tune/_tutorials/overview.rst @@ -70,6 +70,11 @@ Take a look at any of the below tutorials to get started with Tune. :figure: /images/wandb_logo.png :description: :doc:`Track your experiment process with the Weights & Biases tools ` +.. customgalleryitem:: + :tooltip: Use MLFlow with Ray Tune. + :figure: /images/mlflow.png + :description: :doc:`Log and track your hyperparameter sweep with MLFlow Tracking & AutoLogging ` + .. raw:: html @@ -81,12 +86,13 @@ Take a look at any of the below tutorials to get started with Tune. tune-tutorial.rst tune-advanced-tutorial.rst - tune-lifecycle.rst tune-distributed.rst - tune-sklearn.rst + tune-lifecycle.rst + tune-mlflow.rst tune-pytorch-cifar.rst tune-pytorch-lightning.rst tune-serve-integration-mnist.rst + tune-sklearn.rst tune-xgboost.rst tune-wandb.rst @@ -156,4 +162,4 @@ Check out: .. _tune-faq: -.. include:: _faq.rst \ No newline at end of file +.. include:: _faq.rst diff --git a/doc/source/tune/_tutorials/tune-mlflow.rst b/doc/source/tune/_tutorials/tune-mlflow.rst new file mode 100644 index 000000000..6b5519e3f --- /dev/null +++ b/doc/source/tune/_tutorials/tune-mlflow.rst @@ -0,0 +1,47 @@ +.. _tune-mlflow: + +Using MLFlow with Tune +====================== + +`MLFlow `_ is an open source platform to manage the ML lifecycle, including experimentation, +reproducibility, deployment, and a central model registry. It currently offers four components, including +MLFlow Tracking to record and query experiments, including code, data, config, and results. + +.. image:: /images/mlflow.png + :height: 80px + :alt: MLflow + :align: center + :target: https://www.mlflow.org/ + +Ray Tune currently offers two lightweight integrations for MLFlow Tracking. +One is the :ref:`MLFlowLoggerCallback `, which automatically logs +metrics reported to Tune to the MLFlow Tracking API. + +The other one is the :ref:`@mlflow_mixin ` decorator, which can be +used with the function API. It automatically +initializes the MLFlow API with Tune's training information and creates a run for each Tune trial. +Then within your training function, you can just use the +MLFlow like you would normally do, e.g. using ``mlflow.log_metrics()`` or even ``mlflow.autolog()`` +to log to your training process. + +Please :doc:`see here ` for a full example on how you can use either the +MLFlowLoggerCallback or the mlflow_mixin. + +MLFlow AutoLogging +------------------ +You can also check out :doc:`here ` for an example on how you can leverage MLflow +autologging, in this case with Pytorch Lightning + +MLFlow Logger API +----------------- +.. _tune-mlflow-logger: + +.. autoclass:: ray.tune.integration.mlflow.MLFlowLoggerCallback + :noindex: + +MLFlow Mixin API +---------------- +.. _tune-mlflow-mixin: + +.. autofunction:: ray.tune.integration.mlflow.mlflow_mixin + :noindex: diff --git a/doc/source/tune/api_docs/logging.rst b/doc/source/tune/api_docs/logging.rst index 240ec97af..7e8784829 100644 --- a/doc/source/tune/api_docs/logging.rst +++ b/doc/source/tune/api_docs/logging.rst @@ -70,11 +70,7 @@ An example of creating a custom logger can be found in :doc:`/tune/examples/logg Trainable Logging ----------------- -By default, Tune only logs the *training result dictionaries* from your Trainable. However, you may want to visualize the model weights, model graph, or use a custom logging library that requires multi-process logging. For example, you may want to do this if: - - * you're using `Weights and Biases `_ - * you're using `MLFlow `__ - * you're trying to log images to Tensorboard. +By default, Tune only logs the *training result dictionaries* from your Trainable. However, you may want to visualize the model weights, model graph, or use a custom logging library that requires multi-process logging. For example, you may want to do this if you're trying to log images to Tensorboard. You can do this in the trainable, as shown below: @@ -163,12 +159,17 @@ CSVLogger .. autoclass:: ray.tune.logger.CSVLoggerCallback -MLFLowLogger +MLFlowLogger ------------ -Tune also provides a default logger for `MLFlow `_. You can install MLFlow via ``pip install mlflow``. An example can be found in :doc:`/tune/examples/mlflow_example`. Note that this currently does not include artifact logging support. For this, you can use the native MLFlow APIs inside your Trainable definition. +Tune also provides a default logger for `MLFlow `_. You can install MLFlow via ``pip install mlflow``. +You can see the :doc:`tutorial here `. -.. autoclass:: ray.tune.logger.MLFLowLogger +WandbLogger +----------- + +Tune also provides a default logger for `Weights & Biases `_. You can install Wandb via ``pip install wandb``. +You can see the :doc:`tutorial here ` .. _logger-interface: diff --git a/doc/source/tune/examples/index.rst b/doc/source/tune/examples/index.rst index 54852d550..4f3594cf6 100644 --- a/doc/source/tune/examples/index.rst +++ b/doc/source/tune/examples/index.rst @@ -88,6 +88,7 @@ Wandb, MLFlow - :ref:`Tutorial ` for using `wandb `__ with Ray Tune - :doc:`/tune/examples/wandb_example`: Example for using `Weights and Biases `__ with Ray Tune. - :doc:`/tune/examples/mlflow_example`: Example for using `MLFlow `__ with Ray Tune. +- :doc:`/tune/examples/mlflow_ptl_example`: Example for using `MLFlow `__ and `Pytorch Lightning `_ with Ray Tune. Tensorflow/Keras ~~~~~~~~~~~~~~~~ diff --git a/doc/source/tune/examples/mlflow_ptl_example.rst b/doc/source/tune/examples/mlflow_ptl_example.rst new file mode 100644 index 000000000..73c07499b --- /dev/null +++ b/doc/source/tune/examples/mlflow_ptl_example.rst @@ -0,0 +1,6 @@ +:orphan: + +mlflow_ptl_example +~~~~~~~~~~~~~~~~~~ + +.. literalinclude:: /../../python/ray/tune/examples/mlflow_ptl.py diff --git a/python/ray/tune/BUILD b/python/ray/tune/BUILD index 70fd218d5..3b5757eb3 100644 --- a/python/ray/tune/BUILD +++ b/python/ray/tune/BUILD @@ -125,6 +125,14 @@ py_test( tags = ["exclusive"], ) +py_test( + name = "test_integration_mlflow", + size = "small", + srcs = ["tests/test_integration_mlflow.py"], + deps = [":tune_lib"], + tags = ["exclusive"] +) + py_test( name = "test_logger", size = "small", @@ -473,15 +481,23 @@ py_test( args = ["--smoke-test"] ) -# Commenting out for now because it is not idempotent -# py_test( -# name = "mlflow_example", -# size = "medium", -# srcs = ["examples/mlflow_example.py"], -# deps = [":tune_lib"], -# tags = ["exclusive", "example"] -# ) +py_test( + name = "mlflow_example", + size = "medium", + srcs = ["examples/mlflow_example.py"], + deps = [":tune_lib"], + tags = ["exclusive", "example"] +) +# Comment out for now until we sort out our dependencies. +#py_test( +# name = "mlflow_ptl", +# size = "medium", +# srcs = ["examples/mlflow_ptl.py"], +# deps = [":tune_lib"], +# tags = ["exclusive", "example", "py37", "pytorch"], +# args = ["--smoke-test"] +#) py_test( name = "mnist_pytorch", size = "small", diff --git a/python/ray/tune/examples/mlflow_example.py b/python/ray/tune/examples/mlflow_example.py index 875c7837b..e0f290b29 100644 --- a/python/ray/tune/examples/mlflow_example.py +++ b/python/ray/tune/examples/mlflow_example.py @@ -1,17 +1,14 @@ #!/usr/bin/env python -"""Simple MLFLow Logger example. - -This uses a simple MLFlow logger. One limitation of this is that there is -no artifact support; to save artifacts with Tune and MLFlow, you will need to -start a MLFlow run inside the Trainable function/class. - +"""Examples using MLFlowLoggerCallback and mlflow_mixin. """ -import mlflow -from mlflow.tracking import MlflowClient +import os +import tempfile import time +import mlflow + from ray import tune -from ray.tune.logger import MLFLowLogger, DEFAULT_LOGGERS +from ray.tune.integration.mlflow import MLFlowLoggerCallback, mlflow_mixin def evaluation_fn(step, width, height): @@ -25,27 +22,84 @@ def easy_objective(config): for step in range(config.get("steps", 100)): # Iterative training function - can be any arbitrary training procedure intermediate_score = evaluation_fn(step, width, height) - # Feed the score back back to Tune. + # Feed the score back to Tune. tune.report(iterations=step, mean_loss=intermediate_score) time.sleep(0.1) -if __name__ == "__main__": - client = MlflowClient() - experiment_id = client.create_experiment("test") - - trials = tune.run( +def tune_function(mlflow_tracking_uri, finish_fast=False): + tune.run( easy_objective, name="mlflow", num_samples=5, - loggers=DEFAULT_LOGGERS + (MLFLowLogger, ), + callbacks=[ + MLFlowLoggerCallback( + tracking_uri=mlflow_tracking_uri, + experiment_name="example", + save_artifact=True) + ], config={ - "logger_config": { - "mlflow_experiment_id": experiment_id, - }, "width": tune.randint(10, 100), "height": tune.randint(0, 100), + "steps": 5 if finish_fast else 100, }) - df = mlflow.search_runs([experiment_id]) - print(df) + +@mlflow_mixin +def decorated_easy_objective(config): + # Hyperparameters + width, height = config["width"], config["height"] + + for step in range(config.get("steps", 100)): + # Iterative training function - can be any arbitrary training procedure + intermediate_score = evaluation_fn(step, width, height) + # Log the metrics to mlflow + mlflow.log_metrics(dict(mean_loss=intermediate_score), step=step) + # Feed the score back to Tune. + tune.report(iterations=step, mean_loss=intermediate_score) + time.sleep(0.1) + + +def tune_decorated(mlflow_tracking_uri, finish_fast=False): + # Set the experiment, or create a new one if does not exist yet. + mlflow.set_tracking_uri(mlflow_tracking_uri) + mlflow.set_experiment(experiment_name="mixin_example") + tune.run( + decorated_easy_objective, + name="mlflow", + num_samples=5, + config={ + "width": tune.randint(10, 100), + "height": tune.randint(0, 100), + "steps": 5 if finish_fast else 100, + "mlflow": { + "experiment_name": "mixin_example", + "tracking_uri": mlflow.get_tracking_uri() + } + }) + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument( + "--smoke-test", action="store_true", help="Finish quickly for testing") + args, _ = parser.parse_known_args() + + if args.smoke_test: + mlflow_tracking_uri = os.path.join(tempfile.gettempdir(), "mlruns") + else: + mlflow_tracking_uri = None + + tune_function(mlflow_tracking_uri, finish_fast=args.smoke_test) + if not args.smoke_test: + df = mlflow.search_runs( + [mlflow.get_experiment_by_name("example").experiment_id]) + print(df) + + tune_decorated(mlflow_tracking_uri, finish_fast=args.smoke_test) + if not args.smoke_test: + df = mlflow.search_runs( + [mlflow.get_experiment_by_name("mixin_example").experiment_id]) + print(df) diff --git a/python/ray/tune/examples/mlflow_ptl.py b/python/ray/tune/examples/mlflow_ptl.py new file mode 100644 index 000000000..5957a7e34 --- /dev/null +++ b/python/ray/tune/examples/mlflow_ptl.py @@ -0,0 +1,93 @@ +"""An example showing how to use Pytorch Lightning training, Ray Tune +HPO, and MLFlow autologging all together.""" +import os +import tempfile + +import pytorch_lightning as pl +from pl_bolts.datamodules import MNISTDataModule + +import mlflow + +from ray import tune +from ray.tune.integration.mlflow import mlflow_mixin +from ray.tune.integration.pytorch_lightning import TuneReportCallback +from ray.tune.examples.mnist_ptl_mini import LightningMNISTClassifier + + +@mlflow_mixin +def train_mnist_tune(config, data_dir=None, num_epochs=10, num_gpus=0): + model = LightningMNISTClassifier(config, data_dir) + dm = MNISTDataModule( + data_dir=data_dir, num_workers=1, batch_size=config["batch_size"]) + metrics = {"loss": "ptl/val_loss", "acc": "ptl/val_accuracy"} + mlflow.pytorch.autolog() + trainer = pl.Trainer( + max_epochs=num_epochs, + gpus=num_gpus, + progress_bar_refresh_rate=0, + callbacks=[TuneReportCallback(metrics, on="validation_end")]) + trainer.fit(model, dm) + + +def tune_mnist(num_samples=10, + num_epochs=10, + gpus_per_trial=0, + tracking_uri=None): + data_dir = os.path.join(tempfile.gettempdir(), "mnist_data_") + # Download data + MNISTDataModule(data_dir=data_dir).prepare_data() + + # Set the MLFlow experiment, or create it if it does not exist. + mlflow.set_tracking_uri(tracking_uri) + mlflow.set_experiment("ptl_autologging_test") + + config = { + "layer_1": tune.choice([32, 64, 128]), + "layer_2": tune.choice([64, 128, 256]), + "lr": tune.loguniform(1e-4, 1e-1), + "batch_size": tune.choice([32, 64, 128]), + "mlflow": { + "experiment_name": "ptl_autologging_test", + "tracking_uri": mlflow.get_tracking_uri() + }, + "data_dir": os.path.join(tempfile.gettempdir(), "mnist_data_"), + "num_epochs": num_epochs + } + + trainable = tune.with_parameters( + train_mnist_tune, + data_dir=data_dir, + num_epochs=num_epochs, + num_gpus=gpus_per_trial) + + analysis = tune.run( + trainable, + resources_per_trial={ + "cpu": 1, + "gpu": gpus_per_trial + }, + metric="loss", + mode="min", + config=config, + num_samples=num_samples, + name="tune_mnist") + + print("Best hyperparameters found were: ", analysis.best_config) + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument( + "--smoke-test", action="store_true", help="Finish quickly for testing") + args, _ = parser.parse_known_args() + + if args.smoke_test: + tune_mnist( + num_samples=1, + num_epochs=1, + gpus_per_trial=0, + tracking_uri=os.path.join(tempfile.gettempdir(), "mlruns")) + else: + tune_mnist(num_samples=10, num_epochs=10, gpus_per_trial=0) diff --git a/python/ray/tune/function_runner.py b/python/ray/tune/function_runner.py index 79e0f5da9..9da6b2601 100644 --- a/python/ray/tune/function_runner.py +++ b/python/ray/tune/function_runner.py @@ -509,8 +509,9 @@ class FunctionRunner(Trainable): try: err_tb_str = self._error_queue.get( block=block, timeout=ERROR_FETCH_TIMEOUT) - raise TuneError(("Trial raised an exception. Traceback:\n{}" - .format(err_tb_str))) + raise TuneError( + ("Trial raised an exception. Traceback:\n{}".format(err_tb_str) + )) except queue.Empty: pass @@ -649,6 +650,10 @@ def with_parameters(fn, **kwargs): def _inner(config): inner(config, checkpoint_dir=None) + if hasattr(fn, "__mixins__"): + _inner.__mixins__ = fn.__mixins__ return _inner + if hasattr(fn, "__mixins__"): + inner.__mixins__ = fn.__mixins__ return inner diff --git a/python/ray/tune/integration/mlflow.py b/python/ray/tune/integration/mlflow.py new file mode 100644 index 000000000..1d1e01d48 --- /dev/null +++ b/python/ray/tune/integration/mlflow.py @@ -0,0 +1,366 @@ +import os +from typing import Dict, Callable, Optional +import logging + +from ray.tune.trainable import Trainable +from ray.tune.logger import Logger, LoggerCallback +from ray.tune.result import TRAINING_ITERATION +from ray.tune.trial import Trial + +logger = logging.getLogger(__name__) + + +def _import_mlflow(): + try: + import mlflow + except ImportError: + mlflow = None + return mlflow + + +class MLFlowLoggerCallback(LoggerCallback): + """MLFlow Logger to automatically log Tune results and config to MLFlow. + + MLFlow (https://mlflow.org) Tracking is an open source library for + recording and querying experiments. This Ray Tune ``LoggerCallback`` + sends information (config parameters, training results & metrics, + and artifacts) to MLFlow for automatic experiment tracking. + + Args: + tracking_uri (str): The tracking URI for where to manage experiments + and runs. This can either be a local file path or a remote server. + This arg gets passed directly to mlflow.tracking.MlflowClient + initialization. When using Tune in a multi-node setting, make sure + to set this to a remote server and not a local file path. + registry_uri (str): The registry URI that gets passed directly to + mlflow.tracking.MlflowClient initialization. + experiment_name (str): The experiment name to use for this Tune run. + If None is passed in here, the Logger will automatically then + check the MLFLOW_EXPERIMENT_NAME and then the MLFLOW_EXPERIMENT_ID + environment variables to determine the experiment name. + If the experiment with the name already exists with MlFlow, + it will be reused. If not, a new experiment will be created with + that name. + save_artifact (bool): If set to True, automatically save the entire + contents of the Tune local_dir as an artifact to the + corresponding run in MlFlow. + + Example: + + .. code-block:: python + + from ray.tune.integration.mlflow import MLFlowLoggerCallback + tune.run( + train_fn, + config={ + # define search space here + "parameter_1": tune.choice([1, 2, 3]), + "parameter_2": tune.choice([4, 5, 6]), + }, + callbacks=[MLFlowLoggerCallback( + experiment_name="experiment1", + save_artifact=True)]) + + """ + + def __init__(self, + tracking_uri: Optional[str] = None, + registry_uri: Optional[str] = None, + experiment_name: Optional[str] = None, + save_artifact: bool = False): + + mlflow = _import_mlflow() + if mlflow is None: + raise RuntimeError("MLFlow has not been installed. Please `pip " + "install mlflow` to use the MLFlowLogger.") + + from mlflow.tracking import MlflowClient + self.client = MlflowClient( + tracking_uri=tracking_uri, registry_uri=registry_uri) + + if experiment_name is None: + # If no name is passed in, then check env vars. + # First check if experiment_name env var is set. + experiment_name = os.environ.get("MLFLOW_EXPERIMENT_NAME") + + if experiment_name is not None: + # First check if experiment with name exists. + experiment = self.client.get_experiment_by_name(experiment_name) + if experiment is not None: + # If it already exists then get the id. + experiment_id = experiment.experiment_id + else: + # If it does not exist, create the experiment. + experiment_id = self.client.create_experiment( + name=experiment_name) + else: + # No experiment_name is passed in and name env var is not set. + # Now check the experiment id env var. + experiment_id = os.environ.get("MLFLOW_EXPERIMENT_ID") + # Confirm that an experiment with this id exists. + if experiment_id is None or self.client.get_experiment( + experiment_id) is None: + raise ValueError("No experiment_name passed, " + "MLFLOW_EXPERIMENT_NAME env var is not " + "set, and MLFLOW_EXPERIMENT_ID either " + "is not set or does not exist. Please " + "set one of these to use the " + "MLFlowLoggerCallback.") + + # At this point, experiment_id should be set. + self.experiment_id = experiment_id + self.save_artifact = save_artifact + + self._trial_runs = {} + + def log_trial_start(self, trial: "Trial"): + # Create run if not already exists. + if trial not in self._trial_runs: + run = self.client.create_run( + experiment_id=self.experiment_id, + tags={"trial_name": str(trial)}) + self._trial_runs[trial] = run.info.run_id + + run_id = self._trial_runs[trial] + + # Log the config parameters. + config = trial.config + + for key, value in config.items(): + self.client.log_param(run_id=run_id, key=key, value=value) + + def log_trial_result(self, iteration: int, trial: "Trial", result: Dict): + run_id = self._trial_runs[trial] + for key, value in result.items(): + try: + value = float(value) + except (ValueError, TypeError): + logger.debug("Cannot log key {} with value {} since the " + "value cannot be converted to float.".format( + key, value)) + continue + self.client.log_metric( + run_id=run_id, key=key, value=value, step=iteration) + + def log_trial_end(self, trial: "Trial", failed: bool = False): + run_id = self._trial_runs[trial] + + # Log the artifact if set_artifact is set to True. + if self.save_artifact: + self.client.log_artifacts(run_id, local_dir=trial.logdir) + + # Stop the run once trial finishes. + status = "FINISHED" if not failed else "FAILED" + self.client.set_terminated(run_id=run_id, status=status) + + +class MLFlowLogger(Logger): + """MLFlow logger using the deprecated Logger API. + + Requires the experiment configuration to have a MLFlow Experiment ID + or manually set the proper environment variables. + """ + + _experiment_logger_cls = MLFlowLoggerCallback + + def _init(self): + mlflow = _import_mlflow() + logger_config = self.config.pop("logger_config", {}) + tracking_uri = logger_config.get("mlflow_tracking_uri") + registry_uri = logger_config.get("mlflow_registry_uri") + + experiment_id = logger_config.get("mlflow_experiment_id") + if experiment_id is None or not mlflow.get_experiment(experiment_id): + raise ValueError( + "You must provide a valid `mlflow_experiment_id` " + "in your `logger_config` dict in the `config` " + "dict passed to `tune.run`. " + "Are you sure you passed in a `experiment_id` and " + "the experiment exists?") + else: + experiment_name = mlflow.get_experiment(experiment_id).name + + self._trial_experiment_logger = self._experiment_logger_cls( + tracking_uri, registry_uri, experiment_name) + + self._trial_experiment_logger.log_trial_start(self.trial) + + def on_result(self, result: Dict): + self._trial_experiment_logger.log_trial_result( + iteration=result.get(TRAINING_ITERATION), + trial=self.trial, + result=result) + + def close(self): + self._trial_experiment_logger.log_trial_end( + trial=self.trial, failed=False) + del self._trial_experiment_logger + + +def mlflow_mixin(func: Callable): + """mlflow_mixin + + MLFlow (https://mlflow.org) Tracking is an open source library for + recording and querying experiments. This Ray Tune Trainable mixin helps + initialize the MLflow API for use with the ``Trainable`` class or the + ``@mlflow_mixin`` function API. This mixin automatically configures MLFlow + and creates a run in the same process as each Tune trial. You can then + use the mlflow API inside the your training function and it will + automatically get reported to the correct run. + + For basic usage, just prepend your training function with the + ``@mlflow_mixin`` decorator: + + .. code-block:: python + + from ray.tune.integration.mlflow import mlflow_mixin + + @mlflow_mixin + def train_fn(config): + ... + mlflow.log_metric(...) + + You can also use MlFlow's autologging feature if using a training + framework like Pytorch Lightning, XGBoost, etc. More information can be + found here (https://mlflow.org/docs/latest/tracking.html#automatic + -logging). + + .. code-block:: python + + from ray.tune.integration.mlflow import mlflow_mixin + + @mlflow_mixin + def train_fn(config): + mlflow.autolog() + xgboost_results = xgb.train(config, ...) + + The MlFlow configuration is done by passing a ``mlflow`` key to + the ``config`` parameter of ``tune.run()`` (see example below). + + The content of the ``mlflow`` config entry is used to + configure MlFlow. Here are the keys you can pass in to this config entry: + + Args: + tracking_uri (str): The tracking URI for MLflow tracking. If using + Tune in a multi-node setting, make sure to use a remote server for + tracking. + experiment_id (str): The id of an already created MLflow experiment. + All logs from all trials in ``tune.run`` will be reported to this + experiment. If this is not provided or the experiment with this + id does not exist, you must provide an``experiment_name``. This + parameter takes precedence over ``experiment_name``. + experiment_name (str): The name of an already existing MLflow + experiment. All logs from all trials in ``tune.run`` will be + reported to this experiment. If this is not provided, you must + provide a valid ``experiment_id``. + + Example: + + .. code-block:: python + + from ray import tune + from ray.tune.integration.mlflow import mlflow_mixin + + import mlflow + + # Create the MlFlow expriment. + mlflow.create_experiment("my_experiment") + + @mlflow_mixin + def train_fn(config): + for i in range(10): + loss = self.config["a"] + self.config["b"] + mlflow.log_metric(key="loss", value=loss}) + tune.report(loss=loss, done=True) + + tune.run( + train_fn, + config={ + # define search space here + "a": tune.choice([1, 2, 3]), + "b": tune.choice([4, 5, 6]), + # mlflow configuration + "mlflow": { + "experiment_name": "my_experiment", + "tracking_uri": mlflow.get_tracking_uri() + } + }) + """ + if _import_mlflow() is None: + raise RuntimeError("MLFlow has not been installed. Please `pip " + "install mlflow` to use the mlflow_mixin.") + if hasattr(func, "__mixins__"): + func.__mixins__ = func.__mixins__ + (MLFlowTrainableMixin, ) + else: + func.__mixins__ = (MLFlowTrainableMixin, ) + return func + + +class MLFlowTrainableMixin: + def __init__(self, config: Dict, *args, **kwargs): + self._mlflow = _import_mlflow() + + if not isinstance(self, Trainable): + raise ValueError( + "The `MLFlowTrainableMixin` can only be used as a mixin " + "for `tune.Trainable` classes. Please make sure your " + "class inherits from both. For example: " + "`class YourTrainable(MLFlowTrainableMixin)`.") + + super().__init__(config, *args, **kwargs) + _config = config.copy() + try: + mlflow_config = _config.pop("mlflow").copy() + except KeyError as e: + raise ValueError( + "MLFlow mixin specified but no configuration has been passed. " + "Make sure to include a `mlflow` key in your `config` dict " + "containing at least a `tracking_uri` and either " + "`experiment_name` or `experiment_id` specification.") from e + + tracking_uri = mlflow_config.pop("tracking_uri", None) + if tracking_uri is None: + raise ValueError("MLFlow mixin specified but no " + "tracking_uri has been " + "passed in. Make sure to include a `mlflow` " + "key in your `config` dict containing at " + "least a `tracking_uri`") + self._mlflow.set_tracking_uri(tracking_uri) + + # First see if experiment_id is passed in. + experiment_id = mlflow_config.pop("experiment_id", None) + if experiment_id is None or self._mlflow.get_experiment( + experiment_id) is None: + logger.debug("Either no experiment_id is passed in, or the " + "experiment with the given id does not exist. " + "Checking experiment_name") + # Check for name. + experiment_name = mlflow_config.pop("experiment_name", None) + if experiment_name is None: + raise ValueError( + "MLFlow mixin specified but no " + "experiment_name or experiment_id has been " + "passed in. Make sure to include a `mlflow` " + "key in your `config` dict containing at " + "least a `experiment_name` or `experiment_id` " + "specification.") + experiment = self._mlflow.get_experiment_by_name(experiment_name) + if experiment is not None: + # Experiment with this name exists. + experiment_id = experiment.experiment_id + else: + raise ValueError("No experiment with the given " + "name: {} or id: {} currently exists. Make " + "sure to first start the MLFlow experiment " + "before calling tune.run.".format( + experiment_name, experiment_id)) + + self.experiment_id = experiment_id + + run_name = self.trial_name + "_" + self.trial_id + run_name = run_name.replace("/", "_") + self._mlflow.start_run( + experiment_id=self.experiment_id, run_name=run_name) + + def stop(self): + self._mlflow.end_run() diff --git a/python/ray/tune/integration/wandb.py b/python/ray/tune/integration/wandb.py index 82327fffc..4cd2cdee8 100644 --- a/python/ray/tune/integration/wandb.py +++ b/python/ray/tune/integration/wandb.py @@ -139,7 +139,10 @@ def wandb_mixin(func: Callable): }) """ - func.__mixins__ = (WandbTrainableMixin, ) + if hasattr(func, "__mixins__"): + func.__mixins__ = func.__mixins__ + (WandbTrainableMixin, ) + else: + func.__mixins__ = (WandbTrainableMixin, ) return func diff --git a/python/ray/tune/logger.py b/python/ray/tune/logger.py index b4ff76bae..3029f2000 100644 --- a/python/ray/tune/logger.py +++ b/python/ray/tune/logger.py @@ -77,37 +77,6 @@ class NoopLogger(Logger): pass -class MLFLowLogger(Logger): - """MLFlow logger. - - Requires the experiment configuration to have a MLFlow Experiment ID - or manually set the proper environment variables. - - """ - - def _init(self): - logger_config = self.config.get("logger_config", {}) - from mlflow.tracking import MlflowClient - client = MlflowClient( - tracking_uri=logger_config.get("mlflow_tracking_uri"), - registry_uri=logger_config.get("mlflow_registry_uri")) - run = client.create_run(logger_config.get("mlflow_experiment_id")) - self._run_id = run.info.run_id - for key, value in self.config.items(): - client.log_param(self._run_id, key, value) - self.client = client - - def on_result(self, result: Dict): - for key, value in result.items(): - if not isinstance(value, float): - continue - self.client.log_metric( - self._run_id, key, value, step=result.get(TRAINING_ITERATION)) - - def close(self): - self.client.set_terminated(self._run_id) - - class JsonLogger(Logger): """Logs trial results in json format. @@ -734,6 +703,13 @@ class TBXLoggerCallback(LoggerCallback): "in the hyperparameter values.") +# Maintain backwards compatibility. +from ray.tune.integration.mlflow import MLFlowLogger as _MLFlowLogger # noqa: E402, E501 +MLFlowLogger = _MLFlowLogger +# The capital L is a typo, but needs to remain for backwards compatibility. +MLFLowLogger = _MLFlowLogger + + def pretty_print(result): result = result.copy() result.update(config=None) # drop config from pretty print diff --git a/python/ray/tune/tests/test_dependency.py b/python/ray/tune/tests/test_dependency.py index c2e2f2c7c..d626b0e36 100644 --- a/python/ray/tune/tests/test_dependency.py +++ b/python/ray/tune/tests/test_dependency.py @@ -23,3 +23,4 @@ if __name__ == "__main__": } }) assert "ray.rllib" not in sys.modules, "RLlib should not be imported" + assert "mlflow" not in sys.modules, "MLFlow should not be imported" diff --git a/python/ray/tune/tests/test_integration_mlflow.py b/python/ray/tune/tests/test_integration_mlflow.py new file mode 100644 index 000000000..6613e0229 --- /dev/null +++ b/python/ray/tune/tests/test_integration_mlflow.py @@ -0,0 +1,306 @@ +import os +import unittest +from collections import namedtuple +from unittest.mock import patch + +from ray.tune.function_runner import wrap_function +from ray.tune.integration.mlflow import MLFlowLoggerCallback, MLFlowLogger, \ + mlflow_mixin, MLFlowTrainableMixin + + +class MockTrial( + namedtuple("MockTrial", + ["config", "trial_name", "trial_id", "logdir"])): + def __hash__(self): + return hash(self.trial_id) + + def __str__(self): + return self.trial_name + + +MockRunInfo = namedtuple("MockRunInfo", ["run_id"]) + + +class MockRun: + def __init__(self, run_id, tags=None): + self.run_id = run_id + self.tags = tags + self.info = MockRunInfo(run_id) + self.params = [] + self.metrics = [] + self.artifacts = [] + + def log_param(self, key, value): + self.params.append({key: value}) + + def log_metric(self, key, value): + self.metrics.append({key: value}) + + def log_artifact(self, artifact): + self.artifacts.append(artifact) + + def set_terminated(self, status): + self.terminated = True + self.status = status + + +MockExperiment = namedtuple("MockExperiment", ["name", "experiment_id"]) + + +class MockMlflowClient: + def __init__(self, tracking_uri=None, registry_uri=None): + self.tracking_uri = tracking_uri + self.registry_uri = registry_uri + self.experiments = [MockExperiment("existing_experiment", 0)] + self.runs = {0: []} + self.active_run = None + + def set_tracking_uri(self, tracking_uri): + self.tracking_uri = tracking_uri + + def get_experiment_by_name(self, name): + try: + index = self.experiment_names.index(name) + return self.experiments[index] + except ValueError: + return None + + def get_experiment(self, experiment_id): + experiment_id = int(experiment_id) + try: + return self.experiments[experiment_id] + except IndexError: + return None + + def create_experiment(self, name): + experiment_id = len(self.experiments) + self.experiments.append(MockExperiment(name, experiment_id)) + self.runs[experiment_id] = [] + return experiment_id + + def create_run(self, experiment_id, tags=None): + experiment_runs = self.runs[experiment_id] + run_id = (experiment_id, len(experiment_runs)) + run = MockRun(run_id=run_id, tags=tags) + experiment_runs.append(run) + return run + + def start_run(self, experiment_id, run_name): + # Creates new run and sets it as active. + run = self.create_run(experiment_id) + self.active_run = run + + def get_mock_run(self, run_id): + return self.runs[run_id[0]][run_id[1]] + + def log_param(self, run_id, key, value): + run = self.get_mock_run(run_id) + run.log_param(key, value) + + def log_metric(self, run_id, key, value, step): + run = self.get_mock_run(run_id) + run.log_metric(key, value) + + def log_artifacts(self, run_id, local_dir): + run = self.get_mock_run(run_id) + run.log_artifact(local_dir) + + def set_terminated(self, run_id, status): + run = self.get_mock_run(run_id) + run.set_terminated(status) + + @property + def experiment_names(self): + return [e.name for e in self.experiments] + + +def clear_env_vars(): + if "MLFLOW_EXPERIMENT_NAME" in os.environ: + del os.environ["MLFLOW_EXPERIMENT_NAME"] + if "MLFLOW_EXPERIMENT_ID" in os.environ: + del os.environ["MLFLOW_EXPERIMENT_ID"] + + +class MLFlowTest(unittest.TestCase): + @patch("mlflow.tracking.MlflowClient", MockMlflowClient) + def testMlFlowLoggerCallbackConfig(self): + # Explicitly pass in all args. + logger = MLFlowLoggerCallback( + tracking_uri="test1", + registry_uri="test2", + experiment_name="test_exp") + self.assertEqual(logger.client.tracking_uri, "test1") + self.assertEqual(logger.client.registry_uri, "test2") + self.assertListEqual(logger.client.experiment_names, + ["existing_experiment", "test_exp"]) + self.assertEqual(logger.experiment_id, 1) + + # Check if client recognizes already existing experiment. + logger = MLFlowLoggerCallback(experiment_name="existing_experiment") + self.assertListEqual(logger.client.experiment_names, + ["existing_experiment"]) + self.assertEqual(logger.experiment_id, 0) + + # Pass in experiment name as env var. + clear_env_vars() + os.environ["MLFLOW_EXPERIMENT_NAME"] = "test_exp" + logger = MLFlowLoggerCallback() + self.assertListEqual(logger.client.experiment_names, + ["existing_experiment", "test_exp"]) + self.assertEqual(logger.experiment_id, 1) + + # Pass in existing experiment name as env var. + clear_env_vars() + os.environ["MLFLOW_EXPERIMENT_NAME"] = "existing_experiment" + logger = MLFlowLoggerCallback() + self.assertListEqual(logger.client.experiment_names, + ["existing_experiment"]) + self.assertEqual(logger.experiment_id, 0) + + # Pass in existing experiment id as env var. + clear_env_vars() + os.environ["MLFLOW_EXPERIMENT_ID"] = "0" + logger = MLFlowLoggerCallback() + self.assertListEqual(logger.client.experiment_names, + ["existing_experiment"]) + self.assertEqual(logger.experiment_id, "0") + + # Pass in non existing experiment id as env var. + clear_env_vars() + os.environ["MLFLOW_EXPERIMENT_ID"] = "500" + with self.assertRaises(ValueError): + logger = MLFlowLoggerCallback() + + # Experiment name env var should take precedence over id env var. + clear_env_vars() + os.environ["MLFLOW_EXPERIMENT_NAME"] = "test_exp" + os.environ["MLFLOW_EXPERIMENT_ID"] = "0" + logger = MLFlowLoggerCallback() + self.assertListEqual(logger.client.experiment_names, + ["existing_experiment", "test_exp"]) + self.assertEqual(logger.experiment_id, 1) + + @patch("mlflow.tracking.MlflowClient", MockMlflowClient) + def testMlFlowLoggerLogging(self): + clear_env_vars() + trial_config = {"par1": 4, "par2": 9.} + trial = MockTrial(trial_config, "trial1", 0, "artifact") + + logger = MLFlowLoggerCallback( + experiment_name="test1", save_artifact=True) + + # Check if run is created. + logger.on_trial_start(iteration=0, trials=[], trial=trial) + # New run should be created for this trial with correct tag. + mock_run = logger.client.runs[1][0] + self.assertDictEqual(mock_run.tags, {"trial_name": "trial1"}) + self.assertTupleEqual(mock_run.run_id, (1, 0)) + self.assertTupleEqual(logger._trial_runs[trial], mock_run.run_id) + # Params should be logged. + self.assertListEqual(mock_run.params, [{"par1": 4}, {"par2": 9}]) + + # When same trial is started again, new run should not be created. + logger.on_trial_start(iteration=0, trials=[], trial=trial) + self.assertEqual(len(logger.client.runs[1]), 1) + + # Check metrics are logged properly. + result = {"metric1": 0.8, "metric2": 1, "metric3": None} + logger.on_trial_result(0, [], trial, result) + mock_run = logger.client.runs[1][0] + # metric3 is not logged since it cannot be converted to float. + self.assertListEqual(mock_run.metrics, [{ + "metric1": 0.8 + }, { + "metric2": 1.0 + }]) + + # Check that artifact is logged on termination. + logger.on_trial_complete(0, [], trial) + mock_run = logger.client.runs[1][0] + self.assertListEqual(mock_run.artifacts, ["artifact"]) + self.assertTrue(mock_run.terminated) + self.assertEqual(mock_run.status, "FINISHED") + + @patch("mlflow.tracking.MlflowClient", MockMlflowClient) + def testMlFlowLegacyLoggerConfig(self): + mlflow = MockMlflowClient() + with patch.dict("sys.modules", mlflow=mlflow): + clear_env_vars() + trial_config = {"par1": 4, "par2": 9.} + trial = MockTrial(trial_config, "trial1", 0, "artifact") + + # No experiment_id is passed in config, should raise an error. + with self.assertRaises(ValueError): + logger = MLFlowLogger(trial_config, "/tmp", trial) + + trial_config.update({ + "logger_config": { + "mlflow_tracking_uri": "test_tracking_uri", + "mlflow_experiment_id": 0 + } + }) + trial = MockTrial(trial_config, "trial2", 1, "artifact") + logger = MLFlowLogger(trial_config, "/tmp", trial) + experiment_logger = logger._trial_experiment_logger + client = experiment_logger.client + self.assertEqual(client.tracking_uri, "test_tracking_uri") + # Check to make sure that a run was created on experiment_id 0. + self.assertEqual(len(client.runs[0]), 1) + mock_run = client.runs[0][0] + self.assertDictEqual(mock_run.tags, {"trial_name": "trial2"}) + self.assertListEqual(mock_run.params, [{"par1": 4}, {"par2": 9}]) + + @patch("ray.tune.integration.mlflow._import_mlflow", + lambda: MockMlflowClient()) + def testMlFlowMixinConfig(self): + clear_env_vars() + trial_config = {"par1": 4, "par2": 9.} + + @mlflow_mixin + def train_fn(config): + return 1 + + train_fn.__mixins__ = (MLFlowTrainableMixin, ) + + # No MLFlow config passed in. + with self.assertRaises(ValueError): + wrapped = wrap_function(train_fn)(trial_config) + + trial_config.update({"mlflow": {}}) + # No tracking uri or experiment_id/name passed in. + with self.assertRaises(ValueError): + wrapped = wrap_function(train_fn)(trial_config) + + # Invalid experiment-id + trial_config["mlflow"].update({"experiment_id": "500"}) + # No tracking uri or experiment_id/name passed in. + with self.assertRaises(ValueError): + wrapped = wrap_function(train_fn)(trial_config) + + trial_config["mlflow"].update({ + "tracking_uri": "test_tracking_uri", + "experiment_name": "existing_experiment" + }) + wrapped = wrap_function(train_fn)(trial_config) + client = wrapped._mlflow + self.assertEqual(client.tracking_uri, "test_tracking_uri") + self.assertTupleEqual(client.active_run.run_id, (0, 0)) + + with patch("ray.tune.integration.mlflow._import_mlflow", + lambda: client): + train_fn.__mixins__ = (MLFlowTrainableMixin, ) + wrapped = wrap_function(train_fn)(trial_config) + client = wrapped._mlflow + self.assertTupleEqual(client.active_run.run_id, (0, 1)) + + # Set to experiment that does not already exist. + # New experiment should be created. + trial_config["mlflow"]["experiment_name"] = "new_experiment" + with self.assertRaises(ValueError): + wrapped = wrap_function(train_fn)(trial_config) + + +if __name__ == "__main__": + import pytest + import sys + sys.exit(pytest.main(["-v", __file__])) diff --git a/python/ray/tune/tests/test_sample.py b/python/ray/tune/tests/test_sample.py index 921e0c9ca..8a06be5d0 100644 --- a/python/ray/tune/tests/test_sample.py +++ b/python/ray/tune/tests/test_sample.py @@ -972,4 +972,4 @@ class SearchSpaceTest(unittest.TestCase): if __name__ == "__main__": import pytest import sys - sys.exit(pytest.main(["-v", __file__])) + sys.exit(pytest.main(["-v", __file__] + sys.argv[1:])) diff --git a/python/requirements_tune.txt b/python/requirements_tune.txt index d68d3b3d3..9be5ee118 100644 --- a/python/requirements_tune.txt +++ b/python/requirements_tune.txt @@ -3,6 +3,7 @@ bayesian-optimization ConfigSpace==0.4.10 dragonfly-opt gluoncv +gorilla # Need this because bug in mlflow. Should be fixed in v1.12.2 gym[atari] GPy h5py From 64c97d25d397457a4c873ca4a057eb1fbfd7ecd0 Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Sat, 19 Dec 2020 13:22:24 -0800 Subject: [PATCH 35/88] Enable by default new scheduler (#12735) --- .travis.yml | 44 ------------------- .../tests/test_logical_view_head.py | 2 +- .../java/io/ray/test/DynamicResourceTest.java | 3 +- python/ray/test_utils.py | 2 +- python/ray/tests/BUILD | 11 +---- src/ray/common/ray_config_def.h | 2 +- src/ray/raylet/node_manager.cc | 3 ++ .../raylet/scheduling/cluster_task_manager.cc | 36 +++++++++++++++ .../raylet/scheduling/cluster_task_manager.h | 5 +++ 9 files changed, 50 insertions(+), 58 deletions(-) diff --git a/.travis.yml b/.travis.yml index a8ddb0d46..e2d3822a2 100644 --- a/.travis.yml +++ b/.travis.yml @@ -20,50 +20,6 @@ before_install: matrix: include: - - os: linux - env: - - PYTHON=3.6 SMALL_AND_LARGE_TESTS=1 RAY_ENABLE_NEW_SCHEDULER=1 - - PYTHONWARNINGS=ignore - - RAY_DEFAULT_BUILD=1 - - RAY_CYTHON_EXAMPLES=1 - - RAY_USE_RANDOM_PORTS=1 - install: - - . ./ci/travis/ci.sh init RAY_CI_SERVE_AFFECTED,RAY_CI_TUNE_AFFECTED,RAY_CI_PYTHON_AFFECTED,RAY_CI_DASHBOARD_AFFECTED - before_script: - - . ./ci/travis/ci.sh build - script: - # bazel python tests. This should be run last to keep its logs at the end of travis logs. - - if [ $RAY_CI_PYTHON_AFFECTED == "1" ]; then ./ci/keep_alive bazel test --config=ci $(./scripts/bazel_export_options) --test_tag_filters=-jenkins_only,-medium_size_python_tests_a_to_j,-medium_size_python_tests_k_to_z,-new_scheduler_broken python/ray/tests/...; fi - - - os: linux - env: - - PYTHON=3.6 MEDIUM_TESTS_A_TO_J=1 RAY_ENABLE_NEW_SCHEDULER=1 - - PYTHONWARNINGS=ignore - - RAY_DEFAULT_BUILD=1 - - RAY_CYTHON_EXAMPLES=1 - - RAY_USE_RANDOM_PORTS=1 - install: - - . ./ci/travis/ci.sh init RAY_CI_SERVE_AFFECTED,RAY_CI_TUNE_AFFECTED,RAY_CI_PYTHON_AFFECTED,RAY_CI_DASHBOARD_AFFECTED - before_script: - - . ./ci/travis/ci.sh build - script: - # bazel python tests for medium size tests. Used for parallelization. - - if [ $RAY_CI_PYTHON_AFFECTED == "1" ]; then ./ci/keep_alive bazel test --config=ci $(./scripts/bazel_export_options) --test_tag_filters=-jenkins_only,medium_size_python_tests_a_to_j,-new_scheduler_broken python/ray/tests/...; fi - - - os: linux - env: - - PYTHON=3.6 MEDIUM_TESTS_K_TO_Z=1 RAY_ENABLE_NEW_SCHEDULER=1 - - PYTHONWARNINGS=ignore - - RAY_DEFAULT_BUILD=1 - - RAY_CYTHON_EXAMPLES=1 - - RAY_USE_RANDOM_PORTS=1 - install: - - . ./ci/travis/ci.sh init RAY_CI_SERVE_AFFECTED,RAY_CI_TUNE_AFFECTED,RAY_CI_PYTHON_AFFECTED,RAY_CI_DASHBOARD_AFFECTED - before_script: - - . ./ci/travis/ci.sh build - script: - # bazel python tests for medium size tests. Used for parallelization. - - if [ $RAY_CI_PYTHON_AFFECTED == "1" ]; then ./ci/keep_alive bazel test --config=ci $(./scripts/bazel_export_options) --test_tag_filters=-jenkins_only,medium_size_python_tests_k_to_z,-new_scheduler_broken python/ray/tests/...; fi - os: linux env: - PYTHON=3.6 SMALL_AND_LARGE_TESTS=1 diff --git a/dashboard/modules/logical_view/tests/test_logical_view_head.py b/dashboard/modules/logical_view/tests/test_logical_view_head.py index 2144918a4..5e4a8bb6c 100644 --- a/dashboard/modules/logical_view/tests/test_logical_view_head.py +++ b/dashboard/modules/logical_view/tests/test_logical_view_head.py @@ -35,7 +35,7 @@ def test_actor_groups(ray_start_with_dashboard): assert wait_until_server_available(webui_url) webui_url = format_web_url(webui_url) - timeout_seconds = 5 + timeout_seconds = 10 start_time = time.time() last_ex = None while True: diff --git a/java/test/src/main/java/io/ray/test/DynamicResourceTest.java b/java/test/src/main/java/io/ray/test/DynamicResourceTest.java index a103d6943..eeaa55b2c 100644 --- a/java/test/src/main/java/io/ray/test/DynamicResourceTest.java +++ b/java/test/src/main/java/io/ray/test/DynamicResourceTest.java @@ -15,7 +15,8 @@ public class DynamicResourceTest extends BaseTest { return "hi"; } - @Test(groups = {"cluster"}) + // Dynamic resources not supported yet. + @Test(groups = {"cluster"}, enabled = false) public void testSetResource() { // Call a task in advance to warm up the cluster to avoid being too slow to start workers. TestUtils.warmUpCluster(); diff --git a/python/ray/test_utils.py b/python/ray/test_utils.py index 7f6aaa360..594431e2f 100644 --- a/python/ray/test_utils.py +++ b/python/ray/test_utils.py @@ -442,4 +442,4 @@ def format_web_url(url): def new_scheduler_enabled(): - return os.environ.get("RAY_ENABLE_NEW_SCHEDULER") == "1" + return os.environ.get("RAY_ENABLE_NEW_SCHEDULER", "1") == "1" diff --git a/python/ray/tests/BUILD b/python/ray/tests/BUILD index 588710e3a..c5837a158 100644 --- a/python/ray/tests/BUILD +++ b/python/ray/tests/BUILD @@ -10,6 +10,7 @@ SRCS = [] + select({ py_test_module_list( files = [ +# "test_dynres.py", # dyn res not implemented "test_async.py", "test_actor.py", "test_actor_advanced.py", @@ -40,16 +41,6 @@ py_test_module_list( deps = ["//:ray_lib"], ) -py_test_module_list( - files = [ - "test_dynres.py", # dyn res not implemented - ], - size = "medium", - extra_srcs = SRCS, - tags = ["exclusive", "medium_size_python_tests_a_to_j", "new_scheduler_broken"], - deps = ["//:ray_lib"], -) - py_test_module_list( files = [ "test_memory_limits.py", diff --git a/src/ray/common/ray_config_def.h b/src/ray/common/ray_config_def.h index f5f420463..9f9392bf7 100644 --- a/src/ray/common/ray_config_def.h +++ b/src/ray/common/ray_config_def.h @@ -113,7 +113,7 @@ RAY_CONFIG(bool, lineage_pinning_enabled, false) /// only to work with direct calls. Once direct calls are becoming /// the default, this scheduler will also become the default. RAY_CONFIG(bool, new_scheduler_enabled, - getenv("RAY_ENABLE_NEW_SCHEDULER") != nullptr && + getenv("RAY_ENABLE_NEW_SCHEDULER") == nullptr || getenv("RAY_ENABLE_NEW_SCHEDULER") == std::string("1")) // The max allowed size in bytes of a return object from direct actor calls. diff --git a/src/ray/raylet/node_manager.cc b/src/ray/raylet/node_manager.cc index 9dd1f25b8..e78820d42 100644 --- a/src/ray/raylet/node_manager.cc +++ b/src/ray/raylet/node_manager.cc @@ -2951,6 +2951,9 @@ void NodeManager::HandlePinObjectIDs(const rpc::PinObjectIDsRequest &request, void NodeManager::HandleGetNodeStats(const rpc::GetNodeStatsRequest &node_stats_request, rpc::GetNodeStatsReply *reply, rpc::SendReplyCallback send_reply_callback) { + if (new_scheduler_enabled_) { + cluster_task_manager_->FillPendingActorInfo(reply); + } for (const auto &task : local_queues_.GetTasks(TaskState::INFEASIBLE)) { if (task.GetTaskSpecification().IsActorCreationTask()) { auto infeasible_task = reply->add_infeasible_tasks(); diff --git a/src/ray/raylet/scheduling/cluster_task_manager.cc b/src/ray/raylet/scheduling/cluster_task_manager.cc index 09db70f1b..12715430e 100644 --- a/src/ray/raylet/scheduling/cluster_task_manager.cc +++ b/src/ray/raylet/scheduling/cluster_task_manager.cc @@ -8,6 +8,9 @@ namespace ray { namespace raylet { +// The max number of pending actors to report in node stats. +const int kMaxPendingActorsToReport = 20; + ClusterTaskManager::ClusterTaskManager( const NodeID &self_node_id, std::shared_ptr cluster_resource_scheduler, @@ -330,6 +333,39 @@ bool ClusterTaskManager::CancelTask(const TaskID &task_id) { return false; } +void ClusterTaskManager::FillPendingActorInfo(rpc::GetNodeStatsReply *reply) const { + // Report infeasible actors. + int num_reported = 0; + for (const auto &shapes_it : infeasible_tasks_) { + auto &work_queue = shapes_it.second; + for (const auto &work_it : work_queue) { + Task task = std::get<0>(work_it); + if (task.GetTaskSpecification().IsActorCreationTask()) { + if (num_reported++ > kMaxPendingActorsToReport) { + break; // Protect the raylet from reporting too much data. + } + auto infeasible_task = reply->add_infeasible_tasks(); + infeasible_task->CopyFrom(task.GetTaskSpecification().GetMessage()); + } + } + } + // Report actors blocked on resources. + num_reported = 0; + for (const auto &shapes_it : boost::join(tasks_to_dispatch_, tasks_to_schedule_)) { + auto &work_queue = shapes_it.second; + for (const auto &work_it : work_queue) { + Task task = std::get<0>(work_it); + if (task.GetTaskSpecification().IsActorCreationTask()) { + if (num_reported++ > kMaxPendingActorsToReport) { + break; // Protect the raylet from reporting too much data. + } + auto ready_task = reply->add_infeasible_tasks(); + ready_task->CopyFrom(task.GetTaskSpecification().GetMessage()); + } + } + } +} + void ClusterTaskManager::FillResourceUsage( bool light_report_resource_usage_enabled, std::shared_ptr data) const { diff --git a/src/ray/raylet/scheduling/cluster_task_manager.h b/src/ray/raylet/scheduling/cluster_task_manager.h index aabc0a6fb..3e3ff2e44 100644 --- a/src/ray/raylet/scheduling/cluster_task_manager.h +++ b/src/ray/raylet/scheduling/cluster_task_manager.h @@ -105,6 +105,11 @@ class ClusterTaskManager { /// false if the task is already running. bool CancelTask(const TaskID &task_id); + /// Populate the list of pending or infeasible actor tasks for node stats. + /// + /// \param Output parameter. + void FillPendingActorInfo(rpc::GetNodeStatsReply *reply) const; + /// Populate the relevant parts of the heartbeat table. This is intended for /// sending raylet <-> gcs heartbeats. In particular, this should fill in /// resource_load and resource_load_by_shape. From 4832b3906611a92b027747c74bec6051dfb3fe72 Mon Sep 17 00:00:00 2001 From: Dmitri Gekhtman <62982571+DmitriGekhtman@users.noreply.github.com> Date: Sat, 19 Dec 2020 16:09:24 -0800 Subject: [PATCH 36/88] Suggest mounting into home. Note non-root user. (#12987) --- python/ray/autoscaler/kubernetes/defaults.yaml | 6 ++++-- python/ray/autoscaler/kubernetes/example-full.yaml | 6 ++++-- python/ray/autoscaler/kubernetes/example-ingress.yaml | 6 ++++-- 3 files changed, 12 insertions(+), 6 deletions(-) diff --git a/python/ray/autoscaler/kubernetes/defaults.yaml b/python/ray/autoscaler/kubernetes/defaults.yaml index beadf1668..31b3301ea 100644 --- a/python/ray/autoscaler/kubernetes/defaults.yaml +++ b/python/ray/autoscaler/kubernetes/defaults.yaml @@ -250,9 +250,11 @@ worker_nodes: # Files or directories to copy to the head and worker nodes. The format is a # dictionary from REMOTE_PATH: LOCAL_PATH, e.g. file_mounts: { -# "/path1/on/remote/machine": "/path1/on/local/machine", -# "/path2/on/remote/machine": "/path2/on/local/machine", +# "~/path1/on/remote/machine": "/path1/on/local/machine", +# "~/path2/on/remote/machine": "/path2/on/local/machine", } +# Note that the container images in this example have a non-root user. +# To avoid permissions issues, we recommend mounting into a subdirectory of home (~). # Files or directories to copy from the head node to the worker nodes. The format is a # list of paths. The same path on the head node will be copied to the worker node. diff --git a/python/ray/autoscaler/kubernetes/example-full.yaml b/python/ray/autoscaler/kubernetes/example-full.yaml index 764048e1c..80ada3b27 100644 --- a/python/ray/autoscaler/kubernetes/example-full.yaml +++ b/python/ray/autoscaler/kubernetes/example-full.yaml @@ -250,9 +250,11 @@ worker_nodes: # Files or directories to copy to the head and worker nodes. The format is a # dictionary from REMOTE_PATH: LOCAL_PATH, e.g. file_mounts: { -# "/path1/on/remote/machine": "/path1/on/local/machine", -# "/path2/on/remote/machine": "/path2/on/local/machine", +# "~/path1/on/remote/machine": "/path1/on/local/machine", +# "~/path2/on/remote/machine": "/path2/on/local/machine", } +# Note that the container images in this example have a non-root user. +# To avoid permissions issues, we recommend mounting into a subdirectory of home (~). # Files or directories to copy from the head node to the worker nodes. The format is a # list of paths. The same path on the head node will be copied to the worker node. diff --git a/python/ray/autoscaler/kubernetes/example-ingress.yaml b/python/ray/autoscaler/kubernetes/example-ingress.yaml index 47afaeff1..b0ded43f2 100644 --- a/python/ray/autoscaler/kubernetes/example-ingress.yaml +++ b/python/ray/autoscaler/kubernetes/example-ingress.yaml @@ -286,9 +286,11 @@ worker_nodes: # Files or directories to copy to the head and worker nodes. The format is a # dictionary from REMOTE_PATH: LOCAL_PATH, e.g. file_mounts: { -# "/path1/on/remote/machine": "/path1/on/local/machine", -# "/path2/on/remote/machine": "/path2/on/local/machine", +# "~/path1/on/remote/machine": "/path1/on/local/machine", +# "~/path2/on/remote/machine": "/path2/on/local/machine", } +# Note that the container images in this example have a non-root user. +# To avoid permissions issues, we recommend mounting into a subdirectory of home (~). # List of commands that will be run before `setup_commands`. If docker is # enabled, these commands will run outside the container and before docker From 51139ed37c5a64a87c916e6a6675385f9b700535 Mon Sep 17 00:00:00 2001 From: Amog Kamsetty Date: Sat, 19 Dec 2020 21:46:33 -0800 Subject: [PATCH 37/88] [SGD] Fix process group timeout units (#12477) --- python/ray/util/sgd/torch/constants.py | 2 +- python/ray/util/sgd/torch/torch_trainer.py | 9 +++++++-- python/ray/util/sgd/torch/worker_group.py | 4 ++-- 3 files changed, 10 insertions(+), 5 deletions(-) diff --git a/python/ray/util/sgd/torch/constants.py b/python/ray/util/sgd/torch/constants.py index cf3a7dc8f..1fd37c8d6 100644 --- a/python/ray/util/sgd/torch/constants.py +++ b/python/ray/util/sgd/torch/constants.py @@ -6,7 +6,7 @@ SCHEDULER_STEP = "scheduler_step" SCHEDULER_STEP_BATCH = "batch" SCHEDULER_STEP_EPOCH = "epoch" SCHEDULER_STEP_MANUAL = "manual" -NCCL_TIMEOUT_S = env_integer("NCCL_TIMEOUT_S", 10) +NCCL_TIMEOUT_S = env_integer("NCCL_TIMEOUT_S", 1800) VALID_SCHEDULER_STEP = { SCHEDULER_STEP_BATCH, SCHEDULER_STEP_EPOCH, SCHEDULER_STEP_MANUAL diff --git a/python/ray/util/sgd/torch/torch_trainer.py b/python/ray/util/sgd/torch/torch_trainer.py index 50f20e39f..79217bcc0 100644 --- a/python/ray/util/sgd/torch/torch_trainer.py +++ b/python/ray/util/sgd/torch/torch_trainer.py @@ -105,7 +105,9 @@ class TorchTrainer: wrap_ddp (bool): Whether to automatically wrap DistributedDataParallel over each model. If False, you are expected to call it yourself. timeout_s (float): Seconds before the torch process group - times out. Useful when machines are unreliable. + times out. Useful when machines are unreliable. If not set, default + to 30 min, which is the same default as + ``torch.init_process_group(...)``. add_dist_sampler (bool): Whether to automatically add a DistributedSampler to all created dataloaders. Only applicable if num_workers > 1. @@ -143,7 +145,7 @@ class TorchTrainer: use_gpu="auto", backend="auto", wrap_ddp=True, - timeout_s=NCCL_TIMEOUT_S, + timeout_s=1800, use_fp16=False, use_tqdm=False, add_dist_sampler=True, @@ -230,6 +232,9 @@ class TorchTrainer: if backend == "auto": backend = "nccl" if use_gpu else "gloo" + if backend == "nccl": + timeout_s = NCCL_TIMEOUT_S + logger.debug(f"Using {backend} as backend.") self.backend = backend self.num_cpus_per_worker = num_cpus_per_worker diff --git a/python/ray/util/sgd/torch/worker_group.py b/python/ray/util/sgd/torch/worker_group.py index f1a82082a..390059c87 100644 --- a/python/ray/util/sgd/torch/worker_group.py +++ b/python/ray/util/sgd/torch/worker_group.py @@ -175,7 +175,7 @@ class RemoteWorkerGroup(WorkerGroupInterface): url=address, world_rank=i + starting_rank, world_size=world_size, - timeout=timedelta(self._timeout_s)) + timeout=timedelta(seconds=self._timeout_s)) for i, worker in enumerate(self.remote_workers) ] return remote_pgroup_setups @@ -467,7 +467,7 @@ class LocalWorkerGroup(WorkerGroupInterface): url=address, world_rank=0, world_size=num_workers, - timeout=timedelta(self._timeout_s)) + timeout=timedelta(seconds=self._timeout_s)) ray.get(remote_pgs) local_node_ip = ray.services.get_node_ip_address() From 4c63917439899f2526907471c82f76a924e2a4c0 Mon Sep 17 00:00:00 2001 From: Amog Kamsetty Date: Sun, 20 Dec 2020 00:42:21 -0800 Subject: [PATCH 38/88] [Queue] Add options and shutdown to Queue (#12932) Co-authored-by: Richard Liaw --- python/ray/tests/test_queue.py | 29 +++++++++++++++++++++++- python/ray/util/queue.py | 40 ++++++++++++++++++++++++++++++---- 2 files changed, 64 insertions(+), 5 deletions(-) diff --git a/python/ray/tests/test_queue.py b/python/ray/tests/test_queue.py index 11b2a9a4c..df87a93f8 100644 --- a/python/ray/tests/test_queue.py +++ b/python/ray/tests/test_queue.py @@ -1,7 +1,9 @@ +import time + import pytest import ray -from ray.exceptions import GetTimeoutError +from ray.exceptions import GetTimeoutError, RayActorError from ray.util.queue import Queue, Empty, Full @@ -184,6 +186,31 @@ def test_qsize(ray_start_regular_shared): assert q.qsize() == size +def test_shutdown(ray_start_regular_shared): + q = Queue() + actor = q.actor + q.shutdown() + assert q.actor is None + with pytest.raises(RayActorError): + ray.get(actor.empty.remote()) + + +def test_custom_resources(ray_start_regular_shared): + current_resources = ray.available_resources() + assert current_resources["CPU"] == 1.0 + + # By default an actor should not reserve any resources. + Queue() + current_resources = ray.available_resources() + assert current_resources["CPU"] == 1.0 + + # Specify resource requirement. The queue should now reserve 1 CPU. + Queue(actor_options={"num_cpus": 1}) + time.sleep(1) + current_resources = ray.available_resources() + assert "CPU" not in current_resources, current_resources + + if __name__ == "__main__": import sys sys.exit(pytest.main(["-v", __file__])) diff --git a/python/ray/util/queue.py b/python/ray/util/queue.py index 59be761de..627d64a7a 100644 --- a/python/ray/util/queue.py +++ b/python/ray/util/queue.py @@ -1,5 +1,5 @@ import asyncio -from typing import Optional, Any, List +from typing import Optional, Any, List, Dict from collections.abc import Iterable import ray @@ -28,6 +28,10 @@ class Queue: Args: maxsize (optional, int): maximum size of the queue. If zero, size is unbounded. + actor_options (optional, Dict): Dictionary of options to pass into + the QueueActor during creation. These are directly passed into + QueueActor.options(...). This could be useful if you + need to pass in custom resource requirements, for example. Examples: >>> q = Queue() @@ -36,11 +40,16 @@ class Queue: >>> q.put(item) >>> for item in items: >>> assert item == q.get() + >>> # Create Queue with the underlying actor reserving 1 CPU. + >>> q = Queue(actor_options={"num_cpus": 1}) """ - def __init__(self, maxsize: int = 0) -> None: + def __init__(self, maxsize: int = 0, + actor_options: Optional[Dict] = None) -> None: + actor_options = actor_options or {} self.maxsize = maxsize - self.actor = _QueueActor.remote(self.maxsize) + self.actor = ray.remote(_QueueActor).options(**actor_options).remote( + self.maxsize) def __len__(self) -> int: return self.size() @@ -212,8 +221,31 @@ class Queue: return ray.get(self.actor.get_nowait_batch.remote(num_items)) + def shutdown(self, force: bool = False, grace_period_s: int = 5) -> None: + """Terminates the underlying QueueActor. + + All of the resources reserved by the queue will be released. + + Args: + force (bool): If True, forcefully kill the actor, causing an + immediate failure. If False, graceful + actor termination will be attempted first, before falling back + to a forceful kill. + grace_period_s (int): If force is False, how long in seconds to + wait for graceful termination before falling back to + forceful kill. + """ + if self.actor: + if force: + ray.kill(self.actor, no_restart=True) + else: + done_ref = self.actor.__ray_terminate__.remote() + done, not_done = ray.wait([done_ref], timeout=grace_period_s) + if not_done: + ray.kill(self.actor, no_restart=True) + self.actor = None + -@ray.remote class _QueueActor: def __init__(self, maxsize): self.maxsize = maxsize From ec9ad4a56b9873a41aed306034579e67b51a34c1 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Sun, 20 Dec 2020 00:43:27 -0800 Subject: [PATCH 39/88] Documentation for Ray debugger stepping (#12845) --- doc/source/ray-debugging.rst | 106 +++++++++++++++++++++++++++++++++++ 1 file changed, 106 insertions(+) diff --git a/doc/source/ray-debugging.rst b/doc/source/ray-debugging.rst index f16270f2b..4dba30cf2 100644 --- a/doc/source/ray-debugging.rst +++ b/doc/source/ray-debugging.rst @@ -89,6 +89,112 @@ The Ray debugger supports the `same commands as PDB `_. +Stepping between Ray tasks +-------------------------- + +You can use the debugger to step between Ray tasks. Let's take the +following recursive function as an example: + +.. code-block:: python + + import ray + + ray.init() + + @ray.remote + def fact(n): + if n == 1: + return n + else: + n_id = fact.remote(n - 1) + return n * ray.get(n_id) + + ray.util.pdb.set_trace() + result_ref = fact.remote(5) + result = ray.get(result_ref) + + +After running the program by executing the Python file and calling +``ray debug``, you can select the breakpoint by pressing ``0`` and +enter. This will result in the following output: + +.. code-block:: python + + Enter breakpoint index or press enter to refresh: 0 + > /Users/pcmoritz/tmp/stepping.py(14)() + -> result_ref = fact.remote(5) + (Pdb) + +You can jump into the call with the ``remote`` command in Ray's debugger. +Inside the function, print the value of `n` with ``p(n)``, resulting in +the following output: + +.. code-block:: python + + -> result_ref = fact.remote(5) + (Pdb) remote + *** Connection closed by remote host *** + Continuing pdb session in different process... + --Call-- + > /Users/pcmoritz/tmp/stepping.py(5)fact() + -> @ray.remote + (Pdb) ll + 5 -> @ray.remote + 6 def fact(n): + 7 if n == 1: + 8 return n + 9 else: + 10 n_id = fact.remote(n - 1) + 11 return n * ray.get(n_id) + (Pdb) p(n) + 5 + (Pdb) + +Now step into the next remote call again with +``remote`` and print `n`. You an now either continue recursing into +the function by calling ``remote`` a few more times, or you can jump +to the location where ``ray.get`` is called on the result by using the +``get`` debugger comand. Use ``get`` again to jump back to the original +call site and use ``p(result)`` to print the result: + +.. code-block:: python + + Enter breakpoint index or press enter to refresh: 0 + > /Users/pcmoritz/tmp/stepping.py(14)() + -> result_ref = fact.remote(5) + (Pdb) remote + *** Connection closed by remote host *** + Continuing pdb session in different process... + --Call-- + > /Users/pcmoritz/tmp/stepping.py(5)fact() + -> @ray.remote + (Pdb) p(n) + 5 + (Pdb) remote + *** Connection closed by remote host *** + Continuing pdb session in different process... + --Call-- + > /Users/pcmoritz/tmp/stepping.py(5)fact() + -> @ray.remote + (Pdb) p(n) + 4 + (Pdb) get + *** Connection closed by remote host *** + Continuing pdb session in different process... + --Return-- + > /Users/pcmoritz/tmp/stepping.py(5)fact()->120 + -> @ray.remote + (Pdb) get + *** Connection closed by remote host *** + Continuing pdb session in different process... + --Return-- + > /Users/pcmoritz/tmp/stepping.py(14)()->None + -> result_ref = fact.remote(5) + (Pdb) p(result) + 120 + (Pdb) + + Post Mortem Debugging --------------------- From 038a50af5296855a50764afd9f007cd552ff5c74 Mon Sep 17 00:00:00 2001 From: Richard Liaw Date: Sun, 20 Dec 2020 01:01:09 -0800 Subject: [PATCH 40/88] [tune] skopt fix-extra-import (#12970) Signed-off-by: Richard Liaw --- python/ray/tune/suggest/skopt.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/python/ray/tune/suggest/skopt.py b/python/ray/tune/suggest/skopt.py index c329b94d5..574be4f35 100644 --- a/python/ray/tune/suggest/skopt.py +++ b/python/ray/tune/suggest/skopt.py @@ -144,7 +144,7 @@ class SkOptSearch(Searcher): """ def __init__(self, - optimizer: Optional[sko.optimizer.Optimizer] = None, + optimizer: Optional["sko.optimizer.Optimizer"] = None, space: Union[List[str], Dict[str, Union[Tuple, List]]] = None, metric: Optional[str] = None, mode: Optional[str] = None, @@ -152,9 +152,9 @@ class SkOptSearch(Searcher): evaluated_rewards: Optional[List] = None, max_concurrent: Optional[int] = None, use_early_stopped_trials: Optional[bool] = None): - assert sko is not None, """skopt must be installed! - You can install Skopt with the command: - `pip install scikit-optimize`.""" + assert sko is not None, ("skopt must be installed! " + "You can install Skopt with the command: " + "`pip install scikit-optimize`.") if mode: assert mode in ["min", "max"], "`mode` must be 'min' or 'max'." From 3fab93b61b4b88dfe5a5d931605d18ce756fe1a8 Mon Sep 17 00:00:00 2001 From: fangfengbin <869218239a@zju.edu.cn> Date: Sun, 20 Dec 2020 20:20:07 +0800 Subject: [PATCH 41/88] Fix scheduling_resources comment errors (#12991) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Fix scheduling_resources comment error * add part code Co-authored-by: 灵洵 --- src/ray/common/task/scheduling_resources.cc | 4 ++-- src/ray/common/task/scheduling_resources.h | 3 +-- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/src/ray/common/task/scheduling_resources.cc b/src/ray/common/task/scheduling_resources.cc index db7be28b6..a90e13667 100644 --- a/src/ray/common/task/scheduling_resources.cc +++ b/src/ray/common/task/scheduling_resources.cc @@ -200,8 +200,8 @@ void ResourceSet::AddResourcesCapacityConstrained(const ResourceSet &other, const FractionalResourceQuantity &to_add_resource_capacity = resource_pair.second; if (total_resource_map.count(to_add_resource_label) != 0) { // If resource exists in total map, add to the local capacity map. - // If the new capacity will be greater the total capacity, set the new capacity to - // total capacity (capping to the total) + // If the new capacity is less than the total capacity, set the new capacity to + // the local capacity (capping to the total). const FractionalResourceQuantity &total_capacity = total_resource_map.at(to_add_resource_label); resource_capacity_[to_add_resource_label] = diff --git a/src/ray/common/task/scheduling_resources.h b/src/ray/common/task/scheduling_resources.h index 41ed07a1c..6093d8f13 100644 --- a/src/ray/common/task/scheduling_resources.h +++ b/src/ray/common/task/scheduling_resources.h @@ -136,8 +136,7 @@ class ResourceSet { /// /// \param other: The other resource set to add. /// \param total_resources: Total resource set which sets upper limits on capacity for - /// each label. \return True if the resource set was added successfully. False - /// otherwise. + /// each label. void AddResourcesCapacityConstrained(const ResourceSet &other, const ResourceSet &total_resources); From 407a3523f367a1e2f124b4bdfb1aef3a2d4340a7 Mon Sep 17 00:00:00 2001 From: Sven Mika Date: Sun, 20 Dec 2020 15:37:31 +0100 Subject: [PATCH 42/88] [RLlib] eval_workers after restore not generated in Trainer due to unintuitive config handling. (#12844) --- rllib/agents/trainer.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/rllib/agents/trainer.py b/rllib/agents/trainer.py index c57ff0b67..8d2f886b5 100644 --- a/rllib/agents/trainer.py +++ b/rllib/agents/trainer.py @@ -1110,6 +1110,17 @@ class Trainer(Trainable): "`count_steps_by` must be one of [env_steps|agent_steps]! " "Got {}".format(config["multiagent"]["count_steps_by"])) + # If evaluation_num_workers > 0, warn if evaluation_interval is None + # (also set it to 1). + if config["evaluation_num_workers"] > 0 and \ + not config["evaluation_interval"]: + logger.warning( + "You have specified {} evaluation workers, but no evaluation " + "interval! Will set the interval to 1 (each `train()` call). " + "If this is too frequent, set `evaluation_interval` to some " + "larger value.".format(config["evaluation_num_workers"])) + config["evaluation_interval"] = 1 + def _try_recover(self): """Try to identify and remove any unhealthy workers. From d6e243ad4651967ef7650f3a7a05f5bec9d9737b Mon Sep 17 00:00:00 2001 From: Ian Rodney Date: Sun, 20 Dec 2020 11:03:57 -0800 Subject: [PATCH 43/88] [serve] Refactor to full control loop design (#12537) --- python/ray/serve/controller.py | 195 ++++++++++++++++++++++----------- 1 file changed, 132 insertions(+), 63 deletions(-) diff --git a/python/ray/serve/controller.py b/python/ray/serve/controller.py index cba1ec64b..17a543048 100644 --- a/python/ray/serve/controller.py +++ b/python/ray/serve/controller.py @@ -155,6 +155,24 @@ class ActorStateReconciler: default_factory=lambda: defaultdict(list)) backends_to_remove: List[BackendTag] = field(default_factory=list) + # NOTE(ilr): These are not checkpointed, but will be recreated by + # `_enqueue_pending_scale_changes_loop`. + currently_starting_replicas: Dict[asyncio.Future, Tuple[ + BackendTag, ReplicaTag, ActorHandle]] = field(default_factory=dict) + currently_stopping_replicas: Dict[asyncio.Future, Tuple[ + BackendTag, ReplicaTag]] = field(default_factory=dict) + + def __getstate__(self): + state = self.__dict__.copy() + del state["currently_stopping_replicas"] + del state["currently_starting_replicas"] + return state + + def __setstate__(self, state): + self.__dict__.update(state) + self.currently_stopping_replicas = {} + self.currently_starting_replicas = {} + # TODO(edoakes): consider removing this and just using the names. def http_proxy_handles(self) -> List[ActorHandle]: @@ -174,42 +192,6 @@ class ActorStateReconciler: for replica_dict in self.backend_replicas.values() ])) - async def _start_pending_backend_replicas( - self, current_state: SystemState) -> None: - """Starts the pending backend replicas in self.backend_replicas_to_start. - - Waits for replicas to start up, then removes them from - self.backend_replicas_to_start. - """ - fut_to_replica_info = {} - for backend_tag, replicas_to_create in self.backend_replicas_to_start.\ - items(): - for replica_tag in replicas_to_create: - replica_handle = await self._start_backend_replica( - current_state, backend_tag, replica_tag) - ready_future = replica_handle.ready.remote().as_future() - fut_to_replica_info[ready_future] = (backend_tag, replica_tag, - replica_handle) - - start = time.time() - prev_warning = start - while fut_to_replica_info: - if time.time() - prev_warning > REPLICA_STARTUP_TIME_WARNING_S: - prev_warning = time.time() - logger.warning("Waited {:.2f}s for replicas to start up. Make " - "sure there are enough resources to create the " - "replicas.".format(time.time() - start)) - - done, pending = await asyncio.wait( - list(fut_to_replica_info.keys()), timeout=1) - for fut in done: - (backend_tag, replica_tag, - replica_handle) = fut_to_replica_info.pop(fut) - self.backend_replicas[backend_tag][ - replica_tag] = replica_handle - - self.backend_replicas_to_start.clear() - async def _start_backend_replica(self, current_state: SystemState, backend_tag: BackendTag, replica_tag: ReplicaTag) -> ActorHandle: @@ -254,6 +236,7 @@ class ActorStateReconciler: intended replicas. This avoids inconsistencies with starting/stopping a replica and then crashing before writing a checkpoint. """ + logger.debug("Scaling backend '{}' to {} replicas".format( backend_tag, num_replicas)) assert (backend_tag in backends @@ -300,32 +283,102 @@ class ActorStateReconciler: self.backend_replicas_to_stop[backend_tag].append(replica_tag) - async def _stop_pending_backend_replicas(self) -> None: - """Stops the pending backend replicas in self.backend_replicas_to_stop. + async def _enqueue_pending_scale_changes_loop(self, + current_state: SystemState): + for backend_tag, replicas_to_create in self.backend_replicas_to_start.\ + items(): + for replica_tag in replicas_to_create: + replica_handle = await self._start_backend_replica( + current_state, backend_tag, replica_tag) + ready_future = replica_handle.ready.remote().as_future() + self.currently_starting_replicas[ready_future] = ( + backend_tag, replica_tag, replica_handle) - Removes backend_replicas from the http_proxy, kills them, and clears - self.backend_replicas_to_stop. - """ - for backend_tag, replicas_list in self.backend_replicas_to_stop.items( - ): - for replica_tag in replicas_list: - # NOTE(edoakes): the replicas may already be stopped if we - # failed after stopping them but before writing a checkpoint. + for backend_tag, replicas_to_stop in self.backend_replicas_to_stop.\ + items(): + for replica_tag in replicas_to_stop: replica_name = format_actor_name(replica_tag, self.controller_name) - try: - replica = ray.get_actor(replica_name) - except ValueError: - continue - # TODO(edoakes): this logic isn't ideal because there may be - # pending tasks still executing on the replica. However, if we - # use replica.__ray_terminate__, we may send it while the - # replica is being restarted and there's no way to tell if it - # successfully killed the worker or not. - ray.kill(replica, no_restart=True) + async def kill_actor(replica_name_to_use): + # NOTE: the replicas may already be stopped if we failed + # after stopping them but before writing a checkpoint. + try: + replica = ray.get_actor(replica_name_to_use) + except ValueError: + return - self.backend_replicas_to_stop.clear() + # TODO(edoakes): this logic isn't ideal because there may + # be pending tasks still executing on the replica. However, + # if we use replica.__ray_terminate__, we may send it while + # the replica is being restarted and there's no way to tell + # if it successfully killed the worker or not. + ray.kill(replica, no_restart=True) + + self.currently_stopping_replicas[asyncio.ensure_future( + kill_actor(replica_name))] = (backend_tag, replica_tag) + + async def _check_currently_starting_replicas(self) -> bool: + """Returns a boolean specifying if there are more replicas to start""" + in_flight = list() + + if self.currently_starting_replicas: + done, in_flight = await asyncio.wait( + list(self.currently_starting_replicas.keys()), timeout=0) + for fut in done: + (backend_tag, replica_tag, + replica_handle) = self.currently_starting_replicas.pop(fut) + self.backend_replicas[backend_tag][ + replica_tag] = replica_handle + + backend = self.backend_replicas_to_start.get(backend_tag) + if backend: + try: + backend.remove(replica_tag) + except ValueError: + pass + if len(backend) == 0: + del self.backend_replicas_to_start[backend_tag] + return len(in_flight) > 0 + + async def _check_currently_stopping_replicas(self) -> bool: + """Returns a boolean specifying if there are more replicas to stop""" + in_flight = list() + if self.currently_stopping_replicas: + done_stoppping, in_flight = await asyncio.wait( + list(self.currently_stopping_replicas.keys()), timeout=0) + for fut in done_stoppping: + (backend_tag, + replica_tag) = self.currently_stopping_replicas.pop(fut) + + backend = self.backend_replicas_to_stop.get(backend_tag) + + if backend: + try: + backend.remove(replica_tag) + except ValueError: + pass + if len(backend) == 0: + del self.backend_replicas_to_stop[backend_tag] + + return len(in_flight) > 0 + + async def backend_control_loop(self): + start = time.time() + prev_warning = start + need_to_continue = True + while need_to_continue: + if time.time() - prev_warning > REPLICA_STARTUP_TIME_WARNING_S: + prev_warning = time.time() + logger.warning("Waited {:.2f}s for replicas to start up. Make " + "sure there are enough resources to create the " + "replicas.".format(time.time() - start)) + + need_to_continue = ( + await self._check_currently_starting_replicas() + or await self._check_currently_stopping_replicas()) + + asyncio.sleep(1) def _start_http_proxies_if_needed(self, http_host: str, http_port: str, http_middlewares: List[Any]) -> None: @@ -415,8 +468,8 @@ class ActorStateReconciler: backend, metadata.autoscaling_config) # Start/stop any pending backend replicas. - await self._start_pending_backend_replicas(current_state) - await self._stop_pending_backend_replicas() + await self._enqueue_pending_scale_changes_loop(current_state) + await self.backend_control_loop() return autoscaling_policies @@ -671,6 +724,12 @@ class ServeController: await self.update_backend_config( backend, BackendConfig(num_replicas=new_num_replicas)) + async def reconcile_current_and_goal_backends(self): + pass + # backends_to_delete = set( + # self.current_state.backends.keys()).difference( + # self.goal_state.backends.keys()) + async def run_control_loop(self) -> None: while True: await self.do_autoscale() @@ -872,6 +931,7 @@ class ServeController: backend_tag, metadata.autoscaling_config) try: + # This call should be to run control loop self.actor_reconciler._scale_backend_replicas( self.current_state.backends, backend_tag, backend_config.num_replicas) @@ -886,8 +946,9 @@ class ServeController: # or pushing the updated config to avoid inconsistent state if we # crash while making the change. self._checkpoint() - await self.actor_reconciler._start_pending_backend_replicas( + await self.actor_reconciler._enqueue_pending_scale_changes_loop( self.current_state) + await self.actor_reconciler.backend_control_loop() self.notify_replica_handles_changed() @@ -916,6 +977,10 @@ class ServeController: # Scale its replicas down to 0. This will also remove the backend # from self.current_state.backends and # self.actor_reconciler.backend_replicas. + + self.goal_state.backends[backend_tag] = None + + # This should be a call to the control loop self.actor_reconciler._scale_backend_replicas( self.current_state.backends, backend_tag, 0) @@ -932,7 +997,9 @@ class ServeController: # backend from the routers to avoid inconsistent state if we crash # after pushing the update. self._checkpoint() - await self.actor_reconciler._stop_pending_backend_replicas() + await self.actor_reconciler._enqueue_pending_scale_changes_loop( + self.current_state) + await self.actor_reconciler.backend_control_loop() self.notify_replica_handles_changed() return return_uuid @@ -955,6 +1022,8 @@ class ServeController: backend_info = self.current_state.get_backend(backend_tag) # Scale the replicas with the new configuration. + + # This should be to run the control loop self.actor_reconciler._scale_backend_replicas( self.current_state.backends, backend_tag, backend_config.num_replicas) @@ -970,9 +1039,9 @@ class ServeController: # Inform the routers about change in configuration # (particularly for setting max_batch_size). - await self.actor_reconciler._start_pending_backend_replicas( + await self.actor_reconciler._enqueue_pending_scale_changes_loop( self.current_state) - await self.actor_reconciler._stop_pending_backend_replicas() + await self.actor_reconciler.backend_control_loop() self.notify_replica_handles_changed() self.notify_backend_configs_changed() From bf6577c8f4587c3f4de36658b2ae7a2f21a8bd3c Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Sun, 20 Dec 2020 12:10:28 -0800 Subject: [PATCH 44/88] Switch debugger to sockets and support unicode (#13004) --- python/ray/scripts/scripts.py | 7 ++----- python/ray/tests/test_ray_debugger.py | 5 +++++ python/ray/util/rpdb.py | 25 +++++++++++++++++++++++++ 3 files changed, 32 insertions(+), 5 deletions(-) diff --git a/python/ray/scripts/scripts.py b/python/ray/scripts/scripts.py index 7d72914ee..4a1dd6e28 100644 --- a/python/ray/scripts/scripts.py +++ b/python/ray/scripts/scripts.py @@ -6,7 +6,6 @@ import logging import os import subprocess import sys -from telnetlib import Telnet import time import urllib import urllib.parse @@ -172,8 +171,7 @@ def continue_debug_session(): ray.experimental.internal_kv._internal_kv_del(key) return host, port = session["pdb_address"].split(":") - with Telnet(host, int(port)) as tn: - tn.interact() + ray.util.rpdb.connect_pdb_client(host, int(port)) ray.experimental.internal_kv._internal_kv_del(key) continue_debug_session() return @@ -215,8 +213,7 @@ def debug(address): ray.experimental.internal_kv._internal_kv_get( active_sessions[index])) host, port = session["pdb_address"].split(":") - with Telnet(host, int(port)) as tn: - tn.interact() + ray.util.rpdb.connect_pdb_client(host, int(port)) @cli.command() diff --git a/python/ray/tests/test_ray_debugger.py b/python/ray/tests/test_ray_debugger.py index adea19684..e271dd3f6 100644 --- a/python/ray/tests/test_ray_debugger.py +++ b/python/ray/tests/test_ray_debugger.py @@ -44,6 +44,7 @@ def test_ray_debugger_commands(shutdown_only): @ray.remote def f(): + """We support unicode too: 🐛""" ray.util.pdb.set_trace() result1 = f.remote() @@ -55,6 +56,10 @@ def test_ray_debugger_commands(shutdown_only): p.expect("Enter breakpoint index or press enter to refresh: ") p.sendline("0") p.expect("-> ray.util.pdb.set_trace()") + p.sendline("ll") + # Cannot use the 🐛 symbol here because pexpect doesn't support + # unicode, but this test also does nicely: + p.expect("unicode") p.sendline("c") p.expect("Enter breakpoint index or press enter to refresh: ") p.sendline("0") diff --git a/python/ray/util/rpdb.py b/python/ray/util/rpdb.py index 134e0b4ec..251dc25a4 100644 --- a/python/ray/util/rpdb.py +++ b/python/ray/util/rpdb.py @@ -8,6 +8,7 @@ import json import logging import os import re +import select import socket import sys import uuid @@ -234,3 +235,27 @@ def set_trace(breakpoint_uuid=None): def post_mortem(): rdb = connect_ray_pdb(None, None, False, None) rdb.post_mortem() + + +def connect_pdb_client(host, port): + s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + s.connect((host, port)) + + while True: + # Get the list of sockets which are readable. + read_sockets, write_sockets, error_sockets = select.select( + [sys.stdin, s], [], []) + + for sock in read_sockets: + if sock == s: + # Incoming message from remote debugger. + data = sock.recv(4096) + if not data: + return + else: + sys.stdout.write(data.decode()) + sys.stdout.flush() + else: + # User entered a message. + msg = sys.stdin.readline() + s.send(msg.encode()) From 7ab9164f1b556b8e3145463a17c694b7949fc31c Mon Sep 17 00:00:00 2001 From: Barak Michener Date: Sun, 20 Dec 2020 14:54:18 -0800 Subject: [PATCH 45/88] [ray_client] Integrate with test_basic, test_basic_2 and test_actor (#12964) --- .travis.yml | 1 + bazel/python.bzl | 6 +- python/ray/experimental/client/__init__.py | 46 ++++--- .../ray/experimental/client/client_pickler.py | 54 ++++++++- python/ray/experimental/client/common.py | 38 +++--- python/ray/experimental/client/dataclient.py | 29 +++-- .../client/server/core_ray_api.py | 5 - .../ray/experimental/client/server/server.py | 112 ++++++++++-------- .../client/server/server_pickler.py | 19 ++- .../client/server/server_stubs.py | 22 ++-- python/ray/experimental/client/worker.py | 20 +++- python/ray/test_utils.py | 4 + python/ray/tests/BUILD | 17 +++ python/ray/tests/client_test_utils.py | 20 ++++ python/ray/tests/conftest.py | 25 +++- python/ray/tests/test_actor.py | 36 ++++-- python/ray/tests/test_basic.py | 32 ++++- python/ray/tests/test_basic_2.py | 27 ++++- python/ray/tests/test_experimental_client.py | 6 +- .../test_experimental_client_terminate.py | 18 +-- src/ray/protobuf/ray_client.proto | 12 +- 21 files changed, 375 insertions(+), 174 deletions(-) create mode 100644 python/ray/tests/client_test_utils.py diff --git a/.travis.yml b/.travis.yml index e2d3822a2..dc133de49 100644 --- a/.travis.yml +++ b/.travis.yml @@ -46,6 +46,7 @@ matrix: script: # bazel python tests for medium size tests. Used for parallelization. - if [ $RAY_CI_PYTHON_AFFECTED == "1" ]; then ./ci/keep_alive bazel test --config=ci $(./scripts/bazel_export_options) --test_tag_filters=-jenkins_only,medium_size_python_tests_a_to_j python/ray/tests/...; fi + - if [ $RAY_CI_PYTHON_AFFECTED == "1" ]; then ./ci/keep_alive bazel test --config=ci $(./scripts/bazel_export_options) --test_tag_filters=-jenkins_only,client_tests --test_env=RAY_TEST_CLIENT_MODE=1 python/ray/tests/...; fi - os: linux env: diff --git a/bazel/python.bzl b/bazel/python.bzl index 523792fdc..5afe570e7 100644 --- a/bazel/python.bzl +++ b/bazel/python.bzl @@ -1,12 +1,14 @@ # py_test_module_list creates a py_test target for each # Python file in `files` -def py_test_module_list(files, size, deps, extra_srcs, **kwargs): +def py_test_module_list(files, size, deps, extra_srcs, name_suffix="", **kwargs): for file in files: # remove .py - name = file[:-3] + name = file[:-3] + name_suffix + main = file native.py_test( name = name, size = size, + main = file, srcs = extra_srcs + [file], **kwargs ) diff --git a/python/ray/experimental/client/__init__.py b/python/ray/experimental/client/__init__.py index 2af86d023..ed1983528 100644 --- a/python/ray/experimental/client/__init__.py +++ b/python/ray/experimental/client/__init__.py @@ -4,6 +4,7 @@ from typing import Optional, List, Tuple from contextlib import contextmanager import logging +import os logger = logging.getLogger(__name__) @@ -43,9 +44,11 @@ def stash_api_for_tests(in_test: bool): is_server = _is_server if in_test: _is_server = True - yield _server_api - if in_test: - _is_server = is_server + try: + yield _server_api + finally: + if in_test: + _is_server = is_server def _set_client_api(val: Optional[APIImpl]): @@ -77,18 +80,7 @@ def reset_api(): def _get_client_api() -> APIImpl: global _client_api - global _server_api - global _is_server - api = None - if _is_server: - api = _server_api - else: - api = _client_api - if api is None: - # We're inside a raylet worker - from ray.experimental.client.server.core_ray_api import CoreRayAPI - return CoreRayAPI() - return api + return _client_api def _get_server_instance(): @@ -124,9 +116,33 @@ class RayAPIStub: global _client_api return _client_api is not None + def init(self, *args, **kwargs): + if _is_client_test_env(): + global _test_server + import ray.experimental.client.server.server as ray_client_server + _test_server, address_info = ray_client_server.init_and_serve( + "localhost:50051", test_mode=True, *args, **kwargs) + self.connect("localhost:50051") + return address_info + else: + raise NotImplementedError( + "Please call ray.connect() in client mode") + ray = RayAPIStub() +_test_server = None + + +def _stop_test_server(*args): + global _test_server + _test_server.stop(*args) + + +def _is_client_test_env() -> bool: + return os.environ.get("RAY_TEST_CLIENT_MODE") == "1" + + # Someday we might add methods in this module so that someone who # tries to `import ray_client as ray` -- as a module, instead of # `from ray_client import ray` -- as the API stub diff --git a/python/ray/experimental/client/client_pickler.py b/python/ray/experimental/client/client_pickler.py index 73df31c0e..2496199ea 100644 --- a/python/ray/experimental/client/client_pickler.py +++ b/python/ray/experimental/client/client_pickler.py @@ -28,11 +28,15 @@ import sys from typing import NamedTuple from typing import Any +from typing import Optional +from ray.experimental.client import RayAPIStub from ray.experimental.client.common import ClientObjectRef from ray.experimental.client.common import ClientActorHandle from ray.experimental.client.common import ClientActorRef +from ray.experimental.client.common import ClientActorClass from ray.experimental.client.common import ClientRemoteFunc +from ray.experimental.client.common import ClientRemoteMethod from ray.experimental.client.common import SelfReferenceSentinel import ray.core.generated.ray_client_pb2 as ray_client_pb2 @@ -44,8 +48,11 @@ if sys.version_info < (3, 8): else: import pickle # noqa: F401 -PickleStub = NamedTuple("PickleStub", [("type", str), ("client_id", str), - ("ref_id", bytes)]) +# NOTE(barakmich): These PickleStubs are really close to +# the data for an exectuion, with no arguments. Combine the two? +PickleStub = NamedTuple("PickleStub", + [("type", str), ("client_id", str), ("ref_id", bytes), + ("name", Optional[str])]) class ClientPickler(cloudpickle.CloudPickler): @@ -54,17 +61,26 @@ class ClientPickler(cloudpickle.CloudPickler): self.client_id = client_id def persistent_id(self, obj): - if isinstance(obj, ClientObjectRef): + if isinstance(obj, RayAPIStub): + return PickleStub( + type="Ray", + client_id=self.client_id, + ref_id=b"", + name=None, + ) + elif isinstance(obj, ClientObjectRef): return PickleStub( type="Object", client_id=self.client_id, ref_id=obj.id, + name=None, ) elif isinstance(obj, ClientActorHandle): return PickleStub( type="Actor", client_id=self.client_id, ref_id=obj._actor_id, + name=None, ) elif isinstance(obj, ClientRemoteFunc): # TODO(barakmich): This is going to have trouble with mutually @@ -77,11 +93,39 @@ class ClientPickler(cloudpickle.CloudPickler): return PickleStub( type="RemoteFuncSelfReference", client_id=self.client_id, - ref_id=b"") + ref_id=b"", + name=None, + ) return PickleStub( type="RemoteFunc", client_id=self.client_id, - ref_id=obj._ref.id) + ref_id=obj._ref.id, + name=None, + ) + elif isinstance(obj, ClientActorClass): + # TODO(barakmich): Mutual recursion, as above. + if obj._ref is None: + obj._ensure_ref() + if type(obj._ref) == SelfReferenceSentinel: + return PickleStub( + type="RemoteActorSelfReference", + client_id=self.client_id, + ref_id=b"", + name=None, + ) + return PickleStub( + type="RemoteActor", + client_id=self.client_id, + ref_id=obj._ref.id, + name=None, + ) + elif isinstance(obj, ClientRemoteMethod): + return PickleStub( + type="RemoteMethod", + client_id=self.client_id, + ref_id=obj.actor_handle.actor_ref.id, + name=obj.method_name, + ) return None diff --git a/python/ray/experimental/client/common.py b/python/ray/experimental/client/common.py index 74f11c2c2..60901c661 100644 --- a/python/ray/experimental/client/common.py +++ b/python/ray/experimental/client/common.py @@ -1,6 +1,5 @@ import ray.core.generated.ray_client_pb2 as ray_client_pb2 from ray.experimental.client import ray -from typing import Dict class ClientBaseRef: @@ -8,17 +7,20 @@ class ClientBaseRef: self.id: bytes = id ray.call_retain(id) + def binary(self): + return self.id + + def __eq__(self, other): + return self.id == other.id + def __repr__(self): return "%s(%s)" % ( type(self).__name__, self.id.hex(), ) - def __eq__(self, other): - return self.id == other.id - - def binary(self): - return self.id + def __hash__(self): + return hash(self.id) def __del__(self): if ray.is_connected(): @@ -107,18 +109,13 @@ class ClientActorClass(ClientStub): raise TypeError(f"Remote actor cannot be instantiated directly. " "Use {self._name}.remote() instead") - def __getstate__(self) -> Dict: - state = { - "actor_cls": self.actor_cls, - "_name": self._name, - "_ref": self._ref, - } - return state - - def __setstate__(self, state: Dict) -> None: - self.actor_cls = state["actor_cls"] - self._name = state["_name"] - self._ref = state["_ref"] + def _ensure_ref(self): + if self._ref is None: + # As before, set the state of the reference to be an + # in-progress self reference value, which + # the encoding can detect and handle correctly. + self._ref = SelfReferenceSentinel() + self._ref = ray.put(self.actor_cls) def remote(self, *args, **kwargs) -> "ClientActorHandle": # Actually instantiate the actor @@ -126,7 +123,7 @@ class ClientActorClass(ClientStub): return ClientActorHandle(ClientActorRef(ref_id), self) def __repr__(self): - return "ClientRemoteActor(%s, %s)" % (self._name, self._ref) + return "ClientActorClass(%s, %s)" % (self._name, self._ref) def __getattr__(self, key): if key not in self.__dict__: @@ -134,8 +131,7 @@ class ClientActorClass(ClientStub): raise NotImplementedError("static methods") def _prepare_client_task(self) -> ray_client_pb2.ClientTask: - if self._ref is None: - self._ref = ray.put(self.actor_cls) + self._ensure_ref() task = ray_client_pb2.ClientTask() task.type = ray_client_pb2.ClientTask.ACTOR task.name = self._name diff --git a/python/ray/experimental/client/dataclient.py b/python/ray/experimental/client/dataclient.py index 7e16c015b..c6a745df8 100644 --- a/python/ray/experimental/client/dataclient.py +++ b/python/ray/experimental/client/dataclient.py @@ -53,24 +53,33 @@ class DataClient: resp_stream = stub.Datapath( iter(self.request_queue.get, None), metadata=(("client_id", self._client_id), )) - for response in resp_stream: - if response.req_id == 0: - # This is not being waited for. - logger.debug(f"Got unawaited response {response}") - continue - with self.cv: - self.ready_data[response.req_id] = response - self.cv.notify_all() + try: + for response in resp_stream: + if response.req_id == 0: + # This is not being waited for. + logger.debug(f"Got unawaited response {response}") + continue + with self.cv: + self.ready_data[response.req_id] = response + self.cv.notify_all() + except grpc.RpcError as e: + if grpc.StatusCode.CANCELLED == e.code(): + # Gracefully shutting down + logger.info("Cancelling data channel") + else: + logger.error( + f"Got Error from rpc channel -- shutting down: {e}") + raise e def close(self, close_channel: bool = False) -> None: if self.request_queue is not None: self.request_queue.put(None) self.request_queue = None + if close_channel: + self.channel.close() if self.data_thread is not None: self.data_thread.join() self.data_thread = None - if close_channel: - self.channel.close() def _blocking_send(self, req: ray_client_pb2.DataRequest ) -> ray_client_pb2.DataResponse: diff --git a/python/ray/experimental/client/server/core_ray_api.py b/python/ray/experimental/client/server/core_ray_api.py index 2d930f352..0762cd0b1 100644 --- a/python/ray/experimental/client/server/core_ray_api.py +++ b/python/ray/experimental/client/server/core_ray_api.py @@ -79,8 +79,3 @@ class RayServerAPI(CoreRayAPI): def __init__(self, server_instance): self.server = server_instance - - def call_remote(self, instance: ClientStub, *args, **kwargs) -> bytes: - task = instance._prepare_client_task() - ticket = self.server.Schedule(task, prepared_args=args) - return ticket.return_id diff --git a/python/ray/experimental/client/server/server.py b/python/ray/experimental/client/server/server.py index 0fd34eda4..2841384d8 100644 --- a/python/ray/experimental/client/server/server.py +++ b/python/ray/experimental/client/server/server.py @@ -21,7 +21,7 @@ from ray.experimental.client.server.server_pickler import dumps_from_server from ray.experimental.client.server.server_pickler import loads_from_client from ray.experimental.client.server.core_ray_api import RayServerAPI from ray.experimental.client.server.dataservicer import DataServicer -from ray.experimental.client.server.server_stubs import current_func +from ray.experimental.client.server.server_stubs import current_remote logger = logging.getLogger(__name__) @@ -205,82 +205,75 @@ class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer): ready_object_ids=ready_object_ids, remaining_object_ids=remaining_object_ids) - def Schedule(self, task, context=None, - prepared_args=None) -> ray_client_pb2.ClientTaskTicket: + def Schedule(self, task, context=None) -> ray_client_pb2.ClientTaskTicket: logger.info("schedule: %s %s" % (task.name, ray_client_pb2.ClientTask.RemoteExecType.Name(task.type))) with stash_api_for_tests(self._test_mode): - if task.type == ray_client_pb2.ClientTask.FUNCTION: - return self._schedule_function(task, context, prepared_args) - elif task.type == ray_client_pb2.ClientTask.ACTOR: - return self._schedule_actor(task, context, prepared_args) - elif task.type == ray_client_pb2.ClientTask.METHOD: - return self._schedule_method(task, context, prepared_args) - else: - raise NotImplementedError( - "Unimplemented Schedule task type: %s" % - ray_client_pb2.ClientTask.RemoteExecType.Name(task.type)) + try: + if task.type == ray_client_pb2.ClientTask.FUNCTION: + result = self._schedule_function(task, context) + elif task.type == ray_client_pb2.ClientTask.ACTOR: + result = self._schedule_actor(task, context) + elif task.type == ray_client_pb2.ClientTask.METHOD: + result = self._schedule_method(task, context) + else: + raise NotImplementedError( + "Unimplemented Schedule task type: %s" % + ray_client_pb2.ClientTask.RemoteExecType.Name( + task.type)) + result.valid = True + return result + except Exception as e: + logger.error(f"Caught schedule exception {e}") + return ray_client_pb2.ClientTaskTicket( + valid=False, error=cloudpickle.dumps(e)) - def _schedule_method( - self, - task: ray_client_pb2.ClientTask, - context=None, - prepared_args=None) -> ray_client_pb2.ClientTaskTicket: + def _schedule_method(self, task: ray_client_pb2.ClientTask, + context=None) -> ray_client_pb2.ClientTaskTicket: actor_handle = self.actor_refs.get(task.payload_id) if actor_handle is None: raise Exception( "Can't run an actor the server doesn't have a handle for") - arglist = self._convert_args(task.args, prepared_args) - output = getattr(actor_handle, task.name).remote(*arglist) + arglist, kwargs = self._convert_args(task.args, task.kwargs) + output = getattr(actor_handle, task.name).remote(*arglist, **kwargs) self.object_refs[task.client_id][output.binary()] = output return ray_client_pb2.ClientTaskTicket(return_id=output.binary()) - def _schedule_actor(self, - task: ray_client_pb2.ClientTask, - context=None, - prepared_args=None) -> ray_client_pb2.ClientTaskTicket: - if task.payload_id not in self.registered_actor_classes: - actor_class_ref = \ - self.object_refs[task.client_id][task.payload_id] - actor_class = ray.get(actor_class_ref) - if not inspect.isclass(actor_class): - raise Exception("Attempting to schedule actor that " - "isn't a class.") - reg_class = ray.remote(actor_class) - self.registered_actor_classes[task.payload_id] = reg_class - remote_class = self.registered_actor_classes[task.payload_id] - arglist = self._convert_args(task.args, prepared_args) - actor = remote_class.remote(*arglist) + def _schedule_actor(self, task: ray_client_pb2.ClientTask, + context=None) -> ray_client_pb2.ClientTaskTicket: + remote_class = self.lookup_or_register_actor(task.payload_id, + task.client_id) + + arglist, kwargs = self._convert_args(task.args, task.kwargs) + with current_remote(remote_class): + actor = remote_class.remote(*arglist, **kwargs) self.actor_refs[actor._actor_id.binary()] = actor self.actor_owners[task.client_id].add(actor._actor_id.binary()) return ray_client_pb2.ClientTaskTicket( return_id=actor._actor_id.binary()) - def _schedule_function( - self, - task: ray_client_pb2.ClientTask, - context=None, - prepared_args=None) -> ray_client_pb2.ClientTaskTicket: + def _schedule_function(self, task: ray_client_pb2.ClientTask, + context=None) -> ray_client_pb2.ClientTaskTicket: remote_func = self.lookup_or_register_func(task.payload_id, task.client_id) - arglist = self._convert_args(task.args, prepared_args) - # Prepare call if we're in a test - with current_func(remote_func): - output = remote_func.remote(*arglist) + arglist, kwargs = self._convert_args(task.args, task.kwargs) + with current_remote(remote_func): + output = remote_func.remote(*arglist, **kwargs) if output.binary() in self.object_refs[task.client_id]: raise Exception("already found it") self.object_refs[task.client_id][output.binary()] = output return ray_client_pb2.ClientTaskTicket(return_id=output.binary()) - def _convert_args(self, arg_list, prepared_args=None): - if prepared_args is not None: - return prepared_args - out = [] + def _convert_args(self, arg_list, kwarg_map): + argout = [] for arg in arg_list: t = convert_from_arg(arg, self) - out.append(t) - return out + argout.append(t) + kwargout = {} + for k in kwarg_map: + kwargout[k] = convert_from_arg(kwarg_map[k], self) + return argout, kwargout def lookup_or_register_func(self, id: bytes, client_id: str ) -> ray.remote_function.RemoteFunction: @@ -293,6 +286,17 @@ class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer): self.function_refs[id] = ray.remote(func) return self.function_refs[id] + def lookup_or_register_actor(self, id: bytes, client_id: str): + if id not in self.registered_actor_classes: + actor_class_ref = self.object_refs[client_id][id] + actor_class = ray.get(actor_class_ref) + if not inspect.isclass(actor_class): + raise Exception("Attempting to schedule actor that " + "isn't a class.") + reg_class = ray.remote(actor_class) + self.registered_actor_classes[id] = reg_class + return self.registered_actor_classes[id] + def return_exception_in_context(err, context): if context is not None: @@ -319,6 +323,12 @@ def serve(connection_str, test_mode=False): return server +def init_and_serve(connection_str, test_mode=False, *args, **kwargs): + info = ray.init(*args, **kwargs) + server = serve(connection_str, test_mode) + return (server, info) + + if __name__ == "__main__": logging.basicConfig(level="INFO") # TODO(barakmich): Perhaps wrap ray init diff --git a/python/ray/experimental/client/server/server_pickler.py b/python/ray/experimental/client/server/server_pickler.py index ea6bd74d0..c3cd161bd 100644 --- a/python/ray/experimental/client/server/server_pickler.py +++ b/python/ray/experimental/client/server/server_pickler.py @@ -21,7 +21,8 @@ from typing import Any from typing import TYPE_CHECKING from ray.experimental.client.client_pickler import PickleStub -from ray.experimental.client.server.server_stubs import ServerFunctionSentinel +from ray.experimental.client.server.server_stubs import ( + ServerSelfReferenceSentinel) if TYPE_CHECKING: from ray.experimental.client.server.server import RayletServicer @@ -54,6 +55,7 @@ class ServerPickler(cloudpickle.CloudPickler): type="Object", client_id=self.client_id, ref_id=obj_id, + name=None, ) elif isinstance(obj, ray.actor.ActorHandle): actor_id = obj._actor_id.binary() @@ -66,6 +68,7 @@ class ServerPickler(cloudpickle.CloudPickler): type="Actor", client_id=self.client_id, ref_id=obj._actor_id.binary(), + name=None, ) return None @@ -77,15 +80,25 @@ class ClientUnpickler(pickle.Unpickler): def persistent_load(self, pid): assert isinstance(pid, PickleStub) - if pid.type == "Object": + if pid.type == "Ray": + return ray + elif pid.type == "Object": return self.server.object_refs[pid.client_id][pid.ref_id] elif pid.type == "Actor": return self.server.actor_refs[pid.ref_id] elif pid.type == "RemoteFuncSelfReference": - return ServerFunctionSentinel() + return ServerSelfReferenceSentinel() elif pid.type == "RemoteFunc": return self.server.lookup_or_register_func(pid.ref_id, pid.client_id) + elif pid.type == "RemoteActorSelfReference": + return ServerSelfReferenceSentinel() + elif pid.type == "RemoteActor": + return self.server.lookup_or_register_actor( + pid.ref_id, pid.client_id) + elif pid.type == "RemoteMethod": + actor = self.server.actor_refs[pid.ref_id] + return getattr(actor, pid.name) else: raise NotImplementedError("Uncovered client data type") diff --git a/python/ray/experimental/client/server/server_stubs.py b/python/ray/experimental/client/server/server_stubs.py index f55f64f25..9a75747f2 100644 --- a/python/ray/experimental/client/server/server_stubs.py +++ b/python/ray/experimental/client/server/server_stubs.py @@ -1,28 +1,28 @@ from contextlib import contextmanager -_current_remote_func = None +_current_remote_obj = None @contextmanager -def current_func(f): - global _current_remote_func - remote_func = _current_remote_func - _current_remote_func = f +def current_remote(r): + global _current_remote_obj + remote = _current_remote_obj + _current_remote_obj = r try: yield finally: - _current_remote_func = remote_func + _current_remote_obj = remote -class ServerFunctionSentinel: +class ServerSelfReferenceSentinel: def __init__(self): pass def __reduce__(self): - global _current_remote_func - if _current_remote_func is None: - return (ServerFunctionSentinel, tuple()) - return (identity, (_current_remote_func, )) + global _current_remote_obj + if _current_remote_obj is None: + return (ServerSelfReferenceSentinel, tuple()) + return (identity, (_current_remote_obj, )) def identity(x): diff --git a/python/ray/experimental/client/worker.py b/python/ray/experimental/client/worker.py index 54ac71711..6bfab6b75 100644 --- a/python/ray/experimental/client/worker.py +++ b/python/ray/experimental/client/worker.py @@ -109,9 +109,13 @@ class Worker: num_returns: int = 1, timeout: float = None ) -> Tuple[List[ClientObjectRef], List[ClientObjectRef]]: - assert isinstance(object_refs, list) + if not isinstance(object_refs, list): + raise TypeError("wait() expected a list of ClientObjectRef, " + f"got {type(object_refs)}") for ref in object_refs: - assert isinstance(ref, ClientObjectRef) + if not isinstance(ref, ClientObjectRef): + raise TypeError("wait() expected a list of ClientObjectRef, " + f"got list containing {type(ref)}") data = { "object_ids": [object_ref.id for object_ref in object_refs], "num_returns": num_returns, @@ -149,9 +153,16 @@ class Worker: for arg in args: pb_arg = convert_to_arg(arg, self._client_id) task.args.append(pb_arg) + for k, v in kwargs.items(): + task.kwargs[k].CopyFrom(convert_to_arg(v, self._client_id)) task.client_id = self._client_id logger.debug("Scheduling %s" % task) - ticket = self.server.Schedule(task, metadata=self.metadata) + try: + ticket = self.server.Schedule(task, metadata=self.metadata) + except grpc.RpcError as e: + raise e.details() + if not ticket.valid: + raise cloudpickle.loads(ticket.error) return ticket.return_id def call_release(self, id: bytes) -> None: @@ -171,10 +182,9 @@ class Worker: self.reference_count[id] += 1 def close(self): - self.data_client.close() + self.data_client.close(close_channel=True) self.server = None if self.channel: - self.channel.close() self.channel = None def terminate_actor(self, actor: ClientActorHandle, diff --git a/python/ray/test_utils.py b/python/ray/test_utils.py index 594431e2f..a479903ff 100644 --- a/python/ray/test_utils.py +++ b/python/ray/test_utils.py @@ -443,3 +443,7 @@ def format_web_url(url): def new_scheduler_enabled(): return os.environ.get("RAY_ENABLE_NEW_SCHEDULER", "1") == "1" + + +def client_test_enabled() -> bool: + return os.environ.get("RAY_TEST_CLIENT_MODE") == "1" diff --git a/python/ray/tests/BUILD b/python/ray/tests/BUILD index c5837a158..e88986475 100644 --- a/python/ray/tests/BUILD +++ b/python/ray/tests/BUILD @@ -153,3 +153,20 @@ py_test( tags = ["exclusive"], deps = ["//:ray_lib"], ) + + +py_test_module_list( + files = [ + "test_actor.py", + "test_basic.py", + "test_basic_2.py", + ], + size = "medium", + extra_srcs = SRCS, + name_suffix = "_client_mode", + # TODO(barakmich): py_test will support env in Bazel 4.0.0... + # Until then, we can use tags. + #env = {"RAY_TEST_CLIENT_MODE": "true"}, + tags = ["exclusive", "client_tests"], + deps = ["//:ray_lib"], +) diff --git a/python/ray/tests/client_test_utils.py b/python/ray/tests/client_test_utils.py new file mode 100644 index 000000000..c7b0081d3 --- /dev/null +++ b/python/ray/tests/client_test_utils.py @@ -0,0 +1,20 @@ +import asyncio + + +def create_remote_signal_actor(ray): + # TODO(barakmich): num_cpus=0 + @ray.remote + class SignalActor: + def __init__(self): + self.ready_event = asyncio.Event() + + def send(self, clear=False): + self.ready_event.set() + if clear: + self.ready_event.clear() + + async def wait(self, should_wait=True): + if should_wait: + await self.ready_event.wait() + + return SignalActor diff --git a/python/ray/tests/conftest.py b/python/ray/tests/conftest.py index 26d2bbdc0..05cd9d8ca 100644 --- a/python/ray/tests/conftest.py +++ b/python/ray/tests/conftest.py @@ -9,12 +9,18 @@ import subprocess import ray from ray.cluster_utils import Cluster from ray.test_utils import init_error_pubsub +from ray.test_utils import client_test_enabled +import ray.experimental.client as ray_client @pytest.fixture def shutdown_only(): yield None # The code after the yield will run as teardown code. + if client_test_enabled(): + ray_client.ray.disconnect() + ray_client._stop_test_server(1) + ray_client.reset_api() ray.shutdown() @@ -43,9 +49,17 @@ def _ray_start(**kwargs): init_kwargs = get_default_fixture_ray_kwargs() init_kwargs.update(kwargs) # Start the Ray processes. - address_info = ray.init(**init_kwargs) + if client_test_enabled(): + address_info = ray_client.ray.init(**init_kwargs) + else: + address_info = ray.init(**init_kwargs) + yield address_info # The code after the yield will run as teardown code. + if client_test_enabled(): + ray_client.ray.disconnect() + ray_client._stop_test_server(1) + ray_client.reset_api() ray.shutdown() @@ -130,9 +144,16 @@ def _ray_start_cluster(**kwargs): # We assume driver will connect to the head (first node), # so ray init will be invoked if do_init is true if len(remote_nodes) == 1 and do_init: - ray.init(address=cluster.address) + if client_test_enabled(): + ray_client.ray.init(address=cluster.address) + else: + ray.init(address=cluster.address) yield cluster # The code after the yield will run as teardown code. + if client_test_enabled(): + ray_client.ray.disconnect() + ray_client._stop_test_server(1) + ray_client.reset_api() ray.shutdown() cluster.shutdown() diff --git a/python/ray/tests/test_actor.py b/python/ray/tests/test_actor.py index 05d53c1b3..1e761762e 100644 --- a/python/ray/tests/test_actor.py +++ b/python/ray/tests/test_actor.py @@ -11,15 +11,21 @@ import sys import tempfile import datetime -import ray -import ray.test_utils -import ray.cluster_utils +from ray.test_utils import client_test_enabled +from ray.test_utils import wait_for_condition +from ray.test_utils import wait_for_pid_to_exit +from ray.tests.client_test_utils import create_remote_signal_actor +if client_test_enabled(): + from ray.experimental.client import ray +else: + import ray # NOTE: We have to import setproctitle after ray because we bundle setproctitle # with ray. -import setproctitle +import setproctitle # noqa +@pytest.mark.skipif(client_test_enabled(), reason="test setup order") def test_caching_actors(shutdown_only): # Test defining actors before ray.init() has been called. @@ -238,6 +244,7 @@ def test_actor_import_counter(ray_start_10_cpus): assert ray.get(g.remote()) == num_remote_functions - 1 +@pytest.mark.skipif(client_test_enabled(), reason="internal api") def test_actor_method_metadata_cache(ray_start_regular): class Actor(object): pass @@ -257,6 +264,7 @@ def test_actor_method_metadata_cache(ray_start_regular): assert [id(x) for x in list(cache.items())[0]] == cached_data_id +@pytest.mark.skipif(client_test_enabled(), reason="internal api") def test_actor_class_name(ray_start_regular): @ray.remote class Foo: @@ -556,6 +564,7 @@ def test_actor_static_attributes(ray_start_regular_shared): assert ray.get(t.g.remote()) == 3 +@pytest.mark.skipif(client_test_enabled(), reason="remote args") def test_decorator_args(ray_start_regular_shared): # This is an invalid way of using the actor decorator. with pytest.raises(Exception): @@ -618,6 +627,8 @@ def test_random_id_generation(ray_start_regular_shared): assert f1._actor_id != f2._actor_id +@pytest.mark.skipif( + client_test_enabled(), reason="differing inheritence structure") def test_actor_inheritance(ray_start_regular_shared): class NonActorBase: def __init__(self): @@ -630,8 +641,7 @@ def test_actor_inheritance(ray_start_regular_shared): pass # Test that you can't instantiate an actor class directly. - with pytest.raises( - Exception, match="Actors cannot be instantiated directly."): + with pytest.raises(Exception, match="cannot be instantiated directly"): ActorBase() # Test that you can't inherit from an actor class. @@ -645,6 +655,7 @@ def test_actor_inheritance(ray_start_regular_shared): pass +@pytest.mark.skipif(client_test_enabled(), reason="remote args") def test_multiple_return_values(ray_start_regular_shared): @ray.remote class Foo: @@ -678,6 +689,7 @@ def test_multiple_return_values(ray_start_regular_shared): assert ray.get([id3a, id3b, id3c]) == [1, 2, 3] +@pytest.mark.skipif(client_test_enabled(), reason="remote args") def test_options_num_returns(ray_start_regular_shared): @ray.remote class Foo: @@ -693,6 +705,7 @@ def test_options_num_returns(ray_start_regular_shared): assert ray.get([obj1, obj2]) == [1, 2] +@pytest.mark.skipif(client_test_enabled(), reason="remote args") def test_options_name(ray_start_regular_shared): @ray.remote class Foo: @@ -734,13 +747,13 @@ def test_actor_deletion(ray_start_regular_shared): a = Actor.remote() pid = ray.get(a.getpid.remote()) a = None - ray.test_utils.wait_for_pid_to_exit(pid) + wait_for_pid_to_exit(pid) actors = [Actor.remote() for _ in range(10)] pids = ray.get([a.getpid.remote() for a in actors]) a = None actors = None - [ray.test_utils.wait_for_pid_to_exit(pid) for pid in pids] + [wait_for_pid_to_exit(pid) for pid in pids] def test_actor_method_deletion(ray_start_regular_shared): @@ -769,7 +782,8 @@ def test_distributed_actor_handle_deletion(ray_start_regular_shared): ray.get(signal.wait.remote()) return ray.get(actor.method.remote()) - signal = ray.test_utils.SignalActor.remote() + SignalActor = create_remote_signal_actor(ray) + signal = SignalActor.remote() a = Actor.remote() pid = ray.get(a.getpid.remote()) # Pass the handle to another task that cannot run yet. @@ -780,7 +794,7 @@ def test_distributed_actor_handle_deletion(ray_start_regular_shared): # Once the task finishes, the actor process should get killed. ray.get(signal.send.remote()) assert ray.get(x_id) == 1 - ray.test_utils.wait_for_pid_to_exit(pid) + wait_for_pid_to_exit(pid) def test_multiple_actors(ray_start_regular_shared): @@ -921,7 +935,7 @@ def test_atexit_handler(ray_start_regular_shared, exit_condition): if exit_condition == "ray.kill": assert not check_file_written() else: - ray.test_utils.wait_for_condition(check_file_written) + wait_for_condition(check_file_written) if __name__ == "__main__": diff --git a/python/ray/tests/test_basic.py b/python/ray/tests/test_basic.py index d0e98972a..709b467e6 100644 --- a/python/ray/tests/test_basic.py +++ b/python/ray/tests/test_basic.py @@ -8,14 +8,23 @@ import time import numpy as np import pytest -import ray import ray.cluster_utils -import ray.test_utils +from ray.test_utils import ( + client_test_enabled, + dicts_equal, + wait_for_pid_to_exit, +) + +if client_test_enabled(): + from ray.experimental.client import ray +else: + import ray logger = logging.getLogger(__name__) # https://github.com/ray-project/ray/issues/6662 +@pytest.mark.skipif(client_test_enabled(), reason="internal api") def test_ignore_http_proxy(shutdown_only): ray.init(num_cpus=1) os.environ["http_proxy"] = "http://example.com" @@ -29,6 +38,7 @@ def test_ignore_http_proxy(shutdown_only): # https://github.com/ray-project/ray/issues/7263 +@pytest.mark.skipif(client_test_enabled(), reason="message size") def test_grpc_message_size(shutdown_only): ray.init(num_cpus=1) @@ -45,12 +55,14 @@ def test_grpc_message_size(shutdown_only): # https://github.com/ray-project/ray/issues/7287 +@pytest.mark.skipif(client_test_enabled(), reason="internal api") def test_omp_threads_set(shutdown_only): ray.init(num_cpus=1) # Should have been auto set by ray init. assert os.environ["OMP_NUM_THREADS"] == "1" +@pytest.mark.skipif(client_test_enabled(), reason="remote args") def test_submit_api(shutdown_only): ray.init(num_cpus=2, num_gpus=1, resources={"Custom": 1}) @@ -109,6 +121,7 @@ def test_submit_api(shutdown_only): assert ray.get([id1, id2, id3, id4]) == [0, 1, "test", 2] +@pytest.mark.skipif(client_test_enabled(), reason="remote args") def test_invalid_arguments(shutdown_only): ray.init(num_cpus=2) @@ -163,6 +176,7 @@ def test_invalid_arguments(shutdown_only): x = 1 +@pytest.mark.skipif(client_test_enabled(), reason="remote args") def test_many_fractional_resources(shutdown_only): ray.init(num_cpus=2, num_gpus=2, resources={"Custom": 2}) @@ -178,7 +192,7 @@ def test_many_fractional_resources(shutdown_only): } if block: ray.get(g.remote()) - return ray.test_utils.dicts_equal(true_resources, accepted_resources) + return dicts_equal(true_resources, accepted_resources) # Check that the resource are assigned correctly. result_ids = [] @@ -230,6 +244,7 @@ def test_many_fractional_resources(shutdown_only): assert False, "Did not get correct available resources." +@pytest.mark.skipif(client_test_enabled(), reason="remote args") def test_background_tasks_with_max_calls(shutdown_only): ray.init(num_cpus=2) @@ -257,7 +272,7 @@ def test_background_tasks_with_max_calls(shutdown_only): pid, g_id = nested.pop(0) ray.get(g_id) del g_id - ray.test_utils.wait_for_pid_to_exit(pid) + wait_for_pid_to_exit(pid) def test_fair_queueing(shutdown_only): @@ -327,6 +342,7 @@ def test_wait_timing(shutdown_only): assert len(not_ready) == 1 +@pytest.mark.skipif(client_test_enabled(), reason="internal _raylet") def test_function_descriptor(): python_descriptor = ray._raylet.PythonFunctionDescriptor( "module_name", "function_name", "class_name", "function_hash") @@ -344,6 +360,7 @@ def test_function_descriptor(): assert d.get(python_descriptor2) == 123 +@pytest.mark.skipif(client_test_enabled(), reason="remote args") def test_ray_options(shutdown_only): @ray.remote( num_cpus=2, num_gpus=3, memory=150 * 2**20, resources={"custom1": 1}) @@ -371,6 +388,7 @@ def test_ray_options(shutdown_only): assert without_options != with_options +@pytest.mark.skipif(client_test_enabled(), reason="message size") @pytest.mark.parametrize( "ray_start_cluster_head", [{ "num_cpus": 0, @@ -438,8 +456,11 @@ def test_nested_functions(ray_start_shared_local_modes): assert ray.get(factorial.remote(4)) == 24 assert ray.get(factorial.remote(5)) == 120 - # Test remote functions that recursively call each other. +@pytest.mark.skipif( + client_test_enabled(), reason="mutual recursion is a known issue") +def test_mutually_recursive_functions(ray_start_shared_local_modes): + # Test remote functions that recursively call each other. @ray.remote def factorial_even(n): assert n % 2 == 0 @@ -710,6 +731,7 @@ def test_args_stars_after(ray_start_shared_local_modes): ray.get(remote_test_function.remote(local_method, actor_method)) +@pytest.mark.skipif(client_test_enabled(), reason="internal api") def test_object_id_backward_compatibility(ray_start_shared_local_modes): # We've renamed Python's `ObjectID` to `ObjectRef`, and added a type # alias for backward compatibility. diff --git a/python/ray/tests/test_basic_2.py b/python/ray/tests/test_basic_2.py index fc6befc7b..25688a6f7 100644 --- a/python/ray/tests/test_basic_2.py +++ b/python/ray/tests/test_basic_2.py @@ -9,10 +9,16 @@ import pytest from unittest.mock import MagicMock, patch -import ray import ray.cluster_utils -import ray.test_utils +from ray.test_utils import client_test_enabled +from ray.tests.client_test_utils import create_remote_signal_actor from ray.exceptions import GetTimeoutError +from ray.exceptions import RayTaskError + +if client_test_enabled(): + from ray.experimental.client import ray +else: + import ray logger = logging.getLogger(__name__) @@ -25,6 +31,8 @@ logger = logging.getLogger(__name__) }], indirect=True) def test_variable_number_of_args(shutdown_only): + ray.init(num_cpus=1) + @ray.remote def varargs_fct1(*a): return " ".join(map(str, a)) @@ -33,8 +41,6 @@ def test_variable_number_of_args(shutdown_only): def varargs_fct2(a, *b): return " ".join(map(str, b)) - ray.init(num_cpus=1) - x = varargs_fct1.remote(0, 1, 2) assert ray.get(x) == "0 1 2" x = varargs_fct2.remote(0, 1, 2) @@ -160,7 +166,7 @@ def test_redefining_remote_functions(shutdown_only): def g(): return nonexistent() - with pytest.raises(ray.exceptions.RayTaskError, match="nonexistent"): + with pytest.raises(RayTaskError, match="nonexistent"): ray.get(g.remote()) def nonexistent(): @@ -187,6 +193,7 @@ def test_redefining_remote_functions(shutdown_only): assert ray.get(ray.get(h.remote(i))) == i +@pytest.mark.skipif(client_test_enabled(), reason="message size") def test_call_matrix(shutdown_only): ray.init(object_store_memory=1000 * 1024 * 1024) @@ -312,6 +319,7 @@ def test_actor_pass_by_ref_order_optimization(shutdown_only): assert delta < 10, "did not skip slow value" +@pytest.mark.skipif(client_test_enabled(), reason="message size") @pytest.mark.parametrize( "ray_start_cluster", [{ "num_cpus": 1, @@ -332,6 +340,7 @@ def test_call_chain(ray_start_cluster): assert ray.get(x) == 100 +@pytest.mark.skipif(client_test_enabled(), reason="message size") def test_system_config_when_connecting(ray_start_cluster): config = {"object_pinning_enabled": 0, "object_timeout_milliseconds": 200} cluster = ray.cluster_utils.Cluster() @@ -368,7 +377,8 @@ def test_get_multiple(ray_start_regular_shared): def test_get_with_timeout(ray_start_regular_shared): - signal = ray.test_utils.SignalActor.remote() + SignalActor = create_remote_signal_actor(ray) + signal = SignalActor.remote() # Check that get() returns early if object is ready. start = time.time() @@ -438,6 +448,7 @@ def test_inline_arg_memory_corruption(ray_start_regular_shared): ray.get(a.add.remote(f.remote())) +@pytest.mark.skipif(client_test_enabled(), reason="internal api") def test_skip_plasma(ray_start_regular_shared): @ray.remote class Actor: @@ -454,6 +465,8 @@ def test_skip_plasma(ray_start_regular_shared): assert ray.get(obj_ref) == 2 +@pytest.mark.skipif( + client_test_enabled(), reason="internal api and message size") def test_actor_large_objects(ray_start_regular_shared): @ray.remote class Actor: @@ -524,6 +537,7 @@ def test_actor_recursive(ray_start_regular_shared): assert result == [x * 2 for x in range(100)] +@pytest.mark.skipif(client_test_enabled(), reason="remote args") def test_actor_concurrent(ray_start_regular_shared): @ray.remote class Batcher: @@ -626,6 +640,7 @@ def test_duplicate_args(ray_start_regular_shared): arg1, arg2, arg1, kwarg1=arg1, kwarg2=arg2, kwarg1_duplicate=arg1)) +@pytest.mark.skipif(client_test_enabled(), reason="internal api") def test_get_correct_node_ip(): with patch("ray.worker") as worker_mock: node_mock = MagicMock() diff --git a/python/ray/tests/test_experimental_client.py b/python/ray/tests/test_experimental_client.py index 1231b6730..e68abb366 100644 --- a/python/ray/tests/test_experimental_client.py +++ b/python/ray/tests/test_experimental_client.py @@ -81,11 +81,11 @@ def test_wait(ray_start_regular_shared): with pytest.raises(Exception): # Reference not in the object store. ray.wait([ClientObjectRef("blabla")]) - with pytest.raises(AssertionError): + with pytest.raises(TypeError): ray.wait("blabla") - with pytest.raises(AssertionError): + with pytest.raises(TypeError): ray.wait(ClientObjectRef("blabla")) - with pytest.raises(AssertionError): + with pytest.raises(TypeError): ray.wait(["blabla"]) diff --git a/python/ray/tests/test_experimental_client_terminate.py b/python/ray/tests/test_experimental_client_terminate.py index e44c617e6..c475a5457 100644 --- a/python/ray/tests/test_experimental_client_terminate.py +++ b/python/ray/tests/test_experimental_client_terminate.py @@ -1,6 +1,6 @@ import pytest -import asyncio from ray.tests.test_experimental_client import ray_start_client_server +from ray.tests.client_test_utils import create_remote_signal_actor from ray.test_utils import wait_for_condition from ray.exceptions import TaskCancelledError from ray.exceptions import RayTaskError @@ -45,21 +45,7 @@ def test_kill_actor_immediately_after_creation(ray_start_regular): @pytest.mark.parametrize("use_force", [True, False]) def test_cancel_chain(ray_start_regular, use_force): with ray_start_client_server() as ray: - - @ray.remote - class SignalActor: - def __init__(self): - self.ready_event = asyncio.Event() - - def send(self, clear=False): - self.ready_event.set() - if clear: - self.ready_event.clear() - - async def wait(self, should_wait=True): - if should_wait: - await self.ready_event.wait() - + SignalActor = create_remote_signal_actor(ray) signaler = SignalActor.remote() @ray.remote diff --git a/src/ray/protobuf/ray_client.proto b/src/ray/protobuf/ray_client.proto index cdc3ee8aa..ea4939738 100644 --- a/src/ray/protobuf/ray_client.proto +++ b/src/ray/protobuf/ray_client.proto @@ -50,16 +50,22 @@ message ClientTask { string name = 2; // A reference to the payload. bytes payload_id = 3; - // The parameters to pass to this call. + // Positional parameters to pass to this call. repeated Arg args = 4; + // Keyword parameters to pass to this call. + map kwargs = 5; // The ID of the client namespace associated with the Datapath stream making this // request. - string client_id = 5; + string client_id = 6; } message ClientTaskTicket { + // Was the task successful? + bool valid = 1; // A reference to the returned value from the execution. - bytes return_id = 1; + bytes return_id = 2; + // If unsuccessful, an encoding of the error. + bytes error = 3; } // Delivers data to the server From 11f34f72d832af418de01d92fc7559971ee574c0 Mon Sep 17 00:00:00 2001 From: Ameer Haj Ali Date: Mon, 21 Dec 2020 00:54:46 +0200 Subject: [PATCH 46/88] [autoscaler] Do not count head node with min_workers constraint. (#12980) --- python/ray/autoscaler/_private/autoscaler.py | 9 +- .../_private/resource_demand_scheduler.py | 11 +- python/ray/tests/test_autoscaler.py | 97 +++-- .../tests/test_resource_demand_scheduler.py | 406 +++++++++++------- 4 files changed, 330 insertions(+), 193 deletions(-) diff --git a/python/ray/autoscaler/_private/autoscaler.py b/python/ray/autoscaler/_private/autoscaler.py index 2e55ed151..64167b4cb 100644 --- a/python/ray/autoscaler/_private/autoscaler.py +++ b/python/ray/autoscaler/_private/autoscaler.py @@ -149,7 +149,6 @@ class StandardAutoscaler: def _update(self): now = time.time() - # Throttle autoscaling updates to this interval to avoid exceeding # rate limits on API calls. if now - self.last_update_time < self.update_interval_s: @@ -333,7 +332,7 @@ class StandardAutoscaler: NodeIP, ResourceDict] = \ self.load_metrics.get_static_node_resources_by_ip() - head_node_resources = static_nodes[head_ip] + head_node_resources = static_nodes.get(head_ip, {}) else: head_node_resources = {} @@ -482,11 +481,13 @@ class StandardAutoscaler: # for legacy yamls. self.resource_demand_scheduler.reset_config( self.provider, self.available_node_types, - self.config["max_workers"], upscaling_speed) + self.config["max_workers"], self.config["head_node_type"], + upscaling_speed) else: self.resource_demand_scheduler = ResourceDemandScheduler( self.provider, self.available_node_types, - self.config["max_workers"], upscaling_speed) + self.config["max_workers"], self.config["head_node_type"], + upscaling_speed) except Exception as e: if errors_fatal: diff --git a/python/ray/autoscaler/_private/resource_demand_scheduler.py b/python/ray/autoscaler/_private/resource_demand_scheduler.py index 6bbae1762..f3ec607df 100644 --- a/python/ray/autoscaler/_private/resource_demand_scheduler.py +++ b/python/ray/autoscaler/_private/resource_demand_scheduler.py @@ -47,16 +47,19 @@ class ResourceDemandScheduler: provider: NodeProvider, node_types: Dict[NodeType, NodeTypeConfigDict], max_workers: int, + head_node_type: NodeType, upscaling_speed: float = 1) -> None: self.provider = provider self.node_types = copy.deepcopy(node_types) self.max_workers = max_workers + self.head_node_type = head_node_type self.upscaling_speed = upscaling_speed def reset_config(self, provider: NodeProvider, node_types: Dict[NodeType, NodeTypeConfigDict], max_workers: int, + head_node_type: NodeType, upscaling_speed: float = 1) -> None: """Updates the class state variables. @@ -89,6 +92,7 @@ class ResourceDemandScheduler: self.provider = provider self.node_types = copy.deepcopy(final_node_types) self.max_workers = max_workers + self.head_node_type = head_node_type self.upscaling_speed = upscaling_speed def is_legacy_yaml(self, @@ -153,7 +157,7 @@ class ResourceDemandScheduler: adjusted_min_workers) = \ _add_min_workers_nodes( node_resources, node_type_counts, self.node_types, - self.max_workers, ensure_min_cluster_size) + self.max_workers, self.head_node_type, ensure_min_cluster_size) # Step 3: add nodes for strict spread groups logger.info(f"Placement group demands: {pending_placement_groups}") @@ -490,7 +494,7 @@ def _add_min_workers_nodes( node_resources: List[ResourceDict], node_type_counts: Dict[NodeType, int], node_types: Dict[NodeType, NodeTypeConfigDict], max_workers: int, - ensure_min_cluster_size: List[ResourceDict] + head_node_type: NodeType, ensure_min_cluster_size: List[ResourceDict] ) -> (List[ResourceDict], Dict[NodeType, int], Dict[NodeType, int]): """Updates resource demands to respect the min_workers and request_resources() constraints. @@ -515,6 +519,9 @@ def _add_min_workers_nodes( existing = node_type_counts.get(node_type, 0) target = min( config.get("min_workers", 0), config.get("max_workers", 0)) + if node_type == head_node_type: + # Add 1 to account for head node. + target = target + 1 if existing < target: total_nodes_to_add_dict[node_type] = target - existing node_type_counts[node_type] = target diff --git a/python/ray/tests/test_autoscaler.py b/python/ray/tests/test_autoscaler.py index 7ef1e9c5b..72f361fe2 100644 --- a/python/ray/tests/test_autoscaler.py +++ b/python/ray/tests/test_autoscaler.py @@ -12,7 +12,6 @@ import sys from jsonschema.exceptions import ValidationError import ray -import ray._private.services as services from ray.autoscaler._private.util import prepare_config, validate_config from ray.autoscaler._private import commands from ray.autoscaler.sdk import get_docker_host_mount_location @@ -559,8 +558,13 @@ class AutoscalingTest(unittest.TestCase): config_path = self.write_config(SMALL_CLUSTER) self.provider = MockProvider() runner = MockProcessRunner() - runner.respond_to_call("json .Config.Env", ["[]" for i in range(11)]) + runner.respond_to_call("json .Config.Env", ["[]" for i in range(12)]) lm = LoadMetrics() + self.provider.create_node({}, { + TAG_RAY_NODE_KIND: NODE_KIND_HEAD, + TAG_RAY_USER_NODE_TYPE: NODE_TYPE_LEGACY_HEAD + }, 1) + lm.update("172.0.0.0", {"CPU": 1}, {"CPU": 0}, {}) autoscaler = StandardAutoscaler( config_path, lm, @@ -569,16 +573,16 @@ class AutoscalingTest(unittest.TestCase): max_failures=0, process_runner=runner, update_interval_s=0) - self.waitForNodes(0) + self.waitForNodes(0, tag_filters={TAG_RAY_NODE_KIND: NODE_KIND_WORKER}) autoscaler.update() - self.waitForNodes(2) + self.waitForNodes(2, tag_filters={TAG_RAY_NODE_KIND: NODE_KIND_WORKER}) # Update the config to reduce the cluster size new_config = SMALL_CLUSTER.copy() new_config["max_workers"] = 1 self.write_config(new_config) autoscaler.update() - self.waitForNodes(1) + self.waitForNodes(1, tag_filters={TAG_RAY_NODE_KIND: NODE_KIND_WORKER}) # Update the config to reduce the cluster size new_config["min_workers"] = 10 @@ -587,12 +591,13 @@ class AutoscalingTest(unittest.TestCase): autoscaler.update() # Because one worker already started, the scheduler waits for its # resources to be updated before it launches the remaining min_workers. - self.waitForNodes(1) + self.waitForNodes(1, tag_filters={TAG_RAY_NODE_KIND: NODE_KIND_WORKER}) worker_ip = self.provider.non_terminated_node_ips( tag_filters={TAG_RAY_NODE_KIND: NODE_KIND_WORKER}, )[0] lm.update(worker_ip, {"CPU": 1}, {"CPU": 1}, {}) autoscaler.update() - self.waitForNodes(10) + self.waitForNodes( + 10, tag_filters={TAG_RAY_NODE_KIND: NODE_KIND_WORKER}) def testInitialWorkers(self): """initial_workers is deprecated, this tests that it is ignored.""" @@ -760,7 +765,10 @@ class AutoscalingTest(unittest.TestCase): config_path = self.write_config(config) self.provider = MockProvider() - self.provider.create_node({}, {TAG_RAY_NODE_KIND: "head"}, 1) + self.provider.create_node({}, { + TAG_RAY_NODE_KIND: "head", + TAG_RAY_USER_NODE_TYPE: NODE_TYPE_LEGACY_HEAD + }, 1) head_ip = self.provider.non_terminated_node_ips( tag_filters={TAG_RAY_NODE_KIND: "head"}, )[0] @@ -964,8 +972,13 @@ class AutoscalingTest(unittest.TestCase): config_path = self.write_config(SMALL_CLUSTER) self.provider = MockProvider() runner = MockProcessRunner() - runner.respond_to_call("json .Config.Env", ["[]" for i in range(10)]) + runner.respond_to_call("json .Config.Env", ["[]" for i in range(11)]) + self.provider.create_node({}, { + TAG_RAY_NODE_KIND: NODE_KIND_HEAD, + TAG_RAY_USER_NODE_TYPE: NODE_TYPE_LEGACY_HEAD + }, 1) lm = LoadMetrics() + lm.update("172.0.0.0", {"CPU": 1}, {"CPU": 0}, {}) autoscaler = StandardAutoscaler( config_path, lm, @@ -975,7 +988,7 @@ class AutoscalingTest(unittest.TestCase): max_failures=0, update_interval_s=0) autoscaler.update() - self.waitForNodes(2) + self.waitForNodes(2, tag_filters={TAG_RAY_NODE_KIND: NODE_KIND_WORKER}) # Write a corrupted config self.write_config("asdf", call_prepare_config=False) @@ -983,7 +996,10 @@ class AutoscalingTest(unittest.TestCase): autoscaler.update() time.sleep(0.1) assert autoscaler.pending_launches.value == 0 - assert len(self.provider.non_terminated_nodes({})) == 2 + assert len( + self.provider.non_terminated_nodes({ + TAG_RAY_NODE_KIND: NODE_KIND_WORKER + })) == 2 # New a good config again new_config = SMALL_CLUSTER.copy() @@ -996,7 +1012,8 @@ class AutoscalingTest(unittest.TestCase): # resources to be updated before it launches the remaining min_workers. lm.update(worker_ip, {"CPU": 1}, {"CPU": 1}, {}) autoscaler.update() - self.waitForNodes(10) + self.waitForNodes( + 10, tag_filters={TAG_RAY_NODE_KIND: NODE_KIND_WORKER}) def testMaxFailures(self): config_path = self.write_config(SMALL_CLUSTER) @@ -1113,53 +1130,60 @@ class AutoscalingTest(unittest.TestCase): self.provider = MockProvider() lm = LoadMetrics() runner = MockProcessRunner() - runner.respond_to_call("json .Config.Env", ["[]" for i in range(5)]) + runner.respond_to_call("json .Config.Env", ["[]" for i in range(6)]) + self.provider.create_node({}, { + TAG_RAY_NODE_KIND: NODE_KIND_HEAD, + TAG_RAY_USER_NODE_TYPE: NODE_TYPE_LEGACY_HEAD + }, 1) + lm.update("172.0.0.0", {"CPU": 1}, {"CPU": 0}, {}) autoscaler = StandardAutoscaler( config_path, lm, max_failures=0, process_runner=runner, update_interval_s=0) - assert len(self.provider.non_terminated_nodes({})) == 0 + assert len( + self.provider.non_terminated_nodes({ + TAG_RAY_NODE_KIND: NODE_KIND_WORKER + })) == 0 autoscaler.update() - self.waitForNodes(1) + self.waitForNodes(1, tag_filters={TAG_RAY_NODE_KIND: NODE_KIND_WORKER}) autoscaler.update() assert autoscaler.pending_launches.value == 0 - assert len(self.provider.non_terminated_nodes({})) == 1 + assert len( + self.provider.non_terminated_nodes({ + TAG_RAY_NODE_KIND: NODE_KIND_WORKER + })) == 1 - # Scales up as nodes are reported as used - local_ip = services.get_node_ip_address() - lm.update( - local_ip, {"CPU": 2}, {"CPU": 0}, {}, - waiting_bundles=2 * [{ - "CPU": 2 - }]) # head autoscaler.update() lm.update( - "172.0.0.0", {"CPU": 2}, {"CPU": 0}, {}, + "172.0.0.1", {"CPU": 2}, {"CPU": 0}, {}, waiting_bundles=2 * [{ "CPU": 2 }]) autoscaler.update() - self.waitForNodes(3) + self.waitForNodes(3, tag_filters={TAG_RAY_NODE_KIND: NODE_KIND_WORKER}) lm.update( - "172.0.0.1", {"CPU": 2}, {"CPU": 0}, {}, + "172.0.0.2", {"CPU": 2}, {"CPU": 0}, {}, waiting_bundles=3 * [{ "CPU": 2 }]) autoscaler.update() - self.waitForNodes(5) + self.waitForNodes(5, tag_filters={TAG_RAY_NODE_KIND: NODE_KIND_WORKER}) # Holds steady when load is removed - lm.update("172.0.0.0", {"CPU": 2}, {"CPU": 2}, {}) lm.update("172.0.0.1", {"CPU": 2}, {"CPU": 2}, {}) + lm.update("172.0.0.2", {"CPU": 2}, {"CPU": 2}, {}) autoscaler.update() assert autoscaler.pending_launches.value == 0 - assert len(self.provider.non_terminated_nodes({})) == 5 + assert len( + self.provider.non_terminated_nodes({ + TAG_RAY_NODE_KIND: NODE_KIND_WORKER + })) == 5 # Scales down as nodes become unused - lm.last_used_time_by_ip["172.0.0.0"] = 0 lm.last_used_time_by_ip["172.0.0.1"] = 0 + lm.last_used_time_by_ip["172.0.0.2"] = 0 autoscaler.update() assert autoscaler.pending_launches.value == 0 @@ -1167,18 +1191,21 @@ class AutoscalingTest(unittest.TestCase): # are not connected and hence we rely more on connected nodes for # min_workers. When the "pending" nodes show up as connected, # then we can terminate the ones connected before. - assert len(self.provider.non_terminated_nodes({})) == 4 - lm.last_used_time_by_ip["172.0.0.2"] = 0 + assert len( + self.provider.non_terminated_nodes({ + TAG_RAY_NODE_KIND: NODE_KIND_WORKER + })) == 4 lm.last_used_time_by_ip["172.0.0.3"] = 0 + lm.last_used_time_by_ip["172.0.0.4"] = 0 autoscaler.update() assert autoscaler.pending_launches.value == 0 # 2 nodes and not 1 because 1 is needed for min_worker and the other 1 # is still not connected. - self.waitForNodes(2) + self.waitForNodes(2, tag_filters={TAG_RAY_NODE_KIND: NODE_KIND_WORKER}) # when we connect it, we will see 1 node. - lm.last_used_time_by_ip["172.0.0.4"] = 0 + lm.last_used_time_by_ip["172.0.0.5"] = 0 autoscaler.update() - self.waitForNodes(1) + self.waitForNodes(1, tag_filters={TAG_RAY_NODE_KIND: NODE_KIND_WORKER}) def testTargetUtilizationFraction(self): config = SMALL_CLUSTER.copy() diff --git a/python/ray/tests/test_resource_demand_scheduler.py b/python/ray/tests/test_resource_demand_scheduler.py index 50d899af0..067b5f53d 100644 --- a/python/ray/tests/test_resource_demand_scheduler.py +++ b/python/ray/tests/test_resource_demand_scheduler.py @@ -256,19 +256,19 @@ def test_add_min_workers_nodes(): } assert _add_min_workers_nodes([], {}, - types, None, None) == \ + types, None, None, None) == \ ([{"CPU": 2}]*50+[{"GPU": 1}]*99999, {"m2.large": 50, "gpu": 99999}, {"m2.large": 50, "gpu": 99999}) assert _add_min_workers_nodes([{"CPU": 2}]*5, {"m2.large": 5}, - types, None, None) == \ + types, None, None, None) == \ ([{"CPU": 2}]*50+[{"GPU": 1}]*99999, {"m2.large": 50, "gpu": 99999}, {"m2.large": 45, "gpu": 99999}) assert _add_min_workers_nodes([{"CPU": 2}]*60, {"m2.large": 60}, - types, None, None) == \ + types, None, None, None) == \ ([{"CPU": 2}]*60+[{"GPU": 1}]*99999, {"m2.large": 60, "gpu": 99999}, {"gpu": 99999}) @@ -279,7 +279,7 @@ def test_add_min_workers_nodes(): }] * 99999, { "m2.large": 50, "gpu": 99999 - }, types, None, None) == ([{ + }, types, None, None, None) == ([{ "CPU": 2 }] * 50 + [{ "GPU": 1 @@ -289,11 +289,11 @@ def test_add_min_workers_nodes(): }, {}) assert _add_min_workers_nodes([], {}, {"gpubla": types["gpubla"]}, None, - None) == ([], {}, {}) + None, None) == ([], {}, {}) types["gpubla"]["max_workers"] = 10 assert _add_min_workers_nodes([], {}, {"gpubla": types["gpubla"]}, None, - None) == ([{ + None, None) == ([{ "GPU": 1 }] * 10, { "gpubla": 10 @@ -306,9 +306,13 @@ def test_get_nodes_to_launch_with_min_workers(): provider = MockProvider() new_types = copy.deepcopy(TYPES_A) new_types["p2.8xlarge"]["min_workers"] = 2 - scheduler = ResourceDemandScheduler(provider, new_types, 3) + scheduler = ResourceDemandScheduler( + provider, new_types, 3, head_node_type="p2.8xlarge") - provider.create_node({}, {TAG_RAY_USER_NODE_TYPE: "p2.8xlarge"}, 1) + provider.create_node({}, { + TAG_RAY_USER_NODE_TYPE: "p2.8xlarge", + TAG_RAY_NODE_KIND: NODE_KIND_HEAD + }, 1) nodes = provider.non_terminated_nodes({}) @@ -318,15 +322,19 @@ def test_get_nodes_to_launch_with_min_workers(): to_launch = scheduler.get_nodes_to_launch(nodes, {}, [{ "GPU": 8 }], utilizations, [], {}) - assert to_launch == {"p2.8xlarge": 1} + assert to_launch == {"p2.8xlarge": 2} def test_get_nodes_to_launch_with_min_workers_and_bin_packing(): provider = MockProvider() new_types = copy.deepcopy(TYPES_A) new_types["p2.8xlarge"]["min_workers"] = 2 - scheduler = ResourceDemandScheduler(provider, new_types, 10) - + scheduler = ResourceDemandScheduler( + provider, new_types, 10, head_node_type="p2.8xlarge") + provider.create_node({}, { + TAG_RAY_NODE_KIND: NODE_KIND_HEAD, + TAG_RAY_USER_NODE_TYPE: "p2.8xlarge" + }, 1) provider.create_node({}, {TAG_RAY_USER_NODE_TYPE: "p2.8xlarge"}, 1) nodes = provider.non_terminated_nodes({}) @@ -336,17 +344,18 @@ def test_get_nodes_to_launch_with_min_workers_and_bin_packing(): utilizations = {ip: {"GPU": 8} for ip in ips} # 1 more on the way pending_nodes = {"p2.8xlarge": 1} - # requires 2 p2.8xls (only 2 are in cluster/pending) and 1 p2.xlarge + # requires 3 p2.8xls (only 2 are in cluster/pending) and 1 p2.xlarge demands = [{"GPU": 8}] * (len(utilizations) + 1) + [{"GPU": 1}] to_launch = scheduler.get_nodes_to_launch(nodes, pending_nodes, demands, utilizations, [], {}) assert to_launch == {"p2.xlarge": 1} - # 3 min_workers of p2.8xlarge covers the 2 p2.8xlarge + 1 p2.xlarge demand. - # 2 p2.8xlarge are running/pending. So we need 1 more p2.8xlarge only to - # meet the min_workers constraint and the demand. + # 3 min_workers + 1 head of p2.8xlarge covers the 3 p2.8xlarge + 1 + # p2.xlarge demand. 3 p2.8xlarge are running/pending. So we need 1 more + # p2.8xlarge only tomeet the min_workers constraint and the demand. new_types["p2.8xlarge"]["min_workers"] = 3 - scheduler = ResourceDemandScheduler(provider, new_types, 10) + scheduler = ResourceDemandScheduler( + provider, new_types, 10, head_node_type="p2.8xlarge") to_launch = scheduler.get_nodes_to_launch(nodes, pending_nodes, demands, utilizations, [], {}) # Make sure it does not return [("p2.8xlarge", 1), ("p2.xlarge", 1)] @@ -355,7 +364,8 @@ def test_get_nodes_to_launch_with_min_workers_and_bin_packing(): def test_get_nodes_to_launch_limits(): provider = MockProvider() - scheduler = ResourceDemandScheduler(provider, TYPES_A, 3) + scheduler = ResourceDemandScheduler( + provider, TYPES_A, 3, head_node_type="p2.8xlarge") provider.create_node({}, {TAG_RAY_USER_NODE_TYPE: "p2.8xlarge"}, 2) @@ -372,7 +382,8 @@ def test_get_nodes_to_launch_limits(): def test_calculate_node_resources(): provider = MockProvider() - scheduler = ResourceDemandScheduler(provider, TYPES_A, 10) + scheduler = ResourceDemandScheduler( + provider, TYPES_A, 10, head_node_type="p2.8xlarge") provider.create_node({}, {TAG_RAY_USER_NODE_TYPE: "p2.8xlarge"}, 2) @@ -403,7 +414,8 @@ def test_request_resources_existing_usage(): "max_workers": 40, }, } - scheduler = ResourceDemandScheduler(provider, TYPES, max_workers=100) + scheduler = ResourceDemandScheduler( + provider, TYPES, max_workers=100, head_node_type="empty_node") # 5 nodes with 32 CPU and 8 GPU each provider.create_node({}, { @@ -475,7 +487,10 @@ def test_backlog_queue_impact_on_binpacking_time(): num_available_nodes, time_to_assert, demand_request_shape): provider = MockProvider() scheduler = ResourceDemandScheduler( - provider, new_types, max_workers=10000) + provider, + new_types, + max_workers=10000, + head_node_type="m4.16xlarge") provider.create_node({}, { TAG_RAY_USER_NODE_TYPE: "m4.16xlarge", @@ -574,7 +589,8 @@ def test_backlog_queue_impact_on_binpacking_time(): class TestPlacementGroupScaling: def test_strategies(self): provider = MockProvider() - scheduler = ResourceDemandScheduler(provider, TYPES_A, 10) + scheduler = ResourceDemandScheduler( + provider, TYPES_A, 10, head_node_type="p2.8xlarge") provider.create_node({}, {TAG_RAY_USER_NODE_TYPE: "p2.8xlarge"}, 2) # At this point our cluster has 2 p2.8xlarge instances (16 GPUs) and is @@ -616,7 +632,8 @@ class TestPlacementGroupScaling: def test_many_strict_spreads(self): provider = MockProvider() - scheduler = ResourceDemandScheduler(provider, TYPES_A, 10) + scheduler = ResourceDemandScheduler( + provider, TYPES_A, 10, head_node_type="p2.8xlarge") provider.create_node({}, {TAG_RAY_USER_NODE_TYPE: "p2.8xlarge"}, 2) # At this point our cluster has 2 p2.8xlarge instances (16 GPUs) and is @@ -640,7 +657,8 @@ class TestPlacementGroupScaling: def test_packing(self): provider = MockProvider() - scheduler = ResourceDemandScheduler(provider, TYPES_A, 10) + scheduler = ResourceDemandScheduler( + provider, TYPES_A, 10, head_node_type="p2.8xlarge") provider.create_node({}, {TAG_RAY_USER_NODE_TYPE: "p2.8xlarge"}, 1) # At this point our cluster has 1 p2.8xlarge instances (8 GPUs) and is @@ -668,7 +686,8 @@ def test_get_concurrent_resource_demand_to_launch(): node_types["m4.large"]["min_workers"] = 2 node_types["m4.large"]["max_workers"] = 100 provider = MockProvider() - scheduler = ResourceDemandScheduler(provider, node_types, 200) + scheduler = ResourceDemandScheduler( + provider, node_types, 200, head_node_type="empty_node") # Sanity check. assert len(provider.non_terminated_nodes({})) == 0 @@ -776,7 +795,8 @@ def test_get_nodes_to_launch_max_launch_concurrency(): new_types["p2.8xlarge"]["min_workers"] = 4 new_types["p2.8xlarge"]["max_workers"] = 40 - scheduler = ResourceDemandScheduler(provider, new_types, 30) + scheduler = ResourceDemandScheduler( + provider, new_types, 30, head_node_type=None) to_launch = scheduler.get_nodes_to_launch([], {}, [], {}, [], {}) # Respects min_workers despite concurrency limitation. @@ -847,7 +867,10 @@ def test_handle_legacy_cluster_config_yaml(): cluster_config = rewrite_legacy_yaml_to_available_node_types( cluster_config) scheduler = ResourceDemandScheduler( - provider, cluster_config["available_node_types"], 0) + provider, + cluster_config["available_node_types"], + 0, + head_node_type=NODE_TYPE_LEGACY_HEAD) provider.create_node({}, { TAG_RAY_NODE_KIND: NODE_KIND_HEAD, TAG_RAY_USER_NODE_TYPE: NODE_TYPE_LEGACY_HEAD @@ -1084,17 +1107,46 @@ class AutoscalingTest(unittest.TestCase): config_path = self.write_config(config) self.provider = MockProvider() runner = MockProcessRunner() + self.provider.create_node({}, { + TAG_RAY_NODE_KIND: NODE_KIND_HEAD, + TAG_RAY_USER_NODE_TYPE: "empty_node" + }, 1) autoscaler = StandardAutoscaler( config_path, - LoadMetrics(), + LoadMetrics("172.0.0.0"), max_failures=0, process_runner=runner, update_interval_s=0) - assert len(self.provider.non_terminated_nodes({})) == 0 + assert len(self.provider.non_terminated_nodes({})) == 1 autoscaler.update() - self.waitForNodes(2) + self.waitForNodes(3) autoscaler.update() - self.waitForNodes(2) + self.waitForNodes(3) + + def testScaleUpMinSanityWithHeadNode(self): + """Make sure when min_workers is used with head node it does not count + head_node in min_workers.""" + config = copy.deepcopy(MULTI_WORKER_CLUSTER) + config["available_node_types"]["empty_node"]["min_workers"] = 2 + config["available_node_types"]["empty_node"]["max_workers"] = 2 + config_path = self.write_config(config) + self.provider = MockProvider() + runner = MockProcessRunner() + self.provider.create_node({}, { + TAG_RAY_NODE_KIND: NODE_KIND_HEAD, + TAG_RAY_USER_NODE_TYPE: "empty_node" + }, 1) + autoscaler = StandardAutoscaler( + config_path, + LoadMetrics("172.0.0.0"), + max_failures=0, + process_runner=runner, + update_interval_s=0) + assert len(self.provider.non_terminated_nodes({})) == 1 + autoscaler.update() + self.waitForNodes(3) + autoscaler.update() + self.waitForNodes(3) def testPlacementGroup(self): # Note this is mostly an integration test. See @@ -1102,21 +1154,23 @@ class AutoscalingTest(unittest.TestCase): config = copy.deepcopy(MULTI_WORKER_CLUSTER) config["min_workers"] = 0 config["max_workers"] = 999 + config["head_node_type"] = "m4.4xlarge" config_path = self.write_config(config) self.provider = MockProvider() runner = MockProcessRunner() - lm = LoadMetrics() + self.provider.create_node({}, { + TAG_RAY_NODE_KIND: "head", + TAG_RAY_USER_NODE_TYPE: "m4.4xlarge" + }, 1) + head_ip = self.provider.non_terminated_node_ips({})[0] + lm = LoadMetrics(head_ip) autoscaler = StandardAutoscaler( config_path, lm, max_failures=0, process_runner=runner, update_interval_s=0) - self.provider.create_node({}, { - TAG_RAY_NODE_KIND: "head", - TAG_RAY_USER_NODE_TYPE: "m4.4xlarge" - }, 1) - head_ip = self.provider.non_terminated_node_ips({})[0] + assert len(self.provider.non_terminated_nodes({})) == 1 autoscaler.update() self.waitForNodes(1) @@ -1172,20 +1226,24 @@ class AutoscalingTest(unittest.TestCase): config_path = self.write_config(config) self.provider = MockProvider() runner = MockProcessRunner() - lm = LoadMetrics() + self.provider.create_node({}, { + TAG_RAY_NODE_KIND: NODE_KIND_HEAD, + TAG_RAY_USER_NODE_TYPE: "empty_node" + }, 1) + lm = LoadMetrics("172.0.0.0") autoscaler = StandardAutoscaler( config_path, lm, max_failures=0, process_runner=runner, update_interval_s=0) - assert len(self.provider.non_terminated_nodes({})) == 0 + assert len(self.provider.non_terminated_nodes({})) == 1 autoscaler.update() - self.waitForNodes(2) - assert len(self.provider.mock_nodes) == 2 + self.waitForNodes(3) + assert len(self.provider.mock_nodes) == 3 assert { - self.provider.mock_nodes[0].node_type, - self.provider.mock_nodes[1].node_type + self.provider.mock_nodes[1].node_type, + self.provider.mock_nodes[2].node_type } == {"p2.8xlarge", "m4.large"} self.provider.create_node({}, { TAG_RAY_USER_NODE_TYPE: "p2.8xlarge", @@ -1195,16 +1253,17 @@ class AutoscalingTest(unittest.TestCase): TAG_RAY_USER_NODE_TYPE: "m4.16xlarge", TAG_RAY_NODE_KIND: NODE_KIND_WORKER }, 2) - assert len(self.provider.non_terminated_nodes({})) == 6 + assert len(self.provider.non_terminated_nodes({})) == 7 # Make sure that after idle_timeout_minutes we don't kill idle # min workers. for node_id in self.provider.non_terminated_nodes({}): lm.last_used_time_by_ip[self.provider.internal_ip(node_id)] = -60 autoscaler.update() - self.waitForNodes(2) + self.waitForNodes(3) cnt = 0 - for id in self.provider.mock_nodes: + # [1:] skips the head node. + for id in list(self.provider.mock_nodes.keys())[1:]: if self.provider.mock_nodes[id].state == "running" or \ self.provider.mock_nodes[id].state == "pending": assert self.provider.mock_nodes[id].node_type in { @@ -1218,6 +1277,7 @@ class AutoscalingTest(unittest.TestCase): # Commenting out this line causes the test case to fail?!?! config["min_workers"] = 0 config["target_utilization_fraction"] = 1.0 + config["head_node_type"] = "p2.xlarge" config_path = self.write_config(config) self.provider = MockProvider() self.provider.create_node({}, { @@ -1295,28 +1355,32 @@ class AutoscalingTest(unittest.TestCase): config_path = self.write_config(config) self.provider = MockProvider() runner = MockProcessRunner() + self.provider.create_node({}, { + TAG_RAY_NODE_KIND: NODE_KIND_HEAD, + TAG_RAY_USER_NODE_TYPE: "empty_node" + }, 1) autoscaler = StandardAutoscaler( config_path, - LoadMetrics(), + LoadMetrics("172.0.0.0"), max_failures=0, process_runner=runner, update_interval_s=0) - assert len(self.provider.non_terminated_nodes({})) == 0 - autoscaler.update() - self.waitForNodes(0) - autoscaler.request_resources([{"CPU": 1}]) + assert len(self.provider.non_terminated_nodes({})) == 1 autoscaler.update() self.waitForNodes(1) - assert self.provider.mock_nodes[0].node_type == "m4.large" - autoscaler.request_resources([{"GPU": 8}]) + autoscaler.request_resources([{"CPU": 1}]) autoscaler.update() self.waitForNodes(2) - assert self.provider.mock_nodes[1].node_type == "p2.8xlarge" + assert self.provider.mock_nodes[1].node_type == "m4.large" + autoscaler.request_resources([{"GPU": 8}]) + autoscaler.update() + self.waitForNodes(3) + assert self.provider.mock_nodes[2].node_type == "p2.8xlarge" autoscaler.request_resources([{"CPU": 32}] * 4) autoscaler.update() - self.waitForNodes(4) - assert self.provider.mock_nodes[2].node_type == "m4.16xlarge" + self.waitForNodes(5) assert self.provider.mock_nodes[3].node_type == "m4.16xlarge" + assert self.provider.mock_nodes[4].node_type == "m4.16xlarge" def testResourcePassing(self): config = MULTI_WORKER_CLUSTER.copy() @@ -1326,23 +1390,27 @@ class AutoscalingTest(unittest.TestCase): self.provider = MockProvider() runner = MockProcessRunner() runner.respond_to_call("json .Config.Env", ["[]" for i in range(2)]) + self.provider.create_node({}, { + TAG_RAY_NODE_KIND: NODE_KIND_HEAD, + TAG_RAY_USER_NODE_TYPE: "empty_node" + }, 1) autoscaler = StandardAutoscaler( config_path, - LoadMetrics(), + LoadMetrics("172.0.0.0"), max_failures=0, process_runner=runner, update_interval_s=0) - assert len(self.provider.non_terminated_nodes({})) == 0 + assert len(self.provider.non_terminated_nodes({})) == 1 autoscaler.update() - self.waitForNodes(0) + self.waitForNodes(0, tag_filters={TAG_RAY_NODE_KIND: NODE_KIND_WORKER}) autoscaler.request_resources([{"CPU": 1}]) autoscaler.update() - self.waitForNodes(1) - assert self.provider.mock_nodes[0].node_type == "m4.large" + self.waitForNodes(1, tag_filters={TAG_RAY_NODE_KIND: NODE_KIND_WORKER}) + assert self.provider.mock_nodes[1].node_type == "m4.large" autoscaler.request_resources([{"GPU": 8}]) autoscaler.update() - self.waitForNodes(2) - assert self.provider.mock_nodes[1].node_type == "p2.8xlarge" + self.waitForNodes(2, tag_filters={TAG_RAY_NODE_KIND: NODE_KIND_WORKER}) + assert self.provider.mock_nodes[2].node_type == "p2.8xlarge" # TODO (Alex): Autoscaler creates the node during one update then # starts the updater in the enxt update. The sleep is largely @@ -1353,11 +1421,11 @@ class AutoscalingTest(unittest.TestCase): # These checks are done separately because we have no guarantees on the # order the dict is serialized in. - runner.assert_has_call("172.0.0.0", "RAY_OVERRIDE_RESOURCES=") - runner.assert_has_call("172.0.0.0", "\"CPU\":2") runner.assert_has_call("172.0.0.1", "RAY_OVERRIDE_RESOURCES=") - runner.assert_has_call("172.0.0.1", "\"CPU\":32") - runner.assert_has_call("172.0.0.1", "\"GPU\":8") + runner.assert_has_call("172.0.0.1", "\"CPU\":2") + runner.assert_has_call("172.0.0.2", "RAY_OVERRIDE_RESOURCES=") + runner.assert_has_call("172.0.0.2", "\"CPU\":32") + runner.assert_has_call("172.0.0.2", "\"GPU\":8") def testScaleUpLoadMetrics(self): config = MULTI_WORKER_CLUSTER.copy() @@ -1366,16 +1434,20 @@ class AutoscalingTest(unittest.TestCase): config_path = self.write_config(config) self.provider = MockProvider() runner = MockProcessRunner() - lm = LoadMetrics() + self.provider.create_node({}, { + TAG_RAY_NODE_KIND: NODE_KIND_HEAD, + TAG_RAY_USER_NODE_TYPE: "empty_node" + }, 1) + lm = LoadMetrics("172.0.0.0") autoscaler = StandardAutoscaler( config_path, lm, max_failures=0, process_runner=runner, update_interval_s=0) - assert len(self.provider.non_terminated_nodes({})) == 0 + assert len(self.provider.non_terminated_nodes({})) == 1 autoscaler.update() - self.waitForNodes(0) + self.waitForNodes(0, tag_filters={TAG_RAY_NODE_KIND: NODE_KIND_WORKER}) autoscaler.update() lm.update( "1.2.3.4", {}, {}, {}, @@ -1386,10 +1458,10 @@ class AutoscalingTest(unittest.TestCase): "CPU": 16 }]) autoscaler.update() - self.waitForNodes(2) + self.waitForNodes(2, tag_filters={TAG_RAY_NODE_KIND: NODE_KIND_WORKER}) nodes = { - self.provider.mock_nodes[0].node_type, - self.provider.mock_nodes[1].node_type + self.provider.mock_nodes[1].node_type, + self.provider.mock_nodes[2].node_type } assert nodes == {"p2.xlarge", "m4.4xlarge"} @@ -1407,40 +1479,46 @@ class AutoscalingTest(unittest.TestCase): config_path = self.write_config(config) self.provider = MockProvider() runner = MockProcessRunner() - runner.respond_to_call("json .Config.Env", ["[]" for i in range(3)]) + runner.respond_to_call("json .Config.Env", ["[]" for i in range(4)]) + self.provider.create_node({}, { + TAG_RAY_NODE_KIND: NODE_KIND_HEAD, + TAG_RAY_USER_NODE_TYPE: "empty_node" + }, 1) + lm = LoadMetrics("172.0.0.0") + lm.update("172.0.0.0", {"CPU": 0}, {"CPU": 0}, {}) autoscaler = StandardAutoscaler( config_path, - LoadMetrics(), + lm, max_failures=0, process_runner=runner, update_interval_s=0) - assert len(self.provider.non_terminated_nodes({})) == 0 - autoscaler.update() - self.waitForNodes(0) - autoscaler.request_resources([{"CPU": 1}]) + assert len(self.provider.non_terminated_nodes({})) == 1 autoscaler.update() self.waitForNodes(1) - assert self.provider.mock_nodes[0].node_type == "m4.large" - autoscaler.request_resources([{"GPU": 8}]) + autoscaler.request_resources([{"CPU": 1}]) autoscaler.update() self.waitForNodes(2) - assert self.provider.mock_nodes[1].node_type == "p2.8xlarge" - autoscaler.request_resources([{"GPU": 1}] * 9) + assert self.provider.mock_nodes[1].node_type == "m4.large" + autoscaler.request_resources([{"GPU": 8}]) autoscaler.update() self.waitForNodes(3) - assert self.provider.mock_nodes[2].node_type == "p2.xlarge" + assert self.provider.mock_nodes[2].node_type == "p2.8xlarge" + autoscaler.request_resources([{"GPU": 1}] * 9) + autoscaler.update() + self.waitForNodes(4) + assert self.provider.mock_nodes[3].node_type == "p2.xlarge" autoscaler.update() sleep(0.1) - runner.assert_has_call(self.provider.mock_nodes[1].internal_ip, + runner.assert_has_call(self.provider.mock_nodes[2].internal_ip, "new_worker_setup_command") - runner.assert_not_has_call(self.provider.mock_nodes[1].internal_ip, - "setup_cmd") - runner.assert_not_has_call(self.provider.mock_nodes[1].internal_ip, - "worker_setup_cmd") - runner.assert_has_call(self.provider.mock_nodes[2].internal_ip, - "new_worker_initialization_cmd") runner.assert_not_has_call(self.provider.mock_nodes[2].internal_ip, + "setup_cmd") + runner.assert_not_has_call(self.provider.mock_nodes[2].internal_ip, + "worker_setup_cmd") + runner.assert_has_call(self.provider.mock_nodes[3].internal_ip, + "new_worker_initialization_cmd") + runner.assert_not_has_call(self.provider.mock_nodes[3].internal_ip, "init_cmd") def testDockerWorkers(self): @@ -1461,28 +1539,32 @@ class AutoscalingTest(unittest.TestCase): config_path = self.write_config(config) self.provider = MockProvider() runner = MockProcessRunner() - runner.respond_to_call("json .Config.Env", ["[]" for i in range(4)]) + runner.respond_to_call("json .Config.Env", ["[]" for i in range(5)]) + self.provider.create_node({}, { + TAG_RAY_NODE_KIND: NODE_KIND_HEAD, + TAG_RAY_USER_NODE_TYPE: "empty_node" + }, 1) autoscaler = StandardAutoscaler( config_path, - LoadMetrics(), + LoadMetrics("172.0.0.0"), max_failures=0, process_runner=runner, update_interval_s=0) - assert len(self.provider.non_terminated_nodes({})) == 0 - autoscaler.update() - self.waitForNodes(0) - autoscaler.request_resources([{"CPU": 1}]) + assert len(self.provider.non_terminated_nodes({})) == 1 autoscaler.update() self.waitForNodes(1) - assert self.provider.mock_nodes[0].node_type == "m4.large" - autoscaler.request_resources([{"GPU": 8}]) + autoscaler.request_resources([{"CPU": 1}]) autoscaler.update() self.waitForNodes(2) - assert self.provider.mock_nodes[1].node_type == "p2.8xlarge" - autoscaler.request_resources([{"GPU": 1}] * 9) + assert self.provider.mock_nodes[1].node_type == "m4.large" + autoscaler.request_resources([{"GPU": 8}]) autoscaler.update() self.waitForNodes(3) - assert self.provider.mock_nodes[2].node_type == "p2.xlarge" + assert self.provider.mock_nodes[2].node_type == "p2.8xlarge" + autoscaler.request_resources([{"GPU": 1}] * 9) + autoscaler.update() + self.waitForNodes(4) + assert self.provider.mock_nodes[3].node_type == "p2.xlarge" autoscaler.update() # Fill up m4, p2.8, p2 and request 2 more CPUs autoscaler.request_resources([{ @@ -1495,33 +1577,33 @@ class AutoscalingTest(unittest.TestCase): "CPU": 2 }]) autoscaler.update() - self.waitForNodes(4) - assert self.provider.mock_nodes[3].node_type == "m4.16xlarge" + self.waitForNodes(5) + assert self.provider.mock_nodes[4].node_type == "m4.16xlarge" autoscaler.update() sleep(0.1) - runner.assert_has_call(self.provider.mock_nodes[1].internal_ip, + runner.assert_has_call(self.provider.mock_nodes[2].internal_ip, "p2.8x-run-options") - runner.assert_has_call(self.provider.mock_nodes[1].internal_ip, + runner.assert_has_call(self.provider.mock_nodes[2].internal_ip, "p2.8x_image:latest") - runner.assert_not_has_call(self.provider.mock_nodes[1].internal_ip, + runner.assert_not_has_call(self.provider.mock_nodes[2].internal_ip, "default-image:nightly") - runner.assert_not_has_call(self.provider.mock_nodes[1].internal_ip, + runner.assert_not_has_call(self.provider.mock_nodes[2].internal_ip, "standard-run-options") - runner.assert_has_call(self.provider.mock_nodes[2].internal_ip, + runner.assert_has_call(self.provider.mock_nodes[3].internal_ip, "p2x_image:nightly") - runner.assert_has_call(self.provider.mock_nodes[2].internal_ip, + runner.assert_has_call(self.provider.mock_nodes[3].internal_ip, "standard-run-options") - runner.assert_not_has_call(self.provider.mock_nodes[2].internal_ip, + runner.assert_not_has_call(self.provider.mock_nodes[3].internal_ip, "p2.8x-run-options") - runner.assert_has_call(self.provider.mock_nodes[3].internal_ip, + runner.assert_has_call(self.provider.mock_nodes[4].internal_ip, "default-image:nightly") - runner.assert_has_call(self.provider.mock_nodes[3].internal_ip, + runner.assert_has_call(self.provider.mock_nodes[4].internal_ip, "standard-run-options") - runner.assert_not_has_call(self.provider.mock_nodes[3].internal_ip, + runner.assert_not_has_call(self.provider.mock_nodes[4].internal_ip, "p2.8x-run-options") - runner.assert_not_has_call(self.provider.mock_nodes[3].internal_ip, + runner.assert_not_has_call(self.provider.mock_nodes[4].internal_ip, "p2x_image:nightly") def testUpdateConfig(self): @@ -1531,21 +1613,25 @@ class AutoscalingTest(unittest.TestCase): config_path = self.write_config(config) self.provider = MockProvider() runner = MockProcessRunner() + self.provider.create_node({}, { + TAG_RAY_NODE_KIND: NODE_KIND_HEAD, + TAG_RAY_USER_NODE_TYPE: "empty_node" + }, 1) autoscaler = StandardAutoscaler( config_path, - LoadMetrics(), + LoadMetrics("172.0.0.0"), max_failures=0, process_runner=runner, update_interval_s=0) - assert len(self.provider.non_terminated_nodes({})) == 0 + assert len(self.provider.non_terminated_nodes({})) == 1 autoscaler.update() - self.waitForNodes(2) + self.waitForNodes(2, tag_filters={TAG_RAY_NODE_KIND: NODE_KIND_WORKER}) config["available_node_types"]["m4.large"]["min_workers"] = 0 config["available_node_types"]["m4.large"]["node_config"][ "field_changed"] = 1 config_path = self.write_config(config) autoscaler.update() - self.waitForNodes(0) + self.waitForNodes(0, tag_filters={TAG_RAY_NODE_KIND: NODE_KIND_WORKER}) def testEmptyDocker(self): config = MULTI_WORKER_CLUSTER.copy() @@ -1555,23 +1641,27 @@ class AutoscalingTest(unittest.TestCase): config_path = self.write_config(config) self.provider = MockProvider() runner = MockProcessRunner() + self.provider.create_node({}, { + TAG_RAY_NODE_KIND: NODE_KIND_HEAD, + TAG_RAY_USER_NODE_TYPE: "empty_node" + }, 1) autoscaler = StandardAutoscaler( config_path, - LoadMetrics(), + LoadMetrics("172.0.0.0"), max_failures=0, process_runner=runner, update_interval_s=0) - assert len(self.provider.non_terminated_nodes({})) == 0 - autoscaler.update() - self.waitForNodes(0) - autoscaler.request_resources([{"CPU": 1}]) + assert len(self.provider.non_terminated_nodes({})) == 1 autoscaler.update() self.waitForNodes(1) - assert self.provider.mock_nodes[0].node_type == "m4.large" - autoscaler.request_resources([{"GPU": 8}]) + autoscaler.request_resources([{"CPU": 1}]) autoscaler.update() self.waitForNodes(2) - assert self.provider.mock_nodes[1].node_type == "p2.8xlarge" + assert self.provider.mock_nodes[1].node_type == "m4.large" + autoscaler.request_resources([{"GPU": 8}]) + autoscaler.update() + self.waitForNodes(3) + assert self.provider.mock_nodes[2].node_type == "p2.8xlarge" def testRequestResourcesIdleTimeout(self): """Test request_resources() with and without idle timeout.""" @@ -1599,8 +1689,12 @@ class AutoscalingTest(unittest.TestCase): config_path = self.write_config(config) self.provider = MockProvider() runner = MockProcessRunner() - lm = LoadMetrics() - runner.respond_to_call("json .Config.Env", ["[]" for i in range(2)]) + self.provider.create_node({}, { + TAG_RAY_NODE_KIND: NODE_KIND_HEAD, + TAG_RAY_USER_NODE_TYPE: "empty_node" + }, 1) + lm = LoadMetrics("172.0.0.0") + runner.respond_to_call("json .Config.Env", ["[]" for i in range(3)]) autoscaler = StandardAutoscaler( config_path, lm, @@ -1608,14 +1702,14 @@ class AutoscalingTest(unittest.TestCase): process_runner=runner, update_interval_s=0) autoscaler.update() - self.waitForNodes(0) + self.waitForNodes(0, tag_filters={TAG_RAY_NODE_KIND: NODE_KIND_WORKER}) autoscaler.request_resources([{"CPU": 0.2, "WORKER": 1.0}]) autoscaler.update() - self.waitForNodes(1) + self.waitForNodes(1, tag_filters={TAG_RAY_NODE_KIND: NODE_KIND_WORKER}) non_terminated_nodes = autoscaler.provider.non_terminated_nodes({}) - assert len(non_terminated_nodes) == 1 - node_id = non_terminated_nodes[0] - node_ip = autoscaler.provider.non_terminated_node_ips({})[0] + assert len(non_terminated_nodes) == 2 + node_id = non_terminated_nodes[1] + node_ip = autoscaler.provider.non_terminated_node_ips({})[1] # A hack to check if the node was terminated when it shouldn't. autoscaler.provider.mock_nodes[node_id].state = "unterminatable" @@ -1629,10 +1723,10 @@ class AutoscalingTest(unittest.TestCase): }]) autoscaler.update() # this fits on request_resources()! - self.waitForNodes(1) + self.waitForNodes(1, tag_filters={TAG_RAY_NODE_KIND: NODE_KIND_WORKER}) autoscaler.request_resources([{"CPU": 0.2, "WORKER": 1.0}] * 2) autoscaler.update() - self.waitForNodes(2) + self.waitForNodes(2, tag_filters={TAG_RAY_NODE_KIND: NODE_KIND_WORKER}) autoscaler.request_resources([{"CPU": 0.2, "WORKER": 1.0}]) lm.update( node_ip, @@ -1642,7 +1736,7 @@ class AutoscalingTest(unittest.TestCase): "WORKER": 1.0 }]) autoscaler.update() - self.waitForNodes(2) + self.waitForNodes(2, tag_filters={TAG_RAY_NODE_KIND: NODE_KIND_WORKER}) lm.update( node_ip, config["available_node_types"]["def_worker"]["resources"], @@ -1653,13 +1747,13 @@ class AutoscalingTest(unittest.TestCase): }]) autoscaler.update() # Still 2 as the second node did not show up a heart beat. - self.waitForNodes(2) + self.waitForNodes(2, tag_filters={TAG_RAY_NODE_KIND: NODE_KIND_WORKER}) # If node {node_id} was terminated any time then it's state will be set # to terminated. assert autoscaler.provider.mock_nodes[ node_id].state == "unterminatable" lm.update( - "172.0.0.1", + "172.0.0.2", config["available_node_types"]["def_worker"]["resources"], config["available_node_types"]["def_worker"]["resources"], {}, waiting_bundles=[{ @@ -1669,7 +1763,7 @@ class AutoscalingTest(unittest.TestCase): autoscaler.update() # Now it is 1 because it showed up in last used (heart beat). # The remaining one is 127.0.0.1. - self.waitForNodes(1) + self.waitForNodes(1, tag_filters={TAG_RAY_NODE_KIND: NODE_KIND_WORKER}) def testRequestResourcesRaceConditionsLong(self): """Test request_resources(), race conditions & demands/min_workers. @@ -1704,7 +1798,11 @@ class AutoscalingTest(unittest.TestCase): self.provider = MockProvider() runner = MockProcessRunner() runner.respond_to_call("json .Config.Env", ["[]" for i in range(3)]) - lm = LoadMetrics() + self.provider.create_node({}, { + TAG_RAY_NODE_KIND: NODE_KIND_HEAD, + TAG_RAY_USER_NODE_TYPE: "empty_node" + }, 1) + lm = LoadMetrics("172.0.0.0") autoscaler = StandardAutoscaler( config_path, lm, @@ -1714,11 +1812,11 @@ class AutoscalingTest(unittest.TestCase): autoscaler.request_resources([{"CPU": 0.2, "WORKER": 1.0}]) autoscaler.update() # 1 min worker for both min_worker and request_resources() - self.waitForNodes(1) + self.waitForNodes(1, tag_filters={TAG_RAY_NODE_KIND: NODE_KIND_WORKER}) non_terminated_nodes = autoscaler.provider.non_terminated_nodes({}) - assert len(non_terminated_nodes) == 1 - node_id = non_terminated_nodes[0] - node_ip = autoscaler.provider.non_terminated_node_ips({})[0] + assert len(non_terminated_nodes) == 2 + node_id = non_terminated_nodes[1] + node_ip = autoscaler.provider.non_terminated_node_ips({})[1] # A hack to check if the node was terminated when it shouldn't. autoscaler.provider.mock_nodes[node_id].state = "unterminatable" @@ -1733,12 +1831,12 @@ class AutoscalingTest(unittest.TestCase): autoscaler.request_resources([{"CPU": 0.2, "WORKER": 1.0}] * 2) autoscaler.update() # 2 requested_resource, 1 min worker, 1 free node -> 2 nodes total - self.waitForNodes(2) + self.waitForNodes(2, tag_filters={TAG_RAY_NODE_KIND: NODE_KIND_WORKER}) autoscaler.request_resources([{"CPU": 0.2, "WORKER": 1.0}]) autoscaler.update() # Still 2 because the second one is not connected and hence # request_resources occupies the connected node. - self.waitForNodes(2) + self.waitForNodes(2, tag_filters={TAG_RAY_NODE_KIND: NODE_KIND_WORKER}) autoscaler.request_resources([{"CPU": 0.2, "WORKER": 1.0}] * 3) lm.update( node_ip, @@ -1748,14 +1846,14 @@ class AutoscalingTest(unittest.TestCase): "WORKER": 1.0 }] * 3) autoscaler.update() - self.waitForNodes(3) + self.waitForNodes(3, tag_filters={TAG_RAY_NODE_KIND: NODE_KIND_WORKER}) autoscaler.request_resources([]) - lm.update("172.0.0.1", + lm.update("172.0.0.2", config["available_node_types"]["def_worker"]["resources"], config["available_node_types"]["def_worker"]["resources"], {}) - lm.update("172.0.0.2", + lm.update("172.0.0.3", config["available_node_types"]["def_worker"]["resources"], config["available_node_types"]["def_worker"]["resources"], {}) @@ -1763,7 +1861,7 @@ class AutoscalingTest(unittest.TestCase): config["available_node_types"]["def_worker"]["resources"], {}, {}) autoscaler.update() - self.waitForNodes(1) + self.waitForNodes(1, tag_filters={TAG_RAY_NODE_KIND: NODE_KIND_WORKER}) # If node {node_id} was terminated any time then it's state will be set # to terminated. assert autoscaler.provider.mock_nodes[ @@ -1799,7 +1897,11 @@ class AutoscalingTest(unittest.TestCase): self.provider = MockProvider() runner = MockProcessRunner() runner.respond_to_call("json .Config.Env", ["[]" for i in range(2)]) - lm = LoadMetrics() + self.provider.create_node({}, { + TAG_RAY_NODE_KIND: NODE_KIND_HEAD, + TAG_RAY_USER_NODE_TYPE: "empty_node" + }, 1) + lm = LoadMetrics("172.0.0.0") autoscaler = StandardAutoscaler( config_path, lm, @@ -1809,7 +1911,7 @@ class AutoscalingTest(unittest.TestCase): autoscaler.request_resources([{"CPU": 2, "WORKER": 1.0}] * 2) autoscaler.update() # 2 min worker for both min_worker and request_resources(), not 3. - self.waitForNodes(2) + self.waitForNodes(2, tag_filters={TAG_RAY_NODE_KIND: NODE_KIND_WORKER}) def testRequestResourcesRaceConditionWithResourceDemands(self): """Test request_resources() with resource_demands. From 80f6dd16b2a7f5e90ff882751c37b7f2e02bd147 Mon Sep 17 00:00:00 2001 From: Barak Michener Date: Sun, 20 Dec 2020 15:43:48 -0800 Subject: [PATCH 47/88] [ray_client] Implement optional arguments to ray.remote() and f.options() (#12985) --- .../ray/experimental/client/client_pickler.py | 16 ++- python/ray/experimental/client/common.py | 119 +++++++++++++++--- python/ray/experimental/client/options.py | 54 ++++++++ .../ray/experimental/client/server/server.py | 91 ++++++++++---- .../client/server/server_pickler.py | 8 +- python/ray/experimental/client/worker.py | 64 +++++++--- python/ray/tests/BUILD | 1 + python/ray/tests/test_actor.py | 8 +- python/ray/tests/test_advanced.py | 24 ++-- python/ray/tests/test_basic.py | 12 +- python/ray/tests/test_basic_2.py | 1 - src/ray/protobuf/ray_client.proto | 28 ++++- 12 files changed, 336 insertions(+), 90 deletions(-) create mode 100644 python/ray/experimental/client/options.py diff --git a/python/ray/experimental/client/client_pickler.py b/python/ray/experimental/client/client_pickler.py index 2496199ea..7ba83b3ac 100644 --- a/python/ray/experimental/client/client_pickler.py +++ b/python/ray/experimental/client/client_pickler.py @@ -28,6 +28,7 @@ import sys from typing import NamedTuple from typing import Any +from typing import Dict from typing import Optional from ray.experimental.client import RayAPIStub @@ -37,6 +38,7 @@ from ray.experimental.client.common import ClientActorRef from ray.experimental.client.common import ClientActorClass from ray.experimental.client.common import ClientRemoteFunc from ray.experimental.client.common import ClientRemoteMethod +from ray.experimental.client.common import OptionWrapper from ray.experimental.client.common import SelfReferenceSentinel import ray.core.generated.ray_client_pb2 as ray_client_pb2 @@ -52,7 +54,8 @@ else: # the data for an exectuion, with no arguments. Combine the two? PickleStub = NamedTuple("PickleStub", [("type", str), ("client_id", str), ("ref_id", bytes), - ("name", Optional[str])]) + ("name", Optional[str]), + ("baseline_options", Optional[Dict])]) class ClientPickler(cloudpickle.CloudPickler): @@ -67,6 +70,7 @@ class ClientPickler(cloudpickle.CloudPickler): client_id=self.client_id, ref_id=b"", name=None, + baseline_options=None, ) elif isinstance(obj, ClientObjectRef): return PickleStub( @@ -74,6 +78,7 @@ class ClientPickler(cloudpickle.CloudPickler): client_id=self.client_id, ref_id=obj.id, name=None, + baseline_options=None, ) elif isinstance(obj, ClientActorHandle): return PickleStub( @@ -81,6 +86,7 @@ class ClientPickler(cloudpickle.CloudPickler): client_id=self.client_id, ref_id=obj._actor_id, name=None, + baseline_options=None, ) elif isinstance(obj, ClientRemoteFunc): # TODO(barakmich): This is going to have trouble with mutually @@ -95,12 +101,14 @@ class ClientPickler(cloudpickle.CloudPickler): client_id=self.client_id, ref_id=b"", name=None, + baseline_options=None, ) return PickleStub( type="RemoteFunc", client_id=self.client_id, ref_id=obj._ref.id, name=None, + baseline_options=obj._options, ) elif isinstance(obj, ClientActorClass): # TODO(barakmich): Mutual recursion, as above. @@ -112,12 +120,14 @@ class ClientPickler(cloudpickle.CloudPickler): client_id=self.client_id, ref_id=b"", name=None, + baseline_options=None, ) return PickleStub( type="RemoteActor", client_id=self.client_id, ref_id=obj._ref.id, name=None, + baseline_options=obj._options, ) elif isinstance(obj, ClientRemoteMethod): return PickleStub( @@ -125,7 +135,11 @@ class ClientPickler(cloudpickle.CloudPickler): client_id=self.client_id, ref_id=obj.actor_handle.actor_ref.id, name=obj.method_name, + baseline_options=None, ) + elif isinstance(obj, OptionWrapper): + raise NotImplementedError( + "Sending a partial option is unimplemented") return None diff --git a/python/ray/experimental/client/common.py b/python/ray/experimental/client/common.py index 60901c661..49eee05d6 100644 --- a/python/ray/experimental/client/common.py +++ b/python/ray/experimental/client/common.py @@ -1,9 +1,21 @@ import ray.core.generated.ray_client_pb2 as ray_client_pb2 from ray.experimental.client import ray +from ray.experimental.client.options import validate_options + +import json +import threading +from typing import Any +from typing import List +from typing import Dict +from typing import Optional +from typing import Union class ClientBaseRef: def __init__(self, id: bytes): + self.id = None + if not isinstance(id, bytes): + raise TypeError("ClientRefs must be created with bytes IDs") self.id: bytes = id ray.call_retain(id) @@ -23,7 +35,7 @@ class ClientBaseRef: return hash(self.id) def __del__(self): - if ray.is_connected(): + if ray.is_connected() and self.id is not None: ray.call_release(self.id) @@ -52,33 +64,42 @@ class ClientRemoteFunc(ClientStub): _ref: The ClientObjectRef of the pickled code of the function, _func """ - def __init__(self, f): + def __init__(self, f, options=None): + self._lock = threading.Lock() self._func = f self._name = f.__name__ self._ref = None + self._options = validate_options(options) def __call__(self, *args, **kwargs): raise TypeError(f"Remote function cannot be called directly. " "Use {self._name}.remote method instead") def remote(self, *args, **kwargs): - return ClientObjectRef(ray.call_remote(self, *args, **kwargs)) + return return_refs(ray.call_remote(self, *args, **kwargs)) + + def options(self, **kwargs): + return OptionWrapper(self, kwargs) + + def _remote(self, args=[], kwargs={}, **option_args): + return self.options(**option_args).remote(*args, **kwargs) def __repr__(self): return "ClientRemoteFunc(%s, %s)" % (self._name, self._ref) def _ensure_ref(self): - if self._ref is None: - # While calling ray.put() on our function, if - # our function is recursive, it will attempt to - # encode the ClientRemoteFunc -- itself -- and - # infinitely recurse on _ensure_ref. - # - # So we set the state of the reference to be an - # in-progress self reference value, which - # the encoding can detect and handle correctly. - self._ref = SelfReferenceSentinel() - self._ref = ray.put(self._func) + with self._lock: + if self._ref is None: + # While calling ray.put() on our function, if + # our function is recursive, it will attempt to + # encode the ClientRemoteFunc -- itself -- and + # infinitely recurse on _ensure_ref. + # + # So we set the state of the reference to be an + # in-progress self reference value, which + # the encoding can detect and handle correctly. + self._ref = SelfReferenceSentinel() + self._ref = ray.put(self._func) def _prepare_client_task(self) -> ray_client_pb2.ClientTask: self._ensure_ref() @@ -86,6 +107,7 @@ class ClientRemoteFunc(ClientStub): task.type = ray_client_pb2.ClientTask.FUNCTION task.name = self._name task.payload_id = self._ref.id + set_task_options(task, self._options, "baseline_options") return task @@ -100,10 +122,11 @@ class ClientActorClass(ClientStub): _ref: The ClientObjectRef of the pickled `actor_cls` """ - def __init__(self, actor_cls): + def __init__(self, actor_cls, options=None): self.actor_cls = actor_cls self._name = actor_cls.__name__ self._ref = None + self._options = validate_options(options) def __call__(self, *args, **kwargs): raise TypeError(f"Remote actor cannot be instantiated directly. " @@ -119,8 +142,15 @@ class ClientActorClass(ClientStub): def remote(self, *args, **kwargs) -> "ClientActorHandle": # Actually instantiate the actor - ref_id = ray.call_remote(self, *args, **kwargs) - return ClientActorHandle(ClientActorRef(ref_id), self) + ref_ids = ray.call_remote(self, *args, **kwargs) + assert len(ref_ids) == 1 + return ClientActorHandle(ClientActorRef(ref_ids[0]), self) + + def options(self, **kwargs): + return ActorOptionWrapper(self, kwargs) + + def _remote(self, args=[], kwargs={}, **option_args): + return self.options(**option_args).remote(*args, **kwargs) def __repr__(self): return "ClientActorClass(%s, %s)" % (self._name, self._ref) @@ -136,6 +166,7 @@ class ClientActorClass(ClientStub): task.type = ray_client_pb2.ClientTask.ACTOR task.name = self._name task.payload_id = self._ref.id + set_task_options(task, self._options, "baseline_options") return task @@ -160,7 +191,8 @@ class ClientActorHandle(ClientStub): self.actor_ref = actor_ref def __del__(self) -> None: - ray.call_release(self.actor_ref.id) + if ray.is_connected(): + ray.call_release(self.actor_ref.id) @property def _actor_id(self): @@ -193,12 +225,18 @@ class ClientRemoteMethod(ClientStub): f"Use {self._name}.remote() instead") def remote(self, *args, **kwargs): - return ClientObjectRef(ray.call_remote(self, *args, **kwargs)) + return return_refs(ray.call_remote(self, *args, **kwargs)) def __repr__(self): return "ClientRemoteMethod(%s, %s)" % (self.method_name, self.actor_handle) + def options(self, **kwargs): + return OptionWrapper(self, kwargs) + + def _remote(self, args=[], kwargs={}, **option_args): + return self.options(**option_args).remote(*args, **kwargs) + def _prepare_client_task(self) -> ray_client_pb2.ClientTask: task = ray_client_pb2.ClientTask() task.type = ray_client_pb2.ClientTask.METHOD @@ -207,6 +245,49 @@ class ClientRemoteMethod(ClientStub): return task +class OptionWrapper: + def __init__(self, stub: ClientStub, options: Optional[Dict[str, Any]]): + self.remote_stub = stub + self.options = validate_options(options) + + def remote(self, *args, **kwargs): + return return_refs(ray.call_remote(self, *args, **kwargs)) + + def __getattr__(self, key): + return getattr(self.remote_stub, key) + + def _prepare_client_task(self): + task = self.remote_stub._prepare_client_task() + set_task_options(task, self.options) + return task + + +class ActorOptionWrapper(OptionWrapper): + def remote(self, *args, **kwargs): + ref_ids = ray.call_remote(self, *args, **kwargs) + assert len(ref_ids) == 1 + return ClientActorHandle(ClientActorRef(ref_ids[0]), self) + + +def set_task_options(task: ray_client_pb2.ClientTask, + options: Optional[Dict[str, Any]], + field: str = "options") -> None: + if options is None: + task.ClearField(field) + return + options_str = json.dumps(options) + getattr(task, field).json_options = options_str + + +def return_refs(ids: List[bytes] + ) -> Union[None, ClientObjectRef, List[ClientObjectRef]]: + if len(ids) == 1: + return ClientObjectRef(ids[0]) + if len(ids) == 0: + return None + return [ClientObjectRef(id) for id in ids] + + class DataEncodingSentinel: def __repr__(self) -> str: return self.__class__.__name__ diff --git a/python/ray/experimental/client/options.py b/python/ray/experimental/client/options.py new file mode 100644 index 000000000..79727b126 --- /dev/null +++ b/python/ray/experimental/client/options.py @@ -0,0 +1,54 @@ +from typing import Any +from typing import Dict +from typing import Optional + +options = { + "num_returns": (int, lambda x: x >= 0, + "The keyword 'num_returns' only accepts 0 " + "or a positive integer"), + "num_cpus": (), + "num_gpus": (), + "resources": (), + "accelerator_type": (), + "max_calls": (int, lambda x: x >= 0, + "The keyword 'max_calls' only accepts 0 " + "or a positive integer"), + "max_restarts": (int, lambda x: x >= -1, + "The keyword 'max_restarts' only accepts -1, 0 " + "or a positive integer"), + "max_task_retries": (int, lambda x: x >= -1, + "The keyword 'max_task_retries' only accepts -1, 0 " + "or a positive integer"), + "max_retries": (int, lambda x: x >= -1, + "The keyword 'max_retries' only accepts 0, -1 " + "or a positive integer"), + "max_concurrency": (), + "name": (), + "lifetime": (), + "memory": (), + "object_store_memory": (), + "placement_group": (), + "placement_group_bundle_index": (), + "placement_group_capture_child_tasks": (), + "override_environment_variables": (), +} + + +def validate_options( + kwargs_dict: Optional[Dict[str, Any]]) -> Optional[Dict[str, Any]]: + if kwargs_dict is None: + return None + if len(kwargs_dict) == 0: + return None + out = {} + for k, v in kwargs_dict.items(): + if k not in options.keys(): + raise TypeError(f"Invalid option passed to remote(): {k}") + validator = options[k] + if len(validator) != 0: + if not isinstance(v, validator[0]): + raise ValueError(validator[2]) + if not validator[1](v): + raise ValueError(validator[2]) + out[k] = v + return out diff --git a/python/ray/experimental/client/server/server.py b/python/ray/experimental/client/server/server.py index 2841384d8..5f86ddee2 100644 --- a/python/ray/experimental/client/server/server.py +++ b/python/ray/experimental/client/server/server.py @@ -4,8 +4,10 @@ import grpc import base64 from collections import defaultdict +from typing import Any from typing import Dict from typing import Set +from typing import Optional from ray import cloudpickle import ray @@ -187,9 +189,11 @@ class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer): ready_object_refs, remaining_object_refs = ray.wait( object_refs, num_returns=num_returns, - timeout=timeout if timeout != -1 else None) - except Exception: + timeout=timeout if timeout != -1 else None, + ) + except Exception as e: # TODO(ameer): improve exception messages. + logger.error(f"Exception {e}") return ray_client_pb2.WaitResponse(valid=False) logger.debug("wait: %s %s" % (str(ready_object_refs), str(remaining_object_refs))) @@ -206,9 +210,10 @@ class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer): remaining_object_ids=remaining_object_ids) def Schedule(self, task, context=None) -> ray_client_pb2.ClientTaskTicket: - logger.info("schedule: %s %s" % - (task.name, - ray_client_pb2.ClientTask.RemoteExecType.Name(task.type))) + logger.debug( + "schedule: %s %s" % (task.name, + ray_client_pb2.ClientTask.RemoteExecType.Name( + task.type))) with stash_api_for_tests(self._test_mode): try: if task.type == ray_client_pb2.ClientTask.FUNCTION: @@ -226,6 +231,7 @@ class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer): return result except Exception as e: logger.error(f"Caught schedule exception {e}") + raise e return ray_client_pb2.ClientTaskTicket( valid=False, error=cloudpickle.dumps(e)) @@ -236,34 +242,44 @@ class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer): raise Exception( "Can't run an actor the server doesn't have a handle for") arglist, kwargs = self._convert_args(task.args, task.kwargs) - output = getattr(actor_handle, task.name).remote(*arglist, **kwargs) - self.object_refs[task.client_id][output.binary()] = output - return ray_client_pb2.ClientTaskTicket(return_id=output.binary()) + method = getattr(actor_handle, task.name) + opts = decode_options(task.options) + if opts is not None: + method = method.options(**opts) + output = method.remote(*arglist, **kwargs) + ids = self.unify_and_track_outputs(output, task.client_id) + return ray_client_pb2.ClientTaskTicket(return_ids=ids) def _schedule_actor(self, task: ray_client_pb2.ClientTask, context=None) -> ray_client_pb2.ClientTaskTicket: - remote_class = self.lookup_or_register_actor(task.payload_id, - task.client_id) + remote_class = self.lookup_or_register_actor( + task.payload_id, task.client_id, + decode_options(task.baseline_options)) arglist, kwargs = self._convert_args(task.args, task.kwargs) + opts = decode_options(task.options) + if opts is not None: + remote_class = remote_class.options(**opts) with current_remote(remote_class): actor = remote_class.remote(*arglist, **kwargs) self.actor_refs[actor._actor_id.binary()] = actor self.actor_owners[task.client_id].add(actor._actor_id.binary()) return ray_client_pb2.ClientTaskTicket( - return_id=actor._actor_id.binary()) + return_ids=[actor._actor_id.binary()]) def _schedule_function(self, task: ray_client_pb2.ClientTask, context=None) -> ray_client_pb2.ClientTaskTicket: - remote_func = self.lookup_or_register_func(task.payload_id, - task.client_id) + remote_func = self.lookup_or_register_func( + task.payload_id, task.client_id, + decode_options(task.baseline_options)) arglist, kwargs = self._convert_args(task.args, task.kwargs) + opts = decode_options(task.options) + if opts is not None: + remote_func = remote_func.options(**opts) with current_remote(remote_func): output = remote_func.remote(*arglist, **kwargs) - if output.binary() in self.object_refs[task.client_id]: - raise Exception("already found it") - self.object_refs[task.client_id][output.binary()] = output - return ray_client_pb2.ClientTaskTicket(return_id=output.binary()) + ids = self.unify_and_track_outputs(output, task.client_id) + return ray_client_pb2.ClientTaskTicket(return_ids=ids) def _convert_args(self, arg_list, kwarg_map): argout = [] @@ -275,28 +291,50 @@ class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer): kwargout[k] = convert_from_arg(kwarg_map[k], self) return argout, kwargout - def lookup_or_register_func(self, id: bytes, client_id: str - ) -> ray.remote_function.RemoteFunction: + def lookup_or_register_func( + self, id: bytes, client_id: str, + options: Optional[Dict]) -> ray.remote_function.RemoteFunction: if id not in self.function_refs: funcref = self.object_refs[client_id][id] func = ray.get(funcref) if not inspect.isfunction(func): raise Exception("Attempting to register function that " "isn't a function.") - self.function_refs[id] = ray.remote(func) + if options is None or len(options) == 0: + self.function_refs[id] = ray.remote(func) + else: + self.function_refs[id] = ray.remote(**options)(func) return self.function_refs[id] - def lookup_or_register_actor(self, id: bytes, client_id: str): + def lookup_or_register_actor(self, id: bytes, client_id: str, + options: Optional[Dict]): if id not in self.registered_actor_classes: actor_class_ref = self.object_refs[client_id][id] actor_class = ray.get(actor_class_ref) if not inspect.isclass(actor_class): raise Exception("Attempting to schedule actor that " "isn't a class.") - reg_class = ray.remote(actor_class) + if options is None or len(options) == 0: + reg_class = ray.remote(actor_class) + else: + reg_class = ray.remote(**options)(actor_class) self.registered_actor_classes[id] = reg_class + return self.registered_actor_classes[id] + def unify_and_track_outputs(self, output, client_id): + if output is None: + outputs = [] + elif isinstance(output, list): + outputs = output + else: + outputs = [output] + for out in outputs: + if out.binary() in self.object_refs[client_id]: + logger.warning(f"Already saw object_ref {out}") + self.object_refs[client_id][out.binary()] = out + return [out.binary() for out in outputs] + def return_exception_in_context(err, context): if context is not None: @@ -309,6 +347,15 @@ def encode_exception(exception) -> str: return base64.standard_b64encode(data).decode() +def decode_options( + options: ray_client_pb2.TaskOptions) -> Optional[Dict[str, Any]]: + if options.json_options == "": + return None + opts = json.loads(options.json_options) + assert isinstance(opts, dict) + return opts + + def serve(connection_str, test_mode=False): server = grpc.server(futures.ThreadPoolExecutor(max_workers=10)) task_servicer = RayletServicer(test_mode=test_mode) diff --git a/python/ray/experimental/client/server/server_pickler.py b/python/ray/experimental/client/server/server_pickler.py index c3cd161bd..10da70cc1 100644 --- a/python/ray/experimental/client/server/server_pickler.py +++ b/python/ray/experimental/client/server/server_pickler.py @@ -56,6 +56,7 @@ class ServerPickler(cloudpickle.CloudPickler): client_id=self.client_id, ref_id=obj_id, name=None, + baseline_options=None, ) elif isinstance(obj, ray.actor.ActorHandle): actor_id = obj._actor_id.binary() @@ -69,6 +70,7 @@ class ServerPickler(cloudpickle.CloudPickler): client_id=self.client_id, ref_id=obj._actor_id.binary(), name=None, + baseline_options=None, ) return None @@ -89,13 +91,13 @@ class ClientUnpickler(pickle.Unpickler): elif pid.type == "RemoteFuncSelfReference": return ServerSelfReferenceSentinel() elif pid.type == "RemoteFunc": - return self.server.lookup_or_register_func(pid.ref_id, - pid.client_id) + return self.server.lookup_or_register_func( + pid.ref_id, pid.client_id, pid.baseline_options) elif pid.type == "RemoteActorSelfReference": return ServerSelfReferenceSentinel() elif pid.type == "RemoteActor": return self.server.lookup_or_register_actor( - pid.ref_id, pid.client_id) + pid.ref_id, pid.client_id, pid.baseline_options) elif pid.type == "RemoteMethod": actor = self.server.actor_refs[pid.ref_id] return getattr(actor, pid.name) diff --git a/python/ray/experimental/client/worker.py b/python/ray/experimental/client/worker.py index 6bfab6b75..d2ba52d62 100644 --- a/python/ray/experimental/client/worker.py +++ b/python/ray/experimental/client/worker.py @@ -21,12 +21,13 @@ import ray.cloudpickle as cloudpickle import ray.core.generated.ray_client_pb2 as ray_client_pb2 import ray.core.generated.ray_client_pb2_grpc as ray_client_pb2_grpc from ray.experimental.client.client_pickler import convert_to_arg -from ray.experimental.client.client_pickler import loads_from_server from ray.experimental.client.client_pickler import dumps_from_client -from ray.experimental.client.common import ClientObjectRef +from ray.experimental.client.client_pickler import loads_from_server from ray.experimental.client.common import ClientActorClass from ray.experimental.client.common import ClientActorHandle +from ray.experimental.client.common import ClientObjectRef from ray.experimental.client.common import ClientRemoteFunc +from ray.experimental.client.common import ClientStub from ray.experimental.client.dataclient import DataClient logger = logging.getLogger(__name__) @@ -80,7 +81,9 @@ class Worker: except grpc.RpcError as e: raise e.details() if not data.valid: - raise cloudpickle.loads(data.error) + err = cloudpickle.loads(data.error) + logger.error(err) + raise err return loads_from_server(data.data) def put(self, vals): @@ -98,6 +101,13 @@ class Worker: return out def _put(self, val): + if isinstance(val, ClientObjectRef): + raise TypeError( + "Calling 'put' on an ObjectRef is not allowed " + "(similarly, returning an ObjectRef from a remote " + "function is not allowed). If you really want to " + "do this, you can wrap the ObjectRef in a list and " + "call 'put' on it (or return it).") data = dumps_from_client(val, self._client_id) req = ray_client_pb2.PutRequest(data=data) resp = self.data_client.PutObject(req) @@ -107,7 +117,8 @@ class Worker: object_refs: List[ClientObjectRef], *, num_returns: int = 1, - timeout: float = None + timeout: float = None, + fetch_local: bool = True ) -> Tuple[List[ClientObjectRef], List[ClientObjectRef]]: if not isinstance(object_refs, list): raise TypeError("wait() expected a list of ClientObjectRef, " @@ -136,19 +147,22 @@ class Worker: return (client_ready_object_ids, client_remaining_object_ids) - def remote(self, function_or_class, *args, **kwargs): - # TODO(barakmich): Arguments to ray.remote - # get captured here. - if (inspect.isfunction(function_or_class) - or is_cython(function_or_class)): - return ClientRemoteFunc(function_or_class) - elif inspect.isclass(function_or_class): - return ClientActorClass(function_or_class) - else: - raise TypeError("The @ray.remote decorator must be applied to " - "either a function or to a class.") + def remote(self, *args, **kwargs): + if len(args) == 1 and len(kwargs) == 0 and callable(args[0]): + # This is the case where the decorator is just @ray.remote. + return remote_decorator(options=None)(args[0]) + error_string = ("The @ray.remote decorator must be applied either " + "with no arguments and no parentheses, for example " + "'@ray.remote', or it must be applied using some of " + "the arguments 'num_returns', 'num_cpus', 'num_gpus', " + "'memory', 'object_store_memory', 'resources', " + "'max_calls', or 'max_restarts', like " + "'@ray.remote(num_returns=2, " + "resources={\"CustomResource\": 1})'.") + assert len(args) == 0 and len(kwargs) > 0, error_string + return remote_decorator(options=kwargs) - def call_remote(self, instance, *args, **kwargs) -> bytes: + def call_remote(self, instance, *args, **kwargs) -> List[bytes]: task = instance._prepare_client_task() for arg in args: pb_arg = convert_to_arg(arg, self._client_id) @@ -160,10 +174,10 @@ class Worker: try: ticket = self.server.Schedule(task, metadata=self.metadata) except grpc.RpcError as e: - raise e.details() + raise decode_exception(e.details) if not ticket.valid: raise cloudpickle.loads(ticket.error) - return ticket.return_id + return ticket.return_ids def call_release(self, id: bytes) -> None: self.reference_count[id] -= 1 @@ -234,6 +248,20 @@ class Worker: return False +def remote_decorator(options: Optional[Dict[str, Any]]): + def decorator(function_or_class) -> ClientStub: + if (inspect.isfunction(function_or_class) + or is_cython(function_or_class)): + return ClientRemoteFunc(function_or_class, options=options) + elif inspect.isclass(function_or_class): + return ClientActorClass(function_or_class, options=options) + else: + raise TypeError("The @ray.remote decorator must be applied to " + "either a function or to a class.") + + return decorator + + def make_client_id() -> str: id = uuid.uuid4() return id.hex diff --git a/python/ray/tests/BUILD b/python/ray/tests/BUILD index e88986475..7e552e616 100644 --- a/python/ray/tests/BUILD +++ b/python/ray/tests/BUILD @@ -158,6 +158,7 @@ py_test( py_test_module_list( files = [ "test_actor.py", + "test_advanced.py", "test_basic.py", "test_basic_2.py", ], diff --git a/python/ray/tests/test_actor.py b/python/ray/tests/test_actor.py index 1e761762e..3ba2ed7eb 100644 --- a/python/ray/tests/test_actor.py +++ b/python/ray/tests/test_actor.py @@ -25,7 +25,9 @@ else: import setproctitle # noqa -@pytest.mark.skipif(client_test_enabled(), reason="test setup order") +@pytest.mark.skipif( + client_test_enabled(), + reason="defining early, no ray package injection yet") def test_caching_actors(shutdown_only): # Test defining actors before ray.init() has been called. @@ -564,7 +566,6 @@ def test_actor_static_attributes(ray_start_regular_shared): assert ray.get(t.g.remote()) == 3 -@pytest.mark.skipif(client_test_enabled(), reason="remote args") def test_decorator_args(ray_start_regular_shared): # This is an invalid way of using the actor decorator. with pytest.raises(Exception): @@ -655,7 +656,7 @@ def test_actor_inheritance(ray_start_regular_shared): pass -@pytest.mark.skipif(client_test_enabled(), reason="remote args") +@pytest.mark.skipif(client_test_enabled(), reason="ray.method unimplemented") def test_multiple_return_values(ray_start_regular_shared): @ray.remote class Foo: @@ -689,7 +690,6 @@ def test_multiple_return_values(ray_start_regular_shared): assert ray.get([id3a, id3b, id3c]) == [1, 2, 3] -@pytest.mark.skipif(client_test_enabled(), reason="remote args") def test_options_num_returns(ray_start_regular_shared): @ray.remote class Foo: diff --git a/python/ray/tests/test_advanced.py b/python/ray/tests/test_advanced.py index 08dd168fa..ea2a6c693 100644 --- a/python/ray/tests/test_advanced.py +++ b/python/ray/tests/test_advanced.py @@ -10,16 +10,22 @@ import time import numpy as np import pytest -import ray import ray.cluster_utils import ray.test_utils +from ray.test_utils import client_test_enabled from ray.test_utils import RayTestTimeoutException +if client_test_enabled(): + from ray.experimental.client import ray +else: + import ray + logger = logging.getLogger(__name__) # issue https://github.com/ray-project/ray/issues/7105 +@pytest.mark.skipif(client_test_enabled(), reason="message size") def test_internal_free(shutdown_only): ray.init(num_cpus=1) @@ -60,14 +66,14 @@ def test_multiple_waits_and_gets(shutdown_only): return 1 @ray.remote - def g(l): - # The argument l should be a list containing one object ref. - ray.wait([l[0]]) + def g(input_list): + # The argument input_list should be a list containing one object ref. + ray.wait([input_list[0]]) @ray.remote - def h(l): - # The argument l should be a list containing one object ref. - ray.get(l[0]) + def h(input_list): + # The argument input_list should be a list containing one object ref. + ray.get(input_list[0]) # Make sure that multiple wait requests involving the same object ref # all return. @@ -80,6 +86,7 @@ def test_multiple_waits_and_gets(shutdown_only): ray.get([h.remote([x]), h.remote([x])]) +@pytest.mark.skipif(client_test_enabled(), reason="internal api") def test_caching_functions_to_run(shutdown_only): # Test that we export functions to run on all workers before the driver # is connected. @@ -125,6 +132,7 @@ def test_caching_functions_to_run(shutdown_only): ray.worker.global_worker.run_function_on_all_workers(f) +@pytest.mark.skipif(client_test_enabled(), reason="internal api") def test_running_function_on_all_workers(ray_start_regular): def f(worker_info): sys.path.append("fake_directory") @@ -152,6 +160,7 @@ def test_running_function_on_all_workers(ray_start_regular): assert "fake_directory" not in ray.get(get_path2.remote()) +@pytest.mark.skipif(client_test_enabled(), reason="ray.timeline") def test_profiling_api(ray_start_2_cpus): @ray.remote def f(): @@ -482,6 +491,7 @@ def test_multithreading(ray_start_2_cpus): ray.get(actor.join.remote()) == "ok" +@pytest.mark.skipif(client_test_enabled(), reason="message size") def test_wait_makes_object_local(ray_start_cluster): cluster = ray_start_cluster cluster.add_node(num_cpus=0) diff --git a/python/ray/tests/test_basic.py b/python/ray/tests/test_basic.py index 709b467e6..38330645b 100644 --- a/python/ray/tests/test_basic.py +++ b/python/ray/tests/test_basic.py @@ -24,7 +24,7 @@ logger = logging.getLogger(__name__) # https://github.com/ray-project/ray/issues/6662 -@pytest.mark.skipif(client_test_enabled(), reason="internal api") +@pytest.mark.skipif(client_test_enabled(), reason="interferes with grpc") def test_ignore_http_proxy(shutdown_only): ray.init(num_cpus=1) os.environ["http_proxy"] = "http://example.com" @@ -55,14 +55,12 @@ def test_grpc_message_size(shutdown_only): # https://github.com/ray-project/ray/issues/7287 -@pytest.mark.skipif(client_test_enabled(), reason="internal api") def test_omp_threads_set(shutdown_only): ray.init(num_cpus=1) # Should have been auto set by ray init. assert os.environ["OMP_NUM_THREADS"] == "1" -@pytest.mark.skipif(client_test_enabled(), reason="remote args") def test_submit_api(shutdown_only): ray.init(num_cpus=2, num_gpus=1, resources={"Custom": 1}) @@ -121,7 +119,6 @@ def test_submit_api(shutdown_only): assert ray.get([id1, id2, id3, id4]) == [0, 1, "test", 2] -@pytest.mark.skipif(client_test_enabled(), reason="remote args") def test_invalid_arguments(shutdown_only): ray.init(num_cpus=2) @@ -176,7 +173,6 @@ def test_invalid_arguments(shutdown_only): x = 1 -@pytest.mark.skipif(client_test_enabled(), reason="remote args") def test_many_fractional_resources(shutdown_only): ray.init(num_cpus=2, num_gpus=2, resources={"Custom": 2}) @@ -244,7 +240,6 @@ def test_many_fractional_resources(shutdown_only): assert False, "Did not get correct available resources." -@pytest.mark.skipif(client_test_enabled(), reason="remote args") def test_background_tasks_with_max_calls(shutdown_only): ray.init(num_cpus=2) @@ -360,8 +355,9 @@ def test_function_descriptor(): assert d.get(python_descriptor2) == 123 -@pytest.mark.skipif(client_test_enabled(), reason="remote args") def test_ray_options(shutdown_only): + ray.init(num_cpus=10, num_gpus=10, resources={"custom1": 2}) + @ray.remote( num_cpus=2, num_gpus=3, memory=150 * 2**20, resources={"custom1": 1}) def foo(): @@ -370,8 +366,6 @@ def test_ray_options(shutdown_only): time.sleep(0.1) return ray.available_resources() - ray.init(num_cpus=10, num_gpus=10, resources={"custom1": 2}) - without_options = ray.get(foo.remote()) with_options = ray.get( foo.options( diff --git a/python/ray/tests/test_basic_2.py b/python/ray/tests/test_basic_2.py index 25688a6f7..cd8114aa8 100644 --- a/python/ray/tests/test_basic_2.py +++ b/python/ray/tests/test_basic_2.py @@ -537,7 +537,6 @@ def test_actor_recursive(ray_start_regular_shared): assert result == [x * 2 for x in range(100)] -@pytest.mark.skipif(client_test_enabled(), reason="remote args") def test_actor_concurrent(ray_start_regular_shared): @ray.remote class Batcher: diff --git a/src/ray/protobuf/ray_client.proto b/src/ray/protobuf/ray_client.proto index ea4939738..cbd6679dd 100644 --- a/src/ray/protobuf/ray_client.proto +++ b/src/ray/protobuf/ray_client.proto @@ -35,6 +35,18 @@ message Arg { Type type = 4; } +// A message representing the valid options to modify a task exectution +// +// TODO(barakmich): In the longer term, if everything were a client, +// this message could be the actual standard for which options are +// allowed in the API. Today, however, it's a bit flexible and defined in the +// Python code. So for now, it's a stand-in message with a json field, but +// this is forwards-compatible with deprecating that field and instituting +// strongly defined and typed fields, without migrating the original ClientTask. +message TaskOptions { + string json_options = 1; +} + // Represents one unit of work to be executed by the server. message ClientTask { enum RemoteExecType { @@ -45,8 +57,8 @@ message ClientTask { } // Which type of work this request represents. RemoteExecType type = 1; - // A name parameter, if the payload can be called in more than one way (like a method on - // a payload object). + // A name parameter, if the payload can be called in more than one way + // (like a method on a payload object). string name = 2; // A reference to the payload. bytes payload_id = 3; @@ -54,16 +66,20 @@ message ClientTask { repeated Arg args = 4; // Keyword parameters to pass to this call. map kwargs = 5; - // The ID of the client namespace associated with the Datapath stream making this - // request. + // The ID of the client namespace associated with the Datapath stream + // making this request. string client_id = 6; + // Options for modifying the remote task execution environment. + TaskOptions options = 7; + // Options passed to create the default remote task excution environment. + TaskOptions baseline_options = 8; } message ClientTaskTicket { // Was the task successful? bool valid = 1; - // A reference to the returned value from the execution. - bytes return_id = 2; + // A reference to the returned values from the execution. + repeated bytes return_ids = 2; // If unsuccessful, an encoding of the error. bytes error = 3; } From e715ade2d1026d1ca381b463fbb482c30acd3d93 Mon Sep 17 00:00:00 2001 From: Barak Michener Date: Sun, 20 Dec 2020 16:34:50 -0800 Subject: [PATCH 48/88] Support retrieval of named actor handles (#13000) Change-Id: I05d31c9c67943d2a0230782cbdaa98341584cbc7 --- python/ray/experimental/client/api.py | 3 ++ python/ray/experimental/client/common.py | 7 ++--- .../ray/experimental/client/server/server.py | 12 ++++++++ python/ray/experimental/client/worker.py | 15 +++++++++- python/ray/tests/test_experimental_client.py | 29 +++++++++++++++++++ src/ray/protobuf/ray_client.proto | 1 + 6 files changed, 62 insertions(+), 5 deletions(-) diff --git a/python/ray/experimental/client/api.py b/python/ray/experimental/client/api.py index 5167e5988..93da6382f 100644 --- a/python/ray/experimental/client/api.py +++ b/python/ray/experimental/client/api.py @@ -197,6 +197,9 @@ class ClientAPI(APIImpl): def close(self) -> None: return self.worker.close() + def get_actor(self, name: str) -> "ClientActorHandle": + return self.worker.get_actor(name) + def kill(self, actor: "ClientActorHandle", *, no_restart=True): return self.worker.terminate_actor(actor, no_restart) diff --git a/python/ray/experimental/client/common.py b/python/ray/experimental/client/common.py index 49eee05d6..f68b26e2c 100644 --- a/python/ray/experimental/client/common.py +++ b/python/ray/experimental/client/common.py @@ -144,7 +144,7 @@ class ClientActorClass(ClientStub): # Actually instantiate the actor ref_ids = ray.call_remote(self, *args, **kwargs) assert len(ref_ids) == 1 - return ClientActorHandle(ClientActorRef(ref_ids[0]), self) + return ClientActorHandle(ClientActorRef(ref_ids[0])) def options(self, **kwargs): return ActorOptionWrapper(self, kwargs) @@ -186,8 +186,7 @@ class ClientActorHandle(ClientStub): ray.actor.ActorHandle contained in the actor_id ref. """ - def __init__(self, actor_ref: ClientActorRef, - actor_class: ClientActorClass): + def __init__(self, actor_ref: ClientActorRef): self.actor_ref = actor_ref def __del__(self) -> None: @@ -266,7 +265,7 @@ class ActorOptionWrapper(OptionWrapper): def remote(self, *args, **kwargs): ref_ids = ray.call_remote(self, *args, **kwargs) assert len(ref_ids) == 1 - return ClientActorHandle(ClientActorRef(ref_ids[0]), self) + return ClientActorHandle(ClientActorRef(ref_ids[0])) def set_task_options(task: ray_client_pb2.ClientTask, diff --git a/python/ray/experimental/client/server/server.py b/python/ray/experimental/client/server/server.py index 5f86ddee2..442cf1afa 100644 --- a/python/ray/experimental/client/server/server.py +++ b/python/ray/experimental/client/server/server.py @@ -222,6 +222,8 @@ class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer): result = self._schedule_actor(task, context) elif task.type == ray_client_pb2.ClientTask.METHOD: result = self._schedule_method(task, context) + elif task.type == ray_client_pb2.ClientTask.NAMED_ACTOR: + result = self._schedule_named_actor(task, context) else: raise NotImplementedError( "Unimplemented Schedule task type: %s" % @@ -281,6 +283,16 @@ class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer): ids = self.unify_and_track_outputs(output, task.client_id) return ray_client_pb2.ClientTaskTicket(return_ids=ids) + def _schedule_named_actor(self, + task: ray_client_pb2.ClientTask, + context=None) -> ray_client_pb2.ClientTaskTicket: + assert len(task.payload_id) == 0 + actor = ray.get_actor(task.name) + self.actor_refs[actor._actor_id.binary()] = actor + self.actor_owners[task.client_id].add(actor._actor_id.binary()) + return ray_client_pb2.ClientTaskTicket( + return_ids=[actor._actor_id.binary()]) + def _convert_args(self, arg_list, kwarg_map): argout = [] for arg in arg_list: diff --git a/python/ray/experimental/client/worker.py b/python/ray/experimental/client/worker.py index d2ba52d62..bba23584b 100644 --- a/python/ray/experimental/client/worker.py +++ b/python/ray/experimental/client/worker.py @@ -25,6 +25,7 @@ from ray.experimental.client.client_pickler import dumps_from_client from ray.experimental.client.client_pickler import loads_from_server from ray.experimental.client.common import ClientActorClass from ray.experimental.client.common import ClientActorHandle +from ray.experimental.client.common import ClientActorRef from ray.experimental.client.common import ClientObjectRef from ray.experimental.client.common import ClientRemoteFunc from ray.experimental.client.common import ClientStub @@ -169,8 +170,12 @@ class Worker: task.args.append(pb_arg) for k, v in kwargs.items(): task.kwargs[k].CopyFrom(convert_to_arg(v, self._client_id)) - task.client_id = self._client_id + return self._call_schedule_for_task(task) + + def _call_schedule_for_task( + self, task: ray_client_pb2.ClientTask) -> List[bytes]: logger.debug("Scheduling %s" % task) + task.client_id = self._client_id try: ticket = self.server.Schedule(task, metadata=self.metadata) except grpc.RpcError as e: @@ -201,6 +206,14 @@ class Worker: if self.channel: self.channel = None + def get_actor(self, name: str) -> ClientActorHandle: + task = ray_client_pb2.ClientTask() + task.type = ray_client_pb2.ClientTask.NAMED_ACTOR + task.name = name + ids = self._call_schedule_for_task(task) + assert len(ids) == 1 + return ClientActorHandle(ClientActorRef(ids[0])) + def terminate_actor(self, actor: ClientActorHandle, no_restart: bool) -> None: if not isinstance(actor, ClientActorHandle): diff --git a/python/ray/tests/test_experimental_client.py b/python/ray/tests/test_experimental_client.py index e68abb366..cc15e7272 100644 --- a/python/ray/tests/test_experimental_client.py +++ b/python/ray/tests/test_experimental_client.py @@ -234,6 +234,35 @@ def test_pass_handles(ray_start_regular_shared): 4)) == local_fact(4) +def test_basic_named_actor(ray_start_regular_shared): + """ + Test that ray.get_actor() can create and return a detached actor. + """ + with ray_start_client_server() as ray: + + @ray.remote + class Accumulator: + def __init__(self): + self.x = 0 + + def inc(self): + self.x += 1 + + def get(self): + return self.x + + # Create the actor + actor = Accumulator.options(name="test_acc").remote() + + actor.inc.remote() + actor.inc.remote() + del actor + + new_actor = ray.get_actor("test_acc") + new_actor.inc.remote() + assert ray.get(new_actor.get.remote()) == 3 + + if __name__ == "__main__": import sys sys.exit(pytest.main(["-v", __file__])) diff --git a/src/ray/protobuf/ray_client.proto b/src/ray/protobuf/ray_client.proto index cbd6679dd..a566f8031 100644 --- a/src/ray/protobuf/ray_client.proto +++ b/src/ray/protobuf/ray_client.proto @@ -54,6 +54,7 @@ message ClientTask { ACTOR = 1; METHOD = 2; STATIC_METHOD = 3; + NAMED_ACTOR = 4; } // Which type of work this request represents. RemoteExecType type = 1; From b2bcab711d333442c282cf64c66a9fac2c93218f Mon Sep 17 00:00:00 2001 From: Sven Mika Date: Mon, 21 Dec 2020 02:22:32 +0100 Subject: [PATCH 49/88] [RLlib] Attention Nets: tf (#12753) --- rllib/BUILD | 7 - rllib/agents/callbacks.py | 4 +- rllib/agents/ppo/ppo_tf_policy.py | 3 +- .../collectors/simple_list_collector.py | 171 +++++++++--- rllib/evaluation/postprocessing.py | 4 - rllib/examples/attention_net.py | 10 +- rllib/examples/cartpole_lstm.py | 2 +- rllib/models/modelv2.py | 69 +++-- rllib/models/tests/test_attention_nets.py | 263 ------------------ rllib/models/tf/attention_net.py | 137 +++++---- rllib/models/tf/layers/__init__.py | 4 +- .../layers/relative_multi_head_attention.py | 45 ++- rllib/models/tf/layers/skip_connection.py | 1 - rllib/models/torch/modules/skip_connection.py | 1 - rllib/policy/dynamic_tf_policy.py | 26 +- rllib/policy/eager_tf_policy.py | 16 +- rllib/policy/policy.py | 53 +++- rllib/policy/rnn_sequencing.py | 125 +++++---- rllib/policy/sample_batch.py | 43 ++- rllib/policy/tf_policy.py | 32 ++- rllib/policy/torch_policy.py | 3 +- rllib/policy/view_requirement.py | 20 +- rllib/tests/test_attention_net_learning.py | 6 +- rllib/tests/test_lstm.py | 51 ++-- rllib/utils/sgd.py | 29 +- rllib/utils/typing.py | 3 + 26 files changed, 567 insertions(+), 561 deletions(-) delete mode 100644 rllib/models/tests/test_attention_nets.py diff --git a/rllib/BUILD b/rllib/BUILD index bd612d0ff..c645c27a0 100644 --- a/rllib/BUILD +++ b/rllib/BUILD @@ -1074,13 +1074,6 @@ py_test( # Tag: models # -------------------------------------------------------------------- -py_test( - name = "test_attention_nets", - tags = ["models"], - size = "small", - srcs = ["models/tests/test_attention_nets.py"] -) - py_test( name = "test_convtranspose2d_stack", tags = ["models"], diff --git a/rllib/agents/callbacks.py b/rllib/agents/callbacks.py index bf5284740..e84cf4148 100644 --- a/rllib/agents/callbacks.py +++ b/rllib/agents/callbacks.py @@ -191,8 +191,8 @@ class DefaultCallbacks: **kwargs) -> None: """Called at the beginning of Policy.learn_on_batch(). - Note: This is called before the Model's `preprocess_train_batch()` - is called. + Note: This is called before 0-padding via + `pad_batch_to_sequences_of_same_size`. Args: policy (Policy): Reference to the current Policy object. diff --git a/rllib/agents/ppo/ppo_tf_policy.py b/rllib/agents/ppo/ppo_tf_policy.py index 29266dfcc..957d68ce3 100644 --- a/rllib/agents/ppo/ppo_tf_policy.py +++ b/rllib/agents/ppo/ppo_tf_policy.py @@ -198,7 +198,8 @@ def postprocess_ppo_gae( # input_dict. if policy.config["_use_trajectory_view_api"]: # Create an input dict according to the Model's requirements. - input_dict = policy.model.get_input_dict(sample_batch, index=-1) + input_dict = policy.model.get_input_dict( + sample_batch, index="last") last_r = policy._value(**input_dict) # TODO: (sven) Remove once trajectory view API is all-algo default. else: diff --git a/rllib/evaluation/collectors/simple_list_collector.py b/rllib/evaluation/collectors/simple_list_collector.py index efcadf32f..1d5fe3f76 100644 --- a/rllib/evaluation/collectors/simple_list_collector.py +++ b/rllib/evaluation/collectors/simple_list_collector.py @@ -1,6 +1,7 @@ import collections from gym.spaces import Space import logging +import math import numpy as np from typing import Any, List, Dict, Tuple, TYPE_CHECKING, Union @@ -34,6 +35,9 @@ def to_float_np_array(v: List[Any]) -> np.ndarray: return arr +_INIT_COLS = [SampleBatch.OBS] + + class _AgentCollector: """Collects samples for one agent in one trajectory (episode). @@ -45,9 +49,18 @@ class _AgentCollector: _next_unroll_id = 0 # disambiguates unrolls within a single episode - def __init__(self, shift_before: int = 0): - self.shift_before = max(shift_before, 1) + def __init__(self, view_reqs): + # Determine the size of the buffer we need for data before the actual + # episode starts. This is used for 0-buffering of e.g. prev-actions, + # or internal state inputs. + self.shift_before = -min( + (int(vr.shift.split(":")[0]) + if isinstance(vr.shift, str) else vr.shift) + + (-1 if vr.data_col in _INIT_COLS or k in _INIT_COLS else 0) + for k, vr in view_reqs.items()) + # The actual data buffers (lists holding each timestep's data). self.buffers: Dict[str, List] = {} + # The episode ID for the agent for which we collect data. self.episode_id = None # The simple timestep count for this agent. Gets increased by one # each time a (non-initial!) observation is added. @@ -137,31 +150,88 @@ class _AgentCollector: # -> skip. if data_col not in self.buffers: continue + # OBS are already shifted by -1 (the initial obs starts one ts # before all other data columns). - shift = view_req.shift - \ - (1 if data_col == SampleBatch.OBS else 0) + obs_shift = -1 if data_col == SampleBatch.OBS else 0 + + # Keep an np-array cache so we don't have to regenerate the + # np-array for different view_cols using to the same data_col. if data_col not in np_data: np_data[data_col] = to_float_np_array(self.buffers[data_col]) - # Shift is exactly 0: Send trajectory as is. - if shift == 0: - data = np_data[data_col][self.shift_before:] - # Shift is positive: We still need to 0-pad at the end here. - elif shift > 0: - data = to_float_np_array( - self.buffers[data_col][self.shift_before + shift:] + [ - np.zeros( - shape=view_req.space.shape, - dtype=view_req.space.dtype) for _ in range(shift) + + # Range of indices on time-axis, e.g. "-50:-1". Together with + # the `batch_repeat_value`, this determines the data produced. + # Example: + # batch_repeat_value=10, shift_from=-3, shift_to=-1 + # buffer=[-3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12] + # resulting data=[[-3, -2, -1], [7, 8, 9]] + # Range of 3 consecutive items repeats every 10 timesteps. + if view_req.shift_from is not None: + if view_req.batch_repeat_value > 1: + count = int( + math.ceil((len(np_data[data_col]) - self.shift_before) + / view_req.batch_repeat_value)) + data = np.asarray([ + np_data[data_col][self.shift_before + + (i * view_req.batch_repeat_value) + + view_req.shift_from + + obs_shift:self.shift_before + + (i * view_req.batch_repeat_value) + + view_req.shift_to + 1 + obs_shift] + for i in range(count) ]) - # Shift is negative: Shift into the already existing and 0-padded - # "before" area of our buffers. + else: + data = np_data[data_col][self.shift_before + + view_req.shift_from + + obs_shift:self.shift_before + + view_req.shift_to + 1 + obs_shift] + # Set of (probably non-consecutive) indices. + # Example: + # shift=[-3, 0] + # buffer=[-3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12] + # resulting data=[[-3, 0], [-2, 1], [-1, 2], [0, 3], [1, 4], ...] + elif isinstance(view_req.shift, np.ndarray): + data = np_data[data_col][self.shift_before + obs_shift + + view_req.shift] + # Single shift int value. Use the trajectory as-is, and if + # `shift` != 0: shifted by that value. else: - data = np_data[data_col][self.shift_before + shift:shift] + shift = view_req.shift + obs_shift + + # Batch repeat (only provide a value every n timesteps). + if view_req.batch_repeat_value > 1: + count = int( + math.ceil((len(np_data[data_col]) - self.shift_before) + / view_req.batch_repeat_value)) + data = np.asarray([ + np_data[data_col][self.shift_before + ( + i * view_req.batch_repeat_value) + shift] + for i in range(count) + ]) + # Shift is exactly 0: Use trajectory as is. + elif shift == 0: + data = np_data[data_col][self.shift_before:] + # Shift is positive: We still need to 0-pad at the end. + elif shift > 0: + data = to_float_np_array( + self.buffers[data_col][self.shift_before + shift:] + [ + np.zeros( + shape=view_req.space.shape, + dtype=view_req.space.dtype) + for _ in range(shift) + ]) + # Shift is negative: Shift into the already existing and + # 0-padded "before" area of our buffers. + else: + data = np_data[data_col][self.shift_before + shift:shift] + if len(data) > 0: batch_data[view_col] = data - batch = SampleBatch(batch_data) + # Due to possible batch-repeats > 1, columns in the resulting batch + # may not all have the same batch size. + batch = SampleBatch(batch_data, _dont_check_lens=True) # Add EPS_ID and UNROLL_ID to batch. batch.data[SampleBatch.EPS_ID] = np.repeat(self.episode_id, @@ -230,15 +300,22 @@ class _PolicyCollector: appended to this policy's buffers. """ - def __init__(self): - """Initializes a _PolicyCollector instance.""" + def __init__(self, policy): + """Initializes a _PolicyCollector instance. + + Args: + policy (Policy): The policy object. + """ self.buffers: Dict[str, List] = collections.defaultdict(list) + self.policy = policy # The total timestep count for all agents that use this policy. # NOTE: This is not an env-step count (across n agents). AgentA and # agentB, both using this policy, acting in the same episode and both # doing n steps would increase the count by 2*n. self.agent_steps = 0 + # Seq-lens list of already added agent batches. + self.seq_lens = [] if policy.is_recurrent() else None def add_postprocessed_batch_for_training( self, batch: SampleBatch, @@ -257,11 +334,18 @@ class _PolicyCollector: # 1) If col is not in view_requirements, we must have a direct # child of the base Policy that doesn't do auto-view req creation. # 2) Col is in view-reqs and needed for training. - if view_col not in view_requirements or \ - view_requirements[view_col].used_for_training: + view_req = view_requirements.get(view_col) + if view_req is None or view_req.used_for_training: self.buffers[view_col].extend(data) # Add the agent's trajectory length to our count. self.agent_steps += batch.count + # Adjust the seq-lens array depending on the incoming agent sequences. + if self.seq_lens is not None: + max_seq_len = self.policy.config["model"]["max_seq_len"] + count = batch.count + while count > 0: + self.seq_lens.append(min(count, max_seq_len)) + count -= max_seq_len def build(self): """Builds a SampleBatch for this policy from the collected data. @@ -273,20 +357,22 @@ class _PolicyCollector: this policy. """ # Create batch from our buffers. - batch = SampleBatch(self.buffers) - assert SampleBatch.UNROLL_ID in batch.data + batch = SampleBatch( + self.buffers, _seq_lens=self.seq_lens, _dont_check_lens=True) # Clear buffers for future samples. self.buffers.clear() - # Reset agent steps to 0. + # Reset agent steps to 0 and seq-lens to empty list. self.agent_steps = 0 + if self.seq_lens is not None: + self.seq_lens = [] return batch class _PolicyCollectorGroup: def __init__(self, policy_map): self.policy_collectors = { - pid: _PolicyCollector() - for pid in policy_map.keys() + pid: _PolicyCollector(policy) + for pid, policy in policy_map.items() } # Total env-steps (1 env-step=up to N agents stepped). self.env_steps = 0 @@ -396,11 +482,14 @@ class _SimpleListCollector(_SampleCollector): self.agent_key_to_policy_id[agent_key] = policy_id else: assert self.agent_key_to_policy_id[agent_key] == policy_id + policy = self.policy_map[policy_id] + view_reqs = policy.model.inference_view_requirements if \ + getattr(policy, "model", None) else policy.view_requirements # Add initial obs to Trajectory. assert agent_key not in self.agent_collectors # TODO: determine exact shift-before based on the view-req shifts. - self.agent_collectors[agent_key] = _AgentCollector() + self.agent_collectors[agent_key] = _AgentCollector(view_reqs) self.agent_collectors[agent_key].add_init_obs( episode_id=episode.episode_id, agent_index=episode._agent_index(agent_id), @@ -466,11 +555,19 @@ class _SimpleListCollector(_SampleCollector): for view_col, view_req in view_reqs.items(): # Create the batch of data from the different buffers. data_col = view_req.data_col or view_col - time_indices = \ - view_req.shift - ( - 1 if data_col in [SampleBatch.OBS, "t", "env_id", - SampleBatch.AGENT_INDEX] else 0) + delta = -1 if data_col in [ + SampleBatch.OBS, "t", "env_id", SampleBatch.EPS_ID, + SampleBatch.AGENT_INDEX + ] else 0 + # Range of shifts, e.g. "-100:0". Note: This includes index 0! + if view_req.shift_from is not None: + time_indices = (view_req.shift_from + delta, + view_req.shift_to + delta) + # Single shift (e.g. -1) or list of shifts, e.g. [-4, -1, 0]. + else: + time_indices = view_req.shift + delta data_list = [] + # Loop through agents and add-up their data (batch). for k in keys: if data_col == SampleBatch.EPS_ID: data_list.append(self.agent_collectors[k].episode_id) @@ -482,7 +579,15 @@ class _SimpleListCollector(_SampleCollector): self.agent_collectors[k]._build_buffers({ data_col: fill_value }) - data_list.append(buffers[k][data_col][time_indices]) + if isinstance(time_indices, tuple): + if time_indices[1] == -1: + data_list.append( + buffers[k][data_col][time_indices[0]:]) + else: + data_list.append(buffers[k][data_col][time_indices[ + 0]:time_indices[1] + 1]) + else: + data_list.append(buffers[k][data_col][time_indices]) input_dict[view_col] = np.array(data_list) self._reset_inference_calls(policy_id) diff --git a/rllib/evaluation/postprocessing.py b/rllib/evaluation/postprocessing.py index a19411433..0cb25d5c7 100644 --- a/rllib/evaluation/postprocessing.py +++ b/rllib/evaluation/postprocessing.py @@ -50,8 +50,6 @@ def compute_advantages(rollout: SampleBatch, processed rewards. """ - rollout_size = len(rollout[SampleBatch.ACTIONS]) - assert SampleBatch.VF_PREDS in rollout or not use_critic, \ "use_critic=True but values not found" assert use_critic or not use_gae, \ @@ -90,6 +88,4 @@ def compute_advantages(rollout: SampleBatch, rollout[Postprocessing.ADVANTAGES] = rollout[ Postprocessing.ADVANTAGES].astype(np.float32) - assert all(val.shape[0] == rollout_size for key, val in rollout.items()), \ - "Rollout stacked incorrectly!" return rollout diff --git a/rllib/examples/attention_net.py b/rllib/examples/attention_net.py index 49884d9f3..de3f06c29 100644 --- a/rllib/examples/attention_net.py +++ b/rllib/examples/attention_net.py @@ -39,6 +39,7 @@ if __name__ == "__main__": config = { "env": args.env, + # This env_config is only used for the RepeatAfterMeEnv env. "env_config": { "repeat_delay": 2, }, @@ -48,7 +49,7 @@ if __name__ == "__main__": "num_workers": 0, "num_envs_per_worker": 20, "entropy_coeff": 0.001, - "num_sgd_iter": 5, + "num_sgd_iter": 10, "vf_loss_coeff": 1e-5, "model": { "custom_model": GTrXLNet, @@ -56,9 +57,10 @@ if __name__ == "__main__": "custom_model_config": { "num_transformer_units": 1, "attn_dim": 64, - "num_heads": 2, - "memory_tau": 50, + "memory_inference": 100, + "memory_training": 50, "head_dim": 32, + "num_heads": 2, "ff_hidden_dim": 32, }, }, @@ -71,7 +73,7 @@ if __name__ == "__main__": "episode_reward_mean": args.stop_reward, } - results = tune.run(args.run, config=config, stop=stop, verbose=1) + results = tune.run(args.run, config=config, stop=stop, verbose=2) if args.as_test: check_learning_achieved(results, args.stop_reward) diff --git a/rllib/examples/cartpole_lstm.py b/rllib/examples/cartpole_lstm.py index 1c9edc655..de53c6ff1 100644 --- a/rllib/examples/cartpole_lstm.py +++ b/rllib/examples/cartpole_lstm.py @@ -59,7 +59,7 @@ if __name__ == "__main__": "episode_reward_mean": args.stop_reward, } - results = tune.run(args.run, config=config, stop=stop, verbose=1) + results = tune.run(args.run, config=config, stop=stop, verbose=2) if args.as_test: check_learning_achieved(results, args.stop_reward) diff --git a/rllib/models/modelv2.py b/rllib/models/modelv2.py index 38478857c..fc45149a5 100644 --- a/rllib/models/modelv2.py +++ b/rllib/models/modelv2.py @@ -13,7 +13,8 @@ from ray.rllib.utils.annotations import DeveloperAPI, PublicAPI from ray.rllib.utils.framework import try_import_tf, try_import_torch, \ TensorType from ray.rllib.utils.spaces.repeated import Repeated -from ray.rllib.utils.typing import ModelConfigDict, TensorStructType +from ray.rllib.utils.typing import ModelConfigDict, ModelInputDict, \ + TensorStructType tf1, tf, tfv = try_import_tf() torch, _ = try_import_torch() @@ -238,14 +239,14 @@ class ModelV2: right input dict, state, and seq len arguments. """ - train_batch["is_training"] = is_training + input_dict = train_batch.copy() + input_dict["is_training"] = is_training states = [] i = 0 - while "state_in_{}".format(i) in train_batch: - states.append(train_batch["state_in_{}".format(i)]) + while "state_in_{}".format(i) in input_dict: + states.append(input_dict["state_in_{}".format(i)]) i += 1 - ret = self.__call__(train_batch, states, train_batch.get("seq_lens")) - del train_batch["is_training"] + ret = self.__call__(input_dict, states, input_dict.get("seq_lens")) return ret def import_from_h5(self, h5_file: str) -> None: @@ -316,21 +317,57 @@ class ModelV2: # TODO: (sven) Experimental method. def get_input_dict(self, sample_batch, - index: int = -1) -> Dict[str, TensorType]: - if index < 0: - index = sample_batch.count - 1 + index: Union[int, str] = "last") -> ModelInputDict: + """Creates single ts input-dict at given index from a SampleBatch. + + Args: + sample_batch (SampleBatch): A single-trajectory SampleBatch object + to generate the compute_actions input dict from. + index (Union[int, str]): An integer index value indicating the + position in the trajectory for which to generate the + compute_actions input dict. Set to "last" to generate the dict + at the very end of the trajectory (e.g. for value estimation). + Note that "last" is different from -1, as "last" will use the + final NEXT_OBS as observation input. + + Returns: + ModelInputDict: The (single-timestep) input dict for ModelV2 calls. + """ + last_mappings = { + SampleBatch.OBS: SampleBatch.NEXT_OBS, + SampleBatch.PREV_ACTIONS: SampleBatch.ACTIONS, + SampleBatch.PREV_REWARDS: SampleBatch.REWARDS, + } input_dict = {} for view_col, view_req in self.inference_view_requirements.items(): # Create batches of size 1 (single-agent input-dict). - - # Index range. - if isinstance(index, tuple): - data = sample_batch[view_col][index[0]:index[1] + 1] - input_dict[view_col] = np.array([data]) - # Single index. + data_col = view_req.data_col or view_col + if index == "last": + data_col = last_mappings.get(data_col, data_col) + if view_req.shift_from is not None: + data = sample_batch[view_col][-1] + traj_len = len(sample_batch[data_col]) + missing_at_end = traj_len % view_req.batch_repeat_value + input_dict[view_col] = np.array([ + np.concatenate([ + data, sample_batch[data_col][-missing_at_end:] + ])[view_req.shift_from:view_req.shift_to + + 1 if view_req.shift_to != -1 else None] + ]) + else: + data = sample_batch[data_col][-1] + input_dict[view_col] = np.array([data]) else: - input_dict[view_col] = sample_batch[view_col][index:index + 1] + # Index range. + if isinstance(index, tuple): + data = sample_batch[data_col][index[0]:index[1] + 1 + if index[1] != -1 else None] + input_dict[view_col] = np.array([data]) + # Single index. + else: + input_dict[view_col] = sample_batch[data_col][ + index:index + 1 if index != -1 else None] # Add valid `seq_lens`, just in case RNNs need it. input_dict["seq_lens"] = np.array([1], dtype=np.int32) diff --git a/rllib/models/tests/test_attention_nets.py b/rllib/models/tests/test_attention_nets.py deleted file mode 100644 index ac6ec134d..000000000 --- a/rllib/models/tests/test_attention_nets.py +++ /dev/null @@ -1,263 +0,0 @@ -import gym -import numpy as np -import unittest - -from ray.rllib.models.tf.attention_net import relative_position_embedding, \ - GTrXLNet -from ray.rllib.models.tf.layers import MultiHeadAttention -from ray.rllib.models.torch.attention_net import relative_position_embedding \ - as relative_position_embedding_torch, GTrXLNet as TorchGTrXLNet -from ray.rllib.models.torch.modules.multi_head_attention import \ - MultiHeadAttention as TorchMultiHeadAttention -from ray.rllib.utils.framework import try_import_torch, try_import_tf -from ray.rllib.utils.test_utils import framework_iterator - -torch, nn = try_import_torch() -tf1, tf, tfv = try_import_tf() - - -class TestAttentionNets(unittest.TestCase): - """Tests various torch/modules and tf/layers required for AttentionNet""" - - def train_torch_full_model(self, - model, - inputs, - outputs, - num_epochs=250, - state=None, - seq_lens=None): - """Convenience method that trains a Torch model for num_epochs epochs - and tests whether loss decreased, as expected. - - Args: - model (nn.Module): Torch model to be trained. - inputs (torch.Tensor): Training data - outputs (torch.Tensor): Training labels - num_epochs (int): Number of epochs to train for - state (torch.Tensor): Internal state of module - seq_lens (torch.Tensor): Tensor of sequence lengths - """ - - criterion = torch.nn.MSELoss(reduction="sum") - optimizer = torch.optim.Adam(model.parameters(), lr=3e-4) - - # Check that the layer trains correctly - for t in range(num_epochs): - y_pred = model(inputs, state, seq_lens) - loss = criterion(y_pred[0], torch.squeeze(outputs[0])) - - if t % 10 == 1: - print(t, loss.item()) - - # if t == 0: - # init_loss = loss.item() - - optimizer.zero_grad() - loss.backward() - optimizer.step() - - # final_loss = loss.item() - - # The final loss has decreased, which tests - # that the model is learning from the training data. - # self.assertLess(final_loss / init_loss, 0.99) - - def train_torch_layer(self, model, inputs, outputs, num_epochs=250): - """Convenience method that trains a Torch model for num_epochs epochs - and tests whether loss decreased, as expected. - - Args: - model (nn.Module): Torch model to be trained. - inputs (torch.Tensor): Training data - outputs (torch.Tensor): Training labels - num_epochs (int): Number of epochs to train for - """ - criterion = torch.nn.MSELoss(reduction="sum") - optimizer = torch.optim.SGD(model.parameters(), lr=1e-4) - - # Check that the layer trains correctly - for t in range(num_epochs): - y_pred = model(inputs) - loss = criterion(y_pred, outputs) - - if t == 1: - init_loss = loss.item() - - optimizer.zero_grad() - loss.backward() - optimizer.step() - - final_loss = loss.item() - - # The final loss has decreased by a factor of 2, which tests - # that the model is learning from the training data. - self.assertLess(final_loss / init_loss, 0.5) - - def train_tf_model(self, - model, - inputs, - outputs, - num_epochs=250, - minibatch_size=32): - """Convenience method that trains a Tensorflow model for num_epochs - epochs and tests whether loss decreased, as expected. - - Args: - model (tf.Model): Torch model to be trained. - inputs (np.array): Training data - outputs (np.array): Training labels - num_epochs (int): Number of training epochs - batch_size (int): Number of samples in each minibatch - """ - - # Configure a model for mean-squared error loss. - model.compile(optimizer="SGD", loss="mse", metrics=["mae"]) - - hist = model.fit( - inputs, - outputs, - verbose=0, - epochs=num_epochs, - batch_size=minibatch_size).history - init_loss = hist["loss"][0] - final_loss = hist["loss"][-1] - - self.assertLess(final_loss / init_loss, 0.5) - - def test_multi_head_attention(self): - """Tests the MultiHeadAttention mechanism of Vaswani et al.""" - # B is batch size - B = 1 - # D_in is attention dim, L is memory_tau - L, D_in, D_out = 2, 32, 10 - - for fw, sess in framework_iterator( - frameworks=("tfe", "torch", "tf"), session=True): - # Create a single attention layer with 2 heads. - if fw == "torch": - - # Create random Tensors to hold inputs and outputs - x = torch.randn(B, L, D_in) - y = torch.randn(B, L, D_out) - - model = TorchMultiHeadAttention( - in_dim=D_in, out_dim=D_out, num_heads=2, head_dim=32) - - self.train_torch_layer(model, x, y, num_epochs=500) - - # Framework is tensorflow or tensorflow-eager. - else: - x = np.random.random((B, L, D_in)) - y = np.random.random((B, L, D_out)) - - inputs = tf.keras.layers.Input(shape=(L, D_in)) - - model = tf.keras.Sequential([ - inputs, - MultiHeadAttention( - out_dim=D_out, num_heads=2, head_dim=32) - ]) - self.train_tf_model(model, x, y) - - def test_attention_net(self): - """Tests the GTrXL. - - Builds a full AttentionNet and checks that it trains in a supervised - setting.""" - - # Checks that torch and tf embedding matrices are the same - with tf1.Session().as_default() as sess: - assert np.allclose( - relative_position_embedding(20, 15).eval(session=sess), - relative_position_embedding_torch(20, 15).numpy()) - - # B is batch size - B = 32 - # D_in is attention dim, L is memory_tau - L, D_in, D_out = 2, 16, 2 - - for fw, sess in framework_iterator(session=True): - - # Create a single attention layer with 2 heads - if fw == "torch": - # Create random Tensors to hold inputs and outputs - x = torch.randn(B, L, D_in) - y = torch.randn(B, L, D_out) - - value_labels = torch.randn(B, L, D_in) - memory_labels = torch.randn(B, L, D_out) - - attention_net = TorchGTrXLNet( - observation_space=gym.spaces.Box( - low=float("-inf"), high=float("inf"), shape=(D_in, )), - action_space=gym.spaces.Discrete(D_out), - num_outputs=D_out, - model_config={"max_seq_len": 2}, - name="TestTorchAttentionNet", - num_transformer_units=2, - attn_dim=D_in, - num_heads=2, - memory_tau=L, - head_dim=D_out, - ff_hidden_dim=16, - init_gate_bias=2.0) - - init_state = attention_net.get_initial_state() - - # Get initial state and add a batch dimension. - init_state = [np.expand_dims(s, 0) for s in init_state] - seq_lens_init = torch.full( - size=(B, ), fill_value=L, dtype=torch.int32) - - # Torch implementation expects a formatted input_dict instead - # of a numpy array as input. - input_dict = {"obs": x} - self.train_torch_full_model( - attention_net, - input_dict, [y, value_labels, memory_labels], - num_epochs=250, - state=init_state, - seq_lens=seq_lens_init) - # Framework is tensorflow or tensorflow-eager. - else: - x = np.random.random((B, L, D_in)) - y = np.random.random((B, L, D_out)) - - value_labels = np.random.random((B, L, 1)) - memory_labels = np.random.random((B, L, D_in)) - - # We need to create (N-1) MLP labels for N transformer units - mlp_labels = np.random.random((B, L, D_in)) - - attention_net = GTrXLNet( - observation_space=gym.spaces.Box( - low=float("-inf"), high=float("inf"), shape=(D_in, )), - action_space=gym.spaces.Discrete(D_out), - num_outputs=D_out, - model_config={"max_seq_len": 2}, - name="TestTFAttentionNet", - num_transformer_units=2, - attn_dim=D_in, - num_heads=2, - memory_tau=L, - head_dim=D_out, - ff_hidden_dim=16, - init_gate_bias=2.0) - model = attention_net.trxl_model - - # Get initial state and add a batch dimension. - init_state = attention_net.get_initial_state() - init_state = [np.tile(s, (B, 1, 1)) for s in init_state] - - self.train_tf_model( - model, [x] + init_state, - [y, value_labels, memory_labels, mlp_labels], - num_epochs=200, - minibatch_size=B) - - -if __name__ == "__main__": - import pytest - import sys - - sys.exit(pytest.main(["-v", __file__])) diff --git a/rllib/models/tf/attention_net.py b/rllib/models/tf/attention_net.py index 2ddbaf33b..ef49f4610 100644 --- a/rllib/models/tf/attention_net.py +++ b/rllib/models/tf/attention_net.py @@ -8,14 +8,17 @@ Z. Dai, Z. Yang, et al. - Carnegie Mellon U - 2019. https://www.aclweb.org/anthology/P19-1285.pdf """ +from gym.spaces import Box import numpy as np import gym -from typing import Optional, Any +from typing import Any, Optional from ray.rllib.models.modelv2 import ModelV2 from ray.rllib.models.tf.layers import GRUGate, RelativeMultiHeadAttention, \ SkipConnection from ray.rllib.models.tf.recurrent_net import RecurrentNetwork +from ray.rllib.policy.sample_batch import SampleBatch +from ray.rllib.policy.view_requirement import ViewRequirement from ray.rllib.utils.annotations import override from ray.rllib.utils.framework import try_import_tf from ray.rllib.utils.typing import ModelConfigDict, TensorType, List @@ -60,7 +63,7 @@ class TrXLNet(RecurrentNetwork): model_config: ModelConfigDict, name: str, num_transformer_units: int, attn_dim: int, num_heads: int, head_dim: int, ff_hidden_dim: int): - """Initializes a TfXLNet object. + """Initializes a TrXLNet object. Args: num_transformer_units (int): The number of Transformer repeats to @@ -88,8 +91,6 @@ class TrXLNet(RecurrentNetwork): self.max_seq_len = model_config["max_seq_len"] self.obs_dim = observation_space.shape[0] - pos_embedding = relative_position_embedding(self.max_seq_len, attn_dim) - inputs = tf.keras.layers.Input( shape=(self.max_seq_len, self.obs_dim), name="inputs") E_out = tf.keras.layers.Dense(attn_dim)(inputs) @@ -100,7 +101,6 @@ class TrXLNet(RecurrentNetwork): out_dim=attn_dim, num_heads=num_heads, head_dim=head_dim, - rel_pos_encoder=pos_embedding, input_layernorm=False, output_activation=None), fan_in_layer=None)(E_out) @@ -160,7 +160,8 @@ class GTrXLNet(RecurrentNetwork): >> num_transformer_units=1, >> attn_dim=32, >> num_heads=2, - >> memory_tau=50, + >> memory_inference=100, + >> memory_training=50, >> etc.. >> } """ @@ -174,11 +175,12 @@ class GTrXLNet(RecurrentNetwork): num_transformer_units: int, attn_dim: int, num_heads: int, - memory_tau: int, + memory_inference: int, + memory_training: int, head_dim: int, ff_hidden_dim: int, init_gate_bias: float = 2.0): - """Initializes a GTrXLNet. + """Initializes a GTrXLNet instance. Args: num_transformer_units (int): The number of Transformer repeats to @@ -187,9 +189,15 @@ class GTrXLNet(RecurrentNetwork): unit. num_heads (int): The number of attention heads to use in parallel. Denoted as `H` in [3]. - memory_tau (int): The number of timesteps to store in each - transformer block's memory M (concat'd over time and fed into - next transformer block as input). + memory_inference (int): The number of timesteps to concat (time + axis) and feed into the next transformer unit as inference + input. The first transformer unit will receive this number of + past observations (plus the current one), instead. + memory_training (int): The number of timesteps to concat (time + axis) and feed into the next transformer unit as training + input (plus the actual input sequence of len=max_seq_len). + The first transformer unit will receive this number of + past observations (plus the input sequence), instead. head_dim (int): The dimension of a single(!) head. Denoted as `d` in [3]. ff_hidden_dim (int): The dimension of the hidden layer within @@ -208,21 +216,18 @@ class GTrXLNet(RecurrentNetwork): self.num_transformer_units = num_transformer_units self.attn_dim = attn_dim self.num_heads = num_heads - self.memory_tau = memory_tau + self.memory_inference = memory_inference + self.memory_training = memory_training self.head_dim = head_dim self.max_seq_len = model_config["max_seq_len"] self.obs_dim = observation_space.shape[0] - # Constant (non-trainable) sinusoid rel pos encoding matrix. - Phi = relative_position_embedding(self.max_seq_len + self.memory_tau, - self.attn_dim) - - # Raw observation input. + # Raw observation input (plus (None) time axis). input_layer = tf.keras.layers.Input( - shape=(self.max_seq_len, self.obs_dim), name="inputs") + shape=(None, self.obs_dim), name="inputs") memory_ins = [ tf.keras.layers.Input( - shape=(self.memory_tau, self.attn_dim), + shape=(None, self.attn_dim), dtype=tf.float32, name="memory_in_{}".format(i)) for i in range(self.num_transformer_units) @@ -242,7 +247,6 @@ class GTrXLNet(RecurrentNetwork): out_dim=self.attn_dim, num_heads=num_heads, head_dim=head_dim, - rel_pos_encoder=Phi, input_layernorm=True, output_activation=tf.nn.relu), fan_in_layer=GRUGate(init_gate_bias), @@ -280,69 +284,52 @@ class GTrXLNet(RecurrentNetwork): self.register_variables(self.trxl_model.variables) self.trxl_model.summary() - @override(RecurrentNetwork) - def forward_rnn(self, inputs: TensorType, state: List[TensorType], - seq_lens: TensorType) -> (TensorType, List[TensorType]): - # To make Attention work with current RLlib's ModelV2 API: - # We assume `state` is the history of L recent observations (all - # concatenated into one tensor) and append the current inputs to the - # end and only keep the most recent (up to `max_seq_len`). This allows - # us to deal with timestep-wise inference and full sequence training - # within the same logic. - observations = state[0] - memory = state[1:] + # Setup inference view (`memory-inference` x past observations + + # current one (0)) + # 1 to `num_transformer_units`: Memory data (one per transformer unit). + for i in range(self.num_transformer_units): + space = Box(-1.0, 1.0, shape=(self.attn_dim, )) + self.inference_view_requirements["state_in_{}".format(i)] = \ + ViewRequirement( + "state_out_{}".format(i), + shift="-{}:-1".format(self.memory_inference), + # Repeat the incoming state every max-seq-len times. + batch_repeat_value=self.max_seq_len, + space=space) + self.inference_view_requirements["state_out_{}".format(i)] = \ + ViewRequirement( + space=space, + used_for_training=False) - observations = tf.concat( - (observations, inputs), axis=1)[:, -self.max_seq_len:] - all_out = self.trxl_model([observations] + memory) - logits, self._value_out = all_out[0], all_out[1] + @override(ModelV2) + def forward(self, input_dict, state: List[TensorType], + seq_lens: TensorType) -> (TensorType, List[TensorType]): + assert seq_lens is not None + + # Add the time dim to observations. + B = tf.shape(seq_lens)[0] + observations = input_dict[SampleBatch.OBS] + + shape = tf.shape(observations) + T = shape[0] // B + observations = tf.reshape(observations, + tf.concat([[-1, T], shape[1:]], axis=0)) + + all_out = self.trxl_model([observations] + state) + + logits = all_out[0] + self._value_out = all_out[1] memory_outs = all_out[2:] - # If memory_tau > max_seq_len -> overlap w/ previous `memory` input. - if self.memory_tau > self.max_seq_len: - memory_outs = [ - tf.concat( - [memory[i][:, -(self.memory_tau - self.max_seq_len):], m], - axis=1) for i, m in enumerate(memory_outs) - ] - else: - memory_outs = [m[:, -self.memory_tau:] for m in memory_outs] - T = tf.shape(inputs)[1] # Length of input segment (time). - logits = logits[:, -T:] - self._value_out = self._value_out[:, -T:] - - return logits, [observations] + memory_outs + return tf.reshape(logits, [-1, self.num_outputs]), [ + tf.reshape(m, [-1, self.attn_dim]) for m in memory_outs + ] # TODO: (sven) Deprecate this once trajectory view API has fully matured. @override(RecurrentNetwork) def get_initial_state(self) -> List[np.ndarray]: - # State is the T last observations concat'd together into one Tensor. - # Plus all Transformer blocks' E(l) outputs concat'd together (up to - # tau timesteps). - return [np.zeros((self.max_seq_len, self.obs_dim), np.float32)] + \ - [np.zeros((self.memory_tau, self.attn_dim), np.float32) - for _ in range(self.num_transformer_units)] + return [] @override(ModelV2) def value_function(self) -> TensorType: return tf.reshape(self._value_out, [-1]) - - -def relative_position_embedding(seq_length: int, out_dim: int) -> TensorType: - """Creates a [seq_length x seq_length] matrix for rel. pos encoding. - - Denoted as Phi in [2] and [3]. Phi is the standard sinusoid encoding - matrix. - - Args: - seq_length (int): The max. sequence length (time axis). - out_dim (int): The number of nodes to go into the first Tranformer - layer with. - - Returns: - tf.Tensor: The encoding matrix Phi. - """ - inverse_freq = 1 / (10000**(tf.range(0, out_dim, 2.0) / out_dim)) - pos_offsets = tf.range(seq_length - 1., -1., -1.) - inputs = pos_offsets[:, None] * inverse_freq[None, :] - return tf.concat((tf.sin(inputs), tf.cos(inputs)), axis=-1) diff --git a/rllib/models/tf/layers/__init__.py b/rllib/models/tf/layers/__init__.py index 68ae2ea53..0661aac98 100644 --- a/rllib/models/tf/layers/__init__.py +++ b/rllib/models/tf/layers/__init__.py @@ -1,11 +1,11 @@ from ray.rllib.models.tf.layers.gru_gate import GRUGate from ray.rllib.models.tf.layers.noisy_layer import NoisyLayer from ray.rllib.models.tf.layers.relative_multi_head_attention import \ - RelativeMultiHeadAttention + PositionalEmbedding, RelativeMultiHeadAttention from ray.rllib.models.tf.layers.skip_connection import SkipConnection from ray.rllib.models.tf.layers.multi_head_attention import MultiHeadAttention __all__ = [ - "GRUGate", "MultiHeadAttention", "NoisyLayer", + "GRUGate", "MultiHeadAttention", "NoisyLayer", "PositionalEmbedding", "RelativeMultiHeadAttention", "SkipConnection" ] diff --git a/rllib/models/tf/layers/relative_multi_head_attention.py b/rllib/models/tf/layers/relative_multi_head_attention.py index f7d70ab60..840449e1c 100644 --- a/rllib/models/tf/layers/relative_multi_head_attention.py +++ b/rllib/models/tf/layers/relative_multi_head_attention.py @@ -1,4 +1,4 @@ -from typing import Optional, Any +from typing import Optional from ray.rllib.utils.framework import try_import_tf from ray.rllib.utils.typing import TensorType @@ -16,9 +16,8 @@ class RelativeMultiHeadAttention(tf.keras.layers.Layer if tf else object): out_dim: int, num_heads: int, head_dim: int, - rel_pos_encoder: Any, input_layernorm: bool = False, - output_activation: Optional[Any] = None, + output_activation: Optional["tf.nn.activation"] = None, **kwargs): """Initializes a RelativeMultiHeadAttention keras Layer object. @@ -28,7 +27,6 @@ class RelativeMultiHeadAttention(tf.keras.layers.Layer if tf else object): Denoted `H` in [2]. head_dim (int): The dimension of a single(!) attention head Denoted `D` in [2]. - rel_pos_encoder (: input_layernorm (bool): Whether to prepend a LayerNorm before everything else. Should be True for building a GTrXL. output_activation (Optional[tf.nn.activation]): Optional tf.nn @@ -50,9 +48,14 @@ class RelativeMultiHeadAttention(tf.keras.layers.Layer if tf else object): self._uvar = self.add_weight(shape=(num_heads, head_dim)) self._vvar = self.add_weight(shape=(num_heads, head_dim)) + # Constant (non-trainable) sinusoid rel pos encoding matrix, which + # depends on this incoming time dimension. + # For inference, we prepend the memory to the current timestep's + # input: Tau + 1. For training, we prepend the memory to the input + # sequence: Tau + T. + self._pos_embedding = PositionalEmbedding(out_dim) self._pos_proj = tf.keras.layers.Dense( num_heads * head_dim, use_bias=False) - self._rel_pos_encoder = rel_pos_encoder self._input_layernorm = None if input_layernorm: @@ -66,9 +69,8 @@ class RelativeMultiHeadAttention(tf.keras.layers.Layer if tf else object): # Add previous memory chunk (as const, w/o gradient) to input. # Tau (number of (prev) time slices in each memory chunk). - Tau = memory.shape.as_list()[1] if memory is not None else 0 - if memory is not None: - inputs = tf.concat((tf.stop_gradient(memory), inputs), axis=1) + Tau = tf.shape(memory)[1] + inputs = tf.concat([tf.stop_gradient(memory), inputs], axis=1) # Apply the Layer-Norm. if self._input_layernorm is not None: @@ -77,15 +79,17 @@ class RelativeMultiHeadAttention(tf.keras.layers.Layer if tf else object): qkv = self._qkv_layer(inputs) queries, keys, values = tf.split(qkv, 3, -1) - # Cut out Tau memory timesteps from query. + # Cut out memory timesteps from query. queries = queries[:, -T:] + # Splitting up queries into per-head dims (d). queries = tf.reshape(queries, [-1, T, H, d]) - keys = tf.reshape(keys, [-1, T + Tau, H, d]) - values = tf.reshape(values, [-1, T + Tau, H, d]) + keys = tf.reshape(keys, [-1, Tau + T, H, d]) + values = tf.reshape(values, [-1, Tau + T, H, d]) - R = self._pos_proj(self._rel_pos_encoder) - R = tf.reshape(R, [T + Tau, H, d]) + R = self._pos_embedding(Tau + T) + R = self._pos_proj(R) + R = tf.reshape(R, [Tau + T, H, d]) # b=batch # i and j=time indices (i=max-timesteps (inputs); j=Tau memory space) @@ -96,9 +100,9 @@ class RelativeMultiHeadAttention(tf.keras.layers.Layer if tf else object): score = score + self.rel_shift(pos_score) score = score / d**0.5 - # causal mask of the same length as the sequence + # Causal mask of the same length as the sequence. mask = tf.sequence_mask( - tf.range(Tau + 1, T + Tau + 1), dtype=score.dtype) + tf.range(Tau + 1, Tau + T + 1), dtype=score.dtype) mask = mask[None, :, :, None] masked_score = score * mask + 1e30 * (mask - 1.) @@ -121,3 +125,14 @@ class RelativeMultiHeadAttention(tf.keras.layers.Layer if tf else object): x = tf.reshape(x, x_size) return x + + +class PositionalEmbedding(tf.keras.layers.Layer if tf else object): + def __init__(self, out_dim, **kwargs): + super().__init__(**kwargs) + self.inverse_freq = 1 / (10000**(tf.range(0, out_dim, 2.0) / out_dim)) + + def call(self, seq_length): + pos_offsets = tf.cast(tf.range(seq_length - 1, -1, -1), tf.float32) + inputs = pos_offsets[:, None] * self.inverse_freq[None, :] + return tf.concat((tf.sin(inputs), tf.cos(inputs)), axis=-1) diff --git a/rllib/models/tf/layers/skip_connection.py b/rllib/models/tf/layers/skip_connection.py index efb89f2e3..a44ae2bc1 100644 --- a/rllib/models/tf/layers/skip_connection.py +++ b/rllib/models/tf/layers/skip_connection.py @@ -16,7 +16,6 @@ class SkipConnection(tf.keras.layers.Layer if tf else object): def __init__(self, layer: Any, fan_in_layer: Optional[Any] = None, - add_memory: bool = False, **kwargs): """Initializes a SkipConnection keras layer object. diff --git a/rllib/models/torch/modules/skip_connection.py b/rllib/models/torch/modules/skip_connection.py index 126274b1d..8d79b7826 100644 --- a/rllib/models/torch/modules/skip_connection.py +++ b/rllib/models/torch/modules/skip_connection.py @@ -15,7 +15,6 @@ class SkipConnection(nn.Module): def __init__(self, layer: nn.Module, fan_in_layer: Optional[nn.Module] = None, - add_memory: bool = False, **kwargs): """Initializes a SkipConnection nn Module object. diff --git a/rllib/policy/dynamic_tf_policy.py b/rllib/policy/dynamic_tf_policy.py index 432e384f2..39b31f63b 100644 --- a/rllib/policy/dynamic_tf_policy.py +++ b/rllib/policy/dynamic_tf_policy.py @@ -183,11 +183,12 @@ class DynamicTFPolicy(TFPolicy): else: if self.config["_use_trajectory_view_api"]: self._state_inputs = [ - tf1.placeholder( - shape=(None, ) + vr.space.shape, dtype=vr.space.dtype) - for k, vr in + get_placeholder( + space=vr.space, + time_axis=not isinstance(vr.shift, int), + ) for k, vr in self.model.inference_view_requirements.items() - if k[:9] == "state_in_" + if k.startswith("state_in_") ] else: self._state_inputs = [ @@ -423,9 +424,14 @@ class DynamicTFPolicy(TFPolicy): input_dict[view_col] = existing_inputs[view_col] # All others. else: + time_axis = not isinstance(view_req.shift, int) if view_req.used_for_training: + # Create a +time-axis placeholder if the shift is not an + # int (range or list of ints). input_dict[view_col] = get_placeholder( - space=view_req.space, name=view_col) + space=view_req.space, + name=view_col, + time_axis=time_axis) dummy_batch = self._get_dummy_batch_from_view_requirements( batch_size=32) @@ -490,10 +496,10 @@ class DynamicTFPolicy(TFPolicy): dummy_batch["seq_lens"] = np.array([1], dtype=np.int32) for k, v in self.extra_compute_action_fetches().items(): dummy_batch[k] = fake_array(v) + dummy_batch = SampleBatch(dummy_batch) - sb = SampleBatch(dummy_batch) - batch_for_postproc = UsageTrackingDict(sb) - batch_for_postproc.count = sb.count + batch_for_postproc = UsageTrackingDict(dummy_batch) + batch_for_postproc.count = dummy_batch.count logger.info("Testing `postprocess_trajectory` w/ dummy batch.") self.exploration.postprocess_trajectory(self, batch_for_postproc, self._sess) @@ -519,6 +525,7 @@ class DynamicTFPolicy(TFPolicy): train_batch.update({ SampleBatch.PREV_ACTIONS: self._prev_action_input, SampleBatch.PREV_REWARDS: self._prev_reward_input, + SampleBatch.CUR_OBS: self._obs_input, }) for k, v in postprocessed_batch.items(): @@ -578,7 +585,8 @@ class DynamicTFPolicy(TFPolicy): for key in batch_for_postproc.accessed_keys: if key not in train_batch.accessed_keys and \ key not in self.model.inference_view_requirements: - self.view_requirements[key].used_for_training = False + if key in self.view_requirements: + self.view_requirements[key].used_for_training = False if key in self._loss_input_dict: del self._loss_input_dict[key] # Remove those not needed at all (leave those that are needed diff --git a/rllib/policy/eager_tf_policy.py b/rllib/policy/eager_tf_policy.py index 758cfc948..f17d60e06 100644 --- a/rllib/policy/eager_tf_policy.py +++ b/rllib/policy/eager_tf_policy.py @@ -314,12 +314,16 @@ def build_eager_tf_policy(name, self.callbacks.on_learn_on_batch( policy=self, train_batch=postprocessed_batch) - # Get batch ready for RNNs, if applicable. pad_batch_to_sequences_of_same_size( postprocessed_batch, shuffle=False, max_seq_len=self._max_seq_len, - batch_divisibility_req=self.batch_divisibility_req) + batch_divisibility_req=self.batch_divisibility_req, + view_requirements=self.view_requirements, + ) + + self._is_training = True + postprocessed_batch["is_training"] = True return self._learn_on_batch_eager(postprocessed_batch) @convert_eager_inputs @@ -332,12 +336,14 @@ def build_eager_tf_policy(name, @override(Policy) def compute_gradients(self, samples): - # Get batch ready for RNNs, if applicable. pad_batch_to_sequences_of_same_size( samples, shuffle=False, max_seq_len=self._max_seq_len, batch_divisibility_req=self.batch_divisibility_req) + + self._is_training = True + samples["is_training"] = True return self._compute_gradients_eager(samples) @convert_eager_inputs @@ -369,7 +375,7 @@ def build_eager_tf_policy(name, # TODO: remove python side effect to cull sources of bugs. self._is_training = False - self._state_in = state_batches + self._state_in = state_batches or [] if not tf1.executing_eagerly(): tf1.enable_eager_execution() @@ -591,8 +597,6 @@ def build_eager_tf_policy(name, def _compute_gradients(self, samples): """Computes and returns grads as eager tensors.""" - self._is_training = True - with tf.GradientTape(persistent=gradients_fn is not None) as tape: loss = loss_fn(self, self.model, self.dist_class, samples) diff --git a/rllib/policy/policy.py b/rllib/policy/policy.py index a1e92ac37..4695e366f 100644 --- a/rllib/policy/policy.py +++ b/rllib/policy/policy.py @@ -629,10 +629,9 @@ class Policy(metaclass=ABCMeta): batch_for_postproc.count = self._dummy_batch.count self.exploration.postprocess_trajectory(self, batch_for_postproc) postprocessed_batch = self.postprocess_trajectory(batch_for_postproc) + seq_lens = None if state_outs: B = 4 # For RNNs, have B=4, T=[depends on sample_batch_size] - # TODO: (sven) This hack will not work for attention net traj. - # view setup. i = 0 while "state_in_{}".format(i) in postprocessed_batch: postprocessed_batch["state_in_{}".format(i)] = \ @@ -642,12 +641,11 @@ class Policy(metaclass=ABCMeta): postprocessed_batch["state_out_{}".format(i)][:B] i += 1 seq_len = sample_batch_size // B - postprocessed_batch["seq_lens"] = \ - np.array([seq_len for _ in range(B)], dtype=np.int32) - # Remove the UsageTrackingDict wrap to prep for wrapping the - # train batch with a to-tensor UsageTrackingDict. - train_batch = {k: v for k, v in postprocessed_batch.items()} - train_batch = self._lazy_tensor_dict(train_batch) + seq_lens = np.array([seq_len for _ in range(B)], dtype=np.int32) + # Wrap `train_batch` with a to-tensor UsageTrackingDict. + train_batch = self._lazy_tensor_dict(postprocessed_batch) + if seq_lens is not None: + train_batch["seq_lens"] = seq_lens train_batch.count = self._dummy_batch.count # Call the loss function, if it exists. if self._loss is not None: @@ -712,13 +710,33 @@ class Policy(metaclass=ABCMeta): ret[view_col] = \ np.zeros((batch_size, ) + shape[1:], np.float32) else: - if isinstance(view_req.space, gym.spaces.Space): - ret[view_col] = np.zeros_like( - [view_req.space.sample() for _ in range(batch_size)]) + # Range of indices on time-axis, e.g. "-50:-1". + if view_req.shift_from is not None: + ret[view_col] = np.zeros_like([[ + view_req.space.sample() + for _ in range(view_req.shift_to - + view_req.shift_from + 1) + ] for _ in range(batch_size)]) + # Set of (probably non-consecutive) indices. + elif isinstance(view_req.shift, (list, tuple)): + ret[view_col] = np.zeros_like([[ + view_req.space.sample() + for t in range(len(view_req.shift)) + ] for _ in range(batch_size)]) + # Single shift int value. else: - ret[view_col] = [view_req.space for _ in range(batch_size)] + if isinstance(view_req.space, gym.spaces.Space): + ret[view_col] = np.zeros_like([ + view_req.space.sample() for _ in range(batch_size) + ]) + else: + ret[view_col] = [ + view_req.space for _ in range(batch_size) + ] - return SampleBatch(ret) + # Due to different view requirements for the different columns, + # columns in the resulting batch may not all have the same batch size. + return SampleBatch(ret, _dont_check_lens=True) def _update_model_inference_view_requirements_from_init_state(self): """Uses Model's (or this Policy's) init state to add needed ViewReqs. @@ -737,8 +755,13 @@ class Policy(metaclass=ABCMeta): view_reqs = model.inference_view_requirements if model else \ self.view_requirements view_reqs["state_in_{}".format(i)] = ViewRequirement( - "state_out_{}".format(i), shift=-1, space=space) - view_reqs["state_out_{}".format(i)] = ViewRequirement(space=space) + "state_out_{}".format(i), + shift=-1, + batch_repeat_value=self.config.get("model", {}).get( + "max_seq_len", 1), + space=space) + view_reqs["state_out_{}".format(i)] = ViewRequirement( + space=space, used_for_training=True) def clip_action(action, action_space): diff --git a/rllib/policy/rnn_sequencing.py b/rllib/policy/rnn_sequencing.py index 486bbf0db..1cf3fc4aa 100644 --- a/rllib/policy/rnn_sequencing.py +++ b/rllib/policy/rnn_sequencing.py @@ -19,7 +19,7 @@ from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.utils.annotations import DeveloperAPI from ray.rllib.utils.debug import summarize from ray.rllib.utils.framework import try_import_tf, try_import_torch -from ray.rllib.utils.typing import TensorType +from ray.rllib.utils.typing import TensorType, ViewRequirementsDict from ray.util import log_once tf1, tf, tfv = try_import_tf() @@ -35,6 +35,7 @@ def pad_batch_to_sequences_of_same_size( shuffle: bool = False, batch_divisibility_req: int = 1, feature_keys: Optional[List[str]] = None, + view_requirements: Optional[ViewRequirementsDict] = None, ): """Applies padding to `batch` so it's choppable into same-size sequences. @@ -55,6 +56,9 @@ def pad_batch_to_sequences_of_same_size( feature_keys (Optional[List[str]]): An optional list of keys to apply sequence-chopping to. If None, use all keys in batch that are not "state_in/out_"-type keys. + view_requirements (Optional[ViewRequirementsDict]): An optional + Policy ViewRequirements dict to be able to infer whether + e.g. dynamic max'ing should be applied over the seq_lens. """ if batch_divisibility_req > 1: meets_divisibility_reqs = ( @@ -64,46 +68,65 @@ def pad_batch_to_sequences_of_same_size( else: meets_divisibility_reqs = True - # RNN-case. + states_already_reduced_to_init = False + + # RNN/attention net case. Figure out whether we should apply dynamic + # max'ing over the list of sequence lengths. if "state_in_0" in batch or "state_out_0" in batch: - dynamic_max = True + # Check, whether the state inputs have already been reduced to their + # init values at the beginning of each max_seq_len chunk. + if batch.seq_lens is not None and \ + len(batch["state_in_0"]) == len(batch.seq_lens): + states_already_reduced_to_init = True + + # RNN (or single timestep state-in): Set the max dynamically. + if view_requirements["state_in_0"].shift_from is None: + dynamic_max = True + # Attention Nets (state inputs are over some range): No dynamic maxing + # possible. + else: + dynamic_max = False # Multi-agent case. elif not meets_divisibility_reqs: max_seq_len = batch_divisibility_req dynamic_max = False - # Simple case: not RNN nor do we need to pad. + # Simple case: No RNN/attention net, nor do we need to pad. else: if shuffle: batch.shuffle() return - # RNN or multi-agent case. + # RNN, attention net, or multi-agent case. state_keys = [] feature_keys_ = feature_keys or [] - for k in batch.keys(): - if "state_in_" in k: + for k, v in batch.items(): + if k.startswith("state_in_"): state_keys.append(k) - elif not feature_keys and "state_out_" not in k and k != "infos": + elif not feature_keys and not k.startswith("state_out_") and \ + k not in ["infos", "seq_lens"] and isinstance(v, np.ndarray): feature_keys_.append(k) feature_sequences, initial_states, seq_lens = \ chop_into_sequences( - batch[SampleBatch.EPS_ID], - batch[SampleBatch.UNROLL_ID], - batch[SampleBatch.AGENT_INDEX], - [batch[k] for k in feature_keys_], - [batch[k] for k in state_keys], - max_seq_len, + feature_columns=[batch[k] for k in feature_keys_], + state_columns=[batch[k] for k in state_keys], + episode_ids=batch.get(SampleBatch.EPS_ID), + unroll_ids=batch.get(SampleBatch.UNROLL_ID), + agent_indices=batch.get(SampleBatch.AGENT_INDEX), + seq_lens=getattr(batch, "seq_lens", batch.get("seq_lens")), + max_seq_len=max_seq_len, dynamic_max=dynamic_max, + states_already_reduced_to_init=states_already_reduced_to_init, shuffle=shuffle) + for i, k in enumerate(feature_keys_): batch[k] = feature_sequences[i] for i, k in enumerate(state_keys): batch[k] = initial_states[i] - batch["seq_lens"] = seq_lens + batch["seq_lens"] = np.array(seq_lens) if log_once("rnn_ma_feed_dict"): - logger.info("Padded input for RNN:\n\n{}\n".format( + logger.info("Padded input for RNN/Attn.Nets/MA:\n\n{}\n".format( summarize({ "features": feature_sequences, "initial_states": initial_states, @@ -157,18 +180,18 @@ def add_time_dimension(padded_inputs: TensorType, return torch.reshape(padded_inputs, new_shape) -# NOTE: This function will be deprecated once chunks already come padded and -# correctly chopped from the _SampleCollector object (in time-major fashion -# or not). It is already no longer user iff `_use_trajectory_view_api` = True. @DeveloperAPI -def chop_into_sequences(episode_ids, - unroll_ids, - agent_indices, +def chop_into_sequences(*, feature_columns, state_columns, max_seq_len, + episode_ids=None, + unroll_ids=None, + agent_indices=None, dynamic_max=True, shuffle=False, + seq_lens=None, + states_already_reduced_to_init=False, _extra_padding=0): """Truncate and pad experiences into fixed-length sequences. @@ -212,23 +235,24 @@ def chop_into_sequences(episode_ids, [2, 3, 1] """ - prev_id = None - seq_lens = [] - seq_len = 0 - unique_ids = np.add( - np.add(episode_ids, agent_indices), - np.array(unroll_ids, dtype=np.int64) << 32) - for uid in unique_ids: - if (prev_id is not None and uid != prev_id) or \ - seq_len >= max_seq_len: + if seq_lens is None or len(seq_lens) == 0: + prev_id = None + seq_lens = [] + seq_len = 0 + unique_ids = np.add( + np.add(episode_ids, agent_indices), + np.array(unroll_ids, dtype=np.int64) << 32) + for uid in unique_ids: + if (prev_id is not None and uid != prev_id) or \ + seq_len >= max_seq_len: + seq_lens.append(seq_len) + seq_len = 0 + seq_len += 1 + prev_id = uid + if seq_len: seq_lens.append(seq_len) - seq_len = 0 - seq_len += 1 - prev_id = uid - if seq_len: - seq_lens.append(seq_len) - assert sum(seq_lens) == len(unique_ids) - seq_lens = np.array(seq_lens, dtype=np.int32) + seq_lens = np.array(seq_lens, dtype=np.int32) + assert sum(seq_lens) == len(feature_columns[0]) # Dynamically shrink max len as needed to optimize memory usage if dynamic_max: @@ -252,18 +276,23 @@ def chop_into_sequences(episode_ids, f_pad[seq_base + seq_offset] = f[i] i += 1 seq_base += max_seq_len - assert i == len(unique_ids), f + assert i == len(f), f feature_sequences.append(f_pad) - initial_states = [] - for s in state_columns: - s = np.array(s) - s_init = [] - i = 0 - for len_ in seq_lens: - s_init.append(s[i]) - i += len_ - initial_states.append(np.array(s_init)) + if states_already_reduced_to_init: + initial_states = state_columns + else: + initial_states = [] + for s in state_columns: + # Skip unnecessary copy. + if not isinstance(s, np.ndarray): + s = np.array(s) + s_init = [] + i = 0 + for len_ in seq_lens: + s_init.append(s[i]) + i += len_ + initial_states.append(np.array(s_init)) if shuffle: permutation = np.random.permutation(len(seq_lens)) diff --git a/rllib/policy/sample_batch.py b/rllib/policy/sample_batch.py index a2934fdb9..a1b4c43bc 100644 --- a/rllib/policy/sample_batch.py +++ b/rllib/policy/sample_batch.py @@ -61,6 +61,7 @@ class SampleBatch: # Possible seq_lens (TxB or BxT) setup. self.time_major = kwargs.pop("_time_major", None) self.seq_lens = kwargs.pop("_seq_lens", None) + self.dont_check_lens = kwargs.pop("_dont_check_lens", False) self.max_seq_len = None if self.seq_lens is not None and len(self.seq_lens) > 0: self.max_seq_len = max(self.seq_lens) @@ -76,8 +77,10 @@ class SampleBatch: self.data[k] = np.array(v) if not lengths: raise ValueError("Empty sample batch") - assert len(set(lengths)) == 1, \ - "Data columns must be same length, but lens are {}".format(lengths) + if not self.dont_check_lens: + assert len(set(lengths)) == 1, \ + "Data columns must be same length, but lens are " \ + "{}".format(lengths) if self.seq_lens is not None and len(self.seq_lens) > 0: self.count = sum(self.seq_lens) else: @@ -117,7 +120,8 @@ class SampleBatch: return SampleBatch( out, _seq_lens=np.array(seq_lens, dtype=np.int32), - _time_major=concat_samples[0].time_major) + _time_major=concat_samples[0].time_major, + _dont_check_lens=True) @PublicAPI def concat(self, other: "SampleBatch") -> "SampleBatch": @@ -248,12 +252,35 @@ class SampleBatch: SampleBatch: A new SampleBatch, which has a slice of this batch's data. """ - if self.time_major is not None: + if self.seq_lens is not None and len(self.seq_lens) > 0: + data = {k: v[start:end] for k, v in self.data.items()} + # Fix state_in_x data. + count = 0 + state_start = None + seq_lens = None + for i, seq_len in enumerate(self.seq_lens): + count += seq_len + if count >= end: + state_idx = 0 + state_key = "state_in_{}".format(state_idx) + while state_key in self.data: + data[state_key] = self.data[state_key][state_start:i + + 1] + state_idx += 1 + state_key = "state_in_{}".format(state_idx) + seq_lens = list(self.seq_lens[state_start:i]) + [ + seq_len - (count - end) + ] + assert sum(seq_lens) == (end - start) + break + elif state_start is None and count > start: + state_start = i + return SampleBatch( - {k: v[:, start:end] - for k, v in self.data.items()}, - _seq_lens=self.seq_lens[start:end], - _time_major=self.time_major) + data, + _seq_lens=np.array(seq_lens, dtype=np.int32), + _time_major=self.time_major, + _dont_check_lens=True) else: return SampleBatch( {k: v[start:end] diff --git a/rllib/policy/tf_policy.py b/rllib/policy/tf_policy.py index f6e48dad2..fe6ec900b 100644 --- a/rllib/policy/tf_policy.py +++ b/rllib/policy/tf_policy.py @@ -174,11 +174,6 @@ class TFPolicy(Policy): raise ValueError( "Number of state input and output tensors must match, got: " "{} vs {}".format(self._state_inputs, self._state_outputs)) - if len(self.get_initial_state()) != len(self._state_inputs): - raise ValueError( - "Length of initial state must match number of state inputs, " - "got: {} vs {}".format(self.get_initial_state(), - self._state_inputs)) if self._state_inputs and self._seq_lens is None: raise ValueError( "seq_lens tensor must be given if state inputs are defined") @@ -263,6 +258,11 @@ class TFPolicy(Policy): (name, tf1.placeholders) needed for calculating the loss. """ self._loss_input_dict = dict(loss_inputs) + self._loss_input_dict_no_rnn = { + k: v + for k, v in self._loss_input_dict.items() + if (v not in self._state_inputs and v != self._seq_lens) + } for i, ph in enumerate(self._state_inputs): self._loss_input_dict["state_in_{}".format(i)] = ph @@ -791,11 +791,11 @@ class TFPolicy(Policy): **fetches[LEARNER_STATS_KEY]) return fetches - def _get_loss_inputs_dict(self, batch, shuffle): + def _get_loss_inputs_dict(self, train_batch, shuffle): """Return a feed dict from a batch. Args: - batch (SampleBatch): batch of data to derive inputs from + train_batch (SampleBatch): batch of data to derive inputs from. shuffle (bool): whether to shuffle batch sequences. Shuffle may be done in-place. This only makes sense if you're further applying minibatch SGD after getting the outputs. @@ -806,28 +806,30 @@ class TFPolicy(Policy): # Get batch ready for RNNs, if applicable. pad_batch_to_sequences_of_same_size( - batch, + train_batch, shuffle=shuffle, max_seq_len=self._max_seq_len, batch_divisibility_req=self._batch_divisibility_req, - feature_keys=[ - k for k in self._loss_input_dict.keys() if k != "seq_lens" - ], + feature_keys=list(self._loss_input_dict_no_rnn.keys()), + view_requirements=self.view_requirements, ) - batch["is_training"] = True + + # Mark the batch as "is_training" so the Model can use this + # information. + train_batch["is_training"] = True # Build the feed dict from the batch. feed_dict = {} for key, placeholder in self._loss_input_dict.items(): - feed_dict[placeholder] = batch[key] + feed_dict[placeholder] = train_batch[key] state_keys = [ "state_in_{}".format(i) for i in range(len(self._state_inputs)) ] for key in state_keys: - feed_dict[self._loss_input_dict[key]] = batch[key] + feed_dict[self._loss_input_dict[key]] = train_batch[key] if state_keys: - feed_dict[self._seq_lens] = batch["seq_lens"] + feed_dict[self._seq_lens] = train_batch["seq_lens"] return feed_dict diff --git a/rllib/policy/torch_policy.py b/rllib/policy/torch_policy.py index f294b510d..c27a7603d 100644 --- a/rllib/policy/torch_policy.py +++ b/rllib/policy/torch_policy.py @@ -345,12 +345,13 @@ class TorchPolicy(Policy): @DeveloperAPI def compute_gradients(self, postprocessed_batch: SampleBatch) -> ModelGradients: - # Get batch ready for RNNs, if applicable. + pad_batch_to_sequences_of_same_size( postprocessed_batch, max_seq_len=self.max_seq_len, shuffle=False, batch_divisibility_req=self.batch_divisibility_req, + view_requirements=self.view_requirements, ) train_batch = self._lazy_tensor_dict(postprocessed_batch) diff --git a/rllib/policy/view_requirement.py b/rllib/policy/view_requirement.py index f9c7750d4..25a5e908a 100644 --- a/rllib/policy/view_requirement.py +++ b/rllib/policy/view_requirement.py @@ -1,4 +1,5 @@ import gym +import numpy as np from typing import List, Optional, Union from ray.rllib.utils.framework import try_import_torch @@ -29,8 +30,9 @@ class ViewRequirement: def __init__(self, data_col: Optional[str] = None, space: gym.Space = None, - shift: Union[int, List[int]] = 0, + shift: Union[int, str, List[int]] = 0, index: Optional[int] = None, + batch_repeat_value: int = 1, used_for_training: bool = True): """Initializes a ViewRequirement object. @@ -64,7 +66,19 @@ class ViewRequirement: self.space = space if space is not None else gym.spaces.Box( float("-inf"), float("inf"), shape=()) - self.index = index - self.shift = shift + if isinstance(self.shift, (list, tuple)): + self.shift = np.array(self.shift) + + # Special case: Providing a (probably larger) range of indices, e.g. + # "-100:0" (past 100 timesteps plus current one). + self.shift_from = self.shift_to = None + if isinstance(self.shift, str): + f, t = self.shift.split(":") + self.shift_from = int(f) + self.shift_to = int(t) + + self.index = index + self.batch_repeat_value = batch_repeat_value + self.used_for_training = used_for_training diff --git a/rllib/tests/test_attention_net_learning.py b/rllib/tests/test_attention_net_learning.py index b060651d6..35e5b3b08 100644 --- a/rllib/tests/test_attention_net_learning.py +++ b/rllib/tests/test_attention_net_learning.py @@ -44,7 +44,8 @@ class TestAttentionNetLearning(unittest.TestCase): "num_transformer_units": 1, "attn_dim": 32, "num_heads": 1, - "memory_tau": 5, + "memory_inference": 5, + "memory_training": 5, "head_dim": 32, "ff_hidden_dim": 32, }, @@ -71,7 +72,8 @@ class TestAttentionNetLearning(unittest.TestCase): # "num_transformer_units": 1, # "attn_dim": 64, # "num_heads": 1, - # "memory_tau": 10, + # "memory_inference": 10, + # "memory_training": 10, # "head_dim": 32, # "ff_hidden_dim": 32, # }, diff --git a/rllib/tests/test_lstm.py b/rllib/tests/test_lstm.py index 2685fa942..09b7aef73 100644 --- a/rllib/tests/test_lstm.py +++ b/rllib/tests/test_lstm.py @@ -18,9 +18,13 @@ class TestLSTMUtils(unittest.TestCase): f = [[101, 102, 103, 201, 202, 203, 204, 205], [[101], [102], [103], [201], [202], [203], [204], [205]]] s = [[209, 208, 207, 109, 108, 107, 106, 105]] - f_pad, s_init, seq_lens = chop_into_sequences(eps_ids, - np.ones_like(eps_ids), - agent_ids, f, s, 4) + f_pad, s_init, seq_lens = chop_into_sequences( + episode_ids=eps_ids, + unroll_ids=np.ones_like(eps_ids), + agent_indices=agent_ids, + feature_columns=f, + state_columns=s, + max_seq_len=4) self.assertEqual([f.tolist() for f in f_pad], [ [101, 102, 103, 0, 201, 202, 203, 204, 205, 0, 0, 0], [[101], [102], [103], [0], [201], [202], [203], [204], [205], [0], @@ -35,9 +39,13 @@ class TestLSTMUtils(unittest.TestCase): obs = np.ones((84, 84, 4)) f = [[obs, obs * 2, obs * 3]] s = [[209, 208, 207]] - f_pad, s_init, seq_lens = chop_into_sequences(eps_ids, - np.ones_like(eps_ids), - agent_ids, f, s, 4) + f_pad, s_init, seq_lens = chop_into_sequences( + episode_ids=eps_ids, + unroll_ids=np.ones_like(eps_ids), + agent_indices=agent_ids, + feature_columns=f, + state_columns=s, + max_seq_len=4) self.assertEqual([f.tolist() for f in f_pad], [ np.array([obs, obs * 2, obs * 3]).tolist(), ]) @@ -51,8 +59,13 @@ class TestLSTMUtils(unittest.TestCase): f = [[101, 102, 103, 201, 202, 203, 204, 205], [[101], [102], [103], [201], [202], [203], [204], [205]]] s = [[209, 208, 207, 109, 108, 107, 106, 105]] - _, _, seq_lens = chop_into_sequences(eps_ids, batch_ids, agent_ids, f, - s, 4) + _, _, seq_lens = chop_into_sequences( + episode_ids=eps_ids, + unroll_ids=batch_ids, + agent_indices=agent_ids, + feature_columns=f, + state_columns=s, + max_seq_len=4) self.assertEqual(seq_lens.tolist(), [2, 1, 1, 2, 2]) def test_multi_agent(self): @@ -62,12 +75,12 @@ class TestLSTMUtils(unittest.TestCase): [[101], [102], [103], [201], [202], [203], [204], [205]]] s = [[209, 208, 207, 109, 108, 107, 106, 105]] f_pad, s_init, seq_lens = chop_into_sequences( - eps_ids, - np.ones_like(eps_ids), - agent_ids, - f, - s, - 4, + episode_ids=eps_ids, + unroll_ids=np.ones_like(eps_ids), + agent_indices=agent_ids, + feature_columns=f, + state_columns=s, + max_seq_len=4, dynamic_max=False) self.assertEqual(seq_lens.tolist(), [2, 1, 2, 2, 1]) self.assertEqual(len(f_pad[0]), 20) @@ -78,9 +91,13 @@ class TestLSTMUtils(unittest.TestCase): agent_ids = [2, 2, 2] f = [[1, 1, 1]] s = [[1, 1, 1]] - f_pad, s_init, seq_lens = chop_into_sequences(eps_ids, - np.ones_like(eps_ids), - agent_ids, f, s, 4) + f_pad, s_init, seq_lens = chop_into_sequences( + episode_ids=eps_ids, + unroll_ids=np.ones_like(eps_ids), + agent_indices=agent_ids, + feature_columns=f, + state_columns=s, + max_seq_len=4) self.assertEqual([f.tolist() for f in f_pad], [[1, 0, 1, 1]]) self.assertEqual([s.tolist() for s in s_init], [[1, 1]]) self.assertEqual(seq_lens.tolist(), [1, 2]) diff --git a/rllib/utils/sgd.py b/rllib/utils/sgd.py index d5576e0fa..b5b72d44d 100644 --- a/rllib/utils/sgd.py +++ b/rllib/utils/sgd.py @@ -72,18 +72,23 @@ def minibatches(samples, sgd_minibatch_size): i = 0 slices = [] - if samples.seq_lens: - seq_no = 0 - while i < samples.count: - seq_no_end = seq_no - actual_count = 0 - while actual_count < sgd_minibatch_size and len( - samples.seq_lens) > seq_no_end: - actual_count += samples.seq_lens[seq_no_end] - seq_no_end += 1 - slices.append((seq_no, seq_no_end)) - i += actual_count - seq_no = seq_no_end + if samples.seq_lens is not None and len(samples.seq_lens) > 0: + start_pos = 0 + minibatch_size = 0 + idx = 0 + while idx < len(samples.seq_lens): + seq_len = samples.seq_lens[idx] + minibatch_size += seq_len + # Complete minibatch -> Append to slices. + if minibatch_size >= sgd_minibatch_size: + slices.append((start_pos, start_pos + sgd_minibatch_size)) + start_pos += sgd_minibatch_size + if minibatch_size > sgd_minibatch_size: + overhead = minibatch_size - sgd_minibatch_size + start_pos -= (seq_len - overhead) + idx -= 1 + minibatch_size = 0 + idx += 1 else: while i < samples.count: slices.append((i, i + sgd_minibatch_size)) diff --git a/rllib/utils/typing.py b/rllib/utils/typing.py index 592f0424d..3010366fc 100644 --- a/rllib/utils/typing.py +++ b/rllib/utils/typing.py @@ -100,6 +100,9 @@ ModelGradients = Union[List[Tuple[TensorType, TensorType]], List[TensorType]] # Type of dict returned by get_weights() representing model weights. ModelWeights = dict +# An input dict used for direct ModelV2 calls or `ModelV2.from_batch` calls. +ModelInputDict = Dict[str, TensorType] + # Some kind of sample batch. SampleBatchType = Union["SampleBatch", "MultiAgentBatch"] From 4caa6c6d7806d1e406322bfb50fa01072f63985e Mon Sep 17 00:00:00 2001 From: fangfengbin <869218239a@zju.edu.cn> Date: Mon, 21 Dec 2020 11:00:25 +0800 Subject: [PATCH 50/88] [GCS]GCS resource manager remove cluster_resources_ (#12972) --- src/ray/gcs/gcs_server/gcs_node_manager.cc | 2 +- .../gcs/gcs_server/gcs_resource_manager.cc | 100 ++++++++++++------ src/ray/gcs/gcs_server/gcs_resource_manager.h | 6 +- .../gcs_placement_group_scheduler_test.cc | 1 + 4 files changed, 69 insertions(+), 40 deletions(-) diff --git a/src/ray/gcs/gcs_server/gcs_node_manager.cc b/src/ray/gcs/gcs_server/gcs_node_manager.cc index 499abc90f..57f878d60 100644 --- a/src/ray/gcs/gcs_server/gcs_node_manager.cc +++ b/src/ray/gcs/gcs_server/gcs_node_manager.cc @@ -240,7 +240,7 @@ void GcsNodeManager::AddNode(std::shared_ptr node) { for (auto &listener : node_added_listeners_) { listener(node); } - gcs_resource_manager_->OnNodeAdd(node_id); + gcs_resource_manager_->OnNodeAdd(*node); } } diff --git a/src/ray/gcs/gcs_server/gcs_resource_manager.cc b/src/ray/gcs/gcs_server/gcs_resource_manager.cc index 7357eeaaf..f0b3be06c 100644 --- a/src/ray/gcs/gcs_server/gcs_resource_manager.cc +++ b/src/ray/gcs/gcs_server/gcs_resource_manager.cc @@ -26,10 +26,13 @@ void GcsResourceManager::HandleGetResources(const rpc::GetResourcesRequest &requ rpc::GetResourcesReply *reply, rpc::SendReplyCallback send_reply_callback) { NodeID node_id = NodeID::FromBinary(request.node_id()); - auto iter = cluster_resources_.find(node_id); - if (iter != cluster_resources_.end()) { - for (const auto &resource : iter->second.items()) { - (*reply->mutable_resources())[resource.first] = resource.second; + auto iter = cluster_scheduling_resources_.find(node_id); + if (iter != cluster_scheduling_resources_.end()) { + const auto &resource_map = iter->second.GetTotalResources().GetResourceMap(); + rpc::ResourceTableData resource_table_data; + for (const auto &resource : resource_map) { + resource_table_data.set_resource_capacity(resource.second); + (*reply->mutable_resources())[resource.first] = resource_table_data; } } GCS_RPC_SEND_REPLY(send_reply_callback, reply, Status::OK()); @@ -41,29 +44,35 @@ void GcsResourceManager::HandleUpdateResources( rpc::SendReplyCallback send_reply_callback) { NodeID node_id = NodeID::FromBinary(request.node_id()); RAY_LOG(DEBUG) << "Updating resources, node id = " << node_id; - auto iter = cluster_resources_.find(node_id); - std::unordered_map to_be_updated_resources; + auto changed_resources = std::make_shared>(); for (const auto &entry : request.resources()) { - to_be_updated_resources.emplace(entry.first, entry.second.resource_capacity()); + changed_resources->emplace(entry.first, entry.second.resource_capacity()); } - if (iter != cluster_resources_.end()) { - for (const auto &entry : request.resources()) { - (*iter->second.mutable_items())[entry.first] = entry.second; + auto iter = cluster_scheduling_resources_.find(node_id); + if (iter != cluster_scheduling_resources_.end()) { + // Update `cluster_scheduling_resources_`. + SchedulingResources &scheduling_resources = iter->second; + for (const auto &entry : *changed_resources) { + scheduling_resources.UpdateResourceCapacity(entry.first, entry.second); } - UpdateResourceCapacity(node_id, to_be_updated_resources); - auto on_done = [this, node_id, to_be_updated_resources, reply, + + // Update gcs storage. + rpc::ResourceMap resource_map; + for (const auto &entry : iter->second.GetTotalResources().GetResourceMap()) { + (*resource_map.mutable_items())[entry.first].set_resource_capacity(entry.second); + } + for (const auto &entry : *changed_resources) { + (*resource_map.mutable_items())[entry.first].set_resource_capacity(entry.second); + } + + auto on_done = [this, node_id, changed_resources, reply, send_reply_callback](const Status &status) { RAY_CHECK_OK(status); rpc::NodeResourceChange node_resource_change; node_resource_change.set_node_id(node_id.Binary()); - for (const auto &it : to_be_updated_resources) { - const auto &resource_name = it.first; - const auto &resource_capacity = it.second; - auto &node_updated_resources = - (*node_resource_change.mutable_updated_resources()); - node_updated_resources[resource_name] = resource_capacity; - } + node_resource_change.mutable_updated_resources()->insert(changed_resources->begin(), + changed_resources->end()); RAY_CHECK_OK(gcs_pub_sub_->Publish(NODE_RESOURCE_CHANNEL, node_id.Hex(), node_resource_change.SerializeAsString(), nullptr)); @@ -73,7 +82,7 @@ void GcsResourceManager::HandleUpdateResources( }; RAY_CHECK_OK( - gcs_table_storage_->NodeResourceTable().Put(node_id, iter->second, on_done)); + gcs_table_storage_->NodeResourceTable().Put(node_id, resource_map, on_done)); } else { GCS_RPC_SEND_REPLY(send_reply_callback, reply, Status::Invalid("Node is not exist.")); RAY_LOG(ERROR) << "Failed to update resources as node " << node_id @@ -88,13 +97,23 @@ void GcsResourceManager::HandleDeleteResources( NodeID node_id = NodeID::FromBinary(request.node_id()); RAY_LOG(DEBUG) << "Deleting node resources, node id = " << node_id; auto resource_names = VectorFromProtobuf(request.resource_name_list()); - auto iter = cluster_resources_.find(node_id); - if (iter != cluster_resources_.end()) { - DeleteResources(node_id, resource_names); - + auto iter = cluster_scheduling_resources_.find(node_id); + if (iter != cluster_scheduling_resources_.end()) { + // Update `cluster_scheduling_resources_`. for (const auto &resource_name : resource_names) { - RAY_IGNORE_EXPR(iter->second.mutable_items()->erase(resource_name)); + iter->second.DeleteResource(resource_name); } + + // Update gcs storage. + rpc::ResourceMap resource_map; + auto resources = iter->second.GetTotalResources().GetResourceMap(); + for (const auto &resource_name : resource_names) { + resources.erase(resource_name); + } + for (const auto &entry : resources) { + (*resource_map.mutable_items())[entry.first].set_resource_capacity(entry.second); + } + auto on_done = [this, node_id, resource_names, reply, send_reply_callback](const Status &status) { RAY_CHECK_OK(status); @@ -110,7 +129,7 @@ void GcsResourceManager::HandleDeleteResources( GCS_RPC_SEND_REPLY(send_reply_callback, reply, status); }; RAY_CHECK_OK( - gcs_table_storage_->NodeResourceTable().Put(node_id, iter->second, on_done)); + gcs_table_storage_->NodeResourceTable().Put(node_id, resource_map, on_done)); } else { GCS_RPC_SEND_REPLY(send_reply_callback, reply, Status::OK()); RAY_LOG(DEBUG) << "Finished deleting node resources, node id = " << node_id; @@ -136,10 +155,20 @@ void GcsResourceManager::HandleGetAllAvailableResources( void GcsResourceManager::Initialize(const GcsInitData &gcs_init_data) { const auto &nodes = gcs_init_data.Nodes(); - for (auto &entry : gcs_init_data.ClusterResources()) { - const auto &iter = nodes.find(entry.first); - if (iter->second.state() == rpc::GcsNodeInfo::ALIVE) { - cluster_resources_[entry.first] = entry.second; + for (const auto &entry : nodes) { + if (entry.second.state() == rpc::GcsNodeInfo::ALIVE) { + OnNodeAdd(entry.second); + } + } + + const auto &cluster_resources = gcs_init_data.ClusterResources(); + for (const auto &entry : cluster_resources) { + const auto &iter = cluster_scheduling_resources_.find(entry.first); + if (iter != cluster_scheduling_resources_.end()) { + for (const auto &resource : entry.second.items()) { + iter->second.UpdateResourceCapacity(resource.first, + resource.second.resource_capacity()); + } } } } @@ -173,19 +202,20 @@ void GcsResourceManager::DeleteResources( const NodeID &node_id, const std::vector &deleted_resources) { auto iter = cluster_scheduling_resources_.find(node_id); if (iter != cluster_scheduling_resources_.end()) { - for (auto &resource_name : deleted_resources) { + for (const auto &resource_name : deleted_resources) { iter->second.DeleteResource(resource_name); } } } -void GcsResourceManager::OnNodeAdd(const NodeID &node_id) { - // Add an empty resources for this node. - cluster_resources_.emplace(node_id, rpc::ResourceMap()); +void GcsResourceManager::OnNodeAdd(const rpc::GcsNodeInfo &node) { + auto node_id = NodeID::FromBinary(node.node_id()); + if (!cluster_scheduling_resources_.contains(node_id)) { + cluster_scheduling_resources_.emplace(node_id, SchedulingResources()); + } } void GcsResourceManager::OnNodeDead(const NodeID &node_id) { - cluster_resources_.erase(node_id); cluster_scheduling_resources_.erase(node_id); } diff --git a/src/ray/gcs/gcs_server/gcs_resource_manager.h b/src/ray/gcs/gcs_server/gcs_resource_manager.h index eda9ced4d..095e0c234 100644 --- a/src/ray/gcs/gcs_server/gcs_resource_manager.h +++ b/src/ray/gcs/gcs_server/gcs_resource_manager.h @@ -72,8 +72,8 @@ class GcsResourceManager : public rpc::NodeResourceInfoHandler { /// Handle a node registration. /// - /// \param node_id The specified node id. - void OnNodeAdd(const NodeID &node_id); + /// \param node The specified node to add. + void OnNodeAdd(const rpc::GcsNodeInfo &node); /// Handle a node death. /// @@ -130,8 +130,6 @@ class GcsResourceManager : public rpc::NodeResourceInfoHandler { std::shared_ptr gcs_pub_sub_; /// Storage for GCS tables. std::shared_ptr gcs_table_storage_; - /// Cluster resources. - absl::flat_hash_map cluster_resources_; /// Map from node id to the scheduling resources of the node. absl::flat_hash_map cluster_scheduling_resources_; diff --git a/src/ray/gcs/gcs_server/test/gcs_placement_group_scheduler_test.cc b/src/ray/gcs/gcs_server/test/gcs_placement_group_scheduler_test.cc index 3bf5923c9..ef81f8887 100644 --- a/src/ray/gcs/gcs_server/test/gcs_placement_group_scheduler_test.cc +++ b/src/ray/gcs/gcs_server/test/gcs_placement_group_scheduler_test.cc @@ -101,6 +101,7 @@ class GcsPlacementGroupSchedulerTest : public ::testing::Test { rpc::ResourcesData resource; resource.set_node_id(node->node_id()); (*resource.mutable_resources_available())["CPU"] = cpu_num; + resource.set_resources_available_changed(true); gcs_node_manager_->UpdateNodeRealtimeResources(NodeID::FromBinary(node->node_id()), resource); } From c576f0b0737370507bfa1a55075977e7df6e82a1 Mon Sep 17 00:00:00 2001 From: Barak Michener Date: Sun, 20 Dec 2020 19:35:34 -0800 Subject: [PATCH 51/88] [ray_client] Implement a gRPC streaming logs API for the client (#13001) --- python/ray/experimental/client/dataclient.py | 9 +- python/ray/experimental/client/logsclient.py | 84 ++++++++++++++++ .../client/server/dataservicer.py | 2 +- .../experimental/client/server/logservicer.py | 99 +++++++++++++++++++ .../ray/experimental/client/server/server.py | 4 + python/ray/experimental/client/worker.py | 13 ++- python/ray/ray_logging.py | 27 +++++ python/ray/tests/test_experimental_client.py | 45 ++++++++- python/ray/worker.py | 56 ++++++----- src/ray/protobuf/ray_client.proto | 27 +++++ 10 files changed, 332 insertions(+), 34 deletions(-) create mode 100644 python/ray/experimental/client/logsclient.py create mode 100644 python/ray/experimental/client/server/logservicer.py diff --git a/python/ray/experimental/client/dataclient.py b/python/ray/experimental/client/dataclient.py index c6a745df8..b0dda0a1b 100644 --- a/python/ray/experimental/client/dataclient.py +++ b/python/ray/experimental/client/dataclient.py @@ -26,6 +26,7 @@ class DataClient: Args: channel: connected gRPC channel + client_id: the generated ID representing this client """ self.channel = channel self.request_queue = queue.Queue() @@ -68,18 +69,14 @@ class DataClient: logger.info("Cancelling data channel") else: logger.error( - f"Got Error from rpc channel -- shutting down: {e}") + f"Got Error from data channel -- shutting down: {e}") raise e - def close(self, close_channel: bool = False) -> None: + def close(self) -> None: if self.request_queue is not None: self.request_queue.put(None) - self.request_queue = None - if close_channel: - self.channel.close() if self.data_thread is not None: self.data_thread.join() - self.data_thread = None def _blocking_send(self, req: ray_client_pb2.DataRequest ) -> ray_client_pb2.DataResponse: diff --git a/python/ray/experimental/client/logsclient.py b/python/ray/experimental/client/logsclient.py new file mode 100644 index 000000000..f26417e7e --- /dev/null +++ b/python/ray/experimental/client/logsclient.py @@ -0,0 +1,84 @@ +""" +This file implements a threaded stream controller to return logs back from +the ray clientserver. +""" +import sys +import logging +import queue +import threading +import grpc + +import ray.core.generated.ray_client_pb2 as ray_client_pb2 +import ray.core.generated.ray_client_pb2_grpc as ray_client_pb2_grpc + +logger = logging.getLogger(__name__) + + +class LogstreamClient: + def __init__(self, channel: "grpc._channel.Channel"): + """Initializes a thread-safe log stream over a Ray Client gRPC channel. + + Args: + channel: connected gRPC channel + """ + self.channel = channel + self.request_queue = queue.Queue() + self.log_thread = self._start_logthread() + self.log_thread.start() + + def _start_logthread(self) -> threading.Thread: + return threading.Thread(target=self._log_main, args=(), daemon=True) + + def _log_main(self) -> None: + stub = ray_client_pb2_grpc.RayletLogStreamerStub(self.channel) + log_stream = stub.Logstream(iter(self.request_queue.get, None)) + try: + for record in log_stream: + if record.level < 0: + self.stdstream(level=record.level, msg=record.msg) + self.log(level=record.level, msg=record.msg) + except grpc.RpcError as e: + if grpc.StatusCode.CANCELLED != e.code(): + # Not just shutting down normally + logger.error( + f"Got Error from logger channel -- shutting down: {e}") + raise e + + def log(self, level: int, msg: str): + """ + Log the message from the log stream. + By default, calls logger.log but this can be overridden. + + Args: + level: The loglevel of the received log message + msg: The content of the message + """ + logger.log(level=level, msg=msg) + + def stdstream(self, level: int, msg: str): + """ + Log the stdout/stderr entry from the log stream. + By default, calls print but this can be overridden. + + Args: + level: The loglevel of the received log message + msg: The content of the message + """ + print_file = sys.stderr if level == -2 else sys.stdout + print(msg, file=print_file) + + def set_logstream_level(self, level: int): + req = ray_client_pb2.LogSettingsRequest() + req.enabled = True + req.loglevel = level + self.request_queue.put(req) + + def close(self) -> None: + self.request_queue.put(None) + if self.log_thread is not None: + self.log_thread.join() + + def disable_logs(self) -> None: + req = ray_client_pb2.LogSettingsRequest() + req.enabled = False + self.request_queue.put(req) diff --git a/python/ray/experimental/client/server/dataservicer.py b/python/ray/experimental/client/server/dataservicer.py index 874e741d9..925adca28 100644 --- a/python/ray/experimental/client/server/dataservicer.py +++ b/python/ray/experimental/client/server/dataservicer.py @@ -48,7 +48,7 @@ class DataServicer(ray_client_pb2_grpc.RayletDataStreamerServicer): resp.req_id = req.req_id yield resp except grpc.RpcError as e: - logger.debug(f"Closing channel: {e}") + logger.debug(f"Closing data channel: {e}") finally: logger.info(f"Lost data connection from client {client_id}") self.basic_service.release_all(client_id) diff --git a/python/ray/experimental/client/server/logservicer.py b/python/ray/experimental/client/server/logservicer.py new file mode 100644 index 000000000..9b2fa24bf --- /dev/null +++ b/python/ray/experimental/client/server/logservicer.py @@ -0,0 +1,99 @@ +""" +This file responds to log stream requests and forwards logs +with its handler. +""" +import io +import threading +import queue +import logging +import grpc +import uuid + +from ray.worker import print_worker_logs +from ray.ray_logging import global_worker_stdstream_dispatcher +import ray.core.generated.ray_client_pb2 as ray_client_pb2 +import ray.core.generated.ray_client_pb2_grpc as ray_client_pb2_grpc + +logger = logging.getLogger(__name__) + + +class LogstreamHandler(logging.Handler): + def __init__(self, queue, level): + super().__init__() + self.queue = queue + self.level = level + + def emit(self, record: logging.LogRecord): + logdata = ray_client_pb2.LogData() + logdata.msg = record.getMessage() + logdata.level = record.levelno + logdata.name = record.name + self.queue.put(logdata) + + +class StdStreamHandler: + def __init__(self, queue): + self.queue = queue + self.id = str(uuid.uuid4()) + + def handle(self, data): + logdata = ray_client_pb2.LogData() + logdata.level = -2 if data["is_err"] else -1 + logdata.name = "stderr" if data["is_err"] else "stdout" + with io.StringIO() as file: + print_worker_logs(data, file) + logdata.msg = file.getvalue() + self.queue.put(logdata) + + def register_global(self): + global_worker_stdstream_dispatcher.add_handler(self.id, self.handle) + + def unregister_global(self): + global_worker_stdstream_dispatcher.remove_handler(self.id) + + +def log_status_change_thread(log_queue, request_iterator): + std_handler = StdStreamHandler(log_queue) + current_handler = None + root_logger = logging.getLogger("ray") + default_level = root_logger.getEffectiveLevel() + try: + for req in request_iterator: + if current_handler is not None: + root_logger.setLevel(default_level) + root_logger.removeHandler(current_handler) + std_handler.unregister_global() + if not req.enabled: + current_handler = None + continue + current_handler = LogstreamHandler(log_queue, req.loglevel) + std_handler.register_global() + root_logger.addHandler(current_handler) + root_logger.setLevel(req.loglevel) + finally: + if current_handler is not None: + root_logger.setLevel(default_level) + root_logger.removeHandler(current_handler) + std_handler.unregister_global() + log_queue.put(None) + + +class LogstreamServicer(ray_client_pb2_grpc.RayletLogStreamerServicer): + def Logstream(self, request_iterator, context): + logger.info("New logs connection") + log_queue = queue.Queue() + thread = threading.Thread( + target=log_status_change_thread, + args=(log_queue, request_iterator), + daemon=True) + thread.start() + try: + queue_iter = iter(log_queue.get, None) + for record in queue_iter: + if record is None: + break + yield record + except grpc.RpcError as e: + logger.debug(f"Closing log channel: {e}") + finally: + thread.join() diff --git a/python/ray/experimental/client/server/server.py b/python/ray/experimental/client/server/server.py index 442cf1afa..7cc286de8 100644 --- a/python/ray/experimental/client/server/server.py +++ b/python/ray/experimental/client/server/server.py @@ -23,6 +23,7 @@ from ray.experimental.client.server.server_pickler import dumps_from_server from ray.experimental.client.server.server_pickler import loads_from_client from ray.experimental.client.server.core_ray_api import RayServerAPI from ray.experimental.client.server.dataservicer import DataServicer +from ray.experimental.client.server.logservicer import LogstreamServicer from ray.experimental.client.server.server_stubs import current_remote logger = logging.getLogger(__name__) @@ -372,11 +373,14 @@ def serve(connection_str, test_mode=False): server = grpc.server(futures.ThreadPoolExecutor(max_workers=10)) task_servicer = RayletServicer(test_mode=test_mode) data_servicer = DataServicer(task_servicer) + logs_servicer = LogstreamServicer() _set_server_api(RayServerAPI(task_servicer)) ray_client_pb2_grpc.add_RayletDriverServicer_to_server( task_servicer, server) ray_client_pb2_grpc.add_RayletDataStreamerServicer_to_server( data_servicer, server) + ray_client_pb2_grpc.add_RayletLogStreamerServicer_to_server( + logs_servicer, server) server.add_insecure_port(connection_str) server.start() return server diff --git a/python/ray/experimental/client/worker.py b/python/ray/experimental/client/worker.py index bba23584b..8ed41bff4 100644 --- a/python/ray/experimental/client/worker.py +++ b/python/ray/experimental/client/worker.py @@ -30,6 +30,7 @@ from ray.experimental.client.common import ClientObjectRef from ray.experimental.client.common import ClientRemoteFunc from ray.experimental.client.common import ClientStub from ray.experimental.client.dataclient import DataClient +from ray.experimental.client.logsclient import LogstreamClient logger = logging.getLogger(__name__) @@ -54,9 +55,13 @@ class Worker: else: self.channel = grpc.insecure_channel(conn_str) self.server = ray_client_pb2_grpc.RayletDriverStub(self.channel) + self.data_client = DataClient(self.channel, self._client_id) self.reference_count: Dict[bytes, int] = defaultdict(int) + self.log_client = LogstreamClient(self.channel) + self.log_client.set_logstream_level(logging.INFO) + def get(self, vals, *, timeout: Optional[float] = None) -> Any: to_get = [] single = False @@ -197,14 +202,16 @@ class Worker: ray_client_pb2.ReleaseRequest(ids=[id])) def call_retain(self, id: bytes) -> None: - logger.debug(f"Retaining {id}") + logger.debug(f"Retaining {id.hex()}") self.reference_count[id] += 1 def close(self): - self.data_client.close(close_channel=True) - self.server = None + self.log_client.close() + self.data_client.close() if self.channel: + self.channel.close() self.channel = None + self.server = None def get_actor(self, name: str) -> ClientActorHandle: task = ray_client_pb2.ClientTask() diff --git a/python/ray/ray_logging.py b/python/ray/ray_logging.py index 0668f397f..56df7b5c2 100644 --- a/python/ray/ray_logging.py +++ b/python/ray/ray_logging.py @@ -1,8 +1,11 @@ import logging import os import sys +import threading from logging.handlers import RotatingFileHandler +from typing import Callable + import ray from ray.utils import binary_to_hex @@ -258,3 +261,27 @@ def setup_and_get_worker_interceptor_logger(args, # logger to add a newline at the end of string. handler.terminator = "" return logger + + +class WorkerStandardStreamDispatcher: + def __init__(self): + self.handlers = [] + self._lock = threading.Lock() + + def add_handler(self, name: str, handler: Callable) -> None: + with self._lock: + self.handlers.append((name, handler)) + + def remove_handler(self, name: str) -> None: + with self._lock: + new_handlers = [pair for pair in self.handlers if pair[0] != name] + self.handlers = new_handlers + + def emit(self, data): + with self._lock: + for pair in self.handlers: + _, handle = pair + handle(data) + + +global_worker_stdstream_dispatcher = WorkerStandardStreamDispatcher() diff --git a/python/ray/tests/test_experimental_client.py b/python/ray/tests/test_experimental_client.py index cc15e7272..e6afee042 100644 --- a/python/ray/tests/test_experimental_client.py +++ b/python/ray/tests/test_experimental_client.py @@ -1,4 +1,7 @@ import pytest +import time +import sys +import logging from contextlib import contextmanager import ray.experimental.client.server.server as ray_client_server @@ -234,6 +237,47 @@ def test_pass_handles(ray_start_regular_shared): 4)) == local_fact(4) +def test_basic_log_stream(ray_start_regular_shared): + with ray_start_client_server() as ray: + log_msgs = [] + + def test_log(level, msg): + log_msgs.append(msg) + + ray.worker.log_client.log = test_log + ray.worker.log_client.set_logstream_level(logging.DEBUG) + # Allow some time to propogate + time.sleep(1) + x = ray.put("Foo") + assert ray.get(x) == "Foo" + time.sleep(1) + logs_with_id = [msg for msg in log_msgs if msg.find(x.id.hex()) >= 0] + assert len(logs_with_id) >= 2 + assert any((msg.find("get") >= 0 for msg in logs_with_id)) + assert any((msg.find("put") >= 0 for msg in logs_with_id)) + + +def test_stdout_log_stream(ray_start_regular_shared): + with ray_start_client_server() as ray: + log_msgs = [] + + def test_log(level, msg): + log_msgs.append(msg) + + ray.worker.log_client.stdstream = test_log + + @ray.remote + def print_on_stderr_and_stdout(s): + print(s) + print(s, file=sys.stderr) + + time.sleep(1) + print_on_stderr_and_stdout.remote("Hello world") + time.sleep(1) + assert len(log_msgs) == 2 + assert all((msg.find("Hello world") for msg in log_msgs)) + + def test_basic_named_actor(ray_start_regular_shared): """ Test that ray.get_actor() can create and return a detached actor. @@ -264,5 +308,4 @@ def test_basic_named_actor(ray_start_regular_shared): if __name__ == "__main__": - import sys sys.exit(pytest.main(["-v", __file__])) diff --git a/python/ray/worker.py b/python/ray/worker.py index 495478ad7..631a82767 100644 --- a/python/ray/worker.py +++ b/python/ray/worker.py @@ -48,6 +48,7 @@ from ray.exceptions import ( ) from ray.function_manager import FunctionActorManager from ray.ray_logging import setup_logger +from ray.ray_logging import global_worker_stdstream_dispatcher from ray.utils import _random_string, check_oversized_pickle from ray.util.inspect import is_cython @@ -910,29 +911,8 @@ def print_logs(redis_client, threads_stopped, job_id): if data["job"] and ray.utils.binary_to_hex( job_id.binary()) != data["job"]: continue - - print_file = sys.stderr if data["is_err"] else sys.stdout - - def color_for(data): - if data["pid"] == "raylet": - return colorama.Fore.YELLOW - else: - return colorama.Fore.CYAN - - if data["ip"] == localhost: - for line in data["lines"]: - print( - "{}{}(pid={}){} {}".format( - colorama.Style.DIM, color_for(data), data["pid"], - colorama.Style.RESET_ALL, line), - file=print_file) - else: - for line in data["lines"]: - print( - "{}{}(pid={}, ip={}){} {}".format( - colorama.Style.DIM, color_for(data), data["pid"], - data["ip"], colorama.Style.RESET_ALL, line), - file=print_file) + data["localhost"] = localhost + global_worker_stdstream_dispatcher.emit(data) except (OSError, redis.exceptions.ConnectionError) as e: logger.error(f"print_logs: {e}") @@ -941,6 +921,34 @@ def print_logs(redis_client, threads_stopped, job_id): pubsub_client.close() +def print_to_stdstream(data): + print_file = sys.stderr if data["is_err"] else sys.stdout + print_worker_logs(data, print_file) + + +def print_worker_logs(data, print_file): + def color_for(data): + if data["pid"] == "raylet": + return colorama.Fore.YELLOW + else: + return colorama.Fore.CYAN + + if data["ip"] == data["localhost"]: + for line in data["lines"]: + print( + "{}{}(pid={}){} {}".format(colorama.Style.DIM, color_for(data), + data["pid"], + colorama.Style.RESET_ALL, line), + file=print_file) + else: + for line in data["lines"]: + print( + "{}{}(pid={}, ip={}){} {}".format( + colorama.Style.DIM, color_for(data), data["pid"], + data["ip"], colorama.Style.RESET_ALL, line), + file=print_file) + + def print_error_messages_raylet(task_error_queue, threads_stopped): """Prints message received in the given output queue. @@ -1201,6 +1209,8 @@ def connect(node, worker.printer_thread.daemon = True worker.printer_thread.start() if log_to_driver: + global_worker_stdstream_dispatcher.add_handler( + "ray_print_logs", print_to_stdstream) worker.logger_thread = threading.Thread( target=print_logs, name="ray_print_logs", diff --git a/src/ray/protobuf/ray_client.proto b/src/ray/protobuf/ray_client.proto index a566f8031..3dd3128b2 100644 --- a/src/ray/protobuf/ray_client.proto +++ b/src/ray/protobuf/ray_client.proto @@ -229,3 +229,30 @@ service RayletDataStreamer { rpc Datapath(stream DataRequest) returns (stream DataResponse) { } } + +// A request to change the quantity or type of the logs +// currently being streamed. Initially, all logs are disabled. +message LogSettingsRequest { + // Set to recieve logs. + bool enabled = 1; + // At what loglevel should logs be forwarded on the stream. + int32 loglevel = 2; + // TODO(barakmich): More log filtering options. +} + +message LogData { + // The message data in the log + string msg = 1; + // The loglevel at which this log should be displayed. + // * level > 0: Log leveling as per python's logging library + // * level == -1: stdout (fd 1) + // * level == -2: stderr (fd 2) + int32 level = 2; + // The name of the logger that generated this message. + string name = 3; +} + +service RayletLogStreamer { + rpc Logstream(stream LogSettingsRequest) returns (stream LogData) { + } +} From 85a4435ba01a6e1590c543de846a6e2492ac3d0a Mon Sep 17 00:00:00 2001 From: fangfengbin <869218239a@zju.edu.cn> Date: Mon, 21 Dec 2020 20:02:50 +0800 Subject: [PATCH 52/88] [GCS]Fix redis store client AsyncPutWithIndex unordered bug (#13002) --- .../gcs/store_client/redis_store_client.cc | 39 ++++++++----------- 1 file changed, 17 insertions(+), 22 deletions(-) diff --git a/src/ray/gcs/store_client/redis_store_client.cc b/src/ray/gcs/store_client/redis_store_client.cc index 26a8776d5..b104be3ad 100644 --- a/src/ray/gcs/store_client/redis_store_client.cc +++ b/src/ray/gcs/store_client/redis_store_client.cc @@ -38,29 +38,24 @@ Status RedisStoreClient::AsyncPutWithIndex(const std::string &table_name, const std::string &index_key, const std::string &data, const StatusCallback &callback) { - auto write_callback = [this, table_name, key, data, callback](Status status) { - if (!status.ok()) { - // Run callback if failed. - if (callback != nullptr) { - callback(status); - } - return; - } - - // Write data to Redis. - status = DoPut(GenRedisKey(table_name, key), data, callback); - - if (!status.ok()) { - // Run callback if failed. - if (callback != nullptr) { - callback(status); - } - } - }; - + // NOTE: To ensure the atomicity of `AsyncPutWithIndex`, we can't write data to Redis in + // the callback function of index writing. // Write index to Redis. - std::string index_table_key = GenRedisKey(table_name, key, index_key); - return DoPut(index_table_key, key, write_callback); + const auto &index_table_key = GenRedisKey(table_name, key, index_key); + RAY_CHECK_OK(DoPut(index_table_key, key, nullptr)); + + // Write data to Redis. + // The operation of redis client is executed in order, and it can ensure that index is + // written first and then data is written. The index and data are decoupled, so we don't + // need to write data in the callback function of index writing. + const auto &status = DoPut(GenRedisKey(table_name, key), data, callback); + if (!status.ok()) { + // Run callback if failed. + if (callback != nullptr) { + callback(status); + } + } + return status; } Status RedisStoreClient::AsyncGet(const std::string &table_name, const std::string &key, From 6e354690b6729fcf28b206aa327b0adea5b01643 Mon Sep 17 00:00:00 2001 From: Hao Chen Date: Mon, 21 Dec 2020 23:58:43 +0800 Subject: [PATCH 53/88] [Java] Make task options serializable (#13010) --- java/api/src/main/java/io/ray/api/options/BaseTaskOptions.java | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/java/api/src/main/java/io/ray/api/options/BaseTaskOptions.java b/java/api/src/main/java/io/ray/api/options/BaseTaskOptions.java index 32fb4285e..943e2608d 100644 --- a/java/api/src/main/java/io/ray/api/options/BaseTaskOptions.java +++ b/java/api/src/main/java/io/ray/api/options/BaseTaskOptions.java @@ -1,12 +1,13 @@ package io.ray.api.options; +import java.io.Serializable; import java.util.HashMap; import java.util.Map; /** * The options class for RayCall or ActorCreation. */ -public abstract class BaseTaskOptions { +public abstract class BaseTaskOptions implements Serializable { public final Map resources; From 5a6801dde7cf016639ec1b06a8b26a69fae9e0c1 Mon Sep 17 00:00:00 2001 From: Kai Yang Date: Tue, 22 Dec 2020 00:01:27 +0800 Subject: [PATCH 54/88] [Core] Remove `delete_creating_tasks` (#12962) --- .../java/io/ray/api/runtime/RayRuntime.java | 3 +- .../io/ray/runtime/AbstractRayRuntime.java | 4 +- .../java/io/ray/runtime/gcs/GcsClient.java | 35 ---------- .../runtime/object/LocalModeObjectStore.java | 2 +- .../ray/runtime/object/NativeObjectStore.java | 7 +- .../io/ray/runtime/object/ObjectStore.java | 4 +- .../main/java/io/ray/runtime/util/IdUtil.java | 69 ------------------- .../java/io/ray/runtime/UniqueIdTest.java | 8 --- .../src/main/java/io/ray/test/ActorTest.java | 4 +- .../main/java/io/ray/test/PlasmaFreeTest.java | 19 +---- python/ray/_raylet.pyx | 5 +- python/ray/includes/libcoreworker.pxd | 2 +- python/ray/internal/internal_api.py | 7 +- src/ray/core_worker/core_worker.cc | 6 +- src/ray/core_worker/core_worker.h | 5 +- ...io_ray_runtime_object_NativeObjectStore.cc | 7 +- .../io_ray_runtime_object_NativeObjectStore.h | 4 +- .../store_provider/plasma_store_provider.cc | 7 +- .../store_provider/plasma_store_provider.h | 3 +- src/ray/core_worker/test/core_worker_test.cc | 2 +- src/ray/gcs/accessor.h | 10 --- .../gcs/gcs_client/service_based_accessor.cc | 19 ----- .../gcs/gcs_client/service_based_accessor.h | 3 - .../test/service_based_gcs_client_test.cc | 11 --- .../gcs/gcs_server/task_info_handler_impl.cc | 25 ------- .../gcs/gcs_server/task_info_handler_impl.h | 4 -- .../gcs_server/test/gcs_server_rpc_test.cc | 17 ----- src/ray/gcs/redis_accessor.cc | 13 ---- src/ray/gcs/redis_accessor.h | 3 - src/ray/protobuf/gcs_service.proto | 10 --- src/ray/raylet/format/node_manager.fbs | 2 - src/ray/raylet/node_manager.cc | 8 --- src/ray/raylet_client/raylet_client.cc | 6 +- src/ray/raylet_client/raylet_client.h | 4 +- src/ray/rpc/gcs_server/gcs_rpc_client.h | 3 - src/ray/rpc/gcs_server/gcs_rpc_server.h | 5 -- 36 files changed, 33 insertions(+), 313 deletions(-) diff --git a/java/api/src/main/java/io/ray/api/runtime/RayRuntime.java b/java/api/src/main/java/io/ray/api/runtime/RayRuntime.java index 620a40042..8817d5b1b 100644 --- a/java/api/src/main/java/io/ray/api/runtime/RayRuntime.java +++ b/java/api/src/main/java/io/ray/api/runtime/RayRuntime.java @@ -72,9 +72,8 @@ public interface RayRuntime { * * @param objectRefs The object references to free. * @param localOnly Whether only free objects for local object store or not. - * @param deleteCreatingTasks Whether also delete objects' creating tasks from GCS. */ - void free(List> objectRefs, boolean localOnly, boolean deleteCreatingTasks); + void free(List> objectRefs, boolean localOnly); /** * Set the resource for the specific node. diff --git a/java/runtime/src/main/java/io/ray/runtime/AbstractRayRuntime.java b/java/runtime/src/main/java/io/ray/runtime/AbstractRayRuntime.java index ac199fd95..6dabe3c3c 100644 --- a/java/runtime/src/main/java/io/ray/runtime/AbstractRayRuntime.java +++ b/java/runtime/src/main/java/io/ray/runtime/AbstractRayRuntime.java @@ -100,9 +100,9 @@ public abstract class AbstractRayRuntime implements RayRuntimeInternal { } @Override - public void free(List> objectRefs, boolean localOnly, boolean deleteCreatingTasks) { + public void free(List> objectRefs, boolean localOnly) { objectStore.delete(objectRefs.stream().map(ref -> ((ObjectRefImpl) ref).getId()).collect( - Collectors.toList()), localOnly, deleteCreatingTasks); + Collectors.toList()), localOnly); } @Override diff --git a/java/runtime/src/main/java/io/ray/runtime/gcs/GcsClient.java b/java/runtime/src/main/java/io/ray/runtime/gcs/GcsClient.java index 9c5d10072..41a82c2d5 100644 --- a/java/runtime/src/main/java/io/ray/runtime/gcs/GcsClient.java +++ b/java/runtime/src/main/java/io/ray/runtime/gcs/GcsClient.java @@ -3,10 +3,8 @@ package io.ray.runtime.gcs; import com.google.common.base.Preconditions; import com.google.protobuf.InvalidProtocolBufferException; import io.ray.api.id.ActorId; -import io.ray.api.id.BaseId; import io.ray.api.id.JobId; import io.ray.api.id.PlacementGroupId; -import io.ray.api.id.TaskId; import io.ray.api.id.UniqueId; import io.ray.api.placementgroup.PlacementGroup; import io.ray.api.runtimecontext.NodeInfo; @@ -14,12 +12,10 @@ import io.ray.runtime.generated.Gcs; import io.ray.runtime.generated.Gcs.GcsNodeInfo; import io.ray.runtime.generated.Gcs.TablePrefix; import io.ray.runtime.placementgroup.PlacementGroupUtils; -import io.ray.runtime.util.IdUtil; import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; -import java.util.stream.Collectors; import org.apache.commons.lang3.ArrayUtils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -31,25 +27,10 @@ public class GcsClient { private static Logger LOGGER = LoggerFactory.getLogger(GcsClient.class); private RedisClient primary; - private List shards; private GlobalStateAccessor globalStateAccessor; public GcsClient(String redisAddress, String redisPassword) { primary = new RedisClient(redisAddress, redisPassword); - int numShards = 0; - try { - numShards = Integer.valueOf(primary.get("NumRedisShards", null)); - Preconditions.checkState(numShards > 0, - String.format("Expected at least one Redis shards, found %d.", numShards)); - } catch (NumberFormatException e) { - throw new RuntimeException("Failed to get number of redis shards.", e); - } - - List shardAddresses = primary.lrange("RedisShards".getBytes(), 0, -1); - Preconditions.checkState(shardAddresses.size() == numShards); - shards = shardAddresses.stream().map((byte[] address) -> { - return new RedisClient(new String(address), redisPassword); - }).collect(Collectors.toList()); globalStateAccessor = GlobalStateAccessor.getInstance(redisAddress, redisPassword); } @@ -163,16 +144,6 @@ public class GcsClient { return actorTableData.getNumRestarts() != 0; } - /** - * Query whether the raylet task exists in Gcs. - */ - public boolean rayletTaskExistsInGcs(TaskId taskId) { - byte[] key = ArrayUtils.addAll(TablePrefix.RAYLET_TASK.toString().getBytes(), - taskId.getBytes()); - RedisClient client = getShardClient(taskId); - return client.exists(key); - } - public JobId nextJobId() { int jobCounter = (int) primary.incr("JobCounter".getBytes()); return JobId.fromInt(jobCounter); @@ -186,10 +157,4 @@ public class GcsClient { LOGGER.debug("Destroying global state accessor."); GlobalStateAccessor.destroyInstance(); } - - private RedisClient getShardClient(BaseId key) { - return shards.get((int) Long.remainderUnsigned(IdUtil.murmurHashCode(key), - shards.size())); - } - } diff --git a/java/runtime/src/main/java/io/ray/runtime/object/LocalModeObjectStore.java b/java/runtime/src/main/java/io/ray/runtime/object/LocalModeObjectStore.java index 87f0adc00..4614100ae 100644 --- a/java/runtime/src/main/java/io/ray/runtime/object/LocalModeObjectStore.java +++ b/java/runtime/src/main/java/io/ray/runtime/object/LocalModeObjectStore.java @@ -93,7 +93,7 @@ public class LocalModeObjectStore extends ObjectStore { } @Override - public void delete(List objectIds, boolean localOnly, boolean deleteCreatingTasks) { + public void delete(List objectIds, boolean localOnly) { for (ObjectId objectId : objectIds) { pool.remove(objectId); } diff --git a/java/runtime/src/main/java/io/ray/runtime/object/NativeObjectStore.java b/java/runtime/src/main/java/io/ray/runtime/object/NativeObjectStore.java index ef85cf62c..38a6a16c9 100644 --- a/java/runtime/src/main/java/io/ray/runtime/object/NativeObjectStore.java +++ b/java/runtime/src/main/java/io/ray/runtime/object/NativeObjectStore.java @@ -50,8 +50,8 @@ public class NativeObjectStore extends ObjectStore { } @Override - public void delete(List objectIds, boolean localOnly, boolean deleteCreatingTasks) { - nativeDelete(toBinaryList(objectIds), localOnly, deleteCreatingTasks); + public void delete(List objectIds, boolean localOnly) { + nativeDelete(toBinaryList(objectIds), localOnly); } @Override @@ -116,8 +116,7 @@ public class NativeObjectStore extends ObjectStore { private static native List nativeWait(List objectIds, int numObjects, long timeoutMs); - private static native void nativeDelete(List objectIds, boolean localOnly, - boolean deleteCreatingTasks); + private static native void nativeDelete(List objectIds, boolean localOnly); private static native void nativeAddLocalReference(byte[] workerId, byte[] objectId); diff --git a/java/runtime/src/main/java/io/ray/runtime/object/ObjectStore.java b/java/runtime/src/main/java/io/ray/runtime/object/ObjectStore.java index e72bed802..bfec229f1 100644 --- a/java/runtime/src/main/java/io/ray/runtime/object/ObjectStore.java +++ b/java/runtime/src/main/java/io/ray/runtime/object/ObjectStore.java @@ -167,10 +167,8 @@ public abstract class ObjectStore { * * @param objectIds IDs of the objects to delete. * @param localOnly Whether only delete the objects in local node, or all nodes in the cluster. - * @param deleteCreatingTasks Whether also delete the tasks that created these objects. */ - public abstract void delete(List objectIds, boolean localOnly, - boolean deleteCreatingTasks); + public abstract void delete(List objectIds, boolean localOnly); /** * Increase the local reference count for this object ID. diff --git a/java/runtime/src/main/java/io/ray/runtime/util/IdUtil.java b/java/runtime/src/main/java/io/ray/runtime/util/IdUtil.java index eca2860af..23b24728c 100644 --- a/java/runtime/src/main/java/io/ray/runtime/util/IdUtil.java +++ b/java/runtime/src/main/java/io/ray/runtime/util/IdUtil.java @@ -1,7 +1,6 @@ package io.ray.runtime.util; import io.ray.api.id.ActorId; -import io.ray.api.id.BaseId; import io.ray.api.id.ObjectId; import io.ray.api.id.TaskId; @@ -11,74 +10,6 @@ import io.ray.api.id.TaskId; */ public class IdUtil { - /** - * Compute the murmur hash code of this ID. - */ - public static long murmurHashCode(BaseId id) { - return murmurHash64A(id.getBytes(), id.size(), 0); - } - - /** - * This method is the same as `Hash()` method of `ID` class in ray/src/ray/common/id.h - */ - private static long murmurHash64A(byte[] data, int length, int seed) { - final long m = 0xc6a4a7935bd1e995L; - final int r = 47; - - long h = (seed & 0xFFFFFFFFL) ^ (length * m); - - int length8 = length / 8; - - for (int i = 0; i < length8; i++) { - final int i8 = i * 8; - long k = ((long) data[i8] & 0xff) - + (((long) data[i8 + 1] & 0xff) << 8) - + (((long) data[i8 + 2] & 0xff) << 16) - + (((long) data[i8 + 3] & 0xff) << 24) - + (((long) data[i8 + 4] & 0xff) << 32) - + (((long) data[i8 + 5] & 0xff) << 40) - + (((long) data[i8 + 6] & 0xff) << 48) - + (((long) data[i8 + 7] & 0xff) << 56); - - k *= m; - k ^= k >>> r; - k *= m; - - h ^= k; - h *= m; - } - - final int remaining = length % 8; - if (remaining >= 7) { - h ^= (long) (data[(length & ~7) + 6] & 0xff) << 48; - } - if (remaining >= 6) { - h ^= (long) (data[(length & ~7) + 5] & 0xff) << 40; - } - if (remaining >= 5) { - h ^= (long) (data[(length & ~7) + 4] & 0xff) << 32; - } - if (remaining >= 4) { - h ^= (long) (data[(length & ~7) + 3] & 0xff) << 24; - } - if (remaining >= 3) { - h ^= (long) (data[(length & ~7) + 2] & 0xff) << 16; - } - if (remaining >= 2) { - h ^= (long) (data[(length & ~7) + 1] & 0xff) << 8; - } - if (remaining >= 1) { - h ^= (long) (data[length & ~7] & 0xff); - h *= m; - } - - h ^= h >>> r; - h *= m; - h ^= h >>> r; - - return h; - } - /** * Compute the actor ID of the task which created this object. * @return The actor ID of the task which created this object. diff --git a/java/runtime/src/test/java/io/ray/runtime/UniqueIdTest.java b/java/runtime/src/test/java/io/ray/runtime/UniqueIdTest.java index 25704f321..ce1b61db1 100644 --- a/java/runtime/src/test/java/io/ray/runtime/UniqueIdTest.java +++ b/java/runtime/src/test/java/io/ray/runtime/UniqueIdTest.java @@ -1,7 +1,6 @@ package io.ray.runtime; import io.ray.api.id.UniqueId; -import io.ray.runtime.util.IdUtil; import java.nio.ByteBuffer; import java.util.Arrays; import javax.xml.bind.DatatypeConverter; @@ -46,11 +45,4 @@ public class UniqueIdTest { Assert.assertEquals("FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF".toLowerCase(), id6.toString()); Assert.assertTrue(id6.isNil()); } - - @Test - void testMurmurHash() { - UniqueId id = UniqueId.fromHexString("3131313131313131313132323232323232323232"); - long remainder = Long.remainderUnsigned(IdUtil.murmurHashCode(id), 1000000000); - Assert.assertEquals(remainder, 787616861); - } } diff --git a/java/test/src/main/java/io/ray/test/ActorTest.java b/java/test/src/main/java/io/ray/test/ActorTest.java index a7e9c6ac6..78b8f2468 100644 --- a/java/test/src/main/java/io/ray/test/ActorTest.java +++ b/java/test/src/main/java/io/ray/test/ActorTest.java @@ -128,7 +128,7 @@ public class ActorTest extends BaseTest { ObjectRef value = counter.task(Counter::getValue).remote(); Assert.assertEquals(100, value.get()); // Delete the object from the object store. - Ray.internal().free(ImmutableList.of(value), false, false); + Ray.internal().free(ImmutableList.of(value), false); // Wait for delete RPC to propagate TimeUnit.SECONDS.sleep(1); // Free deletes from in-memory store. @@ -138,7 +138,7 @@ public class ActorTest extends BaseTest { ObjectRef largeValue = counter.task(Counter::createLargeObject).remote(); Assert.assertTrue(largeValue.get() instanceof TestUtils.LargeObject); // Delete the object from the object store. - Ray.internal().free(ImmutableList.of(largeValue), false, false); + Ray.internal().free(ImmutableList.of(largeValue), false); // Wait for delete RPC to propagate TimeUnit.SECONDS.sleep(1); // Free deletes big objects from plasma store. diff --git a/java/test/src/main/java/io/ray/test/PlasmaFreeTest.java b/java/test/src/main/java/io/ray/test/PlasmaFreeTest.java index 1b924f3c0..59fafc0f4 100644 --- a/java/test/src/main/java/io/ray/test/PlasmaFreeTest.java +++ b/java/test/src/main/java/io/ray/test/PlasmaFreeTest.java @@ -3,9 +3,7 @@ package io.ray.test; import com.google.common.collect.ImmutableList; import io.ray.api.ObjectRef; import io.ray.api.Ray; -import io.ray.api.id.TaskId; import io.ray.runtime.object.ObjectRefImpl; -import java.util.Arrays; import org.testng.Assert; import org.testng.annotations.Test; @@ -20,7 +18,7 @@ public class PlasmaFreeTest extends BaseTest { ObjectRef helloId = Ray.task(PlasmaFreeTest::hello).remote(); String helloString = helloId.get(); Assert.assertEquals("hello", helloString); - Ray.internal().free(ImmutableList.of(helloId), true, false); + Ray.internal().free(ImmutableList.of(helloId), true); final boolean result = TestUtils.waitForCondition(() -> !TestUtils.getRuntime().getObjectStore() @@ -32,19 +30,4 @@ public class PlasmaFreeTest extends BaseTest { Assert.assertFalse(result); } } - - @Test(groups = {"cluster"}) - public void testDeleteCreatingTasks() { - ObjectRef helloId = Ray.task(PlasmaFreeTest::hello).remote(); - Assert.assertEquals("hello", helloId.get()); - Ray.internal().free(ImmutableList.of(helloId), true, true); - - TaskId taskId = TaskId.fromBytes( - Arrays.copyOf(((ObjectRefImpl) helloId).getId().getBytes(), TaskId.LENGTH)); - final boolean result = TestUtils.waitForCondition( - () -> !TestUtils.getRuntime().getGcsClient() - .rayletTaskExistsInGcs(taskId), 50); - Assert.assertTrue(result); - } - } diff --git a/python/ray/_raylet.pyx b/python/ray/_raylet.pyx index 4d5bb8ff9..356222bb9 100644 --- a/python/ray/_raylet.pyx +++ b/python/ray/_raylet.pyx @@ -1031,14 +1031,13 @@ cdef class CoreWorker: return ready, not_ready - def free_objects(self, object_refs, c_bool local_only, - c_bool delete_creating_tasks): + def free_objects(self, object_refs, c_bool local_only): cdef: c_vector[CObjectID] free_ids = ObjectRefsToVector(object_refs) with nogil: check_status(CCoreWorkerProcess.GetCoreWorker().Delete( - free_ids, local_only, delete_creating_tasks)) + free_ids, local_only)) def global_gc(self): with nogil: diff --git a/python/ray/includes/libcoreworker.pxd b/python/ray/includes/libcoreworker.pxd index 68c1a95b3..9dd63aafe 100644 --- a/python/ray/includes/libcoreworker.pxd +++ b/python/ray/includes/libcoreworker.pxd @@ -182,7 +182,7 @@ cdef extern from "ray/core_worker/core_worker.h" nogil: int64_t timeout_ms, c_vector[c_bool] *results, c_bool fetch_local) CRayStatus Delete(const c_vector[CObjectID] &object_ids, - c_bool local_only, c_bool delete_creating_tasks) + c_bool local_only) CRayStatus TriggerGlobalGC() c_string MemoryUsageString() diff --git a/python/ray/internal/internal_api.py b/python/ray/internal/internal_api.py index d3e25c1ec..601b3986a 100644 --- a/python/ray/internal/internal_api.py +++ b/python/ray/internal/internal_api.py @@ -37,7 +37,7 @@ def memory_summary(): return reply.memory_summary -def free(object_refs, local_only=False, delete_creating_tasks=False): +def free(object_refs, local_only=False): """Free a list of IDs from the in-process and plasma object stores. This function is a low-level API which should be used in restricted @@ -59,8 +59,6 @@ def free(object_refs, local_only=False, delete_creating_tasks=False): object_refs (List[ObjectRef]): List of object refs to delete. local_only (bool): Whether only deleting the list of objects in local object store or all object stores. - delete_creating_tasks (bool): Whether also delete the object creating - tasks. """ worker = ray.worker.global_worker @@ -83,5 +81,4 @@ def free(object_refs, local_only=False, delete_creating_tasks=False): if len(object_refs) == 0: return - worker.core_worker.free_objects(object_refs, local_only, - delete_creating_tasks) + worker.core_worker.free_objects(object_refs, local_only) diff --git a/src/ray/core_worker/core_worker.cc b/src/ray/core_worker/core_worker.cc index 9d7099303..d2ab2c150 100644 --- a/src/ray/core_worker/core_worker.cc +++ b/src/ray/core_worker/core_worker.cc @@ -1111,8 +1111,7 @@ Status CoreWorker::Wait(const std::vector &ids, int num_objects, return Status::OK(); } -Status CoreWorker::Delete(const std::vector &object_ids, bool local_only, - bool delete_creating_tasks) { +Status CoreWorker::Delete(const std::vector &object_ids, bool local_only) { // Release the object from plasma. This does not affect the object's ref // count. If this was called from a non-owning worker, then a warning will be // logged and the object will not get released. @@ -1129,8 +1128,7 @@ Status CoreWorker::Delete(const std::vector &object_ids, bool local_on // We only delete from plasma, which avoids hangs (issue #7105). In-memory // objects can only be deleted once the ref count goes to 0. absl::flat_hash_set plasma_object_ids(object_ids.begin(), object_ids.end()); - return plasma_store_provider_->Delete(plasma_object_ids, local_only, - delete_creating_tasks); + return plasma_store_provider_->Delete(plasma_object_ids, local_only); } void CoreWorker::TriggerGlobalGC() { diff --git a/src/ray/core_worker/core_worker.h b/src/ray/core_worker/core_worker.h index 171a42d76..14136a895 100644 --- a/src/ray/core_worker/core_worker.h +++ b/src/ray/core_worker/core_worker.h @@ -571,11 +571,8 @@ class CoreWorker : public rpc::CoreWorkerServiceHandler { /// \param[in] object_ids IDs of the objects to delete. /// \param[in] local_only Whether only delete the objects in local node, or all nodes in /// the cluster. - /// \param[in] delete_creating_tasks Whether also delete the tasks that - /// created these objects. /// \return Status. - Status Delete(const std::vector &object_ids, bool local_only, - bool delete_creating_tasks); + Status Delete(const std::vector &object_ids, bool local_only); /// Trigger garbage collection on each worker in the cluster. void TriggerGlobalGC(); diff --git a/src/ray/core_worker/lib/java/io_ray_runtime_object_NativeObjectStore.cc b/src/ray/core_worker/lib/java/io_ray_runtime_object_NativeObjectStore.cc index b62b19818..d66088de1 100644 --- a/src/ray/core_worker/lib/java/io_ray_runtime_object_NativeObjectStore.cc +++ b/src/ray/core_worker/lib/java/io_ray_runtime_object_NativeObjectStore.cc @@ -120,15 +120,14 @@ JNIEXPORT jobject JNICALL Java_io_ray_runtime_object_NativeObjectStore_nativeWai } JNIEXPORT void JNICALL Java_io_ray_runtime_object_NativeObjectStore_nativeDelete( - JNIEnv *env, jclass, jobject objectIds, jboolean localOnly, - jboolean deleteCreatingTasks) { + JNIEnv *env, jclass, jobject objectIds, jboolean localOnly) { std::vector object_ids; JavaListToNativeVector( env, objectIds, &object_ids, [](JNIEnv *env, jobject id) { return JavaByteArrayToId(env, static_cast(id)); }); - auto status = ray::CoreWorkerProcess::GetCoreWorker().Delete( - object_ids, (bool)localOnly, (bool)deleteCreatingTasks); + auto status = + ray::CoreWorkerProcess::GetCoreWorker().Delete(object_ids, (bool)localOnly); THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, (void)0); } diff --git a/src/ray/core_worker/lib/java/io_ray_runtime_object_NativeObjectStore.h b/src/ray/core_worker/lib/java/io_ray_runtime_object_NativeObjectStore.h index 4e11c0456..b1da06e57 100644 --- a/src/ray/core_worker/lib/java/io_ray_runtime_object_NativeObjectStore.h +++ b/src/ray/core_worker/lib/java/io_ray_runtime_object_NativeObjectStore.h @@ -60,10 +60,10 @@ JNIEXPORT jobject JNICALL Java_io_ray_runtime_object_NativeObjectStore_nativeWai /* * Class: io_ray_runtime_object_NativeObjectStore * Method: nativeDelete - * Signature: (Ljava/util/List;ZZ)V + * Signature: (Ljava/util/List;Z)V */ JNIEXPORT void JNICALL Java_io_ray_runtime_object_NativeObjectStore_nativeDelete( - JNIEnv *, jclass, jobject, jboolean, jboolean); + JNIEnv *, jclass, jobject, jboolean); /* * Class: io_ray_runtime_object_NativeObjectStore diff --git a/src/ray/core_worker/store_provider/plasma_store_provider.cc b/src/ray/core_worker/store_provider/plasma_store_provider.cc index 25007a863..3079b99f5 100644 --- a/src/ray/core_worker/store_provider/plasma_store_provider.cc +++ b/src/ray/core_worker/store_provider/plasma_store_provider.cc @@ -361,10 +361,9 @@ Status CoreWorkerPlasmaStoreProvider::Wait( } Status CoreWorkerPlasmaStoreProvider::Delete( - const absl::flat_hash_set &object_ids, bool local_only, - bool delete_creating_tasks) { + const absl::flat_hash_set &object_ids, bool local_only) { std::vector object_id_vector(object_ids.begin(), object_ids.end()); - return raylet_client_->FreeObjects(object_id_vector, local_only, delete_creating_tasks); + return raylet_client_->FreeObjects(object_id_vector, local_only); } std::string CoreWorkerPlasmaStoreProvider::MemoryUsageString() { @@ -424,7 +423,7 @@ Status CoreWorkerPlasmaStoreProvider::WarmupStore() { RAY_RETURN_NOT_OK(Create(nullptr, 8, object_id, rpc::Address(), &data)); RAY_RETURN_NOT_OK(Seal(object_id)); RAY_RETURN_NOT_OK(Release(object_id)); - RAY_RETURN_NOT_OK(Delete({object_id}, false, false)); + RAY_RETURN_NOT_OK(Delete({object_id}, false)); return Status::OK(); } diff --git a/src/ray/core_worker/store_provider/plasma_store_provider.h b/src/ray/core_worker/store_provider/plasma_store_provider.h index 88bed0428..6085a50c1 100644 --- a/src/ray/core_worker/store_provider/plasma_store_provider.h +++ b/src/ray/core_worker/store_provider/plasma_store_provider.h @@ -98,8 +98,7 @@ class CoreWorkerPlasmaStoreProvider { int64_t timeout_ms, const WorkerContext &ctx, absl::flat_hash_set *ready); - Status Delete(const absl::flat_hash_set &object_ids, bool local_only, - bool delete_creating_tasks); + Status Delete(const absl::flat_hash_set &object_ids, bool local_only); /// Lists objects in used (pinned) by the current client. /// diff --git a/src/ray/core_worker/test/core_worker_test.cc b/src/ray/core_worker/test/core_worker_test.cc index f06e1a7f4..0c4d69149 100644 --- a/src/ray/core_worker/test/core_worker_test.cc +++ b/src/ray/core_worker/test/core_worker_test.cc @@ -822,7 +822,7 @@ TEST_F(SingleNodeTest, TestObjectInterface) { // Test Delete(). // clear the reference held by PlasmaBuffer. results.clear(); - RAY_CHECK_OK(core_worker.Delete(ids, true, false)); + RAY_CHECK_OK(core_worker.Delete(ids, true)); // Note that Delete() calls RayletClient::FreeObjects and would not // wait for objects being deleted, so wait a while for plasma store diff --git a/src/ray/gcs/accessor.h b/src/ray/gcs/accessor.h index 655c47aa7..e1e70d5c5 100644 --- a/src/ray/gcs/accessor.h +++ b/src/ray/gcs/accessor.h @@ -208,16 +208,6 @@ class TaskInfoAccessor { virtual Status AsyncGet(const TaskID &task_id, const OptionalItemCallback &callback) = 0; - /// Delete tasks from GCS asynchronously. - /// - /// \param task_ids The vector of IDs to delete from GCS. - /// \param callback Callback that is called after delete finished. - /// \return Status - // TODO(micafan) Will support callback of batch deletion in the future. - // Currently this callback will never be called. - virtual Status AsyncDelete(const std::vector &task_ids, - const StatusCallback &callback) = 0; - /// Subscribe asynchronously to the event that the given task is added in GCS. /// /// \param task_id The ID of the task to be subscribed to. diff --git a/src/ray/gcs/gcs_client/service_based_accessor.cc b/src/ray/gcs/gcs_client/service_based_accessor.cc index 2cf1d2caf..7e7d67d44 100644 --- a/src/ray/gcs/gcs_client/service_based_accessor.cc +++ b/src/ray/gcs/gcs_client/service_based_accessor.cc @@ -886,25 +886,6 @@ Status ServiceBasedTaskInfoAccessor::AsyncGet( return Status::OK(); } -Status ServiceBasedTaskInfoAccessor::AsyncDelete(const std::vector &task_ids, - const StatusCallback &callback) { - RAY_LOG(DEBUG) << "Deleting tasks, task id list size = " << task_ids.size(); - rpc::DeleteTasksRequest request; - for (auto &task_id : task_ids) { - request.add_task_id_list(task_id.Binary()); - } - client_impl_->GetGcsRpcClient().DeleteTasks( - request, - [task_ids, callback](const Status &status, const rpc::DeleteTasksReply &reply) { - if (callback) { - callback(status); - } - RAY_LOG(DEBUG) << "Finished deleting tasks, status = " << status - << ", task id list size = " << task_ids.size(); - }); - return Status::OK(); -} - Status ServiceBasedTaskInfoAccessor::AsyncSubscribe( const TaskID &task_id, const SubscribeCallback &subscribe, const StatusCallback &done) { diff --git a/src/ray/gcs/gcs_client/service_based_accessor.h b/src/ray/gcs/gcs_client/service_based_accessor.h index f0e1f45bc..05f2d4316 100644 --- a/src/ray/gcs/gcs_client/service_based_accessor.h +++ b/src/ray/gcs/gcs_client/service_based_accessor.h @@ -278,9 +278,6 @@ class ServiceBasedTaskInfoAccessor : public TaskInfoAccessor { Status AsyncGet(const TaskID &task_id, const OptionalItemCallback &callback) override; - Status AsyncDelete(const std::vector &task_ids, - const StatusCallback &callback) override; - Status AsyncSubscribe(const TaskID &task_id, const SubscribeCallback &subscribe, const StatusCallback &done) override; diff --git a/src/ray/gcs/gcs_client/test/service_based_gcs_client_test.cc b/src/ray/gcs/gcs_client/test/service_based_gcs_client_test.cc index 9df66dfff..b470598d0 100644 --- a/src/ray/gcs/gcs_client/test/service_based_gcs_client_test.cc +++ b/src/ray/gcs/gcs_client/test/service_based_gcs_client_test.cc @@ -415,13 +415,6 @@ class ServiceBasedGcsClientTest : public ::testing::Test { return task_table_data; } - bool DeleteTask(const std::vector &task_ids) { - std::promise promise; - RAY_CHECK_OK(gcs_client_->Tasks().AsyncDelete( - task_ids, [&promise](Status status) { promise.set_value(status.ok()); })); - return WaitReady(promise.get_future(), timeout_ms_); - } - bool SubscribeTaskLease( const TaskID &task_id, const gcs::SubscribeCallback> @@ -875,10 +868,6 @@ TEST_F(ServiceBasedGcsClientTest, TestTaskInfo) { std::this_thread::sleep_for(std::chrono::milliseconds(100)); EXPECT_EQ(task_count, 1); - // Delete tasks from GCS. - std::vector task_ids = {task_id}; - ASSERT_TRUE(DeleteTask(task_ids)); - // Subscribe to the event that the given task lease is added in GCS. std::atomic task_lease_count(0); auto task_lease_subscribe = [&task_lease_count]( diff --git a/src/ray/gcs/gcs_server/task_info_handler_impl.cc b/src/ray/gcs/gcs_server/task_info_handler_impl.cc index 7034c87a5..b47ab7cef 100644 --- a/src/ray/gcs/gcs_server/task_info_handler_impl.cc +++ b/src/ray/gcs/gcs_server/task_info_handler_impl.cc @@ -68,30 +68,6 @@ void DefaultTaskInfoHandler::HandleGetTask(const GetTaskRequest &request, ++counts_[CountType::GET_TASK_REQUEST]; } -void DefaultTaskInfoHandler::HandleDeleteTasks(const DeleteTasksRequest &request, - DeleteTasksReply *reply, - SendReplyCallback send_reply_callback) { - std::vector task_ids = IdVectorFromProtobuf(request.task_id_list()); - JobID job_id = task_ids.empty() ? JobID::Nil() : task_ids[0].JobId(); - RAY_LOG(DEBUG) << "Deleting tasks, job id = " << job_id - << ", task id list size = " << task_ids.size(); - auto on_done = [job_id, task_ids, request, reply, send_reply_callback](Status status) { - if (!status.ok()) { - RAY_LOG(ERROR) << "Failed to delete tasks, job id = " << job_id - << ", task id list size = " << task_ids.size(); - } - GCS_RPC_SEND_REPLY(send_reply_callback, reply, status); - }; - - Status status = gcs_table_storage_->TaskTable().BatchDelete(task_ids, on_done); - if (!status.ok()) { - on_done(status); - } - RAY_LOG(DEBUG) << "Finished deleting tasks, job id = " << job_id - << ", task id list size = " << task_ids.size(); - ++counts_[CountType::DELETE_TASKS_REQUEST]; -} - void DefaultTaskInfoHandler::HandleAddTaskLease(const AddTaskLeaseRequest &request, AddTaskLeaseReply *reply, SendReplyCallback send_reply_callback) { @@ -183,7 +159,6 @@ std::string DefaultTaskInfoHandler::DebugString() const { stream << "DefaultTaskInfoHandler: {AddTask request count: " << counts_[CountType::ADD_TASK_REQUEST] << ", GetTask request count: " << counts_[CountType::GET_TASK_REQUEST] - << ", DeleteTasks request count: " << counts_[CountType::DELETE_TASKS_REQUEST] << ", AddTaskLease request count: " << counts_[CountType::ADD_TASK_LEASE_REQUEST] << ", GetTaskLease request count: " << counts_[CountType::GET_TASK_LEASE_REQUEST] << ", AttemptTaskReconstruction request count: " diff --git a/src/ray/gcs/gcs_server/task_info_handler_impl.h b/src/ray/gcs/gcs_server/task_info_handler_impl.h index 98cd64bda..5a7599e8f 100644 --- a/src/ray/gcs/gcs_server/task_info_handler_impl.h +++ b/src/ray/gcs/gcs_server/task_info_handler_impl.h @@ -35,9 +35,6 @@ class DefaultTaskInfoHandler : public rpc::TaskInfoHandler { void HandleGetTask(const GetTaskRequest &request, GetTaskReply *reply, SendReplyCallback send_reply_callback) override; - void HandleDeleteTasks(const DeleteTasksRequest &request, DeleteTasksReply *reply, - SendReplyCallback send_reply_callback) override; - void HandleAddTaskLease(const AddTaskLeaseRequest &request, AddTaskLeaseReply *reply, SendReplyCallback send_reply_callback) override; @@ -58,7 +55,6 @@ class DefaultTaskInfoHandler : public rpc::TaskInfoHandler { enum CountType { ADD_TASK_REQUEST = 0, GET_TASK_REQUEST = 1, - DELETE_TASKS_REQUEST = 2, ADD_TASK_LEASE_REQUEST = 3, GET_TASK_LEASE_REQUEST = 4, ATTEMPT_TASK_RECONSTRUCTION_REQUEST = 5, diff --git a/src/ray/gcs/gcs_server/test/gcs_server_rpc_test.cc b/src/ray/gcs/gcs_server/test/gcs_server_rpc_test.cc index fd2084168..ea8ebc09d 100644 --- a/src/ray/gcs/gcs_server/test/gcs_server_rpc_test.cc +++ b/src/ray/gcs/gcs_server/test/gcs_server_rpc_test.cc @@ -280,16 +280,6 @@ class GcsServerTest : public ::testing::Test { return task_data; } - bool DeleteTasks(const rpc::DeleteTasksRequest &request) { - std::promise promise; - client_->DeleteTasks( - request, [&promise](const Status &status, const rpc::DeleteTasksReply &reply) { - RAY_CHECK_OK(status); - promise.set_value(true); - }); - return WaitReady(promise.get_future(), timeout_ms_); - } - bool AddTaskLease(const rpc::AddTaskLeaseRequest &request) { std::promise promise; client_->AddTaskLease( @@ -574,13 +564,6 @@ TEST_F(GcsServerTest, TestTaskInfo) { rpc::TaskTableData result = GetTask(task_id.Binary()); ASSERT_TRUE(result.task().task_spec().job_id() == job_id.Binary()); - // Delete task - rpc::DeleteTasksRequest delete_tasks_request; - delete_tasks_request.add_task_id_list(task_id.Binary()); - ASSERT_TRUE(DeleteTasks(delete_tasks_request)); - result = GetTask(task_id.Binary()); - ASSERT_TRUE(!result.has_task()); - // Add task lease NodeID node_id = NodeID::FromRandom(); auto task_lease_data = Mocker::GenTaskLeaseData(task_id.Binary(), node_id.Binary()); diff --git a/src/ray/gcs/redis_accessor.cc b/src/ray/gcs/redis_accessor.cc index bd3fe0604..248eb9a89 100644 --- a/src/ray/gcs/redis_accessor.cc +++ b/src/ray/gcs/redis_accessor.cc @@ -247,19 +247,6 @@ Status RedisTaskInfoAccessor::AsyncGet( return task_table.Lookup(task_id.JobId(), task_id, on_success, on_failure); } -Status RedisTaskInfoAccessor::AsyncDelete(const std::vector &task_ids, - const StatusCallback &callback) { - raylet::TaskTable &task_table = client_impl_->raylet_task_table(); - JobID job_id = task_ids.empty() ? JobID::Nil() : task_ids[0].JobId(); - task_table.Delete(job_id, task_ids); - if (callback) { - callback(Status::OK()); - } - // TODO(micafan) Always return OK here. - // Confirm if we need to handle the deletion failure and how to handle it. - return Status::OK(); -} - Status RedisTaskInfoAccessor::AsyncSubscribe( const TaskID &task_id, const SubscribeCallback &subscribe, const StatusCallback &done) { diff --git a/src/ray/gcs/redis_accessor.h b/src/ray/gcs/redis_accessor.h index c8263d0c8..ec5d389f6 100644 --- a/src/ray/gcs/redis_accessor.h +++ b/src/ray/gcs/redis_accessor.h @@ -182,9 +182,6 @@ class RedisTaskInfoAccessor : public TaskInfoAccessor { Status AsyncGet(const TaskID &task_id, const OptionalItemCallback &callback) override; - Status AsyncDelete(const std::vector &task_ids, - const StatusCallback &callback) override; - Status AsyncSubscribe(const TaskID &task_id, const SubscribeCallback &subscribe, const StatusCallback &done) override; diff --git a/src/ray/protobuf/gcs_service.proto b/src/ray/protobuf/gcs_service.proto index eb730a7cf..8bba86e56 100644 --- a/src/ray/protobuf/gcs_service.proto +++ b/src/ray/protobuf/gcs_service.proto @@ -348,14 +348,6 @@ message GetTaskReply { TaskTableData task_data = 2; } -message DeleteTasksRequest { - repeated bytes task_id_list = 1; -} - -message DeleteTasksReply { - GcsStatus status = 1; -} - message AddTaskLeaseRequest { TaskLeaseData task_lease_data = 1; } @@ -387,8 +379,6 @@ service TaskInfoGcsService { rpc AddTask(AddTaskRequest) returns (AddTaskReply); // Get task information from GCS Service. rpc GetTask(GetTaskRequest) returns (GetTaskReply); - // Delete tasks from GCS Service. - rpc DeleteTasks(DeleteTasksRequest) returns (DeleteTasksReply); // Add a task lease to GCS Service. rpc AddTaskLease(AddTaskLeaseRequest) returns (AddTaskLeaseReply); // Get task lease information from GCS Service. diff --git a/src/ray/raylet/format/node_manager.fbs b/src/ray/raylet/format/node_manager.fbs index fb95bbc61..1e405b5e3 100644 --- a/src/ray/raylet/format/node_manager.fbs +++ b/src/ray/raylet/format/node_manager.fbs @@ -249,8 +249,6 @@ table FreeObjectsRequest { // Whether keep this request with local object store // or send it to all the object stores. local_only: bool; - // Whether also delete objects' creating tasks from GCS. - delete_creating_tasks: bool; // List of object ids we'll delete from object store. object_ids: [string]; } diff --git a/src/ray/raylet/node_manager.cc b/src/ray/raylet/node_manager.cc index e78820d42..b27a0e32c 100644 --- a/src/ray/raylet/node_manager.cc +++ b/src/ray/raylet/node_manager.cc @@ -1224,14 +1224,6 @@ void NodeManager::ProcessClientMessage(const std::shared_ptr & std::vector object_ids = from_flatbuf(*message->object_ids()); // Clean up objects from the object store. object_manager_.FreeObjects(object_ids, message->local_only()); - if (message->delete_creating_tasks()) { - // Clean up their creating tasks from GCS. - std::vector creating_task_ids; - for (const auto &object_id : object_ids) { - creating_task_ids.push_back(object_id.TaskId()); - } - RAY_CHECK_OK(gcs_client_->Tasks().AsyncDelete(creating_task_ids, nullptr)); - } } break; case protocol::MessageType::SubscribePlasmaReady: { ProcessSubscribePlasmaReady(client, message_data); diff --git a/src/ray/raylet_client/raylet_client.cc b/src/ray/raylet_client/raylet_client.cc index 1c6365796..5582a68ba 100644 --- a/src/ray/raylet_client/raylet_client.cc +++ b/src/ray/raylet_client/raylet_client.cc @@ -274,10 +274,10 @@ Status raylet::RayletClient::PushProfileEvents(const ProfileTableData &profile_e } Status raylet::RayletClient::FreeObjects(const std::vector &object_ids, - bool local_only, bool delete_creating_tasks) { + bool local_only) { flatbuffers::FlatBufferBuilder fbb; - auto message = protocol::CreateFreeObjectsRequest( - fbb, local_only, delete_creating_tasks, to_flatbuf(fbb, object_ids)); + auto message = + protocol::CreateFreeObjectsRequest(fbb, local_only, to_flatbuf(fbb, object_ids)); fbb.Finish(message); return conn_->WriteMessage(MessageType::FreeObjectsInObjectStoreRequest, &fbb); } diff --git a/src/ray/raylet_client/raylet_client.h b/src/ray/raylet_client/raylet_client.h index 9fa1b7982..185ca445a 100644 --- a/src/ray/raylet_client/raylet_client.h +++ b/src/ray/raylet_client/raylet_client.h @@ -313,10 +313,8 @@ class RayletClient : public RayletClientInterface { /// \param object_ids A list of ObjectsIDs to be deleted. /// \param local_only Whether keep this request with local object store /// or send it to all the object stores. - /// \param delete_creating_tasks Whether also delete objects' creating tasks from GCS. /// \return ray::Status. - ray::Status FreeObjects(const std::vector &object_ids, bool local_only, - bool deleteCreatingTasks); + ray::Status FreeObjects(const std::vector &object_ids, bool local_only); /// Sets a resource with the specified capacity and client id /// \param resource_name Name of the resource to be set diff --git a/src/ray/rpc/gcs_server/gcs_rpc_client.h b/src/ray/rpc/gcs_server/gcs_rpc_client.h index 82857123e..39641358f 100644 --- a/src/ray/rpc/gcs_server/gcs_rpc_client.h +++ b/src/ray/rpc/gcs_server/gcs_rpc_client.h @@ -219,9 +219,6 @@ class GcsRpcClient { /// Get task information from GCS Service. VOID_GCS_RPC_CLIENT_METHOD(TaskInfoGcsService, GetTask, task_info_grpc_client_, ) - /// Delete tasks from GCS Service. - VOID_GCS_RPC_CLIENT_METHOD(TaskInfoGcsService, DeleteTasks, task_info_grpc_client_, ) - /// Add a task lease to GCS Service. VOID_GCS_RPC_CLIENT_METHOD(TaskInfoGcsService, AddTaskLease, task_info_grpc_client_, ) diff --git a/src/ray/rpc/gcs_server/gcs_rpc_server.h b/src/ray/rpc/gcs_server/gcs_rpc_server.h index 248ec9837..a39323e40 100644 --- a/src/ray/rpc/gcs_server/gcs_rpc_server.h +++ b/src/ray/rpc/gcs_server/gcs_rpc_server.h @@ -375,10 +375,6 @@ class TaskInfoGcsServiceHandler { virtual void HandleGetTask(const GetTaskRequest &request, GetTaskReply *reply, SendReplyCallback send_reply_callback) = 0; - virtual void HandleDeleteTasks(const DeleteTasksRequest &request, - DeleteTasksReply *reply, - SendReplyCallback send_reply_callback) = 0; - virtual void HandleAddTaskLease(const AddTaskLeaseRequest &request, AddTaskLeaseReply *reply, SendReplyCallback send_reply_callback) = 0; @@ -410,7 +406,6 @@ class TaskInfoGrpcService : public GrpcService { std::vector> *server_call_factories) override { TASK_INFO_SERVICE_RPC_HANDLER(AddTask); TASK_INFO_SERVICE_RPC_HANDLER(GetTask); - TASK_INFO_SERVICE_RPC_HANDLER(DeleteTasks); TASK_INFO_SERVICE_RPC_HANDLER(AddTaskLease); TASK_INFO_SERVICE_RPC_HANDLER(GetTaskLease); TASK_INFO_SERVICE_RPC_HANDLER(AttemptTaskReconstruction); From ef95db51e190ffd6896582d4a813b9f9372f0bde Mon Sep 17 00:00:00 2001 From: roireshef Date: Mon, 21 Dec 2020 19:19:33 +0200 Subject: [PATCH 55/88] [RLlib] Arbitrary input to value() when not using GAE (#12941) --- rllib/agents/ppo/ppo_torch_policy.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rllib/agents/ppo/ppo_torch_policy.py b/rllib/agents/ppo/ppo_torch_policy.py index a268e7487..fa2ca6c1d 100644 --- a/rllib/agents/ppo/ppo_torch_policy.py +++ b/rllib/agents/ppo/ppo_torch_policy.py @@ -244,7 +244,7 @@ class ValueNetworkMixin: # When not doing GAE, we do not require the value function's output. else: - def value(ob, prev_action, prev_reward, *state): + def value(*args, **kwargs): return 0.0 self._value = value From 5e2b850836b893a7276edbcf42079b6759dbd2f0 Mon Sep 17 00:00:00 2001 From: Ameer Haj Ali Date: Mon, 21 Dec 2020 20:30:03 +0200 Subject: [PATCH 56/88] [autoscaler] Fixes max_workers bug. (#13008) --- .../_private/resource_demand_scheduler.py | 15 +- .../tests/test_resource_demand_scheduler.py | 133 +++++++++++++----- 2 files changed, 107 insertions(+), 41 deletions(-) diff --git a/python/ray/autoscaler/_private/resource_demand_scheduler.py b/python/ray/autoscaler/_private/resource_demand_scheduler.py index f3ec607df..aba8cff2d 100644 --- a/python/ray/autoscaler/_private/resource_demand_scheduler.py +++ b/python/ray/autoscaler/_private/resource_demand_scheduler.py @@ -192,7 +192,8 @@ class ResourceDemandScheduler: # Add 1 to account for the head node. max_to_add = self.max_workers + 1 - sum(node_type_counts.values()) nodes_to_add_based_on_demand = get_nodes_for( - self.node_types, node_type_counts, max_to_add, unfulfilled) + self.node_types, node_type_counts, self.head_node_type, max_to_add, + unfulfilled) # Merge nodes to add based on demand and nodes to add based on # min_workers constraint. We add them because nodes to add based on # demand was calculated after the min_workers constraint was respected. @@ -447,6 +448,7 @@ class ResourceDemandScheduler: to_launch = get_nodes_for( self.node_types, node_type_counts, + self.head_node_type, max_to_add, unfulfilled, strict_spread=True) @@ -544,7 +546,7 @@ def _add_min_workers_nodes( max_node_resources, ensure_min_cluster_size) # Get the nodes to meet the unfulfilled. nodes_to_add_request_resources = get_nodes_for( - node_types, node_type_counts, max_to_add, + node_types, node_type_counts, head_node_type, max_to_add, resource_requests_unfulfilled) # Update the resources, counts and total nodes to add. for node_type in nodes_to_add_request_resources: @@ -565,6 +567,7 @@ def _add_min_workers_nodes( def get_nodes_for(node_types: Dict[NodeType, NodeTypeConfigDict], existing_nodes: Dict[NodeType, int], + head_node_type: NodeType, max_to_add: int, resources: List[ResourceDict], strict_spread: bool = False) -> Dict[NodeType, int]: @@ -588,9 +591,13 @@ def get_nodes_for(node_types: Dict[NodeType, NodeTypeConfigDict], while resources and sum(nodes_to_add.values()) < max_to_add: utilization_scores = [] for node_type in node_types: + max_workers_of_node_type = node_types[node_type].get( + "max_workers", 0) + if head_node_type == node_type: + # Add 1 to account for head node. + max_workers_of_node_type = max_workers_of_node_type + 1 if (existing_nodes.get(node_type, 0) + nodes_to_add.get( - node_type, 0) >= node_types[node_type].get( - "max_workers", 0)): + node_type, 0) >= max_workers_of_node_type): continue node_resources = node_types[node_type]["resources"] if strict_spread: diff --git a/python/ray/tests/test_resource_demand_scheduler.py b/python/ray/tests/test_resource_demand_scheduler.py index 067b5f53d..2093f1e14 100644 --- a/python/ray/tests/test_resource_demand_scheduler.py +++ b/python/ray/tests/test_resource_demand_scheduler.py @@ -143,43 +143,100 @@ def test_bin_pack(): def test_get_nodes_packing_heuristic(): - assert get_nodes_for(TYPES_A, {}, 9999, [{"GPU": 8}]) == \ - {"p2.8xlarge": 1} - assert get_nodes_for(TYPES_A, {}, 9999, [{"GPU": 1}] * 6) == \ - {"p2.8xlarge": 1} - assert get_nodes_for(TYPES_A, {}, 9999, [{"GPU": 1}] * 4) == \ - {"p2.xlarge": 4} - assert get_nodes_for(TYPES_A, {}, 9999, [{"CPU": 32, "GPU": 1}] * 3) \ - == {"p2.8xlarge": 3} - assert get_nodes_for(TYPES_A, {}, 9999, [{"CPU": 64, "GPU": 1}] * 3) \ - == {} - assert get_nodes_for(TYPES_A, {}, 9999, [{"CPU": 64}] * 3) == \ - {"m4.16xlarge": 3} - assert get_nodes_for(TYPES_A, {}, 9999, [{"CPU": 64}, {"CPU": 1}]) \ - == {"m4.16xlarge": 1, "m4.large": 1} + assert get_nodes_for(TYPES_A, {}, "empty_node", 9999, [{ + "GPU": 8 + }]) == { + "p2.8xlarge": 1 + } + assert get_nodes_for(TYPES_A, {}, "empty_node", 9999, [{ + "GPU": 1 + }] * 6) == { + "p2.8xlarge": 1 + } + assert get_nodes_for(TYPES_A, {}, "empty_node", 9999, [{ + "GPU": 1 + }] * 4) == { + "p2.xlarge": 4 + } + assert get_nodes_for(TYPES_A, {}, "empty_node", 9999, [{ + "CPU": 32, + "GPU": 1 + }] * 3) == { + "p2.8xlarge": 3 + } + assert get_nodes_for(TYPES_A, {}, "empty_node", 9999, [{ + "CPU": 64, + "GPU": 1 + }] * 3) == {} + assert get_nodes_for(TYPES_A, {}, "empty_node", 9999, [{ + "CPU": 64 + }] * 3) == { + "m4.16xlarge": 3 + } + assert get_nodes_for(TYPES_A, {}, "empty_node", 9999, [{ + "CPU": 64 + }, { + "CPU": 1 + }]) == { + "m4.16xlarge": 1, + "m4.large": 1 + } + assert get_nodes_for(TYPES_A, {}, "empty_node", 9999, [{ + "CPU": 64 + }, { + "CPU": 9 + }, { + "CPU": 9 + }]) == { + "m4.16xlarge": 1, + "m4.4xlarge": 2 + } + assert get_nodes_for(TYPES_A, {}, "empty_node", 9999, [{ + "CPU": 16 + }] * 5) == { + "m4.16xlarge": 1, + "m4.4xlarge": 1 + } + assert get_nodes_for(TYPES_A, {}, "empty_node", 9999, [{ + "CPU": 8 + }] * 10) == { + "m4.16xlarge": 1, + "m4.4xlarge": 1 + } + assert get_nodes_for(TYPES_A, {}, "empty_node", 9999, [{ + "CPU": 1 + }] * 100) == { + "m4.16xlarge": 1, + "m4.4xlarge": 2, + "m4.large": 2 + } + + assert get_nodes_for(TYPES_A, {}, "empty_node", 9999, [{ + "GPU": 1 + }] + ([{ + "CPU": 1 + }] * 64)) == { + "m4.16xlarge": 1, + "p2.xlarge": 1 + } + + assert get_nodes_for(TYPES_A, {}, "empty_node", 9999, ([{ + "GPU": 1 + }] * 8) + ([{ + "CPU": 1 + }] * 64)) == { + "m4.16xlarge": 1, + "p2.8xlarge": 1 + } + assert get_nodes_for( - TYPES_A, {}, 9999, [{"CPU": 64}, {"CPU": 9}, {"CPU": 9}]) == \ - {"m4.16xlarge": 1, "m4.4xlarge": 2} - assert get_nodes_for(TYPES_A, {}, 9999, [{"CPU": 16}] * 5) == \ - {"m4.16xlarge": 1, "m4.4xlarge": 1} - assert get_nodes_for(TYPES_A, {}, 9999, [{"CPU": 8}] * 10) == \ - {"m4.16xlarge": 1, "m4.4xlarge": 1} - assert get_nodes_for(TYPES_A, {}, 9999, [{"CPU": 1}] * 100) == \ - {"m4.16xlarge": 1, "m4.4xlarge": 2, "m4.large": 2} - assert get_nodes_for( - TYPES_A, {}, 9999, [{"GPU": 1}] + ([{"CPU": 1}] * 64)) == \ - {"m4.16xlarge": 1, "p2.xlarge": 1} - assert get_nodes_for( - TYPES_A, {}, 9999, ([{"GPU": 1}] * 8) + ([{"CPU": 1}] * 64)) == \ - {"m4.16xlarge": 1, "p2.8xlarge": 1} - assert get_nodes_for( - TYPES_A, {}, 9999, [{ + TYPES_A, {}, "empty_node", 9999, [{ "GPU": 1 }] * 8, strict_spread=False) == { "p2.8xlarge": 1 } assert get_nodes_for( - TYPES_A, {}, 9999, [{ + TYPES_A, {}, "empty_node", 9999, [{ "GPU": 1 }] * 8, strict_spread=True) == { "p2.xlarge": 8 @@ -201,22 +258,22 @@ def test_get_nodes_respects_max_limit(): "max_workers": 99999, }, } - assert get_nodes_for(types, {}, 2, [{"CPU": 1}] * 10) == \ + assert get_nodes_for(types, {}, "empty_node", 2, [{"CPU": 1}] * 10) == \ {"m4.large": 2} - assert get_nodes_for(types, {"m4.large": 9999}, 9999, [{ + assert get_nodes_for(types, {"m4.large": 9999}, "empty_node", 9999, [{ "CPU": 1 }] * 10) == {} - assert get_nodes_for(types, {"m4.large": 0}, 9999, [{ + assert get_nodes_for(types, {"m4.large": 0}, "empty_node", 9999, [{ "CPU": 1 }] * 10) == { "m4.large": 5 } - assert get_nodes_for(types, {"m4.large": 7}, 4, [{ + assert get_nodes_for(types, {"m4.large": 7}, "m4.large", 4, [{ "CPU": 1 }] * 10) == { - "m4.large": 3 + "m4.large": 4 } - assert get_nodes_for(types, {"m4.large": 7}, 2, [{ + assert get_nodes_for(types, {"m4.large": 7}, "m4.large", 2, [{ "CPU": 1 }] * 10) == { "m4.large": 2 @@ -1355,6 +1412,7 @@ class AutoscalingTest(unittest.TestCase): config_path = self.write_config(config) self.provider = MockProvider() runner = MockProcessRunner() + runner.respond_to_call("json .Config.Env", ["[]" for i in range(6)]) self.provider.create_node({}, { TAG_RAY_NODE_KIND: NODE_KIND_HEAD, TAG_RAY_USER_NODE_TYPE: "empty_node" @@ -1379,6 +1437,7 @@ class AutoscalingTest(unittest.TestCase): autoscaler.request_resources([{"CPU": 32}] * 4) autoscaler.update() self.waitForNodes(5) + assert self.provider.mock_nodes[3].node_type == "m4.16xlarge" assert self.provider.mock_nodes[4].node_type == "m4.16xlarge" From 43b9c7811ead485c8fcc25e3b3029c9c9915d0b1 Mon Sep 17 00:00:00 2001 From: Barak Michener Date: Mon, 21 Dec 2020 12:17:44 -0800 Subject: [PATCH 57/88] [ray_client] add client microbenchmarks (#13007) --- .../ray/_private/ray_client_microbenchmark.py | 83 +++++++++++++++++++ .../_private/ray_microbenchmark_helpers.py | 39 +++++++++ .../experimental/client/ray_client_helpers.py | 16 ++++ python/ray/ray_perf.py | 32 ++----- python/ray/tests/test_experimental_client.py | 16 +--- .../test_experimental_client_metadata.py | 2 +- .../test_experimental_client_references.py | 2 +- .../test_experimental_client_terminate.py | 2 +- 8 files changed, 148 insertions(+), 44 deletions(-) create mode 100644 python/ray/_private/ray_client_microbenchmark.py create mode 100644 python/ray/_private/ray_microbenchmark_helpers.py create mode 100644 python/ray/experimental/client/ray_client_helpers.py diff --git a/python/ray/_private/ray_client_microbenchmark.py b/python/ray/_private/ray_client_microbenchmark.py new file mode 100644 index 000000000..c2b7d4486 --- /dev/null +++ b/python/ray/_private/ray_client_microbenchmark.py @@ -0,0 +1,83 @@ +import inspect +import logging +import sys + +from ray.experimental.client.ray_client_helpers import ray_start_client_server + +from ray._private.ray_microbenchmark_helpers import timeit +from ray._private.ray_microbenchmark_helpers import ray_setup_and_teardown + + +def benchmark_get_calls(ray): + value = ray.put(0) + + def get_small(): + ray.get(value) + + timeit("client: get calls", get_small) + + +def benchmark_put_calls(ray): + def put_small(): + ray.put(0) + + timeit("client: put calls", put_small) + + +def benchmark_remote_put_calls(ray): + @ray.remote + def do_put_small(): + for _ in range(100): + ray.put(0) + + def put_multi_small(): + ray.get([do_put_small.remote() for _ in range(10)]) + + timeit("client: remote put calls", put_multi_small, 1000) + + +def benchmark_simple_actor(ray): + @ray.remote(num_cpus=0) + class Actor: + def small_value(self): + return b"ok" + + def small_value_arg(self, x): + return b"ok" + + def small_value_batch(self, n): + ray.get([self.small_value.remote() for _ in range(n)]) + + a = Actor.remote() + + def actor_sync(): + ray.get(a.small_value.remote()) + + timeit("client: 1:1 actor calls sync", actor_sync) + + def actor_async(): + ray.get([a.small_value.remote() for _ in range(1000)]) + + timeit("client: 1:1 actor calls async", actor_async, 1000) + + a = Actor.options(max_concurrency=16).remote() + + def actor_concurrent(): + ray.get([a.small_value.remote() for _ in range(1000)]) + + timeit("client: 1:1 actor calls concurrent", actor_concurrent, 1000) + + +def main(): + system_config = {"put_small_object_in_memory_store": True} + with ray_setup_and_teardown( + logging_level=logging.WARNING, _system_config=system_config): + for name, obj in inspect.getmembers(sys.modules[__name__]): + if not name.startswith("benchmark_"): + continue + with ray_start_client_server() as ray: + obj(ray) + + +if __name__ == "__main__": + main() diff --git a/python/ray/_private/ray_microbenchmark_helpers.py b/python/ray/_private/ray_microbenchmark_helpers.py new file mode 100644 index 000000000..dffa0afd8 --- /dev/null +++ b/python/ray/_private/ray_microbenchmark_helpers.py @@ -0,0 +1,39 @@ +import time +import os +import ray +import numpy as np + +from contextlib import contextmanager + +# Only run tests matching this filter pattern. +filter_pattern = os.environ.get("TESTS_TO_RUN", "") + + +def timeit(name, fn, multiplier=1): + if filter_pattern not in name: + return + # warmup + start = time.time() + while time.time() - start < 1: + fn() + # real run + stats = [] + for _ in range(4): + start = time.time() + count = 0 + while time.time() - start < 2: + fn() + count += 1 + end = time.time() + stats.append(multiplier * count / (end - start)) + print(name, "per second", round(np.mean(stats), 2), "+-", + round(np.std(stats), 2)) + + +@contextmanager +def ray_setup_and_teardown(**init_args): + ray.init(**init_args) + try: + yield None + finally: + ray.shutdown() diff --git a/python/ray/experimental/client/ray_client_helpers.py b/python/ray/experimental/client/ray_client_helpers.py new file mode 100644 index 000000000..ab9d7408a --- /dev/null +++ b/python/ray/experimental/client/ray_client_helpers.py @@ -0,0 +1,16 @@ +from contextlib import contextmanager + +import ray.experimental.client.server.server as ray_client_server +from ray.experimental.client import ray, reset_api + + +@contextmanager +def ray_start_client_server(): + server = ray_client_server.serve("localhost:50051", test_mode=True) + ray.connect("localhost:50051") + try: + yield ray + finally: + ray.disconnect() + server.stop(0) + reset_api() diff --git a/python/ray/ray_perf.py b/python/ray/ray_perf.py index 35c484d9e..d1d07a8c5 100644 --- a/python/ray/ray_perf.py +++ b/python/ray/ray_perf.py @@ -2,17 +2,15 @@ import asyncio import logging -import os -import time +from ray._private.ray_microbenchmark_helpers import timeit +from ray._private.ray_client_microbenchmark import (main as + client_microbenchmark_main) import numpy as np import multiprocessing import ray logger = logging.getLogger(__name__) -# Only run tests matching this filter pattern. -filter_pattern = os.environ.get("TESTS_TO_RUN", "") - @ray.remote(num_cpus=0) class Actor: @@ -71,27 +69,6 @@ def small_value_batch(n): return 0 -def timeit(name, fn, multiplier=1): - if filter_pattern not in name: - return - # warmup - start = time.time() - while time.time() - start < 1: - fn() - # real run - stats = [] - for _ in range(4): - start = time.time() - count = 0 - while time.time() - start < 2: - fn() - count += 1 - end = time.time() - stats.append(multiplier * count / (end - start)) - print(name, "per second", round(np.mean(stats), 2), "+-", - round(np.std(stats), 2)) - - def check_optimized_build(): if not ray._raylet.OPTIMIZED: msg = ("WARNING: Unoptimized build! " @@ -277,6 +254,9 @@ def main(): ray.get([async_actor_work.remote(a) for _ in range(m)]) timeit("n:n async-actor calls async", async_actor_multi, m * n) + ray.shutdown() + + client_microbenchmark_main() if __name__ == "__main__": diff --git a/python/ray/tests/test_experimental_client.py b/python/ray/tests/test_experimental_client.py index e6afee042..131954ede 100644 --- a/python/ray/tests/test_experimental_client.py +++ b/python/ray/tests/test_experimental_client.py @@ -2,23 +2,9 @@ import pytest import time import sys import logging -from contextlib import contextmanager -import ray.experimental.client.server.server as ray_client_server -from ray.experimental.client import ray, reset_api from ray.experimental.client.common import ClientObjectRef - - -@contextmanager -def ray_start_client_server(): - server = ray_client_server.serve("localhost:50051", test_mode=True) - ray.connect("localhost:50051") - try: - yield ray - finally: - ray.disconnect() - server.stop(0) - reset_api() +from ray.experimental.client.ray_client_helpers import ray_start_client_server def test_real_ray_fallback(ray_start_regular_shared): diff --git a/python/ray/tests/test_experimental_client_metadata.py b/python/ray/tests/test_experimental_client_metadata.py index d0bb86c9e..f5a65cd66 100644 --- a/python/ray/tests/test_experimental_client_metadata.py +++ b/python/ray/tests/test_experimental_client_metadata.py @@ -1,4 +1,4 @@ -from ray.tests.test_experimental_client import ray_start_client_server +from ray.experimental.client.ray_client_helpers import ray_start_client_server def test_get_ray_metadata(ray_start_regular_shared): diff --git a/python/ray/tests/test_experimental_client_references.py b/python/ray/tests/test_experimental_client_references.py index 9675b9c97..4875d1ae0 100644 --- a/python/ray/tests/test_experimental_client_references.py +++ b/python/ray/tests/test_experimental_client_references.py @@ -1,4 +1,4 @@ -from ray.tests.test_experimental_client import ray_start_client_server +from ray.experimental.client.ray_client_helpers import ray_start_client_server from ray.test_utils import wait_for_condition import ray as real_ray from ray.core.generated.gcs_pb2 import ActorTableData diff --git a/python/ray/tests/test_experimental_client_terminate.py b/python/ray/tests/test_experimental_client_terminate.py index c475a5457..3936dfb24 100644 --- a/python/ray/tests/test_experimental_client_terminate.py +++ b/python/ray/tests/test_experimental_client_terminate.py @@ -1,5 +1,5 @@ import pytest -from ray.tests.test_experimental_client import ray_start_client_server +from ray.experimental.client.ray_client_helpers import ray_start_client_server from ray.tests.client_test_utils import create_remote_signal_actor from ray.test_utils import wait_for_condition from ray.exceptions import TaskCancelledError From 5b48480e2932de0aaa1853b8abcc3e1f69e59cd5 Mon Sep 17 00:00:00 2001 From: Hao Zhang Date: Mon, 21 Dec 2020 15:48:00 -0500 Subject: [PATCH 58/88] [Collective][PR 3/6] Other collectives (#12864) --- python/ray/util/collective/__init__.py | 10 +- python/ray/util/collective/collective.py | 149 ++++++++-- .../collective_group/base_collective_group.py | 28 +- .../collective_group/nccl_collective_group.py | 175 +++++++++-- .../collective/collective_group/nccl_util.py | 103 ++++++- python/ray/util/collective/const.py | 3 +- .../tests/distributed_tests/__init__.py | 0 .../test_distributed_allgather.py | 133 +++++++++ .../test_distributed_allreduce.py | 139 +++++++++ .../test_distributed_basic_apis.py | 135 +++++++++ .../test_distributed_broadcast.py | 67 +++++ .../test_distributed_reduce.py | 119 ++++++++ .../test_distributed_reducescatter.py | 128 ++++++++ .../util/collective/tests/test_allgather.py | 131 +++++++++ .../util/collective/tests/test_allreduce.py | 143 +++++++++ .../util/collective/tests/test_basic_apis.py | 127 ++++++++ .../util/collective/tests/test_broadcast.py | 67 +++++ .../tests/test_collective_2_nodes_4_gpus.py | 276 ------------------ .../test_collective_single_node_2_gpus.py | 267 ----------------- .../ray/util/collective/tests/test_reduce.py | 143 +++++++++ .../collective/tests/test_reducescatter.py | 127 ++++++++ python/ray/util/collective/tests/util.py | 67 ++++- python/ray/util/collective/types.py | 37 ++- 23 files changed, 1968 insertions(+), 606 deletions(-) create mode 100644 python/ray/util/collective/tests/distributed_tests/__init__.py create mode 100644 python/ray/util/collective/tests/distributed_tests/test_distributed_allgather.py create mode 100644 python/ray/util/collective/tests/distributed_tests/test_distributed_allreduce.py create mode 100644 python/ray/util/collective/tests/distributed_tests/test_distributed_basic_apis.py create mode 100644 python/ray/util/collective/tests/distributed_tests/test_distributed_broadcast.py create mode 100644 python/ray/util/collective/tests/distributed_tests/test_distributed_reduce.py create mode 100644 python/ray/util/collective/tests/distributed_tests/test_distributed_reducescatter.py create mode 100644 python/ray/util/collective/tests/test_allgather.py create mode 100644 python/ray/util/collective/tests/test_allreduce.py create mode 100644 python/ray/util/collective/tests/test_basic_apis.py create mode 100644 python/ray/util/collective/tests/test_broadcast.py delete mode 100644 python/ray/util/collective/tests/test_collective_2_nodes_4_gpus.py delete mode 100644 python/ray/util/collective/tests/test_collective_single_node_2_gpus.py create mode 100644 python/ray/util/collective/tests/test_reduce.py create mode 100644 python/ray/util/collective/tests/test_reducescatter.py diff --git a/python/ray/util/collective/__init__.py b/python/ray/util/collective/__init__.py index 68fcb78d4..fcc879589 100644 --- a/python/ray/util/collective/__init__.py +++ b/python/ray/util/collective/__init__.py @@ -1,9 +1,11 @@ -from .collective import nccl_available, mpi_available, is_group_initialized, \ - init_collective_group, destroy_collective_group, get_rank, \ - get_world_size, allreduce, barrier +from ray.util.collective.collective import nccl_available, mpi_available, \ + is_group_initialized, init_collective_group, destroy_collective_group, \ + get_rank, get_world_size, allreduce, barrier, reduce, broadcast, \ + allgather, reducescatter __all__ = [ "nccl_available", "mpi_available", "is_group_initialized", "init_collective_group", "destroy_collective_group", "get_rank", - "get_world_size", "allreduce", "barrier" + "get_world_size", "allreduce", "barrier", "reduce", "broadcast", + "allgather", "reducescatter" ] diff --git a/python/ray/util/collective/collective.py b/python/ray/util/collective/collective.py index 343487e71..464b116a0 100644 --- a/python/ray/util/collective/collective.py +++ b/python/ray/util/collective/collective.py @@ -32,8 +32,7 @@ def mpi_available(): class GroupManager(object): - """ - Use this class to manage the collective groups we created so far. + """Use this class to manage the collective groups we created so far. Each process will have an instance of `GroupManager`. Each process could belong to multiple collective groups. The membership information @@ -45,8 +44,7 @@ class GroupManager(object): self._group_name_map = {} def create_collective_group(self, backend, world_size, rank, group_name): - """ - The entry to create new collective groups and register in the manager. + """The entry to create new collective groups in the manager. Put the registration and the group information into the manager metadata as well. @@ -120,8 +118,7 @@ def init_collective_group(world_size: int, rank: int, backend=types.Backend.NCCL, group_name: str = "default"): - """ - Initialize a collective group inside an actor process. + """Initialize a collective group inside an actor process. Args: world_size (int): the total number of processed in the group. @@ -158,8 +155,7 @@ def destroy_collective_group(group_name: str = "default") -> None: def get_rank(group_name: str = "default") -> int: - """ - Return the rank of this process in the given group. + """Return the rank of this process in the given group. Args: group_name (str): the name of the group to query @@ -176,9 +172,8 @@ def get_rank(group_name: str = "default") -> int: return g.rank -def get_world_size(group_name="default") -> int: - """ - Return the size of the collective gropu with the given name. +def get_world_size(group_name: str = "default") -> int: + """Return the size of the collective gropu with the given name. Args: group_name: the name of the group to query @@ -195,9 +190,8 @@ def get_world_size(group_name="default") -> int: return g.world_size -def allreduce(tensor, group_name: str, op=types.ReduceOp.SUM): - """ - Collective allreduce the tensor across the group with name group_name. +def allreduce(tensor, group_name: str = "default", op=types.ReduceOp.SUM): + """Collective allreduce the tensor across the group. Args: tensor: the tensor to be all-reduced on this process. @@ -214,9 +208,8 @@ def allreduce(tensor, group_name: str, op=types.ReduceOp.SUM): g.allreduce(tensor, opts) -def barrier(group_name): - """ - Barrier all processes in the collective group. +def barrier(group_name: str = "default"): + """Barrier all processes in the collective group. Args: group_name (str): the name of the group to barrier. @@ -228,6 +221,107 @@ def barrier(group_name): g.barrier() +def reduce(tensor, + dst_rank: int = 0, + group_name: str = "default", + op=types.ReduceOp.SUM): + """Reduce the tensor across the group to the destination rank. + + Args: + tensor: the tensor to be reduced on this process. + dst_rank: the rank of the destination process. + group_name: the collective group name to perform reduce. + op: The reduce operation. + + Returns: + None + """ + _check_single_tensor_input(tensor) + g = _check_and_get_group(group_name) + + # check dst rank + _check_rank_valid(g, dst_rank) + opts = types.ReduceOptions() + opts.reduceOp = op + opts.root_rank = dst_rank + g.reduce(tensor, opts) + + +def broadcast(tensor, src_rank: int = 0, group_name: str = "default"): + """Broadcast the tensor from a source process to all others. + + Args: + tensor: the tensor to be broadcasted (src) or received (destination). + src_rank: the rank of the source process. + group_name: he collective group name to perform broadcast. + + Returns: + None + """ + _check_single_tensor_input(tensor) + g = _check_and_get_group(group_name) + + # check src rank + _check_rank_valid(g, src_rank) + opts = types.BroadcastOptions() + opts.root_rank = src_rank + g.broadcast(tensor, opts) + + +def allgather(tensor_list: list, tensor, group_name: str = "default"): + """Allgather tensors from each process of the group into a list. + + Args: + tensor_list (list): the results, stored as a list of tensors. + tensor: the tensor (to be gathered) in the current process + group_name: the name of the collective group. + + Returns: + None + """ + _check_single_tensor_input(tensor) + _check_tensor_list_input(tensor_list) + g = _check_and_get_group(group_name) + if len(tensor_list) != g.world_size: + # Typically CLL lib requires len(tensor_list) >= world_size; + # Here we make it more strict: len(tensor_list) == world_size. + raise RuntimeError( + "The length of the tensor list operands to allgather " + "must not be equal to world_size.") + opts = types.AllGatherOptions() + g.allgather(tensor_list, tensor, opts) + + +def reducescatter(tensor, + tensor_list: list, + group_name: str = "default", + op=types.ReduceOp.SUM): + """Reducescatter a list of tensors across the group. + + Reduce the list of the tensors across each process in the group, then + scatter the reduced list of tensors -- one tensor for each process. + + Args: + tensor: the resulted tensor on this process. + tensor_list (list): The list of tensors to be reduced and scattered. + group_name (str): the name of the collective group. + op: The reduce operation. + + Returns: + None + """ + _check_single_tensor_input(tensor) + _check_tensor_list_input(tensor_list) + g = _check_and_get_group(group_name) + if len(tensor_list) != g.world_size: + raise RuntimeError( + "The length of the tensor list operands to reducescatter " + "must not be equal to world_size.") + opts = types.ReduceScatterOptions() + opts.reduceOp = op + g.reducescatter(tensor, tensor_list, opts) + + def _check_and_get_group(group_name): """Check the existence and return the group handle.""" _check_inside_actor() @@ -244,8 +338,6 @@ def _check_backend_availability(backend: types.Backend): if not mpi_available(): raise RuntimeError("MPI is not available.") elif backend == types.Backend.NCCL: - # expect some slowdown at the first call - # as I defer the import to invocation. if not nccl_available(): raise RuntimeError("NCCL is not available.") @@ -273,3 +365,22 @@ def _check_inside_actor(): else: raise RuntimeError("The collective APIs shall be only used inside " "a Ray actor or task.") + + +def _check_rank_valid(g, rank: int): + if rank < 0: + raise ValueError("rank '{}' is negative.".format(rank)) + if rank > g.world_size: + raise ValueError("rank '{}' is greater than world size " + "'{}'".format(rank, g.world_size)) + + +def _check_tensor_list_input(tensor_list): + """Check if the input is a list of supported tensor types.""" + if not isinstance(tensor_list, list): + raise RuntimeError("The input must be a list of tensors. " + "Got '{}'.".format(type(tensor_list))) + if not tensor_list: + raise RuntimeError("Got an empty list of tensors.") + for t in tensor_list: + _check_single_tensor_input(t) diff --git a/python/ray/util/collective/collective_group/base_collective_group.py b/python/ray/util/collective/collective_group/base_collective_group.py index a3f54fa26..81caf1a6b 100644 --- a/python/ray/util/collective/collective_group/base_collective_group.py +++ b/python/ray/util/collective/collective_group/base_collective_group.py @@ -2,13 +2,13 @@ from abc import ABCMeta from abc import abstractmethod -from ray.util.collective.types import AllReduceOptions, BarrierOptions +from ray.util.collective.types import AllReduceOptions, BarrierOptions, \ + ReduceOptions, AllGatherOptions, BroadcastOptions, ReduceScatterOptions class BaseGroup(metaclass=ABCMeta): def __init__(self, world_size, rank, group_name): - """ - Init the process group with basic information. + """Init the process group with basic information. Args: world_size (int): The total number of processes in the group. @@ -50,3 +50,25 @@ class BaseGroup(metaclass=ABCMeta): @abstractmethod def barrier(self, barrier_options=BarrierOptions()): raise NotImplementedError() + + @abstractmethod + def reduce(self, tensor, reduce_options=ReduceOptions()): + raise NotImplementedError() + + @abstractmethod + def allgather(self, + tensor_list, + tensor, + allgather_options=AllGatherOptions()): + raise NotImplementedError() + + @abstractmethod + def broadcast(self, tensor, broadcast_options=BroadcastOptions()): + raise NotImplementedError() + + @abstractmethod + def reducescatter(self, + tensor, + tensor_list, + reducescatter_options=ReduceScatterOptions()): + raise NotImplementedError() diff --git a/python/ray/util/collective/collective_group/nccl_collective_group.py b/python/ray/util/collective/collective_group/nccl_collective_group.py index 31412b5a4..4341f8e67 100644 --- a/python/ray/util/collective/collective_group/nccl_collective_group.py +++ b/python/ray/util/collective/collective_group/nccl_collective_group.py @@ -9,7 +9,8 @@ from ray.util.collective.collective_group import nccl_util from ray.util.collective.collective_group.base_collective_group \ import BaseGroup from ray.util.collective.types import AllReduceOptions, \ - BarrierOptions, Backend + BarrierOptions, Backend, ReduceOptions, BroadcastOptions, \ + AllGatherOptions, ReduceScatterOptions from ray.util.collective.const import get_nccl_store_name logger = logging.getLogger(__name__) @@ -21,8 +22,7 @@ logger = logging.getLogger(__name__) class Rendezvous: - """ - A rendezvous class for different actor/task processes to meet. + """A rendezvous class for different actor/task processes to meet. To initialize an NCCL collective communication group, different actors/tasks spawned in Ray in a collective group needs to meet @@ -42,8 +42,7 @@ class Rendezvous: self._store = None def meet(self, timeout_s=180): - """ - Meet at the named actor store. + """Meet at the named actor store. Args: timeout_s: timeout in seconds. @@ -80,8 +79,7 @@ class Rendezvous: return self._store def get_nccl_id(self, timeout_s=180): - """ - Get the NCCLUniqueID from the store through Ray. + """Get the NCCLUniqueID from the store through Ray. Args: timeout_s: timeout in seconds. @@ -132,10 +130,7 @@ class NCCLGroup(BaseGroup): self._barrier_tensor = cupy.array([1]) def _init_nccl_unique_id(self): - """ - Init the NCCL unique ID required for setting up NCCL communicator. - - """ + """Init the NCCLUniqueID required for creating NCCL communicators.""" self._nccl_uid = self._rendezvous.get_nccl_id() @property @@ -143,10 +138,7 @@ class NCCLGroup(BaseGroup): return self._nccl_uid def destroy_group(self): - """ - Destroy the group and release the NCCL communicators safely. - - """ + """Destroy the group and release the NCCL communicators safely.""" if self._nccl_comm is not None: self.barrier() # We also need a barrier call here. @@ -162,8 +154,7 @@ class NCCLGroup(BaseGroup): return Backend.NCCL def allreduce(self, tensor, allreduce_options=AllReduceOptions()): - """ - AllReduce a list of tensors following options. + """AllReduce the tensor across the collective group following options. Args: tensor: the tensor to be reduced, each tensor locates on a GPU @@ -186,8 +177,7 @@ class NCCLGroup(BaseGroup): comm.allReduce(ptr, ptr, n_elems, dtype, reduce_op, stream.ptr) def barrier(self, barrier_options=BarrierOptions()): - """ - Blocks until all processes reach this barrier. + """Blocks until all processes reach this barrier. Args: barrier_options: @@ -196,9 +186,108 @@ class NCCLGroup(BaseGroup): """ self.allreduce(self._barrier_tensor) - def _get_nccl_communicator(self): + def reduce(self, tensor, reduce_options=ReduceOptions()): + """Reduce tensor to a destination process following options. + + Args: + tensor: the tensor to be reduced. + reduce_options: reduce options + + Returns: + None """ - Create or use a cached NCCL communicator for the collective task. + comm = self._get_nccl_communicator() + stream = self._get_cuda_stream() + + dtype = nccl_util.get_nccl_tensor_dtype(tensor) + ptr = nccl_util.get_tensor_ptr(tensor) + n_elems = nccl_util.get_tensor_n_elements(tensor) + reduce_op = nccl_util.get_nccl_reduce_op(reduce_options.reduceOp) + + # in-place reduce + comm.reduce(ptr, ptr, n_elems, dtype, reduce_op, + reduce_options.root_rank, stream.ptr) + + def broadcast(self, tensor, broadcast_options=BroadcastOptions()): + """Broadcast tensor to all other processes following options. + + Args: + tensor: the tensor to be broadcasted. + broadcast_options: broadcast options. + + Returns: + None + """ + comm = self._get_nccl_communicator() + stream = self._get_cuda_stream() + + dtype = nccl_util.get_nccl_tensor_dtype(tensor) + ptr = nccl_util.get_tensor_ptr(tensor) + n_elems = nccl_util.get_tensor_n_elements(tensor) + # in-place broadcast + comm.broadcast(ptr, ptr, n_elems, dtype, broadcast_options.root_rank, + stream.ptr) + + def allgather(self, + tensor_list, + tensor, + allgather_options=AllGatherOptions()): + """Allgather tensors across the group into a list of tensors. + + Args: + tensor_list: the tensor list to store the results. + tensor: the tensor to be allgather-ed across the group. + allgather_options: allgather options. + + Returns: + None + """ + + _check_inputs_compatibility_for_scatter_gather(tensor, tensor_list) + comm = self._get_nccl_communicator() + stream = self._get_cuda_stream() + + dtype = nccl_util.get_nccl_tensor_dtype(tensor) + send_ptr = nccl_util.get_tensor_ptr(tensor) + n_elems = nccl_util.get_tensor_n_elements(tensor) + flattened = _flatten_for_scatter_gather(tensor_list, copy=False) + recv_ptr = nccl_util.get_tensor_ptr(flattened) + comm.allGather(send_ptr, recv_ptr, n_elems, dtype, stream.ptr) + for i, t in enumerate(tensor_list): + nccl_util.copy_tensor(t, flattened[i]) + + def reducescatter(self, + tensor, + tensor_list, + reducescatter_options=ReduceScatterOptions()): + """Reducescatter a list of tensors across the group. + + Args: + tensor: the output after reducescatter (could be unspecified). + tensor_list: the list of tensor to be reduce and scattered. + reducescatter_options: reducescatter options. + + Returns: + None + """ + _check_inputs_compatibility_for_scatter_gather(tensor, tensor_list) + + comm = self._get_nccl_communicator() + stream = self._get_cuda_stream() + dtype = nccl_util.get_nccl_tensor_dtype(tensor_list[0]) + n_elems = nccl_util.get_tensor_n_elements(tensor_list[0]) + reduce_op = nccl_util.get_nccl_reduce_op( + reducescatter_options.reduceOp) + + # get the send_ptr + flattened = _flatten_for_scatter_gather(tensor_list, copy=True) + send_ptr = nccl_util.get_tensor_ptr(flattened) + recv_ptr = nccl_util.get_tensor_ptr(tensor) + comm.reduceScatter(send_ptr, recv_ptr, n_elems, dtype, reduce_op, + stream.ptr) + + def _get_nccl_communicator(self): + """Create or use a cached NCCL communicator for the collective task. """ # TODO(Hao): later change this to use device keys and query from cache. @@ -217,3 +306,47 @@ class NCCLGroup(BaseGroup): # def _collective_call(self, *args): # """Private method to encapsulate all collective calls""" # pass + + +def _flatten_for_scatter_gather(tensor_list, copy=False): + """Flatten the tensor for gather/scatter operations. + + Args: + tensor_list: the list of tensors to be scattered/gathered. + copy: whether the copy the tensors in tensor_list into the buffer. + + Returns: + The flattened tensor buffer. + """ + if not tensor_list: + raise RuntimeError("Received an empty list.") + t = tensor_list[0] + # note we need a cupy dtype here. + dtype = nccl_util.get_cupy_tensor_dtype(t) + buffer_shape = [len(tensor_list)] + nccl_util.get_tensor_shape(t) + buffer = cupy.empty(buffer_shape, dtype=dtype) + if copy: + for i, tensor in enumerate(tensor_list): + nccl_util.copy_tensor(buffer[i], tensor) + return buffer + + +def _check_inputs_compatibility_for_scatter_gather(tensor, tensor_list): + """Check the compatibility between tensor input and tensor list inputs.""" + if not tensor_list: + raise RuntimeError("Got empty list of tensors.") + dtype = nccl_util.get_nccl_tensor_dtype(tensor) + shape = nccl_util.get_tensor_shape(tensor) + for t in tensor_list: + # check dtype + dt = nccl_util.get_nccl_tensor_dtype(t) + if dt != dtype: + raise RuntimeError("All tensor operands to scatter/gather must " + "have the same dtype. Got '{}' and '{}'" + "".format(dt, dtype)) + # Note: typically CCL libraries only requires they have the same + # number of elements; + # Here we make it more strict -- we require exact shape match. + if nccl_util.get_tensor_shape(t) != shape: + raise RuntimeError("All tensor operands to scatter/gather must " + "have the same shape.") diff --git a/python/ray/util/collective/collective_group/nccl_util.py b/python/ray/util/collective/collective_group/nccl_util.py index 4d2fc456f..da9ced35a 100644 --- a/python/ray/util/collective/collective_group/nccl_util.py +++ b/python/ray/util/collective/collective_group/nccl_util.py @@ -28,6 +28,7 @@ NUMPY_NCCL_DTYPE_MAP = { if torch_available(): import torch + import torch.utils.dlpack TORCH_NCCL_DTYPE_MAP = { torch.uint8: nccl.NCCL_UINT8, torch.float16: nccl.NCCL_FLOAT16, @@ -35,6 +36,13 @@ if torch_available(): torch.float64: nccl.NCCL_FLOAT64, } + TORCH_NUMPY_DTYPE_MAP = { + torch.uint8: numpy.uint8, + torch.float16: numpy.float16, + torch.float32: numpy.float32, + torch.float64: numpy.float64, + } + def get_nccl_build_version(): return get_build_version() @@ -49,8 +57,7 @@ def get_nccl_unique_id(): def create_nccl_communicator(world_size, nccl_unique_id, rank): - """ - Create an NCCL communicator using NCCL APIs. + """Create an NCCL communicator using NCCL APIs. Args: world_size (int): the number of processes of this communcator group. @@ -66,8 +73,7 @@ def create_nccl_communicator(world_size, nccl_unique_id, rank): def get_nccl_reduce_op(reduce_op): - """ - Map the reduce op to NCCL reduce op type. + """Map the reduce op to NCCL reduce op type. Args: reduce_op (ReduceOp): ReduceOp Enum (SUM/PRODUCT/MIN/MAX). @@ -87,8 +93,21 @@ def get_nccl_tensor_dtype(tensor): if torch_available(): if isinstance(tensor, torch.Tensor): return TORCH_NCCL_DTYPE_MAP[tensor.dtype] - raise ValueError("Unsupported tensor type. " - "Got: {}.".format(type(tensor))) + raise ValueError("Unsupported tensor type. Got: {}. Supported " + "GPU tensor types are: torch.Tensor, " + "cupy.ndarray.".format(type(tensor))) + + +def get_cupy_tensor_dtype(tensor): + """Return the corresponded Cupy dtype given a tensor.""" + if isinstance(tensor, cupy.ndarray): + return tensor.dtype.type + if torch_available(): + if isinstance(tensor, torch.Tensor): + return TORCH_NUMPY_DTYPE_MAP[tensor.dtype] + raise ValueError("Unsupported tensor type. Got: {}. Supported " + "GPU tensor types are: torch.Tensor, " + "cupy.ndarray.".format(type(tensor))) def get_tensor_ptr(tensor): @@ -102,8 +121,9 @@ def get_tensor_ptr(tensor): if not tensor.is_cuda: raise RuntimeError("torch tensor must be on gpu.") return tensor.data_ptr() - raise ValueError("Unsupported tensor type. " - "Got: {}.".format(type(tensor))) + raise ValueError("Unsupported tensor type. Got: {}. Supported " + "GPU tensor types are: torch.Tensor, " + "cupy.ndarray.".format(type(tensor))) def get_tensor_n_elements(tensor): @@ -113,5 +133,68 @@ def get_tensor_n_elements(tensor): if torch_available(): if isinstance(tensor, torch.Tensor): return torch.numel(tensor) - raise ValueError("Unsupported tensor type. " - "Got: {}.".format(type(tensor))) + raise ValueError("Unsupported tensor type. Got: {}. Supported " + "GPU tensor types are: torch.Tensor, " + "cupy.ndarray.".format(type(tensor))) + + +def get_tensor_shape(tensor): + """Return the shape of the tensor as a list.""" + if isinstance(tensor, cupy.ndarray): + return list(tensor.shape) + if torch_available(): + if isinstance(tensor, torch.Tensor): + return list(tensor.size()) + raise ValueError("Unsupported tensor type. Got: {}. Supported " + "GPU tensor types are: torch.Tensor, " + "cupy.ndarray.".format(type(tensor))) + + +def get_tensor_strides(tensor): + """Return the strides of the tensor as a list.""" + if isinstance(tensor, cupy.ndarray): + return [ + int(stride / tensor.dtype.itemsize) for stride in tensor.strides + ] + if torch_available(): + if isinstance(tensor, torch.Tensor): + return list(tensor.stride()) + raise ValueError("Unsupported tensor type. Got: {}. Supported " + "GPU tensor types are: torch.Tensor, " + "cupy.ndarray.".format(type(tensor))) + + +def copy_tensor(dst_tensor, src_tensor): + """Copy the content from src_tensor to dst_tensor. + + Args: + dst_tensor: the tensor to copy from. + src_tensor: the tensor to copy to. + + Returns: + None + """ + copied = True + if isinstance(dst_tensor, cupy.ndarray) \ + and isinstance(src_tensor, cupy.ndarray): + cupy.copyto(dst_tensor, src_tensor) + elif torch_available(): + if isinstance(dst_tensor, torch.Tensor) and isinstance( + src_tensor, torch.Tensor): + dst_tensor.copy_(src_tensor) + elif isinstance(dst_tensor, torch.Tensor) and isinstance( + src_tensor, cupy.ndarray): + t = torch.utils.dlpack.from_dlpack(src_tensor.toDlpack()) + dst_tensor.copy_(t) + elif isinstance(dst_tensor, cupy.ndarray) and isinstance( + src_tensor, torch.Tensor): + t = cupy.fromDlpack(torch.utils.dlpack.to_dlpack(src_tensor)) + cupy.copyto(dst_tensor, t) + else: + copied = False + else: + copied = False + if not copied: + raise ValueError("Unsupported tensor type. Got: {} and {}. Supported " + "GPU tensor types are: torch.Tensor, cupy.ndarray." + .format(type(dst_tensor), type(src_tensor))) diff --git a/python/ray/util/collective/const.py b/python/ray/util/collective/const.py index 6eded9c51..ebc48982d 100644 --- a/python/ray/util/collective/const.py +++ b/python/ray/util/collective/const.py @@ -7,8 +7,7 @@ import hashlib def get_nccl_store_name(group_name): - """ - Generate the unique name for the NCCLUniqueID store (named actor). + """Generate the unique name for the NCCLUniqueID store (named actor). Args: group_name (str): unique user name for the store. diff --git a/python/ray/util/collective/tests/distributed_tests/__init__.py b/python/ray/util/collective/tests/distributed_tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/python/ray/util/collective/tests/distributed_tests/test_distributed_allgather.py b/python/ray/util/collective/tests/distributed_tests/test_distributed_allgather.py new file mode 100644 index 000000000..5a369c852 --- /dev/null +++ b/python/ray/util/collective/tests/distributed_tests/test_distributed_allgather.py @@ -0,0 +1,133 @@ +"""Test the allgather API on a distributed Ray cluster.""" +import pytest +import ray + +import cupy as cp +import torch + +from ray.util.collective.tests.util import create_collective_workers, \ + init_tensors_for_gather_scatter + + +@pytest.mark.parametrize("tensor_backend", ["cupy", "torch"]) +@pytest.mark.parametrize("array_size", + [2, 2**5, 2**10, 2**15, 2**20, [2, 2], [5, 5, 5]]) +def test_allgather_different_array_size(ray_start_distributed_2_nodes_4_gpus, + array_size, tensor_backend): + world_size = 4 + actors, _ = create_collective_workers(world_size) + init_tensors_for_gather_scatter( + actors, array_size=array_size, tensor_backend=tensor_backend) + results = ray.get([a.do_allgather.remote() for a in actors]) + for i in range(world_size): + for j in range(world_size): + if tensor_backend == "cupy": + assert (results[i][j] == cp.ones(array_size, dtype=cp.float32) + * (j + 1)).all() + else: + assert (results[i][j] == torch.ones( + array_size, dtype=torch.float32).cuda() * (j + 1)).all() + + +@pytest.mark.parametrize("dtype", + [cp.uint8, cp.float16, cp.float32, cp.float64]) +def test_allgather_different_dtype(ray_start_distributed_2_nodes_4_gpus, + dtype): + world_size = 4 + actors, _ = create_collective_workers(world_size) + init_tensors_for_gather_scatter(actors, dtype=dtype) + results = ray.get([a.do_allgather.remote() for a in actors]) + for i in range(world_size): + for j in range(world_size): + assert (results[i][j] == cp.ones(10, dtype=dtype) * (j + 1)).all() + + +@pytest.mark.parametrize("length", [0, 1, 3, 4, 7, 8]) +def test_unmatched_tensor_list_length(ray_start_distributed_2_nodes_4_gpus, + length): + world_size = 4 + actors, _ = create_collective_workers(world_size) + list_buffer = [cp.ones(10, dtype=cp.float32) for _ in range(length)] + ray.wait([a.set_list_buffer.remote(list_buffer) for a in actors]) + if length != world_size: + with pytest.raises(RuntimeError): + ray.get([a.do_allgather.remote() for a in actors]) + else: + ray.get([a.do_allgather.remote() for a in actors]) + + +@pytest.mark.parametrize("shape", [10, 20, [4, 5], [1, 3, 5, 7]]) +def test_unmatched_tensor_shape(ray_start_distributed_2_nodes_4_gpus, shape): + world_size = 4 + actors, _ = create_collective_workers(world_size) + init_tensors_for_gather_scatter(actors, array_size=10) + list_buffer = [cp.ones(shape, dtype=cp.float32) for _ in range(world_size)] + ray.get([a.set_list_buffer.remote(list_buffer) for a in actors]) + if shape != 10: + with pytest.raises(RuntimeError): + ray.get([a.do_allgather.remote() for a in actors]) + else: + ray.get([a.do_allgather.remote() for a in actors]) + + +def test_allgather_torch_cupy(ray_start_distributed_2_nodes_4_gpus): + world_size = 4 + shape = [10, 10] + actors, _ = create_collective_workers(world_size) + + # tensor is pytorch, list is cupy + for i, a in enumerate(actors): + t = torch.ones(shape, dtype=torch.float32).cuda() * (i + 1) + ray.wait([a.set_buffer.remote(t)]) + list_buffer = [ + cp.ones(shape, dtype=cp.float32) for _ in range(world_size) + ] + ray.wait([a.set_list_buffer.remote(list_buffer)]) + results = ray.get([a.do_allgather.remote() for a in actors]) + for i in range(world_size): + for j in range(world_size): + assert (results[i][j] == cp.ones(shape, dtype=cp.float32) * + (j + 1)).all() + + # tensor is cupy, list is pytorch + for i, a in enumerate(actors): + t = cp.ones(shape, dtype=cp.float32) * (i + 1) + ray.wait([a.set_buffer.remote(t)]) + list_buffer = [ + torch.ones(shape, dtype=torch.float32).cuda() + for _ in range(world_size) + ] + ray.wait([a.set_list_buffer.remote(list_buffer)]) + results = ray.get([a.do_allgather.remote() for a in actors]) + for i in range(world_size): + for j in range(world_size): + assert (results[i][j] == torch.ones( + shape, dtype=torch.float32).cuda() * (j + 1)).all() + + # some tensors in the list are pytorch, some are cupy + for i, a in enumerate(actors): + t = cp.ones(shape, dtype=cp.float32) * (i + 1) + ray.wait([a.set_buffer.remote(t)]) + list_buffer = [] + for j in range(world_size): + if j % 2 == 0: + list_buffer.append( + torch.ones(shape, dtype=torch.float32).cuda()) + else: + list_buffer.append(cp.ones(shape, dtype=cp.float32)) + ray.wait([a.set_list_buffer.remote(list_buffer)]) + results = ray.get([a.do_allgather.remote() for a in actors]) + for i in range(world_size): + for j in range(world_size): + if j % 2 == 0: + assert (results[i][j] == torch.ones( + shape, dtype=torch.float32).cuda() * (j + 1)).all() + else: + assert (results[i][j] == cp.ones(shape, dtype=cp.float32) * + (j + 1)).all() + + +if __name__ == "__main__": + import pytest + import sys + sys.exit(pytest.main(["-v", "-x", __file__])) diff --git a/python/ray/util/collective/tests/distributed_tests/test_distributed_allreduce.py b/python/ray/util/collective/tests/distributed_tests/test_distributed_allreduce.py new file mode 100644 index 000000000..35aae35b2 --- /dev/null +++ b/python/ray/util/collective/tests/distributed_tests/test_distributed_allreduce.py @@ -0,0 +1,139 @@ +"""Test the collective allreduice API on a distributed Ray cluster.""" +import pytest +import ray +from ray.util.collective.types import ReduceOp + +import cupy as cp +import torch + +from ray.util.collective.tests.util import create_collective_workers + + +@pytest.mark.parametrize("group_name", ["default", "test", "123?34!"]) +@pytest.mark.parametrize("world_size", [2, 3, 4]) +def test_allreduce_different_name(ray_start_distributed_2_nodes_4_gpus, + group_name, world_size): + actors, _ = create_collective_workers( + num_workers=world_size, group_name=group_name) + results = ray.get([a.do_allreduce.remote(group_name) for a in actors]) + assert (results[0] == cp.ones((10, ), dtype=cp.float32) * world_size).all() + assert (results[1] == cp.ones((10, ), dtype=cp.float32) * world_size).all() + + +@pytest.mark.parametrize("array_size", [2, 2**5, 2**10, 2**15, 2**20]) +def test_allreduce_different_array_size(ray_start_distributed_2_nodes_4_gpus, + array_size): + world_size = 4 + actors, _ = create_collective_workers(world_size) + ray.wait([ + a.set_buffer.remote(cp.ones(array_size, dtype=cp.float32)) + for a in actors + ]) + results = ray.get([a.do_allreduce.remote() for a in actors]) + assert (results[0] == cp.ones( + (array_size, ), dtype=cp.float32) * world_size).all() + assert (results[1] == cp.ones( + (array_size, ), dtype=cp.float32) * world_size).all() + + +def test_allreduce_destroy(ray_start_distributed_2_nodes_4_gpus, + backend="nccl", + group_name="default"): + world_size = 4 + actors, _ = create_collective_workers(world_size) + + results = ray.get([a.do_allreduce.remote() for a in actors]) + assert (results[0] == cp.ones((10, ), dtype=cp.float32) * world_size).all() + assert (results[1] == cp.ones((10, ), dtype=cp.float32) * world_size).all() + + # destroy the group and try do work, should fail + ray.wait([a.destroy_group.remote() for a in actors]) + with pytest.raises(RuntimeError): + results = ray.get([a.do_allreduce.remote() for a in actors]) + + # reinit the same group and all reduce + ray.get([ + actor.init_group.remote(world_size, i, backend, group_name) + for i, actor in enumerate(actors) + ]) + results = ray.get([a.do_allreduce.remote() for a in actors]) + assert (results[0] == cp.ones( + (10, ), dtype=cp.float32) * world_size * world_size).all() + assert (results[1] == cp.ones( + (10, ), dtype=cp.float32) * world_size * world_size).all() + + +def test_allreduce_multiple_group(ray_start_distributed_2_nodes_4_gpus, + backend="nccl", + num_groups=5): + world_size = 4 + actors, _ = create_collective_workers(world_size) + for group_name in range(1, num_groups): + ray.get([ + actor.init_group.remote(world_size, i, backend, str(group_name)) + for i, actor in enumerate(actors) + ]) + for i in range(num_groups): + group_name = "default" if i == 0 else str(i) + results = ray.get([a.do_allreduce.remote(group_name) for a in actors]) + assert (results[0] == cp.ones( + (10, ), dtype=cp.float32) * (world_size**(i + 1))).all() + + +def test_allreduce_different_op(ray_start_distributed_2_nodes_4_gpus): + world_size = 4 + actors, _ = create_collective_workers(world_size) + + # check product + ray.wait([ + a.set_buffer.remote(cp.ones(10, dtype=cp.float32) * (i + 2)) + for i, a in enumerate(actors) + ]) + results = ray.get( + [a.do_allreduce.remote(op=ReduceOp.PRODUCT) for a in actors]) + assert (results[0] == cp.ones((10, ), dtype=cp.float32) * 120).all() + assert (results[1] == cp.ones((10, ), dtype=cp.float32) * 120).all() + + # check min + ray.wait([ + a.set_buffer.remote(cp.ones(10, dtype=cp.float32) * (i + 2)) + for i, a in enumerate(actors) + ]) + results = ray.get([a.do_allreduce.remote(op=ReduceOp.MIN) for a in actors]) + assert (results[0] == cp.ones((10, ), dtype=cp.float32) * 2).all() + assert (results[1] == cp.ones((10, ), dtype=cp.float32) * 2).all() + + # check max + ray.wait([ + a.set_buffer.remote(cp.ones(10, dtype=cp.float32) * (i + 2)) + for i, a in enumerate(actors) + ]) + results = ray.get([a.do_allreduce.remote(op=ReduceOp.MAX) for a in actors]) + assert (results[0] == cp.ones((10, ), dtype=cp.float32) * 5).all() + assert (results[1] == cp.ones((10, ), dtype=cp.float32) * 5).all() + + +@pytest.mark.parametrize("dtype", + [cp.uint8, cp.float16, cp.float32, cp.float64]) +def test_allreduce_different_dtype(ray_start_distributed_2_nodes_4_gpus, + dtype): + world_size = 4 + actors, _ = create_collective_workers(world_size) + ray.wait([a.set_buffer.remote(cp.ones(10, dtype=dtype)) for a in actors]) + results = ray.get([a.do_allreduce.remote() for a in actors]) + assert (results[0] == cp.ones((10, ), dtype=dtype) * world_size).all() + assert (results[1] == cp.ones((10, ), dtype=dtype) * world_size).all() + + +def test_allreduce_torch_cupy(ray_start_distributed_2_nodes_4_gpus): + # import torch + world_size = 4 + actors, _ = create_collective_workers(world_size) + ray.wait([actors[1].set_buffer.remote(torch.ones(10, ).cuda())]) + results = ray.get([a.do_allreduce.remote() for a in actors]) + assert (results[0] == cp.ones((10, )) * world_size).all() + + ray.wait([actors[0].set_buffer.remote(torch.ones(10, ))]) + ray.wait([actors[1].set_buffer.remote(cp.ones(10, ))]) + with pytest.raises(RuntimeError): + results = ray.get([a.do_allreduce.remote() for a in actors]) diff --git a/python/ray/util/collective/tests/distributed_tests/test_distributed_basic_apis.py b/python/ray/util/collective/tests/distributed_tests/test_distributed_basic_apis.py new file mode 100644 index 000000000..0f17b79ba --- /dev/null +++ b/python/ray/util/collective/tests/distributed_tests/test_distributed_basic_apis.py @@ -0,0 +1,135 @@ +"""Test the collective group APIs.""" +import pytest +import ray +from random import shuffle + +from ray.util.collective.tests.util import Worker, \ + create_collective_workers + + +@pytest.mark.parametrize("world_size", [2, 3, 4]) +@pytest.mark.parametrize("group_name", ["default", "test", "123?34!"]) +def test_init_two_actors(ray_start_distributed_2_nodes_4_gpus, world_size, + group_name): + actors, results = create_collective_workers(world_size, group_name) + for i in range(world_size): + assert (results[i]) + + +@pytest.mark.parametrize("world_size", [2, 3, 4]) +def test_init_multiple_groups(ray_start_distributed_2_nodes_4_gpus, + world_size): + num_groups = 1 + actors = [Worker.remote() for _ in range(world_size)] + for i in range(num_groups): + group_name = str(i) + init_results = ray.get([ + actor.init_group.remote(world_size, i, group_name=group_name) + for i, actor in enumerate(actors) + ]) + for j in range(world_size): + assert init_results[j] + + +@pytest.mark.parametrize("world_size", [2, 3, 4]) +def test_get_rank(ray_start_distributed_2_nodes_4_gpus, world_size): + actors, _ = create_collective_workers(world_size) + actor0_rank = ray.get(actors[0].report_rank.remote()) + assert actor0_rank == 0 + actor1_rank = ray.get(actors[1].report_rank.remote()) + assert actor1_rank == 1 + + # create a second group with a different name, and different + # orders of ranks. + new_group_name = "default2" + ranks = list(range(world_size)) + shuffle(ranks) + _ = ray.get([ + actor.init_group.remote( + world_size, ranks[i], group_name=new_group_name) + for i, actor in enumerate(actors) + ]) + actor0_rank = ray.get(actors[0].report_rank.remote(new_group_name)) + assert actor0_rank == ranks[0] + actor1_rank = ray.get(actors[1].report_rank.remote(new_group_name)) + assert actor1_rank == ranks[1] + + +@pytest.mark.parametrize("world_size", [2, 3, 4]) +def test_get_world_size(ray_start_distributed_2_nodes_4_gpus, world_size): + actors, _ = create_collective_workers(world_size) + actor0_world_size = ray.get(actors[0].report_world_size.remote()) + actor1_world_size = ray.get(actors[1].report_world_size.remote()) + assert actor0_world_size == actor1_world_size == world_size + + +def test_availability(ray_start_distributed_2_nodes_4_gpus): + world_size = 4 + actors, _ = create_collective_workers(world_size) + actor0_nccl_availability = ray.get( + actors[0].report_nccl_availability.remote()) + assert actor0_nccl_availability + actor0_mpi_availability = ray.get( + actors[0].report_mpi_availability.remote()) + assert not actor0_mpi_availability + + +def test_is_group_initialized(ray_start_distributed_2_nodes_4_gpus): + world_size = 4 + actors, _ = create_collective_workers(world_size) + # check group is_init + actor0_is_init = ray.get(actors[0].report_is_group_initialized.remote()) + assert actor0_is_init + actor0_is_init = ray.get( + actors[0].report_is_group_initialized.remote("random")) + assert not actor0_is_init + actor0_is_init = ray.get( + actors[0].report_is_group_initialized.remote("123")) + assert not actor0_is_init + actor1_is_init = ray.get(actors[0].report_is_group_initialized.remote()) + assert actor1_is_init + actor1_is_init = ray.get( + actors[0].report_is_group_initialized.remote("456")) + assert not actor1_is_init + + +def test_destroy_group(ray_start_distributed_2_nodes_4_gpus): + world_size = 4 + actors, _ = create_collective_workers(world_size) + # Now destroy the group at actor0 + ray.wait([actors[0].destroy_group.remote()]) + actor0_is_init = ray.get(actors[0].report_is_group_initialized.remote()) + assert not actor0_is_init + + # should go well as the group `random` does not exist at all + ray.wait([actors[0].destroy_group.remote("random")]) + + actor1_is_init = ray.get(actors[1].report_is_group_initialized.remote()) + assert actor1_is_init + ray.wait([actors[1].destroy_group.remote("random")]) + actor1_is_init = ray.get(actors[1].report_is_group_initialized.remote()) + assert actor1_is_init + ray.wait([actors[1].destroy_group.remote("default")]) + actor1_is_init = ray.get(actors[1].report_is_group_initialized.remote()) + assert not actor1_is_init + for i in [2, 3]: + ray.wait([actors[i].destroy_group.remote("default")]) + + # Now reconstruct the group using the same name + init_results = ray.get([ + actor.init_group.remote(world_size, i) + for i, actor in enumerate(actors) + ]) + for i in range(world_size): + assert init_results[i] + actor0_is_init = ray.get(actors[0].report_is_group_initialized.remote()) + assert actor0_is_init + actor1_is_init = ray.get(actors[0].report_is_group_initialized.remote()) + assert actor1_is_init + + +if __name__ == "__main__": + import pytest + import sys + + sys.exit(pytest.main(["-v", "-x", __file__])) diff --git a/python/ray/util/collective/tests/distributed_tests/test_distributed_broadcast.py b/python/ray/util/collective/tests/distributed_tests/test_distributed_broadcast.py new file mode 100644 index 000000000..408ebce76 --- /dev/null +++ b/python/ray/util/collective/tests/distributed_tests/test_distributed_broadcast.py @@ -0,0 +1,67 @@ +"""Test the broadcast API.""" +import pytest +import cupy as cp +import ray + +from ray.util.collective.tests.util import create_collective_workers + + +@pytest.mark.parametrize("group_name", ["default", "test", "123?34!"]) +@pytest.mark.parametrize("src_rank", [0, 1, 2, 3]) +def test_broadcast_different_name(ray_start_distributed_2_nodes_4_gpus, + group_name, src_rank): + world_size = 4 + actors, _ = create_collective_workers( + num_workers=world_size, group_name=group_name) + ray.wait([ + a.set_buffer.remote(cp.ones((10, ), dtype=cp.float32) * (i + 2)) + for i, a in enumerate(actors) + ]) + results = ray.get([ + a.do_broadcast.remote(group_name=group_name, src_rank=src_rank) + for a in actors + ]) + for i in range(world_size): + assert (results[i] == cp.ones( + (10, ), dtype=cp.float32) * (src_rank + 2)).all() + + +@pytest.mark.parametrize("array_size", [2, 2**5, 2**10, 2**15, 2**20]) +@pytest.mark.parametrize("src_rank", [0, 1, 2, 3]) +def test_broadcast_different_array_size(ray_start_distributed_2_nodes_4_gpus, + array_size, src_rank): + world_size = 4 + actors, _ = create_collective_workers(world_size) + ray.wait([ + a.set_buffer.remote(cp.ones(array_size, dtype=cp.float32) * (i + 2)) + for i, a in enumerate(actors) + ]) + results = ray.get( + [a.do_broadcast.remote(src_rank=src_rank) for a in actors]) + for i in range(world_size): + assert (results[i] == cp.ones( + (array_size, ), dtype=cp.float32) * (src_rank + 2)).all() + + +@pytest.mark.parametrize("src_rank", [0, 1]) +def test_broadcast_torch_cupy(ray_start_distributed_2_nodes_4_gpus, src_rank): + import torch + world_size = 4 + actors, _ = create_collective_workers(world_size) + ray.wait( + [actors[1].set_buffer.remote(torch.ones(10, ).cuda() * world_size)]) + results = ray.get( + [a.do_broadcast.remote(src_rank=src_rank) for a in actors]) + if src_rank == 0: + assert (results[0] == cp.ones((10, ))).all() + assert (results[1] == torch.ones((10, )).cuda()).all() + else: + assert (results[0] == cp.ones((10, )) * world_size).all() + assert (results[1] == torch.ones((10, )).cuda() * world_size).all() + + +def test_broadcast_invalid_rank(ray_start_single_node_2_gpus, src_rank=3): + world_size = 2 + actors, _ = create_collective_workers(world_size) + with pytest.raises(ValueError): + _ = ray.get([a.do_broadcast.remote(src_rank=src_rank) for a in actors]) diff --git a/python/ray/util/collective/tests/distributed_tests/test_distributed_reduce.py b/python/ray/util/collective/tests/distributed_tests/test_distributed_reduce.py new file mode 100644 index 000000000..9646f8d12 --- /dev/null +++ b/python/ray/util/collective/tests/distributed_tests/test_distributed_reduce.py @@ -0,0 +1,119 @@ +"""Test the reduce API.""" +import pytest +import cupy as cp +import ray +from ray.util.collective.types import ReduceOp + +from ray.util.collective.tests.util import create_collective_workers + + +@pytest.mark.parametrize("group_name", ["default", "test", "123?34!"]) +@pytest.mark.parametrize("dst_rank", [0, 1, 2, 3]) +def test_reduce_different_name(ray_start_distributed_2_nodes_4_gpus, + group_name, dst_rank): + world_size = 4 + actors, _ = create_collective_workers( + num_workers=world_size, group_name=group_name) + results = ray.get( + [a.do_reduce.remote(group_name, dst_rank) for a in actors]) + for i in range(world_size): + if i == dst_rank: + assert (results[i] == cp.ones( + (10, ), dtype=cp.float32) * world_size).all() + else: + assert (results[i] == cp.ones((10, ), dtype=cp.float32)).all() + + +@pytest.mark.parametrize("array_size", [2, 2**5, 2**10, 2**15, 2**20]) +@pytest.mark.parametrize("dst_rank", [0, 1, 2, 3]) +def test_reduce_different_array_size(ray_start_distributed_2_nodes_4_gpus, + array_size, dst_rank): + world_size = 4 + actors, _ = create_collective_workers(world_size) + ray.wait([ + a.set_buffer.remote(cp.ones(array_size, dtype=cp.float32)) + for a in actors + ]) + results = ray.get([a.do_reduce.remote(dst_rank=dst_rank) for a in actors]) + for i in range(world_size): + if i == dst_rank: + assert (results[i] == cp.ones( + (array_size, ), dtype=cp.float32) * world_size).all() + else: + assert (results[i] == cp.ones((array_size, ), + dtype=cp.float32)).all() + + +@pytest.mark.parametrize("dst_rank", [0, 1, 2, 3]) +def test_reduce_different_op(ray_start_distributed_2_nodes_4_gpus, dst_rank): + world_size = 4 + actors, _ = create_collective_workers(world_size) + + # check product + ray.wait([ + a.set_buffer.remote(cp.ones(10, dtype=cp.float32) * (i + 2)) + for i, a in enumerate(actors) + ]) + results = ray.get([ + a.do_reduce.remote(dst_rank=dst_rank, op=ReduceOp.PRODUCT) + for a in actors + ]) + for i in range(world_size): + if i == dst_rank: + assert (results[i] == cp.ones( + (10, ), dtype=cp.float32) * 120).all() + else: + assert (results[i] == cp.ones( + (10, ), dtype=cp.float32) * (i + 2)).all() + + # check min + ray.wait([ + a.set_buffer.remote(cp.ones(10, dtype=cp.float32) * (i + 2)) + for i, a in enumerate(actors) + ]) + results = ray.get([ + a.do_reduce.remote(dst_rank=dst_rank, op=ReduceOp.MIN) for a in actors + ]) + for i in range(world_size): + if i == dst_rank: + assert (results[i] == cp.ones((10, ), dtype=cp.float32) * 2).all() + else: + assert (results[i] == cp.ones( + (10, ), dtype=cp.float32) * (i + 2)).all() + + # check max + ray.wait([ + a.set_buffer.remote(cp.ones(10, dtype=cp.float32) * (i + 2)) + for i, a in enumerate(actors) + ]) + results = ray.get([ + a.do_reduce.remote(dst_rank=dst_rank, op=ReduceOp.MAX) for a in actors + ]) + for i in range(world_size): + if i == dst_rank: + assert (results[i] == cp.ones((10, ), dtype=cp.float32) * 5).all() + else: + assert (results[i] == cp.ones( + (10, ), dtype=cp.float32) * (i + 2)).all() + + +@pytest.mark.parametrize("dst_rank", [0, 1]) +def test_reduce_torch_cupy(ray_start_distributed_2_nodes_4_gpus, dst_rank): + import torch + world_size = 4 + actors, _ = create_collective_workers(world_size) + ray.wait([actors[1].set_buffer.remote(torch.ones(10, ).cuda())]) + results = ray.get([a.do_reduce.remote(dst_rank=dst_rank) for a in actors]) + if dst_rank == 0: + assert (results[0] == cp.ones((10, )) * world_size).all() + assert (results[1] == torch.ones((10, )).cuda()).all() + else: + assert (results[0] == cp.ones((10, ))).all() + assert (results[1] == torch.ones((10, )).cuda() * world_size).all() + + +def test_reduce_invalid_rank(ray_start_distributed_2_nodes_4_gpus, dst_rank=7): + world_size = 4 + actors, _ = create_collective_workers(world_size) + with pytest.raises(ValueError): + _ = ray.get([a.do_reduce.remote(dst_rank=dst_rank) for a in actors]) diff --git a/python/ray/util/collective/tests/distributed_tests/test_distributed_reducescatter.py b/python/ray/util/collective/tests/distributed_tests/test_distributed_reducescatter.py new file mode 100644 index 000000000..63230402a --- /dev/null +++ b/python/ray/util/collective/tests/distributed_tests/test_distributed_reducescatter.py @@ -0,0 +1,128 @@ +"""Test the collective reducescatter API on a distributed Ray cluster.""" +import pytest +import ray + +import cupy as cp +import torch + +from ray.util.collective.tests.util import create_collective_workers, \ + init_tensors_for_gather_scatter + + +@pytest.mark.parametrize("tensor_backend", ["cupy", "torch"]) +@pytest.mark.parametrize("array_size", + [2, 2**5, 2**10, 2**15, 2**20, [2, 2], [5, 5, 5]]) +def test_reducescatter_different_array_size( + ray_start_distributed_2_nodes_4_gpus, array_size, tensor_backend): + world_size = 4 + actors, _ = create_collective_workers(world_size) + init_tensors_for_gather_scatter( + actors, array_size=array_size, tensor_backend=tensor_backend) + results = ray.get([a.do_reducescatter.remote() for a in actors]) + for i in range(world_size): + if tensor_backend == "cupy": + assert (results[i] == cp.ones(array_size, dtype=cp.float32) * + world_size).all() + else: + assert (results[i] == torch.ones( + array_size, dtype=torch.float32).cuda() * world_size).all() + + +@pytest.mark.parametrize("dtype", + [cp.uint8, cp.float16, cp.float32, cp.float64]) +def test_reducescatter_different_dtype(ray_start_distributed_2_nodes_4_gpus, + dtype): + world_size = 4 + actors, _ = create_collective_workers(world_size) + init_tensors_for_gather_scatter(actors, dtype=dtype) + results = ray.get([a.do_reducescatter.remote() for a in actors]) + for i in range(world_size): + for j in range(world_size): + assert (results[i] == cp.ones(10, dtype=dtype) * world_size).all() + + +def test_reducescatter_torch_cupy(ray_start_distributed_2_nodes_4_gpus): + world_size = 4 + shape = [10, 10] + actors, _ = create_collective_workers(world_size) + + # tensor is pytorch, list is cupy + for i, a in enumerate(actors): + t = torch.ones(shape, dtype=torch.float32).cuda() * (i + 1) + ray.wait([a.set_buffer.remote(t)]) + list_buffer = [ + cp.ones(shape, dtype=cp.float32) for _ in range(world_size) + ] + ray.wait([a.set_list_buffer.remote(list_buffer)]) + results = ray.get([a.do_reducescatter.remote() for a in actors]) + for i in range(world_size): + assert (results[i] == torch.ones(shape, dtype=torch.float32).cuda() * + world_size).all() + + # tensor is cupy, list is pytorch + for i, a in enumerate(actors): + t = cp.ones(shape, dtype=cp.float32) * (i + 1) + ray.wait([a.set_buffer.remote(t)]) + list_buffer = [ + torch.ones(shape, dtype=torch.float32).cuda() + for _ in range(world_size) + ] + ray.wait([a.set_list_buffer.remote(list_buffer)]) + results = ray.get([a.do_reducescatter.remote() for a in actors]) + for i in range(world_size): + assert ( + results[i] == cp.ones(shape, dtype=cp.float32) * world_size).all() + + # some tensors in the list are pytorch, some are cupy + for i, a in enumerate(actors): + if i % 2 == 0: + t = torch.ones(shape, dtype=torch.float32).cuda() * (i + 1) + else: + t = cp.ones(shape, dtype=cp.float32) * (i + 1) + ray.wait([a.set_buffer.remote(t)]) + list_buffer = [] + for j in range(world_size): + if j % 2 == 0: + list_buffer.append( + torch.ones(shape, dtype=torch.float32).cuda()) + else: + list_buffer.append(cp.ones(shape, dtype=cp.float32)) + ray.wait([a.set_list_buffer.remote(list_buffer)]) + results = ray.get([a.do_reducescatter.remote() for a in actors]) + for i in range(world_size): + if i % 2 == 0: + assert (results[i] == torch.ones( + shape, dtype=torch.float32).cuda() * world_size).all() + else: + assert (results[i] == cp.ones(shape, dtype=cp.float32) * + world_size).all() + + # mixed case + for i, a in enumerate(actors): + if i % 2 == 0: + t = torch.ones(shape, dtype=torch.float32).cuda() * (i + 1) + else: + t = cp.ones(shape, dtype=cp.float32) * (i + 1) + ray.wait([a.set_buffer.remote(t)]) + list_buffer = [] + for j in range(world_size): + if j % 2 == 0: + list_buffer.append(cp.ones(shape, dtype=cp.float32)) + else: + list_buffer.append( + torch.ones(shape, dtype=torch.float32).cuda()) + ray.wait([a.set_list_buffer.remote(list_buffer)]) + results = ray.get([a.do_reducescatter.remote() for a in actors]) + for i in range(world_size): + if i % 2 == 0: + assert (results[i] == torch.ones( + shape, dtype=torch.float32).cuda() * world_size).all() + else: + assert (results[i] == cp.ones(shape, dtype=cp.float32) * + world_size).all() + + +if __name__ == "__main__": + import pytest + import sys + sys.exit(pytest.main(["-v", "-x", __file__])) diff --git a/python/ray/util/collective/tests/test_allgather.py b/python/ray/util/collective/tests/test_allgather.py new file mode 100644 index 000000000..33cf9a6d0 --- /dev/null +++ b/python/ray/util/collective/tests/test_allgather.py @@ -0,0 +1,131 @@ +"""Test the collective allgather API.""" +import pytest +import ray + +import cupy as cp +import torch + +from ray.util.collective.tests.util import create_collective_workers, \ + init_tensors_for_gather_scatter + + +@pytest.mark.parametrize("tensor_backend", ["cupy", "torch"]) +@pytest.mark.parametrize("array_size", + [2, 2**5, 2**10, 2**15, 2**20, [2, 2], [5, 5, 5]]) +def test_allgather_different_array_size(ray_start_single_node_2_gpus, + array_size, tensor_backend): + world_size = 2 + actors, _ = create_collective_workers(world_size) + init_tensors_for_gather_scatter( + actors, array_size=array_size, tensor_backend=tensor_backend) + results = ray.get([a.do_allgather.remote() for a in actors]) + for i in range(world_size): + for j in range(world_size): + if tensor_backend == "cupy": + assert (results[i][j] == cp.ones(array_size, dtype=cp.float32) + * (j + 1)).all() + else: + assert (results[i][j] == torch.ones( + array_size, dtype=torch.float32).cuda() * (j + 1)).all() + + +@pytest.mark.parametrize("dtype", + [cp.uint8, cp.float16, cp.float32, cp.float64]) +def test_allgather_different_dtype(ray_start_single_node_2_gpus, dtype): + world_size = 2 + actors, _ = create_collective_workers(world_size) + init_tensors_for_gather_scatter(actors, dtype=dtype) + results = ray.get([a.do_allgather.remote() for a in actors]) + for i in range(world_size): + for j in range(world_size): + assert (results[i][j] == cp.ones(10, dtype=dtype) * (j + 1)).all() + + +@pytest.mark.parametrize("length", [0, 1, 2, 3]) +def test_unmatched_tensor_list_length(ray_start_single_node_2_gpus, length): + world_size = 2 + actors, _ = create_collective_workers(world_size) + list_buffer = [cp.ones(10, dtype=cp.float32) for _ in range(length)] + ray.wait([a.set_list_buffer.remote(list_buffer) for a in actors]) + if length != world_size: + with pytest.raises(RuntimeError): + ray.get([a.do_allgather.remote() for a in actors]) + else: + ray.get([a.do_allgather.remote() for a in actors]) + + +@pytest.mark.parametrize("shape", [10, 20, [4, 5], [1, 3, 5, 7]]) +def test_unmatched_tensor_shape(ray_start_single_node_2_gpus, shape): + world_size = 2 + actors, _ = create_collective_workers(world_size) + init_tensors_for_gather_scatter(actors, array_size=10) + list_buffer = [cp.ones(shape, dtype=cp.float32) for _ in range(world_size)] + ray.get([a.set_list_buffer.remote(list_buffer) for a in actors]) + if shape != 10: + with pytest.raises(RuntimeError): + ray.get([a.do_allgather.remote() for a in actors]) + else: + ray.get([a.do_allgather.remote() for a in actors]) + + +def test_allgather_torch_cupy(ray_start_single_node_2_gpus): + world_size = 2 + shape = [10, 10] + actors, _ = create_collective_workers(world_size) + + # tensor is pytorch, list is cupy + for i, a in enumerate(actors): + t = torch.ones(shape, dtype=torch.float32).cuda() * (i + 1) + ray.wait([a.set_buffer.remote(t)]) + list_buffer = [ + cp.ones(shape, dtype=cp.float32) for _ in range(world_size) + ] + ray.wait([a.set_list_buffer.remote(list_buffer)]) + results = ray.get([a.do_allgather.remote() for a in actors]) + for i in range(world_size): + for j in range(world_size): + assert (results[i][j] == cp.ones(shape, dtype=cp.float32) * + (j + 1)).all() + + # tensor is cupy, list is pytorch + for i, a in enumerate(actors): + t = cp.ones(shape, dtype=cp.float32) * (i + 1) + ray.wait([a.set_buffer.remote(t)]) + list_buffer = [ + torch.ones(shape, dtype=torch.float32).cuda() + for _ in range(world_size) + ] + ray.wait([a.set_list_buffer.remote(list_buffer)]) + results = ray.get([a.do_allgather.remote() for a in actors]) + for i in range(world_size): + for j in range(world_size): + assert (results[i][j] == torch.ones( + shape, dtype=torch.float32).cuda() * (j + 1)).all() + + # some tensors in the list are pytorch, some are cupy + for i, a in enumerate(actors): + t = cp.ones(shape, dtype=cp.float32) * (i + 1) + ray.wait([a.set_buffer.remote(t)]) + list_buffer = [] + for j in range(world_size): + if j % 2 == 0: + list_buffer.append( + torch.ones(shape, dtype=torch.float32).cuda()) + else: + list_buffer.append(cp.ones(shape, dtype=cp.float32)) + ray.wait([a.set_list_buffer.remote(list_buffer)]) + results = ray.get([a.do_allgather.remote() for a in actors]) + for i in range(world_size): + for j in range(world_size): + if j % 2 == 0: + assert (results[i][j] == torch.ones( + shape, dtype=torch.float32).cuda() * (j + 1)).all() + else: + assert (results[i][j] == cp.ones(shape, dtype=cp.float32) * + (j + 1)).all() + + +if __name__ == "__main__": + import pytest + import sys + sys.exit(pytest.main(["-v", "-x", __file__])) diff --git a/python/ray/util/collective/tests/test_allreduce.py b/python/ray/util/collective/tests/test_allreduce.py new file mode 100644 index 000000000..1fbdf526b --- /dev/null +++ b/python/ray/util/collective/tests/test_allreduce.py @@ -0,0 +1,143 @@ +"""Test the collective allreduice API.""" +import pytest +import ray +from ray.util.collective.types import ReduceOp + +import cupy as cp +import torch + +from ray.util.collective.tests.util import create_collective_workers + + +@pytest.mark.parametrize("group_name", ["default", "test", "123?34!"]) +def test_allreduce_different_name(ray_start_single_node_2_gpus, group_name): + world_size = 2 + actors, _ = create_collective_workers( + num_workers=world_size, group_name=group_name) + results = ray.get([a.do_allreduce.remote(group_name) for a in actors]) + assert (results[0] == cp.ones((10, ), dtype=cp.float32) * world_size).all() + assert (results[1] == cp.ones((10, ), dtype=cp.float32) * world_size).all() + + +@pytest.mark.parametrize("array_size", [2, 2**5, 2**10, 2**15, 2**20]) +def test_allreduce_different_array_size(ray_start_single_node_2_gpus, + array_size): + world_size = 2 + actors, _ = create_collective_workers(world_size) + ray.wait([ + a.set_buffer.remote(cp.ones(array_size, dtype=cp.float32)) + for a in actors + ]) + results = ray.get([a.do_allreduce.remote() for a in actors]) + assert (results[0] == cp.ones( + (array_size, ), dtype=cp.float32) * world_size).all() + assert (results[1] == cp.ones( + (array_size, ), dtype=cp.float32) * world_size).all() + + +def test_allreduce_destroy(ray_start_single_node_2_gpus, + backend="nccl", + group_name="default"): + world_size = 2 + actors, _ = create_collective_workers(world_size) + + results = ray.get([a.do_allreduce.remote() for a in actors]) + assert (results[0] == cp.ones((10, ), dtype=cp.float32) * world_size).all() + assert (results[1] == cp.ones((10, ), dtype=cp.float32) * world_size).all() + + # destroy the group and try do work, should fail + ray.wait([a.destroy_group.remote() for a in actors]) + with pytest.raises(RuntimeError): + results = ray.get([a.do_allreduce.remote() for a in actors]) + + # reinit the same group and all reduce + ray.get([ + actor.init_group.remote(world_size, i, backend, group_name) + for i, actor in enumerate(actors) + ]) + results = ray.get([a.do_allreduce.remote() for a in actors]) + assert (results[0] == cp.ones( + (10, ), dtype=cp.float32) * world_size * 2).all() + assert (results[1] == cp.ones( + (10, ), dtype=cp.float32) * world_size * 2).all() + + +def test_allreduce_multiple_group(ray_start_single_node_2_gpus, + backend="nccl", + num_groups=5): + world_size = 2 + actors, _ = create_collective_workers(world_size) + for group_name in range(1, num_groups): + ray.get([ + actor.init_group.remote(world_size, i, backend, str(group_name)) + for i, actor in enumerate(actors) + ]) + for i in range(num_groups): + group_name = "default" if i == 0 else str(i) + results = ray.get([a.do_allreduce.remote(group_name) for a in actors]) + assert (results[0] == cp.ones( + (10, ), dtype=cp.float32) * (world_size**(i + 1))).all() + + +def test_allreduce_different_op(ray_start_single_node_2_gpus): + world_size = 2 + actors, _ = create_collective_workers(world_size) + + # check product + ray.wait([ + a.set_buffer.remote(cp.ones(10, dtype=cp.float32) * (i + 2)) + for i, a in enumerate(actors) + ]) + results = ray.get( + [a.do_allreduce.remote(op=ReduceOp.PRODUCT) for a in actors]) + assert (results[0] == cp.ones((10, ), dtype=cp.float32) * 6).all() + assert (results[1] == cp.ones((10, ), dtype=cp.float32) * 6).all() + + # check min + ray.wait([ + a.set_buffer.remote(cp.ones(10, dtype=cp.float32) * (i + 2)) + for i, a in enumerate(actors) + ]) + results = ray.get([a.do_allreduce.remote(op=ReduceOp.MIN) for a in actors]) + assert (results[0] == cp.ones((10, ), dtype=cp.float32) * 2).all() + assert (results[1] == cp.ones((10, ), dtype=cp.float32) * 2).all() + + # check max + ray.wait([ + a.set_buffer.remote(cp.ones(10, dtype=cp.float32) * (i + 2)) + for i, a in enumerate(actors) + ]) + results = ray.get([a.do_allreduce.remote(op=ReduceOp.MAX) for a in actors]) + assert (results[0] == cp.ones((10, ), dtype=cp.float32) * 3).all() + assert (results[1] == cp.ones((10, ), dtype=cp.float32) * 3).all() + + +@pytest.mark.parametrize("dtype", + [cp.uint8, cp.float16, cp.float32, cp.float64]) +def test_allreduce_different_dtype(ray_start_single_node_2_gpus, dtype): + world_size = 2 + actors, _ = create_collective_workers(world_size) + ray.wait([a.set_buffer.remote(cp.ones(10, dtype=dtype)) for a in actors]) + results = ray.get([a.do_allreduce.remote() for a in actors]) + assert (results[0] == cp.ones((10, ), dtype=dtype) * world_size).all() + assert (results[1] == cp.ones((10, ), dtype=dtype) * world_size).all() + + +def test_allreduce_torch_cupy(ray_start_single_node_2_gpus): + # import torch + world_size = 2 + actors, _ = create_collective_workers(world_size) + ray.wait([actors[1].set_buffer.remote(torch.ones(10, ).cuda())]) + results = ray.get([a.do_allreduce.remote() for a in actors]) + assert (results[0] == cp.ones((10, )) * world_size).all() + + ray.wait([actors[0].set_buffer.remote(torch.ones(10, ))]) + ray.wait([actors[1].set_buffer.remote(cp.ones(10, ))]) + with pytest.raises(RuntimeError): + results = ray.get([a.do_allreduce.remote() for a in actors]) + + +if __name__ == "__main__": + import pytest + import sys + sys.exit(pytest.main(["-v", "-x", __file__])) diff --git a/python/ray/util/collective/tests/test_basic_apis.py b/python/ray/util/collective/tests/test_basic_apis.py new file mode 100644 index 000000000..8c23442a3 --- /dev/null +++ b/python/ray/util/collective/tests/test_basic_apis.py @@ -0,0 +1,127 @@ +"""Test the collective group APIs.""" +import pytest +import ray + +from ray.util.collective.tests.util import Worker, \ + create_collective_workers + + +@pytest.mark.parametrize("group_name", ["default", "test", "123?34!"]) +def test_init_two_actors(ray_start_single_node_2_gpus, group_name): + world_size = 2 + actors, results = create_collective_workers(world_size, group_name) + for i in range(world_size): + assert (results[i]) + + +def test_init_multiple_groups(ray_start_single_node_2_gpus): + world_size = 2 + num_groups = 10 + actors = [Worker.remote() for i in range(world_size)] + for i in range(num_groups): + group_name = str(i) + init_results = ray.get([ + actor.init_group.remote(world_size, i, group_name=group_name) + for i, actor in enumerate(actors) + ]) + for j in range(world_size): + assert init_results[j] + + +def test_get_rank(ray_start_single_node_2_gpus): + world_size = 2 + actors, _ = create_collective_workers(world_size) + actor0_rank = ray.get(actors[0].report_rank.remote()) + assert actor0_rank == 0 + actor1_rank = ray.get(actors[1].report_rank.remote()) + assert actor1_rank == 1 + + # create a second group with a different name, + # and different order of ranks. + new_group_name = "default2" + _ = ray.get([ + actor.init_group.remote( + world_size, world_size - 1 - i, group_name=new_group_name) + for i, actor in enumerate(actors) + ]) + actor0_rank = ray.get(actors[0].report_rank.remote(new_group_name)) + assert actor0_rank == 1 + actor1_rank = ray.get(actors[1].report_rank.remote(new_group_name)) + assert actor1_rank == 0 + + +def test_get_world_size(ray_start_single_node_2_gpus): + world_size = 2 + actors, _ = create_collective_workers(world_size) + actor0_world_size = ray.get(actors[0].report_world_size.remote()) + actor1_world_size = ray.get(actors[1].report_world_size.remote()) + assert actor0_world_size == actor1_world_size == world_size + + +def test_availability(ray_start_single_node_2_gpus): + world_size = 2 + actors, _ = create_collective_workers(world_size) + actor0_nccl_availability = ray.get( + actors[0].report_nccl_availability.remote()) + assert actor0_nccl_availability + actor0_mpi_availability = ray.get( + actors[0].report_mpi_availability.remote()) + assert not actor0_mpi_availability + + +def test_is_group_initialized(ray_start_single_node_2_gpus): + world_size = 2 + actors, _ = create_collective_workers(world_size) + # check group is_init + actor0_is_init = ray.get(actors[0].report_is_group_initialized.remote()) + assert actor0_is_init + actor0_is_init = ray.get( + actors[0].report_is_group_initialized.remote("random")) + assert not actor0_is_init + actor0_is_init = ray.get( + actors[0].report_is_group_initialized.remote("123")) + assert not actor0_is_init + actor1_is_init = ray.get(actors[0].report_is_group_initialized.remote()) + assert actor1_is_init + actor1_is_init = ray.get( + actors[0].report_is_group_initialized.remote("456")) + assert not actor1_is_init + + +def test_destroy_group(ray_start_single_node_2_gpus): + world_size = 2 + actors, _ = create_collective_workers(world_size) + # Now destroy the group at actor0 + ray.wait([actors[0].destroy_group.remote()]) + actor0_is_init = ray.get(actors[0].report_is_group_initialized.remote()) + assert not actor0_is_init + + # should go well as the group `random` does not exist at all + ray.wait([actors[0].destroy_group.remote("random")]) + + actor1_is_init = ray.get(actors[1].report_is_group_initialized.remote()) + assert actor1_is_init + ray.wait([actors[1].destroy_group.remote("random")]) + actor1_is_init = ray.get(actors[1].report_is_group_initialized.remote()) + assert actor1_is_init + ray.wait([actors[1].destroy_group.remote("default")]) + actor1_is_init = ray.get(actors[1].report_is_group_initialized.remote()) + assert not actor1_is_init + + # Now reconstruct the group using the same name + init_results = ray.get([ + actor.init_group.remote(world_size, i) + for i, actor in enumerate(actors) + ]) + for i in range(world_size): + assert init_results[i] + actor0_is_init = ray.get(actors[0].report_is_group_initialized.remote()) + assert actor0_is_init + actor1_is_init = ray.get(actors[0].report_is_group_initialized.remote()) + assert actor1_is_init + + +if __name__ == "__main__": + import pytest + import sys + sys.exit(pytest.main(["-v", "-x", __file__])) diff --git a/python/ray/util/collective/tests/test_broadcast.py b/python/ray/util/collective/tests/test_broadcast.py new file mode 100644 index 000000000..3d62b6d2e --- /dev/null +++ b/python/ray/util/collective/tests/test_broadcast.py @@ -0,0 +1,67 @@ +"""Test the broadcast API.""" +import pytest +import cupy as cp +import ray + +from ray.util.collective.tests.util import create_collective_workers + + +@pytest.mark.parametrize("group_name", ["default", "test", "123?34!"]) +@pytest.mark.parametrize("src_rank", [0, 1]) +def test_broadcast_different_name(ray_start_single_node_2_gpus, group_name, + src_rank): + world_size = 2 + actors, _ = create_collective_workers( + num_workers=world_size, group_name=group_name) + ray.wait([ + a.set_buffer.remote(cp.ones((10, ), dtype=cp.float32) * (i + 2)) + for i, a in enumerate(actors) + ]) + results = ray.get([ + a.do_broadcast.remote(group_name=group_name, src_rank=src_rank) + for a in actors + ]) + for i in range(world_size): + assert (results[i] == cp.ones( + (10, ), dtype=cp.float32) * (src_rank + 2)).all() + + +@pytest.mark.parametrize("array_size", [2, 2**5, 2**10, 2**15, 2**20]) +@pytest.mark.parametrize("src_rank", [0, 1]) +def test_broadcast_different_array_size(ray_start_single_node_2_gpus, + array_size, src_rank): + world_size = 2 + actors, _ = create_collective_workers(world_size) + ray.wait([ + a.set_buffer.remote(cp.ones(array_size, dtype=cp.float32) * (i + 2)) + for i, a in enumerate(actors) + ]) + results = ray.get( + [a.do_broadcast.remote(src_rank=src_rank) for a in actors]) + for i in range(world_size): + assert (results[i] == cp.ones( + (array_size, ), dtype=cp.float32) * (src_rank + 2)).all() + + +@pytest.mark.parametrize("src_rank", [0, 1]) +def test_broadcast_torch_cupy(ray_start_single_node_2_gpus, src_rank): + import torch + world_size = 2 + actors, _ = create_collective_workers(world_size) + ray.wait( + [actors[1].set_buffer.remote(torch.ones(10, ).cuda() * world_size)]) + results = ray.get( + [a.do_broadcast.remote(src_rank=src_rank) for a in actors]) + if src_rank == 0: + assert (results[0] == cp.ones((10, ))).all() + assert (results[1] == torch.ones((10, )).cuda()).all() + else: + assert (results[0] == cp.ones((10, )) * world_size).all() + assert (results[1] == torch.ones((10, )).cuda() * world_size).all() + + +def test_broadcast_invalid_rank(ray_start_single_node_2_gpus, src_rank=3): + world_size = 2 + actors, _ = create_collective_workers(world_size) + with pytest.raises(ValueError): + _ = ray.get([a.do_broadcast.remote(src_rank=src_rank) for a in actors]) diff --git a/python/ray/util/collective/tests/test_collective_2_nodes_4_gpus.py b/python/ray/util/collective/tests/test_collective_2_nodes_4_gpus.py deleted file mode 100644 index c35e48b9a..000000000 --- a/python/ray/util/collective/tests/test_collective_2_nodes_4_gpus.py +++ /dev/null @@ -1,276 +0,0 @@ -"""Test the collective group APIs.""" -from random import shuffle -import pytest -import ray -from ray.util.collective.types import ReduceOp - -import cupy as cp -import torch - -from .util import Worker - - -def get_actors_group(num_workers=2, group_name="default", backend="nccl"): - actors = [Worker.remote() for i in range(num_workers)] - world_size = num_workers - init_results = ray.get([ - actor.init_group.remote(world_size, i, backend, group_name) - for i, actor in enumerate(actors) - ]) - return actors, init_results - - -@pytest.mark.parametrize("world_size", [2, 3, 4]) -@pytest.mark.parametrize("group_name", ["default", "test", "123?34!"]) -def test_init_two_actors(ray_start_distributed_2_nodes_4_gpus, world_size, - group_name): - actors, results = get_actors_group(world_size, group_name) - for i in range(world_size): - assert (results[i]) - - -@pytest.mark.parametrize("world_size", [2, 3, 4]) -def test_init_multiple_groups(ray_start_distributed_2_nodes_4_gpus, - world_size): - num_groups = 1 - actors = [Worker.remote() for _ in range(world_size)] - for i in range(num_groups): - group_name = str(i) - init_results = ray.get([ - actor.init_group.remote(world_size, i, group_name=group_name) - for i, actor in enumerate(actors) - ]) - for j in range(world_size): - assert init_results[j] - - -@pytest.mark.parametrize("world_size", [2, 3, 4]) -def test_get_rank(ray_start_distributed_2_nodes_4_gpus, world_size): - actors, _ = get_actors_group(world_size) - actor0_rank = ray.get(actors[0].report_rank.remote()) - assert actor0_rank == 0 - actor1_rank = ray.get(actors[1].report_rank.remote()) - assert actor1_rank == 1 - - # create a second group with a different name, and different - # orders of ranks. - new_group_name = "default2" - ranks = list(range(world_size)) - shuffle(ranks) - _ = ray.get([ - actor.init_group.remote( - world_size, ranks[i], group_name=new_group_name) - for i, actor in enumerate(actors) - ]) - actor0_rank = ray.get(actors[0].report_rank.remote(new_group_name)) - assert actor0_rank == ranks[0] - actor1_rank = ray.get(actors[1].report_rank.remote(new_group_name)) - assert actor1_rank == ranks[1] - - -@pytest.mark.parametrize("world_size", [2, 3, 4]) -def test_get_world_size(ray_start_distributed_2_nodes_4_gpus, world_size): - actors, _ = get_actors_group(world_size) - actor0_world_size = ray.get(actors[0].report_world_size.remote()) - actor1_world_size = ray.get(actors[1].report_world_size.remote()) - assert actor0_world_size == actor1_world_size == world_size - - -def test_availability(ray_start_distributed_2_nodes_4_gpus): - world_size = 4 - actors, _ = get_actors_group(world_size) - actor0_nccl_availability = ray.get( - actors[0].report_nccl_availability.remote()) - assert actor0_nccl_availability - actor0_mpi_availability = ray.get( - actors[0].report_mpi_availability.remote()) - assert not actor0_mpi_availability - - -def test_is_group_initialized(ray_start_distributed_2_nodes_4_gpus): - world_size = 4 - actors, _ = get_actors_group(world_size) - # check group is_init - actor0_is_init = ray.get(actors[0].report_is_group_initialized.remote()) - assert actor0_is_init - actor0_is_init = ray.get( - actors[0].report_is_group_initialized.remote("random")) - assert not actor0_is_init - actor0_is_init = ray.get( - actors[0].report_is_group_initialized.remote("123")) - assert not actor0_is_init - actor1_is_init = ray.get(actors[0].report_is_group_initialized.remote()) - assert actor1_is_init - actor1_is_init = ray.get( - actors[0].report_is_group_initialized.remote("456")) - assert not actor1_is_init - - -def test_destroy_group(ray_start_distributed_2_nodes_4_gpus): - world_size = 4 - actors, _ = get_actors_group(world_size) - # Now destroy the group at actor0 - ray.wait([actors[0].destroy_group.remote()]) - actor0_is_init = ray.get(actors[0].report_is_group_initialized.remote()) - assert not actor0_is_init - - # should go well as the group `random` does not exist at all - ray.wait([actors[0].destroy_group.remote("random")]) - - actor1_is_init = ray.get(actors[1].report_is_group_initialized.remote()) - assert actor1_is_init - ray.wait([actors[1].destroy_group.remote("random")]) - actor1_is_init = ray.get(actors[1].report_is_group_initialized.remote()) - assert actor1_is_init - ray.wait([actors[1].destroy_group.remote("default")]) - actor1_is_init = ray.get(actors[1].report_is_group_initialized.remote()) - assert not actor1_is_init - for i in [2, 3]: - ray.wait([actors[i].destroy_group.remote("default")]) - - # Now reconstruct the group using the same name - init_results = ray.get([ - actor.init_group.remote(world_size, i) - for i, actor in enumerate(actors) - ]) - for i in range(world_size): - assert init_results[i] - actor0_is_init = ray.get(actors[0].report_is_group_initialized.remote()) - assert actor0_is_init - actor1_is_init = ray.get(actors[0].report_is_group_initialized.remote()) - assert actor1_is_init - - -@pytest.mark.parametrize("group_name", ["default", "test", "123?34!"]) -@pytest.mark.parametrize("world_size", [2, 3, 4]) -def test_allreduce_different_name(ray_start_distributed_2_nodes_4_gpus, - group_name, world_size): - actors, _ = get_actors_group(num_workers=world_size, group_name=group_name) - results = ray.get([a.do_work.remote(group_name) for a in actors]) - assert (results[0] == cp.ones((10, ), dtype=cp.float32) * world_size).all() - assert (results[1] == cp.ones((10, ), dtype=cp.float32) * world_size).all() - - -@pytest.mark.parametrize("array_size", [2, 2**5, 2**10, 2**15, 2**20]) -def test_allreduce_different_array_size(ray_start_distributed_2_nodes_4_gpus, - array_size): - world_size = 4 - actors, _ = get_actors_group(world_size) - ray.wait([ - a.set_buffer.remote(cp.ones(array_size, dtype=cp.float32)) - for a in actors - ]) - results = ray.get([a.do_work.remote() for a in actors]) - assert (results[0] == cp.ones( - (array_size, ), dtype=cp.float32) * world_size).all() - assert (results[1] == cp.ones( - (array_size, ), dtype=cp.float32) * world_size).all() - - -def test_allreduce_destroy(ray_start_distributed_2_nodes_4_gpus, - backend="nccl", - group_name="default"): - world_size = 4 - actors, _ = get_actors_group(world_size) - - results = ray.get([a.do_work.remote() for a in actors]) - assert (results[0] == cp.ones((10, ), dtype=cp.float32) * world_size).all() - assert (results[1] == cp.ones((10, ), dtype=cp.float32) * world_size).all() - - # destroy the group and try do work, should fail - ray.wait([a.destroy_group.remote() for a in actors]) - with pytest.raises(RuntimeError): - results = ray.get([a.do_work.remote() for a in actors]) - - # reinit the same group and all reduce - ray.get([ - actor.init_group.remote(world_size, i, backend, group_name) - for i, actor in enumerate(actors) - ]) - results = ray.get([a.do_work.remote() for a in actors]) - assert (results[0] == cp.ones( - (10, ), dtype=cp.float32) * world_size * world_size).all() - assert (results[1] == cp.ones( - (10, ), dtype=cp.float32) * world_size * world_size).all() - - -def test_allreduce_multiple_group(ray_start_distributed_2_nodes_4_gpus, - backend="nccl", - num_groups=5): - world_size = 4 - actors, _ = get_actors_group(world_size) - for group_name in range(1, num_groups): - ray.get([ - actor.init_group.remote(world_size, i, backend, str(group_name)) - for i, actor in enumerate(actors) - ]) - for i in range(num_groups): - group_name = "default" if i == 0 else str(i) - results = ray.get([a.do_work.remote(group_name) for a in actors]) - assert (results[0] == cp.ones( - (10, ), dtype=cp.float32) * (world_size**(i + 1))).all() - - -def test_allreduce_different_op(ray_start_distributed_2_nodes_4_gpus): - world_size = 4 - actors, _ = get_actors_group(world_size) - - # check product - ray.wait([ - a.set_buffer.remote(cp.ones(10, dtype=cp.float32) * (i + 2)) - for i, a in enumerate(actors) - ]) - results = ray.get([a.do_work.remote(op=ReduceOp.PRODUCT) for a in actors]) - assert (results[0] == cp.ones((10, ), dtype=cp.float32) * 120).all() - assert (results[1] == cp.ones((10, ), dtype=cp.float32) * 120).all() - - # check min - ray.wait([ - a.set_buffer.remote(cp.ones(10, dtype=cp.float32) * (i + 2)) - for i, a in enumerate(actors) - ]) - results = ray.get([a.do_work.remote(op=ReduceOp.MIN) for a in actors]) - assert (results[0] == cp.ones((10, ), dtype=cp.float32) * 2).all() - assert (results[1] == cp.ones((10, ), dtype=cp.float32) * 2).all() - - # check max - ray.wait([ - a.set_buffer.remote(cp.ones(10, dtype=cp.float32) * (i + 2)) - for i, a in enumerate(actors) - ]) - results = ray.get([a.do_work.remote(op=ReduceOp.MAX) for a in actors]) - assert (results[0] == cp.ones((10, ), dtype=cp.float32) * 5).all() - assert (results[1] == cp.ones((10, ), dtype=cp.float32) * 5).all() - - -@pytest.mark.parametrize("dtype", - [cp.uint8, cp.float16, cp.float32, cp.float64]) -def test_allreduce_different_dtype(ray_start_distributed_2_nodes_4_gpus, - dtype): - world_size = 4 - actors, _ = get_actors_group(world_size) - ray.wait([a.set_buffer.remote(cp.ones(10, dtype=dtype)) for a in actors]) - results = ray.get([a.do_work.remote() for a in actors]) - assert (results[0] == cp.ones((10, ), dtype=dtype) * world_size).all() - assert (results[1] == cp.ones((10, ), dtype=dtype) * world_size).all() - - -def test_allreduce_torch_cupy(ray_start_distributed_2_nodes_4_gpus): - # import torch - world_size = 4 - actors, _ = get_actors_group(world_size) - ray.wait([actors[1].set_buffer.remote(torch.ones(10, ).cuda())]) - results = ray.get([a.do_work.remote() for a in actors]) - assert (results[0] == cp.ones((10, )) * world_size).all() - - ray.wait([actors[0].set_buffer.remote(torch.ones(10, ))]) - ray.wait([actors[1].set_buffer.remote(cp.ones(10, ))]) - with pytest.raises(RuntimeError): - results = ray.get([a.do_work.remote() for a in actors]) - - -if __name__ == "__main__": - import pytest - import sys - - sys.exit(pytest.main(["-v", "-x", __file__])) diff --git a/python/ray/util/collective/tests/test_collective_single_node_2_gpus.py b/python/ray/util/collective/tests/test_collective_single_node_2_gpus.py deleted file mode 100644 index 267375e29..000000000 --- a/python/ray/util/collective/tests/test_collective_single_node_2_gpus.py +++ /dev/null @@ -1,267 +0,0 @@ -"""Test the collective group APIs.""" -import pytest -import ray -from ray.util.collective.types import ReduceOp - -import cupy as cp -import torch - -from .util import Worker - - -def get_actors_group(num_workers=2, group_name="default", backend="nccl"): - actors = [Worker.remote() for _ in range(num_workers)] - world_size = num_workers - init_results = ray.get([ - actor.init_group.remote(world_size, i, backend, group_name) - for i, actor in enumerate(actors) - ]) - return actors, init_results - - -@pytest.mark.parametrize("group_name", ["default", "test", "123?34!"]) -def test_init_two_actors(ray_start_single_node_2_gpus, group_name): - world_size = 2 - actors, results = get_actors_group(world_size, group_name) - for i in range(world_size): - assert (results[i]) - - -def test_init_multiple_groups(ray_start_single_node_2_gpus): - world_size = 2 - num_groups = 10 - actors = [Worker.remote() for i in range(world_size)] - for i in range(num_groups): - group_name = str(i) - init_results = ray.get([ - actor.init_group.remote(world_size, i, group_name=group_name) - for i, actor in enumerate(actors) - ]) - for j in range(world_size): - assert init_results[j] - - -def test_get_rank(ray_start_single_node_2_gpus): - world_size = 2 - actors, _ = get_actors_group(world_size) - actor0_rank = ray.get(actors[0].report_rank.remote()) - assert actor0_rank == 0 - actor1_rank = ray.get(actors[1].report_rank.remote()) - assert actor1_rank == 1 - - # create a second group with a different name, - # and different order of ranks. - new_group_name = "default2" - _ = ray.get([ - actor.init_group.remote( - world_size, world_size - 1 - i, group_name=new_group_name) - for i, actor in enumerate(actors) - ]) - actor0_rank = ray.get(actors[0].report_rank.remote(new_group_name)) - assert actor0_rank == 1 - actor1_rank = ray.get(actors[1].report_rank.remote(new_group_name)) - assert actor1_rank == 0 - - -def test_get_world_size(ray_start_single_node_2_gpus): - world_size = 2 - actors, _ = get_actors_group(world_size) - actor0_world_size = ray.get(actors[0].report_world_size.remote()) - actor1_world_size = ray.get(actors[1].report_world_size.remote()) - assert actor0_world_size == actor1_world_size == world_size - - -def test_availability(ray_start_single_node_2_gpus): - world_size = 2 - actors, _ = get_actors_group(world_size) - actor0_nccl_availability = ray.get( - actors[0].report_nccl_availability.remote()) - assert actor0_nccl_availability - actor0_mpi_availability = ray.get( - actors[0].report_mpi_availability.remote()) - assert not actor0_mpi_availability - - -def test_is_group_initialized(ray_start_single_node_2_gpus): - world_size = 2 - actors, _ = get_actors_group(world_size) - # check group is_init - actor0_is_init = ray.get(actors[0].report_is_group_initialized.remote()) - assert actor0_is_init - actor0_is_init = ray.get( - actors[0].report_is_group_initialized.remote("random")) - assert not actor0_is_init - actor0_is_init = ray.get( - actors[0].report_is_group_initialized.remote("123")) - assert not actor0_is_init - actor1_is_init = ray.get(actors[0].report_is_group_initialized.remote()) - assert actor1_is_init - actor1_is_init = ray.get( - actors[0].report_is_group_initialized.remote("456")) - assert not actor1_is_init - - -def test_destroy_group(ray_start_single_node_2_gpus): - world_size = 2 - actors, _ = get_actors_group(world_size) - # Now destroy the group at actor0 - ray.wait([actors[0].destroy_group.remote()]) - actor0_is_init = ray.get(actors[0].report_is_group_initialized.remote()) - assert not actor0_is_init - - # should go well as the group `random` does not exist at all - ray.wait([actors[0].destroy_group.remote("random")]) - - actor1_is_init = ray.get(actors[1].report_is_group_initialized.remote()) - assert actor1_is_init - ray.wait([actors[1].destroy_group.remote("random")]) - actor1_is_init = ray.get(actors[1].report_is_group_initialized.remote()) - assert actor1_is_init - ray.wait([actors[1].destroy_group.remote("default")]) - actor1_is_init = ray.get(actors[1].report_is_group_initialized.remote()) - assert not actor1_is_init - - # Now reconstruct the group using the same name - init_results = ray.get([ - actor.init_group.remote(world_size, i) - for i, actor in enumerate(actors) - ]) - for i in range(world_size): - assert init_results[i] - actor0_is_init = ray.get(actors[0].report_is_group_initialized.remote()) - assert actor0_is_init - actor1_is_init = ray.get(actors[0].report_is_group_initialized.remote()) - assert actor1_is_init - - -@pytest.mark.parametrize("group_name", ["default", "test", "123?34!"]) -# @pytest.mark.parametrize("group_name", ['123?34!']) -def test_allreduce_different_name(ray_start_single_node_2_gpus, group_name): - world_size = 2 - actors, _ = get_actors_group(num_workers=world_size, group_name=group_name) - results = ray.get([a.do_work.remote(group_name) for a in actors]) - assert (results[0] == cp.ones((10, ), dtype=cp.float32) * world_size).all() - assert (results[1] == cp.ones((10, ), dtype=cp.float32) * world_size).all() - - -@pytest.mark.parametrize("array_size", [2, 2**5, 2**10, 2**15, 2**20]) -def test_allreduce_different_array_size(ray_start_single_node_2_gpus, - array_size): - world_size = 2 - actors, _ = get_actors_group(world_size) - ray.wait([ - a.set_buffer.remote(cp.ones(array_size, dtype=cp.float32)) - for a in actors - ]) - results = ray.get([a.do_work.remote() for a in actors]) - assert (results[0] == cp.ones( - (array_size, ), dtype=cp.float32) * world_size).all() - assert (results[1] == cp.ones( - (array_size, ), dtype=cp.float32) * world_size).all() - - -def test_allreduce_destroy(ray_start_single_node_2_gpus, - backend="nccl", - group_name="default"): - world_size = 2 - actors, _ = get_actors_group(world_size) - - results = ray.get([a.do_work.remote() for a in actors]) - assert (results[0] == cp.ones((10, ), dtype=cp.float32) * world_size).all() - assert (results[1] == cp.ones((10, ), dtype=cp.float32) * world_size).all() - - # destroy the group and try do work, should fail - ray.wait([a.destroy_group.remote() for a in actors]) - with pytest.raises(RuntimeError): - results = ray.get([a.do_work.remote() for a in actors]) - - # reinit the same group and all reduce - ray.get([ - actor.init_group.remote(world_size, i, backend, group_name) - for i, actor in enumerate(actors) - ]) - results = ray.get([a.do_work.remote() for a in actors]) - assert (results[0] == cp.ones( - (10, ), dtype=cp.float32) * world_size * 2).all() - assert (results[1] == cp.ones( - (10, ), dtype=cp.float32) * world_size * 2).all() - - -def test_allreduce_multiple_group(ray_start_single_node_2_gpus, - backend="nccl", - num_groups=5): - world_size = 2 - actors, _ = get_actors_group(world_size) - for group_name in range(1, num_groups): - ray.get([ - actor.init_group.remote(world_size, i, backend, str(group_name)) - for i, actor in enumerate(actors) - ]) - for i in range(num_groups): - group_name = "default" if i == 0 else str(i) - results = ray.get([a.do_work.remote(group_name) for a in actors]) - assert (results[0] == cp.ones( - (10, ), dtype=cp.float32) * (world_size**(i + 1))).all() - - -def test_allreduce_different_op(ray_start_single_node_2_gpus): - world_size = 2 - actors, _ = get_actors_group(world_size) - - # check product - ray.wait([ - a.set_buffer.remote(cp.ones(10, dtype=cp.float32) * (i + 2)) - for i, a in enumerate(actors) - ]) - results = ray.get([a.do_work.remote(op=ReduceOp.PRODUCT) for a in actors]) - assert (results[0] == cp.ones((10, ), dtype=cp.float32) * 6).all() - assert (results[1] == cp.ones((10, ), dtype=cp.float32) * 6).all() - - # check min - ray.wait([ - a.set_buffer.remote(cp.ones(10, dtype=cp.float32) * (i + 2)) - for i, a in enumerate(actors) - ]) - results = ray.get([a.do_work.remote(op=ReduceOp.MIN) for a in actors]) - assert (results[0] == cp.ones((10, ), dtype=cp.float32) * 2).all() - assert (results[1] == cp.ones((10, ), dtype=cp.float32) * 2).all() - - # check max - ray.wait([ - a.set_buffer.remote(cp.ones(10, dtype=cp.float32) * (i + 2)) - for i, a in enumerate(actors) - ]) - results = ray.get([a.do_work.remote(op=ReduceOp.MAX) for a in actors]) - assert (results[0] == cp.ones((10, ), dtype=cp.float32) * 3).all() - assert (results[1] == cp.ones((10, ), dtype=cp.float32) * 3).all() - - -@pytest.mark.parametrize("dtype", - [cp.uint8, cp.float16, cp.float32, cp.float64]) -def test_allreduce_different_dtype(ray_start_single_node_2_gpus, dtype): - world_size = 2 - actors, _ = get_actors_group(world_size) - ray.wait([a.set_buffer.remote(cp.ones(10, dtype=dtype)) for a in actors]) - results = ray.get([a.do_work.remote() for a in actors]) - assert (results[0] == cp.ones((10, ), dtype=dtype) * world_size).all() - assert (results[1] == cp.ones((10, ), dtype=dtype) * world_size).all() - - -def test_allreduce_torch_cupy(ray_start_single_node_2_gpus): - # import torch - world_size = 2 - actors, _ = get_actors_group(world_size) - ray.wait([actors[1].set_buffer.remote(torch.ones(10, ).cuda())]) - results = ray.get([a.do_work.remote() for a in actors]) - assert (results[0] == cp.ones((10, )) * world_size).all() - - ray.wait([actors[0].set_buffer.remote(torch.ones(10, ))]) - ray.wait([actors[1].set_buffer.remote(cp.ones(10, ))]) - with pytest.raises(RuntimeError): - results = ray.get([a.do_work.remote() for a in actors]) - - -if __name__ == "__main__": - import pytest - import sys - sys.exit(pytest.main(["-v", "-x", __file__])) diff --git a/python/ray/util/collective/tests/test_reduce.py b/python/ray/util/collective/tests/test_reduce.py new file mode 100644 index 000000000..89063620c --- /dev/null +++ b/python/ray/util/collective/tests/test_reduce.py @@ -0,0 +1,143 @@ +"""Test the reduce API.""" +import pytest +import cupy as cp +import ray +from ray.util.collective.types import ReduceOp + +from ray.util.collective.tests.util import create_collective_workers + + +@pytest.mark.parametrize("group_name", ["default", "test", "123?34!"]) +@pytest.mark.parametrize("dst_rank", [0, 1]) +def test_reduce_different_name(ray_start_single_node_2_gpus, group_name, + dst_rank): + world_size = 2 + actors, _ = create_collective_workers( + num_workers=world_size, group_name=group_name) + results = ray.get( + [a.do_reduce.remote(group_name, dst_rank) for a in actors]) + for i in range(world_size): + if i == dst_rank: + assert (results[i] == cp.ones( + (10, ), dtype=cp.float32) * world_size).all() + else: + assert (results[i] == cp.ones((10, ), dtype=cp.float32)).all() + + +@pytest.mark.parametrize("array_size", [2, 2**5, 2**10, 2**15, 2**20]) +@pytest.mark.parametrize("dst_rank", [0, 1]) +def test_reduce_different_array_size(ray_start_single_node_2_gpus, array_size, + dst_rank): + world_size = 2 + actors, _ = create_collective_workers(world_size) + ray.wait([ + a.set_buffer.remote(cp.ones(array_size, dtype=cp.float32)) + for a in actors + ]) + results = ray.get([a.do_reduce.remote(dst_rank=dst_rank) for a in actors]) + for i in range(world_size): + if i == dst_rank: + assert (results[i] == cp.ones( + (array_size, ), dtype=cp.float32) * world_size).all() + else: + assert (results[i] == cp.ones((array_size, ), + dtype=cp.float32)).all() + + +@pytest.mark.parametrize("dst_rank", [0, 1]) +def test_reduce_multiple_group(ray_start_single_node_2_gpus, + dst_rank, + num_groups=5): + world_size = 2 + actors, _ = create_collective_workers(world_size) + for group_name in range(1, num_groups): + ray.get([ + actor.init_group.remote(world_size, i, "nccl", str(group_name)) + for i, actor in enumerate(actors) + ]) + for i in range(num_groups): + group_name = "default" if i == 0 else str(i) + results = ray.get([ + a.do_reduce.remote(dst_rank=dst_rank, group_name=group_name) + for a in actors + ]) + for j in range(world_size): + if j == dst_rank: + assert (results[j] == cp.ones( + (10, ), dtype=cp.float32) * (i + 2)).all() + else: + assert (results[j] == cp.ones((10, ), dtype=cp.float32)).all() + + +@pytest.mark.parametrize("dst_rank", [0, 1]) +def test_reduce_different_op(ray_start_single_node_2_gpus, dst_rank): + world_size = 2 + actors, _ = create_collective_workers(world_size) + + # check product + ray.wait([ + a.set_buffer.remote(cp.ones(10, dtype=cp.float32) * (i + 2)) + for i, a in enumerate(actors) + ]) + results = ray.get([ + a.do_reduce.remote(dst_rank=dst_rank, op=ReduceOp.PRODUCT) + for a in actors + ]) + for i in range(world_size): + if i == dst_rank: + assert (results[i] == cp.ones((10, ), dtype=cp.float32) * 6).all() + else: + assert (results[i] == cp.ones( + (10, ), dtype=cp.float32) * (i + 2)).all() + + # check min + ray.wait([ + a.set_buffer.remote(cp.ones(10, dtype=cp.float32) * (i + 2)) + for i, a in enumerate(actors) + ]) + results = ray.get([ + a.do_reduce.remote(dst_rank=dst_rank, op=ReduceOp.MIN) for a in actors + ]) + for i in range(world_size): + if i == dst_rank: + assert (results[i] == cp.ones((10, ), dtype=cp.float32) * 2).all() + else: + assert (results[i] == cp.ones( + (10, ), dtype=cp.float32) * (i + 2)).all() + + # check max + ray.wait([ + a.set_buffer.remote(cp.ones(10, dtype=cp.float32) * (i + 2)) + for i, a in enumerate(actors) + ]) + results = ray.get([ + a.do_reduce.remote(dst_rank=dst_rank, op=ReduceOp.MAX) for a in actors + ]) + for i in range(world_size): + if i == dst_rank: + assert (results[i] == cp.ones((10, ), dtype=cp.float32) * 3).all() + else: + assert (results[i] == cp.ones( + (10, ), dtype=cp.float32) * (i + 2)).all() + + +@pytest.mark.parametrize("dst_rank", [0, 1]) +def test_reduce_torch_cupy(ray_start_single_node_2_gpus, dst_rank): + import torch + world_size = 2 + actors, _ = create_collective_workers(world_size) + ray.wait([actors[1].set_buffer.remote(torch.ones(10, ).cuda())]) + results = ray.get([a.do_reduce.remote(dst_rank=dst_rank) for a in actors]) + if dst_rank == 0: + assert (results[0] == cp.ones((10, )) * world_size).all() + assert (results[1] == torch.ones((10, )).cuda()).all() + else: + assert (results[0] == cp.ones((10, ))).all() + assert (results[1] == torch.ones((10, )).cuda() * world_size).all() + + +def test_reduce_invalid_rank(ray_start_single_node_2_gpus, dst_rank=3): + world_size = 2 + actors, _ = create_collective_workers(world_size) + with pytest.raises(ValueError): + _ = ray.get([a.do_reduce.remote(dst_rank=dst_rank) for a in actors]) diff --git a/python/ray/util/collective/tests/test_reducescatter.py b/python/ray/util/collective/tests/test_reducescatter.py new file mode 100644 index 000000000..4b1322ed4 --- /dev/null +++ b/python/ray/util/collective/tests/test_reducescatter.py @@ -0,0 +1,127 @@ +"""Test the collective reducescatter API.""" +import pytest +import ray + +import cupy as cp +import torch + +from ray.util.collective.tests.util import create_collective_workers, \ + init_tensors_for_gather_scatter + + +@pytest.mark.parametrize("tensor_backend", ["cupy", "torch"]) +@pytest.mark.parametrize("array_size", + [2, 2**5, 2**10, 2**15, 2**20, [2, 2], [5, 5, 5]]) +def test_reducescatter_different_array_size(ray_start_single_node_2_gpus, + array_size, tensor_backend): + world_size = 2 + actors, _ = create_collective_workers(world_size) + init_tensors_for_gather_scatter( + actors, array_size=array_size, tensor_backend=tensor_backend) + results = ray.get([a.do_reducescatter.remote() for a in actors]) + for i in range(world_size): + if tensor_backend == "cupy": + assert (results[i] == cp.ones(array_size, dtype=cp.float32) * + world_size).all() + else: + assert (results[i] == torch.ones( + array_size, dtype=torch.float32).cuda() * world_size).all() + + +@pytest.mark.parametrize("dtype", + [cp.uint8, cp.float16, cp.float32, cp.float64]) +def test_reducescatter_different_dtype(ray_start_single_node_2_gpus, dtype): + world_size = 2 + actors, _ = create_collective_workers(world_size) + init_tensors_for_gather_scatter(actors, dtype=dtype) + results = ray.get([a.do_reducescatter.remote() for a in actors]) + for i in range(world_size): + for j in range(world_size): + assert (results[i] == cp.ones(10, dtype=dtype) * world_size).all() + + +def test_reducescatter_torch_cupy(ray_start_single_node_2_gpus): + world_size = 2 + shape = [10, 10] + actors, _ = create_collective_workers(world_size) + + # tensor is pytorch, list is cupy + for i, a in enumerate(actors): + t = torch.ones(shape, dtype=torch.float32).cuda() * (i + 1) + ray.wait([a.set_buffer.remote(t)]) + list_buffer = [ + cp.ones(shape, dtype=cp.float32) for _ in range(world_size) + ] + ray.wait([a.set_list_buffer.remote(list_buffer)]) + results = ray.get([a.do_reducescatter.remote() for a in actors]) + for i in range(world_size): + assert (results[i] == torch.ones(shape, dtype=torch.float32).cuda() * + world_size).all() + + # tensor is cupy, list is pytorch + for i, a in enumerate(actors): + t = cp.ones(shape, dtype=cp.float32) * (i + 1) + ray.wait([a.set_buffer.remote(t)]) + list_buffer = [ + torch.ones(shape, dtype=torch.float32).cuda() + for _ in range(world_size) + ] + ray.wait([a.set_list_buffer.remote(list_buffer)]) + results = ray.get([a.do_reducescatter.remote() for a in actors]) + for i in range(world_size): + assert ( + results[i] == cp.ones(shape, dtype=cp.float32) * world_size).all() + + # some tensors in the list are pytorch, some are cupy + for i, a in enumerate(actors): + if i % 2 == 0: + t = torch.ones(shape, dtype=torch.float32).cuda() * (i + 1) + else: + t = cp.ones(shape, dtype=cp.float32) * (i + 1) + ray.wait([a.set_buffer.remote(t)]) + list_buffer = [] + for j in range(world_size): + if j % 2 == 0: + list_buffer.append( + torch.ones(shape, dtype=torch.float32).cuda()) + else: + list_buffer.append(cp.ones(shape, dtype=cp.float32)) + ray.wait([a.set_list_buffer.remote(list_buffer)]) + results = ray.get([a.do_reducescatter.remote() for a in actors]) + for i in range(world_size): + if i % 2 == 0: + assert (results[i] == torch.ones( + shape, dtype=torch.float32).cuda() * world_size).all() + else: + assert (results[i] == cp.ones(shape, dtype=cp.float32) * + world_size).all() + + # mixed case + for i, a in enumerate(actors): + if i % 2 == 0: + t = torch.ones(shape, dtype=torch.float32).cuda() * (i + 1) + else: + t = cp.ones(shape, dtype=cp.float32) * (i + 1) + ray.wait([a.set_buffer.remote(t)]) + list_buffer = [] + for j in range(world_size): + if j % 2 == 0: + list_buffer.append(cp.ones(shape, dtype=cp.float32)) + else: + list_buffer.append( + torch.ones(shape, dtype=torch.float32).cuda()) + ray.wait([a.set_list_buffer.remote(list_buffer)]) + results = ray.get([a.do_reducescatter.remote() for a in actors]) + for i in range(world_size): + if i % 2 == 0: + assert (results[i] == torch.ones( + shape, dtype=torch.float32).cuda() * world_size).all() + else: + assert (results[i] == cp.ones(shape, dtype=cp.float32) * + world_size).all() + + +if __name__ == "__main__": + import pytest + import sys + sys.exit(pytest.main(["-v", "-x", __file__])) diff --git a/python/ray/util/collective/tests/util.py b/python/ray/util/collective/tests/util.py index d59294d3f..3cee4de59 100644 --- a/python/ray/util/collective/tests/util.py +++ b/python/ray/util/collective/tests/util.py @@ -4,11 +4,17 @@ import ray import ray.util.collective as col from ray.util.collective.types import Backend, ReduceOp +import torch + @ray.remote(num_gpus=1) class Worker: def __init__(self): self.buffer = cp.ones((10, ), dtype=cp.float32) + self.list_buffer = [ + cp.ones((10, ), dtype=cp.float32), + cp.ones((10, ), dtype=cp.float32) + ] def init_group(self, world_size, @@ -22,10 +28,30 @@ class Worker: self.buffer = data return self.buffer - def do_work(self, group_name="default", op=ReduceOp.SUM): + def set_list_buffer(self, list_of_arrays): + self.list_buffer = list_of_arrays + return self.list_buffer + + def do_allreduce(self, group_name="default", op=ReduceOp.SUM): col.allreduce(self.buffer, group_name, op) return self.buffer + def do_reduce(self, group_name="default", dst_rank=0, op=ReduceOp.SUM): + col.reduce(self.buffer, dst_rank, group_name, op) + return self.buffer + + def do_broadcast(self, group_name="default", src_rank=0): + col.broadcast(self.buffer, src_rank, group_name) + return self.buffer + + def do_allgather(self, group_name="default"): + col.allgather(self.list_buffer, self.buffer, group_name) + return self.list_buffer + + def do_reducescatter(self, group_name="default", op=ReduceOp.SUM): + col.reducescatter(self.buffer, self.list_buffer, group_name, op) + return self.buffer + def destroy_group(self, group_name="default"): col.destroy_collective_group(group_name) return True @@ -49,3 +75,42 @@ class Worker: def report_is_group_initialized(self, group_name="default"): is_init = col.is_group_initialized(group_name) return is_init + + +def create_collective_workers(num_workers=2, + group_name="default", + backend="nccl"): + actors = [Worker.remote() for _ in range(num_workers)] + world_size = num_workers + init_results = ray.get([ + actor.init_group.remote(world_size, i, backend, group_name) + for i, actor in enumerate(actors) + ]) + return actors, init_results + + +def init_tensors_for_gather_scatter(actors, + array_size=10, + dtype=cp.float32, + tensor_backend="cupy"): + world_size = len(actors) + for i, a in enumerate(actors): + if tensor_backend == "cupy": + t = cp.ones(array_size, dtype=dtype) * (i + 1) + elif tensor_backend == "torch": + t = torch.ones(array_size, dtype=torch.float32).cuda() * (i + 1) + else: + raise RuntimeError("Unsupported tensor backend.") + ray.wait([a.set_buffer.remote(t)]) + if tensor_backend == "cupy": + list_buffer = [ + cp.ones(array_size, dtype=dtype) for _ in range(world_size) + ] + elif tensor_backend == "torch": + list_buffer = [ + torch.ones(array_size, dtype=torch.float32).cuda() + for _ in range(world_size) + ] + else: + raise RuntimeError("Unsupported tensor backend.") + ray.get([a.set_list_buffer.remote(list_buffer) for a in actors]) diff --git a/python/ray/util/collective/types.py b/python/ray/util/collective/types.py index ef037373a..be92b98f2 100644 --- a/python/ray/util/collective/types.py +++ b/python/ray/util/collective/types.py @@ -50,15 +50,46 @@ class ReduceOp(Enum): MAX = 3 -unset_timeout = timedelta(milliseconds=-1) +unset_timeout_ms = timedelta(milliseconds=-1) @dataclass class AllReduceOptions: reduceOp = ReduceOp.SUM - timeout = unset_timeout + timeout_ms = unset_timeout_ms @dataclass class BarrierOptions: - timeout = unset_timeout + timeout_ms = unset_timeout_ms + + +@dataclass +class ReduceOptions: + reduceOp = ReduceOp.SUM + root_rank = 0 + timeout_ms = unset_timeout_ms + + +@dataclass +class AllGatherOptions: + timeout_ms = unset_timeout_ms + + +# +# @dataclass +# class GatherOptions: +# root_rank = 0 +# timeout = unset_timeout + + +@dataclass +class BroadcastOptions: + root_rank = 0 + timeout_ms = unset_timeout_ms + + +@dataclass +class ReduceScatterOptions: + reduceOp = ReduceOp.SUM + timeout_ms = unset_timeout_ms From 8b4b4bf0a2ae0f1c08095f1d23a74698528c52e0 Mon Sep 17 00:00:00 2001 From: architkulkarni Date: Mon, 21 Dec 2020 13:34:15 -0800 Subject: [PATCH 59/88] [Serve] Migrate from Flask.Request to Starlette Request (#12852) --- doc/source/serve/faq.rst | 37 ++++----- doc/source/serve/index.rst | 3 + doc/source/serve/key-concepts.rst | 8 +- doc/source/serve/package-ref.rst | 2 +- doc/source/serve/tutorials/batch.rst | 10 +-- .../serve/examples/doc/quickstart_class.py | 2 +- .../serve/examples/doc/quickstart_function.py | 4 +- .../examples/doc/snippet_model_composition.py | 10 +-- .../ray/serve/examples/doc/tutorial_batch.py | 8 +- .../ray/serve/examples/doc/tutorial_deploy.py | 12 +-- .../serve/examples/doc/tutorial_pytorch.py | 4 +- .../serve/examples/doc/tutorial_sklearn.py | 6 +- .../serve/examples/doc/tutorial_tensorflow.py | 4 +- python/ray/serve/examples/echo.py | 4 +- python/ray/serve/examples/echo_actor.py | 7 +- python/ray/serve/examples/echo_actor_batch.py | 9 ++- python/ray/serve/examples/echo_batching.py | 2 +- python/ray/serve/examples/echo_full.py | 6 +- python/ray/serve/handle.py | 6 +- python/ray/serve/http_util.py | 79 ++++--------------- python/ray/serve/tests/test_api.py | 30 +++++-- python/ray/serve/tests/test_backend_worker.py | 6 +- python/ray/serve/tests/test_handle.py | 16 ++-- python/ray/serve/tests/test_persistence.py | 2 +- python/ray/serve/tests/test_regression.py | 4 +- python/ray/serve/utils.py | 33 ++++---- 26 files changed, 140 insertions(+), 174 deletions(-) diff --git a/doc/source/serve/faq.rst b/doc/source/serve/faq.rst index 4f6791c5b..451307ce3 100644 --- a/doc/source/serve/faq.rst +++ b/doc/source/serve/faq.rst @@ -117,7 +117,7 @@ policies `, finding the next available replica, and batching requests together. When the request arrives in the model, you can access the data similarly to how -you would with HTTP request. Here are some examples how ServeRequest mirrors Flask.Request: +you would with HTTP request. Here are some examples how ServeRequest mirrors Starlette.Request: .. list-table:: :header-rows: 1 @@ -125,25 +125,25 @@ you would with HTTP request. Here are some examples how ServeRequest mirrors Fla * - HTTP - ServeHandle - | Request - | (Flask.Request and ServeRequest) + | (Starlette.Request and ServeRequest) * - ``requests.get(..., headers={...})`` - ``handle.options(http_headers={...})`` - ``request.headers`` * - ``requests.post(...)`` - ``handle.options(http_method="POST")`` - - ``requests.method`` - * - ``request.get(..., json={...})`` + - ``request.method`` + * - ``requests.get(..., json={...})`` - ``handle.remote({...})`` - - ``request.json`` - * - ``request.get(..., form={...})`` + - ``await request.json()`` + * - ``requests.get(..., form={...})`` - ``handle.remote({...})`` - - ``request.form`` - * - ``request.get(..., params={"a":"b"})`` + - ``await request.form()`` + * - ``requests.get(..., params={"a":"b"})`` - ``handle.remote(a="b")`` - - ``request.args`` - * - ``request.get(..., data="long string")`` + - ``request.query_params`` + * - ``requests.get(..., data="long string")`` - ``handle.remote("long string")`` - - ``request.data`` + - ``await request.body()`` * - ``N/A`` - ``handle.remote(python_object)`` - ``request.data`` @@ -157,9 +157,9 @@ you would with HTTP request. Here are some examples how ServeRequest mirrors Fla .. code-block:: python - import flask + import starlette.requests - if isinstance(request, flask.Request): + if isinstance(request, starlette.requests.Request): print("Request coming from web!") elif isinstance(request, ServeRequest): print("Request coming from Python!") @@ -170,10 +170,10 @@ you would with HTTP request. Here are some examples how ServeRequest mirrors Fla .. code-block:: python - handle.remote(flask_request) + handle.remote(starlette_request) In this case, Serve will `not` wrap it in ServeRequest. You can directly - process the request as a ``flask.Request``. + process the request as a ``starlette.requests.Request``. How fast is Ray Serve? ---------------------- @@ -187,13 +187,6 @@ You can checkout our `microbenchmark instruction `_ as our web server, -alongside with the power of Python asyncio. -**Flask is ONLY the request object that we are using, Uvicorn (not flask) provides the webserver.** - Can I use asyncio along with Ray Serve? --------------------------------------- Yes! You can make your servable methods ``async def`` and Serve will run them diff --git a/doc/source/serve/index.rst b/doc/source/serve/index.rst index b9c0d1497..af64475a3 100644 --- a/doc/source/serve/index.rst +++ b/doc/source/serve/index.rst @@ -33,6 +33,9 @@ Since Serve is built on Ray, it also allows you to scale to many machines, in yo If you want to try out Serve, join our `community slack `_ and discuss in the #serve channel. +.. note:: + Starting with Ray version 1.3.0, Ray Serve backends must take in a Starlette Request object instead of a Flask Request object. + See the `migration guide `_ for details. Installation ============ diff --git a/doc/source/serve/key-concepts.rst b/doc/source/serve/key-concepts.rst index deada7233..f215ad4b5 100644 --- a/doc/source/serve/key-concepts.rst +++ b/doc/source/serve/key-concepts.rst @@ -19,10 +19,8 @@ Backends Backends define the implementation of your business logic or models that will handle requests when queries come in to :ref:`serve-endpoint`. In order to support seamless scalability backends can have many replicas, which are individual processes running in the Ray cluster to handle requests. To define a backend, first you must define the "handler" or the business logic you'd like to respond with. -The handler should take as input a `Flask Request object `_. -The handler should return any JSON-serializable object as output. For a more customizable response type, the handler may return a +The handler should take as input a `Starlette Request object `_ and return any JSON-serializable object as output. For a more customizable response type, the handler may return a `Starlette Response object `_. -In the future, Ray Serve will support `Starlette Request objects `_ as input as well. A backend is defined using :mod:`client.create_backend `, and the implementation can be defined as either a function or a class. Use a function when your response is stateless and a class when you might need to maintain some state (like a model). @@ -32,7 +30,7 @@ A backend consists of a number of *replicas*, which are individual copies of the .. code-block:: python - def handle_request(flask_request): + def handle_request(starlette_request): return "hello world" class RequestHandler: @@ -40,7 +38,7 @@ A backend consists of a number of *replicas*, which are individual copies of the def __init__(self, msg): self.msg = msg - def __call__(self, flask_request): + def __call__(self, starlette_request): return self.msg client.create_backend("simple_backend", handle_request) diff --git a/doc/source/serve/package-ref.rst b/doc/source/serve/package-ref.rst index 5a9e947ff..e64794500 100644 --- a/doc/source/serve/package-ref.rst +++ b/doc/source/serve/package-ref.rst @@ -23,7 +23,7 @@ Handle API :members: remote, options When calling from Python, the backend implementation will receive ``ServeRequest`` -objects instead of Flask requests. +objects instead of Starlette requests. .. autoclass:: ray.serve.utils.ServeRequest :members: diff --git a/doc/source/serve/tutorials/batch.rst b/doc/source/serve/tutorials/batch.rst index e9a74f94f..48fbc5344 100644 --- a/doc/source/serve/tutorials/batch.rst +++ b/doc/source/serve/tutorials/batch.rst @@ -30,13 +30,13 @@ You can use the ``@serve.accept_batch`` decorator to annotate a function or a cl This annotation is needed because batched backends have different APIs compared to single request backends. In a batched backend, the inputs are a list of values. -For single query backend, the input type is a single Flask request or +For single query backend, the input type is a single Starlette request or :mod:`ServeRequest `: .. code-block:: python def single_request( - request: Union[Flask.Request, ServeRequest], + request: Union[starlette.requests.Request, ServeRequest], ): pass @@ -47,7 +47,7 @@ types: @serve.accept_batch def batched_request( - request: List[Union[Flask.Request, ServeRequest]], + request: List[Union[starlette.requests.Request, ServeRequest]], ): pass @@ -84,8 +84,8 @@ Ray Serve was able to evaluate them in batches. What if you want to evaluate a whole batch in Python? Ray Serve allows you to send queries via the Python API. A batch of queries can either come from the web server -or the Python API. Requests coming from the Python API will have the similar API -as Flask.Request. See more on the API :ref:`here`. +or the Python API. Requests coming from the Python API will have a similar API +to Starlette Request. See more on the API :ref:`here`. .. literalinclude:: ../../../../python/ray/serve/examples/doc/tutorial_batch.py :start-after: __doc_define_servable_v1_begin__ diff --git a/python/ray/serve/examples/doc/quickstart_class.py b/python/ray/serve/examples/doc/quickstart_class.py index 6c6ba2808..d4238ea6a 100644 --- a/python/ray/serve/examples/doc/quickstart_class.py +++ b/python/ray/serve/examples/doc/quickstart_class.py @@ -10,7 +10,7 @@ class Counter: def __init__(self): self.count = 0 - def __call__(self, flask_request): + def __call__(self, starlette_request): self.count += 1 return {"current_counter": self.count} diff --git a/python/ray/serve/examples/doc/quickstart_function.py b/python/ray/serve/examples/doc/quickstart_function.py index 9e7c9bb1f..81ae4b7f1 100644 --- a/python/ray/serve/examples/doc/quickstart_function.py +++ b/python/ray/serve/examples/doc/quickstart_function.py @@ -6,8 +6,8 @@ ray.init(num_cpus=8) client = serve.start() -def echo(flask_request): - return "hello " + flask_request.args.get("name", "serve!") +def echo(starlette_request): + return "hello " + starlette_request.query_params.get("name", "serve!") client.create_backend("hello", echo) diff --git a/python/ray/serve/examples/doc/snippet_model_composition.py b/python/ray/serve/examples/doc/snippet_model_composition.py index 67a1890bf..6439bb9bf 100644 --- a/python/ray/serve/examples/doc/snippet_model_composition.py +++ b/python/ray/serve/examples/doc/snippet_model_composition.py @@ -16,13 +16,13 @@ client = serve.start() def model_one(request): - print("Model 1 called with data ", request.args.get("data")) + print("Model 1 called with data ", request.query_params.get("data")) return random() def model_two(request): - print("Model 2 called with data ", request.args.get("data")) - return request.args.get("data") + print("Model 2 called with data ", request.query_params.get("data")) + return request.query_params.get("data") class ComposedModel: @@ -32,8 +32,8 @@ class ComposedModel: self.model_two = client.get_handle("model_two") # This method can be called concurrently! - async def __call__(self, flask_request): - data = flask_request.data + async def __call__(self, starlette_request): + data = await starlette_request.body() score = await self.model_one.remote(data=data) if score > 0.5: diff --git a/python/ray/serve/examples/doc/tutorial_batch.py b/python/ray/serve/examples/doc/tutorial_batch.py index 3f8129a05..8aa8828b8 100644 --- a/python/ray/serve/examples/doc/tutorial_batch.py +++ b/python/ray/serve/examples/doc/tutorial_batch.py @@ -14,8 +14,10 @@ import requests # __doc_define_servable_v0_begin__ @serve.accept_batch -def batch_adder_v0(flask_requests: List): - numbers = [int(request.args["number"]) for request in flask_requests] +def batch_adder_v0(starlette_requests: List): + numbers = [ + int(request.query_params["number"]) for request in starlette_requests + ] input_array = np.array(numbers) print("Our input array has shape:", input_array.shape) @@ -58,7 +60,7 @@ print("Result returned:", results) # __doc_define_servable_v1_begin__ @serve.accept_batch def batch_adder_v1(requests: List): - numbers = [int(request.args["number"]) for request in requests] + numbers = [int(request.query_params["number"]) for request in requests] input_array = np.array(numbers) print("Our input array has shape:", input_array.shape) # Sleep for 200ms, this could be performing CPU intensive computation diff --git a/python/ray/serve/examples/doc/tutorial_deploy.py b/python/ray/serve/examples/doc/tutorial_deploy.py index 018f28dca..a7d5c75ca 100644 --- a/python/ray/serve/examples/doc/tutorial_deploy.py +++ b/python/ray/serve/examples/doc/tutorial_deploy.py @@ -48,9 +48,9 @@ class BoostingModel: with open("/tmp/iris_labels.json") as f: self.label_list = json.load(f) - def __call__(self, flask_request): - payload = flask_request.json - print("Worker: received flask request with data", payload) + async def __call__(self, starlette_request): + payload = await starlette_request.json() + print("Worker: received starlette request with data", payload) input_vector = [ payload["sepal length"], @@ -143,9 +143,9 @@ class BoostingModelv2: with open("/tmp/iris_labels_2.json") as f: self.label_list = json.load(f) - def __call__(self, flask_request): - payload = flask_request.json - print("Worker: received flask request with data", payload) + async def __call__(self, starlette_request): + payload = await starlette_request.json() + print("Worker: received starlette request with data", payload) input_vector = [ payload["sepal length"], diff --git a/python/ray/serve/examples/doc/tutorial_pytorch.py b/python/ray/serve/examples/doc/tutorial_pytorch.py index 80e3b9d32..bc43534ef 100644 --- a/python/ray/serve/examples/doc/tutorial_pytorch.py +++ b/python/ray/serve/examples/doc/tutorial_pytorch.py @@ -27,8 +27,8 @@ class ImageModel: mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) - def __call__(self, flask_request): - image_payload_bytes = flask_request.data + async def __call__(self, starlette_request): + image_payload_bytes = await starlette_request.body() pil_image = Image.open(BytesIO(image_payload_bytes)) print("[1/3] Parsed image data: {}".format(pil_image)) diff --git a/python/ray/serve/examples/doc/tutorial_sklearn.py b/python/ray/serve/examples/doc/tutorial_sklearn.py index 69f17953b..c2eeb8c8f 100644 --- a/python/ray/serve/examples/doc/tutorial_sklearn.py +++ b/python/ray/serve/examples/doc/tutorial_sklearn.py @@ -54,9 +54,9 @@ class BoostingModel: with open(LABEL_PATH) as f: self.label_list = json.load(f) - def __call__(self, flask_request): - payload = flask_request.json - print("Worker: received flask request with data", payload) + async def __call__(self, starlette_request): + payload = await starlette_request.json() + print("Worker: received starlette request with data", payload) input_vector = [ payload["sepal length"], diff --git a/python/ray/serve/examples/doc/tutorial_tensorflow.py b/python/ray/serve/examples/doc/tutorial_tensorflow.py index 07fb36381..022526296 100644 --- a/python/ray/serve/examples/doc/tutorial_tensorflow.py +++ b/python/ray/serve/examples/doc/tutorial_tensorflow.py @@ -51,10 +51,10 @@ class TFMnistModel: self.model_path = model_path self.model = tf.keras.models.load_model(model_path) - def __call__(self, flask_request): + async def __call__(self, starlette_request): # Step 1: transform HTTP request -> tensorflow input # Here we define the request schema to be a json array. - input_array = np.array(flask_request.json["array"]) + input_array = np.array((await starlette_request.json())["array"]) reshaped_array = input_array.reshape((1, 28, 28)) # Step 2: tensorflow input -> tensorflow output diff --git a/python/ray/serve/examples/echo.py b/python/ray/serve/examples/echo.py index 4f73d3da9..218beef01 100644 --- a/python/ray/serve/examples/echo.py +++ b/python/ray/serve/examples/echo.py @@ -9,8 +9,8 @@ import requests from ray import serve -def echo(flask_request): - return ["hello " + flask_request.args.get("name", "serve!")] +def echo(starlette_request): + return ["hello " + starlette_request.query_params.get("name", "serve!")] client = serve.start() diff --git a/python/ray/serve/examples/echo_actor.py b/python/ray/serve/examples/echo_actor.py index c8d94080a..e6ed0c1b8 100644 --- a/python/ray/serve/examples/echo_actor.py +++ b/python/ray/serve/examples/echo_actor.py @@ -1,6 +1,6 @@ """ Example actor that adds an increment to a number. This number can -come from either web (parsing Flask request) or python call. +come from either web (parsing Starlette request) or python call. This actor can be called from HTTP as well as from Python. """ @@ -30,9 +30,10 @@ class MagicCounter: def __init__(self, increment): self.increment = increment - def __call__(self, flask_request, base_number=None): + def __call__(self, starlette_request, base_number=None): if serve.context.web: - base_number = int(flask_request.args.get("base_number", "0")) + base_number = int( + starlette_request.query_params.get("base_number", "0")) return base_number + self.increment diff --git a/python/ray/serve/examples/echo_actor_batch.py b/python/ray/serve/examples/echo_actor_batch.py index f9ed1df4f..15c7ae67d 100644 --- a/python/ray/serve/examples/echo_actor_batch.py +++ b/python/ray/serve/examples/echo_actor_batch.py @@ -1,6 +1,6 @@ """ Example actor that adds an increment to a number. This number can -come from either web (parsing Flask request) or python call. +come from either web (parsing Starlette request) or python call. The queries incoming to this actor are batched. This actor can be called from HTTP as well as from Python. """ @@ -31,12 +31,13 @@ class MagicCounter: self.increment = increment @serve.accept_batch - def __call__(self, flask_request_list, base_number=None): + def __call__(self, starlette_request_list, base_number=None): # batch_size = serve.context.batch_size if serve.context.web: result = [] - for flask_request in flask_request_list: - base_number = int(flask_request.args.get("base_number", "0")) + for starlette_request in starlette_request_list: + base_number = int( + starlette_request.query_params.get("base_number", "0")) result.append(base_number) return list(map(lambda x: x + self.increment, result)) else: diff --git a/python/ray/serve/examples/echo_batching.py b/python/ray/serve/examples/echo_batching.py index b4c55249a..9ffb214eb 100644 --- a/python/ray/serve/examples/echo_batching.py +++ b/python/ray/serve/examples/echo_batching.py @@ -11,7 +11,7 @@ class MagicCounter: self.increment = increment @serve.accept_batch - def __call__(self, flask_request, base_number=None): + def __call__(self, starlette_request, base_number=None): # __call__ fn should preserve the batch size # base_number is a python list diff --git a/python/ray/serve/examples/echo_full.py b/python/ray/serve/examples/echo_full.py index 9639f1a25..046ec2d6a 100644 --- a/python/ray/serve/examples/echo_full.py +++ b/python/ray/serve/examples/echo_full.py @@ -12,8 +12,8 @@ client = serve.start() # a backend can be a function or class. # it can be made to be invoked from web as well as python. -def echo_v1(flask_request): - response = flask_request.args.get("response", "web") +def echo_v1(starlette_request): + response = starlette_request.query_params.get("response", "web") return response @@ -32,7 +32,7 @@ print(ray.get(client.get_handle("my_endpoint").remote(response="hello"))) # We can also add a new backend and split the traffic. -def echo_v2(flask_request): +def echo_v2(starlette_request): # magic, only from web. return "something new" diff --git a/python/ray/serve/handle.py b/python/ray/serve/handle.py index 9e88959c8..4bfd663fd 100644 --- a/python/ray/serve/handle.py +++ b/python/ray/serve/handle.py @@ -96,7 +96,7 @@ class RayServeHandle: def remote(self, request_data: Optional[Union[Dict, Any]] = None, **kwargs): - """Issue an asynchrounous request to the endpoint. + """Issue an asynchronous request to the endpoint. Returns a Ray ObjectRef whose results can be waited for or retrieved using ray.wait or ray.get, respectively. @@ -106,9 +106,9 @@ class RayServeHandle: Args: request_data(dict, Any): If it's a dictionary, the data will be available in ``request.json()`` or ``request.form()``. - Otherwise, it will be available in ``request.data``. + Otherwise, it will be available in ``request.body()``. ``**kwargs``: All keyword arguments will be available in - ``request.args``. + ``request.query_params``. """ if not self.sync: raise RayServeException( diff --git a/python/ray/serve/http_util.py b/python/ray/serve/http_util.py index 1a057a88e..0aa4ccf84 100644 --- a/python/ray/serve/http_util.py +++ b/python/ray/serve/http_util.py @@ -1,76 +1,25 @@ -import io import json -import flask +import starlette.requests -def build_flask_request(asgi_scope_dict, request_body): - """Build and return a flask request from ASGI payload +def build_starlette_request(scope, serialized_body: bytes): + """Build and return a Starlette Request from ASGI payload. - This function is indented to be used immediately before task invocation - happen. + This function is intended to be used immediately before task invocation + happens. """ - wsgi_environ = build_wsgi_environ(asgi_scope_dict, request_body) - # We set populate_request=False to prevent self reference, which can lead - # to objects tracked by python garbage collector and memory growth. See - # https://github.com/ray-project/ray/issues/12395. - return flask.Request(wsgi_environ, populate_request=False) + # Simulates receiving HTTP body from TCP socket. In reality, the body has + # already been streamed in chunks and stored in serialized_body. + async def mock_receive(): + return { + "body": serialized_body, + "type": "http.request", + "more_body": False + } -def build_wsgi_environ(scope, body): - """ - Builds a scope and request body into a WSGI environ object. - - This code snippet is taken from https://github.com/django/asgiref/blob - /36c3e8dc70bf38fe2db87ac20b514f21aaf5ea9d/asgiref/wsgi.py#L52 - - WSGI specification can be found at - https://www.python.org/dev/peps/pep-0333/ - - This function helps translate ASGI scope and body into a flask request. - """ - environ = { - "REQUEST_METHOD": scope["method"], - "SCRIPT_NAME": scope.get("root_path", ""), - "PATH_INFO": scope["path"], - "QUERY_STRING": scope["query_string"].decode("ascii"), - "SERVER_PROTOCOL": "HTTP/{}".format(scope["http_version"]), - "wsgi.version": (1, 0), - "wsgi.url_scheme": scope.get("scheme", "http"), - "wsgi.input": body, - "wsgi.errors": io.BytesIO(), - "wsgi.multithread": True, - "wsgi.multiprocess": True, - "wsgi.run_once": False, - } - - # Get server name and port - required in WSGI, not in ASGI - environ["SERVER_NAME"] = scope["server"][0] - environ["SERVER_PORT"] = str(scope["server"][1]) - environ["REMOTE_ADDR"] = scope["client"][0] - - # Transforms headers into environ entries. - for name, value in scope.get("headers", []): - # name, values are both bytes, we need to decode them to string - name = name.decode("latin1") - value = value.decode("latin1") - - # Handle name correction to conform to WSGI spec - # https://www.python.org/dev/peps/pep-0333/#environ-variables - if name == "content-length": - corrected_name = "CONTENT_LENGTH" - elif name == "content-type": - corrected_name = "CONTENT_TYPE" - else: - corrected_name = "HTTP_%s" % name.upper().replace("-", "_") - - # If the header value repeated, - # we will just concatenate it to the field. - if corrected_name in environ: - value = environ[corrected_name] + "," + value - - environ[corrected_name] = value - return environ + return starlette.requests.Request(scope, mock_receive) class Response: diff --git a/python/ray/serve/tests/test_api.py b/python/ray/serve/tests/test_api.py index 1f236f915..ea0f6f35d 100644 --- a/python/ray/serve/tests/test_api.py +++ b/python/ray/serve/tests/test_api.py @@ -19,8 +19,8 @@ from ray.serve.utils import (block_until_http_ready, format_actor_name, def test_e2e(serve_instance): client = serve_instance - def function(flask_request): - return {"method": flask_request.method} + def function(starlette_request): + return {"method": starlette_request.method} client.create_backend("echo:v1", function) client.create_endpoint( @@ -97,7 +97,7 @@ def test_backend_user_config(serve_instance): def __init__(self): self.count = 10 - def __call__(self, flask_request): + def __call__(self, starlette_request): return self.count, os.getpid() def reconfigure(self, config): @@ -820,8 +820,8 @@ def test_serve_metrics(serve_instance): client = serve_instance @serve.accept_batch - def batcher(flask_requests): - return ["hello"] * len(flask_requests) + def batcher(starlette_requests): + return ["hello"] * len(starlette_requests) client.create_backend("metrics", batcher) client.create_endpoint("metrics", backend="metrics", route="/metrics") @@ -871,6 +871,26 @@ def test_serve_metrics(serve_instance): verify_metrics() +def test_starlette_request(serve_instance): + client = serve_instance + + async def echo_body(starlette_request): + data = await starlette_request.body() + return data + + UVICORN_HIGH_WATER_MARK = 65536 # max bytes in one message + + # Long string to test serialization of multiple messages. + long_string = "x" * 10 * UVICORN_HIGH_WATER_MARK + + client.create_backend("echo:v1", echo_body) + client.create_endpoint( + "endpoint", backend="echo:v1", route="/api", methods=["GET", "POST"]) + + resp = requests.post("http://127.0.0.1:8000/api", data=long_string).text + assert resp == long_string + + if __name__ == "__main__": import sys sys.exit(pytest.main(["-v", "-s", __file__])) diff --git a/python/ray/serve/tests/test_backend_worker.py b/python/ray/serve/tests/test_backend_worker.py index 1b03c0835..4ae745ba6 100644 --- a/python/ray/serve/tests/test_backend_worker.py +++ b/python/ray/serve/tests/test_backend_worker.py @@ -85,7 +85,7 @@ async def test_runner_wraps_error(): async def test_servable_function(serve_instance, router, mock_controller_with_name): def echo(request): - return request.args["i"] + return request.query_params["i"] await add_servable_to_router(echo, router, mock_controller_with_name[0]) @@ -103,7 +103,7 @@ async def test_servable_class(serve_instance, router, self.increment = inc def __call__(self, request): - return request.args["i"] + self.increment + return request.query_params["i"] + self.increment await add_servable_to_router( MyAdder, router, mock_controller_with_name[0], init_args=(3, )) @@ -277,7 +277,7 @@ async def test_user_config_update(serve_instance, router, def __init__(self): self.reval = "" - def __call__(self, flask_request): + def __call__(self, starlette_request): return self.retval def reconfigure(self, config): diff --git a/python/ray/serve/tests/test_handle.py b/python/ray/serve/tests/test_handle.py index 6e5c91d85..cc6b1e72b 100644 --- a/python/ray/serve/tests/test_handle.py +++ b/python/ray/serve/tests/test_handle.py @@ -8,7 +8,7 @@ def test_handle_in_endpoint(serve_instance): client = serve_instance class Endpoint1: - def __call__(self, flask_request): + def __call__(self, starlette_request): return "hello" class Endpoint2: @@ -40,12 +40,12 @@ def test_handle_http_args(serve_instance): client = serve_instance class Endpoint: - def __call__(self, request): + async def __call__(self, request): return { - "args": dict(request.args), + "args": dict(request.query_params), "headers": dict(request.headers), "method": request.method, - "json": request.json + "json": await request.json() } client.create_backend("backend", Endpoint) @@ -58,7 +58,7 @@ def test_handle_http_args(serve_instance): "arg2": "2" }, "headers": { - "X-Custom-Header": "value" + "x-custom-header": "value" }, "method": "POST", "json": { @@ -81,10 +81,10 @@ def test_handle_http_args(serve_instance): for resp in [resp_web, resp_handle]: for field in ["args", "method", "json"]: assert resp[field] == ground_truth[field] - resp["headers"]["X-Custom-Header"] == "value" + resp["headers"]["x-custom-header"] == "value" -def test_handle_inject_flask_request(serve_instance): +def test_handle_inject_starlette_request(serve_instance): client = serve_instance def echo_request_type(request): @@ -103,7 +103,7 @@ def test_handle_inject_flask_request(serve_instance): for route in ["/echo", "/wrapper"]: resp = requests.get(f"http://127.0.0.1:8000{route}") request_type = resp.text - assert request_type == "" + assert request_type == "" if __name__ == "__main__": diff --git a/python/ray/serve/tests/test_persistence.py b/python/ray/serve/tests/test_persistence.py index fec43f838..6124f41a0 100644 --- a/python/ray/serve/tests/test_persistence.py +++ b/python/ray/serve/tests/test_persistence.py @@ -12,7 +12,7 @@ ray.init(address="{}") from ray import serve client = serve.connect() -def driver(flask_request): +def driver(starlette_request): return "OK!" client.create_backend("driver", driver) diff --git a/python/ray/serve/tests/test_regression.py b/python/ray/serve/tests/test_regression.py index 027b53a72..e2425edf9 100644 --- a/python/ray/serve/tests/test_regression.py +++ b/python/ray/serve/tests/test_regression.py @@ -15,7 +15,7 @@ def test_np_in_composed_model(serve_instance): # in cloudpickle _from_numpy_buffer def sum_model(request): - return np.sum(request.args["data"]) + return np.sum(request.query_params["data"]) class ComposedModel: def __init__(self): @@ -42,7 +42,7 @@ def test_backend_worker_memory_growth(serve_instance): # https://github.com/ray-project/ray/issues/12395 client = serve_instance - def gc_unreachable_objects(flask_request): + def gc_unreachable_objects(starlette_request): gc.set_debug(gc.DEBUG_SAVEALL) gc.collect() return len(gc.garbage) diff --git a/python/ray/serve/utils.py b/python/ray/serve/utils.py index e8c5a6d13..99a65125e 100644 --- a/python/ray/serve/utils.py +++ b/python/ray/serve/utils.py @@ -8,7 +8,6 @@ import random import string import time from typing import List, Dict -import io import os from ray.serve.exceptions import RayServeException from collections import UserDict @@ -16,18 +15,18 @@ from collections import UserDict import requests import numpy as np import pydantic -import flask +import starlette.requests import ray from ray.serve.constants import HTTP_PROXY_TIMEOUT from ray.serve.context import TaskContext -from ray.serve.http_util import build_flask_request +from ray.serve.http_util import build_starlette_request ACTOR_FAILURE_RETRY_TIMEOUT_S = 60 class ServeMultiDict(UserDict): - """Compatible data structure to simulate Flask.Request.args API.""" + """Compatible data structure to simulate Starlette Request query_args.""" def getlist(self, key): """Return the list of items for a given key.""" @@ -35,11 +34,14 @@ class ServeMultiDict(UserDict): class ServeRequest: - """The request object used in Python context. + """The request object used when passing arguments via ServeHandle. - ServeRequest is built to have similar API as Flask.Request. You only need - to write your model serving code once; it can be queried by both HTTP and - Python. + ServeRequest partially implements the API of Starlette Request. You only + need to write your model serving code once; it can be queried by both HTTP + and Python. + + To use the full Starlette Request interface with ServeHandle, you may + instead directly pass in a Starlette Request object to the ServeHandle. """ def __init__(self, data, kwargs, headers, method): @@ -59,28 +61,25 @@ class ServeRequest: return self._method @property - def args(self): + def query_params(self): """The keyword arguments from ``handle.remote(**kwargs)``.""" return self._kwargs - @property - def json(self): + async def json(self): """The request dictionary, from ``handle.remote(dict)``.""" if not isinstance(self._data, dict): raise RayServeException("Request data is not a dictionary. " f"It is {type(self._data)}.") return self._data - @property - def form(self): + async def form(self): """The request dictionary, from ``handle.remote(dict)``.""" if not isinstance(self._data, dict): raise RayServeException("Request data is not a dictionary. " f"It is {type(self._data)}.") return self._data - @property - def data(self): + async def body(self): """The request data from ``handle.remote(obj)``.""" return self._data @@ -88,13 +87,13 @@ class ServeRequest: def parse_request_item(request_item): if request_item.metadata.request_context == TaskContext.Web: asgi_scope, body_bytes = request_item.args - return build_flask_request(asgi_scope, io.BytesIO(body_bytes)) + return build_starlette_request(asgi_scope, body_bytes) else: arg = request_item.args[0] if len(request_item.args) == 1 else None # If the input data from handle is web request, we don't need to wrap # it in ServeRequest. - if isinstance(arg, flask.Request): + if isinstance(arg, starlette.requests.Request): return arg return ServeRequest( From 03a5b90ed633809901e11575a83bb599e38fbf80 Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Mon, 21 Dec 2020 15:16:42 -0800 Subject: [PATCH 60/88] =?UTF-8?q?Revert=20"Revert=20"Increase=20the=20numb?= =?UTF-8?q?er=20of=20unique=20bits=20for=20actors=20to=20avoi=E2=80=A6=20(?= =?UTF-8?q?#12990)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../tests/test_stats_collector.py | 18 +++++++----------- dashboard/tests/test_memory_utils.py | 5 +++-- .../src/main/java/io/ray/api/id/ActorId.java | 2 +- .../src/main/java/io/ray/api/id/ObjectId.java | 2 +- .../src/main/java/io/ray/api/id/UniqueId.java | 2 +- .../test/java/io/ray/runtime/UniqueIdTest.java | 14 +++++++------- python/ray/exceptions.py | 6 ++++-- python/ray/includes/function_descriptor.pxi | 13 +++++++++---- python/ray/includes/unique_ids.pxi | 2 +- python/ray/log_monitor.py | 2 +- python/ray/ray_constants.py | 2 +- python/ray/serialization.py | 5 +++-- python/ray/tests/test_advanced_3.py | 6 +++--- python/ray/tests/test_multi_node.py | 8 ++++---- python/ray/utils.py | 4 ++-- python/ray/worker.py | 3 ++- src/ray/common/constants.h | 2 +- src/ray/common/id.h | 2 +- src/ray/core_worker/actor_manager.cc | 2 ++ src/ray/gcs/redis_context.cc | 3 +-- .../object_manager/test/object_manager_test.cc | 2 ++ 21 files changed, 57 insertions(+), 48 deletions(-) diff --git a/dashboard/modules/stats_collector/tests/test_stats_collector.py b/dashboard/modules/stats_collector/tests/test_stats_collector.py index f4246770a..bed6d650f 100644 --- a/dashboard/modules/stats_collector/tests/test_stats_collector.py +++ b/dashboard/modules/stats_collector/tests/test_stats_collector.py @@ -112,20 +112,16 @@ def test_memory_table(disable_aiohttp_cache, ray_start_with_dashboard): def check_mem_table(): resp = requests.get(f"{webui_url}/memory/memory_table") resp_data = resp.json() - if not resp_data["result"]: - return False + assert resp_data["result"] latest_memory_table = resp_data["data"]["memoryTable"] summary = latest_memory_table["summary"] - try: - # 1 ref per handle and per object the actor has a ref to - assert summary["totalActorHandles"] == len(actors) * 2 - # 1 ref for my_obj - assert summary["totalLocalRefCount"] == 1 - return True - except AssertionError: - return False + # 1 ref per handle and per object the actor has a ref to + assert summary["totalActorHandles"] == len(actors) * 2 + # 1 ref for my_obj + assert summary["totalLocalRefCount"] == 1 - wait_for_condition(check_mem_table, 10) + wait_until_succeeded_without_exception( + check_mem_table, (AssertionError, ), timeout_ms=1000) def test_get_all_node_details(disable_aiohttp_cache, ray_start_with_dashboard): diff --git a/dashboard/tests/test_memory_utils.py b/dashboard/tests/test_memory_utils.py index f58ecd8ae..212eeefad 100644 --- a/dashboard/tests/test_memory_utils.py +++ b/dashboard/tests/test_memory_utils.py @@ -7,8 +7,9 @@ from ray.new_dashboard.memory_utils import ( NODE_ADDRESS = "127.0.0.1" IS_DRIVER = True PID = 1 -OBJECT_ID = "7wpsIhgZiBz/////AQAAyAEAAAA=" -ACTOR_ID = "fffffffffffffffff66d17ba010000c801000000" + +OBJECT_ID = "ZmZmZmZmZmZmZmZmZmZmZmZmZmZmZmZmZmZmZg==" +ACTOR_ID = "fffffffffffffffffffffffffffffffff66d17ba010000c801000000" DECODED_ID = decode_object_ref_if_needed(OBJECT_ID) OBJECT_SIZE = 100 diff --git a/java/api/src/main/java/io/ray/api/id/ActorId.java b/java/api/src/main/java/io/ray/api/id/ActorId.java index 65a0cf19a..a21d4e79f 100644 --- a/java/api/src/main/java/io/ray/api/id/ActorId.java +++ b/java/api/src/main/java/io/ray/api/id/ActorId.java @@ -7,7 +7,7 @@ import java.util.Random; public class ActorId extends BaseId implements Serializable { - private static final int UNIQUE_BYTES_LENGTH = 4; + private static final int UNIQUE_BYTES_LENGTH = 12; public static final int LENGTH = JobId.LENGTH + UNIQUE_BYTES_LENGTH; diff --git a/java/api/src/main/java/io/ray/api/id/ObjectId.java b/java/api/src/main/java/io/ray/api/id/ObjectId.java index 9b1fa246f..78b677ac8 100644 --- a/java/api/src/main/java/io/ray/api/id/ObjectId.java +++ b/java/api/src/main/java/io/ray/api/id/ObjectId.java @@ -10,7 +10,7 @@ import java.util.Random; */ public class ObjectId extends BaseId implements Serializable { - public static final int LENGTH = 20; + public static final int LENGTH = 28; /** * Create an ObjectId from a ByteBuffer. diff --git a/java/api/src/main/java/io/ray/api/id/UniqueId.java b/java/api/src/main/java/io/ray/api/id/UniqueId.java index 03de53943..44b19f6a7 100644 --- a/java/api/src/main/java/io/ray/api/id/UniqueId.java +++ b/java/api/src/main/java/io/ray/api/id/UniqueId.java @@ -11,7 +11,7 @@ import java.util.Random; */ public class UniqueId extends BaseId implements Serializable { - public static final int LENGTH = 20; + public static final int LENGTH = 28; public static final UniqueId NIL = genNil(); /** diff --git a/java/runtime/src/test/java/io/ray/runtime/UniqueIdTest.java b/java/runtime/src/test/java/io/ray/runtime/UniqueIdTest.java index ce1b61db1..7496f1baf 100644 --- a/java/runtime/src/test/java/io/ray/runtime/UniqueIdTest.java +++ b/java/runtime/src/test/java/io/ray/runtime/UniqueIdTest.java @@ -12,12 +12,12 @@ public class UniqueIdTest { @Test public void testConstructUniqueId() { // Test `fromHexString()` - UniqueId id1 = UniqueId.fromHexString("00000000123456789ABCDEF123456789ABCDEF00"); - Assert.assertEquals("00000000123456789abcdef123456789abcdef00", id1.toString()); + UniqueId id1 = UniqueId.fromHexString("00000000123456789ABCDEF123456789ABCDEF0123456789ABCDEF00"); + Assert.assertEquals("00000000123456789abcdef123456789abcdef0123456789abcdef00", id1.toString()); Assert.assertFalse(id1.isNil()); try { - UniqueId id2 = UniqueId.fromHexString("000000123456789ABCDEF123456789ABCDEF00"); + UniqueId id2 = UniqueId.fromHexString("000000123456789ABCDEF123456789ABCDEF0123456789ABCDEF00"); // This shouldn't be happened. Assert.assertTrue(false); } catch (IllegalArgumentException e) { @@ -33,16 +33,16 @@ public class UniqueIdTest { } // Test `fromByteBuffer()` - byte[] bytes = DatatypeConverter.parseHexBinary("0123456789ABCDEF0123456789ABCDEF01234567"); - ByteBuffer byteBuffer = ByteBuffer.wrap(bytes, 0, 20); + byte[] bytes = DatatypeConverter.parseHexBinary("0123456789ABCDEF0123456789ABCDEF012345670123456789ABCDEF"); + ByteBuffer byteBuffer = ByteBuffer.wrap(bytes, 0, 28); UniqueId id4 = UniqueId.fromByteBuffer(byteBuffer); Assert.assertTrue(Arrays.equals(bytes, id4.getBytes())); - Assert.assertEquals("0123456789abcdef0123456789abcdef01234567", id4.toString()); + Assert.assertEquals("0123456789abcdef0123456789abcdef012345670123456789abcdef", id4.toString()); // Test `genNil()` UniqueId id6 = UniqueId.NIL; - Assert.assertEquals("FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF".toLowerCase(), id6.toString()); + Assert.assertEquals("FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF".toLowerCase(), id6.toString()); Assert.assertTrue(id6.isNil()); } } diff --git a/python/ray/exceptions.py b/python/ray/exceptions.py index b5a0b477c..56e943db6 100644 --- a/python/ray/exceptions.py +++ b/python/ray/exceptions.py @@ -142,7 +142,8 @@ class WorkerCrashedError(RayError): """Indicates that the worker died unexpectedly while executing a task.""" def __str__(self): - return "The worker died unexpectedly while executing this task." + return ("The worker died unexpectedly while executing this task. " + "Check python-core-worker-*.log files for more information.") class RayActorError(RayError): @@ -153,7 +154,8 @@ class RayActorError(RayError): """ def __str__(self): - return "The actor died unexpectedly before finishing this task." + return ("The actor died unexpectedly before finishing this task. " + "Check python-core-worker-*.log files for more information.") class RaySystemError(RayError): diff --git a/python/ray/includes/function_descriptor.pxi b/python/ray/includes/function_descriptor.pxi index a9ac11fdb..d2c4cbbf4 100644 --- a/python/ray/includes/function_descriptor.pxi +++ b/python/ray/includes/function_descriptor.pxi @@ -12,6 +12,7 @@ import hashlib import cython import inspect import uuid +import ray.ray_constants as ray_constants ctypedef object (*FunctionDescriptor_from_cpp)(const CFunctionDescriptor &) @@ -188,7 +189,8 @@ cdef class PythonFunctionDescriptor(FunctionDescriptor): function_name = function.__name__ class_name = "" - pickled_function_hash = hashlib.sha1(pickled_function).hexdigest() + pickled_function_hash = hashlib.shake_128(pickled_function).hexdigest( + ray_constants.ID_SIZE) return cls(module_name, function_name, class_name, pickled_function_hash) @@ -208,7 +210,10 @@ cdef class PythonFunctionDescriptor(FunctionDescriptor): module_name = target_class.__module__ class_name = target_class.__name__ # Use a random uuid as function hash to solve actor name conflict. - return cls(module_name, "__init__", class_name, str(uuid.uuid4())) + return cls( + module_name, "__init__", class_name, + hashlib.shake_128( + uuid.uuid4().bytes).hexdigest(ray_constants.ID_SIZE)) @property def module_name(self): @@ -268,14 +273,14 @@ cdef class PythonFunctionDescriptor(FunctionDescriptor): Returns: ray.ObjectRef to represent the function descriptor. """ - function_id_hash = hashlib.sha1() + function_id_hash = hashlib.shake_128() # Include the function module and name in the hash. function_id_hash.update(self.typed_descriptor.ModuleName()) function_id_hash.update(self.typed_descriptor.FunctionName()) function_id_hash.update(self.typed_descriptor.ClassName()) function_id_hash.update(self.typed_descriptor.FunctionHash()) # Compute the function ID. - function_id = function_id_hash.digest() + function_id = function_id_hash.digest(ray_constants.ID_SIZE) return ray.FunctionID(function_id) def is_actor_method(self): diff --git a/python/ray/includes/unique_ids.pxi b/python/ray/includes/unique_ids.pxi index bcf766829..52a6730e6 100644 --- a/python/ray/includes/unique_ids.pxi +++ b/python/ray/includes/unique_ids.pxi @@ -31,7 +31,7 @@ def check_id(b, size=kUniqueIDSize): raise TypeError("Unsupported type: " + str(type(b))) if len(b) != size: raise ValueError("ID string needs to have length " + - str(size)) + str(size) + ", got " + str(len(b))) cdef extern from "ray/common/constants.h" nogil: diff --git a/python/ray/log_monitor.py b/python/ray/log_monitor.py index ac5fa5296..d6b3a314e 100644 --- a/python/ray/log_monitor.py +++ b/python/ray/log_monitor.py @@ -22,7 +22,7 @@ from ray.ray_logging import setup_component_logger logger = logging.getLogger(__name__) # The groups are worker id, job id, and pid. -JOB_LOG_PATTERN = re.compile(".*worker-([0-9a-f]{40})-(\d+)-(\d+)") +JOB_LOG_PATTERN = re.compile(".*worker-([0-9a-f]+)-(\d+)-(\d+)") class LogFileInfo: diff --git a/python/ray/ray_constants.py b/python/ray/ray_constants.py index be717ca3c..30b3b5c7b 100644 --- a/python/ray/ray_constants.py +++ b/python/ray/ray_constants.py @@ -19,7 +19,7 @@ def env_bool(key, default): return default -ID_SIZE = 20 +ID_SIZE = 28 # The default maximum number of bytes to allocate to the object store unless # overridden by the user. diff --git a/python/ray/serialization.py b/python/ray/serialization.py index dc9a2c40e..9a24f3ccc 100644 --- a/python/ray/serialization.py +++ b/python/ray/serialization.py @@ -74,7 +74,8 @@ def _try_to_compute_deterministic_class_id(cls, depth=5): new_class_id = pickle.dumps(pickle.loads(class_id)) if new_class_id == class_id: # We appear to have reached a fix point, so use this as the ID. - return hashlib.sha1(new_class_id).digest() + return hashlib.shake_128(new_class_id).digest( + ray_constants.ID_SIZE) class_id = new_class_id # We have not reached a fixed point, so we may end up with a different @@ -82,7 +83,7 @@ def _try_to_compute_deterministic_class_id(cls, depth=5): # same class definition being exported many many times. logger.warning( f"WARNING: Could not produce a deterministic class ID for class {cls}") - return hashlib.sha1(new_class_id).digest() + return hashlib.shake_128(new_class_id).digest(ray_constants.ID_SIZE) def object_ref_deserializer(reduced_obj_ref, owner_address): diff --git a/python/ray/tests/test_advanced_3.py b/python/ray/tests/test_advanced_3.py index 7f1e8e639..b1bc25fbb 100644 --- a/python/ray/tests/test_advanced_3.py +++ b/python/ray/tests/test_advanced_3.py @@ -284,14 +284,14 @@ def test_workers(shutdown_only): def test_object_ref_properties(): - id_bytes = b"00112233445566778899" + id_bytes = b"0011223344556677889900001111" object_ref = ray.ObjectRef(id_bytes) assert object_ref.binary() == id_bytes object_ref = ray.ObjectRef.nil() assert object_ref.is_nil() - with pytest.raises(ValueError, match=r".*needs to have length 20.*"): + with pytest.raises(ValueError, match=r".*needs to have length.*"): ray.ObjectRef(id_bytes + b"1234") - with pytest.raises(ValueError, match=r".*needs to have length 20.*"): + with pytest.raises(ValueError, match=r".*needs to have length.*"): ray.ObjectRef(b"0123456789") object_ref = ray.ObjectRef.from_random() assert not object_ref.is_nil() diff --git a/python/ray/tests/test_multi_node.py b/python/ray/tests/test_multi_node.py index cb206112d..fbce475c1 100644 --- a/python/ray/tests/test_multi_node.py +++ b/python/ray/tests/test_multi_node.py @@ -741,10 +741,10 @@ ray.get(main_wait.release.remote()) driver1_out_split = driver1_out.split("\n") driver2_out_split = driver2_out.split("\n") - assert driver1_out_split[0][-1] == "1" - assert driver1_out_split[1][-1] == "2" - assert driver2_out_split[0][-1] == "3" - assert driver2_out_split[1][-1] == "4" + assert driver1_out_split[0][-1] == "1", driver1_out_split + assert driver1_out_split[1][-1] == "2", driver1_out_split + assert driver2_out_split[0][-1] == "3", driver2_out_split + assert driver2_out_split[1][-1] == "4", driver2_out_split if __name__ == "__main__": diff --git a/python/ray/utils.py b/python/ray/utils.py index a3940d6e8..2704e07cc 100644 --- a/python/ray/utils.py +++ b/python/ray/utils.py @@ -50,9 +50,9 @@ def get_ray_temp_dir(): def _random_string(): - id_hash = hashlib.sha1() + id_hash = hashlib.shake_128() id_hash.update(uuid.uuid4().bytes) - id_bytes = id_hash.digest() + id_bytes = id_hash.digest(ray_constants.ID_SIZE) assert len(id_bytes) == ray_constants.ID_SIZE return id_bytes diff --git a/python/ray/worker.py b/python/ray/worker.py index 631a82767..a3d07e5ee 100644 --- a/python/ray/worker.py +++ b/python/ray/worker.py @@ -346,7 +346,8 @@ class Worker: # actually run the function locally. pickled_function = pickle.dumps(function) - function_to_run_id = hashlib.sha1(pickled_function).digest() + function_to_run_id = hashlib.shake_128(pickled_function).digest( + ray_constants.ID_SIZE) key = b"FunctionsToRun:" + function_to_run_id # First run the function on the driver. # We always run the task locally. diff --git a/src/ray/common/constants.h b/src/ray/common/constants.h index 1636846f0..3a3461f2c 100644 --- a/src/ray/common/constants.h +++ b/src/ray/common/constants.h @@ -18,7 +18,7 @@ #include /// Length of Ray full-length IDs in bytes. -constexpr size_t kUniqueIDSize = 20; +constexpr size_t kUniqueIDSize = 28; /// An ObjectID's bytes are split into the task ID itself and the index of the /// object's creation. This is the maximum width of the object index in bits. diff --git a/src/ray/common/id.h b/src/ray/common/id.h index d12ba550d..bd55b27e5 100644 --- a/src/ray/common/id.h +++ b/src/ray/common/id.h @@ -124,7 +124,7 @@ class JobID : public BaseID { class ActorID : public BaseID { private: - static constexpr size_t kUniqueBytesLength = 4; + static constexpr size_t kUniqueBytesLength = 12; public: /// Length of `ActorID` in bytes. diff --git a/src/ray/core_worker/actor_manager.cc b/src/ray/core_worker/actor_manager.cc index e6ef4fc87..6b931082a 100644 --- a/src/ray/core_worker/actor_manager.cc +++ b/src/ray/core_worker/actor_manager.cc @@ -91,6 +91,8 @@ bool ActorManager::AddActorHandle(std::unique_ptr actor_handle, std::placeholders::_1, std::placeholders::_2); RAY_CHECK_OK(gcs_client_->Actors().AsyncSubscribe( actor_id, actor_notification_callback, nullptr)); + } else { + RAY_LOG(ERROR) << "Actor handle already exists " << actor_id.Hex(); } return inserted; diff --git a/src/ray/gcs/redis_context.cc b/src/ray/gcs/redis_context.cc index c4edbb688..afc904d60 100644 --- a/src/ray/gcs/redis_context.cc +++ b/src/ray/gcs/redis_context.cc @@ -336,8 +336,7 @@ Status ConnectWithRetries(const std::string &address, int port, RAY_LOG(WARNING) << errorMessage << " Will retry in " << RayConfig::instance().redis_db_connect_wait_milliseconds() << " milliseconds."; - } - if ((*context)->err) { + } else if ((*context)->err) { RAY_LOG(WARNING) << "Could not establish connection to Redis " << address << ":" << port << " (context.err = " << (*context)->err << "), will retry in " diff --git a/src/ray/object_manager/test/object_manager_test.cc b/src/ray/object_manager/test/object_manager_test.cc index 493127000..48fa9a65a 100644 --- a/src/ray/object_manager/test/object_manager_test.cc +++ b/src/ray/object_manager/test/object_manager_test.cc @@ -467,12 +467,14 @@ class TestObjectManager : public TestObjectManagerBase { } }; +/* TODO(ekl) this seems to be hanging occasionally on Linux TEST_F(TestObjectManager, StartTestObjectManager) { // TODO: Break this test suite into unit tests. auto AsyncStartTests = main_service.wrap([this]() { WaitConnections(); }); AsyncStartTests(); main_service.run(); } +*/ } // namespace ray From 015a0f9935bd3e061d046d57e6e5fbc64ec706c5 Mon Sep 17 00:00:00 2001 From: Edward Oakes Date: Mon, 21 Dec 2020 17:19:39 -0600 Subject: [PATCH 61/88] [serve] Rename replica_tag -> replica in metrics for consistency (#13022) --- python/ray/serve/backend_worker.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/python/ray/serve/backend_worker.py b/python/ray/serve/backend_worker.py index 73088b558..1644c1690 100644 --- a/python/ray/serve/backend_worker.py +++ b/python/ray/serve/backend_worker.py @@ -186,10 +186,10 @@ class RayServeReplica: "backend_replica_starts", description=("The number of time this replica " "has been restarted due to failure."), - tag_keys=("backend", "replica_tag")) + tag_keys=("backend", "replica")) self.restart_counter.set_default_tags({ "backend": self.backend_tag, - "replica_tag": self.replica_tag + "replica": self.replica_tag }) self.queuing_latency_tracker = metrics.Histogram( @@ -198,39 +198,39 @@ class RayServeReplica: "The latency for queries waiting in the replica's queue " "waiting to be processed or batched."), boundaries=DEFAULT_LATENCY_BUCKET_MS, - tag_keys=("backend", "replica_tag")) + tag_keys=("backend", "replica")) self.queuing_latency_tracker.set_default_tags({ "backend": self.backend_tag, - "replica_tag": self.replica_tag + "replica": self.replica_tag }) self.processing_latency_tracker = metrics.Histogram( "backend_processing_latency_ms", description="The latency for queries to be processed", boundaries=DEFAULT_LATENCY_BUCKET_MS, - tag_keys=("backend", "replica_tag", "batch_size")) + tag_keys=("backend", "replica", "batch_size")) self.processing_latency_tracker.set_default_tags({ "backend": self.backend_tag, - "replica_tag": self.replica_tag + "replica": self.replica_tag }) self.num_queued_items = metrics.Gauge( "replica_queued_queries", description=("Current number of queries queued in the " "the backend replicas"), - tag_keys=("backend", "replica_tag")) + tag_keys=("backend", "replica")) self.num_queued_items.set_default_tags({ "backend": self.backend_tag, - "replica_tag": self.replica_tag + "replica": self.replica_tag }) self.num_processing_items = metrics.Gauge( "replica_processing_queries", description="Current number of queries being processed", - tag_keys=("backend", "replica_tag")) + tag_keys=("backend", "replica")) self.num_processing_items.set_default_tags({ "backend": self.backend_tag, - "replica_tag": self.replica_tag + "replica": self.replica_tag }) self.restart_counter.record(1) From 8068041006545e1c1fae44f2322a9655f69e68e1 Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Mon, 21 Dec 2020 18:32:40 -0800 Subject: [PATCH 62/88] Don't release resources during plasma fetch (#13025) --- python/ray/tests/test_object_spilling.py | 54 +++++++++++++++++++ src/ray/common/ray_config_def.h | 4 ++ .../memory_store/memory_store.cc | 20 +++---- .../memory_store/memory_store.h | 8 ++- .../store_provider/plasma_store_provider.cc | 10 ++-- .../store_provider/plasma_store_provider.h | 2 +- 6 files changed, 75 insertions(+), 23 deletions(-) diff --git a/python/ray/tests/test_object_spilling.py b/python/ray/tests/test_object_spilling.py index 624fcb85d..4cefca998 100644 --- a/python/ray/tests/test_object_spilling.py +++ b/python/ray/tests/test_object_spilling.py @@ -563,5 +563,59 @@ def test_fusion_objects(tmp_path, shutdown_only): assert is_test_passing +# https://github.com/ray-project/ray/issues/12912 +def do_test_release_resource(tmp_path, expect_released): + temp_folder = tmp_path / "spill" + ray.init( + num_cpus=1, + object_store_memory=75 * 1024 * 1024, + _system_config={ + "max_io_workers": 1, + "release_resources_during_plasma_fetch": expect_released, + "automatic_object_spilling_enabled": True, + "object_spilling_config": json.dumps({ + "type": "filesystem", + "params": { + "directory_path": str(temp_folder) + } + }), + }) + plasma_obj = ray.put(np.ones(50 * 1024 * 1024, dtype=np.uint8)) + for _ in range(5): + ray.put(np.ones(50 * 1024 * 1024, dtype=np.uint8)) # Force spilling + + @ray.remote + def sneaky_task_tries_to_steal_released_resources(): + print("resources were released!") + + @ray.remote + def f(dep): + while True: + try: + ray.get(dep[0], timeout=0.001) + except ray.exceptions.GetTimeoutError: + pass + + done = f.remote([plasma_obj]) # noqa + canary = sneaky_task_tries_to_steal_released_resources.remote() + ready, _ = ray.wait([canary], timeout=2) + if expect_released: + assert ready + else: + assert not ready + + +@pytest.mark.skipif( + platform.system() == "Windows", reason="Failing on Windows.") +def test_no_release_during_plasma_fetch(tmp_path, shutdown_only): + do_test_release_resource(tmp_path, expect_released=False) + + +@pytest.mark.skipif( + platform.system() == "Windows", reason="Failing on Windows.") +def test_release_during_plasma_fetch(tmp_path, shutdown_only): + do_test_release_resource(tmp_path, expect_released=True) + + if __name__ == "__main__": sys.exit(pytest.main(["-sv", __file__])) diff --git a/src/ray/common/ray_config_def.h b/src/ray/common/ray_config_def.h index 9f9392bf7..fe41477f7 100644 --- a/src/ray/common/ray_config_def.h +++ b/src/ray/common/ray_config_def.h @@ -274,6 +274,10 @@ RAY_CONFIG(int32_t, minimum_gcs_reconnect_interval_milliseconds, 5000) /// Whether start the Plasma Store as a Raylet thread. RAY_CONFIG(bool, plasma_store_as_thread, false) +/// Whether to release worker CPUs during plasma fetches. +/// See https://github.com/ray-project/ray/issues/12912 for further discussion. +RAY_CONFIG(bool, release_resources_during_plasma_fetch, false) + /// The interval at which the gcs client will check if the address of gcs service has /// changed. When the address changed, we will resubscribe again. RAY_CONFIG(int64_t, gcs_service_address_check_interval_milliseconds, 1000) diff --git a/src/ray/core_worker/store_provider/memory_store/memory_store.cc b/src/ray/core_worker/store_provider/memory_store/memory_store.cc index 0391b7a1d..6dad1b37b 100644 --- a/src/ray/core_worker/store_provider/memory_store/memory_store.cc +++ b/src/ray/core_worker/store_provider/memory_store/memory_store.cc @@ -232,18 +232,16 @@ bool CoreWorkerMemoryStore::Put(const RayObject &object, const ObjectID &object_ Status CoreWorkerMemoryStore::Get(const std::vector &object_ids, int num_objects, int64_t timeout_ms, const WorkerContext &ctx, bool remove_after_get, - std::vector> *results, - bool release_resources) { + std::vector> *results) { return GetImpl(object_ids, num_objects, timeout_ms, ctx, remove_after_get, results, - /*abort_if_any_object_is_exception=*/true, release_resources); + /*abort_if_any_object_is_exception=*/true); } Status CoreWorkerMemoryStore::GetImpl(const std::vector &object_ids, int num_objects, int64_t timeout_ms, const WorkerContext &ctx, bool remove_after_get, std::vector> *results, - bool abort_if_any_object_is_exception, - bool release_resources) { + bool abort_if_any_object_is_exception) { (*results).resize(object_ids.size(), nullptr); std::shared_ptr get_request; @@ -301,8 +299,7 @@ Status CoreWorkerMemoryStore::GetImpl(const std::vector &object_ids, // Wait for remaining objects (or timeout). if (should_notify_raylet) { - // SANG-TODO Implement memory store get - RAY_CHECK_OK(raylet_client_->NotifyDirectCallTaskBlocked(release_resources)); + RAY_CHECK_OK(raylet_client_->NotifyDirectCallTaskBlocked(/*release_resources=*/true)); } bool done = false; @@ -377,11 +374,11 @@ Status CoreWorkerMemoryStore::Get( const absl::flat_hash_set &object_ids, int64_t timeout_ms, const WorkerContext &ctx, absl::flat_hash_map> *results, - bool *got_exception, bool release_resources) { + bool *got_exception) { const std::vector id_vector(object_ids.begin(), object_ids.end()); std::vector> result_objects; RAY_RETURN_NOT_OK(Get(id_vector, id_vector.size(), timeout_ms, ctx, - /*remove_after_get=*/false, &result_objects, release_resources)); + /*remove_after_get=*/false, &result_objects)); for (size_t i = 0; i < id_vector.size(); i++) { if (result_objects[i] != nullptr) { @@ -404,9 +401,8 @@ Status CoreWorkerMemoryStore::Wait(const absl::flat_hash_set &object_i std::vector id_vector(object_ids.begin(), object_ids.end()); std::vector> result_objects; RAY_CHECK(object_ids.size() == id_vector.size()); - auto status = - GetImpl(id_vector, num_objects, timeout_ms, ctx, false, &result_objects, - /*abort_if_any_object_is_exception=*/false, /*release_resources=*/true); + auto status = GetImpl(id_vector, num_objects, timeout_ms, ctx, false, &result_objects, + /*abort_if_any_object_is_exception=*/false); // Ignore TimedOut statuses since we return ready objects explicitly. if (!status.IsTimedOut()) { RAY_RETURN_NOT_OK(status); diff --git a/src/ray/core_worker/store_provider/memory_store/memory_store.h b/src/ray/core_worker/store_provider/memory_store/memory_store.h index faadafaff..709227f65 100644 --- a/src/ray/core_worker/store_provider/memory_store/memory_store.h +++ b/src/ray/core_worker/store_provider/memory_store/memory_store.h @@ -58,14 +58,13 @@ class CoreWorkerMemoryStore { /// \return Status. Status Get(const std::vector &object_ids, int num_objects, int64_t timeout_ms, const WorkerContext &ctx, bool remove_after_get, - std::vector> *results, - bool release_resources = true); + std::vector> *results); /// Convenience wrapper around Get() that stores results in a given result map. Status Get(const absl::flat_hash_set &object_ids, int64_t timeout_ms, const WorkerContext &ctx, absl::flat_hash_map> *results, - bool *got_exception, bool release_resources = true); + bool *got_exception); /// Convenience wrapper around Get() that stores ready objects in a given result set. Status Wait(const absl::flat_hash_set &object_ids, int num_objects, @@ -138,12 +137,11 @@ class CoreWorkerMemoryStore { private: /// See the public version of `Get` for meaning of the other arguments. /// \param[in] abort_if_any_object_is_exception Whether we should abort if any object - /// \param[in] release_resources true if memory store blocking get needs to release /// resources. is an exception. Status GetImpl(const std::vector &object_ids, int num_objects, int64_t timeout_ms, const WorkerContext &ctx, bool remove_after_get, std::vector> *results, - bool abort_if_any_object_is_exception, bool release_resources); + bool abort_if_any_object_is_exception); /// Optional callback for putting objects into the plasma store. std::function store_in_plasma_; diff --git a/src/ray/core_worker/store_provider/plasma_store_provider.cc b/src/ray/core_worker/store_provider/plasma_store_provider.cc index 3079b99f5..f7559e9b9 100644 --- a/src/ray/core_worker/store_provider/plasma_store_provider.cc +++ b/src/ray/core_worker/store_provider/plasma_store_provider.cc @@ -226,7 +226,7 @@ Status CoreWorkerPlasmaStoreProvider::Get( const absl::flat_hash_set &object_ids, int64_t timeout_ms, const WorkerContext &ctx, absl::flat_hash_map> *results, - bool *got_exception, bool release_resources) { + bool *got_exception) { int64_t batch_size = RayConfig::instance().worker_fetch_request_size(); std::vector batch_ids; absl::flat_hash_set remaining(object_ids.begin(), object_ids.end()); @@ -277,7 +277,8 @@ Status CoreWorkerPlasmaStoreProvider::Get( size_t previous_size = remaining.size(); // This is a separate IPC from the FetchAndGet in direct call mode. if (ctx.CurrentTaskIsDirectCall() && ctx.ShouldReleaseResourcesOnBlockingCalls()) { - RAY_RETURN_NOT_OK(raylet_client_->NotifyDirectCallTaskBlocked(release_resources)); + RAY_RETURN_NOT_OK(raylet_client_->NotifyDirectCallTaskBlocked( + RayConfig::instance().release_resources_during_plasma_fetch())); } RAY_RETURN_NOT_OK( FetchAndGetFromPlasmaStore(remaining, batch_ids, batch_timeout, @@ -334,9 +335,8 @@ Status CoreWorkerPlasmaStoreProvider::Wait( // This is a separate IPC from the Wait in direct call mode. if (ctx.CurrentTaskIsDirectCall() && ctx.ShouldReleaseResourcesOnBlockingCalls()) { - // SANG-TODO Implement wait - RAY_RETURN_NOT_OK( - raylet_client_->NotifyDirectCallTaskBlocked(/*release_resources*/ true)); + RAY_RETURN_NOT_OK(raylet_client_->NotifyDirectCallTaskBlocked( + RayConfig::instance().release_resources_during_plasma_fetch())); } const auto owner_addresses = reference_counter_->GetOwnerAddresses(id_vector); RAY_RETURN_NOT_OK( diff --git a/src/ray/core_worker/store_provider/plasma_store_provider.h b/src/ray/core_worker/store_provider/plasma_store_provider.h index 6085a50c1..e9c7a23ee 100644 --- a/src/ray/core_worker/store_provider/plasma_store_provider.h +++ b/src/ray/core_worker/store_provider/plasma_store_provider.h @@ -90,7 +90,7 @@ class CoreWorkerPlasmaStoreProvider { Status Get(const absl::flat_hash_set &object_ids, int64_t timeout_ms, const WorkerContext &ctx, absl::flat_hash_map> *results, - bool *got_exception, bool release_resources = true); + bool *got_exception); Status Contains(const ObjectID &object_id, bool *has_object); From d5604eaba321c11c1b9616c283262c4ddea55049 Mon Sep 17 00:00:00 2001 From: Sven Mika Date: Mon, 21 Dec 2020 21:38:34 -0500 Subject: [PATCH 63/88] [RLlib] Attention nets PyTorch support and cleanup (using traj. view API). (#12029) --- rllib/BUILD | 19 ++- rllib/agents/ppo/ppo_torch_policy.py | 3 +- .../collectors/simple_list_collector.py | 8 +- .../tests/test_trajectory_view_api.py | 97 +++++++++++- rllib/examples/attention_net.py | 5 +- .../examples/custom_metrics_and_callbacks.py | 1 + rllib/examples/env/debug_counter_env.py | 19 ++- .../models/centralized_critic_models.py | 9 +- rllib/models/torch/attention_net.py | 140 +++++++++--------- rllib/models/torch/modules/gru_gate.py | 25 ++-- .../modules/relative_multi_head_attention.py | 84 ++++++++--- rllib/policy/eager_tf_policy.py | 12 +- rllib/policy/tf_policy_template.py | 3 - rllib/policy/torch_policy.py | 15 +- 14 files changed, 292 insertions(+), 148 deletions(-) diff --git a/rllib/BUILD b/rllib/BUILD index c645c27a0..44a147b6d 100644 --- a/rllib/BUILD +++ b/rllib/BUILD @@ -1480,20 +1480,19 @@ py_test( name = "examples/attention_net_tf", main = "examples/attention_net.py", tags = ["examples", "examples_A"], - size = "large", + size = "medium", srcs = ["examples/attention_net.py"], args = ["--as-test", "--stop-reward=80"] ) -# TODO(sven): GTrXL PyTorch. -# py_test( -# name = "examples/attention_net_torch", -# main = "examples/attention_net.py", -# tags = ["examples", "examples_A"], -# size = "large", -# srcs = ["examples/attention_net.py"], -# args = ["--as-test", "--torch", "--stop-reward=90"] -# ) +py_test( + name = "examples/attention_net_torch", + main = "examples/attention_net.py", + tags = ["examples", "examples_A"], + size = "medium", + srcs = ["examples/attention_net.py"], + args = ["--as-test", "--stop-reward=80", "--torch"] +) py_test( name = "examples/autoregressive_action_dist_tf", diff --git a/rllib/agents/ppo/ppo_torch_policy.py b/rllib/agents/ppo/ppo_torch_policy.py index fa2ca6c1d..d99251298 100644 --- a/rllib/agents/ppo/ppo_torch_policy.py +++ b/rllib/agents/ppo/ppo_torch_policy.py @@ -49,7 +49,8 @@ def ppo_surrogate_loss( # RNN case: Mask away 0-padded chunks at end of time axis. if state: - max_seq_len = torch.max(train_batch["seq_lens"]) + B = len(train_batch["seq_lens"]) + max_seq_len = logits.shape[0] // B mask = sequence_mask( train_batch["seq_lens"], max_seq_len, diff --git a/rllib/evaluation/collectors/simple_list_collector.py b/rllib/evaluation/collectors/simple_list_collector.py index 1d5fe3f76..96e6d0624 100644 --- a/rllib/evaluation/collectors/simple_list_collector.py +++ b/rllib/evaluation/collectors/simple_list_collector.py @@ -35,9 +35,6 @@ def to_float_np_array(v: List[Any]) -> np.ndarray: return arr -_INIT_COLS = [SampleBatch.OBS] - - class _AgentCollector: """Collects samples for one agent in one trajectory (episode). @@ -55,8 +52,9 @@ class _AgentCollector: # or internal state inputs. self.shift_before = -min( (int(vr.shift.split(":")[0]) - if isinstance(vr.shift, str) else vr.shift) + - (-1 if vr.data_col in _INIT_COLS or k in _INIT_COLS else 0) + if isinstance(vr.shift, str) else vr.shift) - + (1 + if vr.data_col == SampleBatch.OBS or k == SampleBatch.OBS else 0) for k, vr in view_reqs.items()) # The actual data buffers (lists holding each timestep's data). self.buffers: Dict[str, List] = {} diff --git a/rllib/evaluation/tests/test_trajectory_view_api.py b/rllib/evaluation/tests/test_trajectory_view_api.py index a50978bfd..1a13300de 100644 --- a/rllib/evaluation/tests/test_trajectory_view_api.py +++ b/rllib/evaluation/tests/test_trajectory_view_api.py @@ -7,19 +7,41 @@ import unittest import ray from ray import tune +from ray.rllib.agents.callbacks import DefaultCallbacks import ray.rllib.agents.dqn as dqn import ray.rllib.agents.ppo as ppo from ray.rllib.examples.env.debug_counter_env import MultiAgentDebugCounterEnv from ray.rllib.examples.env.multi_agent import MultiAgentCartPole from ray.rllib.evaluation.rollout_worker import RolloutWorker from ray.rllib.examples.policy.episode_env_aware_policy import \ - EpisodeEnvAwareLSTMPolicy + EpisodeEnvAwareAttentionPolicy, EpisodeEnvAwareLSTMPolicy +from ray.rllib.models.tf.attention_net import GTrXLNet from ray.rllib.policy.rnn_sequencing import pad_batch_to_sequences_of_same_size from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.policy.view_requirement import ViewRequirement +from ray.rllib.utils.annotations import override from ray.rllib.utils.test_utils import framework_iterator, check +class MyCallbacks(DefaultCallbacks): + @override(DefaultCallbacks) + def on_learn_on_batch(self, *, policy, train_batch, **kwargs): + assert train_batch.count == 201 + assert sum(train_batch.seq_lens) == 201 + for k, v in train_batch.data.items(): + if k == "state_in_0": + assert len(v) == len(train_batch.seq_lens) + else: + assert len(v) == 201 + current = None + for o in train_batch[SampleBatch.OBS]: + if current: + assert o == current + 1 + current = o + if o == 15: + current = None + + class TestTrajectoryViewAPI(unittest.TestCase): @classmethod def setUpClass(cls) -> None: @@ -116,6 +138,45 @@ class TestTrajectoryViewAPI(unittest.TestCase): assert view_req_policy[key].shift == 1 trainer.stop() + def test_traj_view_attention_net(self): + config = ppo.DEFAULT_CONFIG.copy() + # Setup attention net. + config["model"] = config["model"].copy() + config["model"]["max_seq_len"] = 50 + config["model"]["custom_model"] = GTrXLNet + config["model"]["custom_model_config"] = { + "num_transformer_units": 1, + "attn_dim": 64, + "num_heads": 2, + "memory_inference": 50, + "memory_training": 50, + "head_dim": 32, + "ff_hidden_dim": 32, + } + # Test with odd batch numbers. + config["train_batch_size"] = 1031 + config["sgd_minibatch_size"] = 201 + config["num_sgd_iter"] = 5 + config["num_workers"] = 0 + config["callbacks"] = MyCallbacks + config["env_config"] = { + "config": { + "start_at_t": 1 + } + } # first obs is [1.0] + + for _ in framework_iterator(config, frameworks="tf2"): + trainer = ppo.PPOTrainer( + config, + env="ray.rllib.examples.env.debug_counter_env.DebugCounterEnv", + ) + rw = trainer.workers.local_worker() + sample = rw.sample() + assert sample.count == config["rollout_fragment_length"] + results = trainer.train() + assert results["train_batch_size"] == config["train_batch_size"] + trainer.stop() + def test_traj_view_simple_performance(self): """Test whether PPOTrainer runs faster w/ `_use_trajectory_view_api`. """ @@ -298,6 +359,40 @@ class TestTrajectoryViewAPI(unittest.TestCase): pol_batch_wo = result.policy_batches["pol0"] check(pol_batch_w.data, pol_batch_wo.data) + def test_traj_view_attention_functionality(self): + action_space = Box(-float("inf"), float("inf"), shape=(3, )) + obs_space = Box(float("-inf"), float("inf"), (4, )) + max_seq_len = 50 + rollout_fragment_length = 201 + policies = { + "pol0": (EpisodeEnvAwareAttentionPolicy, obs_space, action_space, + {}), + } + + def policy_fn(agent_id): + return "pol0" + + config = { + "multiagent": { + "policies": policies, + "policy_mapping_fn": policy_fn, + }, + "model": { + "max_seq_len": max_seq_len, + }, + }, + + rollout_worker_w_api = RolloutWorker( + env_creator=lambda _: MultiAgentDebugCounterEnv({"num_agents": 4}), + policy_config=dict(config, **{"_use_trajectory_view_api": True}), + rollout_fragment_length=rollout_fragment_length, + policy_spec=policies, + policy_mapping_fn=policy_fn, + num_envs=1, + ) + batch = rollout_worker_w_api.sample() + print(batch) + def test_counting_by_agent_steps(self): """Test whether a PPOTrainer can be built with all frameworks.""" config = copy.deepcopy(ppo.DEFAULT_CONFIG) diff --git a/rllib/examples/attention_net.py b/rllib/examples/attention_net.py index de3f06c29..a490b73e9 100644 --- a/rllib/examples/attention_net.py +++ b/rllib/examples/attention_net.py @@ -4,6 +4,7 @@ import os import ray from ray import tune from ray.rllib.models.tf.attention_net import GTrXLNet +from ray.rllib.models.torch.attention_net import GTrXLNet as TorchGTrXLNet from ray.rllib.examples.env.look_and_push import LookAndPush, OneHot from ray.rllib.examples.env.repeat_after_me_env import RepeatAfterMeEnv from ray.rllib.examples.env.repeat_initial_obs_env import RepeatInitialObsEnv @@ -27,8 +28,6 @@ parser.add_argument("--stop-reward", type=float, default=80) if __name__ == "__main__": args = parser.parse_args() - assert not args.torch, "PyTorch not supported for AttentionNets yet!" - ray.init(num_cpus=args.num_cpus or None) registry.register_env("RepeatAfterMeEnv", lambda c: RepeatAfterMeEnv(c)) @@ -52,7 +51,7 @@ if __name__ == "__main__": "num_sgd_iter": 10, "vf_loss_coeff": 1e-5, "model": { - "custom_model": GTrXLNet, + "custom_model": TorchGTrXLNet if args.torch else GTrXLNet, "max_seq_len": 50, "custom_model_config": { "num_transformer_units": 1, diff --git a/rllib/examples/custom_metrics_and_callbacks.py b/rllib/examples/custom_metrics_and_callbacks.py index d7a2c849d..745a94029 100644 --- a/rllib/examples/custom_metrics_and_callbacks.py +++ b/rllib/examples/custom_metrics_and_callbacks.py @@ -28,6 +28,7 @@ class MyCallbacks(DefaultCallbacks): episode: MultiAgentEpisode, env_index: int, **kwargs): print("episode {} (env-idx={}) started.".format( episode.episode_id, env_index)) + episode.user_data["pole_angles"] = [] episode.hist_data["pole_angles"] = [] diff --git a/rllib/examples/env/debug_counter_env.py b/rllib/examples/env/debug_counter_env.py index c14d49951..aa3a9b3b7 100644 --- a/rllib/examples/env/debug_counter_env.py +++ b/rllib/examples/env/debug_counter_env.py @@ -12,18 +12,25 @@ class DebugCounterEnv(gym.Env): Reward is always: current ts % 3. """ - def __init__(self): + def __init__(self, config=None): + config = config or {} self.action_space = gym.spaces.Discrete(2) - self.observation_space = gym.spaces.Box(0, 100, (1, )) - self.i = 0 + self.observation_space = \ + gym.spaces.Box(0, 100, (1, ), dtype=np.float32) + self.start_at_t = int(config.get("start_at_t", 0)) + self.i = self.start_at_t def reset(self): - self.i = 0 - return [self.i] + self.i = self.start_at_t + return self._get_obs() def step(self, action): self.i += 1 - return [self.i], self.i % 3, self.i >= 15, {} + return self._get_obs(), float(self.i % 3), \ + self.i >= 15 + self.start_at_t, {} + + def _get_obs(self): + return np.array([self.i], dtype=np.float32) class MultiAgentDebugCounterEnv(MultiAgentEnv): diff --git a/rllib/examples/models/centralized_critic_models.py b/rllib/examples/models/centralized_critic_models.py index 276f42381..23f1e8b92 100644 --- a/rllib/examples/models/centralized_critic_models.py +++ b/rllib/examples/models/centralized_critic_models.py @@ -45,9 +45,10 @@ class CentralizedCriticModel(TFModelV2): def central_value_function(self, obs, opponent_obs, opponent_actions): return tf.reshape( - self.central_vf( - [obs, opponent_obs, - tf.one_hot(opponent_actions, 2)]), [-1]) + self.central_vf([ + obs, opponent_obs, + tf.one_hot(tf.cast(opponent_actions, tf.int32), 2) + ]), [-1]) @override(ModelV2) def value_function(self): @@ -124,7 +125,7 @@ class TorchCentralizedCriticModel(TorchModelV2, nn.Module): def central_value_function(self, obs, opponent_obs, opponent_actions): input_ = torch.cat([ obs, opponent_obs, - torch.nn.functional.one_hot(opponent_actions, 2).float() + torch.nn.functional.one_hot(opponent_actions.long(), 2).float() ], 1) return torch.reshape(self.central_vf(input_), [-1]) diff --git a/rllib/models/torch/attention_net.py b/rllib/models/torch/attention_net.py index 58480a64b..27d2d494e 100644 --- a/rllib/models/torch/attention_net.py +++ b/rllib/models/torch/attention_net.py @@ -10,12 +10,15 @@ """ import numpy as np import gym +from gym.spaces import Box from ray.rllib.models.modelv2 import ModelV2 from ray.rllib.models.torch.misc import SlimFC from ray.rllib.models.torch.modules import GRUGate, \ RelativeMultiHeadAttention, SkipConnection from ray.rllib.models.torch.recurrent_net import RecurrentNetwork +from ray.rllib.policy.sample_batch import SampleBatch +from ray.rllib.policy.view_requirement import ViewRequirement from ray.rllib.utils.annotations import override from ray.rllib.utils.framework import try_import_torch from ray.rllib.utils.typing import ModelConfigDict, TensorType, List @@ -23,26 +26,6 @@ from ray.rllib.utils.typing import ModelConfigDict, TensorType, List torch, nn = try_import_torch() -def relative_position_embedding(seq_length: int, out_dim: int) -> TensorType: - """Creates a [seq_length x seq_length] matrix for rel. pos encoding. - - Denoted as Phi in [2] and [3]. Phi is the standard sinusoid encoding - matrix. - - Args: - seq_length (int): The max. sequence length (time axis). - out_dim (int): The number of nodes to go into the first Tranformer - layer with. - - Returns: - torch.Tensor: The encoding matrix Phi. - """ - inverse_freq = 1 / (10000**(torch.arange(0, out_dim, 2.0) / out_dim)) - pos_offsets = torch.arange(seq_length - 1, -1, -1) - inputs = pos_offsets[:, None] * inverse_freq[None, :] - return torch.cat((torch.sin(inputs), torch.cos(inputs)), dim=-1) - - class GTrXLNet(RecurrentNetwork, nn.Module): """A GTrXL net Model described in [2]. @@ -74,7 +57,8 @@ class GTrXLNet(RecurrentNetwork, nn.Module): num_transformer_units: int, attn_dim: int, num_heads: int, - memory_tau: int, + memory_inference: int, + memory_training: int, head_dim: int, ff_hidden_dim: int, init_gate_bias: float = 2.0): @@ -87,9 +71,15 @@ class GTrXLNet(RecurrentNetwork, nn.Module): unit. num_heads (int): The number of attention heads to use in parallel. Denoted as `H` in [3]. - memory_tau (int): The number of timesteps to store in each - transformer block's memory M (concat'd over time and fed into - next transformer block as input). + memory_inference (int): The number of timesteps to concat (time + axis) and feed into the next transformer unit as inference + input. The first transformer unit will receive this number of + past observations (plus the current one), instead. + memory_training (int): The number of timesteps to concat (time + axis) and feed into the next transformer unit as training + input (plus the actual input sequence of len=max_seq_len). + The first transformer unit will receive this number of + past observations (plus the input sequence), instead. head_dim (int): The dimension of a single(!) head. Denoted as `d` in [3]. ff_hidden_dim (int): The dimension of the hidden layer within @@ -110,20 +100,18 @@ class GTrXLNet(RecurrentNetwork, nn.Module): self.num_transformer_units = num_transformer_units self.attn_dim = attn_dim self.num_heads = num_heads - self.memory_tau = memory_tau + self.memory_inference = memory_inference + self.memory_training = memory_training self.head_dim = head_dim self.max_seq_len = model_config["max_seq_len"] self.obs_dim = observation_space.shape[0] - # Constant (non-trainable) sinusoid rel pos encoding matrix. - Phi = relative_position_embedding(self.max_seq_len + self.memory_tau, - self.attn_dim) - self.linear_layer = SlimFC( in_size=self.obs_dim, out_size=self.attn_dim) self.layers = [self.linear_layer] + attention_layers = [] # 2) Create L Transformer blocks according to [2]. for i in range(self.num_transformer_units): # RelativeMultiHeadAttention part. @@ -133,7 +121,6 @@ class GTrXLNet(RecurrentNetwork, nn.Module): out_dim=self.attn_dim, num_heads=num_heads, head_dim=head_dim, - rel_pos_encoder=Phi, input_layernorm=True, output_activation=nn.ReLU), fan_in_layer=GRUGate(self.attn_dim, init_gate_bias)) @@ -154,8 +141,13 @@ class GTrXLNet(RecurrentNetwork, nn.Module): activation_fn=nn.ReLU)), fan_in_layer=GRUGate(self.attn_dim, init_gate_bias)) - # Build a list of all layers in order. - self.layers.extend([MHA_layer, E_layer]) + # Build a list of all attanlayers in order. + attention_layers.extend([MHA_layer, E_layer]) + + # Create a Sequential such that all parameters inside the attention + # layers are automatically registered with this top-level model. + self.attention_layers = nn.Sequential(*attention_layers) + self.layers.extend(attention_layers) # Postprocess GTrXL output with another hidden layer. self.logits = SlimFC( @@ -168,62 +160,64 @@ class GTrXLNet(RecurrentNetwork, nn.Module): self.values_out = SlimFC( in_size=self.attn_dim, out_size=1, activation_fn=None) - @override(RecurrentNetwork) - def forward_rnn(self, inputs: TensorType, state: List[TensorType], - seq_lens: TensorType) -> (TensorType, List[TensorType]): - # To make Attention work with current RLlib's ModelV2 API: - # We assume `state` is the history of L recent observations (all - # concatenated into one tensor) and append the current inputs to the - # end and only keep the most recent (up to `max_seq_len`). This allows - # us to deal with timestep-wise inference and full sequence training - # within the same logic. - state = [torch.from_numpy(item) for item in state] - observations = state[0] - memory = state[1:] + # Setup inference view (`memory-inference` x past observations + + # current one (0)) + # 1 to `num_transformer_units`: Memory data (one per transformer unit). + for i in range(self.num_transformer_units): + space = Box(-1.0, 1.0, shape=(self.attn_dim, )) + self.inference_view_requirements["state_in_{}".format(i)] = \ + ViewRequirement( + "state_out_{}".format(i), + shift="-{}:-1".format(self.memory_inference), + # Repeat the incoming state every max-seq-len times. + batch_repeat_value=self.max_seq_len, + space=space) + self.inference_view_requirements["state_out_{}".format(i)] = \ + ViewRequirement( + space=space, + used_for_training=False) - inputs = torch.reshape(inputs, [1, -1, observations.shape[-1]]) - observations = torch.cat( - (observations, inputs), axis=1)[:, -self.max_seq_len:] + @override(ModelV2) + def forward(self, input_dict, state: List[TensorType], + seq_lens: TensorType) -> (TensorType, List[TensorType]): + assert seq_lens is not None + + # Add the needed batch rank (tf Models' Input requires this). + observations = input_dict[SampleBatch.OBS] + # Add the time dim to observations. + B = len(seq_lens) + T = observations.shape[0] // B + observations = torch.reshape(observations, + [-1, T] + list(observations.shape[1:])) all_out = observations + memory_outs = [] for i in range(len(self.layers)): # MHA layers which need memory passed in. if i % 2 == 1: - all_out = self.layers[i](all_out, memory=memory[i // 2]) - # Either linear layers or MultiLayerPerceptrons. + all_out = self.layers[i](all_out, memory=state[i // 2]) + # Either self.linear_layer (initial obs -> attn. dim layer) or + # MultiLayerPerceptrons. The output of these layers is always the + # memory for the next forward pass. else: all_out = self.layers[i](all_out) + memory_outs.append(all_out) + + # Discard last output (not needed as a memory since it's the last + # layer). + memory_outs = memory_outs[:-1] logits = self.logits(all_out) self._value_out = self.values_out(all_out) - memory_outs = all_out[2:] - # If memory_tau > max_seq_len -> overlap w/ previous `memory` input. - if self.memory_tau > self.max_seq_len: - memory_outs = [ - torch.cat( - [memory[i][:, -(self.memory_tau - self.max_seq_len):], m], - axis=1) for i, m in enumerate(memory_outs) - ] - else: - memory_outs = [m[:, -self.memory_tau:] for m in memory_outs] - - T = list(inputs.size())[1] # Length of input segment (time). - - # Postprocessing final output. - logits = logits[:, -T:] - self._value_out = self._value_out[:, -T:] - - return logits, [observations] + memory_outs + return torch.reshape(logits, [-1, self.num_outputs]), [ + torch.reshape(m, [-1, self.attn_dim]) for m in memory_outs + ] + # TODO: (sven) Deprecate this once trajectory view API has fully matured. @override(RecurrentNetwork) def get_initial_state(self) -> List[np.ndarray]: - # State is the T last observations concat'd together into one Tensor. - # Plus all Transformer blocks' E(l) outputs concat'd together (up to - # tau timesteps). - return [np.zeros((self.max_seq_len, self.obs_dim), np.float32)] + \ - [np.zeros((self.memory_tau, self.attn_dim), np.float32) - for _ in range(self.num_transformer_units)] + return [] @override(ModelV2) def value_function(self) -> TensorType: diff --git a/rllib/models/torch/modules/gru_gate.py b/rllib/models/torch/modules/gru_gate.py index 4cabc5eb2..724c41464 100644 --- a/rllib/models/torch/modules/gru_gate.py +++ b/rllib/models/torch/modules/gru_gate.py @@ -13,26 +13,29 @@ class GRUGate(nn.Module): init_bias (int): Bias added to every input to stabilize training """ super().__init__(**kwargs) - self._init_bias = init_bias - # Xavier initialization of torch tensors - self._w_r = torch.zeros(dim, dim) - self._w_z = torch.zeros(dim, dim) - self._w_h = torch.zeros(dim, dim) - - self._u_r = torch.zeros(dim, dim) - self._u_z = torch.zeros(dim, dim) - self._u_h = torch.zeros(dim, dim) - + self._w_r = nn.Parameter(torch.zeros(dim, dim)) + self._w_z = nn.Parameter(torch.zeros(dim, dim)) + self._w_h = nn.Parameter(torch.zeros(dim, dim)) nn.init.xavier_uniform_(self._w_r) nn.init.xavier_uniform_(self._w_z) nn.init.xavier_uniform_(self._w_h) + self.register_parameter("_w_r", self._w_r) + self.register_parameter("_w_z", self._w_z) + self.register_parameter("_w_h", self._w_h) + self._u_r = nn.Parameter(torch.zeros(dim, dim)) + self._u_z = nn.Parameter(torch.zeros(dim, dim)) + self._u_h = nn.Parameter(torch.zeros(dim, dim)) nn.init.xavier_uniform_(self._u_r) nn.init.xavier_uniform_(self._u_z) nn.init.xavier_uniform_(self._u_h) + self.register_parameter("_u_r", self._u_r) + self.register_parameter("_u_z", self._u_z) + self.register_parameter("_u_h", self._u_h) - self._bias_z = torch.zeros(dim, ).fill_(self._init_bias) + self._bias_z = nn.Parameter(torch.zeros(dim, ).fill_(init_bias)) + self.register_parameter("_bias_z", self._bias_z) def forward(self, inputs: TensorType, **kwargs) -> TensorType: # Pass in internal state first. diff --git a/rllib/models/torch/modules/relative_multi_head_attention.py b/rllib/models/torch/modules/relative_multi_head_attention.py index fe28d6f73..3efa9c664 100644 --- a/rllib/models/torch/modules/relative_multi_head_attention.py +++ b/rllib/models/torch/modules/relative_multi_head_attention.py @@ -1,11 +1,47 @@ +from typing import Union + from ray.rllib.utils.framework import try_import_torch from ray.rllib.models.torch.misc import SlimFC from ray.rllib.utils.torch_ops import sequence_mask -from ray.rllib.utils.typing import TensorType, Any +from ray.rllib.utils.typing import TensorType torch, nn = try_import_torch() +class RelativePositionEmbedding(nn.Module): + """Creates a [seq_length x seq_length] matrix for rel. pos encoding. + + Denoted as Phi in [2] and [3]. Phi is the standard sinusoid encoding + matrix. + + Args: + seq_length (int): The max. sequence length (time axis). + out_dim (int): The number of nodes to go into the first Tranformer + layer with. + + Returns: + torch.Tensor: The encoding matrix Phi. + """ + + def __init__(self, out_dim, **kwargs): + super().__init__() + self.out_dim = out_dim + + out_range = torch.arange(0, self.out_dim, 2.0) + inverse_freq = 1 / (10000**(out_range / self.out_dim)) + self.register_buffer("inverse_freq", inverse_freq) + + def forward(self, seq_length): + pos_input = torch.arange( + seq_length - 1, -1, -1.0, + dtype=torch.float).to(self.inverse_freq.device) + sinusoid_input = torch.einsum("i,j->ij", pos_input, self.inverse_freq) + pos_embeddings = torch.cat( + [torch.sin(sinusoid_input), + torch.cos(sinusoid_input)], dim=-1) + return pos_embeddings[:, None, :] + + class RelativeMultiHeadAttention(nn.Module): """A RelativeMultiHeadAttention layer as described in [3]. @@ -17,24 +53,24 @@ class RelativeMultiHeadAttention(nn.Module): out_dim: int, num_heads: int, head_dim: int, - rel_pos_encoder: Any, input_layernorm: bool = False, - output_activation: Any = None, + output_activation: Union[str, callable] = None, **kwargs): """Initializes a RelativeMultiHeadAttention nn.Module object. Args: in_dim (int): - out_dim (int): + out_dim (int): The output dimension of this module. Also known as + "attention dim". num_heads (int): The number of attention heads to use. Denoted `H` in [2]. head_dim (int): The dimension of a single(!) attention head Denoted `D` in [2]. - rel_pos_encoder (: input_layernorm (bool): Whether to prepend a LayerNorm before everything else. Should be True for building a GTrXL. - output_activation (Optional[tf.nn.activation]): Optional tf.nn - activation function. Should be relu for GTrXL. + output_activation (Union[str, callable]): Optional activation + function or activation function specifier (str). + Should be "relu" for GTrXL. **kwargs: """ super().__init__(**kwargs) @@ -53,17 +89,18 @@ class RelativeMultiHeadAttention(nn.Module): use_bias=False, activation_fn=output_activation) - self._pos_proj = SlimFC( - in_size=in_dim, out_size=num_heads * head_dim, use_bias=False) - - self._uvar = torch.zeros(num_heads, head_dim) - self._vvar = torch.zeros(num_heads, head_dim) + self._uvar = nn.Parameter(torch.zeros(num_heads, head_dim)) + self._vvar = nn.Parameter(torch.zeros(num_heads, head_dim)) nn.init.xavier_uniform_(self._uvar) nn.init.xavier_uniform_(self._vvar) + self.register_parameter("_uvar", self._uvar) + self.register_parameter("_vvar", self._vvar) + + self._pos_proj = SlimFC( + in_size=in_dim, out_size=num_heads * head_dim, use_bias=False) + self._rel_pos_embedding = RelativePositionEmbedding(out_dim) - self._rel_pos_encoder = rel_pos_encoder self._input_layernorm = None - if input_layernorm: self._input_layernorm = torch.nn.LayerNorm(in_dim) @@ -75,10 +112,8 @@ class RelativeMultiHeadAttention(nn.Module): # Add previous memory chunk (as const, w/o gradient) to input. # Tau (number of (prev) time slices in each memory chunk). - Tau = list(memory.shape)[1] if memory is not None else 0 - if memory is not None: - memory.requires_grad_(False) - inputs = torch.cat((memory, inputs), dim=1) + Tau = list(memory.shape)[1] + inputs = torch.cat((memory.detach(), inputs), dim=1) # Apply the Layer-Norm. if self._input_layernorm is not None: @@ -91,11 +126,11 @@ class RelativeMultiHeadAttention(nn.Module): queries = queries[:, -T:] queries = torch.reshape(queries, [-1, T, H, d]) - keys = torch.reshape(keys, [-1, T + Tau, H, d]) - values = torch.reshape(values, [-1, T + Tau, H, d]) + keys = torch.reshape(keys, [-1, Tau + T, H, d]) + values = torch.reshape(values, [-1, Tau + T, H, d]) - R = self._pos_proj(self._rel_pos_encoder) - R = torch.reshape(R, [T + Tau, H, d]) + R = self._pos_proj(self._rel_pos_embedding(Tau + T)) + R = torch.reshape(R, [Tau + T, H, d]) # b=batch # i and j=time indices (i=max-timesteps (inputs); j=Tau memory space) @@ -108,10 +143,11 @@ class RelativeMultiHeadAttention(nn.Module): # causal mask of the same length as the sequence mask = sequence_mask( - torch.arange(Tau + 1, T + Tau + 1), dtype=score.dtype) + torch.arange(Tau + 1, Tau + T + 1), + dtype=score.dtype).to(score.device) mask = mask[None, :, :, None] - masked_score = score * mask + 1e30 * (mask.to(torch.float32) - 1.) + masked_score = score * mask + 1e30 * (mask.float() - 1.) wmat = nn.functional.softmax(masked_score, dim=2) out = torch.einsum("bijh,bjhd->bihd", wmat, values) diff --git a/rllib/policy/eager_tf_policy.py b/rllib/policy/eager_tf_policy.py index f17d60e06..af4fa512c 100644 --- a/rllib/policy/eager_tf_policy.py +++ b/rllib/policy/eager_tf_policy.py @@ -259,10 +259,8 @@ def build_eager_tf_policy(name, self._update_model_inference_view_requirements_from_init_state() self.exploration = self._create_exploration() - self._state_in = [ - tf.convert_to_tensor([s]) - for s in self.model.get_initial_state() - ] + self._state_inputs = self.model.get_initial_state() + self._is_recurrent = len(self._state_inputs) > 0 # Combine view_requirements for Model and Policy. self.view_requirements.update( @@ -375,6 +373,8 @@ def build_eager_tf_policy(name, # TODO: remove python side effect to cull sources of bugs. self._is_training = False + self._is_recurrent = \ + state_batches is not None and state_batches != [] self._state_in = state_batches or [] if not tf1.executing_eagerly(): @@ -552,11 +552,11 @@ def build_eager_tf_policy(name, @override(Policy) def is_recurrent(self): - return len(self._state_in) > 0 + return self._is_recurrent @override(Policy) def num_state_tensors(self): - return len(self._state_in) + return len(self._state_inputs) @override(Policy) def get_initial_state(self): diff --git a/rllib/policy/tf_policy_template.py b/rllib/policy/tf_policy_template.py index a4f5e12b2..34e7da360 100644 --- a/rllib/policy/tf_policy_template.py +++ b/rllib/policy/tf_policy_template.py @@ -170,9 +170,6 @@ def build_tf_policy( mixins (Optional[List[type]]): Optional list of any class mixins for the returned policy class. These mixins will be applied in order and will have higher precedence than the DynamicTFPolicy class. - view_requirements_fn (Callable[[Policy], - Dict[str, ViewRequirement]]): An optional callable to retrieve - additional train view requirements for this policy. get_batch_divisibility_req (Optional[Callable[[Policy], int]]): Optional callable that returns the divisibility requirement for sample batches. If None, will assume a value of 1. diff --git a/rllib/policy/torch_policy.py b/rllib/policy/torch_policy.py index c27a7603d..10e875d50 100644 --- a/rllib/policy/torch_policy.py +++ b/rllib/policy/torch_policy.py @@ -110,6 +110,8 @@ class TorchPolicy(Policy): logger.info("TorchPolicy running on CPU.") self.device = torch.device("cpu") self.model = model.to(self.device) + self._state_inputs = self.model.get_initial_state() + self._is_recurrent = len(self._state_inputs) > 0 # Auto-update model's inference view requirements, if recurrent. self._update_model_inference_view_requirements_from_init_state() # Combine view_requirements for Model and Policy. @@ -203,6 +205,11 @@ class TorchPolicy(Policy): Tuple: - actions, state_out, extra_fetches, logp. """ + self._is_recurrent = state_batches is not None and state_batches != [] + # Switch to eval mode. + if self.model: + self.model.eval() + if self.action_sampler_fn: action_dist = dist_inputs = None state_out = state_batches @@ -325,6 +332,9 @@ class TorchPolicy(Policy): @DeveloperAPI def learn_on_batch( self, postprocessed_batch: SampleBatch) -> Dict[str, TensorType]: + # Set Model to train mode. + if self.model: + self.model.train() # Callback handling. self.callbacks.on_learn_on_batch( policy=self, train_batch=postprocessed_batch) @@ -354,6 +364,9 @@ class TorchPolicy(Policy): view_requirements=self.view_requirements, ) + # Mark the batch as "is_training" so the Model can use this + # information. + postprocessed_batch["is_training"] = True train_batch = self._lazy_tensor_dict(postprocessed_batch) # Calculate the actual policy loss. @@ -448,7 +461,7 @@ class TorchPolicy(Policy): @override(Policy) @DeveloperAPI def is_recurrent(self) -> bool: - return len(self.model.get_initial_state()) > 0 + return self._is_recurrent @override(Policy) @DeveloperAPI From b52cce6632dc0a91c72e2fae6312eac4e2ad4d0d Mon Sep 17 00:00:00 2001 From: Edward Oakes Date: Mon, 21 Dec 2020 20:39:13 -0600 Subject: [PATCH 64/88] [serve] Refactor SystemState into EndpointState and BackendState (#13018) --- python/ray/serve/controller.py | 198 +++++++++++++++++---------------- 1 file changed, 100 insertions(+), 98 deletions(-) diff --git a/python/ray/serve/controller.py b/python/ray/serve/controller.py index 17a543048..4a4b754ff 100644 --- a/python/ray/serve/controller.py +++ b/python/ray/serve/controller.py @@ -93,33 +93,16 @@ class BackendInfo(BaseModel): arbitrary_types_allowed = True -@dataclass -class SystemState: - backends: Dict[BackendTag, BackendInfo] = field(default_factory=dict) - traffic_policies: Dict[EndpointTag, TrafficPolicy] = field( - default_factory=dict) - routes: Dict[BackendTag, Tuple[EndpointTag, Any]] = field( - default_factory=dict) +class EndpointState: + def __init__(self, checkpoint: bytes = None): + self.routes: Dict[BackendTag, Tuple[EndpointTag, Any]] = dict() + self.traffic_policies: Dict[EndpointTag, TrafficPolicy] = dict() - backend_goal_ids: Dict[BackendTag, GoalId] = field(default_factory=dict) - traffic_goal_ids: Dict[EndpointTag, GoalId] = field(default_factory=dict) - route_goal_ids: Dict[BackendTag, GoalId] = field(default_factory=dict) + if checkpoint is not None: + self.routes, self.traffic_policies = pickle.loads(checkpoint) - def get_backend_configs(self) -> Dict[BackendTag, BackendConfig]: - return { - tag: info.backend_config - for tag, info in self.backends.items() - } - - def get_backend(self, backend_tag: BackendTag) -> Optional[BackendInfo]: - return self.backends.get(backend_tag) - - def add_backend(self, - backend_tag: BackendTag, - backend_info: BackendInfo, - goal_id: GoalId = 0) -> None: - self.backends[backend_tag] = backend_info - self.backend_goal_ids = goal_id + def checkpoint(self): + return pickle.dumps((self.routes, self.traffic_policies)) def get_endpoints(self) -> Dict[EndpointTag, Dict[str, Any]]: endpoints = {} @@ -141,6 +124,32 @@ class SystemState: return endpoints +class BackendState: + def __init__(self, checkpoint: bytes = None): + self.backends: Dict[BackendTag, BackendInfo] = dict() + + if checkpoint is not None: + self.backends = pickle.loads(checkpoint) + + def checkpoint(self): + return pickle.dumps(self.backends) + + def get_backend_configs(self) -> Dict[BackendTag, BackendConfig]: + return { + tag: info.backend_config + for tag, info in self.backends.items() + } + + def get_backend(self, backend_tag: BackendTag) -> Optional[BackendInfo]: + return self.backends.get(backend_tag) + + def add_backend(self, + backend_tag: BackendTag, + backend_info: BackendInfo, + goal_id: GoalId = 0) -> None: + self.backends[backend_tag] = backend_info + + @dataclass class ActorStateReconciler: controller_name: str = field(init=True) @@ -192,7 +201,7 @@ class ActorStateReconciler: for replica_dict in self.backend_replicas.values() ])) - async def _start_backend_replica(self, current_state: SystemState, + async def _start_backend_replica(self, backend_state: BackendState, backend_tag: BackendTag, replica_tag: ReplicaTag) -> ActorHandle: """Start a replica and return its actor handle. @@ -210,7 +219,7 @@ class ActorStateReconciler: except ValueError: logger.debug("Starting replica '{}' for backend '{}'.".format( replica_tag, backend_tag)) - backend_info = current_state.get_backend(backend_tag) + backend_info = backend_state.get_backend(backend_tag) replica_handle = ray.remote(backend_info.worker_class).options( name=replica_name, @@ -284,12 +293,12 @@ class ActorStateReconciler: self.backend_replicas_to_stop[backend_tag].append(replica_tag) async def _enqueue_pending_scale_changes_loop(self, - current_state: SystemState): + backend_state: BackendState): for backend_tag, replicas_to_create in self.backend_replicas_to_start.\ items(): for replica_tag in replicas_to_create: replica_handle = await self._start_backend_replica( - current_state, backend_tag, replica_tag) + backend_state, backend_tag, replica_tag) ready_future = replica_handle.ready.remote().as_future() self.currently_starting_replicas[ready_future] = ( backend_tag, replica_tag, replica_handle) @@ -456,19 +465,19 @@ class ActorStateReconciler: replica_tag] = ray.get_actor(replica_name) async def _recover_from_checkpoint( - self, current_state: SystemState, controller: "ServeController" + self, backend_state: BackendState, controller: "ServeController" ) -> Dict[BackendTag, BasicAutoscalingPolicy]: self._recover_actor_handles() autoscaling_policies = dict() - for backend, info in current_state.backends.items(): + for backend, info in backend_state.backends.items(): metadata = info.backend_config.internal_metadata if metadata.autoscaling_config is not None: autoscaling_policies[backend] = BasicAutoscalingPolicy( backend, metadata.autoscaling_config) # Start/stop any pending backend replicas. - await self._enqueue_pending_scale_changes_loop(current_state) + await self._enqueue_pending_scale_changes_loop(backend_state) await self.backend_control_loop() return autoscaling_policies @@ -482,8 +491,8 @@ class FutureResult: @dataclass class Checkpoint: - goal_state: SystemState - current_state: SystemState + endpoint_state_checkpoint: bytes + backend_state_checkpoint: bytes reconciler: ActorStateReconciler # TODO(ilr) Rename reconciler to PendingState inflight_reqs: Dict[uuid4, FutureResult] @@ -523,13 +532,6 @@ class ServeController: detached: bool = False): # Used to read/write checkpoints. self.kv_store = RayInternalKVStore(namespace=controller_name) - # Current State - self.current_state = SystemState() - # Goal State - # TODO(ilr) This is currently *unused* until the refactor of the serve - # controller. - self.goal_state = SystemState() - # ActorStateReconciler self.actor_reconciler = ActorStateReconciler(controller_name, detached) # backend -> AutoscalingPolicy @@ -556,10 +558,17 @@ class ServeController: self.inflight_results: Dict[UUID, asyncio.Event] = dict() self._serializable_inflight_results: Dict[UUID, FutureResult] = dict() - checkpoint = self.kv_store.get(CHECKPOINT_KEY) - if checkpoint is None: + checkpoint_bytes = self.kv_store.get(CHECKPOINT_KEY) + if checkpoint_bytes is None: logger.debug("No checkpoint found") + self.backend_state = BackendState() + self.endpoint_state = EndpointState() else: + checkpoint: Checkpoint = pickle.loads(checkpoint_bytes) + self.backend_state = BackendState( + checkpoint=checkpoint.backend_state_checkpoint) + self.endpoint_state = EndpointState( + checkpoint=checkpoint.endpoint_state_checkpoint) await self._recover_from_checkpoint(checkpoint) # NOTE(simon): Currently we do all-to-all broadcast. This means @@ -618,17 +627,17 @@ class ServeController: def notify_traffic_policies_changed(self): self.long_poll_host.notify_changed( LongPollKey.TRAFFIC_POLICIES, - self.current_state.traffic_policies, + self.endpoint_state.traffic_policies, ) def notify_backend_configs_changed(self): self.long_poll_host.notify_changed( LongPollKey.BACKEND_CONFIGS, - self.current_state.get_backend_configs()) + self.backend_state.get_backend_configs()) def notify_route_table_changed(self): self.long_poll_host.notify_changed(LongPollKey.ROUTE_TABLE, - self.current_state.routes) + self.endpoint_state.routes) async def listen_for_change(self, keys_to_snapshot_ids: Dict[str, int]): """Proxy long pull client's listen request. @@ -652,19 +661,19 @@ class ServeController: start = time.time() checkpoint = pickle.dumps( - Checkpoint(self.goal_state, self.current_state, - self.actor_reconciler, + Checkpoint(self.endpoint_state.checkpoint(), + self.backend_state.checkpoint(), self.actor_reconciler, self._serializable_inflight_results)) self.kv_store.put(CHECKPOINT_KEY, checkpoint) - logger.debug("Wrote checkpoint in {:.2f}".format(time.time() - start)) + logger.debug("Wrote checkpoint in {:.3f}s".format(time.time() - start)) if random.random( ) < _CRASH_AFTER_CHECKPOINT_PROBABILITY and self.detached: logger.warning("Intentionally crashing after checkpoint") os._exit(0) - async def _recover_from_checkpoint(self, checkpoint_bytes: bytes) -> None: + async def _recover_from_checkpoint(self, checkpoint: Checkpoint) -> None: """Recover the instance state from the provided checkpoint. This should be called in the constructor to ensure that the internal @@ -679,12 +688,9 @@ class ServeController: start = time.time() logger.info("Recovering from checkpoint") - restored_checkpoint: Checkpoint = pickle.loads(checkpoint_bytes) - self.current_state = restored_checkpoint.current_state + self.actor_reconciler = checkpoint.reconciler - self.actor_reconciler = restored_checkpoint.reconciler - - self._serializable_inflight_results = restored_checkpoint.inflight_reqs + self._serializable_inflight_results = checkpoint.inflight_reqs for uuid, fut_result in self._serializable_inflight_results.items(): self._create_event_with_result(fut_result.requested_goal, uuid) @@ -704,7 +710,7 @@ class ServeController: async def finish_recover_from_checkpoint(): assert self.write_lock.locked() self.autoscaling_policies = await self.actor_reconciler.\ - _recover_from_checkpoint(self.current_state, self) + _recover_from_checkpoint(self.backend_state, self) self.write_lock.release() logger.info( "Recovered from checkpoint in {:.3f}s".format(time.time() - @@ -714,7 +720,7 @@ class ServeController: asyncio.get_event_loop().create_task(finish_recover_from_checkpoint()) async def do_autoscale(self) -> None: - for backend, info in self.current_state.backends.items(): + for backend, info in self.backend_state.backends.items(): if backend not in self.autoscaling_policies: continue @@ -726,9 +732,6 @@ class ServeController: async def reconcile_current_and_goal_backends(self): pass - # backends_to_delete = set( - # self.current_state.backends.keys()).difference( - # self.goal_state.backends.keys()) async def run_control_loop(self) -> None: while True: @@ -750,15 +753,15 @@ class ServeController: def get_all_backends(self) -> Dict[BackendTag, BackendConfig]: """Returns a dictionary of backend tag to backend config.""" - return self.current_state.get_backend_configs() + return self.backend_state.get_backend_configs() def get_all_endpoints(self) -> Dict[EndpointTag, Dict[BackendTag, Any]]: """Returns a dictionary of backend tag to backend config.""" - return self.current_state.get_endpoints() + return self.endpoint_state.get_endpoints() async def _set_traffic(self, endpoint_name: str, traffic_dict: Dict[str, float]) -> UUID: - if endpoint_name not in self.current_state.get_endpoints(): + if endpoint_name not in self.endpoint_state.get_endpoints(): raise ValueError("Attempted to assign traffic for an endpoint '{}'" " that is not registered.".format(endpoint_name)) @@ -766,13 +769,13 @@ class ServeController: dict), "Traffic policy must be a dictionary." for backend in traffic_dict: - if self.current_state.get_backend(backend) is None: + if self.backend_state.get_backend(backend) is None: raise ValueError( "Attempted to assign traffic to a backend '{}' that " "is not registered.".format(backend)) traffic_policy = TrafficPolicy(traffic_dict) - self.current_state.traffic_policies[endpoint_name] = traffic_policy + self.endpoint_state.traffic_policies[endpoint_name] = traffic_policy return_uuid = self._create_event_with_result({ endpoint_name: traffic_policy @@ -795,20 +798,21 @@ class ServeController: proportion: float) -> UUID: """Shadow traffic from the endpoint to the backend.""" async with self.write_lock: - if endpoint_name not in self.current_state.get_endpoints(): + if endpoint_name not in self.endpoint_state.get_endpoints(): raise ValueError("Attempted to shadow traffic from an " "endpoint '{}' that is not registered." .format(endpoint_name)) - if self.current_state.get_backend(backend_tag) is None: + if self.backend_state.get_backend(backend_tag) is None: raise ValueError( "Attempted to shadow traffic to a backend '{}' that " "is not registered.".format(backend_tag)) - self.current_state.traffic_policies[endpoint_name].set_shadow( + self.endpoint_state.traffic_policies[endpoint_name].set_shadow( backend_tag, proportion) - traffic_policy = self.current_state.traffic_policies[endpoint_name] + traffic_policy = self.endpoint_state.traffic_policies[ + endpoint_name] return_uuid = self._create_event_with_result({ endpoint_name: traffic_policy @@ -839,10 +843,10 @@ class ServeController: # TODO(edoakes): move this to client side. err_prefix = "Cannot create endpoint." - if route in self.current_state.routes: + if route in self.endpoint_state.routes: # Ensures this method is idempotent - if self.current_state.routes[route] == (endpoint, methods): + if self.endpoint_state.routes[route] == (endpoint, methods): return else: @@ -850,7 +854,7 @@ class ServeController: "{} Route '{}' is already registered.".format( err_prefix, route)) - if endpoint in self.current_state.get_endpoints(): + if endpoint in self.endpoint_state.get_endpoints(): raise ValueError( "{} Endpoint '{}' is already registered.".format( err_prefix, endpoint)) @@ -859,7 +863,7 @@ class ServeController: "Registering route '{}' to endpoint '{}' with methods '{}'.". format(route, endpoint, methods)) - self.current_state.routes[route] = (endpoint, methods) + self.endpoint_state.routes[route] = (endpoint, methods) # NOTE(edoakes): checkpoint is written in self._set_traffic. return_uuid = await self._set_traffic(endpoint, traffic_dict) @@ -876,7 +880,7 @@ class ServeController: # This method must be idempotent. We should validate that the # specified endpoint exists on the client. for route, (route_endpoint, - _) in self.current_state.routes.items(): + _) in self.endpoint_state.routes.items(): if route_endpoint == endpoint: route_to_delete = route break @@ -885,11 +889,11 @@ class ServeController: return # Remove the routing entry. - del self.current_state.routes[route_to_delete] + del self.endpoint_state.routes[route_to_delete] # Remove the traffic policy entry if it exists. - if endpoint in self.current_state.traffic_policies: - del self.current_state.traffic_policies[endpoint] + if endpoint in self.endpoint_state.traffic_policies: + del self.endpoint_state.traffic_policies[endpoint] return_uuid = self._create_event_with_result({ route_to_delete: None, @@ -908,7 +912,7 @@ class ServeController: """Register a new backend under the specified tag.""" async with self.write_lock: # Ensures this method is idempotent. - backend_info = self.current_state.get_backend(backend_tag) + backend_info = self.backend_state.get_backend(backend_tag) if backend_info is not None: if (backend_info.backend_config == backend_config and backend_info.replica_config == replica_config): @@ -923,7 +927,7 @@ class ServeController: worker_class=backend_replica, backend_config=backend_config, replica_config=replica_config) - self.current_state.add_backend(backend_tag, backend_info) + self.backend_state.add_backend(backend_tag, backend_info) metadata = backend_config.internal_metadata if metadata.autoscaling_config is not None: self.autoscaling_policies[ @@ -933,10 +937,10 @@ class ServeController: try: # This call should be to run control loop self.actor_reconciler._scale_backend_replicas( - self.current_state.backends, backend_tag, + self.backend_state.backends, backend_tag, backend_config.num_replicas) except RayServeException as e: - del self.current_state.backends[backend_tag] + del self.backend_state.backends[backend_tag] raise e return_uuid = self._create_event_with_result({ @@ -947,7 +951,7 @@ class ServeController: # crash while making the change. self._checkpoint() await self.actor_reconciler._enqueue_pending_scale_changes_loop( - self.current_state) + self.backend_state) await self.actor_reconciler.backend_control_loop() self.notify_replica_handles_changed() @@ -961,11 +965,11 @@ class ServeController: async with self.write_lock: # This method must be idempotent. We should validate that the # specified backend exists on the client. - if self.current_state.get_backend(backend_tag) is None: + if self.backend_state.get_backend(backend_tag) is None: return # Check that the specified backend isn't used by any endpoints. - for endpoint, traffic_policy in self.current_state.\ + for endpoint, traffic_policy in self.endpoint_state.\ traffic_policies.items(): if (backend_tag in traffic_policy.traffic_dict or backend_tag in traffic_policy.shadow_dict): @@ -975,17 +979,15 @@ class ServeController: "again.".format(backend_tag, endpoint)) # Scale its replicas down to 0. This will also remove the backend - # from self.current_state.backends and + # from self.backend_state.backends and # self.actor_reconciler.backend_replicas. - self.goal_state.backends[backend_tag] = None - # This should be a call to the control loop self.actor_reconciler._scale_backend_replicas( - self.current_state.backends, backend_tag, 0) + self.backend_state.backends, backend_tag, 0) # Remove the backend's metadata. - del self.current_state.backends[backend_tag] + del self.backend_state.backends[backend_tag] if backend_tag in self.autoscaling_policies: del self.autoscaling_policies[backend_tag] @@ -998,7 +1000,7 @@ class ServeController: # after pushing the update. self._checkpoint() await self.actor_reconciler._enqueue_pending_scale_changes_loop( - self.current_state) + self.backend_state) await self.actor_reconciler.backend_control_loop() self.notify_replica_handles_changed() @@ -1008,24 +1010,24 @@ class ServeController: config_options: BackendConfig) -> UUID: """Set the config for the specified backend.""" async with self.write_lock: - assert (self.current_state.get_backend(backend_tag) + assert (self.backend_state.get_backend(backend_tag) ), "Backend {} is not registered.".format(backend_tag) assert isinstance(config_options, BackendConfig) - stored_backend_config = self.current_state.get_backend( + stored_backend_config = self.backend_state.get_backend( backend_tag).backend_config backend_config = stored_backend_config.copy( update=config_options.dict(exclude_unset=True)) backend_config._validate_complete() - self.current_state.get_backend( + self.backend_state.get_backend( backend_tag).backend_config = backend_config - backend_info = self.current_state.get_backend(backend_tag) + backend_info = self.backend_state.get_backend(backend_tag) # Scale the replicas with the new configuration. # This should be to run the control loop self.actor_reconciler._scale_backend_replicas( - self.current_state.backends, backend_tag, + self.backend_state.backends, backend_tag, backend_config.num_replicas) return_uuid = self._create_event_with_result({ @@ -1040,7 +1042,7 @@ class ServeController: # (particularly for setting max_batch_size). await self.actor_reconciler._enqueue_pending_scale_changes_loop( - self.current_state) + self.backend_state) await self.actor_reconciler.backend_control_loop() self.notify_replica_handles_changed() @@ -1049,9 +1051,9 @@ class ServeController: def get_backend_config(self, backend_tag: BackendTag) -> BackendConfig: """Get the current config for the specified backend.""" - assert (self.current_state.get_backend(backend_tag) + assert (self.backend_state.get_backend(backend_tag) ), "Backend {} is not registered.".format(backend_tag) - return self.current_state.get_backend(backend_tag).backend_config + return self.backend_state.get_backend(backend_tag).backend_config def get_http_config(self): """Return the HTTP proxy configuration.""" From ea8d782be1554262f2dde9d7c84d8ac25478a35d Mon Sep 17 00:00:00 2001 From: Alex Wu Date: Mon, 21 Dec 2020 19:17:51 -0800 Subject: [PATCH 65/88] [core] Pull Manager exponential backoff (#13024) --- src/ray/object_manager/pull_manager.cc | 7 ++++++- src/ray/object_manager/pull_manager.h | 5 +++-- src/ray/object_manager/test/pull_manager_test.cc | 4 ++-- 3 files changed, 11 insertions(+), 5 deletions(-) diff --git a/src/ray/object_manager/pull_manager.cc b/src/ray/object_manager/pull_manager.cc index 7632c5c7b..8ced0f51b 100644 --- a/src/ray/object_manager/pull_manager.cc +++ b/src/ray/object_manager/pull_manager.cc @@ -111,6 +111,10 @@ void PullManager::TryPull(const ObjectID &object_id) { RAY_LOG(DEBUG) << "Sending pull request from " << self_node_id_ << " to " << node_id << " of object " << object_id; + const auto time = get_time_(); + auto &request = it->second; + auto retry_timeout_len = (pull_timeout_ms_ / 1000.) * (1UL << request.num_retries); + request.next_pull_time = time + retry_timeout_len; send_pull_request_(object_id, node_id); } @@ -131,7 +135,8 @@ void PullManager::Tick() { const auto time = get_time_(); if (time >= request.next_pull_time) { TryPull(object_id); - request.next_pull_time = time + pull_timeout_ms_ / 1000; + // Bound the retry time at 10 * 1024 seconds. + request.num_retries = std::min(request.num_retries + 1, 10); } } } diff --git a/src/ray/object_manager/pull_manager.h b/src/ray/object_manager/pull_manager.h index 023f72d0e..f312af17a 100644 --- a/src/ray/object_manager/pull_manager.h +++ b/src/ray/object_manager/pull_manager.h @@ -81,9 +81,10 @@ class PullManager { /// A helper structure for tracking information about each ongoing object pull. struct PullRequest { PullRequest(double first_retry_time) - : client_locations(), next_pull_time(first_retry_time) {} + : client_locations(), next_pull_time(first_retry_time), num_retries(0) {} std::vector client_locations; double next_pull_time; + uint8_t num_retries; }; /// See the constructor's arguments. @@ -92,7 +93,7 @@ class PullManager { const std::function send_pull_request_; const RestoreSpilledObjectCallback restore_spilled_object_; const std::function get_time_; - int pull_timeout_ms_; + uint64_t pull_timeout_ms_; /// The objects that this object manager is currently trying to fetch from /// remote object managers. diff --git a/src/ray/object_manager/test/pull_manager_test.cc b/src/ray/object_manager/test/pull_manager_test.cc index 90f34048c..fb7b1c1c2 100644 --- a/src/ray/object_manager/test/pull_manager_test.cc +++ b/src/ray/object_manager/test/pull_manager_test.cc @@ -125,7 +125,7 @@ TEST_F(PullManagerTest, TestRetryTimer) { ASSERT_EQ(num_send_pull_request_calls_, 1); ASSERT_EQ(num_restore_spilled_object_calls_, 0); - for (; fake_time_ <= 127 * 10; fake_time_ += 0.1) { + for (; fake_time_ <= 127 * 10; fake_time_ += 1.) { pull_manager_.Tick(); } @@ -140,7 +140,7 @@ TEST_F(PullManagerTest, TestRetryTimer) { // OnLocationChange also doesn't count towards the retry timer. // To the casual observer, this may seem off-by-one, but this is due to floating point // error (0.1 + 0.1 ... 10k times > 10 == True) - ASSERT_EQ(num_send_pull_request_calls_, 127 * 2); + ASSERT_EQ(num_send_pull_request_calls_, 1 + 7 + 127); ASSERT_EQ(num_restore_spilled_object_calls_, 0); pull_manager_.CancelPull(obj1); From 01faeabc17dc007f3781d11612fd5aa4dd558931 Mon Sep 17 00:00:00 2001 From: Sven Mika Date: Tue, 22 Dec 2020 15:28:07 +0100 Subject: [PATCH 66/88] [RLlib] Issue 12789: RLlib throws the warning "The given NumPy array is not writeable" (#12793) --- rllib/utils/torch_ops.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/rllib/utils/torch_ops.py b/rllib/utils/torch_ops.py index 9b9c52c37..53c26dc4d 100644 --- a/rllib/utils/torch_ops.py +++ b/rllib/utils/torch_ops.py @@ -1,6 +1,7 @@ from gym.spaces import Discrete, MultiDiscrete import numpy as np import tree +import warnings from ray.rllib.models.repeated_values import RepeatedValues from ray.rllib.utils.framework import try_import_torch @@ -62,7 +63,13 @@ def convert_to_torch_tensor(x, device=None): return RepeatedValues( tree.map_structure(mapping, item.values), item.lengths, item.max_len) - tensor = torch.from_numpy(np.asarray(item)) + # Non-writable numpy-arrays will cause PyTorch warning. + if isinstance(item, np.ndarray) and item.flags.writeable is False: + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + tensor = torch.from_numpy(item) + else: + tensor = torch.from_numpy(np.asarray(item)) # Floatify all float64 tensors. if tensor.dtype == torch.double: tensor = tensor.float() From a79c9fcac3f1a6b3211c04dcf20c9c2f62a50e4c Mon Sep 17 00:00:00 2001 From: Alex Wu Date: Tue, 22 Dec 2020 11:05:33 -0800 Subject: [PATCH 67/88] [release tests] test_many_tasks fix (#12984) --- release/stress_tests/autoscaler-cluster.yaml | 108 ++++++++++ .../stress_tests/workloads/test_many_tasks.py | 185 ++++++++++-------- 2 files changed, 211 insertions(+), 82 deletions(-) create mode 100644 release/stress_tests/autoscaler-cluster.yaml diff --git a/release/stress_tests/autoscaler-cluster.yaml b/release/stress_tests/autoscaler-cluster.yaml new file mode 100644 index 000000000..a75b04fae --- /dev/null +++ b/release/stress_tests/autoscaler-cluster.yaml @@ -0,0 +1,108 @@ +#################################################################### +# All nodes in this cluster will auto-terminate in 1 hour +#################################################################### + +# An unique identifier for the head node and workers of this cluster. +cluster_name: autoscaler-stress-test-1.1.0-alex + +# The minimum number of workers nodes to launch in addition to the head +# node. This number should be >= 0. +min_workers: 100 + +# The maximum number of workers nodes to launch in addition to the head +# node. This takes precedence over min_workers. +max_workers: 100 + +# The autoscaler will scale up the cluster to this target fraction of resource +# usage. For example, if a cluster of 10 nodes is 100% busy and +# target_utilization is 0.8, it would resize the cluster to 13. This fraction +# can be decreased to increase the aggressiveness of upscaling. +# This value must be less than 1.0 for scaling to happen. +target_utilization_fraction: 0.8 + +# If a node is idle for this many minutes, it will be removed. +idle_timeout_minutes: 5 + +# Cloud-provider specific configuration. +provider: + type: aws + region: us-west-1 + availability_zone: us-west-1a + cache_stopped_nodes: False + +# How Ray will authenticate with newly launched nodes. +auth: + ssh_user: ubuntu +# By default Ray creates a new private keypair, but you can also use your own. +# If you do so, make sure to also set "KeyName" in the head and worker node +# configurations below. +# ssh_private_key: /path/to/your/key.pem + +# Provider-specific config for the head node, e.g. instance type. By default +# Ray will auto-configure unspecified fields such as SubnetId and KeyName. +# For more documentation on available fields, see: +# http://boto3.readthedocs.io/en/latest/reference/services/ec2.html#EC2.ServiceResource.create_instances +head_node: + InstanceType: m4.16xlarge + ImageId: ami-0cc472544ce594a19 # Custom ami + + # Set primary volume to 25 GiB + BlockDeviceMappings: + - DeviceName: /dev/sda1 + Ebs: + VolumeSize: 100 + + # Additional options in the boto docs. + +# Provider-specific config for worker nodes, e.g. instance type. By default +# Ray will auto-configure unspecified fields such as SubnetId and KeyName. +# For more documentation on available fields, see: +# http://boto3.readthedocs.io/en/latest/reference/services/ec2.html#EC2.ServiceResource.create_instances +worker_nodes: + InstanceType: m4.large + ImageId: ami-0cc472544ce594a19 # Custom ami + + # Set primary volume to 25 GiB + BlockDeviceMappings: + - DeviceName: /dev/sda1 + Ebs: + VolumeSize: 100 + + # Run workers on spot by default. Comment this out to use on-demand. + InstanceMarketOptions: + MarketType: spot + # Additional options can be found in the boto docs, e.g. + # SpotOptions: + # MaxPrice: MAX_HOURLY_PRICE + + # Additional options in the boto docs. + +# List of shell commands to run to set up nodes. +setup_commands: + # Uncomment these if you want to build ray from source. + # - sudo apt-get -qq update + # - sudo apt-get install -y build-essential curl unzip + # # Build Ray. + # - git clone https://github.com/ray-project/ray || true + # - ray/ci/travis/install-bazel.sh + - pip install -U pip + - pip install terminado + - pip install boto3==1.4.8 cython==0.29.0 + # - cd ray/python; git checkout master; git pull; pip install -e . --verbose + - pip install -U pip install https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-1.2.0.dev0-cp38-cp38-manylinux2014_x86_64.whl + +# Custom commands that will be run on the head node after common setup. +head_setup_commands: [] + +# Custom commands that will be run on worker nodes after common setup. +worker_setup_commands: [] + +# Command to start ray on the head node. You don't need to change this. +head_start_ray_commands: + - ray stop + - ulimit -n 65536; ray start --head --port=6379 --autoscaling-config=~/ray_bootstrap_config.yaml + +# Command to start ray on worker nodes. You don't need to change this. +worker_start_ray_commands: + - ray stop + - ulimit -n 65536; ray start --address=$RAY_HEAD_IP:6379 --num-gpus=100 diff --git a/release/stress_tests/workloads/test_many_tasks.py b/release/stress_tests/workloads/test_many_tasks.py index 1abebfb51..bc95b0407 100644 --- a/release/stress_tests/workloads/test_many_tasks.py +++ b/release/stress_tests/workloads/test_many_tasks.py @@ -45,72 +45,91 @@ class Actor(object): # Stage 0: Submit a bunch of small tasks with large returns. -stage_0_iterations = [] -start_time = time.time() -logger.info("Submitting many tasks with large returns.") -for i in range(10): - iteration_start = time.time() - logger.info("Iteration %s", i) - ray.get([f.remote(1000000) for _ in range(1000)]) - stage_0_iterations.append(time.time() - iteration_start) +def stage0(): + stage_0_iterations = [] + start_time = time.time() + logger.info("Submitting many tasks with large returns.") + for i in range(10): + iteration_start = time.time() + logger.info("Iteration %s", i) + ray.get([f.remote(1000000) for _ in range(1000)]) + stage_0_iterations.append(time.time() - iteration_start) -stage_0_time = time.time() - start_time + return time.time() - start_time + + +stage_0_time = stage0() logger.info("Finished stage 0 after %s seconds.", stage_0_time) -# Stage 1: Launch a bunch of tasks. -stage_1_iterations = [] -start_time = time.time() -logger.info("Submitting many tasks.") -for i in range(10): - iteration_start = time.time() - logger.info("Iteration %s", i) - ray.get([f.remote(0) for _ in range(100000)]) - stage_1_iterations.append(time.time() - iteration_start) -stage_1_time = time.time() - start_time +# Stage 1: Launch a bunch of tasks. +def stage1(): + stage_1_iterations = [] + start_time = time.time() + logger.info("Submitting many tasks.") + for i in range(10): + iteration_start = time.time() + logger.info("Iteration %s", i) + ray.get([f.remote(0) for _ in range(100000)]) + stage_1_iterations.append(time.time() - iteration_start) + + return time.time() - start_time, stage_1_iterations + + +stage_1_time, stage_1_iterations = stage1() logger.info("Finished stage 1 after %s seconds.", stage_1_time) + # Launch a bunch of tasks, each with a bunch of dependencies. TODO(rkn): This # test starts to fail if we increase the number of tasks in the inner loop from # 500 to 1000. (approximately 615 seconds) -stage_2_iterations = [] -start_time = time.time() -logger.info("Submitting tasks with many dependencies.") -x_ids = [] -for _ in range(5): - iteration_start = time.time() - for i in range(20): - logger.info("Iteration %s. Cumulative time %s seconds", i, - time.time() - start_time) - x_ids = [f.remote(0, *x_ids) for _ in range(500)] - ray.get(x_ids) - stage_2_iterations.append(time.time() - iteration_start) - logger.info("Finished after %s seconds.", time.time() - start_time) +def stage2(): + stage_2_iterations = [] + start_time = time.time() + logger.info("Submitting tasks with many dependencies.") + x_ids = [] + for _ in range(5): + iteration_start = time.time() + for i in range(20): + logger.info("Iteration %s. Cumulative time %s seconds", i, + time.time() - start_time) + x_ids = [f.remote(0, *x_ids) for _ in range(500)] + ray.get(x_ids) + stage_2_iterations.append(time.time() - iteration_start) + logger.info("Finished after %s seconds.", time.time() - start_time) + return time.time() - start_time, stage_2_iterations -stage_2_time = time.time() - start_time + +stage_2_time, stage_2_iterations = stage2() logger.info("Finished stage 2 after %s seconds.", stage_2_time) -# Create a bunch of actors. -start_time = time.time() -logger.info("Creating %s actors.", num_remote_cpus) -actors = [Actor.remote() for _ in range(num_remote_cpus)] -stage_3_creation_time = time.time() - start_time -logger.info("Finished stage 3 actor creation in %s seconds.", - stage_3_creation_time) -# Submit a bunch of small tasks to each actor. (approximately 1070 seconds) -start_time = time.time() -logger.info("Submitting many small actor tasks.") -for N in [1000, 100000]: - x_ids = [] - for i in range(N): - x_ids = [a.method.remote(0) for a in actors] - if i % 100 == 0: - logger.info("Submitted {}".format(i * len(actors))) - ray.get(x_ids) -stage_3_time = time.time() - start_time +# Create a bunch of actors. +def stage3(): + start_time = time.time() + logger.info("Creating %s actors.", num_remote_cpus) + actors = [Actor.remote() for _ in range(num_remote_cpus)] + stage_3_creation_time = time.time() - start_time + logger.info("Finished stage 3 actor creation in %s seconds.", + stage_3_creation_time) + + # Submit a bunch of small tasks to each actor. (approximately 1070 seconds) + start_time = time.time() + logger.info("Submitting many small actor tasks.") + for N in [1000, 100000]: + x_ids = [] + for i in range(N): + x_ids = [a.method.remote(0) for a in actors] + if i % 100 == 0: + logger.info("Submitted {}".format(i * len(actors))) + ray.get(x_ids) + return time.time() - start_time, stage_3_creation_time + + +stage_3_time, stage_3_creation_time = stage3() logger.info("Finished stage 3 in %s seconds.", stage_3_time) + # This tests https://github.com/ray-project/ray/issues/10150. The only way to # integration test this is via performance. The goal is to fill up the cluster # so that all tasks can be run, but spillback is required. Since the driver @@ -119,38 +138,39 @@ logger.info("Finished stage 3 in %s seconds.", stage_3_time) # task will require O(N) queries. Since we limit the number of inflight # requests, we will run into head of line blocking and we should be able to # measure this timing. -num_tasks = int(ray.cluster_resources()["GPU"]) -logger.info(f"Scheduling many tasks for spillback.") +def stage4(): + num_tasks = int(ray.cluster_resources()["GPU"]) + logger.info(f"Scheduling many tasks for spillback.") + + @ray.remote(num_gpus=1) + def func(t): + if t % 100 == 0: + logger.info(f"[spillback test] {t}/{num_tasks}") + start = time.perf_counter() + time.sleep(1) + end = time.perf_counter() + return start, end, ray.worker.global_worker.node.unique_id + + results = ray.get([func.remote(i) for i in range(num_tasks)]) + + host_to_start_times = defaultdict(list) + for start, end, host in results: + host_to_start_times[host].append(start) + + spreads = [] + for host in host_to_start_times: + last = max(host_to_start_times[host]) + first = min(host_to_start_times[host]) + spread = last - first + spreads.append(spread) + logger.info(f"Spread: {last - first}\tLast: {last}\tFirst: {first}") + + avg_spread = sum(spreads) / len(spreads) + logger.info(f"Avg spread: {sum(spreads)/len(spreads)}") + return avg_spread -@ray.remote(num_gpus=1) -def func(t): - if t % 100 == 0: - logger.info(f"[spillback test] {t}/{num_tasks}") - start = time.perf_counter() - time.sleep(1) - end = time.perf_counter() - return start, end, ray.worker.global_worker.node.unique_id - - -results = ray.get([func.remote(i) for i in range(num_tasks)]) - -host_to_start_times = defaultdict(list) -for start, end, host in results: - host_to_start_times[host].append(start) - -spreads = [] -for host in host_to_start_times: - last = max(host_to_start_times[host]) - first = min(host_to_start_times[host]) - spread = last - first - spreads.append(spread) - logger.info(f"Spread: {last - first}\tLast: {last}\tFirst: {first}") - -# avg_spread ~ 115 with Ray 1.0 scheduler. ~695 with (buggy) 0.8.7 scheduler. -avg_spread = sum(spreads) / len(spreads) -logger.info(f"Avg spread: {sum(spreads)/len(spreads)}") - +stage_4_spread = stage4() print("Stage 0 results:") print("\tTotal time: {}".format(stage_0_time)) @@ -173,7 +193,8 @@ print("\tActor creation time: {}".format(stage_3_creation_time)) print("\tTotal time: {}".format(stage_3_time)) print("Stage 4 results:") -print(f"\tScheduling spread: {avg_spread}.") +# avg_spread ~ 115 with Ray 1.0 scheduler. ~695 with (buggy) 0.8.7 scheduler. +print(f"\tScheduling spread: {stage_4_spread}.") # TODO(rkn): The test below is commented out because it currently does not # pass. From 81d3cbaa77b698f930189c92ef2c2642551ee8af Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Tue, 22 Dec 2020 16:08:41 -0800 Subject: [PATCH 68/88] Add "beta" documentation for enabling object spilling manually (#13047) --- doc/source/advanced.rst | 36 ++++++++++++++++++++++++++++++++++++ doc/source/walkthrough.rst | 2 ++ 2 files changed, 38 insertions(+) diff --git a/doc/source/advanced.rst b/doc/source/advanced.rst index 742fbfa68..ea431bb24 100644 --- a/doc/source/advanced.rst +++ b/doc/source/advanced.rst @@ -392,3 +392,39 @@ To get information about the current available resource capacity of your cluster .. autofunction:: ray.available_resources :noindex: + +Object Spilling +--------------- + +Ray 1.2.0+ has *beta* support for spilling objects to external storage once the capacity +of the object store is used up. Please file a `GitHub issue `__ +if you encounter any problems with this new feature. Eventually, object spilling will be +enabled by default, but for now you need to enable it manually: + +To enable object spilling to the local filesystem (single node clusters only): + +.. code-block:: python + + ray.init( + _system_config={ + "automatic_object_spilling_enabled": True, + "object_spilling_config": json.dumps( + {"type": "filesystem", "params": {"directory_path": "/tmp/spill"}}, + ) + }, + ) + +To enable object spilling to remote storage (any URI supported by `smart_open `__): + +.. code-block:: python + + ray.init( + _system_config={ + "automatic_object_spilling_enabled": True, + "max_io_workers": 4, # More IO workers for remote storage. + "min_spilling_size": 100 * 1024 * 1024, # Spill at least 100MB at a time. + "object_spilling_config": json.dumps( + {"type": "smart_open", "params": {"uri": "s3:///bucket/path"}}, + ) + }, + ) diff --git a/doc/source/walkthrough.rst b/doc/source/walkthrough.rst index 2a4b91716..8680c1af3 100644 --- a/doc/source/walkthrough.rst +++ b/doc/source/walkthrough.rst @@ -417,6 +417,8 @@ actors. to the object ref returned by the put exists. This only applies to the specific ref returned by put, not refs in general or copies of that refs. +See also: `object spilling `__. + Remote Classes (Actors) ----------------------- From bc6826014454244195781a2b64e2302a628bc93d Mon Sep 17 00:00:00 2001 From: Simon Mo Date: Tue, 22 Dec 2020 19:13:16 -0800 Subject: [PATCH 69/88] [Serve] Handle Bug Fixes (#12971) --- doc/source/serve/advanced.rst | 20 +++ python/ray/serve/api.py | 107 +++++++++++--- python/ray/serve/handle.py | 201 +++++++++++--------------- python/ray/serve/tests/test_api.py | 2 +- python/ray/serve/tests/test_handle.py | 30 ++++ 5 files changed, 221 insertions(+), 139 deletions(-) diff --git a/doc/source/serve/advanced.rst b/doc/source/serve/advanced.rst index 3c3ca3940..fdf2684e7 100644 --- a/doc/source/serve/advanced.rst +++ b/doc/source/serve/advanced.rst @@ -267,6 +267,26 @@ That's it. Let's take a look at an example: .. literalinclude:: ../../../python/ray/serve/examples/doc/snippet_model_composition.py + +.. _serve-sync-async-handles: + +Sync and Async Handles +====================== + +Ray Serve offers two types of ``ServeHandle``. You can use the ``client.get_handle(..., sync=True|False)`` +flag to toggle between them. + +- When you set ``sync=True`` (the default), a synchronous handle is returned. + Calling ``handle.remote()`` should return a Ray ObjectRef. +- When you set ``sync=False``, an asyncio based handle is returned. You need to + Call it with ``await handle.remote()`` to return a Ray ObjectRef. To use ``await``, + you have to run ``client.get_handle`` and ``handle.remote`` in Python asyncio event loop. + +The async handle has performance advantage because it uses asyncio directly; as compared +to the sync handle, which talks to an asyncio event loop in a thread. To learn more about +the reasoning behind these, checkout our `architecture documentation <./architecture.html>`_. + + Monitoring ========== diff --git a/python/ray/serve/api.py b/python/ray/serve/api.py index ec1593654..3e4b53b28 100644 --- a/python/ray/serve/api.py +++ b/python/ray/serve/api.py @@ -4,22 +4,40 @@ import time from functools import wraps import os from uuid import UUID +import threading +from typing import Any, Callable, Coroutine, Dict, List, Optional, Type, Union + +from ray.serve.context import TaskContext import ray from ray.serve.constants import (DEFAULT_HTTP_HOST, DEFAULT_HTTP_PORT, SERVE_CONTROLLER_NAME, HTTP_PROXY_TIMEOUT) from ray.serve.controller import ServeController -from ray.serve.handle import RayServeHandle +from ray.serve.handle import RayServeHandle, RayServeSyncHandle from ray.serve.utils import (block_until_http_ready, format_actor_name, get_random_letters, logger, get_conda_env_dir) from ray.serve.exceptions import RayServeException from ray.serve.config import BackendConfig, ReplicaConfig, BackendMetadata from ray.serve.env import CondaEnv +from ray.serve.router import RequestMetadata, Router from ray.actor import ActorHandle -from typing import Any, Callable, Dict, List, Optional, Type, Union _INTERNAL_CONTROLLER_NAME = None +global_async_loop = None + + +def create_or_get_async_loop_in_thread(): + global global_async_loop + if global_async_loop is None: + global_async_loop = asyncio.new_event_loop() + thread = threading.Thread( + daemon=True, + target=global_async_loop.run_forever, + ) + thread.start() + return global_async_loop + def _set_internal_controller_name(name): global _INTERNAL_CONTROLLER_NAME @@ -36,6 +54,36 @@ def _ensure_connected(f: Callable) -> Callable: return check +class ThreadProxiedRouter: + def __init__(self, controller_handle, sync: bool): + self.router = Router(controller_handle) + + if sync: + self.async_loop = create_or_get_async_loop_in_thread() + asyncio.run_coroutine_threadsafe( + self.router.setup_in_async_loop(), + self.async_loop, + ) + else: + self.async_loop = asyncio.get_event_loop() + self.async_loop.create_task(self.router.setup_in_async_loop()) + + def _remote(self, endpoint_name, handle_options, request_data, + kwargs) -> Coroutine: + request_metadata = RequestMetadata( + get_random_letters(10), # Used for debugging. + endpoint_name, + TaskContext.Python, + call_method=handle_options.method_name, + shard_key=handle_options.shard_key, + http_method=handle_options.http_method, + http_headers=handle_options.http_headers, + ) + coro = self.router.assign_request(request_metadata, request_data, + **kwargs) + return coro + + class Client: def __init__(self, controller: ActorHandle, @@ -48,12 +96,8 @@ class Client: self._http_host, self._http_port = ray.get( controller.get_http_config.remote()) - # NOTE(simon): Used to cache client.get_handle(endpoint) call. It will - # mostly grow in size, it will only shrink when user calls the - # .remove_endpoint method. This is fine because we expect the number of - # endpoints to be fairly small. However, in case this dictionary does - # grow very big, we can replace it with a LRU cache instead. - self._handle_cache: Dict[str, ActorHandle] = dict() + self._sync_proxied_router = None + self._async_proxied_router = None # NOTE(edoakes): Need this because the shutdown order isn't guaranteed # when the interpreter is exiting so we can't rely on __del__ (it @@ -65,6 +109,18 @@ class Client: atexit.register(shutdown_serve_client) + def _get_proxied_router(self, sync: bool): + if sync: + if self._sync_proxied_router is None: + self._sync_proxied_router = ThreadProxiedRouter( + self._controller, sync=True) + return self._sync_proxied_router + else: + if self._async_proxied_router is None: + self._async_proxied_router = ThreadProxiedRouter( + self._controller, sync=False) + return self._async_proxied_router + def __del__(self): if not self._detached: logger.debug("Shutting down Ray Serve because client went out of " @@ -198,8 +254,6 @@ class Client: Does not delete any associated backends. """ - if endpoint in self._handle_cache: - del self._handle_cache[endpoint] self._get_result(self._controller.delete_endpoint.remote(endpoint)) @_ensure_connected @@ -410,10 +464,11 @@ class Client: proportion)) @_ensure_connected - def get_handle(self, - endpoint_name: str, - missing_ok: Optional[bool] = False, - sync: bool = True) -> RayServeHandle: + def get_handle( + self, + endpoint_name: str, + missing_ok: Optional[bool] = False, + sync: bool = True) -> Union[RayServeHandle, RayServeSyncHandle]: """Retrieve RayServeHandle for service endpoint to invoke it from Python. Args: @@ -433,14 +488,26 @@ class Client: if asyncio.get_event_loop().is_running() and sync: logger.warning( - "You are retrieving a ServeHandle inside an asyncio loop. " + "You are retrieving a sync handle inside an asyncio loop. " "Try getting client.get_handle(.., sync=False) to get better " - "performance.") + "performance. Learn more at https://docs.ray.io/en/master/" + "serve/advanced.html#sync-and-async-handles") - if endpoint_name not in self._handle_cache: - handle = RayServeHandle(self._controller, endpoint_name, sync=sync) - self._handle_cache[endpoint_name] = handle - return self._handle_cache[endpoint_name] + if not asyncio.get_event_loop().is_running() and not sync: + logger.warning( + "You are retrieving an async handle outside an asyncio loop. " + "You should make sure client.get_handle is called inside a " + "running event loop. Or call client.get_handle(.., sync=True) " + "to create sync handle. Learn more at https://docs.ray.io/en/" + "master/serve/advanced.html#sync-and-async-handles") + + if sync: + handle = RayServeSyncHandle( + self._get_proxied_router(sync=sync), endpoint_name) + else: + handle = RayServeHandle( + self._get_proxied_router(sync=sync), endpoint_name) + return handle def start(detached: bool = False, diff --git a/python/ray/serve/handle.py b/python/ray/serve/handle.py index 4bfd663fd..381c8b833 100644 --- a/python/ray/serve/handle.py +++ b/python/ray/serve/handle.py @@ -1,27 +1,23 @@ import asyncio import concurrent.futures -import threading -from typing import Any, Coroutine, Dict, Optional, Union - -import ray -from ray.serve.context import TaskContext -from ray.serve.router import RequestMetadata, Router -from ray.serve.utils import get_random_letters -from ray.serve.exceptions import RayServeException - -global_async_loop = None +from dataclasses import dataclass, field +from typing import Any, Dict, Optional, Union +from enum import Enum -def create_or_get_async_loop_in_thread(): - global global_async_loop - if global_async_loop is None: - global_async_loop = asyncio.new_event_loop() - thread = threading.Thread( - daemon=True, - target=global_async_loop.run_forever, - ) - thread.start() - return global_async_loop +@dataclass(frozen=True) +class HandleOptions: + """Options for each ServeHandle instances. These fields are immutable.""" + method_name: str = "__call__" + shard_key: Optional[str] = None + http_method: str = "GET" + http_headers: Dict[str, str] = field(default_factory=dict) + + +# Use a global singleton enum to emulate default options. We cannot use None +# for those option because None is a valid new value. +class DEFAULT(Enum): + VALUE = 1 class RayServeHandle: @@ -31,75 +27,59 @@ class RayServeHandle: an HTTP endpoint. Example: - >>> handle = serve.get_handle("my_endpoint") + >>> handle = serve_client.get_handle("my_endpoint") >>> handle - RayServeHandle( - Endpoint="my_endpoint", - Traffic=... - ) - >>> handle.remote(my_request_content) + RayServeHandle(endpoint="my_endpoint") + >>> await handle.remote(my_request_content) ObjectRef(...) - >>> ray.get(handle.remote(...)) + >>> ray.get(await handle.remote(...)) # result - >>> ray.get(handle.remote(let_it_crash_request)) + >>> ray.get(await handle.remote(let_it_crash_request)) # raises RayTaskError Exception """ - def __init__( - self, - controller_handle, - endpoint_name, - sync: bool, - *, - method_name=None, - shard_key=None, - http_method=None, - http_headers=None, - ): - self.controller_handle = controller_handle + def __init__(self, + router, + endpoint_name, + handle_options: Optional[HandleOptions] = None): + self.router = router self.endpoint_name = endpoint_name + self.handle_options = handle_options or HandleOptions() - self.method_name = method_name - self.shard_key = shard_key - self.http_method = http_method - self.http_headers = http_headers + def options(self, + *, + method_name: Union[str, DEFAULT] = DEFAULT.VALUE, + shard_key: Union[str, DEFAULT] = DEFAULT.VALUE, + http_method: Union[str, DEFAULT] = DEFAULT.VALUE, + http_headers: Union[Dict[str, str], DEFAULT] = DEFAULT.VALUE): + """Set options for this handle. - self.router = Router(self.controller_handle) - self.sync = sync - # In the synchrounous mode, we create a new event loop in a separate - # thread and run the Router.setup in that loop. In the async mode, we - # can just use the current loop we are in right now. - if self.sync: - self.async_loop = create_or_get_async_loop_in_thread() - asyncio.run_coroutine_threadsafe( - self.router.setup_in_async_loop(), - self.async_loop, - ) - else: # async - self.async_loop = asyncio.get_event_loop() - # create_task is not threadsafe. - self.async_loop.create_task(self.router.setup_in_async_loop()) + Args: + method_name(str): The method to invoke on the backend. + http_method(str): The HTTP method to use for the request. + shard_key(str): A string to use to deterministically map this + request to a backend if there are multiple for this endpoint. + """ + new_options_dict = self.handle_options.__dict__.copy() + user_modified_options_dict = { + key: value + for key, value in + zip(["method_name", "shard_key", "http_method", "http_headers"], + [method_name, shard_key, http_method, http_headers]) + if value != DEFAULT.VALUE + } + new_options_dict.update(user_modified_options_dict) + new_options = HandleOptions(**new_options_dict) - def _remote(self, request_data, kwargs) -> Coroutine: - request_metadata = RequestMetadata( - get_random_letters(10), # Used for debugging. - self.endpoint_name, - TaskContext.Python, - call_method=self.method_name or "__call__", - shard_key=self.shard_key, - http_method=self.http_method or "GET", - http_headers=self.http_headers or dict(), - ) - coro = self.router.assign_request(request_metadata, request_data, - **kwargs) - return coro + return self.__class__(self.router, self.endpoint_name, new_options) - def remote(self, request_data: Optional[Union[Dict, Any]] = None, - **kwargs): - """Issue an asynchronous request to the endpoint. + async def remote(self, + request_data: Optional[Union[Dict, Any]] = None, + **kwargs): + """Issue an asynchrounous request to the endpoint. Returns a Ray ObjectRef whose results can be waited for or retrieved - using ray.wait or ray.get, respectively. + using ray.wait or ray.get (or ``await object_ref``), respectively. Returns: ray.ObjectRef @@ -110,47 +90,32 @@ class RayServeHandle: ``**kwargs``: All keyword arguments will be available in ``request.query_params``. """ - if not self.sync: - raise RayServeException( - "You are trying to call handle.remote() with async handle. " - "Please use `await handle.remote_async()` instead.") - - coro = self._remote(request_data, kwargs) - future: concurrent.futures.Future = asyncio.run_coroutine_threadsafe( - coro, self.async_loop) - - # Block until the result is ready. - return future.result() - - async def remote_async(self, - request_data: Optional[Union[Dict, Any]] = None, - **kwargs) -> ray.ObjectRef: - """Experimental API for enqueue a request in async context.""" - if not asyncio.get_event_loop().is_running(): - raise RayServeException( - "remote_async must be called from a running event loop.") - return await self._remote(request_data, kwargs) - - def options(self, - method_name: Optional[str] = None, - *, - shard_key: Optional[str] = None, - http_method: Optional[str] = None, - http_headers: Optional[Dict[str, str]] = None): - """Set options for this handle. - - Args: - method_name(str): The method to invoke on the backend. - http_method(str): The HTTP method to use for the request. - shard_key(str): A string to use to deterministically map this - request to a backend if there are multiple for this endpoint. - """ - # Don't override default non-null values. - self.method_name = self.method_name or method_name - self.shard_key = self.shard_key or shard_key - self.http_method = self.http_method or http_method - self.http_headers = self.http_headers or http_headers - return self + return await self.router._remote( + self.endpoint_name, self.handle_options, request_data, kwargs) def __repr__(self): - return f"RayServeHandle(endpoint='{self.endpoint_name}')" + return f"{self.__class__.__name__}(endpoint='{self.endpoint_name}')" + + +class RayServeSyncHandle(RayServeHandle): + def remote(self, request_data: Optional[Union[Dict, Any]] = None, + **kwargs): + """Issue an asynchrounous request to the endpoint. + + Returns a Ray ObjectRef whose results can be waited for or retrieved + using ray.wait or ray.get (or ``await object_ref``), respectively. + + Returns: + ray.ObjectRef + Args: + request_data(dict, Any): If it's a dictionary, the data will be + available in ``request.json()`` or ``request.form()``. + Otherwise, it will be available in ``request.data``. + ``**kwargs``: All keyword arguments will be available in + ``request.args``. + """ + coro = self.router._remote(self.endpoint_name, self.handle_options, + request_data, kwargs) + future: concurrent.futures.Future = asyncio.run_coroutine_threadsafe( + coro, self.router.async_loop) + return future.result() diff --git a/python/ray/serve/tests/test_api.py b/python/ray/serve/tests/test_api.py index ea0f6f35d..318c93732 100644 --- a/python/ray/serve/tests/test_api.py +++ b/python/ray/serve/tests/test_api.py @@ -146,7 +146,7 @@ def test_call_method(serve_instance): # Test serve handle path. handle = client.get_handle("endpoint") - assert ray.get(handle.options("method").remote()) == "hello" + assert ray.get(handle.options(method_name="method").remote()) == "hello" def test_no_route(serve_instance): diff --git a/python/ray/serve/tests/test_handle.py b/python/ray/serve/tests/test_handle.py index cc6b1e72b..c17db7686 100644 --- a/python/ray/serve/tests/test_handle.py +++ b/python/ray/serve/tests/test_handle.py @@ -106,6 +106,36 @@ def test_handle_inject_starlette_request(serve_instance): assert request_type == "" +def test_handle_option_chaining(serve_instance): + # https://github.com/ray-project/ray/issues/12802 + # https://github.com/ray-project/ray/issues/12798 + + client = serve_instance + + class MultiMethod: + def method_a(self, _): + return "method_a" + + def method_b(self, _): + return "method_b" + + def __call__(self, _): + return "__call__" + + client.create_backend("m", MultiMethod) + client.create_endpoint("m", backend="m") + + # get_handle should give you a clean handle + handle1 = client.get_handle("m").options(method_name="method_a") + handle2 = client.get_handle("m") + # options().options() override should work + handle3 = handle1.options(method_name="method_b") + + assert ray.get(handle1.remote()) == "method_a" + assert ray.get(handle2.remote()) == "__call__" + assert ray.get(handle3.remote()) == "method_b" + + if __name__ == "__main__": import sys import pytest From 62a5832007b1282ecf8e1984de2592bba7716241 Mon Sep 17 00:00:00 2001 From: fyrestone Date: Wed, 23 Dec 2020 11:14:23 +0800 Subject: [PATCH 70/88] [Dashboard] Add GET /logical/actors API (#12913) --- dashboard/datacenter.py | 3 +- .../modules/logical_view/logical_view_head.py | 10 +++- .../tests/test_logical_view_head.py | 57 +++++++++++++++++++ .../stats_collector/stats_collector_head.py | 3 +- src/ray/gcs/gcs_server/gcs_actor_manager.cc | 3 + src/ray/gcs/gcs_server/gcs_actor_scheduler.cc | 1 + src/ray/protobuf/gcs.proto | 2 + src/ray/protobuf/node_manager.proto | 2 + src/ray/raylet/node_manager.cc | 5 +- 9 files changed, 82 insertions(+), 4 deletions(-) diff --git a/dashboard/datacenter.py b/dashboard/datacenter.py index 357bab3c1..c6b05d9e8 100644 --- a/dashboard/datacenter.py +++ b/dashboard/datacenter.py @@ -77,7 +77,8 @@ class DataOrganizer: job_workers = {} node_workers = {} core_worker_stats = {} - for node_id in DataSource.nodes.keys(): + # await inside for loop, so we create a copy of keys(). + for node_id in list(DataSource.nodes.keys()): workers = await cls.get_node_workers(node_id) for worker in workers: job_id = worker["jobId"] diff --git a/dashboard/modules/logical_view/logical_view_head.py b/dashboard/modules/logical_view/logical_view_head.py index 674ac64e9..cf29db637 100644 --- a/dashboard/modules/logical_view/logical_view_head.py +++ b/dashboard/modules/logical_view/logical_view_head.py @@ -4,7 +4,7 @@ import ray.utils import ray.new_dashboard.utils as dashboard_utils import ray.new_dashboard.actor_utils as actor_utils from ray.new_dashboard.utils import rest_response -from ray.new_dashboard.datacenter import DataOrganizer +from ray.new_dashboard.datacenter import DataOrganizer, DataSource from ray.core.generated import core_worker_pb2 from ray.core.generated import core_worker_pb2_grpc @@ -29,6 +29,14 @@ class LogicalViewHead(dashboard_utils.DashboardHeadModule): message="Fetched actor groups.", actor_groups=actor_groups) + @routes.get("/logical/actors") + @dashboard_utils.aiohttp_cache + async def get_all_actors(self, req) -> aiohttp.web.Response: + return dashboard_utils.rest_response( + success=True, + message="All actors fetched.", + actors=DataSource.actors) + @routes.get("/logical/kill_actor") async def kill_actor(self, req) -> aiohttp.web.Response: try: diff --git a/dashboard/modules/logical_view/tests/test_logical_view_head.py b/dashboard/modules/logical_view/tests/test_logical_view_head.py index 5e4a8bb6c..ceee063dd 100644 --- a/dashboard/modules/logical_view/tests/test_logical_view_head.py +++ b/dashboard/modules/logical_view/tests/test_logical_view_head.py @@ -79,6 +79,63 @@ def test_actor_groups(ray_start_with_dashboard): raise Exception(f"Timed out while testing, {ex_stack}") +def test_actors(disable_aiohttp_cache, ray_start_with_dashboard): + @ray.remote + class Foo: + def __init__(self, num): + self.num = num + + def do_task(self): + return self.num + + @ray.remote(num_gpus=1) + class InfeasibleActor: + pass + + foo_actors = [Foo.remote(4), Foo.remote(5)] + infeasible_actor = InfeasibleActor.remote() # noqa + results = [actor.do_task.remote() for actor in foo_actors] # noqa + webui_url = ray_start_with_dashboard["webui_url"] + assert wait_until_server_available(webui_url) + webui_url = format_web_url(webui_url) + + timeout_seconds = 5 + start_time = time.time() + last_ex = None + while True: + time.sleep(1) + try: + resp = requests.get(f"{webui_url}/logical/actors") + resp_json = resp.json() + resp_data = resp_json["data"] + actors = resp_data["actors"] + assert len(actors) == 3 + one_entry = list(actors.values())[0] + assert "jobId" in one_entry + assert "taskSpec" in one_entry + assert "functionDescriptor" in one_entry["taskSpec"] + assert type(one_entry["taskSpec"]["functionDescriptor"]) is dict + assert "address" in one_entry + assert type(one_entry["address"]) is dict + assert "state" in one_entry + assert "name" in one_entry + assert "numRestarts" in one_entry + assert "pid" in one_entry + all_pids = [entry["pid"] for entry in actors.values()] + assert 0 in all_pids # The infeasible actor + assert len(all_pids) > 1 + break + except Exception as ex: + last_ex = ex + finally: + if time.time() > start_time + timeout_seconds: + ex_stack = traceback.format_exception( + type(last_ex), last_ex, + last_ex.__traceback__) if last_ex else [] + ex_stack = "".join(ex_stack) + raise Exception(f"Timed out while testing, {ex_stack}") + + def test_kill_actor(ray_start_with_dashboard): @ray.remote class Actor: diff --git a/dashboard/modules/stats_collector/stats_collector_head.py b/dashboard/modules/stats_collector/stats_collector_head.py index 1224b6d62..ae75864e5 100644 --- a/dashboard/modules/stats_collector/stats_collector_head.py +++ b/dashboard/modules/stats_collector/stats_collector_head.py @@ -246,7 +246,8 @@ class StatsCollector(dashboard_utils.DashboardHeadModule): @async_loop_forever( stats_collector_consts.NODE_STATS_UPDATE_INTERVAL_SECONDS) async def _update_node_stats(self): - for node_id, stub in self._stubs.items(): + # Copy self._stubs to avoid `dictionary changed size during iteration`. + for node_id, stub in list(self._stubs.items()): node_info = DataSource.nodes.get(node_id) if node_info["state"] != "ALIVE": continue diff --git a/src/ray/gcs/gcs_server/gcs_actor_manager.cc b/src/ray/gcs/gcs_server/gcs_actor_manager.cc index e30cf2569..17c7c0f5a 100644 --- a/src/ray/gcs/gcs_server/gcs_actor_manager.cc +++ b/src/ray/gcs/gcs_server/gcs_actor_manager.cc @@ -341,6 +341,9 @@ Status GcsActorManager::RegisterActor(const ray::rpc::RegisterActorRequest &requ // the actor state to DEAD to avoid race condition. return; } + RAY_CHECK_OK(gcs_pub_sub_->Publish(ACTOR_CHANNEL, actor->GetActorID().Hex(), + actor->GetActorTableData().SerializeAsString(), + nullptr)); // Invoke all callbacks for all registration requests of this actor (duplicated // requests are included) and remove all of them from // actor_to_register_callbacks_. diff --git a/src/ray/gcs/gcs_server/gcs_actor_scheduler.cc b/src/ray/gcs/gcs_server/gcs_actor_scheduler.cc index aa420f463..bbdd58856 100644 --- a/src/ray/gcs/gcs_server/gcs_actor_scheduler.cc +++ b/src/ray/gcs/gcs_server/gcs_actor_scheduler.cc @@ -323,6 +323,7 @@ void GcsActorScheduler::HandleWorkerLeasedReply( .emplace(leased_worker->GetWorkerID(), leased_worker) .second); actor->UpdateAddress(leased_worker->GetAddress()); + actor->GetMutableActorTableData()->set_pid(reply.worker_pid()); // Make sure to connect to the client before persisting actor info to GCS. // Without this, there could be a possible race condition. Related issues: // https://github.com/ray-project/ray/pull/9215/files#r449469320 diff --git a/src/ray/protobuf/gcs.proto b/src/ray/protobuf/gcs.proto index fe2511bd0..82704bac4 100644 --- a/src/ray/protobuf/gcs.proto +++ b/src/ray/protobuf/gcs.proto @@ -143,6 +143,8 @@ message ActorTableData { // Resource mapping ids acquired by the leased worker. This field is only set when this // actor already has a leased worker. repeated ResourceMapEntry resource_mapping = 15; + // The process id of this actor. + uint32 pid = 16; } message ErrorTableData { diff --git a/src/ray/protobuf/node_manager.proto b/src/ray/protobuf/node_manager.proto index 087135e39..dcf6fe783 100644 --- a/src/ray/protobuf/node_manager.proto +++ b/src/ray/protobuf/node_manager.proto @@ -37,6 +37,8 @@ message RequestWorkerLeaseReply { // Whether this lease request was canceled. In this case, the // client should try again if the resources are still required. bool canceled = 4; + // PID of the worker process. + uint32 worker_pid = 5; } message PrepareBundleResourcesRequest { diff --git a/src/ray/raylet/node_manager.cc b/src/ray/raylet/node_manager.cc index b27a0e32c..52e9354a2 100644 --- a/src/ray/raylet/node_manager.cc +++ b/src/ray/raylet/node_manager.cc @@ -1711,10 +1711,14 @@ void NodeManager::HandleRequestWorkerLease(const rpc::RequestWorkerLeaseRequest [this, owner_address, reply, send_reply_callback]( const std::shared_ptr granted, const std::string &address, int port, const WorkerID &worker_id, const ResourceIdSet &resource_ids) { + auto worker = std::static_pointer_cast(granted); + uint32_t worker_pid = static_cast(worker->GetProcess().GetId()); + reply->mutable_worker_address()->set_ip_address(address); reply->mutable_worker_address()->set_port(port); reply->mutable_worker_address()->set_worker_id(worker_id.Binary()); reply->mutable_worker_address()->set_raylet_id(self_node_id_.Binary()); + reply->set_worker_pid(worker_pid); for (const auto &mapping : resource_ids.AvailableResources()) { auto resource = reply->add_resource_mapping(); resource->set_name(mapping.first); @@ -1743,7 +1747,6 @@ void NodeManager::HandleRequestWorkerLease(const rpc::RequestWorkerLeaseRequest RAY_CHECK(leased_workers_.find(worker_id) == leased_workers_.end()) << "Worker is already leased out " << worker_id; - auto worker = std::static_pointer_cast(granted); leased_workers_[worker_id] = worker; }); task.OnSpillbackInstead( From 646c4201ac3279713800cd84113d6136e69c5208 Mon Sep 17 00:00:00 2001 From: fangfengbin <869218239a@zju.edu.cn> Date: Wed, 23 Dec 2020 11:25:01 +0800 Subject: [PATCH 71/88] [GCS]Decouple gcs resource manager and gcs node manager (#13012) --- src/ray/gcs/gcs_server/gcs_node_manager.cc | 38 ++++++++----------- src/ray/gcs/gcs_server/gcs_node_manager.h | 24 +++++++----- src/ray/gcs/gcs_server/gcs_server.cc | 12 +++++- .../test/gcs_actor_scheduler_test.cc | 4 +- .../gcs_server/test/gcs_node_manager_test.cc | 6 +-- .../test/gcs_object_manager_test.cc | 4 +- .../test/gcs_placement_group_manager_test.cc | 4 +- .../gcs_placement_group_scheduler_test.cc | 17 +++++---- 8 files changed, 58 insertions(+), 51 deletions(-) diff --git a/src/ray/gcs/gcs_server/gcs_node_manager.cc b/src/ray/gcs/gcs_server/gcs_node_manager.cc index 57f878d60..322b0349f 100644 --- a/src/ray/gcs/gcs_server/gcs_node_manager.cc +++ b/src/ray/gcs/gcs_server/gcs_node_manager.cc @@ -23,16 +23,14 @@ namespace ray { namespace gcs { ////////////////////////////////////////////////////////////////////////////////////////// -GcsNodeManager::GcsNodeManager( - boost::asio::io_service &main_io_service, std::shared_ptr gcs_pub_sub, - std::shared_ptr gcs_table_storage, - std::shared_ptr gcs_resource_manager) +GcsNodeManager::GcsNodeManager(boost::asio::io_service &main_io_service, + std::shared_ptr gcs_pub_sub, + std::shared_ptr gcs_table_storage) : resource_timer_(main_io_service), light_report_resource_usage_enabled_( RayConfig::instance().light_report_resource_usage_enabled()), gcs_pub_sub_(gcs_pub_sub), - gcs_table_storage_(gcs_table_storage), - gcs_resource_manager_(gcs_resource_manager) { + gcs_table_storage_(gcs_table_storage) { SendBatchedResourceUsage(); } @@ -104,10 +102,19 @@ void GcsNodeManager::HandleReportResourceUsage( auto resources_data = std::make_shared(); resources_data->CopyFrom(request.resources()); - UpdateNodeResourceUsage(node_id, request); + // We use `node_resource_usages_` to filter out the nodes that report resource + // information for the first time. `UpdateNodeResourceUsage` will modify + // `node_resource_usages_`, so we need to do it before `UpdateNodeResourceUsage`. + if (!light_report_resource_usage_enabled_ || + node_resource_usages_.count(node_id) == 0 || + resources_data->resources_available_changed()) { + const auto &resource_changed = MapFromProtobuf(resources_data->resources_available()); + for (auto &listener : node_resource_changed_listeners_) { + listener(node_id, resource_changed); + } + } - // Update node realtime resources. - UpdateNodeRealtimeResources(node_id, *resources_data); + UpdateNodeResourceUsage(node_id, request); if (!light_report_resource_usage_enabled_ || resources_data->should_global_gc() || resources_data->resources_total_size() > 0 || @@ -240,7 +247,6 @@ void GcsNodeManager::AddNode(std::shared_ptr node) { for (auto &listener : node_added_listeners_) { listener(node); } - gcs_resource_manager_->OnNodeAdd(*node); } } @@ -255,8 +261,6 @@ std::shared_ptr GcsNodeManager::RemoveNode( stats::NodeFailureTotal.Record(1); // Remove from alive nodes. alive_nodes_.erase(iter); - // Remove from cluster resources. - gcs_resource_manager_->OnNodeDead(node_id); resources_buffer_.erase(node_id); node_resource_usages_.erase(node_id); if (!is_intended) { @@ -313,16 +317,6 @@ void GcsNodeManager::Initialize(const GcsInitData &gcs_init_data) { const std::pair &right) { return left.second < right.second; }); } -void GcsNodeManager::UpdateNodeRealtimeResources( - const NodeID &node_id, const rpc::ResourcesData &resource_data) { - if (!light_report_resource_usage_enabled_ || - gcs_resource_manager_->GetClusterResources().count(node_id) == 0 || - resource_data.resources_available_changed()) { - gcs_resource_manager_->SetAvailableResources( - node_id, ResourceSet(MapFromProtobuf(resource_data.resources_available()))); - } -} - void GcsNodeManager::UpdatePlacementGroupLoad( const std::shared_ptr placement_group_load) { placement_group_load_ = absl::make_optional(placement_group_load); diff --git a/src/ray/gcs/gcs_server/gcs_node_manager.h b/src/ray/gcs/gcs_server/gcs_node_manager.h index 3c62af3a7..8b99eaa13 100644 --- a/src/ray/gcs/gcs_server/gcs_node_manager.h +++ b/src/ray/gcs/gcs_server/gcs_node_manager.h @@ -39,11 +39,9 @@ class GcsNodeManager : public rpc::NodeInfoHandler { /// \param main_io_service The main event loop. /// \param gcs_pub_sub GCS message publisher. /// \param gcs_table_storage GCS table external storage accessor. - /// \param gcs_resource_manager GCS resource manager. explicit GcsNodeManager(boost::asio::io_service &main_io_service, std::shared_ptr gcs_pub_sub, - std::shared_ptr gcs_table_storage, - std::shared_ptr gcs_resource_manager); + std::shared_ptr gcs_table_storage); /// Handle register rpc request come from raylet. void HandleRegisterNode(const rpc::RegisterNodeRequest &request, @@ -135,16 +133,22 @@ class GcsNodeManager : public rpc::NodeInfoHandler { node_added_listeners_.emplace_back(std::move(listener)); } + /// Add listener to monitor the resource change of nodes. + /// + /// \param listener The handler which process the resource change of nodes. + void AddNodeResourceChangedListener( + std::function &)> + listener) { + RAY_CHECK(listener); + node_resource_changed_listeners_.emplace_back(std::move(listener)); + } + /// Initialize with the gcs tables data synchronously. /// This should be called when GCS server restarts after a failure. /// /// \param gcs_init_data. void Initialize(const GcsInitData &gcs_init_data); - // Update node realtime resources. - void UpdateNodeRealtimeResources(const NodeID &node_id, - const rpc::ResourcesData &heartbeat); - /// Update the placement group load information so that it will be reported through /// heartbeat. /// @@ -185,12 +189,14 @@ class GcsNodeManager : public rpc::NodeInfoHandler { /// Listeners which monitors the removal of nodes. std::vector)>> node_removed_listeners_; + /// Listeners which monitors the resource change of nodes. + std::vector &)>> + node_resource_changed_listeners_; /// A publisher for publishing gcs messages. std::shared_ptr gcs_pub_sub_; /// Storage for GCS tables. std::shared_ptr gcs_table_storage_; - /// Gcs resource manager. - std::shared_ptr gcs_resource_manager_; /// Placement group load information that is used for autoscaler. absl::optional> placement_group_load_; diff --git a/src/ray/gcs/gcs_server/gcs_server.cc b/src/ray/gcs/gcs_server/gcs_server.cc index 23a12f6ec..71e2a6d81 100644 --- a/src/ray/gcs/gcs_server/gcs_server.cc +++ b/src/ray/gcs/gcs_server/gcs_server.cc @@ -133,8 +133,8 @@ void GcsServer::Stop() { void GcsServer::InitGcsNodeManager(const GcsInitData &gcs_init_data) { RAY_CHECK(redis_gcs_client_ && gcs_table_storage_ && gcs_pub_sub_); - gcs_node_manager_ = std::make_shared( - main_service_, gcs_pub_sub_, gcs_table_storage_, gcs_resource_manager_); + gcs_node_manager_ = + std::make_shared(main_service_, gcs_pub_sub_, gcs_table_storage_); // Initialize by gcs tables data. gcs_node_manager_->Initialize(gcs_init_data); // Register service. @@ -292,6 +292,7 @@ void GcsServer::InstallEventListeners() { gcs_node_manager_->AddNodeAddedListener([this](std::shared_ptr node) { // Because a new node has been added, we need to try to schedule the pending // placement groups and the pending actors. + gcs_resource_manager_->OnNodeAdd(*node); gcs_placement_group_manager_->SchedulePendingPlacementGroups(); gcs_actor_manager_->SchedulePendingActors(); gcs_heartbeat_manager_->AddNode(NodeID::FromBinary(node->node_id())); @@ -301,10 +302,17 @@ void GcsServer::InstallEventListeners() { auto node_id = NodeID::FromBinary(node->node_id()); // All of the related placement groups and actors should be reconstructed when a // node is removed from the GCS. + gcs_resource_manager_->OnNodeDead(node_id); gcs_placement_group_manager_->OnNodeDead(node_id); gcs_actor_manager_->OnNodeDead(node_id); raylet_client_pool_->Disconnect(NodeID::FromBinary(node->node_id())); }); + gcs_node_manager_->AddNodeResourceChangedListener( + [this](const NodeID &node_id, + const std::unordered_map &resource_changed) { + gcs_resource_manager_->SetAvailableResources(node_id, + ResourceSet(resource_changed)); + }); // Install worker event listener. gcs_worker_manager_->AddWorkerDeadListener( diff --git a/src/ray/gcs/gcs_server/test/gcs_actor_scheduler_test.cc b/src/ray/gcs/gcs_server/test/gcs_actor_scheduler_test.cc index 4ddba0627..7bb1ca716 100644 --- a/src/ray/gcs/gcs_server/test/gcs_actor_scheduler_test.cc +++ b/src/ray/gcs/gcs_server/test/gcs_actor_scheduler_test.cc @@ -27,8 +27,8 @@ class GcsActorSchedulerTest : public ::testing::Test { gcs_pub_sub_ = std::make_shared(redis_client_); gcs_table_storage_ = std::make_shared(redis_client_); gcs_resource_manager_ = std::make_shared(nullptr, nullptr); - gcs_node_manager_ = std::make_shared( - io_service_, gcs_pub_sub_, gcs_table_storage_, gcs_resource_manager_); + gcs_node_manager_ = std::make_shared(io_service_, gcs_pub_sub_, + gcs_table_storage_); store_client_ = std::make_shared(io_service_); gcs_actor_table_ = std::make_shared(store_client_); diff --git a/src/ray/gcs/gcs_server/test/gcs_node_manager_test.cc b/src/ray/gcs/gcs_server/test/gcs_node_manager_test.cc index 74c4b8fd1..25f80733a 100644 --- a/src/ray/gcs/gcs_server/test/gcs_node_manager_test.cc +++ b/src/ray/gcs/gcs_server/test/gcs_node_manager_test.cc @@ -35,8 +35,7 @@ class GcsNodeManagerTest : public ::testing::Test { TEST_F(GcsNodeManagerTest, TestManagement) { boost::asio::io_service io_service; - gcs::GcsNodeManager node_manager(io_service, gcs_pub_sub_, gcs_table_storage_, - gcs_resource_manager_); + gcs::GcsNodeManager node_manager(io_service, gcs_pub_sub_, gcs_table_storage_); // Test Add/Get/Remove functionality. auto node = Mocker::GenNodeInfo(); auto node_id = NodeID::FromBinary(node->node_id()); @@ -82,8 +81,7 @@ TEST_F(GcsNodeManagerTest, TestManagement) { TEST_F(GcsNodeManagerTest, TestListener) { boost::asio::io_service io_service; - gcs::GcsNodeManager node_manager(io_service, gcs_pub_sub_, gcs_table_storage_, - gcs_resource_manager_); + gcs::GcsNodeManager node_manager(io_service, gcs_pub_sub_, gcs_table_storage_); // Test AddNodeAddedListener. int node_count = 1000; std::vector> added_nodes; diff --git a/src/ray/gcs/gcs_server/test/gcs_object_manager_test.cc b/src/ray/gcs/gcs_server/test/gcs_object_manager_test.cc index 15f96a6a8..700fdfc10 100644 --- a/src/ray/gcs/gcs_server/test/gcs_object_manager_test.cc +++ b/src/ray/gcs/gcs_server/test/gcs_object_manager_test.cc @@ -55,8 +55,8 @@ class GcsObjectManagerTest : public ::testing::Test { void SetUp() override { gcs_table_storage_ = std::make_shared(io_service_); gcs_resource_manager_ = std::make_shared(nullptr, nullptr); - gcs_node_manager_ = std::make_shared( - io_service_, gcs_pub_sub_, gcs_table_storage_, gcs_resource_manager_); + gcs_node_manager_ = std::make_shared(io_service_, gcs_pub_sub_, + gcs_table_storage_); gcs_object_manager_ = std::make_shared( gcs_table_storage_, gcs_pub_sub_, *gcs_node_manager_); GenTestData(); diff --git a/src/ray/gcs/gcs_server/test/gcs_placement_group_manager_test.cc b/src/ray/gcs/gcs_server/test/gcs_placement_group_manager_test.cc index e74b5fe1b..70bfdce31 100644 --- a/src/ray/gcs/gcs_server/test/gcs_placement_group_manager_test.cc +++ b/src/ray/gcs/gcs_server/test/gcs_placement_group_manager_test.cc @@ -69,8 +69,8 @@ class GcsPlacementGroupManagerTest : public ::testing::Test { gcs_pub_sub_ = std::make_shared(redis_client_); gcs_table_storage_ = std::make_shared(io_service_); gcs_resource_manager_ = std::make_shared(nullptr, nullptr); - gcs_node_manager_ = std::make_shared( - io_service_, gcs_pub_sub_, gcs_table_storage_, gcs_resource_manager_); + gcs_node_manager_ = std::make_shared(io_service_, gcs_pub_sub_, + gcs_table_storage_); gcs_placement_group_manager_.reset( new gcs::GcsPlacementGroupManager(io_service_, mock_placement_group_scheduler_, gcs_table_storage_, *gcs_node_manager_)); diff --git a/src/ray/gcs/gcs_server/test/gcs_placement_group_scheduler_test.cc b/src/ray/gcs/gcs_server/test/gcs_placement_group_scheduler_test.cc index ef81f8887..6a0a5839b 100644 --- a/src/ray/gcs/gcs_server/test/gcs_placement_group_scheduler_test.cc +++ b/src/ray/gcs/gcs_server/test/gcs_placement_group_scheduler_test.cc @@ -40,8 +40,8 @@ class GcsPlacementGroupSchedulerTest : public ::testing::Test { gcs_table_storage_ = std::make_shared(io_service_); gcs_pub_sub_ = std::make_shared(redis_client_); gcs_resource_manager_ = std::make_shared(nullptr, nullptr); - gcs_node_manager_ = std::make_shared( - io_service_, gcs_pub_sub_, gcs_table_storage_, gcs_resource_manager_); + gcs_node_manager_ = std::make_shared(io_service_, gcs_pub_sub_, + gcs_table_storage_); gcs_table_storage_ = std::make_shared(io_service_); store_client_ = std::make_shared(io_service_); raylet_client_pool_ = std::make_shared( @@ -98,12 +98,13 @@ class GcsPlacementGroupSchedulerTest : public ::testing::Test { void AddNode(const std::shared_ptr &node, int cpu_num = 10) { gcs_node_manager_->AddNode(node); - rpc::ResourcesData resource; - resource.set_node_id(node->node_id()); - (*resource.mutable_resources_available())["CPU"] = cpu_num; - resource.set_resources_available_changed(true); - gcs_node_manager_->UpdateNodeRealtimeResources(NodeID::FromBinary(node->node_id()), - resource); + gcs_resource_manager_->OnNodeAdd(*node); + + const auto &node_id = NodeID::FromBinary(node->node_id()); + std::unordered_map resource_map; + resource_map["CPU"] = cpu_num; + ResourceSet resources(resource_map); + gcs_resource_manager_->SetAvailableResources(node_id, resources); } void ScheduleFailedWithZeroNodeTest(rpc::PlacementStrategy strategy) { From c4e273920f517b18c99fbabca49135dd6e30e683 Mon Sep 17 00:00:00 2001 From: Barak Michener Date: Tue, 22 Dec 2020 22:51:45 -0800 Subject: [PATCH 72/88] [ray_client]: Insert decorators into the real ray module to allow for client mode (#13031) --- .travis.yml | 2 +- ci/travis/ci.sh | 7 + python/ray/_private/client_mode_hook.py | 47 ++++ python/ray/_raylet.pyx | 7 + python/ray/experimental/client/__init__.py | 200 +++++++---------- python/ray/experimental/client/api.py | 212 +++++++----------- .../ray/experimental/client/client_pickler.py | 13 +- python/ray/experimental/client/common.py | 21 +- python/ray/experimental/client/dataclient.py | 3 +- .../experimental/client/examples/run_tune.py | 7 + python/ray/experimental/client/logsclient.py | 14 +- .../experimental/client/ray_client_helpers.py | 7 +- .../client/server/core_ray_api.py | 81 ------- .../experimental/client/server/logservicer.py | 6 +- .../ray/experimental/client/server/server.py | 132 ++++++----- .../client/server/server_pickler.py | 17 +- python/ray/experimental/client/worker.py | 42 +--- python/ray/state.py | 4 + python/ray/test_utils.py | 2 +- python/ray/tests/BUILD | 2 +- python/ray/tests/conftest.py | 24 +- python/ray/tests/test_actor.py | 9 +- python/ray/tests/test_advanced.py | 2 + python/ray/tests/test_basic.py | 5 +- python/ray/tests/test_experimental_client.py | 30 ++- .../test_experimental_client_metadata.py | 3 +- .../test_experimental_client_references.py | 6 +- python/ray/worker.py | 11 + 28 files changed, 419 insertions(+), 497 deletions(-) create mode 100644 python/ray/_private/client_mode_hook.py create mode 100644 python/ray/experimental/client/examples/run_tune.py delete mode 100644 python/ray/experimental/client/server/core_ray_api.py diff --git a/.travis.yml b/.travis.yml index dc133de49..8173ec1ac 100644 --- a/.travis.yml +++ b/.travis.yml @@ -46,7 +46,7 @@ matrix: script: # bazel python tests for medium size tests. Used for parallelization. - if [ $RAY_CI_PYTHON_AFFECTED == "1" ]; then ./ci/keep_alive bazel test --config=ci $(./scripts/bazel_export_options) --test_tag_filters=-jenkins_only,medium_size_python_tests_a_to_j python/ray/tests/...; fi - - if [ $RAY_CI_PYTHON_AFFECTED == "1" ]; then ./ci/keep_alive bazel test --config=ci $(./scripts/bazel_export_options) --test_tag_filters=-jenkins_only,client_tests --test_env=RAY_TEST_CLIENT_MODE=1 python/ray/tests/...; fi + - if [ $RAY_CI_PYTHON_AFFECTED == "1" ]; then ./ci/keep_alive bazel test --config=ci $(./scripts/bazel_export_options) --test_tag_filters=-jenkins_only,client_tests --test_env=RAY_CLIENT_MODE=1 python/ray/tests/...; fi - os: linux env: diff --git a/ci/travis/ci.sh b/ci/travis/ci.sh index 843515400..e4d9741cd 100755 --- a/ci/travis/ci.sh +++ b/ci/travis/ci.sh @@ -262,6 +262,11 @@ _bazel_build_before_install() { bazel build "${target}" } + +_bazel_build_protobuf() { + bazel build "//:install_py_proto" +} + install_ray() { # TODO(mehrdadn): This function should be unified with the one in python/build-wheel-windows.sh. ( @@ -457,6 +462,8 @@ init() { build() { if [ "${LINT-}" != 1 ]; then _bazel_build_before_install + else + _bazel_build_protobuf fi if ! need_wheels; then diff --git a/python/ray/_private/client_mode_hook.py b/python/ray/_private/client_mode_hook.py new file mode 100644 index 000000000..4fbc568c8 --- /dev/null +++ b/python/ray/_private/client_mode_hook.py @@ -0,0 +1,47 @@ +import os +from contextlib import contextmanager + +client_mode_enabled = os.environ.get("RAY_CLIENT_MODE", "0") == "1" + +_client_hook_enabled = True + + +def _enable_client_hook(val: bool): + global _client_hook_enabled + _client_hook_enabled = val + + +def _disable_client_hook(): + global _client_hook_enabled + out = _client_hook_enabled + _client_hook_enabled = False + return out + + +def _explicitly_enable_client_mode(): + global client_mode_enabled + client_mode_enabled = True + + +@contextmanager +def disable_client_hook(): + val = _disable_client_hook() + try: + yield None + finally: + _enable_client_hook(val) + + +def client_mode_hook(func): + """ + Decorator for ray module methods to delegate to ray client + """ + from ray.experimental.client import ray + + def wrapper(*args, **kwargs): + global _client_hook_enabled + if client_mode_enabled and _client_hook_enabled: + return getattr(ray, func.__name__)(*args, **kwargs) + return func(*args, **kwargs) + + return wrapper diff --git a/python/ray/_raylet.pyx b/python/ray/_raylet.pyx index 356222bb9..1360c96ce 100644 --- a/python/ray/_raylet.pyx +++ b/python/ray/_raylet.pyx @@ -107,6 +107,10 @@ from ray.exceptions import ( TaskCancelledError ) from ray.utils import decode +from ray._private.client_mode_hook import ( + _enable_client_hook, + _disable_client_hook, +) import msgpack cimport cpython @@ -558,6 +562,7 @@ cdef CRayStatus task_execution_handler( with gil: try: + client_was_enabled = _disable_client_hook() try: # The call to execute_task should never raise an exception. If # it does, that indicates that there was an internal error. @@ -582,6 +587,8 @@ cdef CRayStatus task_execution_handler( else: logger.exception("SystemExit was raised from the worker") return CRayStatus.UnexpectedSystemExit() + finally: + _enable_client_hook(client_was_enabled) return CRayStatus.OK() diff --git a/python/ray/experimental/client/__init__.py b/python/ray/experimental/client/__init__.py index ed1983528..674dfa7f7 100644 --- a/python/ray/experimental/client/__init__.py +++ b/python/ray/experimental/client/__init__.py @@ -1,148 +1,104 @@ -from ray.experimental.client.api import ClientAPI -from ray.experimental.client.api import APIImpl -from typing import Optional, List, Tuple -from contextlib import contextmanager +from typing import List, Tuple import logging -import os logger = logging.getLogger(__name__) -# About these global variables: Ray 1.0 uses exported module functions to -# provide its API, and we need to match that. However, we want different -# behaviors depending on where, exactly, in the client stack this is running. -# -# The reason for these differences depends on what's being pickled and passed -# to functions, or functions inside functions. So there are three cases to care -# about -# -# (Python Client)-->(Python ClientServer)-->(Internal Raylet Process) -# -# * _client_api should be set if we're inside the client -# * _server_api should be set if we're inside the clientserver -# * Both will be set if we're running both (as in a test) -# * Neither should be set if we're inside the raylet (but we still need to shim -# from the client API surface to the Ray API) -# -# The job of RayAPIStub (below) delegates to the appropriate one of these -# depending on what's set or not. Then, all users importing the ray object -# from this package get the stub which routes them to the appropriate APIImpl. -_client_api: Optional[APIImpl] = None -_server_api: Optional[APIImpl] = None - -# The reason for _is_server is a hack around the above comment while running -# tests. If we have both a client and a server trying to control these static -# variables then we need a way to decide which to use. In this case, both -# _client_api and _server_api are set. -# This boolean flips between the two -_is_server: bool = False - - -@contextmanager -def stash_api_for_tests(in_test: bool): - global _is_server - is_server = _is_server - if in_test: - _is_server = True - try: - yield _server_api - finally: - if in_test: - _is_server = is_server - - -def _set_client_api(val: Optional[APIImpl]): - global _client_api - global _is_server - if _client_api is not None: - raise Exception("Trying to set more than one client API") - _client_api = val - _is_server = False - - -def _set_server_api(val: Optional[APIImpl]): - global _server_api - global _is_server - if _server_api is not None: - raise Exception("Trying to set more than one server API") - _server_api = val - _is_server = True - - -def reset_api(): - global _client_api - global _server_api - global _is_server - _client_api = None - _server_api = None - _is_server = False - - -def _get_client_api() -> APIImpl: - global _client_api - return _client_api - - -def _get_server_instance(): - """Used inside tests to inspect the running server. - """ - global _server_api - if _server_api is not None: - return _server_api.server - class RayAPIStub: + """This class stands in as the replacement API for the `import ray` module. + + Much like the ray module, this mostly delegates the work to the + _client_worker. As parts of the ray API are covered, they are piped through + here or on the client worker API. + """ + + def __init__(self): + from ray.experimental.client.api import ClientAPI + self.api = ClientAPI() + self.client_worker = None + self._server = None + self._connected_with_init = False + self._inside_client_test = False + def connect(self, conn_str: str, secure: bool = False, - metadata: List[Tuple[str, str]] = None, - stub=None) -> None: + metadata: List[Tuple[str, str]] = None) -> None: + """Connect the Ray Client to a server. + + Args: + conn_str: Connection string, in the form "[host]:port" + secure: Whether to use a TLS secured gRPC channel + metadata: gRPC metadata to send on connect + """ + # Delay imports until connect to avoid circular imports. from ray.experimental.client.worker import Worker - _client_worker = Worker(conn_str, secure=secure, metadata=metadata) - _set_client_api(ClientAPI(_client_worker)) + import ray._private.client_mode_hook + if self.client_worker is not None: + if self._connected_with_init: + return + raise Exception( + "ray.connect() called, but ray client is already connected") + if not self._inside_client_test: + # If we're calling a client connect specifically and we're not + # currently in client mode, ensure we are. + ray._private.client_mode_hook._explicitly_enable_client_mode() + self.client_worker = Worker(conn_str, secure=secure, metadata=metadata) + self.api.worker = self.client_worker def disconnect(self): - global _client_api - if _client_api is not None: - _client_api.close() - _client_api = None + """Disconnect the Ray Client. + """ + if self.client_worker is not None: + self.client_worker.close() + self.client_worker = None + + # remote can be called outside of a connection, which is why it + # exists on the same API layer as connect() itself. + def remote(self, *args, **kwargs): + """remote is the hook stub passed on to replace `ray.remote`. + + This sets up remote functions or actors, as the decorator, + but does not execute them. + + Args: + args: opaque arguments + kwargs: opaque keyword arguments + """ + return self.api.remote(*args, **kwargs) def __getattr__(self, key: str): - global _get_client_api - api = _get_client_api() - return getattr(api, key) + if not self.is_connected(): + raise Exception("Ray Client is not connected. " + "Please connect by calling `ray.connect`.") + return getattr(self.api, key) def is_connected(self) -> bool: - global _client_api - return _client_api is not None + return self.api is not None def init(self, *args, **kwargs): - if _is_client_test_env(): - global _test_server - import ray.experimental.client.server.server as ray_client_server - _test_server, address_info = ray_client_server.init_and_serve( - "localhost:50051", test_mode=True, *args, **kwargs) - self.connect("localhost:50051") - return address_info - else: - raise NotImplementedError( - "Please call ray.connect() in client mode") + if self._server is not None: + raise Exception("Trying to start two instances of ray via client") + import ray.experimental.client.server.server as ray_client_server + self._server, address_info = ray_client_server.init_and_serve( + "localhost:50051", *args, **kwargs) + self.connect("localhost:50051") + self._connected_with_init = True + return address_info + + def shutdown(self, _exiting_interpreter=False): + self.disconnect() + import ray.experimental.client.server.server as ray_client_server + if self._server is None: + return + ray_client_server.shutdown_with_server(self._server, + _exiting_interpreter) + self._server = None ray = RayAPIStub() -_test_server = None - - -def _stop_test_server(*args): - global _test_server - _test_server.stop(*args) - - -def _is_client_test_env() -> bool: - return os.environ.get("RAY_TEST_CLIENT_MODE") == "1" - - # Someday we might add methods in this module so that someone who # tries to `import ray_client as ray` -- as a module, instead of # `from ray_client import ray` -- as the API stub diff --git a/python/ray/experimental/client/api.py b/python/ray/experimental/client/api.py index 93da6382f..58680bf9f 100644 --- a/python/ray/experimental/client/api.py +++ b/python/ray/experimental/client/api.py @@ -1,74 +1,51 @@ -# This file defines an interface and client-side API stub -# for referring either to the core Ray API or the same interface -# from the Ray client. -# -# In tandem with __init__.py, we want to expose an API that's -# close to `python/ray/__init__.py` but with more than one implementation. -# The stubs in __init__ should call into a well-defined interface. -# Only the core Ray API implementation should actually `import ray` -# (and thus import all the raylet worker C bindings and such). -# But to make sure that we're matching these calls, we define this API. - -from abc import ABC -from abc import abstractmethod -from typing import TYPE_CHECKING, Any, Union, Optional -import ray.core.generated.ray_client_pb2 as ray_client_pb2 +"""This file defines the interface between the ray client worker +and the overall ray module API. +""" +from typing import TYPE_CHECKING if TYPE_CHECKING: - from ray.experimental.client.common import ClientActorHandle from ray.experimental.client.common import ClientStub + from ray.experimental.client.common import ClientActorHandle from ray.experimental.client.common import ClientObjectRef - from ray._raylet import ObjectRef - - # Use the imports for type checking. This is a python 3.6 limitation. - # See https://www.python.org/dev/peps/pep-0563/ - PutType = Union[ClientObjectRef, ObjectRef] -class APIImpl(ABC): - """ - APIImpl is the interface to implement for whichever version of the core - Ray API that needs abstracting when run in client mode. +class ClientAPI: + """The Client-side methods corresponding to the ray API. Delegates + to the Client Worker that contains the connection to the ClientServer. """ - @abstractmethod - def get(self, vals, *, timeout: Optional[float] = None) -> Any: - """ - get is the hook stub passed on to replace `ray.get` + def __init__(self, worker=None): + self.worker = worker + + def get(self, vals, *, timeout=None): + """get is the hook stub passed on to replace `ray.get` Args: vals: [Client]ObjectRef or list of these refs to retrieve. timeout: Optional timeout in milliseconds """ - pass + return self.worker.get(vals, timeout=timeout) - @abstractmethod - def put(self, vals: Any, *args, - **kwargs) -> Union["ClientObjectRef", "ObjectRef"]: - """ - put is the hook stub passed on to replace `ray.put` + def put(self, *args, **kwargs): + """put is the hook stub passed on to replace `ray.put` Args: vals: The value or list of values to `put`. args: opaque arguments kwargs: opaque keyword arguments """ - pass + return self.worker.put(*args, **kwargs) - @abstractmethod def wait(self, *args, **kwargs): - """ - wait is the hook stub passed on to replace `ray.wait` + """wait is the hook stub passed on to replace `ray.wait` Args: args: opaque arguments kwargs: opaque keyword arguments """ - pass + return self.worker.wait(*args, **kwargs) - @abstractmethod def remote(self, *args, **kwargs): - """ - remote is the hook stub passed on to replace `ray.remote`. + """remote is the hook stub passed on to replace `ray.remote`. This sets up remote functions or actors, as the decorator, but does not execute them. @@ -77,12 +54,24 @@ class APIImpl(ABC): args: opaque arguments kwargs: opaque keyword arguments """ - pass + # Delayed import to avoid a cyclic import + from ray.experimental.client.common import remote_decorator + if len(args) == 1 and len(kwargs) == 0 and callable(args[0]): + # This is the case where the decorator is just @ray.remote. + return remote_decorator(options=None)(args[0]) + error_string = ("The @ray.remote decorator must be applied either " + "with no arguments and no parentheses, for example " + "'@ray.remote', or it must be applied using some of " + "the arguments 'num_returns', 'num_cpus', 'num_gpus', " + "'memory', 'object_store_memory', 'resources', " + "'max_calls', or 'max_restarts', like " + "'@ray.remote(num_returns=2, " + "resources={\"CustomResource\": 1})'.") + assert len(args) == 0 and len(kwargs) > 0, error_string + return remote_decorator(options=kwargs) - @abstractmethod def call_remote(self, instance: "ClientStub", *args, **kwargs): - """ - call_remote is called by stub objects to execute them remotely. + """call_remote is called by stub objects to execute them remotely. This is used by stub objects in situations where they're called with .remote, eg, `f.remote()` or `actor_cls.remote()`. @@ -95,31 +84,57 @@ class APIImpl(ABC): args: opaque arguments kwargs: opaque keyword arguments """ - pass + return self.worker.call_remote(instance, *args, **kwargs) - @abstractmethod - def close(self) -> None: + def call_release(self, id: bytes) -> None: + """Attempts to release an object reference. + + When client references are destructed, they release their reference, + which can opportunistically send a notification through the datachannel + to release the reference being held for that object on the server. + + Args: + id: The id of the reference to release on the server side. """ - close cleans up an API connection by closing any channels or + return self.worker.call_release(id) + + def call_retain(self, id: bytes) -> None: + """Attempts to retain a client object reference. + + Increments the reference count on the client side, to prevent + the client worker from attempting to release the server reference. + + Args: + id: The id of the reference to retain on the client side. + """ + return self.worker.call_retain(id) + + def close(self) -> None: + """close cleans up an API connection by closing any channels or shutting down any servers gracefully. """ - pass + return self.worker.close() - @abstractmethod - def kill(self, actor, *, no_restart=True): + def get_actor(self, name: str) -> "ClientActorHandle": + """Returns a handle to an actor by name. + + Args: + name: The name passed to this actor by + Actor.options(name="name").remote() """ - kill forcibly stops an actor running in the cluster + return self.worker.get_actor(name) + + def kill(self, actor: "ClientActorHandle", *, no_restart=True): + """kill forcibly stops an actor running in the cluster Args: no_restart: Whether this actor should be restarted if it's a restartable actor. """ - pass + return self.worker.terminate_actor(actor, no_restart) - @abstractmethod - def cancel(self, obj, *, force=False, recursive=True): - """ - Cancels a task on the cluster. + def cancel(self, obj: "ClientObjectRef", *, force=False, recursive=True): + """Cancels a task on the cluster. If the specified task is pending execution, it will not be executed. If the task is currently executing, the behavior depends on the ``force`` @@ -136,80 +151,11 @@ class APIImpl(ABC): recursive (boolean): Whether to try to cancel tasks submitted by the task specified. """ - pass - - @abstractmethod - def call_release(self, id: bytes) -> None: - """ - Attempts to release an object reference. - - When client references are destructed, they release their reference, - which can opportunistically send a notification through the datachannel - to release the reference being held for that object on the server. - - Args: - id: The id of the reference to release on the server side. - """ - - @abstractmethod - def call_retain(self, id: bytes) -> None: - """ - Attempts to retain a client object reference. - - Increments the reference count on the client side, to prevent - the client worker from attempting to release the server reference. - - Args: - id: The id of the reference to retain on the client side. - """ - - -class ClientAPI(APIImpl): - """ - The Client-side methods corresponding to the ray API. Delegates - to the Client Worker that contains the connection to the ClientServer. - """ - - def __init__(self, worker): - self.worker = worker - - def get(self, vals, *, timeout=None): - return self.worker.get(vals, timeout=timeout) - - def put(self, *args, **kwargs): - return self.worker.put(*args, **kwargs) - - def wait(self, *args, **kwargs): - return self.worker.wait(*args, **kwargs) - - def remote(self, *args, **kwargs): - return self.worker.remote(*args, **kwargs) - - def call_remote(self, instance: "ClientStub", *args, **kwargs): - return self.worker.call_remote(instance, *args, **kwargs) - - def call_release(self, id: bytes) -> None: - return self.worker.call_release(id) - - def call_retain(self, id: bytes) -> None: - return self.worker.call_retain(id) - - def close(self) -> None: - return self.worker.close() - - def get_actor(self, name: str) -> "ClientActorHandle": - return self.worker.get_actor(name) - - def kill(self, actor: "ClientActorHandle", *, no_restart=True): - return self.worker.terminate_actor(actor, no_restart) - - def cancel(self, obj: "ClientObjectRef", *, force=False, recursive=True): return self.worker.terminate_task(obj, force, recursive) # Various metadata methods for the client that are defined in the protocol. def is_initialized(self) -> bool: - """ True if our client is connected, and if the server is initialized. - + """True if our client is connected, and if the server is initialized. Returns: A boolean determining if the client is connected and server initialized. @@ -222,6 +168,8 @@ class ClientAPI(APIImpl): Returns: Information about the Ray clients in the cluster. """ + # This should be imported here, otherwise, it will error doc build. + import ray.core.generated.ray_client_pb2 as ray_client_pb2 return self.worker.get_cluster_info( ray_client_pb2.ClusterInfoType.NODES) @@ -235,6 +183,8 @@ class ClientAPI(APIImpl): A dictionary mapping resource name to the total quantity of that resource in the cluster. """ + # This should be imported here, otherwise, it will error doc build. + import ray.core.generated.ray_client_pb2 as ray_client_pb2 return self.worker.get_cluster_info( ray_client_pb2.ClusterInfoType.CLUSTER_RESOURCES) @@ -250,6 +200,8 @@ class ClientAPI(APIImpl): A dictionary mapping resource name to the total quantity of that resource in the cluster. """ + # This should be imported here, otherwise, it will error doc build. + import ray.core.generated.ray_client_pb2 as ray_client_pb2 return self.worker.get_cluster_info( ray_client_pb2.ClusterInfoType.AVAILABLE_RESOURCES) diff --git a/python/ray/experimental/client/client_pickler.py b/python/ray/experimental/client/client_pickler.py index 7ba83b3ac..863884687 100644 --- a/python/ray/experimental/client/client_pickler.py +++ b/python/ray/experimental/client/client_pickler.py @@ -1,5 +1,4 @@ -""" -Implements the client side of the client/server pickling protocol. +"""Implements the client side of the client/server pickling protocol. All ray client client/server data transfer happens through this pickling protocol. The model is as follows: @@ -41,6 +40,7 @@ from ray.experimental.client.common import ClientRemoteMethod from ray.experimental.client.common import OptionWrapper from ray.experimental.client.common import SelfReferenceSentinel import ray.core.generated.ray_client_pb2 as ray_client_pb2 +from ray._private.client_mode_hook import disable_client_hook if sys.version_info < (3, 8): try: @@ -155,10 +155,11 @@ class ServerUnpickler(pickle.Unpickler): def dumps_from_client(obj: Any, client_id: str, protocol=None) -> bytes: - with io.BytesIO() as file: - cp = ClientPickler(client_id, file, protocol=protocol) - cp.dump(obj) - return file.getvalue() + with disable_client_hook(): + with io.BytesIO() as file: + cp = ClientPickler(client_id, file, protocol=protocol) + cp.dump(obj) + return file.getvalue() def loads_from_server(data: bytes, diff --git a/python/ray/experimental/client/common.py b/python/ray/experimental/client/common.py index f68b26e2c..18708f279 100644 --- a/python/ray/experimental/client/common.py +++ b/python/ray/experimental/client/common.py @@ -2,6 +2,8 @@ import ray.core.generated.ray_client_pb2 as ray_client_pb2 from ray.experimental.client import ray from ray.experimental.client.options import validate_options +import inspect +from ray.util.inspect import is_cython import json import threading from typing import Any @@ -52,8 +54,7 @@ class ClientStub: class ClientRemoteFunc(ClientStub): - """ - A stub created on the Ray Client to represent a remote + """A stub created on the Ray Client to represent a remote function that can be exectued on the cluster. This class is allowed to be passed around between remote functions. @@ -112,7 +113,7 @@ class ClientRemoteFunc(ClientStub): class ClientActorClass(ClientStub): - """ A stub created on the Ray Client to represent an actor class. + """A stub created on the Ray Client to represent an actor class. It is wrapped by ray.remote and can be executed on the cluster. @@ -294,3 +295,17 @@ class DataEncodingSentinel: class SelfReferenceSentinel(DataEncodingSentinel): pass + + +def remote_decorator(options: Optional[Dict[str, Any]]): + def decorator(function_or_class) -> ClientStub: + if (inspect.isfunction(function_or_class) + or is_cython(function_or_class)): + return ClientRemoteFunc(function_or_class, options=options) + elif inspect.isclass(function_or_class): + return ClientActorClass(function_or_class, options=options) + else: + raise TypeError("The @ray.remote decorator must be applied to " + "either a function or to a class.") + + return decorator diff --git a/python/ray/experimental/client/dataclient.py b/python/ray/experimental/client/dataclient.py index b0dda0a1b..add66c82c 100644 --- a/python/ray/experimental/client/dataclient.py +++ b/python/ray/experimental/client/dataclient.py @@ -1,5 +1,4 @@ -""" -This file implements a threaded stream controller to abstract a data stream +"""This file implements a threaded stream controller to abstract a data stream back to the ray clientserver. """ import logging diff --git a/python/ray/experimental/client/examples/run_tune.py b/python/ray/experimental/client/examples/run_tune.py new file mode 100644 index 000000000..9e0592c1e --- /dev/null +++ b/python/ray/experimental/client/examples/run_tune.py @@ -0,0 +1,7 @@ +from ray.experimental.client import ray + +from ray.tune import tune + +ray.connect("localhost:50051") + +tune.run("PG", config={"env": "CartPole-v0"}) diff --git a/python/ray/experimental/client/logsclient.py b/python/ray/experimental/client/logsclient.py index f26417e7e..acf2619c9 100644 --- a/python/ray/experimental/client/logsclient.py +++ b/python/ray/experimental/client/logsclient.py @@ -1,5 +1,4 @@ -""" -This file implements a threaded stream controller to return logs back from +"""This file implements a threaded stream controller to return logs back from the ray clientserver. """ import sys @@ -12,6 +11,10 @@ import ray.core.generated.ray_client_pb2 as ray_client_pb2 import ray.core.generated.ray_client_pb2_grpc as ray_client_pb2_grpc logger = logging.getLogger(__name__) +# TODO(barakmich): Running a logger in a logger causes loopback. +# The client logger need its own root -- possibly this one. +# For the moment, let's just not propogate beyond this point. +logger.propagate = False class LogstreamClient: @@ -45,8 +48,7 @@ class LogstreamClient: raise e def log(self, level: int, msg: str): - """ - Log the message from the log stream. + """Log the message from the log stream. By default, calls logger.log but this can be overridden. Args: @@ -56,8 +58,7 @@ class LogstreamClient: logger.log(level=level, msg=msg) def stdstream(self, level: int, msg: str): - """ - Log the stdout/stderr entry from the log stream. + """Log the stdout/stderr entry from the log stream. By default, calls print but this can be overridden. Args: @@ -68,6 +69,7 @@ class LogstreamClient: print(msg, file=print_file) def set_logstream_level(self, level: int): + logger.setLevel(level) req = ray_client_pb2.LogSettingsRequest() req.enabled = True req.loglevel = level diff --git a/python/ray/experimental/client/ray_client_helpers.py b/python/ray/experimental/client/ray_client_helpers.py index ab9d7408a..975918cef 100644 --- a/python/ray/experimental/client/ray_client_helpers.py +++ b/python/ray/experimental/client/ray_client_helpers.py @@ -1,16 +1,17 @@ from contextlib import contextmanager import ray.experimental.client.server.server as ray_client_server -from ray.experimental.client import ray, reset_api +from ray.experimental.client import ray @contextmanager def ray_start_client_server(): - server = ray_client_server.serve("localhost:50051", test_mode=True) + ray._inside_client_test = True + server = ray_client_server.serve("localhost:50051") ray.connect("localhost:50051") try: yield ray finally: + ray._inside_client_test = False ray.disconnect() server.stop(0) - reset_api() diff --git a/python/ray/experimental/client/server/core_ray_api.py b/python/ray/experimental/client/server/core_ray_api.py deleted file mode 100644 index 0762cd0b1..000000000 --- a/python/ray/experimental/client/server/core_ray_api.py +++ /dev/null @@ -1,81 +0,0 @@ -# Along with `api.py` this is the stub that interfaces with -# the real (C-binding, raylet) ray core. -# -# Ideally, the first import line is the only time we actually -# import ray in this library (excluding the main function for the server) -# -# While the stub is trivial, it allows us to check that the calls we're -# making into the core-ray module are contained and well-defined. - -from typing import Any -from typing import Optional -from typing import Union - -import logging -import ray - -from ray.experimental.client.api import APIImpl -from ray.experimental.client.common import ClientObjectRef -from ray.experimental.client.common import ClientStub - -logger = logging.getLogger(__name__) - - -class CoreRayAPI(APIImpl): - """ - Implements the equivalent client-side Ray API by simply passing along to - the Core Ray API. Primarily used inside of Ray Workers as a trampoline back - to core ray when passed client stubs. - """ - - def get(self, vals, *, timeout: Optional[float] = None) -> Any: - return ray.get(vals, timeout=timeout) - - def put(self, vals: Any, *args, - **kwargs) -> Union[ClientObjectRef, ray._raylet.ObjectRef]: - return ray.put(vals, *args, **kwargs) - - def wait(self, *args, **kwargs): - return ray.wait(*args, **kwargs) - - def remote(self, *args, **kwargs): - return ray.remote(*args, **kwargs) - - def call_remote(self, instance: ClientStub, *args, **kwargs): - raise NotImplementedError( - "Should not attempt execution of a client stub inside the raylet") - - def close(self) -> None: - return None - - def kill(self, actor, *, no_restart=True): - return ray.kill(actor, no_restart=no_restart) - - def cancel(self, obj, *, force=False, recursive=True): - return ray.cancel(obj, force=force, recursive=recursive) - - def is_initialized(self) -> bool: - return ray.is_initialized() - - def call_release(self, id: bytes) -> None: - return None - - def call_retain(self, id: bytes) -> None: - return None - - # Allow for generic fallback to ray.* in remote methods. This allows calls - # like ray.nodes() to be run in remote functions even though the client - # doesn't currently support them. - def __getattr__(self, key: str): - return getattr(ray, key) - - -class RayServerAPI(CoreRayAPI): - """ - Ray Client server-side API shim. By default, simply calls the default Core - Ray API calls, but also accepts scheduling calls from functions running - inside of other remote functions that need to create more work. - """ - - def __init__(self, server_instance): - self.server = server_instance diff --git a/python/ray/experimental/client/server/logservicer.py b/python/ray/experimental/client/server/logservicer.py index 9b2fa24bf..25e4ccbd5 100644 --- a/python/ray/experimental/client/server/logservicer.py +++ b/python/ray/experimental/client/server/logservicer.py @@ -1,5 +1,4 @@ -""" -This file responds to log stream requests and forwards logs +"""This file responds to log stream requests and forwards logs with its handler. """ import io @@ -70,6 +69,9 @@ def log_status_change_thread(log_queue, request_iterator): std_handler.register_global() root_logger.addHandler(current_handler) root_logger.setLevel(req.loglevel) + except grpc.RpcError as e: + logger.debug(f"closing log thread " + f"grpc error reading request_iterator: {e}") finally: if current_handler is not None: root_logger.setLevel(default_level) diff --git a/python/ray/experimental/client/server/server.py b/python/ray/experimental/client/server/server.py index 7cc286de8..c1b7d6be8 100644 --- a/python/ray/experimental/client/server/server.py +++ b/python/ray/experimental/client/server/server.py @@ -17,27 +17,25 @@ import ray.core.generated.ray_client_pb2_grpc as ray_client_pb2_grpc import time import inspect import json -from ray.experimental.client import stash_api_for_tests, _set_server_api from ray.experimental.client.server.server_pickler import convert_from_arg from ray.experimental.client.server.server_pickler import dumps_from_server from ray.experimental.client.server.server_pickler import loads_from_client -from ray.experimental.client.server.core_ray_api import RayServerAPI from ray.experimental.client.server.dataservicer import DataServicer from ray.experimental.client.server.logservicer import LogstreamServicer from ray.experimental.client.server.server_stubs import current_remote +from ray._private.client_mode_hook import disable_client_hook logger = logging.getLogger(__name__) class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer): - def __init__(self, test_mode=False): + def __init__(self): self.object_refs: Dict[str, Dict[bytes, ray.ObjectRef]] = defaultdict( dict) self.function_refs = {} self.actor_refs: Dict[bytes, ray.ActorHandle] = {} self.actor_owners: Dict[str, Set[bytes]] = defaultdict(set) self.registered_actor_classes = {} - self._test_mode = test_mode self._current_function_stub = None def ClusterInfo(self, request, @@ -45,7 +43,8 @@ class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer): resp = ray_client_pb2.ClusterInfoResponse() resp.type = request.type if request.type == ray_client_pb2.ClusterInfoType.CLUSTER_RESOURCES: - resources = ray.cluster_resources() + with disable_client_hook(): + resources = ray.cluster_resources() # Normalize resources into floats # (the function may return values that are ints) float_resources = {k: float(v) for k, v in resources.items()} @@ -54,7 +53,8 @@ class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer): table=float_resources)) elif request.type == \ ray_client_pb2.ClusterInfoType.AVAILABLE_RESOURCES: - resources = ray.available_resources() + with disable_client_hook(): + resources = ray.available_resources() # Normalize resources into floats # (the function may return values that are ints) float_resources = {k: float(v) for k, v in resources.items()} @@ -62,7 +62,8 @@ class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer): ray_client_pb2.ClusterInfoResponse.ResourceTable( table=float_resources)) else: - resp.json = self._return_debug_cluster_info(request, context) + with disable_client_hook(): + resp.json = self._return_debug_cluster_info(request, context) return resp def _return_debug_cluster_info(self, request, context=None) -> str: @@ -118,16 +119,18 @@ class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer): try: object_ref = \ self.object_refs[req.client_id][req.task_object.id] - ray.cancel( - object_ref, - force=req.task_object.force, - recursive=req.task_object.recursive) + with disable_client_hook(): + ray.cancel( + object_ref, + force=req.task_object.force, + recursive=req.task_object.recursive) except Exception as e: return_exception_in_context(e, context) elif req.WhichOneof("terminate_type") == "actor": try: actor_ref = self.actor_refs[req.actor.id] - ray.kill(actor_ref, no_restart=req.actor.no_restart) + with disable_client_hook(): + ray.kill(actor_ref, no_restart=req.actor.no_restart) except Exception as e: return_exception_in_context(e, context) else: @@ -145,7 +148,8 @@ class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer): objectref = self.object_refs[client_id][request.id] logger.debug("get: %s" % objectref) try: - item = ray.get(objectref, timeout=request.timeout) + with disable_client_hook(): + item = ray.get(objectref, timeout=request.timeout) except Exception as e: return ray_client_pb2.GetResponse( valid=False, error=cloudpickle.dumps(e)) @@ -171,7 +175,8 @@ class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer): context: gRPC context. """ obj = loads_from_client(request.data, self) - objectref = ray.put(obj) + with disable_client_hook(): + objectref = ray.put(obj) self.object_refs[client_id][objectref.binary()] = objectref logger.debug("put: %s" % objectref) return ray_client_pb2.PutResponse(id=objectref.binary()) @@ -187,11 +192,12 @@ class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer): num_returns = request.num_returns timeout = request.timeout try: - ready_object_refs, remaining_object_refs = ray.wait( - object_refs, - num_returns=num_returns, - timeout=timeout if timeout != -1 else None, - ) + with disable_client_hook(): + ready_object_refs, remaining_object_refs = ray.wait( + object_refs, + num_returns=num_returns, + timeout=timeout if timeout != -1 else None, + ) except Exception as e: # TODO(ameer): improve exception messages. logger.error(f"Exception {e}") @@ -215,8 +221,8 @@ class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer): "schedule: %s %s" % (task.name, ray_client_pb2.ClientTask.RemoteExecType.Name( task.type))) - with stash_api_for_tests(self._test_mode): - try: + try: + with disable_client_hook(): if task.type == ray_client_pb2.ClientTask.FUNCTION: result = self._schedule_function(task, context) elif task.type == ray_client_pb2.ClientTask.ACTOR: @@ -232,11 +238,11 @@ class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer): task.type)) result.valid = True return result - except Exception as e: - logger.error(f"Caught schedule exception {e}") - raise e - return ray_client_pb2.ClientTaskTicket( - valid=False, error=cloudpickle.dumps(e)) + except Exception as e: + logger.error(f"Caught schedule exception {e}") + raise e + return ray_client_pb2.ClientTaskTicket( + valid=False, error=cloudpickle.dumps(e)) def _schedule_method(self, task: ray_client_pb2.ClientTask, context=None) -> ray_client_pb2.ClientTaskTicket: @@ -307,31 +313,33 @@ class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer): def lookup_or_register_func( self, id: bytes, client_id: str, options: Optional[Dict]) -> ray.remote_function.RemoteFunction: - if id not in self.function_refs: - funcref = self.object_refs[client_id][id] - func = ray.get(funcref) - if not inspect.isfunction(func): - raise Exception("Attempting to register function that " - "isn't a function.") - if options is None or len(options) == 0: - self.function_refs[id] = ray.remote(func) - else: - self.function_refs[id] = ray.remote(**options)(func) + with disable_client_hook(): + if id not in self.function_refs: + funcref = self.object_refs[client_id][id] + func = ray.get(funcref) + if not inspect.isfunction(func): + raise Exception("Attempting to register function that " + "isn't a function.") + if options is None or len(options) == 0: + self.function_refs[id] = ray.remote(func) + else: + self.function_refs[id] = ray.remote(**options)(func) return self.function_refs[id] def lookup_or_register_actor(self, id: bytes, client_id: str, options: Optional[Dict]): - if id not in self.registered_actor_classes: - actor_class_ref = self.object_refs[client_id][id] - actor_class = ray.get(actor_class_ref) - if not inspect.isclass(actor_class): - raise Exception("Attempting to schedule actor that " - "isn't a class.") - if options is None or len(options) == 0: - reg_class = ray.remote(actor_class) - else: - reg_class = ray.remote(**options)(actor_class) - self.registered_actor_classes[id] = reg_class + with disable_client_hook(): + if id not in self.registered_actor_classes: + actor_class_ref = self.object_refs[client_id][id] + actor_class = ray.get(actor_class_ref) + if not inspect.isclass(actor_class): + raise Exception("Attempting to schedule actor that " + "isn't a class.") + if options is None or len(options) == 0: + reg_class = ray.remote(actor_class) + else: + reg_class = ray.remote(**options)(actor_class) + self.registered_actor_classes[id] = reg_class return self.registered_actor_classes[id] @@ -369,12 +377,22 @@ def decode_options( return opts -def serve(connection_str, test_mode=False): +_current_servicer: Optional[RayletServicer] = None + + +# Used by tests to peek inside the servicer +def _get_current_servicer(): + global _current_servicer + return _current_servicer + + +def serve(connection_str): server = grpc.server(futures.ThreadPoolExecutor(max_workers=10)) - task_servicer = RayletServicer(test_mode=test_mode) + task_servicer = RayletServicer() data_servicer = DataServicer(task_servicer) logs_servicer = LogstreamServicer() - _set_server_api(RayServerAPI(task_servicer)) + global _current_servicer + _current_servicer = task_servicer ray_client_pb2_grpc.add_RayletDriverServicer_to_server( task_servicer, server) ray_client_pb2_grpc.add_RayletDataStreamerServicer_to_server( @@ -386,12 +404,20 @@ def serve(connection_str, test_mode=False): return server -def init_and_serve(connection_str, test_mode=False, *args, **kwargs): - info = ray.init(*args, **kwargs) - server = serve(connection_str, test_mode) +def init_and_serve(connection_str, *args, **kwargs): + with disable_client_hook(): + # Disable client mode inside the worker's environment + info = ray.init(*args, **kwargs) + server = serve(connection_str) return (server, info) +def shutdown_with_server(server, _exiting_interpreter=False): + server.stop(1) + with disable_client_hook(): + ray.shutdown(_exiting_interpreter) + + if __name__ == "__main__": logging.basicConfig(level="INFO") # TODO(barakmich): Perhaps wrap ray init diff --git a/python/ray/experimental/client/server/server_pickler.py b/python/ray/experimental/client/server/server_pickler.py index 10da70cc1..4f25d728f 100644 --- a/python/ray/experimental/client/server/server_pickler.py +++ b/python/ray/experimental/client/server/server_pickler.py @@ -1,5 +1,4 @@ -""" -Implements the client side of the client/server pickling protocol. +"""Implements the client side of the client/server pickling protocol. These picklers are aware of the server internals and can find the references held for the client within the server. @@ -20,6 +19,7 @@ import ray from typing import Any from typing import TYPE_CHECKING +from ray._private.client_mode_hook import disable_client_hook from ray.experimental.client.client_pickler import PickleStub from ray.experimental.client.server.server_stubs import ( ServerSelfReferenceSentinel) @@ -121,12 +121,13 @@ def loads_from_client(data: bytes, fix_imports=True, encoding="ASCII", errors="strict") -> Any: - if isinstance(data, str): - raise TypeError("Can't load pickle from unicode string") - file = io.BytesIO(data) - return ClientUnpickler( - server_instance, file, fix_imports=fix_imports, - encoding=encoding).load() + with disable_client_hook(): + if isinstance(data, str): + raise TypeError("Can't load pickle from unicode string") + file = io.BytesIO(data) + return ClientUnpickler( + server_instance, file, fix_imports=fix_imports, + encoding=encoding).load() def convert_from_arg(pb: "ray_client_pb2.Arg", diff --git a/python/ray/experimental/client/worker.py b/python/ray/experimental/client/worker.py index 8ed41bff4..b9124a9a7 100644 --- a/python/ray/experimental/client/worker.py +++ b/python/ray/experimental/client/worker.py @@ -3,7 +3,6 @@ It implements the Ray API functions that are forwarded through grpc calls to the server. """ import base64 -import inspect import json import logging import uuid @@ -14,7 +13,6 @@ from typing import List from typing import Tuple from typing import Optional -from ray.util.inspect import is_cython import grpc import ray.cloudpickle as cloudpickle @@ -23,12 +21,9 @@ import ray.core.generated.ray_client_pb2_grpc as ray_client_pb2_grpc from ray.experimental.client.client_pickler import convert_to_arg from ray.experimental.client.client_pickler import dumps_from_client from ray.experimental.client.client_pickler import loads_from_server -from ray.experimental.client.common import ClientActorClass from ray.experimental.client.common import ClientActorHandle from ray.experimental.client.common import ClientActorRef from ray.experimental.client.common import ClientObjectRef -from ray.experimental.client.common import ClientRemoteFunc -from ray.experimental.client.common import ClientStub from ray.experimental.client.dataclient import DataClient from ray.experimental.client.logsclient import LogstreamClient @@ -61,6 +56,7 @@ class Worker: self.log_client = LogstreamClient(self.channel) self.log_client.set_logstream_level(logging.INFO) + self.closed = False def get(self, vals, *, timeout: Optional[float] = None) -> Any: to_get = [] @@ -153,21 +149,6 @@ class Worker: return (client_ready_object_ids, client_remaining_object_ids) - def remote(self, *args, **kwargs): - if len(args) == 1 and len(kwargs) == 0 and callable(args[0]): - # This is the case where the decorator is just @ray.remote. - return remote_decorator(options=None)(args[0]) - error_string = ("The @ray.remote decorator must be applied either " - "with no arguments and no parentheses, for example " - "'@ray.remote', or it must be applied using some of " - "the arguments 'num_returns', 'num_cpus', 'num_gpus', " - "'memory', 'object_store_memory', 'resources', " - "'max_calls', or 'max_restarts', like " - "'@ray.remote(num_returns=2, " - "resources={\"CustomResource\": 1})'.") - assert len(args) == 0 and len(kwargs) > 0, error_string - return remote_decorator(options=kwargs) - def call_remote(self, instance, *args, **kwargs) -> List[bytes]: task = instance._prepare_client_task() for arg in args: @@ -190,6 +171,8 @@ class Worker: return ticket.return_ids def call_release(self, id: bytes) -> None: + if self.closed: + return self.reference_count[id] -= 1 if self.reference_count[id] == 0: self._release_server(id) @@ -212,6 +195,7 @@ class Worker: self.channel.close() self.channel = None self.server = None + self.closed = True def get_actor(self, name: str) -> ClientActorHandle: task = ray_client_pb2.ClientTask() @@ -258,7 +242,9 @@ class Worker: req.type = type resp = self.server.ClusterInfo(req) if resp.WhichOneof("response_type") == "resource_table": - return resp.resource_table.table + # translate from a proto map to a python dict + output_dict = {k: v for k, v in resp.resource_table.table.items()} + return output_dict return json.loads(resp.json) def is_initialized(self) -> bool: @@ -268,20 +254,6 @@ class Worker: return False -def remote_decorator(options: Optional[Dict[str, Any]]): - def decorator(function_or_class) -> ClientStub: - if (inspect.isfunction(function_or_class) - or is_cython(function_or_class)): - return ClientRemoteFunc(function_or_class, options=options) - elif inspect.isclass(function_or_class): - return ClientActorClass(function_or_class, options=options) - else: - raise TypeError("The @ray.remote decorator must be applied to " - "either a function or to a class.") - - return decorator - - def make_client_id() -> str: id = uuid.uuid4() return id.hex diff --git a/python/ray/state.py b/python/ray/state.py index 6d9df7870..aa3488e20 100644 --- a/python/ray/state.py +++ b/python/ray/state.py @@ -9,6 +9,7 @@ import ray from ray import gcs_utils from google.protobuf.json_format import MessageToDict from ray._private import services +from ray._private.client_mode_hook import client_mode_hook from ray.utils import (decode, binary_to_hex, hex_to_binary) from ray._raylet import GlobalStateAccessor @@ -851,6 +852,7 @@ def jobs(): return state.job_table() +@client_mode_hook def nodes(): """Get a list of the nodes in the cluster (for debugging only). @@ -964,6 +966,7 @@ def object_transfer_timeline(filename=None): return state.chrome_tracing_object_transfer_dump(filename=filename) +@client_mode_hook def cluster_resources(): """Get the current total cluster resources. @@ -977,6 +980,7 @@ def cluster_resources(): return state.cluster_resources() +@client_mode_hook def available_resources(): """Get the current available cluster resources. diff --git a/python/ray/test_utils.py b/python/ray/test_utils.py index a479903ff..4185d3f0c 100644 --- a/python/ray/test_utils.py +++ b/python/ray/test_utils.py @@ -446,4 +446,4 @@ def new_scheduler_enabled(): def client_test_enabled() -> bool: - return os.environ.get("RAY_TEST_CLIENT_MODE") == "1" + return os.environ.get("RAY_CLIENT_MODE") == "1" diff --git a/python/ray/tests/BUILD b/python/ray/tests/BUILD index 7e552e616..903377ec8 100644 --- a/python/ray/tests/BUILD +++ b/python/ray/tests/BUILD @@ -167,7 +167,7 @@ py_test_module_list( name_suffix = "_client_mode", # TODO(barakmich): py_test will support env in Bazel 4.0.0... # Until then, we can use tags. - #env = {"RAY_TEST_CLIENT_MODE": "true"}, + #env = {"RAY_CLIENT_MODE": "1"}, tags = ["exclusive", "client_tests"], deps = ["//:ray_lib"], ) diff --git a/python/ray/tests/conftest.py b/python/ray/tests/conftest.py index 05cd9d8ca..4fdfe68c6 100644 --- a/python/ray/tests/conftest.py +++ b/python/ray/tests/conftest.py @@ -9,18 +9,12 @@ import subprocess import ray from ray.cluster_utils import Cluster from ray.test_utils import init_error_pubsub -from ray.test_utils import client_test_enabled -import ray.experimental.client as ray_client @pytest.fixture def shutdown_only(): yield None # The code after the yield will run as teardown code. - if client_test_enabled(): - ray_client.ray.disconnect() - ray_client._stop_test_server(1) - ray_client.reset_api() ray.shutdown() @@ -49,17 +43,10 @@ def _ray_start(**kwargs): init_kwargs = get_default_fixture_ray_kwargs() init_kwargs.update(kwargs) # Start the Ray processes. - if client_test_enabled(): - address_info = ray_client.ray.init(**init_kwargs) - else: - address_info = ray.init(**init_kwargs) + address_info = ray.init(**init_kwargs) yield address_info # The code after the yield will run as teardown code. - if client_test_enabled(): - ray_client.ray.disconnect() - ray_client._stop_test_server(1) - ray_client.reset_api() ray.shutdown() @@ -144,16 +131,9 @@ def _ray_start_cluster(**kwargs): # We assume driver will connect to the head (first node), # so ray init will be invoked if do_init is true if len(remote_nodes) == 1 and do_init: - if client_test_enabled(): - ray_client.ray.init(address=cluster.address) - else: - ray.init(address=cluster.address) + ray.init(address=cluster.address) yield cluster # The code after the yield will run as teardown code. - if client_test_enabled(): - ray_client.ray.disconnect() - ray_client._stop_test_server(1) - ray_client.reset_api() ray.shutdown() cluster.shutdown() diff --git a/python/ray/tests/test_actor.py b/python/ray/tests/test_actor.py index 3ba2ed7eb..4db4bdd4b 100644 --- a/python/ray/tests/test_actor.py +++ b/python/ray/tests/test_actor.py @@ -16,18 +16,12 @@ from ray.test_utils import wait_for_condition from ray.test_utils import wait_for_pid_to_exit from ray.tests.client_test_utils import create_remote_signal_actor -if client_test_enabled(): - from ray.experimental.client import ray -else: - import ray +import ray # NOTE: We have to import setproctitle after ray because we bundle setproctitle # with ray. import setproctitle # noqa -@pytest.mark.skipif( - client_test_enabled(), - reason="defining early, no ray package injection yet") def test_caching_actors(shutdown_only): # Test defining actors before ray.init() has been called. @@ -705,7 +699,6 @@ def test_options_num_returns(ray_start_regular_shared): assert ray.get([obj1, obj2]) == [1, 2] -@pytest.mark.skipif(client_test_enabled(), reason="remote args") def test_options_name(ray_start_regular_shared): @ray.remote class Foo: diff --git a/python/ray/tests/test_advanced.py b/python/ray/tests/test_advanced.py index ea2a6c693..50c27b07a 100644 --- a/python/ray/tests/test_advanced.py +++ b/python/ray/tests/test_advanced.py @@ -354,6 +354,8 @@ def test_illegal_api_calls(ray_start_regular): ray.get(3) +@pytest.mark.skipif( + client_test_enabled(), reason="grpc interaction with releasing resources") def test_multithreading(ray_start_2_cpus): # This test requires at least 2 CPUs to finish since the worker does not # release resources when joining the threads. diff --git a/python/ray/tests/test_basic.py b/python/ray/tests/test_basic.py index 38330645b..7d0e7ae83 100644 --- a/python/ray/tests/test_basic.py +++ b/python/ray/tests/test_basic.py @@ -15,10 +15,7 @@ from ray.test_utils import ( wait_for_pid_to_exit, ) -if client_test_enabled(): - from ray.experimental.client import ray -else: - import ray +import ray logger = logging.getLogger(__name__) diff --git a/python/ray/tests/test_experimental_client.py b/python/ray/tests/test_experimental_client.py index 131954ede..c01030e58 100644 --- a/python/ray/tests/test_experimental_client.py +++ b/python/ray/tests/test_experimental_client.py @@ -160,8 +160,7 @@ def test_basic_actor(ray_start_regular_shared): def test_pass_handles(ray_start_regular_shared): - """ - Test that passing client handles to actors and functions to remote actors + """Test that passing client handles to actors and functions to remote actors in functions (on the server or raylet side) works transparently to the caller. """ @@ -264,9 +263,32 @@ def test_stdout_log_stream(ray_start_regular_shared): assert all((msg.find("Hello world") for msg in log_msgs)) -def test_basic_named_actor(ray_start_regular_shared): +def test_create_remote_before_start(ray_start_regular_shared): + """Creates remote objects (as though in a library) before + starting the client. """ - Test that ray.get_actor() can create and return a detached actor. + from ray.experimental.client import ray + + @ray.remote + class Returner: + def doit(self): + return "foo" + + @ray.remote + def f(x): + return x + 20 + + # Prints in verbose tests + print("Created remote functions") + + with ray_start_client_server() as ray: + assert ray.get(f.remote(3)) == 23 + a = Returner.remote() + assert ray.get(a.doit.remote()) == "foo" + + +def test_basic_named_actor(ray_start_regular_shared): + """Test that ray.get_actor() can create and return a detached actor. """ with ray_start_client_server() as ray: diff --git a/python/ray/tests/test_experimental_client_metadata.py b/python/ray/tests/test_experimental_client_metadata.py index f5a65cd66..a35f01649 100644 --- a/python/ray/tests/test_experimental_client_metadata.py +++ b/python/ray/tests/test_experimental_client_metadata.py @@ -2,8 +2,7 @@ from ray.experimental.client.ray_client_helpers import ray_start_client_server def test_get_ray_metadata(ray_start_regular_shared): - """ - Test the ClusterInfo client data pathway and API surface + """Test the ClusterInfo client data pathway and API surface """ with ray_start_client_server() as ray: ip_address = ray_start_regular_shared["node_ip_address"] diff --git a/python/ray/tests/test_experimental_client_references.py b/python/ray/tests/test_experimental_client_references.py index 4875d1ae0..7e5b4d184 100644 --- a/python/ray/tests/test_experimental_client_references.py +++ b/python/ray/tests/test_experimental_client_references.py @@ -2,11 +2,11 @@ from ray.experimental.client.ray_client_helpers import ray_start_client_server from ray.test_utils import wait_for_condition import ray as real_ray from ray.core.generated.gcs_pb2 import ActorTableData -from ray.experimental.client import _get_server_instance +from ray.experimental.client.server.server import _get_current_servicer def server_object_ref_count(n): - server = _get_server_instance() + server = _get_current_servicer() assert server is not None def test_cond(): @@ -20,7 +20,7 @@ def server_object_ref_count(n): def server_actor_ref_count(n): - server = _get_server_instance() + server = _get_current_servicer() assert server is not None def test_cond(): diff --git a/python/ray/worker.py b/python/ray/worker.py index a3d07e5ee..888cf680b 100644 --- a/python/ray/worker.py +++ b/python/ray/worker.py @@ -51,6 +51,7 @@ from ray.ray_logging import setup_logger from ray.ray_logging import global_worker_stdstream_dispatcher from ray.utils import _random_string, check_oversized_pickle from ray.util.inspect import is_cython +from ray._private.client_mode_hook import client_mode_hook SCRIPT_MODE = 0 WORKER_MODE = 1 @@ -469,6 +470,7 @@ _global_node = None """ray.node.Node: The global node object that is created by ray.init().""" +@client_mode_hook def init( address=None, *, @@ -781,6 +783,7 @@ def init( _post_init_hooks = [] +@client_mode_hook def shutdown(_exiting_interpreter=False): """Disconnect the worker, and terminate processes started by ray.init(). @@ -1044,6 +1047,7 @@ def listen_error_messages_raylet(worker, task_error_queue, threads_stopped): worker.error_message_pubsub_client.close() +@client_mode_hook def is_initialized(): """Check if ray.init has been called yet. @@ -1322,6 +1326,7 @@ def show_in_dashboard(message, key="", dtype="text"): blocking_get_inside_async_warned = False +@client_mode_hook def get(object_refs, *, timeout=None): """Get a remote object or a list of remote objects from the object store. @@ -1400,6 +1405,7 @@ def get(object_refs, *, timeout=None): return values +@client_mode_hook def put(value): """Store an object in the object store. @@ -1428,6 +1434,7 @@ def put(value): blocking_wait_inside_async_warned = False +@client_mode_hook def wait(object_refs, *, num_returns=1, timeout=None, fetch_local=True): """Return a list of IDs that are ready and a list of IDs that are not. @@ -1528,6 +1535,7 @@ def wait(object_refs, *, num_returns=1, timeout=None, fetch_local=True): return ready_ids, remaining_ids +@client_mode_hook def get_actor(name): """Get a handle to a detached actor. @@ -1548,6 +1556,7 @@ def get_actor(name): return handle +@client_mode_hook def kill(actor, *, no_restart=True): """Kill an actor forcefully. @@ -1575,6 +1584,7 @@ def kill(actor, *, no_restart=True): worker.core_worker.kill_actor(actor._ray_actor_id, no_restart) +@client_mode_hook def cancel(object_ref, *, force=False, recursive=True): """Cancels a task according to the following conditions. @@ -1691,6 +1701,7 @@ def make_decorator(num_returns=None, return decorator +@client_mode_hook def remote(*args, **kwargs): """Defines a remote function or an actor class. From 6e19facc7f38ac3f23768080acb34bea98570443 Mon Sep 17 00:00:00 2001 From: "DK.Pino" Date: Wed, 23 Dec 2020 20:31:46 +0800 Subject: [PATCH 73/88] [GCS] Delete redis gcs client and redis_xxx_accessor (#12996) --- BUILD.bazel | 106 -- ci/travis/ci.sh | 1 - src/ray/core_worker/actor_handle.cc | 4 +- src/ray/core_worker/actor_handle.h | 3 +- src/ray/core_worker/actor_manager.cc | 11 +- src/ray/core_worker/actor_manager.h | 4 +- src/ray/core_worker/core_worker.cc | 4 +- src/ray/core_worker/core_worker.h | 3 +- src/ray/core_worker/profiling.h | 2 +- .../core_worker/test/actor_manager_test.cc | 22 +- .../test/direct_actor_transport_test.cc | 5 - .../transport/direct_actor_transport.h | 1 - .../gcs/gcs_client/global_state_accessor.cc | 2 +- .../gcs/gcs_client/service_based_accessor.cc | 2 + .../gcs/gcs_client/service_based_accessor.h | 29 +- .../gcs_client/service_based_gcs_client.cc | 18 +- .../gcs/gcs_client/service_based_gcs_client.h | 7 +- src/ray/gcs/gcs_server/gcs_actor_manager.h | 1 - src/ray/gcs/gcs_server/gcs_job_manager.h | 1 - src/ray/gcs/gcs_server/gcs_object_manager.h | 1 - src/ray/gcs/gcs_server/gcs_server.cc | 19 +- src/ray/gcs/gcs_server/gcs_server.h | 4 +- src/ray/gcs/gcs_server/gcs_worker_manager.h | 1 - src/ray/gcs/gcs_server/stats_handler_impl.h | 1 - .../gcs/gcs_server/task_info_handler_impl.h | 1 - .../test/gcs_object_manager_test.cc | 1 - src/ray/gcs/redis_accessor.cc | 697 -------- src/ray/gcs/redis_accessor.h | 491 ------ src/ray/gcs/redis_gcs_client.cc | 144 -- src/ray/gcs/redis_gcs_client.h | 131 -- src/ray/gcs/subscription_executor.cc | 215 --- src/ray/gcs/subscription_executor.h | 108 -- src/ray/gcs/tables.cc | 847 ---------- src/ray/gcs/tables.h | 978 ----------- src/ray/gcs/test/accessor_test_base.h | 95 -- .../test/redis_actor_info_accessor_test.cc | 82 - src/ray/gcs/test/redis_gcs_client_test.cc | 1505 ----------------- .../gcs/test/redis_job_info_accessor_test.cc | 99 -- .../gcs/test/redis_node_info_accessor_test.cc | 181 -- .../test/redis_object_info_accessor_test.cc | 160 -- src/ray/object_manager/object_directory.h | 2 +- .../ownership_based_object_directory.h | 2 +- .../test/object_manager_stress_test.cc | 33 +- .../test/object_manager_test.cc | 14 +- src/ray/raylet/node_manager.cc | 6 +- src/ray/raylet/raylet.cc | 4 +- src/ray/raylet/reconstruction_policy.h | 2 +- src/ray/raylet/reconstruction_policy_test.cc | 19 +- src/ray/raylet/task_dependency_manager.h | 1 - .../raylet/task_dependency_manager_test.cc | 2 - src/ray/raylet/worker_pool.h | 2 +- src/ray/test/run_object_manager_tests.sh | 8 +- 52 files changed, 131 insertions(+), 5951 deletions(-) delete mode 100644 src/ray/gcs/redis_accessor.cc delete mode 100644 src/ray/gcs/redis_accessor.h delete mode 100644 src/ray/gcs/redis_gcs_client.cc delete mode 100644 src/ray/gcs/redis_gcs_client.h delete mode 100644 src/ray/gcs/subscription_executor.cc delete mode 100644 src/ray/gcs/subscription_executor.h delete mode 100644 src/ray/gcs/tables.cc delete mode 100644 src/ray/gcs/tables.h delete mode 100644 src/ray/gcs/test/accessor_test_base.h delete mode 100644 src/ray/gcs/test/redis_actor_info_accessor_test.cc delete mode 100644 src/ray/gcs/test/redis_gcs_client_test.cc delete mode 100644 src/ray/gcs/test/redis_job_info_accessor_test.cc delete mode 100644 src/ray/gcs/test/redis_node_info_accessor_test.cc delete mode 100644 src/ray/gcs/test/redis_object_info_accessor_test.cc diff --git a/BUILD.bazel b/BUILD.bazel index 16b9a315f..8782dbdf8 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -1020,7 +1020,6 @@ cc_test( cc_library( name = "gcs_test_util_lib", hdrs = [ - "src/ray/gcs/test/accessor_test_base.h", "src/ray/gcs/test/gcs_test_util.h", ], copts = COPTS, @@ -1621,111 +1620,6 @@ cc_library( ], ) -# TODO(micafan) Support test group in future. Use test group we can run all gcs test once. -cc_test( - name = "redis_gcs_client_test", - srcs = ["src/ray/gcs/test/redis_gcs_client_test.cc"], - args = [ - "$(location redis-server)", - "$(location redis-cli)", - "$(location libray_redis_module.so)", - ], - copts = COPTS, - data = [ - "//:libray_redis_module.so", - "//:redis-cli", - "//:redis-server", - ], - deps = [ - ":gcs", - "@com_google_googletest//:gtest_main", - ], -) - -cc_test( - name = "redis_actor_info_accessor_test", - srcs = ["src/ray/gcs/test/redis_actor_info_accessor_test.cc"], - args = [ - "$(location redis-server)", - "$(location redis-cli)", - "$(location libray_redis_module.so)", - ], - copts = COPTS, - data = [ - "//:libray_redis_module.so", - "//:redis-cli", - "//:redis-server", - ], - deps = [ - ":gcs", - ":gcs_test_util_lib", - "@com_google_googletest//:gtest_main", - ], -) - -cc_test( - name = "redis_object_info_accessor_test", - srcs = ["src/ray/gcs/test/redis_object_info_accessor_test.cc"], - args = [ - "$(location redis-server)", - "$(location redis-cli)", - "$(location libray_redis_module.so)", - ], - copts = COPTS, - data = [ - "//:libray_redis_module.so", - "//:redis-cli", - "//:redis-server", - ], - deps = [ - ":gcs", - ":gcs_test_util_lib", - "@com_google_googletest//:gtest_main", - ], -) - -cc_test( - name = "redis_job_info_accessor_test", - srcs = ["src/ray/gcs/test/redis_job_info_accessor_test.cc"], - args = [ - "$(location redis-server)", - "$(location redis-cli)", - "$(location libray_redis_module.so)", - ], - copts = COPTS, - data = [ - "//:libray_redis_module.so", - "//:redis-cli", - "//:redis-server", - ], - deps = [ - ":gcs", - ":gcs_test_util_lib", - "@com_google_googletest//:gtest_main", - ], -) - -cc_test( - name = "redis_node_info_accessor_test", - srcs = ["src/ray/gcs/test/redis_node_info_accessor_test.cc"], - args = [ - "$(location redis-server)", - "$(location redis-cli)", - "$(location libray_redis_module.so)", - ], - copts = COPTS, - data = [ - "//:libray_redis_module.so", - "//:redis-cli", - "//:redis-server", - ], - deps = [ - ":gcs", - ":gcs_test_util_lib", - "@com_google_googletest//:gtest_main", - ], -) - cc_test( name = "asio_test", srcs = ["src/ray/gcs/test/asio_test.cc"], diff --git a/ci/travis/ci.sh b/ci/travis/ci.sh index e4d9741cd..9a8c0ecbf 100755 --- a/ci/travis/ci.sh +++ b/ci/travis/ci.sh @@ -120,7 +120,6 @@ test_core() { case "${OSTYPE}" in msys) args+=( - -//:redis_gcs_client_test -//:core_worker_test -//:event_test -//:gcs_pub_sub_test diff --git a/src/ray/core_worker/actor_handle.cc b/src/ray/core_worker/actor_handle.cc index 73e56df54..5448b7057 100644 --- a/src/ray/core_worker/actor_handle.cc +++ b/src/ray/core_worker/actor_handle.cc @@ -45,7 +45,7 @@ ray::rpc::ActorHandle CreateInnerActorHandleFromString(const std::string &serial } ray::rpc::ActorHandle CreateInnerActorHandleFromActorTableData( - const ray::gcs::ActorTableData &actor_table_data) { + const ray::rpc::ActorTableData &actor_table_data) { ray::rpc::ActorHandle inner; inner.set_actor_id(actor_table_data.actor_id()); inner.set_owner_id(actor_table_data.parent_id()); @@ -80,7 +80,7 @@ ActorHandle::ActorHandle( ActorHandle::ActorHandle(const std::string &serialized) : ActorHandle(CreateInnerActorHandleFromString(serialized)) {} -ActorHandle::ActorHandle(const gcs::ActorTableData &actor_table_data) +ActorHandle::ActorHandle(const rpc::ActorTableData &actor_table_data) : ActorHandle(CreateInnerActorHandleFromActorTableData(actor_table_data)) {} void ActorHandle::SetActorTaskSpec(TaskSpecBuilder &builder, const ObjectID new_cursor) { diff --git a/src/ray/core_worker/actor_handle.h b/src/ray/core_worker/actor_handle.h index 12b47cb53..e23929303 100644 --- a/src/ray/core_worker/actor_handle.h +++ b/src/ray/core_worker/actor_handle.h @@ -20,7 +20,6 @@ #include "ray/common/task/task_util.h" #include "ray/core_worker/common.h" #include "ray/core_worker/context.h" -#include "ray/gcs/redis_gcs_client.h" #include "src/ray/protobuf/core_worker.pb.h" #include "src/ray/protobuf/gcs.pb.h" @@ -42,7 +41,7 @@ class ActorHandle { ActorHandle(const std::string &serialized); /// Constructs an ActorHandle from a gcs::ActorTableData message. - ActorHandle(const gcs::ActorTableData &actor_table_data); + ActorHandle(const rpc::ActorTableData &actor_table_data); ActorID GetActorID() const { return ActorID::FromBinary(inner_.actor_id()); }; diff --git a/src/ray/core_worker/actor_manager.cc b/src/ray/core_worker/actor_manager.cc index 6b931082a..73ca9ec34 100644 --- a/src/ray/core_worker/actor_manager.cc +++ b/src/ray/core_worker/actor_manager.cc @@ -15,7 +15,6 @@ #include "ray/core_worker/actor_manager.h" #include "ray/gcs/pb_util.h" -#include "ray/gcs/redis_accessor.h" namespace ray { @@ -124,8 +123,8 @@ void ActorManager::WaitForActorOutOfScope( } void ActorManager::HandleActorStateNotification(const ActorID &actor_id, - const gcs::ActorTableData &actor_data) { - const auto &actor_state = gcs::ActorTableData::ActorState_Name(actor_data.state()); + const rpc::ActorTableData &actor_data) { + const auto &actor_state = rpc::ActorTableData::ActorState_Name(actor_data.state()); RAY_LOG(INFO) << "received notification on actor, state: " << actor_state << ", actor_id: " << actor_id << ", ip address: " << actor_data.address().ip_address() @@ -133,14 +132,14 @@ void ActorManager::HandleActorStateNotification(const ActorID &actor_id, << WorkerID::FromBinary(actor_data.address().worker_id()) << ", raylet_id: " << NodeID::FromBinary(actor_data.address().raylet_id()) << ", num_restarts: " << actor_data.num_restarts(); - if (actor_data.state() == gcs::ActorTableData::RESTARTING) { + if (actor_data.state() == rpc::ActorTableData::RESTARTING) { direct_actor_submitter_->DisconnectActor(actor_id, actor_data.num_restarts(), false); - } else if (actor_data.state() == gcs::ActorTableData::DEAD) { + } else if (actor_data.state() == rpc::ActorTableData::DEAD) { direct_actor_submitter_->DisconnectActor(actor_id, actor_data.num_restarts(), true); // We cannot erase the actor handle here because clients can still // submit tasks to dead actors. This also means we defer unsubscription, // otherwise we crash when bulk unsubscribing all actor handles. - } else if (actor_data.state() == gcs::ActorTableData::ALIVE) { + } else if (actor_data.state() == rpc::ActorTableData::ALIVE) { direct_actor_submitter_->ConnectActor(actor_id, actor_data.address(), actor_data.num_restarts()); } else { diff --git a/src/ray/core_worker/actor_manager.h b/src/ray/core_worker/actor_manager.h index e3c72913a..ff47b7403 100644 --- a/src/ray/core_worker/actor_manager.h +++ b/src/ray/core_worker/actor_manager.h @@ -18,7 +18,7 @@ #include "ray/core_worker/actor_handle.h" #include "ray/core_worker/reference_count.h" #include "ray/core_worker/transport/direct_actor_transport.h" -#include "ray/gcs/redis_gcs_client.h" +#include "ray/gcs/gcs_client.h" namespace ray { @@ -177,7 +177,7 @@ class ActorManager { /// \param[in] actor_id The actor id of this notification. /// \param[in] actor_data The GCS actor data. void HandleActorStateNotification(const ActorID &actor_id, - const gcs::ActorTableData &actor_data); + const rpc::ActorTableData &actor_data); /// GCS client. std::shared_ptr gcs_client_; diff --git a/src/ray/core_worker/core_worker.cc b/src/ray/core_worker/core_worker.cc index d2ab2c150..219b5e062 100644 --- a/src/ray/core_worker/core_worker.cc +++ b/src/ray/core_worker/core_worker.cc @@ -464,7 +464,7 @@ CoreWorker::CoreWorker(const CoreWorkerOptions &options, const WorkerID &worker_ TaskID::ComputeDriverTaskId(worker_context_.GetWorkerID()), GetCallerId(), rpc_address_); - std::shared_ptr data = std::make_shared(); + std::shared_ptr data = std::make_shared(); data->mutable_task()->mutable_task_spec()->CopyFrom(builder.Build().GetMessage()); if (!options_.is_local_mode) { RAY_CHECK_OK(gcs_client_->Tasks().AsyncAdd(data, nullptr)); @@ -1639,7 +1639,7 @@ std::pair CoreWorker::GetNamedActorHandle( std::make_shared>(std::promise()); RAY_CHECK_OK(gcs_client_->Actors().AsyncGetByName( name, [this, &actor_id, name, ready_promise]( - Status status, const boost::optional &result) { + Status status, const boost::optional &result) { if (status.ok() && result) { auto actor_handle = std::unique_ptr(new ActorHandle(*result)); actor_id = actor_handle->GetActorID(); diff --git a/src/ray/core_worker/core_worker.h b/src/ray/core_worker/core_worker.h index 14136a895..256f0c42d 100644 --- a/src/ray/core_worker/core_worker.h +++ b/src/ray/core_worker/core_worker.h @@ -30,8 +30,7 @@ #include "ray/core_worker/store_provider/plasma_store_provider.h" #include "ray/core_worker/transport/direct_actor_transport.h" #include "ray/core_worker/transport/direct_task_transport.h" -#include "ray/gcs/redis_gcs_client.h" -#include "ray/gcs/subscription_executor.h" +#include "ray/gcs/gcs_client.h" #include "ray/raylet_client/raylet_client.h" #include "ray/rpc/node_manager/node_manager_client.h" #include "ray/rpc/worker/core_worker_client.h" diff --git a/src/ray/core_worker/profiling.h b/src/ray/core_worker/profiling.h index 908fb77a3..24c15a29c 100644 --- a/src/ray/core_worker/profiling.h +++ b/src/ray/core_worker/profiling.h @@ -18,7 +18,7 @@ #include "absl/synchronization/mutex.h" #include "absl/time/clock.h" #include "ray/core_worker/context.h" -#include "ray/gcs/redis_gcs_client.h" +#include "ray/gcs/gcs_client.h" namespace ray { diff --git a/src/ray/core_worker/test/actor_manager_test.cc b/src/ray/core_worker/test/actor_manager_test.cc index 06cb9a70e..cd4a21408 100644 --- a/src/ray/core_worker/test/actor_manager_test.cc +++ b/src/ray/core_worker/test/actor_manager_test.cc @@ -20,17 +20,17 @@ #include "ray/common/test_util.h" #include "ray/core_worker/reference_count.h" #include "ray/core_worker/transport/direct_actor_transport.h" -#include "ray/gcs/redis_accessor.h" -#include "ray/gcs/redis_gcs_client.h" +#include "ray/gcs/gcs_client/service_based_accessor.h" +#include "ray/gcs/gcs_client/service_based_gcs_client.h" namespace ray { using ::testing::_; -class MockActorInfoAccessor : public gcs::RedisActorInfoAccessor { +class MockActorInfoAccessor : public gcs::ServiceBasedActorInfoAccessor { public: - MockActorInfoAccessor(gcs::RedisGcsClient *client) - : gcs::RedisActorInfoAccessor(client) {} + MockActorInfoAccessor(gcs::ServiceBasedGcsClient *client) + : gcs::ServiceBasedActorInfoAccessor(client) {} ~MockActorInfoAccessor() {} @@ -44,7 +44,7 @@ class MockActorInfoAccessor : public gcs::RedisActorInfoAccessor { } bool ActorStateNotificationPublished(const ActorID &actor_id, - const gcs::ActorTableData &actor_data) { + const rpc::ActorTableData &actor_data) { auto it = callback_map_.find(actor_id); if (it == callback_map_.end()) return false; auto actor_state_notification_callback = it->second; @@ -60,15 +60,13 @@ class MockActorInfoAccessor : public gcs::RedisActorInfoAccessor { callback_map_; }; -class MockGcsClient : public gcs::RedisGcsClient { +class MockGcsClient : public gcs::ServiceBasedGcsClient { public: - MockGcsClient(const gcs::GcsClientOptions &options) : gcs::RedisGcsClient(options) {} + MockGcsClient(gcs::GcsClientOptions options) : gcs::ServiceBasedGcsClient(options) {} - void Init(MockActorInfoAccessor *actor_accesor_mock) { - actor_accessor_.reset(actor_accesor_mock); + void Init(MockActorInfoAccessor *actor_info_accessor) { + actor_accessor_.reset(actor_info_accessor); } - - ~MockGcsClient() {} }; class MockDirectActorSubmitter : public CoreWorkerDirectActorTaskSubmitterInterface { diff --git a/src/ray/core_worker/test/direct_actor_transport_test.cc b/src/ray/core_worker/test/direct_actor_transport_test.cc index dffb8c4b5..8c196163e 100644 --- a/src/ray/core_worker/test/direct_actor_transport_test.cc +++ b/src/ray/core_worker/test/direct_actor_transport_test.cc @@ -223,7 +223,6 @@ TEST_F(DirectActorSubmitterTest, TestActorDead) { addr.set_worker_id(worker_id.Binary()); ActorID actor_id = ActorID::Of(JobID::FromInt(0), TaskID::Nil(), 0); submitter_.AddActorQueueIfNotExists(actor_id); - gcs::ActorTableData actor_data; submitter_.ConnectActor(actor_id, addr, 0); ASSERT_EQ(worker_client_->callbacks.size(), 0); @@ -256,7 +255,6 @@ TEST_F(DirectActorSubmitterTest, TestActorRestartNoRetry) { addr.set_worker_id(worker_id.Binary()); ActorID actor_id = ActorID::Of(JobID::FromInt(0), TaskID::Nil(), 0); submitter_.AddActorQueueIfNotExists(actor_id); - gcs::ActorTableData actor_data; addr.set_port(0); submitter_.ConnectActor(actor_id, addr, 0); ASSERT_EQ(worker_client_->callbacks.size(), 0); @@ -299,7 +297,6 @@ TEST_F(DirectActorSubmitterTest, TestActorRestartRetry) { addr.set_worker_id(worker_id.Binary()); ActorID actor_id = ActorID::Of(JobID::FromInt(0), TaskID::Nil(), 0); submitter_.AddActorQueueIfNotExists(actor_id); - gcs::ActorTableData actor_data; addr.set_port(0); submitter_.ConnectActor(actor_id, addr, 0); ASSERT_EQ(worker_client_->callbacks.size(), 0); @@ -351,7 +348,6 @@ TEST_F(DirectActorSubmitterTest, TestActorRestartOutOfOrderRetry) { addr.set_worker_id(worker_id.Binary()); ActorID actor_id = ActorID::Of(JobID::FromInt(0), TaskID::Nil(), 0); submitter_.AddActorQueueIfNotExists(actor_id); - gcs::ActorTableData actor_data; addr.set_port(0); submitter_.ConnectActor(actor_id, addr, 0); ASSERT_EQ(worker_client_->callbacks.size(), 0); @@ -401,7 +397,6 @@ TEST_F(DirectActorSubmitterTest, TestActorRestartOutOfOrderGcs) { addr.set_worker_id(worker_id.Binary()); ActorID actor_id = ActorID::Of(JobID::FromInt(0), TaskID::Nil(), 0); submitter_.AddActorQueueIfNotExists(actor_id); - gcs::ActorTableData actor_data; addr.set_port(0); submitter_.ConnectActor(actor_id, addr, 0); ASSERT_EQ(worker_client_->callbacks.size(), 0); diff --git a/src/ray/core_worker/transport/direct_actor_transport.h b/src/ray/core_worker/transport/direct_actor_transport.h index cb7637c9f..ab28dc85a 100644 --- a/src/ray/core_worker/transport/direct_actor_transport.h +++ b/src/ray/core_worker/transport/direct_actor_transport.h @@ -32,7 +32,6 @@ #include "ray/core_worker/store_provider/memory_store/memory_store.h" #include "ray/core_worker/task_manager.h" #include "ray/core_worker/transport/dependency_resolver.h" -#include "ray/gcs/redis_gcs_client.h" #include "ray/rpc/grpc_server.h" #include "ray/rpc/worker/core_worker_client.h" diff --git a/src/ray/gcs/gcs_client/global_state_accessor.cc b/src/ray/gcs/gcs_client/global_state_accessor.cc index 8d188ba07..5791515bc 100644 --- a/src/ray/gcs/gcs_client/global_state_accessor.cc +++ b/src/ray/gcs/gcs_client/global_state_accessor.cc @@ -225,7 +225,7 @@ std::vector GlobalStateAccessor::GetAllWorkerInfo() { } bool GlobalStateAccessor::AddWorkerInfo(const std::string &serialized_string) { - auto data_ptr = std::make_shared(); + auto data_ptr = std::make_shared(); data_ptr->ParseFromString(serialized_string); std::promise promise; RAY_CHECK_OK( diff --git a/src/ray/gcs/gcs_client/service_based_accessor.cc b/src/ray/gcs/gcs_client/service_based_accessor.cc index 7e7d67d44..0e610a68e 100644 --- a/src/ray/gcs/gcs_client/service_based_accessor.cc +++ b/src/ray/gcs/gcs_client/service_based_accessor.cc @@ -19,6 +19,8 @@ namespace ray { namespace gcs { +using namespace ray::rpc; + ServiceBasedJobInfoAccessor::ServiceBasedJobInfoAccessor( ServiceBasedGcsClient *client_impl) : client_impl_(client_impl) {} diff --git a/src/ray/gcs/gcs_client/service_based_accessor.h b/src/ray/gcs/gcs_client/service_based_accessor.h index 05f2d4316..167814bb2 100644 --- a/src/ray/gcs/gcs_client/service_based_accessor.h +++ b/src/ray/gcs/gcs_client/service_based_accessor.h @@ -16,7 +16,6 @@ #include "ray/common/task/task_spec.h" #include "ray/gcs/accessor.h" -#include "ray/gcs/subscription_executor.h" #include "ray/util/sequencer.h" #include "src/ray/protobuf/gcs_service.pb.h" @@ -38,12 +37,12 @@ class ServiceBasedJobInfoAccessor : public JobInfoAccessor { virtual ~ServiceBasedJobInfoAccessor() = default; - Status AsyncAdd(const std::shared_ptr &data_ptr, + Status AsyncAdd(const std::shared_ptr &data_ptr, const StatusCallback &callback) override; Status AsyncMarkFinished(const JobID &job_id, const StatusCallback &callback) override; - Status AsyncSubscribeAll(const SubscribeCallback &subscribe, + Status AsyncSubscribeAll(const SubscribeCallback &subscribe, const StatusCallback &done) override; Status AsyncGetAll(const MultiItemCallback &callback) override; @@ -71,7 +70,7 @@ class ServiceBasedActorInfoAccessor : public ActorInfoAccessor { virtual ~ServiceBasedActorInfoAccessor() = default; - Status GetAll(std::vector *actor_table_data_list) override; + Status GetAll(std::vector *actor_table_data_list) override; Status AsyncGet(const ActorID &actor_id, const OptionalItemCallback &callback) override; @@ -136,30 +135,30 @@ class ServiceBasedNodeInfoAccessor : public NodeInfoAccessor { virtual ~ServiceBasedNodeInfoAccessor() = default; - Status RegisterSelf(const GcsNodeInfo &local_node_info, + Status RegisterSelf(const rpc::GcsNodeInfo &local_node_info, const StatusCallback &callback) override; Status UnregisterSelf() override; const NodeID &GetSelfId() const override; - const GcsNodeInfo &GetSelfInfo() const override; + const rpc::GcsNodeInfo &GetSelfInfo() const override; Status AsyncRegister(const rpc::GcsNodeInfo &node_info, const StatusCallback &callback) override; Status AsyncUnregister(const NodeID &node_id, const StatusCallback &callback) override; - Status AsyncGetAll(const MultiItemCallback &callback) override; + Status AsyncGetAll(const MultiItemCallback &callback) override; Status AsyncSubscribeToNodeChange( - const SubscribeCallback &subscribe, + const SubscribeCallback &subscribe, const StatusCallback &done) override; - boost::optional Get(const NodeID &node_id, - bool filter_dead_nodes = false) const override; + boost::optional Get(const NodeID &node_id, + bool filter_dead_nodes = false) const override; - const std::unordered_map &GetAll() const override; + const std::unordered_map &GetAll() const override; bool IsRemoved(const NodeID &node_id) const override; @@ -207,21 +206,21 @@ class ServiceBasedNodeInfoAccessor : public NodeInfoAccessor { /// from a failure. rpc::ReportResourceUsageRequest cached_resource_usage_ GUARDED_BY(mutex_); - void HandleNotification(const GcsNodeInfo &node_info); + void HandleNotification(const rpc::GcsNodeInfo &node_info); ServiceBasedGcsClient *client_impl_; using NodeChangeCallback = - std::function; + std::function; - GcsNodeInfo local_node_info_; + rpc::GcsNodeInfo local_node_info_; NodeID local_node_id_; /// The callback to call when a new node is added or a node is removed. NodeChangeCallback node_change_callback_{nullptr}; /// A cache for information about all nodes. - std::unordered_map node_cache_; + std::unordered_map node_cache_; /// The set of removed nodes. std::unordered_set removed_nodes_; }; diff --git a/src/ray/gcs/gcs_client/service_based_gcs_client.cc b/src/ray/gcs/gcs_client/service_based_gcs_client.cc index f643496b8..900d3e50d 100644 --- a/src/ray/gcs/gcs_client/service_based_gcs_client.cc +++ b/src/ray/gcs/gcs_client/service_based_gcs_client.cc @@ -37,21 +37,23 @@ Status ServiceBasedGcsClient::Connect(boost::asio::io_service &io_service) { return Status::Invalid("gcs service address is invalid!"); } - // Connect to gcs. - redis_gcs_client_.reset(new RedisGcsClient(options_)); - RAY_CHECK_OK(redis_gcs_client_->Connect(io_service)); + // Connect to redis. + RedisClientOptions redis_client_options(options_.server_ip_, options_.server_port_, + options_.password_, options_.is_test_client_); + redis_client_.reset(new RedisClient(redis_client_options)); + RAY_CHECK_OK(redis_client_->Connect(io_service)); // Init gcs pub sub instance. - gcs_pub_sub_.reset(new GcsPubSub(redis_gcs_client_->GetRedisClient())); + gcs_pub_sub_.reset(new GcsPubSub(redis_client_)); // Get gcs service address. get_server_address_func_ = [this](std::pair *address) { return GetGcsServerAddressFromRedis( - redis_gcs_client_->primary_context()->sync_context(), address); + redis_client_->GetPrimaryContext()->sync_context(), address); }; std::pair address; RAY_CHECK(GetGcsServerAddressFromRedis( - redis_gcs_client_->primary_context()->sync_context(), &address, + redis_client_->GetPrimaryContext()->sync_context(), &address, RayConfig::instance().gcs_service_connect_retries())) << "Failed to get gcs server address when init gcs client."; @@ -96,8 +98,8 @@ void ServiceBasedGcsClient::Disconnect() { is_connected_ = false; detect_timer_->cancel(); gcs_pub_sub_.reset(); - redis_gcs_client_->Disconnect(); - redis_gcs_client_.reset(); + redis_client_->Disconnect(); + redis_client_.reset(); RAY_LOG(DEBUG) << "ServiceBasedGcsClient Disconnected."; } diff --git a/src/ray/gcs/gcs_client/service_based_gcs_client.h b/src/ray/gcs/gcs_client/service_based_gcs_client.h index 906165099..9b0e79806 100644 --- a/src/ray/gcs/gcs_client/service_based_gcs_client.h +++ b/src/ray/gcs/gcs_client/service_based_gcs_client.h @@ -14,8 +14,9 @@ #pragma once +#include "ray/gcs/gcs_client.h" #include "ray/gcs/pubsub/gcs_pub_sub.h" -#include "ray/gcs/redis_gcs_client.h" +#include "ray/gcs/redis_client.h" #include "ray/rpc/gcs_server/gcs_rpc_client.h" namespace ray { @@ -31,8 +32,6 @@ class RAY_EXPORT ServiceBasedGcsClient : public GcsClient { GcsPubSub &GetGcsPubSub() { return *gcs_pub_sub_; } - RedisGcsClient &GetRedisGcsClient() { return *redis_gcs_client_; } - rpc::GcsRpcClient &GetGcsRpcClient() { return *gcs_rpc_client_; } private: @@ -59,7 +58,7 @@ class RAY_EXPORT ServiceBasedGcsClient : public GcsClient { /// Reconnect to GCS RPC server. void ReconnectGcsServer(); - std::unique_ptr redis_gcs_client_; + std::shared_ptr redis_client_; std::unique_ptr gcs_pub_sub_; diff --git a/src/ray/gcs/gcs_server/gcs_actor_manager.h b/src/ray/gcs/gcs_server/gcs_actor_manager.h index c2f23ac2d..e10be2fe8 100644 --- a/src/ray/gcs/gcs_server/gcs_actor_manager.h +++ b/src/ray/gcs/gcs_server/gcs_actor_manager.h @@ -24,7 +24,6 @@ #include "ray/gcs/gcs_server/gcs_init_data.h" #include "ray/gcs/gcs_server/gcs_table_storage.h" #include "ray/gcs/pubsub/gcs_pub_sub.h" -#include "ray/gcs/redis_gcs_client.h" #include "ray/rpc/gcs_server/gcs_rpc_server.h" #include "ray/rpc/worker/core_worker_client.h" #include "src/ray/protobuf/gcs_service.pb.h" diff --git a/src/ray/gcs/gcs_server/gcs_job_manager.h b/src/ray/gcs/gcs_server/gcs_job_manager.h index 24d8f7dfe..da8628967 100644 --- a/src/ray/gcs/gcs_server/gcs_job_manager.h +++ b/src/ray/gcs/gcs_server/gcs_job_manager.h @@ -17,7 +17,6 @@ #include "ray/gcs/gcs_server/gcs_object_manager.h" #include "ray/gcs/gcs_server/gcs_table_storage.h" #include "ray/gcs/pubsub/gcs_pub_sub.h" -#include "ray/gcs/redis_gcs_client.h" #include "ray/rpc/gcs_server/gcs_rpc_server.h" namespace ray { diff --git a/src/ray/gcs/gcs_server/gcs_object_manager.h b/src/ray/gcs/gcs_server/gcs_object_manager.h index 4d728e8e0..bd21bfd1b 100644 --- a/src/ray/gcs/gcs_server/gcs_object_manager.h +++ b/src/ray/gcs/gcs_server/gcs_object_manager.h @@ -18,7 +18,6 @@ #include "ray/gcs/gcs_server/gcs_node_manager.h" #include "ray/gcs/gcs_server/gcs_table_storage.h" #include "ray/gcs/pubsub/gcs_pub_sub.h" -#include "ray/gcs/redis_gcs_client.h" namespace ray { diff --git a/src/ray/gcs/gcs_server/gcs_server.cc b/src/ray/gcs/gcs_server/gcs_server.cc index 71e2a6d81..672c593df 100644 --- a/src/ray/gcs/gcs_server/gcs_server.cc +++ b/src/ray/gcs/gcs_server/gcs_server.cc @@ -43,23 +43,22 @@ GcsServer::~GcsServer() { Stop(); } void GcsServer::Start() { // Init backend client. - GcsClientOptions options(config_.redis_address, config_.redis_port, - config_.redis_password, config_.is_test); - redis_gcs_client_ = std::make_shared(options); - auto status = redis_gcs_client_->Connect(main_service_); + RedisClientOptions redis_client_options(config_.redis_address, config_.redis_port, + config_.redis_password, config_.is_test); + redis_client_ = std::make_shared(redis_client_options); + auto status = redis_client_->Connect(main_service_); RAY_CHECK(status.ok()) << "Failed to init redis gcs client as " << status; // Init redis failure detector. gcs_redis_failure_detector_ = std::make_shared( - main_service_, redis_gcs_client_->primary_context(), [this]() { Stop(); }); + main_service_, redis_client_->GetPrimaryContext(), [this]() { Stop(); }); gcs_redis_failure_detector_->Start(); // Init gcs pub sub instance. - gcs_pub_sub_ = std::make_shared(redis_gcs_client_->GetRedisClient()); + gcs_pub_sub_ = std::make_shared(redis_client_); // Init gcs table storage. - gcs_table_storage_ = - std::make_shared(redis_gcs_client_->GetRedisClient()); + gcs_table_storage_ = std::make_shared(redis_client_); // Load gcs tables data asynchronously. auto gcs_init_data = std::make_shared(gcs_table_storage_); @@ -132,7 +131,7 @@ void GcsServer::Stop() { } void GcsServer::InitGcsNodeManager(const GcsInitData &gcs_init_data) { - RAY_CHECK(redis_gcs_client_ && gcs_table_storage_ && gcs_pub_sub_); + RAY_CHECK(redis_client_ && gcs_table_storage_ && gcs_pub_sub_); gcs_node_manager_ = std::make_shared(main_service_, gcs_pub_sub_, gcs_table_storage_); // Initialize by gcs tables data. @@ -255,7 +254,7 @@ void GcsServer::StoreGcsServerAddressInRedis() { std::string address = ip + ":" + std::to_string(GetPort()); RAY_LOG(INFO) << "Gcs server address = " << address; - RAY_CHECK_OK(redis_gcs_client_->primary_context()->RunArgvAsync( + RAY_CHECK_OK(redis_client_->GetPrimaryContext()->RunArgvAsync( {"SET", "GcsServerAddress", address})); RAY_LOG(INFO) << "Finished setting gcs server address: " << address; } diff --git a/src/ray/gcs/gcs_server/gcs_server.h b/src/ray/gcs/gcs_server/gcs_server.h index a2082539f..1527ca7cf 100644 --- a/src/ray/gcs/gcs_server/gcs_server.h +++ b/src/ray/gcs/gcs_server/gcs_server.h @@ -21,7 +21,7 @@ #include "ray/gcs/gcs_server/gcs_resource_manager.h" #include "ray/gcs/gcs_server/gcs_table_storage.h" #include "ray/gcs/pubsub/gcs_pub_sub.h" -#include "ray/gcs/redis_gcs_client.h" +#include "ray/gcs/redis_client.h" #include "ray/rpc/client_call.h" #include "ray/rpc/gcs_server/gcs_rpc_server.h" #include "ray/rpc/node_manager/node_manager_client_pool.h" @@ -176,7 +176,7 @@ class GcsServer { /// Placement Group info handler and service std::unique_ptr placement_group_info_service_; /// Backend client - std::shared_ptr redis_gcs_client_; + std::shared_ptr redis_client_; /// A publisher for publishing gcs messages. std::shared_ptr gcs_pub_sub_; /// The gcs table storage. diff --git a/src/ray/gcs/gcs_server/gcs_worker_manager.h b/src/ray/gcs/gcs_server/gcs_worker_manager.h index 094e881e6..60001aa12 100644 --- a/src/ray/gcs/gcs_server/gcs_worker_manager.h +++ b/src/ray/gcs/gcs_server/gcs_worker_manager.h @@ -16,7 +16,6 @@ #include "ray/gcs/gcs_server/gcs_table_storage.h" #include "ray/gcs/pubsub/gcs_pub_sub.h" -#include "ray/gcs/redis_gcs_client.h" #include "ray/rpc/gcs_server/gcs_rpc_server.h" namespace ray { diff --git a/src/ray/gcs/gcs_server/stats_handler_impl.h b/src/ray/gcs/gcs_server/stats_handler_impl.h index d9de7e40b..2e065c621 100644 --- a/src/ray/gcs/gcs_server/stats_handler_impl.h +++ b/src/ray/gcs/gcs_server/stats_handler_impl.h @@ -16,7 +16,6 @@ #include "ray/common/ray_config.h" #include "ray/gcs/gcs_server/gcs_table_storage.h" -#include "ray/gcs/redis_gcs_client.h" #include "ray/rpc/gcs_server/gcs_rpc_server.h" namespace ray { diff --git a/src/ray/gcs/gcs_server/task_info_handler_impl.h b/src/ray/gcs/gcs_server/task_info_handler_impl.h index 5a7599e8f..c32eb4894 100644 --- a/src/ray/gcs/gcs_server/task_info_handler_impl.h +++ b/src/ray/gcs/gcs_server/task_info_handler_impl.h @@ -16,7 +16,6 @@ #include "ray/gcs/gcs_server/gcs_table_storage.h" #include "ray/gcs/pubsub/gcs_pub_sub.h" -#include "ray/gcs/redis_gcs_client.h" #include "ray/rpc/gcs_server/gcs_rpc_server.h" namespace ray { diff --git a/src/ray/gcs/gcs_server/test/gcs_object_manager_test.cc b/src/ray/gcs/gcs_server/test/gcs_object_manager_test.cc index 700fdfc10..f6842d287 100644 --- a/src/ray/gcs/gcs_server/test/gcs_object_manager_test.cc +++ b/src/ray/gcs/gcs_server/test/gcs_object_manager_test.cc @@ -86,7 +86,6 @@ class GcsObjectManagerTest : public ::testing::Test { boost::asio::io_service io_service_; std::shared_ptr gcs_resource_manager_; std::shared_ptr gcs_node_manager_; - std::shared_ptr gcs_client_; std::shared_ptr gcs_pub_sub_; std::shared_ptr gcs_object_manager_; std::shared_ptr gcs_table_storage_; diff --git a/src/ray/gcs/redis_accessor.cc b/src/ray/gcs/redis_accessor.cc deleted file mode 100644 index 248eb9a89..000000000 --- a/src/ray/gcs/redis_accessor.cc +++ /dev/null @@ -1,697 +0,0 @@ -// Copyright 2017 The Ray Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "ray/gcs/redis_accessor.h" - -#include - -#include "ray/gcs/pb_util.h" -#include "ray/gcs/redis_gcs_client.h" -#include "ray/util/logging.h" - -namespace ray { - -namespace gcs { - -RedisLogBasedActorInfoAccessor::RedisLogBasedActorInfoAccessor( - RedisGcsClient *client_impl) - : client_impl_(client_impl), - log_based_actor_sub_executor_(client_impl_->log_based_actor_table()) {} - -std::vector RedisLogBasedActorInfoAccessor::GetAllActorID() const { - return client_impl_->log_based_actor_table().GetAllActorID(); -} - -Status RedisLogBasedActorInfoAccessor::Get(const ActorID &actor_id, - ActorTableData *actor_table_data) const { - return client_impl_->log_based_actor_table().Get(actor_id, actor_table_data); -} - -Status RedisLogBasedActorInfoAccessor::GetAll( - std::vector *actor_table_data_list) { - RAY_CHECK(actor_table_data_list); - auto actor_id_list = GetAllActorID(); - actor_table_data_list->resize(actor_id_list.size()); - for (size_t i = 0; i < actor_id_list.size(); ++i) { - RAY_CHECK_OK(Get(actor_id_list[i], &(*actor_table_data_list)[i])); - } - return Status::OK(); -} - -Status RedisLogBasedActorInfoAccessor::AsyncGet( - const ActorID &actor_id, const OptionalItemCallback &callback) { - RAY_CHECK(callback != nullptr); - auto on_done = [callback](RedisGcsClient *client, const ActorID &actor_id, - const std::vector &data) { - boost::optional result; - if (!data.empty()) { - result = data.back(); - } - callback(Status::OK(), result); - }; - - return client_impl_->log_based_actor_table().Lookup(actor_id.JobId(), actor_id, - on_done); -} - -Status RedisLogBasedActorInfoAccessor::AsyncRegisterActor( - const ray::TaskSpecification &task_spec, const ray::gcs::StatusCallback &callback) { - const std::string error_msg = - "Unsupported method of AsyncRegisterActor in RedisLogBasedActorInfoAccessor."; - RAY_LOG(FATAL) << error_msg; - return Status::Invalid(error_msg); -} - -Status RedisLogBasedActorInfoAccessor::AsyncCreateActor( - const ray::TaskSpecification &task_spec, const ray::gcs::StatusCallback &callback) { - const std::string error_msg = - "Unsupported method of AsyncCreateActor in " - "RedisLogBasedActorInfoAccessor."; - RAY_LOG(FATAL) << error_msg; - return Status::Invalid(error_msg); -} - -Status RedisLogBasedActorInfoAccessor::AsyncSubscribeAll( - const SubscribeCallback &subscribe, - const StatusCallback &done) { - RAY_CHECK(subscribe != nullptr); - return log_based_actor_sub_executor_.AsyncSubscribeAll(NodeID::Nil(), subscribe, done); -} - -Status RedisLogBasedActorInfoAccessor::AsyncSubscribe( - const ActorID &actor_id, const SubscribeCallback &subscribe, - const StatusCallback &done) { - RAY_CHECK(subscribe != nullptr); - return log_based_actor_sub_executor_.AsyncSubscribe(subscribe_id_, actor_id, subscribe, - done); -} - -Status RedisLogBasedActorInfoAccessor::AsyncUnsubscribe(const ActorID &actor_id) { - return log_based_actor_sub_executor_.AsyncUnsubscribe(subscribe_id_, actor_id, nullptr); -} - -RedisActorInfoAccessor::RedisActorInfoAccessor(RedisGcsClient *client_impl) - : RedisLogBasedActorInfoAccessor(client_impl), - actor_sub_executor_(client_impl_->actor_table()) {} - -std::vector RedisActorInfoAccessor::GetAllActorID() const { - return client_impl_->actor_table().GetAllActorID(); -} - -Status RedisActorInfoAccessor::Get(const ActorID &actor_id, - ActorTableData *actor_table_data) const { - return client_impl_->actor_table().Get(actor_id, actor_table_data); -} - -Status RedisActorInfoAccessor::AsyncGet( - const ActorID &actor_id, const OptionalItemCallback &callback) { - RAY_CHECK(callback != nullptr); - auto on_done = [callback](RedisGcsClient *client, const ActorID &actor_id, - const ActorTableData &data) { callback(Status::OK(), data); }; - - auto on_failure = [callback](RedisGcsClient *client, const ActorID &actor_id) { - if (callback != nullptr) { - callback(Status::Invalid("Get actor failed."), boost::none); - } - }; - - return client_impl_->actor_table().Lookup(JobID::Nil(), actor_id, on_done, on_failure); -} - -Status RedisActorInfoAccessor::AsyncGetAll( - const MultiItemCallback &callback) { - RAY_CHECK(callback != nullptr); - auto actor_id_list = GetAllActorID(); - if (actor_id_list.empty()) { - callback(Status::OK(), std::vector()); - return Status::OK(); - } - - auto finished_count = std::make_shared(0); - auto result = std::make_shared>(); - int size = actor_id_list.size(); - for (auto &actor_id : actor_id_list) { - auto on_done = [finished_count, size, result, callback]( - const Status &status, - const boost::optional &data) { - ++(*finished_count); - if (data) { - result->push_back(*data); - } - if (*finished_count == size) { - callback(Status::OK(), *result); - } - }; - RAY_CHECK_OK(AsyncGet(actor_id, on_done)); - } - - return Status::OK(); -} - -Status RedisActorInfoAccessor::AsyncSubscribeAll( - const SubscribeCallback &subscribe, - const StatusCallback &done) { - RAY_CHECK(subscribe != nullptr); - return actor_sub_executor_.AsyncSubscribeAll(NodeID::Nil(), subscribe, done); -} - -Status RedisActorInfoAccessor::AsyncSubscribe( - const ActorID &actor_id, const SubscribeCallback &subscribe, - const StatusCallback &done) { - RAY_CHECK(subscribe != nullptr); - return actor_sub_executor_.AsyncSubscribe(subscribe_id_, actor_id, subscribe, done); -} - -Status RedisActorInfoAccessor::AsyncUnsubscribe(const ActorID &actor_id) { - return actor_sub_executor_.AsyncUnsubscribe(subscribe_id_, actor_id, nullptr); -} - -RedisJobInfoAccessor::RedisJobInfoAccessor(RedisGcsClient *client_impl) - : client_impl_(client_impl), job_sub_executor_(client_impl->job_table()) {} - -Status RedisJobInfoAccessor::AsyncAdd(const std::shared_ptr &data_ptr, - const StatusCallback &callback) { - return DoAsyncAppend(data_ptr, callback); -} - -Status RedisJobInfoAccessor::AsyncMarkFinished(const JobID &job_id, - const StatusCallback &callback) { - std::shared_ptr data_ptr = - CreateJobTableData(job_id, /*is_dead*/ true, /*time_stamp*/ std::time(nullptr), - /*driver_ip_address*/ "", /*driver_pid*/ -1); - return DoAsyncAppend(data_ptr, callback); -} - -Status RedisJobInfoAccessor::DoAsyncAppend(const std::shared_ptr &data_ptr, - const StatusCallback &callback) { - JobTable::WriteCallback on_done = nullptr; - if (callback != nullptr) { - on_done = [callback](RedisGcsClient *client, const JobID &job_id, - const JobTableData &data) { callback(Status::OK()); }; - } - - JobID job_id = JobID::FromBinary(data_ptr->job_id()); - return client_impl_->job_table().Append(job_id, job_id, data_ptr, on_done); -} - -Status RedisJobInfoAccessor::AsyncSubscribeAll( - const SubscribeCallback &subscribe, const StatusCallback &done) { - RAY_CHECK(subscribe != nullptr); - return job_sub_executor_.AsyncSubscribeAll(NodeID::Nil(), subscribe, done); -} - -RedisTaskInfoAccessor::RedisTaskInfoAccessor(RedisGcsClient *client_impl) - : client_impl_(client_impl), - task_sub_executor_(client_impl->raylet_task_table()), - task_lease_sub_executor_(client_impl->task_lease_table()) {} - -Status RedisTaskInfoAccessor::AsyncAdd(const std::shared_ptr &data_ptr, - const StatusCallback &callback) { - raylet::TaskTable::WriteCallback on_done = nullptr; - if (callback != nullptr) { - on_done = [callback](RedisGcsClient *client, const TaskID &task_id, - const TaskTableData &data) { callback(Status::OK()); }; - } - - TaskID task_id = TaskID::FromBinary(data_ptr->task().task_spec().task_id()); - raylet::TaskTable &task_table = client_impl_->raylet_task_table(); - return task_table.Add(task_id.JobId(), task_id, data_ptr, on_done); -} - -Status RedisTaskInfoAccessor::AsyncGet( - const TaskID &task_id, const OptionalItemCallback &callback) { - RAY_CHECK(callback != nullptr); - auto on_success = [callback](RedisGcsClient *client, const TaskID &task_id, - const TaskTableData &data) { - boost::optional result(data); - callback(Status::OK(), result); - }; - - auto on_failure = [callback](RedisGcsClient *client, const TaskID &task_id) { - boost::optional result; - callback(Status::Invalid("Task not exist."), result); - }; - - raylet::TaskTable &task_table = client_impl_->raylet_task_table(); - return task_table.Lookup(task_id.JobId(), task_id, on_success, on_failure); -} - -Status RedisTaskInfoAccessor::AsyncSubscribe( - const TaskID &task_id, const SubscribeCallback &subscribe, - const StatusCallback &done) { - RAY_CHECK(subscribe != nullptr); - return task_sub_executor_.AsyncSubscribe(subscribe_id_, task_id, subscribe, done); -} - -Status RedisTaskInfoAccessor::AsyncUnsubscribe(const TaskID &task_id) { - return task_sub_executor_.AsyncUnsubscribe(subscribe_id_, task_id, nullptr); -} - -Status RedisTaskInfoAccessor::AsyncAddTaskLease( - const std::shared_ptr &data_ptr, const StatusCallback &callback) { - TaskLeaseTable::WriteCallback on_done = nullptr; - if (callback != nullptr) { - on_done = [callback](RedisGcsClient *client, const TaskID &id, - const TaskLeaseData &data) { callback(Status::OK()); }; - } - TaskID task_id = TaskID::FromBinary(data_ptr->task_id()); - TaskLeaseTable &task_lease_table = client_impl_->task_lease_table(); - return task_lease_table.Add(task_id.JobId(), task_id, data_ptr, on_done); -} - -Status RedisTaskInfoAccessor::AsyncGetTaskLease( - const TaskID &task_id, const OptionalItemCallback &callback) { - RAY_CHECK(callback != nullptr); - auto on_success = [callback](RedisGcsClient *client, const TaskID &task_id, - const TaskLeaseData &data) { - boost::optional result(data); - callback(Status::OK(), result); - }; - - auto on_failure = [callback](RedisGcsClient *client, const TaskID &task_id) { - boost::optional result; - callback(Status::Invalid("Task lease not exist."), result); - }; - - TaskLeaseTable &task_lease_table = client_impl_->task_lease_table(); - return task_lease_table.Lookup(task_id.JobId(), task_id, on_success, on_failure); -} - -Status RedisTaskInfoAccessor::AsyncSubscribeTaskLease( - const TaskID &task_id, - const SubscribeCallback> &subscribe, - const StatusCallback &done) { - RAY_CHECK(subscribe != nullptr); - return task_lease_sub_executor_.AsyncSubscribe(subscribe_id_, task_id, subscribe, done); -} - -Status RedisTaskInfoAccessor::AsyncUnsubscribeTaskLease(const TaskID &task_id) { - return task_lease_sub_executor_.AsyncUnsubscribe(subscribe_id_, task_id, nullptr); -} - -Status RedisTaskInfoAccessor::AttemptTaskReconstruction( - const std::shared_ptr &data_ptr, - const StatusCallback &callback) { - TaskReconstructionLog::WriteCallback on_success = nullptr; - TaskReconstructionLog::WriteCallback on_failure = nullptr; - if (callback != nullptr) { - on_success = [callback](RedisGcsClient *client, const TaskID &id, - const TaskReconstructionData &data) { - callback(Status::OK()); - }; - on_failure = [callback](RedisGcsClient *client, const TaskID &id, - const TaskReconstructionData &data) { - callback(Status::Invalid("Updating task reconstruction failed.")); - }; - } - - TaskID task_id = TaskID::FromBinary(data_ptr->task_id()); - int reconstruction_attempt = data_ptr->num_reconstructions(); - TaskReconstructionLog &task_reconstruction_log = - client_impl_->task_reconstruction_log(); - return task_reconstruction_log.AppendAt(task_id.JobId(), task_id, data_ptr, on_success, - on_failure, reconstruction_attempt); -} - -RedisObjectInfoAccessor::RedisObjectInfoAccessor(RedisGcsClient *client_impl) - : client_impl_(client_impl), object_sub_executor_(client_impl->object_table()) {} - -Status RedisObjectInfoAccessor::AsyncGetLocations( - const ObjectID &object_id, - const OptionalItemCallback &callback) { - RAY_CHECK(callback != nullptr); - auto on_done = [callback](RedisGcsClient *client, const ObjectID &object_id, - const std::vector &data) { - rpc::ObjectLocationInfo info; - info.set_object_id(object_id.Binary()); - for (const auto &item : data) { - auto item_ptr = info.add_locations(); - item_ptr->CopyFrom(item); - } - callback(Status::OK(), info); - }; - - ObjectTable &object_table = client_impl_->object_table(); - return object_table.Lookup(object_id.TaskId().JobId(), object_id, on_done); -} - -Status RedisObjectInfoAccessor::AsyncAddLocation(const ObjectID &object_id, - const NodeID &node_id, - const StatusCallback &callback) { - std::function - on_done = nullptr; - if (callback != nullptr) { - on_done = [callback](RedisGcsClient *client, const ObjectID &object_id, - const ObjectTableData &data) { callback(Status::OK()); }; - } - - std::shared_ptr data_ptr = std::make_shared(); - data_ptr->set_manager(node_id.Binary()); - - ObjectTable &object_table = client_impl_->object_table(); - return object_table.Add(object_id.TaskId().JobId(), object_id, data_ptr, on_done); -} - -Status RedisObjectInfoAccessor::AsyncRemoveLocation(const ObjectID &object_id, - const NodeID &node_id, - const StatusCallback &callback) { - std::function - on_done = nullptr; - if (callback != nullptr) { - on_done = [callback](RedisGcsClient *client, const ObjectID &object_id, - const ObjectTableData &data) { callback(Status::OK()); }; - } - - std::shared_ptr data_ptr = std::make_shared(); - data_ptr->set_manager(node_id.Binary()); - - ObjectTable &object_table = client_impl_->object_table(); - return object_table.Remove(object_id.TaskId().JobId(), object_id, data_ptr, on_done); -} - -Status RedisObjectInfoAccessor::AsyncSubscribeToLocations( - const ObjectID &object_id, - const SubscribeCallback> &subscribe, - const StatusCallback &done) { - RAY_CHECK(subscribe != nullptr); - return object_sub_executor_.AsyncSubscribe( - subscribe_id_, object_id, - [subscribe](const ObjectID &id, const ObjectChangeNotification ¬ification_data) { - std::vector updates; - for (const auto &item : notification_data.GetData()) { - rpc::ObjectLocationChange update; - update.set_is_add(notification_data.IsAdded()); - update.set_node_id(item.manager()); - updates.push_back(update); - } - subscribe(id, updates); - }, - done); -} - -Status RedisObjectInfoAccessor::AsyncUnsubscribeToLocations(const ObjectID &object_id) { - return object_sub_executor_.AsyncUnsubscribe(subscribe_id_, object_id, nullptr); -} - -RedisNodeInfoAccessor::RedisNodeInfoAccessor(RedisGcsClient *client_impl) - : client_impl_(client_impl), - resource_usage_batch_sub_executor_(client_impl->resource_usage_batch_table()) {} - -Status RedisNodeInfoAccessor::RegisterSelf(const GcsNodeInfo &local_node_info, - const StatusCallback &callback) { - NodeTable &node_table = client_impl_->node_table(); - Status status = node_table.Connect(local_node_info); - if (callback != nullptr) { - callback(Status::OK()); - } - return status; -} - -Status RedisNodeInfoAccessor::UnregisterSelf() { - NodeTable &node_table = client_impl_->node_table(); - return node_table.Disconnect(); -} - -const NodeID &RedisNodeInfoAccessor::GetSelfId() const { - NodeTable &node_table = client_impl_->node_table(); - return node_table.GetLocalNodeId(); -} - -const GcsNodeInfo &RedisNodeInfoAccessor::GetSelfInfo() const { - NodeTable &node_table = client_impl_->node_table(); - return node_table.GetLocalNode(); -} - -Status RedisNodeInfoAccessor::AsyncRegister(const GcsNodeInfo &node_info, - const StatusCallback &callback) { - NodeTable::WriteCallback on_done = nullptr; - if (callback != nullptr) { - on_done = [callback](RedisGcsClient *client, const NodeID &id, - const GcsNodeInfo &data) { callback(Status::OK()); }; - } - NodeTable &node_table = client_impl_->node_table(); - return node_table.MarkConnected(node_info, on_done); -} - -Status RedisNodeInfoAccessor::AsyncUnregister(const NodeID &node_id, - const StatusCallback &callback) { - NodeTable::WriteCallback on_done = nullptr; - if (callback != nullptr) { - on_done = [callback](RedisGcsClient *client, const NodeID &id, - const GcsNodeInfo &data) { callback(Status::OK()); }; - } - NodeTable &node_table = client_impl_->node_table(); - return node_table.MarkDisconnected(node_id, on_done); -} - -Status RedisNodeInfoAccessor::AsyncSubscribeToNodeChange( - const SubscribeCallback &subscribe, const StatusCallback &done) { - RAY_CHECK(subscribe != nullptr); - NodeTable &node_table = client_impl_->node_table(); - return node_table.SubscribeToNodeChange(subscribe, done); -} - -Status RedisNodeInfoAccessor::AsyncGetAll( - const MultiItemCallback &callback) { - RAY_CHECK(callback != nullptr); - auto on_done = [callback](RedisGcsClient *client, const NodeID &id, - const std::vector &data) { - std::vector result; - std::set node_ids; - for (int index = data.size() - 1; index >= 0; --index) { - if (node_ids.insert(data[index].node_id()).second) { - result.emplace_back(data[index]); - } - } - callback(Status::OK(), result); - }; - NodeTable &node_table = client_impl_->node_table(); - return node_table.Lookup(on_done); -} - -boost::optional RedisNodeInfoAccessor::Get(const NodeID &node_id, - bool filter_dead_nodes) const { - GcsNodeInfo node_info; - NodeTable &node_table = client_impl_->node_table(); - bool found = node_table.GetNode(node_id, &node_info); - boost::optional optional_node; - if (found) { - optional_node = std::move(node_info); - } - return optional_node; -} - -const std::unordered_map &RedisNodeInfoAccessor::GetAll() const { - NodeTable &node_table = client_impl_->node_table(); - return node_table.GetAllNodes(); -} - -bool RedisNodeInfoAccessor::IsRemoved(const NodeID &node_id) const { - NodeTable &node_table = client_impl_->node_table(); - return node_table.IsRemoved(node_id); -} -Status RedisNodeInfoAccessor::AsyncReportHeartbeat( - const std::shared_ptr &data_ptr, const StatusCallback &callback) { - HeartbeatTable::WriteCallback on_done = nullptr; - if (callback != nullptr) { - on_done = [callback](RedisGcsClient *client, const NodeID &node_id, - const HeartbeatTableData &data) { callback(Status::OK()); }; - } - - NodeID node_id = NodeID::FromBinary(data_ptr->node_id()); - HeartbeatTable &heartbeat_table = client_impl_->heartbeat_table(); - return heartbeat_table.Add(JobID::Nil(), node_id, data_ptr, on_done); -} - -Status RedisNodeInfoAccessor::AsyncReportResourceUsage( - const std::shared_ptr &data_ptr, const StatusCallback &callback) { - return Status::Invalid("Not implemented"); -} - -void RedisNodeInfoAccessor::AsyncReReportResourceUsage() {} - -Status RedisNodeInfoAccessor::AsyncSubscribeBatchedResourceUsage( - const ItemCallback &subscribe, const StatusCallback &done) { - RAY_CHECK(subscribe != nullptr); - auto on_subscribe = [subscribe](const NodeID &node_id, - const ResourceUsageBatchData &data) { - subscribe(data); - }; - - return resource_usage_batch_sub_executor_.AsyncSubscribeAll(NodeID::Nil(), on_subscribe, - done); -} - -RedisNodeResourceInfoAccessor::RedisNodeResourceInfoAccessor(RedisGcsClient *client_impl) - : client_impl_(client_impl), resource_sub_executor_(client_impl_->resource_table()) {} - -Status RedisNodeResourceInfoAccessor::AsyncGetResources( - const NodeID &node_id, const OptionalItemCallback &callback) { - RAY_CHECK(callback != nullptr); - auto on_done = [callback](RedisGcsClient *client, const NodeID &id, - const ResourceMap &data) { - boost::optional result; - if (!data.empty()) { - result = data; - } - callback(Status::OK(), result); - }; - - DynamicResourceTable &resource_table = client_impl_->resource_table(); - return resource_table.Lookup(JobID::Nil(), node_id, on_done); -} - -Status RedisNodeResourceInfoAccessor::AsyncUpdateResources( - const NodeID &node_id, const ResourceMap &resources, const StatusCallback &callback) { - Hash::HashCallback on_done = nullptr; - if (callback != nullptr) { - on_done = [callback](RedisGcsClient *client, const NodeID &node_id, - const ResourceMap &resources) { callback(Status::OK()); }; - } - - DynamicResourceTable &resource_table = client_impl_->resource_table(); - return resource_table.Update(JobID::Nil(), node_id, resources, on_done); -} - -Status RedisNodeResourceInfoAccessor::AsyncDeleteResources( - const NodeID &node_id, const std::vector &resource_names, - const StatusCallback &callback) { - Hash::HashRemoveCallback on_done = nullptr; - if (callback != nullptr) { - on_done = [callback](RedisGcsClient *client, const NodeID &node_id, - const std::vector &resource_names) { - callback(Status::OK()); - }; - } - - DynamicResourceTable &resource_table = client_impl_->resource_table(); - return resource_table.RemoveEntries(JobID::Nil(), node_id, resource_names, on_done); -} - -Status RedisNodeResourceInfoAccessor::AsyncSubscribeToResources( - const ItemCallback &subscribe, const StatusCallback &done) { - RAY_CHECK(subscribe != nullptr); - auto on_subscribe = [subscribe](const NodeID &id, - const ResourceChangeNotification &result) { - rpc::NodeResourceChange node_resource_change; - node_resource_change.set_node_id(id.Binary()); - if (result.IsAdded()) { - for (auto &it : result.GetData()) { - (*node_resource_change.mutable_updated_resources())[it.first] = - it.second->resource_capacity(); - } - } else { - for (auto &it : result.GetData()) { - node_resource_change.add_deleted_resources(it.first); - } - } - subscribe(node_resource_change); - }; - return resource_sub_executor_.AsyncSubscribeAll(NodeID::Nil(), on_subscribe, done); -} - -RedisErrorInfoAccessor::RedisErrorInfoAccessor(RedisGcsClient *client_impl) {} - -Status RedisErrorInfoAccessor::AsyncReportJobError( - const std::shared_ptr &data_ptr, const StatusCallback &callback) { - return Status::Invalid("Not implemented"); -} - -RedisStatsInfoAccessor::RedisStatsInfoAccessor(RedisGcsClient *client_impl) - : client_impl_(client_impl) {} - -Status RedisStatsInfoAccessor::AsyncAddProfileData( - const std::shared_ptr &data_ptr, const StatusCallback &callback) { - ProfileTable::WriteCallback on_done = nullptr; - if (callback != nullptr) { - on_done = [callback](RedisGcsClient *client, const UniqueID &id, - const ProfileTableData &data) { callback(Status::OK()); }; - } - - ProfileTable &profile_table = client_impl_->profile_table(); - return profile_table.Append(JobID::Nil(), UniqueID::FromRandom(), data_ptr, on_done); -} - -RedisWorkerInfoAccessor::RedisWorkerInfoAccessor(RedisGcsClient *client_impl) - : client_impl_(client_impl), - worker_failure_sub_executor_(client_impl->worker_table()) {} - -Status RedisWorkerInfoAccessor::AsyncSubscribeToWorkerFailures( - const SubscribeCallback &subscribe, - const StatusCallback &done) { - RAY_CHECK(subscribe != nullptr); - return worker_failure_sub_executor_.AsyncSubscribeAll(NodeID::Nil(), subscribe, done); -} - -Status RedisWorkerInfoAccessor::AsyncReportWorkerFailure( - const std::shared_ptr &data_ptr, const StatusCallback &callback) { - WorkerTable::WriteCallback on_done = nullptr; - if (callback != nullptr) { - on_done = [callback](RedisGcsClient *client, const WorkerID &id, - const WorkerTableData &data) { callback(Status::OK()); }; - } - - WorkerID worker_id = WorkerID::FromBinary(data_ptr->worker_address().worker_id()); - WorkerTable &worker_failure_table = client_impl_->worker_table(); - return worker_failure_table.Add(JobID::Nil(), worker_id, data_ptr, on_done); -} - -Status RedisWorkerInfoAccessor::AsyncGet( - const WorkerID &worker_id, - const OptionalItemCallback &callback) { - return Status::Invalid("Not implemented"); -} - -Status RedisWorkerInfoAccessor::AsyncGetAll( - const MultiItemCallback &callback) { - return Status::Invalid("Not implemented"); -} - -Status RedisWorkerInfoAccessor::AsyncAdd( - const std::shared_ptr &data_ptr, - const StatusCallback &callback) { - return Status::Invalid("Not implemented"); -} - -Status RedisPlacementGroupInfoAccessor::AsyncCreatePlacementGroup( - const PlacementGroupSpecification &placement_group_spec) { - return Status::Invalid("Not implemented"); -} - -Status RedisPlacementGroupInfoAccessor::AsyncRemovePlacementGroup( - const PlacementGroupID &placement_group_id, const StatusCallback &callback) { - return Status::Invalid("Not implemented"); -} - -Status RedisPlacementGroupInfoAccessor::AsyncGet( - const PlacementGroupID &placement_group_id, - const OptionalItemCallback &callback) { - return Status::Invalid("Not implemented"); -} - -Status RedisPlacementGroupInfoAccessor::AsyncGetAll( - const MultiItemCallback &callback) { - return Status::Invalid("Not implemented"); -} - -Status RedisPlacementGroupInfoAccessor::AsyncWaitUntilReady( - const PlacementGroupID &placement_group_id, const StatusCallback &callback) { - return Status::Invalid("Not implemented"); -} - -} // namespace gcs - -} // namespace ray diff --git a/src/ray/gcs/redis_accessor.h b/src/ray/gcs/redis_accessor.h deleted file mode 100644 index ec5d389f6..000000000 --- a/src/ray/gcs/redis_accessor.h +++ /dev/null @@ -1,491 +0,0 @@ -// Copyright 2017 The Ray Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include "ray/common/id.h" -#include "ray/common/task/task_spec.h" -#include "ray/gcs/accessor.h" -#include "ray/gcs/callback.h" -#include "ray/gcs/subscription_executor.h" -#include "ray/gcs/tables.h" - -namespace ray { - -namespace gcs { - -class RedisGcsClient; - -/// \class RedisLogBasedActorInfoAccessor -/// `RedisLogBasedActorInfoAccessor` is an implementation of `ActorInfoAccessor` -/// that uses Redis as the backend storage. -class RedisLogBasedActorInfoAccessor : public ActorInfoAccessor { - public: - explicit RedisLogBasedActorInfoAccessor(RedisGcsClient *client_impl); - - virtual ~RedisLogBasedActorInfoAccessor() {} - - Status GetAll(std::vector *actor_table_data_list) override; - - Status AsyncGet(const ActorID &actor_id, - const OptionalItemCallback &callback) override; - - Status AsyncGetAll(const MultiItemCallback &callback) override { - return Status::NotImplemented( - "RedisLogBasedActorInfoAccessor does not support AsyncGetAll."); - } - - Status AsyncGetByName(const std::string &name, - const OptionalItemCallback &callback) override { - return Status::NotImplemented( - "RedisLogBasedActorInfoAccessor does not support named detached actors."); - } - - Status AsyncRegisterActor(const TaskSpecification &task_spec, - const StatusCallback &callback) override; - - Status AsyncCreateActor(const TaskSpecification &task_spec, - const StatusCallback &callback) override; - - Status AsyncSubscribeAll(const SubscribeCallback &subscribe, - const StatusCallback &done) override; - - Status AsyncSubscribe(const ActorID &actor_id, - const SubscribeCallback &subscribe, - const StatusCallback &done) override; - - Status AsyncUnsubscribe(const ActorID &actor_id) override; - - void AsyncResubscribe(bool is_pubsub_server_restarted) override {} - - bool IsActorUnsubscribed(const ActorID &actor_id) override { return false; } - - protected: - virtual std::vector GetAllActorID() const; - virtual Status Get(const ActorID &actor_id, ActorTableData *actor_table_data) const; - - RedisGcsClient *client_impl_{nullptr}; - // Use a random NodeID for actor subscription. Because: - // If we use NodeID::Nil, GCS will still send all actors' updates to this GCS Client. - // Even we can filter out irrelevant updates, but there will be extra overhead. - // And because the new GCS Client will no longer hold the local NodeID, so we use - // random NodeID instead. - // TODO(micafan): Remove this random id, once GCS becomes a service. - NodeID subscribe_id_{NodeID::FromRandom()}; - - private: - typedef SubscriptionExecutor - ActorSubscriptionExecutor; - ActorSubscriptionExecutor log_based_actor_sub_executor_; -}; - -/// \class RedisActorInfoAccessor -/// `RedisActorInfoAccessor` is an implementation of `ActorInfoAccessor` -/// that uses Redis as the backend storage. -class RedisActorInfoAccessor : public RedisLogBasedActorInfoAccessor { - public: - explicit RedisActorInfoAccessor(RedisGcsClient *client_impl); - - virtual ~RedisActorInfoAccessor() {} - - Status AsyncGet(const ActorID &actor_id, - const OptionalItemCallback &callback) override; - - Status AsyncGetAll(const MultiItemCallback &callback) override; - - Status AsyncGetByName(const std::string &name, - const OptionalItemCallback &callback) override { - return Status::NotImplemented( - "RedisActorInfoAccessor does not support named detached actors."); - } - - Status AsyncSubscribeAll(const SubscribeCallback &subscribe, - const StatusCallback &done) override; - - Status AsyncSubscribe(const ActorID &actor_id, - const SubscribeCallback &subscribe, - const StatusCallback &done) override; - - Status AsyncUnsubscribe(const ActorID &actor_id) override; - - protected: - std::vector GetAllActorID() const override; - Status Get(const ActorID &actor_id, ActorTableData *actor_table_data) const override; - - private: - typedef SubscriptionExecutor - ActorSubscriptionExecutor; - ActorSubscriptionExecutor actor_sub_executor_; -}; - -/// \class RedisJobInfoAccessor -/// RedisJobInfoAccessor is an implementation of `JobInfoAccessor` -/// that uses Redis as the backend storage. -class RedisJobInfoAccessor : public JobInfoAccessor { - public: - explicit RedisJobInfoAccessor(RedisGcsClient *client_impl); - - virtual ~RedisJobInfoAccessor() {} - - Status AsyncAdd(const std::shared_ptr &data_ptr, - const StatusCallback &callback) override; - - Status AsyncMarkFinished(const JobID &job_id, const StatusCallback &callback) override; - - Status AsyncSubscribeAll(const SubscribeCallback &subscribe, - const StatusCallback &done) override; - - Status AsyncGetAll(const MultiItemCallback &callback) override { - return Status::NotImplemented("AsyncGetAll not implemented"); - } - - void AsyncResubscribe(bool is_pubsub_server_restarted) override {} - - private: - /// Append job information to GCS asynchronously. - /// - /// \param data_ptr The job information that will be appended to GCS. - /// \param callback Callback that will be called after append done. - /// \return Status - Status DoAsyncAppend(const std::shared_ptr &data_ptr, - const StatusCallback &callback); - - RedisGcsClient *client_impl_{nullptr}; - - typedef SubscriptionExecutor JobSubscriptionExecutor; - JobSubscriptionExecutor job_sub_executor_; -}; - -/// \class RedisTaskInfoAccessor -/// `RedisTaskInfoAccessor` is an implementation of `TaskInfoAccessor` -/// that uses Redis as the backend storage. -class RedisTaskInfoAccessor : public TaskInfoAccessor { - public: - explicit RedisTaskInfoAccessor(RedisGcsClient *client_impl); - - virtual ~RedisTaskInfoAccessor() {} - - Status AsyncAdd(const std::shared_ptr &data_ptr, - const StatusCallback &callback) override; - - Status AsyncGet(const TaskID &task_id, - const OptionalItemCallback &callback) override; - - Status AsyncSubscribe(const TaskID &task_id, - const SubscribeCallback &subscribe, - const StatusCallback &done) override; - - Status AsyncUnsubscribe(const TaskID &task_id) override; - - Status AsyncAddTaskLease(const std::shared_ptr &data_ptr, - const StatusCallback &callback) override; - - Status AsyncGetTaskLease(const TaskID &task_id, - const OptionalItemCallback &callback) override; - - Status AsyncSubscribeTaskLease( - const TaskID &task_id, - const SubscribeCallback> &subscribe, - const StatusCallback &done) override; - - Status AsyncUnsubscribeTaskLease(const TaskID &task_id) override; - - Status AttemptTaskReconstruction( - const std::shared_ptr &data_ptr, - const StatusCallback &callback) override; - - void AsyncResubscribe(bool is_pubsub_server_restarted) override {} - - bool IsTaskUnsubscribed(const TaskID &task_id) override { return false; } - - bool IsTaskLeaseUnsubscribed(const TaskID &task_id) override { return false; } - - private: - RedisGcsClient *client_impl_{nullptr}; - // Use a random NodeID for task subscription. Because: - // If we use NodeID::Nil, GCS will still send all tasks' updates to this GCS Client. - // Even we can filter out irrelevant updates, but there will be extra overhead. - // And because the new GCS Client will no longer hold the local NodeID, so we use - // random NodeID instead. - // TODO(micafan): Remove this random id, once GCS becomes a service. - NodeID subscribe_id_{NodeID::FromRandom()}; - - typedef SubscriptionExecutor - TaskSubscriptionExecutor; - TaskSubscriptionExecutor task_sub_executor_; - - typedef SubscriptionExecutor, TaskLeaseTable> - TaskLeaseSubscriptionExecutor; - TaskLeaseSubscriptionExecutor task_lease_sub_executor_; -}; - -/// \class RedisObjectInfoAccessor -/// RedisObjectInfoAccessor is an implementation of `ObjectInfoAccessor` -/// that uses Redis as the backend storage. -class RedisObjectInfoAccessor : public ObjectInfoAccessor { - public: - explicit RedisObjectInfoAccessor(RedisGcsClient *client_impl); - - virtual ~RedisObjectInfoAccessor() {} - - Status AsyncGetLocations( - const ObjectID &object_id, - const OptionalItemCallback &callback) override; - - Status AsyncGetAll( - const MultiItemCallback &callback) override { - return Status::NotImplemented("AsyncGetAll not implemented"); - } - - Status AsyncAddLocation(const ObjectID &object_id, const NodeID &node_id, - const StatusCallback &callback) override; - - Status AsyncAddSpilledUrl(const ObjectID &object_id, const std::string &spilled_url, - const StatusCallback &callback) override { - return Status::NotImplemented("AsyncAddSpilledUrl not implemented"); - } - - Status AsyncRemoveLocation(const ObjectID &object_id, const NodeID &node_id, - const StatusCallback &callback) override; - - Status AsyncSubscribeToLocations( - const ObjectID &object_id, - const SubscribeCallback> - &subscribe, - const StatusCallback &done) override; - - Status AsyncUnsubscribeToLocations(const ObjectID &object_id) override; - - void AsyncResubscribe(bool is_pubsub_server_restarted) override {} - - bool IsObjectUnsubscribed(const ObjectID &object_id) override { return false; } - - private: - RedisGcsClient *client_impl_{nullptr}; - - // Use a random NodeID for object subscription. Because: - // If we use NodeID::Nil, GCS will still send all objects' updates to this GCS Client. - // Even we can filter out irrelevant updates, but there will be extra overhead. - // And because the new GCS Client will no longer hold the local NodeID, so we use - // random NodeID instead. - // TODO(micafan): Remove this random id, once GCS becomes a service. - NodeID subscribe_id_{NodeID::FromRandom()}; - - typedef SubscriptionExecutor - ObjectSubscriptionExecutor; - ObjectSubscriptionExecutor object_sub_executor_; -}; - -/// \class RedisNodeInfoAccessor -/// RedisNodeInfoAccessor is an implementation of `NodeInfoAccessor` -/// that uses Redis as the backend storage. -class RedisNodeInfoAccessor : public NodeInfoAccessor { - public: - explicit RedisNodeInfoAccessor(RedisGcsClient *client_impl); - - virtual ~RedisNodeInfoAccessor() {} - - Status RegisterSelf(const GcsNodeInfo &local_node_info, - const StatusCallback &callback) override; - - Status UnregisterSelf() override; - - const NodeID &GetSelfId() const override; - - const GcsNodeInfo &GetSelfInfo() const override; - - Status AsyncRegister(const GcsNodeInfo &node_info, - const StatusCallback &callback) override; - - Status AsyncUnregister(const NodeID &node_id, const StatusCallback &callback) override; - - Status AsyncGetAll(const MultiItemCallback &callback) override; - - Status AsyncSubscribeToNodeChange( - const SubscribeCallback &subscribe, - const StatusCallback &done) override; - - boost::optional Get(const NodeID &node_id, - bool filter_dead_nodes = true) const override; - - const std::unordered_map &GetAll() const override; - - bool IsRemoved(const NodeID &node_id) const override; - - Status AsyncReportHeartbeat(const std::shared_ptr &data_ptr, - const StatusCallback &callback) override; - - Status AsyncReportResourceUsage(const std::shared_ptr &data_ptr, - const StatusCallback &callback) override; - - void AsyncReReportResourceUsage() override; - - Status AsyncGetAllResourceUsage( - const ItemCallback &callback) override { - return Status::NotImplemented("AsyncGetAllResourceUsage not implemented"); - } - - Status AsyncSubscribeBatchedResourceUsage( - const ItemCallback &subscribe, - const StatusCallback &done) override; - - void AsyncResubscribe(bool is_pubsub_server_restarted) override {} - - Status AsyncSetInternalConfig( - std::unordered_map &config) override { - return Status::NotImplemented("SetInternaConfig not implemented."); - } - - Status AsyncGetInternalConfig( - const OptionalItemCallback> &callback) - override { - return Status::NotImplemented("GetInternalConfig not implemented."); - } - - private: - RedisGcsClient *client_impl_{nullptr}; - - typedef SubscriptionExecutor - HeartbeatBatchSubscriptionExecutor; - HeartbeatBatchSubscriptionExecutor resource_usage_batch_sub_executor_; -}; - -/// \class RedisNodeResourceInfoAccessor -/// RedisNodeResourceInfoAccessor is an implementation of `NodeResourceInfoAccessor` -/// that uses Redis as the backend storage. -class RedisNodeResourceInfoAccessor : public NodeResourceInfoAccessor { - public: - explicit RedisNodeResourceInfoAccessor(RedisGcsClient *client_impl); - - virtual ~RedisNodeResourceInfoAccessor() {} - - Status AsyncGetResources(const NodeID &node_id, - const OptionalItemCallback &callback) override; - - Status AsyncGetAllAvailableResources( - const MultiItemCallback &callback) override { - return Status::NotImplemented("AsyncGetAllAvailableResources not implemented"); - } - - Status AsyncUpdateResources(const NodeID &node_id, const ResourceMap &resources, - const StatusCallback &callback) override; - - Status AsyncDeleteResources(const NodeID &node_id, - const std::vector &resource_names, - const StatusCallback &callback) override; - - Status AsyncSubscribeToResources(const ItemCallback &subscribe, - const StatusCallback &done) override; - - void AsyncResubscribe(bool is_pubsub_server_restarted) override {} - - private: - RedisGcsClient *client_impl_{nullptr}; - - typedef SubscriptionExecutor - DynamicResourceSubscriptionExecutor; - DynamicResourceSubscriptionExecutor resource_sub_executor_; -}; - -/// \class RedisErrorInfoAccessor -/// RedisErrorInfoAccessor is an implementation of `ErrorInfoAccessor` -/// that uses Redis as the backend storage. -class RedisErrorInfoAccessor : public ErrorInfoAccessor { - public: - explicit RedisErrorInfoAccessor(RedisGcsClient *client_impl); - - virtual ~RedisErrorInfoAccessor() = default; - - Status AsyncReportJobError(const std::shared_ptr &data_ptr, - const StatusCallback &callback) override; -}; - -/// \class RedisStatsInfoAccessor -/// RedisStatsInfoAccessor is an implementation of `StatsInfoAccessor` -/// that uses Redis as the backend storage. -class RedisStatsInfoAccessor : public StatsInfoAccessor { - public: - explicit RedisStatsInfoAccessor(RedisGcsClient *client_impl); - - virtual ~RedisStatsInfoAccessor() = default; - - Status AsyncAddProfileData(const std::shared_ptr &data_ptr, - const StatusCallback &callback) override; - - Status AsyncGetAll(const MultiItemCallback &callback) override { - return Status::NotImplemented("AsyncGetAll not implemented"); - } - - private: - RedisGcsClient *client_impl_{nullptr}; -}; - -/// \class RedisWorkerInfoAccessor -/// RedisWorkerInfoAccessor is an implementation of `WorkerInfoAccessor` -/// that uses Redis as the backend storage. -class RedisWorkerInfoAccessor : public WorkerInfoAccessor { - public: - explicit RedisWorkerInfoAccessor(RedisGcsClient *client_impl); - - virtual ~RedisWorkerInfoAccessor() = default; - - Status AsyncSubscribeToWorkerFailures( - const SubscribeCallback &subscribe, - const StatusCallback &done) override; - - Status AsyncReportWorkerFailure(const std::shared_ptr &data_ptr, - const StatusCallback &callback) override; - - Status AsyncGet(const WorkerID &worker_id, - const OptionalItemCallback &callback) override; - - Status AsyncGetAll(const MultiItemCallback &callback) override; - - Status AsyncAdd(const std::shared_ptr &data_ptr, - const StatusCallback &callback) override; - - void AsyncResubscribe(bool is_pubsub_server_restarted) override {} - - private: - RedisGcsClient *client_impl_{nullptr}; - - typedef SubscriptionExecutor - WorkerFailureSubscriptionExecutor; - WorkerFailureSubscriptionExecutor worker_failure_sub_executor_; -}; - -class RedisPlacementGroupInfoAccessor : public PlacementGroupInfoAccessor { - public: - virtual ~RedisPlacementGroupInfoAccessor() = default; - - Status AsyncCreatePlacementGroup( - const PlacementGroupSpecification &placement_group_spec) override; - - Status AsyncRemovePlacementGroup(const PlacementGroupID &placement_group_id, - const StatusCallback &callback) override; - - Status AsyncGet( - const PlacementGroupID &placement_group_id, - const OptionalItemCallback &callback) override; - - Status AsyncGetAll( - const MultiItemCallback &callback) override; - - Status AsyncWaitUntilReady(const PlacementGroupID &placement_group_id, - const StatusCallback &callback) override; -}; - -} // namespace gcs - -} // namespace ray diff --git a/src/ray/gcs/redis_gcs_client.cc b/src/ray/gcs/redis_gcs_client.cc deleted file mode 100644 index 1b2359346..000000000 --- a/src/ray/gcs/redis_gcs_client.cc +++ /dev/null @@ -1,144 +0,0 @@ -// Copyright 2017 The Ray Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "ray/gcs/redis_gcs_client.h" - -#include "ray/common/ray_config.h" -#include "ray/gcs/redis_accessor.h" -#include "ray/gcs/redis_context.h" - -namespace ray { - -namespace gcs { - -RedisGcsClient::RedisGcsClient(const GcsClientOptions &options) - : RedisGcsClient(options, CommandType::kRegular) {} - -RedisGcsClient::RedisGcsClient(const GcsClientOptions &options, CommandType command_type) - : GcsClient(options), command_type_(command_type) { - RedisClientOptions redis_client_options(options.server_ip_, options.server_port_, - options.password_, options.is_test_client_); - redis_client_.reset(new RedisClient(redis_client_options)); -} - -Status RedisGcsClient::Connect(boost::asio::io_service &io_service) { - RAY_CHECK(!is_connected_); - - Status status = redis_client_->Connect(io_service); - if (!status.ok()) { - RAY_LOG(INFO) << "RedisGcsClient::Connect failed, status " << status.ToString(); - return status; - } - - std::shared_ptr primary_context = redis_client_->GetPrimaryContext(); - std::vector> shard_contexts = - redis_client_->GetShardContexts(); - - log_based_actor_table_.reset(new LogBasedActorTable({primary_context}, this)); - actor_table_.reset(new ActorTable({primary_context}, this)); - - // TODO(micafan) Modify NodeTable' Constructor(remove NodeID) in future. - // We will use NodeID instead of NodeID. - // For worker/driver, it might not have this field(NodeID). - // For raylet, NodeID should be initialized in raylet layer(not here). - node_table_.reset(new NodeTable({primary_context}, this)); - - job_table_.reset(new JobTable({primary_context}, this)); - resource_usage_batch_table_.reset(new ResourceUsageBatchTable({primary_context}, this)); - // Tables below would be sharded. - object_table_.reset(new ObjectTable(shard_contexts, this)); - raylet_task_table_.reset(new raylet::TaskTable(shard_contexts, this, command_type_)); - task_reconstruction_log_.reset(new TaskReconstructionLog(shard_contexts, this)); - task_lease_table_.reset(new TaskLeaseTable(shard_contexts, this)); - heartbeat_table_.reset(new HeartbeatTable(shard_contexts, this)); - profile_table_.reset(new ProfileTable(shard_contexts, this)); - resource_table_.reset(new DynamicResourceTable({primary_context}, this)); - worker_table_.reset(new WorkerTable(shard_contexts, this)); - - actor_accessor_.reset(new RedisActorInfoAccessor(this)); - - job_accessor_.reset(new RedisJobInfoAccessor(this)); - object_accessor_.reset(new RedisObjectInfoAccessor(this)); - node_accessor_.reset(new RedisNodeInfoAccessor(this)); - node_resource_accessor_.reset(new RedisNodeResourceInfoAccessor(this)); - task_accessor_.reset(new RedisTaskInfoAccessor(this)); - error_accessor_.reset(new RedisErrorInfoAccessor(this)); - stats_accessor_.reset(new RedisStatsInfoAccessor(this)); - worker_accessor_.reset(new RedisWorkerInfoAccessor(this)); - placement_group_accessor_.reset(new RedisPlacementGroupInfoAccessor()); - - is_connected_ = true; - - RAY_LOG(DEBUG) << "RedisGcsClient connected."; - - return Status::OK(); -} - -void RedisGcsClient::Disconnect() { - RAY_CHECK(is_connected_); - is_connected_ = false; - redis_client_->Disconnect(); - RAY_LOG(DEBUG) << "RedisGcsClient Disconnected."; -} - -std::string RedisGcsClient::DebugString() const { - std::stringstream result; - result << "RedisGcsClient:"; - result << "\n- TaskTable: " << raylet_task_table_->DebugString(); - result << "\n- LogBasedActorTable: " << log_based_actor_table_->DebugString(); - result << "\n- ActorTable: " << actor_table_->DebugString(); - result << "\n- TaskReconstructionLog: " << task_reconstruction_log_->DebugString(); - result << "\n- TaskLeaseTable: " << task_lease_table_->DebugString(); - result << "\n- HeartbeatTable: " << heartbeat_table_->DebugString(); - result << "\n- ProfileTable: " << profile_table_->DebugString(); - result << "\n- NodeTable: " << node_table_->DebugString(); - result << "\n- JobTable: " << job_table_->DebugString(); - return result.str(); -} - -ObjectTable &RedisGcsClient::object_table() { return *object_table_; } - -raylet::TaskTable &RedisGcsClient::raylet_task_table() { return *raylet_task_table_; } - -LogBasedActorTable &RedisGcsClient::log_based_actor_table() { - return *log_based_actor_table_; -} - -ActorTable &RedisGcsClient::actor_table() { return *actor_table_; } - -WorkerTable &RedisGcsClient::worker_table() { return *worker_table_; } - -TaskReconstructionLog &RedisGcsClient::task_reconstruction_log() { - return *task_reconstruction_log_; -} - -TaskLeaseTable &RedisGcsClient::task_lease_table() { return *task_lease_table_; } - -NodeTable &RedisGcsClient::node_table() { return *node_table_; } - -HeartbeatTable &RedisGcsClient::heartbeat_table() { return *heartbeat_table_; } - -ResourceUsageBatchTable &RedisGcsClient::resource_usage_batch_table() { - return *resource_usage_batch_table_; -} - -JobTable &RedisGcsClient::job_table() { return *job_table_; } - -ProfileTable &RedisGcsClient::profile_table() { return *profile_table_; } - -DynamicResourceTable &RedisGcsClient::resource_table() { return *resource_table_; } - -} // namespace gcs - -} // namespace ray diff --git a/src/ray/gcs/redis_gcs_client.h b/src/ray/gcs/redis_gcs_client.h deleted file mode 100644 index 748b1da72..000000000 --- a/src/ray/gcs/redis_gcs_client.h +++ /dev/null @@ -1,131 +0,0 @@ -// Copyright 2017 The Ray Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include -#include - -#include "ray/common/id.h" -#include "ray/common/status.h" -#include "ray/gcs/asio.h" -#include "ray/gcs/gcs_client.h" -#include "ray/gcs/redis_client.h" -#include "ray/gcs/tables.h" -#include "ray/util/logging.h" - -namespace ray { - -namespace gcs { - -class RedisContext; - -class RAY_EXPORT RedisGcsClient : public GcsClient { - public: - /// Constructor of RedisGcsClient. - /// Connect() must be called(and return ok) before you call any other methods. - /// TODO(micafan) To read and write from the GCS tables requires a further - /// call to Connect() to the client table. Will fix this in next pr. - /// - /// \param options Options of this client, e.g. server address, password and so on. - RedisGcsClient(const GcsClientOptions &options); - - /// This constructor is only used for testing. - /// Connect() must be called(and return ok) before you call any other methods. - /// - /// \param options Options of this client, e.g. server address, password and so on. - /// \param command_type The commands issued type. - RedisGcsClient(const GcsClientOptions &options, CommandType command_type); - - /// Connect to GCS Service. Non-thread safe. - /// Call this function before calling other functions. - /// - /// \param io_service The event loop for this client. - /// Must be single-threaded io_service (get more information from RedisAsioClient). - /// - /// \return Status - Status Connect(boost::asio::io_service &io_service) override; - - /// Disconnect with GCS Service. Non-thread safe. - void Disconnect() override; - - /// Returns debug string for class. - /// - /// \return string. - std::string DebugString() const override; - - // We also need something to export generic code to run on workers from the - // driver (to set the PYTHONPATH) - using GetExportCallback = std::function; - Status AddExport(const std::string &job_id, std::string &export_data); - Status GetExport(const std::string &job_id, int64_t export_index, - const GetExportCallback &done_callback); - - std::vector> shard_contexts() { - return redis_client_->GetShardContexts(); - } - - std::shared_ptr primary_context() { - return redis_client_->GetPrimaryContext(); - } - - std::shared_ptr GetRedisClient() const { return redis_client_; } - - /// The following xxx_table methods implement the Accessor interfaces. - /// Implements the Actors() interface. - LogBasedActorTable &log_based_actor_table(); - ActorTable &actor_table(); - /// Implements the Jobs() interface. - JobTable &job_table(); - /// Implements the Objects() interface. - ObjectTable &object_table(); - /// Implements the Nodes() interface. - NodeTable &node_table(); - HeartbeatTable &heartbeat_table(); - ResourceUsageBatchTable &resource_usage_batch_table(); - DynamicResourceTable &resource_table(); - /// Implements the Tasks() interface. - virtual raylet::TaskTable &raylet_task_table(); - TaskLeaseTable &task_lease_table(); - TaskReconstructionLog &task_reconstruction_log(); - /// Implements the Stats() interface. - ProfileTable &profile_table(); - /// Implements the Workers() interface. - WorkerTable &worker_table(); - - private: - // GCS command type. If CommandType::kChain, chain-replicated versions of the tables - // might be used, if available. - CommandType command_type_{CommandType::kUnknown}; - - std::shared_ptr redis_client_; - - std::unique_ptr object_table_; - std::unique_ptr raylet_task_table_; - std::unique_ptr log_based_actor_table_; - std::unique_ptr actor_table_; - std::unique_ptr task_reconstruction_log_; - std::unique_ptr task_lease_table_; - std::unique_ptr heartbeat_table_; - std::unique_ptr resource_usage_batch_table_; - std::unique_ptr profile_table_; - std::unique_ptr node_table_; - std::unique_ptr resource_table_; - std::unique_ptr worker_table_; - std::unique_ptr job_table_; -}; - -} // namespace gcs - -} // namespace ray diff --git a/src/ray/gcs/subscription_executor.cc b/src/ray/gcs/subscription_executor.cc deleted file mode 100644 index d9617985a..000000000 --- a/src/ray/gcs/subscription_executor.cc +++ /dev/null @@ -1,215 +0,0 @@ -// Copyright 2017 The Ray Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "ray/gcs/subscription_executor.h" - -namespace ray { - -namespace gcs { - -template -Status SubscriptionExecutor::AsyncSubscribeAll( - const NodeID &node_id, const SubscribeCallback &subscribe, - const StatusCallback &done) { - // TODO(micafan) Optimize the lock when necessary. - // Consider avoiding locking in single-threaded processes. - std::unique_lock lock(mutex_); - - if (subscribe_all_callback_ != nullptr) { - RAY_LOG(DEBUG) << "Duplicate subscription! Already subscribed to all elements."; - return Status::Invalid("Duplicate subscription!"); - } - - if (registration_status_ != RegistrationStatus::kNotRegistered) { - if (subscribe != nullptr) { - RAY_LOG(DEBUG) << "Duplicate subscription! Already subscribed to specific elements" - ", can't subscribe to all elements."; - return Status::Invalid("Duplicate subscription!"); - } - } - - if (registration_status_ == RegistrationStatus::kRegistered) { - // Already registered to GCS, just invoke the `done` callback. - lock.unlock(); - if (done != nullptr) { - done(Status::OK()); - } - return Status::OK(); - } - - // Registration to GCS is not finished yet, add the `done` callback to the pending list - // to be invoked when registration is done. - if (done != nullptr) { - pending_subscriptions_.emplace_back(done); - } - - // If there's another registration request that's already on-going, then wait for it - // to finish. - if (registration_status_ == RegistrationStatus::kRegistering) { - return Status::OK(); - } - - auto on_subscribe = [this](RedisGcsClient *client, const ID &id, - const std::vector &result) { - if (result.empty()) { - return; - } - - SubscribeCallback sub_one_callback = nullptr; - SubscribeCallback sub_all_callback = nullptr; - { - std::unique_lock lock(mutex_); - const auto it = id_to_callback_map_.find(id); - if (it != id_to_callback_map_.end()) { - sub_one_callback = it->second; - } - sub_all_callback = subscribe_all_callback_; - } - if (sub_one_callback != nullptr) { - sub_one_callback(id, result.back()); - } - if (sub_all_callback != nullptr) { - RAY_CHECK(sub_one_callback == nullptr); - sub_all_callback(id, result.back()); - } - }; - - auto on_done = [this](RedisGcsClient *client) { - std::list pending_callbacks; - { - std::unique_lock lock(mutex_); - registration_status_ = RegistrationStatus::kRegistered; - pending_callbacks.swap(pending_subscriptions_); - RAY_CHECK(pending_subscriptions_.empty()); - } - - for (const auto &callback : pending_callbacks) { - callback(Status::OK()); - } - }; - - Status status = table_.Subscribe(JobID::Nil(), node_id, on_subscribe, on_done); - if (status.ok()) { - registration_status_ = RegistrationStatus::kRegistering; - subscribe_all_callback_ = subscribe; - } - - return status; -} - -template -Status SubscriptionExecutor::AsyncSubscribe( - const NodeID &node_id, const ID &id, const SubscribeCallback &subscribe, - const StatusCallback &done) { - RAY_CHECK(node_id != NodeID::Nil()); - - // NOTE(zhijunfu): `Subscribe` and other operations use different redis contexts, - // thus we need to call `RequestNotifications` in the Subscribe callback to ensure - // it's processed after the `Subscribe` request. Otherwise if `RequestNotifications` - // is processed first we will miss the initial notification. - auto on_subscribe_done = [this, node_id, id, subscribe, done](Status status) { - auto on_request_notification_done = [this, done, id](Status status) { - if (!status.ok()) { - std::unique_lock lock(mutex_); - id_to_callback_map_.erase(id); - } - if (done != nullptr) { - done(status); - } - }; - - { - std::unique_lock lock(mutex_); - status = table_.RequestNotifications(JobID::Nil(), id, node_id, - on_request_notification_done); - if (!status.ok()) { - id_to_callback_map_.erase(id); - } - } - }; - - { - std::unique_lock lock(mutex_); - const auto it = id_to_callback_map_.find(id); - if (it != id_to_callback_map_.end()) { - RAY_LOG(DEBUG) << "Duplicate subscription to id " << id << " node_id " << node_id; - return Status::Invalid("Duplicate subscription to element!"); - } - id_to_callback_map_[id] = subscribe; - } - - auto status = AsyncSubscribeAll(node_id, nullptr, on_subscribe_done); - if (!status.ok()) { - std::unique_lock lock(mutex_); - id_to_callback_map_.erase(id); - } - return status; -} - -template -Status SubscriptionExecutor::AsyncUnsubscribe( - const NodeID &node_id, const ID &id, const StatusCallback &done) { - SubscribeCallback subscribe = nullptr; - { - std::unique_lock lock(mutex_); - const auto it = id_to_callback_map_.find(id); - if (it == id_to_callback_map_.end()) { - RAY_LOG(DEBUG) << "Invalid Unsubscribe! id " << id << " node_id " << node_id; - return Status::Invalid("Invalid Unsubscribe, no existing subscription found."); - } - subscribe = std::move(it->second); - id_to_callback_map_.erase(it); - } - - RAY_CHECK(subscribe != nullptr); - auto on_done = [this, id, subscribe, done](Status status) { - if (!status.ok()) { - std::unique_lock lock(mutex_); - const auto it = id_to_callback_map_.find(id); - if (it != id_to_callback_map_.end()) { - // The initial AsyncUnsubscribe deleted the callback, but the client - // has subscribed again in the meantime. This new callback will be - // called if we receive more notifications. - RAY_LOG(WARNING) - << "Client called AsyncSubscribe on " << id - << " while AsyncUnsubscribe was pending, but the unsubscribe failed."; - } else { - // The Unsubscribe failed, so restore the initial callback. - id_to_callback_map_[id] = subscribe; - } - } - if (done != nullptr) { - done(status); - } - }; - - return table_.CancelNotifications(JobID::Nil(), id, node_id, on_done); -} - -template class SubscriptionExecutor; -template class SubscriptionExecutor; -template class SubscriptionExecutor; -template class SubscriptionExecutor; -template class SubscriptionExecutor; -template class SubscriptionExecutor, - TaskLeaseTable>; -template class SubscriptionExecutor; -template class SubscriptionExecutor; -template class SubscriptionExecutor; - -} // namespace gcs - -} // namespace ray diff --git a/src/ray/gcs/subscription_executor.h b/src/ray/gcs/subscription_executor.h deleted file mode 100644 index 48a912f3e..000000000 --- a/src/ray/gcs/subscription_executor.h +++ /dev/null @@ -1,108 +0,0 @@ -// Copyright 2017 The Ray Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include -#include -#include - -#include "ray/gcs/callback.h" -#include "ray/gcs/tables.h" - -namespace ray { - -namespace gcs { - -/// \class SubscriptionExecutor -/// SubscriptionExecutor class encapsulates the implementation details of -/// subscribe/unsubscribe to elements (e.g.: actors or tasks or objects or nodes). -/// Support subscribing to a specific element or subscribing to all elements. -template -class SubscriptionExecutor { - public: - explicit SubscriptionExecutor(Table &table) : table_(table) {} - - ~SubscriptionExecutor() {} - - /// Subscribe to operations of all elements. - /// Repeated subscription will return a failure. - /// - /// \param node_id The type of update to listen to. If this is nil, then a - /// message for each update will be received. Else, only - /// messages for the given node will be received. - /// \param subscribe Callback that will be called each time when an element - /// is registered or updated. - /// \param done Callback that will be called when subscription is complete. - /// \return Status - Status AsyncSubscribeAll(const NodeID &node_id, - const SubscribeCallback &subscribe, - const StatusCallback &done); - - /// Subscribe to operations of an element. - /// Repeated subscription to an element will return a failure. - /// - /// \param node_id The type of update to listen to. If this is nil, then a - /// message for each update will be received. Else, only - /// messages for the given node will be received. - /// \param id The id of the element to be subscribe to. - /// \param subscribe Callback that will be called each time when the element - /// is registered or updated. - /// \param done Callback that will be called when subscription is complete. - /// \return Status - Status AsyncSubscribe(const NodeID &node_id, const ID &id, - const SubscribeCallback &subscribe, - const StatusCallback &done); - - /// Cancel subscription to an element. - /// Unsubscribing can only be called after the subscription request is completed. - /// - /// \param node_id The type of update to listen to. If this is nil, then a - /// message for each update will be received. Else, only - /// messages for the given node will be received. - /// \param id The id of the element to be unsubscribed to. - /// \param done Callback that will be called when cancel subscription is complete. - /// \return Status - Status AsyncUnsubscribe(const NodeID &node_id, const ID &id, - const StatusCallback &done); - - private: - Table &table_; - - std::mutex mutex_; - - enum class RegistrationStatus : uint8_t { - kNotRegistered, - kRegistering, - kRegistered, - }; - - /// Whether successfully registered subscription to GCS. - RegistrationStatus registration_status_{RegistrationStatus::kNotRegistered}; - - /// List of subscriptions before registration to GCS is done, these callbacks - /// will be called when the registration to GCS finishes. - std::list pending_subscriptions_; - - /// Subscribe Callback of all elements. - SubscribeCallback subscribe_all_callback_{nullptr}; - - /// A mapping from element ID to subscription callback. - typedef std::unordered_map> IDToCallbackMap; - IDToCallbackMap id_to_callback_map_; -}; - -} // namespace gcs - -} // namespace ray diff --git a/src/ray/gcs/tables.cc b/src/ray/gcs/tables.cc deleted file mode 100644 index 2017d05de..000000000 --- a/src/ray/gcs/tables.cc +++ /dev/null @@ -1,847 +0,0 @@ -// Copyright 2017 The Ray Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "ray/gcs/tables.h" - -#include "absl/time/clock.h" -#include "ray/common/common_protocol.h" -#include "ray/common/grpc_util.h" -#include "ray/common/ray_config.h" -#include "ray/gcs/redis_gcs_client.h" - -extern "C" { -#include "hiredis/hiredis.h" -} - -namespace { - -static const std::string kTableAppendCommand = "RAY.TABLE_APPEND"; -static const std::string kChainTableAppendCommand = "RAY.CHAIN.TABLE_APPEND"; - -static const std::string kTableAddCommand = "RAY.TABLE_ADD"; -static const std::string kChainTableAddCommand = "RAY.CHAIN.TABLE_ADD"; - -std::string GetLogAppendCommand(const ray::gcs::CommandType command_type) { - if (command_type == ray::gcs::CommandType::kRegular) { - return kTableAppendCommand; - } else { - RAY_CHECK(command_type == ray::gcs::CommandType::kChain); - return kChainTableAppendCommand; - } -} - -std::string GetTableAddCommand(const ray::gcs::CommandType command_type) { - if (command_type == ray::gcs::CommandType::kRegular) { - return kTableAddCommand; - } else { - RAY_CHECK(command_type == ray::gcs::CommandType::kChain); - return kChainTableAddCommand; - } -} - -} // namespace - -namespace ray { - -namespace gcs { - -template -Status Log::Append(const JobID &job_id, const ID &id, - const std::shared_ptr &data, - const WriteCallback &done) { - num_appends_++; - auto callback = [this, id, data, done](std::shared_ptr reply) { - const auto status = reply->ReadAsStatus(); - // Failed to append the entry. - RAY_CHECK(status.ok()) << "Failed to execute command TABLE_APPEND:" - << status.ToString(); - if (done != nullptr) { - (done)(client_, id, *data); - } - }; - std::string str = data->SerializeAsString(); - return GetRedisContext(id)->RunAsync(GetLogAppendCommand(command_type_), id, str.data(), - str.length(), prefix_, pubsub_channel_, - std::move(callback)); -} - -template -Status Log::SyncAppend(const JobID &job_id, const ID &id, - const std::shared_ptr &data) { - num_appends_++; - std::string str = data->SerializeAsString(); - auto reply = - GetRedisContext(id)->RunSync(GetLogAppendCommand(command_type_), id, str.data(), - str.length(), prefix_, pubsub_channel_); - Status status = reply ? reply->ReadAsStatus() : Status::RedisError("Redis error"); - return status; -} - -template -Status Log::AppendAt(const JobID &job_id, const ID &id, - const std::shared_ptr &data, - const WriteCallback &done, const WriteCallback &failure, - int log_length) { - num_appends_++; - auto callback = [this, id, data, done, failure](std::shared_ptr reply) { - const auto status = reply->ReadAsStatus(); - if (status.ok()) { - if (done != nullptr) { - (done)(client_, id, *data); - } - } else { - if (failure != nullptr) { - (failure)(client_, id, *data); - } - } - }; - std::string str = data->SerializeAsString(); - return GetRedisContext(id)->RunAsync(GetLogAppendCommand(command_type_), id, str.data(), - str.length(), prefix_, pubsub_channel_, - std::move(callback), log_length); -} - -template -Status Log::Lookup(const JobID &job_id, const ID &id, const Callback &lookup) { - num_lookups_++; - auto callback = [this, id, lookup](std::shared_ptr reply) { - if (lookup != nullptr) { - std::vector results; - if (!reply->IsNil()) { - GcsEntry gcs_entry; - gcs_entry.ParseFromString(reply->ReadAsString()); - RAY_CHECK(ID::FromBinary(gcs_entry.id()) == id); - for (int64_t i = 0; i < gcs_entry.entries_size(); i++) { - Data data; - data.ParseFromString(gcs_entry.entries(i)); - results.emplace_back(std::move(data)); - } - } - lookup(client_, id, results); - } - }; - std::vector nil; - return GetRedisContext(id)->RunAsync("RAY.TABLE_LOOKUP", id, nil.data(), nil.size(), - prefix_, pubsub_channel_, std::move(callback)); -} - -template -Status Log::Subscribe(const JobID &job_id, const NodeID &node_id, - const Callback &subscribe, - const SubscriptionCallback &done) { - auto subscribe_wrapper = [subscribe](RedisGcsClient *client, const ID &id, - const GcsChangeMode change_mode, - const std::vector &data) { - RAY_CHECK(change_mode != GcsChangeMode::REMOVE); - subscribe(client, id, data); - }; - return Subscribe(job_id, node_id, subscribe_wrapper, done); -} - -template -Status Log::Subscribe(const JobID &job_id, const NodeID &node_id, - const NotificationCallback &subscribe, - const SubscriptionCallback &done) { - RAY_CHECK(subscribe_callback_index_ == -1) - << "Client called Subscribe twice on the same table"; - auto callback = [this, subscribe, done](std::shared_ptr reply) { - const auto data = reply->ReadAsPubsubData(); - - if (data.empty()) { - // No notification data is provided. This is the callback for the - // initial subscription request. - if (done != nullptr) { - done(client_); - } - } else { - // Data is provided. This is the callback for a message. - if (subscribe != nullptr) { - // Parse the notification. - GcsEntry gcs_entry; - gcs_entry.ParseFromString(data); - ID id = ID::FromBinary(gcs_entry.id()); - std::vector results; - for (int64_t i = 0; i < gcs_entry.entries_size(); i++) { - Data result; - result.ParseFromString(gcs_entry.entries(i)); - results.emplace_back(std::move(result)); - } - subscribe(client_, id, gcs_entry.change_mode(), results); - } - } - }; - - subscribe_callback_index_ = 1; - for (auto &context : shard_contexts_) { - RAY_RETURN_NOT_OK(context->SubscribeAsync(node_id, pubsub_channel_, callback, - &subscribe_callback_index_)); - } - return Status::OK(); -} - -template -Status Log::RequestNotifications(const JobID &job_id, const ID &id, - const NodeID &node_id, - const StatusCallback &done) { - RAY_CHECK(subscribe_callback_index_ >= 0) - << "Client requested notifications on a key before Subscribe completed"; - - RedisCallback callback = nullptr; - if (done != nullptr) { - callback = [done](std::shared_ptr reply) { - const auto status = reply->IsNil() - ? Status::OK() - : Status::RedisError("request notifications failed."); - done(status); - }; - } - - return GetRedisContext(id)->RunAsync("RAY.TABLE_REQUEST_NOTIFICATIONS", id, - node_id.Data(), node_id.Size(), prefix_, - pubsub_channel_, callback); -} - -template -Status Log::CancelNotifications(const JobID &job_id, const ID &id, - const NodeID &node_id, - const StatusCallback &done) { - RAY_CHECK(subscribe_callback_index_ >= 0) - << "Client canceled notifications on a key before Subscribe completed"; - - RedisCallback callback = nullptr; - if (done != nullptr) { - callback = [done](std::shared_ptr reply) { - const auto status = reply->ReadAsStatus(); - done(status); - }; - } - - return GetRedisContext(id)->RunAsync("RAY.TABLE_CANCEL_NOTIFICATIONS", id, - node_id.Data(), node_id.Size(), prefix_, - pubsub_channel_, callback); -} - -template -void Log::Delete(const JobID &job_id, const std::vector &ids) { - if (ids.empty()) { - return; - } - std::unordered_map sharded_data; - for (const auto &id : ids) { - sharded_data[GetRedisContext(id).get()] << id.Binary(); - } - // Breaking really large deletion commands into batches of smaller size. - const size_t batch_size = - RayConfig::instance().maximum_gcs_deletion_batch_size() * ID::Size(); - for (const auto &pair : sharded_data) { - std::string current_data = pair.second.str(); - for (size_t cur = 0; cur < pair.second.str().size(); cur += batch_size) { - size_t data_field_size = std::min(batch_size, current_data.size() - cur); - uint16_t id_count = data_field_size / ID::Size(); - // Send data contains id count and all the id data. - std::string send_data(data_field_size + sizeof(id_count), 0); - uint8_t *buffer = reinterpret_cast(&send_data[0]); - *reinterpret_cast(buffer) = id_count; - RAY_IGNORE_EXPR( - std::copy_n(reinterpret_cast(current_data.c_str() + cur), - data_field_size, buffer + sizeof(uint16_t))); - - RAY_IGNORE_EXPR( - pair.first->RunAsync("RAY.TABLE_DELETE", UniqueID::Nil(), - reinterpret_cast(send_data.c_str()), - send_data.size(), prefix_, pubsub_channel_, - /*redisCallback=*/nullptr)); - } - } -} - -template -void Log::Delete(const JobID &job_id, const ID &id) { - Delete(job_id, std::vector({id})); -} - -template -std::string Log::DebugString() const { - std::stringstream result; - result << "num lookups: " << num_lookups_ << ", num appends: " << num_appends_; - return result.str(); -} - -template -Status Table::Add(const JobID &job_id, const ID &id, - const std::shared_ptr &data, - const WriteCallback &done) { - num_adds_++; - auto callback = [this, id, data, done](std::shared_ptr reply) { - if (done != nullptr) { - (done)(client_, id, *data); - } - }; - std::string str = data->SerializeAsString(); - return GetRedisContext(id)->RunAsync(GetTableAddCommand(command_type_), id, str.data(), - str.length(), prefix_, pubsub_channel_, - std::move(callback)); -} - -template -Status Table::Lookup(const JobID &job_id, const ID &id, const Callback &lookup, - const FailureCallback &failure) { - num_lookups_++; - return Log::Lookup(job_id, id, - [lookup, failure](RedisGcsClient *client, const ID &id, - const std::vector &data) { - if (data.empty()) { - if (failure != nullptr) { - (failure)(client, id); - } - } else { - RAY_CHECK(data.size() == 1); - if (lookup != nullptr) { - (lookup)(client, id, data[0]); - } - } - }); -} - -template -Status Table::Subscribe(const JobID &job_id, const NodeID &node_id, - const Callback &subscribe, - const FailureCallback &failure, - const SubscriptionCallback &done) { - return Log::Subscribe( - job_id, node_id, - [subscribe, failure](RedisGcsClient *client, const ID &id, - const std::vector &data) { - RAY_CHECK(data.empty() || data.size() == 1); - if (data.size() == 1) { - subscribe(client, id, data[0]); - } else { - if (failure != nullptr) { - failure(client, id); - } - } - }, - done); -} - -template -Status Table::Subscribe(const JobID &job_id, const NodeID &node_id, - const Callback &subscribe, - const SubscriptionCallback &done) { - return Subscribe(job_id, node_id, subscribe, /*failure*/ nullptr, done); -} - -template -std::string Table::DebugString() const { - std::stringstream result; - result << "num lookups: " << num_lookups_ << ", num adds: " << num_adds_; - return result.str(); -} - -template -Status Set::Add(const JobID &job_id, const ID &id, - const std::shared_ptr &data, const WriteCallback &done) { - num_adds_++; - auto callback = [this, id, data, done](std::shared_ptr reply) { - if (done != nullptr) { - (done)(client_, id, *data); - } - }; - std::string str = data->SerializeAsString(); - return GetRedisContext(id)->RunAsync("RAY.SET_ADD", id, str.data(), str.length(), - prefix_, pubsub_channel_, std::move(callback)); -} - -template -Status Set::Remove(const JobID &job_id, const ID &id, - const std::shared_ptr &data, - const WriteCallback &done) { - num_removes_++; - auto callback = [this, id, data, done](std::shared_ptr reply) { - if (done != nullptr) { - (done)(client_, id, *data); - } - }; - std::string str = data->SerializeAsString(); - return GetRedisContext(id)->RunAsync("RAY.SET_REMOVE", id, str.data(), str.length(), - prefix_, pubsub_channel_, std::move(callback)); -} - -template -Status Set::Subscribe(const JobID &job_id, const NodeID &node_id, - const NotificationCallback &subscribe, - const SubscriptionCallback &done) { - auto on_subscribe = [subscribe](RedisGcsClient *client, const ID &id, - const GcsChangeMode change_mode, - const std::vector &data) { - ArrayNotification change_notification(change_mode, data); - std::vector> notification_vec; - notification_vec.emplace_back(std::move(change_notification)); - subscribe(client, id, notification_vec); - }; - return Log::Subscribe(job_id, node_id, on_subscribe, done); -} - -template -std::string Set::DebugString() const { - std::stringstream result; - result << "num lookups: " << num_lookups_ << ", num adds: " << num_adds_ - << ", num removes: " << num_removes_; - return result.str(); -} - -template -Status Hash::Update(const JobID &job_id, const ID &id, const DataMap &data_map, - const HashCallback &done) { - num_adds_++; - auto callback = [this, id, data_map, done](std::shared_ptr reply) { - if (done != nullptr) { - (done)(client_, id, data_map); - } - }; - GcsEntry gcs_entry; - gcs_entry.set_id(id.Binary()); - gcs_entry.set_change_mode(GcsChangeMode::APPEND_OR_ADD); - for (const auto &pair : data_map) { - gcs_entry.add_entries(pair.first); - gcs_entry.add_entries(pair.second->SerializeAsString()); - } - std::string str = gcs_entry.SerializeAsString(); - return GetRedisContext(id)->RunAsync("RAY.HASH_UPDATE", id, str.data(), str.size(), - prefix_, pubsub_channel_, std::move(callback)); -} - -template -Status Hash::RemoveEntries(const JobID &job_id, const ID &id, - const std::vector &keys, - const HashRemoveCallback &remove_callback) { - num_removes_++; - auto callback = [this, id, keys, - remove_callback](std::shared_ptr reply) { - if (remove_callback != nullptr) { - (remove_callback)(client_, id, keys); - } - }; - GcsEntry gcs_entry; - gcs_entry.set_id(id.Binary()); - gcs_entry.set_change_mode(GcsChangeMode::REMOVE); - for (const auto &key : keys) { - gcs_entry.add_entries(key); - } - std::string str = gcs_entry.SerializeAsString(); - return GetRedisContext(id)->RunAsync("RAY.HASH_UPDATE", id, str.data(), str.size(), - prefix_, pubsub_channel_, std::move(callback)); -} - -template -std::string Hash::DebugString() const { - std::stringstream result; - result << "num lookups: " << num_lookups_ << ", num adds: " << num_adds_ - << ", num removes: " << num_removes_; - return result.str(); -} - -template -Status Hash::Lookup(const JobID &job_id, const ID &id, - const HashCallback &lookup) { - num_lookups_++; - auto callback = [this, id, lookup](std::shared_ptr reply) { - if (lookup != nullptr) { - DataMap results; - if (!reply->IsNil()) { - const auto data = reply->ReadAsString(); - GcsEntry gcs_entry; - gcs_entry.ParseFromString(reply->ReadAsString()); - RAY_CHECK(ID::FromBinary(gcs_entry.id()) == id); - RAY_CHECK(gcs_entry.entries_size() % 2 == 0); - for (int i = 0; i < gcs_entry.entries_size(); i += 2) { - const auto &key = gcs_entry.entries(i); - const auto value = std::make_shared(); - value->ParseFromString(gcs_entry.entries(i + 1)); - results.emplace(key, std::move(value)); - } - } - lookup(client_, id, results); - } - }; - std::vector nil; - return GetRedisContext(id)->RunAsync("RAY.TABLE_LOOKUP", id, nil.data(), nil.size(), - prefix_, pubsub_channel_, std::move(callback)); -} - -template -Status Hash::Subscribe(const JobID &job_id, const NodeID &node_id, - const HashNotificationCallback &subscribe, - const SubscriptionCallback &done) { - RAY_CHECK(subscribe_callback_index_ == -1) - << "Client called Subscribe twice on the same table"; - auto callback = [this, subscribe, done](std::shared_ptr reply) { - const auto data = reply->ReadAsPubsubData(); - if (data.empty()) { - // No notification data is provided. This is the callback for the - // initial subscription request. - if (done != nullptr) { - done(client_); - } - } else { - // Data is provided. This is the callback for a message. - if (subscribe != nullptr) { - // Parse the notification. - GcsEntry gcs_entry; - gcs_entry.ParseFromString(data); - ID id = ID::FromBinary(gcs_entry.id()); - DataMap data_map; - if (gcs_entry.change_mode() == GcsChangeMode::REMOVE) { - for (const auto &key : gcs_entry.entries()) { - data_map.emplace(key, std::shared_ptr()); - } - } else { - RAY_CHECK(gcs_entry.entries_size() % 2 == 0); - for (int i = 0; i < gcs_entry.entries_size(); i += 2) { - const auto &key = gcs_entry.entries(i); - const auto value = std::make_shared(); - value->ParseFromString(gcs_entry.entries(i + 1)); - data_map.emplace(key, std::move(value)); - } - } - MapNotification notification(gcs_entry.change_mode(), - data_map); - std::vector> notification_vec; - notification_vec.emplace_back(std::move(notification)); - subscribe(client_, id, notification_vec); - } - } - }; - - subscribe_callback_index_ = 1; - for (auto &context : shard_contexts_) { - RAY_RETURN_NOT_OK(context->SubscribeAsync(node_id, pubsub_channel_, callback, - &subscribe_callback_index_)); - } - return Status::OK(); -} - -std::string ProfileTable::DebugString() const { - return Log::DebugString(); -} - -void NodeTable::RegisterNodeChangeCallback(const NodeChangeCallback &callback) { - RAY_CHECK(node_change_callback_ == nullptr); - node_change_callback_ = callback; - // Call the callback for any added clients that are cached. - for (const auto &entry : node_cache_) { - if (!entry.first.IsNil()) { - RAY_CHECK(entry.second.state() == GcsNodeInfo::ALIVE || - entry.second.state() == GcsNodeInfo::DEAD); - node_change_callback_(entry.first, entry.second); - } - } -} - -void NodeTable::HandleNotification(RedisGcsClient *client, const GcsNodeInfo &node_info) { - NodeID node_id = NodeID::FromBinary(node_info.node_id()); - bool is_alive = (node_info.state() == GcsNodeInfo::ALIVE); - // It's possible to get duplicate notifications from the client table, so - // check whether this notification is new. - auto entry = node_cache_.find(node_id); - bool is_notif_new; - if (entry == node_cache_.end()) { - // If the entry is not in the cache, then the notification is new. - is_notif_new = true; - } else { - // If the entry is in the cache, then the notification is new if the client - // was alive and is now dead or resources have been updated. - bool was_alive = (entry->second.state() == GcsNodeInfo::ALIVE); - is_notif_new = was_alive && !is_alive; - // Once a node with a given ID has been removed, it should never be added - // again. If the entry was in the cache and the node was deleted, check - // that this new notification is not an insertion. - if (!was_alive) { - RAY_CHECK(!is_alive) - << "Notification for addition of a node that was already removed:" << node_id; - } - } - - // Add the notification to our cache. Notifications are idempotent. - RAY_LOG(DEBUG) << "[NodeTableNotification] NodeTable Insertion/Deletion " - "notification for node id " - << node_id << ". IsAlive: " << is_alive - << ". Setting the node cache to data."; - node_cache_[node_id] = node_info; - - // If the notification is new, call any registered callbacks. - GcsNodeInfo &cache_data = node_cache_[node_id]; - if (is_notif_new) { - if (is_alive) { - RAY_CHECK(removed_nodes_.find(node_id) == removed_nodes_.end()); - } else { - // NOTE(swang): The node should be added to this data structure before - // the callback gets called, in case the callback depends on the data - // structure getting updated. - removed_nodes_.insert(node_id); - } - if (node_change_callback_ != nullptr) { - node_change_callback_(node_id, cache_data); - } - } -} - -const NodeID &NodeTable::GetLocalNodeId() const { - RAY_CHECK(!local_node_id_.IsNil()); - return local_node_id_; -} - -const GcsNodeInfo &NodeTable::GetLocalNode() const { return local_node_info_; } - -bool NodeTable::IsRemoved(const NodeID &node_id) const { - return removed_nodes_.count(node_id) == 1; -} - -Status NodeTable::Connect(const GcsNodeInfo &local_node_info) { - RAY_CHECK(!disconnected_) << "Tried to reconnect a disconnected node."; - RAY_CHECK(local_node_id_.IsNil()) << "This node is already connected."; - RAY_CHECK(local_node_info.state() == GcsNodeInfo::ALIVE); - - auto node_info_ptr = std::make_shared(local_node_info); - Status status = SyncAppend(JobID::Nil(), node_log_key_, node_info_ptr); - if (status.ok()) { - local_node_id_ = NodeID::FromBinary(local_node_info.node_id()); - local_node_info_ = local_node_info; - } - return status; -} - -Status NodeTable::Disconnect() { - local_node_info_.set_state(GcsNodeInfo::DEAD); - auto node_info_ptr = std::make_shared(local_node_info_); - Status status = SyncAppend(JobID::Nil(), node_log_key_, node_info_ptr); - - if (status.ok()) { - // We successfully added the deletion entry. Mark ourselves as disconnected. - disconnected_ = true; - } - return status; -} - -ray::Status NodeTable::MarkConnected(const GcsNodeInfo &node_info, - const WriteCallback &done) { - RAY_CHECK(node_info.state() == GcsNodeInfo::ALIVE); - auto node_info_ptr = std::make_shared(node_info); - return Append(JobID::Nil(), node_log_key_, node_info_ptr, done); -} - -ray::Status NodeTable::MarkDisconnected(const NodeID &dead_node_id, - const WriteCallback &done) { - auto node_info = std::make_shared(); - node_info->set_node_id(dead_node_id.Binary()); - node_info->set_state(GcsNodeInfo::DEAD); - return Append(JobID::Nil(), node_log_key_, node_info, done); -} - -ray::Status NodeTable::SubscribeToNodeChange( - const SubscribeCallback &subscribe, const StatusCallback &done) { - // Callback for a notification from the client table. - auto on_subscribe = [this](RedisGcsClient *client, const UniqueID &log_key, - const std::vector ¬ifications) { - RAY_CHECK(log_key == node_log_key_); - std::unordered_map connected_nodes; - std::unordered_map disconnected_nodes; - for (auto ¬ification : notifications) { - // This is temporary fix for Issue 4140 to avoid connect to dead nodes. - // TODO(yuhguo): remove this temporary fix after GCS entry is removable. - if (notification.state() == GcsNodeInfo::ALIVE) { - connected_nodes.emplace(notification.node_id(), notification); - } else { - auto iter = connected_nodes.find(notification.node_id()); - if (iter != connected_nodes.end()) { - connected_nodes.erase(iter); - } - disconnected_nodes.emplace(notification.node_id(), notification); - } - } - for (const auto &pair : connected_nodes) { - HandleNotification(client, pair.second); - } - for (const auto &pair : disconnected_nodes) { - HandleNotification(client, pair.second); - } - }; - - // Callback to request notifications from the client table once we've - // successfully subscribed. - auto on_done = [this, subscribe, done](RedisGcsClient *client) { - auto on_request_notification_done = [this, subscribe, done](Status status) { - RAY_CHECK_OK(status); - if (done != nullptr) { - done(status); - } - // Register node change callbacks after RequestNotification finishes. - RegisterNodeChangeCallback(subscribe); - }; - RAY_CHECK_OK(RequestNotifications(JobID::Nil(), node_log_key_, subscribe_id_, - on_request_notification_done)); - }; - - // Subscribe to the client table. - return Subscribe(JobID::Nil(), subscribe_id_, on_subscribe, on_done); -} - -bool NodeTable::GetNode(const NodeID &node_id, GcsNodeInfo *node_info) const { - RAY_CHECK(!node_id.IsNil()); - auto entry = node_cache_.find(node_id); - auto found = (entry != node_cache_.end()); - if (found) { - *node_info = entry->second; - } - return found; -} - -const std::unordered_map &NodeTable::GetAllNodes() const { - return node_cache_; -} - -Status NodeTable::Lookup(const Callback &lookup) { - RAY_CHECK(lookup != nullptr); - return Log::Lookup(JobID::Nil(), node_log_key_, lookup); -} - -std::string NodeTable::DebugString() const { - std::stringstream result; - result << Log::DebugString(); - result << ", cache size: " << node_cache_.size() - << ", num removed: " << removed_nodes_.size(); - return result.str(); -} - -Status TaskLeaseTable::Subscribe(const JobID &job_id, const NodeID &node_id, - const Callback &subscribe, - const SubscriptionCallback &done) { - auto on_subscribe = [subscribe](RedisGcsClient *client, const TaskID &task_id, - const std::vector &data) { - std::vector> result; - for (const auto &item : data) { - boost::optional optional_item(item); - result.emplace_back(std::move(optional_item)); - } - if (result.empty()) { - boost::optional optional_item; - result.emplace_back(std::move(optional_item)); - } - subscribe(client, task_id, result); - }; - return Table::Subscribe(job_id, node_id, on_subscribe, done); -} - -std::vector SyncGetAllActorID(redisContext *redis_context, - const std::string &table_prefix) { - std::unordered_set actor_id_set; - size_t cursor = 0; - do { - auto r = redisCommand(redis_context, "SCAN %d match %s* count 100", cursor, - table_prefix.c_str()); - auto reply = reinterpret_cast(r); - RAY_CHECK(reply != nullptr && reply->type == REDIS_REPLY_ARRAY); - RAY_CHECK(reply->elements == 2); - - // current cursor - redisReply *cursor_reply = reply->element[0]; - RAY_CHECK(cursor_reply != nullptr && cursor_reply->type == REDIS_REPLY_STRING); - cursor = std::stoi(std::string(cursor_reply->str, cursor_reply->len)); - - // actor ids - redisReply *array_reply = reply->element[1]; - RAY_CHECK(array_reply != nullptr && array_reply->type == REDIS_REPLY_ARRAY); - for (size_t i = 0; i < array_reply->elements; ++i) { - redisReply *id_reply = array_reply->element[i]; - RAY_CHECK(id_reply != nullptr && id_reply->type == REDIS_REPLY_STRING); - auto id_with_prefix = std::string(id_reply->str, id_reply->len); - // The key of actor_checkpoint table and actor_checkpoint_id table have the same - // prefix of `ACTOR`, so we should check the length of the key to filter them. - if (id_with_prefix.size() == table_prefix.size() + ActorID::Size()) { - auto id = ActorID::FromBinary(id_with_prefix.substr(table_prefix.size())); - actor_id_set.emplace(id); - } - } - } while (cursor != 0); - std::vector actor_id_list; - actor_id_list.reserve(actor_id_set.size()); - actor_id_list.insert(actor_id_list.end(), actor_id_set.begin(), actor_id_set.end()); - return actor_id_list; -} - -std::vector LogBasedActorTable::GetAllActorID() { - auto redis_context = client_->primary_context()->sync_context(); - return SyncGetAllActorID(redis_context, TablePrefix_Name(prefix_)); -} - -Status LogBasedActorTable::Get(const ray::ActorID &actor_id, - ray::rpc::ActorTableData *actor_table_data) { - RAY_CHECK(actor_table_data != nullptr); - auto key = TablePrefix_Name(prefix_) + actor_id.Binary(); - auto reply = GetRedisContext(actor_id)->RunArgvSync({"LRANGE", key, "-1", "-1"}); - if (!reply || reply->IsNil()) { - return Status::IOError("Failed to get actor data by actor_id " + actor_id.Hex()); - } - - const auto &data_list = reply->ReadAsStringArray(); - if (data_list.empty()) { - return Status::IOError("Failed to get actor data by actor_id " + actor_id.Hex()); - } - - RAY_CHECK(data_list.size() == 1); - actor_table_data->ParseFromString(data_list.front()); - return Status::OK(); -} - -std::vector ActorTable::GetAllActorID() { - auto redis_context = client_->primary_context()->sync_context(); - return SyncGetAllActorID(redis_context, TablePrefix_Name(prefix_)); -} - -Status ActorTable::Get(const ray::ActorID &actor_id, - ray::rpc::ActorTableData *actor_table_data) { - RAY_CHECK(actor_table_data != nullptr); - auto key = TablePrefix_Name(prefix_) + actor_id.Binary(); - auto reply = GetRedisContext(actor_id)->RunArgvSync({"GET", key}); - if (!reply || reply->IsNil()) { - return Status::IOError("Failed to get actor data by actor_id " + actor_id.Hex()); - } - actor_table_data->ParseFromString(reply->ReadAsString()); - return Status::OK(); -} - -template class Log; -template class Set; -template class Log; -template class Table; -template class Log; -template class Log; -template class Table; -template class Table; -template class Table; -template class Log; -template class Log; -template class Log; -template class Log; -template class Log; -template class Log; -template class Table; -template class Table; - -template class Log; -template class Hash; - -} // namespace gcs - -} // namespace ray diff --git a/src/ray/gcs/tables.h b/src/ray/gcs/tables.h deleted file mode 100644 index c7c647162..000000000 --- a/src/ray/gcs/tables.h +++ /dev/null @@ -1,978 +0,0 @@ -// Copyright 2017 The Ray Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include -#include -#include -#include - -#include "ray/common/constants.h" -#include "ray/common/id.h" -#include "ray/common/status.h" -#include "ray/gcs/callback.h" -#include "ray/gcs/entry_change_notification.h" -#include "ray/gcs/redis_context.h" -#include "ray/util/logging.h" -#include "src/ray/protobuf/gcs.pb.h" - -struct redisAsyncContext; - -namespace ray { - -namespace gcs { - -using rpc::ActorTableData; -using rpc::ErrorTableData; -using rpc::GcsChangeMode; -using rpc::GcsEntry; -using rpc::GcsNodeInfo; -using rpc::HeartbeatTableData; -using rpc::JobTableData; -using rpc::ObjectTableData; -using rpc::ProfileTableData; -using rpc::ResourceTableData; -using rpc::ResourceUsageBatchData; -using rpc::TablePrefix; -using rpc::TablePubsub; -using rpc::TaskLeaseData; -using rpc::TaskReconstructionData; -using rpc::TaskTableData; -using rpc::WorkerTableData; - -class RedisContext; - -class RedisGcsClient; - -/// Specifies whether commands issued to a table should be regular or chain-replicated -/// (when available). -enum class CommandType { kRegular, kChain, kUnknown }; - -/// \class PubsubInterface -/// -/// The interface for a pubsub storage system. The client of a storage system -/// that implements this interface can request and cancel notifications for -/// specific keys. -template -class PubsubInterface { - public: - virtual Status RequestNotifications(const JobID &job_id, const ID &id, - const NodeID &node_id, - const StatusCallback &done) = 0; - virtual Status CancelNotifications(const JobID &job_id, const ID &id, - const NodeID &node_id, - const StatusCallback &done) = 0; - virtual ~PubsubInterface(){}; -}; - -template -class LogInterface { - public: - using WriteCallback = - std::function; - virtual Status Append(const JobID &job_id, const ID &id, - const std::shared_ptr &data, const WriteCallback &done) = 0; - virtual Status AppendAt(const JobID &job_id, const ID &id, - const std::shared_ptr &data, const WriteCallback &done, - const WriteCallback &failure, int log_length) = 0; - virtual ~LogInterface(){}; -}; - -/// \class Log -/// -/// A GCS table where every entry is an append-only log. This class is not -/// meant to be used directly. All log classes should derive from this class -/// and override the prefix_ member with a unique prefix for that log, and the -/// pubsub_channel_ member if pubsub is required. -/// -/// Example tables backed by Log: -/// NodeTable: Stores a log of which GCS clients have been added or deleted -/// from the system. -template -class Log : public LogInterface, virtual public PubsubInterface { - public: - using Callback = std::function &data)>; - - using NotificationCallback = - std::function &data)>; - - /// The callback to call when a write to a key succeeds. - using WriteCallback = typename LogInterface::WriteCallback; - /// The callback to call when a SUBSCRIBE call completes and we are ready to - /// request and receive notifications. - using SubscriptionCallback = std::function; - - struct CallbackData { - ID id; - std::shared_ptr data; - Callback callback; - // An optional callback to call for subscription operations, where the - // first message is a notification of subscription success. - SubscriptionCallback subscription_callback; - Log *log; - RedisGcsClient *client; - }; - - Log(const std::vector> &contexts, RedisGcsClient *client) - : shard_contexts_(contexts), - client_(client), - pubsub_channel_(TablePubsub::NO_PUBLISH), - prefix_(TablePrefix::UNUSED), - subscribe_callback_index_(-1){}; - - /// Append a log entry to a key. - /// - /// \param job_id The ID of the job. - /// \param id The ID of the data that is added to the GCS. - /// \param data Data to append to the log. TODO(rkn): This can be made const, - /// right? - /// \param done Callback that is called once the data has been written to the - /// GCS. - /// \return Status - Status Append(const JobID &job_id, const ID &id, const std::shared_ptr &data, - const WriteCallback &done); - - /// Append a log entry to a key synchronously. - /// - /// \param job_id The ID of the job. - /// \param id The ID of the data that is added to the GCS. - /// \param data Data to append to the log. - /// \return Status - Status SyncAppend(const JobID &job_id, const ID &id, const std::shared_ptr &data); - - /// Append a log entry to a key if and only if the log has the given number - /// of entries. - /// - /// \param job_id The ID of the job. - /// \param id The ID of the data that is added to the GCS. - /// \param data Data to append to the log. - /// \param done Callback that is called if the data was appended to the log. - /// \param failure Callback that is called if the data was not appended to - /// the log because the log length did not match the given `log_length`. - /// \param log_length The number of entries that the log must have for the - /// append to succeed. - /// \return Status - Status AppendAt(const JobID &job_id, const ID &id, const std::shared_ptr &data, - const WriteCallback &done, const WriteCallback &failure, - int log_length); - - /// Lookup the log values at a key asynchronously. - /// - /// \param job_id The ID of the job. - /// \param id The ID of the data that is looked up in the GCS. - /// \param lookup Callback that is called after lookup. If the callback is - /// called with an empty vector, then there was no data at the key. - /// \return Status - Status Lookup(const JobID &job_id, const ID &id, const Callback &lookup); - - /// Subscribe to any Append operations to this table. The caller may choose - /// requests notifications for. This may only be called once per Log - /// - /// \param job_id The ID of the job. - /// \param node_id The type of update to listen to. If this is nil, then a - /// message for each Add to the table will be received. Else, only - /// messages for the given node will be received. In the latter - /// case, the node may request notifications on specific keys in the - /// table via `RequestNotifications`. - /// \param subscribe Callback that is called on each received message. If the - /// callback is called with an empty vector, then there was no data at the key. - /// \param done Callback that is called when subscription is complete and we - /// are ready to receive messages. - /// \return Status - Status Subscribe(const JobID &job_id, const NodeID &node_id, const Callback &subscribe, - const SubscriptionCallback &done); - - /// Request notifications about a key in this table. - /// - /// The notifications will be returned via the subscribe callback that was - /// registered by `Subscribe`. An initial notification will be returned for - /// the current values at the key, if any, and a subsequent notification will - /// be published for every following `Append` to the key. Before - /// notifications can be requested, the caller must first call `Subscribe`, - /// with the same `node_id`. - /// - /// \param job_id The ID of the job. - /// \param id The ID of the key to request notifications for. - /// \param node_id The node who is requesting notifications. - /// \param done Callback that is called when request notifications is complete. - /// notifications can be requested, a call to `Subscribe` to this - /// table with the same `node_id` must complete successfully. - /// \return Status - Status RequestNotifications(const JobID &job_id, const ID &id, const NodeID &node_id, - const StatusCallback &done); - - /// Cancel notifications about a key in this table. - /// - /// \param job_id The ID of the job. - /// \param id The ID of the key to request notifications for. - /// \param node_id The node who originally requested notifications. - /// \param done Callback that is called when cancel notifications is complete. - /// \return Status - Status CancelNotifications(const JobID &job_id, const ID &id, const NodeID &node_id, - const StatusCallback &done); - - /// Subscribe to any modifications to the key. The caller may choose - /// to subscribe to all modifications, or to subscribe only to keys that it - /// requests notifications for. This may only be called once per Log - /// instance. This function is different from public version due to - /// an additional parameter change_mode in NotificationCallback. Therefore this - /// function supports notifications of remove operations. - /// - /// \param job_id The ID of the job. - /// \param node_id The type of update to listen to. If this is nil, then a - /// message for each Add to the table will be received. Else, only - /// messages for the given node will be received. In the latter - /// case, the node may request notifications on specific keys in the - /// table via `RequestNotifications`. - /// \param subscribe Callback that is called on each received message. If the - /// callback is called with an empty vector, then there was no data at the key. - /// \param done Callback that is called when subscription is complete and we - /// are ready to receive messages. - /// \return Status - Status Subscribe(const JobID &job_id, const NodeID &node_id, - const NotificationCallback &subscribe, - const SubscriptionCallback &done); - - /// Delete an entire key from redis. - /// - /// \param job_id The ID of the job. - /// \param id The ID of the data to delete from the GCS. - /// \return Void. - void Delete(const JobID &job_id, const ID &id); - - /// Delete several keys from redis. - /// - /// \param job_id The ID of the job. - /// \param ids The vector of IDs to delete from the GCS. - /// \return Void. - void Delete(const JobID &job_id, const std::vector &ids); - - /// Returns debug string for class. - /// - /// \return string. - std::string DebugString() const; - - protected: - std::shared_ptr GetRedisContext(const ID &id) { - static std::hash index; - return shard_contexts_[index(id) % shard_contexts_.size()]; - } - - /// The connection to the GCS. - std::vector> shard_contexts_; - /// The GCS client. - RedisGcsClient *client_; - /// The pubsub channel to subscribe to for notifications about keys in this - /// table. If no notifications are required, this should be set to - /// TablePubsub_NO_PUBLISH. If notifications are required, then this must be - /// unique across all instances of Log. - TablePubsub pubsub_channel_; - /// The prefix to use for keys in this table. This must be unique across all - /// instances of Log. - TablePrefix prefix_; - /// The index in the RedisCallbackManager for the callback that is called - /// when we receive notifications. This is >= 0 iff we have subscribed to the - /// table, otherwise -1. - int64_t subscribe_callback_index_; - - /// Commands to a GCS table can either be regular (default) or chain-replicated. - CommandType command_type_ = CommandType::kRegular; - - int64_t num_appends_ = 0; - int64_t num_lookups_ = 0; -}; - -template -class TableInterface { - public: - using WriteCallback = typename Log::WriteCallback; - virtual Status Add(const JobID &job_id, const ID &task_id, - const std::shared_ptr &data, const WriteCallback &done) = 0; - virtual ~TableInterface(){}; -}; - -/// \class Table -/// -/// A GCS table where every entry is a single data item. This class is not -/// meant to be used directly. All table classes should derive from this class -/// and override the prefix_ member with a unique prefix for that table, and -/// the pubsub_channel_ member if pubsub is required. -/// -/// Example tables backed by Log: -/// TaskTable: Stores Task metadata needed for executing the task. -template -class Table : private Log, - public TableInterface, - virtual public PubsubInterface { - public: - using Callback = - std::function; - using WriteCallback = typename Log::WriteCallback; - /// The callback to call when a Lookup call returns an empty entry. - using FailureCallback = std::function; - /// The callback to call when a Subscribe call completes and we are ready to - /// request and receive notifications. - using SubscriptionCallback = typename Log::SubscriptionCallback; - - Table(const std::vector> &contexts, - RedisGcsClient *client) - : Log(contexts, client) {} - - using Log::RequestNotifications; - using Log::CancelNotifications; - /// Expose this interface for use by subscription tools class SubscriptionExecutor. - /// In this way TaskTable() can also reuse class SubscriptionExecutor. - using Log::Subscribe; - - /// Add an entry to the table. This overwrites any existing data at the key. - /// - /// \param job_id The ID of the job. - /// \param id The ID of the data that is added to the GCS. - /// \param data Data that is added to the GCS. - /// \param done Callback that is called once the data has been written to the - /// GCS. - /// \return Status - Status Add(const JobID &job_id, const ID &id, const std::shared_ptr &data, - const WriteCallback &done); - - /// Lookup an entry asynchronously. - /// - /// \param job_id The ID of the job. - /// \param id The ID of the data that is looked up in the GCS. - /// \param lookup Callback that is called after lookup if there was data the - /// key. - /// \param failure Callback that is called after lookup if there was no data - /// at the key. - /// \return Status - Status Lookup(const JobID &job_id, const ID &id, const Callback &lookup, - const FailureCallback &failure); - - /// Subscribe to any Add operations to this table. The caller may choose to - /// subscribe to all Adds, or to subscribe only to keys that it requests - /// notifications for. This may only be called once per Table instance. - /// - /// \param job_id The ID of the job. - /// \param node_id The type of update to listen to. If this is nil, then a - /// message for each Add to the table will be received. Else, only - /// messages for the given node will be received. In the latter - /// case, the node may request notifications on specific keys in the - /// table via `RequestNotifications`. - /// \param subscribe Callback that is called on each received message. If the - /// callback is called with an empty vector, then there was no data at the key. - /// \param failure Callback that is called if the key is empty at the time - /// that notifications are requested. - /// \param done Callback that is called when subscription is complete and we - /// are ready to receive messages. - /// \return Status - Status Subscribe(const JobID &job_id, const NodeID &node_id, const Callback &subscribe, - const FailureCallback &failure, const SubscriptionCallback &done); - - /// Subscribe to any Add operations to this table. The caller may choose to - /// subscribe to all Adds, or to subscribe only to keys that it requests - /// notifications for. This may only be called once per Table instance. - /// - /// \param job_id The ID of the job. - /// \param node_id The type of update to listen to. If this is nil, then a - /// message for each Add to the table will be received. Else, only - /// messages for the given node will be received. In the latter - /// case, the node may request notifications on specific keys in the - /// table via `RequestNotifications`. - /// \param subscribe Callback that is called on each received message. If the - /// callback is called with an empty vector, then there was no data at the key. - /// \param done Callback that is called when subscription is complete and we - /// are ready to receive messages. - /// \return Status - Status Subscribe(const JobID &job_id, const NodeID &node_id, const Callback &subscribe, - const SubscriptionCallback &done); - - void Delete(const JobID &job_id, const ID &id) { Log::Delete(job_id, id); } - - void Delete(const JobID &job_id, const std::vector &ids) { - Log::Delete(job_id, ids); - } - - /// Returns debug string for class. - /// - /// \return string. - std::string DebugString() const; - - protected: - using Log::shard_contexts_; - using Log::client_; - using Log::pubsub_channel_; - using Log::prefix_; - using Log::command_type_; - using Log::GetRedisContext; - - int64_t num_adds_ = 0; - int64_t num_lookups_ = 0; -}; - -template -class SetInterface { - public: - using WriteCallback = typename Log::WriteCallback; - virtual Status Add(const JobID &job_id, const ID &id, const std::shared_ptr &data, - const WriteCallback &done) = 0; - virtual Status Remove(const JobID &job_id, const ID &id, - const std::shared_ptr &data, const WriteCallback &done) = 0; - virtual ~SetInterface(){}; -}; - -/// \class Set -/// -/// A GCS table where every entry is an addable & removable set. This class is not -/// meant to be used directly. All set classes should derive from this class -/// and override the prefix_ member with a unique prefix for that set, and the -/// pubsub_channel_ member if pubsub is required. -/// -/// Example tables backed by Set: -/// ObjectTable: Stores a set of which clients have added an object. -template -class Set : private Log, - public SetInterface, - virtual public PubsubInterface { - public: - using Callback = typename Log::Callback; - using WriteCallback = typename Log::WriteCallback; - using SubscriptionCallback = typename Log::SubscriptionCallback; - - Set(const std::vector> &contexts, RedisGcsClient *client) - : Log(contexts, client) {} - - using Log::RequestNotifications; - using Log::CancelNotifications; - using Log::Lookup; - using Log::Delete; - - /// Add an entry to the set. - /// - /// \param job_id The ID of the job. - /// \param id The ID of the data that is added to the GCS. - /// \param data Data to add to the set. - /// \param done Callback that is called once the data has been written to the - /// GCS. - /// \return Status - Status Add(const JobID &job_id, const ID &id, const std::shared_ptr &data, - const WriteCallback &done); - - /// Remove an entry from the set. - /// - /// \param job_id The ID of the job. - /// \param id The ID of the data that is removed from the GCS. - /// \param data Data to remove from the set. - /// \param done Callback that is called once the data has been written to the - /// GCS. - /// \return Status - Status Remove(const JobID &job_id, const ID &id, const std::shared_ptr &data, - const WriteCallback &done); - - using NotificationCallback = - std::function> &data)>; - /// Subscribe to any add or remove operations to this table. - /// - /// \param job_id The ID of the job. - /// \param node_id The type of update to listen to. If this is nil, then a - /// message for each add or remove to the table will be received. Else, only - /// messages for the given node will be received. In the latter - /// case, the node may request notifications on specific keys in the - /// table via `RequestNotifications`. - /// \param subscribe Callback that is called on each received message. - /// \param done Callback that is called when subscription is complete and we - /// are ready to receive messages. - /// \return Status - Status Subscribe(const JobID &job_id, const NodeID &node_id, - const NotificationCallback &subscribe, - const SubscriptionCallback &done); - - /// Returns debug string for class. - /// - /// \return string. - std::string DebugString() const; - - protected: - using Log::shard_contexts_; - using Log::client_; - using Log::pubsub_channel_; - using Log::prefix_; - using Log::GetRedisContext; - - int64_t num_adds_ = 0; - int64_t num_removes_ = 0; - using Log::num_lookups_; -}; - -template -class HashInterface { - public: - using DataMap = std::unordered_map>; - // Reuse Log's SubscriptionCallback when Subscribe is successfully called. - using SubscriptionCallback = typename Log::SubscriptionCallback; - - /// The callback function used by function Update & Lookup. - /// - /// \param client The client on which the RemoveEntries is called. - /// \param id The ID of the Hash Table whose entries are removed. - /// \param data Map data contains the change to the Hash Table. - /// \return Void - using HashCallback = - std::function; - - /// The callback function used by function RemoveEntries. - /// - /// \param client The client on which the RemoveEntries is called. - /// \param id The ID of the Hash Table whose entries are removed. - /// \param keys The keys that are moved from this Hash Table. - /// \return Void - using HashRemoveCallback = std::function &keys)>; - - /// The notification function used by function Subscribe. - /// - /// \param client The client on which the Subscribe is called. - /// \param change_mode The mode to identify the data is removed or updated. - /// \param data Map data contains the change to the Hash Table. - /// \return Void - using HashNotificationCallback = - std::function> &data)>; - - /// Add entries of a hash table. - /// - /// \param job_id The ID of the job. - /// \param id The ID of the data that is added to the GCS. - /// \param pairs Map data to add to the hash table. - /// \param done HashCallback that is called once the request data has been written to - /// the GCS. - /// \return Status - virtual Status Update(const JobID &job_id, const ID &id, const DataMap &pairs, - const HashCallback &done) = 0; - - /// Remove entries from the hash table. - /// - /// \param job_id The ID of the job. - /// \param id The ID of the data that is removed from the GCS. - /// \param keys The entry keys of the hash table. - /// \param remove_callback HashRemoveCallback that is called once the data has been - /// written to the GCS no matter whether the key exists in the hash table. - /// \return Status - virtual Status RemoveEntries(const JobID &job_id, const ID &id, - const std::vector &keys, - const HashRemoveCallback &remove_callback) = 0; - - /// Lookup the map data of a hash table. - /// - /// \param job_id The ID of the job. - /// \param id The ID of the data that is looked up in the GCS. - /// \param lookup HashCallback that is called after lookup. If the callback is - /// called with an empty hash table, then there was no data in the callback. - /// \return Status - virtual Status Lookup(const JobID &job_id, const ID &id, - const HashCallback &lookup) = 0; - - /// Subscribe to any Update or Remove operations to this hash table. - /// - /// \param job_id The ID of the job. - /// \param node_id The type of update to listen to. If this is nil, then a - /// message for each Update to the table will be received. Else, only - /// messages for the given node will be received. In the latter - /// case, the node may request notifications on specific keys in the - /// table via `RequestNotifications`. - /// \param subscribe HashNotificationCallback that is called on each received message. - /// \param done SubscriptionCallback that is called when subscription is complete and - /// we are ready to receive messages. - /// \return Status - virtual Status Subscribe(const JobID &job_id, const NodeID &node_id, - const HashNotificationCallback &subscribe, - const SubscriptionCallback &done) = 0; - - virtual ~HashInterface(){}; -}; - -template -class Hash : private Log, - public HashInterface, - virtual public PubsubInterface { - public: - using DataMap = std::unordered_map>; - using HashCallback = typename HashInterface::HashCallback; - using HashRemoveCallback = typename HashInterface::HashRemoveCallback; - using HashNotificationCallback = - typename HashInterface::HashNotificationCallback; - using SubscriptionCallback = typename Log::SubscriptionCallback; - - Hash(const std::vector> &contexts, RedisGcsClient *client) - : Log(contexts, client) {} - - using Log::RequestNotifications; - using Log::CancelNotifications; - - Status Update(const JobID &job_id, const ID &id, const DataMap &pairs, - const HashCallback &done) override; - - Status Subscribe(const JobID &job_id, const NodeID &node_id, - const HashNotificationCallback &subscribe, - const SubscriptionCallback &done) override; - - Status Lookup(const JobID &job_id, const ID &id, const HashCallback &lookup) override; - - Status RemoveEntries(const JobID &job_id, const ID &id, - const std::vector &keys, - const HashRemoveCallback &remove_callback) override; - - /// Returns debug string for class. - /// - /// \return string. - std::string DebugString() const; - - protected: - using Log::shard_contexts_; - using Log::client_; - using Log::pubsub_channel_; - using Log::prefix_; - using Log::subscribe_callback_index_; - using Log::GetRedisContext; - - int64_t num_adds_ = 0; - int64_t num_removes_ = 0; - using Log::num_lookups_; -}; - -class DynamicResourceTable : public Hash { - public: - DynamicResourceTable(const std::vector> &contexts, - RedisGcsClient *client) - : Hash(contexts, client) { - pubsub_channel_ = TablePubsub::NODE_RESOURCE_PUBSUB; - prefix_ = TablePrefix::NODE_RESOURCE; - }; - - virtual ~DynamicResourceTable(){}; -}; - -class ObjectTable : public Set { - public: - ObjectTable(const std::vector> &contexts, - RedisGcsClient *client) - : Set(contexts, client) { - pubsub_channel_ = TablePubsub::OBJECT_PUBSUB; - prefix_ = TablePrefix::OBJECT; - }; - - virtual ~ObjectTable(){}; -}; - -class HeartbeatTable : public Table { - public: - HeartbeatTable(const std::vector> &contexts, - RedisGcsClient *client) - : Table(contexts, client) { - pubsub_channel_ = TablePubsub::HEARTBEAT_PUBSUB; - prefix_ = TablePrefix::HEARTBEAT; - } - virtual ~HeartbeatTable() {} -}; - -class ResourceUsageBatchTable : public Table { - public: - ResourceUsageBatchTable(const std::vector> &contexts, - RedisGcsClient *client) - : Table(contexts, client) { - pubsub_channel_ = TablePubsub::RESOURCE_USAGE_BATCH_PUBSUB; - prefix_ = TablePrefix::RESOURCE_USAGE_BATCH; - } - virtual ~ResourceUsageBatchTable() {} -}; - -class JobTable : public Log { - public: - JobTable(const std::vector> &contexts, - RedisGcsClient *client) - : Log(contexts, client) { - pubsub_channel_ = TablePubsub::JOB_PUBSUB; - prefix_ = TablePrefix::JOB; - }; - - virtual ~JobTable() {} -}; - -/// Log-based Actor table starts with an ALIVE entry, which represents the first time the -/// actor is created. This may be followed by 0 or more pairs of RESTARTING, ALIVE -/// entries, which represent each time the actor fails (RESTARTING) and gets recreated -/// (ALIVE). These may be followed by a DEAD entry, which means that the actor has failed -/// and will not be reconstructed. -class LogBasedActorTable : public Log { - public: - LogBasedActorTable(const std::vector> &contexts, - RedisGcsClient *client) - : Log(contexts, client) { - pubsub_channel_ = TablePubsub::ACTOR_PUBSUB; - prefix_ = TablePrefix::ACTOR; - } - - /// Get all actor id synchronously. - std::vector GetAllActorID(); - - /// Get actor table data by actor id synchronously. - Status Get(const ActorID &actor_id, ActorTableData *actor_table_data); -}; - -/// Actor table. -/// This table is only used for GCS-based actor management. And when completely migrate to -/// GCS service, the log-based actor table could be removed. -class ActorTable : public Table { - public: - ActorTable(const std::vector> &contexts, - RedisGcsClient *client) - : Table(contexts, client) { - pubsub_channel_ = TablePubsub::ACTOR_PUBSUB; - prefix_ = TablePrefix::ACTOR; - } - - /// Get all actor id synchronously. - std::vector GetAllActorID(); - - /// Get actor table data by actor id synchronously. - Status Get(const ActorID &actor_id, ActorTableData *actor_table_data); -}; - -class WorkerTable : public Table { - public: - WorkerTable(const std::vector> &contexts, - RedisGcsClient *client) - : Table(contexts, client) { - pubsub_channel_ = TablePubsub::WORKER_FAILURE_PUBSUB; - prefix_ = TablePrefix::WORKERS; - } - virtual ~WorkerTable() {} -}; - -class TaskReconstructionLog : public Log { - public: - TaskReconstructionLog(const std::vector> &contexts, - RedisGcsClient *client) - : Log(contexts, client) { - prefix_ = TablePrefix::TASK_RECONSTRUCTION; - } -}; - -class TaskLeaseTable : public Table { - public: - /// Use boost::optional to represent subscription results, so that we can - /// notify raylet whether the entry of task lease is empty. - using Callback = - std::function> &data)>; - - TaskLeaseTable(const std::vector> &contexts, - RedisGcsClient *client) - : Table(contexts, client) { - pubsub_channel_ = TablePubsub::TASK_LEASE_PUBSUB; - prefix_ = TablePrefix::TASK_LEASE; - } - - Status Add(const JobID &job_id, const TaskID &id, - const std::shared_ptr &data, - const WriteCallback &done) override { - RAY_RETURN_NOT_OK((Table::Add(job_id, id, data, done))); - // Mark the entry for expiration in Redis. It's okay if this command fails - // since the lease entry itself contains the expiration period. In the - // worst case, if the command fails, then a client that looks up the lease - // entry will overestimate the expiration time. - // TODO(swang): Use a common helper function to format the key instead of - // hardcoding it to match the Redis module. - std::vector args = {"PEXPIRE", TablePrefix_Name(prefix_) + id.Binary(), - std::to_string(data->timeout())}; - - return GetRedisContext(id)->RunArgvAsync(args); - } - - /// Implement this method for the subscription tools class SubscriptionExecutor. - /// In this way TaskLeaseTable() can also reuse class SubscriptionExecutor. - Status Subscribe(const JobID &job_id, const NodeID &node_id, const Callback &subscribe, - const SubscriptionCallback &done); -}; - -namespace raylet { - -class TaskTable : public Table { - public: - TaskTable(const std::vector> &contexts, - RedisGcsClient *client) - : Table(contexts, client) { - pubsub_channel_ = TablePubsub::RAYLET_TASK_PUBSUB; - prefix_ = TablePrefix::RAYLET_TASK; - } - - TaskTable(const std::vector> &contexts, - RedisGcsClient *client, gcs::CommandType command_type) - : TaskTable(contexts, client) { - command_type_ = command_type; - }; -}; - -} // namespace raylet - -class ProfileTable : public Log { - public: - ProfileTable(const std::vector> &contexts, - RedisGcsClient *client) - : Log(contexts, client) { - prefix_ = TablePrefix::PROFILE; - }; - - /// Returns debug string for class. - /// - /// \return string. - std::string DebugString() const; -}; - -/// \class NodeTable -/// -/// The NodeTable stores information about active and inactive nodes. It is -/// structured as a single log stored at a key known to all nodes. When a -/// node connects, it appends an entry to the log indicating that it is -/// alive. When a node disconnects, or if another node detects its failure, -/// it should append an entry to the log indicating that it is dead. A node -/// that is marked as dead should never again be marked as alive; if it needs -/// to reconnect, it must connect with a different NodeID. -class NodeTable : public Log { - public: - NodeTable(const std::vector> &contexts, - RedisGcsClient *client) - : Log(contexts, client) { - pubsub_channel_ = TablePubsub::NODE_PUBSUB; - prefix_ = TablePrefix::NODE; - }; - - /// Connect as a NODE to the GCS. This registers us in the NODE table - /// and begins subscription to NODE table notifications. - /// - /// \param local_node_info Information about the connecting NODE. This must have the - /// same id as the one set in the NODE table. - /// \return Status - ray::Status Connect(const GcsNodeInfo &local_node_info); - - /// Disconnect the NODE from the GCS. The NODE ID assigned during - /// registration should never be reused after disconnecting. - /// - /// \return Status - ray::Status Disconnect(); - - /// Mark a new node as connected to GCS asynchronously. - /// - /// \param node_info Information about the node. - /// \param done Callback that is called once the node has been marked to connected. - /// \return Status - ray::Status MarkConnected(const GcsNodeInfo &node_info, const WriteCallback &done); - - /// Mark a different node as disconnected. The NODE ID should never be - /// reused for a new node. - /// - /// \param dead_node_id The ID of the node to mark as dead. - /// \param done Callback that is called once the node has been marked to - /// disconnected. - /// \return Status - ray::Status MarkDisconnected(const NodeID &dead_node_id, const WriteCallback &done); - - ray::Status SubscribeToNodeChange( - const SubscribeCallback &subscribe, - const StatusCallback &done); - - /// Get a node's information from the cache. The cache only contains - /// information for nodes that we've heard a notification for. - /// - /// \param node The node to get information about. - /// \param node_info The node information will be copied here if - /// we have the node in the cache. - /// a nil node ID. - /// \return Whether the node is in the cache. - bool GetNode(const NodeID &node, GcsNodeInfo *node_info) const; - - /// Get the local node's ID. - /// - /// \return The local node's ID. - const NodeID &GetLocalNodeId() const; - - /// Get the local node's information. - /// - /// \return The local node's information. - const GcsNodeInfo &GetLocalNode() const; - - /// Check whether the given node is removed. - /// - /// \param node_id The ID of the node to check. - /// \return Whether the node with specified ID is removed. - bool IsRemoved(const NodeID &node_id) const; - - /// Get the information of all nodes. - /// - /// \return The node ID to node information map. - const std::unordered_map &GetAllNodes() const; - - /// Lookup the node data in the node table. - /// - /// \param lookup Callback that is called after lookup. If the callback is - /// called with an empty vector, then there was no data at the key. - /// \return Status. - Status Lookup(const Callback &lookup); - - /// Returns debug string for class. - /// - /// \return string. - std::string DebugString() const; - - /// The key at which the log of node information is stored. This key must - /// be kept the same across all instances of the NodeTable, so that all - /// nodes append and read from the same key. - NodeID node_log_key_; - - private: - using NodeChangeCallback = - std::function; - - /// Register a callback to call when a new node is added or a node is removed. - /// - /// \param callback The callback to register. - void RegisterNodeChangeCallback(const NodeChangeCallback &callback); - - /// Handle a node table notification. - void HandleNotification(RedisGcsClient *client, const GcsNodeInfo &node_info); - - /// Whether this node has called Disconnect(). - bool disconnected_{false}; - /// This node's ID. It will be initialized when we call method `Connect(...)`. - NodeID local_node_id_; - /// Information about this node. - GcsNodeInfo local_node_info_; - /// This ID is used in method `SubscribeToNodeChange(...)` to Subscribe and - /// RequestNotification. - /// The reason for not using `local_node_id_` is because it is only initialized - /// for registered nodes. - NodeID subscribe_id_{NodeID::FromRandom()}; - /// The callback to call when a new node is added or a node is removed. - NodeChangeCallback node_change_callback_{nullptr}; - /// A cache for information about all nodes. - std::unordered_map node_cache_; - /// The set of removed nodes. - std::unordered_set removed_nodes_; -}; - -} // namespace gcs - -} // namespace ray diff --git a/src/ray/gcs/test/accessor_test_base.h b/src/ray/gcs/test/accessor_test_base.h deleted file mode 100644 index 7ce8d0bfa..000000000 --- a/src/ray/gcs/test/accessor_test_base.h +++ /dev/null @@ -1,95 +0,0 @@ -// Copyright 2017 The Ray Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include -#include -#include -#include -#include - -#include "gtest/gtest.h" -#include "ray/common/test_util.h" -#include "ray/gcs/redis_accessor.h" -#include "ray/gcs/redis_gcs_client.h" - -namespace ray { - -namespace gcs { - -template -class AccessorTestBase : public ::testing::Test { - public: - AccessorTestBase() { TestSetupUtil::StartUpRedisServers(std::vector()); } - - virtual ~AccessorTestBase() { TestSetupUtil::ShutDownRedisServers(); } - - virtual void SetUp() { - GenTestData(); - - GcsClientOptions options = - GcsClientOptions("127.0.0.1", TEST_REDIS_SERVER_PORTS.front(), "", true); - gcs_client_.reset(new RedisGcsClient(options)); - RAY_CHECK_OK(gcs_client_->Connect(io_service_)); - - work_thread_.reset(new std::thread([this] { - std::unique_ptr work( - new boost::asio::io_service::work(io_service_)); - io_service_.run(); - })); - } - - virtual void TearDown() { - gcs_client_->Disconnect(); - - io_service_.stop(); - work_thread_->join(); - work_thread_.reset(); - - gcs_client_.reset(); - - ClearTestData(); - } - - protected: - virtual void GenTestData() = 0; - - void ClearTestData() { id_to_data_.clear(); } - - void WaitPendingDone(std::chrono::milliseconds timeout) { - WaitPendingDone(pending_count_, timeout); - } - - void WaitPendingDone(std::atomic &pending_count, - std::chrono::milliseconds timeout) { - auto condition = [&pending_count]() { return pending_count == 0; }; - EXPECT_TRUE(WaitForCondition(condition, timeout.count())); - } - - protected: - std::unique_ptr gcs_client_; - - boost::asio::io_service io_service_; - std::unique_ptr work_thread_; - - std::unordered_map> id_to_data_; - - std::atomic pending_count_{0}; - std::chrono::milliseconds wait_pending_timeout_{10000}; -}; - -} // namespace gcs - -} // namespace ray diff --git a/src/ray/gcs/test/redis_actor_info_accessor_test.cc b/src/ray/gcs/test/redis_actor_info_accessor_test.cc deleted file mode 100644 index 49f474621..000000000 --- a/src/ray/gcs/test/redis_actor_info_accessor_test.cc +++ /dev/null @@ -1,82 +0,0 @@ -// Copyright 2017 The Ray Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include -#include -#include -#include -#include - -#include "gtest/gtest.h" -#include "ray/common/test_util.h" -#include "ray/gcs/redis_gcs_client.h" -#include "ray/gcs/test/accessor_test_base.h" - -namespace ray { - -namespace gcs { - -class ActorInfoAccessorTest : public AccessorTestBase { - protected: - virtual void GenTestData() { - for (size_t i = 0; i < 100; ++i) { - std::shared_ptr actor = std::make_shared(); - actor->set_max_restarts(1); - actor->set_num_restarts(0); - JobID job_id = JobID::FromInt(i); - actor->set_job_id(job_id.Binary()); - actor->set_state(ActorTableData::ALIVE); - ActorID actor_id = ActorID::Of(job_id, RandomTaskId(), /*parent_task_counter=*/i); - actor->set_actor_id(actor_id.Binary()); - id_to_data_[actor_id] = actor; - } - } - - size_t checkpoint_number_{2}; -}; - -TEST_F(ActorInfoAccessorTest, Subscribe) { - ActorInfoAccessor &actor_accessor = gcs_client_->Actors(); - // subscribe - std::atomic sub_pending_count(0); - std::atomic do_sub_pending_count(0); - auto subscribe = [this, &sub_pending_count](const ActorID &actor_id, - const ActorTableData &data) { - const auto it = id_to_data_.find(actor_id); - ASSERT_TRUE(it != id_to_data_.end()); - --sub_pending_count; - }; - auto done = [&do_sub_pending_count](Status status) { - RAY_CHECK_OK(status); - --do_sub_pending_count; - }; - - ++do_sub_pending_count; - RAY_CHECK_OK(actor_accessor.AsyncSubscribeAll(subscribe, done)); - // Wait until subscribe finishes. - WaitPendingDone(do_sub_pending_count, wait_pending_timeout_); -} - -} // namespace gcs - -} // namespace ray - -int main(int argc, char **argv) { - ::testing::InitGoogleTest(&argc, argv); - RAY_CHECK(argc == 4); - ray::TEST_REDIS_SERVER_EXEC_PATH = argv[1]; - ray::TEST_REDIS_CLIENT_EXEC_PATH = argv[2]; - ray::TEST_REDIS_MODULE_LIBRARY_PATH = argv[3]; - return RUN_ALL_TESTS(); -} diff --git a/src/ray/gcs/test/redis_gcs_client_test.cc b/src/ray/gcs/test/redis_gcs_client_test.cc deleted file mode 100644 index 771d6a703..000000000 --- a/src/ray/gcs/test/redis_gcs_client_test.cc +++ /dev/null @@ -1,1505 +0,0 @@ -// Copyright 2017 The Ray Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "ray/gcs/redis_gcs_client.h" - -#include "gtest/gtest.h" -#include "ray/common/ray_config.h" -#include "ray/common/test_util.h" -#include "ray/gcs/pb_util.h" -#include "ray/gcs/tables.h" - -extern "C" { -#include "hiredis/hiredis.h" -} - -namespace ray { - -namespace gcs { - -/* Flush redis. */ -static inline void flushall_redis(void) { - redisContext *context = redisConnect("127.0.0.1", TEST_REDIS_SERVER_PORTS.front()); - freeReplyObject(redisCommand(context, "FLUSHALL")); - redisFree(context); -} - -/// A helper function to generate an unique JobID. -inline JobID NextJobID() { - static int32_t counter = 0; - return JobID::FromInt(++counter); -} - -class TestGcs : public ::testing::Test { - public: - TestGcs(CommandType command_type) : num_callbacks_(0), command_type_(command_type) { - TestSetupUtil::StartUpRedisServers(std::vector()); - job_id_ = NextJobID(); - } - - virtual ~TestGcs() { - // Clear all keys in the GCS. - flushall_redis(); - TestSetupUtil::ShutDownRedisServers(); - }; - - virtual void Start() = 0; - - virtual void Stop() = 0; - - uint64_t NumCallbacks() const { return num_callbacks_; } - - void IncrementNumCallbacks() { num_callbacks_++; } - - protected: - uint64_t num_callbacks_; - gcs::CommandType command_type_; - std::shared_ptr client_; - JobID job_id_; -}; - -TestGcs *test; -NodeID local_node_id = NodeID::FromRandom(); - -class TestGcsWithAsio : public TestGcs { - public: - TestGcsWithAsio(CommandType command_type) - : TestGcs(command_type), io_service_(), work_(io_service_) {} - - TestGcsWithAsio() : TestGcsWithAsio(CommandType::kRegular) {} - - ~TestGcsWithAsio() { - // Destroy the client first since it has a reference to the event loop. - client_->Disconnect(); - client_.reset(); - } - - void SetUp() override { - GcsClientOptions options("127.0.0.1", TEST_REDIS_SERVER_PORTS.front(), "", true); - client_ = std::make_shared(options, command_type_); - RAY_CHECK_OK(client_->Connect(io_service_)); - } - - void Start() override { io_service_.run(); } - void Stop() override { io_service_.stop(); } - - private: - boost::asio::io_service io_service_; - // Give the event loop some work so that it's forced to run until Stop() is - // called. - boost::asio::io_service::work work_; -}; - -class TestGcsWithChainAsio : public TestGcsWithAsio { - public: - TestGcsWithChainAsio() : TestGcsWithAsio(gcs::CommandType::kChain){}; -}; - -class TaskTableTestHelper { - public: - /// A helper function that creates a GCS `TaskTableData` object. - static std::shared_ptr CreateTaskTableData(const TaskID &task_id, - uint64_t num_returns = 0) { - auto data = std::make_shared(); - data->mutable_task()->mutable_task_spec()->set_task_id(task_id.Binary()); - data->mutable_task()->mutable_task_spec()->set_num_returns(num_returns); - return data; - } - - /// A helper function that compare whether 2 `TaskTableData` objects are equal. - /// Note, this function only compares fields set by `CreateTaskTableData`. - static bool TaskTableDataEqual(const TaskTableData &data1, const TaskTableData &data2) { - const auto &spec1 = data1.task().task_spec(); - const auto &spec2 = data2.task().task_spec(); - return (spec1.task_id() == spec2.task_id() && - spec1.num_returns() == spec2.num_returns()); - } - - static void TestTableLookup(const JobID &job_id, - std::shared_ptr client) { - const auto task_id = RandomTaskId(); - const auto data = CreateTaskTableData(task_id); - - // Check that we added the correct task. - auto add_callback = [task_id, data](gcs::RedisGcsClient *client, const TaskID &id, - const TaskTableData &d) { - ASSERT_EQ(id, task_id); - ASSERT_TRUE(TaskTableDataEqual(*data, d)); - }; - - // Check that the lookup returns the added task. - auto lookup_callback = [task_id, data](gcs::RedisGcsClient *client, const TaskID &id, - const TaskTableData &d) { - ASSERT_EQ(id, task_id); - ASSERT_TRUE(TaskTableDataEqual(*data, d)); - test->Stop(); - }; - - // Check that the lookup does not return an empty entry. - auto failure_callback = [](gcs::RedisGcsClient *client, const TaskID &id) { - RAY_CHECK(false); - }; - - // Add the task, then do a lookup. - RAY_CHECK_OK(client->raylet_task_table().Add(job_id, task_id, data, add_callback)); - RAY_CHECK_OK(client->raylet_task_table().Lookup(job_id, task_id, lookup_callback, - failure_callback)); - // Run the event loop. The loop will only stop if the Lookup callback is - // called (or an assertion failure). - test->Start(); - } - - static void TestTableLookupFailure(const JobID &job_id, - std::shared_ptr client) { - TaskID task_id = RandomTaskId(); - - // Check that the lookup does not return data. - auto lookup_callback = [](gcs::RedisGcsClient *client, const TaskID &id, - const TaskTableData &d) { RAY_CHECK(false); }; - - // Check that the lookup returns an empty entry. - auto failure_callback = [task_id](gcs::RedisGcsClient *client, const TaskID &id) { - ASSERT_EQ(id, task_id); - test->Stop(); - }; - - // Lookup the task. We have not done any writes, so the key should be empty. - RAY_CHECK_OK(client->raylet_task_table().Lookup(job_id, task_id, lookup_callback, - failure_callback)); - // Run the event loop. The loop will only stop if the failure callback is - // called (or an assertion failure). - test->Start(); - } - - static void TestDeleteKeysFromTable( - const JobID &job_id, std::shared_ptr client, - std::vector> &data_vector, bool stop_at_end) { - std::vector ids; - TaskID task_id; - for (auto &data : data_vector) { - task_id = RandomTaskId(); - ids.push_back(task_id); - // Check that we added the correct object entries. - auto add_callback = [task_id, data](gcs::RedisGcsClient *client, const TaskID &id, - const TaskTableData &d) { - ASSERT_EQ(id, task_id); - ASSERT_TRUE(TaskTableDataEqual(*data, d)); - test->IncrementNumCallbacks(); - }; - RAY_CHECK_OK(client->raylet_task_table().Add(job_id, task_id, data, add_callback)); - } - for (const auto &task_id : ids) { - auto task_lookup_callback = [task_id](gcs::RedisGcsClient *client, const TaskID &id, - const TaskTableData &data) { - ASSERT_EQ(id, task_id); - test->IncrementNumCallbacks(); - }; - RAY_CHECK_OK(client->raylet_task_table().Lookup(job_id, task_id, - task_lookup_callback, nullptr)); - } - if (ids.size() == 1) { - client->raylet_task_table().Delete(job_id, ids[0]); - } else { - client->raylet_task_table().Delete(job_id, ids); - } - auto expected_failure_callback = [](RedisGcsClient *client, const TaskID &id) { - ASSERT_TRUE(true); - test->IncrementNumCallbacks(); - }; - auto undesired_callback = [](gcs::RedisGcsClient *client, const TaskID &id, - const TaskTableData &data) { ASSERT_TRUE(false); }; - for (size_t i = 0; i < ids.size(); ++i) { - RAY_CHECK_OK(client->raylet_task_table().Lookup(job_id, task_id, undesired_callback, - expected_failure_callback)); - } - if (stop_at_end) { - auto stop_callback = [](RedisGcsClient *client, const TaskID &id) { test->Stop(); }; - RAY_CHECK_OK( - client->raylet_task_table().Lookup(job_id, ids[0], nullptr, stop_callback)); - } - } - - static void TestTableSubscribeId(const JobID &job_id, - std::shared_ptr client) { - size_t num_modifications = 3; - - // Add a table entry. - TaskID task_id1 = RandomTaskId(); - - // Add a table entry at a second key. - TaskID task_id2 = RandomTaskId(); - - // The callback for a notification from the table. This should only be - // received for keys that we requested notifications for. - auto notification_callback = [task_id2, num_modifications]( - gcs::RedisGcsClient *client, const TaskID &id, - const TaskTableData &data) { - // Check that we only get notifications for the requested key. - ASSERT_EQ(id, task_id2); - // Check that we get notifications in the same order as the writes. - ASSERT_TRUE( - TaskTableDataEqual(data, *CreateTaskTableData(task_id2, test->NumCallbacks()))); - test->IncrementNumCallbacks(); - if (test->NumCallbacks() == num_modifications) { - test->Stop(); - } - }; - - // The failure callback should be called once since both keys start as empty. - bool failure_notification_received = false; - auto failure_callback = [task_id2, &failure_notification_received]( - gcs::RedisGcsClient *client, const TaskID &id) { - ASSERT_EQ(id, task_id2); - // The failure notification should be the first notification received. - ASSERT_EQ(test->NumCallbacks(), 0); - failure_notification_received = true; - }; - - // The callback for subscription success. Once we've subscribed, request - // notifications for only one of the keys, then write to both keys. - auto subscribe_callback = [job_id, task_id1, task_id2, - num_modifications](gcs::RedisGcsClient *client) { - // Request notifications for one of the keys. - RAY_CHECK_OK(client->raylet_task_table().RequestNotifications( - job_id, task_id2, local_node_id, nullptr)); - // Write both keys. We should only receive notifications for the key that - // we requested them for. - for (uint64_t i = 0; i < num_modifications; i++) { - auto data = CreateTaskTableData(task_id1, i); - RAY_CHECK_OK(client->raylet_task_table().Add(job_id, task_id1, data, nullptr)); - } - for (uint64_t i = 0; i < num_modifications; i++) { - auto data = CreateTaskTableData(task_id2, i); - RAY_CHECK_OK(client->raylet_task_table().Add(job_id, task_id2, data, nullptr)); - } - }; - - // Subscribe to notifications for this client. This allows us to request and - // receive notifications for specific keys. - RAY_CHECK_OK(client->raylet_task_table().Subscribe( - job_id, local_node_id, notification_callback, failure_callback, - subscribe_callback)); - // Run the event loop. The loop will only stop if the registered subscription - // callback is called for the requested key. - test->Start(); - // Check that the failure callback was called since the key was initially - // empty. - ASSERT_TRUE(failure_notification_received); - // Check that we received one notification callback for each write to the - // requested key. - ASSERT_EQ(test->NumCallbacks(), num_modifications); - } - - static void TestTableSubscribeCancel(const JobID &job_id, - std::shared_ptr client) { - // Add a table entry. - const auto task_id = RandomTaskId(); - uint64_t num_modifications = 3; - const auto data = CreateTaskTableData(task_id, 0); - RAY_CHECK_OK(client->raylet_task_table().Add(job_id, task_id, data, nullptr)); - - // The failure callback should not be called since all keys are non-empty - // when notifications are requested. - auto failure_callback = [](gcs::RedisGcsClient *client, const TaskID &id) { - RAY_CHECK(false); - }; - - // The callback for a notification from the table. This should only be - // received for keys that we requested notifications for. - auto notification_callback = [task_id, num_modifications](gcs::RedisGcsClient *client, - const TaskID &id, - const TaskTableData &data) { - ASSERT_EQ(id, task_id); - // Check that we only get notifications for the first and last writes, - // since notifications are canceled in between. - if (test->NumCallbacks() == 0) { - ASSERT_TRUE(TaskTableDataEqual(data, *CreateTaskTableData(task_id, 0))); - } else { - ASSERT_TRUE(TaskTableDataEqual( - data, *CreateTaskTableData(task_id, num_modifications - 1))); - } - test->IncrementNumCallbacks(); - if (test->NumCallbacks() == num_modifications - 1) { - test->Stop(); - } - }; - - // The callback for a notification from the table. This should only be - // received for keys that we requested notifications for. - auto subscribe_callback = [job_id, task_id, - num_modifications](gcs::RedisGcsClient *client) { - // Request notifications, then cancel immediately. We should receive a - // notification for the current value at the key. - RAY_CHECK_OK(client->raylet_task_table().RequestNotifications( - job_id, task_id, local_node_id, nullptr)); - RAY_CHECK_OK(client->raylet_task_table().CancelNotifications( - job_id, task_id, local_node_id, nullptr)); - // Write to the key. Since we canceled notifications, we should not receive - // a notification for these writes. - for (uint64_t i = 1; i < num_modifications; i++) { - auto data = CreateTaskTableData(task_id, i); - RAY_CHECK_OK(client->raylet_task_table().Add(job_id, task_id, data, nullptr)); - } - // Request notifications again. We should receive a notification for the - // current value at the key. - RAY_CHECK_OK(client->raylet_task_table().RequestNotifications( - job_id, task_id, local_node_id, nullptr)); - }; - - // Subscribe to notifications for this client. This allows us to request and - // receive notifications for specific keys. - RAY_CHECK_OK(client->raylet_task_table().Subscribe( - job_id, local_node_id, notification_callback, failure_callback, - subscribe_callback)); - // Run the event loop. The loop will only stop if the registered subscription - // callback is called for the requested key. - test->Start(); - // Check that we received a notification callback for the first and least - // writes to the key, since notifications are canceled in between. - ASSERT_EQ(test->NumCallbacks(), 2); - } -}; - -// Convenient macro to test across {ae, asio} x {regular, chain} x {the tests}. -// Undefined at the end. -#define TEST_TASK_TABLE_MACRO(FIXTURE, TEST) \ - TEST_F(FIXTURE, TEST) { \ - test = this; \ - TaskTableTestHelper::TEST(job_id_, client_); \ - } - -TEST_TASK_TABLE_MACRO(TestGcsWithAsio, TestTableLookup); - -class LogLookupTestHelper { - public: - static void TestLogLookup(const JobID &job_id, - std::shared_ptr client) { - // Append some entries to the log at an object ID. - TaskID task_id = RandomTaskId(); - std::vector node_manager_ids = {"abc", "def", "ghi"}; - for (auto &node_manager_id : node_manager_ids) { - auto data = std::make_shared(); - data->set_node_manager_id(node_manager_id); - // Check that we added the correct object entries. - auto add_callback = [task_id, data](gcs::RedisGcsClient *client, const TaskID &id, - const TaskReconstructionData &d) { - ASSERT_EQ(id, task_id); - ASSERT_EQ(data->node_manager_id(), d.node_manager_id()); - }; - RAY_CHECK_OK( - client->task_reconstruction_log().Append(job_id, task_id, data, add_callback)); - } - - // Check that lookup returns the added object entries. - auto lookup_callback = [task_id, node_manager_ids]( - gcs::RedisGcsClient *client, const TaskID &id, - const std::vector &data) { - ASSERT_EQ(id, task_id); - for (const auto &entry : data) { - ASSERT_EQ(entry.node_manager_id(), node_manager_ids[test->NumCallbacks()]); - test->IncrementNumCallbacks(); - } - if (test->NumCallbacks() == node_manager_ids.size()) { - test->Stop(); - } - }; - - // Do a lookup at the object ID. - RAY_CHECK_OK( - client->task_reconstruction_log().Lookup(job_id, task_id, lookup_callback)); - // Run the event loop. The loop will only stop if the Lookup callback is - // called (or an assertion failure). - test->Start(); - ASSERT_EQ(test->NumCallbacks(), node_manager_ids.size()); - } - - static void TestLogAppendAt(const JobID &job_id, - std::shared_ptr client) { - TaskID task_id = RandomTaskId(); - std::vector node_manager_ids = {"A", "B"}; - std::vector> data_log; - for (const auto &node_manager_id : node_manager_ids) { - auto data = std::make_shared(); - data->set_node_manager_id(node_manager_id); - data_log.push_back(data); - } - - // Check that we added the correct task. - auto failure_callback = [task_id](gcs::RedisGcsClient *client, const TaskID &id, - const TaskReconstructionData &d) { - ASSERT_EQ(id, task_id); - test->IncrementNumCallbacks(); - }; - - // Will succeed. - RAY_CHECK_OK(client->task_reconstruction_log().Append(job_id, task_id, - data_log.front(), - /*done callback=*/nullptr)); - // Append at index 0 will fail. - RAY_CHECK_OK(client->task_reconstruction_log().AppendAt( - job_id, task_id, data_log[1], - /*done callback=*/nullptr, failure_callback, /*log_length=*/0)); - - // Append at index 2 will fail. - RAY_CHECK_OK(client->task_reconstruction_log().AppendAt( - job_id, task_id, data_log[1], - /*done callback=*/nullptr, failure_callback, /*log_length=*/2)); - - // Append at index 1 will succeed. - RAY_CHECK_OK(client->task_reconstruction_log().AppendAt( - job_id, task_id, data_log[1], - /*done callback=*/nullptr, failure_callback, /*log_length=*/1)); - - auto lookup_callback = [node_manager_ids]( - gcs::RedisGcsClient *client, const TaskID &id, - const std::vector &data) { - std::vector appended_managers; - for (const auto &entry : data) { - appended_managers.push_back(entry.node_manager_id()); - } - ASSERT_EQ(appended_managers, node_manager_ids); - test->Stop(); - }; - RAY_CHECK_OK( - client->task_reconstruction_log().Lookup(job_id, task_id, lookup_callback)); - // Run the event loop. The loop will only stop if the Lookup callback is - // called (or an assertion failure). - test->Start(); - ASSERT_EQ(test->NumCallbacks(), 2); - } -}; - -TEST_F(TestGcsWithAsio, TestLogLookup) { - test = this; - LogLookupTestHelper::TestLogLookup(job_id_, client_); -} - -TEST_TASK_TABLE_MACRO(TestGcsWithAsio, TestTableLookupFailure); - -TEST_F(TestGcsWithAsio, TestLogAppendAt) { - test = this; - LogLookupTestHelper::TestLogAppendAt(job_id_, client_); -} - -class SetTestHelper { - public: - static void TestSet(const JobID &job_id, std::shared_ptr client) { - // Add some entries to the set at an object ID. - ObjectID object_id = ObjectID::FromRandom(); - std::vector managers = {"abc", "def", "ghi"}; - for (auto &manager : managers) { - auto data = std::make_shared(); - data->set_manager(manager); - // Check that we added the correct object entries. - auto add_callback = [object_id, data](gcs::RedisGcsClient *client, - const ObjectID &id, - const ObjectTableData &d) { - ASSERT_EQ(id, object_id); - ASSERT_EQ(data->manager(), d.manager()); - test->IncrementNumCallbacks(); - }; - RAY_CHECK_OK(client->object_table().Add(job_id, object_id, data, add_callback)); - } - - // Check that lookup returns the added object entries. - auto lookup_callback = [object_id, managers]( - gcs::RedisGcsClient *client, const ObjectID &id, - const std::vector &data) { - ASSERT_EQ(id, object_id); - ASSERT_EQ(data.size(), managers.size()); - test->IncrementNumCallbacks(); - }; - - // Do a lookup at the object ID. - RAY_CHECK_OK(client->object_table().Lookup(job_id, object_id, lookup_callback)); - - for (auto &manager : managers) { - auto data = std::make_shared(); - data->set_manager(manager); - // Check that we added the correct object entries. - auto remove_entry_callback = [object_id, data](gcs::RedisGcsClient *client, - const ObjectID &id, - const ObjectTableData &d) { - ASSERT_EQ(id, object_id); - ASSERT_EQ(data->manager(), d.manager()); - test->IncrementNumCallbacks(); - }; - RAY_CHECK_OK( - client->object_table().Remove(job_id, object_id, data, remove_entry_callback)); - } - - // Check that the entries are removed. - auto lookup_callback2 = [object_id, managers]( - gcs::RedisGcsClient *client, const ObjectID &id, - const std::vector &data) { - ASSERT_EQ(id, object_id); - ASSERT_EQ(data.size(), 0); - test->IncrementNumCallbacks(); - test->Stop(); - }; - - // Do a lookup at the object ID. - RAY_CHECK_OK(client->object_table().Lookup(job_id, object_id, lookup_callback2)); - // Run the event loop. The loop will only stop if the Lookup callback is - // called (or an assertion failure). - test->Start(); - ASSERT_EQ(test->NumCallbacks(), managers.size() * 2 + 2); - } - - static void TestDeleteKeysFromSet( - const JobID &job_id, std::shared_ptr client, - std::vector> &data_vector) { - std::vector ids; - ObjectID object_id; - for (auto &data : data_vector) { - object_id = ObjectID::FromRandom(); - ids.push_back(object_id); - // Check that we added the correct object entries. - auto add_callback = [object_id, data](gcs::RedisGcsClient *client, - const ObjectID &id, - const ObjectTableData &d) { - ASSERT_EQ(id, object_id); - ASSERT_EQ(data->manager(), d.manager()); - test->IncrementNumCallbacks(); - }; - RAY_CHECK_OK(client->object_table().Add(job_id, object_id, data, add_callback)); - } - for (const auto &object_id : ids) { - // Check that lookup returns the added object entries. - auto lookup_callback = [object_id, data_vector]( - gcs::RedisGcsClient *client, const ObjectID &id, - const std::vector &data) { - ASSERT_EQ(id, object_id); - ASSERT_EQ(data.size(), 1); - test->IncrementNumCallbacks(); - }; - RAY_CHECK_OK(client->object_table().Lookup(job_id, object_id, lookup_callback)); - } - if (ids.size() == 1) { - client->object_table().Delete(job_id, ids[0]); - } else { - client->object_table().Delete(job_id, ids); - } - for (const auto &object_id : ids) { - auto lookup_callback = [object_id](gcs::RedisGcsClient *client, const ObjectID &id, - const std::vector &data) { - ASSERT_EQ(id, object_id); - ASSERT_TRUE(data.size() == 0); - test->IncrementNumCallbacks(); - }; - RAY_CHECK_OK(client->object_table().Lookup(job_id, object_id, lookup_callback)); - } - } - - static void TestSetSubscribeAll(const JobID &job_id, - std::shared_ptr client) { - std::vector object_ids; - for (int i = 0; i < 3; i++) { - object_ids.emplace_back(ObjectID::FromRandom()); - } - std::vector managers = {"abc", "def", "ghi"}; - - // Callback for a notification. - auto notification_callback = - [object_ids, managers]( - gcs::RedisGcsClient *client, const ObjectID &id, - const std::vector ¬ifications) { - if (test->NumCallbacks() < 3 * 3) { - ASSERT_EQ(notifications[0].GetGcsChangeMode(), GcsChangeMode::APPEND_OR_ADD); - } else { - ASSERT_EQ(notifications[0].GetGcsChangeMode(), GcsChangeMode::REMOVE); - } - ASSERT_EQ(id, object_ids[test->NumCallbacks() / 3 % 3]); - // Check that we get notifications in the same order as the writes. - for (const auto &entry : notifications[0].GetData()) { - ASSERT_EQ(entry.manager(), managers[test->NumCallbacks() % 3]); - test->IncrementNumCallbacks(); - } - if (test->NumCallbacks() == object_ids.size() * 3 * 2) { - test->Stop(); - } - }; - - // Callback for subscription success. We are guaranteed to receive - // notifications after this is called. - auto subscribe_callback = [job_id, object_ids, - managers](gcs::RedisGcsClient *client) { - // We have subscribed. Do the writes to the table. - for (size_t i = 0; i < object_ids.size(); i++) { - for (size_t j = 0; j < managers.size(); j++) { - auto data = std::make_shared(); - data->set_manager(managers[j]); - for (int k = 0; k < 3; k++) { - // Add the same entry several times. - // Expect no notification if the entry already exists. - RAY_CHECK_OK( - client->object_table().Add(job_id, object_ids[i], data, nullptr)); - } - } - } - for (size_t i = 0; i < object_ids.size(); i++) { - for (size_t j = 0; j < managers.size(); j++) { - auto data = std::make_shared(); - data->set_manager(managers[j]); - for (int k = 0; k < 3; k++) { - // Remove the same entry several times. - // Expect no notification if the entry doesn't exist. - RAY_CHECK_OK( - client->object_table().Remove(job_id, object_ids[i], data, nullptr)); - } - } - } - }; - - // Subscribe to all driver table notifications. Once we have successfully - // subscribed, we will append to the key several times and check that we get - // notified for each. - RAY_CHECK_OK(client->object_table().Subscribe( - job_id, NodeID::Nil(), notification_callback, subscribe_callback)); - - // Run the event loop. The loop will only stop if the registered subscription - // callback is called (or an assertion failure). - test->Start(); - // Check that we received one notification callback for each write. - ASSERT_EQ(test->NumCallbacks(), object_ids.size() * 3 * 2); - } - - static void TestSetSubscribeId(const JobID &job_id, - std::shared_ptr client) { - // Add a set entry. - ObjectID object_id1 = ObjectID::FromRandom(); - std::vector managers1 = {"abc", "def", "ghi"}; - auto data1 = std::make_shared(); - data1->set_manager(managers1[0]); - RAY_CHECK_OK(client->object_table().Add(job_id, object_id1, data1, nullptr)); - - // Add a set entry at a second key. - ObjectID object_id2 = ObjectID::FromRandom(); - std::vector managers2 = {"jkl", "mno", "pqr"}; - auto data2 = std::make_shared(); - data2->set_manager(managers2[0]); - RAY_CHECK_OK(client->object_table().Add(job_id, object_id2, data2, nullptr)); - - // The callback for a notification from the table. This should only be - // received for keys that we requested notifications for. - auto notification_callback = - [object_id2, managers2]( - gcs::RedisGcsClient *client, const ObjectID &id, - const std::vector ¬ifications) { - ASSERT_EQ(notifications[0].GetGcsChangeMode(), GcsChangeMode::APPEND_OR_ADD); - // Check that we only get notifications for the requested key. - ASSERT_EQ(id, object_id2); - // Check that we get notifications in the same order as the writes. - for (const auto &entry : notifications[0].GetData()) { - ASSERT_EQ(entry.manager(), managers2[test->NumCallbacks()]); - test->IncrementNumCallbacks(); - } - if (test->NumCallbacks() == managers2.size()) { - test->Stop(); - } - }; - - // The callback for subscription success. Once we've subscribed, request - // notifications for only one of the keys, then write to both keys. - auto subscribe_callback = [job_id, object_id1, object_id2, managers1, - managers2](gcs::RedisGcsClient *client) { - // Request notifications for one of the keys. - RAY_CHECK_OK(client->object_table().RequestNotifications(job_id, object_id2, - local_node_id, nullptr)); - // Write both keys. We should only receive notifications for the key that - // we requested them for. - auto remaining = std::vector(++managers1.begin(), managers1.end()); - for (const auto &manager : remaining) { - auto data = std::make_shared(); - data->set_manager(manager); - RAY_CHECK_OK(client->object_table().Add(job_id, object_id1, data, nullptr)); - } - remaining = std::vector(++managers2.begin(), managers2.end()); - for (const auto &manager : remaining) { - auto data = std::make_shared(); - data->set_manager(manager); - RAY_CHECK_OK(client->object_table().Add(job_id, object_id2, data, nullptr)); - } - }; - - // Subscribe to notifications for this client. This allows us to request and - // receive notifications for specific keys. - RAY_CHECK_OK(client->object_table().Subscribe( - job_id, local_node_id, notification_callback, subscribe_callback)); - // Run the event loop. The loop will only stop if the registered subscription - // callback is called for the requested key. - test->Start(); - // Check that we received one notification callback for each write to the - // requested key. - ASSERT_EQ(test->NumCallbacks(), managers2.size()); - } - - static void TestSetSubscribeCancel(const JobID &job_id, - std::shared_ptr client) { - // Add a set entry. - ObjectID object_id = ObjectID::FromRandom(); - std::vector managers = {"jkl", "mno", "pqr"}; - auto data = std::make_shared(); - data->set_manager(managers[0]); - RAY_CHECK_OK(client->object_table().Add(job_id, object_id, data, nullptr)); - - // The callback for a notification from the object table. This should only be - // received for the object that we requested notifications for. - auto notification_callback = - [object_id, managers]( - gcs::RedisGcsClient *client, const ObjectID &id, - const std::vector ¬ifications) { - ASSERT_EQ(notifications[0].GetGcsChangeMode(), GcsChangeMode::APPEND_OR_ADD); - ASSERT_EQ(id, object_id); - // Check that we get a duplicate notification for the first write. We get a - // duplicate notification because notifications - // are canceled after the first write, then requested again. - const std::vector &data = notifications[0].GetData(); - if (data.size() == 1) { - // first notification - ASSERT_EQ(data[0].manager(), managers[0]); - test->IncrementNumCallbacks(); - } else { - // second notification - ASSERT_EQ(data.size(), managers.size()); - std::unordered_set managers_set(managers.begin(), - managers.end()); - std::unordered_set data_managers_set; - for (const auto &entry : data) { - data_managers_set.insert(entry.manager()); - test->IncrementNumCallbacks(); - } - ASSERT_EQ(managers_set, data_managers_set); - } - if (test->NumCallbacks() == managers.size() + 1) { - test->Stop(); - } - }; - - // The callback for a notification from the table. This should only be - // received for keys that we requested notifications for. - auto subscribe_callback = [job_id, object_id, managers](gcs::RedisGcsClient *client) { - // Request notifications, then cancel immediately. We should receive a - // notification for the current value at the key. - RAY_CHECK_OK(client->object_table().RequestNotifications(job_id, object_id, - local_node_id, nullptr)); - RAY_CHECK_OK(client->object_table().CancelNotifications(job_id, object_id, - local_node_id, nullptr)); - // Add to the key. Since we canceled notifications, we should not - // receive a notification for these writes. - auto remaining = std::vector(++managers.begin(), managers.end()); - for (const auto &manager : remaining) { - auto data = std::make_shared(); - data->set_manager(manager); - RAY_CHECK_OK(client->object_table().Add(job_id, object_id, data, nullptr)); - } - // Request notifications again. We should receive a notification for the - // current values at the key. - RAY_CHECK_OK(client->object_table().RequestNotifications(job_id, object_id, - local_node_id, nullptr)); - }; - - // Subscribe to notifications for this client. This allows us to request and - // receive notifications for specific keys. - RAY_CHECK_OK(client->object_table().Subscribe( - job_id, local_node_id, notification_callback, subscribe_callback)); - // Run the event loop. The loop will only stop if the registered subscription - // callback is called for the requested key. - test->Start(); - // Check that we received a notification callback for the first append to the - // key, then a notification for all of the appends, because we cancel - // notifications in between. - ASSERT_EQ(test->NumCallbacks(), managers.size() + 1); - } -}; - -TEST_F(TestGcsWithAsio, TestSet) { - test = this; - SetTestHelper::TestSet(job_id_, client_); -} - -class LogDeleteTestHelper { - public: - static void TestDeleteKeysFromLog( - const JobID &job_id, std::shared_ptr client, - std::vector> &data_vector) { - std::vector ids; - TaskID task_id; - for (auto &data : data_vector) { - task_id = RandomTaskId(); - ids.push_back(task_id); - // Check that we added the correct object entries. - auto add_callback = [task_id, data](gcs::RedisGcsClient *client, const TaskID &id, - const TaskReconstructionData &d) { - ASSERT_EQ(id, task_id); - ASSERT_EQ(data->node_manager_id(), d.node_manager_id()); - test->IncrementNumCallbacks(); - }; - RAY_CHECK_OK( - client->task_reconstruction_log().Append(job_id, task_id, data, add_callback)); - } - for (const auto &task_id : ids) { - // Check that lookup returns the added object entries. - auto lookup_callback = [task_id, data_vector]( - gcs::RedisGcsClient *client, const TaskID &id, - const std::vector &data) { - ASSERT_EQ(id, task_id); - ASSERT_EQ(data.size(), 1); - test->IncrementNumCallbacks(); - }; - RAY_CHECK_OK( - client->task_reconstruction_log().Lookup(job_id, task_id, lookup_callback)); - } - if (ids.size() == 1) { - client->task_reconstruction_log().Delete(job_id, ids[0]); - } else { - client->task_reconstruction_log().Delete(job_id, ids); - } - for (const auto &task_id : ids) { - auto lookup_callback = [task_id](gcs::RedisGcsClient *client, const TaskID &id, - const std::vector &data) { - ASSERT_EQ(id, task_id); - ASSERT_TRUE(data.size() == 0); - test->IncrementNumCallbacks(); - }; - RAY_CHECK_OK( - client->task_reconstruction_log().Lookup(job_id, task_id, lookup_callback)); - } - } -}; - -// Test delete function for keys of Log or Table. -void TestDeleteKeys(const JobID &job_id, std::shared_ptr client) { - // Test delete function for keys of Log. - std::vector> task_reconstruction_vector; - auto AppendTaskReconstructionData = [&task_reconstruction_vector](size_t add_count) { - for (size_t i = 0; i < add_count; ++i) { - auto data = std::make_shared(); - data->set_node_manager_id(ObjectID::FromRandom().Hex()); - task_reconstruction_vector.push_back(data); - } - }; - // Test one element case. - AppendTaskReconstructionData(1); - ASSERT_EQ(task_reconstruction_vector.size(), 1); - LogDeleteTestHelper::TestDeleteKeysFromLog(job_id, client, task_reconstruction_vector); - // Test the case for more than one elements and less than - // maximum_gcs_deletion_batch_size. - AppendTaskReconstructionData(RayConfig::instance().maximum_gcs_deletion_batch_size() / - 2); - ASSERT_GT(task_reconstruction_vector.size(), 1); - ASSERT_LT(task_reconstruction_vector.size(), - RayConfig::instance().maximum_gcs_deletion_batch_size()); - LogDeleteTestHelper::TestDeleteKeysFromLog(job_id, client, task_reconstruction_vector); - // Test the case for more than maximum_gcs_deletion_batch_size. - // The Delete function will split the data into two commands. - AppendTaskReconstructionData(RayConfig::instance().maximum_gcs_deletion_batch_size() / - 2); - ASSERT_GT(task_reconstruction_vector.size(), - RayConfig::instance().maximum_gcs_deletion_batch_size()); - LogDeleteTestHelper::TestDeleteKeysFromLog(job_id, client, task_reconstruction_vector); - - // Test delete function for keys of Table. - std::vector> task_vector; - auto AppendTaskData = [&task_vector](size_t add_count) { - for (size_t i = 0; i < add_count; ++i) { - task_vector.push_back(TaskTableTestHelper::CreateTaskTableData(RandomTaskId())); - } - }; - AppendTaskData(1); - ASSERT_EQ(task_vector.size(), 1); - TaskTableTestHelper::TestDeleteKeysFromTable(job_id, client, task_vector, false); - - AppendTaskData(RayConfig::instance().maximum_gcs_deletion_batch_size() / 2); - ASSERT_GT(task_vector.size(), 1); - ASSERT_LT(task_vector.size(), RayConfig::instance().maximum_gcs_deletion_batch_size()); - TaskTableTestHelper::TestDeleteKeysFromTable(job_id, client, task_vector, false); - - AppendTaskData(RayConfig::instance().maximum_gcs_deletion_batch_size() / 2); - ASSERT_GT(task_vector.size(), RayConfig::instance().maximum_gcs_deletion_batch_size()); - TaskTableTestHelper::TestDeleteKeysFromTable(job_id, client, task_vector, true); - - test->Start(); - ASSERT_GT(test->NumCallbacks(), - 9 * RayConfig::instance().maximum_gcs_deletion_batch_size()); - - // Test delete function for keys of Set. - std::vector> object_vector; - auto AppendObjectData = [&object_vector](size_t add_count) { - for (size_t i = 0; i < add_count; ++i) { - auto data = std::make_shared(); - data->set_manager(ObjectID::FromRandom().Hex()); - object_vector.push_back(data); - } - }; - // Test one element case. - AppendObjectData(1); - ASSERT_EQ(object_vector.size(), 1); - SetTestHelper::TestDeleteKeysFromSet(job_id, client, object_vector); - // Test the case for more than one elements and less than - // maximum_gcs_deletion_batch_size. - AppendObjectData(RayConfig::instance().maximum_gcs_deletion_batch_size() / 2); - ASSERT_GT(object_vector.size(), 1); - ASSERT_LT(object_vector.size(), - RayConfig::instance().maximum_gcs_deletion_batch_size()); - SetTestHelper::TestDeleteKeysFromSet(job_id, client, object_vector); - // Test the case for more than maximum_gcs_deletion_batch_size. - // The Delete function will split the data into two commands. - AppendObjectData(RayConfig::instance().maximum_gcs_deletion_batch_size() / 2); - ASSERT_GT(object_vector.size(), - RayConfig::instance().maximum_gcs_deletion_batch_size()); - SetTestHelper::TestDeleteKeysFromSet(job_id, client, object_vector); -} - -TEST_F(TestGcsWithAsio, TestDeleteKey) { - test = this; - TestDeleteKeys(job_id_, client_); -} - -/// A helper class for Log Subscribe testing. -class LogSubscribeTestHelper { - public: - static void TestLogSubscribeAll(const JobID &job_id, - std::shared_ptr client) { - std::vector job_ids; - for (int i = 0; i < 3; i++) { - job_ids.emplace_back(NextJobID()); - } - // Callback for a notification. - auto notification_callback = [job_ids](gcs::RedisGcsClient *client, const JobID &id, - const std::vector data) { - ASSERT_EQ(id, job_ids[test->NumCallbacks()]); - // Check that we get notifications in the same order as the writes. - for (const auto &entry : data) { - ASSERT_EQ(entry.job_id(), job_ids[test->NumCallbacks()].Binary()); - test->IncrementNumCallbacks(); - } - if (test->NumCallbacks() == job_ids.size()) { - test->Stop(); - } - }; - - // Callback for subscription success. We are guaranteed to receive - // notifications after this is called. - auto subscribe_callback = [job_ids](gcs::RedisGcsClient *client) { - // We have subscribed. Do the writes to the table. - for (size_t i = 0; i < job_ids.size(); i++) { - auto job_info_ptr = CreateJobTableData(job_ids[i], false, 0, "localhost", 1); - RAY_CHECK_OK( - client->job_table().Append(job_ids[i], job_ids[i], job_info_ptr, nullptr)); - } - }; - - // Subscribe to all driver table notifications. Once we have successfully - // subscribed, we will append to the key several times and check that we get - // notified for each. - RAY_CHECK_OK(client->job_table().Subscribe( - job_id, NodeID::Nil(), notification_callback, subscribe_callback)); - - // Run the event loop. The loop will only stop if the registered subscription - // callback is called (or an assertion failure). - test->Start(); - // Check that we received one notification callback for each write. - ASSERT_EQ(test->NumCallbacks(), job_ids.size()); - } - - static void TestLogSubscribeId(const JobID &job_id, - std::shared_ptr client) { - // Add a log entry. - JobID job_id1 = NextJobID(); - std::vector job_ids1 = {"abc", "def", "ghi"}; - auto data1 = std::make_shared(); - data1->set_job_id(job_ids1[0]); - RAY_CHECK_OK(client->job_table().Append(job_id, job_id1, data1, nullptr)); - - // Add a log entry at a second key. - JobID job_id2 = NextJobID(); - std::vector job_ids2 = {"jkl", "mno", "pqr"}; - auto data2 = std::make_shared(); - data2->set_job_id(job_ids2[0]); - RAY_CHECK_OK(client->job_table().Append(job_id, job_id2, data2, nullptr)); - - // The callback for a notification from the table. This should only be - // received for keys that we requested notifications for. - auto notification_callback = [job_id2, job_ids2]( - gcs::RedisGcsClient *client, const JobID &id, - const std::vector &data) { - // Check that we only get notifications for the requested key. - ASSERT_EQ(id, job_id2); - // Check that we get notifications in the same order as the writes. - for (const auto &entry : data) { - ASSERT_EQ(entry.job_id(), job_ids2[test->NumCallbacks()]); - test->IncrementNumCallbacks(); - } - if (test->NumCallbacks() == job_ids2.size()) { - test->Stop(); - } - }; - - // The callback for subscription success. Once we've subscribed, request - // notifications for only one of the keys, then write to both keys. - auto subscribe_callback = [job_id, job_id1, job_id2, job_ids1, - job_ids2](gcs::RedisGcsClient *client) { - // Request notifications for one of the keys. - RAY_CHECK_OK(client->job_table().RequestNotifications(job_id, job_id2, - local_node_id, nullptr)); - // Write both keys. We should only receive notifications for the key that - // we requested them for. - auto remaining = std::vector(++job_ids1.begin(), job_ids1.end()); - for (const auto &job_id_it : remaining) { - auto data = std::make_shared(); - data->set_job_id(job_id_it); - RAY_CHECK_OK(client->job_table().Append(job_id, job_id1, data, nullptr)); - } - remaining = std::vector(++job_ids2.begin(), job_ids2.end()); - for (const auto &job_id_it : remaining) { - auto data = std::make_shared(); - data->set_job_id(job_id_it); - RAY_CHECK_OK(client->job_table().Append(job_id, job_id2, data, nullptr)); - } - }; - - // Subscribe to notifications for this client. This allows us to request and - // receive notifications for specific keys. - RAY_CHECK_OK(client->job_table().Subscribe( - job_id, local_node_id, notification_callback, subscribe_callback)); - // Run the event loop. The loop will only stop if the registered subscription - // callback is called for the requested key. - test->Start(); - // Check that we received one notification callback for each write to the - // requested key. - ASSERT_EQ(test->NumCallbacks(), job_ids2.size()); - } - - static void TestLogSubscribeCancel(const JobID &job_id, - std::shared_ptr client) { - // Add a log entry. - JobID random_job_id = NextJobID(); - std::vector job_ids = {"jkl", "mno", "pqr"}; - auto data = std::make_shared(); - data->set_job_id(job_ids[0]); - RAY_CHECK_OK(client->job_table().Append(job_id, random_job_id, data, nullptr)); - - // The callback for a notification from the object table. This should only be - // received for the object that we requested notifications for. - auto notification_callback = [random_job_id, job_ids]( - gcs::RedisGcsClient *client, const JobID &id, - const std::vector &data) { - ASSERT_EQ(id, random_job_id); - // Check that we get a duplicate notification for the first write. We get a - // duplicate notification because the log is append-only and notifications - // are canceled after the first write, then requested again. - auto job_ids_copy = job_ids; - job_ids_copy.insert(job_ids_copy.begin(), job_ids_copy.front()); - for (const auto &entry : data) { - ASSERT_EQ(entry.job_id(), job_ids_copy[test->NumCallbacks()]); - test->IncrementNumCallbacks(); - } - if (test->NumCallbacks() == job_ids_copy.size()) { - test->Stop(); - } - }; - - // The callback for a notification from the table. This should only be - // received for keys that we requested notifications for. - auto subscribe_callback = [job_id, random_job_id, - job_ids](gcs::RedisGcsClient *client) { - // Request notifications, then cancel immediately. We should receive a - // notification for the current value at the key. - RAY_CHECK_OK(client->job_table().RequestNotifications(job_id, random_job_id, - local_node_id, nullptr)); - RAY_CHECK_OK(client->job_table().CancelNotifications(job_id, random_job_id, - local_node_id, nullptr)); - // Append to the key. Since we canceled notifications, we should not - // receive a notification for these writes. - auto remaining = std::vector(++job_ids.begin(), job_ids.end()); - for (const auto &remaining_job_id : remaining) { - auto data = std::make_shared(); - data->set_job_id(remaining_job_id); - RAY_CHECK_OK(client->job_table().Append(job_id, random_job_id, data, nullptr)); - } - // Request notifications again. We should receive a notification for the - // current values at the key. - RAY_CHECK_OK(client->job_table().RequestNotifications(job_id, random_job_id, - local_node_id, nullptr)); - }; - - // Subscribe to notifications for this client. This allows us to request and - // receive notifications for specific keys. - RAY_CHECK_OK(client->job_table().Subscribe( - job_id, local_node_id, notification_callback, subscribe_callback)); - // Run the event loop. The loop will only stop if the registered subscription - // callback is called for the requested key. - test->Start(); - // Check that we received a notification callback for the first append to the - // key, then a notification for all of the appends, because we cancel - // notifications in between. - ASSERT_EQ(test->NumCallbacks(), job_ids.size() + 1); - } -}; - -TEST_F(TestGcsWithAsio, TestLogSubscribeAll) { - test = this; - LogSubscribeTestHelper::TestLogSubscribeAll(job_id_, client_); -} - -TEST_F(TestGcsWithAsio, TestSetSubscribeAll) { - test = this; - SetTestHelper::TestSetSubscribeAll(job_id_, client_); -} - -TEST_TASK_TABLE_MACRO(TestGcsWithAsio, TestTableSubscribeId); - -TEST_F(TestGcsWithAsio, TestLogSubscribeId) { - test = this; - LogSubscribeTestHelper::TestLogSubscribeId(job_id_, client_); -} - -TEST_F(TestGcsWithAsio, TestSetSubscribeId) { - test = this; - SetTestHelper::TestSetSubscribeId(job_id_, client_); -} - -TEST_TASK_TABLE_MACRO(TestGcsWithAsio, TestTableSubscribeCancel); - -TEST_F(TestGcsWithAsio, TestLogSubscribeCancel) { - test = this; - LogSubscribeTestHelper::TestLogSubscribeCancel(job_id_, client_); -} - -TEST_F(TestGcsWithAsio, TestSetSubscribeCancel) { - test = this; - SetTestHelper::TestSetSubscribeCancel(job_id_, client_); -} - -/// A helper class for NodeTable testing. -class NodeTableTestHelper { - public: - static void NodeTableNotification(std::shared_ptr client, - const NodeID &node_id, const GcsNodeInfo &data, - bool is_alive) { - NodeID added_id = local_node_id; - ASSERT_EQ(node_id, added_id); - ASSERT_EQ(NodeID::FromBinary(data.node_id()), added_id); - ASSERT_EQ(data.state() == GcsNodeInfo::ALIVE, is_alive); - - GcsNodeInfo cached_node; - ASSERT_TRUE(client->node_table().GetNode(added_id, &cached_node)); - ASSERT_EQ(NodeID::FromBinary(cached_node.node_id()), added_id); - ASSERT_EQ(cached_node.state() == GcsNodeInfo::ALIVE, is_alive); - } - - static void TestNodeTableConnect(const JobID &job_id, - std::shared_ptr client) { - // Subscribe to a node gets added and removed. The latter - // event will stop the event loop. - RAY_CHECK_OK(client->node_table().SubscribeToNodeChange( - [client](const NodeID &id, const GcsNodeInfo &data) { - // TODO(micafan) - RAY_LOG(INFO) << "Test alive=" << data.state() << " id=" << id; - if (data.state() == GcsNodeInfo::ALIVE) { - NodeTableNotification(client, id, data, true); - test->Stop(); - } - }, - nullptr)); - - // Connect and disconnect to node table. We should receive notifications - // for the addition and removal of our own entry. - GcsNodeInfo local_node_info; - local_node_info.set_node_id(local_node_id.Binary()); - local_node_info.set_node_manager_address("127.0.0.1"); - local_node_info.set_node_manager_port(0); - local_node_info.set_object_manager_port(0); - RAY_CHECK_OK(client->node_table().Connect(local_node_info)); - test->Start(); - } - - static void TestNodeTableDisconnect(const JobID &job_id, - std::shared_ptr client) { - // Register callbacks for when a node gets added and removed. The latter - // event will stop the event loop. - RAY_CHECK_OK(client->node_table().SubscribeToNodeChange( - [client](const NodeID &id, const GcsNodeInfo &data) { - if (data.state() == GcsNodeInfo::ALIVE) { - NodeTableNotification(client, id, data, /*is_insertion=*/true); - // Disconnect from the node table. We should receive a notification - // for the removal of our own entry. - RAY_CHECK_OK(client->node_table().Disconnect()); - } else { - NodeTableNotification(client, id, data, /*is_insertion=*/false); - test->Stop(); - } - }, - nullptr)); - - // Connect to the node table. We should receive notification for the - // addition of our own entry. - GcsNodeInfo local_node_info; - local_node_info.set_node_id(local_node_id.Binary()); - local_node_info.set_node_manager_address("127.0.0.1"); - local_node_info.set_node_manager_port(0); - local_node_info.set_object_manager_port(0); - RAY_CHECK_OK(client->node_table().Connect(local_node_info)); - test->Start(); - } - - static void TestNodeTableImmediateDisconnect( - const JobID &job_id, std::shared_ptr client) { - // Register callbacks for when a node gets added and removed. The latter - // event will stop the event loop. - RAY_CHECK_OK(client->node_table().SubscribeToNodeChange( - [client](const NodeID &id, const GcsNodeInfo &data) { - if (data.state() == GcsNodeInfo::ALIVE) { - NodeTableNotification(client, id, data, true); - } else { - NodeTableNotification(client, id, data, false); - test->Stop(); - } - }, - nullptr)); - // Connect to then immediately disconnect from the node table. We should - // receive notifications for the addition and removal of our own entry. - GcsNodeInfo local_node_info; - local_node_info.set_node_id(local_node_id.Binary()); - local_node_info.set_node_manager_address("127.0.0.1"); - local_node_info.set_node_manager_port(0); - local_node_info.set_object_manager_port(0); - RAY_CHECK_OK(client->node_table().Connect(local_node_info)); - RAY_CHECK_OK(client->node_table().Disconnect()); - test->Start(); - } - - static void TestNodeTableMarkDisconnected(const JobID &job_id, - std::shared_ptr client) { - GcsNodeInfo local_node_info; - local_node_info.set_node_id(local_node_id.Binary()); - local_node_info.set_node_manager_address("127.0.0.1"); - local_node_info.set_node_manager_port(0); - local_node_info.set_object_manager_port(0); - // Connect to the node table to start receiving notifications. - RAY_CHECK_OK(client->node_table().Connect(local_node_info)); - // Mark a different node as dead. - NodeID dead_node_id = NodeID::FromRandom(); - RAY_CHECK_OK(client->node_table().MarkDisconnected(dead_node_id, nullptr)); - // Make sure we only get a notification for the removal of the node we - // marked as dead. - RAY_CHECK_OK(client->node_table().SubscribeToNodeChange( - [dead_node_id](const UniqueID &id, const GcsNodeInfo &data) { - if (data.state() == GcsNodeInfo::DEAD) { - ASSERT_EQ(NodeID::FromBinary(data.node_id()), dead_node_id); - test->Stop(); - } - }, - nullptr)); - test->Start(); - } -}; - -TEST_F(TestGcsWithAsio, TestNodeTableConnect) { - test = this; - NodeTableTestHelper::TestNodeTableConnect(job_id_, client_); -} - -TEST_F(TestGcsWithAsio, TestNodeTableDisconnect) { - test = this; - NodeTableTestHelper::TestNodeTableDisconnect(job_id_, client_); -} - -TEST_F(TestGcsWithAsio, TestNodeTableImmediateDisconnect) { - test = this; - NodeTableTestHelper::TestNodeTableImmediateDisconnect(job_id_, client_); -} - -TEST_F(TestGcsWithAsio, TestNodeTableMarkDisconnected) { - test = this; - NodeTableTestHelper::TestNodeTableMarkDisconnected(job_id_, client_); -} - -class HashTableTestHelper { - public: - static void TestHashTable(const JobID &job_id, - std::shared_ptr client) { - uint64_t expected_count = 14; - NodeID node_id = NodeID::FromRandom(); - // Prepare the first resource map: data_map1. - DynamicResourceTable::DataMap data_map1; - auto cpu_data = std::make_shared(); - cpu_data->set_resource_capacity(100); - data_map1.emplace("CPU", cpu_data); - auto gpu_data = std::make_shared(); - gpu_data->set_resource_capacity(2); - data_map1.emplace("GPU", gpu_data); - // Prepare the second resource map: data_map2 which decreases CPU, - // increases GPU and add a new CUSTOM compared to data_map1. - DynamicResourceTable::DataMap data_map2; - auto data_cpu = std::make_shared(); - data_cpu->set_resource_capacity(50); - data_map2.emplace("CPU", data_cpu); - auto data_gpu = std::make_shared(); - data_gpu->set_resource_capacity(10); - data_map2.emplace("GPU", data_gpu); - auto data_custom = std::make_shared(); - data_custom->set_resource_capacity(2); - data_map2.emplace("CUSTOM", data_custom); - data_map2["CPU"]->set_resource_capacity(50); - // This is a common comparison function for the test. - auto compare_test = [](const DynamicResourceTable::DataMap &data1, - const DynamicResourceTable::DataMap &data2) { - ASSERT_EQ(data1.size(), data2.size()); - for (const auto &data : data1) { - auto iter = data2.find(data.first); - ASSERT_TRUE(iter != data2.end()); - ASSERT_EQ(iter->second->resource_capacity(), data.second->resource_capacity()); - } - }; - auto subscribe_callback = [](RedisGcsClient *client) { - ASSERT_TRUE(true); - test->IncrementNumCallbacks(); - }; - auto notification_callback = - [data_map1, data_map2, compare_test, expected_count]( - RedisGcsClient *client, const NodeID &id, - const std::vector &result) { - RAY_CHECK(result.size() == 1); - const ResourceChangeNotification ¬ification = result.back(); - if (notification.IsRemoved()) { - ASSERT_EQ(notification.GetData().size(), 2); - ASSERT_TRUE(notification.GetData().find("GPU") != - notification.GetData().end()); - ASSERT_TRUE( - notification.GetData().find("CUSTOM") != notification.GetData().end() || - notification.GetData().find("CPU") != notification.GetData().end()); - // The key "None-Existent" will not appear in the notification. - } else { - if (notification.GetData().size() == 2) { - compare_test(data_map1, notification.GetData()); - } else if (notification.GetData().size() == 3) { - compare_test(data_map2, notification.GetData()); - } else { - ASSERT_TRUE(false); - } - } - test->IncrementNumCallbacks(); - // It is not sure which of the notification or lookup callback will come first. - if (test->NumCallbacks() == expected_count) { - test->Stop(); - } - }; - // Step 0: Subscribe the change of the hash table. - RAY_CHECK_OK(client->resource_table().Subscribe( - job_id, NodeID::Nil(), notification_callback, subscribe_callback)); - RAY_CHECK_OK(client->resource_table().RequestNotifications(job_id, node_id, - local_node_id, nullptr)); - - // Step 1: Add elements to the hash table. - auto update_callback1 = [data_map1, compare_test]( - RedisGcsClient *client, const NodeID &id, - const DynamicResourceTable::DataMap &callback_data) { - compare_test(data_map1, callback_data); - test->IncrementNumCallbacks(); - }; - RAY_CHECK_OK( - client->resource_table().Update(job_id, node_id, data_map1, update_callback1)); - auto lookup_callback1 = [data_map1, compare_test]( - RedisGcsClient *client, const NodeID &id, - const DynamicResourceTable::DataMap &callback_data) { - compare_test(data_map1, callback_data); - test->IncrementNumCallbacks(); - }; - RAY_CHECK_OK(client->resource_table().Lookup(job_id, node_id, lookup_callback1)); - - // Step 2: Decrease one element, increase one and add a new one. - RAY_CHECK_OK(client->resource_table().Update(job_id, node_id, data_map2, nullptr)); - auto lookup_callback2 = [data_map2, compare_test]( - RedisGcsClient *client, const NodeID &id, - const DynamicResourceTable::DataMap &callback_data) { - compare_test(data_map2, callback_data); - test->IncrementNumCallbacks(); - }; - RAY_CHECK_OK(client->resource_table().Lookup(job_id, node_id, lookup_callback2)); - std::vector delete_keys({"GPU", "CUSTOM", "None-Existent"}); - auto remove_callback = [delete_keys](RedisGcsClient *client, const NodeID &id, - const std::vector &callback_data) { - for (size_t i = 0; i < callback_data.size(); ++i) { - // All deleting keys exist in this argument even if the key doesn't exist. - ASSERT_EQ(callback_data[i], delete_keys[i]); - } - test->IncrementNumCallbacks(); - }; - RAY_CHECK_OK(client->resource_table().RemoveEntries(job_id, node_id, delete_keys, - remove_callback)); - DynamicResourceTable::DataMap data_map3(data_map2); - data_map3.erase("GPU"); - data_map3.erase("CUSTOM"); - auto lookup_callback3 = [data_map3, compare_test]( - RedisGcsClient *client, const NodeID &id, - const DynamicResourceTable::DataMap &callback_data) { - compare_test(data_map3, callback_data); - test->IncrementNumCallbacks(); - }; - RAY_CHECK_OK(client->resource_table().Lookup(job_id, node_id, lookup_callback3)); - - // Step 3: Reset the the resources to data_map1. - RAY_CHECK_OK( - client->resource_table().Update(job_id, node_id, data_map1, update_callback1)); - auto lookup_callback4 = [data_map1, compare_test]( - RedisGcsClient *client, const NodeID &id, - const DynamicResourceTable::DataMap &callback_data) { - compare_test(data_map1, callback_data); - test->IncrementNumCallbacks(); - }; - RAY_CHECK_OK(client->resource_table().Lookup(job_id, node_id, lookup_callback4)); - - // Step 4: Removing all elements will remove the home Hash table from GCS. - RAY_CHECK_OK(client->resource_table().RemoveEntries( - job_id, node_id, {"GPU", "CPU", "CUSTOM", "None-Existent"}, nullptr)); - auto lookup_callback5 = [expected_count]( - RedisGcsClient *client, const NodeID &id, - const DynamicResourceTable::DataMap &callback_data) { - ASSERT_EQ(callback_data.size(), 0); - test->IncrementNumCallbacks(); - // It is not sure which of notification or lookup callback will come first. - if (test->NumCallbacks() == expected_count) { - test->Stop(); - } - }; - RAY_CHECK_OK(client->resource_table().Lookup(job_id, node_id, lookup_callback5)); - test->Start(); - ASSERT_EQ(test->NumCallbacks(), expected_count); - } -}; - -TEST_F(TestGcsWithAsio, TestHashTable) { - test = this; - HashTableTestHelper::TestHashTable(job_id_, client_); -} - -#undef TEST_TASK_TABLE_MACRO - -} // namespace gcs -} // namespace ray - -int main(int argc, char **argv) { - InitShutdownRAII ray_log_shutdown_raii(ray::RayLog::StartRayLog, - ray::RayLog::ShutDownRayLog, argv[0], - ray::RayLogLevel::INFO, - /*log_dir=*/""); - ::testing::InitGoogleTest(&argc, argv); - RAY_CHECK(argc == 4); - ray::TEST_REDIS_SERVER_EXEC_PATH = argv[1]; - ray::TEST_REDIS_CLIENT_EXEC_PATH = argv[2]; - ray::TEST_REDIS_MODULE_LIBRARY_PATH = argv[3]; - return RUN_ALL_TESTS(); -} diff --git a/src/ray/gcs/test/redis_job_info_accessor_test.cc b/src/ray/gcs/test/redis_job_info_accessor_test.cc deleted file mode 100644 index 31dc69393..000000000 --- a/src/ray/gcs/test/redis_job_info_accessor_test.cc +++ /dev/null @@ -1,99 +0,0 @@ -// Copyright 2017 The Ray Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include - -#include "gtest/gtest.h" -#include "ray/common/test_util.h" -#include "ray/gcs/pb_util.h" -#include "ray/gcs/redis_gcs_client.h" -#include "ray/gcs/test/accessor_test_base.h" - -namespace ray { - -namespace gcs { - -class RedisJobInfoAccessorTest : public AccessorTestBase { - protected: - virtual void GenTestData() { - for (size_t i = 0; i < total_job_number_; ++i) { - JobID job_id = JobID::FromInt(i); - std::shared_ptr job_data_ptr = - CreateJobTableData(job_id, /*is_dead*/ false, /*timestamp*/ 1, - /*driver_ip_address*/ "", /*driver_pid*/ i); - id_to_data_[job_id] = job_data_ptr; - } - } - std::atomic subscribe_pending_count_{0}; - size_t total_job_number_{100}; -}; - -TEST_F(RedisJobInfoAccessorTest, AddAndSubscribe) { - JobInfoAccessor &job_accessor = gcs_client_->Jobs(); - // SubscribeAll - auto on_subscribe = [this](const JobID &job_id, const JobTableData &data) { - const auto it = id_to_data_.find(job_id); - RAY_CHECK(it != id_to_data_.end()); - if (data.is_dead()) { - --subscribe_pending_count_; - } - }; - - auto on_done = [this](Status status) { - RAY_CHECK_OK(status); - --pending_count_; - }; - - ++pending_count_; - RAY_CHECK_OK(job_accessor.AsyncSubscribeAll(on_subscribe, on_done)); - - WaitPendingDone(wait_pending_timeout_); - WaitPendingDone(subscribe_pending_count_, wait_pending_timeout_); - - // Register - for (const auto &item : id_to_data_) { - ++pending_count_; - RAY_CHECK_OK(job_accessor.AsyncAdd(item.second, [this](Status status) { - RAY_CHECK_OK(status); - --pending_count_; - })); - } - WaitPendingDone(wait_pending_timeout_); - WaitPendingDone(subscribe_pending_count_, wait_pending_timeout_); - - // Update - for (auto &item : id_to_data_) { - ++pending_count_; - ++subscribe_pending_count_; - RAY_CHECK_OK(job_accessor.AsyncMarkFinished(item.first, [this](Status status) { - RAY_CHECK_OK(status); - --pending_count_; - })); - } - WaitPendingDone(wait_pending_timeout_); - WaitPendingDone(subscribe_pending_count_, wait_pending_timeout_); -} - -} // namespace gcs - -} // namespace ray - -int main(int argc, char **argv) { - ::testing::InitGoogleTest(&argc, argv); - RAY_CHECK(argc == 4); - ray::TEST_REDIS_SERVER_EXEC_PATH = argv[1]; - ray::TEST_REDIS_CLIENT_EXEC_PATH = argv[2]; - ray::TEST_REDIS_MODULE_LIBRARY_PATH = argv[3]; - return RUN_ALL_TESTS(); -} diff --git a/src/ray/gcs/test/redis_node_info_accessor_test.cc b/src/ray/gcs/test/redis_node_info_accessor_test.cc deleted file mode 100644 index e4435184e..000000000 --- a/src/ray/gcs/test/redis_node_info_accessor_test.cc +++ /dev/null @@ -1,181 +0,0 @@ -// Copyright 2017 The Ray Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include - -#include "gtest/gtest.h" -#include "ray/gcs/redis_accessor.h" -#include "ray/gcs/redis_gcs_client.h" -#include "ray/gcs/test/accessor_test_base.h" - -namespace ray { - -namespace gcs { - -class NodeDynamicResourceTest : public AccessorTestBase { - protected: - typedef NodeResourceInfoAccessor::ResourceMap ResourceMap; - virtual void GenTestData() { - for (size_t node_index = 0; node_index < node_number_; ++node_index) { - NodeID id = NodeID::FromRandom(); - ResourceMap resource_map; - for (size_t rs_index = 0; rs_index < resource_type_number_; ++rs_index) { - std::shared_ptr rs_data = - std::make_shared(); - rs_data->set_resource_capacity(rs_index); - std::string resource_name = std::to_string(rs_index); - resource_map[resource_name] = rs_data; - if (resource_to_delete_.empty()) { - resource_to_delete_.emplace_back(resource_name); - } - } - id_to_resource_map_[id] = std::move(resource_map); - } - } - - std::unordered_map id_to_resource_map_; - - size_t node_number_{100}; - size_t resource_type_number_{5}; - - std::vector resource_to_delete_; - - std::atomic sub_pending_count_{0}; - std::atomic do_sub_pending_count_{0}; -}; - -TEST_F(NodeDynamicResourceTest, UpdateAndGet) { - NodeResourceInfoAccessor &node_resource_accessor = gcs_client_->NodeResources(); - for (const auto &node_rs : id_to_resource_map_) { - ++pending_count_; - const NodeID &id = node_rs.first; - // Update - Status status = node_resource_accessor.AsyncUpdateResources( - node_rs.first, node_rs.second, - [this, &node_resource_accessor, id](Status status) { - RAY_CHECK_OK(status); - auto get_callback = [this, id](Status status, - const boost::optional &result) { - --pending_count_; - RAY_CHECK_OK(status); - const auto it = id_to_resource_map_.find(id); - ASSERT_TRUE(result); - ASSERT_EQ(it->second.size(), result->size()); - }; - // Get - status = node_resource_accessor.AsyncGetResources(id, get_callback); - RAY_CHECK_OK(status); - }); - } - WaitPendingDone(wait_pending_timeout_); -} - -TEST_F(NodeDynamicResourceTest, Delete) { - NodeResourceInfoAccessor &node_resource_accessor = gcs_client_->NodeResources(); - for (const auto &node_rs : id_to_resource_map_) { - ++pending_count_; - // Update - Status status = node_resource_accessor.AsyncUpdateResources( - node_rs.first, node_rs.second, [this](Status status) { - RAY_CHECK_OK(status); - --pending_count_; - }); - } - WaitPendingDone(wait_pending_timeout_); - - for (const auto &node_rs : id_to_resource_map_) { - ++pending_count_; - const NodeID &id = node_rs.first; - // Delete - Status status = node_resource_accessor.AsyncDeleteResources( - id, resource_to_delete_, [this, &node_resource_accessor, id](Status status) { - RAY_CHECK_OK(status); - // Get - status = node_resource_accessor.AsyncGetResources( - id, [this, id](Status status, const boost::optional &result) { - --pending_count_; - RAY_CHECK_OK(status); - const auto it = id_to_resource_map_.find(id); - ASSERT_TRUE(result); - ASSERT_EQ(it->second.size() - resource_to_delete_.size(), result->size()); - }); - }); - } - WaitPendingDone(wait_pending_timeout_); -} - -TEST_F(NodeDynamicResourceTest, Subscribe) { - NodeResourceInfoAccessor &node_resource_accessor = gcs_client_->NodeResources(); - for (const auto &node_rs : id_to_resource_map_) { - ++pending_count_; - // Update - Status status = node_resource_accessor.AsyncUpdateResources( - node_rs.first, node_rs.second, [this](Status status) { - RAY_CHECK_OK(status); - --pending_count_; - }); - } - WaitPendingDone(wait_pending_timeout_); - - auto subscribe = [this](const rpc::NodeResourceChange ¬ification) { - auto id = NodeID::FromBinary(notification.node_id()); - RAY_LOG(INFO) << "receive client id=" << id; - auto it = id_to_resource_map_.find(id); - ASSERT_TRUE(it != id_to_resource_map_.end()); - if (0 == notification.deleted_resources_size()) { - ASSERT_EQ(notification.updated_resources_size(), it->second.size()); - } else { - ASSERT_EQ(notification.deleted_resources_size(), resource_to_delete_.size()); - } - --sub_pending_count_; - }; - - auto done = [this](Status status) { - RAY_CHECK_OK(status); - --pending_count_; - }; - - // Subscribe - ++pending_count_; - Status status = node_resource_accessor.AsyncSubscribeToResources(subscribe, done); - RAY_CHECK_OK(status); - - for (const auto &node_rs : id_to_resource_map_) { - // Delete - ++pending_count_; - ++sub_pending_count_; - Status status = node_resource_accessor.AsyncDeleteResources( - node_rs.first, resource_to_delete_, [this](Status status) { - RAY_CHECK_OK(status); - --pending_count_; - }); - RAY_CHECK_OK(status); - } - - WaitPendingDone(wait_pending_timeout_); - WaitPendingDone(sub_pending_count_, wait_pending_timeout_); -} - -} // namespace gcs - -} // namespace ray - -int main(int argc, char **argv) { - ::testing::InitGoogleTest(&argc, argv); - RAY_CHECK(argc == 4); - ray::TEST_REDIS_SERVER_EXEC_PATH = argv[1]; - ray::TEST_REDIS_CLIENT_EXEC_PATH = argv[2]; - ray::TEST_REDIS_MODULE_LIBRARY_PATH = argv[3]; - return RUN_ALL_TESTS(); -} diff --git a/src/ray/gcs/test/redis_object_info_accessor_test.cc b/src/ray/gcs/test/redis_object_info_accessor_test.cc deleted file mode 100644 index bbe310b97..000000000 --- a/src/ray/gcs/test/redis_object_info_accessor_test.cc +++ /dev/null @@ -1,160 +0,0 @@ -// Copyright 2017 The Ray Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include -#include - -#include "gtest/gtest.h" -#include "ray/common/test_util.h" -#include "ray/gcs/redis_accessor.h" -#include "ray/gcs/redis_gcs_client.h" -#include "ray/gcs/test/accessor_test_base.h" - -namespace ray { - -namespace gcs { - -class RedisObjectInfoAccessorTest : public AccessorTestBase { - protected: - void GenTestData() { - for (size_t i = 0; i < object_count_; ++i) { - ObjectVector object_vec; - for (size_t j = 0; j < copy_count_; ++j) { - auto object = std::make_shared(); - NodeID node_id = NodeID::FromRandom(); - object->set_manager(node_id.Binary()); - object_vec.emplace_back(std::move(object)); - } - ObjectID id = ObjectID::FromRandom(); - object_id_to_data_[id] = object_vec; - } - } - - typedef std::vector> ObjectVector; - std::unordered_map object_id_to_data_; - - size_t object_count_{100}; - size_t copy_count_{5}; -}; - -TEST_F(RedisObjectInfoAccessorTest, TestGetAddRemove) { - ObjectInfoAccessor &object_accessor = gcs_client_->Objects(); - // add && get - // add - for (const auto &elem : object_id_to_data_) { - for (const auto &item : elem.second) { - ++pending_count_; - NodeID node_id = NodeID::FromBinary(item->manager()); - RAY_CHECK_OK( - object_accessor.AsyncAddLocation(elem.first, node_id, [this](Status status) { - RAY_CHECK_OK(status); - --pending_count_; - })); - } - } - WaitPendingDone(wait_pending_timeout_); - // get - for (const auto &elem : object_id_to_data_) { - ++pending_count_; - size_t total_size = elem.second.size(); - RAY_CHECK_OK(object_accessor.AsyncGetLocations( - elem.first, - [this, total_size](Status status, - const boost::optional &result) { - RAY_CHECK_OK(status); - ASSERT_EQ(total_size, result->locations().size()); - --pending_count_; - })); - } - WaitPendingDone(wait_pending_timeout_); - - RAY_LOG(INFO) << "Case Add && Get done."; - - // subscribe && delete - // subscribe - std::atomic sub_pending_count(0); - auto subscribe = [this, &sub_pending_count]( - const ObjectID &object_id, - const std::vector &result) { - const auto it = object_id_to_data_.find(object_id); - ASSERT_TRUE(it != object_id_to_data_.end()); - static size_t response_count = 1; - size_t cur_count = response_count <= object_count_ ? copy_count_ : 1; - ASSERT_EQ(result.size(), cur_count); - bool change_mode = response_count <= object_count_; - for (const auto &res : result) { - ASSERT_EQ(change_mode, res.is_add()); - } - ++response_count; - --sub_pending_count; - }; - for (const auto &elem : object_id_to_data_) { - ++pending_count_; - ++sub_pending_count; - RAY_CHECK_OK(object_accessor.AsyncSubscribeToLocations(elem.first, subscribe, - [this](Status status) { - RAY_CHECK_OK(status); - --pending_count_; - })); - } - WaitPendingDone(wait_pending_timeout_); - WaitPendingDone(sub_pending_count, wait_pending_timeout_); - // delete - for (const auto &elem : object_id_to_data_) { - ++pending_count_; - ++sub_pending_count; - const ObjectVector &object_vec = elem.second; - NodeID node_id = NodeID::FromBinary(object_vec[0]->manager()); - RAY_CHECK_OK( - object_accessor.AsyncRemoveLocation(elem.first, node_id, [this](Status status) { - RAY_CHECK_OK(status); - --pending_count_; - })); - } - WaitPendingDone(wait_pending_timeout_); - WaitPendingDone(sub_pending_count, wait_pending_timeout_); - // get - for (const auto &elem : object_id_to_data_) { - ++pending_count_; - size_t total_size = elem.second.size(); - RAY_CHECK_OK(object_accessor.AsyncGetLocations( - elem.first, - [this, total_size](Status status, - const boost::optional &result) { - RAY_CHECK_OK(status); - ASSERT_EQ(total_size - 1, result->locations().size()); - --pending_count_; - })); - } - WaitPendingDone(wait_pending_timeout_); - - RAY_LOG(INFO) << "Case Subscribe && Delete done."; -} - -} // namespace gcs - -} // namespace ray - -int main(int argc, char **argv) { - InitShutdownRAII ray_log_shutdown_raii(ray::RayLog::StartRayLog, - ray::RayLog::ShutDownRayLog, argv[0], - ray::RayLogLevel::INFO, - /*log_dir=*/""); - ::testing::InitGoogleTest(&argc, argv); - RAY_CHECK(argc == 4); - ray::TEST_REDIS_SERVER_EXEC_PATH = argv[1]; - ray::TEST_REDIS_CLIENT_EXEC_PATH = argv[2]; - ray::TEST_REDIS_MODULE_LIBRARY_PATH = argv[3]; - return RUN_ALL_TESTS(); -} diff --git a/src/ray/object_manager/object_directory.h b/src/ray/object_manager/object_directory.h index 7133d1e94..3ce15882b 100644 --- a/src/ray/object_manager/object_directory.h +++ b/src/ray/object_manager/object_directory.h @@ -22,7 +22,7 @@ #include "ray/common/id.h" #include "ray/common/status.h" -#include "ray/gcs/redis_gcs_client.h" +#include "ray/gcs/gcs_client.h" #include "ray/object_manager/format/object_manager_generated.h" namespace ray { diff --git a/src/ray/object_manager/ownership_based_object_directory.h b/src/ray/object_manager/ownership_based_object_directory.h index 68d5140b9..5b07f7999 100644 --- a/src/ray/object_manager/ownership_based_object_directory.h +++ b/src/ray/object_manager/ownership_based_object_directory.h @@ -23,7 +23,7 @@ #include "absl/container/flat_hash_map.h" #include "ray/common/id.h" #include "ray/common/status.h" -#include "ray/gcs/redis_gcs_client.h" +#include "ray/gcs/gcs_client.h" #include "ray/object_manager/format/object_manager_generated.h" #include "ray/object_manager/object_directory.h" #include "ray/rpc/worker/core_worker_client.h" diff --git a/src/ray/object_manager/test/object_manager_stress_test.cc b/src/ray/object_manager/test/object_manager_stress_test.cc index 83daf8297..018bc357b 100644 --- a/src/ray/object_manager/test/object_manager_stress_test.cc +++ b/src/ray/object_manager/test/object_manager_stress_test.cc @@ -20,6 +20,7 @@ #include "gtest/gtest.h" #include "ray/common/status.h" #include "ray/common/test_util.h" +#include "ray/gcs/gcs_client/service_based_gcs_client.h" #include "ray/object_manager/object_manager.h" #include "ray/util/filesystem.h" #include "src/ray/protobuf/common.pb.h" @@ -32,10 +33,24 @@ namespace ray { using rpc::GcsNodeInfo; -static inline void flushall_redis(void) { +static inline bool flushall_redis(void) { redisContext *context = redisConnect("127.0.0.1", 6379); + if (context == nullptr || context->err) { + return false; + } freeReplyObject(redisCommand(context, "FLUSHALL")); + freeReplyObject(redisCommand(context, "SET NumRedisShards 1")); + freeReplyObject(redisCommand(context, "LPUSH RedisShards 127.0.0.1:6380")); redisFree(context); + + redisContext *shard_context = redisConnect("127.0.0.1", 6380); + if (shard_context == nullptr || shard_context->err) { + return false; + } + freeReplyObject(redisCommand(shard_context, "FLUSHALL")); + redisFree(shard_context); + + return true; } int64_t current_time_ms() { @@ -71,6 +86,7 @@ class MockServer { node_info.set_object_manager_port(object_manager_port); ray::Status status = gcs_client_->Nodes().RegisterSelf(node_info, nullptr); + std::this_thread::sleep_for(std::chrono::milliseconds(5000)); return status; } @@ -85,7 +101,7 @@ class MockServer { class TestObjectManagerBase : public ::testing::Test { public: void SetUp() { - flushall_redis(); + WaitForCondition(flushall_redis, 7000); // start store socket_name_1 = TestSetupUtil::StartObjectStore(); @@ -96,9 +112,10 @@ class TestObjectManagerBase : public ::testing::Test { int push_timeout_ms = 10000; // start first server + gcs_server_socket_name_ = TestSetupUtil::StartGcsServer("127.0.0.1"); gcs::GcsClientOptions client_options("127.0.0.1", 6379, /*password*/ "", - /*is_test_client=*/true); - gcs_client_1 = std::make_shared(client_options); + /*is_test_client=*/false); + gcs_client_1 = std::make_shared(client_options); RAY_CHECK_OK(gcs_client_1->Connect(main_service)); ObjectManagerConfig om_config_1; om_config_1.store_socket_name = socket_name_1; @@ -110,7 +127,7 @@ class TestObjectManagerBase : public ::testing::Test { server1.reset(new MockServer(main_service, om_config_1, gcs_client_1)); // start second server - gcs_client_2 = std::make_shared(client_options); + gcs_client_2 = std::make_shared(client_options); RAY_CHECK_OK(gcs_client_2->Connect(main_service)); ObjectManagerConfig om_config_2; om_config_2.store_socket_name = socket_name_2; @@ -139,6 +156,10 @@ class TestObjectManagerBase : public ::testing::Test { TestSetupUtil::StopObjectStore(socket_name_1); TestSetupUtil::StopObjectStore(socket_name_2); + + if (!gcs_server_socket_name_.empty()) { + TestSetupUtil::StopGcsServer(gcs_server_socket_name_); + } } ObjectID WriteDataToClient(plasma::PlasmaClient &client, int64_t data_size) { @@ -172,6 +193,7 @@ class TestObjectManagerBase : public ::testing::Test { std::vector v1; std::vector v2; + std::string gcs_server_socket_name_; std::string socket_name_1; std::string socket_name_2; }; @@ -421,5 +443,6 @@ TEST_F(StressTestObjectManager, StartStressTestObjectManager) { int main(int argc, char **argv) { ::testing::InitGoogleTest(&argc, argv); ray::TEST_STORE_EXEC_PATH = std::string(argv[1]); + ray::TEST_GCS_SERVER_EXEC_PATH = std::string(argv[2]); return RUN_ALL_TESTS(); } diff --git a/src/ray/object_manager/test/object_manager_test.cc b/src/ray/object_manager/test/object_manager_test.cc index 48fa9a65a..9fbecc4ca 100644 --- a/src/ray/object_manager/test/object_manager_test.cc +++ b/src/ray/object_manager/test/object_manager_test.cc @@ -20,6 +20,7 @@ #include "gtest/gtest.h" #include "ray/common/status.h" #include "ray/common/test_util.h" +#include "ray/gcs/gcs_client/service_based_gcs_client.h" #include "ray/util/filesystem.h" #include "src/ray/protobuf/common.pb.h" @@ -38,6 +39,8 @@ using rpc::GcsNodeInfo; static inline void flushall_redis(void) { redisContext *context = redisConnect("127.0.0.1", 6379); freeReplyObject(redisCommand(context, "FLUSHALL")); + freeReplyObject(redisCommand(context, "SET NumRedisShards 1")); + freeReplyObject(redisCommand(context, "LPUSH RedisShards 127.0.0.1:6380")); redisFree(context); } @@ -91,9 +94,10 @@ class TestObjectManagerBase : public ::testing::Test { push_timeout_ms = 1500; // start first server + gcs_server_socket_name_ = TestSetupUtil::StartGcsServer("127.0.0.1"); gcs::GcsClientOptions client_options("127.0.0.1", 6379, /*password*/ "", /*is_test_client=*/true); - gcs_client_1 = std::make_shared(client_options); + gcs_client_1 = std::make_shared(client_options); RAY_CHECK_OK(gcs_client_1->Connect(main_service)); ObjectManagerConfig om_config_1; om_config_1.store_socket_name = socket_name_1; @@ -105,7 +109,7 @@ class TestObjectManagerBase : public ::testing::Test { server1.reset(new MockServer(main_service, om_config_1, gcs_client_1)); // start second server - gcs_client_2 = std::make_shared(client_options); + gcs_client_2 = std::make_shared(client_options); RAY_CHECK_OK(gcs_client_2->Connect(main_service)); ObjectManagerConfig om_config_2; om_config_2.store_socket_name = socket_name_2; @@ -134,6 +138,10 @@ class TestObjectManagerBase : public ::testing::Test { TestSetupUtil::StopObjectStore(socket_name_1); TestSetupUtil::StopObjectStore(socket_name_2); + + if (!gcs_server_socket_name_.empty()) { + TestSetupUtil::StopGcsServer(gcs_server_socket_name_); + } } ObjectID WriteDataToClient(plasma::PlasmaClient &client, int64_t data_size) { @@ -171,6 +179,7 @@ class TestObjectManagerBase : public ::testing::Test { std::vector v1; std::vector v2; + std::string gcs_server_socket_name_; std::string socket_name_1; std::string socket_name_2; @@ -482,5 +491,6 @@ int main(int argc, char **argv) { ::testing::InitGoogleTest(&argc, argv); ray::TEST_STORE_EXEC_PATH = std::string(argv[1]); wait_timeout_ms = std::stoi(std::string(argv[2])); + ray::TEST_GCS_SERVER_EXEC_PATH = std::string(argv[3]); return RUN_ALL_TESTS(); } diff --git a/src/ray/raylet/node_manager.cc b/src/ray/raylet/node_manager.cc index 52e9354a2..c8dcae7f9 100644 --- a/src/ray/raylet/node_manager.cc +++ b/src/ray/raylet/node_manager.cc @@ -306,7 +306,7 @@ ray::Status NodeManager::RegisterGcs() { // node failure. These workers can be identified by comparing the raylet_id // in their rpc::Address to the ID of a failed raylet. const auto &worker_failure_handler = - [this](const WorkerID &id, const gcs::WorkerTableData &worker_failure_data) { + [this](const WorkerID &id, const rpc::WorkerTableData &worker_failure_data) { HandleUnexpectedWorkerFailure(worker_failure_data.worker_address()); }; RAY_CHECK_OK(gcs_client_->Workers().AsyncSubscribeToWorkerFailures( @@ -1984,8 +1984,8 @@ void NodeManager::ProcessSetResourceRequest( RAY_CHECK_OK(gcs_client_->NodeResources().AsyncDeleteResources( node_id, {resource_name}, nullptr)); } else { - std::unordered_map> data_map; - auto resource_table_data = std::make_shared(); + std::unordered_map> data_map; + auto resource_table_data = std::make_shared(); resource_table_data->set_resource_capacity(capacity); data_map.emplace(resource_name, resource_table_data); RAY_CHECK_OK( diff --git a/src/ray/raylet/raylet.cc b/src/ray/raylet/raylet.cc index 6336f3160..3a683a952 100644 --- a/src/ray/raylet/raylet.cc +++ b/src/ray/raylet/raylet.cc @@ -139,10 +139,10 @@ ray::Status Raylet::RegisterGcs() { // Add resource information. const NodeManagerConfig &node_manager_config = node_manager_.GetInitialConfig(); - std::unordered_map> resources; + std::unordered_map> resources; for (const auto &resource_pair : node_manager_config.resource_config.GetResourceMap()) { - auto resource = std::make_shared(); + auto resource = std::make_shared(); resource->set_resource_capacity(resource_pair.second); resources.emplace(resource_pair.first, resource); } diff --git a/src/ray/raylet/reconstruction_policy.h b/src/ray/raylet/reconstruction_policy.h index 2300fb1c2..e221faffe 100644 --- a/src/ray/raylet/reconstruction_policy.h +++ b/src/ray/raylet/reconstruction_policy.h @@ -20,7 +20,7 @@ #include #include "ray/common/id.h" -#include "ray/gcs/tables.h" +#include "ray/gcs/gcs_client.h" #include "ray/object_manager/object_directory.h" namespace ray { diff --git a/src/ray/raylet/reconstruction_policy_test.cc b/src/ray/raylet/reconstruction_policy_test.cc index 17b2f46d6..199e4d51e 100644 --- a/src/ray/raylet/reconstruction_policy_test.cc +++ b/src/ray/raylet/reconstruction_policy_test.cc @@ -21,9 +21,11 @@ #include "gmock/gmock.h" #include "gtest/gtest.h" #include "ray/gcs/callback.h" -#include "ray/gcs/redis_accessor.h" +#include "ray/gcs/gcs_client/service_based_accessor.h" +#include "ray/gcs/gcs_client/service_based_gcs_client.h" #include "ray/object_manager/object_directory.h" #include "ray/raylet/format/node_manager_generated.h" +#include "ray/raylet/reconstruction_policy.h" namespace ray { @@ -97,17 +99,18 @@ class MockObjectDirectory : public ObjectDirectoryInterface { std::unordered_map> locations_; }; -class MockNodeInfoAccessor : public gcs::RedisNodeInfoAccessor { +class MockNodeInfoAccessor : public gcs::ServiceBasedNodeInfoAccessor { public: - MockNodeInfoAccessor(gcs::RedisGcsClient *client) - : gcs::RedisNodeInfoAccessor(client) {} + MockNodeInfoAccessor(gcs::ServiceBasedGcsClient *client) + : gcs::ServiceBasedNodeInfoAccessor(client) {} bool IsRemoved(const NodeID &node_id) const override { return false; } }; -class MockTaskInfoAccessor : public gcs::RedisTaskInfoAccessor { +class MockTaskInfoAccessor : public gcs::ServiceBasedTaskInfoAccessor { public: - MockTaskInfoAccessor(gcs::RedisGcsClient *client) : RedisTaskInfoAccessor(client) {} + MockTaskInfoAccessor(gcs::ServiceBasedGcsClient *client) + : ServiceBasedTaskInfoAccessor(client) {} Status AsyncSubscribeTaskLease( const TaskID &task_id, @@ -180,9 +183,9 @@ class MockTaskInfoAccessor : public gcs::RedisTaskInfoAccessor { task_reconstruction_log_; }; -class MockGcs : public gcs::RedisGcsClient { +class MockGcs : public gcs::ServiceBasedGcsClient { public: - MockGcs() : gcs::RedisGcsClient(gcs::GcsClientOptions("", 0, "")){}; + MockGcs() : gcs::ServiceBasedGcsClient(gcs::GcsClientOptions("", 0, "")){}; void Init(gcs::TaskInfoAccessor *task_accessor, gcs::NodeInfoAccessor *node_accessor) { task_accessor_.reset(task_accessor); diff --git a/src/ray/raylet/task_dependency_manager.h b/src/ray/raylet/task_dependency_manager.h index d35d644e7..75654698f 100644 --- a/src/ray/raylet/task_dependency_manager.h +++ b/src/ray/raylet/task_dependency_manager.h @@ -17,7 +17,6 @@ // clang-format off #include "ray/common/id.h" #include "ray/common/task/task.h" -#include "ray/gcs/redis_gcs_client.h" #include "ray/object_manager/object_manager.h" #include "ray/raylet/reconstruction_policy.h" // clang-format on diff --git a/src/ray/raylet/task_dependency_manager_test.cc b/src/ray/raylet/task_dependency_manager_test.cc index 99f6d5622..d65b0aced 100644 --- a/src/ray/raylet/task_dependency_manager_test.cc +++ b/src/ray/raylet/task_dependency_manager_test.cc @@ -21,8 +21,6 @@ #include "gtest/gtest.h" #include "ray/common/task/task_util.h" #include "ray/common/test_util.h" -#include "ray/gcs/redis_accessor.h" -#include "ray/gcs/redis_gcs_client.h" namespace ray { diff --git a/src/ray/raylet/worker_pool.h b/src/ray/raylet/worker_pool.h index 62dfd20a5..66d4b94c7 100644 --- a/src/ray/raylet/worker_pool.h +++ b/src/ray/raylet/worker_pool.h @@ -26,7 +26,7 @@ #include "ray/common/client_connection.h" #include "ray/common/task/task.h" #include "ray/common/task/task_common.h" -#include "ray/gcs/redis_gcs_client.h" +#include "ray/gcs/gcs_client.h" #include "ray/raylet/worker.h" namespace ray { diff --git a/src/ray/test/run_object_manager_tests.sh b/src/ray/test/run_object_manager_tests.sh index 641c53050..4fba8d2dc 100755 --- a/src/ray/test/run_object_manager_tests.sh +++ b/src/ray/test/run_object_manager_tests.sh @@ -25,18 +25,22 @@ fi REDIS_MODULE="./bazel-bin/libray_redis_module.so" LOAD_MODULE_ARGS=(--loadmodule "${REDIS_MODULE}") STORE_EXEC="./bazel-bin/plasma_store_server" +GCS_SERVER_EXEC="./bazel-bin/gcs_server" # Allow cleanup commands to fail. bazel run //:redis-cli -- -p 6379 shutdown || true +bazel run //:redis-cli -- -p 6380 shutdown || true sleep 1s bazel run //:redis-server -- --loglevel warning "${LOAD_MODULE_ARGS[@]}" --port 6379 & +bazel run //:redis-server -- --loglevel warning "${LOAD_MODULE_ARGS[@]}" --port 6380 & sleep 1s # Run tests. -./bazel-bin/object_manager_stress_test $STORE_EXEC +./bazel-bin/object_manager_stress_test $STORE_EXEC $GCS_SERVER_EXEC sleep 1s # Use timeout=1000ms for the Wait tests. -./bazel-bin/object_manager_test $STORE_EXEC 1000 +./bazel-bin/object_manager_test $STORE_EXEC 1000 $GCS_SERVER_EXEC bazel run //:redis-cli -- -p 6379 shutdown +bazel run //:redis-cli -- -p 6380 shutdown sleep 1s # Include raylet integration test once it's ready. From 670d083a56381510538aaa36d05b486d200917f7 Mon Sep 17 00:00:00 2001 From: Sven Mika Date: Wed, 23 Dec 2020 11:29:58 -0500 Subject: [PATCH 74/88] [RLlib] Fix broken unity3d_env import in example server script. (#13040) --- rllib/examples/serving/unity3d_server.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rllib/examples/serving/unity3d_server.py b/rllib/examples/serving/unity3d_server.py index 04f0b8a05..0a39d3e6b 100755 --- a/rllib/examples/serving/unity3d_server.py +++ b/rllib/examples/serving/unity3d_server.py @@ -34,8 +34,8 @@ import ray from ray.tune import register_env from ray.rllib.agents.ppo import PPOTrainer from ray.rllib.env.policy_server_input import PolicyServerInput +from ray.rllib.env.unity3d_env import Unity3DEnv from ray.rllib.examples.env.random_env import RandomMultiAgentEnv -from ray.rllib.examples.env.unity3d_env import Unity3DEnv SERVER_ADDRESS = "localhost" SERVER_PORT = 9900 From 1e74187179ca7d24564d8f5601f9a38f7e447227 Mon Sep 17 00:00:00 2001 From: Sven Mika Date: Wed, 23 Dec 2020 11:30:50 -0500 Subject: [PATCH 75/88] [RLlib] TorchPolicies: Accessing "infos" dict in train_batch causes `TypeError`. (#13039) --- rllib/utils/torch_ops.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/rllib/utils/torch_ops.py b/rllib/utils/torch_ops.py index 53c26dc4d..ce6c86a16 100644 --- a/rllib/utils/torch_ops.py +++ b/rllib/utils/torch_ops.py @@ -63,11 +63,20 @@ def convert_to_torch_tensor(x, device=None): return RepeatedValues( tree.map_structure(mapping, item.values), item.lengths, item.max_len) - # Non-writable numpy-arrays will cause PyTorch warning. - if isinstance(item, np.ndarray) and item.flags.writeable is False: - with warnings.catch_warnings(): - warnings.simplefilter("ignore") + # Numpy arrays. + if isinstance(item, np.ndarray): + # np.object_ type (e.g. info dicts in train batch): leave as-is. + if item.dtype == np.object_: + return item + # Non-writable numpy-arrays will cause PyTorch warning. + elif item.flags.writeable is False: + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + tensor = torch.from_numpy(item) + # Already numpy: Wrap as torch tensor. + else: tensor = torch.from_numpy(item) + # Everything else: Convert to numpy, then wrap as torch tensor. else: tensor = torch.from_numpy(np.asarray(item)) # Floatify all float64 tensors. From d37e2c3a20f8b601f84a74903f8a0137e0ab5afa Mon Sep 17 00:00:00 2001 From: Ameer Haj Ali Date: Wed, 23 Dec 2020 18:43:34 +0200 Subject: [PATCH 76/88] [joblib] Fix flaky joblib test. (#13046) --- python/ray/tests/mnist_784_100_samples.pkl | Bin 0 -> 627706 bytes python/ray/tests/test_joblib.py | 20 +++++++------------- 2 files changed, 7 insertions(+), 13 deletions(-) create mode 100644 python/ray/tests/mnist_784_100_samples.pkl diff --git a/python/ray/tests/mnist_784_100_samples.pkl b/python/ray/tests/mnist_784_100_samples.pkl new file mode 100644 index 0000000000000000000000000000000000000000..c2dbf316fd6d154e7d038d601e6e1676f3673a7e GIT binary patch literal 627706 zcmeFa%d+*xb?2EbcXzSeJ>Y-?4m3b~!C){LUofaUb!GBU;+_0tS+eD_WP@9>TUG5# z(QOYr%Sm^2P5TM(D0mb+g5IE;M*e?Ep0!UNU|#?vHuDV5+5qyh{%hrR<;u+P$G`o@ zfBlF5_HVEK_MiUwZ(sX|fBpM^`7fXT^FM$3Z-4SP|NOVV{ilEYKmX;Q|Mj>3@+Y6U z80t^{@ptj-XMgme`@cT>4{!g`;j=&f@1OloKmY81e&xeI`tWxjKmW)7_hW@Bu_ruTr*H1qAn9A8G$P zz5gfCf7E;8xrP6J9P@K`k%qsB)c;mg{C47RiZ*|*bWA)iT7cpI^`$*4kN9vaTQ+aFukC@*S|9tNy|Nie{yg_T)8Hnm4P#b?-V3_hNh9)}q=E`NO2k7@We z(m~@G$|rPSJm6xuF;ap5g2s#C5c%vheg=o7UV3EwRve@^B98ED;h8x;ZKSX8XE9ze zxV|eqny+;KRk%!jvU*~io%)SZ$|3Ux&Clk}AJV*ye(*_@gWLEod>Hu%!s93E4|)pp z<7!9jykTQlUu18_$?y*iDvyJ31n{VT`y!$Gt4g3sph}=hph}=hAiVEKAq=uB(x39| z^nG`dGIsBvajK=A?Bct@r*U3u#&_Tvgr^bZefTKq*}W)Q3gDN-&nKe71>c>~n`b`z z%NYJB++QmlSH@o}`yYH-YVq;i{VL+ zZG{i}5c$(?Y=)m*9sssyS8ssyS8sswIW0zLZxu-nUCcQwe*{Fu}iemGZ4 zd0s-Yj|S{?SH{Qwl9u_)Jg%+~%8zL@T0isHQGXKOUm=4ri6HyK>iFc}+H1VN)V#Of z2;ZRcnDAYDzme8CI@tl&)*tp&@UcJVTIJHIUX%W6_~ef`U^i3w#NKf&e5`|m54_h( zKbH=2O!;Deys|%=#h3b_9Al@xf9l5=*yFH2Iq3_0{X+YCEI#Z?IT%6pm3?!L{S1eQ zqZe(c&!S%acB}S;?}2i8<^BNv;ipqRo%q1-+fu?IcMpHUL3=UC9+UJUdzE`3%7Jym z6d#5bqU%FyCt{dA%1R!{b; zf^Sf}#6IlR1-(#)R)6^UER@~q$y$8Oi&kf3p$x%C`MfoJl*iORXkNRJIBT+10#yQ4 z0#yQ40#yQQNL^v3$l z%3NJ7yW!x+v?!03a_-I5GF&Sm`ImRB+tXi*hK>CGuLa-RaR|7mZ?E`kWwl8<)Gu2e z{X*-oYk{P_7&7gNV;s2_L)Xej^^I{P`Ww_A1lIBPHa{OU&v0uIoF-aG3hjICf>uTUH|Rt%HOmb30E%1yBIH><88|$?a0eYKbOwZ z@4plCryN4vkHwe1<9FZ5?%5bU5w8_aU&^`V(jWVS;IlLO;62o{PAkfL(pg$9r>0N8 z%eo}^8rkJNcEO)Yud>rP=`-OY+^Gzd7XDepIY`gQ!*l;iY5bnZ59dRH3;k`~znutU zGWf0HyLJ5qb1&Jgl~3*k8ch3gYxU2o|491#$WPYUT?H`nj1K=0`pI2^d4O5w6T*ca zeJSJ52UtGqm2kEPKSl?qyA?g2C6pSeN}x)hN}x)>z3QV7IEPaE`vd2->pb4QNGix* zaySUr#QCS5aHRs|Tnzb<2V>9fG|qOi$Fv47c6~=WTjP99dJ&$%*bf<>Kpp40$S*D2 z#rS-Ox=)m@+xX9no8iw=>`(SQp3?DkW_={#lAea&Ne|dbjXJP@rg)U!QuOzI@sEA| zV?E2>Yv3F|(tZjTIEyKt#~h#ZrlTF&DZU^(C-;%i&+*LtIq1#H4(y)uC&_8?z%NC9 z!UvJIs+XPg0v*gV?}Pda?&V&b->D2(hy74^;7|Ow^m}^GJl*r<0@aeT*G0@JyfbL&)ra!mwc~0y2YmL{WuR-l3d?d>HChLUY(mHc%_E!zk zBkP)->@X=k$&GzVLF@ONLk1pkS(P|jq(bzk@QvY;M1;UXU6m*j6itfz{v z`;j=$F6Cp&$-p>>_-9ft{^5`MOJ^6@!d%;=!5T;I&e`AX63qGpuvx&IgUJvE%M`KQ4Ok5mzTWZ^!&W6n`Ir z?~cA4@eP_M-5fp?a)>x@4j+E7@&ma!e6nYEvL9G{X}+uV71X^qiXF}33uu2DC-U>eK64)tj4&Rx2 zuhFUmssyS8ssyS8sssvd=CGz zAE|so-_SFH$Wi%4+VY(6QV-YibItL2=6mkfZLL40^tfC5h5SdS%I8+?%Bkw&Uf&@9 z>mj4BR4(x^ed15e>bl{z-p+a$=M7uh<=PFs@3mgfeZTI9+v4I}iC%WGB=-!CDSQq7 zvDcOU*{OkV$k_eB1095$=Nt|4<35|QYe&>NbI<-#?O$>4_?T0M;s+e8zDvV*uDo80 z_-_?o;5^Jy^r!tRJ>};~l>e#f;~sVPQEaWh@JYlmsC|Q<+50rUAvxfGPd{Mv#`xuE zU(2ZiJT-h~KU8~&eD@kJ=cwIxgnSZ}A5i~`nBJ-ET^v4xFY#yekbB=L9P^5Xd0n#+ zE(wNxTDg6sIXsQ}MfSh=L&QP9spQKcnxFLA58T>%q38^=dE<(`PF zT=w$8=ns6i;&<)Xr=xik_^=D^MNLzDvi@1|O-ZWE?^b`F^U%s<_hhatKBslDlM^a5 zLb|_+{_n={Vb`R6>A<)1{eKjvUiYhjuLqy}fC#sD+RILya31{LA>HR+cpCJ#vEBMU zb60p0zAxHOe}zN~N<#Q3pWzp`g^zWWf$~|(`MTFC0$QJF)XuQq<66nrbgBfZ1gZq8 z1gZpXQ3CkE_S~CvBOtN2G44rzC40xG{S@Hx8rR{sws&|=>rFv+jkdIM`OJ?ZZr}?) zvB5A$Ba7hHHh-;sDJp?vp-7rIyL(4!G%BQ z!t75bpShPFT&X=5-$4|-E&Jy5e$ew|z6%H7K86;)G@gLJ2I6>c_!%YwIT)7oVDal|Yq1l|Yq1mB6+VU|uBqL#uIhuklRKd+65Z56u-8b`$J- zy~dwYlzaW==CfHi?9;{{c?w@TS7XXPx*0H&G4}<^egQrwI}6X2c3Mv{sTKP;$UX)R zZ~f~``E?uFUCtJu>=3=d30_OXUj3Q0j~c%vi;L%0{!M4gbLEJfN7SfZb*4Om3+jID zSLX1<>{Pyb)iXbV2VBt7>mTcMtgFEHTrd(qvXl0vf0)9#xBAC<-mEt#s`U$J z@b3*@w)Xzv&G<=nL-fix6+2@#e(4`Lr^B{C_`UwnmyY03{=E7N^agdBS3DYd$&cUk zW9kj(6rH;6uXRmZ?-IU^$7ezFH~1M;p3?gIUBo{M;q!mt{rF6~k{hSB(x=;aKLq<` z@oNBQqO1!_zl~`h3D>FKb<{WcPav1AX^^}*HyFO>(xCa_);K2fS3Bs{Eo1ZlB!rKv# z`?2Uzf`z>}L_R5Jjh@{r@x9y2-%F8S+Jhdu1?Px}7FWJz%D0!@7CjA`&+iqFg*_O% zZ9rc(;7*@u57D33y1UO*_JKgFlwO4|?8>-Bb2l?6{4} zDb(vZunxmGkUgu{{&xIud-;X5*!S>n z(Kwp-S*rPx`@nB1xSW3&^Pfv;(7egx|HQslzYgwT=5H1c^LEZ($WiR<={xq@%)Bw( zE4Z+B!&^~~-1C{9_1@tpcPbyRN9*xTVsmYcSKFHtDmWM(0IBZ-mPnHvQHQOGU z^2z;sC#Szr1ongV;4*j`p8wKE4ja*p_??}t?4d3EXQbEQ9{HO___)XgSHTCS( zu5y1h=Q&MexBh3@Z9Db%*wLo&+3x}UYCPe-3&wqJs^3SSTl?2%a7@1DyfK~AWcA0h-+BMvslI{7($)7zXq*=QCf19y&mEQB zKB%5qT%PjzNz7-mJX{(evmevFc4GeDjp3UKy)~L=T*~(B_}WKm<(%$?dlLCRN>A^5 z>6`UVo6aD8dB&yBGU**GMhzEn=rx;GhaMYX@fY2K6CZ}xkrz4uVH?N8;$9}@RP3utM zZs)k+jEn3zchAhi!aQy=W%X0=|DFxz*^74{yLcU-+lCOxeULK@6)}w zD#yqNN_uXqdaH673@-e2kLulf5mzt!4e_AKE}{OuQ@Vrd^=jwdNN--}EWuyU_|a4E zC}8R@@b`u920vffIhej(*#boOP9gd3LoM6-};e5Pd&x&Uh7AxoqQ#RTM5olHrev@k@(X2pQ)ZvKTSKE+Lxbi zi^l>?=PFVE4d3W1M-%QtS-33DUUKiSzZ*S)uT#0CUfK4=rdyw@1gZq81gZq81gZq8 z1gZq81Wu9w=Lc)vgk5E~c0vnY=Z9dA={4?l_#@l!JAF?6Jw2u6loNKAUgJ%#^|F@rmI>;ydvAr4bp-iGC%;GeKQCXKsdhOR4gbSQd`iZ5>&(+hHeV6_#;So;8HO~2~XT4~G`1h#X1*5kXKN0O8 zT$Gn!<`}UT?%)Ui% zCA#zeEc{e@Gv_a)@v`>6EhX2b(`_{VMPD8J-FJ=)xiOvw-`>w%09U0|0#yQ40#yRb zNWkkpz@T|TEBmb1cr;PjZB9NHbs4tvsbas=IHqGh#Q748Z$w+!CwhfTuQ5<^~`;t=xbo!I+GzUy9Ih23Y$l9K0foilgDT8`cV=I zAL|THKg+z+aV~r*^J(b||LgFph=Y0ubw;Ox9tPww;In|fpU zC1ABynEFvC%$dC^sMm>`XxSkcR%wfXc0e5x$0G48pj6mhRh2-MK$Sq1K$Sq1K$Sq1 zz+)14*^ghD7kAi2eSH5|)FQrpKk^x5--SBeXW46gCe1fAPIs@Cfsgx7nU6CshGK8% zG%in+al6LlgT4p+Xyuo_wT^=P%o5)QN3Z#9BfC&g`Lp4B#dFf<*0``+cH$p?(>xJ> z@o$@QCOZJ4d0d2)Ua;F?-?;F31e6z2GctP?Bg{?zyZy$ERqzCMiY8Rc_ z7q59g;ezZjUi*_aQd2%#*>79fEfU|f@3N1gSH45{axMkyBk*l$Z`vE$S*P@K1}ok9>$Ns zv^Ng>@NtSoy}^$MJm6!05#z<9aKc*@m@hLUyZWI13Vcrci_&;Q9E`pO#=r0;zXA3)ddY$CrHxBq8ybWUee~cHF8@nE2IYGimy}*E1+X#gMZ-FS z)=$6%WnUcfDV}P42>;Mt{Wkh)q-XcU7Z2$kgX}N;P2{`Pc%OP$?D+T|i26_^P$f_$ zP$f{T1m5x^Ym5&(_Q0RS_r>x(#n{J@e~%qe$aePsjV0r0#O@s`H# zTzo zcJaZy{Gwax+xxzFo=;)0rM{(hZ5I~Dur%S9&6WVi6FC#xPN zzVQn+_9oh&FV0IzF^c+;##gd0CFVE>M&lI6cde)QdhV&Oh53t8zOcVIt%D_eANld; zaENq#c;H{F?+x|$fiPRV!e1VIoRgAx!S7+pKS1YMbd0l80^9zu_^`{L%$3G9rQvfr z*KR=W=<7s)$xcc=JTd-K!7tXI?B1QmIn=|Y_&qsYos)?lm7nYPAY5tvjDF=?{$v)` zh(P~EJD+f>Ua{{1|Eos*MFXxMl!?pA&f=uKfI8JLkNsZ#qfMtiR|!-JR0&iGR0&iG zR0&iGR0&iGJRpH@`0*w4iB@)P#@o=av-_z7b|vH26tItG=uo;Jm*%If=5^q4vU@u0 zq_P8YE*0Sh=7WXKn>y(We*1~_71jq^;cqoRO#6}^DDqpr3R-ta_$JO@xA?GgxI*Xi zo$z(kFYMJ?@5CO=`e3i{sXV_E+oLUY^l08HXXuZ0+eUuZyoV-x{iBJN^RZL=&{2=1 z$D~g^d+eODTbF}iCMW72(2vFE*q1ajxl)6%-|-0F+rCtmPvXn}dBic8a=M#+Ia6qV z5%Hlv+E?e}`>E)qerKJmrGIDqLVr;z<5PUS=Eo3?#?e9Ll6%#rqL==i{R7|%H_}eB z4oJVpIP&CkPfJsv+KEo}3w)f{R~oys`FY~Upy42@dWOI1_;JVog4Qu<_XBP$jTi31}XO9VUC;$h82z>SIds+i^?pPxy7rkMS!0 z&6DS+o{5+J!zn-Ie`a!%(eV2vKd8I&lX1@pVDUNGffgGlYd#1r=vw&5uV;Sx!Ik9t zN=&CW_HOc9YQ3eGpLEJc<7S@m^K%#OgTfz;bdTs#>q*Z=xl+%M@)ce;VRxUb8=O-l zKRcVxQsp!8B|C%5&c*pkiRS7Z z<>PK5U7OFj&nk!Dv6OY$l>EoyOZ)#_8=q~w>@(|sXO@e^$B93%6LpgBTsoW>o^)yF zPGs}!;8XP_2#2NcYu`^jx2hLsGkjvibMfJaOgqK-VVrZ5!&w?Wul_%%U1i-NC&npx zt&Pt;^&_i2m|myllV=5%To8Uzh0&rq!gH!&`#O@=GngSA=04r-d^o+tMSj?^#{I6 zs|2b9ssyS8ssz3g63{smoyIEz> z&rt4(ygEPUCiuFKF;)Hs?5JALfbT?QH^V~%nR3SSka^E{_QgBOuL-{yKH&C%uC&rK51GXlrQ`R6E5o)uXxFRm}~c6iR}{mV9?`Q_&Sw0@&#Wympv)M z{J4vJ&$Wl)Z$Lha3-xO`(T6J7e7PnhB}wO z6HcdjZxlv>|6i1To2hWehi=R6?(o6?Dq28hTkJQH36Z>L3RmV7e_e{ zOzRD4KlSteVOCWigX}NZT^#E?;Wv9Pzc|Vx_KX}q z_|dg$PrT~qQg*-`A9LRiTF>m!+YA2TsBgsgjmj_bNbSzUu{M88`EqLKj~J5wHtRwO z*TKi2`hFB|;}_Tu!Ds!K z$4=I$Jxu2c>Rw|nyq)Zx>7L`c8I%_84{wefH3uK5^p^ar$G1EFC%+fk5zn}6$B@3v zdav@B>Y>(Ez8q$kzd!IJ+g|0-&AsoUbg)gtgC3G4YBwyt((Tt@#_-VZUhVy0@TKyo`Ts+4OMcfK{K#IX_7MH0dGNR6 z_c73({(Ttv1{dWKeAqMg1nC~HI@bg|(2n&bo$q!p&K@3g-eha)S*K^WS6n~xqx6Gm z|2;Sl@%tEPPyg;B-`LO46SUd+IJwt23OttX-S67x_N6k;y6_y$&5JVYdek%8k)7Yu zk?x@~egd3J=d}K}x%|&$pl9A`^wx83e|P;B%B`8-8!ky04-)>nxm;vNt$6e(m@y>};Pv4j1n&zT)a1jGjrs*QuSE+J|*! zHdl|9+DonbIZ?))Y28TUu}P#K(Fy!(EgCCGn`BC#mEQ!MY?Y!Gfr{n zDbqjFzcl(zKf`;!*80qMBQEyckbjP{ZX48IHd>E>-?=pPKRJ0{={@J2G5s-!`_N6I#t%Wbui>JcDIaiuL zomzatM=_loKF+<_jU7t!-WI!q^q$4Hlk)+*=Jj*iqZa$q8!9KK7L1j`Gr#vU`-xuX z{{*$uE3r908}a4E=Oiw*pRL*}?j4<<;ld$RpRgA;Dp$5XC7kpJqP_g>7v}Q((oS&CC-sAV8$8gU@umA-#CN^`*NvI__gs4cu2=VkE7h+|f0g!uedk16 zT5oXLhst_jr}cVE8?d<=DBdP@wD#F`brMpwck7Xd)i;t|Feu<_(AJ& zw2$ZFajM6RuhV@HuldLz)am&`woi_(a_eP(;QRrLcQ{1;QaFXbHNVf3Mn1$0$>Nye8pJ>obGqQ;XdrYVw`?=(Wz5?C3M? zVayvE&HozBd*|l6R=6O$CH{=F^Ob#``Gzaho;zqBhdA2VD2Kjo5o)TlnXx5RIT1@JlP zZ;;)zxbwkuJ_vS8_#S8%g}0GEb^H5>o9~3je>NkI?O&^W1=&|R;iCLO>E{wI{KvIV zaC;z^alm;o;4AgMm2*+DetC2F)UF1#m!5ONp7?V`RaQK$qjzd|`a;KJ=L)UYJ~;4& zQvJ=&1lS$?7!!PFCR1EA)_G}9=>Lji^jwDI_k!Mr6UV9NX1i8A#&2Hhf^KK))>Co< zU(mR0Dt!H3B~T?$B~T?$C9o?AJn@lT*?R}gg9hdc8H)d_Q@lhu*J33h1ljL;`7v4Q zWRGARb}3Es!VfLEy2} zOAc9o^%YpXD2r!)j81!E@$u|FiQz8FYm3M}^4(KK{a8RD9=Y!_Il2tPk!<5`zX*?M7I*JaQQ2Z%diV# z=kKu>TKe%ixRHA&J*`Ds>`Naki?dxBmz7^b{@VYM&PBRXd1fDdula4FjQ?d9;ht>P zRZ=*`^YD2gUF&?XQ^m)*D4g>G?J{;3t4E$o-QSXx$E&Je@(X2slywJiL0ON{V<#>aMvgJf-_g?%)jFE{PLx9~nd07y-(Kyb#nq}kqMs^`f6I{U68I_i z(u>7~oo8?D3AijC`^-KaE1uGxJ@=?G9?7&D19m~y?{$uUX?Qpfslg9M{kzpqdRF=D zv8xpa=srV1>wSrTGhShw-hAp~#8T{@_hICpru+uL^-TC0(R_+!e0d@ANj=Kx&%O`L z54CTO{i;x>^&N*CSQlLzwdRSm>lVkgo*AF9uC+GIUiYNK2lKdd@deHoQ~QhGLRyDe z8=%ea13$i@-%FJC)wDmzfq5DJVC<7;d`3DpyZT%uP$f_$P$f_$P$f_$P$f_$@Tdf2 zZwhyQ3}&gj(tK#K`yZc1z8!wNvIBCiZG(M6%nn%pMSwPBGh-vA9};@m3=Qq5et8TzFMQN z2K(=wFC4+c?BR#on0kqQ#iRd+qZj+O7vJ9K%k5lW;A)M&TD2RytzTM4u==9?a4Y#$ zU%f#8$d5@j>L(H&tqb+C|GGcCdG)LpP4HV0msfpCxTs&x;`c>)STJ$!Z3mZEePf+$ z!K_zFnsb!`?IYu_RdQc36V_dy`V}*+Syc&C2~-JG2~-JG2~-JG3EZp%bnc1!rXL|D zKNtC14QJA==JDx#?nA`muy1Pq(_M)+*g04BJ01-Gt)8(@LOE|eNPqa<{OA)?otcf}&3AYtBb`EyeGt=~CWucusxaP=2+}dn!Ly5C2m?Cdq!X z98LC6>8qDrIQ&(NKgeE%JX-15Js-c1nH2ahofiYY*rhD5??)WAUE+RY8+`NUkHjza zP0y2j{W``&kDctslkrv#5LAy?cb~lv^W%umneoQT>`!MWy`cxk{o>EWze&N%-y&qj z119+Pj|V^U06u7|c4W{v6zYw>2=7Wgw+}xZ`9r|(T$=nSo{MrCg;nA|&N2H&)bp2lL%V`m;36pO`Nef##SAi_cB1JLr4{ z;`P#Fn*V*8o|8OL@jbEhw`6`b&eNpZ__M=kI6_Ncs z-H)B=7o6;7X?@{czp*AB&z^O@Uh6(NJ@KCN-X2eN`>)(z+hp_v9!uSy`t6y2jDz>$Tps8+j#CEPulrj{Ve)m zpI^2pOAdq|Sl=qGz5G!1&8(|3PE7dTj^R4Bqru?o5Ak~nU&s&Uo9H1?_1n}t)_>A> z;p>^-JS`KSdnfYMpV$us>D%q){5IjF9$7u{+|!P7zE>gTu(DAlP$f_$uy+Xr_5TCo zJjOMz_VGQoucSbRE{D3@E>{~&$SPlC@0C-Q1# zPtK*c_vRwk!A30KJKyjto~F#2j&gZk9!(%j9~u#pZv73 zi-Rw%x21sV{s?dRF~dRir9H&wj*Dj6m34S=A z-lHez!B502{Ik|AlOKg}wYhzpg~PGz^`>74tnCwL4 zOnRlg4)XT_Ps_Y1Lz{z-5T1#~r*i)L$QOQ)jJxOJBOLn;$^jNR!H?7G3BK)n`%Jh} z@%0+V<@n}3=SGD?)Hm--!^e6SeC4Rm?2F)hcH#}Hm#&a`B>37)d%^o$Jz_pgxK8Pm zQa-)xF7Rh*!}@~-sCiW*y%Mi9e7&hZ@CofR^C|pV2?y;8(;xi6AGBjV=v%S_{o^oOe{DwEL^zGNpf@y~6COfria$l?|tIx)$Ga^O@2gda)FoytjIuoBz_`yw+1ro%f`E1HYES*W2+Y zvwA^4mM$&##nS&9KPC+x#+!`Gpo`^S6Rr}d5~vcW5~vcW5~vc`S_1OJ=sxu$C!7~M zXr99J)@)AY&;7wU9N@9kW4D^hV6E>)ANV_2zU?!5LJn)Odul$&+w#wQ_=ldF{#M~B zcjznO`N(g@(pB=Jx8pWuyWg@S#U=hn?z=@&_((axR4}r={x?J%9sT;IaIp zN5a|frWkDg2oCt=Img)E@Fg5G+YQPm@+e(@3E#vI^a?KYI1z2}&$M^vW&|$u1}#Nz z$eZ?K3NLyg9+Y;q6kO;Fc_%)#jt;-}JK>N=X?VZ|MPKj>|4?wEH~I_WS6U@dB~T?$ zB~T?$B~T@>R|&8_J?Q)m&MyrazrwxYY!7~9xx@Y}3SNib*3N(hok#1nU(n*?*?m9a z-wQnac81LPHNEub@n2$};Z6YhL(~WQv9zTg-otQ`8^M0vT)eh`UuH4;9I z{B)Ud3jC*YVdNKp{`hVw{fzsz@UDjF;f$u*7ur{+eVy#bwdZtLoSY~2aIP!7{lv#i z%=p>z$us+Lob=vkp9|kD9k{P%1tcmO_fuFt?en1a)xPJsmAu$z@L=9+zSOx= z-AtTQ_(Ey^JIM08KK{~vMZ13h{p9ea^7tL4L%O9>-ACfoulJf44eIYH*L-(-J*Tf! z`B%QG1gZq81gZq81gZosmH_8)Fiyj+!TVC?d*=cLN2~dRrG2U6>Kx$L{rH9RZi4Jk z78lRS?vzWfwD->4A@kmb zC!8zfe4>Q&;K!eBru_hi!;Yr&hZ{W;k8qvli<{XCe-QC-FGsNFo#>hKL;~xXuf=%S zuXE=UftU5tQFtr-<9zbg*uUTh+E1*9g}aE8@X*fae{Fno&!3pWhhITB+{bx@552pE z|K#V;iy!zHkT>V{$euRI&9&2;%<%C5LFKIlk0>}#T?fmRPs+V}*FVnCmA_*D%Heae z3%Z%-g5>aRbuUD(b4_yi!c=?iqdus=N9`v4#$T?hH;O@{ahAo$`ij#z;Donya(un9 zzZTKk;R9v=J%2^_h0hriV0GWkulFYukC zX+2u~2l`;<*odd z2H^|L6Y!rqlRBkyBq!}tT3wAVseI0~hfVUQeMXJs1J1S(`@|&wQvtN@04`|q=begQ z`u7{h$Gz|94SF;BlYJOmPWCgWb$W_lulA%pe zxd`Sx3yqm1%ICX51I6H{e4jT(;fU6lk_&jCwn&g*x;wrdT$1v zbKsynQVwPR4S!JxKF)iAPf^ZuxA@r4z&Z(bY zwAv%qWm5ZV*3DA7OaGuOK98d*0tqzmzB8>{PG3DYxKr_|q^ylD&6tc3$SU z)F<#|`~6b%bRr+DThs43=TA(b2CovR5~vcW5~vcW5~vc`MgkuHJI#Z4YoE8`d(n?? zj^?-651+>GHrj5U#lO$l0rzUpMggxx{S3yg#QdG}Z3w zm$T2@jQ-5J3+1yt#Byu2y}!o{KK$$V2S^-i9i=h$!B*#m6|??zu7I`fM|p(0o%>BB zaa4Rs6ubQK71bZ>o4eH6mzmbNY_u!TkL6r&sC&!bm!j}7F6d?V zz8|Q7C+({ZvM=?VPk^5x>nyUz;4g#UZ^Ad_Z@D(qPyCo8>nXwDfp7ahyc6T$AM3G0 zZS0?n_}KTK>W#sLy)zu{y9ppFD|pj^zKR!fEu(I_OdOjbAh>U*Jnr>tgOb;eo3B5w4{?>ifoUjoo!rr_loEd#~Dvv|y`kD}H?%w7Be6&9k>6G@FeE`AK zucR;fbNV;+d!^y9{BVwx#nqZ}=uNqGw6iwB!k*{08%Oxq&r`@a#fp`Fn0AA9k$bCb zc#E&Y4`%LtEjGc0p1s*8M7!Y*{?QyeeMcUX`;{I|ZMl!V<^|}hQM<^wb2ew@N8lO? z)f;>dYTv=v(JnIX5@kP%ogek3&3gbQ+8O#|g3nu1FTblna%cVWi*v~R<189{?yo1pKb~4~Uwf@jpg((WKyeJ8$Pjj$ zTgT^Fe|XNHD4{}wtKV>HhkRk}=A}~O-fhZdYF{r+xRChP9z1oAb0KWcW@FI3+dcJe zcFcVQyjgrt{ip!OIiCB|_o0?aF1hD}cE$f$#5);rCjl78fe$~Cov=U2CitBE>)a3g zWSl((;9NB3Nqge`FeF!W4>R^Gz7Mj?PlkUu zmDS%;{wQZzeBz9HMtIf*S7&@?hO{T>BejQGPwF*)puGW?V|}yGa%tg*6WdGlZqq6K zIsC$pEVf$r1V?YqIj5Z|9nW0;D3?L~->u+7pVUwKk-2gyH5~0utA0M|?^f&;W_&_D zEL9(K`9g1$L+abncxG;bg@v=@8E_WH=IQKce2M=E^{w>r$y}gze9{*_ADD}a_8)x( z=BqmycV%(D?8i@xHwX3GZDD>%S^g(Ju#biI(a8>RYwPP1VyK_11gZq81gZq81a>I_ z>?qvVm-_Yh{fKAovnzCcfq7{fUuM=1roJ<;AU)=Zr6|3ZU7dLWe9omu_Ol!FCrss7 z`uFUcC`Ex6^1=Kfm1EgSFHCmx=-4a3 zAG#TTdG=RHU-WmZGbH+nA4@`M{oiEqd=~jW>K|pF?}Z!v+Gp%kllhbu;3vYZbqDOB zcD*pQKc(@1E>g%nnPteY!`iDI@0pIBl&5x*Ov@^Fdp0Vo_%qw@Qt;9As z&{vDUQVOw;AB`vhn&ulvfJB_K#O${(*fdlD`L!><;do44?1C_l?yf`*U8; z#B&S(fy?NN`NE_fSNGkw!eRC1WEWi+MK<|#&b+BFZtC0x>2Z+%mc_;M*7*FN`w=_* z1eLGc`N;jRV))da(C+A5G}6Hy<5WJf0$%t-WA<+hs()TMpgr$ZAMr0o4xOx#ahsh=;K{YmaDzBDc~VYJe(%DuF72DuEl9fRFfOhiKLBcYhb(cg*`fR=BT= z##sCLFS`T&6lwmzIj*wLU^fkK6^c{xXCCXdk6@6V7|(PXhxXiKEIinG**~zk0$+%7 zW?iI{9(wFA6SQT#E&q9s{~6=fjrE8>QY(I1@y9vd?z<{yccR=Afc``Sc16vj(6`&@ zk9+CC)yqyln;oqyr28CI4+Hx9p5!q*!7B$;d*La6@TK*j$=qIw{D(uNDV@UXF^$@d zUN~w0dbKO=#mK)ie_s6s?FsV^6Kwm(=X~Vhok&~RRlw_L|7BllWDjoD4i(!EOh}e~ z!O@OTjzxp|g^6T}|4e_5KA}$c9~BSah(SNTp?-DfOYKIl@|tk!ex##h3VmVEnNRC4 z@cXJCf6))nzPQrMtIj{^&G^OrDB=MgG<+D}2h~s7rBeNQgVu8?cP;BCnf?Ksj(x7Q zKZVxI{%@i_z~_DwpA(MU55{lQQ>Bv&zSVn+FFffwD(T@tggQ-O=c~6nIB`=lCtv zE(2f4*9w;J_I=^cgY;naQ#jkx>}Bec%js`A#iBx3eWAyqn7n|QBYgB*w%yoiy`vmv zae^27)85N1{EOyr7d!4awJg*A+-m!ZJgh#cpM2jd&FQaHeCPrGa_u6xO#4E8q+KdT zZQ#oA3*VCzxfBB|GBU~c1n%PWW#yaEAAF~Ai^AFH^|?x*N}x)hN}x)hN}x)hI0?M( zN0u7-^D>^pkI=mqzZb{%6hrStzOCQHfd9D0r9tQHP9;~`cQ3!OUN}1WhqyN*?(PsN z{;s?)4M&b2ix0nS!Uy>`fuH^2d@r3=e123B`U~=(FI{hmKjX*QX?`Al8F6{Fr{MFP z_p_JwRDP_3^a2iO(7u62erLSz<(Hl0RQeD6b=40#sP@}fKi@Rp*FLS%;Z5@Q#E(IG z?QcYn_#IBh-${U8`{A7WOYAp0VShH8o%XA84u`vyaw-Wh4xpSm)zjGmDmvT!ES%&N2p%=c2oOO z=>D9f7xr`0&TC)iVD@!-=82*$^TlsP+$p`?{D^#e?Mw2S*Q2+feIUDmQyHxis1m3W zs1m3W_zFvaeYo`By~Zh>8Ncv6=)L=aFH&lpQ3GJi*V^yN z`J>pwCZyTppJmrmjGl71c%SfP^0|TQQn}o+&S5=IY(qKc0tY+1Q*oeYgr$oo!(3%CuWR&mUdjU zH}^$=oA#A@dSY(H1!tYQQ~!-#p|nF=-`{bm;KB!fOroJ4nsGp>pqXx=+o)k zxRcW>ErR-!L3UEgDficu#_p;7aLy#{4fr_s=G62q4zKg$6CbKKx%xu8NW0-_Pc`l$ z9cUVtzUD7x^^^p-&$c!51NiE-zQ8$&UhM|{N`woUc|m%=)%p9cMtRaN!_Qp4ZMf9$ z(SClouD-Im)p!H`EZ^X6m>;PeGEZe*;h8UHGurk8^~#P*I?e;#n_tVe;;gw>2~-JG z2~-K}T>`K8h~T40gZe|}wFBo4V*gD2q~gWc#r)^|YV1nA?6Uc(e$i2pq{t)%h%U?eEb^bWM5AQ_k?jmL0m-Oli`xD>ia)n zrK5etPR2UQH_FD#yl~LG7CdulBRg6tYC7UWc3bAX_(S&CRbH2V3}xM@kzPILO&?{j z?-#B5QS2Z4M)LsozwqPBTgMkPuG@XTVO1zQ>aFU}tH0kI-}AAaF#kDwUyat;`$J5R zeuQ?{s&}<#&1c}JQ@iKp`WL+w@dV}{j6>=Fw)QjEJl1R7ZV-R;>rVH1IPHgU@~`T& z?#_G=`7n=2_m!=tsK+y6KIbKO_{~w|!1tDYtnWpAKAzilUxV%=qu$Uyu+C1u-J@5= zFRFjrab9z-5~vcW5~vcW5~vcW639s4T|a*C^3P>_mZOa07zZ&v%A~&fYy8}kebCs` zI*rHY(w2F})_yXqAIRQm`R`A}W#iAjckfynWt^;i&_VNW%NNf%e5|i99w)q|YvaQo z$?}nV?sNKMex>!!T>RDW$kQa4m6Z=zhj4b^=)4#*#6wTtp|7B_Zr_$Ft z$}lm{%i#oH!+cSGqV9P=m2>@GB~T?$B~T?$B~T?$B~T^sm;~g<&%FUT%KV4#_{T7Q zv)?zn7vkG~L>pvhPUj(MJ`gfA;h6comA_Xnd$;>(%=a--U*((T4LKaS_ow2s>0lRW z@w0z<8N5^aa>$!`&Z$)M5sS~tXR-6KRZa%k?JOVo*IinF(z>Xb2b14o_$cHr#PZS0 z9t1A(cdVyq9f4?bX%63@ddWEvbJM#xoN%;SCxf1=NBn`2S0_F83)zt`#y>X!mH$q5 zR`3m4N80VZP{Qvs_G72}Tjq+Xv~cvNgY1Inje1=gzq9#aynx+Q^~|dsIk*1K&VT8E z2S4UO`+?p9{piy9FOR6b zgnMb<_;TY{_iW#|a(*mcdSm_<{%oH(kJ?9-OUh~bzSH`b_&2z~m)5s;TK{tR9PP>q zZClNvN}x)hN}x)hN}x($WeLcx6z~^eoU^xa_ji31oc6sV#UQ&&FIcEDHtEZ%y~lbIEIfj}=tKA8uhpH6mQC%-w~hmT|Uxtwf< z?@8nnJ5l1F^Y|#IoD)sC{NZW%{Mz$H>5Fr@!R2H(a_Hxch{NLtNqK$!G<@1P0*n4oKv3?vOi=G9c1se>H4MY03Z0siNc1v0Kv41lgTP&r{!x9C9`_E8_LK3|t<^WysS_@h%j8e7a|O0$9^e_5>>L;D z51N;vKgWF~&wGOGn%eRSPS#^~4eoR1@#nxV#<5SB`S0EVR{pC5ssyS8ssyS89+SY2 z{P@*tJOjlaa**8-n%48>Z|P*OW?tnq&sxgP`xun*o9wXofj6>0Fn+m|I?2P@btd_` zG{6u1m_jEzp~c}ee;7WE;gVg#jKgg_o;~x%XEFYm*&6>cxUfHP@5<4+5-NYq1im1- z_>KINyF-*Ga_;emd@H`siP&pCLA)GgoSsW(>Gu|2veU}$a+JUJJA?_nIp&MTetMMN z!Y^3Q3&!q;z1rSe?`6-<=?goX%gJY9Jh|jqKe2q{$K_<-9*mtZg_nH7M)fR*E4_c> z<4@VAz0qgFVbc`X(v;^ zZ2K?y6SNwi7K@MZgvN#Fi+vBdaeZ3PR{P>Ujpenp@)>kaB)FKDF`l&idaZ-?l1Iz< z^sY>N$Q^uXUE$Z_S2WFMO#k4tKhVmdm}iwM@Xe(?f^@f07i~Qwqwe)kcM{tk2 zbA{sqU!_$7RRUE4RRUE4=SV>Q?btno?AN`;D7&Kl2Df zv1_DzWH$lzrA&Y6$A=sq`yM^o_x3q3FV=d-mvUU?C;Key{Lk`fpL;Xjw%_eD=iRRY z>fbYY>c^~3#RcE3?9$kAd7q=?bLswrk?9%#f{*xe{Mh$S`sy{_PFYZIGk&i}=}4|A z{z5;(gx&9mFN-U~S1RvX&+OIikY148I1~Ni*W`~r$k$Tl_mtyv@Asn|lRo)6^*FjQ z6~12eA-r{EJ^nbR2QKJR`m^?D@`Dc4`@8GlY%~2V{ja1!wA zWxdm^YqazWUq3OuQ{~Xa^m~mXEk66~rH9FW`DfzWtZR^78ZVs6@44Zn7y4WH$Wrm$ zhz`;Zl=TVX0Z8b**!ZR<@RGgCgHKe|9yOhUuds>jAu(354pnkCxyF+FYVKn zU&xTbM|t1M4{KCT{Z}PWB~T?$B~T?$B~T^s6_P;siH~UIH`r)=oyMcglV$JSi)ebC zE7h8LPcY+9o(bnm!zur{AbVFQ{;htnr-V!{z8-6RCcK-=>lPP|J{CtysrElI(#b4`tqj??jpPMh`H0fLT+_(JPBvuKB z{D;Flz9^62B0lwR6{b(f1U?=16xJ_Q{wQyQu}1~dey06oC*-|GP|owFT}L0_PS3PA zX5W#Mou;&Mc_cga-;`4*=hCJ2Q25YavL7BLxUUzqtLW3IUD_9lJhzl%%5C60k$jY0 zd^u__v5)S8`#h73Pm;dqzZsY1qn7m%TzzNL1iZ0;wyv*VeB>m@%Kx8wlKg9gS~ zH_|=^&C}@L(>RNHL&-{+!JC~wFbNf| zJWEqL_flEmLGy9TkA3cCkLDbP-mGid@b=lQZ8v-t_2EzC)ABv{*?k!ETR~$n3$0(` zzc`oAlf$V$-OTQ*d3mpUSG>yMB*hO~@tRCIa7r(&->#;>D`#}@V}d#TZq~kB31-mw zD7T8Q(fY%!;&ZoJK6MXCuXg0-^!LNZ+ZFKcwtdk()gjYAH1ZE}?9aJUd$x8m^xxP+1bIB+7Ix_-g~y( zwvH&jH1q}S@Kf7sJL3H)+19j=Nq@hN-?tY3>G?;#X-B{VZKcoC(=8-M?mQvZh1Op8 zy0Td4x6v0gaK6oz^II)B?3L0JxH|f^rOy*u8feDe&VHPwvtKIx@KZk~-CDh=e{1od z*ZS{Ld@Po())R(|{v74=X_Uiad@P)Blfg%OvvBq``6_`bfhvJ2fhvJ2fhvKMB@jOF zC&zg-!Y`HH#^csreyC}Dja_B#J^>4oeH|K~H=NgM!!7-+`JT_r8)oklFnlIIX4!3e zrQfVIKYrPd9N~w{96kj>#9Sho_X52D`2{pgz6i5OH}jBdtonl2$}s30YA7W{({sWiK=uv z{RsG=poV*gM~i->;V4+6jFS`R~C6)%oXZ zvoklpbS^vV6FQ$G#eYS({;Q~P4TOI^zE9yx`5`+d{>jOn_oMs#db3|F;du82DRhxO}%ouH>Z=`Hn8q2{j%PW7-sze`)T^H%@6uraLuIy z_=Gp$2dna&?7`c>`XX1ghn%;VsLopoMz8eO_??3b>af4d@A-?2%KC-%L zin^bcdy89@!(z(YB!I^Q(B4oFyWjiABkbktlZc~Nd*9Pve-z*E1xAj~qjPj=Cvp+? z_Fm)TLGAsu*H^a|M^JejZscB-B+n;){7HQq4)J+>c?{fq-pDDii?zG+zr1#uD zj)8TF)e+s#a0;Je-H7!tyG||t0q{dxtzX>FZT+Vye4WZ?jEZot3U)+nsItAgNjH~fK*appUuAIZvnQ{i-Q?ylibXxB|&Y_SuTEA?CGnnzU#mV=y z->wiCbJ@VJBKhseFNu4q>Hp@YSHo2aR0&iGR0&iGR0&iGtS$l0Q_9#Y9QPou&e&xc zdd~3{F2>V2n!;a}?Q*eYpKde{Y-FEce4WFU-Yb0a%X`5u$Lk3xbN>IJ`9tE@%p>f3 z``jr!?-R}wcv&Rq!H>zpm!+0(`<-VSZqWMEB0SGX6#qK-%F)5-ac(%~lUw=6!e?vB zm3_W69Ou5z5cze2{J@XtWcikTKG(kVYDXH`nS$(T%zLeT@Z$@g#C)#<;y9Hcdr@oF zWfMQFJIJ2WYu&-h0bDi*?z>3onR3CofHvO6&(asT$lvlEa&}9d>zDX@$v?09bJ@BZ8jY&_|M^52^BLjJt>s!ukVxzFGu zJxi1R=O&m5M}M8emGU+HIqgQm`LfUCpx4?rhFv#g(o6L03UQdIn!>$_xJ1?%+?$Ix|0Fe^t*s>aU;-~$FUnqR@Y-#e>qF>fMLNl|i!KwtR1gZq81gZq8 z1WJ+sb{E-WJbu}y^7yKcG2RpXhG=}d`v1_z{9EDbHGcD_`du|jN8@tR3G#zp9T*vE z{0u*xa605qxGAc6Un4sm>6M1h={&3{I7i=Qw`e4vR^54*Xz=y2(=4o}FOq*T;&QUPf(v^*cJ41C%=^W? z5(yXbx>J7hlDRyT?4p*BUjD`>XQ1}dshuTV$2^~Y;^c%+jPRTvBLwLUK04~#x&4#M zGwESRJ-OXI5snS+e;4s3K4||hoR@d_Pts1RU$FTv_N;UDJLM0)i~RxAe*}{*_XsS8 z$0VHm;gK)vIXbt;1pngk&X2#b|M^3FZj{cEywKZP^=;(NIg78?I1KwT>q}1iQX1_O z=;#lPe$cl~IZD}U_0PUB*i z+27#}zO%;z=K{*b`<{*i*k|3T{_WK-a;x%D^Q{u75~vcW5~vcW5~va=P6D#K zhC4s91}%=y97D2$a=&oEZmoI7pz%55S735n<7IaVvM<4RuW>lfmImy<8v#k{ zJ-_fH2d(C_gPFgBgLN6s<>+PS;5qG={~)HfIxMYgVNdcI`6oQsX|Q{1-EVcia~aBB z-^pHxe{m;VgY0Wg_Os#O$CNnNF5QcfOMmP4A4h({gI;)cvU|JdW4Nu^JvslkBj2rE zmzMt0``o^WlQY~VLOO3y^{FxK2X>cjxVwI?tOHU`2i6s?pVc-g>Re3lvHrCUGt;@I zerK-#iTYjeu}^S1r)@&?s!wTs;a&eQBih&C)h~dL^NAjgwEbIe)}z5oy`tY6OuItA zxIKai4mdxT^=UKC?bUv;{^b}KYdkZNdU>Dm54rP^x0glZQ~h5hP$f_$P$f_$P$e)! z0?B_>c7PdCb`gZ}T%sG>!>%PX^p$XF+w9CZLmB(A}&uI5grKOZ(@lE}@Hr=|u&c)*<^p|aq zh__6w%BMP^w}j`!a%Sa~(#dnO$QR28&u3?rk=O3=<=Wki$J^-3(n~{wqkkxb&y(5EIbZnu-Wopcy#d#NT}|gTrF~T=3$Ba^UjDq` zgR;M>3@*>$;+Y>K;a?7}hJ8r%2f3tk_X6j_ul%0-Nx%i|rtH(;D!hLP;tyO<$N5S* zdDH(*={FaDt@qjo2d+-#lKrMRHqyN<+6S2nxA}X6k9&SM=j%cS!iRoLVNm;z{@hmk zNH4@=qjSy>xac38`gt3EZ_k6uWl%rv3U3Ebl`r_xIdrEcP{UUVR0&iGR0&iG+=2vT z&&3`+Fn<^AHads)Y|#o&Z_YolxV+AB;63xjv-4dzB6@&dD0*|)<;4eh2nQvdh1s1> zo^ogOfZoDKemafMLs&h6llS4=dlhEo{y-q%f)D7qe4sD%Ogf|sCENoXUG)`mC13bD zI}^z(r*Ckj{8V1f&My;@{8K(tc9cvy=f1Xjr5>D{o{eYqX1}lXOnE`h$o1X6iy}4R`nQ`Z8IL-KwdUG{=B$UBPxgprKDD@q^ zL9dnmR65|IenQbJ?ZH$!)N{gvBc(_GB>7IJcjW-^MZJKpODXaL|CQN3KRL;ndV)Tu zC~`VK;~KL{ph}=hph}=hph}=h;5-Rfzt*626-(2(q4+6(%b%CwTCtu*{yO6?;)Tyr z@bcdMPWVd0DO|1mKH=B$&wFU@d%6#KExly&3G(|xU+@jh@t^yi_L09NJ!bM<>vf;0 zSKvDPdm(<3{UOS^m;Wkyfr67~Xe0Y`E6RII!`A+iI#-$VD317(ev|($`^VINgb93Z zV|<-yPmb{Edv5+6GVLnqtX0pl`Dp(hd|H}~rw{n`+IT$Ek4~j??eEkJ`0z##)UT<0 zgZ4q%`04rFd*r7w&V84@2HCSLzCrh7yP5s7jh}EyUO6@vd#`)K>BlX9^n>gNo8lv= zAF#N1ZW+%^C3LOtYA8G}Lx%#ViX>EMz{u|*!f8D{KB|+k&FX(3Dvri%p z%3_1IIh#UrD44G%oB27RP*0K_IYqY z2li8!I!?(*+y6Z3qt`qaJeGRl%Oq3WYjAO1jm2Z1_l7UsFRXe69_+o8N8X=1&)-V1 z_)o$A@M$ck_#rh0chJ0QD}GPT_esR#?nJ*P`ZUtU)W_Z7bC3V^vXi3UitllGmG%`o z_n>;%@2&n#2-KF(TRCqd$PU`6Kd1gph+{XO!>|0vE4ZLj0Ma=(uf+IM;kWue8{d+{W~K;=aj;4j#_+9_F$*;X^R5&C1(0R^W&4n)kvSz zN7mcJvl!mJAUejb|EdJ41gZq81gZq81gZq81gZq81isP|;M{!8e>(Z=LWBHanKvb> z`EP%!zva@xe#wtH^x!b}KG^xw)HC>(W;>g{=IufAKQBC%Z}j3de{Zz@3H8Q~PyE&K zNk0RA;b#2=zAWuo*OVNRyp&$f-bOi2;g|aH#E&VrsxOv*o}Ki-IR=zdpYba|@5F~X zl>;aIOTkL%KJnvE>{`iw$37w1tDTSG9PP>&>-Zn`32*eD{H=eYbdIBz{pLyJ8-Gaj zm2eeqx6TUpW#KY@73?!utsG_S{96NsTrLF^=SY>?b06oe;d8QY->Uw&xAJE3$uIP)f{%5Fy|*jy z-mh)hQ;VSZu0K%h~XeV<)C$=R33NkXNNup`TKy2@;UchbJ8h|;?LR} zy}6A4eB%FGKK}eD(sYhU3b64H`F#e{j&OdMlfPQ>?|K$-c$HV|Z0I}LYYOo{?4=+0 z8)zR5u2%k_gqxz_5aqC$_O#C47=K;X#o%km>e{dQ1MgvIF{G zO8(+^P`z}*bF?pj{X48@9~2)!two>2d|5A{oH5RH@=t4oi|@^ZoiM0hY_(4!+{Jv{ z%70dGPB?fE4sdyuPxpg|`K7-m&JV)8Qbst9ypZJN2PxzjqL2_}g#B`7H zTm41kBiHUv;Ysa>^ekpL8zB(R;mM-~Vp$<$&ou z_N4*)1yp{^6Vc20{KSvEgm*<{Kf(@uq(AOyRJ<{A!T-G0!&1F6?adKBefQX_OU3rg zkDLY0n=rUIC#97=&Eog?n?H%^jER->Px6QVWbc%H9=jNJdy9|fo7q>UdPjMMhC`wH zJFzUK_Q#ZC^vJmA#AM3}E_NG!WG~Q;sHg$ex zO27W75~vcW5~vcW5~va=R|3!c$QyoToyKhg<2w0G^1U2CrpHAZ+efWSr+pr;-LFsBwVkP_p8BWz zLAI)B^kcSsf!95EA5V@?yBTeY8@GZc)z*jF_dSHi1xHE=;!$EWfVRDR&Y38(w+ zJbfix{Lp#~=|R`V*QtFV|72&M+8+jv9Nik9*4goU0AH*2Bggja_aFN)DCws5C9@x7 zcDgeI_S&~V{!piOer7U9!SFw5&G>Zf{^8pASSPc%2H|wl=aE1bdT}v6onO#uo!+iL zfGhdeEQBMQ>{CB}3bNNyKE31(zn*i1k-w;yKM3)l=x=R(<#5{e6FyI+{Xax`m4XYr z>|yUk$pxg52RZv?NuJZ3%=ptN0zS}+a#Lk8J96V*Jp*SN*fpmA&XnO~5*`ALVj#m6&I`N@Y2{}xZ9 z@g|h-oyK?lMmX``$KC>8mU_+OgU&NMq8i7g@tl!M!#;!gBEL#VEAJqAC;I{Q=o`}O zRIWdc;hIA5F&~9G{NAT<TrKU%k9LMjLdPCuai!T9!?I(l! z0oG%Ru2YXOXYL`P-AI0{kHszc8~dZY`u#@wO6Tzu%=;YHI<2FpdZuynoVa@kliCgH znW(#w`vMq$qdzFP9rj%I4WKlQ!`*wHH}~o5dQbcuy&voPPpk`O#wXzQ8kc$X6P@(T zehA(}>HmqJ?Bq&^ao+U-J)9!#Od5yL-w6kN&ZQmeF_N2O{`qkBtG{KPBB)%RO5Lk| zb)IX6RRUE4RRUE4RRUE4RRWJnK>l>BLognNJ{orq9}j-yvSplY=pGm|w6zw_ps9u@;P<{LSF=euV91uO6f?^yoNu_%Y!W{Ec(g z(HFRQ-tGNIFGjsJ^7jM}_7-@GSBl?6`#f`_la%8_T0VM10gUDC>Z^`zby0vOD=B&hx5`(CLx`*zS9>$!Wy1>d=GbH#^*%BHujzp|xYxvLVW5~vcW64;Fd zG>_^TukO^Io0E4i=O+i@RmJ<4=|bh0aYHY{%Cg|h+6ckp#6=M zL-b{z_qspF`l*7emmVA0gNs(u1I%I1ZRJ-Fu3*~Be&c(XA7FjfE7}uqp)a1%>#W2I z3WVOk5sbcgzVuw9)gjBzywiN3H@J8n4t{mku9Bhd4M#tn@bT>AZ@#pu-q_}v zlRr6r*S+>TH1bEz@o(Sr3_fSi1y~#CAKr{P>bx4uAJ3M;zkSbh`u>(*6%l^A>}znc z&%zfp@uzavYaTzyAC&X2bMf6_U0j^|+-2~9%gYXfUz!&l_tJfSnGX{WO8nK=sdfHb z5WWFBAor2rU;XBN9M1l!y*bja?A%W6kd3z5^D}>5R@m9p-heOZYinHRp-Qo6Pr>D+ z|JBQ3;^;a*@L%TqBI?=Raj~D#;u=_ozK-}dOJ(4^O!W^5m+o2FjInbw)_FqU8tC^l z4%@4J9&x1eYlMq-B=vs;+!+m=A7uLdfprJdukQsO<~>`p zcWM4!`(}e_7dfBx()2fruX~6(+C$;ne13K%fiQ1HZyonKm2N+{68@S@l|Yq1mB8W> zP^Pz zgwGZ}u{Wt5>EsVeI7e!^p?FRZ`lDUU`T>}&A9 z6n?E^xbMVg$2s1L=hg1g&Y0xZdwk+!Ui4e?zrK_K+7;-f>75(j8-MF`4*R(YDqg2@ zuyy;g6h7pHUCyaqUnsVdqxUM8v>Wu(Cnsbhs9o_If4i@H8)B8`Ui)alhukmCoOXrt z^quCXs1m3Ws1m3Ws1jI70zdR) zMyT`xGAR4vlcQccQ-CcO&&?zPvd;`7=%7;yZrIQdc6( zYY7M493S=*BX{@d37CEqi*HJQT34`m(zD7pdPOeK&GnbTN4v=VqY01s z5kBza8}L9I^?&Z;_}xTX{*iT&zm56Bm&5)-yD`B@4j=3J6VVs@JjkvGE+}@ExfvIS zPy302%1^?@I_8OfEEZp@_Q=aVK37iH3KxVAT+nmxD-Na{(EnYF-`VNB>c^nm3jnTO zcu)2}@-aPYm){KEv&av)p#DUAGyPv{_F0i$x;JQ)=8FF^_)>dv7xB9c9S-q(Z^r@j z|LSL4qjcQLe{aQn@ux!Gmr^^vUK!O5R*x7GTk|5=pVQAQiR=vgl^{rX&d zjKel!e}1zb!Cln3HnS`ex3h0`FZkW)UHuSUFXF9&18na5EN}x)hN?_Fz(7cNAf^`9QvKIeG{5~e~$D3Vd@)N)> zE@*y|qiKB9{F-~sXXo>V1z3A*4j1pi$8*EFSbo?yWMj5ur{cqp2YiX*2c-R&*%&@C z<1u_*`U%*<)c%-q)8o~C*+lD_z0^l0oecLtvn zK!4RUPx!)@p&ISqa*kt+hRiyCqjpFAOy@h#hEnp>`G;h$)c8#M$+wy-Z-!{E{CO_=l?n3ourg}5@0{ipRI=jskPo9Kn?=)L&2{EhN0JlS~aKk=iPi0jj{ z(YQ4gKK3Ep2mW+kYKkg>DuF72DuF72l}Vsyeskva!R%sU9(3368C-Wxf2xPzg3exF z<}LueY5kUP>HXZr&z(Pb5}kRSIyWBj=e)*U!{u`AZ0+@tcBXzGy}2=5gfk98c~6B{ zn4Vl8GkQiYmM#p>9q^Nr{|K&Z{YJ0n25R_)zXLW*yd;lYe`?}q^F1hNq8HLn;*B{(9!GI4!D}tEN}x)hN}x)hN}x)hN}x)hN}x($9tr4v(a!jpI{Crf zly+L@x1u};o#Q26{M6#}lm{Pcr$Os^>cd^b*V}r((ETV=*Awc0OT83b`}c8?(?))}op5ciI}u@}6uog@ z5BhVT`jsMjnr;8fh$G>WpY_`7C3p47uYdmFo#kc5R(=2^5i`b+wm!r32Xum80l;(%+zdT2Cd z;SqG+lyw&RyEF3>%|n9P8|}~OyDMw2pPY-LUZOwrNxdz+7UtRGJdx%p;0#*FuYDet z2VO&}1gZq81gZqClYswMKLl?y9x(nvdyV@%x7b~pm@egg8J~mh7je&<@D%Y(fBgP~ zp&;(lYJ~%u>@CbgWe4=&)%w@zzNyXD5BYigo*u^M;pGxP`mP0d8m;rNw{t$;vepP6 z=fgMP;hqZJ+ho(7`(Jw7_pYj32dZBgShV~?y?vu#uleOH7p#8bI?+eMX zf9!kKvFIa5sn5B7=h*Vj9`9p66q&=tJ?v+{pXauM(E8kIe51doKLzhi@8iEMd$&#B zt9{Y0girjCWyihQYL97rEA-yTBnS2LZX0p5-0S%-@j1nNJw6B1Z+4~nUyE`7i^w9=f?FQ^4CdU)XN~e zu#@c4de8nI@f-AZ5%VM6;#A+!YR&#+qjt2>dpnM!uRuF%3d3at8!jV%b{`h~(7rt5 z-_q-C!+(9O5~vcW5~vcW5~vcmgA%}AAj&>T_mv;Oc3S6+QT$fZzQ~E+b(~AG@XdS+ z8ib4aCUm3uC-HJL%};%l(_H+I<(uZmT9-5bcR7FW??--#2W4H}NG{+`=Q+k0DK%XC zNAB~8J4jC8!;W1l`&qu!UeOzPpiXu}Cw*u0DKr4P8F(!9(qn^PG&w(L<*eO^7;-G{ixM1`=aixr_+fAF?ueCe?}Ej*4) zCQseR-x-_@{s-9qCo6EFj2=h5Zj8TSPk-cM_(I%GRxJM7&ddu={{Hi19L2F!;ybu#p~}?4j&a5YJMNzuH3hxYx__ z6@IPy#eTzuo*#)i{nPwA$bM$?>5~171m~gn(RIe22YO0PmWIgB7CI&?!Cx&I`{n`K6m;q zDE(>pS$q#3OT)h_AGgpyWA_QFS4aE7tADk)TIt`txiG$WV)?ykFB?63*<08r;k)iF z<65qNU0QFpx~RW%<~+lJUNny7-c!$uJ*PvvQ@ssVKh$6Omsp(mBUtg3`pkYN_grE> z>g9J0zH}~Ae$S=i$?~OfyxaOA3jI*P&tR>%1UU68YJVHXHO_e}-s*_PyVmrte(n2P zulqq2c3y)&y<2O2Y=leW8~sPmz60~J)m6mEF!c^xo%S1??q^sjEa#p?y?HaBSjTkEIdWe>&L% z8842dPWFJ(U^%USHW^&3gVK3$>}zH{PkLXNeZtPm(3hkB$vzm&x|VpBdg10??8Wq{ z`3d^5e7axwfiY8?mP&~NRQ;@{X?f^Z+OEB~!t z^K$ZY;QS)ycaU!C7u;)R)INvPh>~OE>Hf@WGBBN`=iM&{Sbxnpn?5{(F~4^p!*zHM^cXC9yHGX9GTtG(JuWE zmGYp-`{;a*jOb1KP35UyJ~Er*qp9Dc9B7#9oQ?LUDF+&6!sj#N*j>Sg-#&7LV&^=* zA{-g>%=&Yeeqan-(8>E3_Qy0XP|k*N0{_q>#dLHuAMx4!lb(Gq{8*39Fcb0PSPrjzMBNM5n0a{GhdGnoDu<7K=Tl)}XSBGI-nd8R|VRNWlm7>{H{%q;*8vm&o}|FwZOHJ382@ z%{~!0kEM-pg@2Cn8iZGWssyS8ssyS8ssyS8ssyS8%9TK7{@39bd~VL4_d|*4-qRb} zpRnI|vfl@-Go0`)b)S2ubp!ltuv;E$|COH;>kaaOCVRnDdXwMPXm(mRwc0O(ALP$> z)yDGiT1Sw+rG8%dm~`TG2KTCG`I26O>=y6^b;tVEaqd|9COv-L)Qc=1QwK<27MD{y zX|_I=_)s)HWq;Y<1Xr(m(0rxz6YDL@2y)gs9Dcyn$?lL%Zvw(+k^hc<1A7JiiuR#Q zvgscmh!0Z_d;I!y`+FmtG_q?VCn)>Y?jn{ulFiM7wV#9CjPmI>@V7p-fKGOn20ck{ z$&Pzuy*{;E8=rdxzy)o@Pmung>_;ZQc=daPw~D_pxX8ENddtGR&nmB_UEGrPCcRfX zZnS?hFPLZKE_=C`zv@jX{3EB0`V;5zz7_}*QLx0XM&9&*C9QU6H4=sxsoLGnyKx!)3frE#DA ziD#x-&&w8nI_a(BI~MDKa4o0cOY2VR>$RzBV3k0XK$XB5B%t%j%ugHa0ykVQpMm4+ zi@5KH@Marn(7c;<1r)!LA@YiTLdL&2%?q%XmU?c$$U12Kh@V<#*0(+Oo5Icq7#N3- zS|0`Mv%!eQh6Xt33gWz*Z?f2ZaE~p-&uSWI< zM}16oh?Uhto%_bW3Ozy7`ID8-x4TR~-(p8PlfUL|t6wg?mz}<2-6z`UeJSU*KF9`oLvYhf&x);p!|BiYo z{I*Q5{<9&T#zp!ir+KHvm-?#@&V#SjIAHl`#h@?}SoTk`t>4!IVzH@GT@ByAN+Djf*uh>Jq!oBSHoJ&sm zlHSN;wnx@Cn&&4zRR8R=v-`^OSIG5Gq)&F?3qKo_>oUe)8h&z=eP`yMt?8FY$9_^! zd!?RS$HS~)+J|Q!Gi3Td_=k2zSeZp_N`Ei%9GLt@V1dW~D)f;PgJ?sLOm;TA>L zq*Ve{0#yQ40#yQ40#yQ40#yQ40#yQQk^uKHL7nU}teaT>5)O6RUt1ij?uEdvo$jMC zI0EaxBk6{Hr!xag_vY(-d&jc?F0z~@v?xc3QrJNxQR`{&rz+>!gv zUgrcXu8wx6{l(evu@CZTl>gM_Y|MF)jp_^Xn!3ER<+t@iQP@Y(UzhfK&&J6#1LF(>Um^yo>HaUp$IXX@|48l2 zoDT@JuPg=U%$GU{=kQv zO6y!$6mll{F^{oNKH_Z}37@6A{rC&NekZEjZ2Pi*$1(doxW@%=<{P`^()DEpTi z`6Z7Qe5MG@H+uRP*_XWR6ySyukLMgc6Z@mZc%83;&sI3uZw<0r(C_g597|i(OJAs+ zKHV>e-S%BSDw{uy;DxWW`uT@E-so@H;n3e*!`ErwlzM&F@Hw3a^IA9DHGJt_KlMv@ z4Igdx&h{??c-QeUPq?e~lX_Zdl|Yq1l|Yrihy=7R8nkZcSwC{$%}DZ6gYoxdUz>R~ zH2F1`3LwYVg&#s1?Js}_ieK}&?w821dGdSw1!Y$OPoi5Nv%D0IPwV;BIzHi1JL~Wx zDa1~8vOslTPcM6l#piSn!^!y;lEB)d!RN8xuk^kmC;uPz@xYb(5AE|MHC*`kC-$g! z^c>$7_uHEL3L4s*@$ct;l#S_6iSM_Ghs5B;R?J$Atg^ z$Sl5&{VnDvnqN0Y&**0(|KA?HGcR38Ulu2NfL~~{jr!=cUxxjba7zR0TD3o0&SKA* z@xk(+dv4J8!zho%aO4Wc?~LK6P!8)L=5amuww~*IqPIo+VZ_1tfsL^*1=*9bv@kn&I^Usn9P?_+zf(B( zygKeXXMKxZP4>~-iutzbS?4v{uROpb`=9$f>N8uO?l(d%jc{^LP&VDn0oW;9@#Des zSa@Xr8)GA#GuJw6!@7RqbFi)Kus!Wb_yYLG?hlq7!~H7C135X?zhi=#J6zvWfv*5g z>X+`<$g$Ds`v;`M&hP#y;?1Qy+j}4R=;${LKI}_Qe)RCwNiMWE@HzBO#8T*^h_j)7 z8k`%w$B!=HcP%>_cn|H`r2rHk+E<}nq8DguaC@^~Q5c#$JF2&>^aHM5_}td#)2<+d zd%~!PR)4F#FGc;O?Kf7BPxm%heSr(wQa?ElQg~g&@2WNZCgXFOm#Dv78$SB|UV8DG z&+t7GFFk=PpufV`LFwK=%@cd(*XlQ^c6_G`Tn+nrL@1Q@TF=3E3qJ0r<2;}66A!%> z3M;X7fKT-f9G3c(=na6Jukq4PYw*#(xG#N7@<5afDy_am-W_I3ClrP!HtFM@k0`4k4&pNdQM)Z%Ad zSqePQy1~pxE#ExjPw8bh+lKfgq!=M!soI6VXr10v=n?n`<3vO zqn`C33VAQ$%f&ARU#oRpj?d&**F-$PbIJ5x>>IJt9UDu8arbNA=0_oqKkg58!X-3O`SL zod{}Jl|Yq1l|Yq1l|Yq1l|Yq1!4i=FcqhNYmUW-(NxkN!(8k(<3xO6s_5&x_-LWr( z%>F``vA?x?zcj2)-@EN(Z(_d${&M{HvUlT`NBl%(XIdJc^;=5(QQxpP$v&OFi~45N zv%Mdh^z6SZjnAn)VON44Yd`DM4m;$b_IPai#qsADkbgaP8T5v|DVJ<5@B6JESVeDM zcBya~-w%N+{_*Uy(|@XeS}~-^y6j{0OF{USKL0kET+RMMLw_<^0n22>xNpbD)ZZ?{ z)?}qJpP`?Gt{&eDevnbXe@X3)_BWXTm&!=z_|O+Ld;i%|eBQ9!v`%0g5KZ&68|GU| zGK;TaU;cSt3X8YmqgDUkp*Q&}-73=(5^H`5F4`UEc}igMRy@e>%HrE-p2@wEw<@E6 z#6k5lFyG`JhC9;!3J`H+p-P}iph{q!62K4bM?Uf|%DMmSJoNP#ANwfljqp?D6LVYc zSLfi}HuA&#C1{=o#m?*HXTtr$;BJiHvn%AhHQx^aM-HD2clg)qUMXJ~zGU|kpYZQl zH|Sgo`hkK!r8oE85WblHWN)Xvz)!+8(o#Cvow&!RHGZH=+o5RpL3T{^)-m54!RI6& z+9TYrb${hIm82tLM#W5u?3JlS2+J(toK z_3%J{jEjpiek23>P7ZjW*x|j#58gw=7yifsmWjsvu~9$SU`OKK5V3VIen`Ki_AB^N zCq2OLzZreJ5uZEkO^R1+|Dy2=eIU*lJrxHH3*TV$sdgIN#Dvt9J>w@tux3Xh&23OLbu|IvA zu+YC`M@r|T%swFPlzq%JJ}6zGs1jS~iFRcv`0RW3#T@(EDaXS8ko~&Vc;U5g(rBDW z_c2gk&A5U+pfFT7W{3N=&?}U4O*hWCl!UU&^j41=S4&~@wtV!)u6_?2;m%SFSt=6^7+o0~pek~BL<@3(s`#9<+?ZdE6 zlD%}Pdy}tK#tDJF@e4}zlJ<8(UL{7wzMXKPH{G*eBF{(h^uWhxT1SD8^D9T?QE(*o zIl%?(ux}U4`?ah+^g}Ryhu~rz&-WIt&7vX|dP$k`FgMf4PB^d~|4C*b0I=Tni_^dLV2JO6T=S6S%04!0Je>Z8}W zJ?dkFU03_@w=QR)3_97VIqyb&3?KQ0IGVQDUVb}CZ}cO6<@j`-oA%agU03S&`;OMC z)}f8oH>`KUCo9(%0#Lo9K62iQ;7Y3mssx@Xf!&{cz5CPcmAiWGvHQ*L&#u1n+uaL4 z-~IW6$Im_f)o*rxarMVv{?R_&7w_)=@^}C9+udLNeD~K69)IV{=l1XZYxg%#o;>;N zfA{}}pT{43z5Cnk@BU-=;)5al2fObM^v&H%1O5Bmy@9^9yFbwPb}tX~?cFO^@y+*l zuU_51x_9;J)pxI69m4)$*C_0r-D_8`4Da6EU0gL+FYn{MwEw-o|L>I{%_~>0?Ni3E zm-aa{`|#KHzxVg=Ufaifb^rgx{{1(9y*h;58{T}A%e|r8dqbLgL)!cM7%%PLzPA6r w8RGA`9G+j^zkg-_?zMex-{kqyKJ_0Tm10$p8QV literal 0 HcmV?d00001 diff --git a/python/ray/tests/test_joblib.py b/python/ray/tests/test_joblib.py index cc40c21a2..eb5cf09dd 100644 --- a/python/ray/tests/test_joblib.py +++ b/python/ray/tests/test_joblib.py @@ -1,12 +1,13 @@ import joblib import sys import time +import os +import pickle import numpy as np from sklearn.datasets import load_digits, load_iris from sklearn.model_selection import RandomizedSearchCV -from sklearn.datasets import fetch_openml from sklearn.ensemble import ExtraTreesClassifier from sklearn.ensemble import RandomForestClassifier from sklearn.kernel_approximation import Nystroem @@ -14,7 +15,6 @@ from sklearn.kernel_approximation import RBFSampler from sklearn.pipeline import make_pipeline from sklearn.svm import LinearSVC, SVC from sklearn.tree import DecisionTreeClassifier -from sklearn.utils import check_array from sklearn.linear_model import LogisticRegression from sklearn.neural_network import MLPClassifier from sklearn.model_selection import cross_val_score @@ -112,20 +112,14 @@ def test_sklearn_benchmarks(ray_start_cluster_2_nodes): } # Load dataset. print("Loading dataset...") - data = fetch_openml("mnist_784") - X = check_array(data["data"], dtype=np.float32, order="C") - y = data["target"] - + unnormalized_X_train, y_train = pickle.load( + open( + os.path.join( + os.path.dirname(__file__), "mnist_784_100_samples.pkl"), "rb")) # Normalize features. - X = X / 255 + X_train = unnormalized_X_train / 255 - # Create train-test split. - print("Creating train-test split...") - n_train = 100 - X_train = X[:n_train] - y_train = y[:n_train] register_ray() - train_time = {} random_seed = 0 # Use two workers per classifier. From a4f2dd2138658cc9aef64ffe73877f82cb2a4856 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Wed, 23 Dec 2020 18:27:16 +0100 Subject: [PATCH 77/88] [Tune]Add integer loguniform support (#12994) * Add integer quantization and loguniform support * Fix hyperopt qloguniform not being np.log'd first * Add tests, __init__ * Try to fix tests, better exceptions * Tweak docstrings * Type checks in SearchSpaceTest * Update docs * Lint, tests * Update doc/source/tune/api_docs/search_space.rst Co-authored-by: Kai Fricke Co-authored-by: Kai Fricke --- doc/source/tune/api_docs/search_space.rst | 10 +++- python/ray/tune/__init__.py | 12 ++--- python/ray/tune/sample.py | 59 +++++++++++++++++++++-- python/ray/tune/suggest/bohb.py | 10 +++- python/ray/tune/suggest/hyperopt.py | 24 ++++++--- python/ray/tune/suggest/nevergrad.py | 15 ++++-- python/ray/tune/suggest/skopt.py | 31 ++++++------ python/ray/tune/tests/test_sample.py | 26 ++++++++-- 8 files changed, 147 insertions(+), 40 deletions(-) diff --git a/doc/source/tune/api_docs/search_space.rst b/doc/source/tune/api_docs/search_space.rst index 3c069760f..9e5014031 100644 --- a/doc/source/tune/api_docs/search_space.rst +++ b/doc/source/tune/api_docs/search_space.rst @@ -192,10 +192,18 @@ For a high-level overview, see this example: # Sample a integer uniformly between -9 (inclusive) and 15 (exclusive) "randint": tune.randint(-9, 15), + # Sample a integer uniformly between 1 (inclusive) and 10 (exclusive), + # while sampling in log space + "lograndint": tune.lograndint(1, 10), + # Sample a random uniformly between -21 (inclusive) and 12 (inclusive (!)) # rounding to increments of 3 (includes 12) "qrandint": tune.qrandint(-21, 12, 3), + # Sample a integer uniformly between 1 (inclusive) and 10 (inclusive (!)), + # while sampling in log space and rounding to increments of 2 + "qlograndint": tune.qlograndint(1, 10, 2), + # Sample an option uniformly from the specified choices "choice": tune.choice(["a", "b", "c"]), @@ -266,4 +274,4 @@ Grid Search API References ---------- -See also :ref:`tune-basicvariant`. \ No newline at end of file +See also :ref:`tune-basicvariant`. diff --git a/python/ray/tune/__init__.py b/python/ray/tune/__init__.py index 8171b0756..58906af40 100644 --- a/python/ray/tune/__init__.py +++ b/python/ray/tune/__init__.py @@ -16,8 +16,8 @@ from ray.tune.session import ( from ray.tune.progress_reporter import (ProgressReporter, CLIReporter, JupyterNotebookReporter) from ray.tune.sample import (function, sample_from, uniform, quniform, choice, - randint, qrandint, randn, qrandn, loguniform, - qloguniform) + randint, lograndint, qrandint, qlograndint, randn, + qrandn, loguniform, qloguniform) from ray.tune.suggest import create_searcher from ray.tune.schedulers import create_scheduler @@ -26,10 +26,10 @@ __all__ = [ "register_env", "register_trainable", "run", "run_experiments", "with_parameters", "Stopper", "EarlyStopping", "Experiment", "function", "sample_from", "track", "uniform", "quniform", "choice", "randint", - "qrandint", "randn", "qrandn", "loguniform", "qloguniform", - "ExperimentAnalysis", "Analysis", "CLIReporter", "JupyterNotebookReporter", - "ProgressReporter", "report", "get_trial_dir", "get_trial_name", - "get_trial_id", "make_checkpoint_dir", "save_checkpoint", + "lograndint", "qrandint", "qlograndint", "randn", "qrandn", "loguniform", + "qloguniform", "ExperimentAnalysis", "Analysis", "CLIReporter", + "JupyterNotebookReporter", "ProgressReporter", "report", "get_trial_dir", + "get_trial_name", "get_trial_id", "make_checkpoint_dir", "save_checkpoint", "is_session_enabled", "checkpoint_dir", "SyncConfig", "create_searcher", "create_scheduler" ] diff --git a/python/ray/tune/sample.py b/python/ray/tune/sample.py index 7190c69d2..4a2180b5d 100644 --- a/python/ray/tune/sample.py +++ b/python/ray/tune/sample.py @@ -228,6 +228,22 @@ class Integer(Domain): items = np.random.randint(domain.lower, domain.upper, size=size) return items if len(items) > 1 else domain.cast(items[0]) + class _LogUniform(LogUniform): + def sample(self, + domain: "Integer", + spec: Optional[Union[List[Dict], Dict]] = None, + size: int = 1): + assert domain.lower > 0, \ + "LogUniform needs a lower bound greater than 0" + assert 0 < domain.upper < float("inf"), \ + "LogUniform needs a upper bound greater than 0" + logmin = np.log(domain.lower) / np.log(self.base) + logmax = np.log(domain.upper) / np.log(self.base) + + items = self.base**(np.random.uniform(logmin, logmax, size=size)) + items = np.round(items).astype(int) + return items if len(items) > 1 else domain.cast(items[0]) + default_sampler_cls = _Uniform def __init__(self, lower, upper): @@ -247,6 +263,23 @@ class Integer(Domain): new.set_sampler(self._Uniform()) return new + def loguniform(self, base: float = 10): + if not self.lower > 0: + raise ValueError( + "LogUniform requires a lower bound greater than 0." + f"Got: {self.lower}. Did you pass a variable that has " + "been log-transformed? If so, pass the non-transformed value " + "instead.") + if not 0 < self.upper < float("inf"): + raise ValueError( + "LogUniform requires a upper bound greater than 0. " + f"Got: {self.lower}. Did you pass a variable that has " + "been log-transformed? If so, pass the non-transformed value " + "instead.") + new = copy(self) + new.set_sampler(self._LogUniform(base)) + return new + def is_valid(self, value: int): return self.lower <= value <= self.upper @@ -445,6 +478,16 @@ def randint(lower: int, upper: int): return Integer(lower, upper).uniform() +def lograndint(lower: int, upper: int, base: float = 10): + """Sample an integer value log-uniformly between ``lower`` and ``upper``, + with ``base`` being the base of logarithm. + + ``lower`` is inclusive, ``upper`` is exclusive. + + """ + return Integer(lower, upper).loguniform(base) + + def qrandint(lower: int, upper: int, q: int = 1): """Sample an integer value uniformly between ``lower`` and ``upper``. @@ -453,13 +496,23 @@ def qrandint(lower: int, upper: int, q: int = 1): The value will be quantized, i.e. rounded to an integer increment of ``q``. Quantization makes the upper bound inclusive. - Sampling from ``tune.randint(10)`` is equivalent to sampling from - ``np.random.randint(10)`` - """ return Integer(lower, upper).uniform().quantized(q) +def qlograndint(lower: int, upper: int, q: int, base: float = 10): + """Sample an integer value log-uniformly between ``lower`` and ``upper``, + with ``base`` being the base of logarithm. + + ``lower`` is inclusive, ``upper`` is also inclusive (!). + + The value will be quantized, i.e. rounded to an integer increment of ``q``. + Quantization makes the upper bound inclusive. + + """ + return Integer(lower, upper).loguniform(base).quantized(q) + + def randn(mean: float = 0., sd: float = 1.): """Sample a float value normally with ``mean`` and ``sd``. diff --git a/python/ray/tune/suggest/bohb.py b/python/ray/tune/suggest/bohb.py index b173d5e85..21de8fe14 100644 --- a/python/ray/tune/suggest/bohb.py +++ b/python/ray/tune/suggest/bohb.py @@ -281,7 +281,15 @@ class TuneBOHB(Searcher): log=False) elif isinstance(domain, Integer): - if isinstance(sampler, Uniform): + if isinstance(sampler, LogUniform): + lower = domain.lower + upper = domain.upper + if quantize: + lower = math.ceil(domain.lower / quantize) * quantize + upper = math.floor(domain.upper / quantize) * quantize + return ConfigSpace.UniformIntegerHyperparameter( + par, lower=lower, upper=upper, q=quantize, log=True) + elif isinstance(sampler, Uniform): lower = domain.lower upper = domain.upper if quantize: diff --git a/python/ray/tune/suggest/hyperopt.py b/python/ray/tune/suggest/hyperopt.py index aee5fd82d..a5b81f2ed 100644 --- a/python/ray/tune/suggest/hyperopt.py +++ b/python/ray/tune/suggest/hyperopt.py @@ -400,8 +400,9 @@ class HyperOptSearch(Searcher): if isinstance(domain, Float): if isinstance(sampler, LogUniform): if quantize: - return hpo.hp.qloguniform(par, domain.lower, - domain.upper, quantize) + return hpo.hp.qloguniform(par, np.log(domain.lower), + np.log(domain.upper), + quantize) return hpo.hp.loguniform(par, np.log(domain.lower), np.log(domain.upper)) elif isinstance(sampler, Uniform): @@ -416,12 +417,21 @@ class HyperOptSearch(Searcher): return hpo.hp.normal(par, sampler.mean, sampler.sd) elif isinstance(domain, Integer): - if isinstance(sampler, Uniform): + if isinstance(sampler, LogUniform): if quantize: - logger.warning( - "HyperOpt does not support quantization for " - "integer values. Reverting back to 'randint'.") - return hpo.hp.randint(par, domain.lower, high=domain.upper) + return hpo.base.pyll.scope.int( + hpo.hp.qloguniform(par, np.log(domain.lower), + np.log(domain.upper), quantize)) + return hpo.base.pyll.scope.int( + hpo.hp.qloguniform(par, np.log(domain.lower), + np.log(domain.upper), 1.0)) + elif isinstance(sampler, Uniform): + if quantize: + return hpo.base.pyll.scope.int( + hpo.hp.quniform(par, domain.lower, domain.upper, + quantize)) + return hpo.hp.uniformint( + par, domain.lower, high=domain.upper) elif isinstance(domain, Categorical): if isinstance(sampler, Uniform): return hpo.hp.choice(par, [ diff --git a/python/ray/tune/suggest/nevergrad.py b/python/ray/tune/suggest/nevergrad.py index f5da80b00..8df7269ae 100644 --- a/python/ray/tune/suggest/nevergrad.py +++ b/python/ray/tune/suggest/nevergrad.py @@ -310,16 +310,23 @@ class NevergradSearch(Searcher): exponent=sampler.base) return ng.p.Scalar(lower=domain.lower, upper=domain.upper) - if isinstance(domain, Integer): + elif isinstance(domain, Integer): + if isinstance(sampler, LogUniform): + return ng.p.Log( + lower=domain.lower, + upper=domain.upper, + exponent=sampler.base).set_integer_casting() return ng.p.Scalar( lower=domain.lower, upper=domain.upper).set_integer_casting() - if isinstance(domain, Categorical): + elif isinstance(domain, Categorical): return ng.p.Choice(choices=domain.categories) - raise ValueError("SkOpt does not support parameters of type " - "`{}`".format(type(domain).__name__)) + raise ValueError("Nevergrad does not support parameters of type " + "`{}` with samplers of type `{}`".format( + type(domain).__name__, + type(domain.sampler).__name__)) # Parameter name is e.g. "a/b/c" for nested dicts space = { diff --git a/python/ray/tune/suggest/skopt.py b/python/ray/tune/suggest/skopt.py index 574be4f35..7c4f337af 100644 --- a/python/ray/tune/suggest/skopt.py +++ b/python/ray/tune/suggest/skopt.py @@ -4,7 +4,8 @@ import pickle from typing import Dict, List, Optional, Tuple, Union from ray.tune.result import DEFAULT_METRIC -from ray.tune.sample import Categorical, Domain, Float, Integer, Quantized +from ray.tune.sample import Categorical, Domain, Float, Integer, Quantized, \ + LogUniform from ray.tune.suggest.suggestion import UNRESOLVED_SEARCH_SPACE, \ UNDEFINED_METRIC_MODE, UNDEFINED_SEARCH_SPACE from ray.tune.suggest.variant_generator import parse_spec_vars @@ -334,24 +335,26 @@ class SkOptSearch(Searcher): sampler = sampler.get_sampler() if isinstance(domain, Float): - if domain.sampler is not None: - logger.warning( - "SkOpt does not support specific sampling methods." - " The {} sampler will be dropped.".format(sampler)) - return domain.lower, domain.upper + if isinstance(domain.sampler, LogUniform): + return sko.space.Real( + domain.lower, domain.upper, prior="log-uniform") + return sko.space.Real( + domain.lower, domain.upper, prior="uniform") - if isinstance(domain, Integer): - if domain.sampler is not None: - logger.warning( - "SkOpt does not support specific sampling methods." - " The {} sampler will be dropped.".format(sampler)) - return domain.lower, domain.upper + elif isinstance(domain, Integer): + if isinstance(domain.sampler, LogUniform): + return sko.space.Integer( + domain.lower, domain.upper, prior="log-uniform") + return sko.space.Integer( + domain.lower, domain.upper, prior="uniform") - if isinstance(domain, Categorical): + elif isinstance(domain, Categorical): return domain.categories raise ValueError("SkOpt does not support parameters of type " - "`{}`".format(type(domain).__name__)) + "`{}` with samplers of type `{}`".format( + type(domain).__name__, + type(domain.sampler).__name__)) # Parameter name is e.g. "a/b/c" for nested dicts space = { diff --git a/python/ray/tune/tests/test_sample.py b/python/ray/tune/tests/test_sample.py index 8a06be5d0..4c1d1a1cb 100644 --- a/python/ray/tune/tests/test_sample.py +++ b/python/ray/tune/tests/test_sample.py @@ -26,7 +26,9 @@ class SearchSpaceTest(unittest.TestCase): "qloguniform": tune.qloguniform(1e-4, 1e-1, 5e-5), "choice": tune.choice([2, 3, 4]), "randint": tune.randint(-9, 15), + "lograndint": tune.lograndint(1, 10), "qrandint": tune.qrandint(-21, 12, 3), + "qlograndint": tune.qlograndint(2, 20, 2), "randn": tune.randn(10, 2), "qrandn": tune.qrandn(10, 2, 0.2), } @@ -58,10 +60,21 @@ class SearchSpaceTest(unittest.TestCase): self.assertGreaterEqual(out["randint"], -9) self.assertLess(out["randint"], 15) + self.assertTrue(isinstance(out["randint"], int)) + + self.assertGreaterEqual(out["lograndint"], 1) + self.assertLess(out["lograndint"], 10) + self.assertTrue(isinstance(out["lograndint"], int)) self.assertGreaterEqual(out["qrandint"], -21) self.assertLessEqual(out["qrandint"], 12) self.assertEqual(out["qrandint"] % 3, 0) + self.assertTrue(isinstance(out["qrandint"], int)) + + self.assertGreaterEqual(out["qlograndint"], 2) + self.assertLessEqual(out["qlograndint"], 20) + self.assertEqual(out["qlograndint"] % 2, 0) + self.assertTrue(isinstance(out["qlograndint"], int)) # Very improbable self.assertGreater(out["randn"], 0) @@ -417,7 +430,7 @@ class SearchSpaceTest(unittest.TestCase): config = { "a": tune.sample.Categorical([2, 3, 4]).uniform(), "b": { - "x": tune.sample.Integer(-15, -10).quantized(2), + "x": tune.sample.Integer(-15, -10), "y": 4, "z": tune.sample.Float(1e-4, 1e-2).loguniform() } @@ -426,7 +439,7 @@ class SearchSpaceTest(unittest.TestCase): hyperopt_config = { "a": hp.choice("a", [2, 3, 4]), "b": { - "x": hp.randint("x", -15, -10), + "x": hp.uniformint("x", -15, -10), "y": 4, "z": hp.loguniform("z", np.log(1e-4), np.log(1e-2)) } @@ -625,17 +638,22 @@ class SearchSpaceTest(unittest.TestCase): def testConvertSkOpt(self): from ray.tune.suggest.skopt import SkOptSearch + from skopt.space import Real, Integer config = { "a": tune.sample.Categorical([2, 3, 4]).uniform(), "b": { - "x": tune.sample.Integer(0, 5).quantized(2), + "x": tune.sample.Integer(0, 5), "y": 4, "z": tune.sample.Float(1e-4, 1e-2).loguniform() } } converted_config = SkOptSearch.convert_search_space(config) - skopt_config = {"a": [2, 3, 4], "b/x": (0, 5), "b/z": (1e-4, 1e-2)} + skopt_config = { + "a": [2, 3, 4], + "b/x": Integer(0, 5), + "b/z": Real(1e-4, 1e-2, prior="log-uniform") + } searcher1 = SkOptSearch(space=converted_config, metric="a", mode="max") searcher2 = SkOptSearch(space=skopt_config, metric="a", mode="max") From d95c8b8a418ef35154ffedecbad1812fcc171db9 Mon Sep 17 00:00:00 2001 From: Stephanie Wang Date: Wed, 23 Dec 2020 09:33:43 -0800 Subject: [PATCH 78/88] [core][new scheduler] Move tasks from ready to dispatch to waiting on argument eviction (#13048) * Add index for tasks to dispatch * Task dependency manager interface * Unsubscribe dependencies and tests * NodeManager * Revert "Add index for tasks to dispatch" This reverts commit c6ccb9aa306e00f80d34b991055e4e83872595ea. * tmp * Move back to waiting if args not ready * update --- src/ray/raylet/node_manager.cc | 89 ++++++------- .../raylet/scheduling/cluster_task_manager.cc | 29 ++++- .../raylet/scheduling/cluster_task_manager.h | 22 +++- .../scheduling/cluster_task_manager_test.cc | 121 +++++++++++++++--- src/ray/raylet/task_dependency_manager.cc | 6 + src/ray/raylet/task_dependency_manager.h | 22 +++- 6 files changed, 218 insertions(+), 71 deletions(-) diff --git a/src/ray/raylet/node_manager.cc b/src/ray/raylet/node_manager.cc index c8dcae7f9..7f207c4fb 100644 --- a/src/ray/raylet/node_manager.cc +++ b/src/ray/raylet/node_manager.cc @@ -205,16 +205,6 @@ NodeManager::NodeManager(boost::asio::io_service &io_service, const NodeID &self std::shared_ptr(new ClusterResourceScheduler( self_node_id_.Binary(), local_resources.GetTotalResources().GetResourceMap())); - std::function fulfills_dependencies_func = - [this](const Task &task) { - bool args_ready = task_dependency_manager_.SubscribeGetDependencies( - task.GetTaskSpecification().TaskId(), task.GetDependencies()); - if (args_ready) { - task_dependency_manager_.UnsubscribeGetDependencies( - task.GetTaskSpecification().TaskId()); - } - return args_ready; - }; auto get_node_info_func = [this](const NodeID &node_id) { return gcs_client_->Nodes().Get(node_id); @@ -228,8 +218,8 @@ NodeManager::NodeManager(boost::asio::io_service &io_service, const NodeID &self PublishInfeasibleTaskError(task); }; cluster_task_manager_ = std::shared_ptr(new ClusterTaskManager( - self_node_id_, new_resource_scheduler_, fulfills_dependencies_func, - is_owner_alive, get_node_info_func, announce_infeasible_task)); + self_node_id_, new_resource_scheduler_, task_dependency_manager_, is_owner_alive, + get_node_info_func, announce_infeasible_task)); placement_group_resource_manager_ = std::make_shared(new_resource_scheduler_); } else { @@ -2644,44 +2634,49 @@ void NodeManager::HandleObjectMissing(const ObjectID &object_id) { } RAY_LOG(DEBUG) << result.str(); - // Transition any tasks that were in the runnable state and are dependent on - // this object to the waiting state. - if (!waiting_task_ids.empty()) { - std::unordered_set waiting_task_id_set(waiting_task_ids.begin(), - waiting_task_ids.end()); + // We don't need to do anything if the new scheduler is enabled because tasks + // will get moved back to waiting once they reach the front of the dispatch + // queue. + if (!new_scheduler_enabled_) { + // Transition any tasks that were in the runnable state and are dependent on + // this object to the waiting state. + if (!waiting_task_ids.empty()) { + std::unordered_set waiting_task_id_set(waiting_task_ids.begin(), + waiting_task_ids.end()); - // NOTE(zhijunfu): For direct actors, the worker is initially assigned actor - // creation task ID, which will not be reset after the task finishes. And later tasks - // of this actor will reuse this task ID to require objects from plasma with - // FetchOrReconstruct, since direct actor task IDs are not known to raylet. - // To support actor reconstruction for direct actor, raylet marks actor creation task - // as completed and removes it from `local_queues_` when it receives `TaskDone` - // message from worker. This is necessary because the actor creation task will be - // re-submitted during reconstruction, if the task is not removed previously, the new - // submitted task will be marked as duplicate and thus ignored. - // So here we check for direct actor creation task explicitly to allow this case. - auto iter = waiting_task_id_set.begin(); - while (iter != waiting_task_id_set.end()) { - if (IsActorCreationTask(*iter)) { - RAY_LOG(DEBUG) << "Ignoring direct actor creation task " << *iter - << " when handling object missing for " << object_id; - iter = waiting_task_id_set.erase(iter); - } else { - ++iter; + // NOTE(zhijunfu): For direct actors, the worker is initially assigned actor + // creation task ID, which will not be reset after the task finishes. And later + // tasks of this actor will reuse this task ID to require objects from plasma with + // FetchOrReconstruct, since direct actor task IDs are not known to raylet. + // To support actor reconstruction for direct actor, raylet marks actor creation + // task as completed and removes it from `local_queues_` when it receives `TaskDone` + // message from worker. This is necessary because the actor creation task will be + // re-submitted during reconstruction, if the task is not removed previously, the + // new submitted task will be marked as duplicate and thus ignored. So here we check + // for direct actor creation task explicitly to allow this case. + auto iter = waiting_task_id_set.begin(); + while (iter != waiting_task_id_set.end()) { + if (IsActorCreationTask(*iter)) { + RAY_LOG(DEBUG) << "Ignoring direct actor creation task " << *iter + << " when handling object missing for " << object_id; + iter = waiting_task_id_set.erase(iter); + } else { + ++iter; + } } - } - // First filter out any tasks that can't be transitioned to READY. These - // are running workers or drivers, now blocked in a get. - local_queues_.FilterState(waiting_task_id_set, TaskState::RUNNING); - local_queues_.FilterState(waiting_task_id_set, TaskState::DRIVER); - // Transition the tasks back to the waiting state. They will be made - // runnable once the deleted object becomes available again. - local_queues_.MoveTasks(waiting_task_id_set, TaskState::READY, TaskState::WAITING); - RAY_CHECK(waiting_task_id_set.empty()); - // Moving ready tasks to waiting may have changed the load, making space for placing - // new tasks locally. - ScheduleTasks(cluster_resource_map_); + // First filter out any tasks that can't be transitioned to READY. These + // are running workers or drivers, now blocked in a get. + local_queues_.FilterState(waiting_task_id_set, TaskState::RUNNING); + local_queues_.FilterState(waiting_task_id_set, TaskState::DRIVER); + // Transition the tasks back to the waiting state. They will be made + // runnable once the deleted object becomes available again. + local_queues_.MoveTasks(waiting_task_id_set, TaskState::READY, TaskState::WAITING); + RAY_CHECK(waiting_task_id_set.empty()); + // Moving ready tasks to waiting may have changed the load, making space for placing + // new tasks locally. + ScheduleTasks(cluster_resource_map_); + } } } diff --git a/src/ray/raylet/scheduling/cluster_task_manager.cc b/src/ray/raylet/scheduling/cluster_task_manager.cc index 12715430e..ab3b46ea7 100644 --- a/src/ray/raylet/scheduling/cluster_task_manager.cc +++ b/src/ray/raylet/scheduling/cluster_task_manager.cc @@ -14,13 +14,13 @@ const int kMaxPendingActorsToReport = 20; ClusterTaskManager::ClusterTaskManager( const NodeID &self_node_id, std::shared_ptr cluster_resource_scheduler, - std::function fulfills_dependencies_func, + TaskDependencyManagerInterface &task_dependency_manager, std::function is_owner_alive, NodeInfoGetter get_node_info, std::function announce_infeasible_task) : self_node_id_(self_node_id), cluster_resource_scheduler_(cluster_resource_scheduler), - fulfills_dependencies_func_(fulfills_dependencies_func), + task_dependency_manager_(task_dependency_manager), is_owner_alive_(is_owner_alive), get_node_info_(get_node_info), announce_infeasible_task_(announce_infeasible_task), @@ -102,7 +102,8 @@ bool ClusterTaskManager::WaitForTaskArgsRequests(Work work) { auto object_ids = task.GetTaskSpecification().GetDependencies(); bool can_dispatch = true; if (object_ids.size() > 0) { - bool args_ready = fulfills_dependencies_func_(task); + bool args_ready = task_dependency_manager_.SubscribeGetDependencies( + task.GetTaskSpecification().TaskId(), task.GetDependencies()); if (args_ready) { RAY_LOG(DEBUG) << "Args already ready, task can be dispatched " << task.GetTaskSpecification().TaskId(); @@ -138,6 +139,16 @@ void ClusterTaskManager::DispatchScheduledTasksToWorkers( auto &task = std::get<0>(work); auto &spec = task.GetTaskSpecification(); + // An argument was evicted since this task was added to the dispatch + // queue. Move it back to the waiting queue. The caller is responsible + // for notifying us when the task is unblocked again. + if (!spec.GetDependencies().empty() && + !task_dependency_manager_.IsTaskReady(spec.TaskId())) { + waiting_tasks_[spec.TaskId()] = std::move(*work_it); + work_it = dispatch_queue.erase(work_it); + continue; + } + std::shared_ptr worker = worker_pool.PopWorker(spec); if (!worker) { // No worker available, we won't be able to schedule any kind of task. @@ -152,6 +163,9 @@ void ClusterTaskManager::DispatchScheduledTasksToWorkers( RAY_LOG(WARNING) << "Task: " << task.GetTaskSpecification().TaskId() << "'s caller is no longer running. Cancelling task."; worker_pool.PushWorker(worker); + if (!spec.GetDependencies().empty()) { + RAY_CHECK(task_dependency_manager_.UnsubscribeGetDependencies(spec.TaskId())); + } work_it = dispatch_queue.erase(work_it); } else { bool worker_leased; @@ -164,6 +178,9 @@ void ClusterTaskManager::DispatchScheduledTasksToWorkers( worker_pool.PushWorker(worker); } if (remove) { + if (!spec.GetDependencies().empty()) { + RAY_CHECK(task_dependency_manager_.UnsubscribeGetDependencies(spec.TaskId())); + } work_it = dispatch_queue.erase(work_it); } else { break; @@ -295,6 +312,9 @@ bool ClusterTaskManager::CancelTask(const TaskID &task_id) { if (task.GetTaskSpecification().TaskId() == task_id) { RemoveFromBacklogTracker(task); ReplyCancelled(*work_it); + if (!task.GetTaskSpecification().GetDependencies().empty()) { + RAY_CHECK(task_dependency_manager_.UnsubscribeGetDependencies(task_id)); + } work_queue.erase(work_it); if (work_queue.empty()) { tasks_to_dispatch_.erase(shapes_it); @@ -326,6 +346,9 @@ bool ClusterTaskManager::CancelTask(const TaskID &task_id) { const auto &task = std::get<0>(iter->second); RemoveFromBacklogTracker(task); ReplyCancelled(iter->second); + if (!task.GetTaskSpecification().GetDependencies().empty()) { + task_dependency_manager_.UnsubscribeGetDependencies(task_id); + } waiting_tasks_.erase(iter); return true; } diff --git a/src/ray/raylet/scheduling/cluster_task_manager.h b/src/ray/raylet/scheduling/cluster_task_manager.h index 3e3ff2e44..61cfce031 100644 --- a/src/ray/raylet/scheduling/cluster_task_manager.h +++ b/src/ray/raylet/scheduling/cluster_task_manager.h @@ -5,6 +5,7 @@ #include "ray/common/task/task.h" #include "ray/common/task/task_common.h" #include "ray/raylet/scheduling/cluster_resource_scheduler.h" +#include "ray/raylet/task_dependency_manager.h" #include "ray/raylet/worker.h" #include "ray/raylet/worker_pool.h" #include "ray/rpc/grpc_client.h" @@ -46,14 +47,13 @@ class ClusterTaskManager { /// \param self_node_id: ID of local node. /// \param cluster_resource_scheduler: The resource scheduler which contains /// the state of the cluster. - /// \param fulfills_dependencies_func: Returns true if all of a task's - /// dependencies are fulfilled. + /// \param task_dependency_manager_ Used to fetch task's dependencies. /// \param is_owner_alive: A callback which returns if the owner process is alive /// (according to our ownership model). /// \param gcs_client: A gcs client. ClusterTaskManager(const NodeID &self_node_id, std::shared_ptr cluster_resource_scheduler, - std::function fulfills_dependencies_func, + TaskDependencyManagerInterface &task_dependency_manager_, std::function is_owner_alive, NodeInfoGetter get_node_info, std::function announce_infeasible_task); @@ -145,8 +145,8 @@ class ClusterTaskManager { const NodeID &self_node_id_; std::shared_ptr cluster_resource_scheduler_; - /// Function to make task dependencies to be local. - std::function fulfills_dependencies_func_; + /// Class to make task dependencies to be local. + TaskDependencyManagerInterface &task_dependency_manager_; /// Function to check if the owner is alive on a given node. std::function is_owner_alive_; /// Function to get the node information of a given node id. @@ -163,10 +163,20 @@ class ClusterTaskManager { /// Queue of lease requests that should be scheduled onto workers. /// Tasks move from scheduled | waiting -> dispatch. + /// Tasks can also move from dispatch -> waiting if one of their arguments is + /// evicted. + /// All tasks in this map that have dependencies should be registered with + /// the dependency manager, in case a dependency gets evicted while the task + /// is still queued. std::unordered_map> tasks_to_dispatch_; /// Tasks waiting for arguments to be transferred locally. /// Tasks move from waiting -> dispatch. + /// Tasks can also move from dispatch -> waiting if one of their arguments is + /// evicted. + /// All tasks in this map that have dependencies should be registered with + /// the dependency manager, so that they can be moved to dispatch once their + /// dependencies are local. absl::flat_hash_map waiting_tasks_; /// Queue of lease requests that are infeasible. @@ -192,6 +202,8 @@ class ClusterTaskManager { void AddToBacklogTracker(const Task &task); void RemoveFromBacklogTracker(const Task &task); + + friend class ClusterTaskManagerTest; }; } // namespace raylet } // namespace ray diff --git a/src/ray/raylet/scheduling/cluster_task_manager_test.cc b/src/ray/raylet/scheduling/cluster_task_manager_test.cc index 24018dbc8..3f33cbf06 100644 --- a/src/ray/raylet/scheduling/cluster_task_manager_test.cc +++ b/src/ray/raylet/scheduling/cluster_task_manager_test.cc @@ -39,6 +39,8 @@ namespace ray { namespace raylet { +using ::testing::_; + class MockWorkerPool : public WorkerPoolInterface { public: std::shared_ptr PopWorker(const TaskSpecification &task_spec) { @@ -92,21 +94,34 @@ Task CreateTask(const std::unordered_map &required_resource return Task(spec_builder.Build(), TaskExecutionSpecification(execution_spec_message)); } +class MockTaskDependencyManager : public TaskDependencyManagerInterface { + public: + bool SubscribeGetDependencies( + const TaskID &task_id, const std::vector &required_objects) { + RAY_CHECK(subscribed_tasks.insert(task_id).second); + return task_ready_; + } + + bool UnsubscribeGetDependencies(const TaskID &task_id) { + return subscribed_tasks.erase(task_id); + } + + bool IsTaskReady(const TaskID &task_id) const { return task_ready_; } + + bool task_ready_ = true; + + std::unordered_set subscribed_tasks; +}; + class ClusterTaskManagerTest : public ::testing::Test { public: ClusterTaskManagerTest() : id_(NodeID::FromRandom()), scheduler_(CreateSingleNodeScheduler(id_.Binary())), - fulfills_dependencies_calls_(0), - dependencies_fulfilled_(true), is_owner_alive_(true), node_info_calls_(0), announce_infeasible_task_calls_(0), - task_manager_(id_, scheduler_, - [this](const Task &_task) { - fulfills_dependencies_calls_++; - return dependencies_fulfilled_; - }, + task_manager_(id_, scheduler_, dependency_manager_, [this](const WorkerID &worker_id, const NodeID &node_id) { return is_owner_alive_; }, @@ -132,20 +147,26 @@ class ClusterTaskManagerTest : public ::testing::Test { node_info_[id] = info; } + void AssertNoLeaks() { + ASSERT_TRUE(task_manager_.tasks_to_schedule_.empty()); + ASSERT_TRUE(task_manager_.tasks_to_dispatch_.empty()); + ASSERT_TRUE(task_manager_.waiting_tasks_.empty()); + ASSERT_TRUE(task_manager_.infeasible_tasks_.empty()); + ASSERT_TRUE(dependency_manager_.subscribed_tasks.empty()); + } + NodeID id_; std::shared_ptr scheduler_; MockWorkerPool pool_; std::unordered_map> leased_workers_; - int fulfills_dependencies_calls_; - bool dependencies_fulfilled_; - bool is_owner_alive_; int node_info_calls_; int announce_infeasible_task_calls_; std::unordered_map> node_info_; + MockTaskDependencyManager dependency_manager_; ClusterTaskManager task_manager_; }; @@ -178,8 +199,9 @@ TEST_F(ClusterTaskManagerTest, BasicTest) { ASSERT_TRUE(callback_occurred); ASSERT_EQ(leased_workers_.size(), 1); ASSERT_EQ(pool_.workers.size(), 0); - ASSERT_EQ(fulfills_dependencies_calls_, 0); ASSERT_EQ(node_info_calls_, 0); + + AssertNoLeaks(); } TEST_F(ClusterTaskManagerTest, NoFeasibleNodeTest) { @@ -202,7 +224,6 @@ TEST_F(ClusterTaskManagerTest, NoFeasibleNodeTest) { ASSERT_EQ(leased_workers_.size(), 0); // Worker is unused. ASSERT_EQ(pool_.workers.size(), 1); - ASSERT_EQ(fulfills_dependencies_calls_, 0); ASSERT_EQ(node_info_calls_, 0); } @@ -227,11 +248,14 @@ TEST_F(ClusterTaskManagerTest, ResourceTakenWhileResolving) { }; /* Blocked on dependencies */ + dependency_manager_.task_ready_ = false; auto task = CreateTask({{ray::kCPU_ResourceLabel, 5}}, 1); - dependencies_fulfilled_ = false; + std::unordered_set expected_subscribed_tasks = { + task.GetTaskSpecification().TaskId()}; task_manager_.QueueTask(task, &reply, callback); task_manager_.SchedulePendingTasks(); task_manager_.DispatchScheduledTasksToWorkers(pool_, leased_workers_); + ASSERT_EQ(dependency_manager_.subscribed_tasks, expected_subscribed_tasks); ASSERT_EQ(num_callbacks, 0); ASSERT_EQ(leased_workers_.size(), 0); @@ -242,18 +266,20 @@ TEST_F(ClusterTaskManagerTest, ResourceTakenWhileResolving) { task_manager_.QueueTask(task2, &reply, callback); task_manager_.SchedulePendingTasks(); task_manager_.DispatchScheduledTasksToWorkers(pool_, leased_workers_); + ASSERT_EQ(dependency_manager_.subscribed_tasks, expected_subscribed_tasks); ASSERT_EQ(num_callbacks, 1); ASSERT_EQ(leased_workers_.size(), 1); ASSERT_EQ(pool_.workers.size(), 1); /* First task is unblocked now, but resources are no longer available */ + dependency_manager_.task_ready_ = true; auto id = task.GetTaskSpecification().TaskId(); std::vector unblocked = {id}; - dependencies_fulfilled_ = true; task_manager_.TasksUnblocked(unblocked); task_manager_.SchedulePendingTasks(); task_manager_.DispatchScheduledTasksToWorkers(pool_, leased_workers_); + ASSERT_EQ(dependency_manager_.subscribed_tasks, expected_subscribed_tasks); ASSERT_EQ(num_callbacks, 1); ASSERT_EQ(leased_workers_.size(), 1); @@ -265,11 +291,13 @@ TEST_F(ClusterTaskManagerTest, ResourceTakenWhileResolving) { task_manager_.SchedulePendingTasks(); task_manager_.DispatchScheduledTasksToWorkers(pool_, leased_workers_); + ASSERT_TRUE(dependency_manager_.subscribed_tasks.empty()); // Task2 is now done so task can run. ASSERT_EQ(num_callbacks, 2); ASSERT_EQ(leased_workers_.size(), 1); ASSERT_EQ(pool_.workers.size(), 0); + AssertNoLeaks(); } TEST_F(ClusterTaskManagerTest, TestSpillAfterAssigned) { @@ -319,6 +347,7 @@ TEST_F(ClusterTaskManagerTest, TestSpillAfterAssigned) { // The second task was spilled. ASSERT_EQ(spillback_reply.retry_at_raylet_address().raylet_id(), remote_node_id.Binary()); + AssertNoLeaks(); } TEST_F(ClusterTaskManagerTest, TaskCancellationTest) { @@ -375,6 +404,7 @@ TEST_F(ClusterTaskManagerTest, TaskCancellationTest) { ASSERT_FALSE(callback_called); ASSERT_EQ(pool_.workers.size(), 0); ASSERT_EQ(leased_workers_.size(), 1); + AssertNoLeaks(); } TEST_F(ClusterTaskManagerTest, TaskCancelInfeasibleTask) { @@ -412,6 +442,7 @@ TEST_F(ClusterTaskManagerTest, TaskCancelInfeasibleTask) { ASSERT_TRUE(reply.canceled()); ASSERT_EQ(leased_workers_.size(), 0); ASSERT_EQ(pool_.workers.size(), 1); + AssertNoLeaks(); } TEST_F(ClusterTaskManagerTest, HeartbeatTest) { @@ -552,7 +583,6 @@ TEST_F(ClusterTaskManagerTest, BacklogReportTest) { ASSERT_FALSE(callback_occurred); ASSERT_EQ(leased_workers_.size(), 0); ASSERT_EQ(pool_.workers.size(), 1); - ASSERT_EQ(fulfills_dependencies_calls_, 0); ASSERT_EQ(node_info_calls_, 0); auto data = std::make_shared(); @@ -578,6 +608,7 @@ TEST_F(ClusterTaskManagerTest, BacklogReportTest) { ASSERT_EQ(shape1.backlog_size(), 0); ASSERT_EQ(shape1.num_infeasible_requests_queued(), 0); ASSERT_EQ(shape1.num_ready_requests_queued(), 0); + AssertNoLeaks(); } TEST_F(ClusterTaskManagerTest, OwnerDeadTest) { @@ -611,6 +642,7 @@ TEST_F(ClusterTaskManagerTest, OwnerDeadTest) { ASSERT_FALSE(callback_occurred); ASSERT_EQ(leased_workers_.size(), 0); ASSERT_EQ(pool_.workers.size(), 1); + AssertNoLeaks(); } TEST_F(ClusterTaskManagerTest, TestInfeasibleTaskWarning) { @@ -653,6 +685,7 @@ TEST_F(ClusterTaskManagerTest, TestInfeasibleTaskWarning) { ASSERT_EQ(pool_.workers.size(), 1); // Make sure the spillback callback is called. ASSERT_EQ(reply.retry_at_raylet_address().raylet_id(), remote_node_id.Binary()); + AssertNoLeaks(); } TEST_F(ClusterTaskManagerTest, TestMultipleInfeasibleTasksWarnOnce) { @@ -719,6 +752,64 @@ TEST_F(ClusterTaskManagerTest, TestAnyPendingTasks) { &pending_actor_creations, &pending_tasks)); } +TEST_F(ClusterTaskManagerTest, ArgumentEvicted) { + /* + Test the task's dependencies becoming local, then one of the arguments is + evicted. The task should go from waiting -> dispatch -> waiting. + */ + std::shared_ptr worker = + std::make_shared(WorkerID::FromRandom(), 1234); + pool_.PushWorker(std::dynamic_pointer_cast(worker)); + + rpc::RequestWorkerLeaseReply reply; + int num_callbacks = 0; + int *num_callbacks_ptr = &num_callbacks; + auto callback = [num_callbacks_ptr]() { + (*num_callbacks_ptr) = *num_callbacks_ptr + 1; + }; + + /* Blocked on dependencies */ + dependency_manager_.task_ready_ = false; + auto task = CreateTask({{ray::kCPU_ResourceLabel, 5}}, 2); + std::unordered_set expected_subscribed_tasks = { + task.GetTaskSpecification().TaskId()}; + task_manager_.QueueTask(task, &reply, callback); + task_manager_.SchedulePendingTasks(); + task_manager_.DispatchScheduledTasksToWorkers(pool_, leased_workers_); + ASSERT_EQ(dependency_manager_.subscribed_tasks, expected_subscribed_tasks); + ASSERT_EQ(num_callbacks, 0); + ASSERT_EQ(leased_workers_.size(), 0); + + /* Task is unblocked now */ + dependency_manager_.task_ready_ = true; + pool_.workers.clear(); + auto id = task.GetTaskSpecification().TaskId(); + task_manager_.TasksUnblocked({id}); + task_manager_.SchedulePendingTasks(); + task_manager_.DispatchScheduledTasksToWorkers(pool_, leased_workers_); + ASSERT_EQ(dependency_manager_.subscribed_tasks, expected_subscribed_tasks); + ASSERT_EQ(num_callbacks, 0); + ASSERT_EQ(leased_workers_.size(), 0); + + /* Task argument gets evicted */ + dependency_manager_.task_ready_ = false; + pool_.PushWorker(std::dynamic_pointer_cast(worker)); + task_manager_.SchedulePendingTasks(); + task_manager_.DispatchScheduledTasksToWorkers(pool_, leased_workers_); + ASSERT_EQ(dependency_manager_.subscribed_tasks, expected_subscribed_tasks); + ASSERT_EQ(num_callbacks, 0); + ASSERT_EQ(leased_workers_.size(), 0); + + /* Worker available and arguments available */ + task_manager_.TasksUnblocked({id}); + dependency_manager_.task_ready_ = true; + task_manager_.SchedulePendingTasks(); + task_manager_.DispatchScheduledTasksToWorkers(pool_, leased_workers_); + ASSERT_EQ(num_callbacks, 1); + ASSERT_EQ(leased_workers_.size(), 1); + AssertNoLeaks(); +} + int main(int argc, char **argv) { ::testing::InitGoogleTest(&argc, argv); return RUN_ALL_TESTS(); diff --git a/src/ray/raylet/task_dependency_manager.cc b/src/ray/raylet/task_dependency_manager.cc index f2b0ab959..74c3d8c7a 100644 --- a/src/ray/raylet/task_dependency_manager.cc +++ b/src/ray/raylet/task_dependency_manager.cc @@ -211,6 +211,12 @@ bool TaskDependencyManager::SubscribeGetDependencies( return (task_entry.num_missing_get_dependencies == 0); } +bool TaskDependencyManager::IsTaskReady(const TaskID &task_id) const { + auto task_entry = task_dependencies_.find(task_id); + RAY_CHECK(task_entry != task_dependencies_.end()); + return task_entry->second.num_missing_get_dependencies == 0; +} + void TaskDependencyManager::SubscribeWaitDependencies( const WorkerID &worker_id, const std::vector &required_objects) { diff --git a/src/ray/raylet/task_dependency_manager.h b/src/ray/raylet/task_dependency_manager.h index 75654698f..eb2c53ee9 100644 --- a/src/ray/raylet/task_dependency_manager.h +++ b/src/ray/raylet/task_dependency_manager.h @@ -29,6 +29,18 @@ using rpc::TaskLeaseData; class ReconstructionPolicy; +/// Used for unit-testing the ClusterTaskManager, which calls these methods for +/// locally queued tasks that have dependencies. +class TaskDependencyManagerInterface { + public: + virtual bool SubscribeGetDependencies( + const TaskID &task_id, + const std::vector &required_objects) = 0; + virtual bool IsTaskReady(const TaskID &task_id) const = 0; + virtual bool UnsubscribeGetDependencies(const TaskID &task_id) = 0; + virtual ~TaskDependencyManagerInterface() {} +}; + /// \class TaskDependencyManager /// /// Responsible for managing object dependencies for tasks. The caller can @@ -39,7 +51,7 @@ class ReconstructionPolicy; /// made available locally, either by object transfer from a remote node or /// reconstruction. The task manager will also cancel these objects if they are /// no longer needed by any task. -class TaskDependencyManager { +class TaskDependencyManager : public TaskDependencyManagerInterface { public: /// Create a task dependency manager. TaskDependencyManager(ObjectManagerInterface &object_manager, @@ -70,6 +82,14 @@ class TaskDependencyManager { bool SubscribeGetDependencies( const TaskID &task_id, const std::vector &required_objects); + /// Check whether a task is ready to run. The task ID must + /// have been previously subscribed by the caller. + /// + /// \param task_id The ID of the task to check. + /// \return Whether all of the dependencies for the task are + /// local. + bool IsTaskReady(const TaskID &task_id) const; + /// Subscribe to object depedencies required by the worker. This should be called for /// ray.wait calls during task execution. /// From 8df94e33e0c10d23dde4776e97fbd5f5bd828101 Mon Sep 17 00:00:00 2001 From: Alex Wu Date: Wed, 23 Dec 2020 12:02:55 -0800 Subject: [PATCH 79/88] [Autoscaler] New output log format (#12772) --- dashboard/modules/reporter/reporter_head.py | 12 +- dashboard/tests/test_dashboard.py | 9 +- python/ray/autoscaler/_private/autoscaler.py | 168 ++++--- python/ray/autoscaler/_private/commands.py | 12 + .../autoscaler/_private/legacy_info_string.py | 35 ++ .../ray/autoscaler/_private/load_metrics.py | 92 +++- .../_private/resource_demand_scheduler.py | 22 +- python/ray/autoscaler/_private/util.py | 175 ++++++- python/ray/monitor.py | 34 +- python/ray/tests/test_autoscaler.py | 60 ++- .../tests/test_resource_demand_scheduler.py | 442 ++++++++++++++++-- 11 files changed, 907 insertions(+), 154 deletions(-) create mode 100644 python/ray/autoscaler/_private/legacy_info_string.py diff --git a/dashboard/modules/reporter/reporter_head.py b/dashboard/modules/reporter/reporter_head.py index 2fdd001d4..8faef274d 100644 --- a/dashboard/modules/reporter/reporter_head.py +++ b/dashboard/modules/reporter/reporter_head.py @@ -13,6 +13,7 @@ import ray.new_dashboard.utils as dashboard_utils import ray._private.services import ray.utils from ray.autoscaler._private.util import (DEBUG_AUTOSCALING_STATUS, + DEBUG_AUTOSCALING_STATUS_LEGACY, DEBUG_AUTOSCALING_ERROR) from ray.core.generated import reporter_pb2 from ray.core.generated import reporter_pb2_grpc @@ -113,13 +114,20 @@ class ReportHead(dashboard_utils.DashboardHeadModule): """ aioredis_client = self._dashboard_head.aioredis_client - status = await aioredis_client.hget(DEBUG_AUTOSCALING_STATUS, "value") + legacy_status = await aioredis_client.hget( + DEBUG_AUTOSCALING_STATUS_LEGACY, "value") + formatted_status_string = await aioredis_client.hget( + DEBUG_AUTOSCALING_STATUS, "value") + formatted_status = json.loads(formatted_status_string.decode() + ) if formatted_status_string else {} error = await aioredis_client.hget(DEBUG_AUTOSCALING_ERROR, "value") return dashboard_utils.rest_response( success=True, message="Got cluster status.", - autoscaling_status=status.decode() if status else None, + autoscaling_status=legacy_status.decode() + if legacy_status else None, autoscaling_error=error.decode() if error else None, + cluster_status=formatted_status if formatted_status else None, ) async def run(self, server): diff --git a/dashboard/tests/test_dashboard.py b/dashboard/tests/test_dashboard.py index 32836c0ae..4bd3e0300 100644 --- a/dashboard/tests/test_dashboard.py +++ b/dashboard/tests/test_dashboard.py @@ -19,7 +19,7 @@ from ray import ray_constants from ray.test_utils import (format_web_url, wait_for_condition, wait_until_server_available, run_string_as_driver, wait_until_succeeded_without_exception) -from ray.autoscaler._private.util import (DEBUG_AUTOSCALING_STATUS, +from ray.autoscaler._private.util import (DEBUG_AUTOSCALING_STATUS_LEGACY, DEBUG_AUTOSCALING_ERROR) import ray.new_dashboard.consts as dashboard_consts import ray.new_dashboard.utils as dashboard_utils @@ -458,11 +458,14 @@ def test_get_cluster_status(ray_start_with_dashboard): def get_cluster_status(): response = requests.get(f"{webui_url}/api/cluster_status") response.raise_for_status() + print(response.json()) assert response.json()["result"] assert "autoscalingStatus" in response.json()["data"] assert response.json()["data"]["autoscalingStatus"] is None assert "autoscalingError" in response.json()["data"] assert response.json()["data"]["autoscalingError"] is None + assert "clusterStatus" in response.json()["data"] + assert "loadMetricsReport" in response.json()["data"]["clusterStatus"] wait_until_succeeded_without_exception(get_cluster_status, (requests.RequestException, )) @@ -478,7 +481,7 @@ def test_get_cluster_status(ray_start_with_dashboard): port=int(address[1]), password=ray_constants.REDIS_DEFAULT_PASSWORD) - client.hset(DEBUG_AUTOSCALING_STATUS, "value", "hello") + client.hset(DEBUG_AUTOSCALING_STATUS_LEGACY, "value", "hello") client.hset(DEBUG_AUTOSCALING_ERROR, "value", "world") response = requests.get(f"{webui_url}/api/cluster_status") @@ -488,6 +491,8 @@ def test_get_cluster_status(ray_start_with_dashboard): assert response.json()["data"]["autoscalingStatus"] == "hello" assert "autoscalingError" in response.json()["data"] assert response.json()["data"]["autoscalingError"] == "world" + assert "clusterStatus" in response.json()["data"] + assert "loadMetricsReport" in response.json()["data"]["clusterStatus"] def test_immutable_types(): diff --git a/python/ray/autoscaler/_private/autoscaler.py b/python/ray/autoscaler/_private/autoscaler.py index 64167b4cb..56c8fa634 100644 --- a/python/ray/autoscaler/_private/autoscaler.py +++ b/python/ray/autoscaler/_private/autoscaler.py @@ -1,4 +1,4 @@ -from collections import defaultdict, namedtuple +from collections import defaultdict, namedtuple, Counter from typing import Any, Optional, Dict, List from urllib3.exceptions import MaxRetryError import copy @@ -16,8 +16,10 @@ from ray.experimental.internal_kv import _internal_kv_put, \ from ray.autoscaler.tags import ( TAG_RAY_LAUNCH_CONFIG, TAG_RAY_RUNTIME_CONFIG, TAG_RAY_FILE_MOUNTS_CONTENTS, TAG_RAY_NODE_STATUS, TAG_RAY_NODE_KIND, - TAG_RAY_USER_NODE_TYPE, STATUS_UP_TO_DATE, NODE_KIND_WORKER, - NODE_KIND_UNMANAGED, NODE_KIND_HEAD) + TAG_RAY_USER_NODE_TYPE, STATUS_UNINITIALIZED, STATUS_WAITING_FOR_SSH, + STATUS_SYNCING_FILES, STATUS_SETTING_UP, STATUS_UP_TO_DATE, + NODE_KIND_WORKER, NODE_KIND_UNMANAGED, NODE_KIND_HEAD) +from ray.autoscaler._private.legacy_info_string import legacy_log_info_string from ray.autoscaler._private.providers import _get_node_provider from ray.autoscaler._private.updater import NodeUpdaterThread from ray.autoscaler._private.node_launcher import NodeLauncher @@ -25,8 +27,8 @@ from ray.autoscaler._private.resource_demand_scheduler import \ get_bin_pack_residual, ResourceDemandScheduler, NodeType, NodeID, NodeIP, \ ResourceDict from ray.autoscaler._private.util import ConcurrentCounter, validate_config, \ - with_head_node_ip, hash_launch_conf, hash_runtime_conf, add_prefix, \ - DEBUG_AUTOSCALING_STATUS, DEBUG_AUTOSCALING_ERROR + with_head_node_ip, hash_launch_conf, hash_runtime_conf, \ + DEBUG_AUTOSCALING_ERROR, format_info_string from ray.autoscaler._private.constants import \ AUTOSCALER_MAX_NUM_FAILURES, AUTOSCALER_MAX_LAUNCH_BATCH, \ AUTOSCALER_MAX_CONCURRENT_LAUNCHES, AUTOSCALER_UPDATE_INTERVAL_S, \ @@ -41,20 +43,23 @@ UpdateInstructions = namedtuple( "UpdateInstructions", ["node_id", "init_commands", "start_ray_commands", "docker_config"]) +AutoscalerSummary = namedtuple( + "AutoscalerSummary", + ["active_nodes", "pending_nodes", "pending_launches", "failed_nodes"]) + class StandardAutoscaler: """The autoscaling control loop for a Ray cluster. There are two ways to start an autoscaling cluster: manually by running - `ray start --head --autoscaling-config=/path/to/config.yaml` on a - instance that has permission to launch other instances, or you can also use - `ray up /path/to/config.yaml` from your laptop, which will - configure the right AWS/Cloud roles automatically. - - StandardAutoscaler's `update` method is periodically called by `monitor.py` - to add and remove nodes as necessary. Currently, load-based autoscaling is - not implemented, so all this class does is try to maintain a constant - cluster size. + `ray start --head --autoscaling-config=/path/to/config.yaml` on a instance + that has permission to launch other instances, or you can also use `ray up + /path/to/config.yaml` from your laptop, which will configure the right + AWS/Cloud roles automatically. See the documentation for a full definition + of autoscaling behavior: + https://docs.ray.io/en/master/cluster/autoscaling.html + StandardAutoscaler's `update` method is periodically called in + `monitor.py`'s monitoring loop. StandardAutoscaler is also used to bootstrap clusters (by adding workers until the cluster size that can handle the resource demand is met). @@ -120,9 +125,6 @@ class StandardAutoscaler: for local_path in self.config["file_mounts"].values(): assert os.path.exists(local_path) - # List of resource bundles the user is requesting of the cluster. - self.resource_demand_vector = [] - logger.info("StandardAutoscaler: {}".format(self.config)) def update(self): @@ -161,7 +163,6 @@ class StandardAutoscaler: self.provider.internal_ip(node_id) for node_id in self.all_workers() ]) - self.log_info_string(nodes) # Terminate any idle or out of date nodes last_used = self.load_metrics.last_used_time_by_ip @@ -175,7 +176,7 @@ class StandardAutoscaler: sorted_node_ids = self._sort_based_on_last_used(nodes, last_used) # Don't terminate nodes needed by request_resources() nodes_allowed_to_terminate: Dict[NodeID, bool] = {} - if self.resource_demand_vector: + if self.load_metrics.get_resource_requests(): nodes_allowed_to_terminate = self._get_nodes_allowed_to_terminate( sorted_node_ids) @@ -201,7 +202,6 @@ class StandardAutoscaler: if nodes_to_terminate: self.provider.terminate_nodes(nodes_to_terminate) nodes = self.workers() - self.log_info_string(nodes) # Terminate nodes if there are too many nodes_to_terminate = [] @@ -216,8 +216,6 @@ class StandardAutoscaler: self.provider.terminate_nodes(nodes_to_terminate) nodes = self.workers() - self.log_info_string(nodes) - to_launch = self.resource_demand_scheduler.get_nodes_to_launch( self.provider.non_terminated_nodes(tag_filters={}), self.pending_launches.breakdown(), @@ -225,7 +223,7 @@ class StandardAutoscaler: self.load_metrics.get_resource_utilization(), self.load_metrics.get_pending_placement_groups(), self.load_metrics.get_static_node_resources_by_ip(), - ensure_min_cluster_size=self.resource_demand_vector) + ensure_min_cluster_size=self.load_metrics.get_resource_requests()) for node_type, count in to_launch.items(): self.launch_new_node(count, node_type=node_type) @@ -255,7 +253,6 @@ class StandardAutoscaler: self.provider.terminate_nodes(nodes_to_terminate) nodes = self.workers() - self.log_info_string(nodes) # Update nodes with out-of-date files. # TODO(edoakes): Spawning these threads directly seems to cause @@ -281,6 +278,9 @@ class StandardAutoscaler: for node_id in nodes: self.recover_if_needed(node_id, now) + logger.info(self.info_string()) + legacy_log_info_string(self, nodes) + def _sort_based_on_last_used(self, nodes: List[NodeID], last_used: Dict[str, float]) -> List[NodeID]: """Sort the nodes based on the last time they were used. @@ -361,7 +361,7 @@ class StandardAutoscaler: used_resource_requests: List[ResourceDict] _, used_resource_requests = \ get_bin_pack_residual(max_node_resources, - self.resource_demand_vector) + self.load_metrics.get_resource_requests()) # Remove the first entry (the head node). max_node_resources.pop(0) # Remove the first entry (the head node). @@ -533,15 +533,17 @@ class StandardAutoscaler: if not self.can_update(node_id): return key = self.provider.internal_ip(node_id) - if key not in self.load_metrics.last_heartbeat_time_by_ip: - self.load_metrics.last_heartbeat_time_by_ip[key] = now - last_heartbeat_time = self.load_metrics.last_heartbeat_time_by_ip[key] - delta = now - last_heartbeat_time - if delta < AUTOSCALER_HEARTBEAT_TIMEOUT_S: - return + + if key in self.load_metrics.last_heartbeat_time_by_ip: + last_heartbeat_time = self.load_metrics.last_heartbeat_time_by_ip[ + key] + delta = now - last_heartbeat_time + if delta < AUTOSCALER_HEARTBEAT_TIMEOUT_S: + return + logger.warning("StandardAutoscaler: " - "{}: No heartbeat in {}s, " - "restarting Ray to recover...".format(node_id, delta)) + "{}: No recent heartbeat, " + "restarting Ray to recover...".format(node_id)) updater = NodeUpdaterThread( node_id=node_id, provider_config=self.config["provider"], @@ -678,43 +680,6 @@ class StandardAutoscaler: return self.provider.non_terminated_nodes( tag_filters={TAG_RAY_NODE_KIND: NODE_KIND_UNMANAGED}) - def log_info_string(self, nodes): - tmp = "Cluster status: " - tmp += self.info_string(nodes) - tmp += "\n" - tmp += self.load_metrics.info_string() - tmp += "\n" - tmp += self.resource_demand_scheduler.debug_string( - nodes, self.pending_launches.breakdown(), - self.load_metrics.get_resource_utilization()) - if _internal_kv_initialized(): - _internal_kv_put(DEBUG_AUTOSCALING_STATUS, tmp, overwrite=True) - if self.prefix_cluster_info: - tmp = add_prefix(tmp, self.config["cluster_name"]) - logger.debug(tmp) - - def info_string(self, nodes): - suffix = "" - if self.updaters: - suffix += " ({} updating)".format(len(self.updaters)) - if self.num_failed_updates: - suffix += " ({} failed to update)".format( - len(self.num_failed_updates)) - - return "{} nodes{}".format(len(nodes), suffix) - - def request_resources(self, resources: List[dict]): - """Called by monitor to request resources. - - Args: - resources: A list of resource bundles. - """ - if resources: - logger.info( - "StandardAutoscaler: resource_requests={}".format(resources)) - assert isinstance(resources, list), resources - self.resource_demand_vector = resources - def kill_workers(self): logger.error("StandardAutoscaler: kill_workers triggered") nodes = self.workers() @@ -722,3 +687,66 @@ class StandardAutoscaler: self.provider.terminate_nodes(nodes) logger.error("StandardAutoscaler: terminated {} node(s)".format( len(nodes))) + + def summary(self): + """Summarizes the active, pending, and failed node launches. + + An active node is a node whose raylet is actively reporting heartbeats. + A pending node is non-active node whose node tag is uninitialized, + waiting for ssh, syncing files, or setting up. + If a node is not pending or active, it is failed. + + Returns: + AutoscalerSummary: The summary. + """ + all_node_ids = self.provider.non_terminated_nodes(tag_filters={}) + + active_nodes = Counter() + pending_nodes = [] + failed_nodes = [] + + for node_id in all_node_ids: + ip = self.provider.internal_ip(node_id) + node_tags = self.provider.node_tags(node_id) + if node_tags[TAG_RAY_NODE_KIND] == NODE_KIND_UNMANAGED: + continue + node_type = node_tags[TAG_RAY_USER_NODE_TYPE] + + # TODO (Alex): If a node's raylet has died, it shouldn't be marked + # as active. + is_active = self.load_metrics.is_active(ip) + if is_active: + active_nodes[node_type] += 1 + else: + status = node_tags[TAG_RAY_NODE_STATUS] + pending_states = [ + STATUS_UNINITIALIZED, STATUS_WAITING_FOR_SSH, + STATUS_SYNCING_FILES, STATUS_SETTING_UP + ] + is_pending = status in pending_states + if is_pending: + pending_nodes.append((ip, node_type)) + else: + # TODO (Alex): Failed nodes are now immediately killed, so + # this list will almost always be empty. We should ideally + # keep a cache of recently failed nodes and their startup + # logs. + failed_nodes.append((ip, node_type)) + + # The concurrent counter leaves some 0 counts in, so we need to + # manually filter those out. + pending_launches = {} + for node_type, count in self.pending_launches.breakdown().items(): + if count: + pending_launches[node_type] = count + + return AutoscalerSummary( + active_nodes=active_nodes, + pending_nodes=pending_nodes, + pending_launches=pending_launches, + failed_nodes=failed_nodes) + + def info_string(self): + lm_summary = self.load_metrics.summary() + autoscaler_summary = self.summary() + return "\n" + format_info_string(lm_summary, autoscaler_summary) diff --git a/python/ray/autoscaler/_private/commands.py b/python/ray/autoscaler/_private/commands.py index 0c7e3abbd..247ba0d69 100644 --- a/python/ray/autoscaler/_private/commands.py +++ b/python/ray/autoscaler/_private/commands.py @@ -43,6 +43,10 @@ from ray.worker import global_worker # type: ignore from ray.util.debug import log_once import ray.autoscaler._private.subprocess_output_util as cmd_output_util +from ray.autoscaler._private.load_metrics import LoadMetricsSummary +from ray.autoscaler._private.autoscaler import AutoscalerSummary +from ray.autoscaler._private.util import format_info_string, \ + format_info_string_no_node_types logger = logging.getLogger(__name__) @@ -94,6 +98,14 @@ def debug_status() -> str: status = "No cluster status." else: status = status.decode("utf-8") + as_dict = json.loads(status) + lm_summary = LoadMetricsSummary(**as_dict["load_metrics_report"]) + if "autoscaler_report" in as_dict: + autoscaler_summary = AutoscalerSummary( + **as_dict["autoscaler_report"]) + status = format_info_string(lm_summary, autoscaler_summary) + else: + status = format_info_string_no_node_types(lm_summary) if error: status += "\n" status += error.decode("utf-8") diff --git a/python/ray/autoscaler/_private/legacy_info_string.py b/python/ray/autoscaler/_private/legacy_info_string.py new file mode 100644 index 000000000..99791efd7 --- /dev/null +++ b/python/ray/autoscaler/_private/legacy_info_string.py @@ -0,0 +1,35 @@ +import logging +from ray.autoscaler._private.util import DEBUG_AUTOSCALING_STATUS_LEGACY +from ray.experimental.internal_kv import _internal_kv_put, \ + _internal_kv_initialized +"""This file provides legacy support for the old info string in order to +ensure the dashboard's `api/cluster_status` does not break backwards +compatibilty. +""" + +logger = logging.getLogger(__name__) + + +def legacy_log_info_string(autoscaler, nodes): + tmp = "Cluster status: " + tmp += info_string(autoscaler, nodes) + tmp += "\n" + tmp += autoscaler.load_metrics.info_string() + tmp += "\n" + tmp += autoscaler.resource_demand_scheduler.debug_string( + nodes, autoscaler.pending_launches.breakdown(), + autoscaler.load_metrics.get_resource_utilization()) + if _internal_kv_initialized(): + _internal_kv_put(DEBUG_AUTOSCALING_STATUS_LEGACY, tmp, overwrite=True) + logger.debug(tmp) + + +def info_string(autoscaler, nodes): + suffix = "" + if autoscaler.updaters: + suffix += " ({} updating)".format(len(autoscaler.updaters)) + if autoscaler.num_failed_updates: + suffix += " ({} failed to update)".format( + len(autoscaler.num_failed_updates)) + + return "{} nodes{}".format(len(nodes), suffix) diff --git a/python/ray/autoscaler/_private/load_metrics.py b/python/ray/autoscaler/_private/load_metrics.py index b688fe617..dc3178015 100644 --- a/python/ray/autoscaler/_private/load_metrics.py +++ b/python/ray/autoscaler/_private/load_metrics.py @@ -1,16 +1,26 @@ +from collections import namedtuple +from functools import reduce import logging import time from typing import Dict, List import numpy as np import ray._private.services as services -from ray.autoscaler._private.constants import MEMORY_RESOURCE_UNIT_BYTES +from ray.autoscaler._private.constants import MEMORY_RESOURCE_UNIT_BYTES,\ + AUTOSCALER_MAX_RESOURCE_DEMAND_VECTOR_SIZE +from ray.autoscaler._private.util import add_resources, freq_of_dicts from ray.gcs_utils import PlacementGroupTableData from ray.autoscaler._private.resource_demand_scheduler import \ NodeIP, ResourceDict +from ray.core.generated.common_pb2 import PlacementStrategy logger = logging.getLogger(__name__) +LoadMetricsSummary = namedtuple("LoadMetricsSummary", [ + "head_ip", "usage", "resource_demand", "pg_demand", "request_demand", + "node_types" +]) + class LoadMetrics: """Container for cluster load metrics. @@ -31,6 +41,7 @@ class LoadMetrics: self.waiting_bundles = [] self.infeasible_bundles = [] self.pending_placement_groups = [] + self.resource_requests = [] def update(self, ip: str, @@ -72,9 +83,12 @@ class LoadMetrics: def mark_active(self, ip): assert ip is not None, "IP should be known at this time" - logger.info("Node {} is newly setup, treating as active".format(ip)) + logger.debug("Node {} is newly setup, treating as active".format(ip)) self.last_heartbeat_time_by_ip[ip] = time.time() + def is_active(self, ip): + return ip in self.last_heartbeat_time_by_ip + def prune_active_ips(self, active_ips): active_ips = set(active_ips) active_ips.add(self.local_ip) @@ -155,12 +169,82 @@ class LoadMetrics: return resources_used, resources_total - def get_resource_demand_vector(self): - return self.waiting_bundles + self.infeasible_bundles + def get_resource_demand_vector(self, clip=True): + if clip: + # Bound the total number of bundles to + # 2xMAX_RESOURCE_DEMAND_VECTOR_SIZE. This guarantees the resource + # demand scheduler bin packing algorithm takes a reasonable amount + # of time to run. + return ( + self. + waiting_bundles[:AUTOSCALER_MAX_RESOURCE_DEMAND_VECTOR_SIZE] + + self. + infeasible_bundles[:AUTOSCALER_MAX_RESOURCE_DEMAND_VECTOR_SIZE] + ) + else: + return self.waiting_bundles + self.infeasible_bundles + + def get_resource_requests(self): + return self.resource_requests def get_pending_placement_groups(self): return self.pending_placement_groups + def summary(self): + available_resources = reduce(add_resources, + self.dynamic_resources_by_ip.values() + ) if self.dynamic_resources_by_ip else {} + total_resources = reduce(add_resources, + self.static_resources_by_ip.values() + ) if self.static_resources_by_ip else {} + usage_dict = {} + for key in total_resources: + total = total_resources[key] + usage_dict[key] = (total - available_resources[key], total) + + summarized_demand_vector = freq_of_dicts( + self.get_resource_demand_vector(clip=False)) + summarized_resource_requests = freq_of_dicts( + self.get_resource_requests()) + + def placement_group_serializer(pg): + bundles = tuple( + frozenset(bundle.unit_resources.items()) + for bundle in pg.bundles) + return (bundles, pg.strategy) + + def placement_group_deserializer(pg_tuple): + # We marshal this as a dictionary so that we can easily json.dumps + # it later. + # TODO (Alex): Would there be a benefit to properly + # marshalling this (into a protobuf)? + bundles = list(map(dict, pg_tuple[0])) + return { + "bundles": freq_of_dicts(bundles), + "strategy": PlacementStrategy.Name(pg_tuple[1]) + } + + summarized_placement_groups = freq_of_dicts( + self.get_pending_placement_groups(), + serializer=placement_group_serializer, + deserializer=placement_group_deserializer) + nodes_summary = freq_of_dicts(self.static_resources_by_ip.values()) + + return LoadMetricsSummary( + head_ip=self.local_ip, + usage=usage_dict, + resource_demand=summarized_demand_vector, + pg_demand=summarized_placement_groups, + request_demand=summarized_resource_requests, + node_types=nodes_summary) + + def set_resource_requests(self, requested_resources): + if requested_resources is not None: + assert isinstance(requested_resources, list), requested_resources + self.resource_requests = [ + request for request in requested_resources if len(request) > 0 + ] + def info_string(self): return " - " + "\n - ".join( ["{}: {}".format(k, v) for k, v in sorted(self._info().items())]) diff --git a/python/ray/autoscaler/_private/resource_demand_scheduler.py b/python/ray/autoscaler/_private/resource_demand_scheduler.py index aba8cff2d..d838c6be1 100644 --- a/python/ray/autoscaler/_private/resource_demand_scheduler.py +++ b/python/ray/autoscaler/_private/resource_demand_scheduler.py @@ -149,8 +149,8 @@ class ResourceDemandScheduler: node_resources, node_type_counts = self.calculate_node_resources( nodes, launching_nodes, unused_resources_by_ip) - logger.info("Cluster resources: {}".format(node_resources)) - logger.info("Node counts: {}".format(node_type_counts)) + logger.debug("Cluster resources: {}".format(node_resources)) + logger.debug("Node counts: {}".format(node_type_counts)) # Step 2: add nodes to add to satisfy min_workers for each type (node_resources, node_type_counts, @@ -160,7 +160,7 @@ class ResourceDemandScheduler: self.max_workers, self.head_node_type, ensure_min_cluster_size) # Step 3: add nodes for strict spread groups - logger.info(f"Placement group demands: {pending_placement_groups}") + logger.debug(f"Placement group demands: {pending_placement_groups}") placement_group_demand_vector, strict_spreads = \ placement_groups_to_resource_demands(pending_placement_groups) resource_demands.extend(placement_group_demand_vector) @@ -187,8 +187,8 @@ class ResourceDemandScheduler: # groups unfulfilled, _ = get_bin_pack_residual(node_resources, resource_demands) - logger.info("Resource demands: {}".format(resource_demands)) - logger.info("Unfulfilled demands: {}".format(unfulfilled)) + logger.debug("Resource demands: {}".format(resource_demands)) + logger.debug("Unfulfilled demands: {}".format(unfulfilled)) # Add 1 to account for the head node. max_to_add = self.max_workers + 1 - sum(node_type_counts.values()) nodes_to_add_based_on_demand = get_nodes_for( @@ -211,7 +211,7 @@ class ResourceDemandScheduler: total_nodes_to_add, unused_resources_by_ip.keys(), nodes, launching_nodes, adjusted_min_workers) - logger.info("Node requests: {}".format(total_nodes_to_add)) + logger.debug("Node requests: {}".format(total_nodes_to_add)) return total_nodes_to_add def _legacy_worker_node_to_launch( @@ -615,8 +615,14 @@ def get_nodes_for(node_types: Dict[NodeType, NodeTypeConfigDict], # starts up because placement groups are scheduled via custom # resources. This will behave properly with the current utilization # score heuristic, but it's a little dangerous and misleading. - logger.info( - "No feasible node type to add for {}".format(resources)) + logger.warning( + f"The autoscaler could not find a node type to satisfy the" + f"request: {resources}. If this request is related to " + f"placement groups the resource request will resolve itself, " + f"otherwise please specify a node type with the necessary " + f"resource " + f"https://docs.ray.io/en/master/cluster/autoscaling.html#multiple-node-type-autoscaling." # noqa: E501 + ) break utilization_scores = sorted(utilization_scores, reverse=True) diff --git a/python/ray/autoscaler/_private/util.py b/python/ray/autoscaler/_private/util.py index b4066df95..1ab7c2e68 100644 --- a/python/ray/autoscaler/_private/util.py +++ b/python/ray/autoscaler/_private/util.py @@ -1,13 +1,15 @@ import collections +from datetime import datetime import logging import hashlib import json import jsonschema import os import threading -from typing import Any, Dict +from typing import Any, Dict, List import ray +import ray.ray_constants import ray._private.services as services from ray.autoscaler._private.providers import _get_default_config from ray.autoscaler._private.docker import validate_docker_config @@ -20,6 +22,7 @@ RAY_SCHEMA_PATH = os.path.join( # Internal kv keys for storing debug status. DEBUG_AUTOSCALING_ERROR = "__autoscaling_error" DEBUG_AUTOSCALING_STATUS = "__autoscaling_status" +DEBUG_AUTOSCALING_STATUS_LEGACY = "__autoscaling_status_legacy" logger = logging.getLogger(__name__) @@ -246,6 +249,47 @@ def hash_runtime_conf(file_mounts, return (_hash_cache[conf_str], file_mounts_contents_hash) +def add_resources(dict1: Dict[str, float], + dict2: Dict[str, float]) -> Dict[str, float]: + """Add the values in two dictionaries. + + Returns: + dict: A new dictionary (inputs remain unmodified). + """ + new_dict = dict1.copy() + for k, v in dict2.items(): + new_dict[k] = v + new_dict.get(k, 0) + return new_dict + + +def freq_of_dicts(dicts: List[Dict], + serializer=lambda d: frozenset(d.items()), + deserializer=dict): + """Count a list of dictionaries (or unhashable types). + + This is somewhat annoying because mutable data structures aren't hashable, + and set/dict keys must be hashable. + + Args: + dicts (List[D]): A list of dictionaries to be counted. + serializer (D -> S): A custom serailization function. The output type S + must be hashable. The default serializer converts a dictionary into + a frozenset of KV pairs. + deserializer (S -> U): A custom deserialization function. See the + serializer for information about type S. For dictionaries U := D. + + Returns: + List[Tuple[U, int]]: Returns a list of tuples. Each entry in the list + is a tuple containing a unique entry from `dicts` and its + corresponding frequency count. + """ + freqs = collections.Counter(map(lambda d: serializer(d), dicts)) + as_list = [] + for as_set, count in freqs.items(): + as_list.append((deserializer(as_set), count)) + return as_list + + def add_prefix(info_string, prefix): """Prefixes each line of info_string, except the first, by prefix.""" lines = info_string.split("\n") @@ -255,3 +299,132 @@ def add_prefix(info_string, prefix): prefixed_lines.append(prefixed_line) prefixed_info_string = "\n".join(prefixed_lines) return prefixed_info_string + + +def format_pg(pg): + strategy = pg["strategy"] + bundles = pg["bundles"] + shape_strs = [] + for bundle, count in bundles: + shape_strs.append(f"{bundle} * {count}") + bundles_str = ", ".join(shape_strs) + return f"{bundles_str} ({strategy})" + + +def get_usage_report(lm_summary): + usage_lines = [] + for resource, (used, total) in lm_summary.usage.items(): + line = f" {used}/{total} {resource}" + if resource in ["memory", "object_store_memory"]: + to_GiB = ray.ray_constants.MEMORY_RESOURCE_UNIT_BYTES / 2**30 + used *= to_GiB + total *= to_GiB + line = f" {used:.2f}/{total:.3f} GiB {resource}" + usage_lines.append(line) + usage_report = "\n".join(usage_lines) + return usage_report + + +def get_demand_report(lm_summary): + demand_lines = [] + for bundle, count in lm_summary.resource_demand: + line = f" {bundle}: {count}+ pending tasks/actors" + demand_lines.append(line) + for entry in lm_summary.pg_demand: + pg, count = entry + pg_str = format_pg(pg) + line = f" {pg_str}: {count}+ pending placement groups" + demand_lines.append(line) + for bundle, count in lm_summary.request_demand: + line = f" {bundle}: {count}+ from request_resources()" + demand_lines.append(line) + if len(demand_lines) > 0: + demand_report = "\n".join(demand_lines) + else: + demand_report = " (no resource demands)" + return demand_report + + +def format_info_string(lm_summary, autoscaler_summary, time=None): + if time is None: + time = datetime.now() + header = "=" * 8 + f" Autoscaler status: {time} " + "=" * 8 + separator = "-" * len(header) + available_node_report_lines = [] + for node_type, count in autoscaler_summary.active_nodes.items(): + line = f" {count} {node_type}" + available_node_report_lines.append(line) + available_node_report = "\n".join(available_node_report_lines) + + pending_lines = [] + for node_type, count in autoscaler_summary.pending_launches.items(): + line = f" {node_type}, {count} launching" + pending_lines.append(line) + for ip, node_type in autoscaler_summary.pending_nodes: + line = f" {ip}: {node_type}, setting up" + pending_lines.append(line) + if pending_lines: + pending_report = "\n".join(pending_lines) + else: + pending_report = " (no pending nodes)" + + failure_lines = [] + for ip, node_type in autoscaler_summary.failed_nodes: + line = f" {ip}: {node_type}" + failure_report = "Recent failures:\n" + if failure_lines: + failure_report += "\n".join(failure_lines) + else: + failure_report += " (no failures)" + + usage_report = get_usage_report(lm_summary) + demand_report = get_demand_report(lm_summary) + + formatted_output = f"""{header} +Node status +{separator} +Healthy: +{available_node_report} +Pending: +{pending_report} +{failure_report} + +Resources +{separator} + +Usage: +{usage_report} + +Demands: +{demand_report}""" + return formatted_output + + +def format_info_string_no_node_types(lm_summary, time=None): + if time is None: + time = datetime.now() + header = "=" * 8 + f" Cluster status: {time} " + "=" * 8 + separator = "-" * len(header) + + node_lines = [] + for node_type, count in lm_summary.node_types: + line = f" {count} node(s) with resources: {node_type}" + node_lines.append(line) + node_report = "\n".join(node_lines) + + usage_report = get_usage_report(lm_summary) + demand_report = get_demand_report(lm_summary) + + formatted_output = f"""{header} +Node status +{separator} +{node_report} + +Resources +{separator} +Usage: +{usage_report} + +Demands: +{demand_report}""" + return formatted_output diff --git a/python/ray/monitor.py b/python/ray/monitor.py index f650e151a..aa819c7d3 100644 --- a/python/ray/monitor.py +++ b/python/ray/monitor.py @@ -15,11 +15,14 @@ from ray.autoscaler._private.constants import AUTOSCALER_UPDATE_INTERVAL_S from ray.autoscaler._private.load_metrics import LoadMetrics from ray.autoscaler._private.constants import \ AUTOSCALER_MAX_RESOURCE_DEMAND_VECTOR_SIZE +from ray.autoscaler._private.util import DEBUG_AUTOSCALING_STATUS import ray.gcs_utils import ray.utils import ray.ray_constants as ray_constants from ray.ray_logging import setup_component_logger from ray._raylet import GlobalStateAccessor +from ray.experimental.internal_kv import _internal_kv_put, \ + _internal_kv_initialized import redis @@ -65,11 +68,7 @@ def parse_resource_demands(resource_load_by_shape): except Exception: logger.exception("Failed to parse resource demands.") - # Bound the total number of bundles to 2xMAX_RESOURCE_DEMAND_VECTOR_SIZE. - # This guarantees the resource demand scheduler bin packing algorithm takes - # a reasonable amount of time to run. - return waiting_bundles[:AUTOSCALER_MAX_RESOURCE_DEMAND_VECTOR_SIZE], \ - infeasible_bundles[:AUTOSCALER_MAX_RESOURCE_DEMAND_VECTOR_SIZE] + return waiting_bundles, infeasible_bundles class Monitor: @@ -184,14 +183,8 @@ class Monitor: data: a resource request as JSON, e.g. {"CPU": 1} """ - if not self.autoscaler: - return - - try: - self.autoscaler.request_resources(json.loads(data)) - except Exception: - # We don't want this to kill the monitor. - traceback.print_exc() + resource_request = json.loads(data) + self.load_metrics.set_resource_requests(resource_request) def process_messages(self, max_messages=10000): """Process all messages ready in the subscription channels. @@ -257,12 +250,23 @@ class Monitor: # Handle messages from the subscription channels. while True: + self.update_raylet_map() + self.update_load_metrics() + status = { + "load_metrics_report": self.load_metrics.summary()._asdict() + } + # Process autoscaling actions if self.autoscaler: # Only used to update the load metrics for the autoscaler. - self.update_raylet_map() - self.update_load_metrics() self.autoscaler.update() + status[ + "autoscaler_report"] = self.autoscaler.summary()._asdict() + + as_json = json.dumps(status) + if _internal_kv_initialized(): + _internal_kv_put( + DEBUG_AUTOSCALING_STATUS, as_json, overwrite=True) # Process a round of messages. self.process_messages() diff --git a/python/ray/tests/test_autoscaler.py b/python/ray/tests/test_autoscaler.py index 72f361fe2..628c1b191 100644 --- a/python/ray/tests/test_autoscaler.py +++ b/python/ray/tests/test_autoscaler.py @@ -54,8 +54,11 @@ class MockProcessRunner: self.calls = [] self.fail_cmds = fail_cmds or [] self.call_response = {} + self.ready_to_run = threading.Event() + self.ready_to_run.set() def check_call(self, cmd, *args, **kwargs): + self.ready_to_run.wait() for token in self.fail_cmds: if token in str(cmd): raise CalledProcessError(1, token, @@ -165,22 +168,28 @@ class MockProvider(NodeProvider): ] def is_running(self, node_id): - return self.mock_nodes[node_id].state == "running" + with self.lock: + return self.mock_nodes[node_id].state == "running" def is_terminated(self, node_id): - return self.mock_nodes[node_id].state in ["stopped", "terminated"] + with self.lock: + return self.mock_nodes[node_id].state in ["stopped", "terminated"] def node_tags(self, node_id): - return self.mock_nodes[node_id].tags + with self.lock: + return self.mock_nodes[node_id].tags def internal_ip(self, node_id): - return self.mock_nodes[node_id].internal_ip + with self.lock: + return self.mock_nodes[node_id].internal_ip def external_ip(self, node_id): - return self.mock_nodes[node_id].external_ip + with self.lock: + return self.mock_nodes[node_id].external_ip - def create_node(self, node_config, tags, count): - self.ready_to_create.wait() + def create_node(self, node_config, tags, count, _skip_wait=False): + if not _skip_wait: + self.ready_to_create.wait() if self.fail_creates: return with self.lock: @@ -200,7 +209,8 @@ class MockProvider(NodeProvider): self.next_id += 1 def set_node_tags(self, node_id, tags): - self.mock_nodes[node_id].tags.update(tags) + with self.lock: + self.mock_nodes[node_id].tags.update(tags) def terminate_node(self, node_id): with self.lock: @@ -534,7 +544,11 @@ class AutoscalingTest(unittest.TestCase): config["max_workers"] = 5 config_path = self.write_config(config) self.provider = MockProvider() - self.provider.create_node({}, {TAG_RAY_NODE_KIND: "worker"}, 10) + self.provider.create_node({}, { + TAG_RAY_NODE_KIND: "worker", + TAG_RAY_NODE_STATUS: STATUS_UP_TO_DATE, + TAG_RAY_USER_NODE_TYPE: NODE_TYPE_LEGACY_WORKER + }, 10) runner = MockProcessRunner() runner.respond_to_call("json .Config.Env", ["[]" for i in range(10)]) autoscaler = StandardAutoscaler( @@ -562,6 +576,7 @@ class AutoscalingTest(unittest.TestCase): lm = LoadMetrics() self.provider.create_node({}, { TAG_RAY_NODE_KIND: NODE_KIND_HEAD, + TAG_RAY_NODE_STATUS: STATUS_UP_TO_DATE, TAG_RAY_USER_NODE_TYPE: NODE_TYPE_LEGACY_HEAD }, 1) lm.update("172.0.0.0", {"CPU": 1}, {"CPU": 0}, {}) @@ -658,11 +673,15 @@ class AutoscalingTest(unittest.TestCase): autoscaler.update() # 1 head node. self.waitForNodes(1) - autoscaler.request_resources([{"CPU": 1}]) + autoscaler.load_metrics.set_resource_requests([{"CPU": 1}]) autoscaler.update() # still 1 head node because request_resources fits in the headnode. self.waitForNodes(1) - autoscaler.request_resources([{"CPU": 1}] + [{"CPU": 2}] * 9) + autoscaler.load_metrics.set_resource_requests([{ + "CPU": 1 + }] + [{ + "CPU": 2 + }] * 9) autoscaler.update() self.waitForNodes(2) # Adds a single worker to get its resources. autoscaler.update() @@ -767,7 +786,8 @@ class AutoscalingTest(unittest.TestCase): self.provider = MockProvider() self.provider.create_node({}, { TAG_RAY_NODE_KIND: "head", - TAG_RAY_USER_NODE_TYPE: NODE_TYPE_LEGACY_HEAD + TAG_RAY_USER_NODE_TYPE: NODE_TYPE_LEGACY_HEAD, + TAG_RAY_NODE_STATUS: STATUS_UP_TO_DATE }, 1) head_ip = self.provider.non_terminated_node_ips( tag_filters={TAG_RAY_NODE_KIND: "head"}, )[0] @@ -817,7 +837,11 @@ class AutoscalingTest(unittest.TestCase): config_path = self.write_config(config) self.provider = MockProvider() - self.provider.create_node({}, {TAG_RAY_NODE_KIND: "head"}, 1) + self.provider.create_node({}, { + TAG_RAY_NODE_KIND: "head", + TAG_RAY_USER_NODE_TYPE: NODE_TYPE_LEGACY_HEAD, + TAG_RAY_NODE_STATUS: STATUS_UP_TO_DATE + }, 1) head_ip = self.provider.non_terminated_node_ips( tag_filters={TAG_RAY_NODE_KIND: "head"}, )[0] @@ -975,6 +999,7 @@ class AutoscalingTest(unittest.TestCase): runner.respond_to_call("json .Config.Env", ["[]" for i in range(11)]) self.provider.create_node({}, { TAG_RAY_NODE_KIND: NODE_KIND_HEAD, + TAG_RAY_NODE_STATUS: STATUS_UP_TO_DATE, TAG_RAY_USER_NODE_TYPE: NODE_TYPE_LEGACY_HEAD }, 1) lm = LoadMetrics() @@ -1096,6 +1121,14 @@ class AutoscalingTest(unittest.TestCase): assert len(self.provider.non_terminated_nodes({})) < 2 def testConfiguresOutdatedNodes(self): + from ray.autoscaler._private.cli_logger import cli_logger + + def do_nothing(*args, **kwargs): + pass + + cli_logger._print = type(cli_logger._print)(do_nothing, + type(cli_logger)) + config_path = self.write_config(SMALL_CLUSTER) self.provider = MockProvider() runner = MockProcessRunner() @@ -1133,6 +1166,7 @@ class AutoscalingTest(unittest.TestCase): runner.respond_to_call("json .Config.Env", ["[]" for i in range(6)]) self.provider.create_node({}, { TAG_RAY_NODE_KIND: NODE_KIND_HEAD, + TAG_RAY_NODE_STATUS: STATUS_UP_TO_DATE, TAG_RAY_USER_NODE_TYPE: NODE_TYPE_LEGACY_HEAD }, 1) lm.update("172.0.0.0", {"CPU": 1}, {"CPU": 0}, {}) diff --git a/python/ray/tests/test_resource_demand_scheduler.py b/python/ray/tests/test_resource_demand_scheduler.py index 2093f1e14..a4bfe7393 100644 --- a/python/ray/tests/test_resource_demand_scheduler.py +++ b/python/ray/tests/test_resource_demand_scheduler.py @@ -1,4 +1,5 @@ import pytest +from datetime import datetime import time import yaml import tempfile @@ -8,13 +9,16 @@ import copy import ray from ray.autoscaler._private.util import \ - rewrite_legacy_yaml_to_available_node_types + rewrite_legacy_yaml_to_available_node_types, format_info_string, \ + format_info_string_no_node_types from ray.tests.test_autoscaler import SMALL_CLUSTER, MockProvider, \ MockProcessRunner from ray.autoscaler._private.providers import (_NODE_PROVIDERS, _clear_provider_cache) -from ray.autoscaler._private.autoscaler import StandardAutoscaler -from ray.autoscaler._private.load_metrics import LoadMetrics +from ray.autoscaler._private.autoscaler import StandardAutoscaler, \ + AutoscalerSummary +from ray.autoscaler._private.load_metrics import LoadMetrics, \ + LoadMetricsSummary from ray.autoscaler._private.commands import get_or_create_head_node from ray.autoscaler._private.resource_demand_scheduler import \ _utilization_score, _add_min_workers_nodes, \ @@ -24,6 +28,7 @@ from ray.core.generated.common_pb2 import Bundle, PlacementStrategy from ray.autoscaler.tags import TAG_RAY_USER_NODE_TYPE, TAG_RAY_NODE_KIND, \ NODE_KIND_WORKER, TAG_RAY_NODE_STATUS, \ STATUS_UP_TO_DATE, STATUS_UNINITIALIZED, \ + STATUS_UPDATE_FAILED, \ NODE_KIND_HEAD, NODE_TYPE_LEGACY_WORKER, \ NODE_TYPE_LEGACY_HEAD from ray.test_utils import same_elements @@ -368,6 +373,7 @@ def test_get_nodes_to_launch_with_min_workers(): provider.create_node({}, { TAG_RAY_USER_NODE_TYPE: "p2.8xlarge", + TAG_RAY_NODE_STATUS: STATUS_UP_TO_DATE, TAG_RAY_NODE_KIND: NODE_KIND_HEAD }, 1) @@ -390,9 +396,13 @@ def test_get_nodes_to_launch_with_min_workers_and_bin_packing(): provider, new_types, 10, head_node_type="p2.8xlarge") provider.create_node({}, { TAG_RAY_NODE_KIND: NODE_KIND_HEAD, + TAG_RAY_NODE_STATUS: STATUS_UP_TO_DATE, + TAG_RAY_USER_NODE_TYPE: "p2.8xlarge" + }, 1) + provider.create_node({}, { + TAG_RAY_NODE_STATUS: STATUS_UP_TO_DATE, TAG_RAY_USER_NODE_TYPE: "p2.8xlarge" }, 1) - provider.create_node({}, {TAG_RAY_USER_NODE_TYPE: "p2.8xlarge"}, 1) nodes = provider.non_terminated_nodes({}) @@ -424,7 +434,10 @@ def test_get_nodes_to_launch_limits(): scheduler = ResourceDemandScheduler( provider, TYPES_A, 3, head_node_type="p2.8xlarge") - provider.create_node({}, {TAG_RAY_USER_NODE_TYPE: "p2.8xlarge"}, 2) + provider.create_node({}, { + TAG_RAY_NODE_STATUS: STATUS_UP_TO_DATE, + TAG_RAY_USER_NODE_TYPE: "p2.8xlarge" + }, 2) nodes = provider.non_terminated_nodes({}) @@ -442,7 +455,10 @@ def test_calculate_node_resources(): scheduler = ResourceDemandScheduler( provider, TYPES_A, 10, head_node_type="p2.8xlarge") - provider.create_node({}, {TAG_RAY_USER_NODE_TYPE: "p2.8xlarge"}, 2) + provider.create_node({}, { + TAG_RAY_NODE_STATUS: STATUS_UP_TO_DATE, + TAG_RAY_USER_NODE_TYPE: "p2.8xlarge" + }, 2) nodes = provider.non_terminated_nodes({}) @@ -1059,6 +1075,86 @@ class LoadMetricsTest(unittest.TestCase): pending_placement_groups=pending_placement_groups) assert lm.get_pending_placement_groups() == pending_placement_groups + def testSummary(self): + lm = LoadMetrics(local_ip="1.1.1.1") + assert lm.summary() is not None + pending_placement_groups = [ + PlacementGroupTableData( + state=PlacementGroupTableData.RESCHEDULING, + strategy=PlacementStrategy.PACK, + bundles=([Bundle(unit_resources={"GPU": 2})] * 2)), + PlacementGroupTableData( + state=PlacementGroupTableData.RESCHEDULING, + strategy=PlacementStrategy.PACK, + bundles=([Bundle(unit_resources={"GPU": 2})] * 2)), + ] + lm.update("1.1.1.1", {"CPU": 64}, {"CPU": 2}, {}) + lm.update("1.1.1.2", { + "CPU": 64, + "GPU": 8, + "accelerator_type:V100": 1 + }, { + "CPU": 0, + "GPU": 1, + "accelerator_type:V100": 1 + }, {}) + lm.update("1.1.1.3", { + "CPU": 64, + "GPU": 8, + "accelerator_type:V100": 1 + }, { + "CPU": 0, + "GPU": 0, + "accelerator_type:V100": 0.92 + }, {}) + lm.update( + "1.1.1.4", {"CPU": 2}, {"CPU": 2}, {}, + waiting_bundles=[{ + "GPU": 2 + }] * 10, + infeasible_bundles=[{ + "CPU": 16 + }, { + "GPU": 2 + }, { + "CPU": 16, + "GPU": 2 + }], + pending_placement_groups=pending_placement_groups) + + lm.set_resource_requests([{"CPU": 64}, {"GPU": 8}, {"GPU": 8}]) + + summary = lm.summary() + + assert summary.head_ip == "1.1.1.1" + + assert summary.usage["CPU"] == (190, 194) + assert summary.usage["GPU"] == (15, 16) + assert summary.usage["accelerator_type:V100"][1] == 2, \ + "Not comparing the usage value due to floating point error." + + assert ({"GPU": 2}, 11) in summary.resource_demand + assert ({"CPU": 16}, 1) in summary.resource_demand + assert ({"CPU": 16, "GPU": 2}, 1) in summary.resource_demand + assert len(summary.resource_demand) == 3 + + assert ({ + "bundles": [({ + "GPU": 2 + }, 2)], + "strategy": "PACK" + }, 2) in summary.pg_demand + assert len(summary.pg_demand) == 1 + + assert ({"GPU": 8}, 2) in summary.request_demand + assert ({"CPU": 64}, 1) in summary.request_demand + assert len(summary.request_demand) == 2 + + # TODO (Alex): This set of nodes won't be very useful in practice + # because the node:xxx.xxx.xxx.xxx resources means that no 2 nodes + # should ever have the same set of resources. + assert len(summary.node_types) == 3 + class AutoscalingTest(unittest.TestCase): def setUp(self): @@ -1157,6 +1253,87 @@ class AutoscalingTest(unittest.TestCase): self.provider.mock_nodes[0].tags.get(TAG_RAY_USER_NODE_TYPE), "empty_node") + def testSummary(self): + config = copy.deepcopy(MULTI_WORKER_CLUSTER) + config["available_node_types"]["m4.large"]["min_workers"] = 2 + config["max_workers"] = 10 + config["docker"] = {} + config_path = self.write_config(config) + self.provider = MockProvider() + runner = MockProcessRunner() + self.provider.create_node({}, { + TAG_RAY_NODE_KIND: NODE_KIND_HEAD, + TAG_RAY_USER_NODE_TYPE: "empty_node", + TAG_RAY_NODE_STATUS: STATUS_UP_TO_DATE + }, 1) + head_ip = self.provider.non_terminated_node_ips({})[0] + lm = LoadMetrics(local_ip=head_ip) + autoscaler = StandardAutoscaler( + config_path, + lm, + max_failures=0, + max_launch_batch=1, + max_concurrent_launches=10, + process_runner=runner, + update_interval_s=0) + assert len(self.provider.non_terminated_nodes({})) == 1 + autoscaler.update() + self.waitForNodes(3) + + for ip in self.provider.non_terminated_node_ips({}): + lm.update(ip, {"CPU": 2}, {"CPU": 0}, {}) + + lm.update(head_ip, {"CPU": 16}, {"CPU": 1}, {}) + autoscaler.update() + + while True: + if len( + self.provider.non_terminated_nodes({ + TAG_RAY_NODE_STATUS: STATUS_UP_TO_DATE + })) == 3: + break + + # After this section, the p2.xlarge is now in the setup process. + runner.ready_to_run.clear() + + lm.update( + head_ip, {"CPU": 16}, {"CPU": 1}, {}, waiting_bundles=[{ + "GPU": 1 + }]) + + autoscaler.update() + self.waitForNodes(4) + + self.provider.ready_to_create.clear() + lm.set_resource_requests([{"CPU": 64}] * 2) + autoscaler.update() + + self.provider.create_node( + {}, { + TAG_RAY_NODE_KIND: NODE_KIND_WORKER, + TAG_RAY_USER_NODE_TYPE: "m4.4xlarge", + TAG_RAY_NODE_STATUS: STATUS_UPDATE_FAILED + }, + 1, + _skip_wait=True) + self.waitForNodes(5) + + print(f"Head ip: {head_ip}") + summary = autoscaler.summary() + + assert summary.active_nodes["m4.large"] == 2 + assert summary.active_nodes["empty_node"] == 1 + assert len(summary.active_nodes) == 2, summary.active_nodes + + assert summary.pending_nodes == [("172.0.0.3", "p2.xlarge")] + assert summary.pending_launches == {"m4.16xlarge": 2} + + assert summary.failed_nodes == [("172.0.0.4", "m4.4xlarge")] + + # Make sure we return something (and don't throw exceptions). Let's not + # get bogged down with a full cli test here. + assert len(autoscaler.info_string()) > 1 + def testScaleUpMinSanity(self): config = copy.deepcopy(MULTI_WORKER_CLUSTER) config["available_node_types"]["m4.large"]["min_workers"] = \ @@ -1166,6 +1343,7 @@ class AutoscalingTest(unittest.TestCase): runner = MockProcessRunner() self.provider.create_node({}, { TAG_RAY_NODE_KIND: NODE_KIND_HEAD, + TAG_RAY_NODE_STATUS: STATUS_UP_TO_DATE, TAG_RAY_USER_NODE_TYPE: "empty_node" }, 1) autoscaler = StandardAutoscaler( @@ -1191,6 +1369,7 @@ class AutoscalingTest(unittest.TestCase): runner = MockProcessRunner() self.provider.create_node({}, { TAG_RAY_NODE_KIND: NODE_KIND_HEAD, + TAG_RAY_NODE_STATUS: STATUS_UP_TO_DATE, TAG_RAY_USER_NODE_TYPE: "empty_node" }, 1) autoscaler = StandardAutoscaler( @@ -1217,6 +1396,7 @@ class AutoscalingTest(unittest.TestCase): runner = MockProcessRunner() self.provider.create_node({}, { TAG_RAY_NODE_KIND: "head", + TAG_RAY_NODE_STATUS: STATUS_UP_TO_DATE, TAG_RAY_USER_NODE_TYPE: "m4.4xlarge" }, 1) head_ip = self.provider.non_terminated_node_ips({})[0] @@ -1227,7 +1407,7 @@ class AutoscalingTest(unittest.TestCase): max_failures=0, process_runner=runner, update_interval_s=0) - + head_ip = self.provider.non_terminated_node_ips({})[0] assert len(self.provider.non_terminated_nodes({})) == 1 autoscaler.update() self.waitForNodes(1) @@ -1285,6 +1465,7 @@ class AutoscalingTest(unittest.TestCase): runner = MockProcessRunner() self.provider.create_node({}, { TAG_RAY_NODE_KIND: NODE_KIND_HEAD, + TAG_RAY_NODE_STATUS: STATUS_UP_TO_DATE, TAG_RAY_USER_NODE_TYPE: "empty_node" }, 1) lm = LoadMetrics("172.0.0.0") @@ -1304,11 +1485,15 @@ class AutoscalingTest(unittest.TestCase): } == {"p2.8xlarge", "m4.large"} self.provider.create_node({}, { TAG_RAY_USER_NODE_TYPE: "p2.8xlarge", - TAG_RAY_NODE_KIND: NODE_KIND_WORKER + TAG_RAY_NODE_STATUS: STATUS_UP_TO_DATE, + TAG_RAY_NODE_KIND: NODE_KIND_WORKER, + TAG_RAY_NODE_STATUS: STATUS_UP_TO_DATE }, 2) self.provider.create_node({}, { TAG_RAY_USER_NODE_TYPE: "m4.16xlarge", - TAG_RAY_NODE_KIND: NODE_KIND_WORKER + TAG_RAY_NODE_STATUS: STATUS_UP_TO_DATE, + TAG_RAY_NODE_KIND: NODE_KIND_WORKER, + TAG_RAY_NODE_STATUS: STATUS_UP_TO_DATE }, 2) assert len(self.provider.non_terminated_nodes({})) == 7 # Make sure that after idle_timeout_minutes we don't kill idle @@ -1339,7 +1524,9 @@ class AutoscalingTest(unittest.TestCase): self.provider = MockProvider() self.provider.create_node({}, { TAG_RAY_NODE_KIND: "head", - TAG_RAY_USER_NODE_TYPE: "p2.xlarge" + TAG_RAY_NODE_STATUS: STATUS_UP_TO_DATE, + TAG_RAY_USER_NODE_TYPE: "p2.xlarge", + TAG_RAY_NODE_STATUS: STATUS_UP_TO_DATE }, 1) head_ip = self.provider.non_terminated_node_ips({})[0] self.provider.finish_starting_nodes() @@ -1377,7 +1564,9 @@ class AutoscalingTest(unittest.TestCase): self.provider = MockProvider() self.provider.create_node({}, { TAG_RAY_USER_NODE_TYPE: "p2.8xlarge", - TAG_RAY_NODE_KIND: "head" + TAG_RAY_NODE_STATUS: STATUS_UP_TO_DATE, + TAG_RAY_NODE_KIND: "head", + TAG_RAY_NODE_STATUS: STATUS_UP_TO_DATE }, 1) runner = MockProcessRunner() autoscaler = StandardAutoscaler( @@ -1391,16 +1580,16 @@ class AutoscalingTest(unittest.TestCase): # These requests fit on the head node. autoscaler.update() self.waitForNodes(1) - autoscaler.request_resources([{"CPU": 1}]) + autoscaler.load_metrics.set_resource_requests([{"CPU": 1}]) autoscaler.update() self.waitForNodes(1) assert len(self.provider.mock_nodes) == 1 - autoscaler.request_resources([{"GPU": 8}]) + autoscaler.load_metrics.set_resource_requests([{"GPU": 8}]) autoscaler.update() self.waitForNodes(1) # This request requires an additional worker node. - autoscaler.request_resources([{"GPU": 8}] * 2) + autoscaler.load_metrics.set_resource_requests([{"GPU": 8}] * 2) autoscaler.update() self.waitForNodes(2) assert self.provider.mock_nodes[1].node_type == "p2.8xlarge" @@ -1415,6 +1604,7 @@ class AutoscalingTest(unittest.TestCase): runner.respond_to_call("json .Config.Env", ["[]" for i in range(6)]) self.provider.create_node({}, { TAG_RAY_NODE_KIND: NODE_KIND_HEAD, + TAG_RAY_NODE_STATUS: STATUS_UP_TO_DATE, TAG_RAY_USER_NODE_TYPE: "empty_node" }, 1) autoscaler = StandardAutoscaler( @@ -1426,15 +1616,15 @@ class AutoscalingTest(unittest.TestCase): assert len(self.provider.non_terminated_nodes({})) == 1 autoscaler.update() self.waitForNodes(1) - autoscaler.request_resources([{"CPU": 1}]) + autoscaler.load_metrics.set_resource_requests([{"CPU": 1}]) autoscaler.update() self.waitForNodes(2) assert self.provider.mock_nodes[1].node_type == "m4.large" - autoscaler.request_resources([{"GPU": 8}]) + autoscaler.load_metrics.set_resource_requests([{"GPU": 8}]) autoscaler.update() self.waitForNodes(3) assert self.provider.mock_nodes[2].node_type == "p2.8xlarge" - autoscaler.request_resources([{"CPU": 32}] * 4) + autoscaler.load_metrics.set_resource_requests([{"CPU": 32}] * 4) autoscaler.update() self.waitForNodes(5) @@ -1450,6 +1640,7 @@ class AutoscalingTest(unittest.TestCase): runner = MockProcessRunner() runner.respond_to_call("json .Config.Env", ["[]" for i in range(2)]) self.provider.create_node({}, { + TAG_RAY_NODE_STATUS: STATUS_UP_TO_DATE, TAG_RAY_NODE_KIND: NODE_KIND_HEAD, TAG_RAY_USER_NODE_TYPE: "empty_node" }, 1) @@ -1462,11 +1653,11 @@ class AutoscalingTest(unittest.TestCase): assert len(self.provider.non_terminated_nodes({})) == 1 autoscaler.update() self.waitForNodes(0, tag_filters={TAG_RAY_NODE_KIND: NODE_KIND_WORKER}) - autoscaler.request_resources([{"CPU": 1}]) + autoscaler.load_metrics.set_resource_requests([{"CPU": 1}]) autoscaler.update() self.waitForNodes(1, tag_filters={TAG_RAY_NODE_KIND: NODE_KIND_WORKER}) assert self.provider.mock_nodes[1].node_type == "m4.large" - autoscaler.request_resources([{"GPU": 8}]) + autoscaler.load_metrics.set_resource_requests([{"GPU": 8}]) autoscaler.update() self.waitForNodes(2, tag_filters={TAG_RAY_NODE_KIND: NODE_KIND_WORKER}) assert self.provider.mock_nodes[2].node_type == "p2.8xlarge" @@ -1495,6 +1686,7 @@ class AutoscalingTest(unittest.TestCase): runner = MockProcessRunner() self.provider.create_node({}, { TAG_RAY_NODE_KIND: NODE_KIND_HEAD, + TAG_RAY_NODE_STATUS: STATUS_UP_TO_DATE, TAG_RAY_USER_NODE_TYPE: "empty_node" }, 1) lm = LoadMetrics("172.0.0.0") @@ -1541,6 +1733,7 @@ class AutoscalingTest(unittest.TestCase): runner.respond_to_call("json .Config.Env", ["[]" for i in range(4)]) self.provider.create_node({}, { TAG_RAY_NODE_KIND: NODE_KIND_HEAD, + TAG_RAY_NODE_STATUS: STATUS_UP_TO_DATE, TAG_RAY_USER_NODE_TYPE: "empty_node" }, 1) lm = LoadMetrics("172.0.0.0") @@ -1554,15 +1747,15 @@ class AutoscalingTest(unittest.TestCase): assert len(self.provider.non_terminated_nodes({})) == 1 autoscaler.update() self.waitForNodes(1) - autoscaler.request_resources([{"CPU": 1}]) + autoscaler.load_metrics.set_resource_requests([{"CPU": 1}]) autoscaler.update() self.waitForNodes(2) assert self.provider.mock_nodes[1].node_type == "m4.large" - autoscaler.request_resources([{"GPU": 8}]) + autoscaler.load_metrics.set_resource_requests([{"GPU": 8}]) autoscaler.update() self.waitForNodes(3) assert self.provider.mock_nodes[2].node_type == "p2.8xlarge" - autoscaler.request_resources([{"GPU": 1}] * 9) + autoscaler.load_metrics.set_resource_requests([{"GPU": 1}] * 9) autoscaler.update() self.waitForNodes(4) assert self.provider.mock_nodes[3].node_type == "p2.xlarge" @@ -1601,6 +1794,7 @@ class AutoscalingTest(unittest.TestCase): runner.respond_to_call("json .Config.Env", ["[]" for i in range(5)]) self.provider.create_node({}, { TAG_RAY_NODE_KIND: NODE_KIND_HEAD, + TAG_RAY_NODE_STATUS: STATUS_UP_TO_DATE, TAG_RAY_USER_NODE_TYPE: "empty_node" }, 1) autoscaler = StandardAutoscaler( @@ -1612,21 +1806,21 @@ class AutoscalingTest(unittest.TestCase): assert len(self.provider.non_terminated_nodes({})) == 1 autoscaler.update() self.waitForNodes(1) - autoscaler.request_resources([{"CPU": 1}]) + autoscaler.load_metrics.set_resource_requests([{"CPU": 1}]) autoscaler.update() self.waitForNodes(2) assert self.provider.mock_nodes[1].node_type == "m4.large" - autoscaler.request_resources([{"GPU": 8}]) + autoscaler.load_metrics.set_resource_requests([{"GPU": 8}]) autoscaler.update() self.waitForNodes(3) assert self.provider.mock_nodes[2].node_type == "p2.8xlarge" - autoscaler.request_resources([{"GPU": 1}] * 9) + autoscaler.load_metrics.set_resource_requests([{"GPU": 1}] * 9) autoscaler.update() self.waitForNodes(4) assert self.provider.mock_nodes[3].node_type == "p2.xlarge" autoscaler.update() # Fill up m4, p2.8, p2 and request 2 more CPUs - autoscaler.request_resources([{ + autoscaler.load_metrics.set_resource_requests([{ "CPU": 2 }, { "CPU": 16 @@ -1674,6 +1868,7 @@ class AutoscalingTest(unittest.TestCase): runner = MockProcessRunner() self.provider.create_node({}, { TAG_RAY_NODE_KIND: NODE_KIND_HEAD, + TAG_RAY_NODE_STATUS: STATUS_UP_TO_DATE, TAG_RAY_USER_NODE_TYPE: "empty_node" }, 1) autoscaler = StandardAutoscaler( @@ -1702,6 +1897,7 @@ class AutoscalingTest(unittest.TestCase): runner = MockProcessRunner() self.provider.create_node({}, { TAG_RAY_NODE_KIND: NODE_KIND_HEAD, + TAG_RAY_NODE_STATUS: STATUS_UP_TO_DATE, TAG_RAY_USER_NODE_TYPE: "empty_node" }, 1) autoscaler = StandardAutoscaler( @@ -1713,11 +1909,11 @@ class AutoscalingTest(unittest.TestCase): assert len(self.provider.non_terminated_nodes({})) == 1 autoscaler.update() self.waitForNodes(1) - autoscaler.request_resources([{"CPU": 1}]) + autoscaler.load_metrics.set_resource_requests([{"CPU": 1}]) autoscaler.update() self.waitForNodes(2) assert self.provider.mock_nodes[1].node_type == "m4.large" - autoscaler.request_resources([{"GPU": 8}]) + autoscaler.load_metrics.set_resource_requests([{"GPU": 8}]) autoscaler.update() self.waitForNodes(3) assert self.provider.mock_nodes[2].node_type == "p2.8xlarge" @@ -1750,6 +1946,7 @@ class AutoscalingTest(unittest.TestCase): runner = MockProcessRunner() self.provider.create_node({}, { TAG_RAY_NODE_KIND: NODE_KIND_HEAD, + TAG_RAY_NODE_STATUS: STATUS_UP_TO_DATE, TAG_RAY_USER_NODE_TYPE: "empty_node" }, 1) lm = LoadMetrics("172.0.0.0") @@ -1762,7 +1959,10 @@ class AutoscalingTest(unittest.TestCase): update_interval_s=0) autoscaler.update() self.waitForNodes(0, tag_filters={TAG_RAY_NODE_KIND: NODE_KIND_WORKER}) - autoscaler.request_resources([{"CPU": 0.2, "WORKER": 1.0}]) + autoscaler.load_metrics.set_resource_requests([{ + "CPU": 0.2, + "WORKER": 1.0 + }]) autoscaler.update() self.waitForNodes(1, tag_filters={TAG_RAY_NODE_KIND: NODE_KIND_WORKER}) non_terminated_nodes = autoscaler.provider.non_terminated_nodes({}) @@ -1783,10 +1983,16 @@ class AutoscalingTest(unittest.TestCase): autoscaler.update() # this fits on request_resources()! self.waitForNodes(1, tag_filters={TAG_RAY_NODE_KIND: NODE_KIND_WORKER}) - autoscaler.request_resources([{"CPU": 0.2, "WORKER": 1.0}] * 2) + autoscaler.load_metrics.set_resource_requests([{ + "CPU": 0.2, + "WORKER": 1.0 + }] * 2) autoscaler.update() self.waitForNodes(2, tag_filters={TAG_RAY_NODE_KIND: NODE_KIND_WORKER}) - autoscaler.request_resources([{"CPU": 0.2, "WORKER": 1.0}]) + autoscaler.load_metrics.set_resource_requests([{ + "CPU": 0.2, + "WORKER": 1.0 + }]) lm.update( node_ip, config["available_node_types"]["def_worker"]["resources"], {}, {}, @@ -1859,6 +2065,7 @@ class AutoscalingTest(unittest.TestCase): runner.respond_to_call("json .Config.Env", ["[]" for i in range(3)]) self.provider.create_node({}, { TAG_RAY_NODE_KIND: NODE_KIND_HEAD, + TAG_RAY_NODE_STATUS: STATUS_UP_TO_DATE, TAG_RAY_USER_NODE_TYPE: "empty_node" }, 1) lm = LoadMetrics("172.0.0.0") @@ -1868,7 +2075,10 @@ class AutoscalingTest(unittest.TestCase): max_failures=0, process_runner=runner, update_interval_s=0) - autoscaler.request_resources([{"CPU": 0.2, "WORKER": 1.0}]) + autoscaler.load_metrics.set_resource_requests([{ + "CPU": 0.2, + "WORKER": 1.0 + }]) autoscaler.update() # 1 min worker for both min_worker and request_resources() self.waitForNodes(1, tag_filters={TAG_RAY_NODE_KIND: NODE_KIND_WORKER}) @@ -1887,16 +2097,25 @@ class AutoscalingTest(unittest.TestCase): "CPU": 0.2, "WORKER": 1.0 }]) - autoscaler.request_resources([{"CPU": 0.2, "WORKER": 1.0}] * 2) + autoscaler.load_metrics.set_resource_requests([{ + "CPU": 0.2, + "WORKER": 1.0 + }] * 2) autoscaler.update() # 2 requested_resource, 1 min worker, 1 free node -> 2 nodes total self.waitForNodes(2, tag_filters={TAG_RAY_NODE_KIND: NODE_KIND_WORKER}) - autoscaler.request_resources([{"CPU": 0.2, "WORKER": 1.0}]) + autoscaler.load_metrics.set_resource_requests([{ + "CPU": 0.2, + "WORKER": 1.0 + }]) autoscaler.update() # Still 2 because the second one is not connected and hence # request_resources occupies the connected node. self.waitForNodes(2, tag_filters={TAG_RAY_NODE_KIND: NODE_KIND_WORKER}) - autoscaler.request_resources([{"CPU": 0.2, "WORKER": 1.0}] * 3) + autoscaler.load_metrics.set_resource_requests([{ + "CPU": 0.2, + "WORKER": 1.0 + }] * 3) lm.update( node_ip, config["available_node_types"]["def_worker"]["resources"], {}, {}, @@ -1906,7 +2125,7 @@ class AutoscalingTest(unittest.TestCase): }] * 3) autoscaler.update() self.waitForNodes(3, tag_filters={TAG_RAY_NODE_KIND: NODE_KIND_WORKER}) - autoscaler.request_resources([]) + autoscaler.load_metrics.set_resource_requests([]) lm.update("172.0.0.2", config["available_node_types"]["def_worker"]["resources"], @@ -1919,6 +2138,8 @@ class AutoscalingTest(unittest.TestCase): lm.update(node_ip, config["available_node_types"]["def_worker"]["resources"], {}, {}) + print("============ Should scale down from here =============", + node_id) autoscaler.update() self.waitForNodes(1, tag_filters={TAG_RAY_NODE_KIND: NODE_KIND_WORKER}) # If node {node_id} was terminated any time then it's state will be set @@ -1958,6 +2179,7 @@ class AutoscalingTest(unittest.TestCase): runner.respond_to_call("json .Config.Env", ["[]" for i in range(2)]) self.provider.create_node({}, { TAG_RAY_NODE_KIND: NODE_KIND_HEAD, + TAG_RAY_NODE_STATUS: STATUS_UP_TO_DATE, TAG_RAY_USER_NODE_TYPE: "empty_node" }, 1) lm = LoadMetrics("172.0.0.0") @@ -1967,7 +2189,10 @@ class AutoscalingTest(unittest.TestCase): max_failures=0, process_runner=runner, update_interval_s=0) - autoscaler.request_resources([{"CPU": 2, "WORKER": 1.0}] * 2) + autoscaler.load_metrics.set_resource_requests([{ + "CPU": 2, + "WORKER": 1.0 + }] * 2) autoscaler.update() # 2 min worker for both min_worker and request_resources(), not 3. self.waitForNodes(2, tag_filters={TAG_RAY_NODE_KIND: NODE_KIND_WORKER}) @@ -1998,12 +2223,14 @@ class AutoscalingTest(unittest.TestCase): "max_workers": 3, } }) + config["idle_timeout_minutes"] = 0 config_path = self.write_config(config) self.provider = MockProvider() self.provider.create_node({}, { TAG_RAY_NODE_KIND: "head", - TAG_RAY_USER_NODE_TYPE: "empty_node" + TAG_RAY_NODE_STATUS: STATUS_UP_TO_DATE, + TAG_RAY_USER_NODE_TYPE: "empty_node", }, 1) runner = MockProcessRunner() @@ -2023,7 +2250,10 @@ class AutoscalingTest(unittest.TestCase): waiting_bundles=[{ "CPU": 2 }]) - autoscaler.request_resources([{"CPU": 2, "GPU": 1}] * 2) + autoscaler.load_metrics.set_resource_requests([{ + "CPU": 2, + "GPU": 1 + }] * 2) autoscaler.update() # 1 head, 1 worker. self.waitForNodes(2) @@ -2041,6 +2271,140 @@ class AutoscalingTest(unittest.TestCase): self.waitForNodes(2) +def format_pg(pg): + strategy = pg["strategy"] + bundles = pg["bundles"] + shape_strs = [] + for bundle, count in bundles: + shape_strs.append(f"{bundle} * {count}") + bundles_str = ", ".join(shape_strs) + return f"{bundles_str} ({strategy})" + + +def test_info_string(): + lm_summary = LoadMetricsSummary( + head_ip="0.0.0.0", + usage={ + "CPU": (530, 544), + "GPU": (2, 2), + "AcceleratorType:V100": (0, 2), + "memory": (0, 1583.19), + "object_store_memory": (0, 471.02) + }, + resource_demand=[({ + "CPU": 1 + }, 150)], + pg_demand=[({ + "bundles": [({ + "CPU": 4 + }, 5)], + "strategy": "PACK" + }, 420)], + request_demand=[({ + "CPU": 16 + }, 100)], + node_types=[]) + autoscaler_summary = AutoscalerSummary( + active_nodes={ + "p3.2xlarge": 2, + "m4.4xlarge": 20 + }, + pending_nodes=[("1.2.3.4", "m4.4xlarge"), ("1.2.3.5", "m4.4xlarge")], + pending_launches={"m4.4xlarge": 2}, + failed_nodes=[("1.2.3.6", "p3.2xlarge")]) + + expected = """ +======== Autoscaler status: 2020-12-28 01:02:03 ======== +Node status +-------------------------------------------------------- +Healthy: + 2 p3.2xlarge + 20 m4.4xlarge +Pending: + m4.4xlarge, 2 launching + 1.2.3.4: m4.4xlarge, setting up + 1.2.3.5: m4.4xlarge, setting up +Recent failures: + (no failures) + +Resources +-------------------------------------------------------- + +Usage: + 530/544 CPU + 2/2 GPU + 0/2 AcceleratorType:V100 + 0.00/77.304 GiB memory + 0.00/22.999 GiB object_store_memory + +Demands: + {'CPU': 1}: 150+ pending tasks/actors + {'CPU': 4} * 5 (PACK): 420+ pending placement groups + {'CPU': 16}: 100+ from request_resources() +""".strip() + + actual = format_info_string( + lm_summary, + autoscaler_summary, + time=datetime(year=2020, month=12, day=28, hour=1, minute=2, second=3)) + print(actual) + assert expected == actual + + +def test_info_string_no_node_type(): + lm_summary = LoadMetricsSummary( + head_ip="0.0.0.0", + usage={ + "CPU": (530, 544), + "GPU": (2, 2), + "AcceleratorType:V100": (0, 2), + "memory": (0, 1583.19), + "object_store_memory": (0, 471.02) + }, + resource_demand=[({ + "CPU": 1 + }, 150)], + pg_demand=[({ + "bundles": [({ + "CPU": 4 + }, 5)], + "strategy": "PACK" + }, 420)], + request_demand=[({ + "CPU": 16 + }, 100)], + node_types=[({ + "CPU": 16 + }, 1)]) + + expected = """ +======== Cluster status: 2020-12-28 01:02:03 ======== +Node status +----------------------------------------------------- + 1 node(s) with resources: {'CPU': 16} + +Resources +----------------------------------------------------- +Usage: + 530/544 CPU + 2/2 GPU + 0/2 AcceleratorType:V100 + 0.00/77.304 GiB memory + 0.00/22.999 GiB object_store_memory + +Demands: + {'CPU': 1}: 150+ pending tasks/actors + {'CPU': 4} * 5 (PACK): 420+ pending placement groups + {'CPU': 16}: 100+ from request_resources() +""".strip() + + actual = format_info_string_no_node_types( + lm_summary, + time=datetime(year=2020, month=12, day=28, hour=1, minute=2, second=3)) + print(actual) + assert expected == actual + + if __name__ == "__main__": import sys sys.exit(pytest.main(["-v", __file__])) From 668ea0bc261d85d4ac2d6474974b02e3ff136bc0 Mon Sep 17 00:00:00 2001 From: Corey Lowman Date: Wed, 23 Dec 2020 16:37:46 -0500 Subject: [PATCH 80/88] Fix typo RMSProp -> RMSprop (#13063) --- rllib/agents/impala/vtrace_torch_policy.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rllib/agents/impala/vtrace_torch_policy.py b/rllib/agents/impala/vtrace_torch_policy.py index a42707e4d..22077202e 100644 --- a/rllib/agents/impala/vtrace_torch_policy.py +++ b/rllib/agents/impala/vtrace_torch_policy.py @@ -246,7 +246,7 @@ def choose_optimizer(policy, config): return torch.optim.Adam( params=policy.model.parameters(), lr=policy.cur_lr) else: - return torch.optim.RMSProp( + return torch.optim.RMSprop( params=policy.model.parameters(), lr=policy.cur_lr, weight_decay=config["decay"], From 3cc213ddf654d9eaa7b77e6aa8ae6cc5c35aaf34 Mon Sep 17 00:00:00 2001 From: Edward Oakes Date: Wed, 23 Dec 2020 18:00:02 -0600 Subject: [PATCH 81/88] [serve] Centralize HTTP-related logic in HTTPState (#13020) --- python/ray/serve/api.py | 14 +- python/ray/serve/config.py | 11 +- python/ray/serve/controller.py | 225 ++++++++++++++++----------------- 3 files changed, 123 insertions(+), 127 deletions(-) diff --git a/python/ray/serve/api.py b/python/ray/serve/api.py index 3e4b53b28..5e8de83e7 100644 --- a/python/ray/serve/api.py +++ b/python/ray/serve/api.py @@ -17,7 +17,8 @@ from ray.serve.handle import RayServeHandle, RayServeSyncHandle from ray.serve.utils import (block_until_http_ready, format_actor_name, get_random_letters, logger, get_conda_env_dir) from ray.serve.exceptions import RayServeException -from ray.serve.config import BackendConfig, ReplicaConfig, BackendMetadata +from ray.serve.config import (BackendConfig, ReplicaConfig, BackendMetadata, + HTTPConfig) from ray.serve.env import CondaEnv from ray.serve.router import RequestMetadata, Router from ray.actor import ActorHandle @@ -93,8 +94,7 @@ class Client: self._controller_name = controller_name self._detached = detached self._shutdown = False - self._http_host, self._http_port = ray.get( - controller.get_http_config.remote()) + self._http_config = ray.get(controller.get_http_config.remote()) self._sync_proxied_router = None self._async_proxied_router = None @@ -237,8 +237,8 @@ class Client: num_cpus=0, resources={ node_id: 0.01 }).remote( - "http://{}:{}/-/routes".format(self._http_host, - self._http_port), + "http://{}:{}/-/routes".format(self._http_config.host, + self._http_config.port), check_ready=check_ready, timeout=HTTP_PROXY_TIMEOUT) futures.append(future) @@ -559,9 +559,7 @@ def start(detached: bool = False, max_task_retries=-1, ).remote( controller_name, - http_host, - http_port, - http_middlewares, + HTTPConfig(http_host, http_port, http_middlewares), detached=detached) if http_host is not None: diff --git a/python/ray/serve/config.py b/python/ray/serve/config.py index 0a8070d9e..104d6da7c 100644 --- a/python/ray/serve/config.py +++ b/python/ray/serve/config.py @@ -2,8 +2,8 @@ import inspect from pydantic import BaseModel, PositiveInt, validator from ray.serve.constants import ASYNC_CONCURRENCY -from typing import Optional, Dict, Any -from dataclasses import dataclass +from typing import Optional, Dict, Any, List +from dataclasses import dataclass, field def _callable_accepts_batch(func_or_class): @@ -191,3 +191,10 @@ class ReplicaConfig: raise TypeError( "resources in ray_actor_options must be a dictionary.") self.resource_dict.update(custom_resources) + + +@dataclass +class HTTPConfig: + host: str = field(init=True) + port: int = field(init=True) + middlewares: List[Any] = field(init=True) diff --git a/python/ray/serve/controller.py b/python/ray/serve/controller.py index 4a4b754ff..62ccb87b2 100644 --- a/python/ray/serve/controller.py +++ b/python/ray/serve/controller.py @@ -20,7 +20,7 @@ from ray.serve.kv_store import RayInternalKVStore from ray.serve.exceptions import RayServeException from ray.serve.utils import (format_actor_name, get_random_letters, logger, try_schedule_resources_on_nodes, get_all_node_ids) -from ray.serve.config import BackendConfig, ReplicaConfig +from ray.serve.config import BackendConfig, ReplicaConfig, HTTPConfig from ray.serve.long_poll import LongPollHost from ray.actor import ActorHandle @@ -80,6 +80,77 @@ class TrafficPolicy: return f"" +class HTTPState: + def __init__(self, controller_name: str, detached: bool, + config: HTTPConfig): + self._controller_name = controller_name + self._detached = detached + self._config = config + self._proxy_actors: Dict[NodeId, ActorHandle] = dict() + + # Will populate self.proxy_actors with existing actors. + self._start_proxies_if_needed() + + def get_config(self): + return self._config + + def get_http_proxy_handles(self) -> Dict[NodeId, ActorHandle]: + return self._proxy_actors + + def update(self): + self._start_proxies_if_needed() + self._stop_proxies_if_needed() + + def _start_proxies_if_needed(self) -> None: + """Start a proxy on every node if it doesn't already exist.""" + if self._config.host is None: + return + + for node_id, node_resource in get_all_node_ids(): + if node_id in self._proxy_actors: + continue + + name = format_actor_name(SERVE_PROXY_NAME, self._controller_name, + node_id) + try: + proxy = ray.get_actor(name) + except ValueError: + logger.info("Starting HTTP proxy with name '{}' on node '{}' " + "listening on '{}:{}'".format( + name, node_id, self._config.host, + self._config.port)) + proxy = HTTPProxyActor.options( + name=name, + lifetime="detached" if self._detached else None, + max_concurrency=ASYNC_CONCURRENCY, + max_restarts=-1, + max_task_retries=-1, + resources={ + node_resource: 0.01 + }, + ).remote( + self._config.host, + self._config.port, + controller_name=self._controller_name, + http_middlewares=self._config.middlewares) + + self._proxy_actors[node_id] = proxy + + def _stop_proxies_if_needed(self) -> bool: + """Removes proxy actors from any nodes that no longer exist.""" + all_node_ids = {node_id for node_id, _ in get_all_node_ids()} + to_stop = [] + for node_id in self._proxy_actors: + if node_id not in all_node_ids: + logger.info("Removing HTTP proxy on removed node '{}'.".format( + node_id)) + to_stop.append(node_id) + + for node_id in to_stop: + proxy = self._proxy_actors.pop(node_id) + ray.kill(proxy, no_restart=True) + + class BackendInfo(BaseModel): # TODO(architkulkarni): Add type hint for worker_class after upgrading # cloudpickle and adding types to RayServeWrappedReplica @@ -93,6 +164,32 @@ class BackendInfo(BaseModel): arbitrary_types_allowed = True +class BackendState: + def __init__(self, checkpoint: bytes = None): + self.backends: Dict[BackendTag, BackendInfo] = dict() + + if checkpoint is not None: + self.backends = pickle.loads(checkpoint) + + def checkpoint(self): + return pickle.dumps(self.backends) + + def get_backend_configs(self) -> Dict[BackendTag, BackendConfig]: + return { + tag: info.backend_config + for tag, info in self.backends.items() + } + + def get_backend(self, backend_tag: BackendTag) -> Optional[BackendInfo]: + return self.backends.get(backend_tag) + + def add_backend(self, + backend_tag: BackendTag, + backend_info: BackendInfo, + goal_id: GoalId = 0) -> None: + self.backends[backend_tag] = backend_info + + class EndpointState: def __init__(self, checkpoint: bytes = None): self.routes: Dict[BackendTag, Tuple[EndpointTag, Any]] = dict() @@ -124,38 +221,11 @@ class EndpointState: return endpoints -class BackendState: - def __init__(self, checkpoint: bytes = None): - self.backends: Dict[BackendTag, BackendInfo] = dict() - - if checkpoint is not None: - self.backends = pickle.loads(checkpoint) - - def checkpoint(self): - return pickle.dumps(self.backends) - - def get_backend_configs(self) -> Dict[BackendTag, BackendConfig]: - return { - tag: info.backend_config - for tag, info in self.backends.items() - } - - def get_backend(self, backend_tag: BackendTag) -> Optional[BackendInfo]: - return self.backends.get(backend_tag) - - def add_backend(self, - backend_tag: BackendTag, - backend_info: BackendInfo, - goal_id: GoalId = 0) -> None: - self.backends[backend_tag] = backend_info - - @dataclass class ActorStateReconciler: controller_name: str = field(init=True) detached: bool = field(init=True) - http_proxy_cache: Dict[NodeId, ActorHandle] = field(default_factory=dict) backend_replicas: Dict[BackendTag, Dict[ReplicaTag, ActorHandle]] = field( default_factory=lambda: defaultdict(dict)) backend_replicas_to_start: Dict[BackendTag, List[ReplicaTag]] = field( @@ -184,9 +254,6 @@ class ActorStateReconciler: # TODO(edoakes): consider removing this and just using the names. - def http_proxy_handles(self) -> List[ActorHandle]: - return list(self.http_proxy_cache.values()) - def get_replica_handles(self) -> List[ActorHandle]: return list( chain.from_iterable([ @@ -389,70 +456,7 @@ class ActorStateReconciler: asyncio.sleep(1) - def _start_http_proxies_if_needed(self, http_host: str, http_port: str, - http_middlewares: List[Any]) -> None: - """Start an HTTP proxy on every node if it doesn't already exist.""" - if http_host is None: - return - - for node_id, node_resource in get_all_node_ids(): - if node_id in self.http_proxy_cache: - continue - - name = format_actor_name(SERVE_PROXY_NAME, self.controller_name, - node_id) - try: - proxy = ray.get_actor(name) - except ValueError: - logger.info("Starting HTTP proxy with name '{}' on node '{}' " - "listening on '{}:{}'".format( - name, node_id, http_host, http_port)) - proxy = HTTPProxyActor.options( - name=name, - lifetime="detached" if self.detached else None, - max_concurrency=ASYNC_CONCURRENCY, - max_restarts=-1, - max_task_retries=-1, - resources={ - node_resource: 0.01 - }, - ).remote( - http_host, - http_port, - controller_name=self.controller_name, - http_middlewares=http_middlewares) - - self.http_proxy_cache[node_id] = proxy - - def _stop_http_proxies_if_needed(self) -> bool: - """Removes HTTP proxy actors from any nodes that no longer exist. - - Returns whether or not any actors were removed (a checkpoint should - be taken). - """ - actor_stopped = False - all_node_ids = {node_id for node_id, _ in get_all_node_ids()} - to_stop = [] - for node_id in self.http_proxy_cache: - if node_id not in all_node_ids: - logger.info("Removing HTTP proxy on removed node '{}'.".format( - node_id)) - to_stop.append(node_id) - - for node_id in to_stop: - proxy = self.http_proxy_cache.pop(node_id) - ray.kill(proxy, no_restart=True) - actor_stopped = True - - return actor_stopped - def _recover_actor_handles(self) -> None: - # Refresh the RouterCache - for node_id in self.http_proxy_cache.keys(): - name = format_actor_name(SERVE_PROXY_NAME, self.controller_name, - node_id) - self.http_proxy_cache[node_id] = ray.get_actor(name) - # Fetch actor handles for all of the backend replicas in the system. # All of these backend_replicas are guaranteed to already exist because # they would not be written to a checkpoint in self.backend_replicas @@ -526,9 +530,7 @@ class ServeController: async def __init__(self, controller_name: str, - http_host: str, - http_port: str, - http_middlewares: List[Any], + http_config: HTTPConfig, detached: bool = False): # Used to read/write checkpoints. self.kv_store = RayInternalKVStore(namespace=controller_name) @@ -544,20 +546,14 @@ class ServeController: # at any given time. self.write_lock = asyncio.Lock() - self.http_host = http_host - self.http_port = http_port - self.http_middlewares = http_middlewares - - # If starting the actor for the first time, starts up the other system - # components. If recovering, fetches their actor handles. - self.actor_reconciler._start_http_proxies_if_needed( - self.http_host, self.http_port, self.http_middlewares) - # Map of awaiting results # TODO(ilr): Checkpoint this once this becomes asynchronous self.inflight_results: Dict[UUID, asyncio.Event] = dict() self._serializable_inflight_results: Dict[UUID, FutureResult] = dict() + # HTTP state doesn't currently require a checkpoint. + self.http_state = HTTPState(controller_name, detached, http_config) + checkpoint_bytes = self.kv_store.get(CHECKPOINT_KEY) if checkpoint_bytes is None: logger.debug("No checkpoint found") @@ -650,9 +646,9 @@ class ServeController: return await ( self.long_poll_host.listen_for_change(keys_to_snapshot_ids)) - def get_http_proxies(self) -> Dict[str, ActorHandle]: + def get_http_proxies(self) -> Dict[NodeId, ActorHandle]: """Returns a dictionary of node ID to http_proxy actor handles.""" - return self.actor_reconciler.http_proxy_cache + return self.http_state.get_http_proxy_handles() def _checkpoint(self) -> None: """Checkpoint internal state and write it to the KV store.""" @@ -737,12 +733,7 @@ class ServeController: while True: await self.do_autoscale() async with self.write_lock: - self.actor_reconciler._start_http_proxies_if_needed( - self.http_host, self.http_port, self.http_middlewares) - checkpoint_required = self.actor_reconciler.\ - _stop_http_proxies_if_needed() - if checkpoint_required: - self._checkpoint() + self.http_state.update() await asyncio.sleep(CONTROL_LOOP_PERIOD_S) @@ -1057,13 +1048,13 @@ class ServeController: def get_http_config(self): """Return the HTTP proxy configuration.""" - return self.http_host, self.http_port + return self.http_state.get_config() async def shutdown(self) -> None: """Shuts down the serve instance completely.""" async with self.write_lock: - for http_proxy in self.actor_reconciler.http_proxy_handles(): - ray.kill(http_proxy, no_restart=True) + for proxy in self.http_state.get_http_proxy_handles().values(): + ray.kill(proxy, no_restart=True) for replica in self.actor_reconciler.get_replica_handles(): ray.kill(replica, no_restart=True) self.kv_store.delete(CHECKPOINT_KEY) From 4461f9980a53f6418d88bbb2a788107201410ff3 Mon Sep 17 00:00:00 2001 From: Stephanie Wang Date: Wed, 23 Dec 2020 18:36:00 -0800 Subject: [PATCH 82/88] Refactor TaskDependencyManager, allow passing bundles of objects to ObjectManager (#13006) * New dependency manager * Switch raylet to new DependencyManager * PullManager accepts bundles * Cleanup, remove old task dependency manager * x * PullManager unit tests * lint * Unit tests * Rename * lint * test * Update src/ray/raylet/dependency_manager.cc Co-authored-by: SangBin Cho * Update src/ray/raylet/dependency_manager.cc Co-authored-by: SangBin Cho * x * lint Co-authored-by: SangBin Cho --- BUILD.bazel | 4 +- src/ray/common/common_protocol.h | 22 + src/ray/object_manager/object_manager.cc | 56 +- src/ray/object_manager/object_manager.h | 25 +- src/ray/object_manager/pull_manager.cc | 121 ++-- src/ray/object_manager/pull_manager.h | 74 ++- .../test/object_manager_stress_test.cc | 21 +- .../object_manager/test/pull_manager_test.cc | 157 ++++- src/ray/raylet/dependency_manager.cc | 311 ++++++++++ src/ray/raylet/dependency_manager.h | 270 +++++++++ src/ray/raylet/dependency_manager_test.cc | 372 ++++++++++++ src/ray/raylet/node_manager.cc | 135 ++--- src/ray/raylet/node_manager.h | 19 +- .../raylet/scheduling/cluster_task_manager.cc | 15 +- .../raylet/scheduling/cluster_task_manager.h | 4 +- .../scheduling/cluster_task_manager_test.cc | 6 +- src/ray/raylet/task_dependency_manager.cc | 474 --------------- src/ray/raylet/task_dependency_manager.h | 260 -------- .../raylet/task_dependency_manager_test.cc | 559 ------------------ 19 files changed, 1339 insertions(+), 1566 deletions(-) create mode 100644 src/ray/raylet/dependency_manager.cc create mode 100644 src/ray/raylet/dependency_manager.h create mode 100644 src/ray/raylet/dependency_manager_test.cc delete mode 100644 src/ray/raylet/task_dependency_manager.cc delete mode 100644 src/ray/raylet/task_dependency_manager.h delete mode 100644 src/ray/raylet/task_dependency_manager_test.cc diff --git a/BUILD.bazel b/BUILD.bazel index 8782dbdf8..669d11a67 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -956,8 +956,8 @@ cc_test( ) cc_test( - name = "task_dependency_manager_test", - srcs = ["src/ray/raylet/task_dependency_manager_test.cc"], + name = "dependency_manager_test", + srcs = ["src/ray/raylet/dependency_manager_test.cc"], copts = COPTS, deps = [ ":raylet_lib", diff --git a/src/ray/common/common_protocol.h b/src/ray/common/common_protocol.h index 8ff224922..4ac5b5403 100644 --- a/src/ray/common/common_protocol.h +++ b/src/ray/common/common_protocol.h @@ -20,6 +20,7 @@ #include "ray/common/id.h" #include "ray/util/logging.h" +#include "src/ray/protobuf/common.pb.h" /// Convert an unique ID to a flatbuffer string. /// @@ -201,3 +202,24 @@ to_flatbuf(flatbuffers::FlatBufferBuilder &fbb, const std::unordered_set &id } return fbb.CreateVector(results); } + +static inline ray::rpc::ObjectReference ObjectIdToRef( + const ray::ObjectID &object_id, const ray::rpc::Address owner_address) { + ray::rpc::ObjectReference ref; + ref.set_object_id(object_id.Binary()); + ref.mutable_owner_address()->CopyFrom(owner_address); + return ref; +} + +static inline ray::ObjectID ObjectRefToId(const ray::rpc::ObjectReference &object_ref) { + return ray::ObjectID::FromBinary(object_ref.object_id()); +} + +static inline std::vector ObjectRefsToIds( + const std::vector &object_refs) { + std::vector object_ids; + for (const auto &ref : object_refs) { + object_ids.push_back(ObjectRefToId(ref)); + } + return object_ids; +} diff --git a/src/ray/object_manager/object_manager.cc b/src/ray/object_manager/object_manager.cc index 760909dc0..8ae4d723e 100644 --- a/src/ray/object_manager/object_manager.cc +++ b/src/ray/object_manager/object_manager.cc @@ -103,8 +103,11 @@ ObjectManager::ObjectManager(asio::io_service &main_service, const NodeID &self_ [this](const object_manager::protocol::ObjectInfoT &object_info) { HandleObjectAdded(object_info); }); - store_notification_->SubscribeObjDeleted( - [this](const ObjectID &oid) { NotifyDirectoryObjectDeleted(oid); }); + store_notification_->SubscribeObjDeleted([this](const ObjectID &oid) { + // TODO(swang): We may want to force the pull manager to fetch this object + // again, in case it was needed by an active pull request. + NotifyDirectoryObjectDeleted(oid); + }); // Start object manager rpc server and send & receive request threads StartRpcService(); @@ -169,10 +172,6 @@ void ObjectManager::HandleObjectAdded( } unfulfilled_push_requests_.erase(iter); } - - // The object is local, so we no longer need to Pull it from a remote - // manager. Cancel any outstanding Pull requests for this object. - CancelPull(object_id); } void ObjectManager::NotifyDirectoryObjectDeleted(const ObjectID &object_id) { @@ -198,13 +197,9 @@ ray::Status ObjectManager::SubscribeObjDeleted( return ray::Status::OK(); } -ray::Status ObjectManager::Pull(const ObjectID &object_id, - const rpc::Address &owner_address) { - if (!pull_manager_->Pull(object_id, owner_address)) { - // If we don't need to pull, the object is either already local or this is a duplicate - // request. - return Status::OK(); - } +uint64_t ObjectManager::Pull(const std::vector &object_refs) { + std::vector objects_to_locate; + auto request_id = pull_manager_->Pull(object_refs, &objects_to_locate); const auto &callback = [this](const ObjectID &object_id, const std::unordered_set &client_ids, @@ -212,12 +207,25 @@ ray::Status ObjectManager::Pull(const ObjectID &object_id, pull_manager_->OnLocationChange(object_id, client_ids, spilled_url); }; - // Subscribe to object notifications. A notification will be received every - // time the set of node IDs for the object changes. Notifications will also - // be received if the list of locations is empty. The set of node IDs has - // no ordering guarantee between notifications. - return object_directory_->SubscribeObjectLocations(object_directory_pull_callback_id_, - object_id, owner_address, callback); + for (const auto &ref : objects_to_locate) { + // Subscribe to object notifications. A notification will be received every + // time the set of node IDs for the object changes. Notifications will also + // be received if the list of locations is empty. The set of node IDs has + // no ordering guarantee between notifications. + auto object_id = ObjectRefToId(ref); + RAY_CHECK_OK(object_directory_->SubscribeObjectLocations( + object_directory_pull_callback_id_, object_id, ref.owner_address(), callback)); + } + + return request_id; +} + +void ObjectManager::CancelPull(uint64_t request_id) { + const auto objects_to_cancel = pull_manager_->CancelPull(request_id); + for (const auto &object_id : objects_to_cancel) { + RAY_CHECK_OK(object_directory_->UnsubscribeObjectLocations( + object_directory_pull_callback_id_, object_id)); + } } void ObjectManager::SendPullRequest(const ObjectID &object_id, const NodeID &client_id) { @@ -426,16 +434,6 @@ void ObjectManager::SendObjectChunk(const UniqueID &push_id, const ObjectID &obj buffer_pool_.ReleaseGetChunk(object_id, chunk_info.chunk_index); } -void ObjectManager::CancelPull(const ObjectID &object_id) { - if (!pull_manager_->CancelPull(object_id)) { - // We weren't tracking a pull request for this object, so there is nothing to cancel. - return; - } - - RAY_CHECK_OK(object_directory_->UnsubscribeObjectLocations( - object_directory_pull_callback_id_, object_id)); -} - ray::Status ObjectManager::Wait( const std::vector &object_ids, const std::unordered_map &owner_addresses, int64_t timeout_ms, diff --git a/src/ray/object_manager/object_manager.h b/src/ray/object_manager/object_manager.h index fdca6a190..3e793e21c 100644 --- a/src/ray/object_manager/object_manager.h +++ b/src/ray/object_manager/object_manager.h @@ -43,6 +43,7 @@ #include "ray/object_manager/push_manager.h" #include "ray/rpc/object_manager/object_manager_client.h" #include "ray/rpc/object_manager/object_manager_server.h" +#include "src/ray/protobuf/common.pb.h" namespace ray { @@ -95,9 +96,8 @@ class ObjectStoreRunner { class ObjectManagerInterface { public: - virtual ray::Status Pull(const ObjectID &object_id, - const rpc::Address &owner_address) = 0; - virtual void CancelPull(const ObjectID &object_id) = 0; + virtual uint64_t Pull(const std::vector &object_refs) = 0; + virtual void CancelPull(uint64_t request_id) = 0; virtual ~ObjectManagerInterface(){}; }; @@ -238,18 +238,19 @@ class ObjectManager : public ObjectManagerInterface, /// \return Void. void Push(const ObjectID &object_id, const NodeID &node_id); - /// Pull an object from NodeID. + /// Pull a bundle of objects. This will attempt to make all objects in the + /// bundle local until the request is canceled with the returned ID. /// - /// \param object_id The object's object id. - /// \return Status of whether the pull request successfully initiated. - ray::Status Pull(const ObjectID &object_id, const rpc::Address &owner_address) override; + /// \param object_refs The bundle of objects that must be made local. + /// \return A request ID that can be used to cancel the request. + uint64_t Pull(const std::vector &object_refs) override; - /// Cancels all requests (Push/Pull) associated with the given ObjectID. This - /// method is idempotent. + /// Cancels the pull request with the given ID. This cancels any fetches for + /// objects that were passed to the original pull request, if no other pull + /// request requires them. /// - /// \param object_id The ObjectID. - /// \return Void. - void CancelPull(const ObjectID &object_id) override; + /// \param pull_request_id The request to cancel. + void CancelPull(uint64_t pull_request_id) override; /// Callback definition for wait. using WaitCallback = std::function &found, diff --git a/src/ray/object_manager/pull_manager.cc b/src/ray/object_manager/pull_manager.cc index 8ced0f51b..c9fa13177 100644 --- a/src/ray/object_manager/pull_manager.cc +++ b/src/ray/object_manager/pull_manager.cc @@ -1,5 +1,7 @@ #include "ray/object_manager/pull_manager.h" +#include "ray/common/common_protocol.h" + namespace ray { PullManager::PullManager( @@ -15,29 +17,56 @@ PullManager::PullManager( pull_timeout_ms_(pull_timeout_ms), gen_(std::chrono::high_resolution_clock::now().time_since_epoch().count()) {} -bool PullManager::Pull(const ObjectID &object_id, const rpc::Address &owner_address) { - RAY_LOG(DEBUG) << "Pull " - << " of object " << object_id; - // Check if object is already local. - if (object_is_local_(object_id)) { - RAY_LOG(DEBUG) << object_id << " attempted to pull an object that's already local."; - return false; - } - if (pull_requests_.find(object_id) != pull_requests_.end()) { - RAY_LOG(DEBUG) << object_id << " has inflight pull_requests, skipping."; - return false; +uint64_t PullManager::Pull(const std::vector &object_ref_bundle, + std::vector *objects_to_locate) { + auto bundle_it = pull_request_bundles_.emplace(next_req_id_++, object_ref_bundle).first; + RAY_LOG(DEBUG) << "Start pull request " << bundle_it->first; + + for (const auto &ref : object_ref_bundle) { + auto obj_id = ObjectRefToId(ref); + auto it = object_pull_requests_.find(obj_id); + if (it == object_pull_requests_.end()) { + RAY_LOG(DEBUG) << "Pull of object " << obj_id; + // We don't have an active pull for this object yet. Ask the caller to + // send us notifications about the object's location. + objects_to_locate->push_back(ref); + it = object_pull_requests_ + .emplace(obj_id, ObjectPullRequest(get_time_() + pull_timeout_ms_ / 1000)) + .first; + } + it->second.bundle_request_ids.insert(bundle_it->first); } - pull_requests_.emplace(object_id, PullRequest(get_time_() + pull_timeout_ms_ / 1000)); - return true; + return bundle_it->first; +} + +std::vector PullManager::CancelPull(uint64_t request_id) { + std::vector objects_to_cancel; + RAY_LOG(DEBUG) << "Cancel pull request " << request_id; + auto bundle_it = pull_request_bundles_.find(request_id); + RAY_CHECK(bundle_it != pull_request_bundles_.end()); + + for (const auto &ref : bundle_it->second) { + auto obj_id = ObjectRefToId(ref); + auto it = object_pull_requests_.find(obj_id); + RAY_CHECK(it != object_pull_requests_.end()); + RAY_CHECK(it->second.bundle_request_ids.erase(request_id)); + if (it->second.bundle_request_ids.empty()) { + object_pull_requests_.erase(it); + objects_to_cancel.push_back(obj_id); + } + } + + pull_request_bundles_.erase(bundle_it); + return objects_to_cancel; } void PullManager::OnLocationChange(const ObjectID &object_id, const std::unordered_set &client_ids, const std::string &spilled_url) { // Exit if the Pull request has already been fulfilled or canceled. - auto it = pull_requests_.find(object_id); - if (it == pull_requests_.end()) { + auto it = object_pull_requests_.find(object_id); + if (it == object_pull_requests_.end()) { return; } // Reset the list of clients that are now expected to have the object. @@ -45,28 +74,50 @@ void PullManager::OnLocationChange(const ObjectID &object_id, // we may end up sending a duplicate request to the same client as // before. it->second.client_locations = std::vector(client_ids.begin(), client_ids.end()); - if (!spilled_url.empty()) { - RAY_LOG(DEBUG) << "OnLocationChange " << spilled_url << " num clients " - << client_ids.size(); + it->second.spilled_url = spilled_url; + RAY_LOG(DEBUG) << "OnLocationChange " << spilled_url << " num clients " + << client_ids.size(); + + TryToMakeObjectLocal(object_id); +} + +void PullManager::TryToMakeObjectLocal(const ObjectID &object_id) { + if (object_is_local_(object_id)) { + return; + } + auto it = object_pull_requests_.find(object_id); + if (it == object_pull_requests_.end()) { + return; + } + + auto &request = it->second; + if (!request.spilled_url.empty()) { // Try to restore the spilled object. - restore_spilled_object_(object_id, spilled_url, + restore_spilled_object_(object_id, request.spilled_url, [this, object_id](const ray::Status &status) { // Fall back to fetching from another object manager. if (!status.ok()) { - TryPull(object_id); + PullFromRandomLocation(object_id); } }); } else { // New object locations were found, so begin trying to pull from a // client. This will be called every time a new client location // appears. - TryPull(object_id); + PullFromRandomLocation(object_id); } + + const auto time = get_time_(); + auto retry_timeout_len = (pull_timeout_ms_ / 1000.) * (1UL << request.num_retries); + request.next_pull_time = time + retry_timeout_len; + + // Bound the retry time at 10 * 1024 seconds. + request.num_retries = std::min(request.num_retries + 1, 10); } -void PullManager::TryPull(const ObjectID &object_id) { - auto it = pull_requests_.find(object_id); - if (it == pull_requests_.end()) { +void PullManager::PullFromRandomLocation(const ObjectID &object_id) { + auto it = object_pull_requests_.find(object_id); + if (it == object_pull_requests_.end()) { return; } @@ -111,36 +162,20 @@ void PullManager::TryPull(const ObjectID &object_id) { RAY_LOG(DEBUG) << "Sending pull request from " << self_node_id_ << " to " << node_id << " of object " << object_id; - const auto time = get_time_(); - auto &request = it->second; - auto retry_timeout_len = (pull_timeout_ms_ / 1000.) * (1UL << request.num_retries); - request.next_pull_time = time + retry_timeout_len; send_pull_request_(object_id, node_id); } -bool PullManager::CancelPull(const ObjectID &object_id) { - auto it = pull_requests_.find(object_id); - if (it == pull_requests_.end()) { - return false; - } - - pull_requests_.erase(it); - return true; -} - void PullManager::Tick() { - for (auto &pair : pull_requests_) { + for (auto &pair : object_pull_requests_) { const auto &object_id = pair.first; auto &request = pair.second; const auto time = get_time_(); if (time >= request.next_pull_time) { - TryPull(object_id); - // Bound the retry time at 10 * 1024 seconds. - request.num_retries = std::min(request.num_retries + 1, 10); + TryToMakeObjectLocal(object_id); } } } -int PullManager::NumActiveRequests() const { return pull_requests_.size(); } +int PullManager::NumActiveRequests() const { return object_pull_requests_.size(); } } // namespace ray diff --git a/src/ray/object_manager/pull_manager.h b/src/ray/object_manager/pull_manager.h index f312af17a..33710a18e 100644 --- a/src/ray/object_manager/pull_manager.h +++ b/src/ray/object_manager/pull_manager.h @@ -42,33 +42,32 @@ class PullManager { const RestoreSpilledObjectCallback restore_spilled_object, const std::function get_time, int pull_timeout_ms); - /// Begin a new pull request if necessary. + /// Begin a new pull request for a bundle of objects. /// - /// \param object_id The object id to pull. - /// \param owner_address The owner of the object. - /// - /// \return True if a new pull request was necessary. If true, the caller should - /// subscribe to new locations of the object, and call OnLocationChange when necessary. - bool Pull(const ObjectID &object_id, const rpc::Address &owner_address); + /// \param object_refs The bundle of objects that must be made local. + /// \param objects_to_locate The objects whose new locations the caller + /// should subscribe to, and call OnLocationChange for. + /// \return A request ID that can be used to cancel the request. + uint64_t Pull(const std::vector &object_ref_bundle, + std::vector *objects_to_locate); /// Called when the available locations for a given object change. /// /// \param object_id The ID of the object which is now available in a new location. /// \param client_ids The new set of nodes that the object is available on. Not - /// necessarily a super or subset of the previously available nodes. \param spilled_url - /// The location of the object if it was spilled. If non-empty, the object may no longer - /// be on any node. + /// necessarily a super or subset of the previously available nodes. + /// \param spilled_url The location of the object if it was spilled. If + /// non-empty, the object may no longer be on any node. void OnLocationChange(const ObjectID &object_id, const std::unordered_set &client_ids, const std::string &spilled_url); - /// Cancel an existing pull request if necessary. + /// Cancel an existing pull request. /// - /// \param object_id The object id that no longer needs to be pulled. - /// - /// \return True if a pull was cancelled. If there was no pending pull request for the - /// object this method may return false. - bool CancelPull(const ObjectID &object_id); + /// \param request_id The request ID returned by Pull that should be canceled. + /// \return The objects for which the caller should stop subscribing to + /// locations. + std::vector CancelPull(uint64_t request_id); /// Called when the retry timer fires. If this fires, the pull manager may try to pull /// existing objects from other nodes if necessary. @@ -79,14 +78,32 @@ class PullManager { private: /// A helper structure for tracking information about each ongoing object pull. - struct PullRequest { - PullRequest(double first_retry_time) - : client_locations(), next_pull_time(first_retry_time), num_retries(0) {} + struct ObjectPullRequest { + ObjectPullRequest(double first_retry_time) + : client_locations(), + spilled_url(), + next_pull_time(first_retry_time), + num_retries(0), + bundle_request_ids() {} std::vector client_locations; + std::string spilled_url; double next_pull_time; uint8_t num_retries; + absl::flat_hash_set bundle_request_ids; }; + /// Try to make an object local, by restoring the object from external + /// storage or by fetching the object from one of its expected client + /// locations. This does nothing if the object is not needed by any pull + /// request or if it is already local. This also sets a timeout for when to + /// make the next attempt to make the object local. + void TryToMakeObjectLocal(const ObjectID &object_id); + + /// Try to Pull an object from one of its expected client locations. If there + /// are more client locations to try after this attempt, then this method + /// will try each of the other clients in succession. + void PullFromRandomLocation(const ObjectID &object_id); + /// See the constructor's arguments. NodeID self_node_id_; const std::function object_is_local_; @@ -95,22 +112,17 @@ class PullManager { const std::function get_time_; uint64_t pull_timeout_ms_; + /// The next ID to assign to a bundle pull request, so that the caller can + /// cancel. Start at 1 because 0 means null. + uint64_t next_req_id_ = 1; + + std::unordered_map> pull_request_bundles_; + /// The objects that this object manager is currently trying to fetch from /// remote object managers. - std::unordered_map pull_requests_; + std::unordered_map object_pull_requests_; /// Internally maintained random number generator. std::mt19937_64 gen_; - - /// Try to Pull an object from one of its expected client locations. If there - /// are more client locations to try after this attempt, then this method - /// will try each of the other clients in succession, with a timeout between - /// each attempt. If the object is received or if the Pull is Canceled before - /// the timeout, then no more Pull requests for this object will be sent - /// to other node managers until TryPull is called again. - /// - /// \param object_id The object's object id. - /// \return Void. - void TryPull(const ObjectID &object_id); }; } // namespace ray diff --git a/src/ray/object_manager/test/object_manager_stress_test.cc b/src/ray/object_manager/test/object_manager_stress_test.cc index 018bc357b..6a55d6180 100644 --- a/src/ray/object_manager/test/object_manager_stress_test.cc +++ b/src/ray/object_manager/test/object_manager_stress_test.cc @@ -18,6 +18,7 @@ #include #include "gtest/gtest.h" +#include "ray/common/common_protocol.h" #include "ray/common/status.h" #include "ray/common/test_util.h" #include "ray/gcs/gcs_client/service_based_gcs_client.h" @@ -338,8 +339,6 @@ class StressTestObjectManager : public TestObjectManagerBase { NodeID node_id_1 = gcs_client_1->Nodes().GetSelfId(); NodeID node_id_2 = gcs_client_2->Nodes().GetSelfId(); - ray::Status status = ray::Status::OK(); - if (transfer_pattern == TransferPattern::BIDIRECTIONAL_PULL || transfer_pattern == TransferPattern::BIDIRECTIONAL_PUSH || transfer_pattern == TransferPattern::BIDIRECTIONAL_PULL_VARIABLE_DATA_SIZE) { @@ -374,21 +373,25 @@ class StressTestObjectManager : public TestObjectManagerBase { case TransferPattern::PULL_A_B: { for (int i = -1; ++i < num_trials;) { ObjectID oid1 = WriteDataToClient(client1, data_size); - status = server2->object_manager_.Pull(oid1, rpc::Address()); + static_cast( + server2->object_manager_.Pull({ObjectIdToRef(oid1, rpc::Address())})); } } break; case TransferPattern::PULL_B_A: { for (int i = -1; ++i < num_trials;) { ObjectID oid2 = WriteDataToClient(client2, data_size); - status = server1->object_manager_.Pull(oid2, rpc::Address()); + static_cast( + server1->object_manager_.Pull({ObjectIdToRef(oid2, rpc::Address())})); } } break; case TransferPattern::BIDIRECTIONAL_PULL: { for (int i = -1; ++i < num_trials;) { ObjectID oid1 = WriteDataToClient(client1, data_size); - status = server2->object_manager_.Pull(oid1, rpc::Address()); + static_cast( + server2->object_manager_.Pull({ObjectIdToRef(oid1, rpc::Address())})); ObjectID oid2 = WriteDataToClient(client2, data_size); - status = server1->object_manager_.Pull(oid2, rpc::Address()); + static_cast( + server1->object_manager_.Pull({ObjectIdToRef(oid2, rpc::Address())})); } } break; case TransferPattern::BIDIRECTIONAL_PULL_VARIABLE_DATA_SIZE: { @@ -397,9 +400,11 @@ class StressTestObjectManager : public TestObjectManagerBase { std::uniform_int_distribution<> dis(1, 50); for (int i = -1; ++i < num_trials;) { ObjectID oid1 = WriteDataToClient(client1, data_size + dis(gen)); - status = server2->object_manager_.Pull(oid1, rpc::Address()); + static_cast( + server2->object_manager_.Pull({ObjectIdToRef(oid1, rpc::Address())})); ObjectID oid2 = WriteDataToClient(client2, data_size + dis(gen)); - status = server1->object_manager_.Pull(oid2, rpc::Address()); + static_cast( + server1->object_manager_.Pull({ObjectIdToRef(oid2, rpc::Address())})); } } break; default: { diff --git a/src/ray/object_manager/test/pull_manager_test.cc b/src/ray/object_manager/test/pull_manager_test.cc index fb7b1c1c2..21e41f874 100644 --- a/src/ray/object_manager/test/pull_manager_test.cc +++ b/src/ray/object_manager/test/pull_manager_test.cc @@ -1,11 +1,15 @@ #include "ray/object_manager/pull_manager.h" +#include "gmock/gmock.h" #include "gtest/gtest.h" +#include "ray/common/common_protocol.h" #include "ray/common/test_util.h" namespace ray { +using ::testing::ElementsAre; + class PullManagerTest : public ::testing::Test { public: PullManagerTest() @@ -33,28 +37,42 @@ class PullManagerTest : public ::testing::Test { PullManager pull_manager_; }; +std::vector CreateObjectRefs(int num_objs) { + std::vector refs; + for (int i = 0; i < num_objs; i++) { + ObjectID obj = ObjectID::FromRandom(); + rpc::ObjectReference ref; + ref.set_object_id(obj.Binary()); + refs.push_back(ref); + } + return refs; +} + TEST_F(PullManagerTest, TestStaleSubscription) { - ObjectID obj1 = ObjectID::FromRandom(); - rpc::Address addr1; + auto refs = CreateObjectRefs(1); + auto oid = ObjectRefsToIds(refs)[0]; ASSERT_EQ(pull_manager_.NumActiveRequests(), 0); - pull_manager_.Pull(obj1, addr1); + std::vector objects_to_locate; + auto req_id = pull_manager_.Pull(refs, &objects_to_locate); + ASSERT_EQ(ObjectRefsToIds(objects_to_locate), ObjectRefsToIds(refs)); ASSERT_EQ(pull_manager_.NumActiveRequests(), 1); std::unordered_set client_ids; - pull_manager_.OnLocationChange(obj1, client_ids, ""); + pull_manager_.OnLocationChange(oid, client_ids, ""); // There are no client ids to pull from. ASSERT_EQ(num_send_pull_request_calls_, 0); ASSERT_EQ(num_restore_spilled_object_calls_, 0); - pull_manager_.CancelPull(obj1); + auto objects_to_cancel = pull_manager_.CancelPull(req_id); + ASSERT_EQ(objects_to_cancel, ObjectRefsToIds(refs)); ASSERT_EQ(num_send_pull_request_calls_, 0); ASSERT_EQ(num_restore_spilled_object_calls_, 0); ASSERT_EQ(pull_manager_.NumActiveRequests(), 0); client_ids.insert(NodeID::FromRandom()); - pull_manager_.OnLocationChange(obj1, client_ids, ""); + pull_manager_.OnLocationChange(oid, client_ids, ""); // Now we're getting a notification about an object that was already cancelled. ASSERT_EQ(num_send_pull_request_calls_, 0); @@ -63,10 +81,13 @@ TEST_F(PullManagerTest, TestStaleSubscription) { } TEST_F(PullManagerTest, TestRestoreSpilledObject) { - ObjectID obj1 = ObjectID::FromRandom(); + auto refs = CreateObjectRefs(1); + auto obj1 = ObjectRefsToIds(refs)[0]; rpc::Address addr1; ASSERT_EQ(pull_manager_.NumActiveRequests(), 0); - pull_manager_.Pull(obj1, addr1); + std::vector objects_to_locate; + auto req_id = pull_manager_.Pull(refs, &objects_to_locate); + ASSERT_EQ(ObjectRefsToIds(objects_to_locate), ObjectRefsToIds(refs)); ASSERT_EQ(pull_manager_.NumActiveRequests(), 1); std::unordered_set client_ids; @@ -84,15 +105,25 @@ TEST_F(PullManagerTest, TestRestoreSpilledObject) { ASSERT_EQ(num_send_pull_request_calls_, 0); ASSERT_EQ(num_restore_spilled_object_calls_, 2); - pull_manager_.CancelPull(obj1); + // Don't restore an object if it's local. + object_is_local_ = true; + num_restore_spilled_object_calls_ = 0; + pull_manager_.OnLocationChange(obj1, client_ids, "remote_url/foo/bar"); + ASSERT_EQ(num_restore_spilled_object_calls_, 0); + + auto objects_to_cancel = pull_manager_.CancelPull(req_id); + ASSERT_EQ(objects_to_cancel, ObjectRefsToIds(refs)); ASSERT_EQ(pull_manager_.NumActiveRequests(), 0); } TEST_F(PullManagerTest, TestManyUpdates) { - ObjectID obj1 = ObjectID::FromRandom(); + auto refs = CreateObjectRefs(1); + auto obj1 = ObjectRefsToIds(refs)[0]; rpc::Address addr1; ASSERT_EQ(pull_manager_.NumActiveRequests(), 0); - pull_manager_.Pull(obj1, addr1); + std::vector objects_to_locate; + auto req_id = pull_manager_.Pull(refs, &objects_to_locate); + ASSERT_EQ(ObjectRefsToIds(objects_to_locate), ObjectRefsToIds(refs)); ASSERT_EQ(pull_manager_.NumActiveRequests(), 1); std::unordered_set client_ids; @@ -105,15 +136,19 @@ TEST_F(PullManagerTest, TestManyUpdates) { ASSERT_EQ(num_send_pull_request_calls_, 100); ASSERT_EQ(num_restore_spilled_object_calls_, 0); - pull_manager_.CancelPull(obj1); + auto objects_to_cancel = pull_manager_.CancelPull(req_id); + ASSERT_EQ(objects_to_cancel, ObjectRefsToIds(refs)); ASSERT_EQ(pull_manager_.NumActiveRequests(), 0); } TEST_F(PullManagerTest, TestRetryTimer) { - ObjectID obj1 = ObjectID::FromRandom(); + auto refs = CreateObjectRefs(1); + auto obj1 = ObjectRefsToIds(refs)[0]; rpc::Address addr1; ASSERT_EQ(pull_manager_.NumActiveRequests(), 0); - pull_manager_.Pull(obj1, addr1); + std::vector objects_to_locate; + auto req_id = pull_manager_.Pull(refs, &objects_to_locate); + ASSERT_EQ(ObjectRefsToIds(objects_to_locate), ObjectRefsToIds(refs)); ASSERT_EQ(pull_manager_.NumActiveRequests(), 1); std::unordered_set client_ids; @@ -143,26 +178,102 @@ TEST_F(PullManagerTest, TestRetryTimer) { ASSERT_EQ(num_send_pull_request_calls_, 1 + 7 + 127); ASSERT_EQ(num_restore_spilled_object_calls_, 0); - pull_manager_.CancelPull(obj1); + // Don't retry an object if it's local. + object_is_local_ = true; + num_send_pull_request_calls_ = 0; + for (; fake_time_ <= 127 * 10; fake_time_ += 1.) { + pull_manager_.Tick(); + } + ASSERT_EQ(num_send_pull_request_calls_, 0); + + auto objects_to_cancel = pull_manager_.CancelPull(req_id); + ASSERT_EQ(objects_to_cancel, ObjectRefsToIds(refs)); ASSERT_EQ(pull_manager_.NumActiveRequests(), 0); } TEST_F(PullManagerTest, TestBasic) { - ObjectID obj1 = ObjectID::FromRandom(); - rpc::Address addr1; + auto refs = CreateObjectRefs(3); + auto oids = ObjectRefsToIds(refs); ASSERT_EQ(pull_manager_.NumActiveRequests(), 0); - pull_manager_.Pull(obj1, addr1); - ASSERT_EQ(pull_manager_.NumActiveRequests(), 1); + std::vector objects_to_locate; + auto req_id = pull_manager_.Pull(refs, &objects_to_locate); + ASSERT_EQ(ObjectRefsToIds(objects_to_locate), oids); + ASSERT_EQ(pull_manager_.NumActiveRequests(), oids.size()); std::unordered_set client_ids; client_ids.insert(NodeID::FromRandom()); - pull_manager_.OnLocationChange(obj1, client_ids, ""); + for (size_t i = 0; i < oids.size(); i++) { + pull_manager_.OnLocationChange(oids[i], client_ids, ""); + ASSERT_EQ(num_send_pull_request_calls_, i + 1); + ASSERT_EQ(num_restore_spilled_object_calls_, 0); + } - ASSERT_EQ(num_send_pull_request_calls_, 1); - ASSERT_EQ(num_restore_spilled_object_calls_, 0); + // Don't pull an object if it's local. + object_is_local_ = true; + num_send_pull_request_calls_ = 0; + for (size_t i = 0; i < oids.size(); i++) { + pull_manager_.OnLocationChange(oids[i], client_ids, ""); + } + ASSERT_EQ(num_send_pull_request_calls_, 0); - pull_manager_.CancelPull(obj1); + auto objects_to_cancel = pull_manager_.CancelPull(req_id); + ASSERT_EQ(objects_to_cancel, oids); ASSERT_EQ(pull_manager_.NumActiveRequests(), 0); + + // Don't pull a remote object if we've canceled. + object_is_local_ = false; + num_send_pull_request_calls_ = 0; + for (size_t i = 0; i < oids.size(); i++) { + pull_manager_.OnLocationChange(oids[i], client_ids, ""); + } + ASSERT_EQ(num_send_pull_request_calls_, 0); +} + +TEST_F(PullManagerTest, TestDeduplicateBundles) { + auto refs = CreateObjectRefs(3); + auto oids = ObjectRefsToIds(refs); + ASSERT_EQ(pull_manager_.NumActiveRequests(), 0); + std::vector objects_to_locate; + auto req_id1 = pull_manager_.Pull(refs, &objects_to_locate); + ASSERT_EQ(ObjectRefsToIds(objects_to_locate), oids); + ASSERT_EQ(pull_manager_.NumActiveRequests(), oids.size()); + + objects_to_locate.clear(); + auto req_id2 = pull_manager_.Pull(refs, &objects_to_locate); + ASSERT_TRUE(objects_to_locate.empty()); + + std::unordered_set client_ids; + client_ids.insert(NodeID::FromRandom()); + for (size_t i = 0; i < oids.size(); i++) { + pull_manager_.OnLocationChange(oids[i], client_ids, ""); + ASSERT_EQ(num_send_pull_request_calls_, i + 1); + ASSERT_EQ(num_restore_spilled_object_calls_, 0); + } + + // Cancel one request. + auto objects_to_cancel = pull_manager_.CancelPull(req_id1); + ASSERT_TRUE(objects_to_cancel.empty()); + // Objects should still be pulled because the other request is still open. + ASSERT_EQ(pull_manager_.NumActiveRequests(), oids.size()); + num_send_pull_request_calls_ = 0; + for (size_t i = 0; i < oids.size(); i++) { + pull_manager_.OnLocationChange(oids[i], client_ids, ""); + ASSERT_EQ(num_send_pull_request_calls_, i + 1); + ASSERT_EQ(num_restore_spilled_object_calls_, 0); + } + + // Cancel the other request. + objects_to_cancel = pull_manager_.CancelPull(req_id2); + ASSERT_EQ(objects_to_cancel, oids); + ASSERT_EQ(pull_manager_.NumActiveRequests(), 0); + + // Don't pull a remote object if we've canceled. + object_is_local_ = false; + num_send_pull_request_calls_ = 0; + for (size_t i = 0; i < oids.size(); i++) { + pull_manager_.OnLocationChange(oids[i], client_ids, ""); + } + ASSERT_EQ(num_send_pull_request_calls_, 0); } } // namespace ray diff --git a/src/ray/raylet/dependency_manager.cc b/src/ray/raylet/dependency_manager.cc new file mode 100644 index 000000000..988893bea --- /dev/null +++ b/src/ray/raylet/dependency_manager.cc @@ -0,0 +1,311 @@ +#include "ray/raylet/dependency_manager.h" + +namespace ray { + +namespace raylet { + +bool DependencyManager::CheckObjectLocal(const ObjectID &object_id) const { + return local_objects_.count(object_id) == 1; +} + +bool DependencyManager::GetOwnerAddress(const ObjectID &object_id, + rpc::Address *owner_address) const { + auto obj = required_objects_.find(object_id); + if (obj == required_objects_.end()) { + return false; + } + + *owner_address = obj->second.owner_address; + return !owner_address->worker_id().empty(); +} + +void DependencyManager::RemoveObjectIfNotNeeded( + absl::flat_hash_map::iterator + required_object_it) { + const auto &object_id = required_object_it->first; + if (required_object_it->second.Empty()) { + RAY_LOG(DEBUG) << "Object " << object_id << " no longer needed"; + if (required_object_it->second.wait_request_id > 0) { + RAY_LOG(DEBUG) << "Canceling pull for wait request of object " << object_id + << " request: " << required_object_it->second.wait_request_id; + object_manager_.CancelPull(required_object_it->second.wait_request_id); + } + if (!local_objects_.count(object_id)) { + reconstruction_policy_.Cancel(object_id); + } + required_objects_.erase(required_object_it); + } +} + +absl::flat_hash_map::iterator +DependencyManager::GetOrInsertRequiredObject(const ObjectID &object_id, + const rpc::ObjectReference &ref) { + auto it = required_objects_.find(object_id); + if (it == required_objects_.end()) { + it = required_objects_.emplace(object_id, ref).first; + if (local_objects_.count(object_id) == 0) { + reconstruction_policy_.ListenAndMaybeReconstruct(object_id, ref.owner_address()); + } + } + return it; +} + +void DependencyManager::StartOrUpdateWaitRequest( + const WorkerID &worker_id, + const std::vector &required_objects) { + RAY_LOG(DEBUG) << "Starting wait request for worker " << worker_id; + auto &wait_request = wait_requests_[worker_id]; + for (const auto &ref : required_objects) { + const auto obj_id = ObjectRefToId(ref); + if (local_objects_.count(obj_id)) { + // Object is already local. No need to fetch it. + continue; + } + + if (wait_request.insert(obj_id).second) { + RAY_LOG(DEBUG) << "Worker " << worker_id << " called ray.wait on non-local object " + << obj_id; + auto it = GetOrInsertRequiredObject(obj_id, ref); + it->second.dependent_wait_requests.insert(worker_id); + if (it->second.wait_request_id == 0) { + it->second.wait_request_id = object_manager_.Pull({ref}); + RAY_LOG(DEBUG) << "Started pull for wait request for object " << obj_id + << " request: " << it->second.wait_request_id; + } + } + } + + // No new objects to wait on. Delete the empty entry that was created. + if (wait_request.empty()) { + wait_requests_.erase(worker_id); + } +} + +void DependencyManager::CancelWaitRequest(const WorkerID &worker_id) { + RAY_LOG(DEBUG) << "Canceling wait request for worker " << worker_id; + auto it = wait_requests_.find(worker_id); + if (it == wait_requests_.end()) { + return; + } + + for (const auto &obj_id : it->second) { + auto it = required_objects_.find(obj_id); + RAY_CHECK(it != required_objects_.end()); + it->second.dependent_wait_requests.erase(worker_id); + RemoveObjectIfNotNeeded(it); + } + + wait_requests_.erase(it); +} + +void DependencyManager::StartOrUpdateGetRequest( + const WorkerID &worker_id, + const std::vector &required_objects) { + RAY_LOG(DEBUG) << "Starting get request for worker " << worker_id; + auto &get_request = get_requests_[worker_id]; + bool modified = false; + for (const auto &ref : required_objects) { + const auto obj_id = ObjectRefToId(ref); + if (get_request.first.insert(obj_id).second) { + RAY_LOG(DEBUG) << "Worker " << worker_id << " called ray.get on object " << obj_id; + auto it = GetOrInsertRequiredObject(obj_id, ref); + it->second.dependent_get_requests.insert(worker_id); + modified = true; + } + } + + if (modified) { + std::vector refs; + for (auto &obj_id : get_request.first) { + auto it = required_objects_.find(obj_id); + RAY_CHECK(it != required_objects_.end()); + refs.push_back(ObjectIdToRef(obj_id, it->second.owner_address)); + } + // Pull the new dependencies before canceling the old request, in case some + // of the old dependencies are still being fetched. + uint64_t new_request_id = object_manager_.Pull(refs); + if (get_request.second != 0) { + RAY_LOG(DEBUG) << "Canceling pull for get request from worker " << worker_id + << " request: " << get_request.second; + object_manager_.CancelPull(get_request.second); + } + get_request.second = new_request_id; + RAY_LOG(DEBUG) << "Started pull for get request from worker " << worker_id + << " request: " << get_request.second; + } +} + +void DependencyManager::CancelGetRequest(const WorkerID &worker_id) { + RAY_LOG(DEBUG) << "Canceling get request for worker " << worker_id; + auto it = get_requests_.find(worker_id); + if (it == get_requests_.end()) { + return; + } + + RAY_LOG(DEBUG) << "Canceling pull for get request from worker " << worker_id + << " request: " << it->second.second; + object_manager_.CancelPull(it->second.second); + + for (const auto &obj_id : it->second.first) { + auto it = required_objects_.find(obj_id); + RAY_CHECK(it != required_objects_.end()); + it->second.dependent_get_requests.erase(worker_id); + RemoveObjectIfNotNeeded(it); + } + + get_requests_.erase(it); +} + +/// Request dependencies for a queued task. +bool DependencyManager::RequestTaskDependencies( + const TaskID &task_id, const std::vector &required_objects) { + RAY_LOG(DEBUG) << "Adding dependencies for task " << task_id; + auto inserted = queued_task_requests_.emplace(task_id, required_objects); + RAY_CHECK(inserted.second) << "Task depedencies can be requested only once per task."; + auto &task_entry = inserted.first->second; + + for (const auto &ref : required_objects) { + const auto obj_id = ObjectRefToId(ref); + RAY_LOG(DEBUG) << "Task " << task_id << " blocked on object " << obj_id; + + auto it = GetOrInsertRequiredObject(obj_id, ref); + it->second.dependent_tasks.insert(task_id); + + if (local_objects_.count(obj_id)) { + task_entry.num_missing_dependencies--; + } + } + + if (!required_objects.empty()) { + task_entry.pull_request_id = object_manager_.Pull(required_objects); + RAY_LOG(DEBUG) << "Started pull for dependencies of task " << task_id + << " request: " << task_entry.pull_request_id; + } + + return task_entry.num_missing_dependencies == 0; +} + +bool DependencyManager::IsTaskReady(const TaskID &task_id) const { + auto task_entry = queued_task_requests_.find(task_id); + RAY_CHECK(task_entry != queued_task_requests_.end()); + return task_entry->second.num_missing_dependencies == 0; +} + +void DependencyManager::RemoveTaskDependencies(const TaskID &task_id) { + RAY_LOG(DEBUG) << "Removing dependencies for task " << task_id; + auto task_entry = queued_task_requests_.find(task_id); + RAY_CHECK(task_entry != queued_task_requests_.end()) + << "Can't remove dependencies of tasks that are not queued."; + + if (task_entry->second.pull_request_id > 0) { + RAY_LOG(DEBUG) << "Canceling pull for dependencies of task " << task_id + << " request: " << task_entry->second.pull_request_id; + object_manager_.CancelPull(task_entry->second.pull_request_id); + } + + for (const auto &obj_id : task_entry->second.dependencies) { + auto it = required_objects_.find(obj_id); + RAY_CHECK(it != required_objects_.end()); + it->second.dependent_tasks.erase(task_id); + RemoveObjectIfNotNeeded(it); + } + + queued_task_requests_.erase(task_entry); +} + +std::vector DependencyManager::HandleObjectMissing( + const ray::ObjectID &object_id) { + RAY_CHECK(local_objects_.erase(object_id)) + << "Evicted object was not local " << object_id; + + // Find any tasks that are dependent on the missing object. + std::vector waiting_task_ids; + auto object_entry = required_objects_.find(object_id); + if (object_entry != required_objects_.end()) { + for (auto &dependent_task_id : object_entry->second.dependent_tasks) { + auto it = queued_task_requests_.find(dependent_task_id); + RAY_CHECK(it != queued_task_requests_.end()); + auto &task_entry = it->second; + // If the dependent task had all of its arguments ready, it was ready to + // run but must be switched to waiting since one of its arguments is now + // missing. + if (task_entry.num_missing_dependencies == 0) { + waiting_task_ids.push_back(dependent_task_id); + // During normal execution we should be able to include the check + // RAY_CHECK(pending_tasks_.count(dependent_task_id) == 1); + // However, this invariant will not hold during unit test execution. + } + task_entry.num_missing_dependencies++; + } + + // The object is missing and needed so wait for a possible failure again. + reconstruction_policy_.ListenAndMaybeReconstruct(object_entry->first, + object_entry->second.owner_address); + } + + // Process callbacks for all of the tasks dependent on the object that are + // now ready to run. + return waiting_task_ids; +} + +std::vector DependencyManager::HandleObjectLocal(const ray::ObjectID &object_id) { + // Add the object to the table of locally available objects. + auto inserted = local_objects_.insert(object_id); + RAY_CHECK(inserted.second) << "Local object was already local " << object_id; + + // Find all tasks and workers that depend on the newly available object. + std::vector ready_task_ids; + auto object_entry = required_objects_.find(object_id); + if (object_entry != required_objects_.end()) { + // Loop through all tasks that depend on the newly available object. + for (const auto &dependent_task_id : object_entry->second.dependent_tasks) { + auto it = queued_task_requests_.find(dependent_task_id); + RAY_CHECK(it != queued_task_requests_.end()); + auto &task_entry = it->second; + task_entry.num_missing_dependencies--; + // If the dependent task now has all of its arguments ready, it's ready + // to run. + if (task_entry.num_missing_dependencies == 0) { + ready_task_ids.push_back(dependent_task_id); + } + } + + // Remove the dependency from all workers that called `ray.wait` on the + // newly available object. + for (const auto &worker_id : object_entry->second.dependent_wait_requests) { + auto worker_it = wait_requests_.find(worker_id); + RAY_CHECK(worker_it != wait_requests_.end()); + RAY_CHECK(worker_it->second.erase(object_id) > 0); + if (worker_it->second.empty()) { + wait_requests_.erase(worker_it); + } + } + // Clear all workers that called `ray.wait` on this object, since the + // `ray.wait` calls can now return the object as ready. + object_entry->second.dependent_wait_requests.clear(); + if (object_entry->second.wait_request_id > 0) { + RAY_LOG(DEBUG) << "Canceling pull for wait request of object " << object_id + << " request: " << object_entry->second.wait_request_id; + object_manager_.CancelPull(object_entry->second.wait_request_id); + object_entry->second.wait_request_id = 0; + } + reconstruction_policy_.Cancel(object_entry->first); + RemoveObjectIfNotNeeded(object_entry); + } + + return ready_task_ids; +} + +std::string DependencyManager::DebugString() const { + std::stringstream result; + result << "TaskDependencyManager:"; + result << "\n- task deps map size: " << queued_task_requests_.size(); + result << "\n- get req map size: " << get_requests_.size(); + result << "\n- wait req map size: " << wait_requests_.size(); + result << "\n- local objects map size: " << local_objects_.size(); + return result.str(); +} + +} // namespace raylet + +} // namespace ray diff --git a/src/ray/raylet/dependency_manager.h b/src/ray/raylet/dependency_manager.h new file mode 100644 index 000000000..1e7ddfcb1 --- /dev/null +++ b/src/ray/raylet/dependency_manager.h @@ -0,0 +1,270 @@ +// Copyright 2017 The Ray Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +// clang-format off +#include "ray/common/common_protocol.h" +#include "ray/common/id.h" +#include "ray/common/task/task.h" +#include "ray/object_manager/object_manager.h" +#include "ray/raylet/reconstruction_policy.h" +// clang-format on + +namespace ray { + +namespace raylet { + +using rpc::TaskLeaseData; + +class ReconstructionPolicy; + +/// Used for unit-testing the ClusterTaskManager, which requests dependencies +/// for queued tasks. +class TaskDependencyManagerInterface { + public: + virtual bool RequestTaskDependencies( + const TaskID &task_id, + const std::vector &required_objects) = 0; + virtual bool IsTaskReady(const TaskID &task_id) const = 0; + virtual void RemoveTaskDependencies(const TaskID &task_id) = 0; + virtual ~TaskDependencyManagerInterface(){}; +}; + +/// \class DependencyManager +/// +/// Responsible for managing object dependencies for local workers calling +/// `ray.get` or `ray.wait` and arguments of queued tasks. The caller can +/// request object dependencies for a task or worker. The task manager will +/// determine which object dependencies are remote and will request that these +/// objects be made available locally, either via the object manager or by +/// storing an error if the object is lost. +class DependencyManager : public TaskDependencyManagerInterface { + public: + /// Create a task dependency manager. + DependencyManager(ObjectManagerInterface &object_manager, + ReconstructionPolicyInterface &reconstruction_policy) + : object_manager_(object_manager), reconstruction_policy_(reconstruction_policy) {} + + /// Check whether an object is locally available. + /// + /// \param object_id The object to check for. + /// \return Whether the object is local. + bool CheckObjectLocal(const ObjectID &object_id) const; + + /// Get the address of the owner of this object. An address will only be + /// returned if the caller previously specified that this object is required + /// on this node, through a call to SubscribeGetDependencies or + /// SubscribeWaitDependencies. + /// + /// \param[in] object_id The object whose owner to get. + /// \param[out] owner_address The address of the object's owner, if + /// available. + /// \return True if we have owner information for the object. + bool GetOwnerAddress(const ObjectID &object_id, rpc::Address *owner_address) const; + + /// Start or update a worker's `ray.wait` request. This will attempt to make + /// any remote objects local, including previously requested objects. The + /// `ray.wait` request will stay active until the objects are made local or + /// the request for this worker is canceled, whichever occurs first. + /// + /// This method may be called multiple times per worker on the same objects. + /// + /// \param worker_id The ID of the worker that called `ray.wait`. + /// \param required_objects The objects required by the worker. + /// \return Void. + void StartOrUpdateWaitRequest( + const WorkerID &worker_id, + const std::vector &required_objects); + + /// Cancel a worker's `ray.wait` request. We will no longer attempt to fetch + /// any objects that this worker requested previously, if no other task or + /// worker requires them. + /// + /// \param worker_id The ID of the worker whose `ray.wait` request we should + /// cancel. + /// \return Void. + void CancelWaitRequest(const WorkerID &worker_id); + + /// Start or update a worker's `ray.get` request. This will attempt to make + /// any remote objects local, including previously requested objects. The + /// `ray.get` request will stay active until the request for this worker is + /// canceled. + /// + /// This method may be called multiple times per worker on the same objects. + /// + /// \param worker_id The ID of the worker that called `ray.wait`. + /// \param required_objects The objects required by the worker. + /// \return Void. + void StartOrUpdateGetRequest(const WorkerID &worker_id, + const std::vector &required_objects); + + /// Cancel a worker's `ray.get` request. We will no longer attempt to fetch + /// any objects that this worker requested previously, if no other task or + /// worker requires them. + /// + /// \param worker_id The ID of the worker whose `ray.get` request we should + /// cancel. + /// \return Void. + void CancelGetRequest(const WorkerID &worker_id); + + /// Request dependencies for a queued task. This will attempt to make any + /// remote objects local until the caller cancels the task's dependencies. + /// + /// This method can only be called once per task, until the task has been + /// canceled. + /// + /// \param task_id The task that requires the objects. + /// \param required_objects The objects required by the task. + /// \return Void. + bool RequestTaskDependencies(const TaskID &task_id, + const std::vector &required_objects); + + /// Check whether a task is ready to run. The task ID must have been + /// previously added by the caller. + /// + /// \param task_id The ID of the task to check. + /// \return Whether all of the dependencies for the task are + /// local. + bool IsTaskReady(const TaskID &task_id) const; + + /// Cancel a task's dependencies. We will no longer attempt to fetch any + /// remote dependencies, if no other task or worker requires them. + /// + /// This method can only be called on a task whose dependencies were added. + /// + /// \param task_id The task that requires the objects. + /// \param required_objects The objects required by the task. + /// \return Void. + void RemoveTaskDependencies(const TaskID &task_id); + + /// Handle an object becoming locally available. + /// + /// \param object_id The object ID of the object to mark as locally + /// available. + /// \return A list of task IDs. This contains all added tasks that now have + /// all of their dependencies fulfilled. + std::vector HandleObjectLocal(const ray::ObjectID &object_id); + + /// Handle an object that is no longer locally available. + /// + /// \param object_id The object ID of the object that was previously locally + /// available. + /// \return A list of task IDs. This contains all added tasks that previously + /// had all of their dependencies fulfilled, but are now missing this object + /// dependency. + std::vector HandleObjectMissing(const ray::ObjectID &object_id); + + /// Returns debug string for class. + /// + /// \return string. + std::string DebugString() const; + + private: + /// Metadata for an object that is needed by at least one executing worker + /// and/or one queued task. + struct ObjectDependencies { + ObjectDependencies(const rpc::ObjectReference &ref) + : owner_address(ref.owner_address()) {} + /// The tasks that depend on this object, either because the object is a task argument + /// or because the task called `ray.get` on the object. + std::unordered_set dependent_tasks; + /// The workers that depend on this object because they called `ray.get` on the + /// object. + std::unordered_set dependent_get_requests; + /// The workers that depend on this object because they called `ray.wait` on the + /// object. + std::unordered_set dependent_wait_requests; + /// If this object is required by at least one worker that called `ray.wait`, this is + /// the pull request ID. + uint64_t wait_request_id = 0; + /// The address of the worker that owns this object. + rpc::Address owner_address; + + bool Empty() const { + return dependent_tasks.empty() && dependent_get_requests.empty() && + dependent_wait_requests.empty(); + } + }; + + /// A struct to represent the object dependencies of a task. + struct TaskDependencies { + TaskDependencies(const std::vector &deps) + : num_missing_dependencies(deps.size()) { + const auto dep_ids = ObjectRefsToIds(deps); + dependencies.insert(dep_ids.begin(), dep_ids.end()); + } + /// The objects that the task depends on. These are the arguments to the + /// task. These must all be simultaneously local before the task is ready + /// to execute. Objects are removed from this set once + /// UnsubscribeGetDependencies is called. + absl::flat_hash_set dependencies; + /// The number of object arguments that are not available locally. This + /// must be zero before the task is ready to execute. + size_t num_missing_dependencies; + /// Used to identify the pull request for the dependencies to the object + /// manager. + uint64_t pull_request_id = 0; + }; + + /// Stop tracking this object, if it is no longer needed by any worker or + /// queued task. + void RemoveObjectIfNotNeeded( + absl::flat_hash_map::iterator required_object_it); + + /// Start tracking an object that is needed by a worker and/or queued task. + absl::flat_hash_map::iterator GetOrInsertRequiredObject( + const ObjectID &object_id, const rpc::ObjectReference &ref); + + /// The object manager, used to fetch required objects from remote nodes. + ObjectManagerInterface &object_manager_; + /// The reconstruction policy, used to reconstruct required objects that no + /// longer exist on any live nodes. + /// TODO(swang): This class is no longer needed for reconstruction, since the + /// object's owner handles reconstruction. We use this class as a timer to + /// detect the owner's failure. Remove this class and move the timer logic + /// into this class. + ReconstructionPolicyInterface &reconstruction_policy_; + + /// A map from the ID of a queued task to metadata about whether the task's + /// dependencies are all local or not. + absl::flat_hash_map queued_task_requests_; + + /// A map from worker ID to the set of objects that the worker called + /// `ray.get` on and a pull request ID for these objects. The pull request ID + /// should be used to cancel the pull request in the object manager once the + /// worker cancels the `ray.get` request. + absl::flat_hash_map, uint64_t>> + get_requests_; + + /// A map from worker ID to the set of objects that the worker called + /// `ray.wait` on. Objects are removed from the set once they are made local, + /// or the worker cancels the `ray.wait` request. + absl::flat_hash_map> wait_requests_; + + /// Deduplicated pool of objects required by all queued tasks and workers. + /// Objects are removed from this set once there are no more tasks or workers + /// that require it. + absl::flat_hash_map required_objects_; + + /// The set of locally available objects. This is used to determine which + /// tasks are ready to run and which `ray.wait` requests can be finished. + std::unordered_set local_objects_; + + friend class DependencyManagerTest; +}; + +} // namespace raylet + +} // namespace ray diff --git a/src/ray/raylet/dependency_manager_test.cc b/src/ray/raylet/dependency_manager_test.cc new file mode 100644 index 000000000..c6d0ab2ee --- /dev/null +++ b/src/ray/raylet/dependency_manager_test.cc @@ -0,0 +1,372 @@ +// Copyright 2017 The Ray Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ray/raylet/dependency_manager.h" + +#include + +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "ray/common/task/task_util.h" +#include "ray/common/test_util.h" + +namespace ray { + +namespace raylet { + +using ::testing::_; +using ::testing::InSequence; +using ::testing::Return; + +class MockObjectManager : public ObjectManagerInterface { + public: + uint64_t Pull(const std::vector &object_refs) { + active_requests.insert(req_id); + return req_id++; + } + + void CancelPull(uint64_t request_id) { ASSERT_TRUE(active_requests.erase(request_id)); } + + uint64_t req_id = 1; + std::unordered_set active_requests; +}; + +class MockReconstructionPolicy : public ReconstructionPolicyInterface { + public: + MOCK_METHOD2(ListenAndMaybeReconstruct, + void(const ObjectID &object_id, const rpc::Address &owner_address)); + MOCK_METHOD1(Cancel, void(const ObjectID &object_id)); +}; + +class DependencyManagerTest : public ::testing::Test { + public: + DependencyManagerTest() + : object_manager_mock_(), + reconstruction_policy_mock_(), + dependency_manager_(object_manager_mock_, reconstruction_policy_mock_) {} + + void AssertNoLeaks() { + ASSERT_TRUE(dependency_manager_.required_objects_.empty()); + ASSERT_TRUE(dependency_manager_.queued_task_requests_.empty()); + ASSERT_TRUE(dependency_manager_.get_requests_.empty()); + ASSERT_TRUE(dependency_manager_.wait_requests_.empty()); + // All pull requests are canceled. + ASSERT_TRUE(object_manager_mock_.active_requests.empty()); + } + + MockObjectManager object_manager_mock_; + MockReconstructionPolicy reconstruction_policy_mock_; + DependencyManager dependency_manager_; +}; + +/// Test requesting the dependencies for a task. The dependency manager should +/// return the task ID as ready once all of its arguments are local. +TEST_F(DependencyManagerTest, TestSimpleTask) { + // Create a task with 3 arguments. + int num_arguments = 3; + std::vector arguments; + for (int i = 0; i < num_arguments; i++) { + arguments.push_back(ObjectID::FromRandom()); + } + TaskID task_id = RandomTaskId(); + // No objects have been registered in the task dependency manager, so all + // arguments should be remote. + for (const auto &argument_id : arguments) { + EXPECT_CALL(reconstruction_policy_mock_, ListenAndMaybeReconstruct(argument_id, _)); + } + bool ready = + dependency_manager_.RequestTaskDependencies(task_id, ObjectIdsToRefs(arguments)); + ASSERT_FALSE(ready); + ASSERT_EQ(object_manager_mock_.active_requests.size(), 1); + ASSERT_FALSE(dependency_manager_.IsTaskReady(task_id)); + + // For each argument, tell the task dependency manager that the argument is + // local. All arguments should be canceled as they become available locally. + for (const auto &argument_id : arguments) { + EXPECT_CALL(reconstruction_policy_mock_, Cancel(argument_id)); + } + auto ready_task_ids = dependency_manager_.HandleObjectLocal(arguments[0]); + ASSERT_TRUE(ready_task_ids.empty()); + ASSERT_FALSE(dependency_manager_.IsTaskReady(task_id)); + ready_task_ids = dependency_manager_.HandleObjectLocal(arguments[1]); + ASSERT_TRUE(ready_task_ids.empty()); + ASSERT_FALSE(dependency_manager_.IsTaskReady(task_id)); + // The task is ready to run. + ready_task_ids = dependency_manager_.HandleObjectLocal(arguments[2]); + ASSERT_EQ(ready_task_ids.size(), 1); + ASSERT_EQ(ready_task_ids.front(), task_id); + ASSERT_TRUE(dependency_manager_.IsTaskReady(task_id)); + + // Remove the task. + dependency_manager_.RemoveTaskDependencies(task_id); + AssertNoLeaks(); +} + +/// Test multiple tasks that depend on the same object. The dependency manager +/// should return all task IDs as ready once the object is local. +TEST_F(DependencyManagerTest, TestMultipleTasks) { + // Create 3 tasks that are dependent on the same object. + ObjectID argument_id = ObjectID::FromRandom(); + std::vector dependent_tasks; + int num_dependent_tasks = 3; + EXPECT_CALL(reconstruction_policy_mock_, ListenAndMaybeReconstruct(argument_id, _)); + for (int i = 0; i < num_dependent_tasks; i++) { + TaskID task_id = RandomTaskId(); + dependent_tasks.push_back(task_id); + bool ready = dependency_manager_.RequestTaskDependencies( + task_id, ObjectIdsToRefs({argument_id})); + ASSERT_FALSE(ready); + ASSERT_FALSE(dependency_manager_.IsTaskReady(task_id)); + // The object should be requested from the object manager once for each task. + ASSERT_EQ(object_manager_mock_.active_requests.size(), i + 1); + } + + // Tell the task dependency manager that the object is local. + EXPECT_CALL(reconstruction_policy_mock_, Cancel(argument_id)); + auto ready_task_ids = dependency_manager_.HandleObjectLocal(argument_id); + // Check that all tasks are now ready to run. + std::unordered_set added_tasks(dependent_tasks.begin(), dependent_tasks.end()); + for (auto &id : ready_task_ids) { + ASSERT_TRUE(added_tasks.erase(id)); + ASSERT_TRUE(dependency_manager_.IsTaskReady(id)); + } + ASSERT_TRUE(added_tasks.empty()); + + for (auto &id : dependent_tasks) { + dependency_manager_.RemoveTaskDependencies(id); + } + AssertNoLeaks(); +} + +/// Test task with multiple dependencies. The dependency manager should return +/// the task ID as ready once all dependencies are local. If a dependency is +/// later evicted, the dependency manager should return the task ID as waiting. +TEST_F(DependencyManagerTest, TestTaskArgEviction) { + // Add a task with 3 arguments. + int num_arguments = 3; + std::vector arguments; + for (int i = 0; i < num_arguments; i++) { + arguments.push_back(ObjectID::FromRandom()); + } + TaskID task_id = RandomTaskId(); + for (const auto &argument_id : arguments) { + EXPECT_CALL(reconstruction_policy_mock_, ListenAndMaybeReconstruct(argument_id, _)); + } + bool ready = + dependency_manager_.RequestTaskDependencies(task_id, ObjectIdsToRefs(arguments)); + ASSERT_FALSE(ready); + ASSERT_FALSE(dependency_manager_.IsTaskReady(task_id)); + + // Tell the task dependency manager that each of the arguments is now + // available. + for (const auto &argument_id : arguments) { + EXPECT_CALL(reconstruction_policy_mock_, Cancel(argument_id)); + } + for (size_t i = 0; i < arguments.size(); i++) { + std::vector ready_tasks; + ready_tasks = dependency_manager_.HandleObjectLocal(arguments[i]); + if (i == arguments.size() - 1) { + ASSERT_EQ(ready_tasks.size(), 1); + ASSERT_EQ(ready_tasks.front(), task_id); + } else { + ASSERT_TRUE(ready_tasks.empty()); + } + } + ASSERT_TRUE(dependency_manager_.IsTaskReady(task_id)); + + // Simulate each of the arguments getting evicted. Each object should now be + // considered remote. + for (const auto &argument_id : arguments) { + EXPECT_CALL(reconstruction_policy_mock_, ListenAndMaybeReconstruct(argument_id, _)); + } + for (size_t i = 0; i < arguments.size(); i++) { + std::vector waiting_tasks; + waiting_tasks = dependency_manager_.HandleObjectMissing(arguments[i]); + if (i == 0) { + // The first eviction should cause the task to go back to the waiting + // state. + ASSERT_EQ(waiting_tasks.size(), 1); + ASSERT_EQ(waiting_tasks.front(), task_id); + } else { + // The subsequent evictions shouldn't cause any more tasks to go back to + // the waiting state. + ASSERT_TRUE(waiting_tasks.empty()); + } + ASSERT_FALSE(dependency_manager_.IsTaskReady(task_id)); + } + + // Tell the task dependency manager that each of the arguments is available + // again. + for (const auto &argument_id : arguments) { + EXPECT_CALL(reconstruction_policy_mock_, Cancel(argument_id)); + } + for (size_t i = 0; i < arguments.size(); i++) { + std::vector ready_tasks; + ready_tasks = dependency_manager_.HandleObjectLocal(arguments[i]); + if (i == arguments.size() - 1) { + ASSERT_EQ(ready_tasks.size(), 1); + ASSERT_EQ(ready_tasks.front(), task_id); + } else { + ASSERT_TRUE(ready_tasks.empty()); + } + } + ASSERT_TRUE(dependency_manager_.IsTaskReady(task_id)); + + dependency_manager_.RemoveTaskDependencies(task_id); + AssertNoLeaks(); +} + +/// Test `ray.get`. Worker calls ray.get on {oid1}, then {oid1, oid2}, then +/// {oid1, oid2, oid3}. +TEST_F(DependencyManagerTest, TestGet) { + WorkerID worker_id = WorkerID::FromRandom(); + int num_arguments = 3; + std::vector arguments; + for (int i = 0; i < num_arguments; i++) { + // Add the new argument to the list of dependencies to subscribe to. + ObjectID argument_id = ObjectID::FromRandom(); + arguments.push_back(argument_id); + // Subscribe to the task's dependencies. All arguments except the last are + // duplicates of previous subscription calls. Each argument should only be + // requested from the node manager once. + EXPECT_CALL(reconstruction_policy_mock_, ListenAndMaybeReconstruct(argument_id, _)); + auto prev_pull_reqs = object_manager_mock_.active_requests; + dependency_manager_.StartOrUpdateGetRequest(worker_id, ObjectIdsToRefs(arguments)); + // Previous pull request for this worker should be canceled upon each new + // bundle. + ASSERT_EQ(object_manager_mock_.active_requests.size(), 1); + ASSERT_NE(object_manager_mock_.active_requests, prev_pull_reqs); + } + + // Nothing happens if the same bundle is requested. + auto prev_pull_reqs = object_manager_mock_.active_requests; + dependency_manager_.StartOrUpdateGetRequest(worker_id, ObjectIdsToRefs(arguments)); + ASSERT_EQ(object_manager_mock_.active_requests, prev_pull_reqs); + + // All arguments should be canceled as they become available locally. + for (const auto &argument_id : arguments) { + EXPECT_CALL(reconstruction_policy_mock_, Cancel(argument_id)); + } + + // Cancel the pull request once the worker cancels the `ray.get`. + dependency_manager_.CancelGetRequest(worker_id); + AssertNoLeaks(); +} + +/// Test that when one of the objects becomes local after a `ray.wait` call, +/// all requests to remote nodes associated with the object are canceled. +TEST_F(DependencyManagerTest, TestWait) { + // Generate a random worker and objects to wait on. + WorkerID worker_id = WorkerID::FromRandom(); + int num_objects = 3; + std::vector oids; + for (int i = 0; i < num_objects; i++) { + oids.push_back(ObjectID::FromRandom()); + } + // Simulate a worker calling `ray.wait` on some objects. + EXPECT_CALL(reconstruction_policy_mock_, ListenAndMaybeReconstruct(_, _)) + .Times(num_objects); + dependency_manager_.StartOrUpdateWaitRequest(worker_id, ObjectIdsToRefs(oids)); + ASSERT_EQ(object_manager_mock_.active_requests.size(), num_objects); + + for (int i = 0; i < num_objects; i++) { + // Object is local. + EXPECT_CALL(reconstruction_policy_mock_, Cancel(oids[i])); + auto ready_task_ids = dependency_manager_.HandleObjectLocal(oids[i]); + + // Local object gets evicted. The `ray.wait` call should not be + // reactivated. + auto waiting_task_ids = dependency_manager_.HandleObjectMissing(oids[i]); + ASSERT_TRUE(waiting_task_ids.empty()); + ASSERT_EQ(object_manager_mock_.active_requests.size(), num_objects - i - 1); + } + AssertNoLeaks(); +} + +/// Test that when no objects are locally available, a `ray.wait` call makes +/// the correct requests to remote nodes and correctly cancels the requests +/// when the `ray.wait` call is canceled. +TEST_F(DependencyManagerTest, TestWaitThenCancel) { + // Generate a random worker and objects to wait on. + WorkerID worker_id = WorkerID::FromRandom(); + int num_objects = 3; + std::vector oids; + for (int i = 0; i < num_objects; i++) { + oids.push_back(ObjectID::FromRandom()); + } + // Simulate a worker calling `ray.wait` on some objects. + EXPECT_CALL(reconstruction_policy_mock_, ListenAndMaybeReconstruct(_, _)) + .Times(num_objects); + dependency_manager_.StartOrUpdateWaitRequest(worker_id, ObjectIdsToRefs(oids)); + ASSERT_EQ(object_manager_mock_.active_requests.size(), num_objects); + auto prev_pull_reqs = object_manager_mock_.active_requests; + // Check that it's okay to call `ray.wait` on the same objects again. No new + // calls should be made to try and make the objects local. + dependency_manager_.StartOrUpdateWaitRequest(worker_id, ObjectIdsToRefs(oids)); + ASSERT_EQ(object_manager_mock_.active_requests, prev_pull_reqs); + // Cancel the worker's `ray.wait`. + EXPECT_CALL(reconstruction_policy_mock_, Cancel(_)).Times(num_objects); + dependency_manager_.CancelWaitRequest(worker_id); + AssertNoLeaks(); +} + +/// Test that when one of the objects is already local at the time of the +/// `ray.wait` call, the `ray.wait` call does not trigger any requests to +/// remote nodes for that object. +TEST_F(DependencyManagerTest, TestWaitObjectLocal) { + // Generate a random worker and objects to wait on. + WorkerID worker_id = WorkerID::FromRandom(); + int num_objects = 3; + std::vector oids; + for (int i = 0; i < num_objects; i++) { + oids.push_back(ObjectID::FromRandom()); + } + // Simulate one of the objects becoming local. The later `ray.wait` call + // should have no effect because the object is already local. + const ObjectID local_object_id = std::move(oids.back()); + auto ready_task_ids = dependency_manager_.HandleObjectLocal(local_object_id); + ASSERT_TRUE(ready_task_ids.empty()); + + // Simulate a worker calling `ray.wait` on the objects. It should only make + // requests for the objects that are not local. + for (const auto &object_id : oids) { + if (object_id != local_object_id) { + EXPECT_CALL(reconstruction_policy_mock_, ListenAndMaybeReconstruct(object_id, _)); + } + } + dependency_manager_.StartOrUpdateWaitRequest(worker_id, ObjectIdsToRefs(oids)); + ASSERT_EQ(object_manager_mock_.active_requests.size(), num_objects - 1); + // Simulate the local object getting evicted. The `ray.wait` call should not + // be reactivated. + auto waiting_task_ids = dependency_manager_.HandleObjectMissing(local_object_id); + ASSERT_TRUE(waiting_task_ids.empty()); + ASSERT_EQ(object_manager_mock_.active_requests.size(), num_objects - 1); + // Cancel the worker's `ray.wait`. + for (const auto &object_id : oids) { + if (object_id != local_object_id) { + EXPECT_CALL(reconstruction_policy_mock_, Cancel(object_id)); + } + } + dependency_manager_.CancelWaitRequest(worker_id); + AssertNoLeaks(); +} + +} // namespace raylet + +} // namespace ray + +int main(int argc, char **argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/src/ray/raylet/node_manager.cc b/src/ray/raylet/node_manager.cc index 7f207c4fb..bcfc4e32f 100644 --- a/src/ray/raylet/node_manager.cc +++ b/src/ray/raylet/node_manager.cc @@ -158,7 +158,7 @@ NodeManager::NodeManager(boost::asio::io_service &io_service, const NodeID &self }, RayConfig::instance().object_timeout_milliseconds(), self_node_id_, gcs_client_, object_directory_), - task_dependency_manager_(object_manager, reconstruction_policy_), + dependency_manager_(object_manager, reconstruction_policy_), node_manager_server_("NodeManager", config.node_manager_port), node_manager_service_(io_service, *this), agent_manager_service_handler_( @@ -218,7 +218,7 @@ NodeManager::NodeManager(boost::asio::io_service &io_service, const NodeID &self PublishInfeasibleTaskError(task); }; cluster_task_manager_ = std::shared_ptr(new ClusterTaskManager( - self_node_id_, new_resource_scheduler_, task_dependency_manager_, is_owner_alive, + self_node_id_, new_resource_scheduler_, dependency_manager_, is_owner_alive, get_node_info_func, announce_infeasible_task)); placement_group_resource_manager_ = std::make_shared(new_resource_scheduler_); @@ -382,7 +382,7 @@ void NodeManager::HandleJobFinished(const JobID &job_id, const JobTableData &job for (const auto &worker : workers) { if (!worker->IsDetachedActor()) { // Clean up any open ray.wait calls that the worker made. - task_dependency_manager_.UnsubscribeWaitDependencies(worker->WorkerId()); + dependency_manager_.CancelWaitRequest(worker->WorkerId()); // Mark the worker as dead so further messages from it are ignored // (except DisconnectClient). worker->MarkDead(); @@ -390,18 +390,6 @@ void NodeManager::HandleJobFinished(const JobID &job_id, const JobTableData &job KillWorker(worker); } } - - if (!new_scheduler_enabled_) { - // Remove all tasks for this job from the scheduling queues, mark - // the results for these tasks as not required, cancel any attempts - // at reconstruction. Note that at this time the workers are likely - // alive because of the delay in killing workers. - auto tasks_to_remove = local_queues_.GetTaskIdsForJob(job_id); - task_dependency_manager_.RemoveTasksAndRelatedObjects(tasks_to_remove); - // NOTE(swang): SchedulingQueue::RemoveTasks modifies its argument so we must - // call it last. - local_queues_.RemoveTasks(tasks_to_remove); - } } void NodeManager::Heartbeat() { @@ -1034,7 +1022,7 @@ void NodeManager::ResourceUsageAdded(const NodeID &node_id, if (state != TaskState::INFEASIBLE) { // Don't unsubscribe for infeasible tasks because we never subscribed in // the first place. - RAY_CHECK(task_dependency_manager_.UnsubscribeGetDependencies(task_id)); + dependency_manager_.RemoveTaskDependencies(task_id); } // Attempt to forward the task. If this fails to forward the task, // the task will be resubmit locally. @@ -1403,7 +1391,7 @@ void NodeManager::ProcessDisconnectClientMessage( AsyncResolveObjectsFinish(client, task_id, true); } // Clean up any open ray.wait calls that the worker made. - task_dependency_manager_.UnsubscribeWaitDependencies(worker->WorkerId()); + dependency_manager_.CancelWaitRequest(worker->WorkerId()); } // Erase any lease metadata. @@ -1425,9 +1413,7 @@ void NodeManager::ProcessDisconnectClientMessage( // If the worker was an actor, it'll be cleaned by GCS. if (actor_id.IsNil()) { Task task; - if (local_queues_.RemoveTask(task_id, &task)) { - TreatTaskAsFailed(task, ErrorType::WORKER_DIED); - } + static_cast(local_queues_.RemoveTask(task_id, &task)); } if (!intentional_disconnect) { @@ -1501,14 +1487,14 @@ void NodeManager::ProcessFetchOrReconstructMessage( const auto refs = FlatbufferToObjectReference(*message->object_ids(), *message->owner_addresses()); if (message->fetch_only()) { - for (const auto &ref : refs) { - ObjectID object_id = ObjectID::FromBinary(ref.object_id()); - // If only a fetch is required, then do not subscribe to the - // dependencies to the task dependency manager. - if (!task_dependency_manager_.CheckObjectLocal(object_id)) { - // Fetch the object if it's not already local. - RAY_CHECK_OK(object_manager_.Pull(object_id, ref.owner_address())); - } + std::shared_ptr worker = worker_pool_.GetRegisteredWorker(client); + if (!worker) { + worker = worker_pool_.GetRegisteredDriver(client); + } + if (worker) { + // This will start a fetch for the objects that gets canceled once the + // objects are local, or if the worker dies. + dependency_manager_.StartOrUpdateWaitRequest(worker->WorkerId(), refs); } } else { // The values are needed. Add all requested objects to the list to @@ -1544,7 +1530,7 @@ void NodeManager::ProcessWaitRequestMessage( bool resolve_objects = false; for (auto const &object_id : object_ids) { - if (!task_dependency_manager_.CheckObjectLocal(object_id)) { + if (!dependency_manager_.CheckObjectLocal(object_id)) { // At least one object requires resolution. resolve_objects = true; } @@ -1904,11 +1890,7 @@ void NodeManager::HandleCancelWorkerLease(const rpc::CancelWorkerLeaseRequest &r bool canceled; if (new_scheduler_enabled_) { canceled = cluster_task_manager_->CancelTask(task_id); - if (canceled) { - // We have not yet granted the worker lease. Cancel it now. - task_dependency_manager_.TaskCanceled(task_id); - task_dependency_manager_.UnsubscribeGetDependencies(task_id); - } else { + if (!canceled) { // There are 2 cases here. // 1. We haven't received the lease request yet. It's the caller's job to // retry the cancellation once we've received the request. @@ -1925,8 +1907,9 @@ void NodeManager::HandleCancelWorkerLease(const rpc::CancelWorkerLeaseRequest &r if (removed_task.OnDispatch()) { // We have not yet granted the worker lease. Cancel it now. removed_task.OnCancellation()(); - task_dependency_manager_.TaskCanceled(task_id); - task_dependency_manager_.UnsubscribeGetDependencies(task_id); + if (removed_task_state == TaskState::WAITING) { + dependency_manager_.RemoveTaskDependencies(task_id); + } } else { // We already granted the worker lease and sent the reply. Re-queue the // task and wait for the requester to return the leased worker. @@ -2035,7 +2018,6 @@ void NodeManager::ScheduleTasks( // submission vs. registering remaining queued placeable tasks here. std::unordered_set move_task_set; for (const auto &task : local_queues_.GetTasks(TaskState::PLACEABLE)) { - task_dependency_manager_.TaskPending(task); move_task_set.insert(task.GetTaskSpecification().TaskId()); PublishInfeasibleTaskError(task); // Assert that this placeable task is not feasible locally (necessary but not @@ -2051,37 +2033,6 @@ void NodeManager::ScheduleTasks( RAY_CHECK(local_queues_.GetTasks(TaskState::PLACEABLE).size() == 0); } -void NodeManager::TreatTaskAsFailed(const Task &task, const ErrorType &error_type) { - const TaskSpecification &spec = task.GetTaskSpecification(); - RAY_LOG(DEBUG) << "Treating task " << spec.TaskId() << " as failed because of error " - << ErrorType_Name(error_type) << "."; - // Loop over the return IDs (except the dummy ID) and store a fake object in - // the object store. - int64_t num_returns = spec.NumReturns(); - if (spec.IsActorCreationTask()) { - // TODO(rkn): We subtract 1 to avoid the dummy ID. However, this leaks - // information about the TaskSpecification implementation. - num_returns -= 1; - } - // Determine which IDs should be marked as failed. - std::vector objects_to_fail; - for (int64_t i = 0; i < num_returns; i++) { - rpc::ObjectReference ref; - ref.set_object_id(spec.ReturnId(i).Binary()); - ref.mutable_owner_address()->CopyFrom(spec.CallerAddress()); - objects_to_fail.push_back(ref); - } - const JobID job_id = task.GetTaskSpecification().JobId(); - MarkObjectsAsFailed(error_type, objects_to_fail, job_id); - task_dependency_manager_.TaskCanceled(spec.TaskId()); - // Notify the task dependency manager that we no longer need this task's - // object dependencies. TODO(swang): Ideally, we would check the return value - // here. However, we don't know at this point if the task was in the WAITING - // or READY queue before, in which case we would not have been subscribed to - // its dependencies. - task_dependency_manager_.UnsubscribeGetDependencies(spec.TaskId()); -} - void NodeManager::MarkObjectsAsFailed( const ErrorType &error_type, const std::vector objects_to_fail, const JobID &job_id) { @@ -2189,7 +2140,7 @@ void NodeManager::HandleDirectCallTaskUnblocked( // First, always release task dependencies. This ensures we don't leak resources even // if we don't need to unblock the worker below. - task_dependency_manager_.UnsubscribeGetDependencies(task_id); + dependency_manager_.CancelGetRequest(worker->WorkerId()); if (new_scheduler_enabled_) { // Important: avoid double unblocking if the unblock RPC finishes after task end. @@ -2281,15 +2232,10 @@ void NodeManager::AsyncResolveObjects( // fetched and/or restarted as necessary, until the objects become local // or are unsubscribed. if (ray_get) { - // TODO(ekl) using the assigned task id is a hack to handle unsubscription for - // HandleDirectCallUnblocked. - auto &task_id = mark_worker_blocked ? current_task_id : worker->GetAssignedTaskId(); - if (!task_id.IsNil()) { - task_dependency_manager_.SubscribeGetDependencies(task_id, required_object_refs); - } + dependency_manager_.StartOrUpdateGetRequest(worker->WorkerId(), required_object_refs); } else { - task_dependency_manager_.SubscribeWaitDependencies(worker->WorkerId(), - required_object_refs); + dependency_manager_.StartOrUpdateWaitRequest(worker->WorkerId(), + required_object_refs); } } @@ -2341,13 +2287,13 @@ void NodeManager::AsyncResolveObjectsFinish( worker = worker_pool_.GetRegisteredDriver(client); } + RAY_CHECK(worker); // Unsubscribe from any `ray.get` objects that the task was blocked on. Any // fetch or reconstruction operations to make the objects local are canceled. // `ray.wait` calls will stay active until the objects become local, or the // task/actor that called `ray.wait` exits. - task_dependency_manager_.UnsubscribeGetDependencies(current_task_id); + dependency_manager_.CancelGetRequest(worker->WorkerId()); // Mark the task as unblocked. - RAY_CHECK(worker); if (was_blocked) { worker->RemoveBlockedTaskId(current_task_id); local_queues_.RemoveBlockedTaskId(current_task_id); @@ -2358,7 +2304,7 @@ void NodeManager::EnqueuePlaceableTask(const Task &task) { // TODO(atumanov): add task lookup hashmap and change EnqueuePlaceableTask to take // a vector of TaskIDs. Trigger MoveTask internally. // Subscribe to the task's dependencies. - bool args_ready = task_dependency_manager_.SubscribeGetDependencies( + bool args_ready = dependency_manager_.RequestTaskDependencies( task.GetTaskSpecification().TaskId(), task.GetDependencies()); // Enqueue the task. If all dependencies are available, then the task is queued // in the READY state, else the WAITING state. @@ -2369,10 +2315,6 @@ void NodeManager::EnqueuePlaceableTask(const Task &task) { } else { local_queues_.QueueTasks({task}, TaskState::WAITING); } - // Mark the task as pending. Once the task has finished execution, or once it - // has been forwarded to another node, the task must be marked as canceled in - // the TaskDependencyManager. - task_dependency_manager_.TaskPending(task); } void NodeManager::AssignTask(const std::shared_ptr &worker, @@ -2470,12 +2412,11 @@ bool NodeManager::FinishAssignedTask(const std::shared_ptr &wor } else { // If this was a non-actor task, then cancel any ray.wait calls that were // made during the task execution. - task_dependency_manager_.UnsubscribeWaitDependencies(worker.WorkerId()); + dependency_manager_.CancelWaitRequest(worker.WorkerId()); } // Notify the task dependency manager that this task has finished execution. - task_dependency_manager_.UnsubscribeGetDependencies(spec.TaskId()); - task_dependency_manager_.TaskCanceled(task_id); + dependency_manager_.CancelGetRequest(worker.WorkerId()); if (!spec.IsActorCreationTask()) { // Unset the worker's assigned task. We keep the assigned task ID for @@ -2507,8 +2448,7 @@ void NodeManager::HandleTaskReconstruction(const TaskID &task_id, const ObjectID &required_object_id) { // Get the owner's address. rpc::Address owner_addr; - bool has_owner = - task_dependency_manager_.GetOwnerAddress(required_object_id, &owner_addr); + bool has_owner = dependency_manager_.GetOwnerAddress(required_object_id, &owner_addr); if (has_owner) { if (!RayConfig::instance().object_pinning_enabled()) { // LRU eviction is enabled. The object may still be in scope, but we @@ -2573,7 +2513,7 @@ void NodeManager::HandleTaskReconstruction(const TaskID &task_id, void NodeManager::HandleObjectLocal(const ObjectID &object_id) { // Notify the task dependency manager that this object is local. - const auto ready_task_ids = task_dependency_manager_.HandleObjectLocal(object_id); + const auto ready_task_ids = dependency_manager_.HandleObjectLocal(object_id); RAY_LOG(DEBUG) << "Object local " << object_id << ", " << " on " << self_node_id_ << ", " << ready_task_ids.size() << " tasks ready"; @@ -2621,7 +2561,7 @@ bool NodeManager::IsActorCreationTask(const TaskID &task_id) { void NodeManager::HandleObjectMissing(const ObjectID &object_id) { // Notify the task dependency manager that this object is no longer local. - const auto waiting_task_ids = task_dependency_manager_.HandleObjectMissing(object_id); + const auto waiting_task_ids = dependency_manager_.HandleObjectMissing(object_id); std::stringstream result; result << "Object missing " << object_id << ", " << " on " << self_node_id_ << ", " << waiting_task_ids.size() @@ -2689,10 +2629,6 @@ void NodeManager::ForwardTaskOrResubmit(const Task &task, const NodeID &node_man RAY_LOG(INFO) << "Failed to forward task " << task_id << " to node manager " << node_manager_id; - // Mark the failed task as pending to let other raylets know that we still - // have the task. TaskDependencyManager::TaskPending() is assumed to be - // idempotent. - task_dependency_manager_.TaskPending(task); // The task is not for an actor and may therefore be placed on another // node immediately. Send it to the scheduling policy to be placed again. local_queues_.QueueTasks({task}, TaskState::PLACEABLE); @@ -2743,7 +2679,7 @@ void NodeManager::FinishAssignTask(const std::shared_ptr &worke local_queues_.QueueTasks({assigned_task}, TaskState::RUNNING); // Notify the task dependency manager that we no longer need this task's // object dependencies. - RAY_CHECK(task_dependency_manager_.UnsubscribeGetDependencies(spec.TaskId())); + dependency_manager_.RemoveTaskDependencies(spec.TaskId()); } else { RAY_LOG(WARNING) << "Failed to send task to worker, disconnecting client"; // We failed to send the task to the worker, so disconnect the worker. @@ -2770,7 +2706,7 @@ void NodeManager::ProcessSubscribePlasmaReady( auto message = flatbuffers::GetRoot(message_data); ObjectID id = from_flatbuf(*message->object_id()); - if (task_dependency_manager_.CheckObjectLocal(id)) { + if (dependency_manager_.CheckObjectLocal(id)) { // Object is already local, so we directly fire the callback to tell the core worker // that the plasma object is ready. rpc::PlasmaObjectReadyRequest request; @@ -2797,8 +2733,7 @@ void NodeManager::ProcessSubscribePlasmaReady( // is local at this time but when the core worker was notified, the object is // is evicted. The core worker should be able to handle evicted object in this // case. - task_dependency_manager_.SubscribeWaitDependencies(associated_worker->WorkerId(), - refs); + dependency_manager_.StartOrUpdateWaitRequest(associated_worker->WorkerId(), refs); // Add this worker to the listeners for the object ID. { @@ -2864,7 +2799,7 @@ std::string NodeManager::DebugString() const { result << "\n" << worker_pool_.DebugString(); result << "\n" << local_queues_.DebugString(); result << "\n" << reconstruction_policy_.DebugString(); - result << "\n" << task_dependency_manager_.DebugString(); + result << "\n" << dependency_manager_.DebugString(); { absl::MutexLock guard(&plasma_object_notification_lock_); result << "\nnum async plasma notifications: " diff --git a/src/ray/raylet/node_manager.h b/src/ray/raylet/node_manager.h index a734ebaec..63f741af0 100644 --- a/src/ray/raylet/node_manager.h +++ b/src/ray/raylet/node_manager.h @@ -35,7 +35,7 @@ #include "ray/raylet/scheduling_policy.h" #include "ray/raylet/scheduling_queue.h" #include "ray/raylet/reconstruction_policy.h" -#include "ray/raylet/task_dependency_manager.h" +#include "ray/raylet/dependency_manager.h" #include "ray/raylet/worker_pool.h" #include "ray/rpc/worker/core_worker_client_pool.h" #include "ray/util/ordered_set.h" @@ -239,18 +239,6 @@ class NodeManager : public rpc::NodeManagerServiceHandler { /// \param task The task in question. /// \return Void. void EnqueuePlaceableTask(const Task &task); - /// This will treat a task removed from the local queue as if it had been - /// executed and failed. This is done by looping over the task return IDs and - /// for each ID storing an object that represents a failure in the object - /// store. When clients retrieve these objects, they will raise - /// application-level exceptions. State for the task will be cleaned up as if - /// it were any other task that had been assigned, executed, and removed from - /// the local queue. - /// - /// \param task The task to fail. - /// \param error_type The type of the error that caused this task to fail. - /// \return Void. - void TreatTaskAsFailed(const Task &task, const ErrorType &error_type); /// Mark the specified objects as failed with the given error type. /// /// \param error_type The type of the error that caused this task to fail. @@ -707,8 +695,9 @@ class NodeManager : public rpc::NodeManagerServiceHandler { SchedulingPolicy scheduling_policy_; /// The reconstruction policy for deciding when to re-execute a task. ReconstructionPolicy reconstruction_policy_; - /// A manager to make waiting tasks's missing object dependencies available. - TaskDependencyManager task_dependency_manager_; + /// A manager to resolve objects needed by queued tasks and workers that + /// called `ray.get` or `ray.wait`. + DependencyManager dependency_manager_; std::unique_ptr agent_manager_; diff --git a/src/ray/raylet/scheduling/cluster_task_manager.cc b/src/ray/raylet/scheduling/cluster_task_manager.cc index ab3b46ea7..fc03a2a77 100644 --- a/src/ray/raylet/scheduling/cluster_task_manager.cc +++ b/src/ray/raylet/scheduling/cluster_task_manager.cc @@ -102,7 +102,7 @@ bool ClusterTaskManager::WaitForTaskArgsRequests(Work work) { auto object_ids = task.GetTaskSpecification().GetDependencies(); bool can_dispatch = true; if (object_ids.size() > 0) { - bool args_ready = task_dependency_manager_.SubscribeGetDependencies( + bool args_ready = task_dependency_manager_.RequestTaskDependencies( task.GetTaskSpecification().TaskId(), task.GetDependencies()); if (args_ready) { RAY_LOG(DEBUG) << "Args already ready, task can be dispatched " @@ -164,7 +164,8 @@ void ClusterTaskManager::DispatchScheduledTasksToWorkers( << "'s caller is no longer running. Cancelling task."; worker_pool.PushWorker(worker); if (!spec.GetDependencies().empty()) { - RAY_CHECK(task_dependency_manager_.UnsubscribeGetDependencies(spec.TaskId())); + task_dependency_manager_.RemoveTaskDependencies( + task.GetTaskSpecification().TaskId()); } work_it = dispatch_queue.erase(work_it); } else { @@ -179,7 +180,8 @@ void ClusterTaskManager::DispatchScheduledTasksToWorkers( } if (remove) { if (!spec.GetDependencies().empty()) { - RAY_CHECK(task_dependency_manager_.UnsubscribeGetDependencies(spec.TaskId())); + task_dependency_manager_.RemoveTaskDependencies( + task.GetTaskSpecification().TaskId()); } work_it = dispatch_queue.erase(work_it); } else { @@ -313,7 +315,8 @@ bool ClusterTaskManager::CancelTask(const TaskID &task_id) { RemoveFromBacklogTracker(task); ReplyCancelled(*work_it); if (!task.GetTaskSpecification().GetDependencies().empty()) { - RAY_CHECK(task_dependency_manager_.UnsubscribeGetDependencies(task_id)); + task_dependency_manager_.RemoveTaskDependencies( + task.GetTaskSpecification().TaskId()); } work_queue.erase(work_it); if (work_queue.empty()) { @@ -347,9 +350,11 @@ bool ClusterTaskManager::CancelTask(const TaskID &task_id) { RemoveFromBacklogTracker(task); ReplyCancelled(iter->second); if (!task.GetTaskSpecification().GetDependencies().empty()) { - task_dependency_manager_.UnsubscribeGetDependencies(task_id); + task_dependency_manager_.RemoveTaskDependencies(task_id); } waiting_tasks_.erase(iter); + + task_dependency_manager_.RemoveTaskDependencies(task_id); return true; } diff --git a/src/ray/raylet/scheduling/cluster_task_manager.h b/src/ray/raylet/scheduling/cluster_task_manager.h index 61cfce031..269954eb9 100644 --- a/src/ray/raylet/scheduling/cluster_task_manager.h +++ b/src/ray/raylet/scheduling/cluster_task_manager.h @@ -4,8 +4,8 @@ #include "absl/container/flat_hash_set.h" #include "ray/common/task/task.h" #include "ray/common/task/task_common.h" +#include "ray/raylet/dependency_manager.h" #include "ray/raylet/scheduling/cluster_resource_scheduler.h" -#include "ray/raylet/task_dependency_manager.h" #include "ray/raylet/worker.h" #include "ray/raylet/worker_pool.h" #include "ray/rpc/grpc_client.h" @@ -53,7 +53,7 @@ class ClusterTaskManager { /// \param gcs_client: A gcs client. ClusterTaskManager(const NodeID &self_node_id, std::shared_ptr cluster_resource_scheduler, - TaskDependencyManagerInterface &task_dependency_manager_, + TaskDependencyManagerInterface &task_dependency_manager, std::function is_owner_alive, NodeInfoGetter get_node_info, std::function announce_infeasible_task); diff --git a/src/ray/raylet/scheduling/cluster_task_manager_test.cc b/src/ray/raylet/scheduling/cluster_task_manager_test.cc index 3f33cbf06..e2ca4c5b7 100644 --- a/src/ray/raylet/scheduling/cluster_task_manager_test.cc +++ b/src/ray/raylet/scheduling/cluster_task_manager_test.cc @@ -96,14 +96,14 @@ Task CreateTask(const std::unordered_map &required_resource class MockTaskDependencyManager : public TaskDependencyManagerInterface { public: - bool SubscribeGetDependencies( + bool RequestTaskDependencies( const TaskID &task_id, const std::vector &required_objects) { RAY_CHECK(subscribed_tasks.insert(task_id).second); return task_ready_; } - bool UnsubscribeGetDependencies(const TaskID &task_id) { - return subscribed_tasks.erase(task_id); + void RemoveTaskDependencies(const TaskID &task_id) { + RAY_CHECK(subscribed_tasks.erase(task_id)); } bool IsTaskReady(const TaskID &task_id) const { return task_ready_; } diff --git a/src/ray/raylet/task_dependency_manager.cc b/src/ray/raylet/task_dependency_manager.cc deleted file mode 100644 index 74c3d8c7a..000000000 --- a/src/ray/raylet/task_dependency_manager.cc +++ /dev/null @@ -1,474 +0,0 @@ -// Copyright 2017 The Ray Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "ray/raylet/task_dependency_manager.h" - -#include "absl/time/clock.h" -#include "ray/stats/stats.h" - -namespace ray { - -namespace raylet { - -TaskDependencyManager::TaskDependencyManager( - ObjectManagerInterface &object_manager, - ReconstructionPolicyInterface &reconstruction_policy) - : object_manager_(object_manager), reconstruction_policy_(reconstruction_policy) {} - -bool TaskDependencyManager::CheckObjectLocal(const ObjectID &object_id) const { - return local_objects_.count(object_id) == 1; -} - -bool TaskDependencyManager::CheckObjectRequired(const ObjectID &object_id, - rpc::Address *owner_address) const { - const TaskID task_id = object_id.TaskId(); - auto task_entry = required_tasks_.find(task_id); - // If there are no subscribed tasks that are dependent on the object, then do - // nothing. - if (task_entry == required_tasks_.end()) { - return false; - } - if (task_entry->second.count(object_id) == 0) { - return false; - } - // If the object is already local, then the dependency is fulfilled. Do - // nothing. - if (local_objects_.count(object_id) == 1) { - return false; - } - // If the task that creates the object is pending execution, then the - // dependency will be fulfilled locally. Do nothing. - if (pending_tasks_.count(task_id) == 1) { - return false; - } - if (owner_address != nullptr) { - *owner_address = task_entry->second.at(object_id).owner_address; - } - return true; -} - -void TaskDependencyManager::HandleRemoteDependencyRequired(const ObjectID &object_id) { - rpc::Address owner_address; - bool required = CheckObjectRequired(object_id, &owner_address); - // If the object is required, then try to make the object available locally. - if (required) { - auto inserted = required_objects_.insert(object_id); - if (inserted.second) { - // If we haven't already, request the object manager to pull it from a - // remote node. - RAY_CHECK_OK(object_manager_.Pull(object_id, owner_address)); - reconstruction_policy_.ListenAndMaybeReconstruct(object_id, owner_address); - } - } -} - -void TaskDependencyManager::HandleRemoteDependencyCanceled(const ObjectID &object_id) { - bool required = CheckObjectRequired(object_id, nullptr); - // If the object is no longer required, then cancel the object. - if (!required) { - auto it = required_objects_.find(object_id); - if (it != required_objects_.end()) { - object_manager_.CancelPull(object_id); - reconstruction_policy_.Cancel(object_id); - required_objects_.erase(it); - } - } -} - -std::vector TaskDependencyManager::HandleObjectLocal( - const ray::ObjectID &object_id) { - // Add the object to the table of locally available objects. - auto inserted = local_objects_.insert(object_id); - RAY_CHECK(inserted.second) << object_id; - - // Find all tasks and workers that depend on the newly available object. - std::vector ready_task_ids; - auto creating_task_entry = required_tasks_.find(object_id.TaskId()); - if (creating_task_entry != required_tasks_.end()) { - auto object_entry = creating_task_entry->second.find(object_id); - if (object_entry != creating_task_entry->second.end()) { - // Loop through all tasks that depend on the newly available object. - for (const auto &dependent_task_id : object_entry->second.dependent_tasks) { - auto &task_entry = task_dependencies_[dependent_task_id]; - task_entry.num_missing_get_dependencies--; - // If the dependent task now has all of its arguments ready, it's ready - // to run. - if (task_entry.num_missing_get_dependencies == 0) { - ready_task_ids.push_back(dependent_task_id); - } - } - // Remove the dependency from all workers that called `ray.wait` on the - // newly available object. - for (const auto &worker_id : object_entry->second.dependent_workers) { - RAY_CHECK(worker_dependencies_[worker_id].erase(object_id) > 0); - } - // Clear all workers that called `ray.wait` on this object, since the - // `ray.wait` calls can now return the object as ready. - object_entry->second.dependent_workers.clear(); - - // If there are no more tasks or workers dependent on the local object or - // the task that created it, then remove the entry completely. - if (object_entry->second.Empty()) { - creating_task_entry->second.erase(object_entry); - if (creating_task_entry->second.empty()) { - required_tasks_.erase(creating_task_entry); - } - } - } - } - - // The object is now local, so cancel any in-progress operations to make the - // object local. - HandleRemoteDependencyCanceled(object_id); - - return ready_task_ids; -} - -std::vector TaskDependencyManager::HandleObjectMissing( - const ray::ObjectID &object_id) { - // Remove the object from the table of locally available objects. - auto erased = local_objects_.erase(object_id); - RAY_CHECK(erased == 1); - - // Find any tasks that are dependent on the missing object. - std::vector waiting_task_ids; - TaskID creating_task_id = object_id.TaskId(); - auto creating_task_entry = required_tasks_.find(creating_task_id); - if (creating_task_entry != required_tasks_.end()) { - auto object_entry = creating_task_entry->second.find(object_id); - if (object_entry != creating_task_entry->second.end()) { - for (auto &dependent_task_id : object_entry->second.dependent_tasks) { - auto &task_entry = task_dependencies_[dependent_task_id]; - // If the dependent task had all of its arguments ready, it was ready to - // run but must be switched to waiting since one of its arguments is now - // missing. - if (task_entry.num_missing_get_dependencies == 0) { - waiting_task_ids.push_back(dependent_task_id); - // During normal execution we should be able to include the check - // RAY_CHECK(pending_tasks_.count(dependent_task_id) == 1); - // However, this invariant will not hold during unit test execution. - } - task_entry.num_missing_get_dependencies++; - } - } - } - // The object is no longer local. Try to make the object local if necessary. - HandleRemoteDependencyRequired(object_id); - // Process callbacks for all of the tasks dependent on the object that are - // now ready to run. - return waiting_task_ids; -} - -bool TaskDependencyManager::SubscribeGetDependencies( - const TaskID &task_id, const std::vector &required_objects) { - auto &task_entry = task_dependencies_[task_id]; - - // Record the task's dependencies. - for (const auto &object : required_objects) { - const auto &object_id = ObjectID::FromBinary(object.object_id()); - auto inserted = task_entry.get_dependencies.insert(object_id); - if (inserted.second) { - RAY_LOG(DEBUG) << "Task " << task_id << " blocked on object " << object_id; - // Get the ID of the task that creates the dependency. - TaskID creating_task_id = object_id.TaskId(); - // Determine whether the dependency can be fulfilled by the local node. - if (local_objects_.count(object_id) == 0) { - // The object is not local. - task_entry.num_missing_get_dependencies++; - } - - auto it = required_tasks_[creating_task_id].find(object_id); - if (it == required_tasks_[creating_task_id].end()) { - it = required_tasks_[creating_task_id] - .emplace(object_id, ObjectDependencies(object)) - .first; - } - // Add the subscribed task to the mapping from object ID to list of - // dependent tasks. - it->second.dependent_tasks.insert(task_id); - } - } - - // These dependencies are required by the given task. Try to make them local - // if necessary. - for (const auto &object : required_objects) { - const auto &object_id = ObjectID::FromBinary(object.object_id()); - HandleRemoteDependencyRequired(object_id); - } - - // Return whether all dependencies are local. - return (task_entry.num_missing_get_dependencies == 0); -} - -bool TaskDependencyManager::IsTaskReady(const TaskID &task_id) const { - auto task_entry = task_dependencies_.find(task_id); - RAY_CHECK(task_entry != task_dependencies_.end()); - return task_entry->second.num_missing_get_dependencies == 0; -} - -void TaskDependencyManager::SubscribeWaitDependencies( - const WorkerID &worker_id, - const std::vector &required_objects) { - auto &worker_entry = worker_dependencies_[worker_id]; - - // Record the worker's dependencies. - for (const auto &object : required_objects) { - const auto &object_id = ObjectID::FromBinary(object.object_id()); - if (local_objects_.count(object_id) == 0) { - RAY_LOG(DEBUG) << "Worker " << worker_id << " called ray.wait on remote object " - << object_id; - // Only add the dependency if the object is not local. If the object is - // local, then the `ray.wait` call can already return it. - auto inserted = worker_entry.insert(object_id); - if (inserted.second) { - // Get the ID of the task that creates the dependency. - TaskID creating_task_id = object_id.TaskId(); - auto it = required_tasks_[creating_task_id].find(object_id); - if (it == required_tasks_[creating_task_id].end()) { - it = required_tasks_[creating_task_id] - .emplace(object_id, ObjectDependencies(object)) - .first; - } - // Add the subscribed worker to the mapping from object ID to list of - // dependent workers. - it->second.dependent_workers.insert(worker_id); - } - } - } - - // These dependencies are required by the given worker. Try to make them - // local if necessary. - for (const auto &object : required_objects) { - const auto &object_id = ObjectID::FromBinary(object.object_id()); - HandleRemoteDependencyRequired(object_id); - } -} - -bool TaskDependencyManager::UnsubscribeGetDependencies(const TaskID &task_id) { - RAY_LOG(DEBUG) << "Task " << task_id << " no longer blocked"; - // Remove the task from the table of subscribed tasks. - auto it = task_dependencies_.find(task_id); - if (it == task_dependencies_.end()) { - return false; - } - const TaskDependencies task_entry = std::move(it->second); - task_dependencies_.erase(it); - - // Remove the task's dependencies. - for (const auto &object_id : task_entry.get_dependencies) { - // Get the ID of the task that creates the dependency. - TaskID creating_task_id = object_id.TaskId(); - auto creating_task_entry = required_tasks_.find(creating_task_id); - // Remove the task from the list of tasks that are dependent on this - // object. - auto it = creating_task_entry->second.find(object_id); - RAY_CHECK(it != creating_task_entry->second.end()); - RAY_CHECK(it->second.dependent_tasks.erase(task_id) > 0); - // If nothing else depends on the object, then erase the object entry. - if (it->second.Empty()) { - creating_task_entry->second.erase(it); - // Remove the task that creates this object if there are no more object - // dependencies created by the task. - if (creating_task_entry->second.empty()) { - required_tasks_.erase(creating_task_entry); - } - } - } - - // These dependencies are no longer required by the given task. Cancel any - // in-progress operations to make them local. - for (const auto &object_id : task_entry.get_dependencies) { - HandleRemoteDependencyCanceled(object_id); - } - - return true; -} - -void TaskDependencyManager::UnsubscribeWaitDependencies(const WorkerID &worker_id) { - RAY_LOG(DEBUG) << "Worker " << worker_id << " no longer blocked"; - // Remove the task from the table of subscribed tasks. - auto it = worker_dependencies_.find(worker_id); - if (it == worker_dependencies_.end()) { - return; - } - const WorkerDependencies worker_entry = std::move(it->second); - worker_dependencies_.erase(it); - - // Remove the task's dependencies. - for (const auto &object_id : worker_entry) { - // Get the ID of the task that creates the dependency. - TaskID creating_task_id = object_id.TaskId(); - auto creating_task_entry = required_tasks_.find(creating_task_id); - // Remove the worker from the list of workers that are dependent on this - // object. - auto it = creating_task_entry->second.find(object_id); - RAY_CHECK(it != creating_task_entry->second.end()); - RAY_CHECK(it->second.dependent_workers.erase(worker_id) > 0); - // If nothing else depends on the object, then erase the object entry. - if (it->second.Empty()) { - creating_task_entry->second.erase(it); - // Remove the task that creates this object if there are no more object - // dependencies created by the task. - if (creating_task_entry->second.empty()) { - required_tasks_.erase(creating_task_entry); - } - } - } - - // These dependencies are no longer required by the given task. Cancel any - // in-progress operations to make them local. - for (const auto &object_id : worker_entry) { - HandleRemoteDependencyCanceled(object_id); - } -} - -void TaskDependencyManager::TaskPending(const Task &task) { - // Direct tasks are not tracked by the raylet. - // NOTE(zhijunfu): Direct tasks are not tracked by the raylet, - // but we still need raylet to reconstruct the actors. - // For direct actor creation task: - // - Initially the caller leases a worker from raylet and - // then pushes actor creation task directly to the worker, - // thus it doesn't need task lease. And actually if we - // acquire a lease in this case and forget to cancel it, - // the lease would never expire which will prevent the - // actor from being restarted; - // - When a direct actor is restarted, raylet resubmits - // the task, and the task can be forwarded to another raylet, - // and eventually assigned to a worker. In this case we need - // the task lease to make sure there's only one raylet can - // resubmit the task. - // - // We can use `OnDispatch` to differeniate whether this task is - // a worker lease request. - // For direct actor creation task: - // - when it's submitted by core worker, we guarantee that - // we always request a new worker lease, in that case - // `OnDispatch` is overridden to an actual callback. - // - when it's resubmitted by raylet because of reconstruction, - // `OnDispatch` will not be overridden and thus is nullptr. - if (task.GetTaskSpecification().IsActorCreationTask() && task.OnDispatch() == nullptr) { - // This is an actor creation task, and it's being restarted, - // in this case we still need the task lease. Note that we don't - // require task lease for direct actor creation task. - } else { - return; - } - - TaskID task_id = task.GetTaskSpecification().TaskId(); - RAY_LOG(DEBUG) << "Task execution " << task_id << " pending"; - - // Record that the task is pending execution. - auto inserted = pending_tasks_.insert(task_id); - if (inserted.second) { - // This is the first time we've heard that this task is pending. Find any - // subscribed tasks that are dependent on objects created by the pending - // task. - auto remote_task_entry = required_tasks_.find(task_id); - if (remote_task_entry != required_tasks_.end()) { - for (const auto &object_entry : remote_task_entry->second) { - // This object created by the pending task will appear locally once the - // task completes execution. Cancel any in-progress operations to make - // the object local. - HandleRemoteDependencyCanceled(object_entry.first); - } - } - } -} - -void TaskDependencyManager::TaskCanceled(const TaskID &task_id) { - RAY_LOG(DEBUG) << "Task execution " << task_id << " canceled"; - // Record that the task is no longer pending execution. - auto it = pending_tasks_.find(task_id); - if (it == pending_tasks_.end()) { - return; - } - pending_tasks_.erase(it); - - // Find any subscribed tasks that are dependent on objects created by the - // canceled task. - auto remote_task_entry = required_tasks_.find(task_id); - if (remote_task_entry != required_tasks_.end()) { - for (const auto &object_entry : remote_task_entry->second) { - // This object created by the task will no longer appear locally since - // the task is canceled. Try to make the object local if necessary. - HandleRemoteDependencyRequired(object_entry.first); - } - } -} - -void TaskDependencyManager::RemoveTasksAndRelatedObjects( - const std::unordered_set &task_ids) { - // Collect a list of all the unique objects that these tasks were subscribed - // to. - std::unordered_set required_objects; - for (auto it = task_ids.begin(); it != task_ids.end(); it++) { - auto task_it = task_dependencies_.find(*it); - if (task_it != task_dependencies_.end()) { - // Add the objects that this task was subscribed to. - required_objects.insert(task_it->second.get_dependencies.begin(), - task_it->second.get_dependencies.end()); - } - // The task no longer depends on anything. - task_dependencies_.erase(*it); - // The task is no longer pending execution. - pending_tasks_.erase(*it); - } - - // Cancel all of the objects that were required by the removed tasks. - for (const auto &object_id : required_objects) { - TaskID creating_task_id = object_id.TaskId(); - required_tasks_.erase(creating_task_id); - HandleRemoteDependencyCanceled(object_id); - } - - // Make sure that the tasks in task_ids no longer have tasks dependent on - // them. - for (const auto &task_id : task_ids) { - RAY_CHECK(required_tasks_.find(task_id) == required_tasks_.end()) - << "RemoveTasksAndRelatedObjects was called on " << task_id - << ", but another task depends on it that was not included in the argument"; - } -} - -std::string TaskDependencyManager::DebugString() const { - std::stringstream result; - result << "TaskDependencyManager:"; - result << "\n- task dep map size: " << task_dependencies_.size(); - result << "\n- task req map size: " << required_tasks_.size(); - result << "\n- req objects map size: " << required_objects_.size(); - result << "\n- local objects map size: " << local_objects_.size(); - result << "\n- pending tasks map size: " << pending_tasks_.size(); - return result.str(); -} - -bool TaskDependencyManager::GetOwnerAddress(const ObjectID &object_id, - rpc::Address *owner_address) const { - const auto creating_task_entry = required_tasks_.find(object_id.TaskId()); - if (creating_task_entry == required_tasks_.end()) { - return false; - } - - const auto it = creating_task_entry->second.find(object_id); - if (it == creating_task_entry->second.end()) { - return false; - } - - *owner_address = it->second.owner_address; - return !owner_address->worker_id().empty(); -} - -} // namespace raylet - -} // namespace ray diff --git a/src/ray/raylet/task_dependency_manager.h b/src/ray/raylet/task_dependency_manager.h deleted file mode 100644 index eb2c53ee9..000000000 --- a/src/ray/raylet/task_dependency_manager.h +++ /dev/null @@ -1,260 +0,0 @@ -// Copyright 2017 The Ray Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -// clang-format off -#include "ray/common/id.h" -#include "ray/common/task/task.h" -#include "ray/object_manager/object_manager.h" -#include "ray/raylet/reconstruction_policy.h" -// clang-format on - -namespace ray { - -namespace raylet { - -using rpc::TaskLeaseData; - -class ReconstructionPolicy; - -/// Used for unit-testing the ClusterTaskManager, which calls these methods for -/// locally queued tasks that have dependencies. -class TaskDependencyManagerInterface { - public: - virtual bool SubscribeGetDependencies( - const TaskID &task_id, - const std::vector &required_objects) = 0; - virtual bool IsTaskReady(const TaskID &task_id) const = 0; - virtual bool UnsubscribeGetDependencies(const TaskID &task_id) = 0; - virtual ~TaskDependencyManagerInterface() {} -}; - -/// \class TaskDependencyManager -/// -/// Responsible for managing object dependencies for tasks. The caller can -/// subscribe to object dependencies for a task. The task manager will -/// determine which object dependencies are remote. These are the objects that -/// are neither in the local object store, nor will they be created by a -/// locally queued task. The task manager will request that these objects be -/// made available locally, either by object transfer from a remote node or -/// reconstruction. The task manager will also cancel these objects if they are -/// no longer needed by any task. -class TaskDependencyManager : public TaskDependencyManagerInterface { - public: - /// Create a task dependency manager. - TaskDependencyManager(ObjectManagerInterface &object_manager, - ReconstructionPolicyInterface &reconstruction_policy); - - /// Check whether an object is locally available. - /// - /// \param object_id The object to check for. - /// \return Whether the object is local. - bool CheckObjectLocal(const ObjectID &object_id) const; - - /// Subscribe to object depedencies required by the task and check whether - /// all dependencies are fulfilled. This should be called for task arguments and - /// `ray.get` calls during task execution. - /// - /// The TaskDependencyManager will track the task's dependencies - /// until UnsubscribeGetDependencies is called on the same task ID. If any - /// dependencies are remote, then they will be requested. When the last - /// remote dependency later appears locally via a call to HandleObjectLocal, - /// the subscribed task will be returned by the HandleObjectLocal call, - /// signifying that it is ready to run. This method may be called multiple - /// times per task. - /// - /// \param task_id The ID of the task whose dependencies to subscribe to. - /// \param required_objects The objects required by the task. - /// \return Whether all of the given dependencies for the given task are - /// local. - bool SubscribeGetDependencies( - const TaskID &task_id, const std::vector &required_objects); - - /// Check whether a task is ready to run. The task ID must - /// have been previously subscribed by the caller. - /// - /// \param task_id The ID of the task to check. - /// \return Whether all of the dependencies for the task are - /// local. - bool IsTaskReady(const TaskID &task_id) const; - - /// Subscribe to object depedencies required by the worker. This should be called for - /// ray.wait calls during task execution. - /// - /// The TaskDependencyManager will track all remote dependencies until the - /// dependencies are local, or until UnsubscribeWaitDependencies is called - /// with the same worker ID, whichever occurs first. Remote dependencies will - /// be requested. This method may be called multiple times per worker on the - /// same objects. - /// - /// \param worker_id The ID of the worker that called `ray.wait`. - /// \param required_objects The objects required by the worker. - /// \return Void. - void SubscribeWaitDependencies( - const WorkerID &worker_id, - const std::vector &required_objects); - - /// Unsubscribe from the object dependencies required by this task through the task - /// arguments or `ray.get`. If the objects were remote and are no longer required by any - /// subscribed task, then they will be canceled. - /// - /// \param task_id The ID of the task whose dependencies we should unsubscribe from. - /// \return Whether the task was subscribed before. - bool UnsubscribeGetDependencies(const TaskID &task_id); - - /// Unsubscribe from the object dependencies required by this worker through `ray.wait`. - /// If the objects were remote and are no longer required by any subscribed task, then - /// they will be canceled. - /// - /// \param worker_id The ID of the worker whose dependencies we should unsubscribe from. - /// \return The objects that the worker was waiting on. - void UnsubscribeWaitDependencies(const WorkerID &worker_id); - - /// Mark that the given task is pending execution. Any objects that it creates - /// are now considered to be pending creation. If there are any subscribed - /// tasks that depend on these objects, then the objects will be canceled. - /// - /// \param task The task that is pending execution. - void TaskPending(const Task &task); - - /// Mark that the given task is no longer pending execution. Any objects that - /// it creates that are not already local are now considered to be remote. If - /// there are any subscribed tasks that depend on these objects, then the - /// objects will be requested. - /// - /// \param task_id The ID of the task to cancel. - void TaskCanceled(const TaskID &task_id); - - /// Handle an object becoming locally available. If there are any subscribed - /// tasks that depend on this object, then the object will be canceled. - /// - /// \param object_id The object ID of the object to mark as locally - /// available. - /// \return A list of task IDs. This contains all subscribed tasks that now - /// have all of their dependencies fulfilled, once this object was made - /// local. - std::vector HandleObjectLocal(const ray::ObjectID &object_id); - - /// Handle an object that is no longer locally available. If there are any - /// subscribed tasks that depend on this object, then the object will be - /// requested. - /// - /// \param object_id The object ID of the object that was previously locally - /// available. - /// \return A list of task IDs. This contains all subscribed tasks that - /// previously had all of their dependencies fulfilled, but are now missing - /// this object dependency. - std::vector HandleObjectMissing(const ray::ObjectID &object_id); - - /// Remove all of the tasks specified. These tasks will no longer be - /// considered pending and the objects they depend on will no longer be - /// required. - /// - /// \param task_ids The collection of task IDs. For a given task in this set, - /// all tasks that depend on the task must also be included in the set. - void RemoveTasksAndRelatedObjects(const std::unordered_set &task_ids); - - /// Returns debug string for class. - /// - /// \return string. - std::string DebugString() const; - - /// Get the address of the owner of this object. An address will only be - /// returned if the caller previously specified that this object is required - /// on this node, through a call to SubscribeGetDependencies or - /// SubscribeWaitDependencies. - /// - /// \param[in] object_id The object whose owner to get. - /// \param[out] owner_address The address of the object's owner, if - /// available. - /// \return True if we have owner information for the object. - bool GetOwnerAddress(const ObjectID &object_id, rpc::Address *owner_address) const; - - private: - struct ObjectDependencies { - ObjectDependencies(const rpc::ObjectReference &ref) - : owner_address(ref.owner_address()) {} - /// The tasks that depend on this object, either because the object is a task argument - /// or because the task called `ray.get` on the object. - std::unordered_set dependent_tasks; - /// The workers that depend on this object because they called `ray.wait` on the - /// object. - std::unordered_set dependent_workers; - /// The address of the worker that owns this object. - rpc::Address owner_address; - - bool Empty() const { return dependent_tasks.empty() && dependent_workers.empty(); } - }; - - /// A struct to represent the object dependencies of a task. - struct TaskDependencies { - /// The objects that the task depends on. These are either the arguments to - /// the task or objects that the task calls `ray.get` on. These must be - /// local before the task is ready to execute. Objects are removed from - /// this set once UnsubscribeGetDependencies is called. - std::unordered_set get_dependencies; - /// The number of object arguments that are not available locally. This - /// must be zero before the task is ready to execute. - int64_t num_missing_get_dependencies; - }; - - /// The objects that the worker is fetching. These are objects that a task that executed - /// or is executing on the worker called `ray.wait` on that are not yet local. An object - /// will be automatically removed from this set once it becomes local. - using WorkerDependencies = std::unordered_set; - - /// Check whether the given object needs to be made available through object - /// transfer or reconstruction. These are objects for which: (1) there is a - /// subscribed task dependent on it, (2) the object is not local, and (3) the - /// task that creates the object is not pending execution locally. - bool CheckObjectRequired(const ObjectID &object_id, rpc::Address *owner_address) const; - /// If the given object is required, then request that the object be made - /// available through object transfer or reconstruction. - void HandleRemoteDependencyRequired(const ObjectID &object_id); - /// If the given object is no longer required, then cancel any in-progress - /// operations to make the object available through object transfer or - /// reconstruction. - void HandleRemoteDependencyCanceled(const ObjectID &object_id); - - /// The object manager, used to fetch required objects from remote nodes. - ObjectManagerInterface &object_manager_; - /// The reconstruction policy, used to reconstruct required objects that no - /// longer exist on any live nodes. - ReconstructionPolicyInterface &reconstruction_policy_; - /// A mapping from task ID of each subscribed task to its list of object - /// dependencies, either task arguments or objects passed into `ray.get`. - std::unordered_map task_dependencies_; - /// A mapping from worker ID to each object that the worker called `ray.wait` on. - std::unordered_map worker_dependencies_; - /// All tasks whose outputs are required by a subscribed task. This is a - /// mapping from task ID to information about the objects that the task - /// creates, either by return value or by `ray.put`. For each object, we - /// store the IDs of the subscribed tasks that are dependent on the object. - std::unordered_map> - required_tasks_; - /// Objects that are required by a subscribed task, are not local, and are - /// not created by a pending task. For these objects, there are pending - /// operations to make the object available. - std::unordered_set required_objects_; - /// The set of locally available objects. - std::unordered_set local_objects_; - /// The set of tasks that are pending execution. Any objects created by these - /// tasks that are not already local are pending creation. - std::unordered_set pending_tasks_; -}; - -} // namespace raylet - -} // namespace ray diff --git a/src/ray/raylet/task_dependency_manager_test.cc b/src/ray/raylet/task_dependency_manager_test.cc deleted file mode 100644 index d65b0aced..000000000 --- a/src/ray/raylet/task_dependency_manager_test.cc +++ /dev/null @@ -1,559 +0,0 @@ -// Copyright 2017 The Ray Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "ray/raylet/task_dependency_manager.h" - -#include -#include - -#include "gmock/gmock.h" -#include "gtest/gtest.h" -#include "ray/common/task/task_util.h" -#include "ray/common/test_util.h" - -namespace ray { - -namespace raylet { - -using ::testing::_; - -const static JobID kDefaultJobId = JobID::FromInt(1); - -const static TaskID kDefaultDriverTaskId = TaskID::ForDriverTask(kDefaultJobId); - -class MockObjectManager : public ObjectManagerInterface { - public: - MOCK_METHOD2(Pull, - ray::Status(const ObjectID &object_id, const rpc::Address &owner_address)); - MOCK_METHOD1(CancelPull, void(const ObjectID &object_id)); -}; - -class MockReconstructionPolicy : public ReconstructionPolicyInterface { - public: - MOCK_METHOD2(ListenAndMaybeReconstruct, - void(const ObjectID &object_id, const rpc::Address &owner_address)); - MOCK_METHOD1(Cancel, void(const ObjectID &object_id)); -}; - -class TaskDependencyManagerTest : public ::testing::Test { - public: - TaskDependencyManagerTest() - : object_manager_mock_(), - reconstruction_policy_mock_(), - task_dependency_manager_(object_manager_mock_, reconstruction_policy_mock_) {} - - protected: - MockObjectManager object_manager_mock_; - MockReconstructionPolicy reconstruction_policy_mock_; - TaskDependencyManager task_dependency_manager_; -}; - -static inline Task ExampleTask(const std::vector &arguments, - uint64_t num_returns) { - TaskSpecBuilder builder; - rpc::Address address; - builder.SetCommonTaskSpec(RandomTaskId(), "example_task", Language::PYTHON, - FunctionDescriptorBuilder::BuildPython("", "", "", ""), - JobID::Nil(), RandomTaskId(), 0, RandomTaskId(), address, - num_returns, {}, {}, - std::make_pair(PlacementGroupID::Nil(), -1), true, ""); - builder.SetActorCreationTaskSpec(ActorID::Nil(), 1, 1, {}, 1, false, "", false); - for (const auto &arg : arguments) { - builder.AddArg(TaskArgByReference(arg, rpc::Address())); - } - rpc::TaskExecutionSpec execution_spec_message; - execution_spec_message.set_num_forwards(1); - return Task(builder.Build(), TaskExecutionSpecification(execution_spec_message)); -} - -std::vector MakeTaskChain(int chain_size, - const std::vector &initial_arguments, - int64_t num_returns) { - std::vector task_chain; - std::vector arguments = initial_arguments; - for (int i = 0; i < chain_size; i++) { - auto task = ExampleTask(arguments, num_returns); - task_chain.push_back(task); - arguments.clear(); - for (size_t j = 0; j < task.GetTaskSpecification().NumReturns(); j++) { - arguments.push_back(task.GetTaskSpecification().ReturnId(j)); - } - } - return task_chain; -} - -TEST_F(TaskDependencyManagerTest, TestSimpleTask) { - // Create a task with 3 arguments. - int num_arguments = 3; - std::vector arguments; - for (int i = 0; i < num_arguments; i++) { - arguments.push_back(ObjectID::FromRandom()); - } - TaskID task_id = RandomTaskId(); - // No objects have been registered in the task dependency manager, so all - // arguments should be remote. - for (const auto &argument_id : arguments) { - EXPECT_CALL(object_manager_mock_, Pull(argument_id, _)); - EXPECT_CALL(reconstruction_policy_mock_, ListenAndMaybeReconstruct(argument_id, _)); - } - // Subscribe to the task's dependencies. - bool ready = task_dependency_manager_.SubscribeGetDependencies( - task_id, ObjectIdsToRefs(arguments)); - ASSERT_FALSE(ready); - - // All arguments should be canceled as they become available locally. - for (const auto &argument_id : arguments) { - EXPECT_CALL(object_manager_mock_, CancelPull(argument_id)); - EXPECT_CALL(reconstruction_policy_mock_, Cancel(argument_id)); - } - // For each argument except the last, tell the task dependency manager that - // the argument is local. - int i = 0; - for (; i < num_arguments - 1; i++) { - auto ready_task_ids = task_dependency_manager_.HandleObjectLocal(arguments[i]); - ASSERT_TRUE(ready_task_ids.empty()); - } - // Tell the task dependency manager that the last argument is local. Now the - // task should be ready to run. - auto ready_task_ids = task_dependency_manager_.HandleObjectLocal(arguments[i]); - ASSERT_EQ(ready_task_ids.size(), 1); - ASSERT_EQ(ready_task_ids.front(), task_id); -} - -TEST_F(TaskDependencyManagerTest, TestDuplicateSubscribeGetDependencies) { - // Create a task with 3 arguments. - TaskID task_id = RandomTaskId(); - int num_arguments = 3; - std::vector arguments; - for (int i = 0; i < num_arguments; i++) { - // Add the new argument to the list of dependencies to subscribe to. - ObjectID argument_id = ObjectID::FromRandom(); - arguments.push_back(argument_id); - // Subscribe to the task's dependencies. All arguments except the last are - // duplicates of previous subscription calls. Each argument should only be - // requested from the node manager once. - EXPECT_CALL(object_manager_mock_, Pull(argument_id, _)); - EXPECT_CALL(reconstruction_policy_mock_, ListenAndMaybeReconstruct(argument_id, _)); - bool ready = task_dependency_manager_.SubscribeGetDependencies( - task_id, ObjectIdsToRefs(arguments)); - ASSERT_FALSE(ready); - } - - // All arguments should be canceled as they become available locally. - for (const auto &argument_id : arguments) { - EXPECT_CALL(object_manager_mock_, CancelPull(argument_id)); - EXPECT_CALL(reconstruction_policy_mock_, Cancel(argument_id)); - } - // For each argument except the last, tell the task dependency manager that - // the argument is local. - int i = 0; - for (; i < num_arguments - 1; i++) { - auto ready_task_ids = task_dependency_manager_.HandleObjectLocal(arguments[i]); - ASSERT_TRUE(ready_task_ids.empty()); - } - // Tell the task dependency manager that the last argument is local. Now the - // task should be ready to run. - auto ready_task_ids = task_dependency_manager_.HandleObjectLocal(arguments[i]); - ASSERT_EQ(ready_task_ids.size(), 1); - ASSERT_EQ(ready_task_ids.front(), task_id); -} - -TEST_F(TaskDependencyManagerTest, TestMultipleTasks) { - // Create 3 tasks that are dependent on the same object. - ObjectID argument_id = ObjectID::FromRandom(); - std::vector dependent_tasks; - int num_dependent_tasks = 3; - // The object should only be requested from the object manager once for all - // three tasks. - EXPECT_CALL(object_manager_mock_, Pull(argument_id, _)); - EXPECT_CALL(reconstruction_policy_mock_, ListenAndMaybeReconstruct(argument_id, _)); - for (int i = 0; i < num_dependent_tasks; i++) { - TaskID task_id = RandomTaskId(); - dependent_tasks.push_back(task_id); - // Subscribe to each of the task's dependencies. - bool ready = task_dependency_manager_.SubscribeGetDependencies( - task_id, ObjectIdsToRefs({argument_id})); - ASSERT_FALSE(ready); - } - - // Tell the task dependency manager that the object is local. - EXPECT_CALL(object_manager_mock_, CancelPull(argument_id)); - EXPECT_CALL(reconstruction_policy_mock_, Cancel(argument_id)); - auto ready_task_ids = task_dependency_manager_.HandleObjectLocal(argument_id); - // Check that all tasks are now ready to run. - ASSERT_EQ(ready_task_ids.size(), dependent_tasks.size()); - for (const auto &task_id : ready_task_ids) { - ASSERT_NE(std::find(dependent_tasks.begin(), dependent_tasks.end(), task_id), - dependent_tasks.end()); - } -} - -TEST_F(TaskDependencyManagerTest, TestTaskChain) { - // Create 3 tasks, each dependent on the previous. The first task has no - // arguments. - int num_tasks = 3; - auto tasks = MakeTaskChain(num_tasks, {}, 1); - int num_ready_tasks = 1; - int i = 0; - // No objects should be remote or canceled since each task depends on a - // locally queued task. - EXPECT_CALL(object_manager_mock_, Pull(_, _)).Times(0); - EXPECT_CALL(reconstruction_policy_mock_, ListenAndMaybeReconstruct(_, _)).Times(0); - EXPECT_CALL(object_manager_mock_, CancelPull(_)).Times(0); - EXPECT_CALL(reconstruction_policy_mock_, Cancel(_)).Times(0); - for (const auto &task : tasks) { - // Subscribe to each of the tasks' arguments. - const auto &arguments = task.GetDependencies(); - bool ready = task_dependency_manager_.SubscribeGetDependencies( - task.GetTaskSpecification().TaskId(), arguments); - if (i < num_ready_tasks) { - // The first task should be ready to run since it has no arguments. - ASSERT_TRUE(ready); - } else { - // All remaining tasks depend on the previous task. - ASSERT_FALSE(ready); - } - - // Mark each task as pending. - task_dependency_manager_.TaskPending(task); - - i++; - } - - // Simulate executing each task. Each task's completion should make the next - // task runnable. - while (!tasks.empty()) { - auto task = tasks.front(); - tasks.erase(tasks.begin()); - TaskID task_id = task.GetTaskSpecification().TaskId(); - auto return_id = task.GetTaskSpecification().ReturnId(0); - - task_dependency_manager_.UnsubscribeGetDependencies(task_id); - // Simulate the object notifications for the task's return values. - auto ready_tasks = task_dependency_manager_.HandleObjectLocal(return_id); - if (tasks.empty()) { - // If there are no more tasks, then there should be no more tasks that - // become ready to run. - ASSERT_TRUE(ready_tasks.empty()); - } else { - // If there are more tasks to run, then the next task in the chain should - // now be ready to run. - ASSERT_EQ(ready_tasks.size(), 1); - ASSERT_EQ(ready_tasks.front(), tasks.front().GetTaskSpecification().TaskId()); - } - // Simulate the task finishing execution. - task_dependency_manager_.TaskCanceled(task_id); - } -} - -TEST_F(TaskDependencyManagerTest, TestDependentPut) { - // Create a task with 3 arguments. - auto task1 = ExampleTask({}, 0); - ObjectID put_id = - ObjectID::FromIndex(task1.GetTaskSpecification().TaskId(), /*index=*/1); - auto task2 = ExampleTask({put_id}, 0); - - // No objects have been registered in the task dependency manager, so the put - // object should be remote. - EXPECT_CALL(object_manager_mock_, Pull(put_id, _)); - EXPECT_CALL(reconstruction_policy_mock_, ListenAndMaybeReconstruct(put_id, _)); - // Subscribe to the task's dependencies. - bool ready = task_dependency_manager_.SubscribeGetDependencies( - task2.GetTaskSpecification().TaskId(), ObjectIdsToRefs({put_id})); - ASSERT_FALSE(ready); - - // The put object should be considered local as soon as the task that creates - // it is pending execution. - EXPECT_CALL(object_manager_mock_, CancelPull(put_id)); - EXPECT_CALL(reconstruction_policy_mock_, Cancel(put_id)); - task_dependency_manager_.TaskPending(task1); -} - -TEST_F(TaskDependencyManagerTest, TestTaskForwarding) { - // Create 2 tasks, one dependent on the other. The first has no arguments. - int num_tasks = 2; - auto tasks = MakeTaskChain(num_tasks, {}, 1); - for (const auto &task : tasks) { - // Subscribe to each of the tasks' arguments. - const auto &arguments = task.GetDependencies(); - static_cast(task_dependency_manager_.SubscribeGetDependencies( - task.GetTaskSpecification().TaskId(), arguments)); - task_dependency_manager_.TaskPending(task); - } - - // Get the first task. - const auto task = tasks.front(); - TaskID task_id = task.GetTaskSpecification().TaskId(); - ObjectID return_id = task.GetTaskSpecification().ReturnId(0); - // Simulate forwarding the first task to a remote node. - task_dependency_manager_.UnsubscribeGetDependencies(task_id); - // The object returned by the first task should be considered remote once we - // cancel the forwarded task, since the second task depends on it. - EXPECT_CALL(object_manager_mock_, Pull(return_id, _)); - EXPECT_CALL(reconstruction_policy_mock_, ListenAndMaybeReconstruct(return_id, _)); - task_dependency_manager_.TaskCanceled(task_id); - - // Simulate the task executing on a remote node and its return value - // appearing locally. - EXPECT_CALL(object_manager_mock_, CancelPull(return_id)); - EXPECT_CALL(reconstruction_policy_mock_, Cancel(return_id)); - auto ready_tasks = task_dependency_manager_.HandleObjectLocal(return_id); - // Check that the task that we kept is now ready to run. - ASSERT_EQ(ready_tasks.size(), 1); - ASSERT_EQ(ready_tasks.front(), tasks.back().GetTaskSpecification().TaskId()); -} - -TEST_F(TaskDependencyManagerTest, TestEviction) { - // Create a task with 3 arguments. - int num_arguments = 3; - std::vector arguments; - for (int i = 0; i < num_arguments; i++) { - arguments.push_back(ObjectID::FromRandom()); - } - TaskID task_id = RandomTaskId(); - // No objects have been registered in the task dependency manager, so all - // arguments should be remote. - for (const auto &argument_id : arguments) { - EXPECT_CALL(object_manager_mock_, Pull(argument_id, _)); - EXPECT_CALL(reconstruction_policy_mock_, ListenAndMaybeReconstruct(argument_id, _)); - } - // Subscribe to the task's dependencies. - bool ready = task_dependency_manager_.SubscribeGetDependencies( - task_id, ObjectIdsToRefs(arguments)); - ASSERT_FALSE(ready); - - // Tell the task dependency manager that each of the arguments is now - // available. - for (const auto &argument_id : arguments) { - EXPECT_CALL(object_manager_mock_, CancelPull(argument_id)); - EXPECT_CALL(reconstruction_policy_mock_, Cancel(argument_id)); - } - for (size_t i = 0; i < arguments.size(); i++) { - std::vector ready_tasks; - ready_tasks = task_dependency_manager_.HandleObjectLocal(arguments[i]); - if (i == arguments.size() - 1) { - ASSERT_EQ(ready_tasks.size(), 1); - ASSERT_EQ(ready_tasks.front(), task_id); - } else { - ASSERT_TRUE(ready_tasks.empty()); - } - } - - // Simulate each of the arguments getting evicted. Each object should now be - // considered remote. - for (const auto &argument_id : arguments) { - EXPECT_CALL(object_manager_mock_, Pull(argument_id, _)); - EXPECT_CALL(reconstruction_policy_mock_, ListenAndMaybeReconstruct(argument_id, _)); - } - for (size_t i = 0; i < arguments.size(); i++) { - std::vector waiting_tasks; - waiting_tasks = task_dependency_manager_.HandleObjectMissing(arguments[i]); - if (i == 0) { - // The first eviction should cause the task to go back to the waiting - // state. - ASSERT_EQ(waiting_tasks.size(), 1); - ASSERT_EQ(waiting_tasks.front(), task_id); - } else { - // The subsequent evictions shouldn't cause any more tasks to go back to - // the waiting state. - ASSERT_TRUE(waiting_tasks.empty()); - } - } - - // Tell the task dependency manager that each of the arguments is available - // again. - for (const auto &argument_id : arguments) { - EXPECT_CALL(object_manager_mock_, CancelPull(argument_id)); - EXPECT_CALL(reconstruction_policy_mock_, Cancel(argument_id)); - } - for (size_t i = 0; i < arguments.size(); i++) { - std::vector ready_tasks; - ready_tasks = task_dependency_manager_.HandleObjectLocal(arguments[i]); - if (i == arguments.size() - 1) { - ASSERT_EQ(ready_tasks.size(), 1); - ASSERT_EQ(ready_tasks.front(), task_id); - } else { - ASSERT_TRUE(ready_tasks.empty()); - } - } -} - -TEST_F(TaskDependencyManagerTest, TestRemoveTasksAndRelatedObjects) { - // Create 3 tasks, each dependent on the previous. The first task has no - // arguments. - int num_tasks = 3; - auto tasks = MakeTaskChain(num_tasks, {}, 1); - // No objects should be remote or canceled since each task depends on a - // locally queued task. - EXPECT_CALL(object_manager_mock_, Pull(_, _)).Times(0); - EXPECT_CALL(reconstruction_policy_mock_, ListenAndMaybeReconstruct(_, _)).Times(0); - EXPECT_CALL(object_manager_mock_, CancelPull(_)).Times(0); - EXPECT_CALL(reconstruction_policy_mock_, Cancel(_)).Times(0); - for (const auto &task : tasks) { - // Subscribe to each of the tasks' arguments. - const auto &arguments = task.GetDependencies(); - task_dependency_manager_.SubscribeGetDependencies( - task.GetTaskSpecification().TaskId(), arguments); - // Mark each task as pending. - task_dependency_manager_.TaskPending(task); - } - - // Simulate executing the first task. This should make the second task - // runnable. - auto task = tasks.front(); - TaskID task_id = task.GetTaskSpecification().TaskId(); - auto return_id = task.GetTaskSpecification().ReturnId(0); - task_dependency_manager_.UnsubscribeGetDependencies(task_id); - // Simulate the object notifications for the task's return values. - auto ready_tasks = task_dependency_manager_.HandleObjectLocal(return_id); - // The second task should be ready to run. - ASSERT_EQ(ready_tasks.size(), 1); - // Simulate the task finishing execution. - task_dependency_manager_.TaskCanceled(task_id); - - // Remove all tasks from the manager except the first task, which already - // finished executing. - std::unordered_set task_ids; - for (const auto &task : tasks) { - task_ids.insert(task.GetTaskSpecification().TaskId()); - } - task_ids.erase(task_id); - task_dependency_manager_.RemoveTasksAndRelatedObjects(task_ids); - // Simulate evicting the return value of the first task. Make sure that this - // does not return the second task, which should have been removed. - auto waiting_tasks = task_dependency_manager_.HandleObjectMissing(return_id); - ASSERT_TRUE(waiting_tasks.empty()); - - // Simulate the object notifications for the second task's return values. - // Make sure that this does not return the third task, which should have been - // removed. - return_id = tasks[1].GetTaskSpecification().ReturnId(0); - ready_tasks = task_dependency_manager_.HandleObjectLocal(return_id); - ASSERT_TRUE(ready_tasks.empty()); -} - -/// Test that when no objects are locally available, a `ray.wait` call makes -/// the correct requests to remote nodes and correctly cancels the requests -/// when the `ray.wait` call is canceled. -TEST_F(TaskDependencyManagerTest, TestWaitDependencies) { - // Generate a random worker and objects to wait on. - WorkerID worker_id = WorkerID::FromRandom(); - int num_objects = 3; - std::vector wait_object_ids; - for (int i = 0; i < num_objects; i++) { - wait_object_ids.push_back(ObjectID::FromRandom()); - } - // Simulate a worker calling `ray.wait` on some objects. - EXPECT_CALL(object_manager_mock_, Pull(_, _)).Times(num_objects); - EXPECT_CALL(reconstruction_policy_mock_, ListenAndMaybeReconstruct(_, _)) - .Times(num_objects); - task_dependency_manager_.SubscribeWaitDependencies(worker_id, - ObjectIdsToRefs(wait_object_ids)); - // Check that it's okay to call `ray.wait` on the same objects again. No new - // calls should be made to try and make the objects local. - task_dependency_manager_.SubscribeWaitDependencies(worker_id, - ObjectIdsToRefs(wait_object_ids)); - // Cancel the worker's `ray.wait`. calls. - EXPECT_CALL(object_manager_mock_, CancelPull(_)).Times(num_objects); - EXPECT_CALL(reconstruction_policy_mock_, Cancel(_)).Times(num_objects); - task_dependency_manager_.UnsubscribeWaitDependencies(worker_id); -} - -/// Test that when one of the objects is already local at the time of the -/// `ray.wait` call, the `ray.wait` call does not trigger any requests to -/// remote nodes for that object. -TEST_F(TaskDependencyManagerTest, TestWaitDependenciesObjectLocal) { - // Generate a random worker and objects to wait on. - WorkerID worker_id = WorkerID::FromRandom(); - int num_objects = 3; - std::vector wait_object_ids; - for (int i = 0; i < num_objects; i++) { - wait_object_ids.push_back(ObjectID::FromRandom()); - } - // Simulate one of the objects becoming local. The later `ray.wait` call - // should have no effect because the object is already local. - const ObjectID local_object_id = std::move(wait_object_ids.back()); - auto ready_task_ids = task_dependency_manager_.HandleObjectLocal(local_object_id); - ASSERT_TRUE(ready_task_ids.empty()); - - // Simulate a worker calling `ray.wait` on the objects. It should only make - // requests for the objects that are not local. - for (const auto &object_id : wait_object_ids) { - if (object_id != local_object_id) { - EXPECT_CALL(object_manager_mock_, Pull(object_id, _)); - EXPECT_CALL(reconstruction_policy_mock_, ListenAndMaybeReconstruct(object_id, _)); - } - } - task_dependency_manager_.SubscribeWaitDependencies(worker_id, - ObjectIdsToRefs(wait_object_ids)); - // Simulate the local object getting evicted. The `ray.wait` call should not - // be reactivated. - auto waiting_task_ids = task_dependency_manager_.HandleObjectMissing(local_object_id); - ASSERT_TRUE(waiting_task_ids.empty()); - // Simulate a worker calling `ray.wait` on the objects. It should only make - // requests for the objects that are not local. - for (const auto &object_id : wait_object_ids) { - if (object_id != local_object_id) { - EXPECT_CALL(object_manager_mock_, CancelPull(object_id)); - EXPECT_CALL(reconstruction_policy_mock_, Cancel(object_id)); - } - } - task_dependency_manager_.UnsubscribeWaitDependencies(worker_id); -} - -/// Test that when one of the objects becomes local after a `ray.wait` call, -/// all requests to remote nodes associated with the object are canceled. -TEST_F(TaskDependencyManagerTest, TestWaitDependenciesHandleObjectLocal) { - // Generate a random worker and objects to wait on. - WorkerID worker_id = WorkerID::FromRandom(); - int num_objects = 3; - std::vector wait_object_ids; - for (int i = 0; i < num_objects; i++) { - wait_object_ids.push_back(ObjectID::FromRandom()); - } - // Simulate a worker calling `ray.wait` on some objects. - EXPECT_CALL(object_manager_mock_, Pull(_, _)).Times(num_objects); - EXPECT_CALL(reconstruction_policy_mock_, ListenAndMaybeReconstruct(_, _)) - .Times(num_objects); - task_dependency_manager_.SubscribeWaitDependencies(worker_id, - ObjectIdsToRefs(wait_object_ids)); - // Simulate one of the objects becoming local while the `ray.wait` calls is - // active. The `ray.wait` call should be canceled. - const ObjectID local_object_id = std::move(wait_object_ids.back()); - wait_object_ids.pop_back(); - EXPECT_CALL(object_manager_mock_, CancelPull(local_object_id)); - EXPECT_CALL(reconstruction_policy_mock_, Cancel(local_object_id)); - auto ready_task_ids = task_dependency_manager_.HandleObjectLocal(local_object_id); - ASSERT_TRUE(ready_task_ids.empty()); - // Simulate the local object getting evicted. The `ray.wait` call should not - // be reactivated. - auto waiting_task_ids = task_dependency_manager_.HandleObjectMissing(local_object_id); - ASSERT_TRUE(waiting_task_ids.empty()); - // Cancel the worker's `ray.wait` calls. Only the objects that are still not - // local should be canceled. - for (const auto &object_id : wait_object_ids) { - EXPECT_CALL(object_manager_mock_, CancelPull(object_id)); - EXPECT_CALL(reconstruction_policy_mock_, Cancel(object_id)); - } - task_dependency_manager_.UnsubscribeWaitDependencies(worker_id); -} - -} // namespace raylet - -} // namespace ray - -int main(int argc, char **argv) { - ::testing::InitGoogleTest(&argc, argv); - return RUN_ALL_TESTS(); -} From b11bd22111d6609cb1fe37a2fe5b399262a9fb01 Mon Sep 17 00:00:00 2001 From: Sumanth Ratna Date: Wed, 23 Dec 2020 22:09:23 -0500 Subject: [PATCH 83/88] [docs] Fix args + kwargs instead of docstrings (#13068) * functools wraps * Fix typo (functoools -> functools) --- python/ray/_private/client_mode_hook.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/ray/_private/client_mode_hook.py b/python/ray/_private/client_mode_hook.py index 4fbc568c8..9029b05c0 100644 --- a/python/ray/_private/client_mode_hook.py +++ b/python/ray/_private/client_mode_hook.py @@ -1,5 +1,6 @@ import os from contextlib import contextmanager +from functools import wraps client_mode_enabled = os.environ.get("RAY_CLIENT_MODE", "0") == "1" @@ -38,6 +39,7 @@ def client_mode_hook(func): """ from ray.experimental.client import ray + @wraps(func) def wrapper(*args, **kwargs): global _client_hook_enabled if client_mode_enabled and _client_hook_enabled: From 81bfee79bcbce9d6ff7e94bf88bcce4b65602c38 Mon Sep 17 00:00:00 2001 From: Max Fitton Date: Wed, 23 Dec 2020 20:51:50 -0800 Subject: [PATCH 84/88] Fix OS X Wheel Build - Update brew cask install (#13062) Co-authored-by: Richard Liaw --- .travis.yml | 2 +- ci/travis/ci.sh | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/.travis.yml b/.travis.yml index 8173ec1ac..9a4398525 100644 --- a/.travis.yml +++ b/.travis.yml @@ -176,7 +176,7 @@ matrix: - . ./ci/travis/ci.sh init RAY_CI_MACOS_WHEELS_AFFECTED,RAY_CI_JAVA_AFFECTED,RAY_CI_STREAMING_JAVA_AFFECTED before_script: - brew tap adoptopenjdk/openjdk - - brew cask install adoptopenjdk8 + - brew install --cask adoptopenjdk8 - export JAVA_HOME=/Library/Java/JavaVirtualMachines/adoptopenjdk-8.jdk/Contents/Home - java -version - . ./ci/travis/ci.sh build diff --git a/ci/travis/ci.sh b/ci/travis/ci.sh index 9a8c0ecbf..3638b3af3 100755 --- a/ci/travis/ci.sh +++ b/ci/travis/ci.sh @@ -300,7 +300,8 @@ build_wheels() { ;; darwin*) # This command should be kept in sync with ray/python/README-building-wheels.md. - suppress_output "${WORKSPACE_DIR}"/python/build-wheel-macos.sh + # Remove suppress_output for now to avoid timeout + "${WORKSPACE_DIR}"/python/build-wheel-macos.sh ;; msys*) keep_alive "${WORKSPACE_DIR}"/python/build-wheel-windows.sh From 85f1716a1f011501c9d5f8bf2043c0c3f356ba95 Mon Sep 17 00:00:00 2001 From: ZhuSenlin Date: Thu, 24 Dec 2020 14:59:14 +0800 Subject: [PATCH 85/88] speed up local mode object store get (#13052) Co-authored-by: senlin.zsl --- .../main/java/io/ray/runtime/object/LocalModeObjectStore.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/java/runtime/src/main/java/io/ray/runtime/object/LocalModeObjectStore.java b/java/runtime/src/main/java/io/ray/runtime/object/LocalModeObjectStore.java index 4614100ae..fed5459b4 100644 --- a/java/runtime/src/main/java/io/ray/runtime/object/LocalModeObjectStore.java +++ b/java/runtime/src/main/java/io/ray/runtime/object/LocalModeObjectStore.java @@ -21,7 +21,7 @@ public class LocalModeObjectStore extends ObjectStore { private static final Logger LOGGER = LoggerFactory.getLogger(LocalModeObjectStore.class); - private static final int GET_CHECK_INTERVAL_MS = 100; + private static final int GET_CHECK_INTERVAL_MS = 1; private final Map pool = new ConcurrentHashMap<>(); private final List> objectPutCallbacks = new ArrayList<>(); From a2d121520051de562244705d163cc6629beba6a2 Mon Sep 17 00:00:00 2001 From: Michael Luo Date: Thu, 24 Dec 2020 06:30:33 -0800 Subject: [PATCH 86/88] [RLlib] Execution Annotation (#13036) --- rllib/execution/common.py | 8 +++-- rllib/execution/concurrency_ops.py | 15 ++++++---- rllib/execution/learner_thread.py | 13 ++++---- rllib/execution/metric_ops.py | 18 +++++------ rllib/execution/minibatch_buffer.py | 13 ++++++-- rllib/execution/multi_gpu_learner.py | 31 +++++++++---------- rllib/execution/replay_buffer.py | 45 ++++++++++++++-------------- rllib/execution/replay_ops.py | 20 +++++++------ rllib/execution/segment_tree.py | 22 ++++++++------ rllib/execution/train_ops.py | 23 +++++++------- rllib/execution/tree_agg.py | 16 +++++----- 11 files changed, 126 insertions(+), 98 deletions(-) diff --git a/rllib/execution/common.py b/rllib/execution/common.py index 9e15f8b36..b12e557d5 100644 --- a/rllib/execution/common.py +++ b/rllib/execution/common.py @@ -1,5 +1,7 @@ from ray.util.iter import LocalIterator from ray.rllib.policy.sample_batch import SampleBatch, MultiAgentBatch +from ray.rllib.utils.typing import Dict, SampleBatchType +from ray.util.iter_metrics import MetricsContext # Counters for training progress (keys for metrics.counters). STEPS_SAMPLED_COUNTER = "num_steps_sampled" @@ -23,19 +25,19 @@ LEARNER_INFO = "learner" # Asserts that an object is a type of SampleBatch. -def _check_sample_batch_type(batch): +def _check_sample_batch_type(batch: SampleBatchType) -> None: if not isinstance(batch, (SampleBatch, MultiAgentBatch)): raise ValueError("Expected either SampleBatch or MultiAgentBatch, " "got {}: {}".format(type(batch), batch)) # Returns pipeline global vars that should be periodically sent to each worker. -def _get_global_vars(): +def _get_global_vars() -> Dict: metrics = LocalIterator.get_metrics() return {"timestep": metrics.counters[STEPS_SAMPLED_COUNTER]} -def _get_shared_metrics(): +def _get_shared_metrics() -> MetricsContext: """Return shared metrics for the training workflow. This only applies if this trainer has an execution plan.""" diff --git a/rllib/execution/concurrency_ops.py b/rllib/execution/concurrency_ops.py index 7b057852e..cfe326ed9 100644 --- a/rllib/execution/concurrency_ops.py +++ b/rllib/execution/concurrency_ops.py @@ -1,15 +1,17 @@ -from typing import List +from typing import List, Optional, Any import queue from ray.util.iter import LocalIterator, _NextValueNotReady from ray.util.iter_metrics import SharedMetrics +from ray.rllib.utils.typing import SampleBatchType def Concurrently(ops: List[LocalIterator], *, - mode="round_robin", - output_indexes=None, - round_robin_weights=None): + mode: str = "round_robin", + output_indexes: Optional[List[int]] = None, + round_robin_weights: Optional[List[int]] = None + ) -> LocalIterator[SampleBatchType]: """Operator that runs the given parent iterators concurrently. Args: @@ -91,7 +93,7 @@ class Enqueue: type(output_queue))) self.queue = output_queue - def __call__(self, x): + def __call__(self, x: Any) -> Any: try: self.queue.put_nowait(x) except queue.Full: @@ -99,7 +101,8 @@ class Enqueue: return x -def Dequeue(input_queue: queue.Queue, check=lambda: True): +def Dequeue(input_queue: queue.Queue, + check=lambda: True) -> LocalIterator[SampleBatchType]: """Dequeue data items from a queue.Queue instance. The dequeue is non-blocking, so Dequeue operations can executed with diff --git a/rllib/execution/learner_thread.py b/rllib/execution/learner_thread.py index 9e905148d..8f5350fa1 100644 --- a/rllib/execution/learner_thread.py +++ b/rllib/execution/learner_thread.py @@ -1,3 +1,4 @@ +from typing import Dict import threading import copy @@ -8,6 +9,7 @@ from ray.rllib.execution.minibatch_buffer import MinibatchBuffer from ray.rllib.utils.framework import try_import_tf from ray.rllib.utils.timer import TimerStat from ray.rllib.utils.window_stat import WindowStat +from ray.rllib.evaluation.rollout_worker import RolloutWorker tf1, tf, tfv = try_import_tf() @@ -21,8 +23,9 @@ class LearnerThread(threading.Thread): improves overall throughput. """ - def __init__(self, local_worker, minibatch_buffer_size, num_sgd_iter, - learner_queue_size, learner_queue_timeout): + def __init__(self, local_worker: RolloutWorker, minibatch_buffer_size: int, + num_sgd_iter: int, learner_queue_size: int, + learner_queue_timeout: int): """Initialize the learner thread. Args: @@ -57,14 +60,14 @@ class LearnerThread(threading.Thread): self.stopped = False self.num_steps = 0 - def run(self): + def run(self) -> None: # Switch on eager mode if configured. if self.local_worker.policy_config.get("framework") in ["tf2", "tfe"]: tf1.enable_eager_execution() while not self.stopped: self.step() - def step(self): + def step(self) -> None: with self.queue_timer: batch, _ = self.minibatch_buffer.get() @@ -77,7 +80,7 @@ class LearnerThread(threading.Thread): self.outqueue.put((batch.count, self.stats)) self.learner_queue_size.push(self.inqueue.qsize()) - def add_learner_metrics(self, result): + def add_learner_metrics(self, result: Dict) -> Dict: """Add internal metrics to a trainer result dict.""" def timer_to_ms(timer): diff --git a/rllib/execution/metric_ops.py b/rllib/execution/metric_ops.py index 374ff047d..70ae38e3f 100644 --- a/rllib/execution/metric_ops.py +++ b/rllib/execution/metric_ops.py @@ -1,4 +1,4 @@ -from typing import Any, List +from typing import Any, List, Dict import time from ray.util.iter import LocalIterator @@ -59,9 +59,9 @@ class CollectMetrics: """ def __init__(self, - workers, - min_history=100, - timeout_seconds=180, + workers: WorkerSet, + min_history: int = 100, + timeout_seconds: int = 180, selected_workers: List["ActorHandle"] = None): self.workers = workers self.episode_history = [] @@ -70,7 +70,7 @@ class CollectMetrics: self.timeout_seconds = timeout_seconds self.selected_workers = selected_workers - def __call__(self, _): + def __call__(self, _: Any) -> Dict: # Collect worker metrics. episodes, self.to_be_collected = collect_episodes( self.workers.local_worker(), @@ -124,11 +124,11 @@ class OncePerTimeInterval: 5.00001 # will be greater than 5 seconds """ - def __init__(self, delay): + def __init__(self, delay: int): self.delay = delay self.last_called = 0 - def __call__(self, item): + def __call__(self, item: Any) -> bool: if self.delay <= 0.0: return True now = time.time() @@ -151,11 +151,11 @@ class OncePerTimestepsElapsed: # will only return after 1000 steps have elapsed """ - def __init__(self, delay_steps): + def __init__(self, delay_steps: int): self.delay_steps = delay_steps self.last_called = 0 - def __call__(self, item): + def __call__(self, item: Any) -> bool: if self.delay_steps <= 0: return True metrics = _get_shared_metrics() diff --git a/rllib/execution/minibatch_buffer.py b/rllib/execution/minibatch_buffer.py index 4cd41fcc7..54b5c4a2c 100644 --- a/rllib/execution/minibatch_buffer.py +++ b/rllib/execution/minibatch_buffer.py @@ -1,10 +1,19 @@ +from typing import Any, Tuple +import queue + + class MinibatchBuffer: """Ring buffer of recent data batches for minibatch SGD. This is for use with AsyncSamplesOptimizer. """ - def __init__(self, inqueue, size, timeout, num_passes, init_num_passes=1): + def __init__(self, + inqueue: queue.Queue, + size: int, + timeout: float, + num_passes: int, + init_num_passes: int = 1): """Initialize a minibatch buffer. Args: @@ -23,7 +32,7 @@ class MinibatchBuffer: self.ttl = [0] * size self.idx = 0 - def get(self): + def get(self) -> Tuple[Any, bool]: """Get a new batch from the internal ring buffer. Returns: diff --git a/rllib/execution/multi_gpu_learner.py b/rllib/execution/multi_gpu_learner.py index 4c450c948..f6455dc98 100644 --- a/rllib/execution/multi_gpu_learner.py +++ b/rllib/execution/multi_gpu_learner.py @@ -12,6 +12,7 @@ from ray.rllib.execution.multi_gpu_impl import LocalSyncParallelOptimizer from ray.rllib.utils.annotations import override from ray.rllib.utils.framework import try_import_tf from ray.rllib.utils.timer import TimerStat +from ray.rllib.evaluation.rollout_worker import RolloutWorker tf1, tf, tfv = try_import_tf() @@ -25,17 +26,17 @@ class TFMultiGPULearner(LearnerThread): """ def __init__(self, - local_worker, - num_gpus=1, - lr=0.0005, - train_batch_size=500, - num_data_loader_buffers=1, - minibatch_buffer_size=1, - num_sgd_iter=1, - learner_queue_size=16, - learner_queue_timeout=300, - num_data_load_threads=16, - _fake_gpus=False): + local_worker: RolloutWorker, + num_gpus: int = 1, + lr: float = 0.0005, + train_batch_size: int = 500, + num_data_loader_buffers: int = 1, + minibatch_buffer_size: int = 1, + num_sgd_iter: int = 1, + learner_queue_size: int = 16, + learner_queue_timeout: int = 300, + num_data_load_threads: int = 16, + _fake_gpus: bool = False): """Initialize a multi-gpu learner thread. Args: @@ -121,7 +122,7 @@ class TFMultiGPULearner(LearnerThread): learner_queue_timeout, num_sgd_iter) @override(LearnerThread) - def step(self): + def step(self) -> None: assert self.loader_thread.is_alive() with self.load_wait_timer: opt, released = self.minibatch_buffer.get() @@ -139,7 +140,7 @@ class TFMultiGPULearner(LearnerThread): class _LoaderThread(threading.Thread): - def __init__(self, learner, share_stats): + def __init__(self, learner: LearnerThread, share_stats: bool): threading.Thread.__init__(self) self.learner = learner self.daemon = True @@ -150,11 +151,11 @@ class _LoaderThread(threading.Thread): self.queue_timer = TimerStat() self.load_timer = TimerStat() - def run(self): + def run(self) -> None: while True: self._step() - def _step(self): + def _step(self) -> None: s = self.learner with self.queue_timer: batch = s.inqueue.get() diff --git a/rllib/execution/replay_buffer.py b/rllib/execution/replay_buffer.py index 8d98e1397..79f4882eb 100644 --- a/rllib/execution/replay_buffer.py +++ b/rllib/execution/replay_buffer.py @@ -3,7 +3,7 @@ import logging import numpy as np import platform import random -from typing import List +from typing import List, Dict # Import ray before psutil will make sure we use psutil's bundled version import ray # noqa F401 @@ -64,11 +64,11 @@ class ReplayBuffer: self._evicted_hit_stats = WindowStat("evicted_hit", 1000) self._est_size_bytes = 0 - def __len__(self): + def __len__(self) -> int: return len(self._storage) @DeveloperAPI - def add(self, item: SampleBatchType, weight: float): + def add(self, item: SampleBatchType, weight: float) -> None: warn_replay_buffer_size( item=item, num_items=self._maxsize / item.count) assert item.count > 0, item @@ -116,7 +116,7 @@ class ReplayBuffer: return self._encode_sample(idxes) @DeveloperAPI - def stats(self, debug=False): + def stats(self, debug=False) -> dict: data = { "added_count": self._num_timesteps_added, "sampled_count": self._num_timesteps_sampled, @@ -156,7 +156,7 @@ class PrioritizedReplayBuffer(ReplayBuffer): self._prio_change_stats = WindowStat("reprio", 1000) @DeveloperAPI - def add(self, item: SampleBatchType, weight: float): + def add(self, item: SampleBatchType, weight: float) -> None: idx = self._next_idx super(PrioritizedReplayBuffer, self).add(item, weight) if weight is None: @@ -164,7 +164,7 @@ class PrioritizedReplayBuffer(ReplayBuffer): self._it_sum[idx] = weight**self._alpha self._it_min[idx] = weight**self._alpha - def _sample_proportional(self, num_items: int): + def _sample_proportional(self, num_items: int) -> List[int]: res = [] for _ in range(num_items): # TODO(szymon): should we ensure no repeats? @@ -215,7 +215,8 @@ class PrioritizedReplayBuffer(ReplayBuffer): return batch @DeveloperAPI - def update_priorities(self, idxes, priorities): + def update_priorities(self, idxes: List[int], + priorities: List[float]) -> None: """Update priorities of sampled transitions. sets priority of transition at index idxes[i] in buffer @@ -242,7 +243,7 @@ class PrioritizedReplayBuffer(ReplayBuffer): self._max_priority = max(self._max_priority, priority) @DeveloperAPI - def stats(self, debug=False): + def stats(self, debug: bool = False) -> Dict: parent = ReplayBuffer.stats(self, debug) if debug: parent.update(self._prio_change_stats.stats()) @@ -260,15 +261,15 @@ class LocalReplayBuffer(ParallelIteratorWorker): may be created to increase parallelism.""" def __init__(self, - num_shards=1, - learning_starts=1000, - buffer_size=10000, - replay_batch_size=1, - prioritized_replay_alpha=0.6, - prioritized_replay_beta=0.4, - prioritized_replay_eps=1e-6, - replay_mode="independent", - replay_sequence_length=1): + num_shards: int = 1, + learning_starts: int = 1000, + buffer_size: int = 10000, + replay_batch_size: int = 1, + prioritized_replay_alpha: float = 0.6, + prioritized_replay_beta: float = 0.4, + prioritized_replay_eps: float = 1e-6, + replay_mode: str = "independent", + replay_sequence_length: int = 1): self.replay_starts = learning_starts // num_shards self.buffer_size = buffer_size // num_shards self.replay_batch_size = replay_batch_size @@ -318,10 +319,10 @@ class LocalReplayBuffer(ParallelIteratorWorker): global _local_replay_buffer return _local_replay_buffer - def get_host(self): + def get_host(self) -> str: return platform.node() - def add_batch(self, batch): + def add_batch(self, batch: SampleBatchType) -> None: # Make a copy so the replay buffer doesn't pin plasma memory. batch = batch.copy() # Handle everything as if multiagent @@ -342,7 +343,7 @@ class LocalReplayBuffer(ParallelIteratorWorker): self.replay_buffers[policy_id].add(s, weight=weight) self.num_added += batch.count - def replay(self): + def replay(self) -> SampleBatchType: if self._fake_batch: fake_batch = SampleBatch(self._fake_batch) return MultiAgentBatch({ @@ -364,7 +365,7 @@ class LocalReplayBuffer(ParallelIteratorWorker): beta=self.prioritized_replay_beta) return MultiAgentBatch(samples, self.replay_batch_size) - def update_priorities(self, prio_dict): + def update_priorities(self, prio_dict: Dict) -> None: with self.update_priorities_timer: for policy_id, (batch_indexes, td_errors) in prio_dict.items(): new_priorities = ( @@ -372,7 +373,7 @@ class LocalReplayBuffer(ParallelIteratorWorker): self.replay_buffers[policy_id].update_priorities( batch_indexes, new_priorities) - def stats(self, debug=False): + def stats(self, debug: bool = False) -> Dict: stat = { "add_batch_time_ms": round(1000 * self.add_batch_timer.mean, 3), "replay_time_ms": round(1000 * self.replay_timer.mean, 3), diff --git a/rllib/execution/replay_ops.py b/rllib/execution/replay_ops.py index 9ed25e9e9..7bfc3a1b9 100644 --- a/rllib/execution/replay_ops.py +++ b/rllib/execution/replay_ops.py @@ -1,4 +1,4 @@ -from typing import List +from typing import List, Any, Optional import random from ray.util.iter import from_actors, LocalIterator, _NextValueNotReady @@ -55,7 +55,7 @@ class StoreToReplayBuffer: def Replay(*, local_buffer: LocalReplayBuffer = None, actors: List["ActorHandle"] = None, - num_async=4): + num_async: int = 4) -> LocalIterator[SampleBatchType]: """Replay experiences from the given buffer or actors. This should be combined with the StoreToReplayActors operation using the @@ -99,10 +99,10 @@ def Replay(*, class WaitUntilTimestepsElapsed: """Callable that returns True once a given number of timesteps are hit.""" - def __init__(self, target_num_timesteps): + def __init__(self, target_num_timesteps: int): self.target_num_timesteps = target_num_timesteps - def __call__(self, item): + def __call__(self, item: Any) -> bool: metrics = _get_shared_metrics() ts = metrics.counters[STEPS_SAMPLED_COUNTER] return ts > self.target_num_timesteps @@ -112,7 +112,9 @@ class WaitUntilTimestepsElapsed: class SimpleReplayBuffer: """Simple replay buffer that operates over batches.""" - def __init__(self, num_slots, replay_proportion: float = None): + def __init__(self, + num_slots: int, + replay_proportion: Optional[float] = None): """Initialize SimpleReplayBuffer. Args: @@ -122,7 +124,7 @@ class SimpleReplayBuffer: self.replay_batches = [] self.replay_index = 0 - def add_batch(self, sample_batch): + def add_batch(self, sample_batch: SampleBatchType) -> None: warn_replay_buffer_size(item=sample_batch, num_items=self.num_slots) if self.num_slots > 0: if len(self.replay_batches) < self.num_slots: @@ -132,7 +134,7 @@ class SimpleReplayBuffer: self.replay_index += 1 self.replay_index %= self.num_slots - def replay(self): + def replay(self) -> SampleBatchType: return random.choice(self.replay_batches) @@ -145,7 +147,7 @@ class MixInReplay: number of replay slots. """ - def __init__(self, num_slots, replay_proportion: float): + def __init__(self, num_slots: int, replay_proportion: float): """Initialize MixInReplay. Args: @@ -171,7 +173,7 @@ class MixInReplay: self.replay_buffer = SimpleReplayBuffer(num_slots) self.replay_proportion = replay_proportion - def __call__(self, sample_batch): + def __call__(self, sample_batch: SampleBatchType) -> List[SampleBatchType]: # Put in replay buffer if enabled. self.replay_buffer.add_batch(sample_batch) diff --git a/rllib/execution/segment_tree.py b/rllib/execution/segment_tree.py index e436f3a5a..ead3e3188 100644 --- a/rllib/execution/segment_tree.py +++ b/rllib/execution/segment_tree.py @@ -1,4 +1,5 @@ import operator +from typing import Any, Optional class SegmentTree: @@ -28,7 +29,10 @@ class SegmentTree: `tree[0]` accesses `internal_array[4]` in the above example. """ - def __init__(self, capacity, operation, neutral_element=None): + def __init__(self, + capacity: int, + operation: Any, + neutral_element: Optional[Any] = None): """Initializes a Segment Tree object. Args: @@ -52,7 +56,7 @@ class SegmentTree: self.value = [self.neutral_element for _ in range(2 * capacity)] self.operation = operation - def reduce(self, start=0, end=None): + def reduce(self, start: int = 0, end: Optional[int] = None) -> Any: """Applies `self.operation` to subsequence of our values. Subsequence is contiguous, includes `start` and excludes `end`. @@ -122,7 +126,7 @@ class SegmentTree: return result - def __setitem__(self, idx, val): + def __setitem__(self, idx: int, val: float) -> None: """ Inserts/overwrites a value in/into the tree. @@ -147,7 +151,7 @@ class SegmentTree: self.value[update_idx + 1]) idx = idx >> 1 # Divide by 2 (faster than division). - def __getitem__(self, idx): + def __getitem__(self, idx: int) -> Any: assert 0 <= idx < self.capacity return self.value[idx + self.capacity] @@ -155,15 +159,15 @@ class SegmentTree: class SumSegmentTree(SegmentTree): """A SegmentTree with the reduction `operation`=operator.add.""" - def __init__(self, capacity): + def __init__(self, capacity: int): super(SumSegmentTree, self).__init__( capacity=capacity, operation=operator.add) - def sum(self, start=0, end=None): + def sum(self, start: int = 0, end: Optional[Any] = None) -> Any: """Returns the sum over a sub-segment of the tree.""" return self.reduce(start, end) - def find_prefixsum_idx(self, prefixsum): + def find_prefixsum_idx(self, prefixsum: float) -> int: """Finds highest i, for which: sum(arr[0]+..+arr[i - i]) <= prefixsum. Args: @@ -188,9 +192,9 @@ class SumSegmentTree(SegmentTree): class MinSegmentTree(SegmentTree): - def __init__(self, capacity): + def __init__(self, capacity: int): super(MinSegmentTree, self).__init__(capacity=capacity, operation=min) - def min(self, start=0, end=None): + def min(self, start: int = 0, end: Optional[Any] = None) -> Any: """Returns min(arr[start], ..., arr[end])""" return self.reduce(start, end) diff --git a/rllib/execution/train_ops.py b/rllib/execution/train_ops.py index f99fd0a0c..e2411ed32 100644 --- a/rllib/execution/train_ops.py +++ b/rllib/execution/train_ops.py @@ -2,7 +2,7 @@ from collections import defaultdict import logging import numpy as np import math -from typing import List +from typing import List, Tuple, Any import ray from ray.rllib.evaluation.metrics import get_learner_stats, LEARNER_STATS_KEY @@ -18,7 +18,7 @@ from ray.rllib.policy.sample_batch import SampleBatch, DEFAULT_POLICY_ID, \ MultiAgentBatch from ray.rllib.utils.framework import try_import_tf from ray.rllib.utils.sgd import do_minibatch_sgd, averaged -from ray.rllib.utils.typing import PolicyID, SampleBatchType +from ray.rllib.utils.typing import PolicyID, SampleBatchType, ModelGradients tf1, tf, tfv = try_import_tf() @@ -242,10 +242,10 @@ class ComputeGradients: Updates the LEARNER_INFO info field in the local iterator context. """ - def __init__(self, workers): + def __init__(self, workers: WorkerSet): self.workers = workers - def __call__(self, samples: SampleBatchType): + def __call__(self, samples: SampleBatchType) -> Tuple[ModelGradients, int]: _check_sample_batch_type(samples) metrics = _get_shared_metrics() with metrics.timers[COMPUTE_GRADS_TIMER]: @@ -283,7 +283,7 @@ class ApplyGradients: self.policies = policies or workers.local_worker().policies_to_train self.update_all = update_all - def __call__(self, item): + def __call__(self, item: Tuple[ModelGradients, int]) -> None: if not isinstance(item, tuple) or len(item) != 2: raise ValueError( "Input must be a tuple of (grad_dict, count), got {}".format( @@ -333,7 +333,8 @@ class AverageGradients: {"var_0": ..., ...}, 1600 # averaged grads, summed batch count """ - def __call__(self, gradients): + def __call__(self, gradients: List[Tuple[ModelGradients, int]] + ) -> Tuple[ModelGradients, int]: acc = None sum_count = 0 for grad, count in gradients: @@ -366,10 +367,10 @@ class UpdateTargetNetwork: """ def __init__(self, - workers, - target_update_freq, - by_steps_trained=False, - policies=frozenset([])): + workers: WorkerSet, + target_update_freq: int, + by_steps_trained: bool = False, + policies: List[PolicyID] = frozenset([])): self.workers = workers self.target_update_freq = target_update_freq self.policies = (policies or workers.local_worker().policies_to_train) @@ -378,7 +379,7 @@ class UpdateTargetNetwork: else: self.metric = STEPS_SAMPLED_COUNTER - def __call__(self, _): + def __call__(self, _: Any) -> None: metrics = _get_shared_metrics() cur_ts = metrics.counters[self.metric] last_update = metrics.counters[LAST_TARGET_UPDATE_TS] diff --git a/rllib/execution/tree_agg.py b/rllib/execution/tree_agg.py index 69e06a4b0..b04bee783 100644 --- a/rllib/execution/tree_agg.py +++ b/rllib/execution/tree_agg.py @@ -1,6 +1,6 @@ import logging import platform -from typing import List +from typing import List, Dict, Any import ray from ray.rllib.execution.common import STEPS_SAMPLED_COUNTER, \ @@ -9,8 +9,9 @@ from ray.rllib.execution.replay_ops import MixInReplay from ray.rllib.execution.rollout_ops import ParallelRollouts, ConcatBatches from ray.rllib.utils.actors import create_colocated from ray.util.iter import ParallelIterator, ParallelIteratorWorker, \ - from_actors -from ray.rllib.utils.typing import SampleBatchType + from_actors, LocalIterator +from ray.rllib.utils.typing import SampleBatchType, ModelWeights +from ray.rllib.evaluation.worker_set import WorkerSet logger = logging.getLogger(__name__) @@ -25,7 +26,7 @@ class Aggregator(ParallelIteratorWorker): work to be offloaded to these actors instead of run in the learner. """ - def __init__(self, config: dict, + def __init__(self, config: Dict, rollout_group: "ParallelIterator[SampleBatchType]"): self.weights = None self.global_vars = None @@ -60,15 +61,16 @@ class Aggregator(ParallelIteratorWorker): super().__init__(generator, repeat=False) - def get_host(self): + def get_host(self) -> str: return platform.node() - def set_weights(self, weights, global_vars): + def set_weights(self, weights: ModelWeights, global_vars: Dict) -> None: self.weights = weights self.global_vars = global_vars -def gather_experiences_tree_aggregation(workers, config): +def gather_experiences_tree_aggregation(workers: WorkerSet, + config: Dict) -> "LocalIterator[Any]": """Tree aggregation version of gather_experiences_directly().""" rollouts = ParallelRollouts(workers, mode="raw") From 4bcd47567183b2fc5e618443e187d7fe7715e8f1 Mon Sep 17 00:00:00 2001 From: Michael Luo Date: Thu, 24 Dec 2020 06:31:35 -0800 Subject: [PATCH 87/88] [RLlib] Improved Documentation for PPO, DDPG, and SAC (#12943) --- rllib/agents/ddpg/README.md | 22 +++++++++++++++++++++- rllib/agents/ppo/README.md | 26 ++++++++++++++++++++------ rllib/agents/sac/README.md | 14 ++++++++++---- 3 files changed, 51 insertions(+), 11 deletions(-) diff --git a/rllib/agents/ddpg/README.md b/rllib/agents/ddpg/README.md index 93c32b0a2..5d4f10b80 100644 --- a/rllib/agents/ddpg/README.md +++ b/rllib/agents/ddpg/README.md @@ -1 +1,21 @@ -Implementation of deep deterministic policy gradients (https://arxiv.org/abs/1509.02971), including an Ape-X variant. +# Deep Deterministic Policy Gradient (DDPG) + +## Overview + +[DDPG](https://arxiv.org/abs/1509.02971) is a model-free off-policy RL algorithm that works well for environments in the continuous-action domain. DDPG employs two networks, a critic Q-network and an actor network. For stable training, DDPG also opts to use target networks to compute labels for the critic's loss function. + +For the critic network, the loss function is the L2 loss between critic output and critic target values. The critic target values are usually computed with a one-step bootstrap from the critic and actor target networks. On the other hand, the actor seeks to maximize the critic Q-values in its loss function. This is done by sampling backpropragable actions (via the reparameterization trick) from the actor and evaluating the critic, with frozen weights, on the generated state-action pairs. Like most off-policy algorithms, DDPG employs a replay buffer, which it samples batches from to compute gradients for the actor and critic networks. + +## Documentation & Implementation: + +1) Deep Deterministic Policy Gradient (DDPG) and Twin Delayed DDPG (TD3) + + **[Detailed Documentation](https://docs.ray.io/en/latest/rllib-algorithms.html#ddpg)** + + **[Implementation](https://github.com/ray-project/ray/blob/master/rllib/agents/ddpg/ddpg.py)** + +2) Ape-X variant of DDPG (Prioritized Experience Replay) + + **[Detailed Documentation](https://docs.ray.io/en/latest/rllib-algorithms.html#apex)** + + **[Implementation](https://github.com/ray-project/ray/blob/master/rllib/agents/ddpg/ddpg.py)** diff --git a/rllib/agents/ppo/README.md b/rllib/agents/ppo/README.md index 1a11124f5..4095f3550 100644 --- a/rllib/agents/ppo/README.md +++ b/rllib/agents/ppo/README.md @@ -1,7 +1,22 @@ -Proximal Policy Optimization (PPO) -================================== +# Proximal Policy Optimization (PPO) -Implementations of: +## Overview + +[PPO](https://arxiv.org/abs/1707.06347) is a model-free on-policy RL algorithm that works well for both discrete and continuous action space environments. PPO utilizes an actor-critic framework, where there are two networks, an actor (policy network) and critic network (value function). + +There are two formulations of PPO, which are both implemented in RLlib. The first formulation of PPO imitates the prior paper [TRPO](https://arxiv.org/abs/1502.05477) without the complexity of second-order optimization. In this formulation, for every iteration, an old version of an actor-network is saved and the agent seeks to optimize the RL objective while staying close to the old policy. This makes sure that the agent does not destabilize during training. In the second formulation, To mitigate destructive large policy updates, an issue discovered for vanilla policy gradient methods, PPO introduces the surrogate objective, which clips large action probability ratios between the current and old policy. Clipping has been shown in the paper to significantly improve training stability and speed. + +## Distributed PPO Algorithms + +PPO is a core algorithm in RLlib due to its ability to scale well with the number of nodes. In RLlib, we provide various implementation of distributed PPO, with different underlying execution plans, as shown below. + +Distributed baseline PPO is a synchronous distributed RL algorithm. Data collection nodes, which represent the old policy, gather data synchronously to create a large pool of on-policy data from which the agent performs minibatch gradient descent on. + +On the other hand, Asychronous PPO (APPO) opts to imitate IMPALA as its distributed execution plan. Data collection nodes gather data asynchronously, which are collected in a circular replay buffer. A target network and doubly-importance sampled surrogate objective is introduced to enforce training stability in the asynchronous data-collection setting. + +Lastly, Decentralized Distributed PPO (DDPPO) removes the assumption that gradient-updates must be done on a central node. Instead, gradients are computed remotely on each data collection node and all-reduced at each mini-batch using torch distributed. This allows each worker’s GPU to be used both for sampling and for training. + +## Documentation & Implementation: 1) Proximal Policy Optimization (PPO). @@ -9,15 +24,14 @@ Implementations of: **[Implementation](https://github.com/ray-project/ray/blob/master/rllib/agents/ppo/ppo.py)** -2) Asynchronous Proximal Policy Optimization (APPO). +2) [Asynchronous Proximal Policy Optimization (APPO)](https://arxiv.org/abs/1912.00167). **[Detailed Documentation](https://docs.ray.io/en/master/rllib-algorithms.html#appo)** **[Implementation](https://github.com/ray-project/ray/blob/master/rllib/agents/ppo/appo.py)** -3) Decentralized Distributed Proximal Policy Optimization (DDPPO) +3) [Decentralized Distributed Proximal Policy Optimization (DDPPO)](https://arxiv.org/abs/1911.00357) **[Detailed Documentation](https://docs.ray.io/en/master/rllib-algorithms.html#decentralized-distributed-proximal-policy-optimization-dd-ppo)** **[Implementation](https://github.com/ray-project/ray/blob/master/rllib/agents/ppo/ddppo.py)** - diff --git a/rllib/agents/sac/README.md b/rllib/agents/sac/README.md index 8aa0c4c45..13ce9a644 100644 --- a/rllib/agents/sac/README.md +++ b/rllib/agents/sac/README.md @@ -1,10 +1,16 @@ -Soft Actor Critic (SAC) -======================= +# Soft Actor Critic (SAC) -Implementations of: +## Overview -Soft Actor-Critic Algorithm (SAC) and a discrete action extension. +[SAC](https://arxiv.org/abs/1801.01290) is a SOTA model-free off-policy RL algorithm that performs remarkably well on continuous-control domains. SAC employs an actor-critic framework and combats high sample complexity and training stability via learning based on a maximum-entropy framework. Unlike the standard RL objective which aims to maximize sum of reward into the future, SAC seeks to optimize sum of rewards as well as expected entropy over the current policy. In addition to optimizing over an actor and critic with entropy-based objectives, SAC also optimizes for the entropy coeffcient. + +## Documentation & Implementation: + +[Soft Actor-Critic Algorithm (SAC)](https://arxiv.org/abs/1801.01290) with also [discrete-action support](https://arxiv.org/abs/1910.07207). **[Detailed Documentation](https://docs.ray.io/en/master/rllib-algorithms.html#sac)** **[Implementation](https://github.com/ray-project/ray/blob/master/rllib/agents/sac/sac.py)** + + + From 2059a2090da953df153cf89b477a2f9f4945b7a9 Mon Sep 17 00:00:00 2001 From: Alind Khare Date: Thu, 24 Dec 2020 12:32:52 -0500 Subject: [PATCH 88/88] [C++ API] Added reference counting to ObjectRef (#13058) * Added reference counting to ObjectRef * Addressed the comments --- cpp/include/ray/api/object_ref.h | 12 +++++ cpp/src/ray/test/cluster/cluster_mode_test.cc | 46 ++++++++----------- 2 files changed, 32 insertions(+), 26 deletions(-) diff --git a/cpp/include/ray/api/object_ref.h b/cpp/include/ray/api/object_ref.h index 71500202b..c4ceff90c 100644 --- a/cpp/include/ray/api/object_ref.h +++ b/cpp/include/ray/api/object_ref.h @@ -16,6 +16,7 @@ template class ObjectRef { public: ObjectRef(); + ~ObjectRef(); ObjectRef(const ObjectID &id); @@ -46,6 +47,17 @@ ObjectRef::ObjectRef() {} template ObjectRef::ObjectRef(const ObjectID &id) { id_ = id; + if (CoreWorkerProcess::IsInitialized()) { + auto &core_worker = CoreWorkerProcess::GetCoreWorker(); + core_worker.AddLocalReference(id_); + } +} +template +ObjectRef::~ObjectRef() { + if (CoreWorkerProcess::IsInitialized()) { + auto &core_worker = CoreWorkerProcess::GetCoreWorker(); + core_worker.RemoveLocalReference(id_); + } } template diff --git a/cpp/src/ray/test/cluster/cluster_mode_test.cc b/cpp/src/ray/test/cluster/cluster_mode_test.cc index 68ac2ec02..780fb0d30 100644 --- a/cpp/src/ray/test/cluster/cluster_mode_test.cc +++ b/cpp/src/ray/test/cluster/cluster_mode_test.cc @@ -91,16 +91,13 @@ TEST(RayClusterModeTest, FullTest) { auto r5 = Ray::Task(Plus, r4, r3).Remote(); auto r6 = Ray::Task(Plus, r4, 10).Remote(); - ///// TODO(ameer/guyang): All the commented code lines below should be - ///// uncommented once reference counting is added. Currently the objects - ///// are leaking from the object store. int result5 = *(Ray::Get(r5)); - // int result4 = *(Ray::Get(r4)); + int result4 = *(Ray::Get(r4)); int result6 = *(Ray::Get(r6)); - // int result3 = *(Ray::Get(r3)); + int result3 = *(Ray::Get(r3)); EXPECT_EQ(result0, 1); - // EXPECT_EQ(result3, 1); - // EXPECT_EQ(result4, 2); + EXPECT_EQ(result3, 1); + EXPECT_EQ(result4, 2); EXPECT_EQ(result5, 3); EXPECT_EQ(result6, 12); @@ -114,37 +111,34 @@ TEST(RayClusterModeTest, FullTest) { int result7 = *(Ray::Get(r7)); int result8 = *(Ray::Get(r8)); int result9 = *(Ray::Get(r9)); - // int result10 = *(Ray::Get(r10)); + int result10 = *(Ray::Get(r10)); EXPECT_EQ(result7, 15); EXPECT_EQ(result8, 16); EXPECT_EQ(result9, 19); - // EXPECT_EQ(result10, 27); + EXPECT_EQ(result10, 27); /// create actor and task function remote call with args passed by reference - // ActorHandle actor5 = Ray::Actor(Counter::FactoryCreate, r10, 0).Remote(); - ActorHandle actor5 = Ray::Actor(Counter::FactoryCreate, 27, 0).Remote(); - // auto r11 = actor5.Task(&Counter::Add, r0).Remote(); - auto r11 = actor5.Task(&Counter::Add, 1).Remote(); - // auto r12 = actor5.Task(&Counter::Add, r11).Remote(); + ActorHandle actor5 = Ray::Actor(Counter::FactoryCreate, r10, 0).Remote(); + + auto r11 = actor5.Task(&Counter::Add, r0).Remote(); + auto r12 = actor5.Task(&Counter::Add, r11).Remote(); auto r13 = actor5.Task(&Counter::Add, r10).Remote(); auto r14 = actor5.Task(&Counter::Add, r13).Remote(); - // auto r15 = Ray::Task(Plus, r0, r11).Remote(); - auto r15 = Ray::Task(Plus, 1, r11).Remote(); + auto r15 = Ray::Task(Plus, r0, r11).Remote(); auto r16 = Ray::Task(Plus1, r15).Remote(); - // int result12 = *(Ray::Get(r12)); + int result12 = *(Ray::Get(r12)); int result14 = *(Ray::Get(r14)); - // int result11 = *(Ray::Get(r11)); - // int result13 = *(Ray::Get(r13)); + int result11 = *(Ray::Get(r11)); + int result13 = *(Ray::Get(r13)); int result16 = *(Ray::Get(r16)); - // int result15 = *(Ray::Get(r15)); + int result15 = *(Ray::Get(r15)); - // EXPECT_EQ(result11, 28); - // EXPECT_EQ(result12, 56); - // EXPECT_EQ(result13, 83); - // EXPECT_EQ(result14, 166); - EXPECT_EQ(result14, 110); - // EXPECT_EQ(result15, 29); + EXPECT_EQ(result11, 28); + EXPECT_EQ(result12, 56); + EXPECT_EQ(result13, 83); + EXPECT_EQ(result14, 166); + EXPECT_EQ(result15, 29); EXPECT_EQ(result16, 30); Ray::Shutdown();