mirror of
https://github.com/wassname/Pointnet2_PyTorch.git
synced 2026-06-27 16:00:07 +08:00
Broke utils library into its own repo so easier to share
This commit is contained in:
@@ -0,0 +1,3 @@
|
||||
[submodule "utils/pytorch_utils"]
|
||||
path = utils/pytorch_utils
|
||||
url = git@github.com:erikwijmans/etw_pytorch_utils.git
|
||||
Submodule
+1
Submodule utils/pytorch_utils added at 327d0c8842
@@ -1,744 +0,0 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.autograd import Variable
|
||||
from torch.autograd.function import InplaceFunction
|
||||
from itertools import repeat
|
||||
import numpy as np
|
||||
import tensorboard_logger as tb_log
|
||||
import shutil, os
|
||||
from tqdm import tqdm
|
||||
from natsort import natsorted
|
||||
from operator import itemgetter
|
||||
from typing import List, Tuple
|
||||
from scipy.stats import t as student_t
|
||||
import statistics as stats
|
||||
import math
|
||||
|
||||
|
||||
class SharedMLP(nn.Sequential):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
args: List[int],
|
||||
*,
|
||||
bn: bool = False,
|
||||
activation=nn.ReLU(inplace=True),
|
||||
name: str = ""
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
for i in range(len(args) - 1):
|
||||
self.add_module(
|
||||
name + 'layer{}'.format(i),
|
||||
Conv2d(args[i], args[i + 1], bn=bn, activation=activation)
|
||||
)
|
||||
|
||||
|
||||
class _ConvBase(nn.Sequential):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_size,
|
||||
out_size,
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
activation,
|
||||
bn,
|
||||
init,
|
||||
conv=None,
|
||||
batch_norm=None,
|
||||
bias=True,
|
||||
name=""
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
bias = bias and (not bn)
|
||||
self.add_module(
|
||||
name + 'conv',
|
||||
conv(
|
||||
in_size,
|
||||
out_size,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
bias=bias
|
||||
)
|
||||
)
|
||||
init(self[0].weight)
|
||||
|
||||
if bias:
|
||||
nn.init.constant(self[0].bias, 0)
|
||||
|
||||
if bn:
|
||||
self.add_module(name + 'bn', batch_norm(out_size))
|
||||
nn.init.constant(self[1].weight, 1)
|
||||
nn.init.constant(self[1].bias, 0)
|
||||
|
||||
if activation is not None:
|
||||
self.add_module(name + 'activation', activation)
|
||||
|
||||
|
||||
class Conv1d(_ConvBase):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_size: int,
|
||||
out_size: int,
|
||||
*,
|
||||
kernel_size: int = 1,
|
||||
stride: int = 1,
|
||||
padding: int = 0,
|
||||
activation=nn.ReLU(inplace=True),
|
||||
bn: bool = False,
|
||||
init=nn.init.kaiming_normal,
|
||||
bias: bool = True,
|
||||
name: str = ""
|
||||
):
|
||||
super().__init__(
|
||||
in_size,
|
||||
out_size,
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
activation,
|
||||
bn,
|
||||
init,
|
||||
conv=nn.Conv1d,
|
||||
batch_norm=nn.BatchNorm1d,
|
||||
bias=bias,
|
||||
name=name
|
||||
)
|
||||
|
||||
|
||||
class Conv2d(_ConvBase):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_size: int,
|
||||
out_size: int,
|
||||
*,
|
||||
kernel_size: Tuple[int, int] = (1, 1),
|
||||
stride: Tuple[int, int] = (1, 1),
|
||||
padding: Tuple[int, int] = (0, 0),
|
||||
activation=nn.ReLU(inplace=True),
|
||||
bn: bool = False,
|
||||
init=nn.init.kaiming_normal,
|
||||
bias: bool = True,
|
||||
name: str = ""
|
||||
):
|
||||
super().__init__(
|
||||
in_size,
|
||||
out_size,
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
activation,
|
||||
bn,
|
||||
init,
|
||||
conv=nn.Conv2d,
|
||||
batch_norm=nn.BatchNorm2d,
|
||||
bias=bias,
|
||||
name=name
|
||||
)
|
||||
|
||||
|
||||
class Conv3d(_ConvBase):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_size: int,
|
||||
out_size: int,
|
||||
*,
|
||||
kernel_size: Tuple[int, int, int] = (1, 1, 1),
|
||||
stride: Tuple[int, int, int] = (1, 1, 1),
|
||||
padding: Tuple[int, int, int] = (0, 0, 0),
|
||||
activation=nn.ReLU(inplace=True),
|
||||
bn: bool = False,
|
||||
init=nn.init.kaiming_normal,
|
||||
bias: bool = True,
|
||||
name: str = ""
|
||||
):
|
||||
super().__init__(
|
||||
in_size,
|
||||
out_size,
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
activation,
|
||||
bn,
|
||||
init,
|
||||
conv=nn.Conv3d,
|
||||
batch_norm=nn.BatchNorm3d,
|
||||
bias=bias,
|
||||
name=name
|
||||
)
|
||||
|
||||
|
||||
class FC(nn.Sequential):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_size: int,
|
||||
out_size: int,
|
||||
*,
|
||||
activation=nn.ReLU(inplace=True),
|
||||
bn: bool = False,
|
||||
init=None,
|
||||
name: str = ""
|
||||
):
|
||||
super().__init__()
|
||||
self.add_module(name + 'fc', nn.Linear(in_size, out_size, bias=not bn))
|
||||
if init is not None:
|
||||
init(self[0].weight)
|
||||
|
||||
if not bn:
|
||||
nn.init.constant(self[0].bias, 0)
|
||||
|
||||
if bn:
|
||||
self.add_module(name + 'bn', nn.BatchNorm1d(out_size))
|
||||
nn.init.constant(self[1].weight, 1)
|
||||
nn.init.constant(self[1].bias, 0)
|
||||
|
||||
if activation is not None:
|
||||
self.add_module(name + 'activation', activation)
|
||||
|
||||
|
||||
class _DropoutNoScaling(InplaceFunction):
|
||||
|
||||
@staticmethod
|
||||
def _make_noise(input):
|
||||
return input.new().resize_as_(input)
|
||||
|
||||
@staticmethod
|
||||
def symbolic(g, input, p=0.5, train=False, inplace=False):
|
||||
if inplace:
|
||||
return None
|
||||
n = g.appendNode(
|
||||
g.create("Dropout", [input]).f_("ratio",
|
||||
p).i_("is_test", not train)
|
||||
)
|
||||
real = g.appendNode(g.createSelect(n, 0))
|
||||
g.appendNode(g.createSelect(n, 1))
|
||||
return real
|
||||
|
||||
@classmethod
|
||||
def forward(cls, ctx, input, p=0.5, train=False, inplace=False):
|
||||
if p < 0 or p > 1:
|
||||
raise ValueError(
|
||||
"dropout probability has to be between 0 and 1, "
|
||||
"but got {}".format(p)
|
||||
)
|
||||
ctx.p = p
|
||||
ctx.train = train
|
||||
ctx.inplace = inplace
|
||||
|
||||
if ctx.inplace:
|
||||
ctx.mark_dirty(input)
|
||||
output = input
|
||||
else:
|
||||
output = input.clone()
|
||||
|
||||
if ctx.p > 0 and ctx.train:
|
||||
ctx.noise = cls._make_noise(input)
|
||||
if ctx.p == 1:
|
||||
ctx.noise.fill_(0)
|
||||
else:
|
||||
ctx.noise.bernoulli_(1 - ctx.p)
|
||||
ctx.noise = ctx.noise.expand_as(input)
|
||||
output.mul_(ctx.noise)
|
||||
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
if ctx.p > 0 and ctx.train:
|
||||
return grad_output.mul(Variable(ctx.noise)), None, None, None
|
||||
else:
|
||||
return grad_output, None, None, None
|
||||
|
||||
|
||||
dropout_no_scaling = _DropoutNoScaling.apply
|
||||
|
||||
|
||||
class _FeatureDropoutNoScaling(_DropoutNoScaling):
|
||||
|
||||
@staticmethod
|
||||
def symbolic(input, p=0.5, train=False, inplace=False):
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _make_noise(input):
|
||||
return input.new().resize_(
|
||||
input.size(0), input.size(1), *repeat(1,
|
||||
input.dim() - 2)
|
||||
)
|
||||
|
||||
|
||||
feature_dropout_no_scaling = _FeatureDropoutNoScaling.apply
|
||||
|
||||
|
||||
def checkpoint_state(model=None, optimizer=None, best_prec=None, epoch=None):
|
||||
return {
|
||||
'epoch': epoch,
|
||||
'best_prec': best_prec,
|
||||
'model_state': model.state_dict() if model is not None else None,
|
||||
'optimizer_state': optimizer.state_dict()
|
||||
if optimizer is not None else None
|
||||
}
|
||||
|
||||
|
||||
def save_checkpoint(
|
||||
state, is_best, filename='checkpoint', bestname='model_best'
|
||||
):
|
||||
filename = '{}.pth.tar'.format(filename)
|
||||
torch.save(state, filename)
|
||||
if is_best:
|
||||
shutil.copyfile(filename, '{}.pth.tar'.format(bestname))
|
||||
|
||||
|
||||
def load_checkpoint(model=None, optimizer=None, filename='checkpoint'):
|
||||
filename = "{}.pth.tar".format(filename)
|
||||
if os.path.isfile(filename):
|
||||
print("==> Loading from checkpoint '{}'".format(filename))
|
||||
checkpoint = torch.load(filename)
|
||||
epoch = checkpoint['epoch']
|
||||
best_prec = checkpoint['best_prec']
|
||||
if model is not None and checkpoint['model_state'] is not None:
|
||||
model.load_state_dict(checkpoint['model_state'])
|
||||
if optimizer is not None and checkpoint['optimizer_state'] is not None:
|
||||
optimizer.load_state_dict(checkpoint['optimizer_state'])
|
||||
print("==> Done")
|
||||
else:
|
||||
print("==> Checkpoint '{}' not found".format(filename))
|
||||
|
||||
return epoch, best_prec
|
||||
|
||||
|
||||
def variable_size_collate(pad_val=0, use_shared_memory=True):
|
||||
import collections
|
||||
_numpy_type_map = {
|
||||
'float64': torch.DoubleTensor,
|
||||
'float32': torch.FloatTensor,
|
||||
'float16': torch.HalfTensor,
|
||||
'int64': torch.LongTensor,
|
||||
'int32': torch.IntTensor,
|
||||
'int16': torch.ShortTensor,
|
||||
'int8': torch.CharTensor,
|
||||
'uint8': torch.ByteTensor,
|
||||
}
|
||||
|
||||
def wrapped(batch):
|
||||
"Puts each data field into a tensor with outer dimension batch size"
|
||||
|
||||
error_msg = "batch must contain tensors, numbers, dicts or lists; found {}"
|
||||
elem_type = type(batch[0])
|
||||
if torch.is_tensor(batch[0]):
|
||||
max_len = 0
|
||||
for b in batch:
|
||||
max_len = max(max_len, b.size(0))
|
||||
|
||||
numel = sum([int(b.numel() / b.size(0) * max_len) for b in batch])
|
||||
if use_shared_memory:
|
||||
# If we're in a background process, concatenate directly into a
|
||||
# shared memory tensor to avoid an extra copy
|
||||
storage = batch[0].storage()._new_shared(numel)
|
||||
out = batch[0].new(storage)
|
||||
else:
|
||||
out = batch[0].new(numel)
|
||||
|
||||
out = out.view(
|
||||
len(batch), max_len,
|
||||
*[batch[0].size(i) for i in range(1, batch[0].dim())]
|
||||
)
|
||||
out.fill_(pad_val)
|
||||
for i in range(len(batch)):
|
||||
out[i, 0:batch[i].size(0)] = batch[i]
|
||||
|
||||
return out
|
||||
elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
|
||||
and elem_type.__name__ != 'string_':
|
||||
elem = batch[0]
|
||||
if elem_type.__name__ == 'ndarray':
|
||||
# array of string classes and object
|
||||
if re.search('[SaUO]', elem.dtype.str) is not None:
|
||||
raise TypeError(error_msg.format(elem.dtype))
|
||||
|
||||
return wrapped([torch.from_numpy(b) for b in batch])
|
||||
if elem.shape == (): # scalars
|
||||
py_type = float if elem.dtype.name.startswith('float') else int
|
||||
return _numpy_type_map[elem.dtype.name](
|
||||
list(map(py_type, batch))
|
||||
)
|
||||
elif isinstance(batch[0], int):
|
||||
return torch.LongTensor(batch)
|
||||
elif isinstance(batch[0], float):
|
||||
return torch.DoubleTensor(batch)
|
||||
elif isinstance(batch[0], collections.Mapping):
|
||||
return {key: wrapped([d[key] for d in batch]) for key in batch[0]}
|
||||
elif isinstance(batch[0], collections.Sequence):
|
||||
transposed = zip(*batch)
|
||||
return [wrapped(samples) for samples in transposed]
|
||||
|
||||
raise TypeError((error_msg.format(type(batch[0]))))
|
||||
|
||||
return wrapped
|
||||
|
||||
|
||||
class TrainValSplitter():
|
||||
r"""
|
||||
Creates a training and validation split to be used as the sampler in a pytorch DataLoader
|
||||
Parameters
|
||||
---------
|
||||
numel : int
|
||||
Number of elements in the entire training dataset
|
||||
percent_train : float
|
||||
Percentage of data in the training split
|
||||
shuffled : bool
|
||||
Whether or not shuffle which data goes to which split
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, *, numel: int, percent_train: float, shuffled: bool = False
|
||||
):
|
||||
indicies = np.array([i for i in range(numel)])
|
||||
if shuffled:
|
||||
np.random.shuffle(indicies)
|
||||
|
||||
self.train = torch.utils.data.sampler.SubsetRandomSampler(
|
||||
indicies[0:int(percent_train * numel)]
|
||||
)
|
||||
self.val = torch.utils.data.sampler.SubsetRandomSampler(
|
||||
indicies[int(percent_train * numel):-1]
|
||||
)
|
||||
|
||||
|
||||
class CrossValSplitter():
|
||||
r"""
|
||||
Class that creates cross validation splits. The train and val splits can be used in pytorch DataLoaders. The splits can be updated
|
||||
by calling next(self) or using a loop:
|
||||
for _ in self:
|
||||
....
|
||||
Parameters
|
||||
---------
|
||||
numel : int
|
||||
Number of elements in the training set
|
||||
k_folds : int
|
||||
Number of folds
|
||||
shuffled : bool
|
||||
Whether or not to shuffle which data goes in which fold
|
||||
"""
|
||||
|
||||
def __init__(self, *, numel: int, k_folds: int, shuffled: bool = False):
|
||||
inidicies = np.array([i for i in range(numel)])
|
||||
if shuffled:
|
||||
np.random.shuffle(inidicies)
|
||||
|
||||
self.folds = np.array(np.array_split(inidicies, k_folds), dtype=object)
|
||||
self.current_v_ind = -1
|
||||
|
||||
self.val = torch.utils.data.sampler.SubsetRandomSampler(self.folds[0])
|
||||
self.train = torch.utils.data.sampler.SubsetRandomSampler(
|
||||
np.concatenate(self.folds[1:], axis=0)
|
||||
)
|
||||
|
||||
self.metrics = {}
|
||||
|
||||
def __iter__(self):
|
||||
self.current_v_ind = -1
|
||||
return self
|
||||
|
||||
def __len__(self):
|
||||
return len(self.folds)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
assert idx >= 0 and idx < len(self)
|
||||
self.val.inidicies = self.folds[idx]
|
||||
self.train.inidicies = np.concatenate(
|
||||
self.folds[np.arange(len(self)) != idx], axis=0
|
||||
)
|
||||
|
||||
def __next__(self):
|
||||
self.current_v_ind += 1
|
||||
if self.current_v_ind >= len(self):
|
||||
raise StopIteration
|
||||
|
||||
self[self.current_v_ind]
|
||||
|
||||
def update_metrics(self, to_post: dict):
|
||||
for k, v in to_post.items():
|
||||
if k in self.metrics:
|
||||
self.metrics[k].append(v)
|
||||
else:
|
||||
self.metrics[k] = [v]
|
||||
|
||||
def print_metrics(self):
|
||||
for name, samples in self.metrics.items():
|
||||
xbar = stats.mean(samples)
|
||||
sx = stats.stdev(samples, xbar)
|
||||
tstar = student_t.ppf(1.0 - 0.025, len(samples) - 1)
|
||||
margin_of_error = tstar * sx / sqrt(len(samples))
|
||||
print("{}: {} +/- {}".format(name, xbar, margin_of_error))
|
||||
|
||||
|
||||
def set_bn_momentum_default(bn_momentum):
|
||||
|
||||
def fn(m):
|
||||
if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)):
|
||||
m.momentum = bn_momentum
|
||||
|
||||
return fn
|
||||
|
||||
|
||||
class BNMomentumScheduler(object):
|
||||
|
||||
def __init__(
|
||||
self, model, bn_lambda, last_epoch=-1,
|
||||
setter=set_bn_momentum_default
|
||||
):
|
||||
if not isinstance(model, nn.Module):
|
||||
raise RuntimeError(
|
||||
"Class '{}' is not a PyTorch nn Module".format(
|
||||
type(model).__name__
|
||||
)
|
||||
)
|
||||
|
||||
self.model = model
|
||||
self.setter = setter
|
||||
self.lmbd = bn_lambda
|
||||
|
||||
self.step(last_epoch + 1)
|
||||
self.last_epoch = last_epoch
|
||||
|
||||
def step(self, epoch=None):
|
||||
if epoch is None:
|
||||
epoch = self.last_epoch + 1
|
||||
|
||||
self.last_epoch = epoch
|
||||
self.model.apply(self.setter(self.lmbd(epoch)))
|
||||
|
||||
|
||||
class Trainer(object):
|
||||
r"""
|
||||
Reasonably generic trainer for pytorch models
|
||||
|
||||
Parameters
|
||||
----------
|
||||
model : pytorch model
|
||||
Model to be trained
|
||||
model_fn : function (model, inputs, labels) -> preds, loss, accuracy
|
||||
optimizer : torch.optim
|
||||
Optimizer for model
|
||||
checkpoint_name : str
|
||||
Name of file to save checkpoints to
|
||||
best_name : str
|
||||
Name of file to save best model to
|
||||
lr_scheduler : torch.optim.lr_scheduler
|
||||
Learning rate scheduler. .step() will be called at the start of every epoch
|
||||
bnm_scheduler : BNMomentumScheduler
|
||||
Batchnorm momentum scheduler. .step() will be called at the start of every epoch
|
||||
eval_frequency : int
|
||||
How often to run an eval
|
||||
log_name : str
|
||||
Name of file to output tensorboard_logger to
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model,
|
||||
model_fn,
|
||||
optimizer,
|
||||
checkpoint_name="ckpt",
|
||||
best_name="best",
|
||||
lr_scheduler=None,
|
||||
bnm_scheduler=None,
|
||||
eval_frequency=1,
|
||||
log_name=None
|
||||
):
|
||||
self.model, self.model_fn, self.optimizer, self.lr_scheduler, self.bnm_scheduler = (
|
||||
model, model_fn, optimizer, lr_scheduler, bnm_scheduler
|
||||
)
|
||||
|
||||
self.checkpoint_name, self.best_name = checkpoint_name, best_name
|
||||
self.eval_frequency = eval_frequency
|
||||
self.training_best = {}
|
||||
self.eval_best = {}
|
||||
|
||||
if log_name is not None:
|
||||
tb_log.configure(log_name)
|
||||
self.logging = True
|
||||
else:
|
||||
self.logging = False
|
||||
|
||||
@staticmethod
|
||||
def _print(mode, epoch, loss, eval_dict, count):
|
||||
to_print = "[{:d}] {}\tMean Loss: {:.4e}".format(
|
||||
epoch, mode, loss / count
|
||||
)
|
||||
for k, v in natsorted(eval_dict.items(), key=itemgetter(0)):
|
||||
to_print += "\tMean {}: {:2.3f}%".format(k, stats.mean(v) * 1e2)
|
||||
|
||||
print(to_print)
|
||||
|
||||
def _train_epoch(self, epoch, d_loader):
|
||||
self.model.train()
|
||||
total_loss = 0.0
|
||||
count = 0.0
|
||||
eval_dict = {}
|
||||
|
||||
for i, data in tqdm(enumerate(d_loader, 0), total=len(d_loader)):
|
||||
if self.lr_scheduler is not None:
|
||||
self.lr_scheduler.step(epoch - 1 + i / len(d_loader))
|
||||
|
||||
if self.bnm_scheduler is not None:
|
||||
self.bnm_scheduler.step(epoch - 1 + i / len(d_loader))
|
||||
|
||||
self.optimizer.zero_grad()
|
||||
_, loss, eval_res = self.model_fn(self.model, data, epoch=epoch)
|
||||
|
||||
loss.backward()
|
||||
self.optimizer.step()
|
||||
|
||||
total_loss += loss.data[0]
|
||||
for k, v in eval_res.items():
|
||||
if v is not None:
|
||||
eval_dict[k] = eval_dict.get(k, []) + [v]
|
||||
|
||||
count += 1.0
|
||||
|
||||
if self.logging:
|
||||
idx = (epoch - 1) * len(d_loader) + i
|
||||
tb_log.log_value("Training loss", loss.data[0], step=idx)
|
||||
for k, v in eval_res.items():
|
||||
if v is not None:
|
||||
tb_log.log_value(
|
||||
"Training {}".format(k), 1.0 - v, step=idx
|
||||
)
|
||||
|
||||
d_loader.dataset.randomize()
|
||||
|
||||
self._print("Train", epoch, total_loss, eval_dict, count)
|
||||
|
||||
if 'loss' in self.training_best:
|
||||
self.training_best['loss'] = np.min(
|
||||
self.training_best['loss'], total_loss / count
|
||||
)
|
||||
else:
|
||||
self.training_best['loss'] = total_loss / count
|
||||
|
||||
for k, v in eval_dict.items():
|
||||
if k in self.training_best:
|
||||
self.training_best[k] = np.max(
|
||||
self.training_best[k], stats.means(v)
|
||||
)
|
||||
else:
|
||||
self.training_best[k] = stats.mean(v)
|
||||
|
||||
def eval_epoch(self, epoch, d_loader):
|
||||
if d_loader is None:
|
||||
return
|
||||
|
||||
self.model.eval()
|
||||
total_loss = 0.0
|
||||
eval_dict = {}
|
||||
count = 0.0
|
||||
|
||||
for i, data in tqdm(enumerate(d_loader, 0), total=len(d_loader)):
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
_, loss, eval_res = self.model_fn(
|
||||
self.model, data, eval=True, epoch=epoch
|
||||
)
|
||||
|
||||
total_loss += loss.data[0]
|
||||
count += 1
|
||||
for k, v in eval_res.items():
|
||||
if v is not None:
|
||||
eval_dict[k] = eval_dict.get(k, []) + [v]
|
||||
|
||||
if self.logging:
|
||||
idx = (epoch - 1) * len(d_loader) + i
|
||||
tb_log.log_value("Eval loss", loss.data[0], step=idx)
|
||||
for k, v in eval_res.items():
|
||||
if v is not None:
|
||||
tb_log.log_value("Eval {}".format(k), 1.0 - v, step=idx)
|
||||
|
||||
d_loader.dataset.randomize()
|
||||
|
||||
self._print("Eval", epoch, total_loss, eval_dict, count)
|
||||
|
||||
if 'loss' in self.eval_best:
|
||||
self.eval_best['loss'] = np.min(
|
||||
self.eval_best['loss'], total_loss / count
|
||||
)
|
||||
else:
|
||||
self.eval_best['loss'] = total_loss / count
|
||||
|
||||
for k, v in eval_dict.items():
|
||||
if k in self.eval_best:
|
||||
self.eval_best[k] = np.max(self.eval_best[k], stats.means(v))
|
||||
else:
|
||||
self.eval_best[k] = stats.mean(v)
|
||||
|
||||
return total_loss / count, eval_dict
|
||||
|
||||
def train(
|
||||
self,
|
||||
start_epoch,
|
||||
n_epochs,
|
||||
train_loader,
|
||||
test_loader=None,
|
||||
best_loss=0.0
|
||||
):
|
||||
r"""
|
||||
Call to begin training the model
|
||||
|
||||
Parameters
|
||||
----------
|
||||
start_epoch : int
|
||||
Epoch to start at
|
||||
n_epochs : int
|
||||
Number of epochs to train for
|
||||
test_loader : torch.utils.data.DataLoader
|
||||
DataLoader of the test_data
|
||||
train_loader : torch.utils.data.DataLoader
|
||||
DataLoader of training data
|
||||
best_loss : float
|
||||
Testing loss of the best model
|
||||
"""
|
||||
for epoch in range(start_epoch, n_epochs + 1):
|
||||
|
||||
print("\n{0} Train Epoch {1:0>3d} {0}\n".format("-" * 5, epoch))
|
||||
self._train_epoch(epoch, train_loader)
|
||||
|
||||
if test_loader is not None and (epoch % self.eval_frequency) == 0:
|
||||
print("\n{0} Eval Epoch {1:0>3d} {0}\n".format("-" * 5, epoch))
|
||||
val_loss, _ = self.eval_epoch(epoch, test_loader)
|
||||
|
||||
is_best = val_loss < best_loss
|
||||
best_loss = min(best_loss, val_loss)
|
||||
save_checkpoint(
|
||||
checkpoint_state(
|
||||
self.model, self.optimizer, val_loss, epoch + 1
|
||||
),
|
||||
is_best,
|
||||
filename=self.checkpoint_name,
|
||||
bestname=self.best_name
|
||||
)
|
||||
|
||||
print("{0} Summary {0}".format("-" * 5))
|
||||
print("** Training Stats **")
|
||||
for k, v in natsorted(self.training_best.items(), key=itemgetter(0)):
|
||||
if k == 'loss':
|
||||
print("Best loss: {:.4e}".format(v))
|
||||
else:
|
||||
print("Best {}: {:2.3f}%".format(k, v * 1e2))
|
||||
|
||||
print("\n** Eval Stats **")
|
||||
for k, v in natsorted(self.eval_best.items(), key=itemgetter(0)):
|
||||
if k == 'loss':
|
||||
print("Best loss: {:.4e}".format(v))
|
||||
else:
|
||||
print("Best {}: {:2.3f}%".format(k, v * 1e2))
|
||||
|
||||
return best_loss
|
||||
Reference in New Issue
Block a user