mirror of
https://github.com/wassname/bert_adapter.git
synced 2026-06-27 16:29:32 +08:00
Validation bugfix.
This commit is contained in:
@@ -92,7 +92,13 @@ if __name__ == '__main__':
|
||||
optimizer.zero_grad()
|
||||
|
||||
input_ids, input_mask, segment_ids = batch['x']
|
||||
input_ids = input_ids.to(device)
|
||||
input_mask = input_mask.to(device)
|
||||
segment_ids = segment_ids.to(device)
|
||||
|
||||
y = batch['y']
|
||||
y = y.to(device)
|
||||
|
||||
|
||||
loss, logits = model.forward(input_ids, input_mask, segment_ids, labels=y)
|
||||
labels.append(torch.argmax(y, dim=1))
|
||||
|
||||
Reference in New Issue
Block a user