From 0ffa9e170b687d3ef769e4e98ec3d159db6ba56e Mon Sep 17 00:00:00 2001 From: Anton Kiselev Date: Sat, 18 May 2019 15:21:50 +0300 Subject: [PATCH] Long bugfix. --- data.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/data.py b/data.py index e975681..98a1a96 100644 --- a/data.py +++ b/data.py @@ -52,9 +52,9 @@ class TokenizedDataFrameDataset(Dataset): return BertInput(*[torch.LongTensor(_) for _ in [input_ids, input_mask, segment_ids]]) def preprocess_label(self, label: int): - result = torch.zeros(len(self.y_labels)) + result = torch.zeros(len(self.y_labels)).long() result[label] = 1 - return torch.LongTensor(result) + return result def __getitem__(self, index) -> dict: sample = self.df.iloc[index]