mirror of
https://github.com/wassname/pytorch-ts.git
synced 2026-06-27 19:32:05 +08:00
added MAF
This commit is contained in:
@@ -13,4 +13,4 @@ from .distribution_output import (
|
||||
from .lambda_layer import LambdaLayer
|
||||
from .feature import FeatureEmbedder, FeatureAssembler
|
||||
from .scaler import MeanScaler, NOPScaler
|
||||
from .flows import RealNVP
|
||||
from .flows import RealNVP, MAF
|
||||
|
||||
+230
-17
@@ -6,6 +6,49 @@ import torch.nn.functional as F
|
||||
from torch.distributions import Normal
|
||||
|
||||
|
||||
def create_masks(
|
||||
input_size, hidden_size, n_hidden, input_order="sequential", input_degrees=None
|
||||
):
|
||||
# MADE paper sec 4:
|
||||
# degrees of connections between layers -- ensure at most in_degree - 1 connections
|
||||
degrees = []
|
||||
|
||||
# set input degrees to what is provided in args (the flipped order of the previous layer in a stack of mades);
|
||||
# else init input degrees based on strategy in input_order (sequential or random)
|
||||
if input_order == "sequential":
|
||||
degrees += (
|
||||
[torch.arange(input_size)] if input_degrees is None else [input_degrees]
|
||||
)
|
||||
for _ in range(n_hidden + 1):
|
||||
degrees += [torch.arange(hidden_size) % (input_size - 1)]
|
||||
degrees += (
|
||||
[torch.arange(input_size) % input_size - 1]
|
||||
if input_degrees is None
|
||||
else [input_degrees % input_size - 1]
|
||||
)
|
||||
|
||||
elif input_order == "random":
|
||||
degrees += (
|
||||
[torch.randperm(input_size)] if input_degrees is None else [input_degrees]
|
||||
)
|
||||
for _ in range(n_hidden + 1):
|
||||
min_prev_degree = min(degrees[-1].min().item(), input_size - 1)
|
||||
degrees += [torch.randint(min_prev_degree, input_size, (hidden_size,))]
|
||||
min_prev_degree = min(degrees[-1].min().item(), input_size - 1)
|
||||
degrees += (
|
||||
[torch.randint(min_prev_degree, input_size, (input_size,)) - 1]
|
||||
if input_degrees is None
|
||||
else [input_degrees - 1]
|
||||
)
|
||||
|
||||
# construct masks
|
||||
masks = []
|
||||
for (d0, d1) in zip(degrees[:-1], degrees[1:]):
|
||||
masks += [(d1.unsqueeze(-1) >= d0.unsqueeze(0)).float()]
|
||||
|
||||
return masks, degrees[0]
|
||||
|
||||
|
||||
class FlowSequential(nn.Sequential):
|
||||
""" Container for layers of a normalizing flow """
|
||||
|
||||
@@ -35,8 +78,8 @@ class BatchNorm(nn.Module):
|
||||
self.log_gamma = nn.Parameter(torch.zeros(input_size))
|
||||
self.beta = nn.Parameter(torch.zeros(input_size))
|
||||
|
||||
self.register_buffer('running_mean', torch.zeros(input_size))
|
||||
self.register_buffer('running_var', torch.ones(input_size))
|
||||
self.register_buffer("running_mean", torch.zeros(input_size))
|
||||
self.register_buffer("running_var", torch.ones(input_size))
|
||||
|
||||
def forward(self, x, cond_y=None):
|
||||
if self.training:
|
||||
@@ -46,9 +89,11 @@ class BatchNorm(nn.Module):
|
||||
|
||||
# update running mean
|
||||
self.running_mean.mul_(self.momentum).add_(
|
||||
self.batch_mean.data * (1 - self.momentum))
|
||||
self.batch_mean.data * (1 - self.momentum)
|
||||
)
|
||||
self.running_var.mul_(self.momentum).add_(
|
||||
self.batch_var.data * (1 - self.momentum))
|
||||
self.batch_var.data * (1 - self.momentum)
|
||||
)
|
||||
|
||||
mean = self.batch_mean
|
||||
var = self.batch_var
|
||||
@@ -62,8 +107,8 @@ class BatchNorm(nn.Module):
|
||||
|
||||
# compute log_abs_det_jacobian (cf RealNVP paper)
|
||||
log_abs_det_jacobian = self.log_gamma - 0.5 * torch.log(var + self.eps)
|
||||
# print('in sum log var {:6.3f} ; out sum log var {:6.3f}; sum log det {:8.3f}; mean log_gamma {:5.3f}; mean beta {:5.3f}'.format(
|
||||
# (var + self.eps).log().sum().data.numpy(), y.var(0).log().sum().data.numpy(), log_abs_det_jacobian.mean(0).item(), self.log_gamma.mean(), self.beta.mean()))
|
||||
# print('in sum log var {:6.3f} ; out sum log var {:6.3f}; sum log det {:8.3f}; mean log_gamma {:5.3f}; mean beta {:5.3f}'.format(
|
||||
# (var + self.eps).log().sum().data.numpy(), y.var(0).log().sum().data.numpy(), log_abs_det_jacobian.mean(0).item(), self.log_gamma.mean(), self.beta.mean()))
|
||||
return y, log_abs_det_jacobian.expand_as(x)
|
||||
|
||||
def inverse(self, y, cond_y=None):
|
||||
@@ -88,11 +133,15 @@ class LinearMaskedCoupling(nn.Module):
|
||||
def __init__(self, input_size, hidden_size, n_hidden, mask, cond_label_size=None):
|
||||
super().__init__()
|
||||
|
||||
self.register_buffer('mask', mask)
|
||||
self.register_buffer("mask", mask)
|
||||
|
||||
# scale function
|
||||
s_net = [nn.Linear(
|
||||
input_size + (cond_label_size if cond_label_size is not None else 0), hidden_size)]
|
||||
s_net = [
|
||||
nn.Linear(
|
||||
input_size + (cond_label_size if cond_label_size is not None else 0),
|
||||
hidden_size,
|
||||
)
|
||||
]
|
||||
for _ in range(n_hidden):
|
||||
s_net += [nn.Tanh(), nn.Linear(hidden_size, hidden_size)]
|
||||
s_net += [nn.Tanh(), nn.Linear(hidden_size, input_size)]
|
||||
@@ -116,7 +165,7 @@ class LinearMaskedCoupling(nn.Module):
|
||||
u = mx + (1 - self.mask) * (x - t) * torch.exp(-s)
|
||||
|
||||
# log det du/dx; cf RealNVP 8 and 6; note, sum over input_size done at model log_prob
|
||||
log_abs_det_jacobian = - (1 - self.mask) * s
|
||||
log_abs_det_jacobian = -(1 - self.mask) * s
|
||||
|
||||
return u, log_abs_det_jacobian
|
||||
|
||||
@@ -134,13 +183,101 @@ class LinearMaskedCoupling(nn.Module):
|
||||
return x, log_abs_det_jacobian
|
||||
|
||||
|
||||
class MADE(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
input_size,
|
||||
hidden_size,
|
||||
n_hidden,
|
||||
cond_label_size=None,
|
||||
activation="ReLU",
|
||||
input_order="sequential",
|
||||
input_degrees=None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
input_size -- scalar; dim of inputs
|
||||
hidden_size -- scalar; dim of hidden layers
|
||||
n_hidden -- scalar; number of hidden layers
|
||||
activation -- str; activation function to use
|
||||
input_order -- str or tensor; variable order for creating the autoregressive masks (sequential|random)
|
||||
or the order flipped from the previous layer in a stack of MADEs
|
||||
conditional -- bool; whether model is conditional
|
||||
"""
|
||||
super().__init__()
|
||||
# base distribution for calculation of log prob under the model
|
||||
self.register_buffer("base_dist_mean", torch.zeros(input_size))
|
||||
self.register_buffer("base_dist_var", torch.ones(input_size))
|
||||
|
||||
# create masks
|
||||
masks, self.input_degrees = create_masks(
|
||||
input_size, hidden_size, n_hidden, input_order, input_degrees
|
||||
)
|
||||
|
||||
# setup activation
|
||||
if activation == "ReLU":
|
||||
activation_fn = nn.ReLU()
|
||||
elif activation == "Tanh":
|
||||
activation_fn = nn.Tanh()
|
||||
else:
|
||||
raise ValueError("Check activation function.")
|
||||
|
||||
# construct model
|
||||
self.net_input = MaskedLinear(
|
||||
input_size, hidden_size, masks[0], cond_label_size
|
||||
)
|
||||
self.net = []
|
||||
for m in masks[1:-1]:
|
||||
self.net += [activation_fn, MaskedLinear(hidden_size, hidden_size, m)]
|
||||
self.net += [
|
||||
activation_fn,
|
||||
MaskedLinear(hidden_size, 2 * input_size, masks[-1].repeat(2, 1)),
|
||||
]
|
||||
self.net = nn.Sequential(*self.net)
|
||||
|
||||
@property
|
||||
def base_dist(self):
|
||||
return D.Normal(self.base_dist_mean, self.base_dist_var)
|
||||
|
||||
def forward(self, x, y=None):
|
||||
# MAF eq 4 -- return mean and log std
|
||||
m, loga = self.net(self.net_input(x, y)).chunk(chunks=2, dim=-1)
|
||||
u = (x - m) * torch.exp(-loga)
|
||||
# MAF eq 5
|
||||
log_abs_det_jacobian = -loga
|
||||
return u, log_abs_det_jacobian
|
||||
|
||||
def inverse(self, u, y=None, sum_log_abs_det_jacobians=None):
|
||||
# MAF eq 3
|
||||
# D = u.shape[-1]
|
||||
x = torch.zeros_like(u)
|
||||
# run through reverse model
|
||||
for i in self.input_degrees:
|
||||
m, loga = self.net(self.net_input(x, y)).chunk(chunks=2, dim=-1)
|
||||
x[..., i] = u[..., i] * torch.exp(loga[..., i]) + m[..., i]
|
||||
log_abs_det_jacobian = loga
|
||||
return x, log_abs_det_jacobian
|
||||
|
||||
def log_prob(self, x, y=None):
|
||||
u, log_abs_det_jacobian = self.forward(x, y)
|
||||
return torch.sum(self.base_dist.log_prob(u) + log_abs_det_jacobian, dim=-1)
|
||||
|
||||
|
||||
class RealNVP(nn.Module):
|
||||
def __init__(self, n_blocks, input_size, hidden_size, n_hidden, cond_label_size=None, batch_norm=True):
|
||||
def __init__(
|
||||
self,
|
||||
n_blocks,
|
||||
input_size,
|
||||
hidden_size,
|
||||
n_hidden,
|
||||
cond_label_size=None,
|
||||
batch_norm=True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# base distribution for calculation of log prob under the model
|
||||
self.register_buffer('base_dist_mean', torch.zeros(input_size))
|
||||
self.register_buffer('base_dist_var', torch.ones(input_size))
|
||||
self.register_buffer("base_dist_mean", torch.zeros(input_size))
|
||||
self.register_buffer("base_dist_var", torch.ones(input_size))
|
||||
|
||||
self.__scale = None
|
||||
|
||||
@@ -148,10 +285,11 @@ class RealNVP(nn.Module):
|
||||
modules = []
|
||||
mask = torch.arange(input_size).float() % 2
|
||||
for i in range(n_blocks):
|
||||
modules += [LinearMaskedCoupling(input_size,
|
||||
hidden_size,
|
||||
n_hidden, mask,
|
||||
cond_label_size)]
|
||||
modules += [
|
||||
LinearMaskedCoupling(
|
||||
input_size, hidden_size, n_hidden, mask, cond_label_size
|
||||
)
|
||||
]
|
||||
mask = 1 - mask
|
||||
modules += batch_norm * [BatchNorm(input_size)]
|
||||
|
||||
@@ -193,3 +331,78 @@ class RealNVP(nn.Module):
|
||||
u = self.base_dist.sample(shape)
|
||||
sample, _ = self.inverse(u, cond)
|
||||
return sample
|
||||
|
||||
|
||||
class MAF(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
n_blocks,
|
||||
input_size,
|
||||
hidden_size,
|
||||
n_hidden,
|
||||
cond_label_size=None,
|
||||
activation="ReLU",
|
||||
input_order="sequential",
|
||||
batch_norm=True,
|
||||
):
|
||||
super().__init__()
|
||||
# base distribution for calculation of log prob under the model
|
||||
self.register_buffer("base_dist_mean", torch.zeros(input_size))
|
||||
self.register_buffer("base_dist_var", torch.ones(input_size))
|
||||
|
||||
# construct model
|
||||
modules = []
|
||||
self.input_degrees = None
|
||||
for i in range(n_blocks):
|
||||
modules += [
|
||||
MADE(
|
||||
input_size,
|
||||
hidden_size,
|
||||
n_hidden,
|
||||
cond_label_size,
|
||||
activation,
|
||||
input_order,
|
||||
self.input_degrees,
|
||||
)
|
||||
]
|
||||
self.input_degrees = modules[-1].input_degrees.flip(0)
|
||||
modules += batch_norm * [BatchNorm(input_size)]
|
||||
|
||||
self.net = FlowSequential(*modules)
|
||||
|
||||
@property
|
||||
def base_dist(self):
|
||||
return Normal(self.base_dist_mean, self.base_dist_var)
|
||||
|
||||
@property
|
||||
def scale(self):
|
||||
return self.__scale
|
||||
|
||||
@scale.setter
|
||||
def scale(self, scale):
|
||||
self.__scale = scale
|
||||
|
||||
def forward(self, x, y=None):
|
||||
if self.scale is not None:
|
||||
x /= self.scale
|
||||
return self.net(x, y)
|
||||
|
||||
def inverse(self, u, y=None):
|
||||
x, log_abs_det_jacobian = self.net.inverse(u, y)
|
||||
if self.scale is not None:
|
||||
x *= self.scale
|
||||
return x, log_abs_det_jacobian
|
||||
|
||||
def log_prob(self, x, y=None):
|
||||
u, sum_log_abs_det_jacobians = self.forward(x, y)
|
||||
return torch.sum(self.base_dist.log_prob(u) + sum_log_abs_det_jacobians, dim=-1)
|
||||
|
||||
def sample(self, sample_shape=torch.Size(), cond=None):
|
||||
if cond is not None:
|
||||
shape = cond.shape[:-1]
|
||||
else:
|
||||
shape = sample_shape
|
||||
|
||||
u = self.base_dist.sample(shape)
|
||||
sample, _ = self.inverse(u, cond)
|
||||
return sample
|
||||
|
||||
Reference in New Issue
Block a user