From 9089fab0ef8330126abfa7516e9361842a4bc10c Mon Sep 17 00:00:00 2001 From: Ameer Haj Ali Date: Mon, 3 Aug 2020 20:38:44 +0300 Subject: [PATCH] [cluster] On Prem Server First PR (#9663) * on prem server first commit * minor fix * verify error on autoscaling in on prem mode * lint * lint * Tests complete * add tests to check for backward compatibility * Fixing comments and autoscaling * minor fixes * coordinating server mode * tests * lint * remove unnecessary import * Resolving Comments * seperating coordinator and local node provider Co-authored-by: Ameer Haj Ali --- .../local/coordinator_node_provider.py | 104 +++++++ .../autoscaler/local/coordinator_server.py | 118 ++++++++ python/ray/autoscaler/local/example-full.yaml | 4 + python/ray/autoscaler/local/node_provider.py | 102 ++++++- python/ray/autoscaler/node_provider.py | 9 +- python/ray/tests/BUILD | 7 + python/ray/tests/test_coordinator_server.py | 254 ++++++++++++++++++ 7 files changed, 585 insertions(+), 13 deletions(-) create mode 100644 python/ray/autoscaler/local/coordinator_node_provider.py create mode 100644 python/ray/autoscaler/local/coordinator_server.py create mode 100644 python/ray/tests/test_coordinator_server.py diff --git a/python/ray/autoscaler/local/coordinator_node_provider.py b/python/ray/autoscaler/local/coordinator_node_provider.py new file mode 100644 index 000000000..5b3eb115a --- /dev/null +++ b/python/ray/autoscaler/local/coordinator_node_provider.py @@ -0,0 +1,104 @@ +import json +import logging +from http.client import RemoteDisconnected + +from ray.autoscaler.node_provider import NodeProvider +from ray.autoscaler.tags import TAG_RAY_CLUSTER_NAME + +logger = logging.getLogger(__name__) + + +class CoordinatorSenderNodeProvider(NodeProvider): + """NodeProvider for automatically managed private/local clusters. + + The cluster management is handled by a remote coordinating server. + The server listens on , therefore, the address + should be provided in the provider section in the cluster config. + The server receieves HTTP requests from this class and uses + LocalNodeProvider to get their responses. + """ + + def __init__(self, provider_config, cluster_name): + NodeProvider.__init__(self, provider_config, cluster_name) + self.coordinator_address = provider_config["coordinator_address"] + + def _get_http_response(self, request): + headers = { + "Content-Type": "application/json", + } + request_message = json.dumps(request).encode() + http_coordinator_address = "http://" + self.coordinator_address + + try: + import requests # `requests` is not part of stdlib. + from requests.exceptions import ConnectionError + + r = requests.get( + http_coordinator_address, + data=request_message, + headers=headers, + timeout=None, + ) + except (RemoteDisconnected, ConnectionError): + logger.exception("Could not connect to: " + + http_coordinator_address + + ". Did you run python coordinator_server.py" + + " --ips --port ?") + raise + except ImportError: + logger.exception("Couldn't import `requests` library. " + "Be sure to install it on the client side.") + raise + + response = r.json() + return response + + def non_terminated_nodes(self, tag_filters): + # Only get the non terminated nodes associated with this cluster name. + tag_filters[TAG_RAY_CLUSTER_NAME] = self.cluster_name + request = {"type": "non_terminated_nodes", "args": (tag_filters, )} + return self._get_http_response(request) + + def is_running(self, node_id): + request = {"type": "is_running", "args": (node_id, )} + return self._get_http_response(request) + + def is_terminated(self, node_id): + request = {"type": "is_terminated", "args": (node_id, )} + return self._get_http_response(request) + + def node_tags(self, node_id): + request = {"type": "node_tags", "args": (node_id, )} + return self._get_http_response(request) + + def external_ip(self, node_id): + request = {"type": "external_ip", "args": (node_id, )} + response = self._get_http_response(request) + return response + + def internal_ip(self, node_id): + request = {"type": "internal_ip", "args": (node_id, )} + response = self._get_http_response(request) + return response + + def create_node(self, node_config, tags, count): + # Tag the newly created node with this cluster name. Helps to get + # the right nodes when calling non_terminated_nodes. + tags[TAG_RAY_CLUSTER_NAME] = self.cluster_name + request = { + "type": "create_node", + "args": (node_config, tags, count), + } + self._get_http_response(request) + + def set_node_tags(self, node_id, tags): + request = {"type": "set_node_tags", "args": (node_id, tags)} + self._get_http_response(request) + + def terminate_node(self, node_id): + request = {"type": "terminate_node", "args": (node_id, )} + self._get_http_response(request) + + def terminate_nodes(self, node_ids): + request = {"type": "terminate_nodes", "args": (node_ids, )} + self._get_http_response(request) diff --git a/python/ray/autoscaler/local/coordinator_server.py b/python/ray/autoscaler/local/coordinator_server.py new file mode 100644 index 000000000..428fa0bd1 --- /dev/null +++ b/python/ray/autoscaler/local/coordinator_server.py @@ -0,0 +1,118 @@ +"""Web server that runs on local/private clusters to coordinate and manage +different clusters for multiple users. It receives node provider function calls +through HTTP requests from remote CoordinatorSenderNodeProvider and runs them +locally in LocalNodeProvider. To start the webserver the user runs: +`python coordinator_server.py --ips --port `.""" +import argparse +import logging +import threading +from http.server import SimpleHTTPRequestHandler, HTTPServer +import json +import socket + +from ray.autoscaler.local.node_provider import LocalNodeProvider + +logger = logging.getLogger(__name__) +logging.basicConfig(level=logging.INFO) + + +def runner_handler(node_provider): + class Handler(SimpleHTTPRequestHandler): + """A custom handler for OnPremCoordinatorServer. + + Handles all requests and responses coming into and from the + remote CoordinatorSenderNodeProvider. + """ + + def _do_header(self, response_code=200, headers=None): + """Sends the header portion of the HTTP response. + + Args: + response_code (int): Standard HTTP response code + headers (list[tuples]): Standard HTTP response headers + """ + if headers is None: + headers = [("Content-type", "application/json")] + + self.send_response(response_code) + for key, value in headers: + self.send_header(key, value) + self.end_headers() + + def do_HEAD(self): + """HTTP HEAD handler method.""" + self._do_header() + + def do_GET(self): + """Processes requests from remote CoordinatorSenderNodeProvider.""" + if self.headers["content-length"]: + raw_data = (self.rfile.read( + int(self.headers["content-length"]))).decode("utf-8") + logger.info("OnPremCoordinatorServer received request: " + + str(raw_data)) + request = json.loads(raw_data) + response = getattr(node_provider, + request["type"])(*request["args"]) + logger.info("OnPremCoordinatorServer response content: " + + str(raw_data)) + response_code = 200 + message = json.dumps(response) + self._do_header(response_code=response_code) + self.wfile.write(message.encode()) + + return Handler + + +class OnPremCoordinatorServer(threading.Thread): + """Initializes HTTPServer and serves CoordinatorSenderNodeProvider forever. + + It handles requests from the remote CoordinatorSenderNodeProvider. The + requests are forwarded to LocalNodeProvider function calls. + """ + + def __init__(self, list_of_node_ips, host, port): + """Initialize HTTPServer and serve forever by invoking self.run().""" + + logger.info("Running on prem coordinator server on address " + host + + ":" + str(port)) + threading.Thread.__init__(self) + self._port = port + self._list_of_node_ips = list_of_node_ips + address = (host, self._port) + config = {"list_of_node_ips": list_of_node_ips} + self._server = HTTPServer( + address, + runner_handler(LocalNodeProvider(config, cluster_name=None)), + ) + self.start() + + def run(self): + self._server.serve_forever() + + def shutdown(self): + """Shutdown the underlying server.""" + self._server.shutdown() + self._server.server_close() + + +def main(): + parser = argparse.ArgumentParser( + description="Please provide a list of node ips and port.") + parser.add_argument( + "--ips", required=True, help="Comma separated list of node ips.") + parser.add_argument( + "--port", + type=int, + required=True, + help="The port on which the coordinator listens.") + args = parser.parse_args() + list_of_node_ips = args.ips.split(",") + OnPremCoordinatorServer( + list_of_node_ips=list_of_node_ips, + host=socket.gethostbyname(socket.gethostname()), + port=args.port, + ) + + +if __name__ == "__main__": + main() diff --git a/python/ray/autoscaler/local/example-full.yaml b/python/ray/autoscaler/local/example-full.yaml index 489fb3e9b..045f2efb4 100644 --- a/python/ray/autoscaler/local/example-full.yaml +++ b/python/ray/autoscaler/local/example-full.yaml @@ -38,6 +38,10 @@ provider: type: local head_ip: YOUR_HEAD_NODE_HOSTNAME worker_ips: [WORKER_NODE_1_HOSTNAME, WORKER_NODE_2_HOSTNAME, ... ] + # Optional when running automatic cluster management on prem. If you use a coordinator server, + # then you can launch multiple autoscaling clusters on the same set of machines, and the coordinator + # will assign individual nodes to clusters as needed. + # coordinator_address: ":" # How Ray will authenticate with newly launched nodes. auth: diff --git a/python/ray/autoscaler/local/node_provider.py b/python/ray/autoscaler/local/node_provider.py index 12414b43b..46d47d663 100644 --- a/python/ray/autoscaler/local/node_provider.py +++ b/python/ray/autoscaler/local/node_provider.py @@ -7,8 +7,11 @@ import logging from ray.autoscaler.node_provider import NodeProvider from ray.autoscaler.local.config import bootstrap_local -from ray.autoscaler.tags import TAG_RAY_NODE_TYPE, NODE_TYPE_WORKER, \ - NODE_TYPE_HEAD +from ray.autoscaler.tags import ( + TAG_RAY_NODE_TYPE, + NODE_TYPE_WORKER, + NODE_TYPE_HEAD, +) logger = logging.getLogger(__name__) @@ -45,8 +48,8 @@ class ClusterState: "state": "terminated", } else: - assert workers[worker_ip]["tags"][ - TAG_RAY_NODE_TYPE] == NODE_TYPE_WORKER + assert (workers[worker_ip]["tags"][TAG_RAY_NODE_TYPE] + == NODE_TYPE_WORKER) if provider_config["head_ip"] not in workers: workers[provider_config["head_ip"]] = { "tags": { @@ -55,8 +58,16 @@ class ClusterState: "state": "terminated", } else: - assert workers[provider_config["head_ip"]]["tags"][ - TAG_RAY_NODE_TYPE] == NODE_TYPE_HEAD + assert (workers[provider_config["head_ip"]]["tags"][ + TAG_RAY_NODE_TYPE] == NODE_TYPE_HEAD) + # Relevant when a user reduces the number of workers + # without changing the headnode. + list_of_node_ips = list(provider_config["worker_ips"]) + list_of_node_ips.append(provider_config["head_ip"]) + for worker_ip in list(workers): + if worker_ip not in list_of_node_ips: + del workers[worker_ip] + assert len(workers) == len(provider_config["worker_ips"]) + 1 with open(self.save_path, "w") as f: logger.debug("ClusterState: " @@ -83,17 +94,82 @@ class ClusterState: f.write(json.dumps(workers)) +class OnPremCoordinatorState(ClusterState): + """Generates & updates the state file of CoordinatorSenderNodeProvider. + + Unlike ClusterState, which generates a cluster specific file with + predefined head and worker ips, OnPremCoordinatorState overwrites + ClusterState's __init__ function to generate and manage a unified + file of the status of all the nodes for multiple clusters. + """ + + def __init__(self, lock_path, save_path, list_of_node_ips): + self.lock = RLock() + self.file_lock = FileLock(lock_path) + self.save_path = save_path + + with self.lock: + with self.file_lock: + if os.path.exists(self.save_path): + nodes = json.loads(open(self.save_path).read()) + else: + nodes = {} + logger.info( + "OnPremCoordinatorState: " + "Loaded on prem coordinator state: {}".format(nodes)) + + # Filter removed node ips. + for node_ip in list(nodes): + if node_ip not in list_of_node_ips: + del nodes[node_ip] + + for node_ip in list_of_node_ips: + if node_ip not in nodes: + nodes[node_ip] = { + "tags": {}, + "state": "terminated", + } + assert len(nodes) == len(list_of_node_ips) + with open(self.save_path, "w") as f: + logger.info( + "OnPremCoordinatorState: " + "Writing on prem coordinator state: {}".format(nodes)) + f.write(json.dumps(nodes)) + + class LocalNodeProvider(NodeProvider): """NodeProvider for private/local clusters. `node_id` is overloaded to also be `node_ip` in this class. + + When `cluster_name` is provided, it manages a single cluster in a cluster + specific state file. But when `cluster_name` is None, it manages multiple + clusters in a unified state file that requires each node to be tagged with + TAG_RAY_CLUSTER_NAME in create and non_terminated_nodes function calls to + associate each node with the right cluster. + + The current use case of managing multiple clusters is by + OnPremCoordinatorServer which receives node provider HTTP requests + from CoordinatorSenderNodeProvider and uses LocalNodeProvider to get + the responses. """ def __init__(self, provider_config, cluster_name): NodeProvider.__init__(self, provider_config, cluster_name) - self.state = ClusterState("/tmp/cluster-{}.lock".format(cluster_name), - "/tmp/cluster-{}.state".format(cluster_name), - provider_config) + + if cluster_name: + self.state = ClusterState( + "/tmp/cluster-{}.lock".format(cluster_name), + "/tmp/cluster-{}.state".format(cluster_name), + provider_config, + ) + self.use_coordinator = False + else: + # LocalNodeProvider with a coordinator server. + self.state = OnPremCoordinatorState( + "/tmp/coordinator.lock", "/tmp/coordinator.state", + provider_config["list_of_node_ips"]) + self.use_coordinator = True def non_terminated_nodes(self, tag_filters): workers = self.state.get() @@ -132,16 +208,20 @@ class LocalNodeProvider(NodeProvider): self.state.put(node_id, info) def create_node(self, node_config, tags, count): + """Creates min(count, currently available) nodes.""" node_type = tags[TAG_RAY_NODE_TYPE] with self.state.file_lock: workers = self.state.get() for node_id, info in workers.items(): if (info["state"] == "terminated" - and info["tags"][TAG_RAY_NODE_TYPE] == node_type): + and (self.use_coordinator + or info["tags"][TAG_RAY_NODE_TYPE] == node_type)): info["tags"] = tags info["state"] = "running" self.state.put(node_id, info) - return + count = count - 1 + if count == 0: + return def terminate_node(self, node_id): workers = self.state.get() diff --git a/python/ray/autoscaler/node_provider.py b/python/ray/autoscaler/node_provider.py index 0f35708fd..7d01bfa10 100644 --- a/python/ray/autoscaler/node_provider.py +++ b/python/ray/autoscaler/node_provider.py @@ -26,8 +26,13 @@ def import_azure(provider_config): def import_local(provider_config): - from ray.autoscaler.local.node_provider import LocalNodeProvider - return LocalNodeProvider + if "coordinator_address" in provider_config: + from ray.autoscaler.local.coordinator_node_provider import ( + CoordinatorSenderNodeProvider) + return CoordinatorSenderNodeProvider + else: + from ray.autoscaler.local.node_provider import LocalNodeProvider + return LocalNodeProvider def import_kubernetes(provider_config): diff --git a/python/ray/tests/BUILD b/python/ray/tests/BUILD index 02d6c98e1..4de4f950e 100644 --- a/python/ray/tests/BUILD +++ b/python/ray/tests/BUILD @@ -190,6 +190,13 @@ py_test( deps = ["//:ray_lib"], ) +py_test( + name = "test_coordinator_server", + size = "small", + srcs = SRCS + ["test_coordinator_server.py"], + tags = ["exclusive"], + deps = ["//:ray_lib"], +) py_test( name = "test_autoscaler_aws", diff --git a/python/ray/tests/test_coordinator_server.py b/python/ray/tests/test_coordinator_server.py new file mode 100644 index 000000000..4b627bef8 --- /dev/null +++ b/python/ray/tests/test_coordinator_server.py @@ -0,0 +1,254 @@ +import os +import unittest +import socket +import json + +from ray.autoscaler.local.coordinator_server import OnPremCoordinatorServer +from ray.autoscaler.node_provider import NODE_PROVIDERS, get_node_provider +from ray.autoscaler.local.node_provider import LocalNodeProvider +from ray.autoscaler.local.coordinator_node_provider import ( + CoordinatorSenderNodeProvider) +from ray.autoscaler.tags import (TAG_RAY_NODE_TYPE, TAG_RAY_CLUSTER_NAME, + TAG_RAY_NODE_NAME, NODE_TYPE_WORKER, + NODE_TYPE_HEAD) +import pytest + + +class OnPremCoordinatorServerTest(unittest.TestCase): + def setUp(self): + self.list_of_node_ips = ["0.0.0.0:1", "0.0.0.0:2"] + self.host, self.port = socket.gethostbyname(socket.gethostname()), 1234 + self.server = OnPremCoordinatorServer( + list_of_node_ips=self.list_of_node_ips, + host=self.host, + port=self.port, + ) + self.coordinator_address = self.host + ":" + str(self.port) + + def tearDown(self): + self.server.shutdown() + state_save_path = "/tmp/coordinator.state" + if os.path.exists(state_save_path): + os.remove(state_save_path) + + def testImportingCorrectClass(self): + """Check correct import when coordinator_address is in config yaml.""" + + provider_config = {"coordinator_address": "fake_address:1234"} + coordinator_node_provider = NODE_PROVIDERS.get("local")( + provider_config) + assert coordinator_node_provider is CoordinatorSenderNodeProvider + local_node_provider = NODE_PROVIDERS.get("local")({}) + assert local_node_provider is LocalNodeProvider + + def testClusterStateInit(self): + """Check ClusterState __init__ func generates correct state file. + + Test the general use case and if num_workers increase/decrease. + """ + + cluster_config = { + "cluster_name": "random_name", + "min_workers": 0, + "max_workers": 0, + "initial_workers": 0, + "provider": { + "type": "local", + "head_ip": "0.0.0.0:2", + "worker_ips": ["0.0.0.0:1"] + }, + } + provider_config = cluster_config["provider"] + node_provider = get_node_provider(provider_config, + cluster_config["cluster_name"]) + assert isinstance(node_provider, LocalNodeProvider) + expected_workers = {} + expected_workers[provider_config["head_ip"]] = { + "tags": { + TAG_RAY_NODE_TYPE: NODE_TYPE_HEAD + }, + "state": "terminated", + } + expected_workers[provider_config["worker_ips"][0]] = { + "tags": { + TAG_RAY_NODE_TYPE: NODE_TYPE_WORKER + }, + "state": "terminated", + } + + state_save_path = "/tmp/cluster-{}.state".format( + cluster_config["cluster_name"]) + assert os.path.exists(state_save_path) + workers = json.loads(open(state_save_path).read()) + assert workers == expected_workers + + # Test removing workers updates the cluster state. + del expected_workers[provider_config["worker_ips"][0]] + removed_ip = provider_config["worker_ips"].pop() + node_provider = get_node_provider(provider_config, + cluster_config["cluster_name"]) + workers = json.loads(open(state_save_path).read()) + assert workers == expected_workers + + # Test adding back workers updates the cluster state. + expected_workers[removed_ip] = { + "tags": { + TAG_RAY_NODE_TYPE: NODE_TYPE_WORKER + }, + "state": "terminated", + } + provider_config["worker_ips"].append(removed_ip) + node_provider = get_node_provider(provider_config, + cluster_config["cluster_name"]) + workers = json.loads(open(state_save_path).read()) + assert workers == expected_workers + + def testOnPremCoordinatorStateInit(self): + """If OnPremCoordinatorState __init__ generates correct state file. + + Test the general use case and if the coordinator server crashes or + updates the list of node ips with more/less nodes. + """ + + expected_nodes = {} + for ip in self.list_of_node_ips: + expected_nodes[ip] = { + "tags": {}, + "state": "terminated", + } + + state_save_path = "/tmp/coordinator.state" + assert os.path.exists(state_save_path) + nodes = json.loads(open(state_save_path).read()) + assert nodes == expected_nodes + + # Test removing workers updates the cluster state. + del expected_nodes[self.list_of_node_ips[1]] + self.server.shutdown() + self.server = OnPremCoordinatorServer( + list_of_node_ips=self.list_of_node_ips[0:1], + host=self.host, + port=self.port, + ) + nodes = json.loads(open(state_save_path).read()) + assert nodes == expected_nodes + + # Test adding back workers updates the cluster state. + expected_nodes[self.list_of_node_ips[1]] = { + "tags": {}, + "state": "terminated", + } + self.server.shutdown() + self.server = OnPremCoordinatorServer( + list_of_node_ips=self.list_of_node_ips, + host=self.host, + port=self.port, + ) + nodes = json.loads(open(state_save_path).read()) + assert nodes == expected_nodes + + def testCoordinatorSenderNodeProvider(self): + """Integration test of CoordinatorSenderNodeProvider.""" + cluster_config = { + "cluster_name": "random_name", + "min_workers": 0, + "max_workers": 0, + "initial_workers": 0, + "provider": { + "type": "local", + "coordinator_address": self.coordinator_address, + }, + "head_node": {}, + "worker_nodes": {}, + } + provider_config = cluster_config["provider"] + node_provider_1 = get_node_provider(provider_config, + cluster_config["cluster_name"]) + assert isinstance(node_provider_1, CoordinatorSenderNodeProvider) + + assert not node_provider_1.non_terminated_nodes({}) + assert not node_provider_1.is_running(self.list_of_node_ips[0]) + assert node_provider_1.is_terminated(self.list_of_node_ips[0]) + assert not node_provider_1.node_tags(self.list_of_node_ips[0]) + head_node_tags = { + TAG_RAY_NODE_TYPE: NODE_TYPE_HEAD, + } + assert not node_provider_1.non_terminated_nodes(head_node_tags) + head_node_tags[TAG_RAY_NODE_NAME] = "ray-{}-head".format( + cluster_config["cluster_name"]) + node_provider_1.create_node(cluster_config["head_node"], + head_node_tags, 1) + assert node_provider_1.non_terminated_nodes( + {}) == [self.list_of_node_ips[0]] + head_node_tags[TAG_RAY_CLUSTER_NAME] = cluster_config["cluster_name"] + assert node_provider_1.node_tags( + self.list_of_node_ips[0]) == head_node_tags + assert node_provider_1.is_running(self.list_of_node_ips[0]) + assert not node_provider_1.is_terminated(self.list_of_node_ips[0]) + + # Add another cluster. + cluster_config["cluster_name"] = "random_name_2" + provider_config = cluster_config["provider"] + node_provider_2 = get_node_provider(provider_config, + cluster_config["cluster_name"]) + assert not node_provider_2.non_terminated_nodes({}) + assert not node_provider_2.is_running(self.list_of_node_ips[1]) + assert node_provider_2.is_terminated(self.list_of_node_ips[1]) + assert not node_provider_2.node_tags(self.list_of_node_ips[1]) + assert not node_provider_2.non_terminated_nodes(head_node_tags) + head_node_tags[TAG_RAY_NODE_NAME] = "ray-{}-head".format( + cluster_config["cluster_name"]) + node_provider_2.create_node(cluster_config["head_node"], + head_node_tags, 1) + assert node_provider_2.non_terminated_nodes( + {}) == [self.list_of_node_ips[1]] + head_node_tags[TAG_RAY_CLUSTER_NAME] = cluster_config["cluster_name"] + assert node_provider_2.node_tags( + self.list_of_node_ips[1]) == head_node_tags + assert node_provider_2.is_running(self.list_of_node_ips[1]) + assert not node_provider_2.is_terminated(self.list_of_node_ips[1]) + + # Add another cluster (should fail because we only have two nodes). + cluster_config["cluster_name"] = "random_name_3" + provider_config = cluster_config["provider"] + node_provider_3 = get_node_provider(provider_config, + cluster_config["cluster_name"]) + assert not node_provider_3.non_terminated_nodes(head_node_tags) + head_node_tags[TAG_RAY_NODE_NAME] = "ray-{}-head".format( + cluster_config["cluster_name"]) + node_provider_3.create_node(cluster_config["head_node"], + head_node_tags, 1) + assert not node_provider_3.non_terminated_nodes({}) + + # Terminate all nodes. + node_provider_1.terminate_node(self.list_of_node_ips[0]) + assert not node_provider_1.non_terminated_nodes({}) + node_provider_2.terminate_node(self.list_of_node_ips[1]) + assert not node_provider_2.non_terminated_nodes({}) + + # Check if now we can create more clusters/nodes. + node_provider_3.create_node(cluster_config["head_node"], + head_node_tags, 1) + worker_node_tags = { + TAG_RAY_NODE_NAME: "ray-{}-worker".format( + cluster_config["cluster_name"]), + TAG_RAY_NODE_TYPE: NODE_TYPE_WORKER + } + node_provider_3.create_node(cluster_config["worker_nodes"], + worker_node_tags, 1) + assert node_provider_3.non_terminated_nodes( + {}) == self.list_of_node_ips + worker_filter = {TAG_RAY_NODE_TYPE: NODE_TYPE_WORKER} + assert node_provider_3.non_terminated_nodes(worker_filter) == [ + self.list_of_node_ips[1] + ] + head_filter = {TAG_RAY_NODE_TYPE: NODE_TYPE_HEAD} + assert node_provider_3.non_terminated_nodes(head_filter) == [ + self.list_of_node_ips[0] + ] + + +if __name__ == "__main__": + import sys + + sys.exit(pytest.main(["-v", __file__]))