mirror of
https://github.com/wassname/bert_adapter.git
synced 2026-06-27 16:44:45 +08:00
Long bugfix.
This commit is contained in:
@@ -52,9 +52,9 @@ class TokenizedDataFrameDataset(Dataset):
|
|||||||
return BertInput(*[torch.LongTensor(_) for _ in [input_ids, input_mask, segment_ids]])
|
return BertInput(*[torch.LongTensor(_) for _ in [input_ids, input_mask, segment_ids]])
|
||||||
|
|
||||||
def preprocess_label(self, label: int):
|
def preprocess_label(self, label: int):
|
||||||
result = torch.zeros(len(self.y_labels))
|
result = torch.zeros(len(self.y_labels)).long()
|
||||||
result[label] = 1
|
result[label] = 1
|
||||||
return torch.LongTensor(result)
|
return result
|
||||||
|
|
||||||
def __getitem__(self, index) -> dict:
|
def __getitem__(self, index) -> dict:
|
||||||
sample = self.df.iloc[index]
|
sample = self.df.iloc[index]
|
||||||
|
|||||||
Reference in New Issue
Block a user