mirror of
https://github.com/wassname/torch-neuralpointprocess.git
synced 2026-06-27 16:32:26 +08:00
96 lines
3.1 KiB
Python
96 lines
3.1 KiB
Python
import torch
|
|
from torch import nn
|
|
from torch.autograd import grad
|
|
from torch.optim import Adam
|
|
from torch.nn import functional as F
|
|
|
|
from matplotlib import pyplot as plt
|
|
|
|
|
|
class IntensityNet(nn.Module):
|
|
|
|
def __init__(self, config):
|
|
super(IntensityNet, self).__init__()
|
|
|
|
self.linear1 = nn.Linear(in_features=1, out_features=1)
|
|
self.linear2 = nn.Linear(in_features=config.hid_dim+1, out_features=config.mlp_dim)
|
|
self.module_list = nn.ModuleList([nn.Linear(in_features=config.mlp_dim, out_features=config.mlp_dim) for _ in range(config.mlp_layer-1)])
|
|
self.linear3 = nn.Sequential(nn.Linear(in_features=config.mlp_dim, out_features=1), nn.Softplus())
|
|
|
|
self.mean_first = config.mean_first
|
|
self.log_t = config.log_t
|
|
|
|
self.init_weights_positive()
|
|
|
|
def init_weights_positive(self):
|
|
eps = 1e-10
|
|
for p in self.parameters():
|
|
p.data = torch.abs(p.data)
|
|
p.data = torch.clamp(p.data, min=eps)
|
|
|
|
|
|
def forward(self, hidden_state, target_time):
|
|
eps = 1e-10
|
|
|
|
for p in self.parameters():
|
|
p.data = torch.clamp(p.data, min=eps)
|
|
|
|
target_time.requires_grad_(True)
|
|
if self.log_t:
|
|
target_time = torch.log(target_time+eps)
|
|
t = self.linear1(target_time.unsqueeze(dim=-1))
|
|
|
|
out = torch.tanh(self.linear2(torch.cat([hidden_state[:,-1,:], t], dim=-1)))
|
|
for layer in self.module_list:
|
|
out = torch.tanh(layer(out))
|
|
int_lmbda = F.softplus(self.linear3(out))
|
|
int_lmbda_mean = int_lmbda.mean()
|
|
|
|
lmbda = grad(
|
|
int_lmbda.mean(),
|
|
target_time,
|
|
create_graph=True, retain_graph=True)[0]
|
|
log_lmbda = (lmbda + eps).log()
|
|
log_lmbda_mean = log_lmbda.mean()
|
|
|
|
if self.mean_first:
|
|
nll = int_lmbda_mean - log_lmbda_mean
|
|
else:
|
|
nll = (int_lmbda - log_lmbda).mean()
|
|
|
|
return [nll, log_lmbda_mean, int_lmbda_mean, lmbda]
|
|
|
|
|
|
class GTPP(nn.Module):
|
|
|
|
def __init__(self, config):
|
|
|
|
super(GTPP, self).__init__()
|
|
|
|
self.batch_size = config.batch_size
|
|
self.lr = config.lr
|
|
self.log_mode = config.log_mode # TODO meant to be used here?
|
|
|
|
|
|
self.embedding = nn.Embedding(num_embeddings=config.event_class, embedding_dim=config.emb_dim)
|
|
self.emb_drop = nn.Dropout(p=config.dropout)
|
|
self.lstm = nn.LSTM(input_size=1+config.emb_dim,
|
|
hidden_size=config.hid_dim,
|
|
batch_first=True,
|
|
bidirectional=False)
|
|
self.intensity_net = IntensityNet(config)
|
|
|
|
|
|
def forward(self, batch):
|
|
time_seq, event_seq = batch
|
|
event_seq = event_seq.long()
|
|
emb = self.embedding(event_seq)
|
|
emb = self.emb_drop(emb)
|
|
lstm_input = torch.cat([emb[:, :-1], time_seq[:, :-1].unsqueeze(-1)], dim=-1)
|
|
hidden_state, _ = self.lstm(lstm_input)
|
|
|
|
nll, log_lmbda, int_lmbda, lmbda = self.intensity_net(hidden_state, time_seq[:, -1])
|
|
|
|
return [nll, log_lmbda.detach(), int_lmbda.detach(), lmbda.detach()]
|
|
|