Files
DeepTime/utils/checkpoint.py
2022-07-13 16:03:34 +08:00

59 lines
2.1 KiB
Python

from typing import Optional, Dict
import logging
from os.path import join
import gin
import torch
from torch.utils.tensorboard import SummaryWriter
@gin.configurable()
class Checkpoint:
def __init__(self,
checkpoint_dir: str,
patience: Optional[int] = 7,
delta: Optional[float] = 0.):
self.checkpoint_dir = checkpoint_dir
self.model_path = join(checkpoint_dir, 'model.pth')
# early stopping
self.patience = patience
self.counter = 0
self.best_loss = float('inf')
self.early_stop = False
self.delta = delta
# logging
self.summary_writer = SummaryWriter(log_dir=checkpoint_dir)
def __call__(self,
epoch: int,
model: torch.nn.Module,
scalars: Optional[Dict[str, float]] = None):
for name, value in scalars.items():
# logging
self.summary_writer.add_scalar(name, value, epoch)
# early stopping
if name == 'Loss/Val':
val_loss = value
if val_loss <= self.best_loss + self.delta:
logging.info(
f"Validation loss decreased ({self.best_loss:.3f} --> {val_loss:.3f}). Saving model ...")
torch.save(model.state_dict(), self.model_path)
self.best_loss = val_loss
self.counter = 0
else:
self.counter += 1
logging.info(f"Validation loss increased ({self.best_loss:.3f} --> {val_loss:.3f}). "
f"Early stopping counter: {self.counter} out of {self.patience}")
if self.counter >= self.patience >= 0:
self.early_stop = True
self.summary_writer.flush()
def close(self, scores: Optional[Dict[str, float]] = None):
if scores is not None:
for name, value in scores.items():
self.summary_writer.add_scalar(name, value)
self.summary_writer.close()