diff --git a/python/ray/tune/examples/tune_mnist_ray_hyperband.py b/python/ray/tune/examples/tune_mnist_ray_hyperband.py index bce19deca..589b53bfe 100755 --- a/python/ray/tune/examples/tune_mnist_ray_hyperband.py +++ b/python/ray/tune/examples/tune_mnist_ray_hyperband.py @@ -199,11 +199,13 @@ class TrainMNIST(Trainable): return {"mean_accuracy": train_accuracy} def _save(self, checkpoint_dir): - return self.saver.save( + prefix = self.saver.save( self.sess, checkpoint_dir + "/save", global_step=self.iterations) + return {"prefix": prefix} - def _restore(self, path): - return self.saver.restore(self.sess, path) + def _restore(self, ckpt_data): + prefix = ckpt_data["prefix"] + return self.saver.restore(self.sess, prefix) # !!! Example of using the ray.tune Python API !!! @@ -229,7 +231,7 @@ if __name__ == '__main__': } if args.smoke_test: - mnist_spec['stop']['training_iteration'] = 2 + mnist_spec['stop']['training_iteration'] = 20 mnist_spec['num_samples'] = 2 ray.init()