mirror of
https://github.com/wassname/ray.git
synced 2026-07-02 11:37:25 +08:00
[tune] Update MNIST Example (#4991)
This commit is contained in:
@@ -1,7 +1,10 @@
|
||||
# Original Code here:
|
||||
# https://github.com/pytorch/examples/blob/master/mnist/main.py
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
import argparse
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -9,181 +12,123 @@ import torch.nn.functional as F
|
||||
import torch.optim as optim
|
||||
from torchvision import datasets, transforms
|
||||
|
||||
# Training settings
|
||||
parser = argparse.ArgumentParser(description="PyTorch MNIST Example")
|
||||
parser.add_argument(
|
||||
"--batch-size",
|
||||
type=int,
|
||||
default=64,
|
||||
metavar="N",
|
||||
help="input batch size for training (default: 64)")
|
||||
parser.add_argument(
|
||||
"--test-batch-size",
|
||||
type=int,
|
||||
default=1000,
|
||||
metavar="N",
|
||||
help="input batch size for testing (default: 1000)")
|
||||
parser.add_argument(
|
||||
"--epochs",
|
||||
type=int,
|
||||
default=1,
|
||||
metavar="N",
|
||||
help="number of epochs to train (default: 1)")
|
||||
parser.add_argument(
|
||||
"--lr",
|
||||
type=float,
|
||||
default=0.01,
|
||||
metavar="LR",
|
||||
help="learning rate (default: 0.01)")
|
||||
parser.add_argument(
|
||||
"--momentum",
|
||||
type=float,
|
||||
default=0.5,
|
||||
metavar="M",
|
||||
help="SGD momentum (default: 0.5)")
|
||||
parser.add_argument(
|
||||
"--no-cuda",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="disables CUDA training")
|
||||
parser.add_argument(
|
||||
"--seed",
|
||||
type=int,
|
||||
default=1,
|
||||
metavar="S",
|
||||
help="random seed (default: 1)")
|
||||
parser.add_argument(
|
||||
"--smoke-test", action="store_true", help="Finish quickly for testing")
|
||||
import ray
|
||||
from ray import tune
|
||||
from ray.tune import track
|
||||
from ray.tune.schedulers import AsyncHyperBandScheduler
|
||||
|
||||
# Change these values if you want the training to run quicker or slower.
|
||||
EPOCH_SIZE = 512
|
||||
TEST_SIZE = 256
|
||||
|
||||
|
||||
def train_mnist(args, config, reporter):
|
||||
vars(args).update(config)
|
||||
args.cuda = not args.no_cuda and torch.cuda.is_available()
|
||||
class Net(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(Net, self).__init__()
|
||||
self.conv1 = nn.Conv2d(1, 3, kernel_size=3)
|
||||
self.fc = nn.Linear(192, 10)
|
||||
|
||||
torch.manual_seed(args.seed)
|
||||
if args.cuda:
|
||||
torch.cuda.manual_seed(args.seed)
|
||||
def forward(self, x):
|
||||
x = F.relu(F.max_pool2d(self.conv1(x), 3))
|
||||
x = x.view(-1, 192)
|
||||
x = self.fc(x)
|
||||
return F.log_softmax(x, dim=1)
|
||||
|
||||
|
||||
def train(model, optimizer, train_loader, device):
|
||||
model.train()
|
||||
for batch_idx, (data, target) in enumerate(train_loader):
|
||||
if batch_idx * len(data) > EPOCH_SIZE:
|
||||
return
|
||||
data, target = data.to(device), target.to(device)
|
||||
optimizer.zero_grad()
|
||||
output = model(data)
|
||||
loss = F.nll_loss(output, target)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
|
||||
def test(model, data_loader, device):
|
||||
model.eval()
|
||||
correct = 0
|
||||
total = 0
|
||||
with torch.no_grad():
|
||||
for batch_idx, (data, target) in enumerate(data_loader):
|
||||
if batch_idx * len(data) > TEST_SIZE:
|
||||
break
|
||||
data, target = data.to(device), target.to(device)
|
||||
outputs = model(data)
|
||||
_, predicted = torch.max(outputs.data, 1)
|
||||
total += target.size(0)
|
||||
correct += (predicted == target).sum().item()
|
||||
|
||||
return correct / total
|
||||
|
||||
|
||||
def get_data_loaders():
|
||||
mnist_transforms = transforms.Compose(
|
||||
[transforms.ToTensor(),
|
||||
transforms.Normalize((0.1307, ), (0.3081, ))])
|
||||
|
||||
kwargs = {"num_workers": 1, "pin_memory": True} if args.cuda else {}
|
||||
train_loader = torch.utils.data.DataLoader(
|
||||
datasets.MNIST(
|
||||
"~/data",
|
||||
train=True,
|
||||
download=False,
|
||||
transform=transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.1307, ), (0.3081, ))
|
||||
])),
|
||||
batch_size=args.batch_size,
|
||||
shuffle=True,
|
||||
**kwargs)
|
||||
"~/data", train=True, download=True, transform=mnist_transforms),
|
||||
batch_size=64,
|
||||
shuffle=True)
|
||||
test_loader = torch.utils.data.DataLoader(
|
||||
datasets.MNIST(
|
||||
"~/data",
|
||||
train=False,
|
||||
transform=transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.1307, ), (0.3081, ))
|
||||
])),
|
||||
batch_size=args.test_batch_size,
|
||||
shuffle=True,
|
||||
**kwargs)
|
||||
datasets.MNIST("~/data", train=False, transform=mnist_transforms),
|
||||
batch_size=64,
|
||||
shuffle=True)
|
||||
return train_loader, test_loader
|
||||
|
||||
class Net(nn.Module):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
|
||||
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
|
||||
self.conv2_drop = nn.Dropout2d()
|
||||
self.fc1 = nn.Linear(320, 50)
|
||||
self.fc2 = nn.Linear(50, 10)
|
||||
|
||||
def forward(self, x):
|
||||
x = F.relu(F.max_pool2d(self.conv1(x), 2))
|
||||
x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
|
||||
x = x.view(-1, 320)
|
||||
x = F.relu(self.fc1(x))
|
||||
x = F.dropout(x, training=self.training)
|
||||
x = self.fc2(x)
|
||||
return F.log_softmax(x, dim=1)
|
||||
|
||||
model = Net()
|
||||
if args.cuda:
|
||||
model.cuda()
|
||||
def train_mnist(config):
|
||||
use_cuda = config.get("use_gpu") and torch.cuda.is_available()
|
||||
device = torch.device("cuda" if use_cuda else "cpu")
|
||||
train_loader, test_loader = get_data_loaders()
|
||||
model = Net(config).to(device)
|
||||
|
||||
optimizer = optim.SGD(
|
||||
model.parameters(), lr=args.lr, momentum=args.momentum)
|
||||
model.parameters(), lr=config["lr"], momentum=config["momentum"])
|
||||
|
||||
def train(epoch):
|
||||
model.train()
|
||||
for batch_idx, (data, target) in enumerate(train_loader):
|
||||
if args.cuda:
|
||||
data, target = data.cuda(), target.cuda()
|
||||
optimizer.zero_grad()
|
||||
output = model(data)
|
||||
loss = F.nll_loss(output, target)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
def test():
|
||||
model.eval()
|
||||
test_loss = 0
|
||||
correct = 0
|
||||
with torch.no_grad():
|
||||
for data, target in test_loader:
|
||||
if args.cuda:
|
||||
data, target = data.cuda(), target.cuda()
|
||||
output = model(data)
|
||||
# sum up batch loss
|
||||
test_loss += F.nll_loss(output, target, reduction="sum").item()
|
||||
# get the index of the max log-probability
|
||||
pred = output.argmax(dim=1, keepdim=True)
|
||||
correct += pred.eq(
|
||||
target.data.view_as(pred)).long().cpu().sum()
|
||||
|
||||
test_loss = test_loss / len(test_loader.dataset)
|
||||
accuracy = correct.item() / len(test_loader.dataset)
|
||||
reporter(mean_loss=test_loss, mean_accuracy=accuracy)
|
||||
|
||||
for epoch in range(1, args.epochs + 1):
|
||||
train(epoch)
|
||||
test()
|
||||
while True:
|
||||
train(model, optimizer, train_loader, device)
|
||||
acc = test(model, test_loader, device)
|
||||
track.log(mean_accuracy=acc)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
datasets.MNIST("~/data", train=True, download=True)
|
||||
parser = argparse.ArgumentParser(description="PyTorch MNIST Example")
|
||||
parser.add_argument(
|
||||
"--cuda",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Enables GPU training")
|
||||
parser.add_argument(
|
||||
"--smoke-test", action="store_true", help="Finish quickly for testing")
|
||||
parser.add_argument(
|
||||
"--ray-redis-address",
|
||||
help="Address of Ray cluster for seamless distributed execution.")
|
||||
args = parser.parse_args()
|
||||
|
||||
import ray
|
||||
from ray import tune
|
||||
from ray.tune.schedulers import AsyncHyperBandScheduler
|
||||
|
||||
ray.init()
|
||||
if args.ray_redis_address:
|
||||
ray.init(redis_address=args.ray_redis_address)
|
||||
sched = AsyncHyperBandScheduler(
|
||||
time_attr="training_iteration",
|
||||
metric="mean_loss",
|
||||
mode="min",
|
||||
max_t=400,
|
||||
grace_period=20)
|
||||
tune.register_trainable(
|
||||
"TRAIN_FN",
|
||||
lambda config, reporter: train_mnist(args, config, reporter))
|
||||
time_attr="training_iteration", metric="mean_accuracy")
|
||||
tune.run(
|
||||
"TRAIN_FN",
|
||||
train_mnist,
|
||||
name="exp",
|
||||
scheduler=sched,
|
||||
**{
|
||||
"stop": {
|
||||
"mean_accuracy": 0.98,
|
||||
"training_iteration": 1 if args.smoke_test else 20
|
||||
},
|
||||
"resources_per_trial": {
|
||||
"cpu": 3,
|
||||
"gpu": int(not args.no_cuda)
|
||||
},
|
||||
"num_samples": 1 if args.smoke_test else 10,
|
||||
"config": {
|
||||
"lr": tune.uniform(0.001, 0.1),
|
||||
"momentum": tune.uniform(0.1, 0.9),
|
||||
}
|
||||
stop={
|
||||
"mean_accuracy": 0.98,
|
||||
"training_iteration": 5 if args.smoke_test else 20
|
||||
},
|
||||
resources_per_trial={
|
||||
"cpu": 2,
|
||||
"gpu": int(args.cuda)
|
||||
},
|
||||
num_samples=1 if args.smoke_test else 10,
|
||||
config={
|
||||
"lr": tune.sample_from(lambda spec: 10**(-10 * np.random.rand())),
|
||||
"momentum": tune.uniform(0.1, 0.9),
|
||||
"use_gpu": int(args.cuda)
|
||||
})
|
||||
|
||||
Reference in New Issue
Block a user