mirror of
https://github.com/wassname/ray.git
synced 2026-07-03 12:19:50 +08:00
0c3b9ebeef
Co-authored-by: krfricke <krfricke@users.noreply.github.com> Co-authored-by: Amog Kamsetty <amogkam@users.noreply.github.com>
239 lines
7.5 KiB
Python
239 lines
7.5 KiB
Python
# flake8: noqa
|
|
# yapf: disable
|
|
|
|
# __import_begin__
|
|
from functools import partial
|
|
import numpy as np
|
|
import os
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
import torch.optim as optim
|
|
from torch.utils.data import random_split
|
|
import torchvision
|
|
import torchvision.transforms as transforms
|
|
from ray import tune
|
|
from ray.tune import CLIReporter
|
|
from ray.tune.schedulers import ASHAScheduler
|
|
# __import_end__
|
|
|
|
|
|
# __load_data_begin__
|
|
def load_data(data_dir="./data"):
|
|
transform = transforms.Compose([
|
|
transforms.ToTensor(),
|
|
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
|
|
])
|
|
|
|
trainset = torchvision.datasets.CIFAR10(
|
|
root=data_dir, train=True, download=True, transform=transform)
|
|
|
|
testset = torchvision.datasets.CIFAR10(
|
|
root=data_dir, train=False, download=True, transform=transform)
|
|
|
|
return trainset, testset
|
|
# __load_data_end__
|
|
|
|
|
|
# __net_begin__
|
|
class Net(nn.Module):
|
|
def __init__(self, l1=120, l2=84):
|
|
super(Net, self).__init__()
|
|
self.conv1 = nn.Conv2d(3, 6, 5)
|
|
self.pool = nn.MaxPool2d(2, 2)
|
|
self.conv2 = nn.Conv2d(6, 16, 5)
|
|
self.fc1 = nn.Linear(16 * 5 * 5, l1)
|
|
self.fc2 = nn.Linear(l1, l2)
|
|
self.fc3 = nn.Linear(l2, 10)
|
|
|
|
def forward(self, x):
|
|
x = self.pool(F.relu(self.conv1(x)))
|
|
x = self.pool(F.relu(self.conv2(x)))
|
|
x = x.view(-1, 16 * 5 * 5)
|
|
x = F.relu(self.fc1(x))
|
|
x = F.relu(self.fc2(x))
|
|
x = self.fc3(x)
|
|
return x
|
|
# __net_end__
|
|
|
|
|
|
# __train_begin__
|
|
def train_cifar(config, checkpoint_dir=None, data_dir=None):
|
|
net = Net(config["l1"], config["l2"])
|
|
|
|
device = "cpu"
|
|
if torch.cuda.is_available():
|
|
device = "cuda:0"
|
|
if torch.cuda.device_count() > 1:
|
|
net = nn.DataParallel(net)
|
|
net.to(device)
|
|
|
|
criterion = nn.CrossEntropyLoss()
|
|
optimizer = optim.SGD(net.parameters(), lr=config["lr"], momentum=0.9)
|
|
|
|
if checkpoint_dir:
|
|
checkpoint = os.path.join(checkpoint_dir, "checkpoint")
|
|
model_state, optimizer_state = torch.load(checkpoint)
|
|
net.load_state_dict(model_state)
|
|
optimizer.load_state_dict(optimizer_state)
|
|
|
|
trainset, testset = load_data(data_dir)
|
|
|
|
test_abs = int(len(trainset) * 0.8)
|
|
train_subset, val_subset = random_split(
|
|
trainset, [test_abs, len(trainset) - test_abs])
|
|
|
|
trainloader = torch.utils.data.DataLoader(
|
|
train_subset,
|
|
batch_size=int(config["batch_size"]),
|
|
shuffle=True,
|
|
num_workers=8)
|
|
valloader = torch.utils.data.DataLoader(
|
|
val_subset,
|
|
batch_size=int(config["batch_size"]),
|
|
shuffle=True,
|
|
num_workers=8)
|
|
|
|
for epoch in range(10): # loop over the dataset multiple times
|
|
running_loss = 0.0
|
|
epoch_steps = 0
|
|
for i, data in enumerate(trainloader, 0):
|
|
# get the inputs; data is a list of [inputs, labels]
|
|
inputs, labels = data
|
|
inputs, labels = inputs.to(device), labels.to(device)
|
|
|
|
# zero the parameter gradients
|
|
optimizer.zero_grad()
|
|
|
|
# forward + backward + optimize
|
|
outputs = net(inputs)
|
|
loss = criterion(outputs, labels)
|
|
loss.backward()
|
|
optimizer.step()
|
|
|
|
# print statistics
|
|
running_loss += loss.item()
|
|
epoch_steps += 1
|
|
if i % 2000 == 1999: # print every 2000 mini-batches
|
|
print("[%d, %5d] loss: %.3f" % (epoch + 1, i + 1,
|
|
running_loss / epoch_steps))
|
|
running_loss = 0.0
|
|
|
|
# Validation loss
|
|
val_loss = 0.0
|
|
val_steps = 0
|
|
total = 0
|
|
correct = 0
|
|
for i, data in enumerate(valloader, 0):
|
|
with torch.no_grad():
|
|
inputs, labels = data
|
|
inputs, labels = inputs.to(device), labels.to(device)
|
|
|
|
outputs = net(inputs)
|
|
_, predicted = torch.max(outputs.data, 1)
|
|
total += labels.size(0)
|
|
correct += (predicted == labels).sum().item()
|
|
|
|
loss = criterion(outputs, labels)
|
|
val_loss += loss.cpu().numpy()
|
|
val_steps += 1
|
|
|
|
with tune.checkpoint_dir(step=epoch) as checkpoint_dir:
|
|
path = os.path.join(checkpoint_dir, "checkpoint")
|
|
torch.save(
|
|
(net.state_dict(), optimizer.state_dict()), path)
|
|
|
|
tune.report(loss=(val_loss / val_steps), accuracy=correct / total)
|
|
print("Finished Training")
|
|
# __train_end__
|
|
|
|
|
|
# __test_acc_begin__
|
|
def test_accuracy(net, device="cpu"):
|
|
trainset, testset = load_data()
|
|
|
|
testloader = torch.utils.data.DataLoader(
|
|
testset, batch_size=4, shuffle=False, num_workers=2)
|
|
|
|
correct = 0
|
|
total = 0
|
|
with torch.no_grad():
|
|
for data in testloader:
|
|
images, labels = data
|
|
images, labels = images.to(device), labels.to(device)
|
|
outputs = net(images)
|
|
_, predicted = torch.max(outputs.data, 1)
|
|
total += labels.size(0)
|
|
correct += (predicted == labels).sum().item()
|
|
|
|
return correct / total
|
|
# __test_acc_end__
|
|
|
|
|
|
# __main_begin__
|
|
def main(num_samples=10, max_num_epochs=10, gpus_per_trial=2):
|
|
data_dir = os.path.abspath("./data")
|
|
load_data(data_dir)
|
|
config = {
|
|
"l1": tune.sample_from(lambda _: 2 ** np.random.randint(2, 9)),
|
|
"l2": tune.sample_from(lambda _: 2 ** np.random.randint(2, 9)),
|
|
"lr": tune.loguniform(1e-4, 1e-1),
|
|
"batch_size": tune.choice([2, 4, 8, 16])
|
|
}
|
|
scheduler = ASHAScheduler(
|
|
metric="loss",
|
|
mode="min",
|
|
max_t=max_num_epochs,
|
|
grace_period=1,
|
|
reduction_factor=2)
|
|
reporter = CLIReporter(
|
|
# parameter_columns=["l1", "l2", "lr", "batch_size"],
|
|
metric_columns=["loss", "accuracy", "training_iteration"])
|
|
result = tune.run(
|
|
partial(train_cifar, data_dir=data_dir),
|
|
resources_per_trial={"cpu": 2, "gpu": gpus_per_trial},
|
|
config=config,
|
|
num_samples=num_samples,
|
|
scheduler=scheduler,
|
|
progress_reporter=reporter,
|
|
checkpoint_at_end=True)
|
|
|
|
best_trial = result.get_best_trial("loss", "min", "last")
|
|
print("Best trial config: {}".format(best_trial.config))
|
|
print("Best trial final validation loss: {}".format(
|
|
best_trial.last_result["loss"]))
|
|
print("Best trial final validation accuracy: {}".format(
|
|
best_trial.last_result["accuracy"]))
|
|
|
|
best_trained_model = Net(best_trial.config["l1"], best_trial.config["l2"])
|
|
device = "cpu"
|
|
if torch.cuda.is_available():
|
|
device = "cuda:0"
|
|
if gpus_per_trial > 1:
|
|
best_trained_model = nn.DataParallel(best_trained_model)
|
|
best_trained_model.to(device)
|
|
|
|
checkpoint_path = os.path.join(best_trial.checkpoint.value, "checkpoint")
|
|
|
|
model_state, optimizer_state = torch.load(checkpoint_path)
|
|
best_trained_model.load_state_dict(model_state)
|
|
|
|
test_acc = test_accuracy(best_trained_model, device)
|
|
print("Best trial test set accuracy: {}".format(test_acc))
|
|
# __main_end__
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import argparse
|
|
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument(
|
|
"--smoke-test", action="store_true", help="Finish quickly for testing")
|
|
args, _ = parser.parse_known_args()
|
|
|
|
if args.smoke_test:
|
|
main(num_samples=1, max_num_epochs=1, gpus_per_trial=0)
|
|
else:
|
|
# Change this to activate training on GPUs
|
|
main(num_samples=10, max_num_epochs=10, gpus_per_trial=0)
|