diff --git a/ci/jenkins_tests/run_tune_tests.sh b/ci/jenkins_tests/run_tune_tests.sh index 5c63f9f7c..e488a6d99 100755 --- a/ci/jenkins_tests/run_tune_tests.sh +++ b/ci/jenkins_tests/run_tune_tests.sh @@ -126,6 +126,14 @@ $SUPPRESS_OUTPUT docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} python /ray/python/ray/tune/examples/pbt_memnn_example.py \ --smoke-test +$SUPPRESS_OUTPUT --force-direct docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \ + python /ray/python/ray/tune/examples/pbt_convnet_example.py \ + --smoke-test + +$SUPPRESS_OUTPUT --force-direct docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \ + python /ray/python/ray/tune/examples/pbt_dcgan_mnist/pbt_dcgan_mnist.py \ + --smoke-test + # uncomment once statsmodels is updated. # $SUPPRESS_OUTPUT docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \ # python /ray/python/ray/tune/examples/bohb_example.py \ diff --git a/doc/source/images/tune_advanced_dcgan_Gloss.png b/doc/source/images/tune_advanced_dcgan_Gloss.png new file mode 100644 index 000000000..44d6c1dcb Binary files /dev/null and b/doc/source/images/tune_advanced_dcgan_Gloss.png differ diff --git a/doc/source/images/tune_advanced_dcgan_generated.gif b/doc/source/images/tune_advanced_dcgan_generated.gif new file mode 100644 index 000000000..c4ded5676 Binary files /dev/null and b/doc/source/images/tune_advanced_dcgan_generated.gif differ diff --git a/doc/source/images/tune_advanced_dcgan_inscore.png b/doc/source/images/tune_advanced_dcgan_inscore.png new file mode 100644 index 000000000..ddec535fd Binary files /dev/null and b/doc/source/images/tune_advanced_dcgan_inscore.png differ diff --git a/doc/source/images/tune_advanced_paper1.png b/doc/source/images/tune_advanced_paper1.png new file mode 100644 index 000000000..7c4dac654 Binary files /dev/null and b/doc/source/images/tune_advanced_paper1.png differ diff --git a/doc/source/images/tune_advanced_plot1.png b/doc/source/images/tune_advanced_plot1.png new file mode 100644 index 000000000..d54bde455 Binary files /dev/null and b/doc/source/images/tune_advanced_plot1.png differ diff --git a/doc/source/index.rst b/doc/source/index.rst index 051530930..76b4f645e 100644 --- a/doc/source/index.rst +++ b/doc/source/index.rst @@ -246,6 +246,7 @@ Getting Involved tune.rst tune-tutorial.rst + tune-advanced-tutorial.rst tune-usage.rst tune-distributed.rst tune-schedulers.rst diff --git a/doc/source/tune-advanced-tutorial.rst b/doc/source/tune-advanced-tutorial.rst new file mode 100644 index 000000000..f48050190 --- /dev/null +++ b/doc/source/tune-advanced-tutorial.rst @@ -0,0 +1,259 @@ +Tune Advanced Tutorials +======================= + +In this page, we will explore some advanced functionality in Tune with more examples. + +On this page: + +.. contents:: + :local: + :backlinks: none + +A native example of Trainable +----------------------------- +As mentioned in `Tune User Guide `_, Training can be done +with either the `Trainable `__ **Class API** or +**function-based API**. Comparably, ``Trainable`` is stateful, supports checkpoint/restore functionality, +and is preferable for advanced algorithms. + +A naive example for ``Trainable`` is a simple number guesser: + +.. code-block:: python + + import ray + from ray import tune + from ray.tune import Trainable + + class Guesser(Trainable): + def _setup(self, config): + self.config = config + self.password = 1024 + + def _train(self): + result_dict = {"diff": abs(self.config['guess'] - self.password)} + return result_dict + + ray.init() + analysis = tune.run( + Guesser, + stop={ + "training_iteration": 1, + }, + num_samples=10, + config={ + "guess": tune.randint(1, 10000) + }) + + print('best config: ', analysis.get_best_config(metric="diff", mode="min")) + +The program randomly picks 10 number from [1, 10000) and finds which is closer to the password. +As a subclass of ``ray.tune.Trainable``, Tune will convert ``Guesser`` into a Ray actor, which +runs on a separate process on a worker. ``_setup`` function is invoked once for each Actor for custom +initialization. + +``_train`` execute one logical iteration of training in the tuning process, +which may include several iterations of actual training (see the next example). As a rule of +thumb, the execution time of one train call should be large enough to avoid overheads +(i.e. more than a few seconds), but short enough to report progress periodically +(i.e. at most a few minutes). + +We only implemented ``_setup`` and ``_train`` methods for simplification, usually it's also required +to implement ``_save``, and ``_restore`` for checkpoint and fault tolerance. + +Next, we train a Pytorch convolution model with Trainable and PBT. + +Trainable with Population Based Training (PBT) +---------------------------------------------- + +Tune includes a distributed implementation of `Population Based Training (PBT) `__ as +a scheduler `PopulationBasedTraining `__ . + +PBT starts by training many neural networks in parallel with random hyperparameters. But instead of the +networks training independently, it uses information from the rest of the population to refine the +hyperparameters and direct computational resources to models which show promise. + +.. image:: images/tune_advanced_paper1.png + +This takes its inspiration from genetic algorithms where each member of the population +can exploit information from the remainder of the population. For example, a worker might +copy the model parameters from a better performing worker. It can also explore new hyperparameters by +changing the current values randomly. + +As the training of the population of neural networks progresses, this process of exploiting and exploring +is performed periodically, ensuring that all the workers in the population have a good base level of performance +and also that new hyperparameters are consistently explored. + +This means that PBT can quickly exploit good hyperparameters, can dedicate more training time to +promising models and, crucially, can adapt the hyperparameter values throughout training, +leading to automatic learning of the best configurations. + +First we define a Trainable that wraps a ConvNet model. + +.. literalinclude:: ../../python/ray/tune/examples/pbt_convnet_example.py + :language: python + :start-after: __trainable_begin__ + :end-before: __trainable_end__ + +The example reuses some of the functions in ray/tune/examples/mnist_pytorch.py, and is also a good +demo for how to decouple the tuning logic and original training code. + +Here, we also override ``reset_config``. This method is optional but can be implemented to speed +up algorithms such as PBT, and to allow performance optimizations such as running experiments +with ``reuse_actors=True``. + +Then, we define a PBT scheduler: + +.. literalinclude:: ../../python/ray/tune/examples/pbt_convnet_example.py + :language: python + :start-after: __pbt_begin__ + :end-before: __pbt_end__ + +Some of the most important parameters are: + +- ``hyperparam_mutations`` and ``custom_explore_fn`` are used to mutate the hyperparameters. + ``hyperparam_mutations`` is a dictionary where each key/value pair specifies the candidates + or function for a hyperparameter. custom_explore_fn is applied after built-in perturbations + from hyperparam_mutations are applied, and should return config updated as needed. + +- ``resample_probability``: The probability of resampling from the original distribution + when applying hyperparam_mutations. If not resampled, the value will be perturbed by a + factor of 1.2 or 0.8 if continuous, or changed to an adjacent value if discrete. Note that + ``resample_probability`` by default is 0.25, thus hyperparameter with a distribution + may go out of the specific range. + +Now we can kick off the tuning process by invoking tune.run: + +.. literalinclude:: ../../python/ray/tune/examples/pbt_convnet_example.py + :language: python + :start-after: __tune_begin__ + :end-before: __tune_end__ + +During the training, we can constantly check the status of the models from console log: + +.. code-block:: bash + + == Status == + Memory usage on this node: 10.4/16.0 GiB + PopulationBasedTraining: 4 checkpoints, 1 perturbs + Resources requested: 4/12 CPUs, 0/0 GPUs, 0.0/3.42 GiB heap, 0.0/1.17 GiB objects + Number of trials: 4 ({'RUNNING': 4}) + Result logdir: /Users/yuhao.yang/ray_results/pbt_test + +--------------------------+----------+---------------------+----------+------------+--------+------------------+----------+ + | Trial name | status | loc | lr | momentum | iter | total time (s) | acc | + |--------------------------+----------+---------------------+----------+------------+--------+------------------+----------| + | PytorchTrainble_3b42d914 | RUNNING | 30.57.180.224:49840 | 0.122032 | 0.302176 | 18 | 3.8689 | 0.8875 | + | PytorchTrainble_3b45091e | RUNNING | 30.57.180.224:49835 | 0.505325 | 0.628559 | 18 | 3.90404 | 0.134375 | + | PytorchTrainble_3b454c46 | RUNNING | 30.57.180.224:49843 | 0.490228 | 0.969013 | 17 | 3.72111 | 0.0875 | + | PytorchTrainble_3b458a9c | RUNNING | 30.57.180.224:49833 | 0.961861 | 0.169701 | 13 | 2.72594 | 0.1125 | + +--------------------------+----------+---------------------+----------+------------+--------+------------------+----------+ + +In {LOG_DIR}/{MY_EXPERIMENT_NAME}/, all mutations are logged in pbt_global.txt +and individual policy perturbations are recorded in pbt_policy_{i}.txt. Tune logs: +[target trial tag, clone trial tag, target trial iteration, clone trial iteration, +old config, new config] on each perturbation step. + +Checking the accuracy: + +.. code-block:: python + + # Plot by wall-clock time + dfs = analysis.fetch_trial_dataframes() + # This plots everything on the same plot + ax = None + for d in dfs.values(): + ax = d.plot("training_iteration", "mean_accuracy", ax=ax, legend=False) + + plt.xlabel("iterations") + plt.ylabel("Test Accuracy") + + print('best config:', analysis.get_best_config("mean_accuracy")) + +.. image:: images/tune_advanced_plot1.png + +DCGAN with Trainable and PBT +---------------------------- + +The Generative Adversarial Networks (GAN) (Goodfellow et al., 2014) framework learns generative +models via a training paradigm consisting of two competing modules – a generator and a +discriminator. GAN training can be remarkably brittle and unstable in the face of suboptimal +hyperparameter selection with generators often collapsing to a single mode or diverging entirely. + +As presented in `Population Based Training (PBT) `__, +PBT can help with the DCGAN training. We will now walk through how to do this in Tune. +Complete code example at `github `__ + +We define the Generator and Discriminator with standard Pytorch API: + +.. literalinclude:: ../../python/ray/tune/examples/pbt_dcgan_mnist/pbt_dcgan_mnist.py + :language: python + :start-after: __GANmodel_begin__ + :end-before: __GANmodel_end__ + +To train the model with PBT, we need to define a metric for the scheduler to evaluate +the model candidates. For a GAN network, inception score is arguably the most +commonly used metric. We trained a mnist classification model (LeNet) and use +it to inference the generated images and evaluate the image quality. + +.. literalinclude:: ../../python/ray/tune/examples/pbt_dcgan_mnist/pbt_dcgan_mnist.py + :language: python + :start-after: __INCEPTION_SCORE_begin__ + :end-before: __INCEPTION_SCORE_end__ + +The ``Trainable`` class includes a Generator and a Discriminator, each with an +independent learning rate and optimizer. + +.. literalinclude:: ../../python/ray/tune/examples/pbt_dcgan_mnist/pbt_dcgan_mnist.py + :language: python + :start-after: __Trainable_begin__ + :end-before: __Trainable_end__ + +We specify inception score as the metric and start the tuning: + +.. literalinclude:: ../../python/ray/tune/examples/pbt_dcgan_mnist/pbt_dcgan_mnist.py + :language: python + :start-after: __tune_begin__ + :end-before: __tune_end__ + +The trained Generator models can be loaded from checkpoints, and generate images +from noise signals. + +.. image:: images/tune_advanced_dcgan_generated.gif + +Visualize the increasing inception score from the training logs. + +.. code-block:: python + + lossG = [df['is_score'].tolist() for df in list(analysis.trial_dataframes.values())] + + plt.figure(figsize=(10,5)) + plt.title("Inception Score During Training") + for i, lossg in enumerate(lossG): + plt.plot(lossg,label=i) + + plt.xlabel("iterations") + plt.ylabel("is_score") + plt.legend() + plt.show() + +.. image:: images/tune_advanced_dcgan_inscore.png + +And the Generator loss: + +.. code-block:: python + + lossG = [df['lossg'].tolist() for df in list(analysis.trial_dataframes.values())] + + plt.figure(figsize=(10,5)) + plt.title("Generator Loss During Training") + for i, lossg in enumerate(lossG): + plt.plot(lossg,label=i) + + plt.xlabel("iterations") + plt.ylabel("LossG") + plt.legend() + plt.show() + +.. image:: images/tune_advanced_dcgan_Gloss.png + +Training of the MNist Generator takes about several minutes. The example can be easily +altered to generate images for other dataset, e.g. cifar10 or LSUN. diff --git a/python/ray/tune/examples/pbt_convnet_example.py b/python/ray/tune/examples/pbt_convnet_example.py new file mode 100644 index 000000000..cfe62c120 --- /dev/null +++ b/python/ray/tune/examples/pbt_convnet_example.py @@ -0,0 +1,108 @@ +#!/usr/bin/env python + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# __tutorial_imports_begin__ +import argparse +import os +import numpy as np +import torch +import torch.optim as optim +from torchvision import datasets +from ray.tune.examples.mnist_pytorch import train, test, ConvNet,\ + get_data_loaders + +import ray +from ray import tune +from ray.tune.schedulers import PopulationBasedTraining +from ray.tune.util import validate_save_restore + +# __tutorial_imports_end__ + + +# __trainable_begin__ +class PytorchTrainble(tune.Trainable): + """Train a Pytorch ConvNet with Trainable and PopulationBasedTraining + scheduler. The example reuse some of the functions in mnist_pytorch, + and is a good demo for how to add the tuning function without + changing the original training code. + """ + + def _setup(self, config): + self.train_loader, self.test_loader = get_data_loaders() + self.model = ConvNet() + self.optimizer = optim.SGD( + self.model.parameters(), + lr=config.get("lr", 0.01), + momentum=config.get("momentum", 0.9)) + + def _train(self): + train(self.model, self.optimizer, self.train_loader) + acc = test(self.model, self.test_loader) + return {"mean_accuracy": acc} + + def _save(self, checkpoint_dir): + checkpoint_path = os.path.join(checkpoint_dir, "model.pth") + torch.save(self.model.state_dict(), checkpoint_path) + return checkpoint_path + + def _restore(self, checkpoint_path): + self.model.load_state_dict(torch.load(checkpoint_path)) + + def reset_config(self, new_config): + del self.optimizer + self.optimizer = optim.SGD( + self.model.parameters(), + lr=new_config.get("lr", 0.01), + momentum=new_config.get("momentum", 0.9)) + return True + + +# __trainable_end__ + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--smoke-test", action="store_true", help="Finish quickly for testing") + args, _ = parser.parse_known_args() + + ray.init() + datasets.MNIST("~/data", train=True, download=True) + + # check if PytorchTrainble will save/restore correctly before execution + validate_save_restore(PytorchTrainble) + validate_save_restore(PytorchTrainble, use_object_store=True) + print("Success!") + + # __pbt_begin__ + scheduler = PopulationBasedTraining( + time_attr="training_iteration", + metric="mean_accuracy", + mode="max", + perturbation_interval=5, + hyperparam_mutations={ + # distribution for resampling + "lr": lambda: np.random.uniform(0.0001, 1), + # allow perturbations within this set of categorical values + "momentum": [0.8, 0.9, 0.99], + }) + # __pbt_end__ + + # __tune_begin__ + analysis = tune.run( + PytorchTrainble, + name="pbt_test", + scheduler=scheduler, + reuse_actors=True, + verbose=1, + stop={ + "training_iteration": 5 if args.smoke_test else 100, + }, + num_samples=4, + config={ + "lr": tune.uniform(0.001, 1), + "momentum": tune.uniform(0.001, 1), + }) + # __tune_end__ diff --git a/python/ray/tune/examples/pbt_dcgan_mnist/mnist_cnn.pt b/python/ray/tune/examples/pbt_dcgan_mnist/mnist_cnn.pt new file mode 100644 index 000000000..1c4364e16 Binary files /dev/null and b/python/ray/tune/examples/pbt_dcgan_mnist/mnist_cnn.pt differ diff --git a/python/ray/tune/examples/pbt_dcgan_mnist/pbt_dcgan_mnist.py b/python/ray/tune/examples/pbt_dcgan_mnist/pbt_dcgan_mnist.py new file mode 100644 index 000000000..83862a2ab --- /dev/null +++ b/python/ray/tune/examples/pbt_dcgan_mnist/pbt_dcgan_mnist.py @@ -0,0 +1,377 @@ +#!/usr/bin/env python + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import ray +from ray import tune +from ray.tune.schedulers import PopulationBasedTraining + +import argparse +import os +import random +import torch +import torch.nn as nn +import torch.nn.parallel +import torch.optim as optim +import torch.utils.data +import torchvision.datasets as dset +import torchvision.transforms as transforms +import torchvision.utils as vutils +import numpy as np +import matplotlib.pyplot as plt +import matplotlib.animation as animation + +from torch.autograd import Variable +from torch.nn import functional as F +from scipy.stats import entropy + +# Training parameters +dataroot = "/tmp/" +workers = 2 +batch_size = 64 +image_size = 32 + +# Number of channels in the training images. For color images this is 3 +nc = 1 + +# Size of z latent vector (i.e. size of generator input) +nz = 100 + +# Size of feature maps in generator +ngf = 32 + +# Size of feature maps in discriminator +ndf = 32 + +# Beta1 hyperparam for Adam optimizers +beta1 = 0.5 + +# iterations of actual training in each Trainable _train +train_iterations_per_step = 5 + + +def get_data_loader(): + dataset = dset.MNIST( + root=dataroot, + download=True, + transform=transforms.Compose([ + transforms.Resize(image_size), + transforms.ToTensor(), + transforms.Normalize((0.5, ), (0.5, )), + ])) + + # Create the dataloader + dataloader = torch.utils.data.DataLoader( + dataset, batch_size=batch_size, shuffle=True, num_workers=workers) + + return dataloader + + +# __GANmodel_begin__ +# custom weights initialization called on netG and netD +def weights_init(m): + classname = m.__class__.__name__ + if classname.find("Conv") != -1: + nn.init.normal_(m.weight.data, 0.0, 0.02) + elif classname.find("BatchNorm") != -1: + nn.init.normal_(m.weight.data, 1.0, 0.02) + nn.init.constant_(m.bias.data, 0) + + +# Generator Code +class Generator(nn.Module): + def __init__(self): + super(Generator, self).__init__() + self.main = nn.Sequential( + # input is Z, going into a convolution + nn.ConvTranspose2d(nz, ngf * 4, 4, 1, 0, bias=False), + nn.BatchNorm2d(ngf * 4), + nn.ReLU(True), + nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False), + nn.BatchNorm2d(ngf * 2), + nn.ReLU(True), + nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False), + nn.BatchNorm2d(ngf), + nn.ReLU(True), + nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False), + nn.Tanh()) + + def forward(self, input): + return self.main(input) + + +class Discriminator(nn.Module): + def __init__(self): + super(Discriminator, self).__init__() + self.main = nn.Sequential( + nn.Conv2d(nc, ndf, 4, 2, 1, bias=False), + nn.LeakyReLU(0.2, inplace=True), + nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False), + nn.BatchNorm2d(ndf * 2), nn.LeakyReLU(0.2, inplace=True), + nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False), + nn.BatchNorm2d(ndf * 4), nn.LeakyReLU(0.2, inplace=True), + nn.Conv2d(ndf * 4, 1, 4, 1, 0, bias=False), nn.Sigmoid()) + + def forward(self, input): + return self.main(input) + + +# __GANmodel_end__ + + +# __INCEPTION_SCORE_begin__ +class Net(nn.Module): + """ + LeNet for MNist classification, used for inception_score + """ + + def __init__(self): + super(Net, self).__init__() + self.conv1 = nn.Conv2d(1, 10, kernel_size=5) + self.conv2 = nn.Conv2d(10, 20, kernel_size=5) + self.conv2_drop = nn.Dropout2d() + self.fc1 = nn.Linear(320, 50) + self.fc2 = nn.Linear(50, 10) + + def forward(self, x): + x = F.relu(F.max_pool2d(self.conv1(x), 2)) + x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2)) + x = x.view(-1, 320) + x = F.relu(self.fc1(x)) + x = F.dropout(x, training=self.training) + x = self.fc2(x) + return F.log_softmax(x, dim=1) + + +def inception_score(imgs, batch_size=32, splits=1): + N = len(imgs) + dtype = torch.FloatTensor + dataloader = torch.utils.data.DataLoader(imgs, batch_size=batch_size) + cm = ray.get(mnist_model_ref) + up = nn.Upsample(size=(28, 28), mode="bilinear").type(dtype) + + def get_pred(x): + x = up(x) + x = cm(x) + return F.softmax(x).data.cpu().numpy() + + preds = np.zeros((N, 10)) + for i, batch in enumerate(dataloader, 0): + batch = batch.type(dtype) + batchv = Variable(batch) + batch_size_i = batch.size()[0] + preds[i * batch_size:i * batch_size + batch_size_i] = get_pred(batchv) + + # Now compute the mean kl-div + split_scores = [] + for k in range(splits): + part = preds[k * (N // splits):(k + 1) * (N // splits), :] + py = np.mean(part, axis=0) + scores = [] + for i in range(part.shape[0]): + pyx = part[i, :] + scores.append(entropy(pyx, py)) + split_scores.append(np.exp(np.mean(scores))) + + return np.mean(split_scores), np.std(split_scores) + + +# __INCEPTION_SCORE_end__ + + +def train(netD, netG, optimG, optimD, criterion, dataloader, iteration, + device): + real_label = 1 + fake_label = 0 + + for i, data in enumerate(dataloader, 0): + if i >= train_iterations_per_step: + break + + netD.zero_grad() + real_cpu = data[0].to(device) + b_size = real_cpu.size(0) + label = torch.full((b_size, ), real_label, device=device) + output = netD(real_cpu).view(-1) + errD_real = criterion(output, label) + errD_real.backward() + D_x = output.mean().item() + + noise = torch.randn(b_size, nz, 1, 1, device=device) + fake = netG(noise) + label.fill_(fake_label) + output = netD(fake.detach()).view(-1) + errD_fake = criterion(output, label) + errD_fake.backward() + D_G_z1 = output.mean().item() + errD = errD_real + errD_fake + optimD.step() + + netG.zero_grad() + label.fill_(real_label) + output = netD(fake).view(-1) + errG = criterion(output, label) + errG.backward() + D_G_z2 = output.mean().item() + optimG.step() + + is_score, is_std = inception_score(fake) + + # Output training stats + if iteration % 10 == 0: + print("[%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z))" + ": %.4f / %.4f \tInception score: %.4f" % + (iteration, len(dataloader), errD.item(), errG.item(), D_x, + D_G_z1, D_G_z2, is_score)) + + return errG.item(), errD.item(), is_score + + +# __Trainable_begin__ +class PytorchTrainable(tune.Trainable): + def _setup(self, config): + use_cuda = config.get("use_gpu") and torch.cuda.is_available() + self.device = torch.device("cuda" if use_cuda else "cpu") + self.netD = Discriminator().to(self.device) + self.netD.apply(weights_init) + self.netG = Generator().to(self.device) + self.netG.apply(weights_init) + self.criterion = nn.BCELoss() + self.optimizerD = optim.Adam( + self.netD.parameters(), + lr=config.get("lr", 0.01), + betas=(beta1, 0.999)) + self.optimizerG = optim.Adam( + self.netG.parameters(), + lr=config.get("lr", 0.01), + betas=(beta1, 0.999)) + self.dataloader = get_data_loader() + + def _train(self): + lossG, lossD, is_score = train( + self.netD, self.netG, self.optimizerG, self.optimizerD, + self.criterion, self.dataloader, self._iteration, self.device) + return {"lossg": lossG, "lossd": lossD, "is_score": is_score} + + def _save(self, checkpoint_dir): + path = os.path.join(checkpoint_dir, "checkpoint") + torch.save({ + "netDmodel": self.netD.state_dict(), + "netGmodel": self.netG.state_dict(), + "optimD": self.optimizerD.state_dict(), + "optimG": self.optimizerG.state_dict(), + }, path) + + return checkpoint_dir + + def _restore(self, checkpoint_dir): + path = os.path.join(checkpoint_dir, "checkpoint") + checkpoint = torch.load(path) + self.netD.load_state_dict(checkpoint["netDmodel"]) + self.netG.load_state_dict(checkpoint["netGmodel"]) + self.optimizerD.load_state_dict(checkpoint["optimD"]) + self.optimizerG.load_state_dict(checkpoint["optimG"]) + + def reset_config(self, new_config): + del self.optimizerD + del self.optimizerG + self.optimizerD = optim.Adam( + self.netD.parameters(), + lr=new_config.get("netD_lr"), + betas=(beta1, 0.999)) + self.optimizerG = optim.Adam( + self.netG.parameters(), + lr=new_config.get("netG_lr"), + betas=(beta1, 0.999)) + self.config = new_config + return True + + +# __Trainable_end__ + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--smoke-test", action="store_true", help="Finish quickly for testing") + args, _ = parser.parse_known_args() + ray.init() + + dataloader = get_data_loader() + if not args.smoke_test: + # Plot some training images + real_batch = next(iter(dataloader)) + plt.figure(figsize=(8, 8)) + plt.axis("off") + plt.title("Original Images") + plt.imshow( + np.transpose( + vutils.make_grid( + real_batch[0][:64], padding=2, normalize=True).cpu(), + (1, 2, 0))) + + plt.show() + + # load the pretrained mnist classification model for inception_score + mnist_cnn = Net() + model_path = os.path.join( + os.path.dirname(ray.__file__), + "tune/examples/pbt_dcgan_mnist/mnist_cnn.pt") + mnist_cnn.load_state_dict(torch.load(model_path)) + mnist_cnn.eval() + mnist_model_ref = ray.put(mnist_cnn) + + # __tune_begin__ + scheduler = PopulationBasedTraining( + time_attr="training_iteration", + metric="is_score", + mode="max", + perturbation_interval=5, + hyperparam_mutations={ + # distribution for resampling + "netG_lr": lambda: np.random.uniform(1e-2, 1e-5), + "netD_lr": lambda: np.random.uniform(1e-2, 1e-5), + }) + + tune_iter = 5 if args.smoke_test else 300 + analysis = tune.run( + PytorchTrainable, + name="pbt_dcgan_mnist", + scheduler=scheduler, + reuse_actors=True, + verbose=1, + checkpoint_at_end=True, + stop={ + "training_iteration": tune_iter, + }, + num_samples=8, + config={ + "netG_lr": tune.sample_from( + lambda spec: random.choice([0.0001, 0.0002, 0.0005])), + "netD_lr": tune.sample_from( + lambda spec: random.choice([0.0001, 0.0002, 0.0005])) + }) + # __tune_end__ + + # demo of the trained Generators + if not args.smoke_test: + logdirs = analysis.dataframe()["logdir"].tolist() + img_list = [] + fixed_noise = torch.randn(64, nz, 1, 1) + for d in logdirs: + netG_path = d + "/checkpoint_" + str(tune_iter) + "/checkpoint" + loadedG = Generator() + loadedG.load_state_dict(torch.load(netG_path)["netGmodel"]) + with torch.no_grad(): + fake = loadedG(fixed_noise).detach().cpu() + img_list.append(vutils.make_grid(fake, padding=2, normalize=True)) + + fig = plt.figure(figsize=(8, 8)) + plt.axis("off") + ims = [[plt.imshow(np.transpose(i, (1, 2, 0)), animated=True)] + for i in img_list] + ani = animation.ArtistAnimation( + fig, ims, interval=1000, repeat_delay=1000, blit=True) + ani.save("./generated.gif", writer="imagemagick", dpi=72) + plt.show()