[tune] Add support for function-based stopping condition (#5754)

This commit is contained in:
Ujval Misra
2019-09-23 18:39:00 -07:00
committed by Richard Liaw
parent b03147e7bf
commit a4659a8f8b
5 changed files with 65 additions and 5 deletions
+36 -1
View File
@@ -106,7 +106,7 @@ All results reported by the trainable will be logged locally to a unique directo
Trial Parallelism
~~~~~~~~~~~~~~~~~
Tune automatically N concurrent trials, where N is the number of CPUs (cores) on your machine. By default, Tune assumes that each trial will only require 1 CPU. You can override this with ``resources_per_trial``:
Tune automatically runs N concurrent trials, where N is the number of CPUs (cores) on your machine. By default, Tune assumes that each trial will only require 1 CPU. You can override this with ``resources_per_trial``:
.. code-block:: python
@@ -474,6 +474,41 @@ You often will want to compute a large object (e.g., training data, model weight
tune.run(f)
Custom Stopping Criteria
------------------------
You can control when trials are stopped early by passing the ``stop`` argument to ``tune.run``. This argument takes either a dictionary or a function.
If a dictionary is passed in, the keys may be any field in the return result of ``tune.track.log`` in the Function API or ``train()`` (including the results from ``_train`` and auto-filled metrics).
In the example below, each trial will be stopped either when it completes 10 iterations OR when it reaches a mean accuracy of 0.98. Note that `training_iteration` is an auto-filled metric by Tune.
.. code-block:: python
tune.run(
my_trainable,
stop={"training_iteration": 10, "mean_accuracy": 0.98}
)
For more flexibility, you can pass in a function instead. If a function is passed in, it must take ``(trial_id, result)`` as arguments and return a boolean (``True`` if trial should be stopped and ``False`` otherwise).
You can use this to stop all trials after the criteria is fulfilled by any individual trial:
.. code-block:: python
class Stopper:
def __init__(self):
self.should_stop = False
def stop(self, trial_id, result):
if not self.should_stop and result['foo'] > 10:
self.should_stop = True
return self.should_stop
stopper = Stopper()
tune.run(my_trainable, stop=stopper.stop)
Note that in the above example all trials will not stop immediately, but will do so once their current iterations are complete.
Auto-Filled Results
-------------------
+10 -1
View File
@@ -3,6 +3,7 @@ from __future__ import division
from __future__ import print_function
import copy
import inspect
import logging
import os
import six
@@ -87,11 +88,19 @@ class Experiment(object):
_raise_deprecation_note(
"sync_function", "sync_to_driver", soft=False)
stop = stop or {}
if not isinstance(stop, dict) and not callable(stop):
raise ValueError("Invalid stop criteria: {}. Must be a callable "
"or dict".format(stop))
if callable(stop) and len(inspect.getargspec(stop).args) != 2:
raise ValueError("Invalid stop criteria: {}. Callable criteria "
"must take exactly 2 parameters.".format(stop))
config = config or {}
run_identifier = Experiment._register_if_needed(run)
spec = {
"run": run_identifier,
"stop": stop or {},
"stop": stop,
"config": config,
"resources_per_trial": resources_per_trial,
"num_samples": num_samples,
@@ -453,6 +453,17 @@ class TrainableFunctionApiTest(unittest.TestCase):
[trial] = tune.run(train, stop={"test/test1/test2": 6}).trials
self.assertEqual(trial.last_result["training_iteration"], 7)
def testStoppingFunction(self):
def train(config, reporter):
for i in range(10):
reporter(test=i)
def stop(trial_id, result):
return result["test"] > 6
[trial] = tune.run(train, stop=stop).trials
self.assertEqual(trial.last_result["training_iteration"], 8)
def testEarlyReturn(self):
def train(config, reporter):
reporter(timesteps_total=100, done=True)
+3
View File
@@ -284,6 +284,9 @@ class Trial(object):
if result.get(DONE):
return True
if callable(self.stopping_criterion):
return self.stopping_criterion(self.trial_id, result)
for criteria, stop_value in self.stopping_criterion.items():
if criteria not in result:
raise TuneError(
+5 -3
View File
@@ -80,9 +80,11 @@ def run(run_or_experiment,
If Experiment, then Tune will execute training based on
Experiment.spec.
name (str): Name of experiment.
stop (dict): The stopping criteria. The keys may be any field in
the return result of 'train()', whichever is reached first.
Defaults to empty dict.
stop (dict|func): The stopping criteria. If dict, the keys may be
any field in the return result of 'train()', whichever is
reached first. If function, it must take (trial_id, result) as
arguments and return a boolean (True if trial should be stopped,
False otherwise).
config (dict): Algorithm-specific configuration for Tune variant
generation (e.g. env, hyperparams). Defaults to empty dict.
Custom search algorithms may ignore this.