mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 11:53:32 +08:00
[tune] Support yield and return statements (#10857)
* Support `yield` and `return` statements in Tune trainable functions * Support anonymous metric with ``tune.report(value)`` * Raise on invalid return/yield value * Fix end to end reporter test
This commit is contained in:
@@ -6,6 +6,8 @@ import shutil
|
||||
import threading
|
||||
import traceback
|
||||
import uuid
|
||||
from functools import partial
|
||||
from numbers import Number
|
||||
|
||||
from ray.tune.registry import parameter_registry
|
||||
from six.moves import queue
|
||||
@@ -130,7 +132,7 @@ class StatusReporter:
|
||||
self._last_checkpoint = None
|
||||
self._fresh_checkpoint = False
|
||||
|
||||
def __call__(self, **kwargs):
|
||||
def __call__(self, _metric=None, **kwargs):
|
||||
"""Report updated training status.
|
||||
|
||||
Pass in `done=True` when the training job is completed.
|
||||
@@ -151,6 +153,9 @@ class StatusReporter:
|
||||
"StatusReporter._start() must be called before the first "
|
||||
"report __call__ is made to ensure correct runtime metrics.")
|
||||
|
||||
if _metric:
|
||||
kwargs["_metric"] = _metric
|
||||
|
||||
# time per iteration is recorded directly in the reporter to ensure
|
||||
# any delays in logging results aren't counted
|
||||
report_time = time.time()
|
||||
@@ -280,7 +285,7 @@ class FunctionRunner(Trainable):
|
||||
self._restore_tmpdir = None
|
||||
self.temp_checkpoint_dir = None
|
||||
|
||||
def _trainable_func(self):
|
||||
def _trainable_func(self, config, reporter, checkpoint_dir):
|
||||
"""Subclasses can override this to set the trainable func."""
|
||||
|
||||
raise NotImplementedError
|
||||
@@ -495,11 +500,32 @@ def wrap_function(train_func, warn=True):
|
||||
|
||||
def _trainable_func(self, config, reporter, checkpoint_dir):
|
||||
if not use_checkpoint and not use_reporter:
|
||||
output = train_func(config)
|
||||
fn = partial(train_func, config)
|
||||
elif use_checkpoint:
|
||||
output = train_func(config, checkpoint_dir=checkpoint_dir)
|
||||
fn = partial(train_func, config, checkpoint_dir=checkpoint_dir)
|
||||
else:
|
||||
output = train_func(config, reporter)
|
||||
fn = partial(train_func, config, reporter)
|
||||
|
||||
def handle_output(output):
|
||||
if not output:
|
||||
return
|
||||
elif isinstance(output, dict):
|
||||
reporter(**output)
|
||||
elif isinstance(output, Number):
|
||||
reporter(_metric=output)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Invalid return or yield value. Either return/yield "
|
||||
"a single number or a dictionary object in your "
|
||||
"trainable function.")
|
||||
|
||||
output = None
|
||||
if inspect.isgeneratorfunction(train_func):
|
||||
for output in fn():
|
||||
handle_output(output)
|
||||
else:
|
||||
output = fn()
|
||||
handle_output(output)
|
||||
|
||||
# If train_func returns, we need to notify the main event loop
|
||||
# of the last result while avoiding double logging. This is done
|
||||
|
||||
@@ -55,7 +55,7 @@ def shutdown():
|
||||
_session = None
|
||||
|
||||
|
||||
def report(**kwargs):
|
||||
def report(_metric=None, **kwargs):
|
||||
"""Logs all keyword arguments.
|
||||
|
||||
.. code-block:: python
|
||||
@@ -71,12 +71,13 @@ def report(**kwargs):
|
||||
analysis = tune.run(run_me)
|
||||
|
||||
Args:
|
||||
_metric: Optional default anonymous metric for ``tune.report(value)``
|
||||
**kwargs: Any key value pair to be logged by Tune. Any of these
|
||||
metrics can be used for early stopping or optimization.
|
||||
"""
|
||||
_session = get_session()
|
||||
if _session:
|
||||
return _session(**kwargs)
|
||||
return _session(_metric, **kwargs)
|
||||
|
||||
|
||||
def make_checkpoint_dir(step=None):
|
||||
|
||||
@@ -467,3 +467,53 @@ class FunctionApiTest(unittest.TestCase):
|
||||
self.assertEquals(trial_1.last_result["cp"], "DIR")
|
||||
self.assertEquals(trial_2.last_result["metric"], 500_000)
|
||||
self.assertEquals(trial_2.last_result["cp"], "DIR")
|
||||
|
||||
def test_return_anonymous(self):
|
||||
def train(config):
|
||||
return config["a"]
|
||||
|
||||
trial_1, trial_2 = tune.run(
|
||||
train, config={
|
||||
"a": tune.grid_search([4, 8])
|
||||
}).trials
|
||||
|
||||
self.assertEquals(trial_1.last_result["_metric"], 4)
|
||||
self.assertEquals(trial_2.last_result["_metric"], 8)
|
||||
|
||||
def test_return_specific(self):
|
||||
def train(config):
|
||||
return {"m": config["a"]}
|
||||
|
||||
trial_1, trial_2 = tune.run(
|
||||
train, config={
|
||||
"a": tune.grid_search([4, 8])
|
||||
}).trials
|
||||
|
||||
self.assertEquals(trial_1.last_result["m"], 4)
|
||||
self.assertEquals(trial_2.last_result["m"], 8)
|
||||
|
||||
def test_yield_anonymous(self):
|
||||
def train(config):
|
||||
for i in range(10):
|
||||
yield config["a"] + i
|
||||
|
||||
trial_1, trial_2 = tune.run(
|
||||
train, config={
|
||||
"a": tune.grid_search([4, 8])
|
||||
}).trials
|
||||
|
||||
self.assertEquals(trial_1.last_result["_metric"], 4 + 9)
|
||||
self.assertEquals(trial_2.last_result["_metric"], 8 + 9)
|
||||
|
||||
def test_yield_specific(self):
|
||||
def train(config):
|
||||
for i in range(10):
|
||||
yield {"m": config["a"] + i}
|
||||
|
||||
trial_1, trial_2 = tune.run(
|
||||
train, config={
|
||||
"a": tune.grid_search([4, 8])
|
||||
}).trials
|
||||
|
||||
self.assertEquals(trial_1.last_result["m"], 4 + 9)
|
||||
self.assertEquals(trial_2.last_result["m"], 8 + 9)
|
||||
|
||||
@@ -48,6 +48,8 @@ END_TO_END_COMMAND = """
|
||||
import ray
|
||||
from ray import tune
|
||||
|
||||
reporter = tune.progress_reporter.CLIReporter(metric_columns=["done"])
|
||||
|
||||
def f(config):
|
||||
return {"done": True}
|
||||
|
||||
@@ -71,7 +73,7 @@ tune.run_experiments({
|
||||
"c": tune.grid_search(list(range(10))),
|
||||
},
|
||||
},
|
||||
}, verbose=1)"""
|
||||
}, verbose=1, progress_reporter=reporter)"""
|
||||
|
||||
EXPECTED_END_TO_END_START = """Number of trials: 1/30 (1 RUNNING)
|
||||
+---------------+----------+-------+-----+
|
||||
@@ -81,40 +83,40 @@ EXPECTED_END_TO_END_START = """Number of trials: 1/30 (1 RUNNING)
|
||||
+---------------+----------+-------+-----+"""
|
||||
|
||||
EXPECTED_END_TO_END_END = """Number of trials: 30/30 (30 TERMINATED)
|
||||
+---------------+------------+-------+-----+-----+-----+
|
||||
| Trial name | status | loc | a | b | c |
|
||||
|---------------+------------+-------+-----+-----+-----|
|
||||
| f_xxxxx_00000 | TERMINATED | | 0 | | |
|
||||
| f_xxxxx_00001 | TERMINATED | | 1 | | |
|
||||
| f_xxxxx_00002 | TERMINATED | | 2 | | |
|
||||
| f_xxxxx_00003 | TERMINATED | | 3 | | |
|
||||
| f_xxxxx_00004 | TERMINATED | | 4 | | |
|
||||
| f_xxxxx_00005 | TERMINATED | | 5 | | |
|
||||
| f_xxxxx_00006 | TERMINATED | | 6 | | |
|
||||
| f_xxxxx_00007 | TERMINATED | | 7 | | |
|
||||
| f_xxxxx_00008 | TERMINATED | | 8 | | |
|
||||
| f_xxxxx_00009 | TERMINATED | | 9 | | |
|
||||
| f_xxxxx_00010 | TERMINATED | | | 0 | |
|
||||
| f_xxxxx_00011 | TERMINATED | | | 1 | |
|
||||
| f_xxxxx_00012 | TERMINATED | | | 2 | |
|
||||
| f_xxxxx_00013 | TERMINATED | | | 3 | |
|
||||
| f_xxxxx_00014 | TERMINATED | | | 4 | |
|
||||
| f_xxxxx_00015 | TERMINATED | | | 5 | |
|
||||
| f_xxxxx_00016 | TERMINATED | | | 6 | |
|
||||
| f_xxxxx_00017 | TERMINATED | | | 7 | |
|
||||
| f_xxxxx_00018 | TERMINATED | | | 8 | |
|
||||
| f_xxxxx_00019 | TERMINATED | | | 9 | |
|
||||
| f_xxxxx_00020 | TERMINATED | | | | 0 |
|
||||
| f_xxxxx_00021 | TERMINATED | | | | 1 |
|
||||
| f_xxxxx_00022 | TERMINATED | | | | 2 |
|
||||
| f_xxxxx_00023 | TERMINATED | | | | 3 |
|
||||
| f_xxxxx_00024 | TERMINATED | | | | 4 |
|
||||
| f_xxxxx_00025 | TERMINATED | | | | 5 |
|
||||
| f_xxxxx_00026 | TERMINATED | | | | 6 |
|
||||
| f_xxxxx_00027 | TERMINATED | | | | 7 |
|
||||
| f_xxxxx_00028 | TERMINATED | | | | 8 |
|
||||
| f_xxxxx_00029 | TERMINATED | | | | 9 |
|
||||
+---------------+------------+-------+-----+-----+-----+"""
|
||||
+---------------+------------+-------+-----+-----+-----+--------+
|
||||
| Trial name | status | loc | a | b | c | done |
|
||||
|---------------+------------+-------+-----+-----+-----+--------|
|
||||
| f_xxxxx_00000 | TERMINATED | | 0 | | | True |
|
||||
| f_xxxxx_00001 | TERMINATED | | 1 | | | True |
|
||||
| f_xxxxx_00002 | TERMINATED | | 2 | | | True |
|
||||
| f_xxxxx_00003 | TERMINATED | | 3 | | | True |
|
||||
| f_xxxxx_00004 | TERMINATED | | 4 | | | True |
|
||||
| f_xxxxx_00005 | TERMINATED | | 5 | | | True |
|
||||
| f_xxxxx_00006 | TERMINATED | | 6 | | | True |
|
||||
| f_xxxxx_00007 | TERMINATED | | 7 | | | True |
|
||||
| f_xxxxx_00008 | TERMINATED | | 8 | | | True |
|
||||
| f_xxxxx_00009 | TERMINATED | | 9 | | | True |
|
||||
| f_xxxxx_00010 | TERMINATED | | | 0 | | True |
|
||||
| f_xxxxx_00011 | TERMINATED | | | 1 | | True |
|
||||
| f_xxxxx_00012 | TERMINATED | | | 2 | | True |
|
||||
| f_xxxxx_00013 | TERMINATED | | | 3 | | True |
|
||||
| f_xxxxx_00014 | TERMINATED | | | 4 | | True |
|
||||
| f_xxxxx_00015 | TERMINATED | | | 5 | | True |
|
||||
| f_xxxxx_00016 | TERMINATED | | | 6 | | True |
|
||||
| f_xxxxx_00017 | TERMINATED | | | 7 | | True |
|
||||
| f_xxxxx_00018 | TERMINATED | | | 8 | | True |
|
||||
| f_xxxxx_00019 | TERMINATED | | | 9 | | True |
|
||||
| f_xxxxx_00020 | TERMINATED | | | | 0 | True |
|
||||
| f_xxxxx_00021 | TERMINATED | | | | 1 | True |
|
||||
| f_xxxxx_00022 | TERMINATED | | | | 2 | True |
|
||||
| f_xxxxx_00023 | TERMINATED | | | | 3 | True |
|
||||
| f_xxxxx_00024 | TERMINATED | | | | 4 | True |
|
||||
| f_xxxxx_00025 | TERMINATED | | | | 5 | True |
|
||||
| f_xxxxx_00026 | TERMINATED | | | | 6 | True |
|
||||
| f_xxxxx_00027 | TERMINATED | | | | 7 | True |
|
||||
| f_xxxxx_00028 | TERMINATED | | | | 8 | True |
|
||||
| f_xxxxx_00029 | TERMINATED | | | | 9 | True |
|
||||
+---------------+------------+-------+-----+-----+-----+--------+"""
|
||||
|
||||
EXPECTED_END_TO_END_AC = """Number of trials: 30/30 (30 TERMINATED)
|
||||
+---------------+------------+-------+-----+-----+-----+
|
||||
|
||||
Reference in New Issue
Block a user