[tune] Add error case for member functions passed as stopping c… (#5823)

This commit is contained in:
Ujval Misra
2019-10-03 09:49:03 -07:00
committed by Richard Liaw
parent 2fb7d7846f
commit 9df6eda84f
2 changed files with 37 additions and 3 deletions
+8 -3
View File
@@ -7,6 +7,7 @@ import inspect
import logging
import os
import six
import types
from ray.tune.error import TuneError
from ray.tune.registry import register_trainable
@@ -92,9 +93,13 @@ class Experiment(object):
if not isinstance(stop, dict) and not callable(stop):
raise ValueError("Invalid stop criteria: {}. Must be a callable "
"or dict".format(stop))
if callable(stop) and len(inspect.getargspec(stop).args) != 2:
raise ValueError("Invalid stop criteria: {}. Callable criteria "
"must take exactly 2 parameters.".format(stop))
if callable(stop):
nargs = len(inspect.getargspec(stop).args)
is_method = isinstance(stop, types.MethodType)
if (is_method and nargs != 3) or (not is_method and nargs != 2):
raise ValueError(
"Invalid stop criteria: {}. Callable "
"criteria must take exactly 2 parameters.".format(stop))
config = config or {}
run_identifier = Experiment._register_if_needed(run)
@@ -464,6 +464,35 @@ class TrainableFunctionApiTest(unittest.TestCase):
[trial] = tune.run(train, stop=stop).trials
self.assertEqual(trial.last_result["training_iteration"], 8)
def testStoppingMemberFunction(self):
def train(config, reporter):
for i in range(10):
reporter(test=i)
class Stopper:
def stop(self, trial_id, result):
return result["test"] > 6
[trial] = tune.run(train, stop=Stopper().stop).trials
self.assertEqual(trial.last_result["training_iteration"], 8)
def testBadStoppingFunction(self):
def train(config, reporter):
for i in range(10):
reporter(test=i)
class Stopper:
def stop(self, result):
return result["test"] > 6
def stop(result):
return result["test"] > 6
with self.assertRaises(ValueError):
tune.run(train, stop=Stopper().stop)
with self.assertRaises(ValueError):
tune.run(train, stop=stop)
def testEarlyReturn(self):
def train(config, reporter):
reporter(timesteps_total=100, done=True)