mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 20:07:41 +08:00
[Dashboard] Authentication (#7888)
* Change authentication schema. Authentication implementation. * Formatting. * Fix a minor style. * Fix tests. * Removed url validation.
This commit is contained in:
@@ -1129,6 +1129,7 @@ if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
ray.utils.setup_logger(args.logging_level, args.logging_format)
|
||||
|
||||
# TODO(sang): Add a URL validation.
|
||||
metrics_export_address = os.environ.get("METRICS_EXPORT_ADDRESS")
|
||||
|
||||
try:
|
||||
|
||||
@@ -15,20 +15,24 @@ def authentication_request(url, cluster_id) -> AuthResponse:
|
||||
auth_requeset = AuthRequest(cluster_id=cluster_id)
|
||||
response = requests.post(url, data=auth_requeset.json())
|
||||
response.raise_for_status()
|
||||
return AuthResponse.parse_obj(response.json())
|
||||
return AuthResponse.parse_obj(json.loads(response.json()))
|
||||
|
||||
|
||||
def ingest_request(url, cluster_id, access_token, ray_config, node_info,
|
||||
raylet_info, tune_info,
|
||||
tune_availability) -> IngestResponse:
|
||||
def ingest_request(url, access_token, ray_config, node_info, raylet_info,
|
||||
tune_info, tune_availability) -> IngestResponse:
|
||||
ingest_request = IngestRequest(
|
||||
cluster_id=cluster_id,
|
||||
access_token=access_token,
|
||||
ray_config=ray_config,
|
||||
node_info=node_info,
|
||||
raylet_info=raylet_info,
|
||||
tune_info=tune_info,
|
||||
tune_availability=tune_availability)
|
||||
response = requests.post(url, data=ingest_request.json())
|
||||
response = requests.post(
|
||||
url,
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": "Bearer {access_token}".format(
|
||||
access_token=access_token)
|
||||
},
|
||||
data=ingest_request.json())
|
||||
response.raise_for_status()
|
||||
return IngestResponse.parse_obj(json.loads(response.json()))
|
||||
|
||||
@@ -16,7 +16,8 @@ class MetricsExportClient:
|
||||
multiple threads that export the same metrics.
|
||||
|
||||
Args:
|
||||
address(str): Address to export metrics
|
||||
address(str): Address to export metrics.
|
||||
This should include a web protocol.
|
||||
dashboard_controller(BaseDashboardController): Dashboard controller to
|
||||
run dashboard business logic.
|
||||
dashboard_id(str): Unique dashboard ID.
|
||||
@@ -25,6 +26,7 @@ class MetricsExportClient:
|
||||
|
||||
def __init__(self, address, dashboard_controller, dashboard_id, exporter):
|
||||
self.dashboard_id = dashboard_id
|
||||
self.address = address
|
||||
self.auth_url = "{}/auth".format(address)
|
||||
self.dashboard_controller = dashboard_controller
|
||||
self.exporter = exporter
|
||||
@@ -44,7 +46,9 @@ class MetricsExportClient:
|
||||
"""
|
||||
self.auth_info = api.authentication_request(self.auth_url,
|
||||
self.dashboard_id)
|
||||
self._dashboard_url = self.auth_info.dashboard_url
|
||||
self._dashboard_url = "{address}/dashboard/{access_token}".format(
|
||||
address=self.address,
|
||||
access_token=self.auth_info.access_token_dashboard)
|
||||
self.is_authenticated = True
|
||||
|
||||
@property
|
||||
@@ -77,7 +81,7 @@ class MetricsExportClient:
|
||||
logger.error(error)
|
||||
return False, error
|
||||
|
||||
self.exporter.access_token = self.auth_info.access_token
|
||||
self.exporter.access_token = self.auth_info.access_token_ingest
|
||||
self.exporter.start()
|
||||
self.is_exporting_started = True
|
||||
return True, None
|
||||
@@ -121,8 +125,8 @@ class Exporter(threading.Thread):
|
||||
def export(self, ray_config, node_info, raylet_info, tune_info,
|
||||
tune_availability):
|
||||
ingest_response = api.ingest_request(
|
||||
self.export_address, self.dashboard_id, self.access_token,
|
||||
ray_config, node_info, raylet_info, tune_info, tune_availability)
|
||||
self.export_address, self.access_token, ray_config, node_info,
|
||||
raylet_info, tune_info, tune_availability)
|
||||
actions = ingest_response.actions
|
||||
self.action_handler.handle_actions(actions)
|
||||
|
||||
|
||||
@@ -59,7 +59,8 @@ class BaseModel:
|
||||
@classmethod
|
||||
def parse_obj(cls, obj):
|
||||
# Validation.
|
||||
assert type(obj) == dict, ("It can only parse dict type object.")
|
||||
assert type(obj) == dict, ("It can only parse dict type object, "
|
||||
"but {} type is given.".format(type(obj)))
|
||||
for field, schema in cls.__schema__.items():
|
||||
required, default, arg_type = schema
|
||||
if field not in obj:
|
||||
@@ -76,8 +77,6 @@ class BaseModel:
|
||||
|
||||
class IngestRequest(BaseModel):
|
||||
__schema__ = {
|
||||
"cluster_id": Field(required=True, default=None, type=str),
|
||||
"access_token": Field(required=True, default=None, type=str),
|
||||
"ray_config": Field(required=True, default=None, type=tuple),
|
||||
"node_info": Field(required=True, default=None, type=dict),
|
||||
"raylet_info": Field(required=True, default=None, type=dict),
|
||||
@@ -99,8 +98,8 @@ class AuthRequest(BaseModel):
|
||||
|
||||
class AuthResponse(BaseModel):
|
||||
__schema__ = {
|
||||
"dashboard_url": Field(required=True, default=None, type=str),
|
||||
"access_token": Field(required=True, default=None, type=str)
|
||||
"access_token_dashboard": Field(required=True, default=None, type=str),
|
||||
"access_token_ingest": Field(required=True, default=None, type=str)
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -10,7 +10,7 @@ from ray.dashboard.metrics_exporter.schema import (AuthResponse, BaseModel,
|
||||
ValidationError, Field)
|
||||
|
||||
MOCK_DASHBOARD_ID = "1234"
|
||||
MOCK_DASHBOARD_ADDRESS = "127.0.0.1:9081"
|
||||
MOCK_DASHBOARD_ADDRESS = "http://127.0.0.1:9081"
|
||||
MOCK_ACCESS_TOKEN = "1234"
|
||||
|
||||
|
||||
@@ -63,12 +63,14 @@ def test_client_invalid_request_status_returned(auth_request, mock_controller):
|
||||
@patch("ray.dashboard.metrics_exporter.api.authentication_request")
|
||||
def test_authentication(auth_request, mock_controller):
|
||||
auth_request.return_value = AuthResponse(
|
||||
dashboard_url=MOCK_DASHBOARD_ADDRESS, access_token=MOCK_ACCESS_TOKEN)
|
||||
access_token_dashboard=MOCK_ACCESS_TOKEN,
|
||||
access_token_ingest=MOCK_ACCESS_TOKEN)
|
||||
exporter, client = _setup_client_and_exporter(mock_controller)
|
||||
|
||||
assert client.enabled is False
|
||||
client._authenticate()
|
||||
assert client.dashboard_url == MOCK_DASHBOARD_ADDRESS
|
||||
assert client.dashboard_url == "{address}/dashboard/{access_token}".format(
|
||||
address=MOCK_DASHBOARD_ADDRESS, access_token=MOCK_ACCESS_TOKEN)
|
||||
assert client.enabled is True
|
||||
|
||||
|
||||
@@ -82,7 +84,8 @@ def test_start_exporting_metrics_without_authentication(
|
||||
are not authenticated.
|
||||
"""
|
||||
auth_request.return_value = AuthResponse(
|
||||
dashboard_url=MOCK_DASHBOARD_ADDRESS, access_token=MOCK_ACCESS_TOKEN)
|
||||
access_token_dashboard=MOCK_ACCESS_TOKEN,
|
||||
access_token_ingest=MOCK_ACCESS_TOKEN)
|
||||
exporter, client = _setup_client_and_exporter(mock_controller)
|
||||
|
||||
# start_exporting_metrics should succeed.
|
||||
@@ -102,7 +105,8 @@ def test_start_exporting_metrics_with_authentication(auth_request,
|
||||
should not authenticate users.
|
||||
"""
|
||||
auth_request.return_value = AuthResponse(
|
||||
dashboard_url=MOCK_DASHBOARD_ADDRESS, access_token=MOCK_ACCESS_TOKEN)
|
||||
access_token_dashboard=MOCK_ACCESS_TOKEN,
|
||||
access_token_ingest=MOCK_ACCESS_TOKEN)
|
||||
exporter, client = _setup_client_and_exporter(mock_controller)
|
||||
# Already authenticated.
|
||||
client._authenticate()
|
||||
@@ -121,7 +125,8 @@ def test_start_exporting_metrics_with_authentication(auth_request,
|
||||
@patch("ray.dashboard.metrics_exporter.api.authentication_request")
|
||||
def test_start_exporting_metrics_succeed(auth_request, mock_controller, start):
|
||||
auth_request.return_value = AuthResponse(
|
||||
dashboard_url=MOCK_DASHBOARD_ADDRESS, access_token=MOCK_ACCESS_TOKEN)
|
||||
access_token_dashboard=MOCK_ACCESS_TOKEN,
|
||||
access_token_ingest=MOCK_ACCESS_TOKEN)
|
||||
exporter, client = _setup_client_and_exporter(mock_controller)
|
||||
|
||||
result, error = client.start_exporting_metrics()
|
||||
|
||||
Reference in New Issue
Block a user