[Tune] Better error when using checkpoint_freq (#9998)

This commit is contained in:
Amog Kamsetty
2020-08-10 13:52:46 -07:00
committed by GitHub
parent be8e63d477
commit 856d4a0533
4 changed files with 15 additions and 9 deletions
+1 -2
View File
@@ -195,8 +195,7 @@ def main(num_samples=10, max_num_epochs=10, gpus_per_trial=2):
config=config,
num_samples=num_samples,
scheduler=scheduler,
progress_reporter=reporter,
checkpoint_at_end=True)
progress_reporter=reporter)
best_trial = result.get_best_trial("loss", "min", "last")
print("Best trial config: {}".format(best_trial.config))
+4 -5
View File
@@ -120,13 +120,12 @@ class Experiment:
restore=None):
config = config or {}
if callable(run) and detect_checkpoint_function(run):
if checkpoint_at_end:
raise ValueError(
"'checkpoint_at_end' cannot be used with a "
"checkpointable function. You can specify and register "
"checkpoints within your trainable function.")
raise ValueError("'checkpoint_at_end' cannot be used with a "
"checkpointable function. You can specify "
"and register checkpoints within "
"your trainable function.")
if checkpoint_freq:
raise ValueError(
"'checkpoint_freq' cannot be used with a "
@@ -300,6 +300,15 @@ class FunctionApiTest(unittest.TestCase):
ray.shutdown()
_register_all() # re-register the evicted objects
def testCheckpointError(self):
def train(config, checkpoint_dir=False):
pass
with self.assertRaises(ValueError):
tune.run(train, checkpoint_freq=1)
with self.assertRaises(ValueError):
tune.run(train, checkpoint_at_end=True)
def testCheckpointFunctionAtEnd(self):
def train(config, checkpoint_dir=False):
for i in range(10):
+1 -2
View File
@@ -264,10 +264,9 @@ def run(run_or_experiment,
for i, exp in enumerate(experiments):
if not isinstance(exp, Experiment):
run_identifier = Experiment.register_if_needed(exp)
experiments[i] = Experiment(
name=name,
run=run_identifier,
run=exp,
stop=stop,
config=config,
resources_per_trial=resources_per_trial,