mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 13:19:38 +08:00
[autoscaler] Instance type assert fix for non-AWS node providers (#9223)
This commit is contained in:
@@ -206,7 +206,7 @@ class StandardAutoscaler:
|
||||
self.max_concurrent_launches - num_pending)
|
||||
|
||||
num_launches = min(max_allowed, target_workers - num_workers)
|
||||
self.launch_new_node(num_launches, instance_type=None)
|
||||
self.launch_new_node(num_launches)
|
||||
nodes = self.workers()
|
||||
self.log_info_string(nodes, target_workers)
|
||||
elif self.load_metrics.num_workers_connected() >= target_workers:
|
||||
@@ -400,14 +400,13 @@ class StandardAutoscaler:
|
||||
return False
|
||||
return True
|
||||
|
||||
def launch_new_node(self, count, instance_type):
|
||||
def launch_new_node(self, count, instance_type=None):
|
||||
logger.info(
|
||||
"StandardAutoscaler: Queue {} new nodes for launch".format(count))
|
||||
# Try to fill in the default instance type so we can tag it properly.
|
||||
if not instance_type:
|
||||
instance_type = self.provider.get_instance_type(
|
||||
self.config["worker_nodes"])
|
||||
assert instance_type is not None
|
||||
self.pending_launches.inc(instance_type, count)
|
||||
config = copy.deepcopy(self.config)
|
||||
self.launch_queue.put((config, count, instance_type))
|
||||
|
||||
@@ -78,7 +78,7 @@ class MockProcessRunner:
|
||||
|
||||
|
||||
class MockProvider(NodeProvider):
|
||||
def __init__(self, cache_stopped=False):
|
||||
def __init__(self, cache_stopped=False, default_instance_type=None):
|
||||
self.mock_nodes = {}
|
||||
self.next_id = 0
|
||||
self.throw = False
|
||||
@@ -86,6 +86,7 @@ class MockProvider(NodeProvider):
|
||||
self.ready_to_create = threading.Event()
|
||||
self.ready_to_create.set()
|
||||
self.cache_stopped = cache_stopped
|
||||
self.default_instance_type = default_instance_type
|
||||
|
||||
def non_terminated_nodes(self, tag_filters):
|
||||
if self.throw:
|
||||
@@ -140,7 +141,7 @@ class MockProvider(NodeProvider):
|
||||
node_config, tags, count, instance_type=instance_type)
|
||||
|
||||
def get_instance_type(self, node_config):
|
||||
return "m4.large"
|
||||
return self.default_instance_type
|
||||
|
||||
def set_node_tags(self, node_id, tags):
|
||||
self.mock_nodes[node_id].tags.update(tags)
|
||||
|
||||
@@ -181,7 +181,7 @@ class AutoscalingTest(unittest.TestCase):
|
||||
|
||||
def testScaleUpMinSanity(self):
|
||||
config_path = self.write_config(MULTI_WORKER_CLUSTER)
|
||||
self.provider = MockProvider()
|
||||
self.provider = MockProvider(default_instance_type="m4.large")
|
||||
runner = MockProcessRunner()
|
||||
autoscaler = StandardAutoscaler(
|
||||
config_path,
|
||||
@@ -200,7 +200,7 @@ class AutoscalingTest(unittest.TestCase):
|
||||
config["min_workers"] = 0
|
||||
config["max_workers"] = 50
|
||||
config_path = self.write_config(config)
|
||||
self.provider = MockProvider()
|
||||
self.provider = MockProvider(default_instance_type="m4.large")
|
||||
runner = MockProcessRunner()
|
||||
autoscaler = StandardAutoscaler(
|
||||
config_path,
|
||||
|
||||
Reference in New Issue
Block a user