Files
Dr. Kashif Rasul d5577a2c9e fix some tests
2020-12-30 19:17:54 +01:00

129 lines
3.9 KiB
Python

from abc import ABC, abstractmethod
from typing import Tuple
import torch
import torch.nn as nn
from gluonts.core.component import validated
class Scaler(ABC, nn.Module):
def __init__(self, keepdim: bool = False, time_first: bool = True):
super().__init__()
self.keepdim = keepdim
self.time_first = time_first
@abstractmethod
def compute_scale(
self, data: torch.Tensor, observed_indicator: torch.Tensor
) -> torch.Tensor:
pass
def forward(
self, data: torch.Tensor, observed_indicator: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Parameters
----------
data
tensor of shape (N, T, C) if ``time_first == True`` or (N, C, T)
if ``time_first == False`` containing the data to be scaled
observed_indicator
observed_indicator: binary tensor with the same shape as
``data``, that has 1 in correspondence of observed data points,
and 0 in correspondence of missing data points.
Returns
-------
Tensor
Tensor containing the "scaled" data, shape: (N, T, C) or (N, C, T).
Tensor
Tensor containing the scale, of shape (N, C) if ``keepdim == False``,
and shape (N, 1, C) or (N, C, 1) if ``keepdim == True``.
"""
scale = self.compute_scale(data, observed_indicator)
if self.time_first:
dim = 1
else:
dim = 2
if self.keepdim:
scale = scale.unsqueeze(dim=dim)
return data / scale, scale
else:
return data / scale.unsqueeze(dim=dim), scale
class MeanScaler(Scaler):
"""
The ``MeanScaler`` computes a per-item scale according to the average
absolute value over time of each item. The average is computed only among
the observed values in the data tensor, as indicated by the second
argument. Items with no observed data are assigned a scale based on the
global average.
Parameters
----------
minimum_scale
default scale that is used if the time series has only zeros.
"""
@validated()
def __init__(self, minimum_scale: float = 1e-10, *args, **kwargs):
super().__init__(*args, **kwargs)
self.register_buffer("minimum_scale", torch.tensor(minimum_scale))
def compute_scale(
self, data: torch.Tensor, observed_indicator: torch.Tensor
) -> torch.Tensor:
if self.time_first:
dim = 1
else:
dim = 2
# these will have shape (N, C)
num_observed = observed_indicator.sum(dim=dim)
sum_observed = (data.abs() * observed_indicator).sum(dim=dim)
# first compute a global scale per-dimension
total_observed = num_observed.sum(dim=0)
denominator = torch.max(total_observed, torch.ones_like(total_observed))
default_scale = sum_observed.sum(dim=0) / denominator
# then compute a per-item, per-dimension scale
denominator = torch.max(num_observed, torch.ones_like(num_observed))
scale = sum_observed / denominator
# use per-batch scale when no element is observed
# or when the sequence contains only zeros
scale = torch.where(
sum_observed > torch.zeros_like(sum_observed),
scale,
default_scale * torch.ones_like(num_observed),
)
return torch.max(scale, self.minimum_scale).detach()
class NOPScaler(Scaler):
"""
The ``NOPScaler`` assigns a scale equals to 1 to each input item, i.e.,
no scaling is applied upon calling the ``NOPScaler``.
"""
@validated()
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def compute_scale(
self, data: torch.Tensor, observed_indicator: torch.Tensor
) -> torch.Tensor:
if self.time_first:
dim = 1
else:
dim = 2
return torch.ones_like(data).mean(dim=dim)