diff --git a/data.py b/data.py index e975681..98a1a96 100644 --- a/data.py +++ b/data.py @@ -52,9 +52,9 @@ class TokenizedDataFrameDataset(Dataset): return BertInput(*[torch.LongTensor(_) for _ in [input_ids, input_mask, segment_ids]]) def preprocess_label(self, label: int): - result = torch.zeros(len(self.y_labels)) + result = torch.zeros(len(self.y_labels)).long() result[label] = 1 - return torch.LongTensor(result) + return result def __getitem__(self, index) -> dict: sample = self.df.iloc[index]