diff --git a/doc/source/autoscaling.rst b/doc/source/autoscaling.rst index ec886e0f0..94b4b5813 100644 --- a/doc/source/autoscaling.rst +++ b/doc/source/autoscaling.rst @@ -157,6 +157,21 @@ with GPU worker nodes instead. MarketType: spot InstanceType: p2.xlarge + +External Node Provider +-------------------------- + +Ray also supports external node providers (check `node_provider.py `__ implementation). +You can specify the external node provider using the yaml config: + +.. code-block:: yaml + + provider: + type: external + module: mypackage.myclass + +The module needs to be in the format `package.provider_class` or `package.sub_package.provider_class`. + Additional Cloud providers -------------------------- diff --git a/python/ray/autoscaler/autoscaler.py b/python/ray/autoscaler/autoscaler.py index fb1c6389b..8132d73be 100644 --- a/python/ray/autoscaler/autoscaler.py +++ b/python/ray/autoscaler/autoscaler.py @@ -53,8 +53,9 @@ CLUSTER_CONFIG_SCHEMA = { # Cloud-provider specific configuration. "provider": ({ "type": (str, REQUIRED), # e.g. aws - "region": (str, REQUIRED), # e.g. us-east-1 - "availability_zone": (str, REQUIRED), # e.g. us-east-1a + "region": (str, OPTIONAL), # e.g. us-east-1 + "availability_zone": (str, OPTIONAL), # e.g. us-east-1a + "module": (str, OPTIONAL), # module, if using external node provider }, REQUIRED), # How Ray will authenticate with newly launched nodes. diff --git a/python/ray/autoscaler/node_provider.py b/python/ray/autoscaler/node_provider.py index 082bb5d43..b3f2796f4 100644 --- a/python/ray/autoscaler/node_provider.py +++ b/python/ray/autoscaler/node_provider.py @@ -2,6 +2,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import importlib import os import yaml @@ -25,6 +26,7 @@ NODE_PROVIDERS = { "kubernetes": None, "docker": None, "local_cluster": None, + "external": None, # Import an external module } DEFAULT_CONFIGS = { @@ -37,8 +39,30 @@ DEFAULT_CONFIGS = { } +def load_class(path): + """ + Load a class at runtime given a full path. + + Example of the path: mypkg.mysubpkg.myclass + """ + class_data = path.split(".") + if len(class_data) < 2: + raise ValueError( + "You need to pass a valid path like mymodule.provider_class" + ) + module_path = ".".join(class_data[:-1]) + class_str = class_data[-1] + module = importlib.import_module(module_path) + return getattr(module, class_str) + + def get_node_provider(provider_config, cluster_name): + if provider_config["type"] == "external": + provider_cls = load_class(path=provider_config["module"]) + return provider_cls(provider_config, cluster_name) + importer = NODE_PROVIDERS.get(provider_config["type"]) + if importer is None: raise NotImplementedError( "Unsupported node provider: {}".format(provider_config["type"])) diff --git a/test/__init__.py b/test/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/test/autoscaler_test.py b/test/autoscaler_test.py index f5193ca0e..fb04ba823 100644 --- a/test/autoscaler_test.py +++ b/test/autoscaler_test.py @@ -507,6 +507,41 @@ class AutoscalingTest(unittest.TestCase): autoscaler.update() self.waitFor(lambda: len(runner.calls) > num_calls) + def testExternalNodeScaler(self): + config = SMALL_CLUSTER.copy() + config["provider"] = { + "type": "external", + "module": "ray.autoscaler.node_provider.NodeProvider", + } + config_path = self.write_config(config) + autoscaler = StandardAutoscaler( + config_path, LoadMetrics(), max_failures=0, update_interval_s=0) + self.assertIsInstance(autoscaler.provider, NodeProvider) + + def testExternalNodeScalerWrongImport(self): + config = SMALL_CLUSTER.copy() + config["provider"] = { + "type": "external", + "module": "mymodule.provider_class", + } + invalid_provider = self.write_config(config) + self.assertRaises( + ImportError, + lambda: StandardAutoscaler( + invalid_provider, LoadMetrics(), update_interval_s=0)) + + def testExternalNodeScalerWrongModuleFormat(self): + config = SMALL_CLUSTER.copy() + config["provider"] = { + "type": "external", + "module": "does-not-exist", + } + invalid_provider = self.write_config(config) + self.assertRaises( + ValueError, + lambda: StandardAutoscaler( + invalid_provider, LoadMetrics(), update_interval_s=0)) + if __name__ == "__main__": unittest.main(verbosity=2)