diff --git a/BUILD.bazel b/BUILD.bazel index 61484ba82..e2cbdd64b 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -495,6 +495,7 @@ flatbuffer_py_library( "ConfigTableData.py", "CustomSerializerData.py", "DriverTableData.py", + "EntryType.py", "ErrorTableData.py", "ErrorType.py", "FunctionTableData.py", diff --git a/doc/source/conf.py b/doc/source/conf.py index b67dbe267..e0bd2c6da 100644 --- a/doc/source/conf.py +++ b/doc/source/conf.py @@ -26,6 +26,7 @@ MOCK_MODULES = [ "ray.core.generated.ActorCheckpointIdData", "ray.core.generated.ClientTableData", "ray.core.generated.DriverTableData", + "ray.core.generated.EntryType", "ray.core.generated.ErrorTableData", "ray.core.generated.ErrorType", "ray.core.generated.GcsTableEntry", diff --git a/doc/source/development.rst b/doc/source/development.rst index 66e666b4d..e4d50327a 100644 --- a/doc/source/development.rst +++ b/doc/source/development.rst @@ -81,7 +81,7 @@ API. The easiest way to do this is to start or connect to a Ray cluster with ray.worker.global_state.client_table() # Returns current information about the nodes in the cluster, such as: # [{'ClientID': '2a9d2b34ad24a37ed54e4fcd32bf19f915742f5b', - # 'IsInsertion': True, + # 'EntryType': 0, # 'NodeManagerAddress': '1.2.3.4', # 'NodeManagerPort': 43280, # 'ObjectManagerPort': 38062, diff --git a/java/BUILD.bazel b/java/BUILD.bazel index 34799a76c..2d2762d83 100644 --- a/java/BUILD.bazel +++ b/java/BUILD.bazel @@ -154,6 +154,7 @@ flatbuffers_generated_files = [ "ConfigTableData.java", "CustomSerializerData.java", "DriverTableData.java", + "EntryType.java", "ErrorTableData.java", "ErrorType.java", "FunctionTableData.java", diff --git a/java/runtime/src/main/java/org/ray/runtime/gcs/GcsClient.java b/java/runtime/src/main/java/org/ray/runtime/gcs/GcsClient.java index a627f200a..647b77e33 100644 --- a/java/runtime/src/main/java/org/ray/runtime/gcs/GcsClient.java +++ b/java/runtime/src/main/java/org/ray/runtime/gcs/GcsClient.java @@ -13,6 +13,7 @@ import org.ray.api.id.UniqueId; import org.ray.api.runtimecontext.NodeInfo; import org.ray.runtime.generated.ActorCheckpointIdData; import org.ray.runtime.generated.ClientTableData; +import org.ray.runtime.generated.EntryType; import org.ray.runtime.generated.TablePrefix; import org.ray.runtime.util.UniqueIdUtil; import org.slf4j.Logger; @@ -63,7 +64,7 @@ public class GcsClient { ClientTableData data = ClientTableData.getRootAsClientTableData(ByteBuffer.wrap(result)); final UniqueId clientId = UniqueId.fromByteBuffer(data.clientIdAsByteBuffer()); - if (data.isInsertion()) { + if (data.entryType() == EntryType.INSERTION) { //Code path of node insertion. Map resources = new HashMap<>(); // Compute resources. @@ -72,12 +73,24 @@ public class GcsClient { for (int i = 0; i < data.resourcesTotalLabelLength(); i++) { resources.put(data.resourcesTotalLabel(i), data.resourcesTotalCapacity(i)); } - NodeInfo nodeInfo = new NodeInfo( clientId, data.nodeManagerAddress(), true, resources); clients.put(clientId, nodeInfo); + } else if (data.entryType() == EntryType.RES_CREATEUPDATE){ + Preconditions.checkState(clients.containsKey(clientId)); + NodeInfo nodeInfo = clients.get(clientId); + for (int i = 0; i < data.resourcesTotalLabelLength(); i++) { + nodeInfo.resources.put(data.resourcesTotalLabel(i), data.resourcesTotalCapacity(i)); + } + } else if (data.entryType() == EntryType.RES_DELETE){ + Preconditions.checkState(clients.containsKey(clientId)); + NodeInfo nodeInfo = clients.get(clientId); + for (int i = 0; i < data.resourcesTotalLabelLength(); i++) { + nodeInfo.resources.remove(data.resourcesTotalLabel(i)); + } } else { // Code path of node deletion. + Preconditions.checkState(data.entryType() == EntryType.DELETION); NodeInfo nodeInfo = new NodeInfo(clientId, clients.get(clientId).nodeAddress, false, clients.get(clientId).resources); clients.put(clientId, nodeInfo); diff --git a/python/ray/_raylet.pyx b/python/ray/_raylet.pyx index 31937837c..bae62f9b1 100644 --- a/python/ray/_raylet.pyx +++ b/python/ray/_raylet.pyx @@ -32,6 +32,7 @@ from ray.includes.libraylet cimport ( from ray.includes.unique_ids cimport ( CActorCheckpointID, CObjectID, + CClientID, ) from ray.includes.task cimport CTaskSpecification from ray.includes.ray_config cimport RayConfig @@ -368,6 +369,9 @@ cdef class RayletClient: check_status(self.client.get().NotifyActorResumedFromCheckpoint( actor_id.native(), checkpoint_id.native())) + def set_resource(self, basestring resource_name, double capacity, ClientID client_id): + self.client.get().SetResource(resource_name.encode("ascii"), capacity, CClientID.from_binary(client_id.binary())) + @property def language(self): return Language.from_native(self.client.get().GetLanguage()) diff --git a/python/ray/experimental/__init__.py b/python/ray/experimental/__init__.py index 425ff2d93..5b811ff0f 100644 --- a/python/ray/experimental/__init__.py +++ b/python/ray/experimental/__init__.py @@ -10,6 +10,7 @@ from .gcs_flush_policy import (set_flushing_policy, GcsFlushPolicy, SimpleGcsFlushPolicy) from .named_actors import get_actor, register_actor from .api import get, wait +from .dynamic_resources import set_resource def TensorFlowVariables(*args, **kwargs): @@ -24,5 +25,5 @@ __all__ = [ "flush_evicted_objects_unsafe", "_flush_finished_tasks_unsafe_shard", "_flush_evicted_objects_unsafe_shard", "get_actor", "register_actor", "get", "wait", "set_flushing_policy", "GcsFlushPolicy", - "SimpleGcsFlushPolicy" + "SimpleGcsFlushPolicy", "set_resource" ] diff --git a/python/ray/experimental/dynamic_resources.py b/python/ray/experimental/dynamic_resources.py new file mode 100644 index 000000000..34b2b99e6 --- /dev/null +++ b/python/ray/experimental/dynamic_resources.py @@ -0,0 +1,35 @@ +import ray + + +def set_resource(resource_name, capacity, client_id=None): + """ Set a resource to a specified capacity. + + This creates, updates or deletes a custom resource for a target clientId. + If the resource already exists, it's capacity is updated to the new value. + If the capacity is set to 0, the resource is deleted. + If ClientID is not specified or set to None, + the resource is created on the local client where the actor is running. + + Args: + resource_name (str): Name of the resource to be created + capacity (int): Capacity of the new resource. Resource is deleted if + capacity is 0. + client_id (str): The ClientId of the node where the resource is to be + set. + + Returns: + None + + Raises: + ValueError: This exception is raised when a non-negative capacity is + specified. + """ + if client_id is not None: + client_id_obj = ray.ClientID(ray.utils.hex_to_binary(client_id)) + else: + client_id_obj = ray.ClientID.nil() + if (capacity < 0) or (capacity != int(capacity)): + raise ValueError( + "Capacity {} must be a non-negative integer.".format(capacity)) + return ray.worker.global_worker.raylet_client.set_resource( + resource_name, capacity, client_id_obj) diff --git a/python/ray/experimental/state.py b/python/ray/experimental/state.py index 31d4b77c6..51b36dc83 100644 --- a/python/ray/experimental/state.py +++ b/python/ray/experimental/state.py @@ -13,6 +13,7 @@ import ray.gcs_utils from ray.ray_constants import ID_SIZE from ray import services +from ray.core.generated.EntryType import EntryType from ray.utils import (decode, binary_to_object_id, binary_to_hex, hex_to_binary) @@ -54,29 +55,43 @@ def parse_client_table(redis_client): } client_id = ray.utils.binary_to_hex(client.ClientId()) - # If this client is being removed, then it must + if client.EntryType() == EntryType.INSERTION: + ordered_client_ids.append(client_id) + node_info[client_id] = { + "ClientID": client_id, + "EntryType": client.EntryType(), + "NodeManagerAddress": decode( + client.NodeManagerAddress(), allow_none=True), + "NodeManagerPort": client.NodeManagerPort(), + "ObjectManagerPort": client.ObjectManagerPort(), + "ObjectStoreSocketName": decode( + client.ObjectStoreSocketName(), allow_none=True), + "RayletSocketName": decode( + client.RayletSocketName(), allow_none=True), + "Resources": resources + } + + # If this client is being updated, then it must # have previously been inserted, and # it cannot have previously been removed. - if not client.IsInsertion(): - assert client_id in node_info, "Client removed not found!" - assert node_info[client_id]["IsInsertion"], ( - "Unexpected duplicate removal of client.") else: - ordered_client_ids.append(client_id) - - node_info[client_id] = { - "ClientID": client_id, - "IsInsertion": client.IsInsertion(), - "NodeManagerAddress": decode( - client.NodeManagerAddress(), allow_none=True), - "NodeManagerPort": client.NodeManagerPort(), - "ObjectManagerPort": client.ObjectManagerPort(), - "ObjectStoreSocketName": decode( - client.ObjectStoreSocketName(), allow_none=True), - "RayletSocketName": decode( - client.RayletSocketName(), allow_none=True), - "Resources": resources - } + assert client_id in node_info, "Client not found!" + assert node_info[client_id]["EntryType"] != EntryType.DELETION, ( + "Unexpected updation of deleted client.") + res_map = node_info[client_id]["Resources"] + if client.EntryType() == EntryType.RES_CREATEUPDATE: + for res in resources: + res_map[res] = resources[res] + elif client.EntryType() == EntryType.RES_DELETE: + for res in resources: + res_map.pop(res, None) + elif client.EntryType() == EntryType.DELETION: + pass # Do nothing with the resmap if client deletion + else: + raise RuntimeError("Unexpected EntryType {}".format( + client.EntryType())) + node_info[client_id]["Resources"] = res_map + node_info[client_id]["EntryType"] = client.EntryType() # NOTE: We return the list comprehension below instead of simply doing # 'list(node_info.values())' in order to have the nodes appear in the order # that they joined the cluster. Python dictionaries do not preserve @@ -757,18 +772,18 @@ class GlobalState(object): resources = defaultdict(int) clients = self.client_table() for client in clients: - # Only count resources from live clients. - if client["IsInsertion"]: + # Only count resources from latest entries of live clients. + if client["EntryType"] != EntryType.DELETION: for key, value in client["Resources"].items(): resources[key] += value - return dict(resources) def _live_client_ids(self): """Returns a set of client IDs corresponding to clients still alive.""" return { client["ClientID"] - for client in self.client_table() if client["IsInsertion"] + for client in self.client_table() + if (client["EntryType"] != EntryType.DELETION) } def available_resources(self): diff --git a/python/ray/includes/libraylet.pxd b/python/ray/includes/libraylet.pxd index be74b06e5..1b4c5e3cd 100644 --- a/python/ray/includes/libraylet.pxd +++ b/python/ray/includes/libraylet.pxd @@ -72,6 +72,7 @@ cdef extern from "ray/raylet/raylet_client.h" nogil: CActorCheckpointID &checkpoint_id) CRayStatus NotifyActorResumedFromCheckpoint( const CActorID &actor_id, const CActorCheckpointID &checkpoint_id) + CRayStatus SetResource(const c_string &resource_name, const double capacity, const CClientID &client_Id) CLanguage GetLanguage() const CClientID GetClientID() const CDriverID GetDriverID() const diff --git a/python/ray/tests/cluster_utils.py b/python/ray/tests/cluster_utils.py index 0a7984d69..a7ed3e14a 100644 --- a/python/ray/tests/cluster_utils.py +++ b/python/ray/tests/cluster_utils.py @@ -8,6 +8,7 @@ import time import redis import ray +from ray.core.generated.EntryType import EntryType logger = logging.getLogger(__name__) @@ -175,7 +176,8 @@ class Cluster(object): while time.time() - start_time < timeout: clients = ray.experimental.state.parse_client_table(redis_client) live_clients = [ - client for client in clients if client["IsInsertion"] + client for client in clients + if client["EntryType"] == EntryType.INSERTION ] expected = len(self.list_all_nodes()) diff --git a/python/ray/tests/test_dynres.py b/python/ray/tests/test_dynres.py new file mode 100644 index 000000000..6f3983930 --- /dev/null +++ b/python/ray/tests/test_dynres.py @@ -0,0 +1,586 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import logging +import time + +import ray +import ray.tests.cluster_utils +import ray.tests.utils + +logger = logging.getLogger(__name__) + + +def test_dynamic_res_creation(ray_start_regular): + # This test creates a resource locally (without specifying the client_id) + res_name = "test_res" + res_capacity = 1.0 + + @ray.remote + def set_res(resource_name, resource_capacity): + ray.experimental.set_resource(resource_name, resource_capacity) + + ray.get(set_res.remote(res_name, res_capacity)) + + available_res = ray.global_state.available_resources() + cluster_res = ray.global_state.cluster_resources() + + assert available_res[res_name] == res_capacity + assert cluster_res[res_name] == res_capacity + + +def test_dynamic_res_deletion(shutdown_only): + # This test deletes a resource locally (without specifying the client_id) + res_name = "test_res" + res_capacity = 1.0 + + ray.init(num_cpus=1, resources={res_name: res_capacity}) + + @ray.remote + def delete_res(resource_name): + ray.experimental.set_resource(resource_name, 0) + + ray.get(delete_res.remote(res_name)) + + available_res = ray.global_state.available_resources() + cluster_res = ray.global_state.cluster_resources() + + assert res_name not in available_res + assert res_name not in cluster_res + + +def test_dynamic_res_infeasible_rescheduling(ray_start_regular): + # This test launches an infeasible task and then creates a + # resource to make the task feasible. This tests if the + # infeasible tasks get rescheduled when resources are + # created at runtime. + res_name = "test_res" + res_capacity = 1.0 + + @ray.remote + def set_res(resource_name, resource_capacity): + ray.experimental.set_resource(resource_name, resource_capacity) + + def f(): + return 1 + + remote_task = ray.remote(resources={res_name: res_capacity})(f) + oid = remote_task.remote() # This is infeasible + ray.get(set_res.remote(res_name, res_capacity)) # Now should be feasible + + available_res = ray.global_state.available_resources() + assert available_res[res_name] == res_capacity + + successful, unsuccessful = ray.wait([oid], timeout=1) + assert successful # The task completed + + +def test_dynamic_res_updation_clientid(ray_start_cluster): + # This test does a simple resource capacity update + cluster = ray_start_cluster + + res_name = "test_res" + res_capacity = 1.0 + num_nodes = 3 + for i in range(num_nodes): + cluster.add_node() + + ray.init(redis_address=cluster.redis_address) + + target_clientid = ray.global_state.client_table()[1]["ClientID"] + + @ray.remote + def set_res(resource_name, resource_capacity, client_id): + ray.experimental.set_resource( + resource_name, resource_capacity, client_id=client_id) + + # Create resource + ray.get(set_res.remote(res_name, res_capacity, target_clientid)) + + # Update resource + new_capacity = res_capacity + 1 + ray.get(set_res.remote(res_name, new_capacity, target_clientid)) + + target_client = next(client for client in ray.global_state.client_table() + if client["ClientID"] == target_clientid) + resources = target_client["Resources"] + + assert res_name in resources + assert resources[res_name] == new_capacity + + +def test_dynamic_res_creation_clientid(ray_start_cluster): + # Creates a resource on a specific client and verifies creation. + cluster = ray_start_cluster + + res_name = "test_res" + res_capacity = 1.0 + num_nodes = 3 + for i in range(num_nodes): + cluster.add_node() + + ray.init(redis_address=cluster.redis_address) + + target_clientid = ray.global_state.client_table()[1]["ClientID"] + + @ray.remote + def set_res(resource_name, resource_capacity, res_client_id): + ray.experimental.set_resource( + resource_name, resource_capacity, client_id=res_client_id) + + ray.get(set_res.remote(res_name, res_capacity, target_clientid)) + target_client = next(client for client in ray.global_state.client_table() + if client["ClientID"] == target_clientid) + resources = target_client["Resources"] + + assert res_name in resources + assert resources[res_name] == res_capacity + + +def test_dynamic_res_creation_clientid_multiple(ray_start_cluster): + # This test creates resources on multiple clients using the clientid + # specifier + cluster = ray_start_cluster + + TIMEOUT = 5 + res_name = "test_res" + res_capacity = 1.0 + num_nodes = 3 + for i in range(num_nodes): + cluster.add_node() + + ray.init(redis_address=cluster.redis_address) + + target_clientids = [ + client["ClientID"] for client in ray.global_state.client_table() + ] + + @ray.remote + def set_res(resource_name, resource_capacity, res_client_id): + ray.experimental.set_resource( + resource_name, resource_capacity, client_id=res_client_id) + + results = [] + for cid in target_clientids: + results.append(set_res.remote(res_name, res_capacity, cid)) + ray.get(results) + + success = False + start_time = time.time() + + while time.time() - start_time < TIMEOUT and not success: + resources_created = [] + for cid in target_clientids: + target_client = next(client + for client in ray.global_state.client_table() + if client["ClientID"] == cid) + resources = target_client["Resources"] + resources_created.append(resources[res_name] == res_capacity) + success = all(resources_created) + assert success + + +def test_dynamic_res_deletion_clientid(ray_start_cluster): + # This test deletes a resource on a given client id + cluster = ray_start_cluster + + res_name = "test_res" + res_capacity = 1.0 + num_nodes = 5 + + for i in range(num_nodes): + # Create resource on all nodes, but later we'll delete it from a + # target node + cluster.add_node(resources={res_name: res_capacity}) + + ray.init(redis_address=cluster.redis_address) + + target_clientid = ray.global_state.client_table()[1]["ClientID"] + + # Launch the delete task + @ray.remote + def delete_res(resource_name, res_client_id): + ray.experimental.set_resource( + resource_name, 0, client_id=res_client_id) + + ray.get(delete_res.remote(res_name, target_clientid)) + + target_client = next(client for client in ray.global_state.client_table() + if client["ClientID"] == target_clientid) + resources = target_client["Resources"] + print(ray.global_state.cluster_resources()) + assert res_name not in resources + + +def test_dynamic_res_creation_scheduler_consistency(ray_start_cluster): + # This makes sure the resource is actually created and the state is + # consistent in the scheduler + # by launching a task which requests the created resource + cluster = ray_start_cluster + + res_name = "test_res" + res_capacity = 1.0 + num_nodes = 5 + + for i in range(num_nodes): + cluster.add_node() + + ray.init(redis_address=cluster.redis_address) + + clientids = [ + client["ClientID"] for client in ray.global_state.client_table() + ] + + @ray.remote + def set_res(resource_name, resource_capacity, res_client_id): + ray.experimental.set_resource( + resource_name, resource_capacity, client_id=res_client_id) + + # Create the resource on node1 + target_clientid = clientids[1] + ray.get(set_res.remote(res_name, res_capacity, target_clientid)) + + # Define a task which requires this resource + @ray.remote(resources={res_name: res_capacity}) + def test_func(): + return 1 + + result = test_func.remote() + successful, unsuccessful = ray.wait([result], timeout=5) + assert successful # The task completed + + +def test_dynamic_res_deletion_scheduler_consistency(ray_start_cluster): + # This makes sure the resource is actually deleted and the state is + # consistent in the scheduler by launching an infeasible task which + # requests the created resource + cluster = ray_start_cluster + + res_name = "test_res" + res_capacity = 1.0 + num_nodes = 5 + TIMEOUT_DURATION = 1 + + for i in range(num_nodes): + cluster.add_node() + + ray.init(redis_address=cluster.redis_address) + + clientids = [ + client["ClientID"] for client in ray.global_state.client_table() + ] + + @ray.remote + def delete_res(resource_name, res_client_id): + ray.experimental.set_resource( + resource_name, 0, client_id=res_client_id) + + @ray.remote + def set_res(resource_name, resource_capacity, res_client_id): + ray.experimental.set_resource( + resource_name, resource_capacity, client_id=res_client_id) + + # Create the resource on node1 + target_clientid = clientids[1] + ray.get(set_res.remote(res_name, res_capacity, target_clientid)) + assert ray.global_state.cluster_resources()[res_name] == res_capacity + + # Delete the resource + ray.get(delete_res.remote(res_name, target_clientid)) + + # Define a task which requires this resource. This should not run + @ray.remote(resources={res_name: res_capacity}) + def test_func(): + return 1 + + result = test_func.remote() + successful, unsuccessful = ray.wait([result], timeout=TIMEOUT_DURATION) + assert unsuccessful # The task did not complete because it's infeasible + + +def test_dynamic_res_concurrent_res_increment(ray_start_cluster): + # This test makes sure resource capacity is updated (increment) correctly + # when a task has already acquired some of the resource. + + cluster = ray_start_cluster + + res_name = "test_res" + res_capacity = 5 + updated_capacity = 10 + num_nodes = 5 + TIMEOUT_DURATION = 1 + + # Create a object ID to have the task wait on + WAIT_OBJECT_ID_STR = ("a" * 20).encode("ascii") + + # Create a object ID to signal that the task is running + TASK_RUNNING_OBJECT_ID_STR = ("b" * 20).encode("ascii") + + for i in range(num_nodes): + cluster.add_node() + + ray.init(redis_address=cluster.redis_address) + + clientids = [ + client["ClientID"] for client in ray.global_state.client_table() + ] + target_clientid = clientids[1] + + @ray.remote + def set_res(resource_name, resource_capacity, res_client_id): + ray.experimental.set_resource( + resource_name, resource_capacity, client_id=res_client_id) + + # Create the resource on node 1 + ray.get(set_res.remote(res_name, res_capacity, target_clientid)) + assert ray.global_state.cluster_resources()[res_name] == res_capacity + + # Task to hold the resource till the driver signals to finish + @ray.remote + def wait_func(running_oid, wait_oid): + # Signal that the task is running + ray.worker.global_worker.put_object(ray.ObjectID(running_oid), 1) + # Make the task wait till signalled by driver + ray.get(ray.ObjectID(wait_oid)) + + @ray.remote + def test_func(): + return 1 + + # Launch the task with resource requirement of 4, thus the new available + # capacity becomes 1 + task = wait_func._remote( + args=[TASK_RUNNING_OBJECT_ID_STR, WAIT_OBJECT_ID_STR], + resources={res_name: 4}) + # Wait till wait_func is launched before updating resource + ray.get(ray.ObjectID(TASK_RUNNING_OBJECT_ID_STR)) + + # Update the resource capacity + ray.get(set_res.remote(res_name, updated_capacity, target_clientid)) + + # Signal task to complete + ray.worker.global_worker.put_object(ray.ObjectID(WAIT_OBJECT_ID_STR), 1) + ray.get(task) + + # Check if scheduler state is consistent by launching a task requiring + # updated capacity + task_2 = test_func._remote(args=[], resources={res_name: updated_capacity}) + successful, unsuccessful = ray.wait([task_2], timeout=TIMEOUT_DURATION) + assert successful # The task completed + + # Check if scheduler state is consistent by launching a task requiring + # updated capacity + 1. This should not execute + task_3 = test_func._remote( + args=[], resources={res_name: updated_capacity + 1 + }) # This should be infeasible + successful, unsuccessful = ray.wait([task_3], timeout=TIMEOUT_DURATION) + assert unsuccessful # The task did not complete because it's infeasible + assert ray.global_state.available_resources()[res_name] == updated_capacity + + +def test_dynamic_res_concurrent_res_decrement(ray_start_cluster): + # This test makes sure resource capacity is updated (decremented) + # correctly when a task has already acquired some + # of the resource. + + cluster = ray_start_cluster + + res_name = "test_res" + res_capacity = 5 + updated_capacity = 2 + num_nodes = 5 + TIMEOUT_DURATION = 1 + + # Create a object ID to have the task wait on + WAIT_OBJECT_ID_STR = ("a" * 20).encode("ascii") + + # Create a object ID to signal that the task is running + TASK_RUNNING_OBJECT_ID_STR = ("b" * 20).encode("ascii") + + for i in range(num_nodes): + cluster.add_node() + + ray.init(redis_address=cluster.redis_address) + + clientids = [ + client["ClientID"] for client in ray.global_state.client_table() + ] + target_clientid = clientids[1] + + @ray.remote + def set_res(resource_name, resource_capacity, res_client_id): + ray.experimental.set_resource( + resource_name, resource_capacity, client_id=res_client_id) + + # Create the resource on node 1 + ray.get(set_res.remote(res_name, res_capacity, target_clientid)) + assert ray.global_state.cluster_resources()[res_name] == res_capacity + + # Task to hold the resource till the driver signals to finish + @ray.remote + def wait_func(running_oid, wait_oid): + # Signal that the task is running + ray.worker.global_worker.put_object(ray.ObjectID(running_oid), 1) + # Make the task wait till signalled by driver + ray.get(ray.ObjectID(wait_oid)) + + @ray.remote + def test_func(): + return 1 + + # Launch the task with resource requirement of 4, thus the new available + # capacity becomes 1 + task = wait_func._remote( + args=[TASK_RUNNING_OBJECT_ID_STR, WAIT_OBJECT_ID_STR], + resources={res_name: 4}) + # Wait till wait_func is launched before updating resource + ray.get(ray.ObjectID(TASK_RUNNING_OBJECT_ID_STR)) + + # Decrease the resource capacity + ray.get(set_res.remote(res_name, updated_capacity, target_clientid)) + + # Signal task to complete + ray.worker.global_worker.put_object(ray.ObjectID(WAIT_OBJECT_ID_STR), 1) + ray.get(task) + + # Check if scheduler state is consistent by launching a task requiring + # updated capacity + task_2 = test_func._remote(args=[], resources={res_name: updated_capacity}) + successful, unsuccessful = ray.wait([task_2], timeout=TIMEOUT_DURATION) + assert successful # The task completed + + # Check if scheduler state is consistent by launching a task requiring + # updated capacity + 1. This should not execute + task_3 = test_func._remote( + args=[], resources={res_name: updated_capacity + 1 + }) # This should be infeasible + successful, unsuccessful = ray.wait([task_3], timeout=TIMEOUT_DURATION) + assert unsuccessful # The task did not complete because it's infeasible + assert ray.global_state.available_resources()[res_name] == updated_capacity + + +def test_dynamic_res_concurrent_res_delete(ray_start_cluster): + # This test makes sure resource gets deleted correctly when a task has + # already acquired the resource + + cluster = ray_start_cluster + + res_name = "test_res" + res_capacity = 5 + num_nodes = 5 + TIMEOUT_DURATION = 1 + + # Create a object ID to have the task wait on + WAIT_OBJECT_ID_STR = ("a" * 20).encode("ascii") + + # Create a object ID to signal that the task is running + TASK_RUNNING_OBJECT_ID_STR = ("b" * 20).encode("ascii") + + for i in range(num_nodes): + cluster.add_node() + + ray.init(redis_address=cluster.redis_address) + + clientids = [ + client["ClientID"] for client in ray.global_state.client_table() + ] + target_clientid = clientids[1] + + @ray.remote + def set_res(resource_name, resource_capacity, res_client_id): + ray.experimental.set_resource( + resource_name, resource_capacity, client_id=res_client_id) + + @ray.remote + def delete_res(resource_name, res_client_id): + ray.experimental.set_resource( + resource_name, 0, client_id=res_client_id) + + # Create the resource on node 1 + ray.get(set_res.remote(res_name, res_capacity, target_clientid)) + assert ray.global_state.cluster_resources()[res_name] == res_capacity + + # Task to hold the resource till the driver signals to finish + @ray.remote + def wait_func(running_oid, wait_oid): + # Signal that the task is running + ray.worker.global_worker.put_object(ray.ObjectID(running_oid), 1) + # Make the task wait till signalled by driver + ray.get(ray.ObjectID(wait_oid)) + + @ray.remote + def test_func(): + return 1 + + # Launch the task with resource requirement of 4, thus the new available + # capacity becomes 1 + task = wait_func._remote( + args=[TASK_RUNNING_OBJECT_ID_STR, WAIT_OBJECT_ID_STR], + resources={res_name: 4}) + # Wait till wait_func is launched before updating resource + ray.get(ray.ObjectID(TASK_RUNNING_OBJECT_ID_STR)) + + # Delete the resource + ray.get(delete_res.remote(res_name, target_clientid)) + + # Signal task to complete + ray.worker.global_worker.put_object(ray.ObjectID(WAIT_OBJECT_ID_STR), 1) + ray.get(task) + + # Check if scheduler state is consistent by launching a task requiring + # the deleted resource This should not execute + task_2 = test_func._remote( + args=[], resources={res_name: 1}) # This should be infeasible + successful, unsuccessful = ray.wait([task_2], timeout=TIMEOUT_DURATION) + assert unsuccessful # The task did not complete because it's infeasible + assert res_name not in ray.global_state.available_resources() + + +def test_dynamic_res_creation_stress(ray_start_cluster): + # This stress tests creates many resources simultaneously on the same + # client and then checks if the final state is consistent + + cluster = ray_start_cluster + + TIMEOUT = 5 + res_capacity = 1 + num_nodes = 5 + NUM_RES_TO_CREATE = 500 + + for i in range(num_nodes): + cluster.add_node() + + ray.init(redis_address=cluster.redis_address) + + clientids = [ + client["ClientID"] for client in ray.global_state.client_table() + ] + target_clientid = clientids[1] + + @ray.remote + def set_res(resource_name, resource_capacity, res_client_id): + ray.experimental.set_resource( + resource_name, resource_capacity, client_id=res_client_id) + + @ray.remote + def delete_res(resource_name, res_client_id): + ray.experimental.set_resource( + resource_name, 0, client_id=res_client_id) + + results = [ + set_res.remote(str(i), res_capacity, target_clientid) + for i in range(0, NUM_RES_TO_CREATE) + ] + ray.get(results) + + success = False + start_time = time.time() + + while time.time() - start_time < TIMEOUT and not success: + resources = ray.global_state.cluster_resources() + all_resources_created = [] + for i in range(0, NUM_RES_TO_CREATE): + all_resources_created.append(str(i) in resources) + success = all(all_resources_created) + assert success diff --git a/src/ray/gcs/client_test.cc b/src/ray/gcs/client_test.cc index d2d225c0a..f7e25a487 100644 --- a/src/ray/gcs/client_test.cc +++ b/src/ray/gcs/client_test.cc @@ -1188,12 +1188,12 @@ void ClientTableNotification(gcs::AsyncGcsClient *client, const ClientID &client ASSERT_EQ(client_id, added_id); ASSERT_EQ(ClientID::from_binary(data.client_id), added_id); ASSERT_EQ(ClientID::from_binary(data.client_id), added_id); - ASSERT_EQ(data.is_insertion, is_insertion); + ASSERT_EQ(data.entry_type == EntryType::INSERTION, is_insertion); ClientTableDataT cached_client; client->client_table().GetClient(added_id, cached_client); ASSERT_EQ(ClientID::from_binary(cached_client.client_id), added_id); - ASSERT_EQ(cached_client.is_insertion, is_insertion); + ASSERT_EQ(cached_client.entry_type == EntryType::INSERTION, is_insertion); } void TestClientTableConnect(const DriverID &driver_id, diff --git a/src/ray/gcs/format/gcs.fbs b/src/ray/gcs/format/gcs.fbs index 7acb24d27..7cf250247 100644 --- a/src/ray/gcs/format/gcs.fbs +++ b/src/ray/gcs/format/gcs.fbs @@ -39,6 +39,14 @@ enum TablePubsub:int { DRIVER, } +// Enum for the entry type in the ClientTable +enum EntryType:int { + INSERTION = 0, + DELETION, + RES_CREATEUPDATE, + RES_DELETE, +} + table Arg { // Object ID for pass-by-reference arguments. Normally there is only one // object ID in this list which represents the object that is being passed. @@ -267,9 +275,8 @@ table ClientTableData { // The port at which the client's object manager is listening for TCP // connections from other object managers. object_manager_port: int; - // True if the message is about the addition of a client and false if it is - // about the deletion of a client. - is_insertion: bool; + // Enum to store the entry type in the log + entry_type: EntryType = INSERTION; resources_total_label: [string]; resources_total_capacity: [double]; } diff --git a/src/ray/gcs/tables.cc b/src/ray/gcs/tables.cc index e0876aa73..dbd39349c 100644 --- a/src/ray/gcs/tables.cc +++ b/src/ray/gcs/tables.cc @@ -363,7 +363,7 @@ void ClientTable::RegisterClientAddedCallback(const ClientTableCallback &callbac client_added_callback_ = callback; // Call the callback for any added clients that are cached. for (const auto &entry : client_cache_) { - if (!entry.first.is_nil() && entry.second.is_insertion) { + if (!entry.first.is_nil() && (entry.second.entry_type == EntryType::INSERTION)) { client_added_callback_(client_, entry.first, entry.second); } } @@ -373,55 +373,136 @@ void ClientTable::RegisterClientRemovedCallback(const ClientTableCallback &callb client_removed_callback_ = callback; // Call the callback for any removed clients that are cached. for (const auto &entry : client_cache_) { - if (!entry.first.is_nil() && !entry.second.is_insertion) { + if (!entry.first.is_nil() && entry.second.entry_type == EntryType::DELETION) { client_removed_callback_(client_, entry.first, entry.second); } } } +void ClientTable::RegisterResourceCreateUpdatedCallback( + const ClientTableCallback &callback) { + resource_createupdated_callback_ = callback; + // Call the callback for any clients that are cached. + for (const auto &entry : client_cache_) { + if (!entry.first.is_nil() && + (entry.second.entry_type == EntryType::RES_CREATEUPDATE)) { + resource_createupdated_callback_(client_, entry.first, entry.second); + } + } +} + +void ClientTable::RegisterResourceDeletedCallback(const ClientTableCallback &callback) { + resource_deleted_callback_ = callback; + // Call the callback for any clients that are cached. + for (const auto &entry : client_cache_) { + if (!entry.first.is_nil() && entry.second.entry_type == EntryType::RES_DELETE) { + resource_deleted_callback_(client_, entry.first, entry.second); + } + } +} + void ClientTable::HandleNotification(AsyncGcsClient *client, const ClientTableDataT &data) { ClientID client_id = ClientID::from_binary(data.client_id); // It's possible to get duplicate notifications from the client table, so // check whether this notification is new. auto entry = client_cache_.find(client_id); - bool is_new; + bool is_notif_new; if (entry == client_cache_.end()) { // If the entry is not in the cache, then the notification is new. - is_new = true; + is_notif_new = true; } else { // If the entry is in the cache, then the notification is new if the client - // was alive and is now dead. - bool was_inserted = entry->second.is_insertion; - bool is_deleted = !data.is_insertion; - is_new = (was_inserted && is_deleted); + // was alive and is now dead or resources have been updated. + bool was_not_deleted = (entry->second.entry_type != EntryType::DELETION); + bool is_deleted = (data.entry_type == EntryType::DELETION); + bool is_res_modified = ((data.entry_type == EntryType::RES_CREATEUPDATE) || + (data.entry_type == EntryType::RES_DELETE)); + is_notif_new = (was_not_deleted && (is_deleted || is_res_modified)); // Once a client with a given ID has been removed, it should never be added // again. If the entry was in the cache and the client was deleted, check // that this new notification is not an insertion. - if (!entry->second.is_insertion) { - RAY_CHECK(!data.is_insertion) + if (entry->second.entry_type == EntryType::DELETION) { + RAY_CHECK((data.entry_type == EntryType::DELETION)) << "Notification for addition of a client that was already removed:" << client_id; } } // Add the notification to our cache. Notifications are idempotent. - client_cache_[client_id] = data; + // If it is a new client or a client removal, add as is + if ((data.entry_type == EntryType::INSERTION) || + (data.entry_type == EntryType::DELETION)) { + RAY_LOG(DEBUG) << "[ClientTableNotification] ClientTable Insertion/Deletion " + "notification for client id " + << client_id << ". EntryType: " << int(data.entry_type) + << ". Setting the client cache to data."; + client_cache_[client_id] = data; + } else if ((data.entry_type == EntryType::RES_CREATEUPDATE) || + (data.entry_type == EntryType::RES_DELETE)) { + RAY_LOG(DEBUG) << "[ClientTableNotification] ClientTable RES_CREATEUPDATE " + "notification for client id " + << client_id << ". EntryType: " << int(data.entry_type) + << ". Updating the client cache with the delta from the log."; + + ClientTableDataT &cache_data = client_cache_[client_id]; + // Iterate over all resources in the new create/update notification + for (std::vector::size_type i = 0; i != data.resources_total_label.size(); i++) { + auto const &resource_name = data.resources_total_label[i]; + auto const &capacity = data.resources_total_capacity[i]; + + // If resource exists in the ClientTableData, update it, else create it + auto existing_resource_label = + std::find(cache_data.resources_total_label.begin(), + cache_data.resources_total_label.end(), resource_name); + if (existing_resource_label != cache_data.resources_total_label.end()) { + auto index = std::distance(cache_data.resources_total_label.begin(), + existing_resource_label); + // Resource already exists, set capacity if updation call.. + if (data.entry_type == EntryType::RES_CREATEUPDATE) { + cache_data.resources_total_capacity[index] = capacity; + } + // .. delete if deletion call. + else if (data.entry_type == EntryType::RES_DELETE) { + cache_data.resources_total_label.erase( + cache_data.resources_total_label.begin() + index); + cache_data.resources_total_capacity.erase( + cache_data.resources_total_capacity.begin() + index); + } + } else { + // Resource does not exist, create resource and add capacity if it was a resource + // create call. + if (data.entry_type == EntryType::RES_CREATEUPDATE) { + cache_data.resources_total_label.push_back(resource_name); + cache_data.resources_total_capacity.push_back(capacity); + } + } + } + } // If the notification is new, call any registered callbacks. - if (is_new) { - if (data.is_insertion) { + ClientTableDataT &cache_data = client_cache_[client_id]; + if (is_notif_new) { + if (data.entry_type == EntryType::INSERTION) { if (client_added_callback_ != nullptr) { - client_added_callback_(client, client_id, data); + client_added_callback_(client, client_id, cache_data); } RAY_CHECK(removed_clients_.find(client_id) == removed_clients_.end()); - } else { + } else if (data.entry_type == EntryType::DELETION) { // NOTE(swang): The client should be added to this data structure before // the callback gets called, in case the callback depends on the data // structure getting updated. removed_clients_.insert(client_id); if (client_removed_callback_ != nullptr) { - client_removed_callback_(client, client_id, data); + client_removed_callback_(client, client_id, cache_data); + } + } else if (data.entry_type == EntryType::RES_CREATEUPDATE) { + if (resource_createupdated_callback_ != nullptr) { + resource_createupdated_callback_(client, client_id, cache_data); + } + } else if (data.entry_type == EntryType::RES_DELETE) { + if (resource_deleted_callback_ != nullptr) { + resource_deleted_callback_(client, client_id, cache_data); } } } @@ -449,7 +530,7 @@ Status ClientTable::Connect(const ClientTableDataT &local_client) { // Construct the data to add to the client table. auto data = std::make_shared(local_client_); - data->is_insertion = true; + data->entry_type = EntryType::INSERTION; // Callback to handle our own successful connection once we've added // ourselves. auto add_callback = [this](AsyncGcsClient *client, const UniqueID &log_key, @@ -467,7 +548,7 @@ Status ClientTable::Connect(const ClientTableDataT &local_client) { for (auto ¬ification : notifications) { // This is temporary fix for Issue 4140 to avoid connect to dead nodes. // TODO(yuhguo): remove this temporary fix after GCS entry is removable. - if (notification.is_insertion) { + if (notification.entry_type != EntryType::DELETION) { connected_nodes.emplace(notification.client_id, notification); } else { auto iter = connected_nodes.find(notification.client_id); @@ -498,7 +579,7 @@ Status ClientTable::Connect(const ClientTableDataT &local_client) { Status ClientTable::Disconnect(const DisconnectCallback &callback) { auto data = std::make_shared(local_client_); - data->is_insertion = false; + data->entry_type = EntryType::DELETION; auto add_callback = [this, callback](AsyncGcsClient *client, const ClientID &id, const ClientTableDataT &data) { HandleConnected(client, data); @@ -516,7 +597,7 @@ Status ClientTable::Disconnect(const DisconnectCallback &callback) { ray::Status ClientTable::MarkDisconnected(const ClientID &dead_client_id) { auto data = std::make_shared(); data->client_id = dead_client_id.binary(); - data->is_insertion = false; + data->entry_type = EntryType::DELETION; return Append(DriverID::nil(), client_log_key_, data, nullptr); } diff --git a/src/ray/gcs/tables.h b/src/ray/gcs/tables.h index b22910832..056bf7b97 100644 --- a/src/ray/gcs/tables.h +++ b/src/ray/gcs/tables.h @@ -677,7 +677,7 @@ using ConfigTable = Table; /// it should append an entry to the log indicating that it is dead. A client /// that is marked as dead should never again be marked as alive; if it needs /// to reconnect, it must connect with a different ClientID. -class ClientTable : private Log { +class ClientTable : public Log { public: using ClientTableCallback = std::function; @@ -729,6 +729,16 @@ class ClientTable : private Log { /// \param callback The callback to register. void RegisterClientRemovedCallback(const ClientTableCallback &callback); + /// Register a callback to call when a resource is created or updated. + /// + /// \param callback The callback to register. + void RegisterResourceCreateUpdatedCallback(const ClientTableCallback &callback); + + /// Register a callback to call when a resource is deleted. + /// + /// \param callback The callback to register. + void RegisterResourceDeletedCallback(const ClientTableCallback &callback); + /// Get a client's information from the cache. The cache only contains /// information for clients that we've heard a notification for. /// @@ -772,16 +782,16 @@ class ClientTable : private Log { /// \return string. std::string DebugString() const; + /// The key at which the log of client information is stored. This key must + /// be kept the same across all instances of the ClientTable, so that all + /// clients append and read from the same key. + ClientID client_log_key_; + private: /// Handle a client table notification. void HandleNotification(AsyncGcsClient *client, const ClientTableDataT ¬ifications); /// Handle this client's successful connection to the GCS. void HandleConnected(AsyncGcsClient *client, const ClientTableDataT &client_data); - - /// The key at which the log of client information is stored. This key must - /// be kept the same across all instances of the ClientTable, so that all - /// clients append and read from the same key. - ClientID client_log_key_; /// Whether this client has called Disconnect(). bool disconnected_; /// This client's ID. @@ -792,6 +802,10 @@ class ClientTable : private Log { ClientTableCallback client_added_callback_; /// The callback to call when a client is removed. ClientTableCallback client_removed_callback_; + /// The callback to call when a resource is created or updated. + ClientTableCallback resource_createupdated_callback_; + /// The callback to call when a resource is deleted. + ClientTableCallback resource_deleted_callback_; /// A cache for information about all clients. std::unordered_map client_cache_; /// The set of removed clients. diff --git a/src/ray/object_manager/object_directory.cc b/src/ray/object_manager/object_directory.cc index 99ed0851c..85157abcd 100644 --- a/src/ray/object_manager/object_directory.cc +++ b/src/ray/object_manager/object_directory.cc @@ -108,7 +108,7 @@ void ObjectDirectory::LookupRemoteConnectionInfo( ClientID result_client_id = ClientID::from_binary(client_data.client_id); if (!result_client_id.is_nil()) { RAY_CHECK(result_client_id == connection_info.client_id); - if (client_data.is_insertion) { + if (client_data.entry_type == EntryType::INSERTION) { connection_info.ip = client_data.node_manager_address; connection_info.port = static_cast(client_data.object_manager_port); } diff --git a/src/ray/raylet/format/node_manager.fbs b/src/ray/raylet/format/node_manager.fbs index f673e2251..a5b041f29 100644 --- a/src/ray/raylet/format/node_manager.fbs +++ b/src/ray/raylet/format/node_manager.fbs @@ -77,6 +77,8 @@ enum MessageType:int { NotifyActorResumedFromCheckpoint, // A node manager requests to connect to another node manager. ConnectClient, + // Set dynamic custom resource + SetResourceRequest, } table TaskExecutionSpecification { @@ -234,3 +236,12 @@ table ConnectClient { // ID of the connecting client. client_id: string; } + +table SetResourceRequest{ + // Name of the resource to be set + resource_name: string; + // Capacity of the resource to be set + capacity: double; + // Client ID where this resource will be set + client_id: string; +} diff --git a/src/ray/raylet/monitor.cc b/src/ray/raylet/monitor.cc index 51e035b1b..1e20fe3f4 100644 --- a/src/ray/raylet/monitor.cc +++ b/src/ray/raylet/monitor.cc @@ -52,7 +52,8 @@ void Monitor::Tick() { const std::vector &all_data) { bool marked = false; for (const auto &data : all_data) { - if (client_id.binary() == data.client_id && !data.is_insertion) { + if (client_id.binary() == data.client_id && + data.entry_type == EntryType::DELETION) { // The node has been marked dead by itself. marked = true; } diff --git a/src/ray/raylet/node_manager.cc b/src/ray/raylet/node_manager.cc index f901c6800..fa4ad4868 100644 --- a/src/ray/raylet/node_manager.cc +++ b/src/ray/raylet/node_manager.cc @@ -185,6 +185,22 @@ ray::Status NodeManager::RegisterGcs() { }; gcs_client_->client_table().RegisterClientRemovedCallback(node_manager_client_removed); + // Register a callback on the client table for resource create/update requests + auto node_manager_resource_createupdated = [this]( + gcs::AsyncGcsClient *client, const UniqueID &id, const ClientTableDataT &data) { + ResourceCreateUpdated(data); + }; + gcs_client_->client_table().RegisterResourceCreateUpdatedCallback( + node_manager_resource_createupdated); + + // Register a callback on the client table for resource delete requests + auto node_manager_resource_deleted = [this]( + gcs::AsyncGcsClient *client, const UniqueID &id, const ClientTableDataT &data) { + ResourceDeleted(data); + }; + gcs_client_->client_table().RegisterResourceDeletedCallback( + node_manager_resource_deleted); + // Subscribe to heartbeat batches from the monitor. const auto &heartbeat_batch_added = [this]( gcs::AsyncGcsClient *client, const ClientID &id, @@ -461,6 +477,92 @@ void NodeManager::ClientRemoved(const ClientTableDataT &client_data) { object_directory_->HandleClientRemoved(client_id); } +void NodeManager::ResourceCreateUpdated(const ClientTableDataT &client_data) { + const ClientID client_id = ClientID::from_binary(client_data.client_id); + const ClientID &local_client_id = gcs_client_->client_table().GetLocalClientId(); + + RAY_LOG(DEBUG) << "[ResourceCreateUpdated] received callback from client id " + << client_id << ". Updating resource map."; + ResourceSet new_res_set(client_data.resources_total_label, + client_data.resources_total_capacity); + + const ResourceSet &old_res_set = cluster_resource_map_[client_id].GetTotalResources(); + ResourceSet difference_set = old_res_set.FindUpdatedResources(new_res_set); + RAY_LOG(DEBUG) << "[ResourceCreateUpdated] The difference in the resource map is " + << difference_set.ToString(); + + SchedulingResources &cluster_schedres = cluster_resource_map_[client_id]; + + // Update local_available_resources_ and SchedulingResources + for (const auto &resource_pair : difference_set.GetResourceMap()) { + const std::string &resource_label = resource_pair.first; + const double &new_resource_capacity = resource_pair.second; + + cluster_schedres.UpdateResource(resource_label, new_resource_capacity); + if (client_id == local_client_id) { + local_available_resources_.AddOrUpdateResource(resource_label, + new_resource_capacity); + } + } + RAY_LOG(DEBUG) << "[ResourceCreateUpdated] Updated cluster_resource_map."; + + if (client_id == local_client_id) { + // The resource update is on the local node, check if we can reschedule tasks. + TryLocalInfeasibleTaskScheduling(); + } + return; +} + +void NodeManager::ResourceDeleted(const ClientTableDataT &client_data) { + const ClientID client_id = ClientID::from_binary(client_data.client_id); + const ClientID &local_client_id = gcs_client_->client_table().GetLocalClientId(); + + ResourceSet new_res_set(client_data.resources_total_label, + client_data.resources_total_capacity); + RAY_LOG(DEBUG) << "[ResourceDeleted] received callback from client id " << client_id + << " with new resources: " << new_res_set.ToString() + << ". Updating resource map."; + + const ResourceSet &old_res_set = cluster_resource_map_[client_id].GetTotalResources(); + ResourceSet deleted_set = old_res_set.FindDeletedResources(new_res_set); + RAY_LOG(DEBUG) << "[ResourceDeleted] The difference in the resource map is " + << deleted_set.ToString(); + + SchedulingResources &cluster_schedres = cluster_resource_map_[client_id]; + + // Update local_available_resources_ and SchedulingResources + for (const auto &resource_pair : deleted_set.GetResourceMap()) { + const std::string &resource_label = resource_pair.first; + + cluster_schedres.DeleteResource(resource_label); + if (client_id == local_client_id) { + local_available_resources_.DeleteResource(resource_label); + } + } + RAY_LOG(DEBUG) << "[ResourceDeleted] Updated cluster_resource_map."; + return; +} + +void NodeManager::TryLocalInfeasibleTaskScheduling() { + RAY_LOG(DEBUG) << "[LocalResourceUpdateRescheduler] The resource update is on the " + "local node, check if we can reschedule tasks"; + const ClientID &local_client_id = gcs_client_->client_table().GetLocalClientId(); + SchedulingResources &new_local_resources = cluster_resource_map_[local_client_id]; + + // SpillOver locally to figure out which infeasible tasks can be placed now + std::vector decision = scheduling_policy_.SpillOver(new_local_resources); + + std::unordered_set local_task_ids(decision.begin(), decision.end()); + + // Transition locally placed tasks to waiting or ready for dispatch. + if (local_task_ids.size() > 0) { + std::vector tasks = local_queues_.RemoveTasks(local_task_ids); + for (const auto &t : tasks) { + EnqueuePlaceableTask(t); + } + } +} + void NodeManager::HeartbeatAdded(const ClientID &client_id, const HeartbeatTableDataT &heartbeat_data) { // Locate the client id in remote client table and update available resources based on @@ -718,6 +820,9 @@ void NodeManager::ProcessClientMessage( case protocol::MessageType::SubmitTask: { ProcessSubmitTaskMessage(message_data); } break; + case protocol::MessageType::SetResourceRequest: { + ProcessSetResourceRequest(client, message_data); + } break; case protocol::MessageType::FetchOrReconstruct: { ProcessFetchOrReconstructMessage(client, message_data); } break; @@ -931,12 +1036,14 @@ void NodeManager::ProcessDisconnectClientMessage( // Return the resources that were being used by this worker. auto const &task_resources = worker->GetTaskResourceIds(); - local_available_resources_.Release(task_resources); + local_available_resources_.ReleaseConstrained( + task_resources, cluster_resource_map_[client_id].GetTotalResources()); cluster_resource_map_[client_id].Release(task_resources.ToResourceSet()); worker->ResetTaskResourceIds(); auto const &lifetime_resources = worker->GetLifetimeResourceIds(); - local_available_resources_.Release(lifetime_resources); + local_available_resources_.ReleaseConstrained( + lifetime_resources, cluster_resource_map_[client_id].GetTotalResources()); cluster_resource_map_[client_id].Release(lifetime_resources.ToResourceSet()); worker->ResetLifetimeResourceIds(); @@ -1170,6 +1277,59 @@ void NodeManager::ProcessNodeManagerMessage(TcpClientConnection &node_manager_cl node_manager_client.ProcessMessages(); } +void NodeManager::ProcessSetResourceRequest( + const std::shared_ptr &client, const uint8_t *message_data) { + // Read the SetResource message + auto message = flatbuffers::GetRoot(message_data); + + auto const &resource_name = string_from_flatbuf(*message->resource_name()); + double const &capacity = message->capacity(); + bool is_deletion = capacity <= 0; + + ClientID client_id = from_flatbuf(*message->client_id()); + + // If the python arg was null, set client_id to the local client + if (client_id.is_nil()) { + client_id = gcs_client_->client_table().GetLocalClientId(); + } + + if (is_deletion && + cluster_resource_map_[client_id].GetTotalResources().GetResourceMap().count( + resource_name) == 0) { + // Resource does not exist in the cluster resource map, thus nothing to delete. + // Return.. + RAY_LOG(INFO) << "[ProcessDeleteResourceRequest] Trying to delete resource " + << resource_name << ", but it does not exist. Doing nothing.."; + return; + } + + // Add the new resource to a skeleton ClientTableDataT object + ClientTableDataT data; + gcs_client_->client_table().GetClient(client_id, data); + // Replace the resource vectors with the resource deltas from the message. + // RES_CREATEUPDATE and RES_DELETE entries in the ClientTable track changes (deltas) in + // the resources + data.resources_total_label = std::vector{resource_name}; + data.resources_total_capacity = std::vector{capacity}; + // Set the correct flag for entry_type + if (is_deletion) { + data.entry_type = EntryType::RES_DELETE; + } else { + data.entry_type = EntryType::RES_CREATEUPDATE; + } + + // Submit to the client table. This calls the ResourceCreateUpdated callback, which + // updates cluster_resource_map_. + std::shared_ptr worker = worker_pool_.GetRegisteredWorker(client); + if (not worker) { + worker = worker_pool_.GetRegisteredDriver(client); + } + auto data_shared_ptr = std::make_shared(data); + auto client_table = gcs_client_->client_table(); + RAY_CHECK_OK(gcs_client_->client_table().Append( + DriverID::nil(), client_table.client_log_key_, data_shared_ptr, nullptr)); +} + void NodeManager::ScheduleTasks( std::unordered_map &resource_map) { const ClientID &local_client_id = gcs_client_->client_table().GetLocalClientId(); @@ -1761,7 +1921,9 @@ void NodeManager::FinishAssignedTask(Worker &worker) { // Release task's resources. The worker's lifetime resources are still held. auto const &task_resources = worker.GetTaskResourceIds(); - local_available_resources_.Release(task_resources); + const ClientID &client_id = gcs_client_->client_table().GetLocalClientId(); + local_available_resources_.ReleaseConstrained( + task_resources, cluster_resource_map_[client_id].GetTotalResources()); cluster_resource_map_[gcs_client_->client_table().GetLocalClientId()].Release( task_resources.ToResourceSet()); worker.ResetTaskResourceIds(); diff --git a/src/ray/raylet/node_manager.h b/src/ray/raylet/node_manager.h index edd456dbe..8c5973c73 100644 --- a/src/ray/raylet/node_manager.h +++ b/src/ray/raylet/node_manager.h @@ -120,6 +120,21 @@ class NodeManager { /// \return Void. void ClientRemoved(const ClientTableDataT &client_data); + /// Handler for the addition or updation of a resource in the GCS + /// \param client_data Data associated with the new client. + /// \return Void. + void ResourceCreateUpdated(const ClientTableDataT &client_data); + + /// Handler for the deletion of a resource in the GCS + /// \param client_data Data associated with the new client. + /// \return Void. + void ResourceDeleted(const ClientTableDataT &client_data); + + /// Evaluates the local infeasible queue to check if any tasks can be scheduled. + /// This is called whenever there's an update to the resources on the local client. + /// \return Void. + void TryLocalInfeasibleTaskScheduling(); + /// Send heartbeats to the GCS. void Heartbeat(); @@ -413,6 +428,13 @@ class NodeManager { /// \param task The task that just finished. void UpdateActorFrontier(const Task &task); + /// Process client message of SetResourceRequest + /// \param client The client that sent the message. + /// \param message_data A pointer to the message data. + /// \return Void. + void ProcessSetResourceRequest(const std::shared_ptr &client, + const uint8_t *message_data); + /// Handle the case where an actor is disconnected, determine whether this /// actor needs to be reconstructed and then update actor table. /// This function needs to be called either when actor process dies or when diff --git a/src/ray/raylet/raylet_client.cc b/src/ray/raylet/raylet_client.cc index 09e9b5fed..0f488089e 100644 --- a/src/ray/raylet/raylet_client.cc +++ b/src/ray/raylet/raylet_client.cc @@ -386,3 +386,13 @@ ray::Status RayletClient::NotifyActorResumedFromCheckpoint( return conn_->WriteMessage(MessageType::NotifyActorResumedFromCheckpoint, &fbb); } + +ray::Status RayletClient::SetResource(const std::string &resource_name, + const double capacity, + const ray::ClientID &client_Id) { + flatbuffers::FlatBufferBuilder fbb; + auto message = ray::protocol::CreateSetResourceRequest( + fbb, fbb.CreateString(resource_name), capacity, to_flatbuf(fbb, client_Id)); + fbb.Finish(message); + return conn_->WriteMessage(MessageType::SetResourceRequest, &fbb); +} diff --git a/src/ray/raylet/raylet_client.h b/src/ray/raylet/raylet_client.h index ff66ff462..0bdd076b5 100644 --- a/src/ray/raylet/raylet_client.h +++ b/src/ray/raylet/raylet_client.h @@ -165,6 +165,14 @@ class RayletClient { ray::Status NotifyActorResumedFromCheckpoint(const ActorID &actor_id, const ActorCheckpointID &checkpoint_id); + /// Sets a resource with the specified capacity and client id + /// \param resource_name Name of the resource to be set + /// \param capacity Capacity of the resource + /// \param client_Id ClientID where the resource is to be set + /// \return ray::Status + ray::Status SetResource(const std::string &resource_name, const double capacity, + const ray::ClientID &client_Id); + Language GetLanguage() const { return language_; } ClientID GetClientID() const { return client_id_; }