diff --git a/python/ray/tune/examples/cifar10_pytorch.py b/python/ray/tune/examples/cifar10_pytorch.py index 5089ef364..1b9b75afb 100644 --- a/python/ray/tune/examples/cifar10_pytorch.py +++ b/python/ray/tune/examples/cifar10_pytorch.py @@ -195,8 +195,7 @@ def main(num_samples=10, max_num_epochs=10, gpus_per_trial=2): config=config, num_samples=num_samples, scheduler=scheduler, - progress_reporter=reporter, - checkpoint_at_end=True) + progress_reporter=reporter) best_trial = result.get_best_trial("loss", "min", "last") print("Best trial config: {}".format(best_trial.config)) diff --git a/python/ray/tune/experiment.py b/python/ray/tune/experiment.py index 2537e30ee..ec2bffb31 100644 --- a/python/ray/tune/experiment.py +++ b/python/ray/tune/experiment.py @@ -120,13 +120,12 @@ class Experiment: restore=None): config = config or {} - if callable(run) and detect_checkpoint_function(run): if checkpoint_at_end: - raise ValueError( - "'checkpoint_at_end' cannot be used with a " - "checkpointable function. You can specify and register " - "checkpoints within your trainable function.") + raise ValueError("'checkpoint_at_end' cannot be used with a " + "checkpointable function. You can specify " + "and register checkpoints within " + "your trainable function.") if checkpoint_freq: raise ValueError( "'checkpoint_freq' cannot be used with a " diff --git a/python/ray/tune/tests/test_function_api.py b/python/ray/tune/tests/test_function_api.py index 0e20849ad..04a70ae9c 100644 --- a/python/ray/tune/tests/test_function_api.py +++ b/python/ray/tune/tests/test_function_api.py @@ -300,6 +300,15 @@ class FunctionApiTest(unittest.TestCase): ray.shutdown() _register_all() # re-register the evicted objects + def testCheckpointError(self): + def train(config, checkpoint_dir=False): + pass + + with self.assertRaises(ValueError): + tune.run(train, checkpoint_freq=1) + with self.assertRaises(ValueError): + tune.run(train, checkpoint_at_end=True) + def testCheckpointFunctionAtEnd(self): def train(config, checkpoint_dir=False): for i in range(10): diff --git a/python/ray/tune/tune.py b/python/ray/tune/tune.py index f0ce9b0ef..f1182dae6 100644 --- a/python/ray/tune/tune.py +++ b/python/ray/tune/tune.py @@ -264,10 +264,9 @@ def run(run_or_experiment, for i, exp in enumerate(experiments): if not isinstance(exp, Experiment): - run_identifier = Experiment.register_if_needed(exp) experiments[i] = Experiment( name=name, - run=run_identifier, + run=exp, stop=stop, config=config, resources_per_trial=resources_per_trial,