mirror of
https://github.com/wassname/pytorch-ts.git
synced 2026-06-27 17:49:41 +08:00
129 lines
3.9 KiB
Python
129 lines
3.9 KiB
Python
from abc import ABC, abstractmethod
|
|
from typing import Tuple
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
from gluonts.core.component import validated
|
|
|
|
|
|
class Scaler(ABC, nn.Module):
|
|
def __init__(self, keepdim: bool = False, time_first: bool = True):
|
|
super().__init__()
|
|
self.keepdim = keepdim
|
|
self.time_first = time_first
|
|
|
|
@abstractmethod
|
|
def compute_scale(
|
|
self, data: torch.Tensor, observed_indicator: torch.Tensor
|
|
) -> torch.Tensor:
|
|
pass
|
|
|
|
def forward(
|
|
self, data: torch.Tensor, observed_indicator: torch.Tensor
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
"""
|
|
Parameters
|
|
----------
|
|
data
|
|
tensor of shape (N, T, C) if ``time_first == True`` or (N, C, T)
|
|
if ``time_first == False`` containing the data to be scaled
|
|
|
|
observed_indicator
|
|
observed_indicator: binary tensor with the same shape as
|
|
``data``, that has 1 in correspondence of observed data points,
|
|
and 0 in correspondence of missing data points.
|
|
|
|
Returns
|
|
-------
|
|
Tensor
|
|
Tensor containing the "scaled" data, shape: (N, T, C) or (N, C, T).
|
|
Tensor
|
|
Tensor containing the scale, of shape (N, C) if ``keepdim == False``,
|
|
and shape (N, 1, C) or (N, C, 1) if ``keepdim == True``.
|
|
"""
|
|
|
|
scale = self.compute_scale(data, observed_indicator)
|
|
|
|
if self.time_first:
|
|
dim = 1
|
|
else:
|
|
dim = 2
|
|
if self.keepdim:
|
|
scale = scale.unsqueeze(dim=dim)
|
|
return data / scale, scale
|
|
else:
|
|
return data / scale.unsqueeze(dim=dim), scale
|
|
|
|
|
|
class MeanScaler(Scaler):
|
|
"""
|
|
The ``MeanScaler`` computes a per-item scale according to the average
|
|
absolute value over time of each item. The average is computed only among
|
|
the observed values in the data tensor, as indicated by the second
|
|
argument. Items with no observed data are assigned a scale based on the
|
|
global average.
|
|
|
|
Parameters
|
|
----------
|
|
minimum_scale
|
|
default scale that is used if the time series has only zeros.
|
|
"""
|
|
|
|
@validated()
|
|
def __init__(self, minimum_scale: float = 1e-10, *args, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
self.register_buffer("minimum_scale", torch.tensor(minimum_scale))
|
|
|
|
def compute_scale(
|
|
self, data: torch.Tensor, observed_indicator: torch.Tensor
|
|
) -> torch.Tensor:
|
|
|
|
if self.time_first:
|
|
dim = 1
|
|
else:
|
|
dim = 2
|
|
|
|
# these will have shape (N, C)
|
|
num_observed = observed_indicator.sum(dim=dim)
|
|
sum_observed = (data.abs() * observed_indicator).sum(dim=dim)
|
|
|
|
# first compute a global scale per-dimension
|
|
total_observed = num_observed.sum(dim=0)
|
|
denominator = torch.max(total_observed, torch.ones_like(total_observed))
|
|
default_scale = sum_observed.sum(dim=0) / denominator
|
|
|
|
# then compute a per-item, per-dimension scale
|
|
denominator = torch.max(num_observed, torch.ones_like(num_observed))
|
|
scale = sum_observed / denominator
|
|
|
|
# use per-batch scale when no element is observed
|
|
# or when the sequence contains only zeros
|
|
scale = torch.where(
|
|
sum_observed > torch.zeros_like(sum_observed),
|
|
scale,
|
|
default_scale * torch.ones_like(num_observed),
|
|
)
|
|
|
|
return torch.max(scale, self.minimum_scale).detach()
|
|
|
|
|
|
class NOPScaler(Scaler):
|
|
"""
|
|
The ``NOPScaler`` assigns a scale equals to 1 to each input item, i.e.,
|
|
no scaling is applied upon calling the ``NOPScaler``.
|
|
"""
|
|
|
|
@validated()
|
|
def __init__(self, *args, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
|
|
def compute_scale(
|
|
self, data: torch.Tensor, observed_indicator: torch.Tensor
|
|
) -> torch.Tensor:
|
|
if self.time_first:
|
|
dim = 1
|
|
else:
|
|
dim = 2
|
|
return torch.ones_like(data).mean(dim=dim)
|