diff --git a/training.py b/training.py index ca19b10..dc60abe 100644 --- a/training.py +++ b/training.py @@ -92,7 +92,13 @@ if __name__ == '__main__': optimizer.zero_grad() input_ids, input_mask, segment_ids = batch['x'] + input_ids = input_ids.to(device) + input_mask = input_mask.to(device) + segment_ids = segment_ids.to(device) + y = batch['y'] + y = y.to(device) + loss, logits = model.forward(input_ids, input_mask, segment_ids, labels=y) labels.append(torch.argmax(y, dim=1))