[tune] Add xgboost_ray integration (#12572)

This commit is contained in:
Kai Fricke
2020-12-04 22:59:20 +01:00
committed by GitHub
parent 219c445648
commit 1c0d10f67e
+26 -7
View File
@@ -54,7 +54,8 @@ class TuneReportCallback(TuneCallback):
metrics = [metrics]
self._metrics = metrics
def __call__(self, env):
def _get_report_dict(self, env):
# Only one worker should report to Tune
result_dict = dict(env.evaluation_result_list)
if not self._metrics:
report_dict = result_dict
@@ -66,6 +67,10 @@ class TuneReportCallback(TuneCallback):
else:
metric = key
report_dict[key] = result_dict[metric]
return report_dict
def __call__(self, env):
report_dict = self._get_report_dict(env)
tune.report(**report_dict)
@@ -81,15 +86,24 @@ class _TuneCheckpointCallback(TuneCallback):
Args:
filename (str): Filename of the checkpoint within the checkpoint
directory. Defaults to "checkpoint".
frequency (int): How often to save checkpoints. Per default, a
checkpoint is saved every five iterations.
"""
def __init__(self, filename: str = "checkpoint"):
def __init__(self, filename: str = "checkpoint", frequency: int = 5):
self._filename = filename
self._frequency = frequency
@staticmethod
def _create_checkpoint(env, filename: str, frequency: int):
if env.iteration % frequency > 0:
return
with tune.checkpoint_dir(step=env.iteration) as checkpoint_dir:
env.model.save_model(os.path.join(checkpoint_dir, filename))
def __call__(self, env):
with tune.checkpoint_dir(step=env.iteration) as checkpoint_dir:
env.model.save_model(os.path.join(checkpoint_dir, self._filename))
self._create_checkpoint(env, self._filename, self._frequency)
class TuneReportCheckpointCallback(TuneCallback):
@@ -108,6 +122,8 @@ class TuneReportCheckpointCallback(TuneCallback):
directory. Defaults to "checkpoint". If this is None,
all metrics will be reported to Tune under their default names as
obtained from XGBoost.
frequency (int): How often to save checkpoints. Per default, a
checkpoint is saved every five iterations.
Example:
@@ -132,12 +148,15 @@ class TuneReportCheckpointCallback(TuneCallback):
{"loss": "eval-logloss"}, "xgboost.mdl)])
"""
_checkpoint_callback_cls = _TuneCheckpointCallback
_report_callbacks_cls = TuneReportCallback
def __init__(self,
metrics: Union[None, str, List[str], Dict[str, str]] = None,
filename: str = "checkpoint"):
self._checkpoint = _TuneCheckpointCallback(filename)
self._report = TuneReportCallback(metrics)
filename: str = "checkpoint",
frequency: int = 5):
self._checkpoint = self._checkpoint_callback_cls(filename, frequency)
self._report = self._report_callbacks_cls(metrics)
def __call__(self, env):
self._checkpoint(env)