mirror of
https://github.com/wassname/bert_adapter.git
synced 2026-06-27 16:29:32 +08:00
CrossEntropyLoss bugfix.
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
import logging
|
||||
from typing import NamedTuple, Callable, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from pytorch_pretrained_bert.modeling import ACT2FN, BertLayerNorm, BertModel, BertSelfOutput
|
||||
|
||||
@@ -92,6 +93,8 @@ class ClassificationModel(nn.Module):
|
||||
|
||||
if labels is not None:
|
||||
loss_function = nn.CrossEntropyLoss()
|
||||
if len(labels.shape) > 1:
|
||||
labels = torch.argmax(labels, dim=1)
|
||||
loss = loss_function(logits, labels)
|
||||
return loss, logits
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user