mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 20:40:09 +08:00
[tune] demo exporting trained models in pbt examples (#6533)
This commit is contained in:
@@ -18,6 +18,7 @@ import ray
|
||||
from ray import tune
|
||||
from ray.tune.schedulers import PopulationBasedTraining
|
||||
from ray.tune.util import validate_save_restore
|
||||
from ray.tune.trial import ExportFormat
|
||||
|
||||
# __tutorial_imports_end__
|
||||
|
||||
@@ -51,6 +52,14 @@ class PytorchTrainble(tune.Trainable):
|
||||
def _restore(self, checkpoint_path):
|
||||
self.model.load_state_dict(torch.load(checkpoint_path))
|
||||
|
||||
def _export_model(self, export_formats, export_dir):
|
||||
if export_formats == [ExportFormat.MODEL]:
|
||||
path = os.path.join(export_dir, "exported_convnet.pt")
|
||||
torch.save(self.model.state_dict(), path)
|
||||
return {export_formats[0]: path}
|
||||
else:
|
||||
raise ValueError("unexpected formats: " + str(export_formats))
|
||||
|
||||
def reset_config(self, new_config):
|
||||
for param_group in self.optimizer.param_groups:
|
||||
if "lr" in new_config:
|
||||
@@ -76,7 +85,6 @@ if __name__ == "__main__":
|
||||
# check if PytorchTrainble will save/restore correctly before execution
|
||||
validate_save_restore(PytorchTrainble)
|
||||
validate_save_restore(PytorchTrainble, use_object_store=True)
|
||||
print("Success!")
|
||||
|
||||
# __pbt_begin__
|
||||
scheduler = PopulationBasedTraining(
|
||||
@@ -90,18 +98,30 @@ if __name__ == "__main__":
|
||||
# allow perturbations within this set of categorical values
|
||||
"momentum": [0.8, 0.9, 0.99],
|
||||
})
|
||||
|
||||
# __pbt_end__
|
||||
|
||||
# __tune_begin__
|
||||
class Stopper:
|
||||
def __init__(self):
|
||||
self.should_stop = False
|
||||
|
||||
def stop(self, trial_id, result):
|
||||
max_iter = 5 if args.smoke_test else 100
|
||||
if not self.should_stop and result["mean_accuracy"] > 0.96:
|
||||
self.should_stop = True
|
||||
return self.should_stop or result["training_iteration"] >= max_iter
|
||||
|
||||
stopper = Stopper()
|
||||
|
||||
analysis = tune.run(
|
||||
PytorchTrainble,
|
||||
name="pbt_test",
|
||||
scheduler=scheduler,
|
||||
reuse_actors=True,
|
||||
verbose=1,
|
||||
stop={
|
||||
"training_iteration": 5 if args.smoke_test else 100,
|
||||
},
|
||||
stop=stopper.stop,
|
||||
export_formats=[ExportFormat.MODEL],
|
||||
num_samples=4,
|
||||
config={
|
||||
"lr": tune.uniform(0.001, 1),
|
||||
|
||||
@@ -7,6 +7,7 @@ from __future__ import print_function
|
||||
import ray
|
||||
from ray import tune
|
||||
from ray.tune.schedulers import PopulationBasedTraining
|
||||
from ray.tune.trial import ExportFormat
|
||||
|
||||
import argparse
|
||||
import os
|
||||
@@ -285,6 +286,17 @@ class PytorchTrainable(tune.Trainable):
|
||||
self.config = new_config
|
||||
return True
|
||||
|
||||
def _export_model(self, export_formats, export_dir):
|
||||
if export_formats == [ExportFormat.MODEL]:
|
||||
path = os.path.join(export_dir, "exported_models")
|
||||
torch.save({
|
||||
"netDmodel": self.netD.state_dict(),
|
||||
"netGmodel": self.netG.state_dict()
|
||||
}, path)
|
||||
return {ExportFormat.MODEL: path}
|
||||
else:
|
||||
raise ValueError("unexpected formats: " + str(export_formats))
|
||||
|
||||
|
||||
# __Trainable_end__
|
||||
|
||||
@@ -343,6 +355,7 @@ if __name__ == "__main__":
|
||||
"training_iteration": tune_iter,
|
||||
},
|
||||
num_samples=8,
|
||||
export_formats=[ExportFormat.MODEL],
|
||||
config={
|
||||
"netG_lr": tune.sample_from(
|
||||
lambda spec: random.choice([0.0001, 0.0002, 0.0005])),
|
||||
@@ -357,7 +370,7 @@ if __name__ == "__main__":
|
||||
img_list = []
|
||||
fixed_noise = torch.randn(64, nz, 1, 1)
|
||||
for d in logdirs:
|
||||
netG_path = d + "/checkpoint_" + str(tune_iter) + "/checkpoint"
|
||||
netG_path = os.path.join(d, "exported_models")
|
||||
loadedG = Generator()
|
||||
loadedG.load_state_dict(torch.load(netG_path)["netGmodel"])
|
||||
with torch.no_grad():
|
||||
|
||||
Reference in New Issue
Block a user