Files
DeepTime/utils/losses.py
2022-07-13 16:03:34 +08:00

16 lines
469 B
Python

from typing import Optional, Callable
from functools import partial
import torch
import torch.nn.functional as F
from torch import Tensor
def get_loss_fn(loss_name: str,
delta: Optional[float] = 1.0,
beta: Optional[float] = 1.0) -> Callable:
return {'mse': F.mse_loss,
'mae': F.l1_loss,
'huber': partial(F.huber_loss, delta=delta),
'smooth_l1': partial(F.smooth_l1_loss, beta=beta)}[loss_name]