[Tune] Add example and tutorial for DCGAN (#6400)

This commit is contained in:
Yuhao Yang
2019-12-13 14:15:44 -08:00
committed by Richard Liaw
parent be5dd8eb5e
commit ad4da17899
11 changed files with 753 additions and 0 deletions
+8
View File
@@ -126,6 +126,14 @@ $SUPPRESS_OUTPUT docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE}
python /ray/python/ray/tune/examples/pbt_memnn_example.py \
--smoke-test
$SUPPRESS_OUTPUT --force-direct docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \
python /ray/python/ray/tune/examples/pbt_convnet_example.py \
--smoke-test
$SUPPRESS_OUTPUT --force-direct docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \
python /ray/python/ray/tune/examples/pbt_dcgan_mnist/pbt_dcgan_mnist.py \
--smoke-test
# uncomment once statsmodels is updated.
# $SUPPRESS_OUTPUT docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \
# python /ray/python/ray/tune/examples/bohb_example.py \
Binary file not shown.

After

Width:  |  Height:  |  Size: 69 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1011 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 89 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 88 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 216 KiB

+1
View File
@@ -246,6 +246,7 @@ Getting Involved
tune.rst
tune-tutorial.rst
tune-advanced-tutorial.rst
tune-usage.rst
tune-distributed.rst
tune-schedulers.rst
+259
View File
@@ -0,0 +1,259 @@
Tune Advanced Tutorials
=======================
In this page, we will explore some advanced functionality in Tune with more examples.
On this page:
.. contents::
:local:
:backlinks: none
A native example of Trainable
-----------------------------
As mentioned in `Tune User Guide <tune-usage.html#Tune Training API>`_, Training can be done
with either the `Trainable <tune-usage.html#trainable-api>`__ **Class API** or
**function-based API**. Comparably, ``Trainable`` is stateful, supports checkpoint/restore functionality,
and is preferable for advanced algorithms.
A naive example for ``Trainable`` is a simple number guesser:
.. code-block:: python
import ray
from ray import tune
from ray.tune import Trainable
class Guesser(Trainable):
def _setup(self, config):
self.config = config
self.password = 1024
def _train(self):
result_dict = {"diff": abs(self.config['guess'] - self.password)}
return result_dict
ray.init()
analysis = tune.run(
Guesser,
stop={
"training_iteration": 1,
},
num_samples=10,
config={
"guess": tune.randint(1, 10000)
})
print('best config: ', analysis.get_best_config(metric="diff", mode="min"))
The program randomly picks 10 number from [1, 10000) and finds which is closer to the password.
As a subclass of ``ray.tune.Trainable``, Tune will convert ``Guesser`` into a Ray actor, which
runs on a separate process on a worker. ``_setup`` function is invoked once for each Actor for custom
initialization.
``_train`` execute one logical iteration of training in the tuning process,
which may include several iterations of actual training (see the next example). As a rule of
thumb, the execution time of one train call should be large enough to avoid overheads
(i.e. more than a few seconds), but short enough to report progress periodically
(i.e. at most a few minutes).
We only implemented ``_setup`` and ``_train`` methods for simplification, usually it's also required
to implement ``_save``, and ``_restore`` for checkpoint and fault tolerance.
Next, we train a Pytorch convolution model with Trainable and PBT.
Trainable with Population Based Training (PBT)
----------------------------------------------
Tune includes a distributed implementation of `Population Based Training (PBT) <https://deepmind.com/blog/population-based-training-neural-networks>`__ as
a scheduler `PopulationBasedTraining <tune-schedulers.html#Population Based Training (PBT)>`__ .
PBT starts by training many neural networks in parallel with random hyperparameters. But instead of the
networks training independently, it uses information from the rest of the population to refine the
hyperparameters and direct computational resources to models which show promise.
.. image:: images/tune_advanced_paper1.png
This takes its inspiration from genetic algorithms where each member of the population
can exploit information from the remainder of the population. For example, a worker might
copy the model parameters from a better performing worker. It can also explore new hyperparameters by
changing the current values randomly.
As the training of the population of neural networks progresses, this process of exploiting and exploring
is performed periodically, ensuring that all the workers in the population have a good base level of performance
and also that new hyperparameters are consistently explored.
This means that PBT can quickly exploit good hyperparameters, can dedicate more training time to
promising models and, crucially, can adapt the hyperparameter values throughout training,
leading to automatic learning of the best configurations.
First we define a Trainable that wraps a ConvNet model.
.. literalinclude:: ../../python/ray/tune/examples/pbt_convnet_example.py
:language: python
:start-after: __trainable_begin__
:end-before: __trainable_end__
The example reuses some of the functions in ray/tune/examples/mnist_pytorch.py, and is also a good
demo for how to decouple the tuning logic and original training code.
Here, we also override ``reset_config``. This method is optional but can be implemented to speed
up algorithms such as PBT, and to allow performance optimizations such as running experiments
with ``reuse_actors=True``.
Then, we define a PBT scheduler:
.. literalinclude:: ../../python/ray/tune/examples/pbt_convnet_example.py
:language: python
:start-after: __pbt_begin__
:end-before: __pbt_end__
Some of the most important parameters are:
- ``hyperparam_mutations`` and ``custom_explore_fn`` are used to mutate the hyperparameters.
``hyperparam_mutations`` is a dictionary where each key/value pair specifies the candidates
or function for a hyperparameter. custom_explore_fn is applied after built-in perturbations
from hyperparam_mutations are applied, and should return config updated as needed.
- ``resample_probability``: The probability of resampling from the original distribution
when applying hyperparam_mutations. If not resampled, the value will be perturbed by a
factor of 1.2 or 0.8 if continuous, or changed to an adjacent value if discrete. Note that
``resample_probability`` by default is 0.25, thus hyperparameter with a distribution
may go out of the specific range.
Now we can kick off the tuning process by invoking tune.run:
.. literalinclude:: ../../python/ray/tune/examples/pbt_convnet_example.py
:language: python
:start-after: __tune_begin__
:end-before: __tune_end__
During the training, we can constantly check the status of the models from console log:
.. code-block:: bash
== Status ==
Memory usage on this node: 10.4/16.0 GiB
PopulationBasedTraining: 4 checkpoints, 1 perturbs
Resources requested: 4/12 CPUs, 0/0 GPUs, 0.0/3.42 GiB heap, 0.0/1.17 GiB objects
Number of trials: 4 ({'RUNNING': 4})
Result logdir: /Users/yuhao.yang/ray_results/pbt_test
+--------------------------+----------+---------------------+----------+------------+--------+------------------+----------+
| Trial name | status | loc | lr | momentum | iter | total time (s) | acc |
|--------------------------+----------+---------------------+----------+------------+--------+------------------+----------|
| PytorchTrainble_3b42d914 | RUNNING | 30.57.180.224:49840 | 0.122032 | 0.302176 | 18 | 3.8689 | 0.8875 |
| PytorchTrainble_3b45091e | RUNNING | 30.57.180.224:49835 | 0.505325 | 0.628559 | 18 | 3.90404 | 0.134375 |
| PytorchTrainble_3b454c46 | RUNNING | 30.57.180.224:49843 | 0.490228 | 0.969013 | 17 | 3.72111 | 0.0875 |
| PytorchTrainble_3b458a9c | RUNNING | 30.57.180.224:49833 | 0.961861 | 0.169701 | 13 | 2.72594 | 0.1125 |
+--------------------------+----------+---------------------+----------+------------+--------+------------------+----------+
In {LOG_DIR}/{MY_EXPERIMENT_NAME}/, all mutations are logged in pbt_global.txt
and individual policy perturbations are recorded in pbt_policy_{i}.txt. Tune logs:
[target trial tag, clone trial tag, target trial iteration, clone trial iteration,
old config, new config] on each perturbation step.
Checking the accuracy:
.. code-block:: python
# Plot by wall-clock time
dfs = analysis.fetch_trial_dataframes()
# This plots everything on the same plot
ax = None
for d in dfs.values():
ax = d.plot("training_iteration", "mean_accuracy", ax=ax, legend=False)
plt.xlabel("iterations")
plt.ylabel("Test Accuracy")
print('best config:', analysis.get_best_config("mean_accuracy"))
.. image:: images/tune_advanced_plot1.png
DCGAN with Trainable and PBT
----------------------------
The Generative Adversarial Networks (GAN) (Goodfellow et al., 2014) framework learns generative
models via a training paradigm consisting of two competing modules a generator and a
discriminator. GAN training can be remarkably brittle and unstable in the face of suboptimal
hyperparameter selection with generators often collapsing to a single mode or diverging entirely.
As presented in `Population Based Training (PBT) <https://deepmind.com/blog/population-based-training-neural-networks>`__,
PBT can help with the DCGAN training. We will now walk through how to do this in Tune.
Complete code example at `github <https://github.com/ray-project/ray/tree/master/python/ray/tune/examples/pbt_dcgan_mnist>`__
We define the Generator and Discriminator with standard Pytorch API:
.. literalinclude:: ../../python/ray/tune/examples/pbt_dcgan_mnist/pbt_dcgan_mnist.py
:language: python
:start-after: __GANmodel_begin__
:end-before: __GANmodel_end__
To train the model with PBT, we need to define a metric for the scheduler to evaluate
the model candidates. For a GAN network, inception score is arguably the most
commonly used metric. We trained a mnist classification model (LeNet) and use
it to inference the generated images and evaluate the image quality.
.. literalinclude:: ../../python/ray/tune/examples/pbt_dcgan_mnist/pbt_dcgan_mnist.py
:language: python
:start-after: __INCEPTION_SCORE_begin__
:end-before: __INCEPTION_SCORE_end__
The ``Trainable`` class includes a Generator and a Discriminator, each with an
independent learning rate and optimizer.
.. literalinclude:: ../../python/ray/tune/examples/pbt_dcgan_mnist/pbt_dcgan_mnist.py
:language: python
:start-after: __Trainable_begin__
:end-before: __Trainable_end__
We specify inception score as the metric and start the tuning:
.. literalinclude:: ../../python/ray/tune/examples/pbt_dcgan_mnist/pbt_dcgan_mnist.py
:language: python
:start-after: __tune_begin__
:end-before: __tune_end__
The trained Generator models can be loaded from checkpoints, and generate images
from noise signals.
.. image:: images/tune_advanced_dcgan_generated.gif
Visualize the increasing inception score from the training logs.
.. code-block:: python
lossG = [df['is_score'].tolist() for df in list(analysis.trial_dataframes.values())]
plt.figure(figsize=(10,5))
plt.title("Inception Score During Training")
for i, lossg in enumerate(lossG):
plt.plot(lossg,label=i)
plt.xlabel("iterations")
plt.ylabel("is_score")
plt.legend()
plt.show()
.. image:: images/tune_advanced_dcgan_inscore.png
And the Generator loss:
.. code-block:: python
lossG = [df['lossg'].tolist() for df in list(analysis.trial_dataframes.values())]
plt.figure(figsize=(10,5))
plt.title("Generator Loss During Training")
for i, lossg in enumerate(lossG):
plt.plot(lossg,label=i)
plt.xlabel("iterations")
plt.ylabel("LossG")
plt.legend()
plt.show()
.. image:: images/tune_advanced_dcgan_Gloss.png
Training of the MNist Generator takes about several minutes. The example can be easily
altered to generate images for other dataset, e.g. cifar10 or LSUN.
@@ -0,0 +1,108 @@
#!/usr/bin/env python
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# __tutorial_imports_begin__
import argparse
import os
import numpy as np
import torch
import torch.optim as optim
from torchvision import datasets
from ray.tune.examples.mnist_pytorch import train, test, ConvNet,\
get_data_loaders
import ray
from ray import tune
from ray.tune.schedulers import PopulationBasedTraining
from ray.tune.util import validate_save_restore
# __tutorial_imports_end__
# __trainable_begin__
class PytorchTrainble(tune.Trainable):
"""Train a Pytorch ConvNet with Trainable and PopulationBasedTraining
scheduler. The example reuse some of the functions in mnist_pytorch,
and is a good demo for how to add the tuning function without
changing the original training code.
"""
def _setup(self, config):
self.train_loader, self.test_loader = get_data_loaders()
self.model = ConvNet()
self.optimizer = optim.SGD(
self.model.parameters(),
lr=config.get("lr", 0.01),
momentum=config.get("momentum", 0.9))
def _train(self):
train(self.model, self.optimizer, self.train_loader)
acc = test(self.model, self.test_loader)
return {"mean_accuracy": acc}
def _save(self, checkpoint_dir):
checkpoint_path = os.path.join(checkpoint_dir, "model.pth")
torch.save(self.model.state_dict(), checkpoint_path)
return checkpoint_path
def _restore(self, checkpoint_path):
self.model.load_state_dict(torch.load(checkpoint_path))
def reset_config(self, new_config):
del self.optimizer
self.optimizer = optim.SGD(
self.model.parameters(),
lr=new_config.get("lr", 0.01),
momentum=new_config.get("momentum", 0.9))
return True
# __trainable_end__
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--smoke-test", action="store_true", help="Finish quickly for testing")
args, _ = parser.parse_known_args()
ray.init()
datasets.MNIST("~/data", train=True, download=True)
# check if PytorchTrainble will save/restore correctly before execution
validate_save_restore(PytorchTrainble)
validate_save_restore(PytorchTrainble, use_object_store=True)
print("Success!")
# __pbt_begin__
scheduler = PopulationBasedTraining(
time_attr="training_iteration",
metric="mean_accuracy",
mode="max",
perturbation_interval=5,
hyperparam_mutations={
# distribution for resampling
"lr": lambda: np.random.uniform(0.0001, 1),
# allow perturbations within this set of categorical values
"momentum": [0.8, 0.9, 0.99],
})
# __pbt_end__
# __tune_begin__
analysis = tune.run(
PytorchTrainble,
name="pbt_test",
scheduler=scheduler,
reuse_actors=True,
verbose=1,
stop={
"training_iteration": 5 if args.smoke_test else 100,
},
num_samples=4,
config={
"lr": tune.uniform(0.001, 1),
"momentum": tune.uniform(0.001, 1),
})
# __tune_end__
@@ -0,0 +1,377 @@
#!/usr/bin/env python
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import ray
from ray import tune
from ray.tune.schedulers import PopulationBasedTraining
import argparse
import os
import random
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from torch.autograd import Variable
from torch.nn import functional as F
from scipy.stats import entropy
# Training parameters
dataroot = "/tmp/"
workers = 2
batch_size = 64
image_size = 32
# Number of channels in the training images. For color images this is 3
nc = 1
# Size of z latent vector (i.e. size of generator input)
nz = 100
# Size of feature maps in generator
ngf = 32
# Size of feature maps in discriminator
ndf = 32
# Beta1 hyperparam for Adam optimizers
beta1 = 0.5
# iterations of actual training in each Trainable _train
train_iterations_per_step = 5
def get_data_loader():
dataset = dset.MNIST(
root=dataroot,
download=True,
transform=transforms.Compose([
transforms.Resize(image_size),
transforms.ToTensor(),
transforms.Normalize((0.5, ), (0.5, )),
]))
# Create the dataloader
dataloader = torch.utils.data.DataLoader(
dataset, batch_size=batch_size, shuffle=True, num_workers=workers)
return dataloader
# __GANmodel_begin__
# custom weights initialization called on netG and netD
def weights_init(m):
classname = m.__class__.__name__
if classname.find("Conv") != -1:
nn.init.normal_(m.weight.data, 0.0, 0.02)
elif classname.find("BatchNorm") != -1:
nn.init.normal_(m.weight.data, 1.0, 0.02)
nn.init.constant_(m.bias.data, 0)
# Generator Code
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
self.main = nn.Sequential(
# input is Z, going into a convolution
nn.ConvTranspose2d(nz, ngf * 4, 4, 1, 0, bias=False),
nn.BatchNorm2d(ngf * 4),
nn.ReLU(True),
nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf * 2),
nn.ReLU(True),
nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf),
nn.ReLU(True),
nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),
nn.Tanh())
def forward(self, input):
return self.main(input)
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.main = nn.Sequential(
nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf * 2), nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf * 4), nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(ndf * 4, 1, 4, 1, 0, bias=False), nn.Sigmoid())
def forward(self, input):
return self.main(input)
# __GANmodel_end__
# __INCEPTION_SCORE_begin__
class Net(nn.Module):
"""
LeNet for MNist classification, used for inception_score
"""
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
self.conv2_drop = nn.Dropout2d()
self.fc1 = nn.Linear(320, 50)
self.fc2 = nn.Linear(50, 10)
def forward(self, x):
x = F.relu(F.max_pool2d(self.conv1(x), 2))
x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
x = x.view(-1, 320)
x = F.relu(self.fc1(x))
x = F.dropout(x, training=self.training)
x = self.fc2(x)
return F.log_softmax(x, dim=1)
def inception_score(imgs, batch_size=32, splits=1):
N = len(imgs)
dtype = torch.FloatTensor
dataloader = torch.utils.data.DataLoader(imgs, batch_size=batch_size)
cm = ray.get(mnist_model_ref)
up = nn.Upsample(size=(28, 28), mode="bilinear").type(dtype)
def get_pred(x):
x = up(x)
x = cm(x)
return F.softmax(x).data.cpu().numpy()
preds = np.zeros((N, 10))
for i, batch in enumerate(dataloader, 0):
batch = batch.type(dtype)
batchv = Variable(batch)
batch_size_i = batch.size()[0]
preds[i * batch_size:i * batch_size + batch_size_i] = get_pred(batchv)
# Now compute the mean kl-div
split_scores = []
for k in range(splits):
part = preds[k * (N // splits):(k + 1) * (N // splits), :]
py = np.mean(part, axis=0)
scores = []
for i in range(part.shape[0]):
pyx = part[i, :]
scores.append(entropy(pyx, py))
split_scores.append(np.exp(np.mean(scores)))
return np.mean(split_scores), np.std(split_scores)
# __INCEPTION_SCORE_end__
def train(netD, netG, optimG, optimD, criterion, dataloader, iteration,
device):
real_label = 1
fake_label = 0
for i, data in enumerate(dataloader, 0):
if i >= train_iterations_per_step:
break
netD.zero_grad()
real_cpu = data[0].to(device)
b_size = real_cpu.size(0)
label = torch.full((b_size, ), real_label, device=device)
output = netD(real_cpu).view(-1)
errD_real = criterion(output, label)
errD_real.backward()
D_x = output.mean().item()
noise = torch.randn(b_size, nz, 1, 1, device=device)
fake = netG(noise)
label.fill_(fake_label)
output = netD(fake.detach()).view(-1)
errD_fake = criterion(output, label)
errD_fake.backward()
D_G_z1 = output.mean().item()
errD = errD_real + errD_fake
optimD.step()
netG.zero_grad()
label.fill_(real_label)
output = netD(fake).view(-1)
errG = criterion(output, label)
errG.backward()
D_G_z2 = output.mean().item()
optimG.step()
is_score, is_std = inception_score(fake)
# Output training stats
if iteration % 10 == 0:
print("[%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z))"
": %.4f / %.4f \tInception score: %.4f" %
(iteration, len(dataloader), errD.item(), errG.item(), D_x,
D_G_z1, D_G_z2, is_score))
return errG.item(), errD.item(), is_score
# __Trainable_begin__
class PytorchTrainable(tune.Trainable):
def _setup(self, config):
use_cuda = config.get("use_gpu") and torch.cuda.is_available()
self.device = torch.device("cuda" if use_cuda else "cpu")
self.netD = Discriminator().to(self.device)
self.netD.apply(weights_init)
self.netG = Generator().to(self.device)
self.netG.apply(weights_init)
self.criterion = nn.BCELoss()
self.optimizerD = optim.Adam(
self.netD.parameters(),
lr=config.get("lr", 0.01),
betas=(beta1, 0.999))
self.optimizerG = optim.Adam(
self.netG.parameters(),
lr=config.get("lr", 0.01),
betas=(beta1, 0.999))
self.dataloader = get_data_loader()
def _train(self):
lossG, lossD, is_score = train(
self.netD, self.netG, self.optimizerG, self.optimizerD,
self.criterion, self.dataloader, self._iteration, self.device)
return {"lossg": lossG, "lossd": lossD, "is_score": is_score}
def _save(self, checkpoint_dir):
path = os.path.join(checkpoint_dir, "checkpoint")
torch.save({
"netDmodel": self.netD.state_dict(),
"netGmodel": self.netG.state_dict(),
"optimD": self.optimizerD.state_dict(),
"optimG": self.optimizerG.state_dict(),
}, path)
return checkpoint_dir
def _restore(self, checkpoint_dir):
path = os.path.join(checkpoint_dir, "checkpoint")
checkpoint = torch.load(path)
self.netD.load_state_dict(checkpoint["netDmodel"])
self.netG.load_state_dict(checkpoint["netGmodel"])
self.optimizerD.load_state_dict(checkpoint["optimD"])
self.optimizerG.load_state_dict(checkpoint["optimG"])
def reset_config(self, new_config):
del self.optimizerD
del self.optimizerG
self.optimizerD = optim.Adam(
self.netD.parameters(),
lr=new_config.get("netD_lr"),
betas=(beta1, 0.999))
self.optimizerG = optim.Adam(
self.netG.parameters(),
lr=new_config.get("netG_lr"),
betas=(beta1, 0.999))
self.config = new_config
return True
# __Trainable_end__
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--smoke-test", action="store_true", help="Finish quickly for testing")
args, _ = parser.parse_known_args()
ray.init()
dataloader = get_data_loader()
if not args.smoke_test:
# Plot some training images
real_batch = next(iter(dataloader))
plt.figure(figsize=(8, 8))
plt.axis("off")
plt.title("Original Images")
plt.imshow(
np.transpose(
vutils.make_grid(
real_batch[0][:64], padding=2, normalize=True).cpu(),
(1, 2, 0)))
plt.show()
# load the pretrained mnist classification model for inception_score
mnist_cnn = Net()
model_path = os.path.join(
os.path.dirname(ray.__file__),
"tune/examples/pbt_dcgan_mnist/mnist_cnn.pt")
mnist_cnn.load_state_dict(torch.load(model_path))
mnist_cnn.eval()
mnist_model_ref = ray.put(mnist_cnn)
# __tune_begin__
scheduler = PopulationBasedTraining(
time_attr="training_iteration",
metric="is_score",
mode="max",
perturbation_interval=5,
hyperparam_mutations={
# distribution for resampling
"netG_lr": lambda: np.random.uniform(1e-2, 1e-5),
"netD_lr": lambda: np.random.uniform(1e-2, 1e-5),
})
tune_iter = 5 if args.smoke_test else 300
analysis = tune.run(
PytorchTrainable,
name="pbt_dcgan_mnist",
scheduler=scheduler,
reuse_actors=True,
verbose=1,
checkpoint_at_end=True,
stop={
"training_iteration": tune_iter,
},
num_samples=8,
config={
"netG_lr": tune.sample_from(
lambda spec: random.choice([0.0001, 0.0002, 0.0005])),
"netD_lr": tune.sample_from(
lambda spec: random.choice([0.0001, 0.0002, 0.0005]))
})
# __tune_end__
# demo of the trained Generators
if not args.smoke_test:
logdirs = analysis.dataframe()["logdir"].tolist()
img_list = []
fixed_noise = torch.randn(64, nz, 1, 1)
for d in logdirs:
netG_path = d + "/checkpoint_" + str(tune_iter) + "/checkpoint"
loadedG = Generator()
loadedG.load_state_dict(torch.load(netG_path)["netGmodel"])
with torch.no_grad():
fake = loadedG(fixed_noise).detach().cpu()
img_list.append(vutils.make_grid(fake, padding=2, normalize=True))
fig = plt.figure(figsize=(8, 8))
plt.axis("off")
ims = [[plt.imshow(np.transpose(i, (1, 2, 0)), animated=True)]
for i in img_list]
ani = animation.ArtistAnimation(
fig, ims, interval=1000, repeat_delay=1000, blit=True)
ani.save("./generated.gif", writer="imagemagick", dpi=72)
plt.show()