Add external module as a node scaler. (#1703)

* WIP: add external module as a node scaler.

* Fix style.

* Add tests, fix style issues.

* Fix typos.

* Fix test error.

* Fix node provider path.

* Add function to spli pkg from class.

* Add doc.

* Correct documentation.

* Debugging....

* Debugging....

* Add __init__.py to tests.

* add more output for debugging

* Add more test, fix error with import.

* Add a small detail to the documentation.

* Update autoscaler.py
This commit is contained in:
Christian Barra
2018-03-18 00:59:13 +01:00
committed by Eric Liang
parent e3685fca5e
commit 070e27ea7a
5 changed files with 77 additions and 2 deletions
+3 -2
View File
@@ -53,8 +53,9 @@ CLUSTER_CONFIG_SCHEMA = {
# Cloud-provider specific configuration.
"provider": ({
"type": (str, REQUIRED), # e.g. aws
"region": (str, REQUIRED), # e.g. us-east-1
"availability_zone": (str, REQUIRED), # e.g. us-east-1a
"region": (str, OPTIONAL), # e.g. us-east-1
"availability_zone": (str, OPTIONAL), # e.g. us-east-1a
"module": (str, OPTIONAL), # module, if using external node provider
}, REQUIRED),
# How Ray will authenticate with newly launched nodes.
+24
View File
@@ -2,6 +2,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import importlib
import os
import yaml
@@ -25,6 +26,7 @@ NODE_PROVIDERS = {
"kubernetes": None,
"docker": None,
"local_cluster": None,
"external": None, # Import an external module
}
DEFAULT_CONFIGS = {
@@ -37,8 +39,30 @@ DEFAULT_CONFIGS = {
}
def load_class(path):
"""
Load a class at runtime given a full path.
Example of the path: mypkg.mysubpkg.myclass
"""
class_data = path.split(".")
if len(class_data) < 2:
raise ValueError(
"You need to pass a valid path like mymodule.provider_class"
)
module_path = ".".join(class_data[:-1])
class_str = class_data[-1]
module = importlib.import_module(module_path)
return getattr(module, class_str)
def get_node_provider(provider_config, cluster_name):
if provider_config["type"] == "external":
provider_cls = load_class(path=provider_config["module"])
return provider_cls(provider_config, cluster_name)
importer = NODE_PROVIDERS.get(provider_config["type"])
if importer is None:
raise NotImplementedError(
"Unsupported node provider: {}".format(provider_config["type"]))