mirror of
https://github.com/wassname/ray.git
synced 2026-06-27 17:49:47 +08:00
[tune] PyTorch CIFAR10 example (#9338)
Co-authored-by: Richard Liaw <rliaw@berkeley.edu> Co-authored-by: Kai Fricke <kai@anyscale.com>
This commit is contained in:
@@ -90,6 +90,9 @@ $SUPPRESS_OUTPUT docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE}
|
||||
python /ray/python/ray/tune/examples/bayesopt_example.py \
|
||||
--smoke-test
|
||||
|
||||
$SUPPRESS_OUTPUT docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} --memory-swap=-1 $DOCKER_SHA \
|
||||
python /ray/python/ray/tune/examples/cifar10_pytorch.py --smoke-test
|
||||
|
||||
$SUPPRESS_OUTPUT docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} --memory-swap=-1 $DOCKER_SHA \
|
||||
python /ray/python/ray/tune/examples/hyperopt_example.py \
|
||||
--smoke-test
|
||||
@@ -116,8 +119,7 @@ $SUPPRESS_OUTPUT docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE}
|
||||
python /ray/python/ray/tune/examples/mnist_pytorch.py --smoke-test
|
||||
|
||||
$SUPPRESS_OUTPUT docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} --memory-swap=-1 $DOCKER_SHA \
|
||||
python /ray/python/ray/tune/examples/mnist_pytorch_lightning.py \
|
||||
--smoke-test
|
||||
python /ray/python/ray/tune/examples/mnist_pytorch_lightning.py --smoke-test
|
||||
|
||||
$SUPPRESS_OUTPUT docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} --memory-swap=-1 $DOCKER_SHA \
|
||||
python /ray/python/ray/tune/examples/mnist_pytorch_trainable.py \
|
||||
|
||||
Binary file not shown.
|
After Width: | Height: | Size: 19 KiB |
@@ -70,6 +70,11 @@ These pages will demonstrate the various features and configurations of Tune.
|
||||
:figure: /images/pytorch_lightning_small.png
|
||||
:description: :doc:`Tuning PyTorch Lightning modules <tune-pytorch-lightning>`
|
||||
|
||||
.. customgalleryitem::
|
||||
:tooltip: How to use Tune with PyTorch
|
||||
:figure: /images/pytorch_logo.png
|
||||
:description: :doc:`How to use Tune with PyTorch <tune-pytorch-cifar>`
|
||||
|
||||
.. customgalleryitem::
|
||||
:tooltip: Tuning XGBoost parameters.
|
||||
:figure: /images/xgboost_logo.png
|
||||
@@ -87,6 +92,7 @@ These pages will demonstrate the various features and configurations of Tune.
|
||||
tune-advanced-tutorial.rst
|
||||
tune-distributed.rst
|
||||
tune-sklearn.rst
|
||||
tune-pytorch-cifar.rst
|
||||
tune-pytorch-lightning.rst
|
||||
tune-xgboost.rst
|
||||
|
||||
|
||||
@@ -0,0 +1,278 @@
|
||||
.. _tune-pytorch-cifar:
|
||||
|
||||
How to use Tune with PyTorch
|
||||
============================
|
||||
|
||||
In this walkthrough, we will show you how to integrate Tune into your PyTorch
|
||||
training workflow. We will follow `this tutorial from the PyTorch documentation
|
||||
<https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html>`_ for training
|
||||
a CIFAR10 image classifier.
|
||||
|
||||
.. image:: /images/pytorch_logo.png
|
||||
:align: center
|
||||
|
||||
Hyperparameter tuning can make the difference between an average model and a highly
|
||||
accurate one. Often simple things like choosing a different learning rate or changing
|
||||
a network layer size can have a dramatic impact on your model performance. Fortunately,
|
||||
Tune makes exploring these optimal parameter combinations easy - and works nicely
|
||||
together with PyTorch.
|
||||
|
||||
As you will see, we only need to add some slight modifications. In particular, we
|
||||
need to
|
||||
|
||||
1. wrap data loading and training in functions,
|
||||
2. make some network parameters configurable,
|
||||
3. add checkpointing (optional),
|
||||
4. and define the search space for the model tuning
|
||||
|
||||
.. note::
|
||||
|
||||
To run this example, you will need to install the following:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
$ pip install ray torch torchvision
|
||||
|
||||
.. contents::
|
||||
:local:
|
||||
:backlinks: none
|
||||
|
||||
Setup / Imports
|
||||
---------------
|
||||
Let's start with the imports:
|
||||
|
||||
.. literalinclude:: /../../python/ray/tune/examples/cifar10_pytorch.py
|
||||
:language: python
|
||||
:start-after: __import_begin__
|
||||
:end-before: __import_end__
|
||||
|
||||
Most of the imports are needed for building the PyTorch model. Only the last three
|
||||
imports are for Ray Tune.
|
||||
|
||||
Data loaders
|
||||
------------
|
||||
We wrap the data loaders in their own function and pass a global data directory.
|
||||
This way we can share a data directory between different trials.
|
||||
|
||||
.. literalinclude:: /../../python/ray/tune/examples/cifar10_pytorch.py
|
||||
:language: python
|
||||
:start-after: __load_data_begin__
|
||||
:end-before: __load_data_end__
|
||||
|
||||
Configurable neural network
|
||||
---------------------------
|
||||
We can only tune those parameters that are configurable. In this example, we can specify
|
||||
the layer sizes of the fully connected layers:
|
||||
|
||||
.. literalinclude:: /../../python/ray/tune/examples/cifar10_pytorch.py
|
||||
:language: python
|
||||
:start-after: __net_begin__
|
||||
:end-before: __net_end__
|
||||
|
||||
The train function
|
||||
------------------
|
||||
Now it gets interesting, because we introduce some changes to the example `from the PyTorch
|
||||
documentation <https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html>`_.
|
||||
|
||||
We wrap the training script in a function ``train_cifar(config, checkpoint=None)``. As you
|
||||
can guess, the ``config`` parameter will receive the hyperparameters we would like to
|
||||
train with. The ``checkpoint`` parameter is used to restore checkpoints.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
net = Net(config["l1"], config["l2"])
|
||||
|
||||
if checkpoint:
|
||||
net.load_state_dict(torch.load(checkpoint))
|
||||
|
||||
The learning rate of the optimizer is made configurable, too:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
optimizer = optim.SGD(net.parameters(), lr=config["lr"], momentum=0.9)
|
||||
|
||||
We also split the training data into a training and validation subset. We thus train on
|
||||
80% of the data and calculate the validation loss on the remaining 20%. The batch sizes
|
||||
with which we iterate through the training and test sets are configurable as well.
|
||||
|
||||
Adding (multi) GPU support with DataParallel
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
Image classification benefits largely from GPUs. Luckily, we can continue to use
|
||||
PyTorch's abstractions in Ray Tune. Thus, we can wrap our model in ``nn.DataParallel``
|
||||
to support data parallel training on multiple GPUs:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
device = "cpu"
|
||||
if torch.cuda.is_available():
|
||||
device = "cuda:0"
|
||||
if torch.cuda.device_count() > 1:
|
||||
net = nn.DataParallel(net)
|
||||
net.to(device)
|
||||
|
||||
By using a ``device`` variable we make sure that training also works when we have
|
||||
no GPUs available. PyTorch requires us to send our data to the GPU memory explicitly,
|
||||
like this:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
for i, data in enumerate(trainloader, 0):
|
||||
inputs, labels = data
|
||||
inputs, labels = inputs.to(device), labels.to(device)
|
||||
|
||||
The code now supports training on CPUs, on a single GPU, and on multiple GPUs. Notably, Ray
|
||||
also supports :doc:`fractional GPUs </using-ray-with-gpus>`
|
||||
so we can share GPUs among trials, as long as the model still fits on the GPU memory. We'll come back
|
||||
to that later.
|
||||
|
||||
Communicating with Ray Tune
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
The most interesting part is the communication with Tune:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
checkpoint_dir = tune.make_checkpoint_dir(epoch)
|
||||
path = os.path.join(checkpoint_dir, "checkpoint")
|
||||
torch.save((net.state_dict(), optimizer.state_dict()), path)
|
||||
tune.save_checkpoint(path)
|
||||
|
||||
tune.report(loss=(val_loss / val_steps), accuracy=correct / total)
|
||||
|
||||
Here we first save a checkpoint and then report some metrics back to Tune. Specifically,
|
||||
we send the validation loss and accuracy back to Tune. Tune can then use these metrics
|
||||
to decide which hyperparameter configuration lead to the best results. These metrics
|
||||
can also be used to stop bad performing trials early in order to avoid wasting
|
||||
resources on those trials.
|
||||
|
||||
The checkpoint saving is optional, however, it is necessary if we wanted to use advanced
|
||||
schedulers like `Population Based Training <https://docs.ray.io/en/master/tune/tutorials/tune-advanced-tutorial.html>`_.
|
||||
Also, by saving the checkpoint we can later load the trained models and validate them
|
||||
on a test set.
|
||||
|
||||
Full training function
|
||||
~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
The full code example looks like this:
|
||||
|
||||
.. literalinclude:: /../../python/ray/tune/examples/cifar10_pytorch.py
|
||||
:language: python
|
||||
:start-after: __train_begin__
|
||||
:end-before: __train_end__
|
||||
:emphasize-lines: 2,4-9,12,14-18,28,33,43,70,81-84,86
|
||||
|
||||
As you can see, most of the code is adapted directly from the example.
|
||||
|
||||
Test set accuracy
|
||||
-----------------
|
||||
Commonly the performance of a machine learning model is tested on a hold-out test
|
||||
set with data that has not been used for training the model. We also wrap this in a
|
||||
function:
|
||||
|
||||
.. literalinclude:: /../../python/ray/tune/examples/cifar10_pytorch.py
|
||||
:language: python
|
||||
:start-after: __test_acc_begin__
|
||||
:end-before: __test_acc_end__
|
||||
|
||||
As you can see, the function also expects a ``device`` parameter, so we can do the
|
||||
test set validation on a GPU.
|
||||
|
||||
Configuring the search space
|
||||
----------------------------
|
||||
Lastly, we need to define Tune's search space. Here is an example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
config = {
|
||||
"l1": tune.sample_from(lambda _: 2**np.random.randint(2, 9)),
|
||||
"l2": tune.sample_from(lambda _: 2**np.random.randint(2, 9)),
|
||||
"lr": tune.loguniform(1e-4, 1e-1),
|
||||
"batch_size": tune.choice([2, 4, 8, 16]),
|
||||
"data_dir": data_dir
|
||||
}
|
||||
|
||||
The ``tune.sample_from()`` function makes it possible to define your own sample
|
||||
methods to obtain hyperparameters. In this example, the ``l1`` and ``l2`` parameters
|
||||
should be powers of 2 between 4 and 256, so either 4, 8, 16, 32, 64, 128, or 256.
|
||||
The ``lr`` (learning rate) should be uniformly sampled between 0.0001 and 0.1. Lastly,
|
||||
the batch size is a choice between 2, 4, 8, and 16.
|
||||
|
||||
At each trial, Tune will now randomly sample a combination of parameters from these
|
||||
search spaces. It will then train a number of models in parallel and find the best
|
||||
performing one among these. We also use the ``ASHAScheduler`` which will terminate bad
|
||||
performing trials early.
|
||||
|
||||
We wrap the ``train_cifar`` function with ``functools.partial`` to set the constant
|
||||
``data_dir`` parameter. We can also tell Ray Tune what resources should be
|
||||
available for each trial:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
gpus_per_trial = 2
|
||||
# ...
|
||||
result = tune.run(
|
||||
partial(train_cifar, data_dir=data_dir),
|
||||
resources_per_trial={"cpu": 8, "gpu": gpus_per_trial},
|
||||
config=config,
|
||||
num_samples=num_samples,
|
||||
scheduler=scheduler,
|
||||
progress_reporter=reporter,
|
||||
checkpoint_at_end=True)
|
||||
|
||||
You can specify the number of CPUs, which are then available e.g.
|
||||
to increase the ``num_workers`` of the PyTorch ``DataLoader`` instances. The selected
|
||||
number of GPUs are made visible to PyTorch in each trial. Trials do not have access to
|
||||
GPUs that haven't been requested for them - so you don't have to care about two trials
|
||||
using the same set of resources.
|
||||
|
||||
Here we can also specify fractional GPUs, so something like ``gpus_per_trial=0.5`` is
|
||||
completely valid. The trials will then share GPUs among each other.
|
||||
You just have to make sure that the models still fit in the GPU memory.
|
||||
|
||||
After training the models, we will find the best performing one and load the trained
|
||||
network from the checkpoint file. We then obtain the test set accuracy and report
|
||||
everything by printing.
|
||||
|
||||
The full main function looks like this:
|
||||
|
||||
.. literalinclude:: /../../python/ray/tune/examples/cifar10_pytorch.py
|
||||
:language: python
|
||||
:start-after: __main_begin__
|
||||
:end-before: __main_end__
|
||||
|
||||
If you run the code, an example output could look like this:
|
||||
|
||||
.. code-block::
|
||||
:emphasize-lines: 7
|
||||
|
||||
Number of trials: 10 (10 TERMINATED)
|
||||
+-------------------------+------------+-------+------+------+-------------+--------------+---------+------------+----------------------+
|
||||
| Trial name | status | loc | l1 | l2 | lr | batch_size | loss | accuracy | training_iteration |
|
||||
|-------------------------+------------+-------+------+------+-------------+--------------+---------+------------+----------------------|
|
||||
| train_cifar_87d1f_00000 | TERMINATED | | 64 | 4 | 0.00011629 | 2 | 1.87273 | 0.244 | 2 |
|
||||
| train_cifar_87d1f_00001 | TERMINATED | | 32 | 64 | 0.000339763 | 8 | 1.23603 | 0.567 | 8 |
|
||||
| train_cifar_87d1f_00002 | TERMINATED | | 8 | 16 | 0.00276249 | 16 | 1.1815 | 0.5836 | 10 |
|
||||
| train_cifar_87d1f_00003 | TERMINATED | | 4 | 64 | 0.000648721 | 4 | 1.31131 | 0.5224 | 8 |
|
||||
| train_cifar_87d1f_00004 | TERMINATED | | 32 | 16 | 0.000340753 | 8 | 1.26454 | 0.5444 | 8 |
|
||||
| train_cifar_87d1f_00005 | TERMINATED | | 8 | 4 | 0.000699775 | 8 | 1.99594 | 0.1983 | 2 |
|
||||
| train_cifar_87d1f_00006 | TERMINATED | | 256 | 8 | 0.0839654 | 16 | 2.3119 | 0.0993 | 1 |
|
||||
| train_cifar_87d1f_00007 | TERMINATED | | 16 | 128 | 0.0758154 | 16 | 2.33575 | 0.1327 | 1 |
|
||||
| train_cifar_87d1f_00008 | TERMINATED | | 16 | 8 | 0.0763312 | 16 | 2.31129 | 0.1042 | 4 |
|
||||
| train_cifar_87d1f_00009 | TERMINATED | | 128 | 16 | 0.000124903 | 4 | 2.26917 | 0.1945 | 1 |
|
||||
+-------------------------+------------+-------+------+------+-------------+--------------+---------+------------+----------------------+
|
||||
|
||||
|
||||
Best trial config: {'l1': 8, 'l2': 16, 'lr': 0.0027624906698231976, 'batch_size': 16, 'data_dir': '...'}
|
||||
Best trial final validation loss: 1.1815014744281769
|
||||
Best trial final validation accuracy: 0.5836
|
||||
Best trial test set accuracy: 0.5806
|
||||
|
||||
As you can see, most trials have been stopped early in order to avoid wasting resources.
|
||||
The best performing trial achieved a validation accuracy of about 58%, which could
|
||||
be confirmed on the test set.
|
||||
|
||||
So that's it! You can now tune the parameters of your PyTorch models.
|
||||
|
||||
If you consider switching to PyTorch Lightning to get rid of some of your boilerplate
|
||||
training code, please know that we also have a walkthrough on :doc:`how to use Tune with
|
||||
PyTorch Lightning models <tune-pytorch-lightning>`.
|
||||
@@ -0,0 +1,236 @@
|
||||
# flake8: noqa
|
||||
# yapf: disable
|
||||
|
||||
# __import_begin__
|
||||
from functools import partial
|
||||
import numpy as np
|
||||
import os
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.optim as optim
|
||||
from torch.utils.data import random_split
|
||||
import torchvision
|
||||
import torchvision.transforms as transforms
|
||||
from ray import tune
|
||||
from ray.tune import CLIReporter
|
||||
from ray.tune.schedulers import ASHAScheduler
|
||||
# __import_end__
|
||||
|
||||
|
||||
# __load_data_begin__
|
||||
def load_data(data_dir="./data"):
|
||||
transform = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
|
||||
])
|
||||
|
||||
trainset = torchvision.datasets.CIFAR10(
|
||||
root=data_dir, train=True, download=True, transform=transform)
|
||||
|
||||
testset = torchvision.datasets.CIFAR10(
|
||||
root=data_dir, train=False, download=True, transform=transform)
|
||||
|
||||
return trainset, testset
|
||||
# __load_data_end__
|
||||
|
||||
|
||||
# __net_begin__
|
||||
class Net(nn.Module):
|
||||
def __init__(self, l1=120, l2=84):
|
||||
super(Net, self).__init__()
|
||||
self.conv1 = nn.Conv2d(3, 6, 5)
|
||||
self.pool = nn.MaxPool2d(2, 2)
|
||||
self.conv2 = nn.Conv2d(6, 16, 5)
|
||||
self.fc1 = nn.Linear(16 * 5 * 5, l1)
|
||||
self.fc2 = nn.Linear(l1, l2)
|
||||
self.fc3 = nn.Linear(l2, 10)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.pool(F.relu(self.conv1(x)))
|
||||
x = self.pool(F.relu(self.conv2(x)))
|
||||
x = x.view(-1, 16 * 5 * 5)
|
||||
x = F.relu(self.fc1(x))
|
||||
x = F.relu(self.fc2(x))
|
||||
x = self.fc3(x)
|
||||
return x
|
||||
# __net_end__
|
||||
|
||||
|
||||
# __train_begin__
|
||||
def train_cifar(config, checkpoint=None, data_dir=None):
|
||||
net = Net(config["l1"], config["l2"])
|
||||
|
||||
device = "cpu"
|
||||
if torch.cuda.is_available():
|
||||
device = "cuda:0"
|
||||
if torch.cuda.device_count() > 1:
|
||||
net = nn.DataParallel(net)
|
||||
net.to(device)
|
||||
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
optimizer = optim.SGD(net.parameters(), lr=config["lr"], momentum=0.9)
|
||||
|
||||
if checkpoint:
|
||||
print("loading checkpoint {}".format(checkpoint))
|
||||
model_state, optimizer_state = torch.load(checkpoint)
|
||||
net.load_state_dict(model_state)
|
||||
optimizer.load_state_dict(optimizer_state)
|
||||
|
||||
trainset, testset = load_data(data_dir)
|
||||
|
||||
test_abs = int(len(trainset) * 0.8)
|
||||
train_subset, val_subset = random_split(
|
||||
trainset, [test_abs, len(trainset) - test_abs])
|
||||
|
||||
trainloader = torch.utils.data.DataLoader(
|
||||
train_subset,
|
||||
batch_size=int(config["batch_size"]),
|
||||
shuffle=True,
|
||||
num_workers=8)
|
||||
valloader = torch.utils.data.DataLoader(
|
||||
val_subset,
|
||||
batch_size=int(config["batch_size"]),
|
||||
shuffle=True,
|
||||
num_workers=8)
|
||||
|
||||
for epoch in range(10): # loop over the dataset multiple times
|
||||
running_loss = 0.0
|
||||
epoch_steps = 0
|
||||
for i, data in enumerate(trainloader, 0):
|
||||
# get the inputs; data is a list of [inputs, labels]
|
||||
inputs, labels = data
|
||||
inputs, labels = inputs.to(device), labels.to(device)
|
||||
|
||||
# zero the parameter gradients
|
||||
optimizer.zero_grad()
|
||||
|
||||
# forward + backward + optimize
|
||||
outputs = net(inputs)
|
||||
loss = criterion(outputs, labels)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
# print statistics
|
||||
running_loss += loss.item()
|
||||
epoch_steps += 1
|
||||
if i % 2000 == 1999: # print every 2000 mini-batches
|
||||
print("[%d, %5d] loss: %.3f" % (epoch + 1, i + 1,
|
||||
running_loss / epoch_steps))
|
||||
running_loss = 0.0
|
||||
|
||||
# Validation loss
|
||||
val_loss = 0.0
|
||||
val_steps = 0
|
||||
total = 0
|
||||
correct = 0
|
||||
for i, data in enumerate(valloader, 0):
|
||||
with torch.no_grad():
|
||||
inputs, labels = data
|
||||
inputs, labels = inputs.to(device), labels.to(device)
|
||||
|
||||
outputs = net(inputs)
|
||||
_, predicted = torch.max(outputs.data, 1)
|
||||
total += labels.size(0)
|
||||
correct += (predicted == labels).sum().item()
|
||||
|
||||
loss = criterion(outputs, labels)
|
||||
val_loss += loss.cpu().numpy()
|
||||
val_steps += 1
|
||||
|
||||
checkpoint_dir = tune.make_checkpoint_dir(epoch)
|
||||
path = os.path.join(checkpoint_dir, "checkpoint")
|
||||
torch.save((net.state_dict(), optimizer.state_dict()), path)
|
||||
tune.save_checkpoint(path)
|
||||
|
||||
tune.report(loss=(val_loss / val_steps), accuracy=correct / total)
|
||||
print("Finished Training")
|
||||
# __train_end__
|
||||
|
||||
|
||||
# __test_acc_begin__
|
||||
def test_accuracy(net, device="cpu"):
|
||||
trainset, testset = load_data()
|
||||
|
||||
testloader = torch.utils.data.DataLoader(
|
||||
testset, batch_size=4, shuffle=False, num_workers=2)
|
||||
|
||||
correct = 0
|
||||
total = 0
|
||||
with torch.no_grad():
|
||||
for data in testloader:
|
||||
images, labels = data
|
||||
images, labels = images.to(device), labels.to(device)
|
||||
outputs = net(images)
|
||||
_, predicted = torch.max(outputs.data, 1)
|
||||
total += labels.size(0)
|
||||
correct += (predicted == labels).sum().item()
|
||||
|
||||
return correct / total
|
||||
# __test_acc_end__
|
||||
|
||||
|
||||
# __main_begin__
|
||||
def main(num_samples=10, max_num_epochs=10, gpus_per_trial=2):
|
||||
data_dir = os.path.abspath("./data")
|
||||
load_data(data_dir)
|
||||
config = {
|
||||
"l1": tune.sample_from(lambda _: 2 ** np.random.randint(2, 9)),
|
||||
"l2": tune.sample_from(lambda _: 2 ** np.random.randint(2, 9)),
|
||||
"lr": tune.loguniform(1e-4, 1e-1),
|
||||
"batch_size": tune.choice([2, 4, 8, 16])
|
||||
}
|
||||
scheduler = ASHAScheduler(
|
||||
metric="loss",
|
||||
mode="min",
|
||||
max_t=max_num_epochs,
|
||||
grace_period=1,
|
||||
reduction_factor=2)
|
||||
reporter = CLIReporter(
|
||||
# parameter_columns=["l1", "l2", "lr", "batch_size"],
|
||||
metric_columns=["loss", "accuracy", "training_iteration"])
|
||||
result = tune.run(
|
||||
partial(train_cifar, data_dir=data_dir),
|
||||
resources_per_trial={"cpu": 2, "gpu": gpus_per_trial},
|
||||
config=config,
|
||||
num_samples=num_samples,
|
||||
scheduler=scheduler,
|
||||
progress_reporter=reporter,
|
||||
checkpoint_at_end=True)
|
||||
|
||||
best_trial = result.get_best_trial("loss", "min", "last")
|
||||
print("Best trial config: {}".format(best_trial.config))
|
||||
print("Best trial final validation loss: {}".format(
|
||||
best_trial.last_result["loss"]))
|
||||
print("Best trial final validation accuracy: {}".format(
|
||||
best_trial.last_result["accuracy"]))
|
||||
|
||||
best_trained_model = Net(best_trial.config["l1"], best_trial.config["l2"])
|
||||
device = "cpu"
|
||||
if torch.cuda.is_available():
|
||||
device = "cuda:0"
|
||||
if gpus_per_trial > 1:
|
||||
best_trained_model = nn.DataParallel(best_trained_model)
|
||||
best_trained_model.to(device)
|
||||
|
||||
model_state, optimizer_state = torch.load(best_trial.checkpoint.value)
|
||||
best_trained_model.load_state_dict(model_state)
|
||||
|
||||
test_acc = test_accuracy(best_trained_model, device)
|
||||
print("Best trial test set accuracy: {}".format(test_acc))
|
||||
# __main_end__
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--smoke-test", action="store_true", help="Finish quickly for testing")
|
||||
args, _ = parser.parse_known_args()
|
||||
|
||||
if args.smoke_test:
|
||||
main(num_samples=1, max_num_epochs=1, gpus_per_trial=0)
|
||||
else:
|
||||
# Change this to activate training on GPUs
|
||||
main(num_samples=10, max_num_epochs=10, gpus_per_trial=0)
|
||||
@@ -185,7 +185,7 @@ def train_mnist_tune_checkpoint(config, checkpoint=None):
|
||||
|
||||
|
||||
# __tune_asha_begin__
|
||||
def tune_mnist_asha():
|
||||
def tune_mnist_asha(num_samples=10, max_num_epochs=10):
|
||||
data_dir = mkdtemp(prefix="mnist_data_")
|
||||
LightningMNISTClassifier.download_data(data_dir)
|
||||
config = {
|
||||
@@ -198,7 +198,7 @@ def tune_mnist_asha():
|
||||
scheduler = ASHAScheduler(
|
||||
metric="loss",
|
||||
mode="min",
|
||||
max_t=10,
|
||||
max_t=max_num_epochs,
|
||||
grace_period=1,
|
||||
reduction_factor=2)
|
||||
reporter = CLIReporter(
|
||||
@@ -208,7 +208,7 @@ def tune_mnist_asha():
|
||||
train_mnist_tune,
|
||||
resources_per_trial={"cpu": 1},
|
||||
config=config,
|
||||
num_samples=10,
|
||||
num_samples=num_samples,
|
||||
scheduler=scheduler,
|
||||
progress_reporter=reporter)
|
||||
shutil.rmtree(data_dir)
|
||||
@@ -250,5 +250,15 @@ def tune_mnist_pbt():
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# tune_mnist_asha() # ASHA scheduler
|
||||
tune_mnist_pbt() # population based training
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--smoke-test", action="store_true", help="Finish quickly for testing")
|
||||
args, _ = parser.parse_known_args()
|
||||
|
||||
if args.smoke_test:
|
||||
tune_mnist_asha(1, 1)
|
||||
else:
|
||||
tune_mnist_asha() # ASHA scheduler
|
||||
tune_mnist_pbt() # population based training
|
||||
|
||||
Reference in New Issue
Block a user