Validation bugfix.

This commit is contained in:
Anton Kiselev
2019-05-18 15:27:34 +03:00
parent 3fce888757
commit 96d37e23e5
+6
View File
@@ -92,7 +92,13 @@ if __name__ == '__main__':
optimizer.zero_grad()
input_ids, input_mask, segment_ids = batch['x']
input_ids = input_ids.to(device)
input_mask = input_mask.to(device)
segment_ids = segment_ids.to(device)
y = batch['y']
y = y.to(device)
loss, logits = model.forward(input_ids, input_mask, segment_ids, labels=y)
labels.append(torch.argmax(y, dim=1))