From 05c103af9480d459ce7d1eb246c059e17e77e5f1 Mon Sep 17 00:00:00 2001 From: fyrestone Date: Tue, 25 Aug 2020 04:24:23 +0800 Subject: [PATCH] [Dashboard] Start the new dashboard (#10131) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Use new dashboard if environment var RAY_USE_NEW_DASHBOARD exists; new dashboard startup * Make fake client/build/static directory for dashboard * Add test_dashboard.py for new dashboard * Travis CI enable new dashboard test * Update new dashboard * Agent manager service * Add agent manager * Register agent to agent manager * Add a new line to the end of agent_manager.cc * Fix merge; Fix lint * Update dashboard/agent.py Co-authored-by: SangBin Cho * Update dashboard/head.py Co-authored-by: SangBin Cho * Fix bug * Add tests for dashboard * Fix * Remove const from Process::Kill() & Fix bugs * Revert error check of execute_after * Raise exception from DashboardAgent.run * Add more tests. * Fix compile on Linux * Use dict comprehension instead of dict(generator) * Fix lint * Fix windows compile * Fix lint * Test Windows CI * Revert "Test Windows CI" This reverts commit 945e01051ec95cff5fcc1c0bc37045b46e7ad9a6. * Fix ParseWindowsCommandLine bug * Update src/ray/util/util.cc Co-authored-by: Robert Nishihara Co-authored-by: 刘宝 Co-authored-by: SangBin Cho Co-authored-by: Robert Nishihara --- .travis.yml | 3 + BUILD.bazel | 30 +++ ci/travis/determine_tests_to_run.py | 2 + dashboard/BUILD | 13 + dashboard/agent.py | 122 +++++++--- dashboard/client/build/.fakebuild | 0 dashboard/client/build/static/.fakestatic | 0 dashboard/consts.py | 6 +- dashboard/dashboard.py | 49 +--- dashboard/head.py | 80 +++++-- dashboard/modules/reporter/reporter_agent.py | 33 +++ dashboard/modules/reporter/reporter_head.py | 14 +- dashboard/modules/test/__init__.py | 0 dashboard/modules/test/test_agent.py | 24 ++ dashboard/modules/test/test_head.py | 62 +++++ dashboard/modules/test/test_utils.py | 11 + dashboard/tests/conftest.py | 1 + dashboard/tests/test_dashboard.py | 224 ++++++++++++++++++ dashboard/utils.py | 53 ++++- python/ray/node.py | 10 +- python/ray/services.py | 28 ++- python/setup.py | 3 +- src/ray/common/ray_config_def.h | 6 + src/ray/gcs/gcs_server/gcs_actor_scheduler.cc | 12 +- src/ray/protobuf/BUILD | 23 ++ src/ray/protobuf/agent_manager.proto | 38 +++ src/ray/raylet/agent_manager.cc | 96 ++++++++ src/ray/raylet/agent_manager.h | 75 ++++++ src/ray/raylet/main.cc | 7 + src/ray/raylet/node_manager.cc | 14 ++ src/ray/raylet/node_manager.h | 9 + .../rpc/agent_manager/agent_manager_client.h | 50 ++++ .../rpc/agent_manager/agent_manager_server.h | 73 ++++++ src/ray/util/asio_util.h | 6 +- src/ray/util/util.cc | 5 + 35 files changed, 1079 insertions(+), 103 deletions(-) create mode 100644 dashboard/BUILD create mode 100644 dashboard/client/build/.fakebuild create mode 100644 dashboard/client/build/static/.fakestatic create mode 100644 dashboard/modules/test/__init__.py create mode 100644 dashboard/modules/test/test_agent.py create mode 100644 dashboard/modules/test/test_head.py create mode 100644 dashboard/modules/test/test_utils.py create mode 100644 dashboard/tests/conftest.py create mode 100644 dashboard/tests/test_dashboard.py create mode 100644 src/ray/protobuf/agent_manager.proto create mode 100644 src/ray/raylet/agent_manager.cc create mode 100644 src/ray/raylet/agent_manager.h create mode 100644 src/ray/rpc/agent_manager/agent_manager_client.h create mode 100644 src/ray/rpc/agent_manager/agent_manager_server.h diff --git a/.travis.yml b/.travis.yml index 0e17dd5f8..28b93f025 100644 --- a/.travis.yml +++ b/.travis.yml @@ -353,6 +353,9 @@ script: # ray dashboard tests - if [ "$RAY_CI_DASHBOARD_AFFECTED" == "1" ]; then ./ci/keep_alive bazel test python/ray/dashboard/...; fi + # ray new dashboard tests + - if [ "$RAY_CI_DASHBOARD_AFFECTED" == "1" ]; then ./ci/keep_alive bazel test python/ray/new_dashboard/...; fi + # ray operator tests - (cd deploy/ray-operator && export CC=gcc && suppress_output go build && suppress_output go test ./...) diff --git a/BUILD.bazel b/BUILD.bazel index 7116f1154..1d5861afe 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -175,6 +175,30 @@ cc_library( ], ) +# Agent manager. +cc_grpc_library( + name = "agent_manager_cc_grpc", + srcs = ["//src/ray/protobuf:agent_manager_proto"], + grpc_only = True, + deps = ["//src/ray/protobuf:agent_manager_cc_proto"], +) + +cc_library( + name = "agent_manager_rpc", + hdrs = glob([ + "src/ray/rpc/agent_manager/*.h", + ]), + copts = COPTS, + strip_include_prefix = "src", + deps = [ + ":agent_manager_cc_grpc", + ":grpc_common_lib", + ":ray_common", + "@boost//:asio", + "@com_github_grpc_grpc//:grpc++", + ], +) + # === End of rpc definitions === # === Begin of plasma definitions === @@ -415,6 +439,7 @@ cc_library( copts = COPTS, strip_include_prefix = "src", deps = [ + ":agent_manager_rpc", ":gcs", ":gcs_pub_sub_lib", ":gcs_service_rpc", @@ -506,6 +531,7 @@ cc_library( strip_include_prefix = "src", visibility = ["//streaming:__subpackages__"], deps = [ + ":agent_manager_rpc", ":gcs", ":node_manager_fbs", ":node_manager_rpc", @@ -549,6 +575,7 @@ cc_library( strip_include_prefix = "src", visibility = ["//streaming:__subpackages__"], deps = [ + ":agent_manager_rpc", ":node_manager_fbs", ":node_manager_rpc", ":ray_common", @@ -1472,6 +1499,7 @@ cc_library( copts = COPTS, strip_include_prefix = "src", deps = [ + ":agent_manager_rpc", ":hiredis", ":node_manager_fbs", ":node_manager_rpc", @@ -1811,9 +1839,11 @@ cc_binary( filegroup( name = "all_py_proto", srcs = [ + "//src/ray/protobuf:agent_manager_py_proto", "//src/ray/protobuf:common_py_proto", "//src/ray/protobuf:core_worker_py_proto", "//src/ray/protobuf:gcs_py_proto", + "//src/ray/protobuf:gcs_service_py_proto", "//src/ray/protobuf:node_manager_py_proto", "//src/ray/protobuf:reporter_py_proto", ], diff --git a/ci/travis/determine_tests_to_run.py b/ci/travis/determine_tests_to_run.py index be3997a3e..93dbf30c3 100644 --- a/ci/travis/determine_tests_to_run.py +++ b/ci/travis/determine_tests_to_run.py @@ -92,6 +92,8 @@ if __name__ == "__main__": RAY_CI_MACOS_WHEELS_AFFECTED = 1 elif changed_file.startswith("python/ray/dashboard"): RAY_CI_DASHBOARD_AFFECTED = 1 + elif changed_file.startswith("dashboard"): + RAY_CI_DASHBOARD_AFFECTED = 1 elif changed_file.startswith("python/"): RAY_CI_TUNE_AFFECTED = 1 RAY_CI_SGD_AFFECTED = 1 diff --git a/dashboard/BUILD b/dashboard/BUILD new file mode 100644 index 000000000..15ed537e6 --- /dev/null +++ b/dashboard/BUILD @@ -0,0 +1,13 @@ +# This is a dummy test dependency that causes the above tests to be +# re-run if any of these files changes. +py_library( + name = "dashboard_lib", + srcs = glob(["**/*.py"],exclude=["tests/*"]), +) + +py_test( + name = "test_dashboard", + size = "small", + srcs = glob(["tests/*.py"]), + deps = [":dashboard_lib"] +) diff --git a/dashboard/agent.py b/dashboard/agent.py index 818770f45..f0fce2f91 100644 --- a/dashboard/agent.py +++ b/dashboard/agent.py @@ -4,10 +4,14 @@ import logging import logging.handlers import os import sys +import socket +import json import traceback import aiohttp -import aioredis +import aiohttp.web +import aiohttp_cors +from aiohttp import hdrs from grpc.experimental import aio as aiogrpc import ray @@ -16,9 +20,12 @@ import ray.new_dashboard.utils as dashboard_utils import ray.ray_constants as ray_constants import ray.services import ray.utils +from ray.core.generated import agent_manager_pb2 +from ray.core.generated import agent_manager_pb2_grpc import psutil logger = logging.getLogger(__name__) +routes = dashboard_utils.ClassMethodRouteTable aiogrpc.init_grpc_aio() @@ -33,11 +40,8 @@ class DashboardAgent(object): object_store_name=None, raylet_name=None): """Initialize the DashboardAgent object.""" - self._agent_cls_list = dashboard_utils.get_all_modules( - dashboard_utils.DashboardAgentModule) - ip, port = redis_address.split(":") # Public attributes are accessible for all agent modules. - self.redis_address = (ip, int(port)) + self.redis_address = dashboard_utils.address_tuple(redis_address) self.redis_password = redis_password self.temp_dir = temp_dir self.log_dir = log_dir @@ -46,39 +50,29 @@ class DashboardAgent(object): self.raylet_name = raylet_name self.ip = ray.services.get_node_ip_address() self.server = aiogrpc.server(options=(("grpc.so_reuseport", 0), )) - listen_address = "[::]:0" - logger.info("Dashboard agent listen at: %s", listen_address) - self.port = self.server.add_insecure_port(listen_address) + self.grpc_port = self.server.add_insecure_port("[::]:0") + logger.info("Dashboard agent grpc address: %s:%s", self.ip, + self.grpc_port) self.aioredis_client = None self.aiogrpc_raylet_channel = aiogrpc.insecure_channel("{}:{}".format( self.ip, self.node_manager_port)) - self.http_session = aiohttp.ClientSession( - loop=asyncio.get_event_loop()) + self.http_session = None def _load_modules(self): """Load dashboard agent modules.""" modules = [] - for cls in self._agent_cls_list: - logger.info("Load %s: %s", + agent_cls_list = dashboard_utils.get_all_modules( + dashboard_utils.DashboardAgentModule) + for cls in agent_cls_list: + logger.info("Loading %s: %s", dashboard_utils.DashboardAgentModule.__name__, cls) c = cls(self) + dashboard_utils.ClassMethodRouteTable.bind(c) modules.append(c) - logger.info("Load {} modules.".format(len(modules))) + logger.info("Loaded {} modules.".format(len(modules))) return modules async def run(self): - # Create an aioredis client for all modules. - self.aioredis_client = await aioredis.create_redis_pool( - address=self.redis_address, password=self.redis_password) - - # Start a grpc asyncio server. - await self.server.start() - - # Write the dashboard agent port to redis. - await self.aioredis_client.set( - "{}{}".format(dashboard_consts.DASHBOARD_AGENT_PORT_PREFIX, - self.ip), self.port) - async def _check_parent(): """Check if raylet is dead.""" curr_proc = psutil.Process() @@ -92,10 +86,83 @@ class DashboardAgent(object): dashboard_consts. DASHBOARD_AGENT_CHECK_PARENT_INTERVAL_SECONDS) + check_parent_task = asyncio.create_task(_check_parent()) + + # Create an aioredis client for all modules. + try: + self.aioredis_client = await dashboard_utils.get_aioredis_client( + self.redis_address, self.redis_password, + dashboard_consts.CONNECT_REDIS_INTERNAL_SECONDS, + dashboard_consts.RETRY_REDIS_CONNECTION_TIMES) + except (socket.gaierror, ConnectionRefusedError): + logger.error( + "Dashboard agent exiting: " + "Failed to connect to redis at %s", self.redis_address) + sys.exit(-1) + + # Create a http session for all modules. + self.http_session = aiohttp.ClientSession( + loop=asyncio.get_event_loop()) + + # Start a grpc asyncio server. + await self.server.start() + modules = self._load_modules() - await asyncio.gather(_check_parent(), + + # Http server should be initialized after all modules loaded. + app = aiohttp.web.Application() + app.add_routes(routes=routes.bound_routes()) + + # Enable CORS on all routes. + cors = aiohttp_cors.setup( + app, + defaults={ + "*": aiohttp_cors.ResourceOptions( + allow_credentials=True, + expose_headers="*", + allow_methods="*", + allow_headers=("Content-Type", "X-Header"), + ) + }) + for route in list(app.router.routes()): + cors.add(route) + + runner = aiohttp.web.AppRunner(app) + await runner.setup() + site = aiohttp.web.TCPSite(runner, self.ip, 0) + await site.start() + http_host, http_port = site._server.sockets[0].getsockname() + logger.info("Dashboard agent http address: %s:%s", http_host, + http_port) + + # Dump registered http routes. + dump_routes = [ + r for r in app.router.routes() if r.method != hdrs.METH_HEAD + ] + for r in dump_routes: + logger.info(r) + logger.info("Registered %s routes.", len(dump_routes)) + + # Write the dashboard agent port to redis. + await self.aioredis_client.set( + "{}{}".format(dashboard_consts.DASHBOARD_AGENT_PORT_PREFIX, + self.ip), json.dumps([http_port, self.grpc_port])) + + # Register agent to agent manager. + raylet_stub = agent_manager_pb2_grpc.AgentManagerServiceStub( + self.aiogrpc_raylet_channel) + + await raylet_stub.RegisterAgent( + agent_manager_pb2.RegisterAgentRequest( + agent_pid=os.getpid(), + agent_port=self.grpc_port, + agent_ip_address=self.ip)) + + await asyncio.gather(check_parent_task, *(m.run(self.server) for m in modules)) await self.server.wait_for_termination() + # Wait for finish signal. + await runner.cleanup() if __name__ == "__main__": @@ -215,8 +282,7 @@ if __name__ == "__main__": raylet_name=args.raylet_name) loop = asyncio.get_event_loop() - loop.create_task(agent.run()) - loop.run_forever() + loop.run_until_complete(agent.run()) except Exception as e: # Something went wrong, so push an error to all drivers. redis_client = ray.services.create_redis_client( diff --git a/dashboard/client/build/.fakebuild b/dashboard/client/build/.fakebuild new file mode 100644 index 000000000..e69de29bb diff --git a/dashboard/client/build/static/.fakestatic b/dashboard/client/build/static/.fakestatic new file mode 100644 index 000000000..e69de29bb diff --git a/dashboard/consts.py b/dashboard/consts.py index a03cb2a90..5b321800a 100644 --- a/dashboard/consts.py +++ b/dashboard/consts.py @@ -1,16 +1,20 @@ +DASHBOARD_LOG_FILENAME = "dashboard.log" DASHBOARD_AGENT_PORT_PREFIX = "DASHBOARD_AGENT_PORT_PREFIX:" DASHBOARD_AGENT_LOG_FILENAME = "dashboard_agent.log" DASHBOARD_AGENT_CHECK_PARENT_INTERVAL_SECONDS = 2 MAX_COUNT_OF_GCS_RPC_ERROR = 10 +RETRY_REDIS_CONNECTION_TIMES = 10 UPDATE_NODES_INTERVAL_SECONDS = 5 CONNECT_GCS_INTERVAL_SECONDS = 2 +CONNECT_REDIS_INTERNAL_SECONDS = 2 PURGE_DATA_INTERVAL_SECONDS = 60 * 10 REDIS_KEY_DASHBOARD = "dashboard" +REDIS_KEY_DASHBOARD_RPC = "dashboard_rpc" REDIS_KEY_GCS_SERVER_ADDRESS = "GcsServerAddress" REPORT_METRICS_TIMEOUT_SECONDS = 2 REPORT_METRICS_INTERVAL_SECONDS = 10 # Named signals SIGNAL_NODE_INFO_FETCHED = "node_info_fetched" # Default param for RotatingFileHandler -LOGGING_ROTATE_BYTES = 100 * 1000 # maxBytes +LOGGING_ROTATE_BYTES = 100 * 1000 * 1000 # maxBytes LOGGING_ROTATE_BACKUP_COUNT = 5 # backupCount diff --git a/dashboard/dashboard.py b/dashboard/dashboard.py index 10705427f..f8f19ef21 100644 --- a/dashboard/dashboard.py +++ b/dashboard/dashboard.py @@ -13,11 +13,7 @@ import logging import logging.handlers import os import traceback -import uuid -import aioredis - -import ray import ray.new_dashboard.consts as dashboard_consts import ray.new_dashboard.head as dashboard_head import ray.new_dashboard.utils as dashboard_utils @@ -32,7 +28,7 @@ logger = logging.getLogger(__name__) routes = dashboard_utils.ClassMethodRouteTable -def setup_static_dir(app): +def setup_static_dir(): build_dir = os.path.join( os.path.dirname(os.path.abspath(__file__)), "client/build") module_name = os.path.basename(os.path.dirname(__file__)) @@ -47,7 +43,7 @@ def setup_static_dir(app): "&& npm run build)".format(module_name), build_dir) static_dir = os.path.join(build_dir, "static") - app.router.add_static("/static", static_dir, follow_symlinks=True) + routes.static("/static", static_dir, follow_symlinks=True) return build_dir @@ -62,29 +58,18 @@ class Dashboard: host(str): Host address of dashboard aiohttp server. port(int): Port number of dashboard aiohttp server. redis_address(str): GCS address of a Ray cluster - temp_dir (str): The temporary directory used for log files and - information for this Ray session. redis_password(str): Redis password to access GCS """ - def __init__(self, - host, - port, - redis_address, - temp_dir, - redis_password=None): - self.host = host - self.port = port - self.temp_dir = temp_dir - self.dashboard_id = str(uuid.uuid4()) + def __init__(self, host, port, redis_address, redis_password=None): self.dashboard_head = dashboard_head.DashboardHead( - redis_address=redis_address, redis_password=redis_password) - - self.app = aiohttp.web.Application() - self.app.add_routes(routes=routes.routes()) + http_host=host, + http_port=port, + redis_address=redis_address, + redis_password=redis_password) # Setup Dashboard Routes - build_dir = setup_static_dir(self.app) + build_dir = setup_static_dir() logger.info("Setup static dir for dashboard: %s", build_dir) dashboard_utils.ClassMethodRouteTable.bind(self) @@ -103,17 +88,7 @@ class Dashboard: "client/build/favicon.ico")) async def run(self): - coroutines = [ - self.dashboard_head.run(), - aiohttp.web._run_app(self.app, host=self.host, port=self.port) - ] - ip = ray.services.get_node_ip_address() - aioredis_client = await aioredis.create_redis_pool( - address=self.dashboard_head.redis_address, - password=self.dashboard_head.redis_password) - await aioredis_client.set(dashboard_consts.REDIS_KEY_DASHBOARD, - ip + ":" + str(self.port)) - await asyncio.gather(*coroutines) + await self.dashboard_head.run() if __name__ == "__main__": @@ -158,9 +133,10 @@ if __name__ == "__main__": "--logging-filename", required=False, type=str, - default="", + default=dashboard_consts.DASHBOARD_LOG_FILENAME, help="Specify the name of log file, " - "log to stdout if set empty, default is \"\"") + "log to stdout if set empty, default is \"{}\"".format( + dashboard_consts.DASHBOARD_LOG_FILENAME)) parser.add_argument( "--logging-rotate-bytes", required=False, @@ -221,7 +197,6 @@ if __name__ == "__main__": args.host, args.port, args.redis_address, - temp_dir, redis_password=args.redis_password) loop = asyncio.get_event_loop() loop.run_until_complete(dashboard.run()) diff --git a/dashboard/head.py b/dashboard/head.py index cb7cb2130..7b40ef013 100644 --- a/dashboard/head.py +++ b/dashboard/head.py @@ -1,9 +1,12 @@ import sys +import socket +import json import asyncio import logging import aiohttp -import aioredis +import aiohttp.web +from aiohttp import hdrs from grpc.experimental import aio as aiogrpc import ray.services @@ -25,22 +28,25 @@ def gcs_node_info_to_dict(message): class DashboardHead: - def __init__(self, redis_address, redis_password): - # Scan and import head modules for collecting http routes. - self._head_cls_list = dashboard_utils.get_all_modules( - dashboard_utils.DashboardHeadModule) - ip, port = redis_address.split(":") + def __init__(self, http_host, http_port, redis_address, redis_password): # NodeInfoGcsService self._gcs_node_info_stub = None self._gcs_rpc_error_counter = 0 # Public attributes are accessible for all head modules. - self.redis_address = (ip, int(port)) + self.http_host = http_host + self.http_port = http_port + self.redis_address = dashboard_utils.address_tuple(redis_address) self.redis_password = redis_password self.aioredis_client = None self.aiogrpc_gcs_channel = None - self.http_session = aiohttp.ClientSession( - loop=asyncio.get_event_loop()) + self.http_session = None self.ip = ray.services.get_node_ip_address() + self.server = aiogrpc.server(options=(("grpc.so_reuseport", 0), )) + self.grpc_port = self.server.add_insecure_port("[::]:0") + logger.info("Dashboard head grpc address: %s:%s", self.ip, + self.grpc_port) + logger.info("Dashboard head http address: %s:%s", self.http_host, + self.http_port) async def _get_nodes(self): """Read the client table. @@ -83,7 +89,7 @@ class DashboardHead: node_ip) agent_port = await self.aioredis_client.get(key) if agent_port: - agents[node_ip] = agent_port + agents[node_ip] = json.loads(agent_port) for ip in agents.keys() - set(node_ips): agents.pop(ip, None) @@ -112,18 +118,34 @@ class DashboardHead: def _load_modules(self): """Load dashboard head modules.""" modules = [] - for cls in self._head_cls_list: - logger.info("Load %s: %s", + head_cls_list = dashboard_utils.get_all_modules( + dashboard_utils.DashboardHeadModule) + for cls in head_cls_list: + logger.info("Loading %s: %s", dashboard_utils.DashboardHeadModule.__name__, cls) c = cls(self) dashboard_utils.ClassMethodRouteTable.bind(c) modules.append(c) + logger.info("Loaded {} modules.".format(len(modules))) return modules async def run(self): # Create an aioredis client for all modules. - self.aioredis_client = await aioredis.create_redis_pool( - address=self.redis_address, password=self.redis_password) + try: + self.aioredis_client = await dashboard_utils.get_aioredis_client( + self.redis_address, self.redis_password, + dashboard_consts.CONNECT_REDIS_INTERNAL_SECONDS, + dashboard_consts.RETRY_REDIS_CONNECTION_TIMES) + except (socket.gaierror, ConnectionError): + logger.error( + "Dashboard head exiting: " + "Failed to connect to redis at %s", self.redis_address) + sys.exit(-1) + + # Create a http session for all modules. + self.http_session = aiohttp.ClientSession( + loop=asyncio.get_event_loop()) + # Waiting for GCS is ready. while True: try: @@ -140,10 +162,21 @@ class DashboardHead: else: self.aiogrpc_gcs_channel = channel break + # Create a NodeInfoGcsServiceStub. self._gcs_node_info_stub = gcs_service_pb2_grpc.NodeInfoGcsServiceStub( self.aiogrpc_gcs_channel) + # Start a grpc asyncio server. + await self.server.start() + + # Write the dashboard head port to redis. + await self.aioredis_client.set(dashboard_consts.REDIS_KEY_DASHBOARD, + self.ip + ":" + str(self.http_port)) + await self.aioredis_client.set( + dashboard_consts.REDIS_KEY_DASHBOARD_RPC, + self.ip + ":" + str(self.grpc_port)) + async def _async_notify(): """Notify signals from queue.""" while True: @@ -164,7 +197,24 @@ class DashboardHead: logger.exception(e) modules = self._load_modules() + + # Http server should be initialized after all modules loaded. + app = aiohttp.web.Application() + app.add_routes(routes=routes.bound_routes()) + web_server = aiohttp.web._run_app( + app, host=self.http_host, port=self.http_port) + + # Dump registered http routes. + dump_routes = [ + r for r in app.router.routes() if r.method != hdrs.METH_HEAD + ] + for r in dump_routes: + logger.info(r) + logger.info("Registered %s routes.", len(dump_routes)) + # Freeze signal after all modules loaded. dashboard_utils.SignalManager.freeze() await asyncio.gather(self._update_nodes(), _async_notify(), - _purge_data(), *(m.run() for m in modules)) + _purge_data(), web_server, + *(m.run(self.server) for m in modules)) + await self.server.wait_for_termination() diff --git a/dashboard/modules/reporter/reporter_agent.py b/dashboard/modules/reporter/reporter_agent.py index 40d2d2491..4e3f7cd55 100644 --- a/dashboard/modules/reporter/reporter_agent.py +++ b/dashboard/modules/reporter/reporter_agent.py @@ -21,6 +21,13 @@ import psutil logger = logging.getLogger(__name__) +try: + import gpustat.core as gpustat +except ImportError: + gpustat = None + logger.warning( + "Install gpustat with 'pip install gpustat' to enable GPU monitoring.") + def recursive_asdict(o): if isinstance(o, tuple) and hasattr(o, "_asdict"): @@ -81,10 +88,35 @@ class ReporterAgent(dashboard_utils.DashboardAgentModule, return reporter_pb2.GetProfilingStatsReply( profiling_stats=profiling_stats, stdout=stdout, stderr=stderr) + async def ReportMetrics(self, request, context): + # TODO(sang): Process metrics here. + return reporter_pb2.ReportMetricsReply() + @staticmethod def _get_cpu_percent(): return psutil.cpu_percent() + @staticmethod + def _get_gpu_usage(): + if gpustat is None: + return [] + gpu_utilizations = [] + gpus = [] + try: + gpus = gpustat.new_query().gpus + except Exception as e: + logger.debug( + "gpustat failed to retrieve GPU information: {}".format(e)) + for gpu in gpus: + # Note the keys in this dict have periods which throws + # off javascript so we change .s to _s + gpu_data = { + "_".join(key.split(".")): val + for key, val in gpu.entry.items() + } + gpu_utilizations.append(gpu_data) + return gpu_utilizations + @staticmethod def _get_boot_time(): return psutil.boot_time() @@ -173,6 +205,7 @@ class ReporterAgent(dashboard_utils.DashboardAgentModule, "bootTime": self._get_boot_time(), "loadAvg": self._get_load_avg(), "disk": self._get_disk_usage(), + "gpus": self._get_gpu_usage(), "net": netstats, "cmdline": self._get_raylet_cmdline(), } diff --git a/dashboard/modules/reporter/reporter_head.py b/dashboard/modules/reporter/reporter_head.py index 959aadb41..fdd262a60 100644 --- a/dashboard/modules/reporter/reporter_head.py +++ b/dashboard/modules/reporter/reporter_head.py @@ -29,8 +29,8 @@ class ReportHead(dashboard_utils.DashboardHeadModule): async def _update_stubs(self, change): if change.new: - ip, port = next(iter(change.new.items())) - channel = aiogrpc.insecure_channel("{}:{}".format(ip, int(port))) + ip, ports = next(iter(change.new.items())) + channel = aiogrpc.insecure_channel("{}:{}".format(ip, ports[1])) stub = reporter_pb2_grpc.ReporterServiceStub(channel) self._stubs[ip] = stub if change.old: @@ -77,15 +77,15 @@ class ReportHead(dashboard_utils.DashboardHeadModule): message="Profiling info fetched.", profiling_info=json.loads(profiling_stats.profiling_stats)) - async def run(self): - p = self._dashboard_head.aioredis_client - mpsc = Receiver() + async def run(self, server): + aioredis_client = self._dashboard_head.aioredis_client + receiver = Receiver() reporter_key = "{}*".format(reporter_consts.REPORTER_PREFIX) - await p.psubscribe(mpsc.pattern(reporter_key)) + await aioredis_client.psubscribe(receiver.pattern(reporter_key)) logger.info("Subscribed to {}".format(reporter_key)) - async for sender, msg in mpsc.iter(): + async for sender, msg in receiver.iter(): try: _, data = msg data = json.loads(ray.utils.decode(data)) diff --git a/dashboard/modules/test/__init__.py b/dashboard/modules/test/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/dashboard/modules/test/test_agent.py b/dashboard/modules/test/test_agent.py new file mode 100644 index 000000000..efe9c3285 --- /dev/null +++ b/dashboard/modules/test/test_agent.py @@ -0,0 +1,24 @@ +import logging + +import aiohttp.web + +import ray.new_dashboard.utils as dashboard_utils +import ray.new_dashboard.modules.test.test_utils as test_utils + +logger = logging.getLogger(__name__) +routes = dashboard_utils.ClassMethodRouteTable + + +class HeadAgent(dashboard_utils.DashboardAgentModule): + def __init__(self, dashboard_agent): + super().__init__(dashboard_agent) + + @routes.get("/test/http_get_from_agent") + async def get_url(self, req) -> aiohttp.web.Response: + url = req.query.get("url") + result = await test_utils.http_get(self._dashboard_agent.http_session, + url) + return aiohttp.web.json_response(result) + + async def run(self, server): + pass diff --git a/dashboard/modules/test/test_head.py b/dashboard/modules/test/test_head.py new file mode 100644 index 000000000..61e7635e0 --- /dev/null +++ b/dashboard/modules/test/test_head.py @@ -0,0 +1,62 @@ +import logging + +import aiohttp.web + +import ray.new_dashboard.utils as dashboard_utils +import ray.new_dashboard.modules.test.test_utils as test_utils +from ray.new_dashboard.datacenter import DataSource + +logger = logging.getLogger(__name__) +routes = dashboard_utils.ClassMethodRouteTable + + +class TestHead(dashboard_utils.DashboardHeadModule): + def __init__(self, dashboard_head): + super().__init__(dashboard_head) + self._notified_agents = {} + DataSource.agents.signal.append(self._update_notified_agents) + + async def _update_notified_agents(self, change): + if change.new: + ip, ports = next(iter(change.new.items())) + self._notified_agents[ip] = ports + if change.old: + ip, port = next(iter(change.old.items())) + self._notified_agents.pop(ip) + + @routes.get("/test/dump") + async def dump(self, req) -> aiohttp.web.Response: + key = req.query.get("key") + if key is None: + all_data = { + k: dict(v) + for k, v in DataSource.__dict__.items() + if not k.startswith("_") + } + return await dashboard_utils.rest_response( + success=True, + message="Fetch all data from datacenter success.", + **all_data) + else: + data = dict(DataSource.__dict__.get(key)) + return await dashboard_utils.rest_response( + success=True, + message="Fetch {} from datacenter success.".format(key), + **{key: data}) + + @routes.get("/test/notified_agents") + async def get_notified_agents(self, req) -> aiohttp.web.Response: + return await dashboard_utils.rest_response( + success=True, + message="Fetch notified agents success.", + **self._notified_agents) + + @routes.get("/test/http_get") + async def get_url(self, req) -> aiohttp.web.Response: + url = req.query.get("url") + result = await test_utils.http_get(self._dashboard_head.http_session, + url) + return aiohttp.web.json_response(result) + + async def run(self, server): + pass diff --git a/dashboard/modules/test/test_utils.py b/dashboard/modules/test/test_utils.py new file mode 100644 index 000000000..5315b05e1 --- /dev/null +++ b/dashboard/modules/test/test_utils.py @@ -0,0 +1,11 @@ +import logging + +import async_timeout + +logger = logging.getLogger(__name__) + + +async def http_get(http_session, url, timeout_seconds=60): + with async_timeout.timeout(timeout_seconds): + async with http_session.get(url) as response: + return await response.json() diff --git a/dashboard/tests/conftest.py b/dashboard/tests/conftest.py new file mode 100644 index 000000000..a60ce1007 --- /dev/null +++ b/dashboard/tests/conftest.py @@ -0,0 +1 @@ +from ray.tests.conftest import * # noqa diff --git a/dashboard/tests/test_dashboard.py b/dashboard/tests/test_dashboard.py new file mode 100644 index 000000000..325044ad7 --- /dev/null +++ b/dashboard/tests/test_dashboard.py @@ -0,0 +1,224 @@ +import os +import json +import time +import logging + +import ray +import psutil +import pytest +import redis +import requests + +from ray import ray_constants +from ray.test_utils import wait_for_condition, wait_until_server_available +import ray.new_dashboard.consts as dashboard_consts +import ray.new_dashboard.modules + +os.environ["RAY_USE_NEW_DASHBOARD"] = "1" + +logger = logging.getLogger(__name__) + + +def cleanup_test_files(): + module_path = ray.new_dashboard.modules.__path__[0] + filename = os.path.join(module_path, "test_for_bad_import.py") + logger.info("Remove test file: %s", filename) + try: + os.remove(filename) + except Exception: + pass + + +def prepare_test_files(): + module_path = ray.new_dashboard.modules.__path__[0] + filename = os.path.join(module_path, "test_for_bad_import.py") + logger.info("Prepare test file: %s", filename) + with open(filename, "w") as f: + f.write(">>>") + + +cleanup_test_files() + + +@pytest.mark.parametrize( + "ray_start_with_dashboard", [{ + "_internal_config": json.dumps({ + "agent_register_timeout_ms": 5000 + }) + }], + indirect=True) +def test_basic(ray_start_with_dashboard): + """Dashboard test that starts a Ray cluster with a dashboard server running, + then hits the dashboard API and asserts that it receives sensible data.""" + assert (wait_until_server_available(ray_start_with_dashboard["webui_url"]) + is True) + address_info = ray_start_with_dashboard + address = address_info["redis_address"] + address = address.split(":") + assert len(address) == 2 + + client = redis.StrictRedis( + host=address[0], + port=int(address[1]), + password=ray_constants.REDIS_DEFAULT_PASSWORD) + + all_processes = ray.worker._global_node.all_processes + assert ray_constants.PROCESS_TYPE_DASHBOARD in all_processes + assert ray_constants.PROCESS_TYPE_REPORTER not in all_processes + dashboard_proc_info = all_processes[ray_constants.PROCESS_TYPE_DASHBOARD][ + 0] + dashboard_proc = psutil.Process(dashboard_proc_info.process.pid) + assert dashboard_proc.status() == psutil.STATUS_RUNNING + raylet_proc_info = all_processes[ray_constants.PROCESS_TYPE_RAYLET][0] + raylet_proc = psutil.Process(raylet_proc_info.process.pid) + + def _search_agent(processes): + for p in processes: + try: + for c in p.cmdline(): + if "new_dashboard/agent.py" in c: + return p + except Exception: + pass + + # Test for bad imports, the agent should be restarted. + logger.info("Test for bad imports.") + agent_proc = _search_agent(raylet_proc.children()) + prepare_test_files() + agent_pids = set() + try: + assert agent_proc is not None + agent_proc.kill() + agent_proc.wait() + # The agent will be restarted for imports failure. + for x in range(40): + agent_proc = _search_agent(raylet_proc.children()) + if agent_proc: + agent_pids.add(agent_proc.pid) + time.sleep(0.1) + finally: + cleanup_test_files() + assert len(agent_pids) > 1, agent_pids + + agent_proc = _search_agent(raylet_proc.children()) + if agent_proc: + agent_proc.kill() + agent_proc.wait() + + logger.info("Test agent register is OK.") + wait_for_condition(lambda: _search_agent(raylet_proc.children())) + assert dashboard_proc.status() == psutil.STATUS_RUNNING + agent_proc = _search_agent(raylet_proc.children()) + agent_pid = agent_proc.pid + + # Check if agent register is OK. + for x in range(5): + logger.info("Check agent is alive.") + agent_proc = _search_agent(raylet_proc.children()) + assert agent_proc.pid == agent_pid + time.sleep(1) + + # Check redis keys are set. + logger.info("Check redis keys are set.") + dashboard_address = client.get(dashboard_consts.REDIS_KEY_DASHBOARD) + assert dashboard_address is not None + dashboard_rpc_address = client.get( + dashboard_consts.REDIS_KEY_DASHBOARD_RPC) + assert dashboard_rpc_address is not None + key = "{}{}".format(dashboard_consts.DASHBOARD_AGENT_PORT_PREFIX, + address[0]) + agent_ports = client.get(key) + assert agent_ports is not None + + +def test_nodes_update(ray_start_with_dashboard): + assert (wait_until_server_available(ray_start_with_dashboard["webui_url"]) + is True) + webui_url = ray_start_with_dashboard["webui_url"] + webui_url = webui_url.replace("localhost", "http://127.0.0.1") + + timeout_seconds = 20 + start_time = time.time() + while True: + time.sleep(1) + try: + response = requests.get(webui_url + "/test/dump") + response.raise_for_status() + try: + dump_info = response.json() + except Exception as ex: + logger.info("failed response: {}".format(response.text)) + raise ex + assert dump_info["result"] is True + dump_data = dump_info["data"] + assert len(dump_data["nodes"]) == 1 + assert len(dump_data["agents"]) == 1 + assert len(dump_data["hostnameToIp"]) == 1 + assert len(dump_data["ipToHostname"]) == 1 + assert dump_data["nodes"].keys() == dump_data[ + "ipToHostname"].keys() + + response = requests.get(webui_url + "/test/notified_agents") + response.raise_for_status() + try: + notified_agents = response.json() + except Exception as ex: + logger.info("failed response: {}".format(response.text)) + raise ex + assert notified_agents["result"] is True + notified_agents = notified_agents["data"] + assert len(notified_agents) == 1 + assert notified_agents == dump_data["agents"] + break + except (AssertionError, requests.exceptions.ConnectionError) as e: + logger.info("Retry because of %s", e) + finally: + if time.time() > start_time + timeout_seconds: + raise Exception( + "Timed out while waiting for dashboard to start.") + + +def test_http_get(ray_start_with_dashboard): + assert (wait_until_server_available(ray_start_with_dashboard["webui_url"]) + is True) + webui_url = ray_start_with_dashboard["webui_url"] + webui_url = webui_url.replace("localhost", "http://127.0.0.1") + + target_url = webui_url + "/test/dump" + + timeout_seconds = 20 + start_time = time.time() + while True: + time.sleep(1) + try: + response = requests.get(webui_url + "/test/http_get?url=" + + target_url) + response.raise_for_status() + try: + dump_info = response.json() + except Exception as ex: + logger.info("failed response: {}".format(response.text)) + raise ex + assert dump_info["result"] is True + dump_data = dump_info["data"] + assert len(dump_data["agents"]) == 1 + ip, ports = next(iter(dump_data["agents"].items())) + http_port, grpc_port = ports + + response = requests.get( + "http://{}:{}/test/http_get_from_agent?url={}".format( + ip, http_port, target_url)) + response.raise_for_status() + try: + dump_info = response.json() + except Exception as ex: + logger.info("failed response: {}".format(response.text)) + raise ex + assert dump_info["result"] is True + break + except (AssertionError, requests.exceptions.ConnectionError) as e: + logger.info("Retry because of %s", e) + finally: + if time.time() > start_time + timeout_seconds: + raise Exception( + "Timed out while waiting for dashboard to start.") diff --git a/dashboard/utils.py b/dashboard/utils.py index e8e546f6d..054fe9fd4 100644 --- a/dashboard/utils.py +++ b/dashboard/utils.py @@ -1,4 +1,5 @@ import abc +import socket import asyncio import collections import copy @@ -12,10 +13,14 @@ import pkgutil import traceback from base64 import b64decode from collections.abc import MutableMapping, Mapping +from typing import Any +import aioredis import aiohttp.web from aiohttp import hdrs from aiohttp.frozenlist import FrozenList +from aiohttp.typedefs import PathLike +from aiohttp.web import RouteDef import aiohttp.signals from google.protobuf.json_format import MessageToDict from ray.utils import binary_to_hex @@ -52,9 +57,12 @@ class DashboardHeadModule(abc.ABC): self._dashboard_head = dashboard_head @abc.abstractmethod - async def run(self): + async def run(self, server): """ - Run the module in an asyncio loop. + Run the module in an asyncio loop. A head module can provide + servicers to the server. + + :param server: Asyncio GRPC server. """ @@ -74,6 +82,22 @@ class ClassMethodRouteTable: def routes(cls): return cls._routes + @classmethod + def bound_routes(cls): + bound_items = [] + for r in cls._routes._items: + if isinstance(r, RouteDef): + route_method = getattr(r.handler, "__route_method__") + route_path = getattr(r.handler, "__route_path__") + instance = cls._bind_map[route_method][route_path].instance + if instance is not None: + bound_items.append(r) + else: + bound_items.append(r) + routes = aiohttp.web.RouteTableDef() + routes._items = bound_items + return routes + @classmethod def _register_route(cls, method, path, **kwargs): def _wrapper(handler): @@ -132,6 +156,10 @@ class ClassMethodRouteTable: def view(cls, path, **kwargs): return cls._register_route(hdrs.METH_ANY, path, **kwargs) + @classmethod + def static(cls, prefix: str, path: PathLike, **kwargs: Any) -> None: + cls._routes.static(prefix, path, **kwargs) + @classmethod def bind(cls, instance): def predicate(o): @@ -161,6 +189,13 @@ def to_posix_time(dt): return (dt - datetime.datetime(1970, 1, 1)).total_seconds() +def address_tuple(address): + if isinstance(address, tuple): + return address + ip, port = address.split(":") + return ip, int(port) + + class CustomEncoder(json.JSONEncoder): def default(self, obj): if isinstance(obj, bytes): @@ -338,3 +373,17 @@ class Dict(MutableMapping): for key in self._data.keys() - d.keys(): self.pop(key) self.update(d) + + +async def get_aioredis_client(redis_address, redis_password, + retry_interval_seconds, retry_times): + for x in range(retry_times): + try: + return await aioredis.create_redis_pool( + address=redis_address, password=redis_password) + except (socket.gaierror, ConnectionError) as ex: + logger.error("Connect to Redis failed: %s, retry...", ex) + await asyncio.sleep(retry_interval_seconds) + # Raise exception from create_redis_pool + return await aioredis.create_redis_pool( + address=redis_address, password=redis_password) diff --git a/python/ray/node.py b/python/ray/node.py index 3bf9757c9..553b9e4c9 100644 --- a/python/ray/node.py +++ b/python/ray/node.py @@ -615,8 +615,11 @@ class Node: if we fail to start the dashboard. Otherwise it will print a warning if we fail to start the dashboard. """ - stdout_file, stderr_file = self.get_log_file_handles( - "dashboard", unique=True) + if "RAY_USE_NEW_DASHBOARD" in os.environ: + stdout_file, stderr_file = None, None + else: + stdout_file, stderr_file = self.get_log_file_handles( + "dashboard", unique=True) self._webui_url, process_info = ray.services.start_dashboard( require_dashboard, self._ray_params.dashboard_host, @@ -797,7 +800,8 @@ class Node: self.start_plasma_store() self.start_raylet() - self.start_reporter() + if "RAY_USE_NEW_DASHBOARD" not in os.environ: + self.start_reporter() if self._ray_params.include_log_monitor: self.start_log_monitor() diff --git a/python/ray/services.py b/python/ray/services.py index 9c4e2552b..7acbbf075 100644 --- a/python/ray/services.py +++ b/python/ray/services.py @@ -1160,8 +1160,14 @@ def start_dashboard(require_dashboard, raise ValueError( f"The given dashboard port {port} is already in use") + if "RAY_USE_NEW_DASHBOARD" in os.environ: + dashboard_dir = "new_dashboard" + else: + dashboard_dir = "dashboard" + dashboard_filepath = os.path.join( - os.path.dirname(os.path.abspath(__file__)), "dashboard/dashboard.py") + os.path.dirname(os.path.abspath(__file__)), dashboard_dir, + "dashboard.py") command = [ sys.executable, "-u", @@ -1398,6 +1404,23 @@ def start_raylet(redis_address, start_worker_command.append( f"--object-spilling-config={json.dumps(object_spilling_config)}") + # Create agent command + agent_command = [ + sys.executable, + "-u", + os.path.join( + os.path.dirname(os.path.abspath(__file__)), + "new_dashboard/agent.py"), + "--redis-address={}".format(redis_address), + "--node-manager-port={}".format(node_manager_port), + "--object-store-name={}".format(plasma_store_name), + "--raylet-name={}".format(raylet_name), + "--temp-dir={}".format(temp_dir), + ] + + if redis_password is not None and len(redis_password) != 0: + agent_command.append("--redis-password={}".format(redis_password)) + command = [ RAYLET_EXECUTABLE, f"--raylet_socket_name={raylet_name}", @@ -1424,6 +1447,9 @@ def start_raylet(redis_address, if start_initial_python_workers_for_first_job: command.append("--num_initial_python_workers_for_first_job={}".format( resource_spec.num_cpus)) + if "RAY_USE_NEW_DASHBOARD" in os.environ: + command.append("--agent_command={}".format( + subprocess.list2cmdline(agent_command))) if config.get("plasma_store_as_thread"): # command related to the plasma store plasma_directory, object_store_memory = determine_plasma_store_config( diff --git a/python/setup.py b/python/setup.py index 2fbded274..d2a7c2872 100644 --- a/python/setup.py +++ b/python/setup.py @@ -133,6 +133,7 @@ extras["all"] = list(set(chain.from_iterable(extras.values()))) # the change in the matching section of requirements.txt install_requires = [ "aiohttp", + "aiohttp_cors", "aioredis", "click >= 7.0", "colorama", @@ -408,7 +409,7 @@ def api_main(program, *args): nonlocal result if excinfo[1].errno != errno.ENOENT: msg = excinfo[1].strerror - logger.error("cannot remove {}: {}" % (path, msg)) + logger.error("cannot remove {}: {}".format(path, msg)) result = 1 for subdir in CLEANABLE_SUBDIRS: diff --git a/src/ray/common/ray_config_def.h b/src/ray/common/ray_config_def.h index 6b3629710..9e68c1e79 100644 --- a/src/ray/common/ray_config_def.h +++ b/src/ray/common/ray_config_def.h @@ -328,6 +328,12 @@ RAY_CONFIG(bool, put_small_object_in_memory_store, false) /// pipelining task submission. RAY_CONFIG(uint32_t, max_tasks_in_flight_per_worker, 1) +/// Interval to restart dashboard agent after the process exit. +RAY_CONFIG(uint32_t, agent_restart_interval_ms, 1000) + +/// Wait timeout for dashboard agent register. +RAY_CONFIG(uint32_t, agent_register_timeout_ms, 30 * 1000) + /// The maximum number of resource shapes included in the resource /// load reported by each raylet. RAY_CONFIG(int64_t, max_resource_shapes_per_load_report, 100) diff --git a/src/ray/gcs/gcs_server/gcs_actor_scheduler.cc b/src/ray/gcs/gcs_server/gcs_actor_scheduler.cc index 8f7618401..5b82bc84c 100644 --- a/src/ray/gcs/gcs_server/gcs_actor_scheduler.cc +++ b/src/ray/gcs/gcs_server/gcs_actor_scheduler.cc @@ -249,9 +249,9 @@ void GcsActorScheduler::LeaseWorkerFromNode(std::shared_ptr actor, void GcsActorScheduler::RetryLeasingWorkerFromNode( std::shared_ptr actor, std::shared_ptr node) { - execute_after(io_context_, - [this, node, actor] { DoRetryLeasingWorkerFromNode(actor, node); }, - RayConfig::instance().gcs_lease_worker_retry_interval_ms()); + RAY_UNUSED(execute_after( + io_context_, [this, node, actor] { DoRetryLeasingWorkerFromNode(actor, node); }, + RayConfig::instance().gcs_lease_worker_retry_interval_ms())); } void GcsActorScheduler::DoRetryLeasingWorkerFromNode( @@ -370,9 +370,9 @@ void GcsActorScheduler::CreateActorOnWorker(std::shared_ptr actor, void GcsActorScheduler::RetryCreatingActorOnWorker( std::shared_ptr actor, std::shared_ptr worker) { - execute_after(io_context_, - [this, actor, worker] { DoRetryCreatingActorOnWorker(actor, worker); }, - RayConfig::instance().gcs_create_actor_retry_interval_ms()); + RAY_UNUSED(execute_after( + io_context_, [this, actor, worker] { DoRetryCreatingActorOnWorker(actor, worker); }, + RayConfig::instance().gcs_create_actor_retry_interval_ms())); } void GcsActorScheduler::DoRetryCreatingActorOnWorker( diff --git a/src/ray/protobuf/BUILD b/src/ray/protobuf/BUILD index 1dc7f0eb0..9efff0f4c 100644 --- a/src/ray/protobuf/BUILD +++ b/src/ray/protobuf/BUILD @@ -86,6 +86,11 @@ cc_proto_library( deps = [":gcs_service_proto"], ) +python_grpc_compile( + name = "gcs_service_py_proto", + deps = [":gcs_service_proto"], +) + proto_library( name = "object_manager_proto", srcs = ["object_manager.proto"], @@ -132,3 +137,21 @@ cc_proto_library( name = "event_cc_proto", deps = [":event_proto"], ) + +# Agent manager gRPC lib. +proto_library( + name = "agent_manager_proto", + srcs = ["agent_manager.proto"], + deps = [], +) + +python_grpc_compile( + name = "agent_manager_py_proto", + deps = [":agent_manager_proto"], +) + +cc_proto_library( + name = "agent_manager_cc_proto", + deps = [":agent_manager_proto"], +) + diff --git a/src/ray/protobuf/agent_manager.proto b/src/ray/protobuf/agent_manager.proto new file mode 100644 index 000000000..f573f5376 --- /dev/null +++ b/src/ray/protobuf/agent_manager.proto @@ -0,0 +1,38 @@ +// Copyright 2017 The Ray Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto3"; + +package ray.rpc; + +enum AgentRpcStatus { + // OK. + AGENT_RPC_STATUS_OK = 0; + // Failed. + AGENT_RPC_STATUS_FAILED = 1; +} + +message RegisterAgentRequest { + int32 agent_pid = 1; + int32 agent_port = 2; + string agent_ip_address = 3; +} + +message RegisterAgentReply { + AgentRpcStatus status = 1; +} + +service AgentManagerService { + rpc RegisterAgent(RegisterAgentRequest) returns (RegisterAgentReply); +} diff --git a/src/ray/raylet/agent_manager.cc b/src/ray/raylet/agent_manager.cc new file mode 100644 index 000000000..af4c482d3 --- /dev/null +++ b/src/ray/raylet/agent_manager.cc @@ -0,0 +1,96 @@ +// Copyright 2017 The Ray Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ray/raylet/agent_manager.h" + +#include + +#include "ray/common/ray_config.h" +#include "ray/util/logging.h" +#include "ray/util/process.h" + +namespace ray { +namespace raylet { + +void AgentManager::HandleRegisterAgent(const rpc::RegisterAgentRequest &request, + rpc::RegisterAgentReply *reply, + rpc::SendReplyCallback send_reply_callback) { + agent_ip_address_ = request.agent_ip_address(); + agent_port_ = request.agent_port(); + agent_pid_ = request.agent_pid(); + RAY_LOG(INFO) << "HandleRegisterAgent, ip: " << agent_ip_address_ + << ", port: " << agent_port_ << ", pid: " << agent_pid_; + reply->set_status(rpc::AGENT_RPC_STATUS_OK); + send_reply_callback(ray::Status::OK(), nullptr, nullptr); +} + +void AgentManager::StartAgent() { + if (options_.agent_commands.empty()) { + RAY_LOG(INFO) << "Not starting agent, the agent command is empty."; + return; + } + + if (RAY_LOG_ENABLED(DEBUG)) { + std::stringstream stream; + stream << "Starting agent process with command:"; + for (const auto &arg : options_.agent_commands) { + stream << " " << arg; + } + RAY_LOG(DEBUG) << stream.str(); + } + + // Launch the process to create the agent. + std::error_code ec; + std::vector argv; + for (const std::string &arg : options_.agent_commands) { + argv.push_back(arg.c_str()); + } + argv.push_back(NULL); + Process child(argv.data(), nullptr, ec); + if (!child.IsValid() || ec) { + // The worker failed to start. This is a fatal error. + RAY_LOG(FATAL) << "Failed to start agent with return value " << ec << ": " + << ec.message(); + RAY_UNUSED(delay_executor_([this] { StartAgent(); }, + RayConfig::instance().agent_restart_interval_ms())); + return; + } + + std::thread monitor_thread([this, child]() mutable { + RAY_LOG(INFO) << "Monitor agent process with pid " << child.GetId() + << ", register timeout " + << RayConfig::instance().agent_register_timeout_ms() << "ms."; + auto timer = delay_executor_( + [this, child]() mutable { + if (agent_pid_ != child.GetId()) { + RAY_LOG(WARNING) << "Agent process with pid " << child.GetId() + << " has not registered, restart it."; + child.Kill(); + } + }, + RayConfig::instance().agent_register_timeout_ms()); + + int exit_code = child.Wait(); + timer->cancel(); + + RAY_LOG(WARNING) << "Agent process with pid " << child.GetId() + << " exit, return value " << exit_code; + RAY_UNUSED(delay_executor_([this] { StartAgent(); }, + RayConfig::instance().agent_restart_interval_ms())); + }); + monitor_thread.detach(); +} + +} // namespace raylet +} // namespace ray diff --git a/src/ray/raylet/agent_manager.h b/src/ray/raylet/agent_manager.h new file mode 100644 index 000000000..29f19d4c6 --- /dev/null +++ b/src/ray/raylet/agent_manager.h @@ -0,0 +1,75 @@ +// Copyright 2017 The Ray Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include + +#include "ray/rpc/agent_manager/agent_manager_client.h" +#include "ray/rpc/agent_manager/agent_manager_server.h" +#include "ray/util/process.h" + +namespace ray { +namespace raylet { + +typedef std::function(std::function, + uint32_t delay_ms)> + DelayExecutorFn; + +class AgentManager : public rpc::AgentManagerServiceHandler { + public: + struct Options { + std::vector agent_commands; + }; + + explicit AgentManager(Options options, DelayExecutorFn delay_executor) + : options_(std::move(options)), delay_executor_(std::move(delay_executor)) { + StartAgent(); + } + + void HandleRegisterAgent(const rpc::RegisterAgentRequest &request, + rpc::RegisterAgentReply *reply, + rpc::SendReplyCallback send_reply_callback) override; + + private: + void StartAgent(); + + private: + Options options_; + pid_t agent_pid_ = 0; + int agent_port_ = 0; + std::string agent_ip_address_; + DelayExecutorFn delay_executor_; +}; + +class DefaultAgentManagerServiceHandler : public rpc::AgentManagerServiceHandler { + public: + explicit DefaultAgentManagerServiceHandler(std::unique_ptr &delegate) + : delegate_(delegate) {} + + void HandleRegisterAgent(const rpc::RegisterAgentRequest &request, + rpc::RegisterAgentReply *reply, + rpc::SendReplyCallback send_reply_callback) override { + RAY_CHECK(delegate_ != nullptr); + delegate_->HandleRegisterAgent(request, reply, send_reply_callback); + } + + private: + std::unique_ptr &delegate_; +}; + +} // namespace raylet +} // namespace ray diff --git a/src/ray/raylet/main.cc b/src/ray/raylet/main.cc index ac66bf2c4..16e07ba92 100644 --- a/src/ray/raylet/main.cc +++ b/src/ray/raylet/main.cc @@ -44,6 +44,7 @@ DEFINE_string(static_resource_list, "", "The static resource list of this node." DEFINE_string(config_list, "", "The raylet config list of this node."); DEFINE_string(python_worker_command, "", "Python worker command."); DEFINE_string(java_worker_command, "", "Java worker command."); +DEFINE_string(agent_command, "", "Dashboard agent command."); DEFINE_string(redis_password, "", "The password of redis."); DEFINE_string(temp_dir, "", "Temporary directory."); DEFINE_string(session_dir, "", "The path of this ray session directory."); @@ -82,6 +83,7 @@ int main(int argc, char *argv[]) { const std::string config_list = FLAGS_config_list; const std::string python_worker_command = FLAGS_python_worker_command; const std::string java_worker_command = FLAGS_java_worker_command; + const std::string agent_command = FLAGS_agent_command; const std::string redis_password = FLAGS_redis_password; const std::string temp_dir = FLAGS_temp_dir; const std::string session_dir = FLAGS_session_dir; @@ -184,6 +186,11 @@ int main(int argc, char *argv[]) { RAY_CHECK(0) << "Either Python worker command or Java worker command should be " "provided."; } + if (!agent_command.empty()) { + node_manager_config.agent_command = agent_command; + } else { + RAY_LOG(DEBUG) << "Agent command is empty."; + } node_manager_config.heartbeat_period_ms = RayConfig::instance().raylet_heartbeat_timeout_milliseconds(); diff --git a/src/ray/raylet/node_manager.cc b/src/ray/raylet/node_manager.cc index 6867b5fd2..094236e91 100644 --- a/src/ray/raylet/node_manager.cc +++ b/src/ray/raylet/node_manager.cc @@ -25,6 +25,7 @@ #include "ray/gcs/pb_util.h" #include "ray/raylet/format/node_manager_generated.h" #include "ray/stats/stats.h" +#include "ray/util/asio_util.h" #include "ray/util/sample.h" namespace { @@ -164,6 +165,9 @@ NodeManager::NodeManager(boost::asio::io_service &io_service, actor_registry_(), node_manager_server_("NodeManager", config.node_manager_port), node_manager_service_(io_service, *this), + agent_manager_service_handler_( + new DefaultAgentManagerServiceHandler(agent_manager_)), + agent_manager_service_(io_service, *agent_manager_service_handler_), client_call_manager_(io_service), new_scheduler_enabled_(RayConfig::instance().new_scheduler_enabled()) { RAY_LOG(INFO) << "Initializing NodeManager with ID " << self_node_id_; @@ -208,8 +212,18 @@ NodeManager::NodeManager(boost::asio::io_service &io_service, RAY_CHECK_OK(store_client_.Connect(config.store_socket_name.c_str())); // Run the node manger rpc server. node_manager_server_.RegisterService(node_manager_service_); + node_manager_server_.RegisterService(agent_manager_service_); node_manager_server_.Run(); + AgentManager::Options options; + options.agent_commands = ParseCommandLine(config.agent_command); + agent_manager_.reset( + new AgentManager(std::move(options), + /*delay_executor=*/ + [this](std::function task, uint32_t delay_ms) { + return execute_after(io_service_, task, delay_ms); + })); + RAY_CHECK_OK(SetupPlasmaSubscription()); } diff --git a/src/ray/raylet/node_manager.h b/src/ray/raylet/node_manager.h index 612d7004b..9186df3d0 100644 --- a/src/ray/raylet/node_manager.h +++ b/src/ray/raylet/node_manager.h @@ -27,6 +27,7 @@ #include "ray/common/task/scheduling_resources.h" #include "ray/object_manager/object_manager.h" #include "ray/raylet/actor_registration.h" +#include "ray/raylet/agent_manager.h" #include "ray/raylet/scheduling/scheduling_ids.h" #include "ray/raylet/scheduling/cluster_resource_scheduler.h" #include "ray/raylet/scheduling/cluster_task_manager.h" @@ -73,6 +74,8 @@ struct NodeManagerConfig { int maximum_startup_concurrency; /// The commands used to start the worker process, grouped by language. WorkerCommandMap worker_commands; + /// The command used to start agent. + std::string agent_command; /// The time between heartbeats in milliseconds. uint64_t heartbeat_period_ms; /// The time between debug dumps in milliseconds, or -1 to disable. @@ -721,12 +724,18 @@ class NodeManager : public rpc::NodeManagerServiceHandler { /// restore the actor. std::unordered_map checkpoint_id_to_restore_; + std::unique_ptr agent_manager_; + /// The RPC server. rpc::GrpcServer node_manager_server_; /// The node manager RPC service. rpc::NodeManagerGrpcService node_manager_service_; + /// The agent manager RPC service. + std::unique_ptr agent_manager_service_handler_; + rpc::AgentManagerGrpcService agent_manager_service_; + /// The `ClientCallManager` object that is shared by all `NodeManagerClient`s /// as well as all `CoreWorkerClient`s. rpc::ClientCallManager client_call_manager_; diff --git a/src/ray/rpc/agent_manager/agent_manager_client.h b/src/ray/rpc/agent_manager/agent_manager_client.h new file mode 100644 index 000000000..42a88d492 --- /dev/null +++ b/src/ray/rpc/agent_manager/agent_manager_client.h @@ -0,0 +1,50 @@ +// Copyright 2017 The Ray Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "ray/rpc/client_call.h" +#include "ray/rpc/grpc_client.h" +#include "src/ray/protobuf/agent_manager.grpc.pb.h" + +namespace ray { +namespace rpc { + +/// Client used for communicating with a remote agent manager server. +class AgentManagerClient { + public: + /// Constructor. + /// + /// \param[in] address Address of the agent manager server. + /// \param[in] port Port of the agent manager server. + /// \param[in] client_call_manager The `ClientCallManager` used for managing requests. + AgentManagerClient(const std::string &address, const int port, + ClientCallManager &client_call_manager) { + grpc_client_ = std::unique_ptr>( + new GrpcClient(address, port, client_call_manager)); + }; + + /// Register agent service to the agent manager server + /// + /// \param request The request message + /// \param callback The callback function that handles reply + VOID_RPC_CLIENT_METHOD(AgentManagerService, RegisterAgent, grpc_client_, ) + + private: + /// The RPC client. + std::unique_ptr> grpc_client_; +}; + +} // namespace rpc +} // namespace ray diff --git a/src/ray/rpc/agent_manager/agent_manager_server.h b/src/ray/rpc/agent_manager/agent_manager_server.h new file mode 100644 index 000000000..479c429f6 --- /dev/null +++ b/src/ray/rpc/agent_manager/agent_manager_server.h @@ -0,0 +1,73 @@ +// Copyright 2017 The Ray Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "ray/rpc/grpc_server.h" +#include "ray/rpc/server_call.h" +#include "src/ray/protobuf/agent_manager.grpc.pb.h" +#include "src/ray/protobuf/agent_manager.pb.h" + +namespace ray { +namespace rpc { + +#define RAY_AGENT_MANAGER_RPC_HANDLERS \ + RPC_SERVICE_HANDLER(AgentManagerService, RegisterAgent) + +/// Implementations of the `AgentManagerGrpcService`, check interface in +/// `src/ray/protobuf/agent_manager.proto`. +class AgentManagerServiceHandler { + public: + virtual ~AgentManagerServiceHandler() = default; + /// Handle a `RegisterAgent` request. + /// The implementation can handle this request asynchronously. When handling is done, + /// the `send_reply_callback` should be called. + /// + /// \param[in] request The request message. + /// \param[out] reply The reply message. + /// \param[in] send_reply_callback The callback to be called when the request is done. + virtual void HandleRegisterAgent(const RegisterAgentRequest &request, + RegisterAgentReply *reply, + SendReplyCallback send_reply_callback) = 0; +}; + +/// The `GrpcService` for `AgentManagerGrpcService`. +class AgentManagerGrpcService : public GrpcService { + public: + /// Construct a `AgentManagerGrpcService`. + /// + /// \param[in] port See `GrpcService`. + /// \param[in] handler The service handler that actually handle the requests. + AgentManagerGrpcService(boost::asio::io_service &io_service, + AgentManagerServiceHandler &service_handler) + : GrpcService(io_service), service_handler_(service_handler){}; + + protected: + grpc::Service &GetGrpcService() override { return service_; } + + void InitServerCallFactories( + const std::unique_ptr &cq, + std::vector> *server_call_factories) override { + RAY_AGENT_MANAGER_RPC_HANDLERS + } + + private: + /// The grpc async service object. + AgentManagerService::AsyncService service_; + /// The service handler that actually handle the requests. + AgentManagerServiceHandler &service_handler_; +}; + +} // namespace rpc +} // namespace ray diff --git a/src/ray/util/asio_util.h b/src/ray/util/asio_util.h index 1e50f8145..7affcca3d 100644 --- a/src/ray/util/asio_util.h +++ b/src/ray/util/asio_util.h @@ -16,8 +16,9 @@ #include -inline void execute_after(boost::asio::io_context &io_context, - const std::function &fn, uint32_t delay_milliseconds) { +inline std::shared_ptr execute_after( + boost::asio::io_context &io_context, const std::function &fn, + uint32_t delay_milliseconds) { auto timer = std::make_shared(io_context); timer->expires_from_now(boost::posix_time::milliseconds(delay_milliseconds)); timer->async_wait([timer, fn](const boost::system::error_code &error) { @@ -25,4 +26,5 @@ inline void execute_after(boost::asio::io_context &io_context, fn(); } }); + return timer; } diff --git a/src/ray/util/util.cc b/src/ray/util/util.cc index 09de3ed35..bf13fbbd5 100644 --- a/src/ray/util/util.cc +++ b/src/ray/util/util.cc @@ -187,6 +187,11 @@ static std::vector ParsePosixCommandLine(const std::string &s) { /// Python analog: None (would be shlex.split(s, posix=False), but it doesn't unquote) static std::vector ParseWindowsCommandLine(const std::string &s) { RAY_CHECK(s.find('\0') >= s.size()) << "Invalid null character in command line"; + // The if statement below may be incorrect. See: + // https://github.com/ray-project/ray/pull/10131#discussion_r473871563 + if (s.empty()) { + return {}; + } std::vector result; std::string arg, c_str = s + '\0'; std::string::const_iterator i = c_str.begin(), j = c_str.end() - 1;