import os from os.path import join import math import logging from typing import Callable, Optional, Union, Dict, Tuple import gin from fire import Fire import numpy as np import torch from torch.utils.data import DataLoader from torch import optim from torch import nn from experiments.base import Experiment from data.datasets import ForecastDataset from models import get_model from utils.checkpoint import Checkpoint from utils.ops import default_device, to_tensor from utils.losses import get_loss_fn from utils.metrics import calc_metrics class ForecastExperiment(Experiment): @gin.configurable() def instance(self, model_type: str, save_vals: Optional[bool] = True,): # load datasets, model, checkpointer train_set, train_loader = get_data(flag='train') val_set, val_loader = get_data(flag='val') test_set, test_loader = get_data(flag='test') model = get_model(model_type, dim_size=train_set.data_x.shape[1], datetime_feats=train_set.timestamps.shape[-1]).to(default_device()) checkpoint = Checkpoint(self.root) # train forecasting task model = train(model, checkpoint, train_loader, val_loader, test_loader) # testing val_metrics = validate(model, loader=val_loader, report_metrics=True) test_metrics = validate(model, loader=test_loader, report_metrics=True, save_path=self.root if save_vals else None) np.save(join(self.root, 'metrics.npy'), {'val': val_metrics, 'test': test_metrics}) val_metrics = {f'ValMetric/{k}': v for k, v in val_metrics.items()} test_metrics = {f'TestMetric/{k}': v for k, v in test_metrics.items()} checkpoint.close({**val_metrics, **test_metrics}) @gin.configurable() def get_optimizer(model: nn.Module, lr: Optional[float] = 1e-3, lambda_lr: Optional[float] = 1., weight_decay: Optional[float] = 1e-2) -> optim.Optimizer: group1 = [] # lambda group2 = [] # no decay group3 = [] # decay no_decay_list = ('bias', 'norm',) for param_name, param in model.named_parameters(): if '_lambda' in param_name: group1.append(param) elif any([mod in param_name for mod in no_decay_list]): group2.append(param) else: group3.append(param) optimizer = optim.Adam([ {'params': group1, 'weight_decay': 0, 'lr': lambda_lr, 'scheduler': 'cosine_annealing'}, {'params': group2, 'weight_decay': 0, 'scheduler': 'cosine_annealing_with_linear_warmup'}, {'params': group3, 'scheduler': 'cosine_annealing_with_linear_warmup'} ], lr=lr, weight_decay=weight_decay) return optimizer @gin.configurable() def get_scheduler(optimizer: optim.Optimizer, T_max: int, warmup_epochs: int, eta_min: Optional[float] = 0.) -> optim.lr_scheduler.LambdaLR: scheduler_fns = [] for param_group in optimizer.param_groups: scheduler = param_group['scheduler'] if scheduler == 'none': fn = lambda T_cur: 1 elif scheduler == 'cosine_annealing': lr = eta_max = param_group['lr'] fn = lambda T_cur: (eta_min + 0.5 * (eta_max - eta_min) * ( 1.0 + math.cos((T_cur - warmup_epochs) / (T_max - warmup_epochs) * math.pi))) / lr elif scheduler == 'cosine_annealing_with_linear_warmup': lr = eta_max = param_group['lr'] # https://blog.csdn.net/qq_36560894/article/details/114004799 fn = lambda T_cur: T_cur / warmup_epochs if T_cur < warmup_epochs else (eta_min + 0.5 * ( eta_max - eta_min) * (1.0 + math.cos( (T_cur - warmup_epochs) / (T_max - warmup_epochs) * math.pi))) / lr else: raise ValueError(f'No such scheduler, {scheduler}') scheduler_fns.append(fn) scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=scheduler_fns) return scheduler @gin.configurable() def get_data(flag: bool, batch_size: int) -> Tuple[ForecastDataset, DataLoader]: if flag in ('val', 'test'): shuffle = False drop_last = False elif flag == 'train': shuffle = True drop_last = True else: raise ValueError(f'no such flag {flag}') dataset = ForecastDataset(flag) data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last) return dataset, data_loader @gin.configurable() def train(model: nn.Module, checkpoint: Checkpoint, train_loader: DataLoader, val_loader: DataLoader, test_loader: DataLoader, loss_name: str, epochs: int, clip: float) -> nn.Module: optimizer = get_optimizer(model) scheduler = get_scheduler(optimizer=optimizer, T_max=epochs) training_loss_fn = get_loss_fn(loss_name) for epoch in range(epochs): train_loss = [] model.train() for it, data in enumerate(train_loader): optimizer.zero_grad() x, y, x_time, y_time = map(to_tensor, data) forecast = model(x, x_time, y_time) if isinstance(forecast, tuple): # for models which require reconstruction + forecast loss loss = training_loss_fn(forecast[0], x) + \ training_loss_fn(forecast[1], y) else: loss = training_loss_fn(forecast, y) loss.backward() nn.utils.clip_grad_norm_(model.parameters(), clip) optimizer.step() train_loss.append(loss.item()) if (it + 1) % 100 == 0: logging.info(f"epochs: {epoch + 1}, iters: {it + 1} | training loss: {loss.item():.2f}") scheduler.step() train_loss = np.average(train_loss) val_loss = validate(model, loader=val_loader, loss_fn=training_loss_fn) test_loss = validate(model, loader=test_loader, loss_fn=training_loss_fn) scalars = {'Loss/Train': train_loss, 'Loss/Val': val_loss, 'Loss/Test': test_loss} checkpoint(epoch + 1, model, scalars=scalars) if checkpoint.early_stop: logging.info("Early stopping") break if epochs > 0: model.load_state_dict(torch.load(checkpoint.model_path)) return model @torch.no_grad() def validate(model: nn.Module, loader: DataLoader, loss_fn: Optional[Callable] = None, report_metrics: Optional[bool] = False, save_path: Optional[str] = None) -> Union[Dict[str, float], float]: model.eval() preds = [] trues = [] inps = [] total_loss = [] for it, data in enumerate(loader): x, y, x_time, y_time = map(to_tensor, data) if x.shape[0] == 1: # skip final batch if batch_size == 1 # due to bug in torch.linalg.solve which raises error when batch_size == 1 continue forecast = model(x, x_time, y_time) if report_metrics: preds.append(forecast) trues.append(y) if save_path is not None: inps.append(x) else: loss = loss_fn(forecast, y, reduction='none') total_loss.append(loss) if report_metrics: preds = torch.cat(preds, dim=0).detach().cpu().numpy() trues = torch.cat(trues, dim=0).detach().cpu().numpy() if save_path is not None: inps = torch.cat(inps, dim=0).detach().cpu().numpy() np.save(join(save_path, 'inps.npy'), inps) np.save(join(save_path, 'preds.npy'), preds) np.save(join(save_path, 'trues.npy'), trues) metrics = calc_metrics(preds, trues) return metrics total_loss = torch.cat(total_loss, dim=0).cpu() return np.average(total_loss) if __name__ == '__main__': logging.root.setLevel(logging.INFO) Fire(ForecastExperiment)