mirror of
https://github.com/wassname/torch-neuralpointprocess.git
synced 2026-06-27 16:32:26 +08:00
Neural Temporal Point Process
This commit is contained in:
@@ -8,7 +8,7 @@ from argparse import ArgumentParser
|
||||
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from utils import read_timeseries, generate_sequence, plt_lmbda
|
||||
from utils import read_timeseries,generate_sequence, plt_lmbda
|
||||
from module import GTPP
|
||||
|
||||
|
||||
@@ -48,6 +48,8 @@ if __name__ == "__main__":
|
||||
val_data = read_timeseries(path + config.data + '_validation.csv')
|
||||
test_data = read_timeseries(path + config.data + '_testing.csv')
|
||||
|
||||
|
||||
|
||||
train_timeseq, train_eventseq = generate_sequence(train_data, config.seq_len, log_mode=config.log_mode)
|
||||
train_loader = DataLoader(torch.utils.data.TensorDataset(train_timeseq, train_eventseq), shuffle=True, batch_size=config.batch_size)
|
||||
val_timeseq, val_eventseq = generate_sequence(val_data, config.seq_len, log_mode=config.log_mode)
|
||||
@@ -57,7 +59,7 @@ if __name__ == "__main__":
|
||||
|
||||
best_loss = 1e3
|
||||
patients = 0
|
||||
tol = 20
|
||||
tol = 30
|
||||
|
||||
for epoch in range(config.epochs):
|
||||
|
||||
@@ -90,10 +92,10 @@ if __name__ == "__main__":
|
||||
|
||||
if epoch % config.prt_evry == 0:
|
||||
print("Epochs:{}".format(epoch))
|
||||
print("Training Negative Log Likelihood:{} Log Lambda:{}: Integral Lambda:{}".format(loss1/config.batch_size, -loss2 / config.batch_size, loss3 / config.batch_size))
|
||||
print("Validation Negative Log Likelihood:{} Log Lambda:{}: Integral Lambda:{}".format(val_loss / config.batch_size,
|
||||
-val_log_lmbda / config.batch_size,
|
||||
val_int_lmbda/ config.batch_size))
|
||||
print("Training Negative Log Likelihood:{} Log Lambda:{}: Integral Lambda:{}".format(loss1/train_timeseq.size(0), -loss2 / train_timeseq.size(0), loss3 / train_timeseq.size(0)))
|
||||
print("Validation Negative Log Likelihood:{} Log Lambda:{}: Integral Lambda:{}".format(val_loss / val_timeseq.size(0),
|
||||
-val_log_lmbda / val_timeseq.size(0),
|
||||
val_int_lmbda/val_timeseq.size(0)))
|
||||
plt_lmbda(train_data[0], model=model, seq_len=config.seq_len, log_mode=config.log_mode)
|
||||
# plt_lmbda(test_data[0], model=model, seq_len=config.seq_len, log_mode=config.log_mode)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user