mirror of
https://github.com/wassname/ray.git
synced 2026-06-29 07:07:00 +08:00
144 lines
4.4 KiB
Python
144 lines
4.4 KiB
Python
# Original Code here:
|
|
# https://github.com/pytorch/examples/blob/master/mnist/main.py
|
|
import os
|
|
import numpy as np
|
|
import argparse
|
|
from filelock import FileLock
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
import torch.optim as optim
|
|
from torchvision import datasets, transforms
|
|
|
|
import ray
|
|
from ray import tune
|
|
from ray.tune import track
|
|
from ray.tune.schedulers import AsyncHyperBandScheduler
|
|
|
|
# Change these values if you want the training to run quicker or slower.
|
|
EPOCH_SIZE = 512
|
|
TEST_SIZE = 256
|
|
|
|
|
|
class ConvNet(nn.Module):
|
|
def __init__(self):
|
|
super(ConvNet, self).__init__()
|
|
self.conv1 = nn.Conv2d(1, 3, kernel_size=3)
|
|
self.fc = nn.Linear(192, 10)
|
|
|
|
def forward(self, x):
|
|
x = F.relu(F.max_pool2d(self.conv1(x), 3))
|
|
x = x.view(-1, 192)
|
|
x = self.fc(x)
|
|
return F.log_softmax(x, dim=1)
|
|
|
|
|
|
def train(model, optimizer, train_loader, device=torch.device("cpu")):
|
|
model.train()
|
|
for batch_idx, (data, target) in enumerate(train_loader):
|
|
if batch_idx * len(data) > EPOCH_SIZE:
|
|
return
|
|
data, target = data.to(device), target.to(device)
|
|
optimizer.zero_grad()
|
|
output = model(data)
|
|
loss = F.nll_loss(output, target)
|
|
loss.backward()
|
|
optimizer.step()
|
|
|
|
|
|
def test(model, data_loader, device=torch.device("cpu")):
|
|
model.eval()
|
|
correct = 0
|
|
total = 0
|
|
with torch.no_grad():
|
|
for batch_idx, (data, target) in enumerate(data_loader):
|
|
if batch_idx * len(data) > TEST_SIZE:
|
|
break
|
|
data, target = data.to(device), target.to(device)
|
|
outputs = model(data)
|
|
_, predicted = torch.max(outputs.data, 1)
|
|
total += target.size(0)
|
|
correct += (predicted == target).sum().item()
|
|
|
|
return correct / total
|
|
|
|
|
|
def get_data_loaders():
|
|
mnist_transforms = transforms.Compose(
|
|
[transforms.ToTensor(),
|
|
transforms.Normalize((0.1307, ), (0.3081, ))])
|
|
|
|
# We add FileLock here because multiple workers will want to
|
|
# download data, and this may cause overwrites since
|
|
# DataLoader is not threadsafe.
|
|
with FileLock(os.path.expanduser("~/data.lock")):
|
|
train_loader = torch.utils.data.DataLoader(
|
|
datasets.MNIST(
|
|
"~/data",
|
|
train=True,
|
|
download=True,
|
|
transform=mnist_transforms),
|
|
batch_size=64,
|
|
shuffle=True)
|
|
test_loader = torch.utils.data.DataLoader(
|
|
datasets.MNIST("~/data", train=False, transform=mnist_transforms),
|
|
batch_size=64,
|
|
shuffle=True)
|
|
return train_loader, test_loader
|
|
|
|
|
|
def train_mnist(config):
|
|
use_cuda = config.get("use_gpu") and torch.cuda.is_available()
|
|
device = torch.device("cuda" if use_cuda else "cpu")
|
|
train_loader, test_loader = get_data_loaders()
|
|
model = ConvNet().to(device)
|
|
|
|
optimizer = optim.SGD(
|
|
model.parameters(), lr=config["lr"], momentum=config["momentum"])
|
|
|
|
while True:
|
|
train(model, optimizer, train_loader, device)
|
|
acc = test(model, test_loader, device)
|
|
track.log(mean_accuracy=acc)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser(description="PyTorch MNIST Example")
|
|
parser.add_argument(
|
|
"--cuda",
|
|
action="store_true",
|
|
default=False,
|
|
help="Enables GPU training")
|
|
parser.add_argument(
|
|
"--smoke-test", action="store_true", help="Finish quickly for testing")
|
|
parser.add_argument(
|
|
"--ray-address",
|
|
help="Address of Ray cluster for seamless distributed execution.")
|
|
args = parser.parse_args()
|
|
if args.ray_address:
|
|
ray.init(address=args.ray_address)
|
|
else:
|
|
ray.init(num_cpus=2 if args.smoke_test else None)
|
|
sched = AsyncHyperBandScheduler(
|
|
time_attr="training_iteration", metric="mean_accuracy")
|
|
analysis = tune.run(
|
|
train_mnist,
|
|
name="exp",
|
|
scheduler=sched,
|
|
stop={
|
|
"mean_accuracy": 0.98,
|
|
"training_iteration": 5 if args.smoke_test else 100
|
|
},
|
|
resources_per_trial={
|
|
"cpu": 2,
|
|
"gpu": int(args.cuda)
|
|
},
|
|
num_samples=1 if args.smoke_test else 50,
|
|
config={
|
|
"lr": tune.sample_from(lambda spec: 10**(-10 * np.random.rand())),
|
|
"momentum": tune.uniform(0.1, 0.9),
|
|
"use_gpu": int(args.cuda)
|
|
})
|
|
|
|
print("Best config is:", analysis.get_best_config(metric="mean_accuracy"))
|