[Tune] Add example and tutorial for DCGAN (#6400)

This commit is contained in:
Yuhao Yang
2019-12-13 14:15:44 -08:00
committed by Richard Liaw
parent be5dd8eb5e
commit ad4da17899
11 changed files with 753 additions and 0 deletions
@@ -0,0 +1,108 @@
#!/usr/bin/env python
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# __tutorial_imports_begin__
import argparse
import os
import numpy as np
import torch
import torch.optim as optim
from torchvision import datasets
from ray.tune.examples.mnist_pytorch import train, test, ConvNet,\
get_data_loaders
import ray
from ray import tune
from ray.tune.schedulers import PopulationBasedTraining
from ray.tune.util import validate_save_restore
# __tutorial_imports_end__
# __trainable_begin__
class PytorchTrainble(tune.Trainable):
"""Train a Pytorch ConvNet with Trainable and PopulationBasedTraining
scheduler. The example reuse some of the functions in mnist_pytorch,
and is a good demo for how to add the tuning function without
changing the original training code.
"""
def _setup(self, config):
self.train_loader, self.test_loader = get_data_loaders()
self.model = ConvNet()
self.optimizer = optim.SGD(
self.model.parameters(),
lr=config.get("lr", 0.01),
momentum=config.get("momentum", 0.9))
def _train(self):
train(self.model, self.optimizer, self.train_loader)
acc = test(self.model, self.test_loader)
return {"mean_accuracy": acc}
def _save(self, checkpoint_dir):
checkpoint_path = os.path.join(checkpoint_dir, "model.pth")
torch.save(self.model.state_dict(), checkpoint_path)
return checkpoint_path
def _restore(self, checkpoint_path):
self.model.load_state_dict(torch.load(checkpoint_path))
def reset_config(self, new_config):
del self.optimizer
self.optimizer = optim.SGD(
self.model.parameters(),
lr=new_config.get("lr", 0.01),
momentum=new_config.get("momentum", 0.9))
return True
# __trainable_end__
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--smoke-test", action="store_true", help="Finish quickly for testing")
args, _ = parser.parse_known_args()
ray.init()
datasets.MNIST("~/data", train=True, download=True)
# check if PytorchTrainble will save/restore correctly before execution
validate_save_restore(PytorchTrainble)
validate_save_restore(PytorchTrainble, use_object_store=True)
print("Success!")
# __pbt_begin__
scheduler = PopulationBasedTraining(
time_attr="training_iteration",
metric="mean_accuracy",
mode="max",
perturbation_interval=5,
hyperparam_mutations={
# distribution for resampling
"lr": lambda: np.random.uniform(0.0001, 1),
# allow perturbations within this set of categorical values
"momentum": [0.8, 0.9, 0.99],
})
# __pbt_end__
# __tune_begin__
analysis = tune.run(
PytorchTrainble,
name="pbt_test",
scheduler=scheduler,
reuse_actors=True,
verbose=1,
stop={
"training_iteration": 5 if args.smoke_test else 100,
},
num_samples=4,
config={
"lr": tune.uniform(0.001, 1),
"momentum": tune.uniform(0.001, 1),
})
# __tune_end__
@@ -0,0 +1,377 @@
#!/usr/bin/env python
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import ray
from ray import tune
from ray.tune.schedulers import PopulationBasedTraining
import argparse
import os
import random
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from torch.autograd import Variable
from torch.nn import functional as F
from scipy.stats import entropy
# Training parameters
dataroot = "/tmp/"
workers = 2
batch_size = 64
image_size = 32
# Number of channels in the training images. For color images this is 3
nc = 1
# Size of z latent vector (i.e. size of generator input)
nz = 100
# Size of feature maps in generator
ngf = 32
# Size of feature maps in discriminator
ndf = 32
# Beta1 hyperparam for Adam optimizers
beta1 = 0.5
# iterations of actual training in each Trainable _train
train_iterations_per_step = 5
def get_data_loader():
dataset = dset.MNIST(
root=dataroot,
download=True,
transform=transforms.Compose([
transforms.Resize(image_size),
transforms.ToTensor(),
transforms.Normalize((0.5, ), (0.5, )),
]))
# Create the dataloader
dataloader = torch.utils.data.DataLoader(
dataset, batch_size=batch_size, shuffle=True, num_workers=workers)
return dataloader
# __GANmodel_begin__
# custom weights initialization called on netG and netD
def weights_init(m):
classname = m.__class__.__name__
if classname.find("Conv") != -1:
nn.init.normal_(m.weight.data, 0.0, 0.02)
elif classname.find("BatchNorm") != -1:
nn.init.normal_(m.weight.data, 1.0, 0.02)
nn.init.constant_(m.bias.data, 0)
# Generator Code
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
self.main = nn.Sequential(
# input is Z, going into a convolution
nn.ConvTranspose2d(nz, ngf * 4, 4, 1, 0, bias=False),
nn.BatchNorm2d(ngf * 4),
nn.ReLU(True),
nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf * 2),
nn.ReLU(True),
nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf),
nn.ReLU(True),
nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),
nn.Tanh())
def forward(self, input):
return self.main(input)
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.main = nn.Sequential(
nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf * 2), nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf * 4), nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(ndf * 4, 1, 4, 1, 0, bias=False), nn.Sigmoid())
def forward(self, input):
return self.main(input)
# __GANmodel_end__
# __INCEPTION_SCORE_begin__
class Net(nn.Module):
"""
LeNet for MNist classification, used for inception_score
"""
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
self.conv2_drop = nn.Dropout2d()
self.fc1 = nn.Linear(320, 50)
self.fc2 = nn.Linear(50, 10)
def forward(self, x):
x = F.relu(F.max_pool2d(self.conv1(x), 2))
x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
x = x.view(-1, 320)
x = F.relu(self.fc1(x))
x = F.dropout(x, training=self.training)
x = self.fc2(x)
return F.log_softmax(x, dim=1)
def inception_score(imgs, batch_size=32, splits=1):
N = len(imgs)
dtype = torch.FloatTensor
dataloader = torch.utils.data.DataLoader(imgs, batch_size=batch_size)
cm = ray.get(mnist_model_ref)
up = nn.Upsample(size=(28, 28), mode="bilinear").type(dtype)
def get_pred(x):
x = up(x)
x = cm(x)
return F.softmax(x).data.cpu().numpy()
preds = np.zeros((N, 10))
for i, batch in enumerate(dataloader, 0):
batch = batch.type(dtype)
batchv = Variable(batch)
batch_size_i = batch.size()[0]
preds[i * batch_size:i * batch_size + batch_size_i] = get_pred(batchv)
# Now compute the mean kl-div
split_scores = []
for k in range(splits):
part = preds[k * (N // splits):(k + 1) * (N // splits), :]
py = np.mean(part, axis=0)
scores = []
for i in range(part.shape[0]):
pyx = part[i, :]
scores.append(entropy(pyx, py))
split_scores.append(np.exp(np.mean(scores)))
return np.mean(split_scores), np.std(split_scores)
# __INCEPTION_SCORE_end__
def train(netD, netG, optimG, optimD, criterion, dataloader, iteration,
device):
real_label = 1
fake_label = 0
for i, data in enumerate(dataloader, 0):
if i >= train_iterations_per_step:
break
netD.zero_grad()
real_cpu = data[0].to(device)
b_size = real_cpu.size(0)
label = torch.full((b_size, ), real_label, device=device)
output = netD(real_cpu).view(-1)
errD_real = criterion(output, label)
errD_real.backward()
D_x = output.mean().item()
noise = torch.randn(b_size, nz, 1, 1, device=device)
fake = netG(noise)
label.fill_(fake_label)
output = netD(fake.detach()).view(-1)
errD_fake = criterion(output, label)
errD_fake.backward()
D_G_z1 = output.mean().item()
errD = errD_real + errD_fake
optimD.step()
netG.zero_grad()
label.fill_(real_label)
output = netD(fake).view(-1)
errG = criterion(output, label)
errG.backward()
D_G_z2 = output.mean().item()
optimG.step()
is_score, is_std = inception_score(fake)
# Output training stats
if iteration % 10 == 0:
print("[%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z))"
": %.4f / %.4f \tInception score: %.4f" %
(iteration, len(dataloader), errD.item(), errG.item(), D_x,
D_G_z1, D_G_z2, is_score))
return errG.item(), errD.item(), is_score
# __Trainable_begin__
class PytorchTrainable(tune.Trainable):
def _setup(self, config):
use_cuda = config.get("use_gpu") and torch.cuda.is_available()
self.device = torch.device("cuda" if use_cuda else "cpu")
self.netD = Discriminator().to(self.device)
self.netD.apply(weights_init)
self.netG = Generator().to(self.device)
self.netG.apply(weights_init)
self.criterion = nn.BCELoss()
self.optimizerD = optim.Adam(
self.netD.parameters(),
lr=config.get("lr", 0.01),
betas=(beta1, 0.999))
self.optimizerG = optim.Adam(
self.netG.parameters(),
lr=config.get("lr", 0.01),
betas=(beta1, 0.999))
self.dataloader = get_data_loader()
def _train(self):
lossG, lossD, is_score = train(
self.netD, self.netG, self.optimizerG, self.optimizerD,
self.criterion, self.dataloader, self._iteration, self.device)
return {"lossg": lossG, "lossd": lossD, "is_score": is_score}
def _save(self, checkpoint_dir):
path = os.path.join(checkpoint_dir, "checkpoint")
torch.save({
"netDmodel": self.netD.state_dict(),
"netGmodel": self.netG.state_dict(),
"optimD": self.optimizerD.state_dict(),
"optimG": self.optimizerG.state_dict(),
}, path)
return checkpoint_dir
def _restore(self, checkpoint_dir):
path = os.path.join(checkpoint_dir, "checkpoint")
checkpoint = torch.load(path)
self.netD.load_state_dict(checkpoint["netDmodel"])
self.netG.load_state_dict(checkpoint["netGmodel"])
self.optimizerD.load_state_dict(checkpoint["optimD"])
self.optimizerG.load_state_dict(checkpoint["optimG"])
def reset_config(self, new_config):
del self.optimizerD
del self.optimizerG
self.optimizerD = optim.Adam(
self.netD.parameters(),
lr=new_config.get("netD_lr"),
betas=(beta1, 0.999))
self.optimizerG = optim.Adam(
self.netG.parameters(),
lr=new_config.get("netG_lr"),
betas=(beta1, 0.999))
self.config = new_config
return True
# __Trainable_end__
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--smoke-test", action="store_true", help="Finish quickly for testing")
args, _ = parser.parse_known_args()
ray.init()
dataloader = get_data_loader()
if not args.smoke_test:
# Plot some training images
real_batch = next(iter(dataloader))
plt.figure(figsize=(8, 8))
plt.axis("off")
plt.title("Original Images")
plt.imshow(
np.transpose(
vutils.make_grid(
real_batch[0][:64], padding=2, normalize=True).cpu(),
(1, 2, 0)))
plt.show()
# load the pretrained mnist classification model for inception_score
mnist_cnn = Net()
model_path = os.path.join(
os.path.dirname(ray.__file__),
"tune/examples/pbt_dcgan_mnist/mnist_cnn.pt")
mnist_cnn.load_state_dict(torch.load(model_path))
mnist_cnn.eval()
mnist_model_ref = ray.put(mnist_cnn)
# __tune_begin__
scheduler = PopulationBasedTraining(
time_attr="training_iteration",
metric="is_score",
mode="max",
perturbation_interval=5,
hyperparam_mutations={
# distribution for resampling
"netG_lr": lambda: np.random.uniform(1e-2, 1e-5),
"netD_lr": lambda: np.random.uniform(1e-2, 1e-5),
})
tune_iter = 5 if args.smoke_test else 300
analysis = tune.run(
PytorchTrainable,
name="pbt_dcgan_mnist",
scheduler=scheduler,
reuse_actors=True,
verbose=1,
checkpoint_at_end=True,
stop={
"training_iteration": tune_iter,
},
num_samples=8,
config={
"netG_lr": tune.sample_from(
lambda spec: random.choice([0.0001, 0.0002, 0.0005])),
"netD_lr": tune.sample_from(
lambda spec: random.choice([0.0001, 0.0002, 0.0005]))
})
# __tune_end__
# demo of the trained Generators
if not args.smoke_test:
logdirs = analysis.dataframe()["logdir"].tolist()
img_list = []
fixed_noise = torch.randn(64, nz, 1, 1)
for d in logdirs:
netG_path = d + "/checkpoint_" + str(tune_iter) + "/checkpoint"
loadedG = Generator()
loadedG.load_state_dict(torch.load(netG_path)["netGmodel"])
with torch.no_grad():
fake = loadedG(fixed_noise).detach().cpu()
img_list.append(vutils.make_grid(fake, padding=2, normalize=True))
fig = plt.figure(figsize=(8, 8))
plt.axis("off")
ims = [[plt.imshow(np.transpose(i, (1, 2, 0)), animated=True)]
for i in img_list]
ani = animation.ArtistAnimation(
fig, ims, interval=1000, repeat_delay=1000, blit=True)
ani.save("./generated.gif", writer="imagemagick", dpi=72)
plt.show()