[tune] Support yield and return statements (#10857)

* Support `yield` and `return` statements in Tune trainable functions

* Support anonymous metric with ``tune.report(value)``

* Raise on invalid return/yield value

* Fix end to end reporter test
This commit is contained in:
Kai Fricke
2020-09-18 04:18:35 +01:00
committed by GitHub
parent 5cbc411e38
commit 508cfa3540
5 changed files with 172 additions and 45 deletions
+31 -5
View File
@@ -6,6 +6,8 @@ import shutil
import threading
import traceback
import uuid
from functools import partial
from numbers import Number
from ray.tune.registry import parameter_registry
from six.moves import queue
@@ -130,7 +132,7 @@ class StatusReporter:
self._last_checkpoint = None
self._fresh_checkpoint = False
def __call__(self, **kwargs):
def __call__(self, _metric=None, **kwargs):
"""Report updated training status.
Pass in `done=True` when the training job is completed.
@@ -151,6 +153,9 @@ class StatusReporter:
"StatusReporter._start() must be called before the first "
"report __call__ is made to ensure correct runtime metrics.")
if _metric:
kwargs["_metric"] = _metric
# time per iteration is recorded directly in the reporter to ensure
# any delays in logging results aren't counted
report_time = time.time()
@@ -280,7 +285,7 @@ class FunctionRunner(Trainable):
self._restore_tmpdir = None
self.temp_checkpoint_dir = None
def _trainable_func(self):
def _trainable_func(self, config, reporter, checkpoint_dir):
"""Subclasses can override this to set the trainable func."""
raise NotImplementedError
@@ -495,11 +500,32 @@ def wrap_function(train_func, warn=True):
def _trainable_func(self, config, reporter, checkpoint_dir):
if not use_checkpoint and not use_reporter:
output = train_func(config)
fn = partial(train_func, config)
elif use_checkpoint:
output = train_func(config, checkpoint_dir=checkpoint_dir)
fn = partial(train_func, config, checkpoint_dir=checkpoint_dir)
else:
output = train_func(config, reporter)
fn = partial(train_func, config, reporter)
def handle_output(output):
if not output:
return
elif isinstance(output, dict):
reporter(**output)
elif isinstance(output, Number):
reporter(_metric=output)
else:
raise ValueError(
"Invalid return or yield value. Either return/yield "
"a single number or a dictionary object in your "
"trainable function.")
output = None
if inspect.isgeneratorfunction(train_func):
for output in fn():
handle_output(output)
else:
output = fn()
handle_output(output)
# If train_func returns, we need to notify the main event loop
# of the last result while avoiding double logging. This is done
+3 -2
View File
@@ -55,7 +55,7 @@ def shutdown():
_session = None
def report(**kwargs):
def report(_metric=None, **kwargs):
"""Logs all keyword arguments.
.. code-block:: python
@@ -71,12 +71,13 @@ def report(**kwargs):
analysis = tune.run(run_me)
Args:
_metric: Optional default anonymous metric for ``tune.report(value)``
**kwargs: Any key value pair to be logged by Tune. Any of these
metrics can be used for early stopping or optimization.
"""
_session = get_session()
if _session:
return _session(**kwargs)
return _session(_metric, **kwargs)
def make_checkpoint_dir(step=None):
@@ -467,3 +467,53 @@ class FunctionApiTest(unittest.TestCase):
self.assertEquals(trial_1.last_result["cp"], "DIR")
self.assertEquals(trial_2.last_result["metric"], 500_000)
self.assertEquals(trial_2.last_result["cp"], "DIR")
def test_return_anonymous(self):
def train(config):
return config["a"]
trial_1, trial_2 = tune.run(
train, config={
"a": tune.grid_search([4, 8])
}).trials
self.assertEquals(trial_1.last_result["_metric"], 4)
self.assertEquals(trial_2.last_result["_metric"], 8)
def test_return_specific(self):
def train(config):
return {"m": config["a"]}
trial_1, trial_2 = tune.run(
train, config={
"a": tune.grid_search([4, 8])
}).trials
self.assertEquals(trial_1.last_result["m"], 4)
self.assertEquals(trial_2.last_result["m"], 8)
def test_yield_anonymous(self):
def train(config):
for i in range(10):
yield config["a"] + i
trial_1, trial_2 = tune.run(
train, config={
"a": tune.grid_search([4, 8])
}).trials
self.assertEquals(trial_1.last_result["_metric"], 4 + 9)
self.assertEquals(trial_2.last_result["_metric"], 8 + 9)
def test_yield_specific(self):
def train(config):
for i in range(10):
yield {"m": config["a"] + i}
trial_1, trial_2 = tune.run(
train, config={
"a": tune.grid_search([4, 8])
}).trials
self.assertEquals(trial_1.last_result["m"], 4 + 9)
self.assertEquals(trial_2.last_result["m"], 8 + 9)
+37 -35
View File
@@ -48,6 +48,8 @@ END_TO_END_COMMAND = """
import ray
from ray import tune
reporter = tune.progress_reporter.CLIReporter(metric_columns=["done"])
def f(config):
return {"done": True}
@@ -71,7 +73,7 @@ tune.run_experiments({
"c": tune.grid_search(list(range(10))),
},
},
}, verbose=1)"""
}, verbose=1, progress_reporter=reporter)"""
EXPECTED_END_TO_END_START = """Number of trials: 1/30 (1 RUNNING)
+---------------+----------+-------+-----+
@@ -81,40 +83,40 @@ EXPECTED_END_TO_END_START = """Number of trials: 1/30 (1 RUNNING)
+---------------+----------+-------+-----+"""
EXPECTED_END_TO_END_END = """Number of trials: 30/30 (30 TERMINATED)
+---------------+------------+-------+-----+-----+-----+
| Trial name | status | loc | a | b | c |
|---------------+------------+-------+-----+-----+-----|
| f_xxxxx_00000 | TERMINATED | | 0 | | |
| f_xxxxx_00001 | TERMINATED | | 1 | | |
| f_xxxxx_00002 | TERMINATED | | 2 | | |
| f_xxxxx_00003 | TERMINATED | | 3 | | |
| f_xxxxx_00004 | TERMINATED | | 4 | | |
| f_xxxxx_00005 | TERMINATED | | 5 | | |
| f_xxxxx_00006 | TERMINATED | | 6 | | |
| f_xxxxx_00007 | TERMINATED | | 7 | | |
| f_xxxxx_00008 | TERMINATED | | 8 | | |
| f_xxxxx_00009 | TERMINATED | | 9 | | |
| f_xxxxx_00010 | TERMINATED | | | 0 | |
| f_xxxxx_00011 | TERMINATED | | | 1 | |
| f_xxxxx_00012 | TERMINATED | | | 2 | |
| f_xxxxx_00013 | TERMINATED | | | 3 | |
| f_xxxxx_00014 | TERMINATED | | | 4 | |
| f_xxxxx_00015 | TERMINATED | | | 5 | |
| f_xxxxx_00016 | TERMINATED | | | 6 | |
| f_xxxxx_00017 | TERMINATED | | | 7 | |
| f_xxxxx_00018 | TERMINATED | | | 8 | |
| f_xxxxx_00019 | TERMINATED | | | 9 | |
| f_xxxxx_00020 | TERMINATED | | | | 0 |
| f_xxxxx_00021 | TERMINATED | | | | 1 |
| f_xxxxx_00022 | TERMINATED | | | | 2 |
| f_xxxxx_00023 | TERMINATED | | | | 3 |
| f_xxxxx_00024 | TERMINATED | | | | 4 |
| f_xxxxx_00025 | TERMINATED | | | | 5 |
| f_xxxxx_00026 | TERMINATED | | | | 6 |
| f_xxxxx_00027 | TERMINATED | | | | 7 |
| f_xxxxx_00028 | TERMINATED | | | | 8 |
| f_xxxxx_00029 | TERMINATED | | | | 9 |
+---------------+------------+-------+-----+-----+-----+"""
+---------------+------------+-------+-----+-----+-----+--------+
| Trial name | status | loc | a | b | c | done |
|---------------+------------+-------+-----+-----+-----+--------|
| f_xxxxx_00000 | TERMINATED | | 0 | | | True |
| f_xxxxx_00001 | TERMINATED | | 1 | | | True |
| f_xxxxx_00002 | TERMINATED | | 2 | | | True |
| f_xxxxx_00003 | TERMINATED | | 3 | | | True |
| f_xxxxx_00004 | TERMINATED | | 4 | | | True |
| f_xxxxx_00005 | TERMINATED | | 5 | | | True |
| f_xxxxx_00006 | TERMINATED | | 6 | | | True |
| f_xxxxx_00007 | TERMINATED | | 7 | | | True |
| f_xxxxx_00008 | TERMINATED | | 8 | | | True |
| f_xxxxx_00009 | TERMINATED | | 9 | | | True |
| f_xxxxx_00010 | TERMINATED | | | 0 | | True |
| f_xxxxx_00011 | TERMINATED | | | 1 | | True |
| f_xxxxx_00012 | TERMINATED | | | 2 | | True |
| f_xxxxx_00013 | TERMINATED | | | 3 | | True |
| f_xxxxx_00014 | TERMINATED | | | 4 | | True |
| f_xxxxx_00015 | TERMINATED | | | 5 | | True |
| f_xxxxx_00016 | TERMINATED | | | 6 | | True |
| f_xxxxx_00017 | TERMINATED | | | 7 | | True |
| f_xxxxx_00018 | TERMINATED | | | 8 | | True |
| f_xxxxx_00019 | TERMINATED | | | 9 | | True |
| f_xxxxx_00020 | TERMINATED | | | | 0 | True |
| f_xxxxx_00021 | TERMINATED | | | | 1 | True |
| f_xxxxx_00022 | TERMINATED | | | | 2 | True |
| f_xxxxx_00023 | TERMINATED | | | | 3 | True |
| f_xxxxx_00024 | TERMINATED | | | | 4 | True |
| f_xxxxx_00025 | TERMINATED | | | | 5 | True |
| f_xxxxx_00026 | TERMINATED | | | | 6 | True |
| f_xxxxx_00027 | TERMINATED | | | | 7 | True |
| f_xxxxx_00028 | TERMINATED | | | | 8 | True |
| f_xxxxx_00029 | TERMINATED | | | | 9 | True |
+---------------+------------+-------+-----+-----+-----+--------+"""
EXPECTED_END_TO_END_AC = """Number of trials: 30/30 (30 TERMINATED)
+---------------+------------+-------+-----+-----+-----+