mirror of
https://github.com/wassname/ray.git
synced 2026-06-30 09:41:11 +08:00
[tune] strict metric checking (#10972)
This commit is contained in:
@@ -1,4 +1,3 @@
|
||||
import numpy as np
|
||||
import sklearn.datasets
|
||||
import sklearn.metrics
|
||||
from ray.tune.schedulers import ASHAScheduler
|
||||
@@ -6,7 +5,7 @@ from sklearn.model_selection import train_test_split
|
||||
import xgboost as xgb
|
||||
|
||||
from ray import tune
|
||||
from ray.tune.integration.xgboost import TuneReportCallback
|
||||
from ray.tune.integration.xgboost import TuneReportCheckpointCallback
|
||||
|
||||
|
||||
def train_breast_cancer(config):
|
||||
@@ -19,39 +18,44 @@ def train_breast_cancer(config):
|
||||
train_set = xgb.DMatrix(train_x, label=train_y)
|
||||
test_set = xgb.DMatrix(test_x, label=test_y)
|
||||
# Train the classifier
|
||||
bst = xgb.train(
|
||||
xgb.train(
|
||||
config,
|
||||
train_set,
|
||||
evals=[(test_set, "eval")],
|
||||
verbose_eval=False,
|
||||
callbacks=[TuneReportCallback()])
|
||||
# Predict labels for the test set
|
||||
preds = bst.predict(test_set)
|
||||
pred_labels = np.rint(preds)
|
||||
# Return prediction accuracy
|
||||
accuracy = sklearn.metrics.accuracy_score(test_y, pred_labels)
|
||||
tune.report(mean_accuracy=accuracy, done=True)
|
||||
callbacks=[TuneReportCheckpointCallback(filename="model.xgb")])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
config = {
|
||||
"objective": "binary:logistic",
|
||||
"eval_metric": ["logloss", "error"],
|
||||
"max_depth": tune.randint(1, 9),
|
||||
"min_child_weight": tune.choice([1, 2, 3]),
|
||||
"subsample": tune.uniform(0.5, 1.0),
|
||||
"eta": tune.loguniform(1e-4, 1e-1),
|
||||
"eval_metric": ["auc", "ams@0", "logloss"]
|
||||
"eta": tune.loguniform(1e-4, 1e-1)
|
||||
}
|
||||
# The ASHAScheduler stops bad performing configurations early
|
||||
scheduler = ASHAScheduler(
|
||||
metric="eval-logloss", # The `eval` prefix is defined in xgb.train
|
||||
mode="min", # Retain configurations with a low logloss
|
||||
max_t=11, # 10 training iterations + 1 final evaluation
|
||||
grace_period=1, # Number of minimum iterations for each trial
|
||||
reduction_factor=2) # How aggressively to stop trials
|
||||
tune.run(
|
||||
train_breast_cancer, # your training function
|
||||
max_t=10, # 10 training iterations
|
||||
grace_period=1,
|
||||
reduction_factor=2)
|
||||
|
||||
analysis = tune.run(
|
||||
train_breast_cancer,
|
||||
metric="eval-logloss",
|
||||
mode="min",
|
||||
resources_per_trial={"cpu": 1}, # You can add "gpu": 0.1 here
|
||||
config=config,
|
||||
num_samples=10, # number of parameter configurations to try
|
||||
num_samples=10,
|
||||
scheduler=scheduler)
|
||||
|
||||
# Load the best model checkpoint
|
||||
import os
|
||||
best_bst = xgb.Booster()
|
||||
best_bst.load_model(os.path.join(analysis.best_checkpoint, "model.xgb"))
|
||||
accuracy = 1. - analysis.best_result["eval-error"]
|
||||
print(f"Best model parameters: {analysis.best_config}")
|
||||
print(f"Best model total accuracy: {accuracy:.4f}")
|
||||
|
||||
# You could now do further predictions with
|
||||
# best_bst.predict(...)
|
||||
|
||||
@@ -278,7 +278,7 @@ class TuneReporterBase(ProgressReporter):
|
||||
continue
|
||||
if not best_metric or \
|
||||
t.last_result[metric] * metric_op > best_metric:
|
||||
best_metric = t.last_result[metric]
|
||||
best_metric = t.last_result[metric] * metric_op
|
||||
best_trial = t
|
||||
return best_trial, metric
|
||||
|
||||
|
||||
@@ -11,6 +11,12 @@ class TrialScheduler:
|
||||
PAUSE = "PAUSE" #: Status for pausing trial execution
|
||||
STOP = "STOP" #: Status for stopping trial execution
|
||||
|
||||
_metric = None
|
||||
|
||||
@property
|
||||
def metric(self):
|
||||
return self._metric
|
||||
|
||||
def set_search_properties(self, metric: Optional[str],
|
||||
mode: Optional[str]) -> bool:
|
||||
"""Pass search properties to scheduler.
|
||||
@@ -22,6 +28,10 @@ class TrialScheduler:
|
||||
metric (str): Metric to optimize
|
||||
mode (str): One of ["min", "max"]. Direction to optimize.
|
||||
"""
|
||||
if self._metric and metric:
|
||||
return False
|
||||
if metric:
|
||||
self._metric = metric
|
||||
return True
|
||||
|
||||
def on_trial_add(self, trial_runner: "trial_runner.TrialRunner",
|
||||
|
||||
@@ -17,6 +17,12 @@ class SearchAlgorithm:
|
||||
"""
|
||||
_finished = False
|
||||
|
||||
_metric = None
|
||||
|
||||
@property
|
||||
def metric(self):
|
||||
return self._metric
|
||||
|
||||
def set_search_properties(self, metric: Optional[str], mode: Optional[str],
|
||||
config: Dict) -> bool:
|
||||
"""Pass search properties to search algorithm.
|
||||
@@ -33,6 +39,10 @@ class SearchAlgorithm:
|
||||
mode (str): One of ["min", "max"]. Direction to optimize.
|
||||
config (dict): Tune config dict.
|
||||
"""
|
||||
if self._metric and metric:
|
||||
return False
|
||||
if metric:
|
||||
self._metric = metric
|
||||
return True
|
||||
|
||||
@property
|
||||
|
||||
@@ -70,6 +70,10 @@ class SearchGenerator(SearchAlgorithm):
|
||||
self._total_samples = 0 # int: total samples to evaluate.
|
||||
self._finished = False
|
||||
|
||||
@property
|
||||
def metric(self):
|
||||
return self.searcher.metric
|
||||
|
||||
def set_search_properties(self, metric: Optional[str], mode: Optional[str],
|
||||
config: Dict) -> bool:
|
||||
return self.searcher.set_search_properties(metric, mode, config)
|
||||
|
||||
@@ -1146,6 +1146,54 @@ class TrainableFunctionApiTest(unittest.TestCase):
|
||||
diff = time.time() - start
|
||||
self.assertLess(diff, 9)
|
||||
|
||||
def testMetricCheckingEndToEnd(self):
|
||||
from ray import tune
|
||||
|
||||
def train(config):
|
||||
tune.report(val=4, second=8)
|
||||
|
||||
def train2(config):
|
||||
return
|
||||
|
||||
os.environ["TUNE_DISABLE_STRICT_METRIC_CHECKING"] = "0"
|
||||
# `acc` is not reported, should raise
|
||||
with self.assertRaises(TuneError):
|
||||
# The trial runner raises a ValueError, but the experiment fails
|
||||
# with a TuneError
|
||||
tune.run(train, metric="acc")
|
||||
|
||||
# `val` is reported, should not raise
|
||||
tune.run(train, metric="val")
|
||||
|
||||
# Run does not report anything, should not raise
|
||||
tune.run(train2, metric="val")
|
||||
|
||||
# Only the scheduler requires a metric
|
||||
with self.assertRaises(TuneError):
|
||||
tune.run(
|
||||
train,
|
||||
scheduler=AsyncHyperBandScheduler(metric="acc", mode="max"))
|
||||
|
||||
tune.run(
|
||||
train, scheduler=AsyncHyperBandScheduler(metric="val", mode="max"))
|
||||
|
||||
# Only the search alg requires a metric
|
||||
with self.assertRaises(TuneError):
|
||||
tune.run(
|
||||
train,
|
||||
config={"a": tune.choice([1, 2])},
|
||||
search_alg=HyperOptSearch(metric="acc", mode="max"))
|
||||
|
||||
# Metric is passed
|
||||
tune.run(
|
||||
train,
|
||||
config={"a": tune.choice([1, 2])},
|
||||
search_alg=HyperOptSearch(metric="val", mode="max"))
|
||||
|
||||
os.environ["TUNE_DISABLE_STRICT_METRIC_CHECKING"] = "1"
|
||||
# With strict metric checking disabled, this should not raise
|
||||
tune.run(train, metric="acc")
|
||||
|
||||
|
||||
class ShimCreationTest(unittest.TestCase):
|
||||
def testCreateScheduler(self):
|
||||
|
||||
@@ -252,7 +252,7 @@ class SearchSpaceTest(unittest.TestCase):
|
||||
with self.assertRaises(ValueError):
|
||||
searcher.set_search_properties("none", "max", invalid_config)
|
||||
|
||||
searcher = BayesOptSearch(metric="a", mode="max")
|
||||
searcher = BayesOptSearch(metric="b", mode="max")
|
||||
analysis = tune.run(
|
||||
_mock_objective, config=config, search_alg=searcher, num_samples=1)
|
||||
trial = analysis.trials[0]
|
||||
|
||||
@@ -26,8 +26,14 @@ from ray.tune.utils import flatten_dict
|
||||
from ray.utils import binary_to_hex, hex_to_binary
|
||||
|
||||
DEBUG_PRINT_INTERVAL = 5
|
||||
MAX_LEN_IDENTIFIER = int(os.environ.get("MAX_LEN_IDENTIFIER", 130))
|
||||
logger = logging.getLogger(__name__)
|
||||
if "MAX_LEN_IDENTIFIER" in os.environ:
|
||||
logger.error(
|
||||
"The MAX_LEN_IDENTIFIER environment variable is deprecated and will "
|
||||
"be removed in the future. Use TUNE_MAX_LEN_IDENTIFIER instead.")
|
||||
MAX_LEN_IDENTIFIER = int(
|
||||
os.environ.get("TUNE_MAX_LEN_IDENTIFIER",
|
||||
os.environ.get("MAX_LEN_IDENTIFIER", 130)))
|
||||
|
||||
|
||||
def date_str():
|
||||
|
||||
@@ -132,15 +132,20 @@ class TrialRunner:
|
||||
fail_fast=False,
|
||||
verbose=True,
|
||||
checkpoint_period=None,
|
||||
trial_executor=None):
|
||||
trial_executor=None,
|
||||
metric=None):
|
||||
self._search_alg = search_alg or BasicVariantGenerator()
|
||||
self._scheduler_alg = scheduler or FIFOScheduler()
|
||||
self.trial_executor = trial_executor or RayTrialExecutor()
|
||||
|
||||
# For debugging, it may be useful to halt trials after some time has
|
||||
# elapsed. TODO(ekl) consider exposing this in the API.
|
||||
self._global_time_limit = float(
|
||||
os.environ.get("TRIALRUNNER_WALLTIME_LIMIT", float("inf")))
|
||||
self._metric = metric
|
||||
|
||||
if "TRIALRUNNER_WALLTIME_LIMIT" in os.environ:
|
||||
raise ValueError(
|
||||
"The TRIALRUNNER_WALLTIME_LIMIT environment variable is "
|
||||
"deprecated. "
|
||||
"Use `tune.run(time_budget_s=limit)` instead.")
|
||||
|
||||
self._total_time = 0
|
||||
self._iteration = 0
|
||||
self._has_errored = False
|
||||
@@ -349,11 +354,6 @@ class TrialRunner:
|
||||
|
||||
def is_finished(self):
|
||||
"""Returns whether all trials have finished running."""
|
||||
if self._total_time > self._global_time_limit:
|
||||
logger.warning("Exceeded global time limit {} / {}".format(
|
||||
self._total_time, self._global_time_limit))
|
||||
return True
|
||||
|
||||
trials_done = all(trial.is_finished() for trial in self._trials)
|
||||
return trials_done and self._search_alg.is_finished()
|
||||
|
||||
@@ -527,6 +527,7 @@ class TrialRunner:
|
||||
result = trial.last_result
|
||||
result.update(done=True)
|
||||
|
||||
self._validate_result_metrics(result)
|
||||
self._total_time += result.get(TIME_THIS_ITER_S, 0)
|
||||
|
||||
flat_result = flatten_dict(result)
|
||||
@@ -572,6 +573,43 @@ class TrialRunner:
|
||||
raise
|
||||
self._process_trial_failure(trial, traceback.format_exc())
|
||||
|
||||
def _validate_result_metrics(self, result):
|
||||
"""
|
||||
Check if any of the required metrics was not reported
|
||||
in the last result. If the only item is `done=True`, this
|
||||
means that no result was ever received and the trial just
|
||||
returned. This is also okay and will not raise an error.
|
||||
"""
|
||||
if int(os.environ.get("TUNE_DISABLE_STRICT_METRIC_CHECKING",
|
||||
0)) != 1 and (len(result) > 1
|
||||
or "done" not in result):
|
||||
base_metric = self._metric
|
||||
scheduler_metric = self._scheduler_alg.metric
|
||||
search_metric = self._search_alg.metric
|
||||
|
||||
if base_metric and base_metric not in result:
|
||||
report_metric = base_metric
|
||||
location = "tune.run()"
|
||||
elif scheduler_metric and scheduler_metric not in result:
|
||||
report_metric = scheduler_metric
|
||||
location = type(self._scheduler_alg).__name__
|
||||
elif search_metric and search_metric not in result:
|
||||
report_metric = search_metric
|
||||
location = type(self._search_alg).__name__
|
||||
else:
|
||||
report_metric = None
|
||||
location = None
|
||||
|
||||
if report_metric:
|
||||
raise ValueError(
|
||||
"Trial returned a result which did not include the "
|
||||
"specified metric `{}` that `{}` expects. "
|
||||
"Make sure your calls to `tune.report()` include the "
|
||||
"metric, or set the "
|
||||
"TUNE_DISABLE_STRICT_METRIC_CHECKING "
|
||||
"environment variable to 1. Result: {}".format(
|
||||
report_metric, location, result))
|
||||
|
||||
def _process_trial_save(self, trial):
|
||||
"""Processes a trial save.
|
||||
|
||||
|
||||
@@ -374,7 +374,8 @@ def run(
|
||||
server_port=server_port,
|
||||
verbose=bool(verbose > 1),
|
||||
fail_fast=fail_fast,
|
||||
trial_executor=trial_executor)
|
||||
trial_executor=trial_executor,
|
||||
metric=metric)
|
||||
|
||||
if not runner.resumed:
|
||||
for exp in experiments:
|
||||
|
||||
Reference in New Issue
Block a user