diff --git a/modules.py b/modules.py index 473a9ad..62d5d8f 100644 --- a/modules.py +++ b/modules.py @@ -1,6 +1,7 @@ import logging from typing import NamedTuple, Callable, Union +import torch import torch.nn as nn from pytorch_pretrained_bert.modeling import ACT2FN, BertLayerNorm, BertModel, BertSelfOutput @@ -92,6 +93,8 @@ class ClassificationModel(nn.Module): if labels is not None: loss_function = nn.CrossEntropyLoss() + if len(labels.shape) > 1: + labels = torch.argmax(labels, dim=1) loss = loss_function(logits, labels) return loss, logits else: