Neural Temporal Point Process

This commit is contained in:
kanghoon
2020-09-15 15:25:43 +09:00
parent 21f900b5d6
commit c0cfe2186e
+8 -6
View File
@@ -8,7 +8,7 @@ from argparse import ArgumentParser
from torch.utils.data import DataLoader
from utils import read_timeseries, generate_sequence, plt_lmbda
from utils import read_timeseries,generate_sequence, plt_lmbda
from module import GTPP
@@ -48,6 +48,8 @@ if __name__ == "__main__":
val_data = read_timeseries(path + config.data + '_validation.csv')
test_data = read_timeseries(path + config.data + '_testing.csv')
train_timeseq, train_eventseq = generate_sequence(train_data, config.seq_len, log_mode=config.log_mode)
train_loader = DataLoader(torch.utils.data.TensorDataset(train_timeseq, train_eventseq), shuffle=True, batch_size=config.batch_size)
val_timeseq, val_eventseq = generate_sequence(val_data, config.seq_len, log_mode=config.log_mode)
@@ -57,7 +59,7 @@ if __name__ == "__main__":
best_loss = 1e3
patients = 0
tol = 20
tol = 30
for epoch in range(config.epochs):
@@ -90,10 +92,10 @@ if __name__ == "__main__":
if epoch % config.prt_evry == 0:
print("Epochs:{}".format(epoch))
print("Training Negative Log Likelihood:{} Log Lambda:{}: Integral Lambda:{}".format(loss1/config.batch_size, -loss2 / config.batch_size, loss3 / config.batch_size))
print("Validation Negative Log Likelihood:{} Log Lambda:{}: Integral Lambda:{}".format(val_loss / config.batch_size,
-val_log_lmbda / config.batch_size,
val_int_lmbda/ config.batch_size))
print("Training Negative Log Likelihood:{} Log Lambda:{}: Integral Lambda:{}".format(loss1/train_timeseq.size(0), -loss2 / train_timeseq.size(0), loss3 / train_timeseq.size(0)))
print("Validation Negative Log Likelihood:{} Log Lambda:{}: Integral Lambda:{}".format(val_loss / val_timeseq.size(0),
-val_log_lmbda / val_timeseq.size(0),
val_int_lmbda/val_timeseq.size(0)))
plt_lmbda(train_data[0], model=model, seq_len=config.seq_len, log_mode=config.log_mode)
# plt_lmbda(test_data[0], model=model, seq_len=config.seq_len, log_mode=config.log_mode)