mirror of
https://github.com/wassname/ray.git
synced 2026-06-29 03:21:06 +08:00
[cluster] On Prem Server First PR (#9663)
* on prem server first commit * minor fix * verify error on autoscaling in on prem mode * lint * lint * Tests complete * add tests to check for backward compatibility * Fixing comments and autoscaling * minor fixes * coordinating server mode * tests * lint * remove unnecessary import * Resolving Comments * seperating coordinator and local node provider Co-authored-by: Ameer Haj Ali <ameerhajali@Ameers-MacBook-Pro.local>
This commit is contained in:
@@ -0,0 +1,104 @@
|
||||
import json
|
||||
import logging
|
||||
from http.client import RemoteDisconnected
|
||||
|
||||
from ray.autoscaler.node_provider import NodeProvider
|
||||
from ray.autoscaler.tags import TAG_RAY_CLUSTER_NAME
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CoordinatorSenderNodeProvider(NodeProvider):
|
||||
"""NodeProvider for automatically managed private/local clusters.
|
||||
|
||||
The cluster management is handled by a remote coordinating server.
|
||||
The server listens on <coordinator_address>, therefore, the address
|
||||
should be provided in the provider section in the cluster config.
|
||||
The server receieves HTTP requests from this class and uses
|
||||
LocalNodeProvider to get their responses.
|
||||
"""
|
||||
|
||||
def __init__(self, provider_config, cluster_name):
|
||||
NodeProvider.__init__(self, provider_config, cluster_name)
|
||||
self.coordinator_address = provider_config["coordinator_address"]
|
||||
|
||||
def _get_http_response(self, request):
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
request_message = json.dumps(request).encode()
|
||||
http_coordinator_address = "http://" + self.coordinator_address
|
||||
|
||||
try:
|
||||
import requests # `requests` is not part of stdlib.
|
||||
from requests.exceptions import ConnectionError
|
||||
|
||||
r = requests.get(
|
||||
http_coordinator_address,
|
||||
data=request_message,
|
||||
headers=headers,
|
||||
timeout=None,
|
||||
)
|
||||
except (RemoteDisconnected, ConnectionError):
|
||||
logger.exception("Could not connect to: " +
|
||||
http_coordinator_address +
|
||||
". Did you run python coordinator_server.py" +
|
||||
" --ips <list_of_node_ips> --port <PORT>?")
|
||||
raise
|
||||
except ImportError:
|
||||
logger.exception("Couldn't import `requests` library. "
|
||||
"Be sure to install it on the client side.")
|
||||
raise
|
||||
|
||||
response = r.json()
|
||||
return response
|
||||
|
||||
def non_terminated_nodes(self, tag_filters):
|
||||
# Only get the non terminated nodes associated with this cluster name.
|
||||
tag_filters[TAG_RAY_CLUSTER_NAME] = self.cluster_name
|
||||
request = {"type": "non_terminated_nodes", "args": (tag_filters, )}
|
||||
return self._get_http_response(request)
|
||||
|
||||
def is_running(self, node_id):
|
||||
request = {"type": "is_running", "args": (node_id, )}
|
||||
return self._get_http_response(request)
|
||||
|
||||
def is_terminated(self, node_id):
|
||||
request = {"type": "is_terminated", "args": (node_id, )}
|
||||
return self._get_http_response(request)
|
||||
|
||||
def node_tags(self, node_id):
|
||||
request = {"type": "node_tags", "args": (node_id, )}
|
||||
return self._get_http_response(request)
|
||||
|
||||
def external_ip(self, node_id):
|
||||
request = {"type": "external_ip", "args": (node_id, )}
|
||||
response = self._get_http_response(request)
|
||||
return response
|
||||
|
||||
def internal_ip(self, node_id):
|
||||
request = {"type": "internal_ip", "args": (node_id, )}
|
||||
response = self._get_http_response(request)
|
||||
return response
|
||||
|
||||
def create_node(self, node_config, tags, count):
|
||||
# Tag the newly created node with this cluster name. Helps to get
|
||||
# the right nodes when calling non_terminated_nodes.
|
||||
tags[TAG_RAY_CLUSTER_NAME] = self.cluster_name
|
||||
request = {
|
||||
"type": "create_node",
|
||||
"args": (node_config, tags, count),
|
||||
}
|
||||
self._get_http_response(request)
|
||||
|
||||
def set_node_tags(self, node_id, tags):
|
||||
request = {"type": "set_node_tags", "args": (node_id, tags)}
|
||||
self._get_http_response(request)
|
||||
|
||||
def terminate_node(self, node_id):
|
||||
request = {"type": "terminate_node", "args": (node_id, )}
|
||||
self._get_http_response(request)
|
||||
|
||||
def terminate_nodes(self, node_ids):
|
||||
request = {"type": "terminate_nodes", "args": (node_ids, )}
|
||||
self._get_http_response(request)
|
||||
@@ -0,0 +1,118 @@
|
||||
"""Web server that runs on local/private clusters to coordinate and manage
|
||||
different clusters for multiple users. It receives node provider function calls
|
||||
through HTTP requests from remote CoordinatorSenderNodeProvider and runs them
|
||||
locally in LocalNodeProvider. To start the webserver the user runs:
|
||||
`python coordinator_server.py --ips <comma separated ips> --port <PORT>`."""
|
||||
import argparse
|
||||
import logging
|
||||
import threading
|
||||
from http.server import SimpleHTTPRequestHandler, HTTPServer
|
||||
import json
|
||||
import socket
|
||||
|
||||
from ray.autoscaler.local.node_provider import LocalNodeProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
|
||||
def runner_handler(node_provider):
|
||||
class Handler(SimpleHTTPRequestHandler):
|
||||
"""A custom handler for OnPremCoordinatorServer.
|
||||
|
||||
Handles all requests and responses coming into and from the
|
||||
remote CoordinatorSenderNodeProvider.
|
||||
"""
|
||||
|
||||
def _do_header(self, response_code=200, headers=None):
|
||||
"""Sends the header portion of the HTTP response.
|
||||
|
||||
Args:
|
||||
response_code (int): Standard HTTP response code
|
||||
headers (list[tuples]): Standard HTTP response headers
|
||||
"""
|
||||
if headers is None:
|
||||
headers = [("Content-type", "application/json")]
|
||||
|
||||
self.send_response(response_code)
|
||||
for key, value in headers:
|
||||
self.send_header(key, value)
|
||||
self.end_headers()
|
||||
|
||||
def do_HEAD(self):
|
||||
"""HTTP HEAD handler method."""
|
||||
self._do_header()
|
||||
|
||||
def do_GET(self):
|
||||
"""Processes requests from remote CoordinatorSenderNodeProvider."""
|
||||
if self.headers["content-length"]:
|
||||
raw_data = (self.rfile.read(
|
||||
int(self.headers["content-length"]))).decode("utf-8")
|
||||
logger.info("OnPremCoordinatorServer received request: " +
|
||||
str(raw_data))
|
||||
request = json.loads(raw_data)
|
||||
response = getattr(node_provider,
|
||||
request["type"])(*request["args"])
|
||||
logger.info("OnPremCoordinatorServer response content: " +
|
||||
str(raw_data))
|
||||
response_code = 200
|
||||
message = json.dumps(response)
|
||||
self._do_header(response_code=response_code)
|
||||
self.wfile.write(message.encode())
|
||||
|
||||
return Handler
|
||||
|
||||
|
||||
class OnPremCoordinatorServer(threading.Thread):
|
||||
"""Initializes HTTPServer and serves CoordinatorSenderNodeProvider forever.
|
||||
|
||||
It handles requests from the remote CoordinatorSenderNodeProvider. The
|
||||
requests are forwarded to LocalNodeProvider function calls.
|
||||
"""
|
||||
|
||||
def __init__(self, list_of_node_ips, host, port):
|
||||
"""Initialize HTTPServer and serve forever by invoking self.run()."""
|
||||
|
||||
logger.info("Running on prem coordinator server on address " + host +
|
||||
":" + str(port))
|
||||
threading.Thread.__init__(self)
|
||||
self._port = port
|
||||
self._list_of_node_ips = list_of_node_ips
|
||||
address = (host, self._port)
|
||||
config = {"list_of_node_ips": list_of_node_ips}
|
||||
self._server = HTTPServer(
|
||||
address,
|
||||
runner_handler(LocalNodeProvider(config, cluster_name=None)),
|
||||
)
|
||||
self.start()
|
||||
|
||||
def run(self):
|
||||
self._server.serve_forever()
|
||||
|
||||
def shutdown(self):
|
||||
"""Shutdown the underlying server."""
|
||||
self._server.shutdown()
|
||||
self._server.server_close()
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Please provide a list of node ips and port.")
|
||||
parser.add_argument(
|
||||
"--ips", required=True, help="Comma separated list of node ips.")
|
||||
parser.add_argument(
|
||||
"--port",
|
||||
type=int,
|
||||
required=True,
|
||||
help="The port on which the coordinator listens.")
|
||||
args = parser.parse_args()
|
||||
list_of_node_ips = args.ips.split(",")
|
||||
OnPremCoordinatorServer(
|
||||
list_of_node_ips=list_of_node_ips,
|
||||
host=socket.gethostbyname(socket.gethostname()),
|
||||
port=args.port,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -38,6 +38,10 @@ provider:
|
||||
type: local
|
||||
head_ip: YOUR_HEAD_NODE_HOSTNAME
|
||||
worker_ips: [WORKER_NODE_1_HOSTNAME, WORKER_NODE_2_HOSTNAME, ... ]
|
||||
# Optional when running automatic cluster management on prem. If you use a coordinator server,
|
||||
# then you can launch multiple autoscaling clusters on the same set of machines, and the coordinator
|
||||
# will assign individual nodes to clusters as needed.
|
||||
# coordinator_address: "<host>:<port>"
|
||||
|
||||
# How Ray will authenticate with newly launched nodes.
|
||||
auth:
|
||||
|
||||
@@ -7,8 +7,11 @@ import logging
|
||||
|
||||
from ray.autoscaler.node_provider import NodeProvider
|
||||
from ray.autoscaler.local.config import bootstrap_local
|
||||
from ray.autoscaler.tags import TAG_RAY_NODE_TYPE, NODE_TYPE_WORKER, \
|
||||
NODE_TYPE_HEAD
|
||||
from ray.autoscaler.tags import (
|
||||
TAG_RAY_NODE_TYPE,
|
||||
NODE_TYPE_WORKER,
|
||||
NODE_TYPE_HEAD,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -45,8 +48,8 @@ class ClusterState:
|
||||
"state": "terminated",
|
||||
}
|
||||
else:
|
||||
assert workers[worker_ip]["tags"][
|
||||
TAG_RAY_NODE_TYPE] == NODE_TYPE_WORKER
|
||||
assert (workers[worker_ip]["tags"][TAG_RAY_NODE_TYPE]
|
||||
== NODE_TYPE_WORKER)
|
||||
if provider_config["head_ip"] not in workers:
|
||||
workers[provider_config["head_ip"]] = {
|
||||
"tags": {
|
||||
@@ -55,8 +58,16 @@ class ClusterState:
|
||||
"state": "terminated",
|
||||
}
|
||||
else:
|
||||
assert workers[provider_config["head_ip"]]["tags"][
|
||||
TAG_RAY_NODE_TYPE] == NODE_TYPE_HEAD
|
||||
assert (workers[provider_config["head_ip"]]["tags"][
|
||||
TAG_RAY_NODE_TYPE] == NODE_TYPE_HEAD)
|
||||
# Relevant when a user reduces the number of workers
|
||||
# without changing the headnode.
|
||||
list_of_node_ips = list(provider_config["worker_ips"])
|
||||
list_of_node_ips.append(provider_config["head_ip"])
|
||||
for worker_ip in list(workers):
|
||||
if worker_ip not in list_of_node_ips:
|
||||
del workers[worker_ip]
|
||||
|
||||
assert len(workers) == len(provider_config["worker_ips"]) + 1
|
||||
with open(self.save_path, "w") as f:
|
||||
logger.debug("ClusterState: "
|
||||
@@ -83,17 +94,82 @@ class ClusterState:
|
||||
f.write(json.dumps(workers))
|
||||
|
||||
|
||||
class OnPremCoordinatorState(ClusterState):
|
||||
"""Generates & updates the state file of CoordinatorSenderNodeProvider.
|
||||
|
||||
Unlike ClusterState, which generates a cluster specific file with
|
||||
predefined head and worker ips, OnPremCoordinatorState overwrites
|
||||
ClusterState's __init__ function to generate and manage a unified
|
||||
file of the status of all the nodes for multiple clusters.
|
||||
"""
|
||||
|
||||
def __init__(self, lock_path, save_path, list_of_node_ips):
|
||||
self.lock = RLock()
|
||||
self.file_lock = FileLock(lock_path)
|
||||
self.save_path = save_path
|
||||
|
||||
with self.lock:
|
||||
with self.file_lock:
|
||||
if os.path.exists(self.save_path):
|
||||
nodes = json.loads(open(self.save_path).read())
|
||||
else:
|
||||
nodes = {}
|
||||
logger.info(
|
||||
"OnPremCoordinatorState: "
|
||||
"Loaded on prem coordinator state: {}".format(nodes))
|
||||
|
||||
# Filter removed node ips.
|
||||
for node_ip in list(nodes):
|
||||
if node_ip not in list_of_node_ips:
|
||||
del nodes[node_ip]
|
||||
|
||||
for node_ip in list_of_node_ips:
|
||||
if node_ip not in nodes:
|
||||
nodes[node_ip] = {
|
||||
"tags": {},
|
||||
"state": "terminated",
|
||||
}
|
||||
assert len(nodes) == len(list_of_node_ips)
|
||||
with open(self.save_path, "w") as f:
|
||||
logger.info(
|
||||
"OnPremCoordinatorState: "
|
||||
"Writing on prem coordinator state: {}".format(nodes))
|
||||
f.write(json.dumps(nodes))
|
||||
|
||||
|
||||
class LocalNodeProvider(NodeProvider):
|
||||
"""NodeProvider for private/local clusters.
|
||||
|
||||
`node_id` is overloaded to also be `node_ip` in this class.
|
||||
|
||||
When `cluster_name` is provided, it manages a single cluster in a cluster
|
||||
specific state file. But when `cluster_name` is None, it manages multiple
|
||||
clusters in a unified state file that requires each node to be tagged with
|
||||
TAG_RAY_CLUSTER_NAME in create and non_terminated_nodes function calls to
|
||||
associate each node with the right cluster.
|
||||
|
||||
The current use case of managing multiple clusters is by
|
||||
OnPremCoordinatorServer which receives node provider HTTP requests
|
||||
from CoordinatorSenderNodeProvider and uses LocalNodeProvider to get
|
||||
the responses.
|
||||
"""
|
||||
|
||||
def __init__(self, provider_config, cluster_name):
|
||||
NodeProvider.__init__(self, provider_config, cluster_name)
|
||||
self.state = ClusterState("/tmp/cluster-{}.lock".format(cluster_name),
|
||||
"/tmp/cluster-{}.state".format(cluster_name),
|
||||
provider_config)
|
||||
|
||||
if cluster_name:
|
||||
self.state = ClusterState(
|
||||
"/tmp/cluster-{}.lock".format(cluster_name),
|
||||
"/tmp/cluster-{}.state".format(cluster_name),
|
||||
provider_config,
|
||||
)
|
||||
self.use_coordinator = False
|
||||
else:
|
||||
# LocalNodeProvider with a coordinator server.
|
||||
self.state = OnPremCoordinatorState(
|
||||
"/tmp/coordinator.lock", "/tmp/coordinator.state",
|
||||
provider_config["list_of_node_ips"])
|
||||
self.use_coordinator = True
|
||||
|
||||
def non_terminated_nodes(self, tag_filters):
|
||||
workers = self.state.get()
|
||||
@@ -132,16 +208,20 @@ class LocalNodeProvider(NodeProvider):
|
||||
self.state.put(node_id, info)
|
||||
|
||||
def create_node(self, node_config, tags, count):
|
||||
"""Creates min(count, currently available) nodes."""
|
||||
node_type = tags[TAG_RAY_NODE_TYPE]
|
||||
with self.state.file_lock:
|
||||
workers = self.state.get()
|
||||
for node_id, info in workers.items():
|
||||
if (info["state"] == "terminated"
|
||||
and info["tags"][TAG_RAY_NODE_TYPE] == node_type):
|
||||
and (self.use_coordinator
|
||||
or info["tags"][TAG_RAY_NODE_TYPE] == node_type)):
|
||||
info["tags"] = tags
|
||||
info["state"] = "running"
|
||||
self.state.put(node_id, info)
|
||||
return
|
||||
count = count - 1
|
||||
if count == 0:
|
||||
return
|
||||
|
||||
def terminate_node(self, node_id):
|
||||
workers = self.state.get()
|
||||
|
||||
@@ -26,8 +26,13 @@ def import_azure(provider_config):
|
||||
|
||||
|
||||
def import_local(provider_config):
|
||||
from ray.autoscaler.local.node_provider import LocalNodeProvider
|
||||
return LocalNodeProvider
|
||||
if "coordinator_address" in provider_config:
|
||||
from ray.autoscaler.local.coordinator_node_provider import (
|
||||
CoordinatorSenderNodeProvider)
|
||||
return CoordinatorSenderNodeProvider
|
||||
else:
|
||||
from ray.autoscaler.local.node_provider import LocalNodeProvider
|
||||
return LocalNodeProvider
|
||||
|
||||
|
||||
def import_kubernetes(provider_config):
|
||||
|
||||
@@ -190,6 +190,13 @@ py_test(
|
||||
deps = ["//:ray_lib"],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "test_coordinator_server",
|
||||
size = "small",
|
||||
srcs = SRCS + ["test_coordinator_server.py"],
|
||||
tags = ["exclusive"],
|
||||
deps = ["//:ray_lib"],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "test_autoscaler_aws",
|
||||
|
||||
@@ -0,0 +1,254 @@
|
||||
import os
|
||||
import unittest
|
||||
import socket
|
||||
import json
|
||||
|
||||
from ray.autoscaler.local.coordinator_server import OnPremCoordinatorServer
|
||||
from ray.autoscaler.node_provider import NODE_PROVIDERS, get_node_provider
|
||||
from ray.autoscaler.local.node_provider import LocalNodeProvider
|
||||
from ray.autoscaler.local.coordinator_node_provider import (
|
||||
CoordinatorSenderNodeProvider)
|
||||
from ray.autoscaler.tags import (TAG_RAY_NODE_TYPE, TAG_RAY_CLUSTER_NAME,
|
||||
TAG_RAY_NODE_NAME, NODE_TYPE_WORKER,
|
||||
NODE_TYPE_HEAD)
|
||||
import pytest
|
||||
|
||||
|
||||
class OnPremCoordinatorServerTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.list_of_node_ips = ["0.0.0.0:1", "0.0.0.0:2"]
|
||||
self.host, self.port = socket.gethostbyname(socket.gethostname()), 1234
|
||||
self.server = OnPremCoordinatorServer(
|
||||
list_of_node_ips=self.list_of_node_ips,
|
||||
host=self.host,
|
||||
port=self.port,
|
||||
)
|
||||
self.coordinator_address = self.host + ":" + str(self.port)
|
||||
|
||||
def tearDown(self):
|
||||
self.server.shutdown()
|
||||
state_save_path = "/tmp/coordinator.state"
|
||||
if os.path.exists(state_save_path):
|
||||
os.remove(state_save_path)
|
||||
|
||||
def testImportingCorrectClass(self):
|
||||
"""Check correct import when coordinator_address is in config yaml."""
|
||||
|
||||
provider_config = {"coordinator_address": "fake_address:1234"}
|
||||
coordinator_node_provider = NODE_PROVIDERS.get("local")(
|
||||
provider_config)
|
||||
assert coordinator_node_provider is CoordinatorSenderNodeProvider
|
||||
local_node_provider = NODE_PROVIDERS.get("local")({})
|
||||
assert local_node_provider is LocalNodeProvider
|
||||
|
||||
def testClusterStateInit(self):
|
||||
"""Check ClusterState __init__ func generates correct state file.
|
||||
|
||||
Test the general use case and if num_workers increase/decrease.
|
||||
"""
|
||||
|
||||
cluster_config = {
|
||||
"cluster_name": "random_name",
|
||||
"min_workers": 0,
|
||||
"max_workers": 0,
|
||||
"initial_workers": 0,
|
||||
"provider": {
|
||||
"type": "local",
|
||||
"head_ip": "0.0.0.0:2",
|
||||
"worker_ips": ["0.0.0.0:1"]
|
||||
},
|
||||
}
|
||||
provider_config = cluster_config["provider"]
|
||||
node_provider = get_node_provider(provider_config,
|
||||
cluster_config["cluster_name"])
|
||||
assert isinstance(node_provider, LocalNodeProvider)
|
||||
expected_workers = {}
|
||||
expected_workers[provider_config["head_ip"]] = {
|
||||
"tags": {
|
||||
TAG_RAY_NODE_TYPE: NODE_TYPE_HEAD
|
||||
},
|
||||
"state": "terminated",
|
||||
}
|
||||
expected_workers[provider_config["worker_ips"][0]] = {
|
||||
"tags": {
|
||||
TAG_RAY_NODE_TYPE: NODE_TYPE_WORKER
|
||||
},
|
||||
"state": "terminated",
|
||||
}
|
||||
|
||||
state_save_path = "/tmp/cluster-{}.state".format(
|
||||
cluster_config["cluster_name"])
|
||||
assert os.path.exists(state_save_path)
|
||||
workers = json.loads(open(state_save_path).read())
|
||||
assert workers == expected_workers
|
||||
|
||||
# Test removing workers updates the cluster state.
|
||||
del expected_workers[provider_config["worker_ips"][0]]
|
||||
removed_ip = provider_config["worker_ips"].pop()
|
||||
node_provider = get_node_provider(provider_config,
|
||||
cluster_config["cluster_name"])
|
||||
workers = json.loads(open(state_save_path).read())
|
||||
assert workers == expected_workers
|
||||
|
||||
# Test adding back workers updates the cluster state.
|
||||
expected_workers[removed_ip] = {
|
||||
"tags": {
|
||||
TAG_RAY_NODE_TYPE: NODE_TYPE_WORKER
|
||||
},
|
||||
"state": "terminated",
|
||||
}
|
||||
provider_config["worker_ips"].append(removed_ip)
|
||||
node_provider = get_node_provider(provider_config,
|
||||
cluster_config["cluster_name"])
|
||||
workers = json.loads(open(state_save_path).read())
|
||||
assert workers == expected_workers
|
||||
|
||||
def testOnPremCoordinatorStateInit(self):
|
||||
"""If OnPremCoordinatorState __init__ generates correct state file.
|
||||
|
||||
Test the general use case and if the coordinator server crashes or
|
||||
updates the list of node ips with more/less nodes.
|
||||
"""
|
||||
|
||||
expected_nodes = {}
|
||||
for ip in self.list_of_node_ips:
|
||||
expected_nodes[ip] = {
|
||||
"tags": {},
|
||||
"state": "terminated",
|
||||
}
|
||||
|
||||
state_save_path = "/tmp/coordinator.state"
|
||||
assert os.path.exists(state_save_path)
|
||||
nodes = json.loads(open(state_save_path).read())
|
||||
assert nodes == expected_nodes
|
||||
|
||||
# Test removing workers updates the cluster state.
|
||||
del expected_nodes[self.list_of_node_ips[1]]
|
||||
self.server.shutdown()
|
||||
self.server = OnPremCoordinatorServer(
|
||||
list_of_node_ips=self.list_of_node_ips[0:1],
|
||||
host=self.host,
|
||||
port=self.port,
|
||||
)
|
||||
nodes = json.loads(open(state_save_path).read())
|
||||
assert nodes == expected_nodes
|
||||
|
||||
# Test adding back workers updates the cluster state.
|
||||
expected_nodes[self.list_of_node_ips[1]] = {
|
||||
"tags": {},
|
||||
"state": "terminated",
|
||||
}
|
||||
self.server.shutdown()
|
||||
self.server = OnPremCoordinatorServer(
|
||||
list_of_node_ips=self.list_of_node_ips,
|
||||
host=self.host,
|
||||
port=self.port,
|
||||
)
|
||||
nodes = json.loads(open(state_save_path).read())
|
||||
assert nodes == expected_nodes
|
||||
|
||||
def testCoordinatorSenderNodeProvider(self):
|
||||
"""Integration test of CoordinatorSenderNodeProvider."""
|
||||
cluster_config = {
|
||||
"cluster_name": "random_name",
|
||||
"min_workers": 0,
|
||||
"max_workers": 0,
|
||||
"initial_workers": 0,
|
||||
"provider": {
|
||||
"type": "local",
|
||||
"coordinator_address": self.coordinator_address,
|
||||
},
|
||||
"head_node": {},
|
||||
"worker_nodes": {},
|
||||
}
|
||||
provider_config = cluster_config["provider"]
|
||||
node_provider_1 = get_node_provider(provider_config,
|
||||
cluster_config["cluster_name"])
|
||||
assert isinstance(node_provider_1, CoordinatorSenderNodeProvider)
|
||||
|
||||
assert not node_provider_1.non_terminated_nodes({})
|
||||
assert not node_provider_1.is_running(self.list_of_node_ips[0])
|
||||
assert node_provider_1.is_terminated(self.list_of_node_ips[0])
|
||||
assert not node_provider_1.node_tags(self.list_of_node_ips[0])
|
||||
head_node_tags = {
|
||||
TAG_RAY_NODE_TYPE: NODE_TYPE_HEAD,
|
||||
}
|
||||
assert not node_provider_1.non_terminated_nodes(head_node_tags)
|
||||
head_node_tags[TAG_RAY_NODE_NAME] = "ray-{}-head".format(
|
||||
cluster_config["cluster_name"])
|
||||
node_provider_1.create_node(cluster_config["head_node"],
|
||||
head_node_tags, 1)
|
||||
assert node_provider_1.non_terminated_nodes(
|
||||
{}) == [self.list_of_node_ips[0]]
|
||||
head_node_tags[TAG_RAY_CLUSTER_NAME] = cluster_config["cluster_name"]
|
||||
assert node_provider_1.node_tags(
|
||||
self.list_of_node_ips[0]) == head_node_tags
|
||||
assert node_provider_1.is_running(self.list_of_node_ips[0])
|
||||
assert not node_provider_1.is_terminated(self.list_of_node_ips[0])
|
||||
|
||||
# Add another cluster.
|
||||
cluster_config["cluster_name"] = "random_name_2"
|
||||
provider_config = cluster_config["provider"]
|
||||
node_provider_2 = get_node_provider(provider_config,
|
||||
cluster_config["cluster_name"])
|
||||
assert not node_provider_2.non_terminated_nodes({})
|
||||
assert not node_provider_2.is_running(self.list_of_node_ips[1])
|
||||
assert node_provider_2.is_terminated(self.list_of_node_ips[1])
|
||||
assert not node_provider_2.node_tags(self.list_of_node_ips[1])
|
||||
assert not node_provider_2.non_terminated_nodes(head_node_tags)
|
||||
head_node_tags[TAG_RAY_NODE_NAME] = "ray-{}-head".format(
|
||||
cluster_config["cluster_name"])
|
||||
node_provider_2.create_node(cluster_config["head_node"],
|
||||
head_node_tags, 1)
|
||||
assert node_provider_2.non_terminated_nodes(
|
||||
{}) == [self.list_of_node_ips[1]]
|
||||
head_node_tags[TAG_RAY_CLUSTER_NAME] = cluster_config["cluster_name"]
|
||||
assert node_provider_2.node_tags(
|
||||
self.list_of_node_ips[1]) == head_node_tags
|
||||
assert node_provider_2.is_running(self.list_of_node_ips[1])
|
||||
assert not node_provider_2.is_terminated(self.list_of_node_ips[1])
|
||||
|
||||
# Add another cluster (should fail because we only have two nodes).
|
||||
cluster_config["cluster_name"] = "random_name_3"
|
||||
provider_config = cluster_config["provider"]
|
||||
node_provider_3 = get_node_provider(provider_config,
|
||||
cluster_config["cluster_name"])
|
||||
assert not node_provider_3.non_terminated_nodes(head_node_tags)
|
||||
head_node_tags[TAG_RAY_NODE_NAME] = "ray-{}-head".format(
|
||||
cluster_config["cluster_name"])
|
||||
node_provider_3.create_node(cluster_config["head_node"],
|
||||
head_node_tags, 1)
|
||||
assert not node_provider_3.non_terminated_nodes({})
|
||||
|
||||
# Terminate all nodes.
|
||||
node_provider_1.terminate_node(self.list_of_node_ips[0])
|
||||
assert not node_provider_1.non_terminated_nodes({})
|
||||
node_provider_2.terminate_node(self.list_of_node_ips[1])
|
||||
assert not node_provider_2.non_terminated_nodes({})
|
||||
|
||||
# Check if now we can create more clusters/nodes.
|
||||
node_provider_3.create_node(cluster_config["head_node"],
|
||||
head_node_tags, 1)
|
||||
worker_node_tags = {
|
||||
TAG_RAY_NODE_NAME: "ray-{}-worker".format(
|
||||
cluster_config["cluster_name"]),
|
||||
TAG_RAY_NODE_TYPE: NODE_TYPE_WORKER
|
||||
}
|
||||
node_provider_3.create_node(cluster_config["worker_nodes"],
|
||||
worker_node_tags, 1)
|
||||
assert node_provider_3.non_terminated_nodes(
|
||||
{}) == self.list_of_node_ips
|
||||
worker_filter = {TAG_RAY_NODE_TYPE: NODE_TYPE_WORKER}
|
||||
assert node_provider_3.non_terminated_nodes(worker_filter) == [
|
||||
self.list_of_node_ips[1]
|
||||
]
|
||||
head_filter = {TAG_RAY_NODE_TYPE: NODE_TYPE_HEAD}
|
||||
assert node_provider_3.non_terminated_nodes(head_filter) == [
|
||||
self.list_of_node_ips[0]
|
||||
]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
|
||||
sys.exit(pytest.main(["-v", __file__]))
|
||||
Reference in New Issue
Block a user