[tune] Add support for function-based stopping condition (#5754)

This commit is contained in:
Ujval Misra
2019-09-23 18:39:00 -07:00
committed by Richard Liaw
parent b03147e7bf
commit a4659a8f8b
5 changed files with 65 additions and 5 deletions
+10 -1
View File
@@ -3,6 +3,7 @@ from __future__ import division
from __future__ import print_function
import copy
import inspect
import logging
import os
import six
@@ -87,11 +88,19 @@ class Experiment(object):
_raise_deprecation_note(
"sync_function", "sync_to_driver", soft=False)
stop = stop or {}
if not isinstance(stop, dict) and not callable(stop):
raise ValueError("Invalid stop criteria: {}. Must be a callable "
"or dict".format(stop))
if callable(stop) and len(inspect.getargspec(stop).args) != 2:
raise ValueError("Invalid stop criteria: {}. Callable criteria "
"must take exactly 2 parameters.".format(stop))
config = config or {}
run_identifier = Experiment._register_if_needed(run)
spec = {
"run": run_identifier,
"stop": stop or {},
"stop": stop,
"config": config,
"resources_per_trial": resources_per_trial,
"num_samples": num_samples,
@@ -453,6 +453,17 @@ class TrainableFunctionApiTest(unittest.TestCase):
[trial] = tune.run(train, stop={"test/test1/test2": 6}).trials
self.assertEqual(trial.last_result["training_iteration"], 7)
def testStoppingFunction(self):
def train(config, reporter):
for i in range(10):
reporter(test=i)
def stop(trial_id, result):
return result["test"] > 6
[trial] = tune.run(train, stop=stop).trials
self.assertEqual(trial.last_result["training_iteration"], 8)
def testEarlyReturn(self):
def train(config, reporter):
reporter(timesteps_total=100, done=True)
+3
View File
@@ -284,6 +284,9 @@ class Trial(object):
if result.get(DONE):
return True
if callable(self.stopping_criterion):
return self.stopping_criterion(self.trial_id, result)
for criteria, stop_value in self.stopping_criterion.items():
if criteria not in result:
raise TuneError(
+5 -3
View File
@@ -80,9 +80,11 @@ def run(run_or_experiment,
If Experiment, then Tune will execute training based on
Experiment.spec.
name (str): Name of experiment.
stop (dict): The stopping criteria. The keys may be any field in
the return result of 'train()', whichever is reached first.
Defaults to empty dict.
stop (dict|func): The stopping criteria. If dict, the keys may be
any field in the return result of 'train()', whichever is
reached first. If function, it must take (trial_id, result) as
arguments and return a boolean (True if trial should be stopped,
False otherwise).
config (dict): Algorithm-specific configuration for Tune variant
generation (e.g. env, hyperparams). Defaults to empty dict.
Custom search algorithms may ignore this.