mirror of
https://github.com/wassname/openai-transformer-lm-gutenberg-erotic.git
synced 2026-06-27 16:10:19 +08:00
Merge pull request #25 from rodgzilla/multiple_choice_head
Simplifying the use of the model to perform different tasks
This commit is contained in:
@@ -0,0 +1,69 @@
|
||||
import torch
|
||||
|
||||
class MultipleChoiceLossCompute:
|
||||
"A Loss compute and train function for multiple choice tasks."
|
||||
|
||||
def __init__(self, lm_criterion, clf_criterion, lm_coef, opt=None):
|
||||
self.lm_criterion = lm_criterion
|
||||
self.clf_criterion = clf_criterion
|
||||
self.lm_coef = lm_coef
|
||||
self.opt = opt
|
||||
|
||||
def __call__(self, X, Y, M, clf_logits, lm_logits=None, only_return_losses=False):
|
||||
# Language modeling loss
|
||||
if lm_logits is not None:
|
||||
x_shifted = X[:, :, 1:, 0].contiguous().view(-1) # Shape: 252
|
||||
M = M.view(-1, M.size(2))
|
||||
lm_losses = self.lm_criterion(lm_logits, x_shifted)
|
||||
lm_losses = lm_losses.view(X.size(0) * X.size(1), X.size(2) - 1)
|
||||
lm_losses = lm_losses * M[:, 1:]
|
||||
lm_losses = lm_losses.sum(1) / torch.sum(M[:, 1:], 1)
|
||||
# Classification loss
|
||||
clf_losses = self.clf_criterion(clf_logits, Y)
|
||||
if only_return_losses:
|
||||
return (clf_losses, lm_losses) if lm_logits is not None else clf_losses
|
||||
|
||||
if self.lm_coef > 0 and lm_logits is not None:
|
||||
train_loss = clf_losses.sum() + self.lm_coef * lm_losses.sum()
|
||||
else:
|
||||
train_loss = clf_losses.sum()
|
||||
train_loss.backward()
|
||||
if self.opt is not None:
|
||||
self.opt.step()
|
||||
self.opt.zero_grad()
|
||||
return train_loss.item()
|
||||
|
||||
class ClassificationLossCompute:
|
||||
"A Loss compute and train function for classification tasks."
|
||||
|
||||
def __init__(self, lm_criterion, clf_criterion, lm_coef, opt=None):
|
||||
self.lm_criterion = lm_criterion
|
||||
self.clf_criterion = clf_criterion
|
||||
self.lm_coef = lm_coef
|
||||
self.opt = opt
|
||||
|
||||
def __call__(self, X, Y, M, clf_logits, lm_logits=None, only_return_losses=False):
|
||||
# Language modeling loss
|
||||
if lm_logits is not None:
|
||||
x_shifted = X[:, 1:, 0].contiguous().view(-1)
|
||||
M = M.view(-1, M.size(-1))
|
||||
lm_losses = self.lm_criterion(lm_logits, x_shifted)
|
||||
lm_losses = lm_losses.view(X.size(0), X.size(-2) - 1)
|
||||
lm_losses = lm_losses * M[:, 1:]
|
||||
lm_losses = lm_losses.sum(1) / torch.sum(M[:, 1:], 1)
|
||||
# Classification loss
|
||||
clf_losses = self.clf_criterion(clf_logits, Y)
|
||||
if only_return_losses:
|
||||
return (clf_losses, lm_losses) if lm_logits is not None else clf_losses
|
||||
|
||||
if self.lm_coef > 0 and lm_logits is not None:
|
||||
train_loss = clf_losses.sum() + self.lm_coef * lm_losses.sum()
|
||||
else:
|
||||
train_loss = clf_losses.sum()
|
||||
train_loss.backward()
|
||||
if self.opt is not None:
|
||||
self.opt.step()
|
||||
self.opt.zero_grad()
|
||||
return train_loss.item()
|
||||
|
||||
# TODO Implement a LossCompute class for similiraty tasks.
|
||||
+79
-9
@@ -2,6 +2,7 @@ import copy
|
||||
import json
|
||||
import math
|
||||
import re
|
||||
import collections
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -187,22 +188,23 @@ class LMHead(nn.Module):
|
||||
return lm_logits
|
||||
|
||||
|
||||
class ClfHead(nn.Module):
|
||||
class MultipleChoiceHead(nn.Module):
|
||||
""" Classifier Head for the transformer """
|
||||
|
||||
def __init__(self, clf_token, cfg):
|
||||
super(ClfHead, self).__init__()
|
||||
super(MultipleChoiceHead, self).__init__()
|
||||
self.n_embd = cfg.n_embd
|
||||
self.clf_token = clf_token
|
||||
self.dropout = nn.Dropout2d(cfg.clf_pdrop) # To reproduce the noise_shape parameter of TF implementation
|
||||
self.linear = nn.Linear(cfg.n_embd, 1)
|
||||
nn.init.normal_(self.linear.weight, std=0.02)
|
||||
|
||||
nn.init.normal_(self.linear.weight, std = 0.02)
|
||||
nn.init.normal_(self.linear.bias, 0)
|
||||
|
||||
def forward(self, h, x):
|
||||
# Classification logits
|
||||
clf_h = h.view(-1, self.n_embd)
|
||||
flat = x[:, :, :, 0].contiguous().view(-1)
|
||||
flat = x[..., 0].contiguous().view(-1)
|
||||
clf_h = clf_h[flat == self.clf_token, :]
|
||||
clf_h = clf_h.view(-1, x.size(1), self.n_embd, 1)
|
||||
# This double transposition is there to replicate the behavior
|
||||
@@ -212,22 +214,90 @@ class ClfHead(nn.Module):
|
||||
clf_h = self.dropout(clf_h.transpose(1, 2)).transpose(1, 2)
|
||||
clf_h = clf_h.contiguous().view(-1, self.n_embd)
|
||||
clf_logits = self.linear(clf_h)
|
||||
|
||||
return clf_logits.view(-1, x.size(1))
|
||||
|
||||
|
||||
class ClfHead(nn.Module):
|
||||
"""Classification Head for the transformer
|
||||
|
||||
TODO: test this class."""
|
||||
def __init__(self, clf_token, cfg, n_class):
|
||||
super(ClfHead, self).__init__()
|
||||
self.n_embd = cfg.n_embd
|
||||
self.clf_token = clf_token
|
||||
self.dropout = nn.Dropout(cfg.clf_pdrop)
|
||||
self.linear = nn.Linear(cfg.n_embd, n_class)
|
||||
|
||||
nn.init.normal_(self.linear.weight, std = 0.02)
|
||||
nn.init.normal_(self.linear.bias, 0)
|
||||
|
||||
def forward(self, h, x):
|
||||
clf_h = h.view(-1, self.n_embd)
|
||||
flat = x[..., 0].contiguous().view(-1)
|
||||
clf_h = clf_h[flat == self.clf_token, :]
|
||||
clf_h = self.dropout(clf_h)
|
||||
clf_logits = self.linear(clf_h)
|
||||
|
||||
return clf_logits
|
||||
|
||||
class SimilarityHead(nn.Module):
|
||||
""" Similarity Head for the transformer
|
||||
|
||||
TODO: test this class."""
|
||||
def __init__(self, clf_token, cfg):
|
||||
super(SimilarityHead, self).__init__()
|
||||
self.n_embd = cfg.n_embd
|
||||
self.clf_token = clf_token
|
||||
self.dropout = nn.Dropout(cfg.clf_pdrop)
|
||||
self.linear = nn.Linear(cfg.n_embd, 1)
|
||||
|
||||
nn.init.normal_(self.linear.weight, std = 0.02)
|
||||
nn.init.normal_(self.linear.bias, 0)
|
||||
|
||||
def forward(self, h, x):
|
||||
sim_h = h.view(-1, self.n_embd)
|
||||
flat = x[..., 0].contiguous().view(-1)
|
||||
sim_h = sim_h[flat == self.clf_token, :]
|
||||
sim_h = self.dropout(sim_h)
|
||||
sim_h = sim_h.sum(dim = 1)
|
||||
sim_logits = self.linear(sim_h)
|
||||
|
||||
return sim_logits
|
||||
|
||||
class DoubleHeadModel(nn.Module):
|
||||
""" Transformer with language model and classification heads """
|
||||
def __init__(self, cfg, clf_token, vocab=40990, n_ctx=512):
|
||||
""" Transformer with language model and task specific heads """
|
||||
def __init__(self, cfg, clf_token, task_head_type, vocab=40990, n_ctx=512):
|
||||
super(DoubleHeadModel, self).__init__()
|
||||
self.transformer = TransformerModel(cfg, vocab=vocab, n_ctx=n_ctx)
|
||||
self.lm_head = LMHead(self.transformer, cfg)
|
||||
self.clf_head = ClfHead(clf_token, cfg)
|
||||
if isinstance(task_head_type, str):
|
||||
if task_head_type == 'multiple_choice':
|
||||
self.task_head = MultipleChoiceHead(clf_token, cfg)
|
||||
elif task_head_type == 'similarity':
|
||||
self.task_head = SimilarityHead(clf_token, cfg)
|
||||
elif task_head_type == 'inference':
|
||||
# the three classes correspond to entailment, contradiction and neutral.
|
||||
self.task_head = ClfHead(clf_token, cfg, 3)
|
||||
else:
|
||||
raise ValueError("task_head_type is expected to be 'multiple_choice' "
|
||||
"'similarity', 'inference' or ('classification', n_class) "
|
||||
f"got {task_head_type}.")
|
||||
elif isinstance(task_head_type, collections.abc.Sequence) and len(task_head_type) == 2 and \
|
||||
task_head_type[0] == 'classification':
|
||||
n_class = task_head_type[1]
|
||||
self.task_head = ClfHead(clf_token, cfg, n_class)
|
||||
else:
|
||||
raise ValueError("task_head_type is expected to be 'multiple_choice' "
|
||||
"'similarity', 'inference' or ('classification', n_class) "
|
||||
f"got {task_head_type}.")
|
||||
|
||||
def forward(self, x):
|
||||
h = self.transformer(x)
|
||||
lm_logits = self.lm_head(h)
|
||||
clf_logits = self.clf_head(h, x)
|
||||
return lm_logits, clf_logits
|
||||
task_logits = self.task_head(h, x)
|
||||
|
||||
return lm_logits, task_logits
|
||||
|
||||
|
||||
def load_openai_pretrained_model(model, n_ctx=-1, n_special=-1, n_transfer=12, n_embd=768, path='./model/',
|
||||
|
||||
@@ -15,41 +15,7 @@ from opt import OpenAIAdam
|
||||
from text_utils import TextEncoder
|
||||
from utils import (encode_dataset, iter_data,
|
||||
ResultLogger, make_path)
|
||||
|
||||
|
||||
class LossCompute:
|
||||
"A Loss compute and train function."
|
||||
|
||||
def __init__(self, lm_criterion, clf_criterion, lm_coef, opt=None):
|
||||
self.lm_criterion = lm_criterion
|
||||
self.clf_criterion = clf_criterion
|
||||
self.lm_coef = lm_coef
|
||||
self.opt = opt
|
||||
|
||||
def __call__(self, X, Y, M, clf_logits, lm_logits=None, only_return_losses=False):
|
||||
# Language modeling loss
|
||||
if lm_logits is not None:
|
||||
x_shifted = X[:, :, 1:, 0].contiguous().view(-1) # Shape: 252
|
||||
M = M.view(-1, M.size(2))
|
||||
lm_losses = self.lm_criterion(lm_logits, x_shifted)
|
||||
lm_losses = lm_losses.view(X.size(0) * X.size(1), X.size(2) - 1)
|
||||
lm_losses = lm_losses * M[:, 1:]
|
||||
lm_losses = lm_losses.sum(1) / torch.sum(M[:, 1:], 1)
|
||||
# Classification loss
|
||||
clf_losses = self.clf_criterion(clf_logits, Y)
|
||||
if only_return_losses:
|
||||
return (clf_losses, lm_losses) if lm_logits is not None else clf_losses
|
||||
|
||||
if self.lm_coef > 0 and lm_logits is not None:
|
||||
train_loss = clf_losses.sum() + self.lm_coef * lm_losses.sum()
|
||||
else:
|
||||
train_loss = clf_losses.sum()
|
||||
train_loss.backward()
|
||||
if self.opt is not None:
|
||||
self.opt.step()
|
||||
self.opt.zero_grad()
|
||||
return train_loss.item()
|
||||
|
||||
from loss import MultipleChoiceLossCompute
|
||||
|
||||
def transform_roc(X1, X2, X3):
|
||||
n_batch = len(X1)
|
||||
@@ -263,7 +229,7 @@ if __name__ == '__main__':
|
||||
n_batch_train = args.n_batch * max(n_gpu, 1)
|
||||
n_updates_total = (n_train // n_batch_train) * args.n_iter
|
||||
|
||||
dh_model = DoubleHeadModel(args, clf_token, vocab, n_ctx)
|
||||
dh_model = DoubleHeadModel(args, clf_token, 'multiple_choice', vocab, n_ctx)
|
||||
|
||||
criterion = nn.CrossEntropyLoss(reduce=False)
|
||||
model_opt = OpenAIAdam(dh_model.parameters(),
|
||||
@@ -277,10 +243,10 @@ if __name__ == '__main__':
|
||||
l2=args.l2,
|
||||
vector_l2=args.vector_l2,
|
||||
max_grad_norm=args.max_grad_norm)
|
||||
compute_loss_fct = LossCompute(criterion,
|
||||
criterion,
|
||||
args.lm_coef,
|
||||
model_opt)
|
||||
compute_loss_fct = MultipleChoiceLossCompute(criterion,
|
||||
criterion,
|
||||
args.lm_coef,
|
||||
model_opt)
|
||||
load_openai_pretrained_model(dh_model.transformer, n_ctx=n_ctx, n_special=n_special)
|
||||
|
||||
dh_model.to(device)
|
||||
|
||||
Reference in New Issue
Block a user