[tune] a tiny ptl example (#11497)

This commit is contained in:
Richard Liaw
2020-10-22 18:50:34 -07:00
committed by GitHub
parent 4348ecf850
commit e7aa6441b7
7 changed files with 146 additions and 7 deletions
+9
View File
@@ -475,6 +475,15 @@ py_test(
args = ["--smoke-test"]
)
py_test(
name = "mnist_ptl_mini",
size = "medium",
srcs = ["examples/mnist_ptl_mini.py"],
deps = [":tune_lib"],
tags = ["exclusive", "example", "pytorch"],
args = ["--smoke-test"]
)
py_test(
name = "mnist_pytorch_trainable",
size = "small",
+117
View File
@@ -0,0 +1,117 @@
import torch
from torch.nn import functional as F
import pytorch_lightning as pl
from pl_bolts.datamodules import MNISTDataModule
import os
from ray.tune.integration.pytorch_lightning import TuneReportCallback
import tempfile
from ray import tune
class LightningMNISTClassifier(pl.LightningModule):
def __init__(self, config, data_dir=None):
super(LightningMNISTClassifier, self).__init__()
self.data_dir = data_dir or os.getcwd()
self.lr = config["lr"]
layer_1, layer_2 = config["layer_1"], config["layer_2"]
# mnist images are (1, 28, 28) (channels, width, height)
self.layer_1 = torch.nn.Linear(28 * 28, layer_1)
self.layer_2 = torch.nn.Linear(layer_1, layer_2)
self.layer_3 = torch.nn.Linear(layer_2, 10)
self.accuracy = pl.metrics.Accuracy()
def forward(self, x):
batch_size, channels, width, height = x.size()
x = x.view(batch_size, -1)
x = self.layer_1(x)
x = torch.relu(x)
x = self.layer_2(x)
x = torch.relu(x)
x = self.layer_3(x)
x = torch.log_softmax(x, dim=1)
return x
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=self.lr)
def training_step(self, train_batch, batch_idx):
x, y = train_batch
logits = self.forward(x)
loss = F.nll_loss(logits, y)
acc = self.accuracy(logits, y)
self.log("ptl/train_loss", loss)
self.log("ptl/train_accuracy", acc)
return loss
def validation_step(self, val_batch, batch_idx):
x, y = val_batch
logits = self.forward(x)
loss = F.nll_loss(logits, y)
acc = self.accuracy(logits, y)
return {"val_loss": loss, "val_accuracy": acc}
def validation_epoch_end(self, outputs):
avg_loss = torch.stack([x["val_loss"] for x in outputs]).mean()
avg_acc = torch.stack([x["val_accuracy"] for x in outputs]).mean()
self.log("ptl/val_loss", avg_loss)
self.log("ptl/val_accuracy", avg_acc)
def train_mnist_tune(config, data_dir=None, num_epochs=10, num_gpus=0):
model = LightningMNISTClassifier(config, data_dir)
dm = MNISTDataModule(
data_dir=data_dir, num_workers=1, batch_size=config["batch_size"])
metrics = {"loss": "ptl/val_loss", "acc": "ptl/val_accuracy"}
trainer = pl.Trainer(
max_epochs=num_epochs,
gpus=num_gpus,
progress_bar_refresh_rate=0,
callbacks=[TuneReportCallback(metrics, on="validation_end")])
trainer.fit(model, dm)
def tune_mnist(num_samples=10, num_epochs=10, gpus_per_trial=0):
data_dir = os.path.join(tempfile.gettempdir(), "mnist_data_")
# Download data
MNISTDataModule(data_dir=data_dir).prepare_data()
config = {
"layer_1": tune.choice([32, 64, 128]),
"layer_2": tune.choice([64, 128, 256]),
"lr": tune.loguniform(1e-4, 1e-1),
"batch_size": tune.choice([32, 64, 128]),
}
trainable = tune.with_parameters(
train_mnist_tune,
data_dir=data_dir,
num_epochs=num_epochs,
num_gpus=gpus_per_trial)
tune.run(
trainable,
resources_per_trial={
"cpu": 1,
"gpu": gpus_per_trial
},
metric="loss",
mode="min",
config=config,
num_samples=num_samples,
name="tune_mnist")
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument(
"--smoke-test", action="store_true", help="Finish quickly for testing")
args, _ = parser.parse_known_args()
if args.smoke_test:
tune_mnist(num_samples=1, num_epochs=1, gpus_per_trial=0)
else:
tune_mnist(num_samples=10, num_epochs=10, gpus_per_trial=0)
@@ -185,7 +185,8 @@ def train(netD, netG, optimG, optimD, criterion, dataloader, iteration, device,
netD.zero_grad()
real_cpu = data[0].to(device)
b_size = real_cpu.size(0)
label = torch.full((b_size, ), real_label, device=device)
label = torch.full(
(b_size, ), real_label, dtype=torch.float, device=device)
output = netD(real_cpu).view(-1)
errD_real = criterion(output, label)
errD_real.backward()
+1
View File
@@ -18,6 +18,7 @@ nevergrad
optuna
pytest-remotedata>=0.3.1
pytorch-lightning
pytorch-lightning-bolts
scikit-optimize
sigopt
smart_open