[core] Better support multi-nic environments by respecting user-provided IP (#8512)

This commit is contained in:
Xianyang Liu
2020-06-26 03:03:12 +08:00
committed by GitHub
parent 46962f5db1
commit 0bfcc2e5ba
4 changed files with 28 additions and 4 deletions
+2 -1
View File
@@ -244,7 +244,8 @@ class Node:
self._ray_params.num_cpus, self._ray_params.num_gpus,
self._ray_params.memory, self._ray_params.object_store_memory,
self._ray_params.resources,
self._ray_params.redis_max_memory).resolve(is_head=self.head)
self._ray_params.redis_max_memory).resolve(
is_head=self.head, node_ip_address=self.node_ip_address)
return self._resource_spec
@property
+12 -3
View File
@@ -121,8 +121,14 @@ class ResourceSpec(
return resources
def resolve(self, is_head):
"""Returns a copy with values filled out with system defaults."""
def resolve(self, is_head, node_ip_address=None):
"""Returns a copy with values filled out with system defaults.
Args:
is_head (bool): Whether this is the head node.
node_ip_address (str): The IP address of the node that we are on.
This is used to automatically create a node id resource.
"""
resources = (self.resources or {}).copy()
assert "CPU" not in resources, resources
@@ -130,9 +136,12 @@ class ResourceSpec(
assert "memory" not in resources, resources
assert "object_store_memory" not in resources, resources
if node_ip_address is None:
node_ip_address = ray.services.get_node_ip_address()
# Automatically create a node id resource on each node. This is
# queryable with ray.state.node_ids() and ray.state.current_node_id().
resources[NODE_ID_PREFIX + ray.services.get_node_ip_address()] = 1.0
resources[NODE_ID_PREFIX + node_ip_address] = 1.0
num_cpus = self.num_cpus
if num_cpus is None:
+3
View File
@@ -324,6 +324,9 @@ def get_node_ip_address(address="8.8.8.8:53"):
Returns:
The IP address of the current node.
"""
if ray.worker._global_node is not None:
return ray.worker._global_node.node_ip_address
ip_address, port = address.split(":")
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
try:
+11
View File
@@ -8,6 +8,8 @@ import time
import numpy as np
import pytest
from unittest.mock import MagicMock, patch
import ray
import ray.cluster_utils
import ray.test_utils
@@ -673,6 +675,15 @@ def test_internal_config_when_connecting(ray_start_cluster):
ray.get(oid)
def test_get_correct_node_ip():
with patch("ray.worker") as worker_mock:
node_mock = MagicMock()
node_mock.node_ip_address = "10.0.0.111"
worker_mock._global_node = node_mock
found_ip = ray.services.get_node_ip_address()
assert found_ip == "10.0.0.111"
if __name__ == "__main__":
import pytest
sys.exit(pytest.main(["-v", __file__]))