From 96d37e23e5b8f46d3c09b927c3e1d2e98767b456 Mon Sep 17 00:00:00 2001 From: Anton Kiselev Date: Sat, 18 May 2019 15:27:34 +0300 Subject: [PATCH] Validation bugfix. --- training.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/training.py b/training.py index ca19b10..dc60abe 100644 --- a/training.py +++ b/training.py @@ -92,7 +92,13 @@ if __name__ == '__main__': optimizer.zero_grad() input_ids, input_mask, segment_ids = batch['x'] + input_ids = input_ids.to(device) + input_mask = input_mask.to(device) + segment_ids = segment_ids.to(device) + y = batch['y'] + y = y.to(device) + loss, logits = model.forward(input_ids, input_mask, segment_ids, labels=y) labels.append(torch.argmax(y, dim=1))