[tune] demo exporting trained models in pbt examples (#6533)

This commit is contained in:
Yuhao Yang
2019-12-25 17:14:49 -08:00
committed by Richard Liaw
parent 93e8c85e72
commit df4533c649
3 changed files with 39 additions and 6 deletions
@@ -18,6 +18,7 @@ import ray
from ray import tune
from ray.tune.schedulers import PopulationBasedTraining
from ray.tune.util import validate_save_restore
from ray.tune.trial import ExportFormat
# __tutorial_imports_end__
@@ -51,6 +52,14 @@ class PytorchTrainble(tune.Trainable):
def _restore(self, checkpoint_path):
self.model.load_state_dict(torch.load(checkpoint_path))
def _export_model(self, export_formats, export_dir):
if export_formats == [ExportFormat.MODEL]:
path = os.path.join(export_dir, "exported_convnet.pt")
torch.save(self.model.state_dict(), path)
return {export_formats[0]: path}
else:
raise ValueError("unexpected formats: " + str(export_formats))
def reset_config(self, new_config):
for param_group in self.optimizer.param_groups:
if "lr" in new_config:
@@ -76,7 +85,6 @@ if __name__ == "__main__":
# check if PytorchTrainble will save/restore correctly before execution
validate_save_restore(PytorchTrainble)
validate_save_restore(PytorchTrainble, use_object_store=True)
print("Success!")
# __pbt_begin__
scheduler = PopulationBasedTraining(
@@ -90,18 +98,30 @@ if __name__ == "__main__":
# allow perturbations within this set of categorical values
"momentum": [0.8, 0.9, 0.99],
})
# __pbt_end__
# __tune_begin__
class Stopper:
def __init__(self):
self.should_stop = False
def stop(self, trial_id, result):
max_iter = 5 if args.smoke_test else 100
if not self.should_stop and result["mean_accuracy"] > 0.96:
self.should_stop = True
return self.should_stop or result["training_iteration"] >= max_iter
stopper = Stopper()
analysis = tune.run(
PytorchTrainble,
name="pbt_test",
scheduler=scheduler,
reuse_actors=True,
verbose=1,
stop={
"training_iteration": 5 if args.smoke_test else 100,
},
stop=stopper.stop,
export_formats=[ExportFormat.MODEL],
num_samples=4,
config={
"lr": tune.uniform(0.001, 1),
@@ -7,6 +7,7 @@ from __future__ import print_function
import ray
from ray import tune
from ray.tune.schedulers import PopulationBasedTraining
from ray.tune.trial import ExportFormat
import argparse
import os
@@ -285,6 +286,17 @@ class PytorchTrainable(tune.Trainable):
self.config = new_config
return True
def _export_model(self, export_formats, export_dir):
if export_formats == [ExportFormat.MODEL]:
path = os.path.join(export_dir, "exported_models")
torch.save({
"netDmodel": self.netD.state_dict(),
"netGmodel": self.netG.state_dict()
}, path)
return {ExportFormat.MODEL: path}
else:
raise ValueError("unexpected formats: " + str(export_formats))
# __Trainable_end__
@@ -343,6 +355,7 @@ if __name__ == "__main__":
"training_iteration": tune_iter,
},
num_samples=8,
export_formats=[ExportFormat.MODEL],
config={
"netG_lr": tune.sample_from(
lambda spec: random.choice([0.0001, 0.0002, 0.0005])),
@@ -357,7 +370,7 @@ if __name__ == "__main__":
img_list = []
fixed_noise = torch.randn(64, nz, 1, 1)
for d in logdirs:
netG_path = d + "/checkpoint_" + str(tune_iter) + "/checkpoint"
netG_path = os.path.join(d, "exported_models")
loadedG = Generator()
loadedG.load_state_dict(torch.load(netG_path)["netGmodel"])
with torch.no_grad():