Files
DeepTime/experiments/forecast.py
T
2022-11-19 20:16:20 +08:00

227 lines
8.0 KiB
Python

import os
from os.path import join
import math
import logging
from typing import Callable, Optional, Union, Dict, Tuple
import gin
from fire import Fire
import numpy as np
import torch
from torch.utils.data import DataLoader
from torch import optim
from torch import nn
from experiments.base import Experiment
from data.datasets import ForecastDataset
from models import get_model
from utils.checkpoint import Checkpoint
from utils.ops import default_device, to_tensor
from utils.losses import get_loss_fn
from utils.metrics import calc_metrics
class ForecastExperiment(Experiment):
@gin.configurable()
def instance(self,
model_type: str,
save_vals: Optional[bool] = True,):
# load datasets, model, checkpointer
train_set, train_loader = get_data(flag='train')
val_set, val_loader = get_data(flag='val')
test_set, test_loader = get_data(flag='test')
model = get_model(model_type,
dim_size=train_set.data_x.shape[1],
datetime_feats=train_set.timestamps.shape[-1]).to(default_device())
checkpoint = Checkpoint(self.root)
# train forecasting task
model = train(model, checkpoint, train_loader, val_loader, test_loader)
# testing
val_metrics = validate(model, loader=val_loader, report_metrics=True)
test_metrics = validate(model, loader=test_loader, report_metrics=True,
save_path=self.root if save_vals else None)
np.save(join(self.root, 'metrics.npy'), {'val': val_metrics, 'test': test_metrics})
val_metrics = {f'ValMetric/{k}': v for k, v in val_metrics.items()}
test_metrics = {f'TestMetric/{k}': v for k, v in test_metrics.items()}
checkpoint.close({**val_metrics, **test_metrics})
@gin.configurable()
def get_optimizer(model: nn.Module,
lr: Optional[float] = 1e-3,
lambda_lr: Optional[float] = 1.,
weight_decay: Optional[float] = 1e-2) -> optim.Optimizer:
group1 = [] # lambda
group2 = [] # no decay
group3 = [] # decay
no_decay_list = ('bias', 'norm',)
for param_name, param in model.named_parameters():
if '_lambda' in param_name:
group1.append(param)
elif any([mod in param_name for mod in no_decay_list]):
group2.append(param)
else:
group3.append(param)
optimizer = optim.Adam([
{'params': group1, 'weight_decay': 0, 'lr': lambda_lr, 'scheduler': 'cosine_annealing'},
{'params': group2, 'weight_decay': 0, 'scheduler': 'cosine_annealing_with_linear_warmup'},
{'params': group3, 'scheduler': 'cosine_annealing_with_linear_warmup'}
], lr=lr, weight_decay=weight_decay)
return optimizer
@gin.configurable()
def get_scheduler(optimizer: optim.Optimizer,
T_max: int,
warmup_epochs: int,
eta_min: Optional[float] = 0.) -> optim.lr_scheduler.LambdaLR:
scheduler_fns = []
for param_group in optimizer.param_groups:
scheduler = param_group['scheduler']
if scheduler == 'none':
fn = lambda T_cur: 1
elif scheduler == 'cosine_annealing':
lr = eta_max = param_group['lr']
fn = lambda T_cur: (eta_min + 0.5 * (eta_max - eta_min) * (
1.0 + math.cos((T_cur - warmup_epochs) / (T_max - warmup_epochs) * math.pi))) / lr
elif scheduler == 'cosine_annealing_with_linear_warmup':
lr = eta_max = param_group['lr']
# https://blog.csdn.net/qq_36560894/article/details/114004799
fn = lambda T_cur: T_cur / warmup_epochs if T_cur < warmup_epochs else (eta_min + 0.5 * (
eta_max - eta_min) * (1.0 + math.cos(
(T_cur - warmup_epochs) / (T_max - warmup_epochs) * math.pi))) / lr
else:
raise ValueError(f'No such scheduler, {scheduler}')
scheduler_fns.append(fn)
scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=scheduler_fns)
return scheduler
@gin.configurable()
def get_data(flag: bool,
batch_size: int) -> Tuple[ForecastDataset, DataLoader]:
if flag in ('val', 'test'):
shuffle = False
drop_last = False
elif flag == 'train':
shuffle = True
drop_last = True
else:
raise ValueError(f'no such flag {flag}')
dataset = ForecastDataset(flag)
data_loader = DataLoader(dataset,
batch_size=batch_size,
shuffle=shuffle,
drop_last=drop_last)
return dataset, data_loader
@gin.configurable()
def train(model: nn.Module,
checkpoint: Checkpoint,
train_loader: DataLoader,
val_loader: DataLoader,
test_loader: DataLoader,
loss_name: str,
epochs: int,
clip: float) -> nn.Module:
optimizer = get_optimizer(model)
scheduler = get_scheduler(optimizer=optimizer, T_max=epochs)
training_loss_fn = get_loss_fn(loss_name)
for epoch in range(epochs):
train_loss = []
model.train()
for it, data in enumerate(train_loader):
optimizer.zero_grad()
x, y, x_time, y_time = map(to_tensor, data)
forecast = model(x, x_time, y_time)
if isinstance(forecast, tuple):
# for models which require reconstruction + forecast loss
loss = training_loss_fn(forecast[0], x) + \
training_loss_fn(forecast[1], y)
else:
loss = training_loss_fn(forecast, y)
loss.backward()
nn.utils.clip_grad_norm_(model.parameters(), clip)
optimizer.step()
train_loss.append(loss.item())
if (it + 1) % 100 == 0:
logging.info(f"epochs: {epoch + 1}, iters: {it + 1} | training loss: {loss.item():.2f}")
scheduler.step()
train_loss = np.average(train_loss)
val_loss = validate(model, loader=val_loader, loss_fn=training_loss_fn)
test_loss = validate(model, loader=test_loader, loss_fn=training_loss_fn)
scalars = {'Loss/Train': train_loss,
'Loss/Val': val_loss,
'Loss/Test': test_loss}
checkpoint(epoch + 1, model, scalars=scalars)
if checkpoint.early_stop:
logging.info("Early stopping")
break
if epochs > 0:
model.load_state_dict(torch.load(checkpoint.model_path))
return model
@torch.no_grad()
def validate(model: nn.Module,
loader: DataLoader,
loss_fn: Optional[Callable] = None,
report_metrics: Optional[bool] = False,
save_path: Optional[str] = None) -> Union[Dict[str, float], float]:
model.eval()
preds = []
trues = []
inps = []
total_loss = []
for it, data in enumerate(loader):
x, y, x_time, y_time = map(to_tensor, data)
if x.shape[0] == 1:
# skip final batch if batch_size == 1
# due to bug in torch.linalg.solve which raises error when batch_size == 1
continue
forecast = model(x, x_time, y_time)
if report_metrics:
preds.append(forecast)
trues.append(y)
if save_path is not None:
inps.append(x)
else:
loss = loss_fn(forecast, y, reduction='none')
total_loss.append(loss)
if report_metrics:
preds = torch.cat(preds, dim=0).detach().cpu().numpy()
trues = torch.cat(trues, dim=0).detach().cpu().numpy()
if save_path is not None:
inps = torch.cat(inps, dim=0).detach().cpu().numpy()
np.save(join(save_path, 'inps.npy'), inps)
np.save(join(save_path, 'preds.npy'), preds)
np.save(join(save_path, 'trues.npy'), trues)
metrics = calc_metrics(preds, trues)
return metrics
total_loss = torch.cat(total_loss, dim=0).cpu()
return np.average(total_loss)
if __name__ == '__main__':
logging.root.setLevel(logging.INFO)
Fire(ForecastExperiment)