[tune] Support yield and return statements (#10857)

* Support `yield` and `return` statements in Tune trainable functions

* Support anonymous metric with ``tune.report(value)``

* Raise on invalid return/yield value

* Fix end to end reporter test
This commit is contained in:
Kai Fricke
2020-09-18 04:18:35 +01:00
committed by GitHub
parent 5cbc411e38
commit 508cfa3540
5 changed files with 172 additions and 45 deletions
+31 -5
View File
@@ -6,6 +6,8 @@ import shutil
import threading
import traceback
import uuid
from functools import partial
from numbers import Number
from ray.tune.registry import parameter_registry
from six.moves import queue
@@ -130,7 +132,7 @@ class StatusReporter:
self._last_checkpoint = None
self._fresh_checkpoint = False
def __call__(self, **kwargs):
def __call__(self, _metric=None, **kwargs):
"""Report updated training status.
Pass in `done=True` when the training job is completed.
@@ -151,6 +153,9 @@ class StatusReporter:
"StatusReporter._start() must be called before the first "
"report __call__ is made to ensure correct runtime metrics.")
if _metric:
kwargs["_metric"] = _metric
# time per iteration is recorded directly in the reporter to ensure
# any delays in logging results aren't counted
report_time = time.time()
@@ -280,7 +285,7 @@ class FunctionRunner(Trainable):
self._restore_tmpdir = None
self.temp_checkpoint_dir = None
def _trainable_func(self):
def _trainable_func(self, config, reporter, checkpoint_dir):
"""Subclasses can override this to set the trainable func."""
raise NotImplementedError
@@ -495,11 +500,32 @@ def wrap_function(train_func, warn=True):
def _trainable_func(self, config, reporter, checkpoint_dir):
if not use_checkpoint and not use_reporter:
output = train_func(config)
fn = partial(train_func, config)
elif use_checkpoint:
output = train_func(config, checkpoint_dir=checkpoint_dir)
fn = partial(train_func, config, checkpoint_dir=checkpoint_dir)
else:
output = train_func(config, reporter)
fn = partial(train_func, config, reporter)
def handle_output(output):
if not output:
return
elif isinstance(output, dict):
reporter(**output)
elif isinstance(output, Number):
reporter(_metric=output)
else:
raise ValueError(
"Invalid return or yield value. Either return/yield "
"a single number or a dictionary object in your "
"trainable function.")
output = None
if inspect.isgeneratorfunction(train_func):
for output in fn():
handle_output(output)
else:
output = fn()
handle_output(output)
# If train_func returns, we need to notify the main event loop
# of the last result while avoiding double logging. This is done