mirror of
https://github.com/wassname/ray.git
synced 2026-06-29 01:27:43 +08:00
[tune] Add error case for member functions passed as stopping c… (#5823)
This commit is contained in:
committed by
Richard Liaw
parent
2fb7d7846f
commit
9df6eda84f
@@ -7,6 +7,7 @@ import inspect
|
||||
import logging
|
||||
import os
|
||||
import six
|
||||
import types
|
||||
|
||||
from ray.tune.error import TuneError
|
||||
from ray.tune.registry import register_trainable
|
||||
@@ -92,9 +93,13 @@ class Experiment(object):
|
||||
if not isinstance(stop, dict) and not callable(stop):
|
||||
raise ValueError("Invalid stop criteria: {}. Must be a callable "
|
||||
"or dict".format(stop))
|
||||
if callable(stop) and len(inspect.getargspec(stop).args) != 2:
|
||||
raise ValueError("Invalid stop criteria: {}. Callable criteria "
|
||||
"must take exactly 2 parameters.".format(stop))
|
||||
if callable(stop):
|
||||
nargs = len(inspect.getargspec(stop).args)
|
||||
is_method = isinstance(stop, types.MethodType)
|
||||
if (is_method and nargs != 3) or (not is_method and nargs != 2):
|
||||
raise ValueError(
|
||||
"Invalid stop criteria: {}. Callable "
|
||||
"criteria must take exactly 2 parameters.".format(stop))
|
||||
|
||||
config = config or {}
|
||||
run_identifier = Experiment._register_if_needed(run)
|
||||
|
||||
@@ -464,6 +464,35 @@ class TrainableFunctionApiTest(unittest.TestCase):
|
||||
[trial] = tune.run(train, stop=stop).trials
|
||||
self.assertEqual(trial.last_result["training_iteration"], 8)
|
||||
|
||||
def testStoppingMemberFunction(self):
|
||||
def train(config, reporter):
|
||||
for i in range(10):
|
||||
reporter(test=i)
|
||||
|
||||
class Stopper:
|
||||
def stop(self, trial_id, result):
|
||||
return result["test"] > 6
|
||||
|
||||
[trial] = tune.run(train, stop=Stopper().stop).trials
|
||||
self.assertEqual(trial.last_result["training_iteration"], 8)
|
||||
|
||||
def testBadStoppingFunction(self):
|
||||
def train(config, reporter):
|
||||
for i in range(10):
|
||||
reporter(test=i)
|
||||
|
||||
class Stopper:
|
||||
def stop(self, result):
|
||||
return result["test"] > 6
|
||||
|
||||
def stop(result):
|
||||
return result["test"] > 6
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
tune.run(train, stop=Stopper().stop)
|
||||
with self.assertRaises(ValueError):
|
||||
tune.run(train, stop=stop)
|
||||
|
||||
def testEarlyReturn(self):
|
||||
def train(config, reporter):
|
||||
reporter(timesteps_total=100, done=True)
|
||||
|
||||
Reference in New Issue
Block a user