[tune] Fix TF checkpointing example (#4043)

Closes #3912, closes #3963.
This commit is contained in:
Richard Liaw
2019-02-15 00:30:27 -08:00
committed by GitHub
parent 3684e5bc0d
commit 7cf62a10cd
@@ -199,11 +199,13 @@ class TrainMNIST(Trainable):
return {"mean_accuracy": train_accuracy}
def _save(self, checkpoint_dir):
return self.saver.save(
prefix = self.saver.save(
self.sess, checkpoint_dir + "/save", global_step=self.iterations)
return {"prefix": prefix}
def _restore(self, path):
return self.saver.restore(self.sess, path)
def _restore(self, ckpt_data):
prefix = ckpt_data["prefix"]
return self.saver.restore(self.sess, prefix)
# !!! Example of using the ray.tune Python API !!!
@@ -229,7 +231,7 @@ if __name__ == '__main__':
}
if args.smoke_test:
mnist_spec['stop']['training_iteration'] = 2
mnist_spec['stop']['training_iteration'] = 20
mnist_spec['num_samples'] = 2
ray.init()