mirror of
https://github.com/wassname/ray.git
synced 2026-06-27 22:38:16 +08:00
[raysgd] Custom training operator (#7211)
This commit is contained in:
@@ -3,6 +3,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
PyTorchTrainer = None
|
||||
PyTorchTrainable = None
|
||||
TrainingOperator = None
|
||||
|
||||
try:
|
||||
import torch # noqa: F401
|
||||
@@ -10,6 +11,8 @@ try:
|
||||
from ray.util.sgd.pytorch.pytorch_trainer import (PyTorchTrainer,
|
||||
PyTorchTrainable)
|
||||
|
||||
__all__ = ["PyTorchTrainer", "PyTorchTrainable"]
|
||||
from ray.util.sgd.pytorch.training_operator import TrainingOperator
|
||||
|
||||
__all__ = ["PyTorchTrainer", "PyTorchTrainable", "TrainingOperator"]
|
||||
except ImportError:
|
||||
logger.warning("PyTorch not found. PyTorchTrainer will not be available")
|
||||
|
||||
@@ -0,0 +1,7 @@
|
||||
USE_FP16 = "__use_fp16__"
|
||||
BATCH_COUNT = "batch_count"
|
||||
SCHEDULER_STEP = "scheduler_step"
|
||||
SCHEDULER_STEP_BATCH = "batch"
|
||||
SCHEDULER_STEP_EPOCH = "epoch"
|
||||
|
||||
VALID_SCHEDULER_STEP = {SCHEDULER_STEP_BATCH, SCHEDULER_STEP_EPOCH}
|
||||
@@ -95,15 +95,22 @@ class DistributedPyTorchRunner(PyTorchRunner):
|
||||
self.validation_loader = torch.utils.data.DataLoader(
|
||||
val_set, batch_size=self.batch_size, **self.dataloader_config)
|
||||
|
||||
def step(self):
|
||||
self.training_operator = self.training_operator_cls(
|
||||
self.config,
|
||||
models=self.models,
|
||||
optimizers=self.optimizers,
|
||||
criterion=self.criterion,
|
||||
schedulers=self.schedulers,
|
||||
use_fp16=self.use_fp16)
|
||||
|
||||
def train_epoch(self, **kwargs):
|
||||
"""Runs a training epoch and updates the model parameters.
|
||||
|
||||
Automatically sets epoch of sampler if possible.
|
||||
"""
|
||||
logger.debug("Starting step")
|
||||
if hasattr(self.train_loader.sampler, "set_epoch"):
|
||||
self.train_loader.sampler.set_epoch(self.epoch)
|
||||
return super(DistributedPyTorchRunner, self).step()
|
||||
self.train_loader.sampler.set_epoch(self.epochs)
|
||||
return super(DistributedPyTorchRunner, self).train_epoch(**kwargs)
|
||||
|
||||
def _get_model_state_dicts(self):
|
||||
"""Fetch state from ``model.module`` instead of ``model``.
|
||||
|
||||
@@ -10,10 +10,9 @@ import torchvision.transforms as transforms
|
||||
import ray
|
||||
from ray.util.sgd.pytorch import (PyTorchTrainer, PyTorchTrainable)
|
||||
from ray.util.sgd.pytorch.resnet import ResNet18
|
||||
from ray.util.sgd.pytorch.utils import TEST_MODE
|
||||
|
||||
|
||||
def initialization_hook(runner):
|
||||
def initialization_hook():
|
||||
print("NCCL DEBUG SET")
|
||||
# Need this for avoiding a connection restart issue
|
||||
os.environ["NCCL_SOCKET_IFNAME"] = "^docker0,lo"
|
||||
@@ -40,6 +39,11 @@ def cifar_creator(config):
|
||||
validation_dataset = torchvision.datasets.CIFAR10(
|
||||
root="~/data", train=False, download=False, transform=transform_test)
|
||||
|
||||
if config.get("test_mode"):
|
||||
train_dataset = torch.utils.data.Subset(train_dataset, list(range(64)))
|
||||
validation_dataset = torch.utils.data.Subset(validation_dataset,
|
||||
list(range(64)))
|
||||
|
||||
return train_dataset, validation_dataset
|
||||
|
||||
|
||||
@@ -58,7 +62,6 @@ def train_example(num_replicas=1,
|
||||
use_gpu=False,
|
||||
use_fp16=False,
|
||||
test_mode=False):
|
||||
config = {TEST_MODE: test_mode}
|
||||
trainer1 = PyTorchTrainer(
|
||||
ResNet18,
|
||||
cifar_creator,
|
||||
@@ -67,7 +70,10 @@ def train_example(num_replicas=1,
|
||||
scheduler_creator=scheduler_creator,
|
||||
initialization_hook=initialization_hook,
|
||||
num_replicas=num_replicas,
|
||||
config=config,
|
||||
config={
|
||||
"lr": 0.01,
|
||||
"test_mode": test_mode
|
||||
},
|
||||
use_gpu=use_gpu,
|
||||
batch_size=16 if test_mode else 512,
|
||||
backend="nccl" if use_gpu else "gloo",
|
||||
@@ -88,14 +94,14 @@ def tune_example(num_replicas=1, use_gpu=False, test_mode=False):
|
||||
"model_creator": ResNet18,
|
||||
"data_creator": cifar_creator,
|
||||
"optimizer_creator": optimizer_creator,
|
||||
"loss_creator": lambda config: nn.CrossEntropyLoss(),
|
||||
"loss_creator": nn.CrossEntropyLoss,
|
||||
"num_replicas": num_replicas,
|
||||
"initialization_hook": initialization_hook,
|
||||
"use_gpu": use_gpu,
|
||||
"batch_size": 16 if test_mode else 512,
|
||||
"config": {
|
||||
"lr": tune.choice([1e-4, 1e-3, 5e-3, 1e-2]),
|
||||
TEST_MODE: test_mode
|
||||
"lr": tune.choice([1e-4, 1e-3]),
|
||||
"test_mode": test_mode
|
||||
},
|
||||
"backend": "nccl" if use_gpu else "gloo"
|
||||
}
|
||||
|
||||
@@ -6,7 +6,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
import torch.utils.data
|
||||
import torchvision.datasets as dset
|
||||
import torchvision.datasets as datasets
|
||||
import torchvision.transforms as transforms
|
||||
import numpy as np
|
||||
|
||||
@@ -16,25 +16,12 @@ from scipy.stats import entropy
|
||||
|
||||
import ray
|
||||
from ray.util.sgd import PyTorchTrainer
|
||||
from ray.util.sgd.pytorch.utils import TEST_MODE
|
||||
|
||||
# Training parameters
|
||||
TRAIN_BATCHES = 5
|
||||
# Number of channels in the training images. For color images this is 3
|
||||
num_channels = 1
|
||||
|
||||
# Size of z latent vector (i.e. size of generator input)
|
||||
latent_vector_size = 100
|
||||
|
||||
# Size of feature maps in generator
|
||||
features_g = 32
|
||||
|
||||
# Size of feature maps in discriminator
|
||||
features_d = 32
|
||||
from ray.util.sgd.utils import override
|
||||
from ray.util.sgd.pytorch import TrainingOperator
|
||||
|
||||
|
||||
def data_creator(config):
|
||||
return dset.MNIST(
|
||||
dataset = datasets.MNIST(
|
||||
root="~/mnist/",
|
||||
download=True,
|
||||
transform=transforms.Compose([
|
||||
@@ -42,62 +29,56 @@ def data_creator(config):
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.5, ), (0.5, )),
|
||||
]))
|
||||
|
||||
|
||||
def weights_init(m):
|
||||
classname = m.__class__.__name__
|
||||
if classname.find("Conv") != -1:
|
||||
nn.init.normal_(m.weight.data, 0.0, 0.02)
|
||||
elif classname.find("BatchNorm") != -1:
|
||||
nn.init.normal_(m.weight.data, 1.0, 0.02)
|
||||
nn.init.constant_(m.bias.data, 0)
|
||||
if config.get("test_mode"):
|
||||
dataset = torch.utils.data.Subset(dataset, list(range(64)))
|
||||
return dataset
|
||||
|
||||
|
||||
class Generator(nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self, latent_vector_size, features=32, num_channels=1):
|
||||
super(Generator, self).__init__()
|
||||
self.latent_vector_size = latent_vector_size
|
||||
self.main = nn.Sequential(
|
||||
# input is Z, going into a convolution
|
||||
nn.ConvTranspose2d(
|
||||
latent_vector_size, features_g * 4, 4, 1, 0, bias=False),
|
||||
nn.BatchNorm2d(features_g * 4),
|
||||
latent_vector_size, features * 4, 4, 1, 0, bias=False),
|
||||
nn.BatchNorm2d(features * 4),
|
||||
nn.ReLU(True),
|
||||
nn.ConvTranspose2d(
|
||||
features_g * 4, features_g * 2, 4, 2, 1, bias=False),
|
||||
nn.BatchNorm2d(features_g * 2),
|
||||
features * 4, features * 2, 4, 2, 1, bias=False),
|
||||
nn.BatchNorm2d(features * 2),
|
||||
nn.ReLU(True),
|
||||
nn.ConvTranspose2d(
|
||||
features_g * 2, features_g, 4, 2, 1, bias=False),
|
||||
nn.BatchNorm2d(features_g),
|
||||
nn.ConvTranspose2d(features * 2, features, 4, 2, 1, bias=False),
|
||||
nn.BatchNorm2d(features),
|
||||
nn.ReLU(True),
|
||||
nn.ConvTranspose2d(features_g, num_channels, 4, 2, 1, bias=False),
|
||||
nn.ConvTranspose2d(features, num_channels, 4, 2, 1, bias=False),
|
||||
nn.Tanh())
|
||||
|
||||
def forward(self, input):
|
||||
return self.main(input)
|
||||
def forward(self, x):
|
||||
return self.main(x)
|
||||
|
||||
|
||||
class Discriminator(nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self, features=32, num_channels=1):
|
||||
super(Discriminator, self).__init__()
|
||||
self.main = nn.Sequential(
|
||||
nn.Conv2d(num_channels, features_d, 4, 2, 1, bias=False),
|
||||
nn.Conv2d(num_channels, features, 4, 2, 1, bias=False),
|
||||
nn.LeakyReLU(0.2, inplace=True),
|
||||
nn.Conv2d(features_d, features_d * 2, 4, 2, 1, bias=False),
|
||||
nn.BatchNorm2d(features_d * 2), nn.LeakyReLU(0.2, inplace=True),
|
||||
nn.Conv2d(features_d * 2, features_d * 4, 4, 2, 1, bias=False),
|
||||
nn.BatchNorm2d(features_d * 4), nn.LeakyReLU(0.2, inplace=True),
|
||||
nn.Conv2d(features_d * 4, 1, 4, 1, 0, bias=False), nn.Sigmoid())
|
||||
nn.Conv2d(features, features * 2, 4, 2, 1, bias=False),
|
||||
nn.BatchNorm2d(features * 2), nn.LeakyReLU(0.2, inplace=True),
|
||||
nn.Conv2d(features * 2, features * 4, 4, 2, 1, bias=False),
|
||||
nn.BatchNorm2d(features * 4), nn.LeakyReLU(0.2, inplace=True),
|
||||
nn.Conv2d(features * 4, 1, 4, 1, 0, bias=False), nn.Sigmoid())
|
||||
|
||||
def forward(self, input):
|
||||
return self.main(input)
|
||||
def forward(self, x):
|
||||
return self.main(x)
|
||||
|
||||
|
||||
class Net(nn.Module):
|
||||
class LeNet(nn.Module):
|
||||
"""LeNet for MNist classification, used for inception_score."""
|
||||
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
super(LeNet, self).__init__()
|
||||
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
|
||||
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
|
||||
self.conv2_drop = nn.Dropout2d()
|
||||
@@ -114,92 +95,22 @@ class Net(nn.Module):
|
||||
return F.log_softmax(x, dim=1)
|
||||
|
||||
|
||||
def inception_score(imgs, batch_size=32, splits=1):
|
||||
N = len(imgs)
|
||||
dtype = torch.FloatTensor
|
||||
dataloader = torch.utils.data.DataLoader(imgs, batch_size=batch_size)
|
||||
cm = Net()
|
||||
cm.load_state_dict(torch.load(model_path))
|
||||
cm.eval()
|
||||
up = nn.Upsample(size=(28, 28), mode="bilinear").type(dtype)
|
||||
|
||||
def get_pred(x):
|
||||
x = up(x)
|
||||
x = cm(x)
|
||||
return F.softmax(x).data.cpu().numpy()
|
||||
|
||||
preds = np.zeros((N, 10))
|
||||
for i, batch in enumerate(dataloader, 0):
|
||||
batch = batch.type(dtype)
|
||||
batchv = Variable(batch)
|
||||
batch_size_i = batch.size()[0]
|
||||
preds[i * batch_size:i * batch_size + batch_size_i] = get_pred(batchv)
|
||||
|
||||
# Now compute the mean kl-div
|
||||
split_scores = []
|
||||
for k in range(splits):
|
||||
part = preds[k * (N // splits):(k + 1) * (N // splits), :]
|
||||
py = np.mean(part, axis=0)
|
||||
scores = []
|
||||
for i in range(part.shape[0]):
|
||||
pyx = part[i, :]
|
||||
scores.append(entropy(pyx, py))
|
||||
split_scores.append(np.exp(np.mean(scores)))
|
||||
|
||||
return np.mean(split_scores), np.std(split_scores)
|
||||
|
||||
|
||||
def model_creator(config):
|
||||
netD = Discriminator()
|
||||
netD.apply(weights_init)
|
||||
def weights_init(m):
|
||||
classname = m.__class__.__name__
|
||||
if classname.find("Conv") != -1:
|
||||
nn.init.normal_(m.weight.data, 0.0, 0.02)
|
||||
elif classname.find("BatchNorm") != -1:
|
||||
nn.init.normal_(m.weight.data, 1.0, 0.02)
|
||||
nn.init.constant_(m.bias.data, 0)
|
||||
|
||||
netG = Generator()
|
||||
netG.apply(weights_init)
|
||||
return netD, netG
|
||||
discriminator = Discriminator()
|
||||
discriminator.apply(weights_init)
|
||||
|
||||
|
||||
def train(config, models, dataloader, criterion, optimizers, **kwargs):
|
||||
netD, netG = models
|
||||
optimD, optimG = optimizers
|
||||
real_label = 1
|
||||
fake_label = 0
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
for i, data in enumerate(dataloader, 0):
|
||||
if i >= TRAIN_BATCHES and config.get(TEST_MODE):
|
||||
break
|
||||
|
||||
netD.zero_grad()
|
||||
real_cpu = data[0].to(device)
|
||||
b_size = real_cpu.size(0)
|
||||
label = torch.full((b_size, ), real_label, device=device)
|
||||
output = netD(real_cpu).view(-1)
|
||||
errD_real = criterion(output, label)
|
||||
errD_real.backward()
|
||||
|
||||
noise = torch.randn(b_size, latent_vector_size, 1, 1, device=device)
|
||||
fake = netG(noise)
|
||||
label.fill_(fake_label)
|
||||
output = netD(fake.detach()).view(-1)
|
||||
errD_fake = criterion(output, label)
|
||||
errD_fake.backward()
|
||||
errD = errD_real + errD_fake
|
||||
optimD.step()
|
||||
|
||||
netG.zero_grad()
|
||||
label.fill_(real_label)
|
||||
output = netD(fake).view(-1)
|
||||
errG = criterion(output, label)
|
||||
errG.backward()
|
||||
optimG.step()
|
||||
|
||||
is_score, is_std = inception_score(fake)
|
||||
|
||||
return {
|
||||
"loss_g": errG.item(),
|
||||
"loss_d": errD.item(),
|
||||
"inception": is_score
|
||||
}
|
||||
generator = Generator(
|
||||
latent_vector_size=config.get("latent_vector_size", 100))
|
||||
generator.apply(weights_init)
|
||||
return discriminator, generator
|
||||
|
||||
|
||||
def optimizer_creator(models, config):
|
||||
@@ -211,22 +122,122 @@ def optimizer_creator(models, config):
|
||||
return discriminator_opt, generator_opt
|
||||
|
||||
|
||||
class GANOperator(TrainingOperator):
|
||||
def setup(self, config):
|
||||
self.device = torch.device("cuda"
|
||||
if torch.cuda.is_available() else "cpu")
|
||||
|
||||
self.classifier = LeNet()
|
||||
self.classifier.load_state_dict(
|
||||
torch.load(config["classification_model_path"]))
|
||||
self.classifier.eval()
|
||||
|
||||
def inception_score(self, imgs, batch_size=32, splits=1):
|
||||
"""Calculate the inception score of the generated images."""
|
||||
N = len(imgs)
|
||||
dataloader = torch.utils.data.DataLoader(imgs, batch_size=batch_size)
|
||||
up = nn.Upsample(
|
||||
size=(28, 28), mode="bilinear").type(torch.FloatTensor)
|
||||
|
||||
def get_pred(x):
|
||||
x = up(x)
|
||||
x = self.classifier(x)
|
||||
return F.softmax(x).data.cpu().numpy()
|
||||
|
||||
# Obtain predictions for the fake provided images
|
||||
preds = np.zeros((N, 10))
|
||||
for i, batch in enumerate(dataloader, 0):
|
||||
batch = batch.type(torch.FloatTensor)
|
||||
batchv = Variable(batch)
|
||||
batch_size_i = batch.size()[0]
|
||||
preds[i * batch_size:i * batch_size +
|
||||
batch_size_i] = get_pred(batchv)
|
||||
|
||||
# Now compute the mean kl-div
|
||||
split_scores = []
|
||||
for k in range(splits):
|
||||
part = preds[k * (N // splits):(k + 1) * (N // splits), :]
|
||||
py = np.mean(part, axis=0)
|
||||
scores = []
|
||||
for i in range(part.shape[0]):
|
||||
pyx = part[i, :]
|
||||
scores.append(entropy(pyx, py))
|
||||
split_scores.append(np.exp(np.mean(scores)))
|
||||
|
||||
return np.mean(split_scores), np.std(split_scores)
|
||||
|
||||
@override(TrainingOperator)
|
||||
def train_batch(self, batch, batch_info):
|
||||
"""Trains on one batch of data from the data creator."""
|
||||
real_label = 1
|
||||
fake_label = 0
|
||||
discriminator, generator = self.models
|
||||
optimD, optimG = self.optimizers
|
||||
|
||||
# Compute a discriminator update for real images
|
||||
discriminator.zero_grad()
|
||||
real_cpu = batch[0].to(self.device)
|
||||
batch_size = real_cpu.size(0)
|
||||
label = torch.full((batch_size, ), real_label, device=self.device)
|
||||
output = discriminator(real_cpu).view(-1)
|
||||
errD_real = self.criterion(output, label)
|
||||
errD_real.backward()
|
||||
|
||||
# Compute a discriminator update for fake images
|
||||
noise = torch.randn(
|
||||
batch_size,
|
||||
self.config.get("latent_vector_size", 100),
|
||||
1,
|
||||
1,
|
||||
device=self.device)
|
||||
fake = generator(noise)
|
||||
label.fill_(fake_label)
|
||||
output = discriminator(fake.detach()).view(-1)
|
||||
errD_fake = self.criterion(output, label)
|
||||
errD_fake.backward()
|
||||
errD = errD_real + errD_fake
|
||||
|
||||
# Update the discriminator
|
||||
optimD.step()
|
||||
|
||||
# Update the generator
|
||||
generator.zero_grad()
|
||||
label.fill_(real_label)
|
||||
output = discriminator(fake).view(-1)
|
||||
errG = self.criterion(output, label)
|
||||
errG.backward()
|
||||
optimG.step()
|
||||
|
||||
is_score, is_std = self.inception_score(fake)
|
||||
|
||||
return {
|
||||
"loss_g": errG.item(),
|
||||
"loss_d": errD.item(),
|
||||
"inception": is_score,
|
||||
"num_samples": batch_size
|
||||
}
|
||||
|
||||
|
||||
def train_example(num_replicas=1, use_gpu=False, test_mode=False):
|
||||
config = {TEST_MODE: test_mode}
|
||||
config = {
|
||||
"test_mode": test_mode,
|
||||
"classification_model_path": os.path.join(
|
||||
os.path.dirname(ray.__file__),
|
||||
"util/sgd/pytorch/examples/mnist_cnn.pt")
|
||||
}
|
||||
trainer = PyTorchTrainer(
|
||||
model_creator,
|
||||
data_creator,
|
||||
optimizer_creator,
|
||||
nn.BCELoss,
|
||||
train_function=train,
|
||||
validation_function=False,
|
||||
training_operator_cls=GANOperator,
|
||||
num_replicas=num_replicas,
|
||||
config=config,
|
||||
use_gpu=use_gpu,
|
||||
batch_size=16 if test_mode else 512,
|
||||
backend="nccl" if use_gpu else "gloo")
|
||||
for i in range(10):
|
||||
stats = trainer.train(max_retries=3)
|
||||
for i in range(5):
|
||||
stats = trainer.train()
|
||||
print(stats)
|
||||
|
||||
return trainer
|
||||
@@ -240,7 +251,7 @@ if __name__ == "__main__":
|
||||
"--address",
|
||||
required=False,
|
||||
type=str,
|
||||
help="the address to use for Redis")
|
||||
help="the address to use to connect to a cluster.")
|
||||
parser.add_argument(
|
||||
"--num-replicas",
|
||||
"-n",
|
||||
@@ -255,10 +266,6 @@ if __name__ == "__main__":
|
||||
args, _ = parser.parse_known_args()
|
||||
ray.init(address=args.address)
|
||||
|
||||
path = os.path.dirname(ray.__file__)
|
||||
model_path = os.path.join(path, "util/sgd/pytorch/examples/mnist_cnn.pt")
|
||||
# load the pretrained mnist classification model for inception_score
|
||||
|
||||
trainer = train_example(
|
||||
num_replicas=args.num_replicas,
|
||||
use_gpu=args.use_gpu,
|
||||
|
||||
@@ -2,13 +2,15 @@ import collections
|
||||
from filelock import FileLock
|
||||
import logging
|
||||
import inspect
|
||||
import itertools
|
||||
import os
|
||||
import torch
|
||||
import torch.utils.data
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
import ray
|
||||
from ray.util.sgd.pytorch import utils as pytorch_utils
|
||||
from ray.util.sgd.pytorch.constants import USE_FP16, SCHEDULER_STEP
|
||||
from ray.util.sgd.pytorch.training_operator import TrainingOperator
|
||||
from ray.util.sgd import utils
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -31,8 +33,7 @@ class PyTorchRunner:
|
||||
loss_creator (dict -> loss | Loss class): see pytorch_trainer.py.
|
||||
scheduler_creator (optimizers, dict -> schedulers): see
|
||||
pytorch_trainer.py.
|
||||
train_function: see pytorch_trainer.py
|
||||
validation_function: see pytorch_trainer.py
|
||||
training_operator_cls: see pytorch_trainer.py
|
||||
config (dict): see pytorch_trainer.py.
|
||||
dataloader_config (dict): See pytorch_trainer.py.
|
||||
batch_size (int): see pytorch_trainer.py.
|
||||
@@ -47,8 +48,7 @@ class PyTorchRunner:
|
||||
optimizer_creator,
|
||||
loss_creator,
|
||||
scheduler_creator=None,
|
||||
train_function=None,
|
||||
validation_function=None,
|
||||
training_operator_cls=None,
|
||||
config=None,
|
||||
dataloader_config=None,
|
||||
batch_size=16,
|
||||
@@ -60,17 +60,15 @@ class PyTorchRunner:
|
||||
self.optimizer_creator = optimizer_creator
|
||||
self.loss_creator = loss_creator
|
||||
self.scheduler_creator = scheduler_creator
|
||||
self.training_operator_cls = training_operator_cls or TrainingOperator
|
||||
self.config = {} if config is None else config
|
||||
self.dataloader_config = {
|
||||
"num_workers": 2
|
||||
} if dataloader_config is None else dataloader_config
|
||||
self.train_function = train_function or pytorch_utils.train
|
||||
self.validation_function = (validation_function
|
||||
or pytorch_utils.validate)
|
||||
self.batch_size = batch_size
|
||||
self.verbose = True
|
||||
|
||||
self.epoch = 0
|
||||
self.epochs = 0
|
||||
self._timers = {
|
||||
k: utils.TimerStat(window_size=1)
|
||||
for k in [
|
||||
@@ -160,6 +158,14 @@ class PyTorchRunner:
|
||||
self.validation_loader = torch.utils.data.DataLoader(
|
||||
val_set, batch_size=self.batch_size, **self.dataloader_config)
|
||||
|
||||
self.training_operator = self.training_operator_cls(
|
||||
self.config,
|
||||
models=self.models,
|
||||
optimizers=self.optimizers,
|
||||
criterion=self.criterion,
|
||||
schedulers=self.schedulers,
|
||||
use_fp16=self.use_fp16)
|
||||
|
||||
def get_node_ip(self):
|
||||
"""Returns the IP address of the current node."""
|
||||
return ray.services.get_node_ip_address()
|
||||
@@ -168,47 +174,42 @@ class PyTorchRunner:
|
||||
"""Finds a free port on the current node."""
|
||||
return utils.find_free_port()
|
||||
|
||||
def step(self):
|
||||
def train_epoch(self, num_steps=None, info=None):
|
||||
"""Runs a training epoch and updates the model parameters."""
|
||||
logger.debug("Begin Training Epoch {}".format(self.epoch + 1))
|
||||
training_config = self.config.copy()
|
||||
training_config.update({
|
||||
pytorch_utils.USE_FP16: self.use_fp16,
|
||||
pytorch_utils.SCHEDULER_STEP: self.scheduler_step_freq
|
||||
logger.debug("Begin Training Step {}".format(self.epochs + 1))
|
||||
info = info or {}
|
||||
info.update({
|
||||
USE_FP16: self.use_fp16,
|
||||
SCHEDULER_STEP: self.scheduler_step_freq
|
||||
})
|
||||
with self._timers["training"]:
|
||||
train_stats = self.train_function(
|
||||
training_config,
|
||||
self.given_models,
|
||||
self.train_loader,
|
||||
self.criterion,
|
||||
self.given_optimizers,
|
||||
scheduler=self.given_schedulers)
|
||||
train_stats["epoch"] = self.epoch
|
||||
|
||||
self.epoch += 1
|
||||
iterator = self.train_loader
|
||||
if num_steps:
|
||||
iterator = itertools.islice(iter(self.train_loader), num_steps)
|
||||
train_stats = self.training_operator.train_epoch(iterator, info)
|
||||
|
||||
self.epochs += 1
|
||||
train_stats.update(self.stats())
|
||||
return train_stats
|
||||
|
||||
def validate(self):
|
||||
def validate(self, num_steps=None, info=None):
|
||||
"""Evaluates the model on the validation data set."""
|
||||
if self.validation_loader is None:
|
||||
raise ValueError("No validation dataloader provided.")
|
||||
info = info or {}
|
||||
with self._timers["validation"]:
|
||||
validation_stats = self.validation_function(
|
||||
self.config,
|
||||
self.given_models,
|
||||
self.validation_loader,
|
||||
self.criterion,
|
||||
scheduler=self.given_schedulers)
|
||||
iterator = self.validation_loader
|
||||
if num_steps:
|
||||
iterator = itertools.islice(
|
||||
iter(self.validation_loader), num_steps)
|
||||
validation_stats = self.training_operator.validate(iterator, info)
|
||||
|
||||
validation_stats.update(self.stats())
|
||||
return validation_stats
|
||||
|
||||
def stats(self):
|
||||
"""Returns a dictionary of statistics collected."""
|
||||
stats = {"epoch": self.epoch}
|
||||
stats = {"epoch": self.epochs}
|
||||
for k, t in self._timers.items():
|
||||
stats[k + "_time_mean"] = t.mean
|
||||
stats[k + "_time_total"] = t.sum
|
||||
@@ -233,7 +234,8 @@ class PyTorchRunner:
|
||||
"""Returns the state of the runner."""
|
||||
|
||||
state = {
|
||||
"epoch": self.epoch,
|
||||
"epoch": self.epochs,
|
||||
"operator": self.training_operator.state_dict(),
|
||||
"models": self._get_model_state_dicts(),
|
||||
"optimizers": [opt.state_dict() for opt in self.optimizers],
|
||||
"stats": self.stats()
|
||||
@@ -262,13 +264,18 @@ class PyTorchRunner:
|
||||
|
||||
if self.use_fp16 and "amp" in state and amp:
|
||||
amp.load_state_dict(state["amp"])
|
||||
self.epoch = state["stats"]["epoch"]
|
||||
self.epochs = state["stats"]["epoch"]
|
||||
self.training_operator.load_state_dict(state_dict)
|
||||
|
||||
def apply_fn(self, fn):
|
||||
return fn(self)
|
||||
def apply(self, fn):
|
||||
return fn()
|
||||
|
||||
def apply_operator(self, fn):
|
||||
return fn(self.training_operator)
|
||||
|
||||
def shutdown(self):
|
||||
"""Attempts to shut down the worker."""
|
||||
del self.training_operator
|
||||
del self.validation_loader
|
||||
del self.train_loader
|
||||
del self.criterion
|
||||
|
||||
@@ -15,12 +15,20 @@ from ray.util.sgd.pytorch.distributed_pytorch_runner import (
|
||||
DistributedPyTorchRunner)
|
||||
from ray.util.sgd import utils
|
||||
from ray.util.sgd.pytorch.pytorch_runner import PyTorchRunner
|
||||
from ray.util.sgd.pytorch import utils as pytorch_utils
|
||||
from ray.util.sgd.pytorch.constants import VALID_SCHEDULER_STEP
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
RESIZE_COOLDOWN_S = 10
|
||||
|
||||
|
||||
def _validate_scheduler_step_freq(scheduler_step_freq):
|
||||
if scheduler_step_freq:
|
||||
if scheduler_step_freq not in VALID_SCHEDULER_STEP:
|
||||
raise ValueError(
|
||||
"Scheduler step freq must be in {}. Got {}".format(
|
||||
VALID_SCHEDULER_STEP, scheduler_step_freq))
|
||||
|
||||
|
||||
class PyTorchTrainer:
|
||||
"""Train a PyTorch model using distributed PyTorch.
|
||||
|
||||
@@ -48,14 +56,15 @@ class PyTorchTrainer:
|
||||
loss_creator=nn.MSELoss,
|
||||
use_gpu=True
|
||||
)
|
||||
trainer.train()
|
||||
for i in range(4):
|
||||
trainer.train()
|
||||
|
||||
|
||||
Args:
|
||||
model_creator (dict -> Model(s)): Constructor function that takes in
|
||||
config and returns the model(s) to be optimized. These must be
|
||||
``torch.nn.Module`` objects. If multiple models are returned,
|
||||
a ``train_function`` must be specified. You do not need to
|
||||
a ``training_operator_cls`` must be specified. You do not need to
|
||||
handle GPU/devices in this function; RaySGD will do that under
|
||||
the hood.
|
||||
data_creator (dict -> Dataset(s)): Constructor function
|
||||
@@ -75,22 +84,18 @@ class PyTorchTrainer:
|
||||
of ``torch.nn.modules.loss._Loss``, which is most Pytorch
|
||||
loss classes. For example, ``loss_creator=torch.nn.BCELoss``.
|
||||
scheduler_creator (optimizers, dict -> loss):
|
||||
A constructor function for the scheduler loss. This is
|
||||
A constructor function for the torch scheduler. This is
|
||||
a function that takes in the generated optimizers (from
|
||||
``optimizer_creator``) provided config for customization.
|
||||
Be sure to set ``scheduler_step_freq`` to increment the
|
||||
scheduler correctly.
|
||||
train_function: Custom function for training. This function
|
||||
will be executed in parallel across all workers at once. The
|
||||
function needs to take in (models, train_dataloader, criterion,
|
||||
optimizers, config), and return a dict of training stats.
|
||||
validation_function: Custom function for validation. This function
|
||||
will be executed in parallel across all workers at once.
|
||||
This takes in (model, val_dataloader, criterion, config)
|
||||
and returns a dict of validation stats.
|
||||
training_operator_cls (type): Custom training operator class
|
||||
that subclasses the TrainingOperator class. This class
|
||||
will be copied onto all remote workers and used to specify
|
||||
custom training and validation operations. Defaults to
|
||||
TrainingOperator.
|
||||
config (dict): Custom configuration value to be passed to
|
||||
"model_creator", "data_creator", "optimizer_creator", and
|
||||
"loss_creator".
|
||||
all creator and operator constructors.
|
||||
dataloader_config (dict): Configuration values to be passed into
|
||||
the ``torch.utils.data.DataLoader`` object that wraps
|
||||
the dataset on each parallel worker for both training
|
||||
@@ -130,8 +135,7 @@ class PyTorchTrainer:
|
||||
optimizer_creator,
|
||||
loss_creator,
|
||||
scheduler_creator=None,
|
||||
train_function=None,
|
||||
validation_function=None,
|
||||
training_operator_cls=None,
|
||||
initialization_hook=None,
|
||||
config=None,
|
||||
dataloader_config=None,
|
||||
@@ -151,11 +155,10 @@ class PyTorchTrainer:
|
||||
|
||||
self.model_creator = model_creator
|
||||
self.data_creator = data_creator
|
||||
self.train_function = train_function
|
||||
self.optimizer_creator = optimizer_creator
|
||||
self.loss_creator = loss_creator
|
||||
self.scheduler_creator = scheduler_creator
|
||||
self.validation_function = validation_function
|
||||
self.training_operator_cls = training_operator_cls
|
||||
self.initialization_hook = initialization_hook
|
||||
self.config = {} if config is None else config
|
||||
self.dataloader_config = dataloader_config
|
||||
@@ -166,6 +169,8 @@ class PyTorchTrainer:
|
||||
|
||||
logger.info("Using {} as backend.".format(backend))
|
||||
self.backend = backend
|
||||
|
||||
# TODO: Have an auto "use_gpu" option to detect and use GPUs.
|
||||
self.use_gpu = use_gpu
|
||||
self.batch_size = batch_size
|
||||
self.max_replicas = num_replicas
|
||||
@@ -180,12 +185,7 @@ class PyTorchTrainer:
|
||||
self._num_failures = 0
|
||||
self._last_resize = float("-inf")
|
||||
|
||||
if scheduler_step_freq and (
|
||||
scheduler_step_freq not in pytorch_utils.VALID_SCHEDULER_STEP):
|
||||
raise ValueError(
|
||||
"Scheduler step freq must be in {}. Got {}".format(
|
||||
pytorch_utils.VALID_SCHEDULER_STEP, scheduler_step_freq))
|
||||
|
||||
_validate_scheduler_step_freq(scheduler_step_freq)
|
||||
self.scheduler_step_freq = scheduler_step_freq
|
||||
|
||||
self._start_workers(self.max_replicas)
|
||||
@@ -204,8 +204,7 @@ class PyTorchTrainer:
|
||||
self.optimizer_creator,
|
||||
self.loss_creator,
|
||||
self.scheduler_creator,
|
||||
train_function=self.train_function,
|
||||
validation_function=self.validation_function,
|
||||
training_operator_cls=self.training_operator_cls,
|
||||
config=self.config,
|
||||
dataloader_config=self.dataloader_config,
|
||||
batch_size=self.batch_size,
|
||||
@@ -243,8 +242,7 @@ class PyTorchTrainer:
|
||||
self.loss_creator,
|
||||
self.scheduler_creator,
|
||||
backend=self.backend,
|
||||
train_function=self.train_function,
|
||||
validation_function=self.validation_function,
|
||||
training_operator_cls=self.training_operator_cls,
|
||||
config=self.config,
|
||||
dataloader_config=self.dataloader_config,
|
||||
batch_size=batch_size_per_replica,
|
||||
@@ -266,21 +264,35 @@ class PyTorchTrainer:
|
||||
for i, worker in enumerate(self.workers)
|
||||
])
|
||||
|
||||
def train(self, max_retries=0, checkpoint="auto"):
|
||||
def train(self,
|
||||
num_steps=None,
|
||||
max_retries=0,
|
||||
checkpoint="auto",
|
||||
info=None):
|
||||
"""Runs a training epoch.
|
||||
|
||||
Runs an average over all values returned from workers. Set
|
||||
`max_retries` to enable fault handling in case of instance preemption.
|
||||
|
||||
Args:
|
||||
num_steps (int): Number of batches to compute update steps on.
|
||||
This corresponds also to the number of times
|
||||
``TrainingOperator.train_batch`` is called.
|
||||
max_retries (int): Must be non-negative. If set to N, will
|
||||
kill all current workers, query the Ray global state for
|
||||
total available resources, and re-launch up to the
|
||||
available resources. Behavior is not well-defined
|
||||
in case of shared cluster usage.
|
||||
checkpoint (str): Path to checkpoint to restore from if retrying.
|
||||
If max_retries is set and checkpoint == "auto", PyTorchTrainer
|
||||
will save a checkpoint before starting to train.
|
||||
If max_retries is set and ``checkpoint == "auto"``,
|
||||
PyTorchTrainer will save a checkpoint before starting to train.
|
||||
info (dict): Optional dictionary passed to the training
|
||||
operator for ``train_epoch`` and ``train_batch``.
|
||||
|
||||
Returns:
|
||||
A dictionary of metrics for training.
|
||||
You can provide custom metrics by passing in a custom
|
||||
``training_operator_cls``.
|
||||
"""
|
||||
assert max_retries >= 0, "`max_retries` must be non-negative."
|
||||
if max_retries:
|
||||
@@ -296,7 +308,8 @@ class PyTorchTrainer:
|
||||
self._resize_workers(checkpoint=checkpoint)
|
||||
|
||||
with self.optimizer_timer:
|
||||
success, worker_stats = self._train_step()
|
||||
success, worker_stats = self._train_epoch(
|
||||
num_steps=num_steps, info=info)
|
||||
# Fault handling
|
||||
for i in range(max_retries):
|
||||
if success:
|
||||
@@ -306,7 +319,8 @@ class PyTorchTrainer:
|
||||
self._resize_workers(checkpoint=checkpoint)
|
||||
logger.info("Retrying training step with %d workers." % len(
|
||||
self.workers))
|
||||
success, worker_stats = self._train_step()
|
||||
success, worker_stats = self._train_epoch(
|
||||
num_steps=num_steps, info=info)
|
||||
if not success:
|
||||
raise RuntimeError("Training run failed.")
|
||||
|
||||
@@ -321,19 +335,58 @@ class PyTorchTrainer:
|
||||
train_stats[stat_key] = worker_stats[0][stat_key]
|
||||
return train_stats
|
||||
|
||||
def _train_step(self):
|
||||
worker_stats = [w.step.remote() for w in self.workers]
|
||||
def _train_epoch(self, num_steps=None, info=None):
|
||||
worker_stats = [
|
||||
w.train_epoch.remote(num_steps=num_steps, info=info)
|
||||
for w in self.workers
|
||||
]
|
||||
success = utils.check_for_failure(worker_stats)
|
||||
return success, worker_stats
|
||||
|
||||
def apply_all_workers(self, fn):
|
||||
return ray.get([w.apply_fn.remote(fn) for w in self.workers])
|
||||
"""Run a function on all operators on the workers.
|
||||
|
||||
def validate(self):
|
||||
"""Evaluates the model on the validation data set."""
|
||||
if self.validation_function is False:
|
||||
return {}
|
||||
worker_stats = ray.get([w.validate.remote() for w in self.workers])
|
||||
Args:
|
||||
fn (Callable): A function that takes in no arguments.
|
||||
|
||||
Returns:
|
||||
A list of objects returned by ``fn`` on each worker.
|
||||
|
||||
"""
|
||||
return ray.get([w.apply.remote(fn) for w in self.workers])
|
||||
|
||||
def apply_all_operators(self, fn):
|
||||
"""Run a function on all operators on the workers.
|
||||
|
||||
Args:
|
||||
fn (Callable[TrainingOperator]): A function that takes in a
|
||||
TrainingOperator.
|
||||
|
||||
Returns:
|
||||
A list of objects returned by ``fn`` on each operator.
|
||||
|
||||
"""
|
||||
return ray.get([w.apply_operator.remote(fn) for w in self.workers])
|
||||
|
||||
def validate(self, num_steps=None, info=None):
|
||||
"""Evaluates the model on the validation data set.
|
||||
|
||||
Args:
|
||||
num_steps (int): Number of batches to compute update steps on.
|
||||
This corresponds also to the number of times
|
||||
``TrainingOperator.validate_batch`` is called.
|
||||
info (dict): Optional dictionary passed to the training
|
||||
operator for `validate` and `validate_batch`.
|
||||
|
||||
Returns:
|
||||
A dictionary of metrics for validation.
|
||||
You can provide custom metrics by passing in a custom
|
||||
``training_operator_cls``.
|
||||
"""
|
||||
worker_stats = ray.get([
|
||||
w.validate.remote(num_steps=num_steps, info=info)
|
||||
for w in self.workers
|
||||
])
|
||||
|
||||
validation_stats = {}
|
||||
for stat_key in worker_stats[0]:
|
||||
@@ -346,8 +399,8 @@ class PyTorchTrainer:
|
||||
|
||||
This is useful for lr_schedulers such as ``ReduceLROnPlateau``.
|
||||
"""
|
||||
self.apply_all_workers(
|
||||
lambda runner: [sched.step(metric) for sched in runner.schedulers])
|
||||
self.apply_all_operators(
|
||||
lambda op: [sched.step(metric) for sched in op.schedulers])
|
||||
|
||||
def get_model(self):
|
||||
"""Returns the learned model(s)."""
|
||||
@@ -366,17 +419,18 @@ class PyTorchTrainer:
|
||||
Args:
|
||||
checkpoint (str): Path to target checkpoint file.
|
||||
|
||||
Returns:
|
||||
checkpoint (str): Path to target checkpoint file.
|
||||
"""
|
||||
state = ray.get(self.workers[0].get_state.remote())
|
||||
torch.save(state, checkpoint)
|
||||
return checkpoint
|
||||
|
||||
def restore(self, checkpoint):
|
||||
"""Restores the model from the provided checkpoint.
|
||||
"""Restores the Trainer and all workers from the provided checkpoint.
|
||||
|
||||
Args:
|
||||
checkpoint (str): Path to target checkpoint file.
|
||||
|
||||
"""
|
||||
state = torch.load(checkpoint)
|
||||
state_id = ray.put(state)
|
||||
@@ -450,7 +504,6 @@ class PyTorchTrainable(Trainable):
|
||||
validation_stats = self._trainer.validate()
|
||||
|
||||
train_stats.update(validation_stats)
|
||||
|
||||
# output {"mean_loss": test_loss, "mean_accuracy": accuracy}
|
||||
return train_stats
|
||||
|
||||
|
||||
@@ -0,0 +1,343 @@
|
||||
import collections
|
||||
import torch
|
||||
|
||||
from ray.util.sgd.utils import TimerStat, AverageMeter
|
||||
from ray.util.sgd.pytorch.constants import (
|
||||
SCHEDULER_STEP_EPOCH, SCHEDULER_STEP_BATCH, SCHEDULER_STEP, BATCH_COUNT)
|
||||
|
||||
amp = None
|
||||
|
||||
try:
|
||||
from apex import amp
|
||||
except ImportError:
|
||||
# Apex library is not installed, so we cannot enable mixed precision.
|
||||
# We don't log here because logging happens in the pytorch_runner,
|
||||
# where amp is initialized.
|
||||
pass
|
||||
|
||||
|
||||
def _is_multiple(component):
|
||||
"""Checks if a component (optimizer, model, etc) is not singular."""
|
||||
return isinstance(component, collections.Iterable) and len(component) > 1
|
||||
|
||||
|
||||
class TrainingOperator:
|
||||
"""Abstract class for custom training or validation loops.
|
||||
|
||||
The scheduler will only be called at a batch or epoch frequency, depending
|
||||
on the user parameter. Be sure to set ``scheduler_step_freq`` in
|
||||
``PyTorchTrainer`` to either "batch" or "epoch" to increment the scheduler
|
||||
correctly during training. If using a learning rate scheduler
|
||||
that depends on validation loss, you can use ``trainer.update_scheduler``.
|
||||
|
||||
For both training and validation, there are two granularities that
|
||||
you can provide customization: per epoch or per batch.
|
||||
You do not need to override both.
|
||||
|
||||
.. image:: raysgd-custom.jpg
|
||||
:scale: 80%
|
||||
:align: center
|
||||
|
||||
Raises:
|
||||
ValueError if multiple models/optimizers/schedulers are provided.
|
||||
You are expected to subclass this class if you wish
|
||||
to train over multiple models/optimizers/schedulers.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
config,
|
||||
models,
|
||||
optimizers,
|
||||
criterion,
|
||||
schedulers=None,
|
||||
use_fp16=False):
|
||||
# You are not expected to override this method.
|
||||
self.timers = {
|
||||
k: TimerStat()
|
||||
for k in ["fwd", "grad", "apply", "epoch_time"]
|
||||
}
|
||||
self._validated_customization = False
|
||||
self._models = models # List of models
|
||||
assert isinstance(models, collections.Iterable), (
|
||||
"Components need to be iterable. Got: {}".format(type(models)))
|
||||
self._optimizers = optimizers # List of optimizers
|
||||
assert isinstance(optimizers, collections.Iterable), (
|
||||
"Components need to be iterable. Got: {}".format(type(optimizers)))
|
||||
self._criterion = criterion
|
||||
self._schedulers = schedulers
|
||||
if schedulers:
|
||||
assert isinstance(schedulers, collections.Iterable), (
|
||||
"Components need to be iterable. Got: {}".format(
|
||||
type(schedulers)))
|
||||
self._config = config
|
||||
self._use_fp16 = use_fp16
|
||||
self.global_step = 0
|
||||
|
||||
if type(self) is TrainingOperator:
|
||||
for component in (models, schedulers, optimizers):
|
||||
if _is_multiple(component):
|
||||
raise ValueError(
|
||||
"Need to provide a custom operator subclassing "
|
||||
"TrainingOperator if using multi-scheduler, "
|
||||
"multi-model or multi-optimizer training/validation.")
|
||||
|
||||
self.setup(config)
|
||||
|
||||
def setup(self, config):
|
||||
"""Override this method to implement custom operator setup.
|
||||
|
||||
Args:
|
||||
config (dict): Custom configuration value to be passed to
|
||||
all creator and operator constructors. Same as ``self.config``.
|
||||
"""
|
||||
pass
|
||||
|
||||
def train_epoch(self, iterator, info):
|
||||
"""Runs one standard training pass over the train_iterator.
|
||||
|
||||
By default, this method will iterate over the given iterator and
|
||||
call ``self.train_batch`` over each batch.
|
||||
|
||||
If ``scheduler_step_freq`` is set, this class will also step the
|
||||
scheduler accordingly.
|
||||
|
||||
You do not need to call ``train_batch`` in this method if you plan
|
||||
to implement a custom optimization/training routine here.
|
||||
|
||||
Args:
|
||||
iterator (iter): Iterator over the training data for the entire
|
||||
epoch. This iterator is expected to be entirely consumed.
|
||||
info (dict): Dictionary for information to be used for custom
|
||||
training operations.
|
||||
|
||||
Returns:
|
||||
A dict of metrics from training.
|
||||
"""
|
||||
self._losses = AverageMeter()
|
||||
|
||||
self.model.train()
|
||||
with self.timers["epoch_time"]:
|
||||
for batch_idx, batch in enumerate(iterator):
|
||||
batch_info = {
|
||||
"batch_idx": batch_idx,
|
||||
"global_step": self.global_step
|
||||
}
|
||||
batch_info.update(info)
|
||||
metrics = self.train_batch(batch, batch_info=batch_info)
|
||||
|
||||
if self.scheduler and batch_info.get(
|
||||
SCHEDULER_STEP) == SCHEDULER_STEP_BATCH:
|
||||
self.scheduler.step()
|
||||
|
||||
if "loss" in metrics:
|
||||
self._losses.update(
|
||||
metrics["loss"], n=metrics.get("num_samples", 1))
|
||||
self.global_step += 1
|
||||
|
||||
if self.scheduler and info.get(SCHEDULER_STEP) == SCHEDULER_STEP_EPOCH:
|
||||
self.scheduler.step()
|
||||
|
||||
stats = {
|
||||
BATCH_COUNT: batch_idx + 1,
|
||||
"mean_train_loss": self._losses.avg,
|
||||
"last_train_loss": self._losses.val,
|
||||
"epoch_time": self.timers["epoch_time"].last
|
||||
}
|
||||
stats.update({
|
||||
timer_tag: timer.mean
|
||||
for timer_tag, timer in self.timers.items()
|
||||
})
|
||||
return stats
|
||||
|
||||
def train_batch(self, batch, batch_info):
|
||||
"""Computes loss and updates the model over one batch.
|
||||
|
||||
This method is responsible for computing the loss and gradient and
|
||||
updating the model.
|
||||
|
||||
By default, this method implementation assumes that batches
|
||||
are in (features, labels) format. If using amp/fp16
|
||||
training, it will also scale the loss automatically.
|
||||
|
||||
You can provide custom loss metrics and training operations if you
|
||||
override this method. If overriding this method, you can access model,
|
||||
optimizer, criterion via ``self.model``, ``self.optimizer``,
|
||||
and ``self.criterion``.
|
||||
|
||||
You do not need to override this method if you plan to
|
||||
override ``train_epoch``.
|
||||
|
||||
Args:
|
||||
batch: One item of the validation iterator.
|
||||
batch_info (dict): Information dict passed in from ``train_epoch``.
|
||||
|
||||
Returns:
|
||||
A dictionary of metrics.
|
||||
By default, this dictionary contains "loss" and "num_samples".
|
||||
"num_samples" corresponds to number of datapoints in the batch.
|
||||
However, you can provide any number of other values.
|
||||
|
||||
"""
|
||||
features, target = batch
|
||||
# Create non_blocking tensors for distributed training
|
||||
if torch.cuda.is_available():
|
||||
features = features.cuda(non_blocking=True)
|
||||
target = target.cuda(non_blocking=True)
|
||||
|
||||
# Compute output.
|
||||
with self.timers["fwd"]:
|
||||
output = self.model(features)
|
||||
loss = self.criterion(output, target)
|
||||
|
||||
# Compute gradients in a backward pass.
|
||||
with self.timers["grad"]:
|
||||
self.optimizer.zero_grad()
|
||||
if self.use_fp16:
|
||||
with amp.scale_loss(loss, self.optimizer) as scaled_loss:
|
||||
scaled_loss.backward()
|
||||
else:
|
||||
loss.backward()
|
||||
|
||||
# Call step of optimizer to update model params.
|
||||
with self.timers["apply"]:
|
||||
self.optimizer.step()
|
||||
return {"loss": loss.item(), "num_samples": features.size(0)}
|
||||
|
||||
def validate(self, val_iterator, info):
|
||||
"""Runs one standard validation pass over the val_iterator.
|
||||
|
||||
This will call ``model.eval()`` and ``torch.no_grad`` when iterating
|
||||
over the validation dataset.
|
||||
|
||||
If overriding this method, you can access model, criterion via
|
||||
``self.model`` and ``self.criterion``. You also do not need to call
|
||||
``validate_batch`` if overriding this method.
|
||||
|
||||
Args:
|
||||
val_iterator (iter): Iterable constructed over the
|
||||
validation dataset.
|
||||
info: (dict): Dictionary for information to be used for custom
|
||||
validation operations.
|
||||
|
||||
Returns:
|
||||
A dict of metrics from the evaluation.
|
||||
By default, returns "mean_accuracy" and "mean_validation_loss"
|
||||
which is computed by aggregating "loss" and "correct" values
|
||||
from ``validate_batch`` and dividing it by the sum of
|
||||
``num_samples`` from all calls to ``self.validate_batch``.
|
||||
"""
|
||||
losses = AverageMeter()
|
||||
total_correct = 0
|
||||
|
||||
# switch to evaluate mode
|
||||
self.model.eval()
|
||||
with torch.no_grad():
|
||||
for batch_idx, batch in enumerate(val_iterator):
|
||||
batch_info = {"batch_idx": batch_idx}
|
||||
batch_info.update(info)
|
||||
metrics = self.validate_batch(batch, batch_info)
|
||||
if "loss" in metrics:
|
||||
losses.update(
|
||||
metrics["loss"], n=metrics.get("num_samples", 1))
|
||||
|
||||
if "num_correct" in metrics:
|
||||
total_correct += metrics["num_correct"]
|
||||
|
||||
stats = {
|
||||
"batch_count": batch_idx + 1,
|
||||
"mean_validation_loss": losses.avg,
|
||||
"mean_accuracy": total_correct / losses.count
|
||||
}
|
||||
return stats
|
||||
|
||||
def validate_batch(self, batch, batch_info):
|
||||
"""Calcuates the loss and accuracy over a given batch.
|
||||
|
||||
You can override this method to provide arbitrary metrics.
|
||||
|
||||
Args:
|
||||
batch: One item of the validation iterator.
|
||||
batch_info (dict): Contains information per batch from
|
||||
``validate()``.
|
||||
|
||||
Returns:
|
||||
A dict of metrics.
|
||||
By default, returns "loss", "num_correct", and "num_samples".
|
||||
"""
|
||||
features, target = batch
|
||||
if torch.cuda.is_available():
|
||||
features = features.cuda(non_blocking=True)
|
||||
target = target.cuda(non_blocking=True)
|
||||
|
||||
# compute output
|
||||
output = self.model(features)
|
||||
loss = self.criterion(output, target)
|
||||
_, predicted = torch.max(output.data, 1)
|
||||
|
||||
return {
|
||||
"loss": loss.item(),
|
||||
"num_correct": (predicted == target).sum().item(),
|
||||
"num_samples": target.size(0)
|
||||
}
|
||||
|
||||
def state_dict(self):
|
||||
"""Returns a serializable representation of the operator state."""
|
||||
pass
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
"""Loads a serializable representation of the operator state."""
|
||||
pass
|
||||
|
||||
@property
|
||||
def config(self):
|
||||
"""Dictionary as provided into PyTorchTrainer."""
|
||||
return self._config
|
||||
|
||||
@property
|
||||
def model(self):
|
||||
"""First or only model created by the provided ``model_creator``."""
|
||||
return self._models[0]
|
||||
|
||||
@property
|
||||
def models(self):
|
||||
"""List of models created by the provided ``model_creator``."""
|
||||
return self._models
|
||||
|
||||
@property
|
||||
def optimizer(self):
|
||||
"""First or only optimizer(s) created by the ``optimizer_creator``."""
|
||||
return self._optimizers[0]
|
||||
|
||||
@property
|
||||
def optimizers(self):
|
||||
"""List of optimizers created by the ``optimizer_creator``."""
|
||||
return self._optimizers
|
||||
|
||||
@property
|
||||
def criterion(self):
|
||||
"""Criterion created by the provided ``loss_creator``."""
|
||||
return self._criterion
|
||||
|
||||
@property
|
||||
def scheduler(self):
|
||||
"""First or only scheduler(s) created by the ``scheduler_creator``."""
|
||||
if self._schedulers:
|
||||
return self._schedulers[0]
|
||||
|
||||
@property
|
||||
def schedulers(self):
|
||||
"""List of schedulers created by the ``scheduler_creator``."""
|
||||
return self._schedulers
|
||||
|
||||
@property
|
||||
def use_fp16(self):
|
||||
"""Whether the model and optimizer have been FP16 enabled."""
|
||||
return self._use_fp16
|
||||
|
||||
|
||||
class _TestingOperator(TrainingOperator):
|
||||
def train_epoch(self, iterator, info):
|
||||
func = self.config.get("custom_func")
|
||||
if callable(func):
|
||||
return func(self, iterator, info)
|
||||
return {"done": 1}
|
||||
@@ -1,229 +0,0 @@
|
||||
import collections
|
||||
import time
|
||||
import torch
|
||||
|
||||
from ray.util.sgd.utils import TimerStat
|
||||
|
||||
amp = None
|
||||
|
||||
try:
|
||||
from apex import amp
|
||||
except ImportError:
|
||||
# Apex library is not installed, so we cannot enable mixed precision.
|
||||
# We don't log here because logging happens in the pytorch_runner,
|
||||
# where amp is initialized.
|
||||
pass
|
||||
|
||||
USE_FP16 = "__use_fp16__"
|
||||
TEST_MODE = "__test_mode__"
|
||||
BATCH_COUNT = "batch_processed"
|
||||
SCHEDULER_STEP = "scheduler_step"
|
||||
SCHEDULER_STEP_BATCH = "batch"
|
||||
SCHEDULER_STEP_EPOCH = "epoch"
|
||||
|
||||
VALID_SCHEDULER_STEP = {SCHEDULER_STEP_BATCH, SCHEDULER_STEP_EPOCH}
|
||||
|
||||
|
||||
def train(config, model, train_iterator, criterion, optimizer, scheduler=None):
|
||||
"""Runs one standard training pass over the train_iterator.
|
||||
|
||||
This function automatically measures timing for various operations such
|
||||
as host to device transfer, gradient calculation, and gradient application.
|
||||
|
||||
It also automatically detects and places the data on the given GPU device
|
||||
if available.
|
||||
|
||||
The scheduler will only be called at a batch or epoch frequency, depending
|
||||
on the user parameter. Be sure to set ``scheduler_step_freq`` in
|
||||
``PyTorchTrainer`` to either "batch" or "epoch" to increment the scheduler
|
||||
correctly during training. If using a learning rate scheduler
|
||||
that depends on validation loss, you can use ``trainer.update_scheduler``.
|
||||
|
||||
Raises:
|
||||
ValueError if multiple models/optimizers/schedulers are provided. You
|
||||
are expected to have a custom training function if you wish
|
||||
to use multiple models/optimizers/schedulers.
|
||||
|
||||
Args:
|
||||
config: (dict): A user configuration provided into the Trainer
|
||||
constructor.
|
||||
model: The model as created by the model_creator.
|
||||
train_iterator: An iterator created from the DataLoader which
|
||||
wraps the provided Dataset.
|
||||
criterion: The loss object created by the loss_creator.
|
||||
optimizer: The torch.optim.Optimizer object as created by the
|
||||
optimizer_creator.
|
||||
scheduler (optional): The torch.optim.lr_scheduler object
|
||||
as created by the scheduler_creator. Be sure to set
|
||||
``scheduler_step_freq`` in ``PyTorchTrainer``
|
||||
to increment the scheduler correctly.
|
||||
|
||||
Returns:
|
||||
A dict of metrics from training.
|
||||
"""
|
||||
if isinstance(model, collections.Iterable) or isinstance(
|
||||
optimizer, collections.Iterable) or isinstance(
|
||||
scheduler, collections.Iterable):
|
||||
raise ValueError(
|
||||
"Need to provide custom training function if using multi-model "
|
||||
"or multi-scheduler or multi-optimizer training.")
|
||||
|
||||
batch_time = AverageMeter()
|
||||
data_time = AverageMeter()
|
||||
losses = AverageMeter()
|
||||
|
||||
timers = {k: TimerStat() for k in ["h2d", "fwd", "grad", "apply"]}
|
||||
|
||||
# switch to train mode
|
||||
model.train()
|
||||
|
||||
end = time.time()
|
||||
|
||||
for batch_idx, (features, target) in enumerate(train_iterator):
|
||||
# measure data loading time
|
||||
data_time.update(time.time() - end)
|
||||
|
||||
# Create non_blocking tensors for distributed training
|
||||
with timers["h2d"]:
|
||||
if torch.cuda.is_available():
|
||||
features = features.cuda(non_blocking=True)
|
||||
target = target.cuda(non_blocking=True)
|
||||
|
||||
# compute output
|
||||
with timers["fwd"]:
|
||||
output = model(features)
|
||||
loss = criterion(output, target)
|
||||
|
||||
# measure accuracy and record loss
|
||||
losses.update(loss.item(), features.size(0))
|
||||
|
||||
with timers["grad"]:
|
||||
# compute gradients in a backward pass
|
||||
optimizer.zero_grad()
|
||||
|
||||
if config.get(USE_FP16):
|
||||
with amp.scale_loss(loss, optimizer) as scaled_loss:
|
||||
scaled_loss.backward()
|
||||
else:
|
||||
loss.backward()
|
||||
|
||||
with timers["apply"]:
|
||||
# Call step of optimizer to update model params
|
||||
optimizer.step()
|
||||
|
||||
if scheduler and config.get(SCHEDULER_STEP) == SCHEDULER_STEP_BATCH:
|
||||
scheduler.step()
|
||||
|
||||
# measure elapsed time
|
||||
batch_time.update(time.time() - end)
|
||||
end = time.time()
|
||||
|
||||
if config.get(TEST_MODE) and batch_idx == 0:
|
||||
break
|
||||
|
||||
if scheduler and config.get(SCHEDULER_STEP) == SCHEDULER_STEP_EPOCH:
|
||||
scheduler.step()
|
||||
|
||||
stats = {
|
||||
"batch_time": batch_time.avg,
|
||||
BATCH_COUNT: batch_idx + 1,
|
||||
"train_loss": losses.avg,
|
||||
"data_time": data_time.avg,
|
||||
}
|
||||
stats.update({k: t.mean for k, t in timers.items()})
|
||||
return stats
|
||||
|
||||
|
||||
def validate(config, model, val_iterator, criterion, scheduler=None):
|
||||
"""Runs one standard validation pass over the val_iterator.
|
||||
|
||||
This function automatically measures timing for various operations such
|
||||
as host to device transfer and processing time for the batch.
|
||||
|
||||
It also automatically detects and places the data on the given GPU device
|
||||
if available.
|
||||
|
||||
Raises:
|
||||
ValueError if multiple models/schedulers are provided. You
|
||||
are expected to have a custom validation function if you wish
|
||||
to use multiple models/schedulers.
|
||||
|
||||
Args:
|
||||
config: (dict): A user configuration provided into the Trainer
|
||||
constructor.
|
||||
model: The model as created by the model_creator.
|
||||
train_iterator: An iterator created from the DataLoader which
|
||||
wraps the provided Dataset.
|
||||
criterion: The loss object created by the loss_creator.
|
||||
scheduler (optional): The torch.optim.lr_scheduler object
|
||||
as created by the scheduler_creator. By default,
|
||||
this is not used in this function.
|
||||
|
||||
Returns:
|
||||
A dict of metrics from the evaluation.
|
||||
"""
|
||||
|
||||
if isinstance(model, collections.Iterable) or isinstance(
|
||||
scheduler, collections.Iterable):
|
||||
raise ValueError(
|
||||
"Need to provide custom validation function if using multi-model "
|
||||
"or multi-scheduler training.")
|
||||
batch_time = AverageMeter()
|
||||
losses = AverageMeter()
|
||||
|
||||
# switch to evaluate mode
|
||||
model.eval()
|
||||
correct = 0
|
||||
total = 0
|
||||
batch_idx = 0
|
||||
with torch.no_grad():
|
||||
end = time.time()
|
||||
for batch_idx, (features, target) in enumerate(val_iterator):
|
||||
if torch.cuda.is_available():
|
||||
features = features.cuda(non_blocking=True)
|
||||
target = target.cuda(non_blocking=True)
|
||||
|
||||
# compute output
|
||||
output = model(features)
|
||||
loss = criterion(output, target)
|
||||
_, predicted = torch.max(output.data, 1)
|
||||
total += target.size(0)
|
||||
correct += (predicted == target).sum().item()
|
||||
|
||||
# measure accuracy and record loss
|
||||
losses.update(loss.item(), features.size(0))
|
||||
|
||||
# measure elapsed time
|
||||
batch_time.update(time.time() - end)
|
||||
end = time.time()
|
||||
|
||||
if config.get(TEST_MODE) and batch_idx == 0:
|
||||
break
|
||||
|
||||
stats = {
|
||||
BATCH_COUNT: batch_idx + 1,
|
||||
"batch_time": batch_time.avg,
|
||||
"validation_loss": losses.avg,
|
||||
"mean_accuracy": correct / total,
|
||||
"mean_loss": losses.sum / total,
|
||||
}
|
||||
return stats
|
||||
|
||||
|
||||
class AverageMeter:
|
||||
"""Computes and stores the average and current value."""
|
||||
|
||||
def __init__(self):
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
self.val = 0
|
||||
self.avg = 0
|
||||
self.sum = 0
|
||||
self.count = 0
|
||||
|
||||
def update(self, val, n=1):
|
||||
self.val = val
|
||||
self.sum += val * n
|
||||
self.count += n
|
||||
self.avg = self.sum / self.count
|
||||
@@ -10,28 +10,34 @@ import torch.distributed as dist
|
||||
|
||||
import ray
|
||||
from ray import tune
|
||||
from ray.tests.conftest import ray_start_2_cpus # noqa: F401
|
||||
from ray.util.sgd.pytorch import PyTorchTrainer, PyTorchTrainable
|
||||
from ray.util.sgd.pytorch.utils import (train, BATCH_COUNT, TEST_MODE,
|
||||
SCHEDULER_STEP)
|
||||
from ray.util.sgd.pytorch.training_operator import _TestingOperator
|
||||
from ray.util.sgd.pytorch.constants import BATCH_COUNT, SCHEDULER_STEP
|
||||
from ray.util.sgd.utils import check_for_failure
|
||||
|
||||
from ray.util.sgd.pytorch.examples.train_example import (
|
||||
model_creator, optimizer_creator, data_creator, LinearDataset)
|
||||
|
||||
|
||||
def test_test_mode(ray_start_2_cpus): # noqa: F811
|
||||
@pytest.fixture
|
||||
def ray_start_2_cpus():
|
||||
address_info = ray.init(num_cpus=2)
|
||||
yield address_info
|
||||
# The code after the yield will run as teardown code.
|
||||
ray.shutdown()
|
||||
|
||||
|
||||
def test_single_step(ray_start_2_cpus): # noqa: F811
|
||||
trainer = PyTorchTrainer(
|
||||
model_creator,
|
||||
data_creator,
|
||||
optimizer_creator,
|
||||
loss_creator=lambda config: nn.MSELoss(),
|
||||
config={TEST_MODE: True},
|
||||
num_replicas=1)
|
||||
metrics = trainer.train()
|
||||
metrics = trainer.train(num_steps=1)
|
||||
assert metrics[BATCH_COUNT] == 1
|
||||
|
||||
val_metrics = trainer.validate()
|
||||
val_metrics = trainer.validate(num_steps=1)
|
||||
assert val_metrics[BATCH_COUNT] == 1
|
||||
|
||||
|
||||
@@ -45,29 +51,51 @@ def test_train(ray_start_2_cpus, num_replicas): # noqa: F811
|
||||
loss_creator=lambda config: nn.MSELoss(),
|
||||
num_replicas=num_replicas)
|
||||
for i in range(3):
|
||||
train_loss1 = trainer.train()["train_loss"]
|
||||
validation_loss1 = trainer.validate()["validation_loss"]
|
||||
train_loss1 = trainer.train()["mean_train_loss"]
|
||||
validation_loss1 = trainer.validate()["mean_validation_loss"]
|
||||
|
||||
for i in range(3):
|
||||
train_loss2 = trainer.train()["train_loss"]
|
||||
validation_loss2 = trainer.validate()["validation_loss"]
|
||||
train_loss2 = trainer.train()["mean_train_loss"]
|
||||
validation_loss2 = trainer.validate()["mean_validation_loss"]
|
||||
|
||||
print(train_loss1, train_loss2)
|
||||
print(validation_loss1, validation_loss2)
|
||||
|
||||
assert train_loss2 <= train_loss1
|
||||
assert validation_loss2 <= validation_loss1
|
||||
assert train_loss2 <= train_loss1, (train_loss2, train_loss1)
|
||||
assert validation_loss2 <= validation_loss1, (validation_loss2,
|
||||
validation_loss1)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_replicas", [1, 2]
|
||||
if dist.is_available() else [1])
|
||||
def test_multi_model(ray_start_2_cpus, num_replicas): # noqa: F811
|
||||
def custom_train(config, models, dataloader, criterion, optimizers,
|
||||
**kwargs):
|
||||
def test_multi_model(ray_start_2_cpus, num_replicas):
|
||||
def train(*, model=None, criterion=None, optimizer=None, dataloader=None):
|
||||
model.train()
|
||||
train_loss = 0
|
||||
correct = 0
|
||||
total = 0
|
||||
for batch_idx, (inputs, targets) in enumerate(dataloader):
|
||||
optimizer.zero_grad()
|
||||
outputs = model(inputs)
|
||||
loss = criterion(outputs, targets)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
train_loss += loss.item()
|
||||
_, predicted = outputs.max(1)
|
||||
total += targets.size(0)
|
||||
correct += predicted.eq(targets).sum().item()
|
||||
return {
|
||||
"accuracy": correct / total,
|
||||
"train_loss": train_loss / (batch_idx + 1)
|
||||
}
|
||||
|
||||
def train_epoch(self, iterator, info):
|
||||
result = {}
|
||||
for i, (model, optimizer) in enumerate(zip(models, optimizers)):
|
||||
result["model_{}".format(i)] = train(config, model, dataloader,
|
||||
criterion, optimizer)
|
||||
for i, (model, optimizer) in enumerate(
|
||||
zip(self.models, self.optimizers)):
|
||||
result["model_{}".format(i)] = train(
|
||||
model=model,
|
||||
criterion=self.criterion,
|
||||
optimizer=optimizer,
|
||||
dataloader=iterator)
|
||||
return result
|
||||
|
||||
def multi_model_creator(config):
|
||||
@@ -84,7 +112,8 @@ def test_multi_model(ray_start_2_cpus, num_replicas): # noqa: F811
|
||||
data_creator,
|
||||
multi_optimizer_creator,
|
||||
loss_creator=lambda config: nn.MSELoss(),
|
||||
train_function=custom_train,
|
||||
config={"custom_func": train_epoch},
|
||||
training_operator_cls=_TestingOperator,
|
||||
num_replicas=num_replicas)
|
||||
trainer1.train()
|
||||
|
||||
@@ -100,6 +129,8 @@ def test_multi_model(ray_start_2_cpus, num_replicas): # noqa: F811
|
||||
data_creator,
|
||||
multi_optimizer_creator,
|
||||
loss_creator=lambda config: nn.MSELoss(),
|
||||
config={"custom_func": train_epoch},
|
||||
training_operator_cls=_TestingOperator,
|
||||
num_replicas=num_replicas)
|
||||
trainer2.restore(filename)
|
||||
|
||||
@@ -123,16 +154,17 @@ def test_multi_model(ray_start_2_cpus, num_replicas): # noqa: F811
|
||||
@pytest.mark.parametrize("num_replicas", [1, 2]
|
||||
if dist.is_available() else [1])
|
||||
def test_multi_model_matrix(ray_start_2_cpus, num_replicas): # noqa: F811
|
||||
def custom_train(config, model, dataloader, criterion, optimizer,
|
||||
scheduler):
|
||||
if config.get("models", 1) > 1:
|
||||
assert len(model) == config["models"], config
|
||||
def train_epoch(self, iterator, info):
|
||||
if self.config.get("models", 1) > 1:
|
||||
assert len(self.models) == self.config["models"], self.config
|
||||
|
||||
if config.get("optimizers", 1) > 1:
|
||||
assert len(optimizer) == config["optimizers"], config
|
||||
if self.config.get("optimizers", 1) > 1:
|
||||
assert len(
|
||||
self.optimizers) == self.config["optimizers"], self.config
|
||||
|
||||
if config.get("schedulers", 1) > 1:
|
||||
assert len(scheduler) == config["schedulers"], config
|
||||
if self.config.get("schedulers", 1) > 1:
|
||||
assert len(
|
||||
self.schedulers) == self.config["schedulers"], self.config
|
||||
return {"done": 1}
|
||||
|
||||
def multi_model_creator(config):
|
||||
@@ -167,12 +199,13 @@ def test_multi_model_matrix(ray_start_2_cpus, num_replicas): # noqa: F811
|
||||
multi_optimizer_creator,
|
||||
loss_creator=nn.MSELoss,
|
||||
scheduler_creator=multi_scheduler_creator,
|
||||
train_function=custom_train,
|
||||
training_operator_cls=_TestingOperator,
|
||||
num_replicas=num_replicas,
|
||||
config={
|
||||
"models": model_count,
|
||||
"optimizers": optimizer_count,
|
||||
"schedulers": scheduler_count
|
||||
"schedulers": scheduler_count,
|
||||
"custom_func": train_epoch
|
||||
})
|
||||
trainer.train()
|
||||
trainer.shutdown()
|
||||
@@ -180,9 +213,8 @@ def test_multi_model_matrix(ray_start_2_cpus, num_replicas): # noqa: F811
|
||||
|
||||
@pytest.mark.parametrize("scheduler_freq", ["epoch", "batch"])
|
||||
def test_scheduler_freq(ray_start_2_cpus, scheduler_freq): # noqa: F811
|
||||
def custom_train(config, model, dataloader, criterion, optimizer,
|
||||
scheduler):
|
||||
assert config[SCHEDULER_STEP] == scheduler_freq
|
||||
def train_epoch(self, iterator, info):
|
||||
assert info[SCHEDULER_STEP] == scheduler_freq
|
||||
return {"done": 1}
|
||||
|
||||
def scheduler_creator(optimizer, config):
|
||||
@@ -194,18 +226,17 @@ def test_scheduler_freq(ray_start_2_cpus, scheduler_freq): # noqa: F811
|
||||
data_creator,
|
||||
optimizer_creator,
|
||||
loss_creator=lambda config: nn.MSELoss(),
|
||||
scheduler_creator=scheduler_creator)
|
||||
config={"custom_func": train_epoch},
|
||||
training_operator_cls=_TestingOperator,
|
||||
scheduler_creator=scheduler_creator,
|
||||
scheduler_step_freq=scheduler_freq)
|
||||
|
||||
for i in range(3):
|
||||
trainer.train()["train_loss"]
|
||||
trainer.train()
|
||||
trainer.shutdown()
|
||||
|
||||
|
||||
def test_scheduler_validate(ray_start_2_cpus): # noqa: F811
|
||||
def custom_train(config, model, dataloader, criterion, optimizer,
|
||||
scheduler):
|
||||
return {"done": 1}
|
||||
|
||||
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
||||
|
||||
trainer = PyTorchTrainer(
|
||||
@@ -213,11 +244,13 @@ def test_scheduler_validate(ray_start_2_cpus): # noqa: F811
|
||||
data_creator,
|
||||
optimizer_creator,
|
||||
loss_creator=lambda config: nn.MSELoss(),
|
||||
scheduler_creator=lambda optimizer, cfg: ReduceLROnPlateau(optimizer))
|
||||
scheduler_creator=lambda optimizer, cfg: ReduceLROnPlateau(optimizer),
|
||||
training_operator_cls=_TestingOperator)
|
||||
trainer.update_scheduler(0.5)
|
||||
trainer.update_scheduler(0.5)
|
||||
assert all(
|
||||
trainer.apply_all_workers(lambda r: r.schedulers[0].last_epoch == 2))
|
||||
trainer.apply_all_operators(
|
||||
lambda op: op.schedulers[0].last_epoch == 2))
|
||||
trainer.shutdown()
|
||||
|
||||
|
||||
@@ -248,13 +281,13 @@ def test_tune_train(ray_start_2_cpus, num_replicas): # noqa: F811
|
||||
|
||||
# checks loss decreasing for every trials
|
||||
for path, df in analysis.trial_dataframes.items():
|
||||
train_loss1 = df.loc[0, "train_loss"]
|
||||
train_loss2 = df.loc[1, "train_loss"]
|
||||
validation_loss1 = df.loc[0, "validation_loss"]
|
||||
validation_loss2 = df.loc[1, "validation_loss"]
|
||||
mean_train_loss1 = df.loc[0, "mean_train_loss"]
|
||||
mean_train_loss2 = df.loc[1, "mean_train_loss"]
|
||||
mean_validation_loss1 = df.loc[0, "mean_validation_loss"]
|
||||
mean_validation_loss2 = df.loc[1, "mean_validation_loss"]
|
||||
|
||||
assert train_loss2 <= train_loss1
|
||||
assert validation_loss2 <= validation_loss1
|
||||
assert mean_train_loss2 <= mean_train_loss1
|
||||
assert mean_validation_loss2 <= mean_validation_loss1
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_replicas", [1, 2]
|
||||
@@ -303,15 +336,17 @@ def test_fail_with_recover(ray_start_2_cpus): # noqa: F811
|
||||
def single_loader(config):
|
||||
return LinearDataset(2, 5, size=1000000)
|
||||
|
||||
def step_with_fail(self):
|
||||
worker_stats = [w.step.remote() for w in self.workers]
|
||||
def step_with_fail(self, *args, **kwargs):
|
||||
worker_stats = [
|
||||
w.train_epoch.remote(*args, **kwargs) for w in self.workers
|
||||
]
|
||||
if self._num_failures < 3:
|
||||
time.sleep(1) # Make the batch will fail correctly.
|
||||
self.workers[0].__ray_kill__()
|
||||
success = check_for_failure(worker_stats)
|
||||
return success, worker_stats
|
||||
|
||||
with patch.object(PyTorchTrainer, "_train_step", step_with_fail):
|
||||
with patch.object(PyTorchTrainer, "_train_epoch", step_with_fail):
|
||||
trainer1 = PyTorchTrainer(
|
||||
model_creator,
|
||||
single_loader,
|
||||
@@ -331,15 +366,17 @@ def test_resize(ray_start_2_cpus): # noqa: F811
|
||||
def single_loader(config):
|
||||
return LinearDataset(2, 5, size=1000000)
|
||||
|
||||
def step_with_fail(self):
|
||||
worker_stats = [w.step.remote() for w in self.workers]
|
||||
def step_with_fail(self, *args, **kwargs):
|
||||
worker_stats = [
|
||||
w.train_epoch.remote(*args, **kwargs) for w in self.workers
|
||||
]
|
||||
if self._num_failures < 1:
|
||||
time.sleep(1) # Make the batch will fail correctly.
|
||||
self.workers[0].__ray_kill__()
|
||||
success = check_for_failure(worker_stats)
|
||||
return success, worker_stats
|
||||
|
||||
with patch.object(PyTorchTrainer, "_train_step", step_with_fail):
|
||||
with patch.object(PyTorchTrainer, "_train_epoch", step_with_fail):
|
||||
trainer1 = PyTorchTrainer(
|
||||
model_creator,
|
||||
single_loader,
|
||||
@@ -365,15 +402,17 @@ def test_fail_twice(ray_start_2_cpus): # noqa: F811
|
||||
def single_loader(config):
|
||||
return LinearDataset(2, 5, size=1000000)
|
||||
|
||||
def step_with_fail(self):
|
||||
worker_stats = [w.step.remote() for w in self.workers]
|
||||
def step_with_fail(self, *args, **kwargs):
|
||||
worker_stats = [
|
||||
w.train_epoch.remote(*args, **kwargs) for w in self.workers
|
||||
]
|
||||
if self._num_failures < 2:
|
||||
time.sleep(1)
|
||||
self.workers[0].__ray_kill__()
|
||||
success = check_for_failure(worker_stats)
|
||||
return success, worker_stats
|
||||
|
||||
with patch.object(PyTorchTrainer, "_train_step", step_with_fail):
|
||||
with patch.object(PyTorchTrainer, "_train_epoch", step_with_fail):
|
||||
trainer1 = PyTorchTrainer(
|
||||
model_creator,
|
||||
single_loader,
|
||||
@@ -383,3 +422,9 @@ def test_fail_twice(ray_start_2_cpus): # noqa: F811
|
||||
num_replicas=2)
|
||||
|
||||
trainer1.train(max_retries=2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import pytest
|
||||
import sys
|
||||
sys.exit(pytest.main(["-v", "-x", __file__]))
|
||||
|
||||
@@ -4,6 +4,7 @@ import torch.nn as nn
|
||||
import unittest
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from ray.util.sgd.pytorch.training_operator import TrainingOperator
|
||||
from ray.util.sgd.pytorch.pytorch_runner import PyTorchRunner
|
||||
|
||||
|
||||
@@ -46,39 +47,55 @@ def create_dataloaders(config):
|
||||
|
||||
class TestPyTorchRunner(unittest.TestCase):
|
||||
def testValidate(self):
|
||||
mock_function = MagicMock(returns=dict(mean_accuracy=10))
|
||||
class MockOperator(TrainingOperator):
|
||||
def setup(self, config):
|
||||
self.train_epoch = MagicMock(returns=dict(mean_accuracy=10))
|
||||
self.validate = MagicMock(returns=dict(mean_accuracy=10))
|
||||
|
||||
runner = PyTorchRunner(
|
||||
model_creator,
|
||||
create_dataloaders,
|
||||
optimizer_creator,
|
||||
loss_creator,
|
||||
validation_function=mock_function)
|
||||
training_operator_cls=MockOperator)
|
||||
runner.setup()
|
||||
runner.step()
|
||||
runner.step()
|
||||
runner.step()
|
||||
self.assertEqual(mock_function.call_count, 0)
|
||||
runner.train_epoch()
|
||||
runner.train_epoch()
|
||||
runner.train_epoch()
|
||||
self.assertEqual(runner.training_operator.validate.call_count, 0)
|
||||
runner.validate()
|
||||
self.assertTrue(mock_function.called)
|
||||
self.assertTrue(runner.training_operator.validate.called)
|
||||
self.assertEqual(runner.stats()["epoch"], 3)
|
||||
|
||||
def testStep(self):
|
||||
mock_function = MagicMock(return_value=dict(mean_accuracy=10))
|
||||
def testtrain_epoch(self):
|
||||
class MockOperator(TrainingOperator):
|
||||
def setup(self, config):
|
||||
self.count = 0
|
||||
|
||||
def train_epoch(self, *args, **kwargs):
|
||||
self.count += 1
|
||||
return {"count": self.count}
|
||||
|
||||
runner = PyTorchRunner(
|
||||
model_creator,
|
||||
create_dataloaders,
|
||||
optimizer_creator,
|
||||
loss_creator,
|
||||
train_function=mock_function)
|
||||
training_operator_cls=MockOperator)
|
||||
runner.setup()
|
||||
runner.step()
|
||||
runner.step()
|
||||
result = runner.step()
|
||||
self.assertEqual(mock_function.call_count, 3)
|
||||
self.assertEqual(result["epoch"], 3)
|
||||
runner.train_epoch(num_steps=1)
|
||||
runner.train_epoch(num_steps=1)
|
||||
result = runner.train_epoch()
|
||||
self.assertEqual(runner.training_operator.count, 3)
|
||||
self.assertEqual(result["count"], 3)
|
||||
self.assertEqual(runner.stats()["epoch"], 3)
|
||||
|
||||
def testGivens(self):
|
||||
class MockOperator(TrainingOperator):
|
||||
def setup(self, config):
|
||||
self.train_epoch = MagicMock(returns=dict(mean_accuracy=10))
|
||||
self.validate = MagicMock(returns=dict(mean_accuracy=10))
|
||||
|
||||
def three_model_creator(config):
|
||||
return nn.Linear(1, 1), nn.Linear(1, 1), nn.Linear(1, 1)
|
||||
|
||||
@@ -88,8 +105,12 @@ class TestPyTorchRunner(unittest.TestCase):
|
||||
]
|
||||
return opts[0], opts[1], opts[2]
|
||||
|
||||
runner = PyTorchRunner(three_model_creator, single_loader,
|
||||
three_optimizer_creator, loss_creator)
|
||||
runner = PyTorchRunner(
|
||||
three_model_creator,
|
||||
single_loader,
|
||||
three_optimizer_creator,
|
||||
loss_creator,
|
||||
training_operator_cls=MockOperator)
|
||||
runner.setup()
|
||||
|
||||
self.assertEqual(len(runner.given_models), 3)
|
||||
@@ -121,7 +142,7 @@ class TestPyTorchRunner(unittest.TestCase):
|
||||
runner = PyTorchRunner(model_creator, single_loader, optimizer_creator,
|
||||
loss_creator)
|
||||
runner.setup()
|
||||
runner.step()
|
||||
runner.train_epoch()
|
||||
with self.assertRaises(ValueError):
|
||||
runner.validate()
|
||||
|
||||
@@ -132,7 +153,7 @@ class TestPyTorchRunner(unittest.TestCase):
|
||||
optimizer_creator,
|
||||
loss_creator=nn.MSELoss)
|
||||
runner.setup()
|
||||
runner.step()
|
||||
runner.train_epoch()
|
||||
|
||||
def testMultiModel(self):
|
||||
def multi_model_creator(config):
|
||||
@@ -146,6 +167,6 @@ class TestPyTorchRunner(unittest.TestCase):
|
||||
|
||||
runner = PyTorchRunner(multi_model_creator, single_loader,
|
||||
multi_optimizer_creator, loss_creator)
|
||||
runner.setup()
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
runner.step()
|
||||
runner.setup()
|
||||
|
||||
@@ -149,3 +149,11 @@ def check_for_failure(remote_values):
|
||||
except RayActorError as exc:
|
||||
logger.exception(str(exc))
|
||||
return False
|
||||
|
||||
|
||||
def override(interface_class):
|
||||
def overrider(method):
|
||||
assert (method.__name__ in dir(interface_class))
|
||||
return method
|
||||
|
||||
return overrider
|
||||
|
||||
Reference in New Issue
Block a user