mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 23:57:45 +08:00
[tune][Dashboard] Added Tune Dashboard (#6911)
This commit is contained in:
+3187
-3047
File diff suppressed because it is too large
Load Diff
@@ -191,3 +191,41 @@ export const launchKillActor = (
|
||||
ip_address: actorIpAddress,
|
||||
port: actorPort
|
||||
});
|
||||
|
||||
export interface TuneTrial {
|
||||
date: string;
|
||||
episodes_total: string;
|
||||
experiment_id: string;
|
||||
experiment_tag: string;
|
||||
hostname: string;
|
||||
iterations_since_restore: number;
|
||||
logdir: string;
|
||||
node_ip: string;
|
||||
pid: number;
|
||||
time_since_restore: number;
|
||||
time_this_iter_s: number;
|
||||
time_total_s: number;
|
||||
timestamp: number;
|
||||
timesteps_since_restore: number;
|
||||
timesteps_total: number;
|
||||
training_iteration: number;
|
||||
start_time: string;
|
||||
status: string;
|
||||
trial_id: string;
|
||||
job_id: string;
|
||||
params: { [key: string]: string };
|
||||
metrics: { [key: string]: string };
|
||||
}
|
||||
|
||||
export interface TuneJobResponse {
|
||||
trial_records: { [key: string]: TuneTrial };
|
||||
}
|
||||
|
||||
export const getTuneInfo = () => get<TuneJobResponse>("/api/tune_info", {});
|
||||
|
||||
export interface TuneAvailabilityResponse {
|
||||
available: boolean;
|
||||
}
|
||||
|
||||
export const getTuneAvailability = () =>
|
||||
get<TuneAvailabilityResponse>("/api/tune_availability", {});
|
||||
|
||||
@@ -6,13 +6,14 @@ import Tabs from "@material-ui/core/Tabs";
|
||||
import Typography from "@material-ui/core/Typography";
|
||||
import React from "react";
|
||||
import { connect } from "react-redux";
|
||||
import { getNodeInfo, getRayletInfo } from "../../api";
|
||||
import { getNodeInfo, getRayletInfo, getTuneAvailability } from "../../api";
|
||||
import { StoreState } from "../../store";
|
||||
import LastUpdated from "./LastUpdated";
|
||||
import LogicalView from "./logical-view/LogicalView";
|
||||
import NodeInfo from "./node-info/NodeInfo";
|
||||
import RayConfig from "./ray-config/RayConfig";
|
||||
import { dashboardActions } from "./state";
|
||||
import Tune from "./Tune";
|
||||
|
||||
const styles = (theme: Theme) =>
|
||||
createStyles({
|
||||
@@ -31,7 +32,8 @@ const styles = (theme: Theme) =>
|
||||
});
|
||||
|
||||
const mapStateToProps = (state: StoreState) => ({
|
||||
tab: state.dashboard.tab
|
||||
tab: state.dashboard.tab,
|
||||
tuneAvailability: state.dashboard.tuneAvailability
|
||||
});
|
||||
|
||||
const mapDispatchToProps = dashboardActions;
|
||||
@@ -43,11 +45,13 @@ class Dashboard extends React.Component<
|
||||
> {
|
||||
refreshNodeAndRayletInfo = async () => {
|
||||
try {
|
||||
const [nodeInfo, rayletInfo] = await Promise.all([
|
||||
const [nodeInfo, rayletInfo, tuneAvailability] = await Promise.all([
|
||||
getNodeInfo(),
|
||||
getRayletInfo()
|
||||
getRayletInfo(),
|
||||
getTuneAvailability()
|
||||
]);
|
||||
this.props.setNodeAndRayletInfo({ nodeInfo, rayletInfo });
|
||||
this.props.setTuneAvailability({ tuneAvailability });
|
||||
this.props.setError(null);
|
||||
} catch (error) {
|
||||
this.props.setError(error.toString());
|
||||
@@ -65,12 +69,19 @@ class Dashboard extends React.Component<
|
||||
};
|
||||
|
||||
render() {
|
||||
const { classes, tab } = this.props;
|
||||
const { classes, tab, tuneAvailability } = this.props;
|
||||
const tabs = [
|
||||
{ label: "Machine view", component: NodeInfo },
|
||||
{ label: "Logical view", component: LogicalView },
|
||||
{ label: "Ray config", component: RayConfig }
|
||||
{ label: "Ray config", component: RayConfig },
|
||||
{ label: "Tune", component: Tune }
|
||||
];
|
||||
|
||||
// if Tune information is not avaliable, remove Tune tab from the dashboard
|
||||
if (!tuneAvailability) {
|
||||
tabs.splice(3);
|
||||
}
|
||||
|
||||
const SelectedComponent = tabs[tab].component;
|
||||
return (
|
||||
<div className={classes.root}>
|
||||
|
||||
@@ -0,0 +1,162 @@
|
||||
import { Theme } from "@material-ui/core/styles/createMuiTheme";
|
||||
import createStyles from "@material-ui/core/styles/createStyles";
|
||||
import withStyles, { WithStyles } from "@material-ui/core/styles/withStyles";
|
||||
import React from "react";
|
||||
import { connect } from "react-redux";
|
||||
import { getTuneInfo } from "../../api";
|
||||
import { StoreState } from "../../store";
|
||||
import { dashboardActions } from "./state";
|
||||
import Typography from "@material-ui/core/Typography";
|
||||
import WarningRoundedIcon from "@material-ui/icons/WarningRounded";
|
||||
import Table from "@material-ui/core/Table";
|
||||
import TableBody from "@material-ui/core/TableBody";
|
||||
import TableCell from "@material-ui/core/TableCell";
|
||||
import TableHead from "@material-ui/core/TableHead";
|
||||
import TableRow from "@material-ui/core/TableRow";
|
||||
|
||||
const styles = (theme: Theme) =>
|
||||
createStyles({
|
||||
root: {
|
||||
padding: theme.spacing(2),
|
||||
"& > :not(:first-child)": {
|
||||
marginTop: theme.spacing(2)
|
||||
}
|
||||
},
|
||||
table: {
|
||||
marginTop: theme.spacing(1)
|
||||
},
|
||||
cell: {
|
||||
padding: theme.spacing(1),
|
||||
textAlign: "center",
|
||||
"&:last-child": {
|
||||
paddingRight: theme.spacing(1)
|
||||
}
|
||||
},
|
||||
warning: {
|
||||
fontSize: "0.8125rem",
|
||||
marginBottom: theme.spacing(2)
|
||||
},
|
||||
warningIcon: {
|
||||
fontSize: "1.25em",
|
||||
verticalAlign: "text-bottom"
|
||||
}
|
||||
});
|
||||
|
||||
const mapStateToProps = (state: StoreState) => ({
|
||||
tuneInfo: state.dashboard.tuneInfo
|
||||
});
|
||||
|
||||
const mapDispatchToProps = dashboardActions;
|
||||
|
||||
class Tune extends React.Component<
|
||||
WithStyles<typeof styles> &
|
||||
ReturnType<typeof mapStateToProps> &
|
||||
typeof mapDispatchToProps
|
||||
> {
|
||||
timeout: number = 0;
|
||||
|
||||
refreshTuneInfo = async () => {
|
||||
try {
|
||||
const [tuneInfo] = await Promise.all([getTuneInfo()]);
|
||||
this.props.setTuneInfo({ tuneInfo });
|
||||
} catch (error) {
|
||||
this.props.setError(error.toString());
|
||||
} finally {
|
||||
this.timeout = window.setTimeout(this.refreshTuneInfo, 1000);
|
||||
}
|
||||
};
|
||||
|
||||
async componentDidMount() {
|
||||
await this.refreshTuneInfo();
|
||||
}
|
||||
|
||||
async componentWillUnmount() {
|
||||
window.clearTimeout(this.timeout);
|
||||
}
|
||||
|
||||
render() {
|
||||
const { classes, tuneInfo } = this.props;
|
||||
|
||||
if (
|
||||
tuneInfo === null ||
|
||||
Object.keys(tuneInfo["trial_records"]).length === 0
|
||||
) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const firstTrial = Object.keys(tuneInfo["trial_records"])[0];
|
||||
let paramNames: string[] = [];
|
||||
if (tuneInfo !== null) {
|
||||
const paramsDict = tuneInfo["trial_records"][firstTrial]["params"];
|
||||
paramNames = Object.keys(paramsDict).filter(k => k !== "args");
|
||||
}
|
||||
|
||||
const metricNames = Object.keys(
|
||||
tuneInfo["trial_records"][firstTrial]["metrics"]
|
||||
);
|
||||
|
||||
return (
|
||||
<div className={classes.root}>
|
||||
<Typography className={classes.warning} color="textSecondary">
|
||||
<WarningRoundedIcon className={classes.warningIcon} /> Note: This tab
|
||||
is experimental.
|
||||
</Typography>
|
||||
<Table className={classes.table}>
|
||||
<TableHead>
|
||||
<TableRow>
|
||||
<TableCell className={classes.cell}>Trial ID</TableCell>
|
||||
<TableCell className={classes.cell}>Job ID</TableCell>
|
||||
<TableCell className={classes.cell}>Start Time</TableCell>
|
||||
{paramNames.map((value, index) => (
|
||||
<TableCell className={classes.cell} key={value}>
|
||||
{value}
|
||||
</TableCell>
|
||||
))}
|
||||
<TableCell className={classes.cell}>Status</TableCell>
|
||||
{metricNames.map((value, index) => (
|
||||
<TableCell className={classes.cell} key={value}>
|
||||
{value}
|
||||
</TableCell>
|
||||
))}
|
||||
</TableRow>
|
||||
</TableHead>
|
||||
<TableBody>
|
||||
{tuneInfo != null &&
|
||||
Object.keys(tuneInfo["trial_records"]).map(key => (
|
||||
<TableRow key={key}>
|
||||
<TableCell className={classes.cell}>
|
||||
{tuneInfo["trial_records"][key]["trial_id"]}
|
||||
</TableCell>
|
||||
<TableCell className={classes.cell}>
|
||||
{tuneInfo["trial_records"][key]["job_id"]}
|
||||
</TableCell>
|
||||
<TableCell className={classes.cell}>
|
||||
{tuneInfo["trial_records"][key]["start_time"]}
|
||||
</TableCell>
|
||||
{paramNames.map((value, index) => (
|
||||
<TableCell className={classes.cell} key={value}>
|
||||
{tuneInfo["trial_records"][key]["params"][value]}
|
||||
</TableCell>
|
||||
))}
|
||||
<TableCell className={classes.cell}>
|
||||
{tuneInfo["trial_records"][key]["status"]}
|
||||
</TableCell>
|
||||
{tuneInfo["trial_records"][key]["metrics"] &&
|
||||
metricNames.map((value, index) => (
|
||||
<TableCell className={classes.cell} key={value}>
|
||||
{tuneInfo["trial_records"][key]["metrics"][value]}
|
||||
</TableCell>
|
||||
))}
|
||||
</TableRow>
|
||||
))}
|
||||
</TableBody>
|
||||
</Table>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
export default connect(
|
||||
mapStateToProps,
|
||||
mapDispatchToProps
|
||||
)(withStyles(styles)(Tune));
|
||||
@@ -2,7 +2,9 @@ import { createSlice, PayloadAction } from "@reduxjs/toolkit";
|
||||
import {
|
||||
NodeInfoResponse,
|
||||
RayConfigResponse,
|
||||
RayletInfoResponse
|
||||
RayletInfoResponse,
|
||||
TuneJobResponse,
|
||||
TuneAvailabilityResponse
|
||||
} from "../../api";
|
||||
|
||||
const name = "dashboard";
|
||||
@@ -12,6 +14,8 @@ interface State {
|
||||
rayConfig: RayConfigResponse | null;
|
||||
nodeInfo: NodeInfoResponse | null;
|
||||
rayletInfo: RayletInfoResponse | null;
|
||||
tuneInfo: TuneJobResponse | null;
|
||||
tuneAvailability: boolean;
|
||||
lastUpdatedAt: number | null;
|
||||
error: string | null;
|
||||
}
|
||||
@@ -21,6 +25,8 @@ const initialState: State = {
|
||||
rayConfig: null,
|
||||
nodeInfo: null,
|
||||
rayletInfo: null,
|
||||
tuneInfo: null,
|
||||
tuneAvailability: false,
|
||||
lastUpdatedAt: null,
|
||||
error: null
|
||||
};
|
||||
@@ -46,6 +52,24 @@ const slice = createSlice({
|
||||
state.rayletInfo = action.payload.rayletInfo;
|
||||
state.lastUpdatedAt = Date.now();
|
||||
},
|
||||
setTuneInfo: (
|
||||
state,
|
||||
action: PayloadAction<{
|
||||
tuneInfo: TuneJobResponse;
|
||||
}>
|
||||
) => {
|
||||
state.tuneInfo = action.payload.tuneInfo;
|
||||
state.lastUpdatedAt = Date.now();
|
||||
},
|
||||
setTuneAvailability: (
|
||||
state,
|
||||
action: PayloadAction<{
|
||||
tuneAvailability: TuneAvailabilityResponse;
|
||||
}>
|
||||
) => {
|
||||
state.tuneAvailability = action.payload.tuneAvailability["available"];
|
||||
state.lastUpdatedAt = Date.now();
|
||||
},
|
||||
setError: (state, action: PayloadAction<string | null>) => {
|
||||
state.error = action.payload;
|
||||
}
|
||||
|
||||
@@ -34,6 +34,12 @@ from ray.core.generated import core_worker_pb2
|
||||
from ray.core.generated import core_worker_pb2_grpc
|
||||
import ray.ray_constants as ray_constants
|
||||
|
||||
try:
|
||||
from ray.tune.result import DEFAULT_RESULTS_DIR
|
||||
from ray.tune import Analysis
|
||||
except ImportError:
|
||||
Analysis = None
|
||||
|
||||
# Logger for this module. It should be configured at the entry point
|
||||
# into the program using Ray. Ray provides a default configuration at
|
||||
# entry/init points.
|
||||
@@ -114,6 +120,8 @@ class Dashboard(object):
|
||||
|
||||
self.node_stats = NodeStats(redis_address, redis_password)
|
||||
self.raylet_stats = RayletStats(redis_address, redis_password)
|
||||
if Analysis is not None:
|
||||
self.tune_stats = TuneCollector(DEFAULT_RESULTS_DIR, 2.0)
|
||||
|
||||
# Setting the environment variable RAY_DASHBOARD_DEV=1 disables some
|
||||
# security checks in the dashboard server to ease development while
|
||||
@@ -258,6 +266,20 @@ class Dashboard(object):
|
||||
result = {"nodes": D, "actors": actor_tree}
|
||||
return await json_response(result=result)
|
||||
|
||||
async def tune_info(req) -> aiohttp.web.Response:
|
||||
if Analysis is not None:
|
||||
D = self.tune_stats.get_stats()
|
||||
else:
|
||||
D = {}
|
||||
return await json_response(result=D)
|
||||
|
||||
async def tune_availability(req) -> aiohttp.web.Response:
|
||||
if Analysis is not None:
|
||||
D = self.tune_stats.get_availability()
|
||||
else:
|
||||
D = {"available": False}
|
||||
return await json_response(result=D)
|
||||
|
||||
async def launch_profiling(req) -> aiohttp.web.Response:
|
||||
node_id = req.query.get("node_id")
|
||||
pid = int(req.query.get("pid"))
|
||||
@@ -317,6 +339,8 @@ class Dashboard(object):
|
||||
self.app.router.add_get("/api/ray_config", ray_config)
|
||||
self.app.router.add_get("/api/node_info", node_info)
|
||||
self.app.router.add_get("/api/raylet_info", raylet_info)
|
||||
self.app.router.add_get("/api/tune_info", tune_info)
|
||||
self.app.router.add_get("/api/tune_availability", tune_availability)
|
||||
self.app.router.add_get("/api/launch_profiling", launch_profiling)
|
||||
self.app.router.add_get("/api/check_profiling_status",
|
||||
check_profiling_status)
|
||||
@@ -337,6 +361,8 @@ class Dashboard(object):
|
||||
self.log_dashboard_url()
|
||||
self.node_stats.start()
|
||||
self.raylet_stats.start()
|
||||
if Analysis is not None:
|
||||
self.tune_stats.start()
|
||||
aiohttp.web.run_app(self.app, host=self.host, port=self.port)
|
||||
|
||||
|
||||
@@ -710,6 +736,130 @@ class RayletStats(threading.Thread):
|
||||
self.update_nodes()
|
||||
|
||||
|
||||
class TuneCollector(threading.Thread):
|
||||
"""Initialize collector worker thread.
|
||||
Args
|
||||
logdir (str): Directory path to save the status information of
|
||||
jobs and trials.
|
||||
reload_interval (float): Interval(in s) of space between loading
|
||||
data from logs
|
||||
"""
|
||||
|
||||
def __init__(self, logdir, reload_interval):
|
||||
super().__init__()
|
||||
self._logdir = logdir
|
||||
self._trial_records = {}
|
||||
self._data_lock = threading.Lock()
|
||||
self._reload_interval = reload_interval
|
||||
self._available = False
|
||||
|
||||
def get_stats(self):
|
||||
with self._data_lock:
|
||||
return {"trial_records": copy.deepcopy(self._trial_records)}
|
||||
|
||||
def get_availability(self):
|
||||
with self._data_lock:
|
||||
return {"available": self._available}
|
||||
|
||||
def run(self):
|
||||
while True:
|
||||
with self._data_lock:
|
||||
self.collect()
|
||||
time.sleep(self._reload_interval)
|
||||
|
||||
def collect(self):
|
||||
"""
|
||||
Collects and cleans data on the running Tune experiment from the
|
||||
Tune logs so that users can see this information in the front-end
|
||||
client
|
||||
"""
|
||||
sub_dirs = os.listdir(self._logdir)
|
||||
job_names = filter(
|
||||
lambda d: os.path.isdir(os.path.join(self._logdir, d)), sub_dirs)
|
||||
|
||||
self._trial_records = {}
|
||||
|
||||
# search through all the sub_directories in log directory
|
||||
for job_name in job_names:
|
||||
analysis = Analysis(str(os.path.join(self._logdir, job_name)))
|
||||
df = analysis.dataframe()
|
||||
if len(df) == 0:
|
||||
continue
|
||||
|
||||
self._available = True
|
||||
|
||||
# make sure that data will convert to JSON without error
|
||||
df["trial_id"] = df["trial_id"].astype(str)
|
||||
df = df.fillna(0)
|
||||
|
||||
# convert df to python dict
|
||||
df = df.set_index("trial_id")
|
||||
trial_data = df.to_dict(orient="index")
|
||||
|
||||
# clean data and update class attribute
|
||||
if len(trial_data) > 0:
|
||||
trial_data = self.clean_trials(trial_data, job_name)
|
||||
self._trial_records.update(trial_data)
|
||||
|
||||
def clean_trials(self, trial_details, job_name):
|
||||
first_trial = trial_details[list(trial_details.keys())[0]]
|
||||
config_keys = []
|
||||
float_keys = []
|
||||
metric_keys = []
|
||||
|
||||
# list of static attributes for trial
|
||||
default_names = [
|
||||
"logdir", "time_this_iter_s", "done", "episodes_total",
|
||||
"training_iteration", "timestamp", "timesteps_total",
|
||||
"experiment_id", "date", "timestamp", "time_total_s", "pid",
|
||||
"hostname", "node_ip", "time_since_restore",
|
||||
"timesteps_since_restore", "iterations_since_restore",
|
||||
"experiment_tag"
|
||||
]
|
||||
|
||||
# filter attributes into floats, metrics, and config variables
|
||||
for key, value in first_trial.items():
|
||||
if isinstance(value, float):
|
||||
float_keys.append(key)
|
||||
if str(key).startswith("config/"):
|
||||
config_keys.append(key)
|
||||
elif key not in default_names:
|
||||
metric_keys.append(key)
|
||||
|
||||
# clean data into a form that front-end client can handle
|
||||
for trial, details in trial_details.items():
|
||||
details["start_time"] = str(
|
||||
round(os.path.getctime(details["logdir"]), 3))
|
||||
details["params"] = {}
|
||||
details["metrics"] = {}
|
||||
|
||||
# round all floats
|
||||
for key in float_keys:
|
||||
details[key] = round(details[key], 3)
|
||||
|
||||
# group together config attributes
|
||||
for key in config_keys:
|
||||
new_name = key[7:]
|
||||
details["params"][new_name] = str(details[key])
|
||||
details.pop(key)
|
||||
|
||||
# group together metric attributes
|
||||
for key in metric_keys:
|
||||
details["metrics"][key] = str(details[key])
|
||||
details.pop(key)
|
||||
|
||||
if details["done"]:
|
||||
details["status"] = "TERMINATED"
|
||||
else:
|
||||
details["status"] = "RUNNING"
|
||||
details.pop("done")
|
||||
|
||||
details["trial_id"] = trial
|
||||
details["job_id"] = job_name
|
||||
|
||||
return trial_details
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description=("Parse Redis server for the "
|
||||
|
||||
Reference in New Issue
Block a user