mirror of
https://github.com/wassname/ray.git
synced 2026-06-30 15:55:01 +08:00
[Autoscaler] Pass custom resources to "ray start" multi instance autoscaling (#9986)
This commit is contained in:
@@ -12,10 +12,10 @@ import yaml
|
||||
from ray.experimental.internal_kv import _internal_kv_put, \
|
||||
_internal_kv_initialized
|
||||
from ray.autoscaler.node_provider import get_node_provider
|
||||
from ray.autoscaler.tags import (TAG_RAY_LAUNCH_CONFIG, TAG_RAY_RUNTIME_CONFIG,
|
||||
TAG_RAY_FILE_MOUNTS_CONTENTS,
|
||||
TAG_RAY_NODE_STATUS, TAG_RAY_NODE_TYPE,
|
||||
STATUS_UP_TO_DATE, NODE_TYPE_WORKER)
|
||||
from ray.autoscaler.tags import (
|
||||
TAG_RAY_LAUNCH_CONFIG, TAG_RAY_RUNTIME_CONFIG,
|
||||
TAG_RAY_FILE_MOUNTS_CONTENTS, TAG_RAY_NODE_STATUS, TAG_RAY_NODE_TYPE,
|
||||
TAG_RAY_INSTANCE_TYPE, STATUS_UP_TO_DATE, NODE_TYPE_WORKER)
|
||||
from ray.autoscaler.updater import NodeUpdaterThread
|
||||
from ray.autoscaler.node_launcher import NodeLauncher
|
||||
from ray.autoscaler.resource_demand_scheduler import ResourceDemandScheduler
|
||||
@@ -95,7 +95,9 @@ class StandardAutoscaler:
|
||||
provider=self.provider,
|
||||
queue=self.launch_queue,
|
||||
index=i,
|
||||
pending=self.pending_launches)
|
||||
pending=self.pending_launches,
|
||||
instance_types=self.instance_types,
|
||||
)
|
||||
node_launcher.daemon = True
|
||||
node_launcher.start()
|
||||
|
||||
@@ -243,10 +245,11 @@ class StandardAutoscaler:
|
||||
for node_id, commands, ray_start in (self.should_update(node_id)
|
||||
for node_id in nodes):
|
||||
if node_id is not None:
|
||||
resources = self._node_resources(node_id)
|
||||
T.append(
|
||||
threading.Thread(
|
||||
target=self.spawn_updater,
|
||||
args=(node_id, commands, ray_start)))
|
||||
args=(node_id, commands, ray_start, resources)))
|
||||
for t in T:
|
||||
t.start()
|
||||
for t in T:
|
||||
@@ -256,6 +259,14 @@ class StandardAutoscaler:
|
||||
for node_id in nodes:
|
||||
self.recover_if_needed(node_id, now)
|
||||
|
||||
def _node_resources(self, node_id):
|
||||
instance_type = self.provider.node_tags(node_id).get(
|
||||
TAG_RAY_INSTANCE_TYPE)
|
||||
if instance_type:
|
||||
return self.instance_types[instance_type].get("resources", {})
|
||||
else:
|
||||
return {}
|
||||
|
||||
def reload_config(self, errors_fatal=False):
|
||||
sync_continuously = False
|
||||
if hasattr(self, "config"):
|
||||
@@ -396,7 +407,8 @@ class StandardAutoscaler:
|
||||
|
||||
return (node_id, init_commands, ray_commands)
|
||||
|
||||
def spawn_updater(self, node_id, init_commands, ray_start_commands):
|
||||
def spawn_updater(self, node_id, init_commands, ray_start_commands,
|
||||
node_resources):
|
||||
updater = NodeUpdaterThread(
|
||||
node_id=node_id,
|
||||
provider_config=self.config["provider"],
|
||||
@@ -413,7 +425,8 @@ class StandardAutoscaler:
|
||||
cluster_synced_files=self.config["cluster_synced_files"],
|
||||
process_runner=self.process_runner,
|
||||
use_internal_ip=True,
|
||||
docker_config=self.config.get("docker"))
|
||||
docker_config=self.config.get("docker"),
|
||||
node_resources=node_resources)
|
||||
updater.start()
|
||||
self.updaters[node_id] = updater
|
||||
|
||||
|
||||
@@ -17,10 +17,10 @@ available_instance_types:
|
||||
resources: {"CPU": 4}
|
||||
max_workers: 10
|
||||
m4.4xlarge:
|
||||
resources: {"CPU": 16}
|
||||
resources: {"CPU": 16, "Custom1": 1}
|
||||
max_workers: 10
|
||||
p2.xlarge:
|
||||
resources: {"CPU": 4, "GPU": 1}
|
||||
resources: {"CPU": 4, "GPU": 1, "Custom2": 2}
|
||||
max_workers: 4
|
||||
p2.8xlarge:
|
||||
resources: {"CPU": 32, "GPU": 8}
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from getpass import getuser
|
||||
from shlex import quote
|
||||
from typing import List, Tuple
|
||||
from typing import List, Tuple, Dict
|
||||
import click
|
||||
import hashlib
|
||||
import logging
|
||||
@@ -66,6 +66,35 @@ def set_using_login_shells(val):
|
||||
_config["use_login_shells"] = val
|
||||
|
||||
|
||||
def _with_environment_variables(cmd: str,
|
||||
environment_variables: Dict[str, object]):
|
||||
"""Prepend environment variables to a shell command.
|
||||
|
||||
Args:
|
||||
cmd (str): The base command.
|
||||
environment_variables (Dict[str, object]): The set of environment
|
||||
variables. If an environment variable value is a dict, it will
|
||||
automatically be converted to a one line yaml string.
|
||||
"""
|
||||
|
||||
def dict_as_one_line_yaml(d):
|
||||
items = []
|
||||
for key, val in d.items():
|
||||
item_str = "{}: {}".format(quote(str(key)), quote(str(val)))
|
||||
items.append(item_str)
|
||||
|
||||
return "{" + ",".join(items) + "}"
|
||||
|
||||
as_strings = []
|
||||
for key, val in environment_variables.items():
|
||||
if isinstance(val, dict):
|
||||
val = dict_as_one_line_yaml(val)
|
||||
s = "export {}={};".format(key, quote(val))
|
||||
as_strings.append(s)
|
||||
all_vars = "".join(as_strings)
|
||||
return all_vars + cmd
|
||||
|
||||
|
||||
def _with_interactive(cmd):
|
||||
force_interactive = ("true && source ~/.bashrc && "
|
||||
"export OMP_NUM_THREADS=1 PYTHONWARNINGS=ignore && ")
|
||||
@@ -85,6 +114,7 @@ class CommandRunnerInterface:
|
||||
exit_on_fail: bool = False,
|
||||
port_forward: List[Tuple[int, int]] = None,
|
||||
with_output: bool = False,
|
||||
environment_variables: Dict[str, object] = None,
|
||||
run_env: str = "auto",
|
||||
ssh_options_override_ssh_key: str = "",
|
||||
) -> str:
|
||||
@@ -100,6 +130,8 @@ class CommandRunnerInterface:
|
||||
port_forward (list): List of (local, remote) ports to forward, or
|
||||
a single tuple.
|
||||
with_output (bool): Whether to return output.
|
||||
environment_variables (Dict[str, str | int | Dict[str, str]):
|
||||
Environment variables that `cmd` should be run with.
|
||||
run_env (str): Options: docker/host/auto. Used in
|
||||
DockerCommandRunner to determine the run environment.
|
||||
ssh_options_override_ssh_key (str): if provided, overwrites
|
||||
@@ -147,6 +179,7 @@ class KubernetesCommandRunner(CommandRunnerInterface):
|
||||
exit_on_fail=False,
|
||||
port_forward=None,
|
||||
with_output=False,
|
||||
environment_variables: Dict[str, object] = None,
|
||||
run_env="auto", # Unused argument.
|
||||
ssh_options_override_ssh_key="", # Unused argument.
|
||||
):
|
||||
@@ -180,6 +213,9 @@ class KubernetesCommandRunner(CommandRunnerInterface):
|
||||
self.node_id,
|
||||
"--",
|
||||
]
|
||||
cmd = _with_interactive(cmd)
|
||||
if environment_variables:
|
||||
cmd = _with_environment_variables(cmd, environment_variables)
|
||||
final_cmd += _with_interactive(cmd)
|
||||
logger.info(self.log_prefix + "Running {}".format(final_cmd))
|
||||
try:
|
||||
@@ -434,6 +470,7 @@ class SSHCommandRunner(CommandRunnerInterface):
|
||||
exit_on_fail=False,
|
||||
port_forward=None,
|
||||
with_output=False,
|
||||
environment_variables: Dict[str, object] = None,
|
||||
run_env="auto", # Unused argument.
|
||||
ssh_options_override_ssh_key="",
|
||||
):
|
||||
@@ -472,6 +509,8 @@ class SSHCommandRunner(CommandRunnerInterface):
|
||||
"{}@{}".format(self.ssh_user, self.ssh_ip)
|
||||
]
|
||||
if cmd:
|
||||
if environment_variables:
|
||||
cmd = _with_environment_variables(cmd, environment_variables)
|
||||
if is_using_login_shells():
|
||||
final_cmd += _with_interactive(cmd)
|
||||
else:
|
||||
@@ -544,12 +583,16 @@ class DockerCommandRunner(SSHCommandRunner):
|
||||
exit_on_fail=False,
|
||||
port_forward=None,
|
||||
with_output=False,
|
||||
environment_variables: Dict[str, object] = None,
|
||||
run_env="auto",
|
||||
ssh_options_override_ssh_key="",
|
||||
):
|
||||
if run_env == "auto":
|
||||
run_env = "host" if cmd.find("docker") == 0 else "docker"
|
||||
|
||||
if environment_variables:
|
||||
cmd = _with_environment_variables(cmd, environment_variables)
|
||||
|
||||
if run_env == "docker":
|
||||
cmd = self._docker_expand_user(cmd, any_char=True)
|
||||
cmd = " ".join(_with_interactive(cmd))
|
||||
|
||||
@@ -13,10 +13,18 @@ logger = logging.getLogger(__name__)
|
||||
class NodeLauncher(threading.Thread):
|
||||
"""Launches nodes asynchronously in the background."""
|
||||
|
||||
def __init__(self, provider, queue, pending, index=None, *args, **kwargs):
|
||||
def __init__(self,
|
||||
provider,
|
||||
queue,
|
||||
pending,
|
||||
instance_types=None,
|
||||
index=None,
|
||||
*args,
|
||||
**kwargs):
|
||||
self.queue = queue
|
||||
self.pending = pending
|
||||
self.provider = provider
|
||||
self.instance_types = instance_types
|
||||
self.index = str(index) if index is not None else ""
|
||||
super(NodeLauncher, self).__init__(*args, **kwargs)
|
||||
|
||||
|
||||
@@ -16,6 +16,7 @@ from ray.autoscaler.log_timer import LogTimer
|
||||
import ray.autoscaler.subprocess_output_util as cmd_output_util
|
||||
|
||||
from ray.autoscaler.cli_logger import cli_logger
|
||||
from ray import ray_constants
|
||||
import colorful as cf
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -38,6 +39,7 @@ class NodeUpdater:
|
||||
ray_start_commands,
|
||||
runtime_hash,
|
||||
file_mounts_contents_hash,
|
||||
node_resources=None,
|
||||
cluster_synced_files=None,
|
||||
process_runner=subprocess,
|
||||
use_internal_ip=False,
|
||||
@@ -61,6 +63,7 @@ class NodeUpdater:
|
||||
self.initialization_commands = initialization_commands
|
||||
self.setup_commands = setup_commands
|
||||
self.ray_start_commands = ray_start_commands
|
||||
self.node_resources = node_resources
|
||||
self.runtime_hash = runtime_hash
|
||||
self.file_mounts_contents_hash = file_mounts_contents_hash
|
||||
self.cluster_synced_files = cluster_synced_files
|
||||
@@ -341,9 +344,17 @@ class NodeUpdater:
|
||||
with LogTimer(
|
||||
self.log_prefix + "Ray start commands", show_status=True):
|
||||
for cmd in self.ray_start_commands:
|
||||
if self.node_resources:
|
||||
env_vars = {
|
||||
ray_constants.RESOURCES_ENVIRONMENT_VARIABLE: self.
|
||||
node_resources
|
||||
}
|
||||
else:
|
||||
env_vars = {}
|
||||
try:
|
||||
cmd_output_util.set_output_redirected(False)
|
||||
self.cmd_runner.run(cmd)
|
||||
cmd_output_util.set_output_redirected(True)
|
||||
self.cmd_runner.run(
|
||||
cmd, environment_variables=env_vars)
|
||||
cmd_output_util.set_output_redirected(True)
|
||||
except ProcessRunnerError as e:
|
||||
if e.msg_type == "ssh_command_failed":
|
||||
|
||||
+1
-1
@@ -272,7 +272,7 @@ class Node:
|
||||
return result
|
||||
|
||||
env_resources = {}
|
||||
env_string = os.getenv("RAY_OVERRIDE_RESOURCES")
|
||||
env_string = os.getenv(ray_constants.RESOURCES_ENVIRONMENT_VARIABLE)
|
||||
if env_string:
|
||||
env_resources = json.loads(env_string)
|
||||
|
||||
|
||||
@@ -131,6 +131,8 @@ RAYLET_CONNECTION_ERROR = "raylet_connection_error"
|
||||
# Used in gpu detection
|
||||
RESOURCE_CONSTRAINT_PREFIX = "GPUType:"
|
||||
|
||||
RESOURCES_ENVIRONMENT_VARIABLE = "RAY_OVERRIDE_RESOURCES"
|
||||
|
||||
# Abort autoscaling if more than this number of errors are encountered. This
|
||||
# is a safety feature to prevent e.g. runaway node launches.
|
||||
AUTOSCALER_MAX_NUM_FAILURES = env_integer("AUTOSCALER_MAX_NUM_FAILURES", 5)
|
||||
|
||||
@@ -64,6 +64,7 @@ py_test_module_list(
|
||||
"test_autoscaler.py",
|
||||
"test_autoscaler_yaml.py",
|
||||
"test_component_failures.py",
|
||||
"test_command_runner.py",
|
||||
"test_coordinator_server.py",
|
||||
"test_dask_scheduler.py",
|
||||
"test_debug_tools.py",
|
||||
|
||||
@@ -51,18 +51,31 @@ class MockProcessRunner:
|
||||
self.check_call(cmd)
|
||||
return "command-output".encode()
|
||||
|
||||
def assert_has_call(self, ip, pattern):
|
||||
def assert_has_call(self, ip, pattern=None, exact=None):
|
||||
assert pattern or exact, \
|
||||
"Must specify either a pattern or exact match."
|
||||
out = ""
|
||||
for cmd in self.calls:
|
||||
msg = " ".join(cmd)
|
||||
if ip in msg:
|
||||
out += msg
|
||||
out += "\n"
|
||||
if pattern in out:
|
||||
return True
|
||||
if pattern is not None:
|
||||
for cmd in self.calls:
|
||||
msg = " ".join(cmd)
|
||||
if ip in msg:
|
||||
out += msg
|
||||
out += "\n"
|
||||
if pattern in out:
|
||||
return True
|
||||
else:
|
||||
raise Exception("Did not find [{}] in [{}] for {}".format(
|
||||
pattern, out, ip))
|
||||
else:
|
||||
raise Exception("Did not find [{}] in [{}] for {}".format(
|
||||
pattern, out, ip))
|
||||
for cmd in self.calls:
|
||||
msg = " ".join(cmd)
|
||||
if ip in msg:
|
||||
out += msg
|
||||
out += "\n"
|
||||
if cmd == exact:
|
||||
return True
|
||||
raise Exception("Did not find {} in {} for {}".format(
|
||||
exact, out, ip))
|
||||
|
||||
def assert_not_has_call(self, ip, pattern):
|
||||
out = ""
|
||||
|
||||
@@ -0,0 +1,145 @@
|
||||
import pytest
|
||||
from ray.tests.test_autoscaler import MockProvider, MockProcessRunner
|
||||
from ray.autoscaler.command_runner import SSHCommandRunner, \
|
||||
_with_environment_variables, DockerCommandRunner
|
||||
from getpass import getuser
|
||||
import hashlib
|
||||
|
||||
auth_config = {
|
||||
"ssh_user": "ray",
|
||||
"ssh_private_key": "8265.pem",
|
||||
}
|
||||
|
||||
|
||||
def test_environment_variable_encoder_strings():
|
||||
env_vars = {"var1": "quote between this \" and this", "var2": "123"}
|
||||
res = _with_environment_variables("echo hello", env_vars)
|
||||
expected = """export var1='quote between this " and this';export var2=123;echo hello""" # noqa: E501
|
||||
assert res == expected
|
||||
|
||||
|
||||
def test_environment_variable_encoder_dict():
|
||||
env_vars = {"value1": "string1", "value2": {"a": "b", "c": 2}}
|
||||
res = _with_environment_variables("echo hello", env_vars)
|
||||
|
||||
expected = """export value1=string1;export value2='{a: b,c: 2}';echo hello""" # noqa: E501
|
||||
assert res == expected
|
||||
|
||||
|
||||
def test_ssh_command_runner():
|
||||
process_runner = MockProcessRunner()
|
||||
provider = MockProvider()
|
||||
provider.create_node({}, {}, 1)
|
||||
cluster_name = "cluster"
|
||||
ssh_control_hash = hashlib.md5(cluster_name.encode()).hexdigest()
|
||||
ssh_user_hash = hashlib.md5(getuser().encode()).hexdigest()
|
||||
ssh_control_path = "/tmp/ray_ssh_{}/{}".format(ssh_user_hash[:10],
|
||||
ssh_control_hash[:10])
|
||||
args = {
|
||||
"log_prefix": "prefix",
|
||||
"node_id": 0,
|
||||
"provider": provider,
|
||||
"auth_config": auth_config,
|
||||
"cluster_name": cluster_name,
|
||||
"process_runner": process_runner,
|
||||
"use_internal_ip": False,
|
||||
}
|
||||
cmd_runner = SSHCommandRunner(**args)
|
||||
|
||||
env_vars = {"var1": "quote between this \" and this", "var2": "123"}
|
||||
cmd_runner.run(
|
||||
"echo helloo",
|
||||
port_forward=[(8265, 8265)],
|
||||
environment_variables=env_vars)
|
||||
|
||||
expected = [
|
||||
"ssh",
|
||||
"-tt",
|
||||
"-L",
|
||||
"8265:localhost:8265",
|
||||
"-i",
|
||||
"8265.pem",
|
||||
"-o",
|
||||
"StrictHostKeyChecking=no",
|
||||
"-o",
|
||||
"UserKnownHostsFile=/dev/null",
|
||||
"-o",
|
||||
"IdentitiesOnly=yes",
|
||||
"-o",
|
||||
"ExitOnForwardFailure=yes",
|
||||
"-o",
|
||||
"ServerAliveInterval=5",
|
||||
"-o",
|
||||
"ServerAliveCountMax=3",
|
||||
"-o",
|
||||
"ControlMaster=auto",
|
||||
"-o",
|
||||
"ControlPath={}/%C".format(ssh_control_path),
|
||||
"-o",
|
||||
"ControlPersist=10s",
|
||||
"-o",
|
||||
"ConnectTimeout=120s",
|
||||
"ray@1.2.3.4",
|
||||
"bash",
|
||||
"--login",
|
||||
"-c",
|
||||
"-i",
|
||||
"""'true && source ~/.bashrc && export OMP_NUM_THREADS=1 PYTHONWARNINGS=ignore && export var1='"'"'quote between this " and this'"'"';export var2=123;echo helloo'""" # noqa: E501
|
||||
]
|
||||
|
||||
# Much easier to debug this loop than the function call.
|
||||
for x, y in zip(process_runner.calls[0], expected):
|
||||
assert x == y
|
||||
process_runner.assert_has_call("1.2.3.4", exact=expected)
|
||||
|
||||
|
||||
def test_docker_command_runner():
|
||||
process_runner = MockProcessRunner()
|
||||
provider = MockProvider()
|
||||
provider.create_node({}, {}, 1)
|
||||
cluster_name = "cluster"
|
||||
ssh_control_hash = hashlib.md5(cluster_name.encode()).hexdigest()
|
||||
ssh_user_hash = hashlib.md5(getuser().encode()).hexdigest()
|
||||
ssh_control_path = "/tmp/ray_ssh_{}/{}".format(ssh_user_hash[:10],
|
||||
ssh_control_hash[:10])
|
||||
docker_config = {"container_name": "container"}
|
||||
args = {
|
||||
"log_prefix": "prefix",
|
||||
"node_id": 0,
|
||||
"provider": provider,
|
||||
"auth_config": auth_config,
|
||||
"cluster_name": cluster_name,
|
||||
"process_runner": process_runner,
|
||||
"use_internal_ip": False,
|
||||
"docker_config": docker_config,
|
||||
}
|
||||
cmd_runner = DockerCommandRunner(**args)
|
||||
process_runner.assert_has_call("1.2.3.4", "command -v docker")
|
||||
process_runner.clear_history()
|
||||
|
||||
env_vars = {"var1": "quote between this \" and this", "var2": "123"}
|
||||
cmd_runner.run("echo hello", environment_variables=env_vars)
|
||||
|
||||
# This string is insane because there are an absurd number of embedded
|
||||
# quotes. While this is a ridiculous string, the escape behavior is
|
||||
# important and somewhat difficult to get right for environment variables.
|
||||
cmd = """'true && source ~/.bashrc && export OMP_NUM_THREADS=1 PYTHONWARNINGS=ignore && docker exec -it container /bin/bash -c '"'"'bash --login -c -i '"'"'"'"'"'"'"'"'true && source ~/.bashrc && export OMP_NUM_THREADS=1 PYTHONWARNINGS=ignore && export var1='"'"'"'"'"'"'"'"'"'"'"'"'"'"'"'"'"'"'"'"'"'"'"'"'"'"'quote between this " and this'"'"'"'"'"'"'"'"'"'"'"'"'"'"'"'"'"'"'"'"'"'"'"'"'"'"';export var2=123;echo hello'"'"'"'"'"'"'"'"''"'"' '""" # noqa: E501
|
||||
|
||||
expected = [
|
||||
"ssh", "-tt", "-i", "8265.pem", "-o", "StrictHostKeyChecking=no", "-o",
|
||||
"UserKnownHostsFile=/dev/null", "-o", "IdentitiesOnly=yes", "-o",
|
||||
"ExitOnForwardFailure=yes", "-o", "ServerAliveInterval=5", "-o",
|
||||
"ServerAliveCountMax=3", "-o", "ControlMaster=auto", "-o",
|
||||
"ControlPath={}/%C".format(ssh_control_path), "-o",
|
||||
"ControlPersist=10s", "-o", "ConnectTimeout=120s", "ray@1.2.3.4",
|
||||
"bash", "--login", "-c", "-i", cmd
|
||||
]
|
||||
# Much easier to debug this loop than the function call.
|
||||
for x, y in zip(process_runner.calls[0], expected):
|
||||
assert x == y
|
||||
process_runner.assert_has_call("1.2.3.4", exact=expected)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
sys.exit(pytest.main(["-v", __file__]))
|
||||
@@ -14,6 +14,8 @@ from ray.autoscaler.node_provider import NODE_PROVIDERS
|
||||
from ray.autoscaler.resource_demand_scheduler import _utilization_score, \
|
||||
get_bin_pack_residual, get_instances_for
|
||||
|
||||
from time import sleep
|
||||
|
||||
TYPES_A = {
|
||||
"m4.large": {
|
||||
"resources": {
|
||||
@@ -226,6 +228,46 @@ class AutoscalingTest(unittest.TestCase):
|
||||
assert self.provider.mock_nodes[2].instance_type == "m4.16xlarge"
|
||||
assert self.provider.mock_nodes[3].instance_type == "m4.16xlarge"
|
||||
|
||||
def testResourcePassing(self):
|
||||
config = MULTI_WORKER_CLUSTER.copy()
|
||||
config["min_workers"] = 0
|
||||
config["max_workers"] = 50
|
||||
config_path = self.write_config(config)
|
||||
self.provider = MockProvider(default_instance_type="m4.large")
|
||||
runner = MockProcessRunner()
|
||||
autoscaler = StandardAutoscaler(
|
||||
config_path,
|
||||
LoadMetrics(),
|
||||
max_failures=0,
|
||||
process_runner=runner,
|
||||
update_interval_s=0)
|
||||
assert len(self.provider.non_terminated_nodes({})) == 0
|
||||
autoscaler.update()
|
||||
self.waitForNodes(0)
|
||||
autoscaler.request_resources([{"CPU": 1}])
|
||||
autoscaler.update()
|
||||
self.waitForNodes(1)
|
||||
assert self.provider.mock_nodes[0].instance_type == "m4.large"
|
||||
autoscaler.request_resources([{"GPU": 8}])
|
||||
autoscaler.update()
|
||||
self.waitForNodes(2)
|
||||
assert self.provider.mock_nodes[1].instance_type == "p2.8xlarge"
|
||||
|
||||
# TODO (Alex): Autoscaler creates the node during one update then
|
||||
# starts the updater in the enxt update. The sleep is largely
|
||||
# unavoidable because the updater runs in its own thread and we have no
|
||||
# good way of ensuring that the commands are sent in time.
|
||||
autoscaler.update()
|
||||
sleep(0.1)
|
||||
|
||||
# These checks are done separately because we have no guarantees on the
|
||||
# order the dict is serialized in.
|
||||
runner.assert_has_call("172.0.0.0", "RAY_OVERRIDE_RESOURCES=")
|
||||
runner.assert_has_call("172.0.0.0", "CPU: 2")
|
||||
runner.assert_has_call("172.0.0.1", "RAY_OVERRIDE_RESOURCES=")
|
||||
runner.assert_has_call("172.0.0.1", "CPU: 32")
|
||||
runner.assert_has_call("172.0.0.1", "GPU: 8")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
|
||||
Reference in New Issue
Block a user