From 7cf62a10cde4dd06daa42463cee0430309f467e3 Mon Sep 17 00:00:00 2001 From: Richard Liaw Date: Fri, 15 Feb 2019 00:30:27 -0800 Subject: [PATCH] [tune] Fix TF checkpointing example (#4043) Closes #3912, closes #3963. --- python/ray/tune/examples/tune_mnist_ray_hyperband.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/python/ray/tune/examples/tune_mnist_ray_hyperband.py b/python/ray/tune/examples/tune_mnist_ray_hyperband.py index bce19deca..589b53bfe 100755 --- a/python/ray/tune/examples/tune_mnist_ray_hyperband.py +++ b/python/ray/tune/examples/tune_mnist_ray_hyperband.py @@ -199,11 +199,13 @@ class TrainMNIST(Trainable): return {"mean_accuracy": train_accuracy} def _save(self, checkpoint_dir): - return self.saver.save( + prefix = self.saver.save( self.sess, checkpoint_dir + "/save", global_step=self.iterations) + return {"prefix": prefix} - def _restore(self, path): - return self.saver.restore(self.sess, path) + def _restore(self, ckpt_data): + prefix = ckpt_data["prefix"] + return self.saver.restore(self.sess, prefix) # !!! Example of using the ray.tune Python API !!! @@ -229,7 +231,7 @@ if __name__ == '__main__': } if args.smoke_test: - mnist_spec['stop']['training_iteration'] = 2 + mnist_spec['stop']['training_iteration'] = 20 mnist_spec['num_samples'] = 2 ray.init()