Long bugfix.

This commit is contained in:
Anton Kiselev
2019-05-18 15:21:50 +03:00
parent cf1495d896
commit 0ffa9e170b
+2 -2
View File
@@ -52,9 +52,9 @@ class TokenizedDataFrameDataset(Dataset):
return BertInput(*[torch.LongTensor(_) for _ in [input_ids, input_mask, segment_ids]])
def preprocess_label(self, label: int):
result = torch.zeros(len(self.y_labels))
result = torch.zeros(len(self.y_labels)).long()
result[label] = 1
return torch.LongTensor(result)
return result
def __getitem__(self, index) -> dict:
sample = self.df.iloc[index]