mirror of
https://github.com/wassname/bert_adapter.git
synced 2026-06-27 15:14:41 +08:00
Long bugfix.
This commit is contained in:
@@ -52,9 +52,9 @@ class TokenizedDataFrameDataset(Dataset):
|
||||
return BertInput(*[torch.LongTensor(_) for _ in [input_ids, input_mask, segment_ids]])
|
||||
|
||||
def preprocess_label(self, label: int):
|
||||
result = torch.zeros(len(self.y_labels))
|
||||
result = torch.zeros(len(self.y_labels)).long()
|
||||
result[label] = 1
|
||||
return torch.LongTensor(result)
|
||||
return result
|
||||
|
||||
def __getitem__(self, index) -> dict:
|
||||
sample = self.df.iloc[index]
|
||||
|
||||
Reference in New Issue
Block a user