mirror of
https://github.com/wassname/openai-transformer-lm-gutenberg-erotic.git
synced 2026-06-27 16:10:19 +08:00
290 lines
11 KiB
Python
290 lines
11 KiB
Python
import re
|
|
import os
|
|
import time
|
|
import math
|
|
import json
|
|
import joblib
|
|
import random
|
|
import argparse
|
|
import numpy as np
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
from tqdm import tqdm
|
|
from functools import partial
|
|
from sklearn.utils import shuffle
|
|
from sklearn.metrics import accuracy_score
|
|
|
|
from model_py import Model, LMHead, ClfHead, load_openai_pretrained_model
|
|
from opt import OpenAIAdam
|
|
from datasets import rocstories
|
|
from analysis import rocstories as rocstories_analysis
|
|
from text_utils import TextEncoder
|
|
from utils import (encode_dataset, flatten, iter_data,
|
|
ResultLogger, make_path)
|
|
|
|
class LossCompute:
|
|
"A Loss compute and train function."
|
|
def __init__(self, lm_criterion, clf_criterion, lm_coef, opt=None):
|
|
self.lm_criterion = lm_criterion
|
|
self.clf_criterion = clf_criterion
|
|
self.lm_coef = lm_coef
|
|
self.opt = opt
|
|
|
|
def __call__(self, X, Y, M, clf_logits, lm_logits=None, only_return_losses=False):
|
|
# Language modeling loss
|
|
if lm_logits is not None:
|
|
x_shifted = X[:, :, 1:, 0].contiguous().view(-1) # Shape: 252
|
|
M = M.view(-1, M.size(2))
|
|
lm_losses = self.lm_criterion(lm_logits, x_shifted)
|
|
lm_losses = lm_losses.view(X.size(0) * X.size(1), X.size(2)-1)
|
|
lm_losses = lm_losses * M[:, 1:]
|
|
lm_losses = lm_losses.sum(1) / torch.sum(M[:, 1:], 1)
|
|
# Classification loss
|
|
clf_losses = self.clf_criterion(clf_logits, Y)
|
|
if only_return_losses:
|
|
return (clf_losses, lm_losses) if lm_logits is not None else clf_losses
|
|
|
|
if self.lm_coef > 0 and lm_logits is not None:
|
|
train_loss = clf_losses.sum() + self.lm_coef * lm_losses.sum()
|
|
else:
|
|
train_loss = clf_losses.sum()
|
|
train_loss.backward()
|
|
if self.opt is not None:
|
|
self.opt.step()
|
|
self.opt.zero_grad()
|
|
return train_loss.item()
|
|
|
|
|
|
def transform_roc(X1, X2, X3):
|
|
n_batch = len(X1)
|
|
xmb = np.zeros((n_batch, 2, n_ctx, 2), dtype=np.int32)
|
|
mmb = np.zeros((n_batch, 2, n_ctx), dtype=np.float32)
|
|
start = encoder['_start_']
|
|
delimiter = encoder['_delimiter_']
|
|
for i, (x1, x2, x3), in enumerate(zip(X1, X2, X3)):
|
|
x12 = [start]+x1[:max_len]+[delimiter]+x2[:max_len]+[clf_token]
|
|
x13 = [start]+x1[:max_len]+[delimiter]+x3[:max_len]+[clf_token]
|
|
l12 = len(x12)
|
|
l13 = len(x13)
|
|
xmb[i, 0, :l12, 0] = x12
|
|
xmb[i, 1, :l13, 0] = x13
|
|
mmb[i, 0, :l12] = 1
|
|
mmb[i, 1, :l13] = 1
|
|
xmb[:, :, :, 1] = np.arange(n_vocab+n_special, n_vocab+n_special+n_ctx)
|
|
return xmb, mmb
|
|
|
|
def iter_apply(Xs, Ms, Ys):
|
|
fns = [lambda x: np.concatenate(x, 0), lambda x: float(np.sum(x))]
|
|
results = []
|
|
with torch.no_grad():
|
|
model.eval()
|
|
for xmb, mmb, ymb in iter_data(Xs, Ms, Ys, n_batch=n_batch_train, truncate=False, verbose=True):
|
|
n = len(xmb)
|
|
XMB = torch.tensor(xmb, dtype=torch.long).to(device)
|
|
YMB = torch.tensor(ymb, dtype=torch.long).to(device)
|
|
MMB = torch.tensor(mmb).to(device)
|
|
h = model(XMB)
|
|
clf_logits = clf_head(h, XMB)
|
|
clf_losses = compute_loss(XMB, YMB, MMB, clf_logits, only_return_losses=True)
|
|
res = (clf_logits.numpy()*n, clf_losses.numpy()*n)
|
|
results.append(res)
|
|
results = zip(*results)
|
|
return [fn(res) for res, fn in zip(results, fns)]
|
|
|
|
def iter_predict(Xs, Ms):
|
|
logits = []
|
|
with torch.no_grad():
|
|
model.eval()
|
|
for xmb, mmb in iter_data(Xs, Ms, n_batch=n_batch_train, truncate=False, verbose=True):
|
|
n = len(xmb)
|
|
XMB = torch.tensor(xmb, dtype=torch.long).to(device)
|
|
MMB = torch.tensor(mmb).to(device)
|
|
h = model(XMB)
|
|
clf_logits = clf_head(h, XMB)
|
|
logits.append(clf_logits.numpy())
|
|
logits = np.concatenate(logits, 0)
|
|
return logits
|
|
|
|
def log():
|
|
global best_score
|
|
tr_logits, tr_cost = iter_apply(trX[:n_valid], trM[:n_valid], trY[:n_valid])
|
|
va_logits, va_cost = iter_apply(vaX, vaM, vaY)
|
|
tr_cost = tr_cost/len(trY[:n_valid])
|
|
va_cost = va_cost/n_valid
|
|
tr_acc = accuracy_score(trY[:n_valid], np.argmax(tr_logits, 1))*100.
|
|
va_acc = accuracy_score(vaY, np.argmax(va_logits, 1))*100.
|
|
logger.log(n_epochs=n_epochs, n_updates=n_updates, tr_cost=tr_cost, va_cost=va_cost, tr_acc=tr_acc, va_acc=va_acc)
|
|
print('%d %d %.3f %.3f %.2f %.2f'%(n_epochs, n_updates, tr_cost, va_cost, tr_acc, va_acc))
|
|
if submit:
|
|
score = va_acc
|
|
if score > best_score:
|
|
best_score = score
|
|
path = os.path.join(save_dir, desc, 'best_params')
|
|
torch.save(model.state_dict(), make_path(path))
|
|
|
|
def predict():
|
|
filename = filenames[dataset]
|
|
pred_fn = pred_fns[dataset]
|
|
label_decoder = label_decoders[dataset]
|
|
predictions = pred_fn(iter_predict(teX, teM))
|
|
if label_decoder is not None:
|
|
predictions = [label_decoder[prediction] for prediction in predictions]
|
|
path = os.path.join(submission_dir, filename)
|
|
os.makedirs(os.path.dirname(path), exist_ok=True)
|
|
with open(path, 'w') as f:
|
|
f.write('{}\t{}\n'.format('index', 'prediction'))
|
|
for i, prediction in enumerate(predictions):
|
|
f.write('{}\t{}\n'.format(i, prediction))
|
|
|
|
def run_epoch():
|
|
for xmb, mmb, ymb in iter_data(*shuffle(trX, trM, trYt, random_state=np.random),
|
|
n_batch=n_batch_train, truncate=True, verbose=True):
|
|
global n_updates
|
|
XMB = torch.tensor(xmb, dtype=torch.long).to(device)
|
|
YMB = torch.tensor(ymb, dtype=torch.long).to(device)
|
|
MMB = torch.tensor(mmb).to(device)
|
|
model.train()
|
|
h = model(XMB)
|
|
lm_logits = lm_head(h)
|
|
clf_logits = clf_head(h, XMB)
|
|
compute_loss(XMB, YMB, MMB, clf_logits, lm_logits)
|
|
n_updates += 1
|
|
if n_updates in [1000, 2000, 4000, 8000, 16000, 32000] and n_epochs == 0:
|
|
log()
|
|
|
|
argmax = lambda x:np.argmax(x, 1)
|
|
|
|
pred_fns = {
|
|
'rocstories':argmax,
|
|
}
|
|
|
|
filenames = {
|
|
'rocstories':'ROCStories.tsv',
|
|
}
|
|
|
|
label_decoders = {
|
|
'rocstories':None,
|
|
}
|
|
|
|
if __name__ == '__main__':
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument('--desc', type=str)
|
|
parser.add_argument('--dataset', type=str)
|
|
parser.add_argument('--log_dir', type=str, default='log/')
|
|
parser.add_argument('--save_dir', type=str, default='save/')
|
|
parser.add_argument('--data_dir', type=str, default='data/')
|
|
parser.add_argument('--submission_dir', type=str, default='submission/')
|
|
parser.add_argument('--submit', action='store_true')
|
|
parser.add_argument('--analysis', action='store_true')
|
|
parser.add_argument('--seed', type=int, default=42)
|
|
parser.add_argument('--n_iter', type=int, default=3)
|
|
parser.add_argument('--n_batch', type=int, default=8)
|
|
parser.add_argument('--max_grad_norm', type=int, default=1)
|
|
parser.add_argument('--lr', type=float, default=6.25e-5)
|
|
parser.add_argument('--lr_warmup', type=float, default=0.002)
|
|
parser.add_argument('--n_ctx', type=int, default=512)
|
|
parser.add_argument('--n_embd', type=int, default=768)
|
|
parser.add_argument('--n_head', type=int, default=12)
|
|
parser.add_argument('--n_layer', type=int, default=12)
|
|
parser.add_argument('--embd_pdrop', type=float, default=0.1)
|
|
parser.add_argument('--attn_pdrop', type=float, default=0.1)
|
|
parser.add_argument('--resid_pdrop', type=float, default=0.1)
|
|
parser.add_argument('--clf_pdrop', type=float, default=0.1)
|
|
parser.add_argument('--l2', type=float, default=0.01)
|
|
parser.add_argument('--vector_l2', action='store_true')
|
|
parser.add_argument('--n_gpu', type=int, default=4)
|
|
parser.add_argument('--opt', type=str, default='adam')
|
|
parser.add_argument('--afn', type=str, default='gelu')
|
|
parser.add_argument('--lr_schedule', type=str, default='warmup_linear')
|
|
parser.add_argument('--encoder_path', type=str, default='model/encoder_bpe_40000.json')
|
|
parser.add_argument('--bpe_path', type=str, default='model/vocab_40000.bpe')
|
|
parser.add_argument('--n_transfer', type=int, default=12)
|
|
parser.add_argument('--lm_coef', type=float, default=0.5)
|
|
parser.add_argument('--b1', type=float, default=0.9)
|
|
parser.add_argument('--b2', type=float, default=0.999)
|
|
parser.add_argument('--e', type=float, default=1e-8)
|
|
|
|
args = parser.parse_args()
|
|
print(args)
|
|
globals().update(args.__dict__) #TODO remove gobal
|
|
random.seed(seed)
|
|
np.random.seed(seed)
|
|
torch.manual_seed(seed)
|
|
torch.cuda.manual_seed_all(seed)
|
|
|
|
# torch.device object used throughout this script TODO add gpu setting
|
|
device = torch.device("cpu") #"cuda" if use_cuda else "cpu")
|
|
|
|
logger = ResultLogger(path=os.path.join(log_dir, '{}.jsonl'.format(desc)), **args.__dict__)
|
|
text_encoder = TextEncoder(encoder_path, bpe_path)
|
|
encoder = text_encoder.encoder
|
|
n_vocab = len(text_encoder.encoder)
|
|
|
|
(trX1, trX2, trX3, trY), (vaX1, vaX2, vaX3, vaY), (teX1, teX2, teX3) = encode_dataset(rocstories(data_dir), encoder=text_encoder)
|
|
n_y = 2
|
|
encoder['_start_'] = len(encoder)
|
|
encoder['_delimiter_'] = len(encoder)
|
|
encoder['_classify_'] = len(encoder)
|
|
clf_token = encoder['_classify_']
|
|
n_special = 3
|
|
max_len = n_ctx//2-2
|
|
n_ctx = min(
|
|
max(
|
|
[len(x1[:max_len])+max(len(x2[:max_len]), len(x3[:max_len])) for x1, x2, x3 in zip(trX1, trX2, trX3)]
|
|
+[len(x1[:max_len])+max(len(x2[:max_len]), len(x3[:max_len])) for x1, x2, x3 in zip(vaX1, vaX2, vaX3)]
|
|
+[len(x1[:max_len])+max(len(x2[:max_len]), len(x3[:max_len])) for x1, x2, x3 in zip(teX1, teX2, teX3)]
|
|
)+3, n_ctx
|
|
)
|
|
vocab = n_vocab + n_special + n_ctx
|
|
trX, trM = transform_roc(trX1, trX2, trX3)
|
|
vaX, vaM = transform_roc(vaX1, vaX2, vaX3)
|
|
if submit:
|
|
teX, teM = transform_roc(teX1, teX2, teX3)
|
|
|
|
n_train = len(trY)
|
|
n_valid = len(vaY)
|
|
n_batch_train = n_batch*n_gpu
|
|
n_updates_total = (n_train//n_batch_train)*n_iter
|
|
|
|
model = Model(vocab, args)
|
|
lm_head = LMHead(model, args)
|
|
clf_head = ClfHead(clf_token, args)
|
|
|
|
criterion = nn.CrossEntropyLoss(reduce=False) # TODO check loss functions
|
|
model_opt = OpenAIAdam(model.parameters(), lr=lr, schedule=lr_schedule,
|
|
warmup=lr_warmup, t_total=n_updates_total, b1=b1,
|
|
b2=b2, e=e, l2=l2, vector_l2=vector_l2,
|
|
max_grad_norm=max_grad_norm)
|
|
|
|
compute_loss = LossCompute(criterion, criterion, lm_coef, model_opt)
|
|
# TODO add train() and eval()
|
|
load_openai_pretrained_model(model, n_ctx, n_special, args)
|
|
|
|
model.to(device)
|
|
lm_head.to(device)
|
|
clf_head.to(device)
|
|
|
|
n_updates = 0
|
|
n_epochs = 0
|
|
if dataset != 'stsb':
|
|
trYt = trY
|
|
if submit:
|
|
path = os.path.join(save_dir, desc, 'best_params')
|
|
torch.save(model.state_dict(), make_path(path))
|
|
best_score = 0
|
|
log()
|
|
for i in range(n_iter):
|
|
run_epoch()
|
|
n_epochs += 1
|
|
log()
|
|
if submit:
|
|
path = os.path.join(save_dir, desc, 'best_params')
|
|
model.load_state_dict(torch.load(path))
|
|
predict()
|
|
if analysis:
|
|
rocstories_analysis(data_dir, os.path.join(submission_dir, 'ROCStories.tsv'),
|
|
os.path.join(log_dir, 'rocstories.jsonl'))
|