diff --git a/python/ray/autoscaler/_private/util.py b/python/ray/autoscaler/_private/util.py index 2bd1e13e9..32758dec6 100644 --- a/python/ray/autoscaler/_private/util.py +++ b/python/ray/autoscaler/_private/util.py @@ -86,6 +86,14 @@ def validate_config(config: Dict[str, Any]) -> None: raise ValueError( "`head_node_type` must be one of `available_node_types`.") + sum_min_workers = sum( + config["available_node_types"][node_type].get("min_workers", 0) + for node_type in config["available_node_types"]) + if sum_min_workers > config["max_workers"]: + raise ValueError( + "The specified global `max_workers` is smaller than the " + "sum of `min_workers` of all the available node types.") + def prepare_config(config): with_defaults = fillout_defaults(config) diff --git a/python/ray/tests/test_autoscaler_yaml.py b/python/ray/tests/test_autoscaler_yaml.py index b712c8955..10edbb8fe 100644 --- a/python/ray/tests/test_autoscaler_yaml.py +++ b/python/ray/tests/test_autoscaler_yaml.py @@ -46,8 +46,34 @@ class AutoscalingConfigTest(unittest.TestCase): self.fail("Config did not pass validation test!") @pytest.mark.skipif( - sys.platform.startswith("win"), - reason="TODO(ameer): fails on Windows.") + sys.platform.startswith("win"), reason="Fails on Windows.") + def testValidateDefaultConfigMinMaxWorkers(self): + aws_config_path = os.path.join( + RAY_PATH, "autoscaler/aws/example-multi-node-type.yaml") + with open(aws_config_path) as f: + config = yaml.safe_load(f) + config = prepare_config(config) + for node_type in config["available_node_types"]: + config["available_node_types"][node_type]["resources"] = config[ + "available_node_types"][node_type].get("resources", {}) + try: + validate_config(config) + except Exception: + self.fail("Config did not pass validation test!") + + config["max_workers"] = 0 # the sum of min_workers is 1. + with pytest.raises(ValueError): + validate_config(config) + + # make sure edge case of exactly 1 passes too. + config["max_workers"] = 1 + try: + validate_config(config) + except Exception: + self.fail("Config did not pass validation test!") + + @pytest.mark.skipif( + sys.platform.startswith("win"), reason="Fails on Windows.") def testValidateDefaultConfigAWSMultiNodeTypes(self): aws_config_path = os.path.join( RAY_PATH, "autoscaler/aws/example-multi-node-type.yaml")