mirror of
https://github.com/wassname/openai-transformer-lm-gutenberg-erotic.git
synced 2026-06-27 16:10:19 +08:00
Renaming ClfHead in MultipleChoiceHead.
This commit is contained in:
+6
-6
@@ -187,11 +187,11 @@ class LMHead(nn.Module):
|
||||
return lm_logits
|
||||
|
||||
|
||||
class ClfHead(nn.Module):
|
||||
class MultipleChoiceHead(nn.Module):
|
||||
""" Classifier Head for the transformer """
|
||||
|
||||
def __init__(self, clf_token, cfg):
|
||||
super(ClfHead, self).__init__()
|
||||
super(MultipleChoiceHead, self).__init__()
|
||||
self.n_embd = cfg.n_embd
|
||||
self.clf_token = clf_token
|
||||
self.dropout = nn.Dropout2d(cfg.clf_pdrop) # To reproduce the noise_shape parameter of TF implementation
|
||||
@@ -216,18 +216,18 @@ class ClfHead(nn.Module):
|
||||
|
||||
|
||||
class DoubleHeadModel(nn.Module):
|
||||
""" Transformer with language model and classification heads """
|
||||
""" Transformer with language model and multiple choice heads """
|
||||
def __init__(self, cfg, clf_token, vocab=40990, n_ctx=512):
|
||||
super(DoubleHeadModel, self).__init__()
|
||||
self.transformer = TransformerModel(cfg, vocab=vocab, n_ctx=n_ctx)
|
||||
self.lm_head = LMHead(self.transformer, cfg)
|
||||
self.clf_head = ClfHead(clf_token, cfg)
|
||||
self.choice_head = MultipleChoiceHead(clf_token, cfg)
|
||||
|
||||
def forward(self, x):
|
||||
h = self.transformer(x)
|
||||
lm_logits = self.lm_head(h)
|
||||
clf_logits = self.clf_head(h, x)
|
||||
return lm_logits, clf_logits
|
||||
choice_logits = self.choice_head(h, x)
|
||||
return lm_logits, choice_logits
|
||||
|
||||
|
||||
def load_openai_pretrained_model(model, n_ctx=-1, n_special=-1, n_transfer=12, n_embd=768, path='./model/',
|
||||
|
||||
Reference in New Issue
Block a user