From 9df6eda84ff8c4038808226e5dfb28fec3464657 Mon Sep 17 00:00:00 2001 From: Ujval Misra Date: Thu, 3 Oct 2019 09:49:03 -0700 Subject: [PATCH] =?UTF-8?q?[tune]=20Add=20error=20case=20for=20member=20fu?= =?UTF-8?q?nctions=20passed=20as=20stopping=20c=E2=80=A6=20(#5823)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- python/ray/tune/experiment.py | 11 +++++--- python/ray/tune/tests/test_trial_runner.py | 29 ++++++++++++++++++++++ 2 files changed, 37 insertions(+), 3 deletions(-) diff --git a/python/ray/tune/experiment.py b/python/ray/tune/experiment.py index baa958191..81ad29ac9 100644 --- a/python/ray/tune/experiment.py +++ b/python/ray/tune/experiment.py @@ -7,6 +7,7 @@ import inspect import logging import os import six +import types from ray.tune.error import TuneError from ray.tune.registry import register_trainable @@ -92,9 +93,13 @@ class Experiment(object): if not isinstance(stop, dict) and not callable(stop): raise ValueError("Invalid stop criteria: {}. Must be a callable " "or dict".format(stop)) - if callable(stop) and len(inspect.getargspec(stop).args) != 2: - raise ValueError("Invalid stop criteria: {}. Callable criteria " - "must take exactly 2 parameters.".format(stop)) + if callable(stop): + nargs = len(inspect.getargspec(stop).args) + is_method = isinstance(stop, types.MethodType) + if (is_method and nargs != 3) or (not is_method and nargs != 2): + raise ValueError( + "Invalid stop criteria: {}. Callable " + "criteria must take exactly 2 parameters.".format(stop)) config = config or {} run_identifier = Experiment._register_if_needed(run) diff --git a/python/ray/tune/tests/test_trial_runner.py b/python/ray/tune/tests/test_trial_runner.py index f1785f0dd..926f45b0a 100644 --- a/python/ray/tune/tests/test_trial_runner.py +++ b/python/ray/tune/tests/test_trial_runner.py @@ -464,6 +464,35 @@ class TrainableFunctionApiTest(unittest.TestCase): [trial] = tune.run(train, stop=stop).trials self.assertEqual(trial.last_result["training_iteration"], 8) + def testStoppingMemberFunction(self): + def train(config, reporter): + for i in range(10): + reporter(test=i) + + class Stopper: + def stop(self, trial_id, result): + return result["test"] > 6 + + [trial] = tune.run(train, stop=Stopper().stop).trials + self.assertEqual(trial.last_result["training_iteration"], 8) + + def testBadStoppingFunction(self): + def train(config, reporter): + for i in range(10): + reporter(test=i) + + class Stopper: + def stop(self, result): + return result["test"] > 6 + + def stop(result): + return result["test"] > 6 + + with self.assertRaises(ValueError): + tune.run(train, stop=Stopper().stop) + with self.assertRaises(ValueError): + tune.run(train, stop=stop) + def testEarlyReturn(self): def train(config, reporter): reporter(timesteps_total=100, done=True)