mirror of
https://github.com/wassname/DeepTime.git
synced 2026-06-27 19:13:18 +08:00
59 lines
2.1 KiB
Python
59 lines
2.1 KiB
Python
from typing import Optional, Dict
|
|
import logging
|
|
from os.path import join
|
|
|
|
import gin
|
|
import torch
|
|
from torch.utils.tensorboard import SummaryWriter
|
|
|
|
|
|
@gin.configurable()
|
|
class Checkpoint:
|
|
def __init__(self,
|
|
checkpoint_dir: str,
|
|
patience: Optional[int] = 7,
|
|
delta: Optional[float] = 0.):
|
|
self.checkpoint_dir = checkpoint_dir
|
|
self.model_path = join(checkpoint_dir, 'model.pth')
|
|
|
|
# early stopping
|
|
self.patience = patience
|
|
self.counter = 0
|
|
self.best_loss = float('inf')
|
|
self.early_stop = False
|
|
self.delta = delta
|
|
|
|
# logging
|
|
self.summary_writer = SummaryWriter(log_dir=checkpoint_dir)
|
|
|
|
def __call__(self,
|
|
epoch: int,
|
|
model: torch.nn.Module,
|
|
scalars: Optional[Dict[str, float]] = None):
|
|
for name, value in scalars.items():
|
|
# logging
|
|
self.summary_writer.add_scalar(name, value, epoch)
|
|
|
|
# early stopping
|
|
if name == 'Loss/Val':
|
|
val_loss = value
|
|
if val_loss <= self.best_loss + self.delta:
|
|
logging.info(
|
|
f"Validation loss decreased ({self.best_loss:.3f} --> {val_loss:.3f}). Saving model ...")
|
|
torch.save(model.state_dict(), self.model_path)
|
|
self.best_loss = val_loss
|
|
self.counter = 0
|
|
else:
|
|
self.counter += 1
|
|
logging.info(f"Validation loss increased ({self.best_loss:.3f} --> {val_loss:.3f}). "
|
|
f"Early stopping counter: {self.counter} out of {self.patience}")
|
|
if self.counter >= self.patience >= 0:
|
|
self.early_stop = True
|
|
|
|
self.summary_writer.flush()
|
|
|
|
def close(self, scores: Optional[Dict[str, float]] = None):
|
|
if scores is not None:
|
|
for name, value in scores.items():
|
|
self.summary_writer.add_scalar(name, value)
|
|
self.summary_writer.close() |