mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 15:40:09 +08:00
[tune] Fix TF checkpointing example (#4043)
Closes #3912, closes #3963.
This commit is contained in:
@@ -199,11 +199,13 @@ class TrainMNIST(Trainable):
|
||||
return {"mean_accuracy": train_accuracy}
|
||||
|
||||
def _save(self, checkpoint_dir):
|
||||
return self.saver.save(
|
||||
prefix = self.saver.save(
|
||||
self.sess, checkpoint_dir + "/save", global_step=self.iterations)
|
||||
return {"prefix": prefix}
|
||||
|
||||
def _restore(self, path):
|
||||
return self.saver.restore(self.sess, path)
|
||||
def _restore(self, ckpt_data):
|
||||
prefix = ckpt_data["prefix"]
|
||||
return self.saver.restore(self.sess, prefix)
|
||||
|
||||
|
||||
# !!! Example of using the ray.tune Python API !!!
|
||||
@@ -229,7 +231,7 @@ if __name__ == '__main__':
|
||||
}
|
||||
|
||||
if args.smoke_test:
|
||||
mnist_spec['stop']['training_iteration'] = 2
|
||||
mnist_spec['stop']['training_iteration'] = 20
|
||||
mnist_spec['num_samples'] = 2
|
||||
|
||||
ray.init()
|
||||
|
||||
Reference in New Issue
Block a user