mirror of
https://github.com/wassname/openai-transformer-lm-gutenberg-erotic.git
synced 2026-06-27 16:10:19 +08:00
Modifying the code of DoubleHeadModel to allow different task heads.
This commit is contained in:
+14
-5
@@ -2,6 +2,7 @@ import copy
|
||||
import json
|
||||
import math
|
||||
import re
|
||||
import collections
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -185,7 +186,7 @@ class LMHead(nn.Module):
|
||||
h_trunc = h[:, :-1].contiguous().view(-1, self.n_embd)
|
||||
lm_logits = self.decoder(h_trunc)
|
||||
return lm_logits
|
||||
-
|
||||
|
||||
|
||||
class MultipleChoiceHead(nn.Module):
|
||||
""" Classifier Head for the transformer """
|
||||
@@ -266,14 +267,22 @@ class SimilarityHead(nn.Module):
|
||||
|
||||
class DoubleHeadModel(nn.Module):
|
||||
""" Transformer with language model and task specific heads """
|
||||
def __init__(self, cfg, clf_token, vocab=40990, n_ctx=512, n_class = None):
|
||||
def __init__(self, cfg, clf_token, task_head_type, vocab=40990, n_ctx=512):
|
||||
super(DoubleHeadModel, self).__init__()
|
||||
self.transformer = TransformerModel(cfg, vocab=vocab, n_ctx=n_ctx)
|
||||
self.lm_head = LMHead(self.transformer, cfg)
|
||||
if n_class is None:
|
||||
self.task_head = MultipleChoiceHead(clf_token, cfg)
|
||||
else:
|
||||
if isinstance(task_head_type, str):
|
||||
if task_head_type == 'multiple_choice':
|
||||
self.task_head = MultipleChoiceHead(clf_token, cfg)
|
||||
elif task_head_type == 'similarity':
|
||||
self.task_head = SimilarityHead(clf_token, cfg)
|
||||
elif isinstance(task_head_type, collections.abc.Sequence) and len(task_head_type) == 2 and \
|
||||
task_head_type[0] == 'classification':
|
||||
n_class = task_head[1]
|
||||
self.task_head = ClfHead(clf_token, cfg, n_class)
|
||||
else:
|
||||
raise ValueError(f"task_head_type expected to be 'multiple_choice' "
|
||||
"'similarity' or ('classification', n_class) got {task_head_type}.")
|
||||
|
||||
def forward(self, x):
|
||||
h = self.transformer(x)
|
||||
|
||||
@@ -263,7 +263,7 @@ if __name__ == '__main__':
|
||||
n_batch_train = args.n_batch * max(n_gpu, 1)
|
||||
n_updates_total = (n_train // n_batch_train) * args.n_iter
|
||||
|
||||
dh_model = DoubleHeadModel(args, clf_token, vocab, n_ctx)
|
||||
dh_model = DoubleHeadModel(args, clf_token, 'multiple_choice', vocab, n_ctx)
|
||||
|
||||
criterion = nn.CrossEntropyLoss(reduce=False)
|
||||
model_opt = OpenAIAdam(dh_model.parameters(),
|
||||
|
||||
Reference in New Issue
Block a user