diff --git a/python/ray/autoscaler/_private/util.py b/python/ray/autoscaler/_private/util.py index 2bd1e13e9..32758dec6 100644 --- a/python/ray/autoscaler/_private/util.py +++ b/python/ray/autoscaler/_private/util.py @@ -86,6 +86,14 @@ def validate_config(config: Dict[str, Any]) -> None: raise ValueError( "`head_node_type` must be one of `available_node_types`.") + sum_min_workers = sum( + config["available_node_types"][node_type].get("min_workers", 0) + for node_type in config["available_node_types"]) + if sum_min_workers > config["max_workers"]: + raise ValueError( + "The specified global `max_workers` is smaller than the " + "sum of `min_workers` of all the available node types.") + def prepare_config(config): with_defaults = fillout_defaults(config) diff --git a/python/ray/tests/test_autoscaler_yaml.py b/python/ray/tests/test_autoscaler_yaml.py index b712c8955..e5220771f 100644 --- a/python/ray/tests/test_autoscaler_yaml.py +++ b/python/ray/tests/test_autoscaler_yaml.py @@ -45,6 +45,31 @@ class AutoscalingConfigTest(unittest.TestCase): except Exception: self.fail("Config did not pass validation test!") + def testValidateDefaultConfigMinMaxWorkers(self): + aws_config_path = os.path.join( + RAY_PATH, "autoscaler/aws/example-multi-node-type.yaml") + with open(aws_config_path) as f: + config = yaml.safe_load(f) + config = prepare_config(config) + for node_type in config["available_node_types"]: + config["available_node_types"][node_type]["resources"] = config[ + "available_node_types"][node_type].get("resources", {}) + try: + validate_config(config) + except Exception: + self.fail("Config did not pass validation test!") + + config["max_workers"] = 0 # the sum of min_workers is 1. + with pytest.raises(ValueError): + validate_config(config) + + # make sure edge case of exactly 1 passes too. + config["max_workers"] = 1 + try: + validate_config(config) + except Exception: + self.fail("Config did not pass validation test!") + @pytest.mark.skipif( sys.platform.startswith("win"), reason="TODO(ameer): fails on Windows.")