Files
DeepTime/models/modules/regressors.py
T
2022-07-13 16:03:34 +08:00

41 lines
1.3 KiB
Python

from typing import Optional
import torch
from torch import Tensor
from torch import nn
import torch.nn.functional as F
class RidgeRegressor(nn.Module):
def __init__(self, lambda_init: Optional[float] =0.):
super().__init__()
self._lambda = nn.Parameter(torch.as_tensor(lambda_init, dtype=torch.float))
def forward(self, reprs: Tensor, x: Tensor, reg_coeff: Optional[float] = None) -> Tensor:
if reg_coeff is None:
reg_coeff = self.reg_coeff()
w, b = self.get_weights(reprs, x, reg_coeff)
return w, b
def get_weights(self, X: Tensor, Y: Tensor, reg_coeff: float) -> Tensor:
batch_size, n_samples, n_dim = X.shape
ones = torch.ones(batch_size, n_samples, 1, device=X.device)
X = torch.concat([X, ones], dim=-1)
if n_samples >= n_dim:
# standard
A = torch.bmm(X.mT, X)
A.diagonal(dim1=-2, dim2=-1).add_(reg_coeff)
B = torch.bmm(X.mT, Y)
weights = torch.linalg.solve(A, B)
else:
# Woodbury
A = torch.bmm(X, X.mT)
A.diagonal(dim1=-2, dim2=-1).add_(reg_coeff)
weights = torch.bmm(X.mT, torch.linalg.solve(A, Y))
return weights[:, :-1], weights[:, -1:]
def reg_coeff(self) -> Tensor:
return F.softplus(self._lambda)