[Dashboard ]Action Implementation. (#7826)

This commit is contained in:
SangBin Cho
2020-04-02 18:02:37 -07:00
committed by GitHub
parent a3181816b3
commit 1d532d1cb8
6 changed files with 228 additions and 46 deletions
@@ -0,0 +1,30 @@
import logging
import typing
from ray.dashboard.metrics_exporter.schema import ActionType, KillAction
logger = logging.getLogger(__name__)
class ActionHandler:
def __init__(self, dashboard_controller):
self.dashboard_controller = dashboard_controller
def handle_kill_action(self, action: dict):
kill_action = KillAction.parse_obj(action)
self.dashboard_controller.kill_actor(
kill_action.actor_id, kill_action.ip_address, kill_action.port)
def handle_actions(self, actions: typing.List[dict]):
for action in actions:
action_type = action.get("type", None)
if action_type == ActionType.KILL_ACTOR:
self.handle_kill_action(action)
else:
logger.warning("Action type {} has been received, but "
"action handler doesn't know how to handle "
"them. It will skip processing the request. "
"Plesae raise an issue if you see this problem."
.format(action_type))
continue
+6 -3
View File
@@ -1,3 +1,5 @@
import json
try:
import requests # `requests` is not part of stdlib.
except ImportError:
@@ -9,7 +11,7 @@ from ray.dashboard.metrics_exporter.schema import AuthRequest, AuthResponse
from ray.dashboard.metrics_exporter.schema import IngestRequest, IngestResponse
def authentication_request(url, cluster_id):
def authentication_request(url, cluster_id) -> AuthResponse:
auth_requeset = AuthRequest(cluster_id=cluster_id)
response = requests.post(url, data=auth_requeset.json())
response.raise_for_status()
@@ -17,7 +19,8 @@ def authentication_request(url, cluster_id):
def ingest_request(url, cluster_id, access_token, ray_config, node_info,
raylet_info, tune_info, tune_availability):
raylet_info, tune_info,
tune_availability) -> IngestResponse:
ingest_request = IngestRequest(
cluster_id=cluster_id,
access_token=access_token,
@@ -28,4 +31,4 @@ def ingest_request(url, cluster_id, access_token, ray_config, node_info,
tune_availability=tune_availability)
response = requests.post(url, data=ingest_request.json())
response.raise_for_status()
return IngestResponse.parse_obj(response.json())
return IngestResponse.parse_obj(json.loads(response.json()))
@@ -4,6 +4,7 @@ import traceback
import time
from ray.dashboard.metrics_exporter import api
from ray.dashboard.metrics_exporter.actions import ActionHandler
logger = logging.getLogger(__name__)
@@ -102,6 +103,7 @@ class Exporter(threading.Thread):
self.dashboard_id = dashboard_id
self.dashboard_controller = dashboard_controller
self.action_handler = ActionHandler(dashboard_controller)
self.export_address = "{}/ingest".format(address)
self.update_frequency = update_frequency
self._access_token = None
@@ -118,10 +120,11 @@ class Exporter(threading.Thread):
def export(self, ray_config, node_info, raylet_info, tune_info,
tune_availability):
api.ingest_request(self.export_address, self.dashboard_id,
self.access_token, ray_config, node_info,
raylet_info, tune_info, tune_availability)
# TODO(sang): Add piggybacking response handler.
ingest_response = api.ingest_request(
self.export_address, self.dashboard_id, self.access_token,
ray_config, node_info, raylet_info, tune_info, tune_availability)
actions = ingest_response.actions
self.action_handler.handle_actions(actions)
def run(self):
assert self.access_token is not None, (
+77 -28
View File
@@ -1,19 +1,47 @@
import json
from collections import namedtuple
class ValidationError(Exception):
pass
Field = namedtuple("Field", ["required", "default", "type"])
class BaseModel:
"""Base class to define schema.
This will raise ValidationError if
- Number of given kwargs are bigger than needed.
- Number of given kwargs are smaller than needed.
Model schema should be defined in class variable `__schema__`
within a child class. `__schema__` should be a dictionary that contains
`field`: `Field(
required=required: bool,
default=default: Any,
type=type: type
)`
See the example below for more details.
This doesn't
- Validate types.
The class can have unexpected behavior if you don't follow the
schema pattern properly.
Example:
class A(BaseModel):
__schema__ = {
"field_name": Field(
required=[True|False],
default=[default],
type=[type]
),
"cluster_id": Field(
required=True,
default="1234",
type=str
),
}
Raises:
ValidationError: Raised if a given arg doesn't satisfy the schema.
"""
def __init__(self, **kwargs):
@@ -30,41 +58,62 @@ class BaseModel:
@classmethod
def parse_obj(cls, obj):
# Validation.
assert type(obj) == dict, ("It can only parse dict type object.")
required_args = cls.__slots__
given_args = obj.keys()
# Check if given_args have args that is not required.
for arg in given_args:
if arg not in required_args:
raise ValidationError(
"Given argument has a key {}, which is not required "
"by this schema: {}".format(arg, required_args))
# Check if given args have all required args.
if len(required_args) != len(given_args):
raise ValidationError("Given args: {} doesn't have all the "
"necessary args for this schema: {}".format(
given_args, required_args))
for field, schema in cls.__schema__.items():
required, default, arg_type = schema
if field not in obj:
if required:
raise ValidationError("{} is required, but doesn't "
"exist in a given object {}".format(
field, obj))
else:
# Set default value if the field is optional
obj[field] = default
return cls(**obj)
class IngestRequest(BaseModel):
__slots__ = [
"cluster_id", "access_token", "ray_config", "node_info", "raylet_info",
"tune_info", "tune_availability"
]
__schema__ = {
"cluster_id": Field(required=True, default=None, type=str),
"access_token": Field(required=True, default=None, type=str),
"ray_config": Field(required=True, default=None, type=tuple),
"node_info": Field(required=True, default=None, type=dict),
"raylet_info": Field(required=True, default=None, type=dict),
"tune_info": Field(required=True, default=None, type=dict),
"tune_availability": Field(required=True, default=None, type=dict)
}
# TODO(sang): Add piggybacked response.
class IngestResponse(BaseModel):
pass
__schema__ = {
"succeed": Field(required=True, default=None, type=bool),
"actions": Field(required=False, default=[], type=list)
}
class AuthRequest(BaseModel):
__slots__ = ["cluster_id"]
__schema__ = {"cluster_id": Field(required=True, default=None, type=str)}
class AuthResponse(BaseModel):
__slots__ = ["dashboard_url", "access_token"]
__schema__ = {
"dashboard_url": Field(required=True, default=None, type=str),
"access_token": Field(required=True, default=None, type=str)
}
# Enum is not used because action types will be received
# through a network communication, and it will be string.
class ActionType:
KILL_ACTOR = "KILL_ACTOR"
class KillAction(BaseModel):
__schema__ = {
"type": Field(required=False, default=ActionType.KILL_ACTOR, type=str),
"actor_id": Field(required=True, default=None, type=str),
"ip_address": Field(required=True, default=None, type=str),
"port": Field(required=True, default=None, type=int)
}
+95 -10
View File
@@ -3,10 +3,11 @@ import requests
from unittest.mock import patch
from ray.dashboard.metrics_exporter.actions import ActionHandler
from ray.dashboard.metrics_exporter.client import MetricsExportClient
from ray.dashboard.metrics_exporter.client import Exporter
from ray.dashboard.metrics_exporter.schema import (AuthResponse, BaseModel,
ValidationError)
ValidationError, Field)
MOCK_DASHBOARD_ID = "1234"
MOCK_DASHBOARD_ADDRESS = "127.0.0.1:9081"
@@ -20,6 +21,11 @@ def _setup_client_and_exporter(controller):
return exporter, client
"""
Test Exporter
"""
@patch("ray.dashboard.dashboard.DashboardController")
def test_verify_exporter_cannot_run_without_access_token(mock_controller):
exporter, client = _setup_client_and_exporter(mock_controller)
@@ -28,6 +34,11 @@ def test_verify_exporter_cannot_run_without_access_token(mock_controller):
exporter.run()
"""
Test Client
"""
@patch("ray.dashboard.dashboard.DashboardController")
@patch(
"ray.dashboard.metrics_exporter.api.authentication_request",
@@ -129,8 +140,13 @@ BaseModel Test
def test_base_model():
DEFAULT_VALUE = "default"
class A(BaseModel):
__slots__ = ["a", "b"]
__schema__ = {
"a": Field(required=True, default=None, type=str),
"b": Field(required=False, default=DEFAULT_VALUE, type=str)
}
# Test the correct case.
obj = {"a": "1", "b": "1"}
@@ -152,20 +168,89 @@ def test_base_model():
with pytest.raises(AssertionError):
a = A.parse_obj(obj)
# Test when fields are not sufficient.
obj = {"a": "1"}
# Test when required fields are not provided.
obj = {"b": "1"}
with pytest.raises(ValidationError):
a = A.parse_obj(obj)
# Test when fields are more than expected.
obj = {"a": "1", "b": "1", "c": "1"}
# Test optional fields are set to default when fields are not given.
obj = {"a": "1"}
a = A.parse_obj(obj)
assert a.b == DEFAULT_VALUE
# Test when fields that are not defined in the schema is given.
# It should be ignoered
obj = {"a": "a", "b": "b", "c": "c"}
a = A.parse_obj(obj)
assert a.a == "a"
assert a.b == "b"
assert a.c == "c"
"""
Test Action Handler
"""
def _get_mock_kill_action():
return {
"type": "KILL_ACTOR",
"actor_id": "1234",
"ip_address": "1234",
"port": 30
}
@patch("ray.dashboard.dashboard.DashboardController")
def test_handle_kill_action(mock_controller):
action_handler = ActionHandler(mock_controller)
kill_action = _get_mock_kill_action()
action_handler.handle_kill_action(kill_action)
assert mock_controller.kill_actor.call_count == 1
@patch("ray.dashboard.dashboard.DashboardController")
def test_handle_kill_action_invalid_dict(mock_controller):
action_handler = ActionHandler(mock_controller)
kill_action = {"type": "KILL_ACTOR", "ip_address": "1234", "port": 30}
with pytest.raises(ValidationError):
a = A.parse_obj(obj)
action_handler.handle_kill_action(kill_action)
@patch("ray.dashboard.dashboard.DashboardController")
def test_handle_actions_many_kill_actor(mock_controller):
action_handler = ActionHandler(mock_controller)
# 10 actions required.
actions = [_get_mock_kill_action() for _ in range(10)]
action_handler.handle_actions(actions)
assert mock_controller.kill_actor.call_count == 10
@patch("ray.dashboard.dashboard.DashboardController")
def test_handle_actions_kill_actor_and_mixed_type(mock_controller):
action_handler = ActionHandler(mock_controller)
wrong_type_action = {"type": "NON_EXIST"}
actions = [
_get_mock_kill_action(), wrong_type_action,
_get_mock_kill_action()
]
action_handler.handle_actions(actions)
assert mock_controller.kill_actor.call_count == 2
@patch("ray.dashboard.dashboard.DashboardController")
def test_handle_actions_only_wrong_type(mock_controller):
action_handler = ActionHandler(mock_controller)
wrong_type_action = {"type": "NON_EXIST"}
actions = [wrong_type_action for _ in range(10)]
action_handler.handle_actions(actions)
assert mock_controller.kill_actor.call_count == 0
if __name__ == "__main__":
import sys
import os
os.environ["LC_ALL"] = "en_US.UTF-8"
os.environ["LANG"] = "en_US.UTF-8"
sys.exit(pytest.main(["-v", __file__]))
+13 -1
View File
@@ -25,8 +25,20 @@ def test_get_webui(shutdown_only):
break
except requests.exceptions.ConnectionError:
if time.time() > start_time + 30:
error_log = None
out_log = None
with open(
"{}/logs/dashboard.out".format(
addresses["session_dir"]), "r") as f:
out_log = f.read()
with open(
"{}/logs/dashboard.err".format(
addresses["session_dir"]), "r") as f:
error_log = f.read()
raise Exception(
"Timed out while waiting for dashboard to start.")
"Timed out while waiting for dashboard to start. "
"Dashboard output log: {}\n"
"Dashboard error log: {}\n".format(out_log, error_log))
assert node_info["error"] is None
assert node_info["result"] is not None
assert isinstance(node_info["timestamp"], float)