From 3fce888757141292f7c19cd856b38bce6f5483b9 Mon Sep 17 00:00:00 2001 From: Anton Kiselev Date: Sat, 18 May 2019 15:24:05 +0300 Subject: [PATCH] CrossEntropyLoss bugfix. --- modules.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/modules.py b/modules.py index 473a9ad..62d5d8f 100644 --- a/modules.py +++ b/modules.py @@ -1,6 +1,7 @@ import logging from typing import NamedTuple, Callable, Union +import torch import torch.nn as nn from pytorch_pretrained_bert.modeling import ACT2FN, BertLayerNorm, BertModel, BertSelfOutput @@ -92,6 +93,8 @@ class ClassificationModel(nn.Module): if labels is not None: loss_function = nn.CrossEntropyLoss() + if len(labels.shape) > 1: + labels = torch.argmax(labels, dim=1) loss = loss_function(logits, labels) return loss, logits else: