[GCS] global state query node info table from GCS. (#8498)

This commit is contained in:
Lingxuan Zuo
2020-05-28 16:39:13 +08:00
committed by GitHub
parent 675ccbc799
commit e594524ed3
27 changed files with 162 additions and 101 deletions
@@ -4,6 +4,7 @@ import time
import ray
from ray.cluster_utils import Cluster
from ray.test_utils import get_non_head_nodes
num_redis_shards = 5
redis_max_memory = 10**8
@@ -51,7 +52,7 @@ while True:
for _ in range(100):
previous_ids = [f.remote(previous_id) for previous_id in previous_ids]
node_to_kill = cluster.list_all_nodes()[1]
node_to_kill = get_non_head_nodes(cluster)[0]
# Remove the first non-head node.
cluster.remove_node(node_to_kill)
cluster.add_node()
@@ -108,8 +108,12 @@ public final class RayNativeRuntime extends AbstractRayRuntime {
manager.cleanup();
manager = null;
}
RayConfig.reset();
}
if (null != gcsClient) {
gcsClient.destroy();
gcsClient = null;
}
RayConfig.reset();
LOGGER.info("RayNativeRuntime shutdown");
}
@@ -32,6 +32,7 @@ public class GcsClient {
private RedisClient primary;
private List<RedisClient> shards;
private GlobalStateAccessor globalStateAccessor;
public GcsClient(String redisAddress, String redisPassword) {
primary = new RedisClient(redisAddress, redisPassword);
@@ -49,16 +50,11 @@ public class GcsClient {
shards = shardAddresses.stream().map((byte[] address) -> {
return new RedisClient(new String(address), redisPassword);
}).collect(Collectors.toList());
globalStateAccessor = GlobalStateAccessor.getInstance(redisAddress, redisPassword);
}
public List<NodeInfo> getAllNodeInfo() {
final String prefix = TablePrefix.CLIENT.toString();
final byte[] key = ArrayUtils.addAll(prefix.getBytes(), UniqueId.NIL.getBytes());
List<byte[]> results = primary.lrange(key, 0, -1);
if (results == null) {
return new ArrayList<>();
}
List<byte[]> results = globalStateAccessor.getAllNodeInfo();
// This map is used for deduplication of node entries.
Map<UniqueId, NodeInfo> nodes = new HashMap<>();
@@ -191,6 +187,15 @@ public class GcsClient {
return JobId.fromInt(jobCounter);
}
/**
* Destroy global state accessor when ray native runtime will be shutdown.
*/
public void destroy() {
// Only ray shutdown should call gcs client destroy.
LOGGER.debug("Destroying global state accessor.");
GlobalStateAccessor.destroyInstance();
}
private RedisClient getShardClient(BaseId key) {
return shards.get((int) Long.remainderUnsigned(IdUtil.murmurHashCode(key),
shards.size()));
@@ -24,6 +24,7 @@ public class GlobalStateAccessor {
public static synchronized void destroyInstance() {
if (null != globalStateAccessor) {
globalStateAccessor.destroyGlobalStateAccessor();
globalStateAccessor = null;
}
}
@@ -45,7 +46,8 @@ public class GlobalStateAccessor {
public List<byte[]> getAllJobInfo() {
// Fetch a job list with protobuf bytes format from GCS.
synchronized (GlobalStateAccessor.class) {
Preconditions.checkState(globalStateAccessorNativePointer != 0);
Preconditions.checkState(globalStateAccessorNativePointer != 0,
"Get all job info when global state accessor have been destroyed.");
return this.nativeGetAllJobInfo(globalStateAccessorNativePointer);
}
}
@@ -56,7 +58,8 @@ public class GlobalStateAccessor {
public List<byte[]> getAllNodeInfo() {
// Fetch a node list with protobuf bytes format from GCS.
synchronized (GlobalStateAccessor.class) {
Preconditions.checkState(globalStateAccessorNativePointer != 0);
Preconditions.checkState(globalStateAccessorNativePointer != 0,
"Get all node info when global state accessor have been destroyed.");
return this.nativeGetAllNodeInfo(globalStateAccessorNativePointer);
}
}
+7 -12
View File
@@ -2,8 +2,6 @@ import json
import logging
import time
import redis
import ray
from ray import ray_constants
@@ -33,6 +31,8 @@ class Cluster:
self.worker_nodes = set()
self.redis_address = None
self.connected = False
# Create a new global state accessor for fetching GCS table.
self.global_state = ray.state.GlobalState()
self._shutdown_at_exit = shutdown_at_exit
if not initialize_head and connect:
raise RuntimeError("Cannot connect to uninitialized cluster.")
@@ -96,6 +96,9 @@ class Cluster:
self.redis_password = node_args.get(
"redis_password", ray_constants.REDIS_DEFAULT_PASSWORD)
self.webui_url = self.head_node.webui_url
# Init global state accessor when creating head node.
self.global_state._initialize_global_state(self.redis_address,
self.redis_password)
else:
ray_params.update_if_absent(redis_address=self.redis_address)
# We only need one log monitor per physical node.
@@ -150,13 +153,9 @@ class Cluster:
TimeoutError: An exception is raised if the timeout expires before
the node appears in the client table.
"""
ip_address, port = self.redis_address.split(":")
redis_client = redis.StrictRedis(
host=ip_address, port=int(port), password=self.redis_password)
start_time = time.time()
while time.time() - start_time < timeout:
clients = ray.state._parse_client_table(redis_client)
clients = self.global_state.node_table()
object_store_socket_names = [
client["ObjectStoreSocketName"] for client in clients
]
@@ -183,13 +182,9 @@ class Cluster:
TimeoutError: An exception is raised if we time out while waiting
for nodes to join.
"""
ip_address, port = self.address.split(":")
redis_client = redis.StrictRedis(
host=ip_address, port=int(port), password=self.redis_password)
start_time = time.time()
while time.time() - start_time < timeout:
clients = ray.state._parse_client_table(redis_client)
clients = self.global_state.node_table()
live_clients = [client for client in clients if client["Alive"]]
expected = len(self.list_all_nodes())
@@ -14,6 +14,7 @@ cdef extern from "ray/gcs/gcs_client/global_state_accessor.h" nogil:
c_bool Connect()
void Disconnect()
c_vector[c_string] GetAllJobInfo()
c_vector[c_string] GetAllNodeInfo()
c_vector[c_string] GetAllProfileInfo()
c_vector[c_string] GetAllObjectInfo()
unique_ptr[c_string] GetObjectInfo(const CObjectID &object_id)
@@ -29,6 +29,9 @@ cdef class GlobalStateAccessor:
def get_job_table(self):
return self.inner.get().GetAllJobInfo()
def get_node_table(self):
return self.inner.get().GetAllNodeInfo()
def get_profile_table(self):
return self.inner.get().GetAllProfileInfo()
+4 -5
View File
@@ -171,11 +171,10 @@ def get_address_info_from_redis_helper(redis_address,
node_ip_address,
redis_password=None):
redis_ip_address, redis_port = redis_address.split(":")
# For this command to work, some other client (on the same machine as
# Redis) must have run "CONFIG SET protected-mode no".
redis_client = create_redis_client(redis_address, password=redis_password)
client_table = ray.state._parse_client_table(redis_client)
# Get node table from global state accessor.
global_state = ray.state.GlobalState()
global_state._initialize_global_state(redis_address, redis_password)
client_table = global_state.node_table()
if len(client_table) == 0:
raise RuntimeError(
"Redis has started but no raylets have registered yet.")
+29 -13
View File
@@ -369,19 +369,35 @@ class GlobalState:
ray.ActorID(actor_id_binary))
return results
def client_table(self):
"""Fetch and parse the Redis DB client table.
def node_table(self):
"""Fetch and parse the Gcs node info table.
Returns:
Information about the Ray clients in the cluster.
Information about the node in the cluster.
"""
self._check_connected()
client_table = _parse_client_table(self.redis_client)
for client in client_table:
# These are equivalent and is better for application developers.
client["alive"] = client["Alive"]
return client_table
node_table = self.global_state_accessor.get_node_table()
results = []
for node_info_item in node_table:
item = gcs_utils.GcsNodeInfo.FromString(node_info_item)
node_info = {
"NodeID": ray.utils.binary_to_hex(item.node_id),
"Alive": item.state ==
gcs_utils.GcsNodeInfo.GcsNodeState.Value("ALIVE"),
"NodeManagerAddress": item.node_manager_address,
"NodeManagerHostname": item.node_manager_hostname,
"NodeManagerPort": item.node_manager_port,
"ObjectManagerPort": item.object_manager_port,
"ObjectStoreSocketName": item.object_store_socket_name,
"RayletSocketName": item.raylet_socket_name
}
node_info["alive"] = node_info["Alive"]
node_info["Resources"] = _parse_resource_table(
self.redis_client,
node_info["NodeID"]) if node_info["Alive"] else {}
results.append(node_info)
return results
def job_table(self):
"""Fetch and parse the Redis job table.
@@ -597,7 +613,7 @@ class GlobalState:
self._check_connected()
node_id_to_address = {}
for node_info in self.client_table():
for node_info in self.node_table():
node_id_to_address[node_info["NodeID"]] = "{}:{}".format(
node_info["NodeManagerAddress"],
node_info["ObjectManagerPort"])
@@ -728,7 +744,7 @@ class GlobalState:
self._check_connected()
resources = defaultdict(int)
clients = self.client_table()
clients = self.node_table()
for client in clients:
# Only count resources from latest entries of live clients.
if client["Alive"]:
@@ -740,7 +756,7 @@ class GlobalState:
"""Returns a set of client IDs corresponding to clients still alive."""
return {
client["NodeID"]
for client in self.client_table() if (client["Alive"])
for client in self.node_table() if (client["Alive"])
}
def available_resources(self):
@@ -929,7 +945,7 @@ def nodes():
Returns:
Information about the Ray clients in the cluster.
"""
return state.client_table()
return state.node_table()
def current_node_id():
+14
View File
@@ -288,3 +288,17 @@ def wait_until_server_available(address,
s.close()
return True
return False
def get_other_nodes(cluster, exclude_head=False):
"""Get all nodes except the one that we're connected to."""
return [
node for node in cluster.list_all_nodes() if
node._raylet_socket_name != ray.worker._global_node._raylet_socket_name
and (exclude_head is False or node.head is False)
]
def get_non_head_nodes(cluster):
"""Get all non-head nodes."""
return list(filter(lambda x: x.head is False, cluster.list_all_nodes()))
+4 -2
View File
@@ -105,8 +105,10 @@ def _ray_start_cluster(**kwargs):
remote_nodes = []
for _ in range(num_nodes):
remote_nodes.append(cluster.add_node(**init_kwargs))
if do_init:
ray.init(address=cluster.address)
# We assume driver will connect to the head (first node),
# so ray init will be invoked if do_init is true
if len(remote_nodes) == 1 and do_init:
ray.init(address=cluster.address)
yield cluster
# The code after the yield will run as teardown code.
ray.shutdown()
+4 -4
View File
@@ -11,7 +11,7 @@ import time
import ray
import ray.test_utils
import ray.cluster_utils
from ray.test_utils import run_string_as_driver
from ray.test_utils import run_string_as_driver, get_non_head_nodes
from ray.experimental.internal_kv import _internal_kv_get, _internal_kv_put
@@ -335,7 +335,7 @@ def test_distributed_handle(ray_start_cluster_2_nodes):
# Kill the second plasma store to get rid of the cached objects and
# trigger the corresponding raylet to exit.
cluster.list_all_nodes()[1].kill_plasma_store(wait=True)
get_non_head_nodes(cluster)[0].kill_plasma_store(wait=True)
# Check that the actor did not restore from a checkpoint.
assert not ray.get(counter.test_restore.remote())
@@ -374,7 +374,7 @@ def test_remote_checkpoint_distributed_handle(ray_start_cluster_2_nodes):
# Kill the second plasma store to get rid of the cached objects and
# trigger the corresponding raylet to exit.
cluster.list_all_nodes()[1].kill_plasma_store(wait=True)
get_non_head_nodes(cluster)[0].kill_plasma_store(wait=True)
# Check that the actor restored from a checkpoint.
assert ray.get(counter.test_restore.remote())
@@ -414,7 +414,7 @@ def test_checkpoint_distributed_handle(ray_start_cluster_2_nodes):
# Kill the second plasma store to get rid of the cached objects and
# trigger the corresponding raylet to exit.
cluster.list_all_nodes()[1].kill_plasma_store(wait=True)
get_non_head_nodes(cluster)[0].kill_plasma_store(wait=True)
# Check that the actor restored from a checkpoint.
assert ray.get(counter.test_restore.remote())
+13 -7
View File
@@ -11,9 +11,15 @@ import ray
import ray.ray_constants as ray_constants
import ray.test_utils
import ray.cluster_utils
from ray.test_utils import (relevant_errors, wait_for_condition,
wait_for_errors, wait_for_pid_to_exit,
generate_internal_config_map)
from ray.test_utils import (
relevant_errors,
wait_for_condition,
wait_for_errors,
wait_for_pid_to_exit,
generate_internal_config_map,
get_non_head_nodes,
get_other_nodes,
)
SIGKILL = signal.SIGKILL if sys.platform != "win32" else signal.SIGTERM
@@ -305,7 +311,7 @@ def test_actor_restart_on_node_failure(ray_start_cluster):
ray.get(actor.ready.remote())
results = [actor.increase.remote() for _ in range(100)]
# Kill actor node, while the above task is still being executed.
cluster.remove_node(cluster.list_all_nodes()[-1])
cluster.remove_node(get_non_head_nodes(cluster)[-1])
cluster.add_node(num_cpus=1, _internal_config=config)
cluster.wait_for_nodes()
# Check that none of the tasks failed and the actor is restarted.
@@ -821,7 +827,7 @@ def test_decorated_method(ray_start_regular):
@pytest.mark.parametrize(
"ray_start_cluster", [{
"num_cpus": 1,
"num_nodes": 2,
"num_nodes": 3,
}], indirect=True)
def test_ray_wait_dead_actor(ray_start_cluster):
"""Tests that methods completed by dead actors are returned as ready"""
@@ -857,8 +863,8 @@ def test_ray_wait_dead_actor(ray_start_cluster):
except ray.exceptions.RayActorError:
return True
# Kill a node.
cluster.remove_node(cluster.list_all_nodes()[-1])
# Kill a node that must not be driver node or head node.
cluster.remove_node(get_other_nodes(cluster, exclude_head=True)[-1])
# Repeatedly submit tasks and call ray.wait until the exception for the
# dead actor is received.
assert wait_for_condition(actor_dead)
+6 -6
View File
@@ -566,8 +566,10 @@ def test_lifetime_and_transient_resources(ray_start_regular):
def test_custom_label_placement(ray_start_cluster):
cluster = ray_start_cluster
cluster.add_node(num_cpus=2, resources={"CustomResource1": 2})
cluster.add_node(num_cpus=2, resources={"CustomResource2": 2})
custom_resource1_node = cluster.add_node(
num_cpus=2, resources={"CustomResource1": 2})
custom_resource2_node = cluster.add_node(
num_cpus=2, resources={"CustomResource2": 2})
ray.init(address=cluster.address)
@ray.remote(resources={"CustomResource1": 1})
@@ -580,17 +582,15 @@ def test_custom_label_placement(ray_start_cluster):
def get_location(self):
return ray.worker.global_worker.node.unique_id
node_id = ray.worker.global_worker.node.unique_id
# Create some actors.
actors1 = [ResourceActor1.remote() for _ in range(2)]
actors2 = [ResourceActor2.remote() for _ in range(2)]
locations1 = ray.get([a.get_location.remote() for a in actors1])
locations2 = ray.get([a.get_location.remote() for a in actors2])
for location in locations1:
assert location == node_id
assert location == custom_resource1_node.unique_id
for location in locations2:
assert location != node_id
assert location == custom_resource2_node.unique_id
def test_creating_more_actors_than_resources(shutdown_only):
+7 -12
View File
@@ -251,11 +251,9 @@ def test_zero_cpus(shutdown_only):
def test_zero_cpus_actor(ray_start_cluster):
cluster = ray_start_cluster
cluster.add_node(num_cpus=0)
cluster.add_node(num_cpus=2)
valid_node = cluster.add_node(num_cpus=2)
ray.init(address=cluster.address)
node_id = ray.worker.global_worker.node.unique_id
@ray.remote
class Foo:
def method(self):
@@ -263,7 +261,7 @@ def test_zero_cpus_actor(ray_start_cluster):
# Make sure tasks and actors run on the remote raylet.
a = Foo.remote()
assert ray.get(a.method.remote()) != node_id
assert valid_node.unique_id == ray.get(a.method.remote())
def test_fractional_resources(shutdown_only):
@@ -446,7 +444,8 @@ def test_multiple_raylets(ray_start_cluster):
def test_custom_resources(ray_start_cluster):
cluster = ray_start_cluster
cluster.add_node(num_cpus=3, resources={"CustomResource": 0})
cluster.add_node(num_cpus=3, resources={"CustomResource": 1})
custom_resource_node = cluster.add_node(
num_cpus=3, resources={"CustomResource": 1})
ray.init(address=cluster.address)
@ray.remote
@@ -467,12 +466,10 @@ def test_custom_resources(ray_start_cluster):
# The f tasks should be scheduled on both raylets.
assert len(set(ray.get([f.remote() for _ in range(500)]))) == 2
node_id = ray.worker.global_worker.node.unique_id
# The g tasks should be scheduled only on the second raylet.
raylet_ids = set(ray.get([g.remote() for _ in range(50)]))
assert len(raylet_ids) == 1
assert list(raylet_ids)[0] != node_id
assert list(raylet_ids)[0] == custom_resource_node.unique_id
# Make sure that resource bookkeeping works when a task that uses a
# custom resources gets blocked.
@@ -506,7 +503,7 @@ def test_two_custom_resources(ray_start_cluster):
"CustomResource1": 1,
"CustomResource2": 2
})
cluster.add_node(
custom_resource_node = cluster.add_node(
num_cpus=3, resources={
"CustomResource1": 3,
"CustomResource2": 4
@@ -542,12 +539,10 @@ def test_two_custom_resources(ray_start_cluster):
assert len(set(ray.get([f.remote() for _ in range(500)]))) == 2
assert len(set(ray.get([g.remote() for _ in range(500)]))) == 2
node_id = ray.worker.global_worker.node.unique_id
# The h tasks should be scheduled only on the second raylet.
raylet_ids = set(ray.get([h.remote() for _ in range(50)]))
assert len(raylet_ids) == 1
assert list(raylet_ids)[0] != node_id
assert list(raylet_ids)[0] == custom_resource_node.unique_id
# Make sure that tasks with unsatisfied custom resource requirements do
# not get scheduled.
@@ -9,7 +9,7 @@ import pytest
import ray
import ray.ray_constants as ray_constants
from ray.cluster_utils import Cluster
from ray.test_utils import RayTestTimeoutException
from ray.test_utils import RayTestTimeoutException, get_other_nodes
SIGKILL = signal.SIGKILL if sys.platform != "win32" else signal.SIGTERM
@@ -90,7 +90,7 @@ def _test_component_failed(cluster, component_type):
# execute. Do this in a loop while submitting tasks between each
# component failure.
time.sleep(0.1)
worker_nodes = cluster.list_all_nodes()[1:]
worker_nodes = get_other_nodes(cluster)
assert len(worker_nodes) > 0
for node in worker_nodes:
process = node.all_processes[component_type][0].process
@@ -119,7 +119,7 @@ def _test_component_failed(cluster, component_type):
def check_components_alive(cluster, component_type, check_component_alive):
"""Check that a given component type is alive on all worker nodes."""
worker_nodes = cluster.list_all_nodes()[1:]
worker_nodes = get_other_nodes(cluster)
assert len(worker_nodes) > 0
for node in worker_nodes:
process = node.all_processes[component_type][0].process
@@ -7,6 +7,7 @@ import pytest
import ray
import ray.ray_constants as ray_constants
from ray.test_utils import get_other_nodes
@pytest.mark.parametrize(
@@ -71,7 +72,7 @@ def test_actor_creation_node_failure(ray_start_cluster):
# Remove a node. Any actor creation tasks that were forwarded to this
# node must be restarted.
cluster.remove_node(cluster.list_all_nodes()[-1])
cluster.remove_node(get_other_nodes(cluster, True)[-1])
@pytest.mark.skipif(
+1 -1
View File
@@ -1062,12 +1062,12 @@ def test_fate_sharing(ray_start_cluster, use_actors, node_failure):
cluster = Cluster()
# Head node with no resources.
cluster.add_node(num_cpus=0, _internal_config=config)
ray.init(address=cluster.address)
# Node to place the parent actor.
node_to_kill = cluster.add_node(num_cpus=1, resources={"parent": 1})
# Node to place the child actor.
cluster.add_node(num_cpus=1, resources={"child": 1})
cluster.wait_for_nodes()
ray.init(address=cluster.address)
@ray.remote
def sleep():
+3 -3
View File
@@ -9,7 +9,7 @@ import pytest
import ray
import ray.ray_constants as ray_constants
from ray.cluster_utils import Cluster
from ray.test_utils import RayTestTimeoutException
from ray.test_utils import RayTestTimeoutException, get_other_nodes
SIGKILL = signal.SIGKILL if sys.platform != "win32" else signal.SIGTERM
@@ -96,7 +96,7 @@ def _test_component_failed(cluster, component_type):
# execute. Do this in a loop while submitting tasks between each
# component failure.
time.sleep(0.1)
worker_nodes = cluster.list_all_nodes()[1:]
worker_nodes = get_other_nodes(cluster)
assert len(worker_nodes) > 0
for node in worker_nodes:
process = node.all_processes[component_type][0].process
@@ -125,7 +125,7 @@ def _test_component_failed(cluster, component_type):
def check_components_alive(cluster, component_type, check_component_alive):
"""Check that a given component type is alive on all worker nodes."""
worker_nodes = cluster.list_all_nodes()[1:]
worker_nodes = get_other_nodes(cluster)
assert len(worker_nodes) > 0
for node in worker_nodes:
process = node.all_processes[component_type][0].process
@@ -7,6 +7,7 @@ import numpy as np
import pytest
import ray
from ray.test_utils import get_other_nodes
import ray.ray_constants as ray_constants
@@ -45,7 +46,7 @@ def test_object_reconstruction(ray_start_cluster):
# execute. Do this in a loop while submitting tasks between each
# component failure.
time.sleep(0.1)
worker_nodes = cluster.list_all_nodes()[1:]
worker_nodes = get_other_nodes(cluster)
assert len(worker_nodes) > 0
component_type = ray_constants.PROCESS_TYPE_RAYLET
for node in worker_nodes:
@@ -121,7 +122,7 @@ def test_actor_creation_node_failure(ray_start_cluster):
children[i] = Child.remote(death_probability)
# Remove a node. Any actor creation tasks that were forwarded to this
# node must be resubmitted.
cluster.remove_node(cluster.list_all_nodes()[-1])
cluster.remove_node(get_other_nodes(cluster, True)[-1])
@pytest.mark.skipif(
+6 -6
View File
@@ -18,13 +18,13 @@ def test_cached_object(ray_start_cluster):
cluster = Cluster()
# Head node with no resources.
cluster.add_node(num_cpus=0, _internal_config=config)
ray.init(address=cluster.address)
# Node to place the initial object.
node_to_kill = cluster.add_node(
num_cpus=1, resources={"node1": 1}, object_store_memory=10**8)
cluster.add_node(
num_cpus=1, resources={"node2": 1}, object_store_memory=10**8)
cluster.wait_for_nodes()
ray.init(address=cluster.address)
@ray.remote
def large_object():
@@ -61,6 +61,7 @@ def test_reconstruction_cached_dependency(ray_start_cluster,
cluster = Cluster()
# Head node with no resources.
cluster.add_node(num_cpus=0, _internal_config=config)
ray.init(address=cluster.address)
# Node to place the initial object.
node_to_kill = cluster.add_node(
num_cpus=1,
@@ -73,7 +74,6 @@ def test_reconstruction_cached_dependency(ray_start_cluster,
object_store_memory=10**8,
_internal_config=config)
cluster.wait_for_nodes()
ray.init(address=cluster.address)
@ray.remote(max_retries=0)
def large_object():
@@ -123,6 +123,7 @@ def test_basic_reconstruction(ray_start_cluster, reconstruction_enabled):
cluster = Cluster()
# Head node with no resources.
cluster.add_node(num_cpus=0, _internal_config=config)
ray.init(address=cluster.address)
# Node to place the initial object.
node_to_kill = cluster.add_node(
num_cpus=1,
@@ -135,7 +136,6 @@ def test_basic_reconstruction(ray_start_cluster, reconstruction_enabled):
object_store_memory=10**8,
_internal_config=config)
cluster.wait_for_nodes()
ray.init(address=cluster.address)
@ray.remote(max_retries=1 if reconstruction_enabled else 0)
def large_object():
@@ -175,6 +175,7 @@ def test_basic_reconstruction_put(ray_start_cluster, reconstruction_enabled):
cluster = Cluster()
# Head node with no resources.
cluster.add_node(num_cpus=0, _internal_config=config)
ray.init(address=cluster.address)
# Node to place the initial object.
node_to_kill = cluster.add_node(
num_cpus=1,
@@ -187,7 +188,6 @@ def test_basic_reconstruction_put(ray_start_cluster, reconstruction_enabled):
object_store_memory=10**8,
_internal_config=config)
cluster.wait_for_nodes()
ray.init(address=cluster.address)
@ray.remote(max_retries=1 if reconstruction_enabled else 0)
def large_object():
@@ -230,6 +230,7 @@ def test_multiple_downstream_tasks(ray_start_cluster, reconstruction_enabled):
cluster = Cluster()
# Head node with no resources.
cluster.add_node(num_cpus=0, _internal_config=config)
ray.init(address=cluster.address)
# Node to place the initial object.
node_to_kill = cluster.add_node(
num_cpus=1,
@@ -242,7 +243,6 @@ def test_multiple_downstream_tasks(ray_start_cluster, reconstruction_enabled):
object_store_memory=10**8,
_internal_config=config)
cluster.wait_for_nodes()
ray.init(address=cluster.address)
@ray.remote(max_retries=1 if reconstruction_enabled else 0)
def large_object():
@@ -294,10 +294,10 @@ def test_reconstruction_chain(ray_start_cluster, reconstruction_enabled):
# Head node with no resources.
cluster.add_node(
num_cpus=0, _internal_config=config, object_store_memory=10**8)
ray.init(address=cluster.address)
node_to_kill = cluster.add_node(
num_cpus=1, object_store_memory=10**8, _internal_config=config)
cluster.wait_for_nodes()
ray.init(address=cluster.address)
@ray.remote(max_retries=1 if reconstruction_enabled else 0)
def large_object():
+8 -4
View File
@@ -108,7 +108,8 @@ void GcsNodeManager::NodeFailureDetector::ScheduleTick() {
GcsNodeManager::GcsNodeManager(boost::asio::io_service &io_service,
gcs::NodeInfoAccessor &node_info_accessor,
gcs::ErrorInfoAccessor &error_info_accessor,
std::shared_ptr<gcs::GcsPubSub> gcs_pub_sub)
std::shared_ptr<gcs::GcsPubSub> gcs_pub_sub,
std::shared_ptr<gcs::GcsTableStorage> gcs_table_storage)
: node_info_accessor_(node_info_accessor),
error_info_accessor_(error_info_accessor),
node_failure_detector_(new NodeFailureDetector(
@@ -125,7 +126,8 @@ GcsNodeManager::GcsNodeManager(boost::asio::io_service &io_service,
// TODO(Shanly): Remove node resources from resource table.
}
})),
gcs_pub_sub_(gcs_pub_sub) {
gcs_pub_sub_(gcs_pub_sub),
gcs_table_storage_(gcs_table_storage) {
// TODO(Shanly): Load node info list from storage synchronously.
// TODO(Shanly): Load cluster resources from storage synchronously.
}
@@ -144,7 +146,8 @@ void GcsNodeManager::HandleRegisterNode(const rpc::RegisterNodeRequest &request,
request.node_info().SerializeAsString(), nullptr));
GCS_RPC_SEND_REPLY(send_reply_callback, reply, status);
};
RAY_CHECK_OK(node_info_accessor_.AsyncRegister(request.node_info(), on_done));
RAY_CHECK_OK(
gcs_table_storage_->NodeTable().Put(node_id, request.node_info(), on_done));
}
void GcsNodeManager::HandleUnregisterNode(const rpc::UnregisterNodeRequest &request,
@@ -163,7 +166,8 @@ void GcsNodeManager::HandleUnregisterNode(const rpc::UnregisterNodeRequest &requ
node->SerializeAsString(), nullptr));
GCS_RPC_SEND_REPLY(send_reply_callback, reply, status);
};
RAY_CHECK_OK(node_info_accessor_.AsyncUnregister(node_id, on_done));
// Update node state to DEAD instead of deleting it.
RAY_CHECK_OK(gcs_table_storage_->NodeTable().Put(node_id, *node, on_done));
// TODO(Shanly): Remove node resources from resource table.
}
}
+8 -2
View File
@@ -22,6 +22,7 @@
#include <ray/rpc/gcs_server/gcs_rpc_server.h>
#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "gcs_table_storage.h"
#include "ray/gcs/pubsub/gcs_pub_sub.h"
namespace ray {
@@ -36,12 +37,15 @@ class GcsNodeManager : public rpc::NodeInfoHandler {
///
/// \param io_service The event loop to run the monitor on.
/// \param node_info_accessor The node info accessor.
/// \param error_info_accessor The error info accessor, which is used to report error
/// \param error_info_accessor The error info accessor, which is used to report error.
/// \param gcs_pub_sub GCS message pushlisher.
/// \param gcs_table_storage GCS table external storage accessor.
/// when detecting the death of nodes.
explicit GcsNodeManager(boost::asio::io_service &io_service,
gcs::NodeInfoAccessor &node_info_accessor,
gcs::ErrorInfoAccessor &error_info_accessor,
std::shared_ptr<gcs::GcsPubSub> gcs_pub_sub);
std::shared_ptr<gcs::GcsPubSub> gcs_pub_sub,
std::shared_ptr<gcs::GcsTableStorage> gcs_table_storage);
/// Handle register rpc request come from raylet.
void HandleRegisterNode(const rpc::RegisterNodeRequest &request,
@@ -203,6 +207,8 @@ class GcsNodeManager : public rpc::NodeInfoHandler {
node_removed_listeners_;
/// A publisher for publishing gcs messages.
std::shared_ptr<gcs::GcsPubSub> gcs_pub_sub_;
/// Storage for GCS tables.
std::shared_ptr<gcs::GcsTableStorage> gcs_table_storage_;
};
} // namespace gcs
+3 -3
View File
@@ -133,9 +133,9 @@ void GcsServer::InitBackendClient() {
void GcsServer::InitGcsNodeManager() {
RAY_CHECK(redis_gcs_client_ != nullptr);
gcs_node_manager_ =
std::make_shared<GcsNodeManager>(main_service_, redis_gcs_client_->Nodes(),
redis_gcs_client_->Errors(), gcs_pub_sub_);
gcs_node_manager_ = std::make_shared<GcsNodeManager>(
main_service_, redis_gcs_client_->Nodes(), redis_gcs_client_->Errors(),
gcs_pub_sub_, gcs_table_storage_);
}
void GcsServer::InitGcsActorManager() {
@@ -26,8 +26,10 @@ class GcsActorSchedulerTest : public ::testing::Test {
raylet_client_ = std::make_shared<GcsServerMocker::MockRayletClient>();
worker_client_ = std::make_shared<GcsServerMocker::MockWorkerClient>();
gcs_pub_sub_ = std::make_shared<GcsServerMocker::MockGcsPubSub>(redis_client_);
gcs_table_storage_ = std::make_shared<gcs::RedisGcsTableStorage>(redis_client_);
gcs_node_manager_ = std::make_shared<gcs::GcsNodeManager>(
io_service_, node_info_accessor_, error_info_accessor_, gcs_pub_sub_);
io_service_, node_info_accessor_, error_info_accessor_, gcs_pub_sub_,
gcs_table_storage_);
gcs_actor_scheduler_ = std::make_shared<GcsServerMocker::MockedGcsActorScheduler>(
io_service_, actor_info_accessor_, *gcs_node_manager_, gcs_pub_sub_,
/*schedule_failure_handler=*/
@@ -57,6 +59,7 @@ class GcsActorSchedulerTest : public ::testing::Test {
std::vector<std::shared_ptr<gcs::GcsActor>> success_actors_;
std::vector<std::shared_ptr<gcs::GcsActor>> failure_actors_;
std::shared_ptr<GcsServerMocker::MockGcsPubSub> gcs_pub_sub_;
std::shared_ptr<gcs::GcsTableStorage> gcs_table_storage_;
std::shared_ptr<gcs::RedisClient> redis_client_;
};
@@ -22,6 +22,7 @@ namespace ray {
class GcsNodeManagerTest : public ::testing::Test {
protected:
std::shared_ptr<gcs::GcsPubSub> gcs_pub_sub_;
std::shared_ptr<gcs::GcsTableStorage> gcs_table_storage_;
};
TEST_F(GcsNodeManagerTest, TestManagement) {
@@ -29,7 +30,7 @@ TEST_F(GcsNodeManagerTest, TestManagement) {
auto node_info_accessor = GcsServerMocker::MockedNodeInfoAccessor();
auto error_info_accessor = GcsServerMocker::MockedErrorInfoAccessor();
gcs::GcsNodeManager node_manager(io_service, node_info_accessor, error_info_accessor,
gcs_pub_sub_);
gcs_pub_sub_, gcs_table_storage_);
// Test Add/Get/Remove functionality.
auto node = Mocker::GenNodeInfo();
auto node_id = ClientID::FromBinary(node->node_id());
@@ -46,7 +47,7 @@ TEST_F(GcsNodeManagerTest, TestListener) {
auto node_info_accessor = GcsServerMocker::MockedNodeInfoAccessor();
auto error_info_accessor = GcsServerMocker::MockedErrorInfoAccessor();
gcs::GcsNodeManager node_manager(io_service, node_info_accessor, error_info_accessor,
gcs_pub_sub_);
gcs_pub_sub_, gcs_table_storage_);
// Test AddNodeAddedListener.
int node_count = 1000;
std::vector<std::shared_ptr<rpc::GcsNodeInfo>> added_nodes;
@@ -54,7 +54,8 @@ class GcsObjectManagerTest : public ::testing::Test {
void SetUp() override {
gcs_table_storage_ = std::make_shared<gcs::InMemoryGcsTableStorage>(io_service_);
gcs_node_manager_ = std::make_shared<gcs::GcsNodeManager>(
io_service_, node_info_accessor_, error_info_accessor_, gcs_pub_sub_);
io_service_, node_info_accessor_, error_info_accessor_, gcs_pub_sub_,
gcs_table_storage_);
gcs_object_manager_ = std::make_shared<MockedGcsObjectManager>(
gcs_table_storage_, gcs_pub_sub_, *gcs_node_manager_);
GenTestData();