diff --git a/dashboard/consts.py b/dashboard/consts.py index 3139dcc14..d1fde937b 100644 --- a/dashboard/consts.py +++ b/dashboard/consts.py @@ -21,6 +21,7 @@ AIOHTTP_CACHE_DISABLE_ENVIRONMENT_KEY = "RAY_DASHBOARD_NO_CACHE" # Named signals SIGNAL_NODE_INFO_FETCHED = "node_info_fetched" SIGNAL_NODE_SUMMARY_FETCHED = "node_summary_fetched" +SIGNAL_JOB_INFO_FETCHED = "job_info_fetched" SIGNAL_WORKER_INFO_FETCHED = "worker_info_fetched" # Default value for datacenter (the default value in protobuf) DEFAULT_LANGUAGE = "PYTHON" diff --git a/dashboard/datacenter.py b/dashboard/datacenter.py index c6b05d9e8..319354068 100644 --- a/dashboard/datacenter.py +++ b/dashboard/datacenter.py @@ -10,6 +10,7 @@ logger = logging.getLogger(__name__) class GlobalSignals: node_info_fetched = Signal(dashboard_consts.SIGNAL_NODE_INFO_FETCHED) node_summary_fetched = Signal(dashboard_consts.SIGNAL_NODE_SUMMARY_FETCHED) + job_info_fetched = Signal(dashboard_consts.SIGNAL_JOB_INFO_FETCHED) worker_info_fetched = Signal(dashboard_consts.SIGNAL_WORKER_INFO_FETCHED) @@ -22,6 +23,8 @@ class DataSource: # {actor id hex(str): actor table data(dict of ActorTableData # in gcs.proto)} actors = Dict() + # {job id hex(str): job table data(dict of JobTableData in gcs.proto)} + jobs = Dict() # {node id hex(str): dashboard agent [http port(int), grpc port(int)]} agents = Dict() # {node id hex(str): gcs node info(dict of GcsNodeInfo in gcs.proto)} diff --git a/dashboard/modules/job/__init__.py b/dashboard/modules/job/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/dashboard/modules/job/job_consts.py b/dashboard/modules/job/job_consts.py new file mode 100644 index 000000000..101756ac3 --- /dev/null +++ b/dashboard/modules/job/job_consts.py @@ -0,0 +1,2 @@ +JOB_CHANNEL = "JOB" +RETRY_GET_ALL_JOB_INFO_INTERVAL_SECONDS = 2 diff --git a/dashboard/modules/job/job_head.py b/dashboard/modules/job/job_head.py new file mode 100644 index 000000000..671482b82 --- /dev/null +++ b/dashboard/modules/job/job_head.py @@ -0,0 +1,124 @@ +import logging +import asyncio + +import aiohttp.web +from aioredis.pubsub import Receiver + +import ray +import ray.gcs_utils +import ray.new_dashboard.modules.job.job_consts as job_consts +import ray.new_dashboard.utils as dashboard_utils +from ray.core.generated import gcs_service_pb2 +from ray.core.generated import gcs_service_pb2_grpc +from ray.new_dashboard.datacenter import ( + DataSource, + DataOrganizer, + GlobalSignals, +) + +logger = logging.getLogger(__name__) +routes = dashboard_utils.ClassMethodRouteTable + + +def job_table_data_to_dict(message): + decode_keys = {"jobId", "rayletId"} + return dashboard_utils.message_to_dict( + message, decode_keys, including_default_value_fields=True) + + +class JobHead(dashboard_utils.DashboardHeadModule): + def __init__(self, dashboard_head): + super().__init__(dashboard_head) + # JobInfoGcsServiceStub + self._gcs_job_info_stub = None + + @routes.get("/jobs") + @dashboard_utils.aiohttp_cache + async def get_all_jobs(self, req) -> aiohttp.web.Response: + view = req.query.get("view") + if view == "summary": + return dashboard_utils.rest_response( + success=True, + message="All job summary fetched.", + summary=list(DataSource.jobs.values())) + else: + return dashboard_utils.rest_response( + success=False, message="Unknown view {}".format(view)) + + @routes.get("/jobs/{job_id}") + @dashboard_utils.aiohttp_cache + async def get_job(self, req) -> aiohttp.web.Response: + job_id = req.match_info.get("job_id") + view = req.query.get("view") + if view is None: + job_detail = { + "jobInfo": DataSource.jobs.get(job_id, {}), + "jobActors": await DataOrganizer.get_job_actors(job_id), + "jobWorkers": DataSource.job_workers.get(job_id, []), + } + await GlobalSignals.job_info_fetched.send(job_detail) + return dashboard_utils.rest_response( + success=True, message="Job detail fetched.", detail=job_detail) + else: + return dashboard_utils.rest_response( + success=False, message="Unknown view {}".format(view)) + + async def _update_jobs(self): + # Subscribe job channel. + aioredis_client = self._dashboard_head.aioredis_client + receiver = Receiver() + + key = f"{job_consts.JOB_CHANNEL}:*" + pattern = receiver.pattern(key) + await aioredis_client.psubscribe(pattern) + logger.info("Subscribed to %s", key) + + # Get all job info. + while True: + try: + logger.info("Getting all job info from GCS.") + request = gcs_service_pb2.GetAllJobInfoRequest() + reply = await self._gcs_job_info_stub.GetAllJobInfo( + request, timeout=5) + if reply.status.code == 0: + jobs = {} + for job_table_data in reply.job_info_list: + data = job_table_data_to_dict(job_table_data) + jobs[data["jobId"]] = data + # Update jobs. + DataSource.jobs.reset(jobs) + logger.info("Received %d job info from GCS.", len(jobs)) + break + else: + raise Exception( + f"Failed to GetAllJobInfo: {reply.status.message}") + except Exception: + logger.exception("Error Getting all job info from GCS.") + await asyncio.sleep( + job_consts.RETRY_GET_ALL_JOB_INFO_INTERVAL_SECONDS) + + # Receive jobs from channel. + async for sender, msg in receiver.iter(): + try: + _, data = msg + pubsub_message = ray.gcs_utils.PubSubMessage.FromString(data) + message = ray.gcs_utils.JobTableData.FromString( + pubsub_message.data) + job_id = ray._raylet.JobID(message.job_id) + if job_id.is_submitted_from_dashboard(): + job_table_data = job_table_data_to_dict(message) + job_id = job_table_data["jobId"] + # Update jobs. + DataSource.jobs[job_id] = job_table_data + else: + logger.info( + "Ignore job %s which is not submitted from dashboard.", + job_id.hex()) + except Exception: + logger.exception("Error receiving job info.") + + async def run(self, server): + self._gcs_job_info_stub = gcs_service_pb2_grpc.JobInfoGcsServiceStub( + self._dashboard_head.aiogrpc_gcs_channel) + + await asyncio.gather(self._update_jobs()) diff --git a/dashboard/modules/job/tests/test_job.py b/dashboard/modules/job/tests/test_job.py new file mode 100644 index 000000000..4a5e38727 --- /dev/null +++ b/dashboard/modules/job/tests/test_job.py @@ -0,0 +1,114 @@ +import os +import sys +import time +import logging +import requests +import traceback + +import ray +from ray.utils import hex_to_binary +from ray.new_dashboard.tests.conftest import * # noqa +from ray.test_utils import ( + format_web_url, + wait_until_server_available, +) +import pytest + +os.environ["RAY_USE_NEW_DASHBOARD"] = "1" + +logger = logging.getLogger(__name__) + + +def test_get_job_info(disable_aiohttp_cache, ray_start_with_dashboard): + @ray.remote + class Actor: + def getpid(self): + return os.getpid() + + actor = Actor.remote() + actor_pid = ray.get(actor.getpid.remote()) + actor_id = actor._actor_id.hex() + + assert (wait_until_server_available(ray_start_with_dashboard["webui_url"]) + is True) + webui_url = ray_start_with_dashboard["webui_url"] + webui_url = format_web_url(webui_url) + + ip = ray._private.services.get_node_ip_address() + + def _check(): + resp = requests.get(f"{webui_url}/jobs?view=summary") + resp.raise_for_status() + result = resp.json() + assert result["result"] is True, resp.text + job_summary = result["data"]["summary"] + assert len(job_summary) == 1 + one_job = job_summary[0] + assert "jobId" in one_job + job_id = one_job["jobId"] + assert ray._raylet.JobID(hex_to_binary(one_job["jobId"])) + assert "driverIpAddress" in one_job + assert one_job["driverIpAddress"] == ip + assert "driverPid" in one_job + assert one_job["driverPid"] == str(os.getpid()) + assert "config" in one_job + assert type(one_job["config"]) is dict + assert "isDead" in one_job + assert one_job["isDead"] is False + assert "timestamp" in one_job + one_job_summary_keys = one_job.keys() + + resp = requests.get(f"{webui_url}/jobs/{job_id}") + resp.raise_for_status() + result = resp.json() + assert result["result"] is True, resp.text + job_detail = result["data"]["detail"] + assert "jobInfo" in job_detail + assert len(one_job_summary_keys - job_detail["jobInfo"].keys()) == 0 + assert "jobActors" in job_detail + job_actors = job_detail["jobActors"] + assert len(job_actors) == 1 + one_job_actor = job_actors[actor_id] + assert "taskSpec" in one_job_actor + assert type(one_job_actor["taskSpec"]) is dict + assert "functionDescriptor" in one_job_actor["taskSpec"] + assert type(one_job_actor["taskSpec"]["functionDescriptor"]) is dict + assert "pid" in one_job_actor + assert one_job_actor["pid"] == actor_pid + check_actor_keys = [ + "name", "timestamp", "address", "actorId", "jobId", "state" + ] + for k in check_actor_keys: + assert k in one_job_actor + assert "jobWorkers" in job_detail + job_workers = job_detail["jobWorkers"] + assert len(job_workers) == 1 + one_job_worker = job_workers[0] + check_worker_keys = [ + "cmdline", "pid", "cpuTimes", "memoryInfo", "cpuPercent", + "coreWorkerStats", "language", "jobId" + ] + for k in check_worker_keys: + assert k in one_job_worker + + timeout_seconds = 5 + start_time = time.time() + last_ex = None + while True: + time.sleep(1) + try: + _check() + break + except (AssertionError, KeyError, IndexError) as ex: + last_ex = ex + finally: + if time.time() > start_time + timeout_seconds: + ex_stack = traceback.format_exception( + type(last_ex), last_ex, + last_ex.__traceback__) if last_ex else [] + ex_stack = "".join(ex_stack) + raise Exception(f"Timed out while testing, {ex_stack}") + + +if __name__ == "__main__": + sys.exit(pytest.main(["-v", __file__]))