diff --git a/doc/source/tune-advanced-tutorial.rst b/doc/source/tune-advanced-tutorial.rst index f48050190..5b7613555 100644 --- a/doc/source/tune-advanced-tutorial.rst +++ b/doc/source/tune-advanced-tutorial.rst @@ -214,7 +214,7 @@ We specify inception score as the metric and start the tuning: :start-after: __tune_begin__ :end-before: __tune_end__ -The trained Generator models can be loaded from checkpoints, and generate images +The trained Generator models can be loaded from log directory, and generate images from noise signals. .. image:: images/tune_advanced_dcgan_generated.gif diff --git a/python/ray/tune/examples/pbt_convnet_example.py b/python/ray/tune/examples/pbt_convnet_example.py index 489aa3504..ea37caa16 100644 --- a/python/ray/tune/examples/pbt_convnet_example.py +++ b/python/ray/tune/examples/pbt_convnet_example.py @@ -18,6 +18,7 @@ import ray from ray import tune from ray.tune.schedulers import PopulationBasedTraining from ray.tune.util import validate_save_restore +from ray.tune.trial import ExportFormat # __tutorial_imports_end__ @@ -51,6 +52,14 @@ class PytorchTrainble(tune.Trainable): def _restore(self, checkpoint_path): self.model.load_state_dict(torch.load(checkpoint_path)) + def _export_model(self, export_formats, export_dir): + if export_formats == [ExportFormat.MODEL]: + path = os.path.join(export_dir, "exported_convnet.pt") + torch.save(self.model.state_dict(), path) + return {export_formats[0]: path} + else: + raise ValueError("unexpected formats: " + str(export_formats)) + def reset_config(self, new_config): for param_group in self.optimizer.param_groups: if "lr" in new_config: @@ -76,7 +85,6 @@ if __name__ == "__main__": # check if PytorchTrainble will save/restore correctly before execution validate_save_restore(PytorchTrainble) validate_save_restore(PytorchTrainble, use_object_store=True) - print("Success!") # __pbt_begin__ scheduler = PopulationBasedTraining( @@ -90,18 +98,30 @@ if __name__ == "__main__": # allow perturbations within this set of categorical values "momentum": [0.8, 0.9, 0.99], }) + # __pbt_end__ # __tune_begin__ + class Stopper: + def __init__(self): + self.should_stop = False + + def stop(self, trial_id, result): + max_iter = 5 if args.smoke_test else 100 + if not self.should_stop and result["mean_accuracy"] > 0.96: + self.should_stop = True + return self.should_stop or result["training_iteration"] >= max_iter + + stopper = Stopper() + analysis = tune.run( PytorchTrainble, name="pbt_test", scheduler=scheduler, reuse_actors=True, verbose=1, - stop={ - "training_iteration": 5 if args.smoke_test else 100, - }, + stop=stopper.stop, + export_formats=[ExportFormat.MODEL], num_samples=4, config={ "lr": tune.uniform(0.001, 1), diff --git a/python/ray/tune/examples/pbt_dcgan_mnist/pbt_dcgan_mnist.py b/python/ray/tune/examples/pbt_dcgan_mnist/pbt_dcgan_mnist.py index d4d7948f2..9a4fe461d 100644 --- a/python/ray/tune/examples/pbt_dcgan_mnist/pbt_dcgan_mnist.py +++ b/python/ray/tune/examples/pbt_dcgan_mnist/pbt_dcgan_mnist.py @@ -7,6 +7,7 @@ from __future__ import print_function import ray from ray import tune from ray.tune.schedulers import PopulationBasedTraining +from ray.tune.trial import ExportFormat import argparse import os @@ -285,6 +286,17 @@ class PytorchTrainable(tune.Trainable): self.config = new_config return True + def _export_model(self, export_formats, export_dir): + if export_formats == [ExportFormat.MODEL]: + path = os.path.join(export_dir, "exported_models") + torch.save({ + "netDmodel": self.netD.state_dict(), + "netGmodel": self.netG.state_dict() + }, path) + return {ExportFormat.MODEL: path} + else: + raise ValueError("unexpected formats: " + str(export_formats)) + # __Trainable_end__ @@ -343,6 +355,7 @@ if __name__ == "__main__": "training_iteration": tune_iter, }, num_samples=8, + export_formats=[ExportFormat.MODEL], config={ "netG_lr": tune.sample_from( lambda spec: random.choice([0.0001, 0.0002, 0.0005])), @@ -357,7 +370,7 @@ if __name__ == "__main__": img_list = [] fixed_noise = torch.randn(64, nz, 1, 1) for d in logdirs: - netG_path = d + "/checkpoint_" + str(tune_iter) + "/checkpoint" + netG_path = os.path.join(d, "exported_models") loadedG = Generator() loadedG.load_state_dict(torch.load(netG_path)["netGmodel"]) with torch.no_grad():