mirror of
https://github.com/wassname/ray.git
synced 2026-06-29 19:32:38 +08:00
[tune] Support yield and return statements (#10857)
* Support `yield` and `return` statements in Tune trainable functions * Support anonymous metric with ``tune.report(value)`` * Raise on invalid return/yield value * Fix end to end reporter test
This commit is contained in:
@@ -6,6 +6,8 @@ import shutil
|
||||
import threading
|
||||
import traceback
|
||||
import uuid
|
||||
from functools import partial
|
||||
from numbers import Number
|
||||
|
||||
from ray.tune.registry import parameter_registry
|
||||
from six.moves import queue
|
||||
@@ -130,7 +132,7 @@ class StatusReporter:
|
||||
self._last_checkpoint = None
|
||||
self._fresh_checkpoint = False
|
||||
|
||||
def __call__(self, **kwargs):
|
||||
def __call__(self, _metric=None, **kwargs):
|
||||
"""Report updated training status.
|
||||
|
||||
Pass in `done=True` when the training job is completed.
|
||||
@@ -151,6 +153,9 @@ class StatusReporter:
|
||||
"StatusReporter._start() must be called before the first "
|
||||
"report __call__ is made to ensure correct runtime metrics.")
|
||||
|
||||
if _metric:
|
||||
kwargs["_metric"] = _metric
|
||||
|
||||
# time per iteration is recorded directly in the reporter to ensure
|
||||
# any delays in logging results aren't counted
|
||||
report_time = time.time()
|
||||
@@ -280,7 +285,7 @@ class FunctionRunner(Trainable):
|
||||
self._restore_tmpdir = None
|
||||
self.temp_checkpoint_dir = None
|
||||
|
||||
def _trainable_func(self):
|
||||
def _trainable_func(self, config, reporter, checkpoint_dir):
|
||||
"""Subclasses can override this to set the trainable func."""
|
||||
|
||||
raise NotImplementedError
|
||||
@@ -495,11 +500,32 @@ def wrap_function(train_func, warn=True):
|
||||
|
||||
def _trainable_func(self, config, reporter, checkpoint_dir):
|
||||
if not use_checkpoint and not use_reporter:
|
||||
output = train_func(config)
|
||||
fn = partial(train_func, config)
|
||||
elif use_checkpoint:
|
||||
output = train_func(config, checkpoint_dir=checkpoint_dir)
|
||||
fn = partial(train_func, config, checkpoint_dir=checkpoint_dir)
|
||||
else:
|
||||
output = train_func(config, reporter)
|
||||
fn = partial(train_func, config, reporter)
|
||||
|
||||
def handle_output(output):
|
||||
if not output:
|
||||
return
|
||||
elif isinstance(output, dict):
|
||||
reporter(**output)
|
||||
elif isinstance(output, Number):
|
||||
reporter(_metric=output)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Invalid return or yield value. Either return/yield "
|
||||
"a single number or a dictionary object in your "
|
||||
"trainable function.")
|
||||
|
||||
output = None
|
||||
if inspect.isgeneratorfunction(train_func):
|
||||
for output in fn():
|
||||
handle_output(output)
|
||||
else:
|
||||
output = fn()
|
||||
handle_output(output)
|
||||
|
||||
# If train_func returns, we need to notify the main event loop
|
||||
# of the last result while avoiding double logging. This is done
|
||||
|
||||
Reference in New Issue
Block a user