diff --git a/pts/modules/__init__.py b/pts/modules/__init__.py index 6911fb9..10c46c9 100644 --- a/pts/modules/__init__.py +++ b/pts/modules/__init__.py @@ -13,4 +13,4 @@ from .distribution_output import ( from .lambda_layer import LambdaLayer from .feature import FeatureEmbedder, FeatureAssembler from .scaler import MeanScaler, NOPScaler -from .flows import RealNVP \ No newline at end of file +from .flows import RealNVP, MAF diff --git a/pts/modules/flows.py b/pts/modules/flows.py index a6419c4..ef2b2de 100644 --- a/pts/modules/flows.py +++ b/pts/modules/flows.py @@ -6,6 +6,49 @@ import torch.nn.functional as F from torch.distributions import Normal +def create_masks( + input_size, hidden_size, n_hidden, input_order="sequential", input_degrees=None +): + # MADE paper sec 4: + # degrees of connections between layers -- ensure at most in_degree - 1 connections + degrees = [] + + # set input degrees to what is provided in args (the flipped order of the previous layer in a stack of mades); + # else init input degrees based on strategy in input_order (sequential or random) + if input_order == "sequential": + degrees += ( + [torch.arange(input_size)] if input_degrees is None else [input_degrees] + ) + for _ in range(n_hidden + 1): + degrees += [torch.arange(hidden_size) % (input_size - 1)] + degrees += ( + [torch.arange(input_size) % input_size - 1] + if input_degrees is None + else [input_degrees % input_size - 1] + ) + + elif input_order == "random": + degrees += ( + [torch.randperm(input_size)] if input_degrees is None else [input_degrees] + ) + for _ in range(n_hidden + 1): + min_prev_degree = min(degrees[-1].min().item(), input_size - 1) + degrees += [torch.randint(min_prev_degree, input_size, (hidden_size,))] + min_prev_degree = min(degrees[-1].min().item(), input_size - 1) + degrees += ( + [torch.randint(min_prev_degree, input_size, (input_size,)) - 1] + if input_degrees is None + else [input_degrees - 1] + ) + + # construct masks + masks = [] + for (d0, d1) in zip(degrees[:-1], degrees[1:]): + masks += [(d1.unsqueeze(-1) >= d0.unsqueeze(0)).float()] + + return masks, degrees[0] + + class FlowSequential(nn.Sequential): """ Container for layers of a normalizing flow """ @@ -35,8 +78,8 @@ class BatchNorm(nn.Module): self.log_gamma = nn.Parameter(torch.zeros(input_size)) self.beta = nn.Parameter(torch.zeros(input_size)) - self.register_buffer('running_mean', torch.zeros(input_size)) - self.register_buffer('running_var', torch.ones(input_size)) + self.register_buffer("running_mean", torch.zeros(input_size)) + self.register_buffer("running_var", torch.ones(input_size)) def forward(self, x, cond_y=None): if self.training: @@ -46,9 +89,11 @@ class BatchNorm(nn.Module): # update running mean self.running_mean.mul_(self.momentum).add_( - self.batch_mean.data * (1 - self.momentum)) + self.batch_mean.data * (1 - self.momentum) + ) self.running_var.mul_(self.momentum).add_( - self.batch_var.data * (1 - self.momentum)) + self.batch_var.data * (1 - self.momentum) + ) mean = self.batch_mean var = self.batch_var @@ -62,8 +107,8 @@ class BatchNorm(nn.Module): # compute log_abs_det_jacobian (cf RealNVP paper) log_abs_det_jacobian = self.log_gamma - 0.5 * torch.log(var + self.eps) -# print('in sum log var {:6.3f} ; out sum log var {:6.3f}; sum log det {:8.3f}; mean log_gamma {:5.3f}; mean beta {:5.3f}'.format( -# (var + self.eps).log().sum().data.numpy(), y.var(0).log().sum().data.numpy(), log_abs_det_jacobian.mean(0).item(), self.log_gamma.mean(), self.beta.mean())) + # print('in sum log var {:6.3f} ; out sum log var {:6.3f}; sum log det {:8.3f}; mean log_gamma {:5.3f}; mean beta {:5.3f}'.format( + # (var + self.eps).log().sum().data.numpy(), y.var(0).log().sum().data.numpy(), log_abs_det_jacobian.mean(0).item(), self.log_gamma.mean(), self.beta.mean())) return y, log_abs_det_jacobian.expand_as(x) def inverse(self, y, cond_y=None): @@ -88,11 +133,15 @@ class LinearMaskedCoupling(nn.Module): def __init__(self, input_size, hidden_size, n_hidden, mask, cond_label_size=None): super().__init__() - self.register_buffer('mask', mask) + self.register_buffer("mask", mask) # scale function - s_net = [nn.Linear( - input_size + (cond_label_size if cond_label_size is not None else 0), hidden_size)] + s_net = [ + nn.Linear( + input_size + (cond_label_size if cond_label_size is not None else 0), + hidden_size, + ) + ] for _ in range(n_hidden): s_net += [nn.Tanh(), nn.Linear(hidden_size, hidden_size)] s_net += [nn.Tanh(), nn.Linear(hidden_size, input_size)] @@ -116,7 +165,7 @@ class LinearMaskedCoupling(nn.Module): u = mx + (1 - self.mask) * (x - t) * torch.exp(-s) # log det du/dx; cf RealNVP 8 and 6; note, sum over input_size done at model log_prob - log_abs_det_jacobian = - (1 - self.mask) * s + log_abs_det_jacobian = -(1 - self.mask) * s return u, log_abs_det_jacobian @@ -134,13 +183,101 @@ class LinearMaskedCoupling(nn.Module): return x, log_abs_det_jacobian +class MADE(nn.Module): + def __init__( + self, + input_size, + hidden_size, + n_hidden, + cond_label_size=None, + activation="ReLU", + input_order="sequential", + input_degrees=None, + ): + """ + Args: + input_size -- scalar; dim of inputs + hidden_size -- scalar; dim of hidden layers + n_hidden -- scalar; number of hidden layers + activation -- str; activation function to use + input_order -- str or tensor; variable order for creating the autoregressive masks (sequential|random) + or the order flipped from the previous layer in a stack of MADEs + conditional -- bool; whether model is conditional + """ + super().__init__() + # base distribution for calculation of log prob under the model + self.register_buffer("base_dist_mean", torch.zeros(input_size)) + self.register_buffer("base_dist_var", torch.ones(input_size)) + + # create masks + masks, self.input_degrees = create_masks( + input_size, hidden_size, n_hidden, input_order, input_degrees + ) + + # setup activation + if activation == "ReLU": + activation_fn = nn.ReLU() + elif activation == "Tanh": + activation_fn = nn.Tanh() + else: + raise ValueError("Check activation function.") + + # construct model + self.net_input = MaskedLinear( + input_size, hidden_size, masks[0], cond_label_size + ) + self.net = [] + for m in masks[1:-1]: + self.net += [activation_fn, MaskedLinear(hidden_size, hidden_size, m)] + self.net += [ + activation_fn, + MaskedLinear(hidden_size, 2 * input_size, masks[-1].repeat(2, 1)), + ] + self.net = nn.Sequential(*self.net) + + @property + def base_dist(self): + return D.Normal(self.base_dist_mean, self.base_dist_var) + + def forward(self, x, y=None): + # MAF eq 4 -- return mean and log std + m, loga = self.net(self.net_input(x, y)).chunk(chunks=2, dim=-1) + u = (x - m) * torch.exp(-loga) + # MAF eq 5 + log_abs_det_jacobian = -loga + return u, log_abs_det_jacobian + + def inverse(self, u, y=None, sum_log_abs_det_jacobians=None): + # MAF eq 3 + # D = u.shape[-1] + x = torch.zeros_like(u) + # run through reverse model + for i in self.input_degrees: + m, loga = self.net(self.net_input(x, y)).chunk(chunks=2, dim=-1) + x[..., i] = u[..., i] * torch.exp(loga[..., i]) + m[..., i] + log_abs_det_jacobian = loga + return x, log_abs_det_jacobian + + def log_prob(self, x, y=None): + u, log_abs_det_jacobian = self.forward(x, y) + return torch.sum(self.base_dist.log_prob(u) + log_abs_det_jacobian, dim=-1) + + class RealNVP(nn.Module): - def __init__(self, n_blocks, input_size, hidden_size, n_hidden, cond_label_size=None, batch_norm=True): + def __init__( + self, + n_blocks, + input_size, + hidden_size, + n_hidden, + cond_label_size=None, + batch_norm=True, + ): super().__init__() # base distribution for calculation of log prob under the model - self.register_buffer('base_dist_mean', torch.zeros(input_size)) - self.register_buffer('base_dist_var', torch.ones(input_size)) + self.register_buffer("base_dist_mean", torch.zeros(input_size)) + self.register_buffer("base_dist_var", torch.ones(input_size)) self.__scale = None @@ -148,10 +285,11 @@ class RealNVP(nn.Module): modules = [] mask = torch.arange(input_size).float() % 2 for i in range(n_blocks): - modules += [LinearMaskedCoupling(input_size, - hidden_size, - n_hidden, mask, - cond_label_size)] + modules += [ + LinearMaskedCoupling( + input_size, hidden_size, n_hidden, mask, cond_label_size + ) + ] mask = 1 - mask modules += batch_norm * [BatchNorm(input_size)] @@ -193,3 +331,78 @@ class RealNVP(nn.Module): u = self.base_dist.sample(shape) sample, _ = self.inverse(u, cond) return sample + + +class MAF(nn.Module): + def __init__( + self, + n_blocks, + input_size, + hidden_size, + n_hidden, + cond_label_size=None, + activation="ReLU", + input_order="sequential", + batch_norm=True, + ): + super().__init__() + # base distribution for calculation of log prob under the model + self.register_buffer("base_dist_mean", torch.zeros(input_size)) + self.register_buffer("base_dist_var", torch.ones(input_size)) + + # construct model + modules = [] + self.input_degrees = None + for i in range(n_blocks): + modules += [ + MADE( + input_size, + hidden_size, + n_hidden, + cond_label_size, + activation, + input_order, + self.input_degrees, + ) + ] + self.input_degrees = modules[-1].input_degrees.flip(0) + modules += batch_norm * [BatchNorm(input_size)] + + self.net = FlowSequential(*modules) + + @property + def base_dist(self): + return Normal(self.base_dist_mean, self.base_dist_var) + + @property + def scale(self): + return self.__scale + + @scale.setter + def scale(self, scale): + self.__scale = scale + + def forward(self, x, y=None): + if self.scale is not None: + x /= self.scale + return self.net(x, y) + + def inverse(self, u, y=None): + x, log_abs_det_jacobian = self.net.inverse(u, y) + if self.scale is not None: + x *= self.scale + return x, log_abs_det_jacobian + + def log_prob(self, x, y=None): + u, sum_log_abs_det_jacobians = self.forward(x, y) + return torch.sum(self.base_dist.log_prob(u) + sum_log_abs_det_jacobians, dim=-1) + + def sample(self, sample_shape=torch.Size(), cond=None): + if cond is not None: + shape = cond.shape[:-1] + else: + shape = sample_shape + + u = self.base_dist.sample(shape) + sample, _ = self.inverse(u, cond) + return sample