Files
wassname 7742b48f69 tidy
2022-02-11 20:07:18 +08:00

96 lines
3.1 KiB
Python

import torch
from torch import nn
from torch.autograd import grad
from torch.optim import Adam
from torch.nn import functional as F
from matplotlib import pyplot as plt
class IntensityNet(nn.Module):
def __init__(self, config):
super(IntensityNet, self).__init__()
self.linear1 = nn.Linear(in_features=1, out_features=1)
self.linear2 = nn.Linear(in_features=config.hid_dim+1, out_features=config.mlp_dim)
self.module_list = nn.ModuleList([nn.Linear(in_features=config.mlp_dim, out_features=config.mlp_dim) for _ in range(config.mlp_layer-1)])
self.linear3 = nn.Sequential(nn.Linear(in_features=config.mlp_dim, out_features=1), nn.Softplus())
self.mean_first = config.mean_first
self.log_t = config.log_t
self.init_weights_positive()
def init_weights_positive(self):
eps = 1e-10
for p in self.parameters():
p.data = torch.abs(p.data)
p.data = torch.clamp(p.data, min=eps)
def forward(self, hidden_state, target_time):
eps = 1e-10
for p in self.parameters():
p.data = torch.clamp(p.data, min=eps)
target_time.requires_grad_(True)
if self.log_t:
target_time = torch.log(target_time+eps)
t = self.linear1(target_time.unsqueeze(dim=-1))
out = torch.tanh(self.linear2(torch.cat([hidden_state[:,-1,:], t], dim=-1)))
for layer in self.module_list:
out = torch.tanh(layer(out))
int_lmbda = F.softplus(self.linear3(out))
int_lmbda_mean = int_lmbda.mean()
lmbda = grad(
int_lmbda.mean(),
target_time,
create_graph=True, retain_graph=True)[0]
log_lmbda = (lmbda + eps).log()
log_lmbda_mean = log_lmbda.mean()
if self.mean_first:
nll = int_lmbda_mean - log_lmbda_mean
else:
nll = (int_lmbda - log_lmbda).mean()
return [nll, log_lmbda_mean, int_lmbda_mean, lmbda]
class GTPP(nn.Module):
def __init__(self, config):
super(GTPP, self).__init__()
self.batch_size = config.batch_size
self.lr = config.lr
self.log_mode = config.log_mode # TODO meant to be used here?
self.embedding = nn.Embedding(num_embeddings=config.event_class, embedding_dim=config.emb_dim)
self.emb_drop = nn.Dropout(p=config.dropout)
self.lstm = nn.LSTM(input_size=1+config.emb_dim,
hidden_size=config.hid_dim,
batch_first=True,
bidirectional=False)
self.intensity_net = IntensityNet(config)
def forward(self, batch):
time_seq, event_seq = batch
event_seq = event_seq.long()
emb = self.embedding(event_seq)
emb = self.emb_drop(emb)
lstm_input = torch.cat([emb[:, :-1], time_seq[:, :-1].unsqueeze(-1)], dim=-1)
hidden_state, _ = self.lstm(lstm_input)
nll, log_lmbda, int_lmbda, lmbda = self.intensity_net(hidden_state, time_seq[:, -1])
return [nll, log_lmbda.detach(), int_lmbda.detach(), lmbda.detach()]