diff --git a/python/ray/tune/integration/xgboost.py b/python/ray/tune/integration/xgboost.py index 730883c25..fa40fa60a 100644 --- a/python/ray/tune/integration/xgboost.py +++ b/python/ray/tune/integration/xgboost.py @@ -54,7 +54,8 @@ class TuneReportCallback(TuneCallback): metrics = [metrics] self._metrics = metrics - def __call__(self, env): + def _get_report_dict(self, env): + # Only one worker should report to Tune result_dict = dict(env.evaluation_result_list) if not self._metrics: report_dict = result_dict @@ -66,6 +67,10 @@ class TuneReportCallback(TuneCallback): else: metric = key report_dict[key] = result_dict[metric] + return report_dict + + def __call__(self, env): + report_dict = self._get_report_dict(env) tune.report(**report_dict) @@ -81,15 +86,24 @@ class _TuneCheckpointCallback(TuneCallback): Args: filename (str): Filename of the checkpoint within the checkpoint directory. Defaults to "checkpoint". + frequency (int): How often to save checkpoints. Per default, a + checkpoint is saved every five iterations. """ - def __init__(self, filename: str = "checkpoint"): + def __init__(self, filename: str = "checkpoint", frequency: int = 5): self._filename = filename + self._frequency = frequency + + @staticmethod + def _create_checkpoint(env, filename: str, frequency: int): + if env.iteration % frequency > 0: + return + with tune.checkpoint_dir(step=env.iteration) as checkpoint_dir: + env.model.save_model(os.path.join(checkpoint_dir, filename)) def __call__(self, env): - with tune.checkpoint_dir(step=env.iteration) as checkpoint_dir: - env.model.save_model(os.path.join(checkpoint_dir, self._filename)) + self._create_checkpoint(env, self._filename, self._frequency) class TuneReportCheckpointCallback(TuneCallback): @@ -108,6 +122,8 @@ class TuneReportCheckpointCallback(TuneCallback): directory. Defaults to "checkpoint". If this is None, all metrics will be reported to Tune under their default names as obtained from XGBoost. + frequency (int): How often to save checkpoints. Per default, a + checkpoint is saved every five iterations. Example: @@ -132,12 +148,15 @@ class TuneReportCheckpointCallback(TuneCallback): {"loss": "eval-logloss"}, "xgboost.mdl)]) """ + _checkpoint_callback_cls = _TuneCheckpointCallback + _report_callbacks_cls = TuneReportCallback def __init__(self, metrics: Union[None, str, List[str], Dict[str, str]] = None, - filename: str = "checkpoint"): - self._checkpoint = _TuneCheckpointCallback(filename) - self._report = TuneReportCallback(metrics) + filename: str = "checkpoint", + frequency: int = 5): + self._checkpoint = self._checkpoint_callback_cls(filename, frequency) + self._report = self._report_callbacks_cls(metrics) def __call__(self, env): self._checkpoint(env)