diff --git a/python/ray/tune/experiment.py b/python/ray/tune/experiment.py index baa958191..81ad29ac9 100644 --- a/python/ray/tune/experiment.py +++ b/python/ray/tune/experiment.py @@ -7,6 +7,7 @@ import inspect import logging import os import six +import types from ray.tune.error import TuneError from ray.tune.registry import register_trainable @@ -92,9 +93,13 @@ class Experiment(object): if not isinstance(stop, dict) and not callable(stop): raise ValueError("Invalid stop criteria: {}. Must be a callable " "or dict".format(stop)) - if callable(stop) and len(inspect.getargspec(stop).args) != 2: - raise ValueError("Invalid stop criteria: {}. Callable criteria " - "must take exactly 2 parameters.".format(stop)) + if callable(stop): + nargs = len(inspect.getargspec(stop).args) + is_method = isinstance(stop, types.MethodType) + if (is_method and nargs != 3) or (not is_method and nargs != 2): + raise ValueError( + "Invalid stop criteria: {}. Callable " + "criteria must take exactly 2 parameters.".format(stop)) config = config or {} run_identifier = Experiment._register_if_needed(run) diff --git a/python/ray/tune/tests/test_trial_runner.py b/python/ray/tune/tests/test_trial_runner.py index f1785f0dd..926f45b0a 100644 --- a/python/ray/tune/tests/test_trial_runner.py +++ b/python/ray/tune/tests/test_trial_runner.py @@ -464,6 +464,35 @@ class TrainableFunctionApiTest(unittest.TestCase): [trial] = tune.run(train, stop=stop).trials self.assertEqual(trial.last_result["training_iteration"], 8) + def testStoppingMemberFunction(self): + def train(config, reporter): + for i in range(10): + reporter(test=i) + + class Stopper: + def stop(self, trial_id, result): + return result["test"] > 6 + + [trial] = tune.run(train, stop=Stopper().stop).trials + self.assertEqual(trial.last_result["training_iteration"], 8) + + def testBadStoppingFunction(self): + def train(config, reporter): + for i in range(10): + reporter(test=i) + + class Stopper: + def stop(self, result): + return result["test"] > 6 + + def stop(result): + return result["test"] > 6 + + with self.assertRaises(ValueError): + tune.run(train, stop=Stopper().stop) + with self.assertRaises(ValueError): + tune.run(train, stop=stop) + def testEarlyReturn(self): def train(config, reporter): reporter(timesteps_total=100, done=True)