diff --git a/doc/source/tune-usage.rst b/doc/source/tune-usage.rst index c9e18fcf7..2a6bcf2eb 100644 --- a/doc/source/tune-usage.rst +++ b/doc/source/tune-usage.rst @@ -106,7 +106,7 @@ All results reported by the trainable will be logged locally to a unique directo Trial Parallelism ~~~~~~~~~~~~~~~~~ -Tune automatically N concurrent trials, where N is the number of CPUs (cores) on your machine. By default, Tune assumes that each trial will only require 1 CPU. You can override this with ``resources_per_trial``: +Tune automatically runs N concurrent trials, where N is the number of CPUs (cores) on your machine. By default, Tune assumes that each trial will only require 1 CPU. You can override this with ``resources_per_trial``: .. code-block:: python @@ -474,6 +474,41 @@ You often will want to compute a large object (e.g., training data, model weight tune.run(f) +Custom Stopping Criteria +------------------------ + +You can control when trials are stopped early by passing the ``stop`` argument to ``tune.run``. This argument takes either a dictionary or a function. + +If a dictionary is passed in, the keys may be any field in the return result of ``tune.track.log`` in the Function API or ``train()`` (including the results from ``_train`` and auto-filled metrics). + +In the example below, each trial will be stopped either when it completes 10 iterations OR when it reaches a mean accuracy of 0.98. Note that `training_iteration` is an auto-filled metric by Tune. + +.. code-block:: python + + tune.run( + my_trainable, + stop={"training_iteration": 10, "mean_accuracy": 0.98} + ) + +For more flexibility, you can pass in a function instead. If a function is passed in, it must take ``(trial_id, result)`` as arguments and return a boolean (``True`` if trial should be stopped and ``False`` otherwise). + +You can use this to stop all trials after the criteria is fulfilled by any individual trial: + +.. code-block:: python + + class Stopper: + def __init__(self): + self.should_stop = False + + def stop(self, trial_id, result): + if not self.should_stop and result['foo'] > 10: + self.should_stop = True + return self.should_stop + + stopper = Stopper() + tune.run(my_trainable, stop=stopper.stop) + +Note that in the above example all trials will not stop immediately, but will do so once their current iterations are complete. Auto-Filled Results ------------------- diff --git a/python/ray/tune/experiment.py b/python/ray/tune/experiment.py index 1af9b043f..baa958191 100644 --- a/python/ray/tune/experiment.py +++ b/python/ray/tune/experiment.py @@ -3,6 +3,7 @@ from __future__ import division from __future__ import print_function import copy +import inspect import logging import os import six @@ -87,11 +88,19 @@ class Experiment(object): _raise_deprecation_note( "sync_function", "sync_to_driver", soft=False) + stop = stop or {} + if not isinstance(stop, dict) and not callable(stop): + raise ValueError("Invalid stop criteria: {}. Must be a callable " + "or dict".format(stop)) + if callable(stop) and len(inspect.getargspec(stop).args) != 2: + raise ValueError("Invalid stop criteria: {}. Callable criteria " + "must take exactly 2 parameters.".format(stop)) + config = config or {} run_identifier = Experiment._register_if_needed(run) spec = { "run": run_identifier, - "stop": stop or {}, + "stop": stop, "config": config, "resources_per_trial": resources_per_trial, "num_samples": num_samples, diff --git a/python/ray/tune/tests/test_trial_runner.py b/python/ray/tune/tests/test_trial_runner.py index f37fa327a..f1785f0dd 100644 --- a/python/ray/tune/tests/test_trial_runner.py +++ b/python/ray/tune/tests/test_trial_runner.py @@ -453,6 +453,17 @@ class TrainableFunctionApiTest(unittest.TestCase): [trial] = tune.run(train, stop={"test/test1/test2": 6}).trials self.assertEqual(trial.last_result["training_iteration"], 7) + def testStoppingFunction(self): + def train(config, reporter): + for i in range(10): + reporter(test=i) + + def stop(trial_id, result): + return result["test"] > 6 + + [trial] = tune.run(train, stop=stop).trials + self.assertEqual(trial.last_result["training_iteration"], 8) + def testEarlyReturn(self): def train(config, reporter): reporter(timesteps_total=100, done=True) diff --git a/python/ray/tune/trial.py b/python/ray/tune/trial.py index b0de00bd3..725ce04fb 100644 --- a/python/ray/tune/trial.py +++ b/python/ray/tune/trial.py @@ -284,6 +284,9 @@ class Trial(object): if result.get(DONE): return True + if callable(self.stopping_criterion): + return self.stopping_criterion(self.trial_id, result) + for criteria, stop_value in self.stopping_criterion.items(): if criteria not in result: raise TuneError( diff --git a/python/ray/tune/tune.py b/python/ray/tune/tune.py index 881751823..73747c449 100644 --- a/python/ray/tune/tune.py +++ b/python/ray/tune/tune.py @@ -80,9 +80,11 @@ def run(run_or_experiment, If Experiment, then Tune will execute training based on Experiment.spec. name (str): Name of experiment. - stop (dict): The stopping criteria. The keys may be any field in - the return result of 'train()', whichever is reached first. - Defaults to empty dict. + stop (dict|func): The stopping criteria. If dict, the keys may be + any field in the return result of 'train()', whichever is + reached first. If function, it must take (trial_id, result) as + arguments and return a boolean (True if trial should be stopped, + False otherwise). config (dict): Algorithm-specific configuration for Tune variant generation (e.g. env, hyperparams). Defaults to empty dict. Custom search algorithms may ignore this.