mirror of
https://github.com/wassname/ray.git
synced 2026-06-30 13:47:22 +08:00
[Tune] Add example and tutorial for DCGAN (#6400)
This commit is contained in:
@@ -126,6 +126,14 @@ $SUPPRESS_OUTPUT docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE}
|
||||
python /ray/python/ray/tune/examples/pbt_memnn_example.py \
|
||||
--smoke-test
|
||||
|
||||
$SUPPRESS_OUTPUT --force-direct docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \
|
||||
python /ray/python/ray/tune/examples/pbt_convnet_example.py \
|
||||
--smoke-test
|
||||
|
||||
$SUPPRESS_OUTPUT --force-direct docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \
|
||||
python /ray/python/ray/tune/examples/pbt_dcgan_mnist/pbt_dcgan_mnist.py \
|
||||
--smoke-test
|
||||
|
||||
# uncomment once statsmodels is updated.
|
||||
# $SUPPRESS_OUTPUT docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \
|
||||
# python /ray/python/ray/tune/examples/bohb_example.py \
|
||||
|
||||
Binary file not shown.
|
After Width: | Height: | Size: 69 KiB |
Binary file not shown.
|
After Width: | Height: | Size: 1011 KiB |
Binary file not shown.
|
After Width: | Height: | Size: 89 KiB |
Binary file not shown.
|
After Width: | Height: | Size: 88 KiB |
Binary file not shown.
|
After Width: | Height: | Size: 216 KiB |
@@ -246,6 +246,7 @@ Getting Involved
|
||||
|
||||
tune.rst
|
||||
tune-tutorial.rst
|
||||
tune-advanced-tutorial.rst
|
||||
tune-usage.rst
|
||||
tune-distributed.rst
|
||||
tune-schedulers.rst
|
||||
|
||||
@@ -0,0 +1,259 @@
|
||||
Tune Advanced Tutorials
|
||||
=======================
|
||||
|
||||
In this page, we will explore some advanced functionality in Tune with more examples.
|
||||
|
||||
On this page:
|
||||
|
||||
.. contents::
|
||||
:local:
|
||||
:backlinks: none
|
||||
|
||||
A native example of Trainable
|
||||
-----------------------------
|
||||
As mentioned in `Tune User Guide <tune-usage.html#Tune Training API>`_, Training can be done
|
||||
with either the `Trainable <tune-usage.html#trainable-api>`__ **Class API** or
|
||||
**function-based API**. Comparably, ``Trainable`` is stateful, supports checkpoint/restore functionality,
|
||||
and is preferable for advanced algorithms.
|
||||
|
||||
A naive example for ``Trainable`` is a simple number guesser:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import ray
|
||||
from ray import tune
|
||||
from ray.tune import Trainable
|
||||
|
||||
class Guesser(Trainable):
|
||||
def _setup(self, config):
|
||||
self.config = config
|
||||
self.password = 1024
|
||||
|
||||
def _train(self):
|
||||
result_dict = {"diff": abs(self.config['guess'] - self.password)}
|
||||
return result_dict
|
||||
|
||||
ray.init()
|
||||
analysis = tune.run(
|
||||
Guesser,
|
||||
stop={
|
||||
"training_iteration": 1,
|
||||
},
|
||||
num_samples=10,
|
||||
config={
|
||||
"guess": tune.randint(1, 10000)
|
||||
})
|
||||
|
||||
print('best config: ', analysis.get_best_config(metric="diff", mode="min"))
|
||||
|
||||
The program randomly picks 10 number from [1, 10000) and finds which is closer to the password.
|
||||
As a subclass of ``ray.tune.Trainable``, Tune will convert ``Guesser`` into a Ray actor, which
|
||||
runs on a separate process on a worker. ``_setup`` function is invoked once for each Actor for custom
|
||||
initialization.
|
||||
|
||||
``_train`` execute one logical iteration of training in the tuning process,
|
||||
which may include several iterations of actual training (see the next example). As a rule of
|
||||
thumb, the execution time of one train call should be large enough to avoid overheads
|
||||
(i.e. more than a few seconds), but short enough to report progress periodically
|
||||
(i.e. at most a few minutes).
|
||||
|
||||
We only implemented ``_setup`` and ``_train`` methods for simplification, usually it's also required
|
||||
to implement ``_save``, and ``_restore`` for checkpoint and fault tolerance.
|
||||
|
||||
Next, we train a Pytorch convolution model with Trainable and PBT.
|
||||
|
||||
Trainable with Population Based Training (PBT)
|
||||
----------------------------------------------
|
||||
|
||||
Tune includes a distributed implementation of `Population Based Training (PBT) <https://deepmind.com/blog/population-based-training-neural-networks>`__ as
|
||||
a scheduler `PopulationBasedTraining <tune-schedulers.html#Population Based Training (PBT)>`__ .
|
||||
|
||||
PBT starts by training many neural networks in parallel with random hyperparameters. But instead of the
|
||||
networks training independently, it uses information from the rest of the population to refine the
|
||||
hyperparameters and direct computational resources to models which show promise.
|
||||
|
||||
.. image:: images/tune_advanced_paper1.png
|
||||
|
||||
This takes its inspiration from genetic algorithms where each member of the population
|
||||
can exploit information from the remainder of the population. For example, a worker might
|
||||
copy the model parameters from a better performing worker. It can also explore new hyperparameters by
|
||||
changing the current values randomly.
|
||||
|
||||
As the training of the population of neural networks progresses, this process of exploiting and exploring
|
||||
is performed periodically, ensuring that all the workers in the population have a good base level of performance
|
||||
and also that new hyperparameters are consistently explored.
|
||||
|
||||
This means that PBT can quickly exploit good hyperparameters, can dedicate more training time to
|
||||
promising models and, crucially, can adapt the hyperparameter values throughout training,
|
||||
leading to automatic learning of the best configurations.
|
||||
|
||||
First we define a Trainable that wraps a ConvNet model.
|
||||
|
||||
.. literalinclude:: ../../python/ray/tune/examples/pbt_convnet_example.py
|
||||
:language: python
|
||||
:start-after: __trainable_begin__
|
||||
:end-before: __trainable_end__
|
||||
|
||||
The example reuses some of the functions in ray/tune/examples/mnist_pytorch.py, and is also a good
|
||||
demo for how to decouple the tuning logic and original training code.
|
||||
|
||||
Here, we also override ``reset_config``. This method is optional but can be implemented to speed
|
||||
up algorithms such as PBT, and to allow performance optimizations such as running experiments
|
||||
with ``reuse_actors=True``.
|
||||
|
||||
Then, we define a PBT scheduler:
|
||||
|
||||
.. literalinclude:: ../../python/ray/tune/examples/pbt_convnet_example.py
|
||||
:language: python
|
||||
:start-after: __pbt_begin__
|
||||
:end-before: __pbt_end__
|
||||
|
||||
Some of the most important parameters are:
|
||||
|
||||
- ``hyperparam_mutations`` and ``custom_explore_fn`` are used to mutate the hyperparameters.
|
||||
``hyperparam_mutations`` is a dictionary where each key/value pair specifies the candidates
|
||||
or function for a hyperparameter. custom_explore_fn is applied after built-in perturbations
|
||||
from hyperparam_mutations are applied, and should return config updated as needed.
|
||||
|
||||
- ``resample_probability``: The probability of resampling from the original distribution
|
||||
when applying hyperparam_mutations. If not resampled, the value will be perturbed by a
|
||||
factor of 1.2 or 0.8 if continuous, or changed to an adjacent value if discrete. Note that
|
||||
``resample_probability`` by default is 0.25, thus hyperparameter with a distribution
|
||||
may go out of the specific range.
|
||||
|
||||
Now we can kick off the tuning process by invoking tune.run:
|
||||
|
||||
.. literalinclude:: ../../python/ray/tune/examples/pbt_convnet_example.py
|
||||
:language: python
|
||||
:start-after: __tune_begin__
|
||||
:end-before: __tune_end__
|
||||
|
||||
During the training, we can constantly check the status of the models from console log:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
== Status ==
|
||||
Memory usage on this node: 10.4/16.0 GiB
|
||||
PopulationBasedTraining: 4 checkpoints, 1 perturbs
|
||||
Resources requested: 4/12 CPUs, 0/0 GPUs, 0.0/3.42 GiB heap, 0.0/1.17 GiB objects
|
||||
Number of trials: 4 ({'RUNNING': 4})
|
||||
Result logdir: /Users/yuhao.yang/ray_results/pbt_test
|
||||
+--------------------------+----------+---------------------+----------+------------+--------+------------------+----------+
|
||||
| Trial name | status | loc | lr | momentum | iter | total time (s) | acc |
|
||||
|--------------------------+----------+---------------------+----------+------------+--------+------------------+----------|
|
||||
| PytorchTrainble_3b42d914 | RUNNING | 30.57.180.224:49840 | 0.122032 | 0.302176 | 18 | 3.8689 | 0.8875 |
|
||||
| PytorchTrainble_3b45091e | RUNNING | 30.57.180.224:49835 | 0.505325 | 0.628559 | 18 | 3.90404 | 0.134375 |
|
||||
| PytorchTrainble_3b454c46 | RUNNING | 30.57.180.224:49843 | 0.490228 | 0.969013 | 17 | 3.72111 | 0.0875 |
|
||||
| PytorchTrainble_3b458a9c | RUNNING | 30.57.180.224:49833 | 0.961861 | 0.169701 | 13 | 2.72594 | 0.1125 |
|
||||
+--------------------------+----------+---------------------+----------+------------+--------+------------------+----------+
|
||||
|
||||
In {LOG_DIR}/{MY_EXPERIMENT_NAME}/, all mutations are logged in pbt_global.txt
|
||||
and individual policy perturbations are recorded in pbt_policy_{i}.txt. Tune logs:
|
||||
[target trial tag, clone trial tag, target trial iteration, clone trial iteration,
|
||||
old config, new config] on each perturbation step.
|
||||
|
||||
Checking the accuracy:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# Plot by wall-clock time
|
||||
dfs = analysis.fetch_trial_dataframes()
|
||||
# This plots everything on the same plot
|
||||
ax = None
|
||||
for d in dfs.values():
|
||||
ax = d.plot("training_iteration", "mean_accuracy", ax=ax, legend=False)
|
||||
|
||||
plt.xlabel("iterations")
|
||||
plt.ylabel("Test Accuracy")
|
||||
|
||||
print('best config:', analysis.get_best_config("mean_accuracy"))
|
||||
|
||||
.. image:: images/tune_advanced_plot1.png
|
||||
|
||||
DCGAN with Trainable and PBT
|
||||
----------------------------
|
||||
|
||||
The Generative Adversarial Networks (GAN) (Goodfellow et al., 2014) framework learns generative
|
||||
models via a training paradigm consisting of two competing modules – a generator and a
|
||||
discriminator. GAN training can be remarkably brittle and unstable in the face of suboptimal
|
||||
hyperparameter selection with generators often collapsing to a single mode or diverging entirely.
|
||||
|
||||
As presented in `Population Based Training (PBT) <https://deepmind.com/blog/population-based-training-neural-networks>`__,
|
||||
PBT can help with the DCGAN training. We will now walk through how to do this in Tune.
|
||||
Complete code example at `github <https://github.com/ray-project/ray/tree/master/python/ray/tune/examples/pbt_dcgan_mnist>`__
|
||||
|
||||
We define the Generator and Discriminator with standard Pytorch API:
|
||||
|
||||
.. literalinclude:: ../../python/ray/tune/examples/pbt_dcgan_mnist/pbt_dcgan_mnist.py
|
||||
:language: python
|
||||
:start-after: __GANmodel_begin__
|
||||
:end-before: __GANmodel_end__
|
||||
|
||||
To train the model with PBT, we need to define a metric for the scheduler to evaluate
|
||||
the model candidates. For a GAN network, inception score is arguably the most
|
||||
commonly used metric. We trained a mnist classification model (LeNet) and use
|
||||
it to inference the generated images and evaluate the image quality.
|
||||
|
||||
.. literalinclude:: ../../python/ray/tune/examples/pbt_dcgan_mnist/pbt_dcgan_mnist.py
|
||||
:language: python
|
||||
:start-after: __INCEPTION_SCORE_begin__
|
||||
:end-before: __INCEPTION_SCORE_end__
|
||||
|
||||
The ``Trainable`` class includes a Generator and a Discriminator, each with an
|
||||
independent learning rate and optimizer.
|
||||
|
||||
.. literalinclude:: ../../python/ray/tune/examples/pbt_dcgan_mnist/pbt_dcgan_mnist.py
|
||||
:language: python
|
||||
:start-after: __Trainable_begin__
|
||||
:end-before: __Trainable_end__
|
||||
|
||||
We specify inception score as the metric and start the tuning:
|
||||
|
||||
.. literalinclude:: ../../python/ray/tune/examples/pbt_dcgan_mnist/pbt_dcgan_mnist.py
|
||||
:language: python
|
||||
:start-after: __tune_begin__
|
||||
:end-before: __tune_end__
|
||||
|
||||
The trained Generator models can be loaded from checkpoints, and generate images
|
||||
from noise signals.
|
||||
|
||||
.. image:: images/tune_advanced_dcgan_generated.gif
|
||||
|
||||
Visualize the increasing inception score from the training logs.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
lossG = [df['is_score'].tolist() for df in list(analysis.trial_dataframes.values())]
|
||||
|
||||
plt.figure(figsize=(10,5))
|
||||
plt.title("Inception Score During Training")
|
||||
for i, lossg in enumerate(lossG):
|
||||
plt.plot(lossg,label=i)
|
||||
|
||||
plt.xlabel("iterations")
|
||||
plt.ylabel("is_score")
|
||||
plt.legend()
|
||||
plt.show()
|
||||
|
||||
.. image:: images/tune_advanced_dcgan_inscore.png
|
||||
|
||||
And the Generator loss:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
lossG = [df['lossg'].tolist() for df in list(analysis.trial_dataframes.values())]
|
||||
|
||||
plt.figure(figsize=(10,5))
|
||||
plt.title("Generator Loss During Training")
|
||||
for i, lossg in enumerate(lossG):
|
||||
plt.plot(lossg,label=i)
|
||||
|
||||
plt.xlabel("iterations")
|
||||
plt.ylabel("LossG")
|
||||
plt.legend()
|
||||
plt.show()
|
||||
|
||||
.. image:: images/tune_advanced_dcgan_Gloss.png
|
||||
|
||||
Training of the MNist Generator takes about several minutes. The example can be easily
|
||||
altered to generate images for other dataset, e.g. cifar10 or LSUN.
|
||||
@@ -0,0 +1,108 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
# __tutorial_imports_begin__
|
||||
import argparse
|
||||
import os
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.optim as optim
|
||||
from torchvision import datasets
|
||||
from ray.tune.examples.mnist_pytorch import train, test, ConvNet,\
|
||||
get_data_loaders
|
||||
|
||||
import ray
|
||||
from ray import tune
|
||||
from ray.tune.schedulers import PopulationBasedTraining
|
||||
from ray.tune.util import validate_save_restore
|
||||
|
||||
# __tutorial_imports_end__
|
||||
|
||||
|
||||
# __trainable_begin__
|
||||
class PytorchTrainble(tune.Trainable):
|
||||
"""Train a Pytorch ConvNet with Trainable and PopulationBasedTraining
|
||||
scheduler. The example reuse some of the functions in mnist_pytorch,
|
||||
and is a good demo for how to add the tuning function without
|
||||
changing the original training code.
|
||||
"""
|
||||
|
||||
def _setup(self, config):
|
||||
self.train_loader, self.test_loader = get_data_loaders()
|
||||
self.model = ConvNet()
|
||||
self.optimizer = optim.SGD(
|
||||
self.model.parameters(),
|
||||
lr=config.get("lr", 0.01),
|
||||
momentum=config.get("momentum", 0.9))
|
||||
|
||||
def _train(self):
|
||||
train(self.model, self.optimizer, self.train_loader)
|
||||
acc = test(self.model, self.test_loader)
|
||||
return {"mean_accuracy": acc}
|
||||
|
||||
def _save(self, checkpoint_dir):
|
||||
checkpoint_path = os.path.join(checkpoint_dir, "model.pth")
|
||||
torch.save(self.model.state_dict(), checkpoint_path)
|
||||
return checkpoint_path
|
||||
|
||||
def _restore(self, checkpoint_path):
|
||||
self.model.load_state_dict(torch.load(checkpoint_path))
|
||||
|
||||
def reset_config(self, new_config):
|
||||
del self.optimizer
|
||||
self.optimizer = optim.SGD(
|
||||
self.model.parameters(),
|
||||
lr=new_config.get("lr", 0.01),
|
||||
momentum=new_config.get("momentum", 0.9))
|
||||
return True
|
||||
|
||||
|
||||
# __trainable_end__
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--smoke-test", action="store_true", help="Finish quickly for testing")
|
||||
args, _ = parser.parse_known_args()
|
||||
|
||||
ray.init()
|
||||
datasets.MNIST("~/data", train=True, download=True)
|
||||
|
||||
# check if PytorchTrainble will save/restore correctly before execution
|
||||
validate_save_restore(PytorchTrainble)
|
||||
validate_save_restore(PytorchTrainble, use_object_store=True)
|
||||
print("Success!")
|
||||
|
||||
# __pbt_begin__
|
||||
scheduler = PopulationBasedTraining(
|
||||
time_attr="training_iteration",
|
||||
metric="mean_accuracy",
|
||||
mode="max",
|
||||
perturbation_interval=5,
|
||||
hyperparam_mutations={
|
||||
# distribution for resampling
|
||||
"lr": lambda: np.random.uniform(0.0001, 1),
|
||||
# allow perturbations within this set of categorical values
|
||||
"momentum": [0.8, 0.9, 0.99],
|
||||
})
|
||||
# __pbt_end__
|
||||
|
||||
# __tune_begin__
|
||||
analysis = tune.run(
|
||||
PytorchTrainble,
|
||||
name="pbt_test",
|
||||
scheduler=scheduler,
|
||||
reuse_actors=True,
|
||||
verbose=1,
|
||||
stop={
|
||||
"training_iteration": 5 if args.smoke_test else 100,
|
||||
},
|
||||
num_samples=4,
|
||||
config={
|
||||
"lr": tune.uniform(0.001, 1),
|
||||
"momentum": tune.uniform(0.001, 1),
|
||||
})
|
||||
# __tune_end__
|
||||
Binary file not shown.
@@ -0,0 +1,377 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import ray
|
||||
from ray import tune
|
||||
from ray.tune.schedulers import PopulationBasedTraining
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import random
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.parallel
|
||||
import torch.optim as optim
|
||||
import torch.utils.data
|
||||
import torchvision.datasets as dset
|
||||
import torchvision.transforms as transforms
|
||||
import torchvision.utils as vutils
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
import matplotlib.animation as animation
|
||||
|
||||
from torch.autograd import Variable
|
||||
from torch.nn import functional as F
|
||||
from scipy.stats import entropy
|
||||
|
||||
# Training parameters
|
||||
dataroot = "/tmp/"
|
||||
workers = 2
|
||||
batch_size = 64
|
||||
image_size = 32
|
||||
|
||||
# Number of channels in the training images. For color images this is 3
|
||||
nc = 1
|
||||
|
||||
# Size of z latent vector (i.e. size of generator input)
|
||||
nz = 100
|
||||
|
||||
# Size of feature maps in generator
|
||||
ngf = 32
|
||||
|
||||
# Size of feature maps in discriminator
|
||||
ndf = 32
|
||||
|
||||
# Beta1 hyperparam for Adam optimizers
|
||||
beta1 = 0.5
|
||||
|
||||
# iterations of actual training in each Trainable _train
|
||||
train_iterations_per_step = 5
|
||||
|
||||
|
||||
def get_data_loader():
|
||||
dataset = dset.MNIST(
|
||||
root=dataroot,
|
||||
download=True,
|
||||
transform=transforms.Compose([
|
||||
transforms.Resize(image_size),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.5, ), (0.5, )),
|
||||
]))
|
||||
|
||||
# Create the dataloader
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset, batch_size=batch_size, shuffle=True, num_workers=workers)
|
||||
|
||||
return dataloader
|
||||
|
||||
|
||||
# __GANmodel_begin__
|
||||
# custom weights initialization called on netG and netD
|
||||
def weights_init(m):
|
||||
classname = m.__class__.__name__
|
||||
if classname.find("Conv") != -1:
|
||||
nn.init.normal_(m.weight.data, 0.0, 0.02)
|
||||
elif classname.find("BatchNorm") != -1:
|
||||
nn.init.normal_(m.weight.data, 1.0, 0.02)
|
||||
nn.init.constant_(m.bias.data, 0)
|
||||
|
||||
|
||||
# Generator Code
|
||||
class Generator(nn.Module):
|
||||
def __init__(self):
|
||||
super(Generator, self).__init__()
|
||||
self.main = nn.Sequential(
|
||||
# input is Z, going into a convolution
|
||||
nn.ConvTranspose2d(nz, ngf * 4, 4, 1, 0, bias=False),
|
||||
nn.BatchNorm2d(ngf * 4),
|
||||
nn.ReLU(True),
|
||||
nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
|
||||
nn.BatchNorm2d(ngf * 2),
|
||||
nn.ReLU(True),
|
||||
nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
|
||||
nn.BatchNorm2d(ngf),
|
||||
nn.ReLU(True),
|
||||
nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),
|
||||
nn.Tanh())
|
||||
|
||||
def forward(self, input):
|
||||
return self.main(input)
|
||||
|
||||
|
||||
class Discriminator(nn.Module):
|
||||
def __init__(self):
|
||||
super(Discriminator, self).__init__()
|
||||
self.main = nn.Sequential(
|
||||
nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
|
||||
nn.LeakyReLU(0.2, inplace=True),
|
||||
nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
|
||||
nn.BatchNorm2d(ndf * 2), nn.LeakyReLU(0.2, inplace=True),
|
||||
nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
|
||||
nn.BatchNorm2d(ndf * 4), nn.LeakyReLU(0.2, inplace=True),
|
||||
nn.Conv2d(ndf * 4, 1, 4, 1, 0, bias=False), nn.Sigmoid())
|
||||
|
||||
def forward(self, input):
|
||||
return self.main(input)
|
||||
|
||||
|
||||
# __GANmodel_end__
|
||||
|
||||
|
||||
# __INCEPTION_SCORE_begin__
|
||||
class Net(nn.Module):
|
||||
"""
|
||||
LeNet for MNist classification, used for inception_score
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
|
||||
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
|
||||
self.conv2_drop = nn.Dropout2d()
|
||||
self.fc1 = nn.Linear(320, 50)
|
||||
self.fc2 = nn.Linear(50, 10)
|
||||
|
||||
def forward(self, x):
|
||||
x = F.relu(F.max_pool2d(self.conv1(x), 2))
|
||||
x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
|
||||
x = x.view(-1, 320)
|
||||
x = F.relu(self.fc1(x))
|
||||
x = F.dropout(x, training=self.training)
|
||||
x = self.fc2(x)
|
||||
return F.log_softmax(x, dim=1)
|
||||
|
||||
|
||||
def inception_score(imgs, batch_size=32, splits=1):
|
||||
N = len(imgs)
|
||||
dtype = torch.FloatTensor
|
||||
dataloader = torch.utils.data.DataLoader(imgs, batch_size=batch_size)
|
||||
cm = ray.get(mnist_model_ref)
|
||||
up = nn.Upsample(size=(28, 28), mode="bilinear").type(dtype)
|
||||
|
||||
def get_pred(x):
|
||||
x = up(x)
|
||||
x = cm(x)
|
||||
return F.softmax(x).data.cpu().numpy()
|
||||
|
||||
preds = np.zeros((N, 10))
|
||||
for i, batch in enumerate(dataloader, 0):
|
||||
batch = batch.type(dtype)
|
||||
batchv = Variable(batch)
|
||||
batch_size_i = batch.size()[0]
|
||||
preds[i * batch_size:i * batch_size + batch_size_i] = get_pred(batchv)
|
||||
|
||||
# Now compute the mean kl-div
|
||||
split_scores = []
|
||||
for k in range(splits):
|
||||
part = preds[k * (N // splits):(k + 1) * (N // splits), :]
|
||||
py = np.mean(part, axis=0)
|
||||
scores = []
|
||||
for i in range(part.shape[0]):
|
||||
pyx = part[i, :]
|
||||
scores.append(entropy(pyx, py))
|
||||
split_scores.append(np.exp(np.mean(scores)))
|
||||
|
||||
return np.mean(split_scores), np.std(split_scores)
|
||||
|
||||
|
||||
# __INCEPTION_SCORE_end__
|
||||
|
||||
|
||||
def train(netD, netG, optimG, optimD, criterion, dataloader, iteration,
|
||||
device):
|
||||
real_label = 1
|
||||
fake_label = 0
|
||||
|
||||
for i, data in enumerate(dataloader, 0):
|
||||
if i >= train_iterations_per_step:
|
||||
break
|
||||
|
||||
netD.zero_grad()
|
||||
real_cpu = data[0].to(device)
|
||||
b_size = real_cpu.size(0)
|
||||
label = torch.full((b_size, ), real_label, device=device)
|
||||
output = netD(real_cpu).view(-1)
|
||||
errD_real = criterion(output, label)
|
||||
errD_real.backward()
|
||||
D_x = output.mean().item()
|
||||
|
||||
noise = torch.randn(b_size, nz, 1, 1, device=device)
|
||||
fake = netG(noise)
|
||||
label.fill_(fake_label)
|
||||
output = netD(fake.detach()).view(-1)
|
||||
errD_fake = criterion(output, label)
|
||||
errD_fake.backward()
|
||||
D_G_z1 = output.mean().item()
|
||||
errD = errD_real + errD_fake
|
||||
optimD.step()
|
||||
|
||||
netG.zero_grad()
|
||||
label.fill_(real_label)
|
||||
output = netD(fake).view(-1)
|
||||
errG = criterion(output, label)
|
||||
errG.backward()
|
||||
D_G_z2 = output.mean().item()
|
||||
optimG.step()
|
||||
|
||||
is_score, is_std = inception_score(fake)
|
||||
|
||||
# Output training stats
|
||||
if iteration % 10 == 0:
|
||||
print("[%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z))"
|
||||
": %.4f / %.4f \tInception score: %.4f" %
|
||||
(iteration, len(dataloader), errD.item(), errG.item(), D_x,
|
||||
D_G_z1, D_G_z2, is_score))
|
||||
|
||||
return errG.item(), errD.item(), is_score
|
||||
|
||||
|
||||
# __Trainable_begin__
|
||||
class PytorchTrainable(tune.Trainable):
|
||||
def _setup(self, config):
|
||||
use_cuda = config.get("use_gpu") and torch.cuda.is_available()
|
||||
self.device = torch.device("cuda" if use_cuda else "cpu")
|
||||
self.netD = Discriminator().to(self.device)
|
||||
self.netD.apply(weights_init)
|
||||
self.netG = Generator().to(self.device)
|
||||
self.netG.apply(weights_init)
|
||||
self.criterion = nn.BCELoss()
|
||||
self.optimizerD = optim.Adam(
|
||||
self.netD.parameters(),
|
||||
lr=config.get("lr", 0.01),
|
||||
betas=(beta1, 0.999))
|
||||
self.optimizerG = optim.Adam(
|
||||
self.netG.parameters(),
|
||||
lr=config.get("lr", 0.01),
|
||||
betas=(beta1, 0.999))
|
||||
self.dataloader = get_data_loader()
|
||||
|
||||
def _train(self):
|
||||
lossG, lossD, is_score = train(
|
||||
self.netD, self.netG, self.optimizerG, self.optimizerD,
|
||||
self.criterion, self.dataloader, self._iteration, self.device)
|
||||
return {"lossg": lossG, "lossd": lossD, "is_score": is_score}
|
||||
|
||||
def _save(self, checkpoint_dir):
|
||||
path = os.path.join(checkpoint_dir, "checkpoint")
|
||||
torch.save({
|
||||
"netDmodel": self.netD.state_dict(),
|
||||
"netGmodel": self.netG.state_dict(),
|
||||
"optimD": self.optimizerD.state_dict(),
|
||||
"optimG": self.optimizerG.state_dict(),
|
||||
}, path)
|
||||
|
||||
return checkpoint_dir
|
||||
|
||||
def _restore(self, checkpoint_dir):
|
||||
path = os.path.join(checkpoint_dir, "checkpoint")
|
||||
checkpoint = torch.load(path)
|
||||
self.netD.load_state_dict(checkpoint["netDmodel"])
|
||||
self.netG.load_state_dict(checkpoint["netGmodel"])
|
||||
self.optimizerD.load_state_dict(checkpoint["optimD"])
|
||||
self.optimizerG.load_state_dict(checkpoint["optimG"])
|
||||
|
||||
def reset_config(self, new_config):
|
||||
del self.optimizerD
|
||||
del self.optimizerG
|
||||
self.optimizerD = optim.Adam(
|
||||
self.netD.parameters(),
|
||||
lr=new_config.get("netD_lr"),
|
||||
betas=(beta1, 0.999))
|
||||
self.optimizerG = optim.Adam(
|
||||
self.netG.parameters(),
|
||||
lr=new_config.get("netG_lr"),
|
||||
betas=(beta1, 0.999))
|
||||
self.config = new_config
|
||||
return True
|
||||
|
||||
|
||||
# __Trainable_end__
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--smoke-test", action="store_true", help="Finish quickly for testing")
|
||||
args, _ = parser.parse_known_args()
|
||||
ray.init()
|
||||
|
||||
dataloader = get_data_loader()
|
||||
if not args.smoke_test:
|
||||
# Plot some training images
|
||||
real_batch = next(iter(dataloader))
|
||||
plt.figure(figsize=(8, 8))
|
||||
plt.axis("off")
|
||||
plt.title("Original Images")
|
||||
plt.imshow(
|
||||
np.transpose(
|
||||
vutils.make_grid(
|
||||
real_batch[0][:64], padding=2, normalize=True).cpu(),
|
||||
(1, 2, 0)))
|
||||
|
||||
plt.show()
|
||||
|
||||
# load the pretrained mnist classification model for inception_score
|
||||
mnist_cnn = Net()
|
||||
model_path = os.path.join(
|
||||
os.path.dirname(ray.__file__),
|
||||
"tune/examples/pbt_dcgan_mnist/mnist_cnn.pt")
|
||||
mnist_cnn.load_state_dict(torch.load(model_path))
|
||||
mnist_cnn.eval()
|
||||
mnist_model_ref = ray.put(mnist_cnn)
|
||||
|
||||
# __tune_begin__
|
||||
scheduler = PopulationBasedTraining(
|
||||
time_attr="training_iteration",
|
||||
metric="is_score",
|
||||
mode="max",
|
||||
perturbation_interval=5,
|
||||
hyperparam_mutations={
|
||||
# distribution for resampling
|
||||
"netG_lr": lambda: np.random.uniform(1e-2, 1e-5),
|
||||
"netD_lr": lambda: np.random.uniform(1e-2, 1e-5),
|
||||
})
|
||||
|
||||
tune_iter = 5 if args.smoke_test else 300
|
||||
analysis = tune.run(
|
||||
PytorchTrainable,
|
||||
name="pbt_dcgan_mnist",
|
||||
scheduler=scheduler,
|
||||
reuse_actors=True,
|
||||
verbose=1,
|
||||
checkpoint_at_end=True,
|
||||
stop={
|
||||
"training_iteration": tune_iter,
|
||||
},
|
||||
num_samples=8,
|
||||
config={
|
||||
"netG_lr": tune.sample_from(
|
||||
lambda spec: random.choice([0.0001, 0.0002, 0.0005])),
|
||||
"netD_lr": tune.sample_from(
|
||||
lambda spec: random.choice([0.0001, 0.0002, 0.0005]))
|
||||
})
|
||||
# __tune_end__
|
||||
|
||||
# demo of the trained Generators
|
||||
if not args.smoke_test:
|
||||
logdirs = analysis.dataframe()["logdir"].tolist()
|
||||
img_list = []
|
||||
fixed_noise = torch.randn(64, nz, 1, 1)
|
||||
for d in logdirs:
|
||||
netG_path = d + "/checkpoint_" + str(tune_iter) + "/checkpoint"
|
||||
loadedG = Generator()
|
||||
loadedG.load_state_dict(torch.load(netG_path)["netGmodel"])
|
||||
with torch.no_grad():
|
||||
fake = loadedG(fixed_noise).detach().cpu()
|
||||
img_list.append(vutils.make_grid(fake, padding=2, normalize=True))
|
||||
|
||||
fig = plt.figure(figsize=(8, 8))
|
||||
plt.axis("off")
|
||||
ims = [[plt.imshow(np.transpose(i, (1, 2, 0)), animated=True)]
|
||||
for i in img_list]
|
||||
ani = animation.ArtistAnimation(
|
||||
fig, ims, interval=1000, repeat_delay=1000, blit=True)
|
||||
ani.save("./generated.gif", writer="imagemagick", dpi=72)
|
||||
plt.show()
|
||||
Reference in New Issue
Block a user