Modifying the code of DoubleHeadModel to allow different task heads.

This commit is contained in:
Grégory Châtel
2018-07-13 17:27:33 +02:00
parent ac2250881a
commit 87b4901a81
2 changed files with 15 additions and 6 deletions
+14 -5
View File
@@ -2,6 +2,7 @@ import copy
import json
import math
import re
import collections
import numpy as np
import torch
@@ -185,7 +186,7 @@ class LMHead(nn.Module):
h_trunc = h[:, :-1].contiguous().view(-1, self.n_embd)
lm_logits = self.decoder(h_trunc)
return lm_logits
-
class MultipleChoiceHead(nn.Module):
""" Classifier Head for the transformer """
@@ -266,14 +267,22 @@ class SimilarityHead(nn.Module):
class DoubleHeadModel(nn.Module):
""" Transformer with language model and task specific heads """
def __init__(self, cfg, clf_token, vocab=40990, n_ctx=512, n_class = None):
def __init__(self, cfg, clf_token, task_head_type, vocab=40990, n_ctx=512):
super(DoubleHeadModel, self).__init__()
self.transformer = TransformerModel(cfg, vocab=vocab, n_ctx=n_ctx)
self.lm_head = LMHead(self.transformer, cfg)
if n_class is None:
self.task_head = MultipleChoiceHead(clf_token, cfg)
else:
if isinstance(task_head_type, str):
if task_head_type == 'multiple_choice':
self.task_head = MultipleChoiceHead(clf_token, cfg)
elif task_head_type == 'similarity':
self.task_head = SimilarityHead(clf_token, cfg)
elif isinstance(task_head_type, collections.abc.Sequence) and len(task_head_type) == 2 and \
task_head_type[0] == 'classification':
n_class = task_head[1]
self.task_head = ClfHead(clf_token, cfg, n_class)
else:
raise ValueError(f"task_head_type expected to be 'multiple_choice' "
"'similarity' or ('classification', n_class) got {task_head_type}.")
def forward(self, x):
h = self.transformer(x)
+1 -1
View File
@@ -263,7 +263,7 @@ if __name__ == '__main__':
n_batch_train = args.n_batch * max(n_gpu, 1)
n_updates_total = (n_train // n_batch_train) * args.n_iter
dh_model = DoubleHeadModel(args, clf_token, vocab, n_ctx)
dh_model = DoubleHeadModel(args, clf_token, 'multiple_choice', vocab, n_ctx)
criterion = nn.CrossEntropyLoss(reduce=False)
model_opt = OpenAIAdam(dh_model.parameters(),