diff --git a/python/ray/dashboard/metrics_exporter/actions.py b/python/ray/dashboard/metrics_exporter/actions.py new file mode 100644 index 000000000..a741f9d74 --- /dev/null +++ b/python/ray/dashboard/metrics_exporter/actions.py @@ -0,0 +1,30 @@ +import logging +import typing + +from ray.dashboard.metrics_exporter.schema import ActionType, KillAction + +logger = logging.getLogger(__name__) + + +class ActionHandler: + def __init__(self, dashboard_controller): + self.dashboard_controller = dashboard_controller + + def handle_kill_action(self, action: dict): + kill_action = KillAction.parse_obj(action) + self.dashboard_controller.kill_actor( + kill_action.actor_id, kill_action.ip_address, kill_action.port) + + def handle_actions(self, actions: typing.List[dict]): + for action in actions: + action_type = action.get("type", None) + + if action_type == ActionType.KILL_ACTOR: + self.handle_kill_action(action) + else: + logger.warning("Action type {} has been received, but " + "action handler doesn't know how to handle " + "them. It will skip processing the request. " + "Plesae raise an issue if you see this problem." + .format(action_type)) + continue diff --git a/python/ray/dashboard/metrics_exporter/api.py b/python/ray/dashboard/metrics_exporter/api.py index d6c24de89..2e7fcb5f7 100644 --- a/python/ray/dashboard/metrics_exporter/api.py +++ b/python/ray/dashboard/metrics_exporter/api.py @@ -1,3 +1,5 @@ +import json + try: import requests # `requests` is not part of stdlib. except ImportError: @@ -9,7 +11,7 @@ from ray.dashboard.metrics_exporter.schema import AuthRequest, AuthResponse from ray.dashboard.metrics_exporter.schema import IngestRequest, IngestResponse -def authentication_request(url, cluster_id): +def authentication_request(url, cluster_id) -> AuthResponse: auth_requeset = AuthRequest(cluster_id=cluster_id) response = requests.post(url, data=auth_requeset.json()) response.raise_for_status() @@ -17,7 +19,8 @@ def authentication_request(url, cluster_id): def ingest_request(url, cluster_id, access_token, ray_config, node_info, - raylet_info, tune_info, tune_availability): + raylet_info, tune_info, + tune_availability) -> IngestResponse: ingest_request = IngestRequest( cluster_id=cluster_id, access_token=access_token, @@ -28,4 +31,4 @@ def ingest_request(url, cluster_id, access_token, ray_config, node_info, tune_availability=tune_availability) response = requests.post(url, data=ingest_request.json()) response.raise_for_status() - return IngestResponse.parse_obj(response.json()) + return IngestResponse.parse_obj(json.loads(response.json())) diff --git a/python/ray/dashboard/metrics_exporter/client.py b/python/ray/dashboard/metrics_exporter/client.py index 5bdc59b21..06e024538 100644 --- a/python/ray/dashboard/metrics_exporter/client.py +++ b/python/ray/dashboard/metrics_exporter/client.py @@ -4,6 +4,7 @@ import traceback import time from ray.dashboard.metrics_exporter import api +from ray.dashboard.metrics_exporter.actions import ActionHandler logger = logging.getLogger(__name__) @@ -102,6 +103,7 @@ class Exporter(threading.Thread): self.dashboard_id = dashboard_id self.dashboard_controller = dashboard_controller + self.action_handler = ActionHandler(dashboard_controller) self.export_address = "{}/ingest".format(address) self.update_frequency = update_frequency self._access_token = None @@ -118,10 +120,11 @@ class Exporter(threading.Thread): def export(self, ray_config, node_info, raylet_info, tune_info, tune_availability): - api.ingest_request(self.export_address, self.dashboard_id, - self.access_token, ray_config, node_info, - raylet_info, tune_info, tune_availability) - # TODO(sang): Add piggybacking response handler. + ingest_response = api.ingest_request( + self.export_address, self.dashboard_id, self.access_token, + ray_config, node_info, raylet_info, tune_info, tune_availability) + actions = ingest_response.actions + self.action_handler.handle_actions(actions) def run(self): assert self.access_token is not None, ( diff --git a/python/ray/dashboard/metrics_exporter/schema.py b/python/ray/dashboard/metrics_exporter/schema.py index ff20bd013..32b07ea5b 100644 --- a/python/ray/dashboard/metrics_exporter/schema.py +++ b/python/ray/dashboard/metrics_exporter/schema.py @@ -1,19 +1,47 @@ import json +from collections import namedtuple + class ValidationError(Exception): pass +Field = namedtuple("Field", ["required", "default", "type"]) + + class BaseModel: """Base class to define schema. - This will raise ValidationError if - - Number of given kwargs are bigger than needed. - - Number of given kwargs are smaller than needed. + Model schema should be defined in class variable `__schema__` + within a child class. `__schema__` should be a dictionary that contains + `field`: `Field( + required=required: bool, + default=default: Any, + type=type: type + )` + See the example below for more details. - This doesn't - - Validate types. + The class can have unexpected behavior if you don't follow the + schema pattern properly. + + Example: + class A(BaseModel): + __schema__ = { + "field_name": Field( + required=[True|False], + default=[default], + type=[type] + ), + "cluster_id": Field( + required=True, + default="1234", + type=str + ), + } + + Raises: + ValidationError: Raised if a given arg doesn't satisfy the schema. """ def __init__(self, **kwargs): @@ -30,41 +58,62 @@ class BaseModel: @classmethod def parse_obj(cls, obj): + # Validation. assert type(obj) == dict, ("It can only parse dict type object.") - required_args = cls.__slots__ - given_args = obj.keys() - - # Check if given_args have args that is not required. - for arg in given_args: - if arg not in required_args: - raise ValidationError( - "Given argument has a key {}, which is not required " - "by this schema: {}".format(arg, required_args)) - - # Check if given args have all required args. - if len(required_args) != len(given_args): - raise ValidationError("Given args: {} doesn't have all the " - "necessary args for this schema: {}".format( - given_args, required_args)) + for field, schema in cls.__schema__.items(): + required, default, arg_type = schema + if field not in obj: + if required: + raise ValidationError("{} is required, but doesn't " + "exist in a given object {}".format( + field, obj)) + else: + # Set default value if the field is optional + obj[field] = default return cls(**obj) class IngestRequest(BaseModel): - __slots__ = [ - "cluster_id", "access_token", "ray_config", "node_info", "raylet_info", - "tune_info", "tune_availability" - ] + __schema__ = { + "cluster_id": Field(required=True, default=None, type=str), + "access_token": Field(required=True, default=None, type=str), + "ray_config": Field(required=True, default=None, type=tuple), + "node_info": Field(required=True, default=None, type=dict), + "raylet_info": Field(required=True, default=None, type=dict), + "tune_info": Field(required=True, default=None, type=dict), + "tune_availability": Field(required=True, default=None, type=dict) + } -# TODO(sang): Add piggybacked response. class IngestResponse(BaseModel): - pass + __schema__ = { + "succeed": Field(required=True, default=None, type=bool), + "actions": Field(required=False, default=[], type=list) + } class AuthRequest(BaseModel): - __slots__ = ["cluster_id"] + __schema__ = {"cluster_id": Field(required=True, default=None, type=str)} class AuthResponse(BaseModel): - __slots__ = ["dashboard_url", "access_token"] + __schema__ = { + "dashboard_url": Field(required=True, default=None, type=str), + "access_token": Field(required=True, default=None, type=str) + } + + +# Enum is not used because action types will be received +# through a network communication, and it will be string. +class ActionType: + KILL_ACTOR = "KILL_ACTOR" + + +class KillAction(BaseModel): + __schema__ = { + "type": Field(required=False, default=ActionType.KILL_ACTOR, type=str), + "actor_id": Field(required=True, default=None, type=str), + "ip_address": Field(required=True, default=None, type=str), + "port": Field(required=True, default=None, type=int) + } diff --git a/python/ray/tests/test_metrics_export.py b/python/ray/tests/test_metrics_export.py index 336ce7c3c..91d880f98 100644 --- a/python/ray/tests/test_metrics_export.py +++ b/python/ray/tests/test_metrics_export.py @@ -3,10 +3,11 @@ import requests from unittest.mock import patch +from ray.dashboard.metrics_exporter.actions import ActionHandler from ray.dashboard.metrics_exporter.client import MetricsExportClient from ray.dashboard.metrics_exporter.client import Exporter from ray.dashboard.metrics_exporter.schema import (AuthResponse, BaseModel, - ValidationError) + ValidationError, Field) MOCK_DASHBOARD_ID = "1234" MOCK_DASHBOARD_ADDRESS = "127.0.0.1:9081" @@ -20,6 +21,11 @@ def _setup_client_and_exporter(controller): return exporter, client +""" +Test Exporter +""" + + @patch("ray.dashboard.dashboard.DashboardController") def test_verify_exporter_cannot_run_without_access_token(mock_controller): exporter, client = _setup_client_and_exporter(mock_controller) @@ -28,6 +34,11 @@ def test_verify_exporter_cannot_run_without_access_token(mock_controller): exporter.run() +""" +Test Client +""" + + @patch("ray.dashboard.dashboard.DashboardController") @patch( "ray.dashboard.metrics_exporter.api.authentication_request", @@ -129,8 +140,13 @@ BaseModel Test def test_base_model(): + DEFAULT_VALUE = "default" + class A(BaseModel): - __slots__ = ["a", "b"] + __schema__ = { + "a": Field(required=True, default=None, type=str), + "b": Field(required=False, default=DEFAULT_VALUE, type=str) + } # Test the correct case. obj = {"a": "1", "b": "1"} @@ -152,20 +168,89 @@ def test_base_model(): with pytest.raises(AssertionError): a = A.parse_obj(obj) - # Test when fields are not sufficient. - obj = {"a": "1"} + # Test when required fields are not provided. + obj = {"b": "1"} with pytest.raises(ValidationError): a = A.parse_obj(obj) - # Test when fields are more than expected. - obj = {"a": "1", "b": "1", "c": "1"} + # Test optional fields are set to default when fields are not given. + obj = {"a": "1"} + a = A.parse_obj(obj) + assert a.b == DEFAULT_VALUE + + # Test when fields that are not defined in the schema is given. + # It should be ignoered + obj = {"a": "a", "b": "b", "c": "c"} + a = A.parse_obj(obj) + assert a.a == "a" + assert a.b == "b" + assert a.c == "c" + + +""" +Test Action Handler +""" + + +def _get_mock_kill_action(): + return { + "type": "KILL_ACTOR", + "actor_id": "1234", + "ip_address": "1234", + "port": 30 + } + + +@patch("ray.dashboard.dashboard.DashboardController") +def test_handle_kill_action(mock_controller): + action_handler = ActionHandler(mock_controller) + kill_action = _get_mock_kill_action() + action_handler.handle_kill_action(kill_action) + assert mock_controller.kill_actor.call_count == 1 + + +@patch("ray.dashboard.dashboard.DashboardController") +def test_handle_kill_action_invalid_dict(mock_controller): + action_handler = ActionHandler(mock_controller) + kill_action = {"type": "KILL_ACTOR", "ip_address": "1234", "port": 30} + with pytest.raises(ValidationError): - a = A.parse_obj(obj) + action_handler.handle_kill_action(kill_action) + + +@patch("ray.dashboard.dashboard.DashboardController") +def test_handle_actions_many_kill_actor(mock_controller): + action_handler = ActionHandler(mock_controller) + # 10 actions required. + actions = [_get_mock_kill_action() for _ in range(10)] + + action_handler.handle_actions(actions) + assert mock_controller.kill_actor.call_count == 10 + + +@patch("ray.dashboard.dashboard.DashboardController") +def test_handle_actions_kill_actor_and_mixed_type(mock_controller): + action_handler = ActionHandler(mock_controller) + wrong_type_action = {"type": "NON_EXIST"} + actions = [ + _get_mock_kill_action(), wrong_type_action, + _get_mock_kill_action() + ] + + action_handler.handle_actions(actions) + assert mock_controller.kill_actor.call_count == 2 + + +@patch("ray.dashboard.dashboard.DashboardController") +def test_handle_actions_only_wrong_type(mock_controller): + action_handler = ActionHandler(mock_controller) + wrong_type_action = {"type": "NON_EXIST"} + actions = [wrong_type_action for _ in range(10)] + + action_handler.handle_actions(actions) + assert mock_controller.kill_actor.call_count == 0 if __name__ == "__main__": import sys - import os - os.environ["LC_ALL"] = "en_US.UTF-8" - os.environ["LANG"] = "en_US.UTF-8" sys.exit(pytest.main(["-v", __file__])) diff --git a/python/ray/tests/test_webui.py b/python/ray/tests/test_webui.py index 2e284650d..42a683ec4 100644 --- a/python/ray/tests/test_webui.py +++ b/python/ray/tests/test_webui.py @@ -25,8 +25,20 @@ def test_get_webui(shutdown_only): break except requests.exceptions.ConnectionError: if time.time() > start_time + 30: + error_log = None + out_log = None + with open( + "{}/logs/dashboard.out".format( + addresses["session_dir"]), "r") as f: + out_log = f.read() + with open( + "{}/logs/dashboard.err".format( + addresses["session_dir"]), "r") as f: + error_log = f.read() raise Exception( - "Timed out while waiting for dashboard to start.") + "Timed out while waiting for dashboard to start. " + "Dashboard output log: {}\n" + "Dashboard error log: {}\n".format(out_log, error_log)) assert node_info["error"] is None assert node_info["result"] is not None assert isinstance(node_info["timestamp"], float)