mirror of
https://github.com/wassname/ray.git
synced 2026-06-30 15:17:56 +08:00
[tune] Add xgboost_ray integration (#12572)
This commit is contained in:
@@ -54,7 +54,8 @@ class TuneReportCallback(TuneCallback):
|
||||
metrics = [metrics]
|
||||
self._metrics = metrics
|
||||
|
||||
def __call__(self, env):
|
||||
def _get_report_dict(self, env):
|
||||
# Only one worker should report to Tune
|
||||
result_dict = dict(env.evaluation_result_list)
|
||||
if not self._metrics:
|
||||
report_dict = result_dict
|
||||
@@ -66,6 +67,10 @@ class TuneReportCallback(TuneCallback):
|
||||
else:
|
||||
metric = key
|
||||
report_dict[key] = result_dict[metric]
|
||||
return report_dict
|
||||
|
||||
def __call__(self, env):
|
||||
report_dict = self._get_report_dict(env)
|
||||
tune.report(**report_dict)
|
||||
|
||||
|
||||
@@ -81,15 +86,24 @@ class _TuneCheckpointCallback(TuneCallback):
|
||||
Args:
|
||||
filename (str): Filename of the checkpoint within the checkpoint
|
||||
directory. Defaults to "checkpoint".
|
||||
frequency (int): How often to save checkpoints. Per default, a
|
||||
checkpoint is saved every five iterations.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, filename: str = "checkpoint"):
|
||||
def __init__(self, filename: str = "checkpoint", frequency: int = 5):
|
||||
self._filename = filename
|
||||
self._frequency = frequency
|
||||
|
||||
@staticmethod
|
||||
def _create_checkpoint(env, filename: str, frequency: int):
|
||||
if env.iteration % frequency > 0:
|
||||
return
|
||||
with tune.checkpoint_dir(step=env.iteration) as checkpoint_dir:
|
||||
env.model.save_model(os.path.join(checkpoint_dir, filename))
|
||||
|
||||
def __call__(self, env):
|
||||
with tune.checkpoint_dir(step=env.iteration) as checkpoint_dir:
|
||||
env.model.save_model(os.path.join(checkpoint_dir, self._filename))
|
||||
self._create_checkpoint(env, self._filename, self._frequency)
|
||||
|
||||
|
||||
class TuneReportCheckpointCallback(TuneCallback):
|
||||
@@ -108,6 +122,8 @@ class TuneReportCheckpointCallback(TuneCallback):
|
||||
directory. Defaults to "checkpoint". If this is None,
|
||||
all metrics will be reported to Tune under their default names as
|
||||
obtained from XGBoost.
|
||||
frequency (int): How often to save checkpoints. Per default, a
|
||||
checkpoint is saved every five iterations.
|
||||
|
||||
Example:
|
||||
|
||||
@@ -132,12 +148,15 @@ class TuneReportCheckpointCallback(TuneCallback):
|
||||
{"loss": "eval-logloss"}, "xgboost.mdl)])
|
||||
|
||||
"""
|
||||
_checkpoint_callback_cls = _TuneCheckpointCallback
|
||||
_report_callbacks_cls = TuneReportCallback
|
||||
|
||||
def __init__(self,
|
||||
metrics: Union[None, str, List[str], Dict[str, str]] = None,
|
||||
filename: str = "checkpoint"):
|
||||
self._checkpoint = _TuneCheckpointCallback(filename)
|
||||
self._report = TuneReportCallback(metrics)
|
||||
filename: str = "checkpoint",
|
||||
frequency: int = 5):
|
||||
self._checkpoint = self._checkpoint_callback_cls(filename, frequency)
|
||||
self._report = self._report_callbacks_cls(metrics)
|
||||
|
||||
def __call__(self, env):
|
||||
self._checkpoint(env)
|
||||
|
||||
Reference in New Issue
Block a user