mirror of
https://github.com/wassname/openai-transformer-lm-gutenberg-erotic.git
synced 2026-06-27 16:10:19 +08:00
Merge pull request #9 from nottombrown/master
Remove globals to make code easier to follow
This commit is contained in:
+2
-1
@@ -109,4 +109,5 @@ venv.bak/
|
||||
/site
|
||||
|
||||
# mypy
|
||||
.mypy_cache/
|
||||
.mypy_cache/cloze_data
|
||||
cloze_data/
|
||||
|
||||
@@ -6,7 +6,7 @@ This implementation comprises **a script to load in the PyTorch model the weight
|
||||
|
||||

|
||||
|
||||
The model classes and loading script are located in [model_py.py](model_py.py).
|
||||
The model classes and loading script are located in [model_py.py](model_pytorch.py).
|
||||
|
||||
The names of the modules in the PyTorch model follow the names of the Variable in the TensorFlow implementation. This implementation tries to follow the original code as closely as possible to minimize the discrepancies.
|
||||
|
||||
@@ -15,7 +15,7 @@ This implementation thus also comprises a modified Adam optimization algorithm a
|
||||
- scheduled learning rate as [commonly used for Transformers](http://nlp.seas.harvard.edu/2018/04/03/attention.html#optimizer).
|
||||
|
||||
## Requirements
|
||||
To use the model it-self by importing [model_py.py](model_py.py), you just need:
|
||||
To use the model it-self by importing [model_py.py](model_pytorch.py), you just need:
|
||||
- PyTorch (version >=0.4)
|
||||
|
||||
To run the classifier training script in [train.py](train.py) you will need in addition:
|
||||
@@ -49,6 +49,7 @@ The ROCStories dataset can be downloaded from the associated [website](http://cs
|
||||
As with the [TensorFlow code](https://github.com/openai/finetune-transformer-lm), this code implements the ROCStories Cloze Test result reported in the paper which can be reproduced by running:
|
||||
|
||||
```bash
|
||||
python -m spacy download en
|
||||
python train.py --dataset rocstories --desc rocstories --submit --analysis --data_dir [path to data here]
|
||||
```
|
||||
|
||||
|
||||
@@ -1,14 +1,15 @@
|
||||
import re
|
||||
import math
|
||||
import json
|
||||
import copy
|
||||
import numpy as np
|
||||
import json
|
||||
import math
|
||||
import re
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
|
||||
def gelu(x):
|
||||
return 0.5*x*(1+torch.tanh(math.sqrt(2/math.pi)*(x+0.044715*torch.pow(x, 3))))
|
||||
|
||||
@@ -63,7 +64,7 @@ class Conv1D(nn.Module):
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, nx, n_ctx, cfg, scale=False):
|
||||
super(Attention, self).__init__()
|
||||
n_state = nx # in Attention: n_state=768 (nx=n_embd)
|
||||
n_state = nx # in Attention: n_state=768 (nx=n_embd)
|
||||
#[switch nx => n_state from Block to Attention to keep identical to TF implem]
|
||||
assert n_state % cfg.n_head==0
|
||||
self.register_buffer('b', torch.tril(torch.ones(n_ctx, n_ctx)).view(1, 1, n_ctx, n_ctx))
|
||||
@@ -1,31 +1,25 @@
|
||||
import re
|
||||
import os
|
||||
import time
|
||||
import math
|
||||
import json
|
||||
import random
|
||||
import argparse
|
||||
import numpy as np
|
||||
import os
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from tqdm import tqdm
|
||||
from functools import partial
|
||||
from sklearn.utils import shuffle
|
||||
from sklearn.metrics import accuracy_score
|
||||
from sklearn.utils import shuffle
|
||||
|
||||
from model_py import Model, LMHead, ClfHead, load_openai_pretrained_model
|
||||
from opt import OpenAIAdam
|
||||
from datasets import rocstories
|
||||
from analysis import rocstories as rocstories_analysis
|
||||
from datasets import rocstories
|
||||
from model_pytorch import Model, LMHead, ClfHead, load_openai_pretrained_model
|
||||
from opt import OpenAIAdam
|
||||
from text_utils import TextEncoder
|
||||
from utils import (encode_dataset, flatten, iter_data,
|
||||
from utils import (encode_dataset, iter_data,
|
||||
ResultLogger, make_path)
|
||||
|
||||
|
||||
class LossCompute:
|
||||
"A Loss compute and train function."
|
||||
|
||||
def __init__(self, lm_criterion, clf_criterion, lm_coef, opt=None):
|
||||
self.lm_criterion = lm_criterion
|
||||
self.clf_criterion = clf_criterion
|
||||
@@ -35,10 +29,10 @@ class LossCompute:
|
||||
def __call__(self, X, Y, M, clf_logits, lm_logits=None, only_return_losses=False):
|
||||
# Language modeling loss
|
||||
if lm_logits is not None:
|
||||
x_shifted = X[:, :, 1:, 0].contiguous().view(-1) # Shape: 252
|
||||
x_shifted = X[:, :, 1:, 0].contiguous().view(-1) # Shape: 252
|
||||
M = M.view(-1, M.size(2))
|
||||
lm_losses = self.lm_criterion(lm_logits, x_shifted)
|
||||
lm_losses = lm_losses.view(X.size(0) * X.size(1), X.size(2)-1)
|
||||
lm_losses = lm_losses.view(X.size(0) * X.size(1), X.size(2) - 1)
|
||||
lm_losses = lm_losses * M[:, 1:]
|
||||
lm_losses = lm_losses.sum(1) / torch.sum(M[:, 1:], 1)
|
||||
# Classification loss
|
||||
@@ -56,6 +50,7 @@ class LossCompute:
|
||||
self.opt.zero_grad()
|
||||
return train_loss.item()
|
||||
|
||||
|
||||
def transform_roc(X1, X2, X3):
|
||||
n_batch = len(X1)
|
||||
xmb = np.zeros((n_batch, 2, n_ctx, 2), dtype=np.int32)
|
||||
@@ -63,17 +58,18 @@ def transform_roc(X1, X2, X3):
|
||||
start = encoder['_start_']
|
||||
delimiter = encoder['_delimiter_']
|
||||
for i, (x1, x2, x3), in enumerate(zip(X1, X2, X3)):
|
||||
x12 = [start]+x1[:max_len]+[delimiter]+x2[:max_len]+[clf_token]
|
||||
x13 = [start]+x1[:max_len]+[delimiter]+x3[:max_len]+[clf_token]
|
||||
x12 = [start] + x1[:max_len] + [delimiter] + x2[:max_len] + [clf_token]
|
||||
x13 = [start] + x1[:max_len] + [delimiter] + x3[:max_len] + [clf_token]
|
||||
l12 = len(x12)
|
||||
l13 = len(x13)
|
||||
xmb[i, 0, :l12, 0] = x12
|
||||
xmb[i, 1, :l13, 0] = x13
|
||||
mmb[i, 0, :l12] = 1
|
||||
mmb[i, 1, :l13] = 1
|
||||
xmb[:, :, :, 1] = np.arange(n_vocab+n_special, n_vocab+n_special+n_ctx)
|
||||
xmb[:, :, :, 1] = np.arange(n_vocab + n_special, n_vocab + n_special + n_ctx)
|
||||
return xmb, mmb
|
||||
|
||||
|
||||
def iter_apply(Xs, Ms, Ys):
|
||||
# fns = [lambda x: np.concatenate(x, 0), lambda x: float(np.sum(x))]
|
||||
logits = []
|
||||
@@ -95,6 +91,7 @@ def iter_apply(Xs, Ms, Ys):
|
||||
logits = np.concatenate(logits, 0)
|
||||
return logits, cost
|
||||
|
||||
|
||||
def iter_predict(Xs, Ms):
|
||||
logits = []
|
||||
with torch.no_grad():
|
||||
@@ -109,17 +106,18 @@ def iter_predict(Xs, Ms):
|
||||
logits = np.concatenate(logits, 0)
|
||||
return logits
|
||||
|
||||
def log():
|
||||
|
||||
def log(save_dir, desc):
|
||||
global best_score
|
||||
print("Logging")
|
||||
tr_logits, tr_cost = iter_apply(trX[:n_valid], trM[:n_valid], trY[:n_valid])
|
||||
va_logits, va_cost = iter_apply(vaX, vaM, vaY)
|
||||
tr_cost = tr_cost/len(trY[:n_valid])
|
||||
va_cost = va_cost/n_valid
|
||||
tr_acc = accuracy_score(trY[:n_valid], np.argmax(tr_logits, 1))*100.
|
||||
va_acc = accuracy_score(vaY, np.argmax(va_logits, 1))*100.
|
||||
tr_cost = tr_cost / len(trY[:n_valid])
|
||||
va_cost = va_cost / n_valid
|
||||
tr_acc = accuracy_score(trY[:n_valid], np.argmax(tr_logits, 1)) * 100.
|
||||
va_acc = accuracy_score(vaY, np.argmax(va_logits, 1)) * 100.
|
||||
logger.log(n_epochs=n_epochs, n_updates=n_updates, tr_cost=tr_cost, va_cost=va_cost, tr_acc=tr_acc, va_acc=va_acc)
|
||||
print('%d %d %.3f %.3f %.2f %.2f'%(n_epochs, n_updates, tr_cost, va_cost, tr_acc, va_acc))
|
||||
print('%d %d %.3f %.3f %.2f %.2f' % (n_epochs, n_updates, tr_cost, va_cost, tr_acc, va_acc))
|
||||
if submit:
|
||||
score = va_acc
|
||||
if score > best_score:
|
||||
@@ -127,7 +125,8 @@ def log():
|
||||
path = os.path.join(save_dir, desc, 'best_params')
|
||||
torch.save(model.state_dict(), make_path(path))
|
||||
|
||||
def predict():
|
||||
|
||||
def predict(dataset, submission_dir):
|
||||
filename = filenames[dataset]
|
||||
pred_fn = pred_fns[dataset]
|
||||
label_decoder = label_decoders[dataset]
|
||||
@@ -141,9 +140,10 @@ def predict():
|
||||
for i, prediction in enumerate(predictions):
|
||||
f.write('{}\t{}\n'.format(i, prediction))
|
||||
|
||||
|
||||
def run_epoch():
|
||||
for xmb, mmb, ymb in iter_data(*shuffle(trX, trM, trYt, random_state=np.random),
|
||||
n_batch=n_batch_train, truncate=True, verbose=True):
|
||||
n_batch=n_batch_train, truncate=True, verbose=True):
|
||||
global n_updates
|
||||
model.train()
|
||||
XMB = torch.tensor(xmb, dtype=torch.long).to(device)
|
||||
@@ -157,23 +157,24 @@ def run_epoch():
|
||||
if n_updates in [1000, 2000, 4000, 8000, 16000, 32000] and n_epochs == 0:
|
||||
log()
|
||||
|
||||
argmax = lambda x:np.argmax(x, 1)
|
||||
|
||||
argmax = lambda x: np.argmax(x, 1)
|
||||
|
||||
pred_fns = {
|
||||
'rocstories':argmax,
|
||||
'rocstories': argmax,
|
||||
}
|
||||
|
||||
filenames = {
|
||||
'rocstories':'ROCStories.tsv',
|
||||
'rocstories': 'ROCStories.tsv',
|
||||
}
|
||||
|
||||
label_decoders = {
|
||||
'rocstories':None,
|
||||
'rocstories': None,
|
||||
}
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--desc', type=str)
|
||||
parser.add_argument('--desc', type=str, help="Description")
|
||||
parser.add_argument('--dataset', type=str)
|
||||
parser.add_argument('--log_dir', type=str, default='log/')
|
||||
parser.add_argument('--save_dir', type=str, default='save/')
|
||||
@@ -197,7 +198,7 @@ if __name__ == '__main__':
|
||||
parser.add_argument('--clf_pdrop', type=float, default=0.1)
|
||||
parser.add_argument('--l2', type=float, default=0.01)
|
||||
parser.add_argument('--vector_l2', action='store_true')
|
||||
parser.add_argument('--n_gpu', type=int, default=1)#4) # TODO add mutli-gpu training logic
|
||||
parser.add_argument('--n_gpu', type=int, default=1) # 4) # TODO add mutli-gpu training logic
|
||||
parser.add_argument('--opt', type=str, default='adam')
|
||||
parser.add_argument('--afn', type=str, default='gelu')
|
||||
parser.add_argument('--lr_schedule', type=str, default='warmup_linear')
|
||||
@@ -212,37 +213,47 @@ if __name__ == '__main__':
|
||||
|
||||
args = parser.parse_args()
|
||||
print(args)
|
||||
globals().update(args.__dict__) #TODO maybe we want to remove these gobal variables to make it cleaner
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
# globals().update(args.__dict__) # TODO maybe we want to remove these gobal variables to make it cleaner
|
||||
|
||||
random.seed(args.seed)
|
||||
np.random.seed(args.seed)
|
||||
torch.manual_seed(args.seed)
|
||||
torch.cuda.manual_seed_all(args.seed)
|
||||
|
||||
# Constants
|
||||
submit = args.submit
|
||||
dataset = args.dataset
|
||||
n_ctx = args.n_ctx
|
||||
save_dir = args.save_dir
|
||||
desc = args.desc
|
||||
data_dir = args.data_dir
|
||||
log_dir = args.log_dir
|
||||
|
||||
# torch.device object used throughout this script TODO add gpu setting
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
logger = ResultLogger(path=os.path.join(log_dir, '{}.jsonl'.format(desc)), **args.__dict__)
|
||||
text_encoder = TextEncoder(encoder_path, bpe_path)
|
||||
text_encoder = TextEncoder(args.encoder_path, args.bpe_path)
|
||||
encoder = text_encoder.encoder
|
||||
n_vocab = len(text_encoder.encoder)
|
||||
|
||||
(trX1, trX2, trX3, trY), (vaX1, vaX2, vaX3, vaY), (teX1, teX2, teX3) = encode_dataset(rocstories(data_dir, n_valid=n_valid), encoder=text_encoder)
|
||||
(trX1, trX2, trX3, trY), (vaX1, vaX2, vaX3, vaY), (teX1, teX2, teX3) = encode_dataset(
|
||||
rocstories(data_dir, n_valid=args.n_valid), encoder=text_encoder)
|
||||
n_y = 2
|
||||
encoder['_start_'] = len(encoder)
|
||||
encoder['_delimiter_'] = len(encoder)
|
||||
encoder['_classify_'] = len(encoder)
|
||||
clf_token = encoder['_classify_']
|
||||
n_special = 3
|
||||
max_len = n_ctx//2-2
|
||||
max_len = n_ctx // 2 - 2
|
||||
n_ctx = min(max(
|
||||
[len(x1[:max_len]) + max(len(x2[:max_len]),
|
||||
len(x3[:max_len])) for x1, x2, x3 in zip(trX1, trX2, trX3)]
|
||||
+[len(x1[:max_len]) + max(len(x2[:max_len]),
|
||||
len(x3[:max_len])) for x1, x2, x3 in zip(vaX1, vaX2, vaX3)]
|
||||
+[len(x1[:max_len]) + max(len(x2[:max_len]),
|
||||
len(x3[:max_len])) for x1, x2, x3 in zip(teX1, teX2, teX3)]
|
||||
) + 3,
|
||||
n_ctx)
|
||||
[len(x1[:max_len]) + max(len(x2[:max_len]),
|
||||
len(x3[:max_len])) for x1, x2, x3 in zip(trX1, trX2, trX3)]
|
||||
+ [len(x1[:max_len]) + max(len(x2[:max_len]),
|
||||
len(x3[:max_len])) for x1, x2, x3 in zip(vaX1, vaX2, vaX3)]
|
||||
+ [len(x1[:max_len]) + max(len(x2[:max_len]),
|
||||
len(x3[:max_len])) for x1, x2, x3 in zip(teX1, teX2, teX3)]
|
||||
) + 3, n_ctx)
|
||||
vocab = n_vocab + n_special + n_ctx
|
||||
trX, trM = transform_roc(trX1, trX2, trX3)
|
||||
vaX, vaM = transform_roc(vaX1, vaX2, vaX3)
|
||||
@@ -251,20 +262,29 @@ if __name__ == '__main__':
|
||||
|
||||
n_train = len(trY)
|
||||
n_valid = len(vaY)
|
||||
n_batch_train = n_batch*n_gpu
|
||||
n_updates_total = (n_train//n_batch_train)*n_iter
|
||||
n_batch_train = args.n_batch * args.n_gpu
|
||||
n_updates_total = (n_train // n_batch_train) * args.n_iter
|
||||
|
||||
model = Model(args, vocab, n_ctx)
|
||||
lm_head = LMHead(model, args)
|
||||
clf_head = ClfHead(clf_token, args)
|
||||
|
||||
criterion = nn.CrossEntropyLoss(reduce=False) # TODO check loss functions
|
||||
criterion = nn.CrossEntropyLoss(reduce=False) # TODO check loss functions
|
||||
model_opt = OpenAIAdam(list(model.parameters()) + list(clf_head.parameters()) + list(lm_head.parameters()),
|
||||
lr=lr, schedule=lr_schedule,
|
||||
warmup=lr_warmup, t_total=n_updates_total, b1=b1,
|
||||
b2=b2, e=e, l2=l2, vector_l2=vector_l2,
|
||||
max_grad_norm=max_grad_norm)
|
||||
compute_loss_fct = LossCompute(criterion, criterion, lm_coef, model_opt)
|
||||
lr=args.lr,
|
||||
schedule=args.lr_schedule,
|
||||
warmup=args.lr_warmup,
|
||||
t_total=n_updates_total,
|
||||
b1=args.b1,
|
||||
b2=args.b2,
|
||||
e=args.e,
|
||||
l2=args.l2,
|
||||
vector_l2=args.vector_l2,
|
||||
max_grad_norm=args.max_grad_norm)
|
||||
compute_loss_fct = LossCompute(criterion,
|
||||
criterion,
|
||||
args.lm_coef,
|
||||
model_opt)
|
||||
load_openai_pretrained_model(model, n_ctx=n_ctx, n_special=n_special)
|
||||
|
||||
model.to(device)
|
||||
@@ -279,15 +299,15 @@ if __name__ == '__main__':
|
||||
path = os.path.join(save_dir, desc, 'best_params')
|
||||
torch.save(model.state_dict(), make_path(path))
|
||||
best_score = 0
|
||||
for i in range(n_iter):
|
||||
for i in range(args.n_iter):
|
||||
print("running epoch", i)
|
||||
run_epoch()
|
||||
n_epochs += 1
|
||||
log()
|
||||
log(save_dir, desc)
|
||||
if submit:
|
||||
path = os.path.join(save_dir, desc, 'best_params')
|
||||
model.load_state_dict(torch.load(path))
|
||||
predict()
|
||||
if analysis:
|
||||
rocstories_analysis(data_dir, os.path.join(submission_dir, 'ROCStories.tsv'),
|
||||
predict(dataset, args.submission_dir)
|
||||
if args.analysis:
|
||||
rocstories_analysis(data_dir, os.path.join(args.submission_dir, 'ROCStories.tsv'),
|
||||
os.path.join(log_dir, 'rocstories.jsonl'))
|
||||
|
||||
Reference in New Issue
Block a user