CrossEntropyLoss bugfix.

This commit is contained in:
Anton Kiselev
2019-05-18 15:24:05 +03:00
parent 0ffa9e170b
commit 3fce888757
+3
View File
@@ -1,6 +1,7 @@
import logging
from typing import NamedTuple, Callable, Union
import torch
import torch.nn as nn
from pytorch_pretrained_bert.modeling import ACT2FN, BertLayerNorm, BertModel, BertSelfOutput
@@ -92,6 +93,8 @@ class ClassificationModel(nn.Module):
if labels is not None:
loss_function = nn.CrossEntropyLoss()
if len(labels.shape) > 1:
labels = torch.argmax(labels, dim=1)
loss = loss_function(logits, labels)
return loss, logits
else: