[tune] strict metric checking (#10972)

This commit is contained in:
Kai Fricke
2020-09-24 18:00:48 +01:00
committed by GitHub
parent 5e6b887f2d
commit d9c4dea7cf
12 changed files with 275 additions and 161 deletions
+25 -21
View File
@@ -1,4 +1,3 @@
import numpy as np
import sklearn.datasets
import sklearn.metrics
from ray.tune.schedulers import ASHAScheduler
@@ -6,7 +5,7 @@ from sklearn.model_selection import train_test_split
import xgboost as xgb
from ray import tune
from ray.tune.integration.xgboost import TuneReportCallback
from ray.tune.integration.xgboost import TuneReportCheckpointCallback
def train_breast_cancer(config):
@@ -19,39 +18,44 @@ def train_breast_cancer(config):
train_set = xgb.DMatrix(train_x, label=train_y)
test_set = xgb.DMatrix(test_x, label=test_y)
# Train the classifier
bst = xgb.train(
xgb.train(
config,
train_set,
evals=[(test_set, "eval")],
verbose_eval=False,
callbacks=[TuneReportCallback()])
# Predict labels for the test set
preds = bst.predict(test_set)
pred_labels = np.rint(preds)
# Return prediction accuracy
accuracy = sklearn.metrics.accuracy_score(test_y, pred_labels)
tune.report(mean_accuracy=accuracy, done=True)
callbacks=[TuneReportCheckpointCallback(filename="model.xgb")])
if __name__ == "__main__":
config = {
"objective": "binary:logistic",
"eval_metric": ["logloss", "error"],
"max_depth": tune.randint(1, 9),
"min_child_weight": tune.choice([1, 2, 3]),
"subsample": tune.uniform(0.5, 1.0),
"eta": tune.loguniform(1e-4, 1e-1),
"eval_metric": ["auc", "ams@0", "logloss"]
"eta": tune.loguniform(1e-4, 1e-1)
}
# The ASHAScheduler stops bad performing configurations early
scheduler = ASHAScheduler(
metric="eval-logloss", # The `eval` prefix is defined in xgb.train
mode="min", # Retain configurations with a low logloss
max_t=11, # 10 training iterations + 1 final evaluation
grace_period=1, # Number of minimum iterations for each trial
reduction_factor=2) # How aggressively to stop trials
tune.run(
train_breast_cancer, # your training function
max_t=10, # 10 training iterations
grace_period=1,
reduction_factor=2)
analysis = tune.run(
train_breast_cancer,
metric="eval-logloss",
mode="min",
resources_per_trial={"cpu": 1}, # You can add "gpu": 0.1 here
config=config,
num_samples=10, # number of parameter configurations to try
num_samples=10,
scheduler=scheduler)
# Load the best model checkpoint
import os
best_bst = xgb.Booster()
best_bst.load_model(os.path.join(analysis.best_checkpoint, "model.xgb"))
accuracy = 1. - analysis.best_result["eval-error"]
print(f"Best model parameters: {analysis.best_config}")
print(f"Best model total accuracy: {accuracy:.4f}")
# You could now do further predictions with
# best_bst.predict(...)
+1 -1
View File
@@ -278,7 +278,7 @@ class TuneReporterBase(ProgressReporter):
continue
if not best_metric or \
t.last_result[metric] * metric_op > best_metric:
best_metric = t.last_result[metric]
best_metric = t.last_result[metric] * metric_op
best_trial = t
return best_trial, metric
@@ -11,6 +11,12 @@ class TrialScheduler:
PAUSE = "PAUSE" #: Status for pausing trial execution
STOP = "STOP" #: Status for stopping trial execution
_metric = None
@property
def metric(self):
return self._metric
def set_search_properties(self, metric: Optional[str],
mode: Optional[str]) -> bool:
"""Pass search properties to scheduler.
@@ -22,6 +28,10 @@ class TrialScheduler:
metric (str): Metric to optimize
mode (str): One of ["min", "max"]. Direction to optimize.
"""
if self._metric and metric:
return False
if metric:
self._metric = metric
return True
def on_trial_add(self, trial_runner: "trial_runner.TrialRunner",
+10
View File
@@ -17,6 +17,12 @@ class SearchAlgorithm:
"""
_finished = False
_metric = None
@property
def metric(self):
return self._metric
def set_search_properties(self, metric: Optional[str], mode: Optional[str],
config: Dict) -> bool:
"""Pass search properties to search algorithm.
@@ -33,6 +39,10 @@ class SearchAlgorithm:
mode (str): One of ["min", "max"]. Direction to optimize.
config (dict): Tune config dict.
"""
if self._metric and metric:
return False
if metric:
self._metric = metric
return True
@property
@@ -70,6 +70,10 @@ class SearchGenerator(SearchAlgorithm):
self._total_samples = 0 # int: total samples to evaluate.
self._finished = False
@property
def metric(self):
return self.searcher.metric
def set_search_properties(self, metric: Optional[str], mode: Optional[str],
config: Dict) -> bool:
return self.searcher.set_search_properties(metric, mode, config)
+48
View File
@@ -1146,6 +1146,54 @@ class TrainableFunctionApiTest(unittest.TestCase):
diff = time.time() - start
self.assertLess(diff, 9)
def testMetricCheckingEndToEnd(self):
from ray import tune
def train(config):
tune.report(val=4, second=8)
def train2(config):
return
os.environ["TUNE_DISABLE_STRICT_METRIC_CHECKING"] = "0"
# `acc` is not reported, should raise
with self.assertRaises(TuneError):
# The trial runner raises a ValueError, but the experiment fails
# with a TuneError
tune.run(train, metric="acc")
# `val` is reported, should not raise
tune.run(train, metric="val")
# Run does not report anything, should not raise
tune.run(train2, metric="val")
# Only the scheduler requires a metric
with self.assertRaises(TuneError):
tune.run(
train,
scheduler=AsyncHyperBandScheduler(metric="acc", mode="max"))
tune.run(
train, scheduler=AsyncHyperBandScheduler(metric="val", mode="max"))
# Only the search alg requires a metric
with self.assertRaises(TuneError):
tune.run(
train,
config={"a": tune.choice([1, 2])},
search_alg=HyperOptSearch(metric="acc", mode="max"))
# Metric is passed
tune.run(
train,
config={"a": tune.choice([1, 2])},
search_alg=HyperOptSearch(metric="val", mode="max"))
os.environ["TUNE_DISABLE_STRICT_METRIC_CHECKING"] = "1"
# With strict metric checking disabled, this should not raise
tune.run(train, metric="acc")
class ShimCreationTest(unittest.TestCase):
def testCreateScheduler(self):
+1 -1
View File
@@ -252,7 +252,7 @@ class SearchSpaceTest(unittest.TestCase):
with self.assertRaises(ValueError):
searcher.set_search_properties("none", "max", invalid_config)
searcher = BayesOptSearch(metric="a", mode="max")
searcher = BayesOptSearch(metric="b", mode="max")
analysis = tune.run(
_mock_objective, config=config, search_alg=searcher, num_samples=1)
trial = analysis.trials[0]
+7 -1
View File
@@ -26,8 +26,14 @@ from ray.tune.utils import flatten_dict
from ray.utils import binary_to_hex, hex_to_binary
DEBUG_PRINT_INTERVAL = 5
MAX_LEN_IDENTIFIER = int(os.environ.get("MAX_LEN_IDENTIFIER", 130))
logger = logging.getLogger(__name__)
if "MAX_LEN_IDENTIFIER" in os.environ:
logger.error(
"The MAX_LEN_IDENTIFIER environment variable is deprecated and will "
"be removed in the future. Use TUNE_MAX_LEN_IDENTIFIER instead.")
MAX_LEN_IDENTIFIER = int(
os.environ.get("TUNE_MAX_LEN_IDENTIFIER",
os.environ.get("MAX_LEN_IDENTIFIER", 130)))
def date_str():
+48 -10
View File
@@ -132,15 +132,20 @@ class TrialRunner:
fail_fast=False,
verbose=True,
checkpoint_period=None,
trial_executor=None):
trial_executor=None,
metric=None):
self._search_alg = search_alg or BasicVariantGenerator()
self._scheduler_alg = scheduler or FIFOScheduler()
self.trial_executor = trial_executor or RayTrialExecutor()
# For debugging, it may be useful to halt trials after some time has
# elapsed. TODO(ekl) consider exposing this in the API.
self._global_time_limit = float(
os.environ.get("TRIALRUNNER_WALLTIME_LIMIT", float("inf")))
self._metric = metric
if "TRIALRUNNER_WALLTIME_LIMIT" in os.environ:
raise ValueError(
"The TRIALRUNNER_WALLTIME_LIMIT environment variable is "
"deprecated. "
"Use `tune.run(time_budget_s=limit)` instead.")
self._total_time = 0
self._iteration = 0
self._has_errored = False
@@ -349,11 +354,6 @@ class TrialRunner:
def is_finished(self):
"""Returns whether all trials have finished running."""
if self._total_time > self._global_time_limit:
logger.warning("Exceeded global time limit {} / {}".format(
self._total_time, self._global_time_limit))
return True
trials_done = all(trial.is_finished() for trial in self._trials)
return trials_done and self._search_alg.is_finished()
@@ -527,6 +527,7 @@ class TrialRunner:
result = trial.last_result
result.update(done=True)
self._validate_result_metrics(result)
self._total_time += result.get(TIME_THIS_ITER_S, 0)
flat_result = flatten_dict(result)
@@ -572,6 +573,43 @@ class TrialRunner:
raise
self._process_trial_failure(trial, traceback.format_exc())
def _validate_result_metrics(self, result):
"""
Check if any of the required metrics was not reported
in the last result. If the only item is `done=True`, this
means that no result was ever received and the trial just
returned. This is also okay and will not raise an error.
"""
if int(os.environ.get("TUNE_DISABLE_STRICT_METRIC_CHECKING",
0)) != 1 and (len(result) > 1
or "done" not in result):
base_metric = self._metric
scheduler_metric = self._scheduler_alg.metric
search_metric = self._search_alg.metric
if base_metric and base_metric not in result:
report_metric = base_metric
location = "tune.run()"
elif scheduler_metric and scheduler_metric not in result:
report_metric = scheduler_metric
location = type(self._scheduler_alg).__name__
elif search_metric and search_metric not in result:
report_metric = search_metric
location = type(self._search_alg).__name__
else:
report_metric = None
location = None
if report_metric:
raise ValueError(
"Trial returned a result which did not include the "
"specified metric `{}` that `{}` expects. "
"Make sure your calls to `tune.report()` include the "
"metric, or set the "
"TUNE_DISABLE_STRICT_METRIC_CHECKING "
"environment variable to 1. Result: {}".format(
report_metric, location, result))
def _process_trial_save(self, trial):
"""Processes a trial save.
+2 -1
View File
@@ -374,7 +374,8 @@ def run(
server_port=server_port,
verbose=bool(verbose > 1),
fail_fast=fail_fast,
trial_executor=trial_executor)
trial_executor=trial_executor,
metric=metric)
if not runner.resumed:
for exp in experiments: