mirror of
https://github.com/wassname/ray.git
synced 2026-07-01 01:43:50 +08:00
[Dashboard ]Action Implementation. (#7826)
This commit is contained in:
@@ -0,0 +1,30 @@
|
||||
import logging
|
||||
import typing
|
||||
|
||||
from ray.dashboard.metrics_exporter.schema import ActionType, KillAction
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ActionHandler:
|
||||
def __init__(self, dashboard_controller):
|
||||
self.dashboard_controller = dashboard_controller
|
||||
|
||||
def handle_kill_action(self, action: dict):
|
||||
kill_action = KillAction.parse_obj(action)
|
||||
self.dashboard_controller.kill_actor(
|
||||
kill_action.actor_id, kill_action.ip_address, kill_action.port)
|
||||
|
||||
def handle_actions(self, actions: typing.List[dict]):
|
||||
for action in actions:
|
||||
action_type = action.get("type", None)
|
||||
|
||||
if action_type == ActionType.KILL_ACTOR:
|
||||
self.handle_kill_action(action)
|
||||
else:
|
||||
logger.warning("Action type {} has been received, but "
|
||||
"action handler doesn't know how to handle "
|
||||
"them. It will skip processing the request. "
|
||||
"Plesae raise an issue if you see this problem."
|
||||
.format(action_type))
|
||||
continue
|
||||
@@ -1,3 +1,5 @@
|
||||
import json
|
||||
|
||||
try:
|
||||
import requests # `requests` is not part of stdlib.
|
||||
except ImportError:
|
||||
@@ -9,7 +11,7 @@ from ray.dashboard.metrics_exporter.schema import AuthRequest, AuthResponse
|
||||
from ray.dashboard.metrics_exporter.schema import IngestRequest, IngestResponse
|
||||
|
||||
|
||||
def authentication_request(url, cluster_id):
|
||||
def authentication_request(url, cluster_id) -> AuthResponse:
|
||||
auth_requeset = AuthRequest(cluster_id=cluster_id)
|
||||
response = requests.post(url, data=auth_requeset.json())
|
||||
response.raise_for_status()
|
||||
@@ -17,7 +19,8 @@ def authentication_request(url, cluster_id):
|
||||
|
||||
|
||||
def ingest_request(url, cluster_id, access_token, ray_config, node_info,
|
||||
raylet_info, tune_info, tune_availability):
|
||||
raylet_info, tune_info,
|
||||
tune_availability) -> IngestResponse:
|
||||
ingest_request = IngestRequest(
|
||||
cluster_id=cluster_id,
|
||||
access_token=access_token,
|
||||
@@ -28,4 +31,4 @@ def ingest_request(url, cluster_id, access_token, ray_config, node_info,
|
||||
tune_availability=tune_availability)
|
||||
response = requests.post(url, data=ingest_request.json())
|
||||
response.raise_for_status()
|
||||
return IngestResponse.parse_obj(response.json())
|
||||
return IngestResponse.parse_obj(json.loads(response.json()))
|
||||
|
||||
@@ -4,6 +4,7 @@ import traceback
|
||||
import time
|
||||
|
||||
from ray.dashboard.metrics_exporter import api
|
||||
from ray.dashboard.metrics_exporter.actions import ActionHandler
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -102,6 +103,7 @@ class Exporter(threading.Thread):
|
||||
|
||||
self.dashboard_id = dashboard_id
|
||||
self.dashboard_controller = dashboard_controller
|
||||
self.action_handler = ActionHandler(dashboard_controller)
|
||||
self.export_address = "{}/ingest".format(address)
|
||||
self.update_frequency = update_frequency
|
||||
self._access_token = None
|
||||
@@ -118,10 +120,11 @@ class Exporter(threading.Thread):
|
||||
|
||||
def export(self, ray_config, node_info, raylet_info, tune_info,
|
||||
tune_availability):
|
||||
api.ingest_request(self.export_address, self.dashboard_id,
|
||||
self.access_token, ray_config, node_info,
|
||||
raylet_info, tune_info, tune_availability)
|
||||
# TODO(sang): Add piggybacking response handler.
|
||||
ingest_response = api.ingest_request(
|
||||
self.export_address, self.dashboard_id, self.access_token,
|
||||
ray_config, node_info, raylet_info, tune_info, tune_availability)
|
||||
actions = ingest_response.actions
|
||||
self.action_handler.handle_actions(actions)
|
||||
|
||||
def run(self):
|
||||
assert self.access_token is not None, (
|
||||
|
||||
@@ -1,19 +1,47 @@
|
||||
import json
|
||||
|
||||
from collections import namedtuple
|
||||
|
||||
|
||||
class ValidationError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
Field = namedtuple("Field", ["required", "default", "type"])
|
||||
|
||||
|
||||
class BaseModel:
|
||||
"""Base class to define schema.
|
||||
|
||||
This will raise ValidationError if
|
||||
- Number of given kwargs are bigger than needed.
|
||||
- Number of given kwargs are smaller than needed.
|
||||
Model schema should be defined in class variable `__schema__`
|
||||
within a child class. `__schema__` should be a dictionary that contains
|
||||
`field`: `Field(
|
||||
required=required: bool,
|
||||
default=default: Any,
|
||||
type=type: type
|
||||
)`
|
||||
See the example below for more details.
|
||||
|
||||
This doesn't
|
||||
- Validate types.
|
||||
The class can have unexpected behavior if you don't follow the
|
||||
schema pattern properly.
|
||||
|
||||
Example:
|
||||
class A(BaseModel):
|
||||
__schema__ = {
|
||||
"field_name": Field(
|
||||
required=[True|False],
|
||||
default=[default],
|
||||
type=[type]
|
||||
),
|
||||
"cluster_id": Field(
|
||||
required=True,
|
||||
default="1234",
|
||||
type=str
|
||||
),
|
||||
}
|
||||
|
||||
Raises:
|
||||
ValidationError: Raised if a given arg doesn't satisfy the schema.
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
@@ -30,41 +58,62 @@ class BaseModel:
|
||||
|
||||
@classmethod
|
||||
def parse_obj(cls, obj):
|
||||
# Validation.
|
||||
assert type(obj) == dict, ("It can only parse dict type object.")
|
||||
required_args = cls.__slots__
|
||||
given_args = obj.keys()
|
||||
|
||||
# Check if given_args have args that is not required.
|
||||
for arg in given_args:
|
||||
if arg not in required_args:
|
||||
raise ValidationError(
|
||||
"Given argument has a key {}, which is not required "
|
||||
"by this schema: {}".format(arg, required_args))
|
||||
|
||||
# Check if given args have all required args.
|
||||
if len(required_args) != len(given_args):
|
||||
raise ValidationError("Given args: {} doesn't have all the "
|
||||
"necessary args for this schema: {}".format(
|
||||
given_args, required_args))
|
||||
for field, schema in cls.__schema__.items():
|
||||
required, default, arg_type = schema
|
||||
if field not in obj:
|
||||
if required:
|
||||
raise ValidationError("{} is required, but doesn't "
|
||||
"exist in a given object {}".format(
|
||||
field, obj))
|
||||
else:
|
||||
# Set default value if the field is optional
|
||||
obj[field] = default
|
||||
|
||||
return cls(**obj)
|
||||
|
||||
|
||||
class IngestRequest(BaseModel):
|
||||
__slots__ = [
|
||||
"cluster_id", "access_token", "ray_config", "node_info", "raylet_info",
|
||||
"tune_info", "tune_availability"
|
||||
]
|
||||
__schema__ = {
|
||||
"cluster_id": Field(required=True, default=None, type=str),
|
||||
"access_token": Field(required=True, default=None, type=str),
|
||||
"ray_config": Field(required=True, default=None, type=tuple),
|
||||
"node_info": Field(required=True, default=None, type=dict),
|
||||
"raylet_info": Field(required=True, default=None, type=dict),
|
||||
"tune_info": Field(required=True, default=None, type=dict),
|
||||
"tune_availability": Field(required=True, default=None, type=dict)
|
||||
}
|
||||
|
||||
|
||||
# TODO(sang): Add piggybacked response.
|
||||
class IngestResponse(BaseModel):
|
||||
pass
|
||||
__schema__ = {
|
||||
"succeed": Field(required=True, default=None, type=bool),
|
||||
"actions": Field(required=False, default=[], type=list)
|
||||
}
|
||||
|
||||
|
||||
class AuthRequest(BaseModel):
|
||||
__slots__ = ["cluster_id"]
|
||||
__schema__ = {"cluster_id": Field(required=True, default=None, type=str)}
|
||||
|
||||
|
||||
class AuthResponse(BaseModel):
|
||||
__slots__ = ["dashboard_url", "access_token"]
|
||||
__schema__ = {
|
||||
"dashboard_url": Field(required=True, default=None, type=str),
|
||||
"access_token": Field(required=True, default=None, type=str)
|
||||
}
|
||||
|
||||
|
||||
# Enum is not used because action types will be received
|
||||
# through a network communication, and it will be string.
|
||||
class ActionType:
|
||||
KILL_ACTOR = "KILL_ACTOR"
|
||||
|
||||
|
||||
class KillAction(BaseModel):
|
||||
__schema__ = {
|
||||
"type": Field(required=False, default=ActionType.KILL_ACTOR, type=str),
|
||||
"actor_id": Field(required=True, default=None, type=str),
|
||||
"ip_address": Field(required=True, default=None, type=str),
|
||||
"port": Field(required=True, default=None, type=int)
|
||||
}
|
||||
|
||||
@@ -3,10 +3,11 @@ import requests
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
from ray.dashboard.metrics_exporter.actions import ActionHandler
|
||||
from ray.dashboard.metrics_exporter.client import MetricsExportClient
|
||||
from ray.dashboard.metrics_exporter.client import Exporter
|
||||
from ray.dashboard.metrics_exporter.schema import (AuthResponse, BaseModel,
|
||||
ValidationError)
|
||||
ValidationError, Field)
|
||||
|
||||
MOCK_DASHBOARD_ID = "1234"
|
||||
MOCK_DASHBOARD_ADDRESS = "127.0.0.1:9081"
|
||||
@@ -20,6 +21,11 @@ def _setup_client_and_exporter(controller):
|
||||
return exporter, client
|
||||
|
||||
|
||||
"""
|
||||
Test Exporter
|
||||
"""
|
||||
|
||||
|
||||
@patch("ray.dashboard.dashboard.DashboardController")
|
||||
def test_verify_exporter_cannot_run_without_access_token(mock_controller):
|
||||
exporter, client = _setup_client_and_exporter(mock_controller)
|
||||
@@ -28,6 +34,11 @@ def test_verify_exporter_cannot_run_without_access_token(mock_controller):
|
||||
exporter.run()
|
||||
|
||||
|
||||
"""
|
||||
Test Client
|
||||
"""
|
||||
|
||||
|
||||
@patch("ray.dashboard.dashboard.DashboardController")
|
||||
@patch(
|
||||
"ray.dashboard.metrics_exporter.api.authentication_request",
|
||||
@@ -129,8 +140,13 @@ BaseModel Test
|
||||
|
||||
|
||||
def test_base_model():
|
||||
DEFAULT_VALUE = "default"
|
||||
|
||||
class A(BaseModel):
|
||||
__slots__ = ["a", "b"]
|
||||
__schema__ = {
|
||||
"a": Field(required=True, default=None, type=str),
|
||||
"b": Field(required=False, default=DEFAULT_VALUE, type=str)
|
||||
}
|
||||
|
||||
# Test the correct case.
|
||||
obj = {"a": "1", "b": "1"}
|
||||
@@ -152,20 +168,89 @@ def test_base_model():
|
||||
with pytest.raises(AssertionError):
|
||||
a = A.parse_obj(obj)
|
||||
|
||||
# Test when fields are not sufficient.
|
||||
obj = {"a": "1"}
|
||||
# Test when required fields are not provided.
|
||||
obj = {"b": "1"}
|
||||
with pytest.raises(ValidationError):
|
||||
a = A.parse_obj(obj)
|
||||
|
||||
# Test when fields are more than expected.
|
||||
obj = {"a": "1", "b": "1", "c": "1"}
|
||||
# Test optional fields are set to default when fields are not given.
|
||||
obj = {"a": "1"}
|
||||
a = A.parse_obj(obj)
|
||||
assert a.b == DEFAULT_VALUE
|
||||
|
||||
# Test when fields that are not defined in the schema is given.
|
||||
# It should be ignoered
|
||||
obj = {"a": "a", "b": "b", "c": "c"}
|
||||
a = A.parse_obj(obj)
|
||||
assert a.a == "a"
|
||||
assert a.b == "b"
|
||||
assert a.c == "c"
|
||||
|
||||
|
||||
"""
|
||||
Test Action Handler
|
||||
"""
|
||||
|
||||
|
||||
def _get_mock_kill_action():
|
||||
return {
|
||||
"type": "KILL_ACTOR",
|
||||
"actor_id": "1234",
|
||||
"ip_address": "1234",
|
||||
"port": 30
|
||||
}
|
||||
|
||||
|
||||
@patch("ray.dashboard.dashboard.DashboardController")
|
||||
def test_handle_kill_action(mock_controller):
|
||||
action_handler = ActionHandler(mock_controller)
|
||||
kill_action = _get_mock_kill_action()
|
||||
action_handler.handle_kill_action(kill_action)
|
||||
assert mock_controller.kill_actor.call_count == 1
|
||||
|
||||
|
||||
@patch("ray.dashboard.dashboard.DashboardController")
|
||||
def test_handle_kill_action_invalid_dict(mock_controller):
|
||||
action_handler = ActionHandler(mock_controller)
|
||||
kill_action = {"type": "KILL_ACTOR", "ip_address": "1234", "port": 30}
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
a = A.parse_obj(obj)
|
||||
action_handler.handle_kill_action(kill_action)
|
||||
|
||||
|
||||
@patch("ray.dashboard.dashboard.DashboardController")
|
||||
def test_handle_actions_many_kill_actor(mock_controller):
|
||||
action_handler = ActionHandler(mock_controller)
|
||||
# 10 actions required.
|
||||
actions = [_get_mock_kill_action() for _ in range(10)]
|
||||
|
||||
action_handler.handle_actions(actions)
|
||||
assert mock_controller.kill_actor.call_count == 10
|
||||
|
||||
|
||||
@patch("ray.dashboard.dashboard.DashboardController")
|
||||
def test_handle_actions_kill_actor_and_mixed_type(mock_controller):
|
||||
action_handler = ActionHandler(mock_controller)
|
||||
wrong_type_action = {"type": "NON_EXIST"}
|
||||
actions = [
|
||||
_get_mock_kill_action(), wrong_type_action,
|
||||
_get_mock_kill_action()
|
||||
]
|
||||
|
||||
action_handler.handle_actions(actions)
|
||||
assert mock_controller.kill_actor.call_count == 2
|
||||
|
||||
|
||||
@patch("ray.dashboard.dashboard.DashboardController")
|
||||
def test_handle_actions_only_wrong_type(mock_controller):
|
||||
action_handler = ActionHandler(mock_controller)
|
||||
wrong_type_action = {"type": "NON_EXIST"}
|
||||
actions = [wrong_type_action for _ in range(10)]
|
||||
|
||||
action_handler.handle_actions(actions)
|
||||
assert mock_controller.kill_actor.call_count == 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
import os
|
||||
os.environ["LC_ALL"] = "en_US.UTF-8"
|
||||
os.environ["LANG"] = "en_US.UTF-8"
|
||||
sys.exit(pytest.main(["-v", __file__]))
|
||||
|
||||
@@ -25,8 +25,20 @@ def test_get_webui(shutdown_only):
|
||||
break
|
||||
except requests.exceptions.ConnectionError:
|
||||
if time.time() > start_time + 30:
|
||||
error_log = None
|
||||
out_log = None
|
||||
with open(
|
||||
"{}/logs/dashboard.out".format(
|
||||
addresses["session_dir"]), "r") as f:
|
||||
out_log = f.read()
|
||||
with open(
|
||||
"{}/logs/dashboard.err".format(
|
||||
addresses["session_dir"]), "r") as f:
|
||||
error_log = f.read()
|
||||
raise Exception(
|
||||
"Timed out while waiting for dashboard to start.")
|
||||
"Timed out while waiting for dashboard to start. "
|
||||
"Dashboard output log: {}\n"
|
||||
"Dashboard error log: {}\n".format(out_log, error_log))
|
||||
assert node_info["error"] is None
|
||||
assert node_info["result"] is not None
|
||||
assert isinstance(node_info["timestamp"], float)
|
||||
|
||||
Reference in New Issue
Block a user