diff --git a/doc/source/tune/api_docs/trainable.rst b/doc/source/tune/api_docs/trainable.rst index 6d6308e40..55224aacb 100644 --- a/doc/source/tune/api_docs/trainable.rst +++ b/doc/source/tune/api_docs/trainable.rst @@ -17,7 +17,6 @@ For the sake of example, let's maximize this objective function: Function API ------------ - Here is a simple example of using the function API. You can report intermediate metrics by simply calling ``tune.report`` within the provided function. .. code-block:: python @@ -28,7 +27,7 @@ Here is a simple example of using the function API. You can report intermediate for x in range(20): intermediate_score = objective(x, config["a"], config["b"]) - tune.report(value=intermediate_score) # This sends the score to Tune. + tune.report(score=intermediate_score) # This sends the score to Tune. analysis = tune.run( trainable, @@ -41,7 +40,56 @@ Here is a simple example of using the function API. You can report intermediate Tune will run this function on a separate thread in a Ray actor process. -.. tip:: If you want to leverage multi-node data parallel training with PyTorch while using parallel hyperparameter tuning, check out our :ref:PyTorch user guide and Tune's :ref:distributed pytorch integrations. +.. tip:: If you want to leverage multi-node data parallel training with PyTorch while using parallel hyperparameter tuning, check out our :ref:`PyTorch ` user guide and Tune's :ref:`distributed pytorch integrations `. + +Function API return and yield values +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +Instead of using ``tune.report()``, you can also use Python's ``yield`` +statement to report metrics to Ray Tune: + + +.. code-block:: python + + def trainable(config): + # config (dict): A dict of hyperparameters. + + for x in range(20): + intermediate_score = objective(x, config["a"], config["b"]) + + yield {"score": intermediate_score} # This sends the score to Tune. + + analysis = tune.run( + trainable, + config={"a": 2, "b": 4} + ) + + print("best config: ", analysis.get_best_config(metric="score", mode="max")) + +If you yield a dictionary object, this will work just as ``tune.report()``. +If you yield a number, if will be reported to Ray Tune with the key ``_metric``, i.e. +as if you had called ``tune.report(_metric=value)``. + +Ray Tune supports the same functionality for return values if you only +report metrics at the end of each run: + +.. code-block:: python + + def trainable(config): + # config (dict): A dict of hyperparameters. + + final_score = 0 + for x in range(20): + final_score = objective(x, config["a"], config["b"]) + + return {"score": final_score} # This sends the score to Tune. + + analysis = tune.run( + trainable, + config={"a": 2, "b": 4} + ) + + print("best config: ", analysis.get_best_config(metric="score", mode="max")) + .. _tune-function-checkpointing: diff --git a/python/ray/tune/function_runner.py b/python/ray/tune/function_runner.py index 84bc82772..9f1550916 100644 --- a/python/ray/tune/function_runner.py +++ b/python/ray/tune/function_runner.py @@ -6,6 +6,8 @@ import shutil import threading import traceback import uuid +from functools import partial +from numbers import Number from ray.tune.registry import parameter_registry from six.moves import queue @@ -130,7 +132,7 @@ class StatusReporter: self._last_checkpoint = None self._fresh_checkpoint = False - def __call__(self, **kwargs): + def __call__(self, _metric=None, **kwargs): """Report updated training status. Pass in `done=True` when the training job is completed. @@ -151,6 +153,9 @@ class StatusReporter: "StatusReporter._start() must be called before the first " "report __call__ is made to ensure correct runtime metrics.") + if _metric: + kwargs["_metric"] = _metric + # time per iteration is recorded directly in the reporter to ensure # any delays in logging results aren't counted report_time = time.time() @@ -280,7 +285,7 @@ class FunctionRunner(Trainable): self._restore_tmpdir = None self.temp_checkpoint_dir = None - def _trainable_func(self): + def _trainable_func(self, config, reporter, checkpoint_dir): """Subclasses can override this to set the trainable func.""" raise NotImplementedError @@ -495,11 +500,32 @@ def wrap_function(train_func, warn=True): def _trainable_func(self, config, reporter, checkpoint_dir): if not use_checkpoint and not use_reporter: - output = train_func(config) + fn = partial(train_func, config) elif use_checkpoint: - output = train_func(config, checkpoint_dir=checkpoint_dir) + fn = partial(train_func, config, checkpoint_dir=checkpoint_dir) else: - output = train_func(config, reporter) + fn = partial(train_func, config, reporter) + + def handle_output(output): + if not output: + return + elif isinstance(output, dict): + reporter(**output) + elif isinstance(output, Number): + reporter(_metric=output) + else: + raise ValueError( + "Invalid return or yield value. Either return/yield " + "a single number or a dictionary object in your " + "trainable function.") + + output = None + if inspect.isgeneratorfunction(train_func): + for output in fn(): + handle_output(output) + else: + output = fn() + handle_output(output) # If train_func returns, we need to notify the main event loop # of the last result while avoiding double logging. This is done diff --git a/python/ray/tune/session.py b/python/ray/tune/session.py index 4a4169bd9..70bbbffd6 100644 --- a/python/ray/tune/session.py +++ b/python/ray/tune/session.py @@ -55,7 +55,7 @@ def shutdown(): _session = None -def report(**kwargs): +def report(_metric=None, **kwargs): """Logs all keyword arguments. .. code-block:: python @@ -71,12 +71,13 @@ def report(**kwargs): analysis = tune.run(run_me) Args: + _metric: Optional default anonymous metric for ``tune.report(value)`` **kwargs: Any key value pair to be logged by Tune. Any of these metrics can be used for early stopping or optimization. """ _session = get_session() if _session: - return _session(**kwargs) + return _session(_metric, **kwargs) def make_checkpoint_dir(step=None): diff --git a/python/ray/tune/tests/test_function_api.py b/python/ray/tune/tests/test_function_api.py index 10f781457..200938120 100644 --- a/python/ray/tune/tests/test_function_api.py +++ b/python/ray/tune/tests/test_function_api.py @@ -467,3 +467,53 @@ class FunctionApiTest(unittest.TestCase): self.assertEquals(trial_1.last_result["cp"], "DIR") self.assertEquals(trial_2.last_result["metric"], 500_000) self.assertEquals(trial_2.last_result["cp"], "DIR") + + def test_return_anonymous(self): + def train(config): + return config["a"] + + trial_1, trial_2 = tune.run( + train, config={ + "a": tune.grid_search([4, 8]) + }).trials + + self.assertEquals(trial_1.last_result["_metric"], 4) + self.assertEquals(trial_2.last_result["_metric"], 8) + + def test_return_specific(self): + def train(config): + return {"m": config["a"]} + + trial_1, trial_2 = tune.run( + train, config={ + "a": tune.grid_search([4, 8]) + }).trials + + self.assertEquals(trial_1.last_result["m"], 4) + self.assertEquals(trial_2.last_result["m"], 8) + + def test_yield_anonymous(self): + def train(config): + for i in range(10): + yield config["a"] + i + + trial_1, trial_2 = tune.run( + train, config={ + "a": tune.grid_search([4, 8]) + }).trials + + self.assertEquals(trial_1.last_result["_metric"], 4 + 9) + self.assertEquals(trial_2.last_result["_metric"], 8 + 9) + + def test_yield_specific(self): + def train(config): + for i in range(10): + yield {"m": config["a"] + i} + + trial_1, trial_2 = tune.run( + train, config={ + "a": tune.grid_search([4, 8]) + }).trials + + self.assertEquals(trial_1.last_result["m"], 4 + 9) + self.assertEquals(trial_2.last_result["m"], 8 + 9) diff --git a/python/ray/tune/tests/test_progress_reporter.py b/python/ray/tune/tests/test_progress_reporter.py index 70ffc07ad..2ffb6bd6f 100644 --- a/python/ray/tune/tests/test_progress_reporter.py +++ b/python/ray/tune/tests/test_progress_reporter.py @@ -48,6 +48,8 @@ END_TO_END_COMMAND = """ import ray from ray import tune +reporter = tune.progress_reporter.CLIReporter(metric_columns=["done"]) + def f(config): return {"done": True} @@ -71,7 +73,7 @@ tune.run_experiments({ "c": tune.grid_search(list(range(10))), }, }, -}, verbose=1)""" +}, verbose=1, progress_reporter=reporter)""" EXPECTED_END_TO_END_START = """Number of trials: 1/30 (1 RUNNING) +---------------+----------+-------+-----+ @@ -81,40 +83,40 @@ EXPECTED_END_TO_END_START = """Number of trials: 1/30 (1 RUNNING) +---------------+----------+-------+-----+""" EXPECTED_END_TO_END_END = """Number of trials: 30/30 (30 TERMINATED) -+---------------+------------+-------+-----+-----+-----+ -| Trial name | status | loc | a | b | c | -|---------------+------------+-------+-----+-----+-----| -| f_xxxxx_00000 | TERMINATED | | 0 | | | -| f_xxxxx_00001 | TERMINATED | | 1 | | | -| f_xxxxx_00002 | TERMINATED | | 2 | | | -| f_xxxxx_00003 | TERMINATED | | 3 | | | -| f_xxxxx_00004 | TERMINATED | | 4 | | | -| f_xxxxx_00005 | TERMINATED | | 5 | | | -| f_xxxxx_00006 | TERMINATED | | 6 | | | -| f_xxxxx_00007 | TERMINATED | | 7 | | | -| f_xxxxx_00008 | TERMINATED | | 8 | | | -| f_xxxxx_00009 | TERMINATED | | 9 | | | -| f_xxxxx_00010 | TERMINATED | | | 0 | | -| f_xxxxx_00011 | TERMINATED | | | 1 | | -| f_xxxxx_00012 | TERMINATED | | | 2 | | -| f_xxxxx_00013 | TERMINATED | | | 3 | | -| f_xxxxx_00014 | TERMINATED | | | 4 | | -| f_xxxxx_00015 | TERMINATED | | | 5 | | -| f_xxxxx_00016 | TERMINATED | | | 6 | | -| f_xxxxx_00017 | TERMINATED | | | 7 | | -| f_xxxxx_00018 | TERMINATED | | | 8 | | -| f_xxxxx_00019 | TERMINATED | | | 9 | | -| f_xxxxx_00020 | TERMINATED | | | | 0 | -| f_xxxxx_00021 | TERMINATED | | | | 1 | -| f_xxxxx_00022 | TERMINATED | | | | 2 | -| f_xxxxx_00023 | TERMINATED | | | | 3 | -| f_xxxxx_00024 | TERMINATED | | | | 4 | -| f_xxxxx_00025 | TERMINATED | | | | 5 | -| f_xxxxx_00026 | TERMINATED | | | | 6 | -| f_xxxxx_00027 | TERMINATED | | | | 7 | -| f_xxxxx_00028 | TERMINATED | | | | 8 | -| f_xxxxx_00029 | TERMINATED | | | | 9 | -+---------------+------------+-------+-----+-----+-----+""" ++---------------+------------+-------+-----+-----+-----+--------+ +| Trial name | status | loc | a | b | c | done | +|---------------+------------+-------+-----+-----+-----+--------| +| f_xxxxx_00000 | TERMINATED | | 0 | | | True | +| f_xxxxx_00001 | TERMINATED | | 1 | | | True | +| f_xxxxx_00002 | TERMINATED | | 2 | | | True | +| f_xxxxx_00003 | TERMINATED | | 3 | | | True | +| f_xxxxx_00004 | TERMINATED | | 4 | | | True | +| f_xxxxx_00005 | TERMINATED | | 5 | | | True | +| f_xxxxx_00006 | TERMINATED | | 6 | | | True | +| f_xxxxx_00007 | TERMINATED | | 7 | | | True | +| f_xxxxx_00008 | TERMINATED | | 8 | | | True | +| f_xxxxx_00009 | TERMINATED | | 9 | | | True | +| f_xxxxx_00010 | TERMINATED | | | 0 | | True | +| f_xxxxx_00011 | TERMINATED | | | 1 | | True | +| f_xxxxx_00012 | TERMINATED | | | 2 | | True | +| f_xxxxx_00013 | TERMINATED | | | 3 | | True | +| f_xxxxx_00014 | TERMINATED | | | 4 | | True | +| f_xxxxx_00015 | TERMINATED | | | 5 | | True | +| f_xxxxx_00016 | TERMINATED | | | 6 | | True | +| f_xxxxx_00017 | TERMINATED | | | 7 | | True | +| f_xxxxx_00018 | TERMINATED | | | 8 | | True | +| f_xxxxx_00019 | TERMINATED | | | 9 | | True | +| f_xxxxx_00020 | TERMINATED | | | | 0 | True | +| f_xxxxx_00021 | TERMINATED | | | | 1 | True | +| f_xxxxx_00022 | TERMINATED | | | | 2 | True | +| f_xxxxx_00023 | TERMINATED | | | | 3 | True | +| f_xxxxx_00024 | TERMINATED | | | | 4 | True | +| f_xxxxx_00025 | TERMINATED | | | | 5 | True | +| f_xxxxx_00026 | TERMINATED | | | | 6 | True | +| f_xxxxx_00027 | TERMINATED | | | | 7 | True | +| f_xxxxx_00028 | TERMINATED | | | | 8 | True | +| f_xxxxx_00029 | TERMINATED | | | | 9 | True | ++---------------+------------+-------+-----+-----+-----+--------+""" EXPECTED_END_TO_END_AC = """Number of trials: 30/30 (30 TERMINATED) +---------------+------------+-------+-----+-----+-----+