mirror of
https://github.com/wassname/DeepTime.git
synced 2026-06-27 19:29:06 +08:00
227 lines
8.0 KiB
Python
227 lines
8.0 KiB
Python
import os
|
|
from os.path import join
|
|
import math
|
|
import logging
|
|
from typing import Callable, Optional, Union, Dict, Tuple
|
|
|
|
import gin
|
|
from fire import Fire
|
|
import numpy as np
|
|
import torch
|
|
from torch.utils.data import DataLoader
|
|
from torch import optim
|
|
from torch import nn
|
|
|
|
from experiments.base import Experiment
|
|
from data.datasets import ForecastDataset
|
|
from models import get_model
|
|
from utils.checkpoint import Checkpoint
|
|
from utils.ops import default_device, to_tensor
|
|
from utils.losses import get_loss_fn
|
|
from utils.metrics import calc_metrics
|
|
|
|
|
|
class ForecastExperiment(Experiment):
|
|
@gin.configurable()
|
|
def instance(self,
|
|
model_type: str,
|
|
save_vals: Optional[bool] = True,):
|
|
# load datasets, model, checkpointer
|
|
train_set, train_loader = get_data(flag='train')
|
|
val_set, val_loader = get_data(flag='val')
|
|
test_set, test_loader = get_data(flag='test')
|
|
|
|
model = get_model(model_type,
|
|
dim_size=train_set.data_x.shape[1],
|
|
datetime_feats=train_set.timestamps.shape[-1]).to(default_device())
|
|
checkpoint = Checkpoint(self.root)
|
|
|
|
# train forecasting task
|
|
model = train(model, checkpoint, train_loader, val_loader, test_loader)
|
|
|
|
# testing
|
|
val_metrics = validate(model, loader=val_loader, report_metrics=True)
|
|
test_metrics = validate(model, loader=test_loader, report_metrics=True,
|
|
save_path=self.root if save_vals else None)
|
|
np.save(join(self.root, 'metrics.npy'), {'val': val_metrics, 'test': test_metrics})
|
|
|
|
val_metrics = {f'ValMetric/{k}': v for k, v in val_metrics.items()}
|
|
test_metrics = {f'TestMetric/{k}': v for k, v in test_metrics.items()}
|
|
checkpoint.close({**val_metrics, **test_metrics})
|
|
|
|
|
|
@gin.configurable()
|
|
def get_optimizer(model: nn.Module,
|
|
lr: Optional[float] = 1e-3,
|
|
lambda_lr: Optional[float] = 1.,
|
|
weight_decay: Optional[float] = 1e-2) -> optim.Optimizer:
|
|
group1 = [] # lambda
|
|
group2 = [] # no decay
|
|
group3 = [] # decay
|
|
no_decay_list = ('bias', 'norm',)
|
|
for param_name, param in model.named_parameters():
|
|
if '_lambda' in param_name:
|
|
group1.append(param)
|
|
elif any([mod in param_name for mod in no_decay_list]):
|
|
group2.append(param)
|
|
else:
|
|
group3.append(param)
|
|
optimizer = optim.Adam([
|
|
{'params': group1, 'weight_decay': 0, 'lr': lambda_lr, 'scheduler': 'cosine_annealing'},
|
|
{'params': group2, 'weight_decay': 0, 'scheduler': 'cosine_annealing_with_linear_warmup'},
|
|
{'params': group3, 'scheduler': 'cosine_annealing_with_linear_warmup'}
|
|
], lr=lr, weight_decay=weight_decay)
|
|
return optimizer
|
|
|
|
|
|
@gin.configurable()
|
|
def get_scheduler(optimizer: optim.Optimizer,
|
|
T_max: int,
|
|
warmup_epochs: int,
|
|
eta_min: Optional[float] = 0.) -> optim.lr_scheduler.LambdaLR:
|
|
scheduler_fns = []
|
|
for param_group in optimizer.param_groups:
|
|
scheduler = param_group['scheduler']
|
|
if scheduler == 'none':
|
|
fn = lambda T_cur: 1
|
|
elif scheduler == 'cosine_annealing':
|
|
lr = eta_max = param_group['lr']
|
|
fn = lambda T_cur: (eta_min + 0.5 * (eta_max - eta_min) * (
|
|
1.0 + math.cos((T_cur - warmup_epochs) / (T_max - warmup_epochs) * math.pi))) / lr
|
|
elif scheduler == 'cosine_annealing_with_linear_warmup':
|
|
lr = eta_max = param_group['lr']
|
|
# https://blog.csdn.net/qq_36560894/article/details/114004799
|
|
fn = lambda T_cur: T_cur / warmup_epochs if T_cur < warmup_epochs else (eta_min + 0.5 * (
|
|
eta_max - eta_min) * (1.0 + math.cos(
|
|
(T_cur - warmup_epochs) / (T_max - warmup_epochs) * math.pi))) / lr
|
|
else:
|
|
raise ValueError(f'No such scheduler, {scheduler}')
|
|
scheduler_fns.append(fn)
|
|
scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=scheduler_fns)
|
|
return scheduler
|
|
|
|
|
|
@gin.configurable()
|
|
def get_data(flag: bool,
|
|
batch_size: int) -> Tuple[ForecastDataset, DataLoader]:
|
|
if flag in ('val', 'test'):
|
|
shuffle = False
|
|
drop_last = False
|
|
elif flag == 'train':
|
|
shuffle = True
|
|
drop_last = True
|
|
else:
|
|
raise ValueError(f'no such flag {flag}')
|
|
dataset = ForecastDataset(flag)
|
|
data_loader = DataLoader(dataset,
|
|
batch_size=batch_size,
|
|
shuffle=shuffle,
|
|
drop_last=drop_last)
|
|
return dataset, data_loader
|
|
|
|
|
|
@gin.configurable()
|
|
def train(model: nn.Module,
|
|
checkpoint: Checkpoint,
|
|
train_loader: DataLoader,
|
|
val_loader: DataLoader,
|
|
test_loader: DataLoader,
|
|
loss_name: str,
|
|
epochs: int,
|
|
clip: float) -> nn.Module:
|
|
|
|
optimizer = get_optimizer(model)
|
|
scheduler = get_scheduler(optimizer=optimizer, T_max=epochs)
|
|
training_loss_fn = get_loss_fn(loss_name)
|
|
|
|
for epoch in range(epochs):
|
|
train_loss = []
|
|
model.train()
|
|
for it, data in enumerate(train_loader):
|
|
optimizer.zero_grad()
|
|
|
|
x, y, x_time, y_time = map(to_tensor, data)
|
|
forecast = model(x, x_time, y_time)
|
|
|
|
if isinstance(forecast, tuple):
|
|
# for models which require reconstruction + forecast loss
|
|
loss = training_loss_fn(forecast[0], x) + \
|
|
training_loss_fn(forecast[1], y)
|
|
else:
|
|
loss = training_loss_fn(forecast, y)
|
|
loss.backward()
|
|
nn.utils.clip_grad_norm_(model.parameters(), clip)
|
|
optimizer.step()
|
|
|
|
train_loss.append(loss.item())
|
|
if (it + 1) % 100 == 0:
|
|
logging.info(f"epochs: {epoch + 1}, iters: {it + 1} | training loss: {loss.item():.2f}")
|
|
scheduler.step()
|
|
|
|
train_loss = np.average(train_loss)
|
|
val_loss = validate(model, loader=val_loader, loss_fn=training_loss_fn)
|
|
test_loss = validate(model, loader=test_loader, loss_fn=training_loss_fn)
|
|
|
|
scalars = {'Loss/Train': train_loss,
|
|
'Loss/Val': val_loss,
|
|
'Loss/Test': test_loss}
|
|
checkpoint(epoch + 1, model, scalars=scalars)
|
|
|
|
if checkpoint.early_stop:
|
|
logging.info("Early stopping")
|
|
break
|
|
|
|
if epochs > 0:
|
|
model.load_state_dict(torch.load(checkpoint.model_path))
|
|
return model
|
|
|
|
|
|
@torch.no_grad()
|
|
def validate(model: nn.Module,
|
|
loader: DataLoader,
|
|
loss_fn: Optional[Callable] = None,
|
|
report_metrics: Optional[bool] = False,
|
|
save_path: Optional[str] = None) -> Union[Dict[str, float], float]:
|
|
model.eval()
|
|
preds = []
|
|
trues = []
|
|
inps = []
|
|
total_loss = []
|
|
for it, data in enumerate(loader):
|
|
x, y, x_time, y_time = map(to_tensor, data)
|
|
|
|
if x.shape[0] == 1:
|
|
# skip final batch if batch_size == 1
|
|
# due to bug in torch.linalg.solve which raises error when batch_size == 1
|
|
continue
|
|
|
|
forecast = model(x, x_time, y_time)
|
|
|
|
if report_metrics:
|
|
preds.append(forecast)
|
|
trues.append(y)
|
|
if save_path is not None:
|
|
inps.append(x)
|
|
else:
|
|
loss = loss_fn(forecast, y, reduction='none')
|
|
total_loss.append(loss)
|
|
|
|
if report_metrics:
|
|
preds = torch.cat(preds, dim=0).detach().cpu().numpy()
|
|
trues = torch.cat(trues, dim=0).detach().cpu().numpy()
|
|
if save_path is not None:
|
|
inps = torch.cat(inps, dim=0).detach().cpu().numpy()
|
|
np.save(join(save_path, 'inps.npy'), inps)
|
|
np.save(join(save_path, 'preds.npy'), preds)
|
|
np.save(join(save_path, 'trues.npy'), trues)
|
|
metrics = calc_metrics(preds, trues)
|
|
return metrics
|
|
|
|
total_loss = torch.cat(total_loss, dim=0).cpu()
|
|
return np.average(total_loss)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
logging.root.setLevel(logging.INFO)
|
|
Fire(ForecastExperiment)
|