mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 08:07:54 +08:00
[tune] Add support for function-based stopping condition (#5754)
This commit is contained in:
committed by
Richard Liaw
parent
b03147e7bf
commit
a4659a8f8b
@@ -3,6 +3,7 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import copy
|
||||
import inspect
|
||||
import logging
|
||||
import os
|
||||
import six
|
||||
@@ -87,11 +88,19 @@ class Experiment(object):
|
||||
_raise_deprecation_note(
|
||||
"sync_function", "sync_to_driver", soft=False)
|
||||
|
||||
stop = stop or {}
|
||||
if not isinstance(stop, dict) and not callable(stop):
|
||||
raise ValueError("Invalid stop criteria: {}. Must be a callable "
|
||||
"or dict".format(stop))
|
||||
if callable(stop) and len(inspect.getargspec(stop).args) != 2:
|
||||
raise ValueError("Invalid stop criteria: {}. Callable criteria "
|
||||
"must take exactly 2 parameters.".format(stop))
|
||||
|
||||
config = config or {}
|
||||
run_identifier = Experiment._register_if_needed(run)
|
||||
spec = {
|
||||
"run": run_identifier,
|
||||
"stop": stop or {},
|
||||
"stop": stop,
|
||||
"config": config,
|
||||
"resources_per_trial": resources_per_trial,
|
||||
"num_samples": num_samples,
|
||||
|
||||
@@ -453,6 +453,17 @@ class TrainableFunctionApiTest(unittest.TestCase):
|
||||
[trial] = tune.run(train, stop={"test/test1/test2": 6}).trials
|
||||
self.assertEqual(trial.last_result["training_iteration"], 7)
|
||||
|
||||
def testStoppingFunction(self):
|
||||
def train(config, reporter):
|
||||
for i in range(10):
|
||||
reporter(test=i)
|
||||
|
||||
def stop(trial_id, result):
|
||||
return result["test"] > 6
|
||||
|
||||
[trial] = tune.run(train, stop=stop).trials
|
||||
self.assertEqual(trial.last_result["training_iteration"], 8)
|
||||
|
||||
def testEarlyReturn(self):
|
||||
def train(config, reporter):
|
||||
reporter(timesteps_total=100, done=True)
|
||||
|
||||
@@ -284,6 +284,9 @@ class Trial(object):
|
||||
if result.get(DONE):
|
||||
return True
|
||||
|
||||
if callable(self.stopping_criterion):
|
||||
return self.stopping_criterion(self.trial_id, result)
|
||||
|
||||
for criteria, stop_value in self.stopping_criterion.items():
|
||||
if criteria not in result:
|
||||
raise TuneError(
|
||||
|
||||
@@ -80,9 +80,11 @@ def run(run_or_experiment,
|
||||
If Experiment, then Tune will execute training based on
|
||||
Experiment.spec.
|
||||
name (str): Name of experiment.
|
||||
stop (dict): The stopping criteria. The keys may be any field in
|
||||
the return result of 'train()', whichever is reached first.
|
||||
Defaults to empty dict.
|
||||
stop (dict|func): The stopping criteria. If dict, the keys may be
|
||||
any field in the return result of 'train()', whichever is
|
||||
reached first. If function, it must take (trial_id, result) as
|
||||
arguments and return a boolean (True if trial should be stopped,
|
||||
False otherwise).
|
||||
config (dict): Algorithm-specific configuration for Tune variant
|
||||
generation (e.g. env, hyperparams). Defaults to empty dict.
|
||||
Custom search algorithms may ignore this.
|
||||
|
||||
Reference in New Issue
Block a user