mirror of
https://github.com/wassname/ray.git
synced 2026-07-05 06:48:16 +08:00
[tune] Add support for function-based stopping condition (#5754)
This commit is contained in:
committed by
Richard Liaw
parent
b03147e7bf
commit
a4659a8f8b
@@ -106,7 +106,7 @@ All results reported by the trainable will be logged locally to a unique directo
|
||||
Trial Parallelism
|
||||
~~~~~~~~~~~~~~~~~
|
||||
|
||||
Tune automatically N concurrent trials, where N is the number of CPUs (cores) on your machine. By default, Tune assumes that each trial will only require 1 CPU. You can override this with ``resources_per_trial``:
|
||||
Tune automatically runs N concurrent trials, where N is the number of CPUs (cores) on your machine. By default, Tune assumes that each trial will only require 1 CPU. You can override this with ``resources_per_trial``:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
@@ -474,6 +474,41 @@ You often will want to compute a large object (e.g., training data, model weight
|
||||
|
||||
tune.run(f)
|
||||
|
||||
Custom Stopping Criteria
|
||||
------------------------
|
||||
|
||||
You can control when trials are stopped early by passing the ``stop`` argument to ``tune.run``. This argument takes either a dictionary or a function.
|
||||
|
||||
If a dictionary is passed in, the keys may be any field in the return result of ``tune.track.log`` in the Function API or ``train()`` (including the results from ``_train`` and auto-filled metrics).
|
||||
|
||||
In the example below, each trial will be stopped either when it completes 10 iterations OR when it reaches a mean accuracy of 0.98. Note that `training_iteration` is an auto-filled metric by Tune.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
tune.run(
|
||||
my_trainable,
|
||||
stop={"training_iteration": 10, "mean_accuracy": 0.98}
|
||||
)
|
||||
|
||||
For more flexibility, you can pass in a function instead. If a function is passed in, it must take ``(trial_id, result)`` as arguments and return a boolean (``True`` if trial should be stopped and ``False`` otherwise).
|
||||
|
||||
You can use this to stop all trials after the criteria is fulfilled by any individual trial:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
class Stopper:
|
||||
def __init__(self):
|
||||
self.should_stop = False
|
||||
|
||||
def stop(self, trial_id, result):
|
||||
if not self.should_stop and result['foo'] > 10:
|
||||
self.should_stop = True
|
||||
return self.should_stop
|
||||
|
||||
stopper = Stopper()
|
||||
tune.run(my_trainable, stop=stopper.stop)
|
||||
|
||||
Note that in the above example all trials will not stop immediately, but will do so once their current iterations are complete.
|
||||
|
||||
Auto-Filled Results
|
||||
-------------------
|
||||
|
||||
@@ -3,6 +3,7 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import copy
|
||||
import inspect
|
||||
import logging
|
||||
import os
|
||||
import six
|
||||
@@ -87,11 +88,19 @@ class Experiment(object):
|
||||
_raise_deprecation_note(
|
||||
"sync_function", "sync_to_driver", soft=False)
|
||||
|
||||
stop = stop or {}
|
||||
if not isinstance(stop, dict) and not callable(stop):
|
||||
raise ValueError("Invalid stop criteria: {}. Must be a callable "
|
||||
"or dict".format(stop))
|
||||
if callable(stop) and len(inspect.getargspec(stop).args) != 2:
|
||||
raise ValueError("Invalid stop criteria: {}. Callable criteria "
|
||||
"must take exactly 2 parameters.".format(stop))
|
||||
|
||||
config = config or {}
|
||||
run_identifier = Experiment._register_if_needed(run)
|
||||
spec = {
|
||||
"run": run_identifier,
|
||||
"stop": stop or {},
|
||||
"stop": stop,
|
||||
"config": config,
|
||||
"resources_per_trial": resources_per_trial,
|
||||
"num_samples": num_samples,
|
||||
|
||||
@@ -453,6 +453,17 @@ class TrainableFunctionApiTest(unittest.TestCase):
|
||||
[trial] = tune.run(train, stop={"test/test1/test2": 6}).trials
|
||||
self.assertEqual(trial.last_result["training_iteration"], 7)
|
||||
|
||||
def testStoppingFunction(self):
|
||||
def train(config, reporter):
|
||||
for i in range(10):
|
||||
reporter(test=i)
|
||||
|
||||
def stop(trial_id, result):
|
||||
return result["test"] > 6
|
||||
|
||||
[trial] = tune.run(train, stop=stop).trials
|
||||
self.assertEqual(trial.last_result["training_iteration"], 8)
|
||||
|
||||
def testEarlyReturn(self):
|
||||
def train(config, reporter):
|
||||
reporter(timesteps_total=100, done=True)
|
||||
|
||||
@@ -284,6 +284,9 @@ class Trial(object):
|
||||
if result.get(DONE):
|
||||
return True
|
||||
|
||||
if callable(self.stopping_criterion):
|
||||
return self.stopping_criterion(self.trial_id, result)
|
||||
|
||||
for criteria, stop_value in self.stopping_criterion.items():
|
||||
if criteria not in result:
|
||||
raise TuneError(
|
||||
|
||||
@@ -80,9 +80,11 @@ def run(run_or_experiment,
|
||||
If Experiment, then Tune will execute training based on
|
||||
Experiment.spec.
|
||||
name (str): Name of experiment.
|
||||
stop (dict): The stopping criteria. The keys may be any field in
|
||||
the return result of 'train()', whichever is reached first.
|
||||
Defaults to empty dict.
|
||||
stop (dict|func): The stopping criteria. If dict, the keys may be
|
||||
any field in the return result of 'train()', whichever is
|
||||
reached first. If function, it must take (trial_id, result) as
|
||||
arguments and return a boolean (True if trial should be stopped,
|
||||
False otherwise).
|
||||
config (dict): Algorithm-specific configuration for Tune variant
|
||||
generation (e.g. env, hyperparams). Defaults to empty dict.
|
||||
Custom search algorithms may ignore this.
|
||||
|
||||
Reference in New Issue
Block a user