"""Recurrent Attentive Neural Process.""" import torch from torch import nn import torch.nn.functional as F import math class LSTMBlock(nn.Module): """Wrapper to return only lstm output.""" def __init__( self, in_channels, out_channels, dropout=0, batchnorm=False, bias=False, num_layers=1, ): super().__init__() self._lstm = nn.LSTM( input_size=in_channels, hidden_size=out_channels, num_layers=num_layers, dropout=dropout, batch_first=True, bias=bias, ) def forward(self, x): return self._lstm(x)[0] class NPBlockRelu2d(nn.Module): """ Block for Neural Processes. We want to apply batchnorm and dropout to the channels. We reshape so we can use Dropout2d & BatchNorm2d """ def __init__( self, in_channels, out_channels, dropout=0, batchnorm=False, bias=False ): super().__init__() self.linear = nn.Linear(in_channels, out_channels, bias=bias) self.act = nn.ReLU() self.dropout = nn.Dropout2d(dropout) self.norm = nn.BatchNorm2d(out_channels) if batchnorm else False def forward(self, x): # x.shape is (Batch, Sequence, Channels) # We pass a linear over it which operates on the Channels x = self.act(self.linear(x)) # Now we want to apply batchnorm and dropout to the channels. So we put it in shape # (Batch, Channels, Sequence, None) so we can use Dropout2d & BatchNorm2d x = x.permute(0, 2, 1)[:, :, :, None] if self.norm: x = self.norm(x) x = self.dropout(x) return x[:, :, :, 0].permute(0, 2, 1) class BatchMLP(nn.Module): """Apply MLP to the final axis of a 3D tensor (reusing already defined MLPs). Args: input: input tensor of shape [B,n,d_in]. output_sizes: An iterable containing the output sizes of the MLP as defined in `basic.Linear`. Returns: tensor of shape [B,n,d_out] where d_out=output_size """ def __init__( self, input_size, output_size, num_layers=2, dropout=0, batchnorm=False ): super().__init__() self.input_size = input_size self.output_size = output_size self.num_layers = num_layers self.initial = NPBlockRelu2d( input_size, output_size, dropout=dropout, batchnorm=batchnorm ) self.encoder = nn.Sequential( *[ NPBlockRelu2d( output_size, output_size, dropout=dropout, batchnorm=batchnorm ) for _ in range(num_layers - 2) ] ) self.final = nn.Linear(output_size, output_size) def forward(self, x): x = self.initial(x) x = self.encoder(x) return self.final(x) class Attention(nn.Module): def __init__( self, hidden_dim, attention_layers=2, n_heads=8, x_dim=1, rep="mlp", dropout=0, batchnorm=False, ): super().__init__() self._rep = rep if self._rep == "mlp": self.batch_mlp_k = BatchMLP( x_dim, hidden_dim, attention_layers, dropout=dropout, batchnorm=batchnorm, ) self.batch_mlp_q = BatchMLP( x_dim, hidden_dim, attention_layers, dropout=dropout, batchnorm=batchnorm, ) self._W = torch.nn.MultiheadAttention( hidden_dim, n_heads, bias=False, dropout=dropout ) self._attention_func = self._pytorch_multihead_attention def forward(self, k, v, q): if self._rep == "mlp": k = self.batch_mlp_k(k) q = self.batch_mlp_q(q) rep = self._attention_func(k, v, q) return rep def _pytorch_multihead_attention(self, k, v, q): # Pytorch multiheaded attention takes inputs if diff order and permutation q = q.permute(1, 0, 2) k = k.permute(1, 0, 2) v = v.permute(1, 0, 2) o = self._W(q, k, v)[0] return o.permute(1, 0, 2) class LatentEncoder(nn.Module): def __init__( self, input_dim, hidden_dim=32, latent_dim=32, n_encoder_layers=3, min_std=0.01, batchnorm=False, dropout=0, nhead=8, attention_dropout=0, attention_layers=2, ): super().__init__() # self._input_layer = nn.Linear(input_dim, hidden_dim) self._encoder = BatchMLP( input_dim, hidden_dim, batchnorm=batchnorm, dropout=dropout, num_layers=n_encoder_layers, ) self._self_attention = Attention( hidden_dim, attention_layers, n_heads=nhead, rep="identity", dropout=attention_dropout, ) self._penultimate_layer = nn.Linear(hidden_dim, hidden_dim) self._mean = nn.Linear(hidden_dim, latent_dim) self._log_var = nn.Linear(hidden_dim, latent_dim) self._min_std = min_std def forward(self, x, y): encoder_input = torch.cat([x, y], dim=-1) # Pass final axis through MLP encoded = self._encoder(encoder_input) # Aggregator: take the mean over all points attention_output = self._self_attention(encoded, encoded, encoded) mean_repr = attention_output.mean(dim=1) # Have further MLP layers that map to the parameters of the Gaussian latent mean_repr = torch.relu(self._penultimate_layer(mean_repr)) # Then apply further linear layers to output latent mu and log sigma mean = self._mean(mean_repr) log_var = self._log_var(mean_repr) sigma = self._min_std + (1 - self._min_std) * torch.sigmoid(log_var * 0.5) dist = torch.distributions.Normal(mean, sigma) return dist, log_var class DeterministicEncoder(nn.Module): def __init__( self, input_dim, x_dim, hidden_dim=32, n_d_encoder_layers=3, attention_layers=2, batchnorm=False, dropout=0, nhead=8, attention_dropout=0, ): super().__init__() # self._input_layer = nn.Linear(input_dim, hidden_dim) self._d_encoder = BatchMLP( input_dim, hidden_dim, batchnorm=batchnorm, dropout=dropout, num_layers=n_d_encoder_layers, ) self._self_attention = Attention( hidden_dim, attention_layers, n_heads=nhead, rep="identity", dropout=attention_dropout, ) self._cross_attention = Attention( hidden_dim, x_dim=x_dim, n_heads=nhead, attention_layers=attention_layers, ) def forward(self, past_x, past_y, future_x): # Concatenate x and y along the filter axes d_encoder_input = torch.cat([past_x, past_y], dim=-1) # Pass final axis through MLP d_encoded = self._d_encoder(d_encoder_input) d_encoded = self._self_attention(d_encoded, d_encoded, d_encoded) # Apply attention as mean aggregation h = self._cross_attention(past_x, d_encoded, future_x) return h class Decoder(nn.Module): def __init__( self, x_dim, y_dim, hidden_dim=32, latent_dim=32, n_decoder_layers=3, use_deterministic_path=True, min_std=0.01, batchnorm=False, dropout=0, ): super(Decoder, self).__init__() self._future_transform = nn.Linear(x_dim, hidden_dim) if use_deterministic_path: hidden_dim_2 = 2 * hidden_dim + latent_dim else: hidden_dim_2 = hidden_dim + latent_dim self._decoder = BatchMLP( hidden_dim_2, hidden_dim_2, batchnorm=batchnorm, dropout=dropout, num_layers=n_decoder_layers, ) self._mean = nn.Linear(hidden_dim_2, y_dim) self._std = nn.Linear(hidden_dim_2, y_dim) self._use_deterministic_path = use_deterministic_path self._min_std = min_std def forward(self, r, z, future_x): # concatenate future_x and representation x = self._future_transform(future_x) if self._use_deterministic_path: z = torch.cat([r, z], dim=-1) r = torch.cat([z, x], dim=-1) r = self._decoder(r) # Get the mean and the variance mean = self._mean(r) log_sigma = self._std(r) # Bound or clamp the variance sigma = self._min_std + (1 - self._min_std) * F.softplus(log_sigma) dist = torch.distributions.Normal(mean, sigma) return dist, log_sigma class RANP(nn.Module): """Recurrent Attentive Neural Process for Sequential Data.""" def __init__( self, x_dim, # features in input y_dim, # number of features in output hidden_dim=32, # size of hidden space latent_dim=32, # size of latent space n_latent_encoder_layers=2, n_det_encoder_layers=2, # number of deterministic encoder layers n_decoder_layers=4, use_deterministic_path=True, min_std=0.01, # To avoid collapse use a minimum standard deviation, should be much smaller than variation in labels dropout=0, nhead=8, attention_dropout=0, batchnorm=False, attention_layers=2, use_rnn=True, # use RNN/LSTM **kwargs, ): super().__init__() self._use_rnn = use_rnn if self._use_rnn: self._lstm = nn.LSTM( input_size=x_dim, hidden_size=hidden_dim, num_layers=attention_layers, dropout=dropout, batch_first=True, ) x_dim = hidden_dim self._latent_encoder = LatentEncoder( x_dim + y_dim, hidden_dim=hidden_dim, latent_dim=latent_dim, n_encoder_layers=n_latent_encoder_layers, attention_layers=attention_layers, dropout=dropout, nhead=nhead, attention_dropout=attention_dropout, batchnorm=batchnorm, min_std=min_std, ) self._deterministic_encoder = DeterministicEncoder( input_dim=x_dim + y_dim, x_dim=x_dim, hidden_dim=hidden_dim, n_d_encoder_layers=n_det_encoder_layers, attention_layers=attention_layers, dropout=dropout, nhead=nhead, batchnorm=batchnorm, attention_dropout=attention_dropout, ) self._decoder = Decoder( x_dim, y_dim, hidden_dim=hidden_dim, latent_dim=latent_dim, dropout=dropout, batchnorm=batchnorm, min_std=min_std, n_decoder_layers=n_decoder_layers, use_deterministic_path=use_deterministic_path, ) self._use_deterministic_path = use_deterministic_path def forward(self, past_x, past_y, future_x, future_y=None): if self._use_rnn: # see https://arxiv.org/abs/1910.09323 where x is substituted with h = RNN(x) # x need to be provided as [B, T, H] S = past_x.shape[1] x = torch.cat([past_x, future_x], 1) x, _ = self._lstm(x) past_x = x[:, :S] future_x = x[:, S:] dist_prior, log_var_prior = self._latent_encoder(past_x, past_y) if (future_y is not None): y = torch.cat([past_y, future_y], 1) dist_post, log_var_post = self._latent_encoder(x, y) if self.training and (future_y is not None): # USe posterior during training, is possible z = dist_post.rsample() else: # During eval use the prior, also take the most probable z = dist_prior.loc num_targets = future_x.size(1) z = z.unsqueeze(1).repeat(1, num_targets, 1) # [B, T_target, H] if self._use_deterministic_path: r = self._deterministic_encoder( past_x, past_y, future_x ) # [B, T_target, H] else: r = None dist, log_sigma = self._decoder(r, z, future_x) loss = None if future_y is not None: log_p = dist.log_prob(future_y).mean(-1) kl_loss = torch.distributions.kl_divergence(dist_post, dist_prior).mean( -1 ) # [B, R].mean(-1) kl_loss = kl_loss[:, None].expand(log_p.shape) mse_loss = F.mse_loss(dist.loc, future_y, reduction="none")[ :, : past_x.size(1) ].mean() loss = (kl_loss - log_p).mean() return dist, {'loss': loss} # class NP(RANP): # """Recurrent Attentive Neural Process for Sequential Data.""" # def __init__( # self, # use_self_attn=True, # # TODO use cross attention flag # use_rnn=True, # use RNN/LSTM # use_lstm_le=False, # use another LSTM in latent encoder instead of MLP # use_lstm_de=False, # use another LSTM in determinstic encoder instead of MLP # use_lstm_d=False, # use another lstm in decoder instead of MLP # **kwargs, # ): # kwargs # super().__init__(**kwargs)