[Dashboard] Start the new dashboard (#10131)

* Use new dashboard if environment var RAY_USE_NEW_DASHBOARD exists; new dashboard startup

* Make fake client/build/static directory for dashboard

* Add test_dashboard.py for new dashboard

* Travis CI enable new dashboard test

* Update new dashboard

* Agent manager service

* Add agent manager

* Register agent to agent manager

* Add a new line to the end of agent_manager.cc

* Fix merge; Fix lint

* Update dashboard/agent.py

Co-authored-by: SangBin Cho <rkooo567@gmail.com>

* Update dashboard/head.py

Co-authored-by: SangBin Cho <rkooo567@gmail.com>

* Fix bug

* Add tests for dashboard

* Fix

* Remove const from Process::Kill() & Fix bugs

* Revert error check of execute_after

* Raise exception from DashboardAgent.run

* Add more tests.

* Fix compile on Linux

* Use dict comprehension instead of dict(generator)

* Fix lint

* Fix windows compile

* Fix lint

* Test Windows CI

* Revert "Test Windows CI"

This reverts commit 945e01051ec95cff5fcc1c0bc37045b46e7ad9a6.

* Fix ParseWindowsCommandLine bug

* Update src/ray/util/util.cc

Co-authored-by: Robert Nishihara <robertnishihara@gmail.com>

Co-authored-by: 刘宝 <po.lb@antfin.com>
Co-authored-by: SangBin Cho <rkooo567@gmail.com>
Co-authored-by: Robert Nishihara <robertnishihara@gmail.com>
This commit is contained in:
fyrestone
2020-08-25 04:24:23 +08:00
committed by GitHub
parent 832f5cdccb
commit 05c103af94
35 changed files with 1079 additions and 103 deletions
+3
View File
@@ -353,6 +353,9 @@ script:
# ray dashboard tests
- if [ "$RAY_CI_DASHBOARD_AFFECTED" == "1" ]; then ./ci/keep_alive bazel test python/ray/dashboard/...; fi
# ray new dashboard tests
- if [ "$RAY_CI_DASHBOARD_AFFECTED" == "1" ]; then ./ci/keep_alive bazel test python/ray/new_dashboard/...; fi
# ray operator tests
- (cd deploy/ray-operator && export CC=gcc && suppress_output go build && suppress_output go test ./...)
+30
View File
@@ -175,6 +175,30 @@ cc_library(
],
)
# Agent manager.
cc_grpc_library(
name = "agent_manager_cc_grpc",
srcs = ["//src/ray/protobuf:agent_manager_proto"],
grpc_only = True,
deps = ["//src/ray/protobuf:agent_manager_cc_proto"],
)
cc_library(
name = "agent_manager_rpc",
hdrs = glob([
"src/ray/rpc/agent_manager/*.h",
]),
copts = COPTS,
strip_include_prefix = "src",
deps = [
":agent_manager_cc_grpc",
":grpc_common_lib",
":ray_common",
"@boost//:asio",
"@com_github_grpc_grpc//:grpc++",
],
)
# === End of rpc definitions ===
# === Begin of plasma definitions ===
@@ -415,6 +439,7 @@ cc_library(
copts = COPTS,
strip_include_prefix = "src",
deps = [
":agent_manager_rpc",
":gcs",
":gcs_pub_sub_lib",
":gcs_service_rpc",
@@ -506,6 +531,7 @@ cc_library(
strip_include_prefix = "src",
visibility = ["//streaming:__subpackages__"],
deps = [
":agent_manager_rpc",
":gcs",
":node_manager_fbs",
":node_manager_rpc",
@@ -549,6 +575,7 @@ cc_library(
strip_include_prefix = "src",
visibility = ["//streaming:__subpackages__"],
deps = [
":agent_manager_rpc",
":node_manager_fbs",
":node_manager_rpc",
":ray_common",
@@ -1472,6 +1499,7 @@ cc_library(
copts = COPTS,
strip_include_prefix = "src",
deps = [
":agent_manager_rpc",
":hiredis",
":node_manager_fbs",
":node_manager_rpc",
@@ -1811,9 +1839,11 @@ cc_binary(
filegroup(
name = "all_py_proto",
srcs = [
"//src/ray/protobuf:agent_manager_py_proto",
"//src/ray/protobuf:common_py_proto",
"//src/ray/protobuf:core_worker_py_proto",
"//src/ray/protobuf:gcs_py_proto",
"//src/ray/protobuf:gcs_service_py_proto",
"//src/ray/protobuf:node_manager_py_proto",
"//src/ray/protobuf:reporter_py_proto",
],
+2
View File
@@ -92,6 +92,8 @@ if __name__ == "__main__":
RAY_CI_MACOS_WHEELS_AFFECTED = 1
elif changed_file.startswith("python/ray/dashboard"):
RAY_CI_DASHBOARD_AFFECTED = 1
elif changed_file.startswith("dashboard"):
RAY_CI_DASHBOARD_AFFECTED = 1
elif changed_file.startswith("python/"):
RAY_CI_TUNE_AFFECTED = 1
RAY_CI_SGD_AFFECTED = 1
+13
View File
@@ -0,0 +1,13 @@
# This is a dummy test dependency that causes the above tests to be
# re-run if any of these files changes.
py_library(
name = "dashboard_lib",
srcs = glob(["**/*.py"],exclude=["tests/*"]),
)
py_test(
name = "test_dashboard",
size = "small",
srcs = glob(["tests/*.py"]),
deps = [":dashboard_lib"]
)
+94 -28
View File
@@ -4,10 +4,14 @@ import logging
import logging.handlers
import os
import sys
import socket
import json
import traceback
import aiohttp
import aioredis
import aiohttp.web
import aiohttp_cors
from aiohttp import hdrs
from grpc.experimental import aio as aiogrpc
import ray
@@ -16,9 +20,12 @@ import ray.new_dashboard.utils as dashboard_utils
import ray.ray_constants as ray_constants
import ray.services
import ray.utils
from ray.core.generated import agent_manager_pb2
from ray.core.generated import agent_manager_pb2_grpc
import psutil
logger = logging.getLogger(__name__)
routes = dashboard_utils.ClassMethodRouteTable
aiogrpc.init_grpc_aio()
@@ -33,11 +40,8 @@ class DashboardAgent(object):
object_store_name=None,
raylet_name=None):
"""Initialize the DashboardAgent object."""
self._agent_cls_list = dashboard_utils.get_all_modules(
dashboard_utils.DashboardAgentModule)
ip, port = redis_address.split(":")
# Public attributes are accessible for all agent modules.
self.redis_address = (ip, int(port))
self.redis_address = dashboard_utils.address_tuple(redis_address)
self.redis_password = redis_password
self.temp_dir = temp_dir
self.log_dir = log_dir
@@ -46,39 +50,29 @@ class DashboardAgent(object):
self.raylet_name = raylet_name
self.ip = ray.services.get_node_ip_address()
self.server = aiogrpc.server(options=(("grpc.so_reuseport", 0), ))
listen_address = "[::]:0"
logger.info("Dashboard agent listen at: %s", listen_address)
self.port = self.server.add_insecure_port(listen_address)
self.grpc_port = self.server.add_insecure_port("[::]:0")
logger.info("Dashboard agent grpc address: %s:%s", self.ip,
self.grpc_port)
self.aioredis_client = None
self.aiogrpc_raylet_channel = aiogrpc.insecure_channel("{}:{}".format(
self.ip, self.node_manager_port))
self.http_session = aiohttp.ClientSession(
loop=asyncio.get_event_loop())
self.http_session = None
def _load_modules(self):
"""Load dashboard agent modules."""
modules = []
for cls in self._agent_cls_list:
logger.info("Load %s: %s",
agent_cls_list = dashboard_utils.get_all_modules(
dashboard_utils.DashboardAgentModule)
for cls in agent_cls_list:
logger.info("Loading %s: %s",
dashboard_utils.DashboardAgentModule.__name__, cls)
c = cls(self)
dashboard_utils.ClassMethodRouteTable.bind(c)
modules.append(c)
logger.info("Load {} modules.".format(len(modules)))
logger.info("Loaded {} modules.".format(len(modules)))
return modules
async def run(self):
# Create an aioredis client for all modules.
self.aioredis_client = await aioredis.create_redis_pool(
address=self.redis_address, password=self.redis_password)
# Start a grpc asyncio server.
await self.server.start()
# Write the dashboard agent port to redis.
await self.aioredis_client.set(
"{}{}".format(dashboard_consts.DASHBOARD_AGENT_PORT_PREFIX,
self.ip), self.port)
async def _check_parent():
"""Check if raylet is dead."""
curr_proc = psutil.Process()
@@ -92,10 +86,83 @@ class DashboardAgent(object):
dashboard_consts.
DASHBOARD_AGENT_CHECK_PARENT_INTERVAL_SECONDS)
check_parent_task = asyncio.create_task(_check_parent())
# Create an aioredis client for all modules.
try:
self.aioredis_client = await dashboard_utils.get_aioredis_client(
self.redis_address, self.redis_password,
dashboard_consts.CONNECT_REDIS_INTERNAL_SECONDS,
dashboard_consts.RETRY_REDIS_CONNECTION_TIMES)
except (socket.gaierror, ConnectionRefusedError):
logger.error(
"Dashboard agent exiting: "
"Failed to connect to redis at %s", self.redis_address)
sys.exit(-1)
# Create a http session for all modules.
self.http_session = aiohttp.ClientSession(
loop=asyncio.get_event_loop())
# Start a grpc asyncio server.
await self.server.start()
modules = self._load_modules()
await asyncio.gather(_check_parent(),
# Http server should be initialized after all modules loaded.
app = aiohttp.web.Application()
app.add_routes(routes=routes.bound_routes())
# Enable CORS on all routes.
cors = aiohttp_cors.setup(
app,
defaults={
"*": aiohttp_cors.ResourceOptions(
allow_credentials=True,
expose_headers="*",
allow_methods="*",
allow_headers=("Content-Type", "X-Header"),
)
})
for route in list(app.router.routes()):
cors.add(route)
runner = aiohttp.web.AppRunner(app)
await runner.setup()
site = aiohttp.web.TCPSite(runner, self.ip, 0)
await site.start()
http_host, http_port = site._server.sockets[0].getsockname()
logger.info("Dashboard agent http address: %s:%s", http_host,
http_port)
# Dump registered http routes.
dump_routes = [
r for r in app.router.routes() if r.method != hdrs.METH_HEAD
]
for r in dump_routes:
logger.info(r)
logger.info("Registered %s routes.", len(dump_routes))
# Write the dashboard agent port to redis.
await self.aioredis_client.set(
"{}{}".format(dashboard_consts.DASHBOARD_AGENT_PORT_PREFIX,
self.ip), json.dumps([http_port, self.grpc_port]))
# Register agent to agent manager.
raylet_stub = agent_manager_pb2_grpc.AgentManagerServiceStub(
self.aiogrpc_raylet_channel)
await raylet_stub.RegisterAgent(
agent_manager_pb2.RegisterAgentRequest(
agent_pid=os.getpid(),
agent_port=self.grpc_port,
agent_ip_address=self.ip))
await asyncio.gather(check_parent_task,
*(m.run(self.server) for m in modules))
await self.server.wait_for_termination()
# Wait for finish signal.
await runner.cleanup()
if __name__ == "__main__":
@@ -215,8 +282,7 @@ if __name__ == "__main__":
raylet_name=args.raylet_name)
loop = asyncio.get_event_loop()
loop.create_task(agent.run())
loop.run_forever()
loop.run_until_complete(agent.run())
except Exception as e:
# Something went wrong, so push an error to all drivers.
redis_client = ray.services.create_redis_client(
View File
+5 -1
View File
@@ -1,16 +1,20 @@
DASHBOARD_LOG_FILENAME = "dashboard.log"
DASHBOARD_AGENT_PORT_PREFIX = "DASHBOARD_AGENT_PORT_PREFIX:"
DASHBOARD_AGENT_LOG_FILENAME = "dashboard_agent.log"
DASHBOARD_AGENT_CHECK_PARENT_INTERVAL_SECONDS = 2
MAX_COUNT_OF_GCS_RPC_ERROR = 10
RETRY_REDIS_CONNECTION_TIMES = 10
UPDATE_NODES_INTERVAL_SECONDS = 5
CONNECT_GCS_INTERVAL_SECONDS = 2
CONNECT_REDIS_INTERNAL_SECONDS = 2
PURGE_DATA_INTERVAL_SECONDS = 60 * 10
REDIS_KEY_DASHBOARD = "dashboard"
REDIS_KEY_DASHBOARD_RPC = "dashboard_rpc"
REDIS_KEY_GCS_SERVER_ADDRESS = "GcsServerAddress"
REPORT_METRICS_TIMEOUT_SECONDS = 2
REPORT_METRICS_INTERVAL_SECONDS = 10
# Named signals
SIGNAL_NODE_INFO_FETCHED = "node_info_fetched"
# Default param for RotatingFileHandler
LOGGING_ROTATE_BYTES = 100 * 1000 # maxBytes
LOGGING_ROTATE_BYTES = 100 * 1000 * 1000 # maxBytes
LOGGING_ROTATE_BACKUP_COUNT = 5 # backupCount
+12 -37
View File
@@ -13,11 +13,7 @@ import logging
import logging.handlers
import os
import traceback
import uuid
import aioredis
import ray
import ray.new_dashboard.consts as dashboard_consts
import ray.new_dashboard.head as dashboard_head
import ray.new_dashboard.utils as dashboard_utils
@@ -32,7 +28,7 @@ logger = logging.getLogger(__name__)
routes = dashboard_utils.ClassMethodRouteTable
def setup_static_dir(app):
def setup_static_dir():
build_dir = os.path.join(
os.path.dirname(os.path.abspath(__file__)), "client/build")
module_name = os.path.basename(os.path.dirname(__file__))
@@ -47,7 +43,7 @@ def setup_static_dir(app):
"&& npm run build)".format(module_name), build_dir)
static_dir = os.path.join(build_dir, "static")
app.router.add_static("/static", static_dir, follow_symlinks=True)
routes.static("/static", static_dir, follow_symlinks=True)
return build_dir
@@ -62,29 +58,18 @@ class Dashboard:
host(str): Host address of dashboard aiohttp server.
port(int): Port number of dashboard aiohttp server.
redis_address(str): GCS address of a Ray cluster
temp_dir (str): The temporary directory used for log files and
information for this Ray session.
redis_password(str): Redis password to access GCS
"""
def __init__(self,
host,
port,
redis_address,
temp_dir,
redis_password=None):
self.host = host
self.port = port
self.temp_dir = temp_dir
self.dashboard_id = str(uuid.uuid4())
def __init__(self, host, port, redis_address, redis_password=None):
self.dashboard_head = dashboard_head.DashboardHead(
redis_address=redis_address, redis_password=redis_password)
self.app = aiohttp.web.Application()
self.app.add_routes(routes=routes.routes())
http_host=host,
http_port=port,
redis_address=redis_address,
redis_password=redis_password)
# Setup Dashboard Routes
build_dir = setup_static_dir(self.app)
build_dir = setup_static_dir()
logger.info("Setup static dir for dashboard: %s", build_dir)
dashboard_utils.ClassMethodRouteTable.bind(self)
@@ -103,17 +88,7 @@ class Dashboard:
"client/build/favicon.ico"))
async def run(self):
coroutines = [
self.dashboard_head.run(),
aiohttp.web._run_app(self.app, host=self.host, port=self.port)
]
ip = ray.services.get_node_ip_address()
aioredis_client = await aioredis.create_redis_pool(
address=self.dashboard_head.redis_address,
password=self.dashboard_head.redis_password)
await aioredis_client.set(dashboard_consts.REDIS_KEY_DASHBOARD,
ip + ":" + str(self.port))
await asyncio.gather(*coroutines)
await self.dashboard_head.run()
if __name__ == "__main__":
@@ -158,9 +133,10 @@ if __name__ == "__main__":
"--logging-filename",
required=False,
type=str,
default="",
default=dashboard_consts.DASHBOARD_LOG_FILENAME,
help="Specify the name of log file, "
"log to stdout if set empty, default is \"\"")
"log to stdout if set empty, default is \"{}\"".format(
dashboard_consts.DASHBOARD_LOG_FILENAME))
parser.add_argument(
"--logging-rotate-bytes",
required=False,
@@ -221,7 +197,6 @@ if __name__ == "__main__":
args.host,
args.port,
args.redis_address,
temp_dir,
redis_password=args.redis_password)
loop = asyncio.get_event_loop()
loop.run_until_complete(dashboard.run())
+65 -15
View File
@@ -1,9 +1,12 @@
import sys
import socket
import json
import asyncio
import logging
import aiohttp
import aioredis
import aiohttp.web
from aiohttp import hdrs
from grpc.experimental import aio as aiogrpc
import ray.services
@@ -25,22 +28,25 @@ def gcs_node_info_to_dict(message):
class DashboardHead:
def __init__(self, redis_address, redis_password):
# Scan and import head modules for collecting http routes.
self._head_cls_list = dashboard_utils.get_all_modules(
dashboard_utils.DashboardHeadModule)
ip, port = redis_address.split(":")
def __init__(self, http_host, http_port, redis_address, redis_password):
# NodeInfoGcsService
self._gcs_node_info_stub = None
self._gcs_rpc_error_counter = 0
# Public attributes are accessible for all head modules.
self.redis_address = (ip, int(port))
self.http_host = http_host
self.http_port = http_port
self.redis_address = dashboard_utils.address_tuple(redis_address)
self.redis_password = redis_password
self.aioredis_client = None
self.aiogrpc_gcs_channel = None
self.http_session = aiohttp.ClientSession(
loop=asyncio.get_event_loop())
self.http_session = None
self.ip = ray.services.get_node_ip_address()
self.server = aiogrpc.server(options=(("grpc.so_reuseport", 0), ))
self.grpc_port = self.server.add_insecure_port("[::]:0")
logger.info("Dashboard head grpc address: %s:%s", self.ip,
self.grpc_port)
logger.info("Dashboard head http address: %s:%s", self.http_host,
self.http_port)
async def _get_nodes(self):
"""Read the client table.
@@ -83,7 +89,7 @@ class DashboardHead:
node_ip)
agent_port = await self.aioredis_client.get(key)
if agent_port:
agents[node_ip] = agent_port
agents[node_ip] = json.loads(agent_port)
for ip in agents.keys() - set(node_ips):
agents.pop(ip, None)
@@ -112,18 +118,34 @@ class DashboardHead:
def _load_modules(self):
"""Load dashboard head modules."""
modules = []
for cls in self._head_cls_list:
logger.info("Load %s: %s",
head_cls_list = dashboard_utils.get_all_modules(
dashboard_utils.DashboardHeadModule)
for cls in head_cls_list:
logger.info("Loading %s: %s",
dashboard_utils.DashboardHeadModule.__name__, cls)
c = cls(self)
dashboard_utils.ClassMethodRouteTable.bind(c)
modules.append(c)
logger.info("Loaded {} modules.".format(len(modules)))
return modules
async def run(self):
# Create an aioredis client for all modules.
self.aioredis_client = await aioredis.create_redis_pool(
address=self.redis_address, password=self.redis_password)
try:
self.aioredis_client = await dashboard_utils.get_aioredis_client(
self.redis_address, self.redis_password,
dashboard_consts.CONNECT_REDIS_INTERNAL_SECONDS,
dashboard_consts.RETRY_REDIS_CONNECTION_TIMES)
except (socket.gaierror, ConnectionError):
logger.error(
"Dashboard head exiting: "
"Failed to connect to redis at %s", self.redis_address)
sys.exit(-1)
# Create a http session for all modules.
self.http_session = aiohttp.ClientSession(
loop=asyncio.get_event_loop())
# Waiting for GCS is ready.
while True:
try:
@@ -140,10 +162,21 @@ class DashboardHead:
else:
self.aiogrpc_gcs_channel = channel
break
# Create a NodeInfoGcsServiceStub.
self._gcs_node_info_stub = gcs_service_pb2_grpc.NodeInfoGcsServiceStub(
self.aiogrpc_gcs_channel)
# Start a grpc asyncio server.
await self.server.start()
# Write the dashboard head port to redis.
await self.aioredis_client.set(dashboard_consts.REDIS_KEY_DASHBOARD,
self.ip + ":" + str(self.http_port))
await self.aioredis_client.set(
dashboard_consts.REDIS_KEY_DASHBOARD_RPC,
self.ip + ":" + str(self.grpc_port))
async def _async_notify():
"""Notify signals from queue."""
while True:
@@ -164,7 +197,24 @@ class DashboardHead:
logger.exception(e)
modules = self._load_modules()
# Http server should be initialized after all modules loaded.
app = aiohttp.web.Application()
app.add_routes(routes=routes.bound_routes())
web_server = aiohttp.web._run_app(
app, host=self.http_host, port=self.http_port)
# Dump registered http routes.
dump_routes = [
r for r in app.router.routes() if r.method != hdrs.METH_HEAD
]
for r in dump_routes:
logger.info(r)
logger.info("Registered %s routes.", len(dump_routes))
# Freeze signal after all modules loaded.
dashboard_utils.SignalManager.freeze()
await asyncio.gather(self._update_nodes(), _async_notify(),
_purge_data(), *(m.run() for m in modules))
_purge_data(), web_server,
*(m.run(self.server) for m in modules))
await self.server.wait_for_termination()
@@ -21,6 +21,13 @@ import psutil
logger = logging.getLogger(__name__)
try:
import gpustat.core as gpustat
except ImportError:
gpustat = None
logger.warning(
"Install gpustat with 'pip install gpustat' to enable GPU monitoring.")
def recursive_asdict(o):
if isinstance(o, tuple) and hasattr(o, "_asdict"):
@@ -81,10 +88,35 @@ class ReporterAgent(dashboard_utils.DashboardAgentModule,
return reporter_pb2.GetProfilingStatsReply(
profiling_stats=profiling_stats, stdout=stdout, stderr=stderr)
async def ReportMetrics(self, request, context):
# TODO(sang): Process metrics here.
return reporter_pb2.ReportMetricsReply()
@staticmethod
def _get_cpu_percent():
return psutil.cpu_percent()
@staticmethod
def _get_gpu_usage():
if gpustat is None:
return []
gpu_utilizations = []
gpus = []
try:
gpus = gpustat.new_query().gpus
except Exception as e:
logger.debug(
"gpustat failed to retrieve GPU information: {}".format(e))
for gpu in gpus:
# Note the keys in this dict have periods which throws
# off javascript so we change .s to _s
gpu_data = {
"_".join(key.split(".")): val
for key, val in gpu.entry.items()
}
gpu_utilizations.append(gpu_data)
return gpu_utilizations
@staticmethod
def _get_boot_time():
return psutil.boot_time()
@@ -173,6 +205,7 @@ class ReporterAgent(dashboard_utils.DashboardAgentModule,
"bootTime": self._get_boot_time(),
"loadAvg": self._get_load_avg(),
"disk": self._get_disk_usage(),
"gpus": self._get_gpu_usage(),
"net": netstats,
"cmdline": self._get_raylet_cmdline(),
}
+7 -7
View File
@@ -29,8 +29,8 @@ class ReportHead(dashboard_utils.DashboardHeadModule):
async def _update_stubs(self, change):
if change.new:
ip, port = next(iter(change.new.items()))
channel = aiogrpc.insecure_channel("{}:{}".format(ip, int(port)))
ip, ports = next(iter(change.new.items()))
channel = aiogrpc.insecure_channel("{}:{}".format(ip, ports[1]))
stub = reporter_pb2_grpc.ReporterServiceStub(channel)
self._stubs[ip] = stub
if change.old:
@@ -77,15 +77,15 @@ class ReportHead(dashboard_utils.DashboardHeadModule):
message="Profiling info fetched.",
profiling_info=json.loads(profiling_stats.profiling_stats))
async def run(self):
p = self._dashboard_head.aioredis_client
mpsc = Receiver()
async def run(self, server):
aioredis_client = self._dashboard_head.aioredis_client
receiver = Receiver()
reporter_key = "{}*".format(reporter_consts.REPORTER_PREFIX)
await p.psubscribe(mpsc.pattern(reporter_key))
await aioredis_client.psubscribe(receiver.pattern(reporter_key))
logger.info("Subscribed to {}".format(reporter_key))
async for sender, msg in mpsc.iter():
async for sender, msg in receiver.iter():
try:
_, data = msg
data = json.loads(ray.utils.decode(data))
View File
+24
View File
@@ -0,0 +1,24 @@
import logging
import aiohttp.web
import ray.new_dashboard.utils as dashboard_utils
import ray.new_dashboard.modules.test.test_utils as test_utils
logger = logging.getLogger(__name__)
routes = dashboard_utils.ClassMethodRouteTable
class HeadAgent(dashboard_utils.DashboardAgentModule):
def __init__(self, dashboard_agent):
super().__init__(dashboard_agent)
@routes.get("/test/http_get_from_agent")
async def get_url(self, req) -> aiohttp.web.Response:
url = req.query.get("url")
result = await test_utils.http_get(self._dashboard_agent.http_session,
url)
return aiohttp.web.json_response(result)
async def run(self, server):
pass
+62
View File
@@ -0,0 +1,62 @@
import logging
import aiohttp.web
import ray.new_dashboard.utils as dashboard_utils
import ray.new_dashboard.modules.test.test_utils as test_utils
from ray.new_dashboard.datacenter import DataSource
logger = logging.getLogger(__name__)
routes = dashboard_utils.ClassMethodRouteTable
class TestHead(dashboard_utils.DashboardHeadModule):
def __init__(self, dashboard_head):
super().__init__(dashboard_head)
self._notified_agents = {}
DataSource.agents.signal.append(self._update_notified_agents)
async def _update_notified_agents(self, change):
if change.new:
ip, ports = next(iter(change.new.items()))
self._notified_agents[ip] = ports
if change.old:
ip, port = next(iter(change.old.items()))
self._notified_agents.pop(ip)
@routes.get("/test/dump")
async def dump(self, req) -> aiohttp.web.Response:
key = req.query.get("key")
if key is None:
all_data = {
k: dict(v)
for k, v in DataSource.__dict__.items()
if not k.startswith("_")
}
return await dashboard_utils.rest_response(
success=True,
message="Fetch all data from datacenter success.",
**all_data)
else:
data = dict(DataSource.__dict__.get(key))
return await dashboard_utils.rest_response(
success=True,
message="Fetch {} from datacenter success.".format(key),
**{key: data})
@routes.get("/test/notified_agents")
async def get_notified_agents(self, req) -> aiohttp.web.Response:
return await dashboard_utils.rest_response(
success=True,
message="Fetch notified agents success.",
**self._notified_agents)
@routes.get("/test/http_get")
async def get_url(self, req) -> aiohttp.web.Response:
url = req.query.get("url")
result = await test_utils.http_get(self._dashboard_head.http_session,
url)
return aiohttp.web.json_response(result)
async def run(self, server):
pass
+11
View File
@@ -0,0 +1,11 @@
import logging
import async_timeout
logger = logging.getLogger(__name__)
async def http_get(http_session, url, timeout_seconds=60):
with async_timeout.timeout(timeout_seconds):
async with http_session.get(url) as response:
return await response.json()
+1
View File
@@ -0,0 +1 @@
from ray.tests.conftest import * # noqa
+224
View File
@@ -0,0 +1,224 @@
import os
import json
import time
import logging
import ray
import psutil
import pytest
import redis
import requests
from ray import ray_constants
from ray.test_utils import wait_for_condition, wait_until_server_available
import ray.new_dashboard.consts as dashboard_consts
import ray.new_dashboard.modules
os.environ["RAY_USE_NEW_DASHBOARD"] = "1"
logger = logging.getLogger(__name__)
def cleanup_test_files():
module_path = ray.new_dashboard.modules.__path__[0]
filename = os.path.join(module_path, "test_for_bad_import.py")
logger.info("Remove test file: %s", filename)
try:
os.remove(filename)
except Exception:
pass
def prepare_test_files():
module_path = ray.new_dashboard.modules.__path__[0]
filename = os.path.join(module_path, "test_for_bad_import.py")
logger.info("Prepare test file: %s", filename)
with open(filename, "w") as f:
f.write(">>>")
cleanup_test_files()
@pytest.mark.parametrize(
"ray_start_with_dashboard", [{
"_internal_config": json.dumps({
"agent_register_timeout_ms": 5000
})
}],
indirect=True)
def test_basic(ray_start_with_dashboard):
"""Dashboard test that starts a Ray cluster with a dashboard server running,
then hits the dashboard API and asserts that it receives sensible data."""
assert (wait_until_server_available(ray_start_with_dashboard["webui_url"])
is True)
address_info = ray_start_with_dashboard
address = address_info["redis_address"]
address = address.split(":")
assert len(address) == 2
client = redis.StrictRedis(
host=address[0],
port=int(address[1]),
password=ray_constants.REDIS_DEFAULT_PASSWORD)
all_processes = ray.worker._global_node.all_processes
assert ray_constants.PROCESS_TYPE_DASHBOARD in all_processes
assert ray_constants.PROCESS_TYPE_REPORTER not in all_processes
dashboard_proc_info = all_processes[ray_constants.PROCESS_TYPE_DASHBOARD][
0]
dashboard_proc = psutil.Process(dashboard_proc_info.process.pid)
assert dashboard_proc.status() == psutil.STATUS_RUNNING
raylet_proc_info = all_processes[ray_constants.PROCESS_TYPE_RAYLET][0]
raylet_proc = psutil.Process(raylet_proc_info.process.pid)
def _search_agent(processes):
for p in processes:
try:
for c in p.cmdline():
if "new_dashboard/agent.py" in c:
return p
except Exception:
pass
# Test for bad imports, the agent should be restarted.
logger.info("Test for bad imports.")
agent_proc = _search_agent(raylet_proc.children())
prepare_test_files()
agent_pids = set()
try:
assert agent_proc is not None
agent_proc.kill()
agent_proc.wait()
# The agent will be restarted for imports failure.
for x in range(40):
agent_proc = _search_agent(raylet_proc.children())
if agent_proc:
agent_pids.add(agent_proc.pid)
time.sleep(0.1)
finally:
cleanup_test_files()
assert len(agent_pids) > 1, agent_pids
agent_proc = _search_agent(raylet_proc.children())
if agent_proc:
agent_proc.kill()
agent_proc.wait()
logger.info("Test agent register is OK.")
wait_for_condition(lambda: _search_agent(raylet_proc.children()))
assert dashboard_proc.status() == psutil.STATUS_RUNNING
agent_proc = _search_agent(raylet_proc.children())
agent_pid = agent_proc.pid
# Check if agent register is OK.
for x in range(5):
logger.info("Check agent is alive.")
agent_proc = _search_agent(raylet_proc.children())
assert agent_proc.pid == agent_pid
time.sleep(1)
# Check redis keys are set.
logger.info("Check redis keys are set.")
dashboard_address = client.get(dashboard_consts.REDIS_KEY_DASHBOARD)
assert dashboard_address is not None
dashboard_rpc_address = client.get(
dashboard_consts.REDIS_KEY_DASHBOARD_RPC)
assert dashboard_rpc_address is not None
key = "{}{}".format(dashboard_consts.DASHBOARD_AGENT_PORT_PREFIX,
address[0])
agent_ports = client.get(key)
assert agent_ports is not None
def test_nodes_update(ray_start_with_dashboard):
assert (wait_until_server_available(ray_start_with_dashboard["webui_url"])
is True)
webui_url = ray_start_with_dashboard["webui_url"]
webui_url = webui_url.replace("localhost", "http://127.0.0.1")
timeout_seconds = 20
start_time = time.time()
while True:
time.sleep(1)
try:
response = requests.get(webui_url + "/test/dump")
response.raise_for_status()
try:
dump_info = response.json()
except Exception as ex:
logger.info("failed response: {}".format(response.text))
raise ex
assert dump_info["result"] is True
dump_data = dump_info["data"]
assert len(dump_data["nodes"]) == 1
assert len(dump_data["agents"]) == 1
assert len(dump_data["hostnameToIp"]) == 1
assert len(dump_data["ipToHostname"]) == 1
assert dump_data["nodes"].keys() == dump_data[
"ipToHostname"].keys()
response = requests.get(webui_url + "/test/notified_agents")
response.raise_for_status()
try:
notified_agents = response.json()
except Exception as ex:
logger.info("failed response: {}".format(response.text))
raise ex
assert notified_agents["result"] is True
notified_agents = notified_agents["data"]
assert len(notified_agents) == 1
assert notified_agents == dump_data["agents"]
break
except (AssertionError, requests.exceptions.ConnectionError) as e:
logger.info("Retry because of %s", e)
finally:
if time.time() > start_time + timeout_seconds:
raise Exception(
"Timed out while waiting for dashboard to start.")
def test_http_get(ray_start_with_dashboard):
assert (wait_until_server_available(ray_start_with_dashboard["webui_url"])
is True)
webui_url = ray_start_with_dashboard["webui_url"]
webui_url = webui_url.replace("localhost", "http://127.0.0.1")
target_url = webui_url + "/test/dump"
timeout_seconds = 20
start_time = time.time()
while True:
time.sleep(1)
try:
response = requests.get(webui_url + "/test/http_get?url=" +
target_url)
response.raise_for_status()
try:
dump_info = response.json()
except Exception as ex:
logger.info("failed response: {}".format(response.text))
raise ex
assert dump_info["result"] is True
dump_data = dump_info["data"]
assert len(dump_data["agents"]) == 1
ip, ports = next(iter(dump_data["agents"].items()))
http_port, grpc_port = ports
response = requests.get(
"http://{}:{}/test/http_get_from_agent?url={}".format(
ip, http_port, target_url))
response.raise_for_status()
try:
dump_info = response.json()
except Exception as ex:
logger.info("failed response: {}".format(response.text))
raise ex
assert dump_info["result"] is True
break
except (AssertionError, requests.exceptions.ConnectionError) as e:
logger.info("Retry because of %s", e)
finally:
if time.time() > start_time + timeout_seconds:
raise Exception(
"Timed out while waiting for dashboard to start.")
+51 -2
View File
@@ -1,4 +1,5 @@
import abc
import socket
import asyncio
import collections
import copy
@@ -12,10 +13,14 @@ import pkgutil
import traceback
from base64 import b64decode
from collections.abc import MutableMapping, Mapping
from typing import Any
import aioredis
import aiohttp.web
from aiohttp import hdrs
from aiohttp.frozenlist import FrozenList
from aiohttp.typedefs import PathLike
from aiohttp.web import RouteDef
import aiohttp.signals
from google.protobuf.json_format import MessageToDict
from ray.utils import binary_to_hex
@@ -52,9 +57,12 @@ class DashboardHeadModule(abc.ABC):
self._dashboard_head = dashboard_head
@abc.abstractmethod
async def run(self):
async def run(self, server):
"""
Run the module in an asyncio loop.
Run the module in an asyncio loop. A head module can provide
servicers to the server.
:param server: Asyncio GRPC server.
"""
@@ -74,6 +82,22 @@ class ClassMethodRouteTable:
def routes(cls):
return cls._routes
@classmethod
def bound_routes(cls):
bound_items = []
for r in cls._routes._items:
if isinstance(r, RouteDef):
route_method = getattr(r.handler, "__route_method__")
route_path = getattr(r.handler, "__route_path__")
instance = cls._bind_map[route_method][route_path].instance
if instance is not None:
bound_items.append(r)
else:
bound_items.append(r)
routes = aiohttp.web.RouteTableDef()
routes._items = bound_items
return routes
@classmethod
def _register_route(cls, method, path, **kwargs):
def _wrapper(handler):
@@ -132,6 +156,10 @@ class ClassMethodRouteTable:
def view(cls, path, **kwargs):
return cls._register_route(hdrs.METH_ANY, path, **kwargs)
@classmethod
def static(cls, prefix: str, path: PathLike, **kwargs: Any) -> None:
cls._routes.static(prefix, path, **kwargs)
@classmethod
def bind(cls, instance):
def predicate(o):
@@ -161,6 +189,13 @@ def to_posix_time(dt):
return (dt - datetime.datetime(1970, 1, 1)).total_seconds()
def address_tuple(address):
if isinstance(address, tuple):
return address
ip, port = address.split(":")
return ip, int(port)
class CustomEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, bytes):
@@ -338,3 +373,17 @@ class Dict(MutableMapping):
for key in self._data.keys() - d.keys():
self.pop(key)
self.update(d)
async def get_aioredis_client(redis_address, redis_password,
retry_interval_seconds, retry_times):
for x in range(retry_times):
try:
return await aioredis.create_redis_pool(
address=redis_address, password=redis_password)
except (socket.gaierror, ConnectionError) as ex:
logger.error("Connect to Redis failed: %s, retry...", ex)
await asyncio.sleep(retry_interval_seconds)
# Raise exception from create_redis_pool
return await aioredis.create_redis_pool(
address=redis_address, password=redis_password)
+7 -3
View File
@@ -615,8 +615,11 @@ class Node:
if we fail to start the dashboard. Otherwise it will print
a warning if we fail to start the dashboard.
"""
stdout_file, stderr_file = self.get_log_file_handles(
"dashboard", unique=True)
if "RAY_USE_NEW_DASHBOARD" in os.environ:
stdout_file, stderr_file = None, None
else:
stdout_file, stderr_file = self.get_log_file_handles(
"dashboard", unique=True)
self._webui_url, process_info = ray.services.start_dashboard(
require_dashboard,
self._ray_params.dashboard_host,
@@ -797,7 +800,8 @@ class Node:
self.start_plasma_store()
self.start_raylet()
self.start_reporter()
if "RAY_USE_NEW_DASHBOARD" not in os.environ:
self.start_reporter()
if self._ray_params.include_log_monitor:
self.start_log_monitor()
+27 -1
View File
@@ -1160,8 +1160,14 @@ def start_dashboard(require_dashboard,
raise ValueError(
f"The given dashboard port {port} is already in use")
if "RAY_USE_NEW_DASHBOARD" in os.environ:
dashboard_dir = "new_dashboard"
else:
dashboard_dir = "dashboard"
dashboard_filepath = os.path.join(
os.path.dirname(os.path.abspath(__file__)), "dashboard/dashboard.py")
os.path.dirname(os.path.abspath(__file__)), dashboard_dir,
"dashboard.py")
command = [
sys.executable,
"-u",
@@ -1398,6 +1404,23 @@ def start_raylet(redis_address,
start_worker_command.append(
f"--object-spilling-config={json.dumps(object_spilling_config)}")
# Create agent command
agent_command = [
sys.executable,
"-u",
os.path.join(
os.path.dirname(os.path.abspath(__file__)),
"new_dashboard/agent.py"),
"--redis-address={}".format(redis_address),
"--node-manager-port={}".format(node_manager_port),
"--object-store-name={}".format(plasma_store_name),
"--raylet-name={}".format(raylet_name),
"--temp-dir={}".format(temp_dir),
]
if redis_password is not None and len(redis_password) != 0:
agent_command.append("--redis-password={}".format(redis_password))
command = [
RAYLET_EXECUTABLE,
f"--raylet_socket_name={raylet_name}",
@@ -1424,6 +1447,9 @@ def start_raylet(redis_address,
if start_initial_python_workers_for_first_job:
command.append("--num_initial_python_workers_for_first_job={}".format(
resource_spec.num_cpus))
if "RAY_USE_NEW_DASHBOARD" in os.environ:
command.append("--agent_command={}".format(
subprocess.list2cmdline(agent_command)))
if config.get("plasma_store_as_thread"):
# command related to the plasma store
plasma_directory, object_store_memory = determine_plasma_store_config(
+2 -1
View File
@@ -133,6 +133,7 @@ extras["all"] = list(set(chain.from_iterable(extras.values())))
# the change in the matching section of requirements.txt
install_requires = [
"aiohttp",
"aiohttp_cors",
"aioredis",
"click >= 7.0",
"colorama",
@@ -408,7 +409,7 @@ def api_main(program, *args):
nonlocal result
if excinfo[1].errno != errno.ENOENT:
msg = excinfo[1].strerror
logger.error("cannot remove {}: {}" % (path, msg))
logger.error("cannot remove {}: {}".format(path, msg))
result = 1
for subdir in CLEANABLE_SUBDIRS:
+6
View File
@@ -328,6 +328,12 @@ RAY_CONFIG(bool, put_small_object_in_memory_store, false)
/// pipelining task submission.
RAY_CONFIG(uint32_t, max_tasks_in_flight_per_worker, 1)
/// Interval to restart dashboard agent after the process exit.
RAY_CONFIG(uint32_t, agent_restart_interval_ms, 1000)
/// Wait timeout for dashboard agent register.
RAY_CONFIG(uint32_t, agent_register_timeout_ms, 30 * 1000)
/// The maximum number of resource shapes included in the resource
/// load reported by each raylet.
RAY_CONFIG(int64_t, max_resource_shapes_per_load_report, 100)
@@ -249,9 +249,9 @@ void GcsActorScheduler::LeaseWorkerFromNode(std::shared_ptr<GcsActor> actor,
void GcsActorScheduler::RetryLeasingWorkerFromNode(
std::shared_ptr<GcsActor> actor, std::shared_ptr<rpc::GcsNodeInfo> node) {
execute_after(io_context_,
[this, node, actor] { DoRetryLeasingWorkerFromNode(actor, node); },
RayConfig::instance().gcs_lease_worker_retry_interval_ms());
RAY_UNUSED(execute_after(
io_context_, [this, node, actor] { DoRetryLeasingWorkerFromNode(actor, node); },
RayConfig::instance().gcs_lease_worker_retry_interval_ms()));
}
void GcsActorScheduler::DoRetryLeasingWorkerFromNode(
@@ -370,9 +370,9 @@ void GcsActorScheduler::CreateActorOnWorker(std::shared_ptr<GcsActor> actor,
void GcsActorScheduler::RetryCreatingActorOnWorker(
std::shared_ptr<GcsActor> actor, std::shared_ptr<GcsLeasedWorker> worker) {
execute_after(io_context_,
[this, actor, worker] { DoRetryCreatingActorOnWorker(actor, worker); },
RayConfig::instance().gcs_create_actor_retry_interval_ms());
RAY_UNUSED(execute_after(
io_context_, [this, actor, worker] { DoRetryCreatingActorOnWorker(actor, worker); },
RayConfig::instance().gcs_create_actor_retry_interval_ms()));
}
void GcsActorScheduler::DoRetryCreatingActorOnWorker(
+23
View File
@@ -86,6 +86,11 @@ cc_proto_library(
deps = [":gcs_service_proto"],
)
python_grpc_compile(
name = "gcs_service_py_proto",
deps = [":gcs_service_proto"],
)
proto_library(
name = "object_manager_proto",
srcs = ["object_manager.proto"],
@@ -132,3 +137,21 @@ cc_proto_library(
name = "event_cc_proto",
deps = [":event_proto"],
)
# Agent manager gRPC lib.
proto_library(
name = "agent_manager_proto",
srcs = ["agent_manager.proto"],
deps = [],
)
python_grpc_compile(
name = "agent_manager_py_proto",
deps = [":agent_manager_proto"],
)
cc_proto_library(
name = "agent_manager_cc_proto",
deps = [":agent_manager_proto"],
)
+38
View File
@@ -0,0 +1,38 @@
// Copyright 2017 The Ray Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto3";
package ray.rpc;
enum AgentRpcStatus {
// OK.
AGENT_RPC_STATUS_OK = 0;
// Failed.
AGENT_RPC_STATUS_FAILED = 1;
}
message RegisterAgentRequest {
int32 agent_pid = 1;
int32 agent_port = 2;
string agent_ip_address = 3;
}
message RegisterAgentReply {
AgentRpcStatus status = 1;
}
service AgentManagerService {
rpc RegisterAgent(RegisterAgentRequest) returns (RegisterAgentReply);
}
+96
View File
@@ -0,0 +1,96 @@
// Copyright 2017 The Ray Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "ray/raylet/agent_manager.h"
#include <thread>
#include "ray/common/ray_config.h"
#include "ray/util/logging.h"
#include "ray/util/process.h"
namespace ray {
namespace raylet {
void AgentManager::HandleRegisterAgent(const rpc::RegisterAgentRequest &request,
rpc::RegisterAgentReply *reply,
rpc::SendReplyCallback send_reply_callback) {
agent_ip_address_ = request.agent_ip_address();
agent_port_ = request.agent_port();
agent_pid_ = request.agent_pid();
RAY_LOG(INFO) << "HandleRegisterAgent, ip: " << agent_ip_address_
<< ", port: " << agent_port_ << ", pid: " << agent_pid_;
reply->set_status(rpc::AGENT_RPC_STATUS_OK);
send_reply_callback(ray::Status::OK(), nullptr, nullptr);
}
void AgentManager::StartAgent() {
if (options_.agent_commands.empty()) {
RAY_LOG(INFO) << "Not starting agent, the agent command is empty.";
return;
}
if (RAY_LOG_ENABLED(DEBUG)) {
std::stringstream stream;
stream << "Starting agent process with command:";
for (const auto &arg : options_.agent_commands) {
stream << " " << arg;
}
RAY_LOG(DEBUG) << stream.str();
}
// Launch the process to create the agent.
std::error_code ec;
std::vector<const char *> argv;
for (const std::string &arg : options_.agent_commands) {
argv.push_back(arg.c_str());
}
argv.push_back(NULL);
Process child(argv.data(), nullptr, ec);
if (!child.IsValid() || ec) {
// The worker failed to start. This is a fatal error.
RAY_LOG(FATAL) << "Failed to start agent with return value " << ec << ": "
<< ec.message();
RAY_UNUSED(delay_executor_([this] { StartAgent(); },
RayConfig::instance().agent_restart_interval_ms()));
return;
}
std::thread monitor_thread([this, child]() mutable {
RAY_LOG(INFO) << "Monitor agent process with pid " << child.GetId()
<< ", register timeout "
<< RayConfig::instance().agent_register_timeout_ms() << "ms.";
auto timer = delay_executor_(
[this, child]() mutable {
if (agent_pid_ != child.GetId()) {
RAY_LOG(WARNING) << "Agent process with pid " << child.GetId()
<< " has not registered, restart it.";
child.Kill();
}
},
RayConfig::instance().agent_register_timeout_ms());
int exit_code = child.Wait();
timer->cancel();
RAY_LOG(WARNING) << "Agent process with pid " << child.GetId()
<< " exit, return value " << exit_code;
RAY_UNUSED(delay_executor_([this] { StartAgent(); },
RayConfig::instance().agent_restart_interval_ms()));
});
monitor_thread.detach();
}
} // namespace raylet
} // namespace ray
+75
View File
@@ -0,0 +1,75 @@
// Copyright 2017 The Ray Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <utility>
#include <vector>
#include "ray/rpc/agent_manager/agent_manager_client.h"
#include "ray/rpc/agent_manager/agent_manager_server.h"
#include "ray/util/process.h"
namespace ray {
namespace raylet {
typedef std::function<std::shared_ptr<boost::asio::deadline_timer>(std::function<void()>,
uint32_t delay_ms)>
DelayExecutorFn;
class AgentManager : public rpc::AgentManagerServiceHandler {
public:
struct Options {
std::vector<std::string> agent_commands;
};
explicit AgentManager(Options options, DelayExecutorFn delay_executor)
: options_(std::move(options)), delay_executor_(std::move(delay_executor)) {
StartAgent();
}
void HandleRegisterAgent(const rpc::RegisterAgentRequest &request,
rpc::RegisterAgentReply *reply,
rpc::SendReplyCallback send_reply_callback) override;
private:
void StartAgent();
private:
Options options_;
pid_t agent_pid_ = 0;
int agent_port_ = 0;
std::string agent_ip_address_;
DelayExecutorFn delay_executor_;
};
class DefaultAgentManagerServiceHandler : public rpc::AgentManagerServiceHandler {
public:
explicit DefaultAgentManagerServiceHandler(std::unique_ptr<AgentManager> &delegate)
: delegate_(delegate) {}
void HandleRegisterAgent(const rpc::RegisterAgentRequest &request,
rpc::RegisterAgentReply *reply,
rpc::SendReplyCallback send_reply_callback) override {
RAY_CHECK(delegate_ != nullptr);
delegate_->HandleRegisterAgent(request, reply, send_reply_callback);
}
private:
std::unique_ptr<AgentManager> &delegate_;
};
} // namespace raylet
} // namespace ray
+7
View File
@@ -44,6 +44,7 @@ DEFINE_string(static_resource_list, "", "The static resource list of this node."
DEFINE_string(config_list, "", "The raylet config list of this node.");
DEFINE_string(python_worker_command, "", "Python worker command.");
DEFINE_string(java_worker_command, "", "Java worker command.");
DEFINE_string(agent_command, "", "Dashboard agent command.");
DEFINE_string(redis_password, "", "The password of redis.");
DEFINE_string(temp_dir, "", "Temporary directory.");
DEFINE_string(session_dir, "", "The path of this ray session directory.");
@@ -82,6 +83,7 @@ int main(int argc, char *argv[]) {
const std::string config_list = FLAGS_config_list;
const std::string python_worker_command = FLAGS_python_worker_command;
const std::string java_worker_command = FLAGS_java_worker_command;
const std::string agent_command = FLAGS_agent_command;
const std::string redis_password = FLAGS_redis_password;
const std::string temp_dir = FLAGS_temp_dir;
const std::string session_dir = FLAGS_session_dir;
@@ -184,6 +186,11 @@ int main(int argc, char *argv[]) {
RAY_CHECK(0) << "Either Python worker command or Java worker command should be "
"provided.";
}
if (!agent_command.empty()) {
node_manager_config.agent_command = agent_command;
} else {
RAY_LOG(DEBUG) << "Agent command is empty.";
}
node_manager_config.heartbeat_period_ms =
RayConfig::instance().raylet_heartbeat_timeout_milliseconds();
+14
View File
@@ -25,6 +25,7 @@
#include "ray/gcs/pb_util.h"
#include "ray/raylet/format/node_manager_generated.h"
#include "ray/stats/stats.h"
#include "ray/util/asio_util.h"
#include "ray/util/sample.h"
namespace {
@@ -164,6 +165,9 @@ NodeManager::NodeManager(boost::asio::io_service &io_service,
actor_registry_(),
node_manager_server_("NodeManager", config.node_manager_port),
node_manager_service_(io_service, *this),
agent_manager_service_handler_(
new DefaultAgentManagerServiceHandler(agent_manager_)),
agent_manager_service_(io_service, *agent_manager_service_handler_),
client_call_manager_(io_service),
new_scheduler_enabled_(RayConfig::instance().new_scheduler_enabled()) {
RAY_LOG(INFO) << "Initializing NodeManager with ID " << self_node_id_;
@@ -208,8 +212,18 @@ NodeManager::NodeManager(boost::asio::io_service &io_service,
RAY_CHECK_OK(store_client_.Connect(config.store_socket_name.c_str()));
// Run the node manger rpc server.
node_manager_server_.RegisterService(node_manager_service_);
node_manager_server_.RegisterService(agent_manager_service_);
node_manager_server_.Run();
AgentManager::Options options;
options.agent_commands = ParseCommandLine(config.agent_command);
agent_manager_.reset(
new AgentManager(std::move(options),
/*delay_executor=*/
[this](std::function<void()> task, uint32_t delay_ms) {
return execute_after(io_service_, task, delay_ms);
}));
RAY_CHECK_OK(SetupPlasmaSubscription());
}
+9
View File
@@ -27,6 +27,7 @@
#include "ray/common/task/scheduling_resources.h"
#include "ray/object_manager/object_manager.h"
#include "ray/raylet/actor_registration.h"
#include "ray/raylet/agent_manager.h"
#include "ray/raylet/scheduling/scheduling_ids.h"
#include "ray/raylet/scheduling/cluster_resource_scheduler.h"
#include "ray/raylet/scheduling/cluster_task_manager.h"
@@ -73,6 +74,8 @@ struct NodeManagerConfig {
int maximum_startup_concurrency;
/// The commands used to start the worker process, grouped by language.
WorkerCommandMap worker_commands;
/// The command used to start agent.
std::string agent_command;
/// The time between heartbeats in milliseconds.
uint64_t heartbeat_period_ms;
/// The time between debug dumps in milliseconds, or -1 to disable.
@@ -721,12 +724,18 @@ class NodeManager : public rpc::NodeManagerServiceHandler {
/// restore the actor.
std::unordered_map<ActorID, ActorCheckpointID> checkpoint_id_to_restore_;
std::unique_ptr<AgentManager> agent_manager_;
/// The RPC server.
rpc::GrpcServer node_manager_server_;
/// The node manager RPC service.
rpc::NodeManagerGrpcService node_manager_service_;
/// The agent manager RPC service.
std::unique_ptr<rpc::AgentManagerServiceHandler> agent_manager_service_handler_;
rpc::AgentManagerGrpcService agent_manager_service_;
/// The `ClientCallManager` object that is shared by all `NodeManagerClient`s
/// as well as all `CoreWorkerClient`s.
rpc::ClientCallManager client_call_manager_;
@@ -0,0 +1,50 @@
// Copyright 2017 The Ray Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "ray/rpc/client_call.h"
#include "ray/rpc/grpc_client.h"
#include "src/ray/protobuf/agent_manager.grpc.pb.h"
namespace ray {
namespace rpc {
/// Client used for communicating with a remote agent manager server.
class AgentManagerClient {
public:
/// Constructor.
///
/// \param[in] address Address of the agent manager server.
/// \param[in] port Port of the agent manager server.
/// \param[in] client_call_manager The `ClientCallManager` used for managing requests.
AgentManagerClient(const std::string &address, const int port,
ClientCallManager &client_call_manager) {
grpc_client_ = std::unique_ptr<GrpcClient<AgentManagerService>>(
new GrpcClient<AgentManagerService>(address, port, client_call_manager));
};
/// Register agent service to the agent manager server
///
/// \param request The request message
/// \param callback The callback function that handles reply
VOID_RPC_CLIENT_METHOD(AgentManagerService, RegisterAgent, grpc_client_, )
private:
/// The RPC client.
std::unique_ptr<GrpcClient<AgentManagerService>> grpc_client_;
};
} // namespace rpc
} // namespace ray
@@ -0,0 +1,73 @@
// Copyright 2017 The Ray Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "ray/rpc/grpc_server.h"
#include "ray/rpc/server_call.h"
#include "src/ray/protobuf/agent_manager.grpc.pb.h"
#include "src/ray/protobuf/agent_manager.pb.h"
namespace ray {
namespace rpc {
#define RAY_AGENT_MANAGER_RPC_HANDLERS \
RPC_SERVICE_HANDLER(AgentManagerService, RegisterAgent)
/// Implementations of the `AgentManagerGrpcService`, check interface in
/// `src/ray/protobuf/agent_manager.proto`.
class AgentManagerServiceHandler {
public:
virtual ~AgentManagerServiceHandler() = default;
/// Handle a `RegisterAgent` request.
/// The implementation can handle this request asynchronously. When handling is done,
/// the `send_reply_callback` should be called.
///
/// \param[in] request The request message.
/// \param[out] reply The reply message.
/// \param[in] send_reply_callback The callback to be called when the request is done.
virtual void HandleRegisterAgent(const RegisterAgentRequest &request,
RegisterAgentReply *reply,
SendReplyCallback send_reply_callback) = 0;
};
/// The `GrpcService` for `AgentManagerGrpcService`.
class AgentManagerGrpcService : public GrpcService {
public:
/// Construct a `AgentManagerGrpcService`.
///
/// \param[in] port See `GrpcService`.
/// \param[in] handler The service handler that actually handle the requests.
AgentManagerGrpcService(boost::asio::io_service &io_service,
AgentManagerServiceHandler &service_handler)
: GrpcService(io_service), service_handler_(service_handler){};
protected:
grpc::Service &GetGrpcService() override { return service_; }
void InitServerCallFactories(
const std::unique_ptr<grpc::ServerCompletionQueue> &cq,
std::vector<std::unique_ptr<ServerCallFactory>> *server_call_factories) override {
RAY_AGENT_MANAGER_RPC_HANDLERS
}
private:
/// The grpc async service object.
AgentManagerService::AsyncService service_;
/// The service handler that actually handle the requests.
AgentManagerServiceHandler &service_handler_;
};
} // namespace rpc
} // namespace ray
+4 -2
View File
@@ -16,8 +16,9 @@
#include <boost/asio.hpp>
inline void execute_after(boost::asio::io_context &io_context,
const std::function<void()> &fn, uint32_t delay_milliseconds) {
inline std::shared_ptr<boost::asio::deadline_timer> execute_after(
boost::asio::io_context &io_context, const std::function<void()> &fn,
uint32_t delay_milliseconds) {
auto timer = std::make_shared<boost::asio::deadline_timer>(io_context);
timer->expires_from_now(boost::posix_time::milliseconds(delay_milliseconds));
timer->async_wait([timer, fn](const boost::system::error_code &error) {
@@ -25,4 +26,5 @@ inline void execute_after(boost::asio::io_context &io_context,
fn();
}
});
return timer;
}
+5
View File
@@ -187,6 +187,11 @@ static std::vector<std::string> ParsePosixCommandLine(const std::string &s) {
/// Python analog: None (would be shlex.split(s, posix=False), but it doesn't unquote)
static std::vector<std::string> ParseWindowsCommandLine(const std::string &s) {
RAY_CHECK(s.find('\0') >= s.size()) << "Invalid null character in command line";
// The if statement below may be incorrect. See:
// https://github.com/ray-project/ray/pull/10131#discussion_r473871563
if (s.empty()) {
return {};
}
std::vector<std::string> result;
std::string arg, c_str = s + '\0';
std::string::const_iterator i = c_str.begin(), j = c_str.end() - 1;