Renaming ClfHead in MultipleChoiceHead.

This commit is contained in:
Grégory Châtel
2018-07-12 14:23:09 +02:00
parent 93522a3b59
commit ed8bb28b50
+6 -6
View File
@@ -187,11 +187,11 @@ class LMHead(nn.Module):
return lm_logits
class ClfHead(nn.Module):
class MultipleChoiceHead(nn.Module):
""" Classifier Head for the transformer """
def __init__(self, clf_token, cfg):
super(ClfHead, self).__init__()
super(MultipleChoiceHead, self).__init__()
self.n_embd = cfg.n_embd
self.clf_token = clf_token
self.dropout = nn.Dropout2d(cfg.clf_pdrop) # To reproduce the noise_shape parameter of TF implementation
@@ -216,18 +216,18 @@ class ClfHead(nn.Module):
class DoubleHeadModel(nn.Module):
""" Transformer with language model and classification heads """
""" Transformer with language model and multiple choice heads """
def __init__(self, cfg, clf_token, vocab=40990, n_ctx=512):
super(DoubleHeadModel, self).__init__()
self.transformer = TransformerModel(cfg, vocab=vocab, n_ctx=n_ctx)
self.lm_head = LMHead(self.transformer, cfg)
self.clf_head = ClfHead(clf_token, cfg)
self.choice_head = MultipleChoiceHead(clf_token, cfg)
def forward(self, x):
h = self.transformer(x)
lm_logits = self.lm_head(h)
clf_logits = self.clf_head(h, x)
return lm_logits, clf_logits
choice_logits = self.choice_head(h, x)
return lm_logits, choice_logits
def load_openai_pretrained_model(model, n_ctx=-1, n_special=-1, n_transfer=12, n_embd=768, path='./model/',