mirror of
https://github.com/wassname/ray.git
synced 2026-07-01 09:44:50 +08:00
[tune] migrate xgboost callback api (#12745)
* Migrate to new-style xgboost callbacks * Fix flaky progress reporter test * Fix import error * Take last value (not first)
This commit is contained in:
@@ -1,14 +1,29 @@
|
||||
from typing import Dict, List, Union
|
||||
from collections import OrderedDict
|
||||
from ray import tune
|
||||
|
||||
import os
|
||||
|
||||
from ray.tune.utils import flatten_dict
|
||||
from xgboost.core import Booster
|
||||
|
||||
class TuneCallback:
|
||||
try:
|
||||
from xgboost.callback import TrainingCallback
|
||||
except ImportError:
|
||||
|
||||
class TrainingCallback:
|
||||
pass
|
||||
|
||||
|
||||
class TuneCallback(TrainingCallback):
|
||||
"""Base class for Tune's XGBoost callbacks."""
|
||||
pass
|
||||
|
||||
def __call__(self, env):
|
||||
"""Compatibility with xgboost<1.3"""
|
||||
return self.after_iteration(env.model, env.iteration,
|
||||
env.evaluation_result_list)
|
||||
|
||||
def after_iteration(self, model: Booster, epoch: int, evals_log: Dict):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@@ -54,9 +69,15 @@ class TuneReportCallback(TuneCallback):
|
||||
metrics = [metrics]
|
||||
self._metrics = metrics
|
||||
|
||||
def _get_report_dict(self, env):
|
||||
# Only one worker should report to Tune
|
||||
result_dict = dict(env.evaluation_result_list)
|
||||
def _get_report_dict(self, evals_log):
|
||||
if isinstance(evals_log, OrderedDict):
|
||||
# xgboost>=1.3
|
||||
result_dict = flatten_dict(evals_log, delimiter="-")
|
||||
for k in list(result_dict):
|
||||
result_dict[k] = result_dict[k][-1]
|
||||
else:
|
||||
# xgboost<1.3
|
||||
result_dict = dict(evals_log)
|
||||
if not self._metrics:
|
||||
report_dict = result_dict
|
||||
else:
|
||||
@@ -69,8 +90,9 @@ class TuneReportCallback(TuneCallback):
|
||||
report_dict[key] = result_dict[metric]
|
||||
return report_dict
|
||||
|
||||
def __call__(self, env):
|
||||
report_dict = self._get_report_dict(env)
|
||||
def after_iteration(self, model: Booster, epoch: int, evals_log: Dict):
|
||||
|
||||
report_dict = self._get_report_dict(evals_log)
|
||||
tune.report(**report_dict)
|
||||
|
||||
|
||||
@@ -96,14 +118,15 @@ class _TuneCheckpointCallback(TuneCallback):
|
||||
self._frequency = frequency
|
||||
|
||||
@staticmethod
|
||||
def _create_checkpoint(env, filename: str, frequency: int):
|
||||
if env.iteration % frequency > 0:
|
||||
def _create_checkpoint(model: Booster, epoch: int, filename: str,
|
||||
frequency: int):
|
||||
if epoch % frequency > 0:
|
||||
return
|
||||
with tune.checkpoint_dir(step=env.iteration) as checkpoint_dir:
|
||||
env.model.save_model(os.path.join(checkpoint_dir, filename))
|
||||
with tune.checkpoint_dir(step=epoch) as checkpoint_dir:
|
||||
model.save_model(os.path.join(checkpoint_dir, filename))
|
||||
|
||||
def __call__(self, env):
|
||||
self._create_checkpoint(env, self._filename, self._frequency)
|
||||
def after_iteration(self, model: Booster, epoch: int, evals_log: Dict):
|
||||
self._create_checkpoint(model, epoch, self._filename, self._frequency)
|
||||
|
||||
|
||||
class TuneReportCheckpointCallback(TuneCallback):
|
||||
@@ -158,6 +181,6 @@ class TuneReportCheckpointCallback(TuneCallback):
|
||||
self._checkpoint = self._checkpoint_callback_cls(filename, frequency)
|
||||
self._report = self._report_callbacks_cls(metrics)
|
||||
|
||||
def __call__(self, env):
|
||||
self._checkpoint(env)
|
||||
self._report(env)
|
||||
def after_iteration(self, model: Booster, epoch: int, evals_log: Dict):
|
||||
self._checkpoint.after_iteration(model, epoch, evals_log)
|
||||
self._report.after_iteration(model, epoch, evals_log)
|
||||
|
||||
@@ -180,14 +180,18 @@ VERBOSE_TRIAL_DETAIL = """+-------------------+----------+-------+----------+
|
||||
VERBOSE_CMD = """from ray import tune
|
||||
import random
|
||||
import numpy as np
|
||||
import time
|
||||
|
||||
|
||||
def train(config):
|
||||
if config["do"] == "complete":
|
||||
time.sleep(0.1)
|
||||
tune.report(acc=5, done=True)
|
||||
elif config["do"] == "once":
|
||||
time.sleep(0.5)
|
||||
tune.report(6)
|
||||
else:
|
||||
time.sleep(1.0)
|
||||
tune.report(acc=7)
|
||||
tune.report(acc=8)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user