Merge pull request #9 from nottombrown/master

Remove globals to make code easier to follow
This commit is contained in:
thomwolf
2018-06-27 09:03:34 +02:00
committed by GitHub
4 changed files with 95 additions and 72 deletions
+2 -1
View File
@@ -109,4 +109,5 @@ venv.bak/
/site
# mypy
.mypy_cache/
.mypy_cache/cloze_data
cloze_data/
+3 -2
View File
@@ -6,7 +6,7 @@ This implementation comprises **a script to load in the PyTorch model the weight
![Transformer Language Model](assets/ftlm.png)
The model classes and loading script are located in [model_py.py](model_py.py).
The model classes and loading script are located in [model_py.py](model_pytorch.py).
The names of the modules in the PyTorch model follow the names of the Variable in the TensorFlow implementation. This implementation tries to follow the original code as closely as possible to minimize the discrepancies.
@@ -15,7 +15,7 @@ This implementation thus also comprises a modified Adam optimization algorithm a
- scheduled learning rate as [commonly used for Transformers](http://nlp.seas.harvard.edu/2018/04/03/attention.html#optimizer).
## Requirements
To use the model it-self by importing [model_py.py](model_py.py), you just need:
To use the model it-self by importing [model_py.py](model_pytorch.py), you just need:
- PyTorch (version >=0.4)
To run the classifier training script in [train.py](train.py) you will need in addition:
@@ -49,6 +49,7 @@ The ROCStories dataset can be downloaded from the associated [website](http://cs
As with the [TensorFlow code](https://github.com/openai/finetune-transformer-lm), this code implements the ROCStories Cloze Test result reported in the paper which can be reproduced by running:
```bash
python -m spacy download en
python train.py --dataset rocstories --desc rocstories --submit --analysis --data_dir [path to data here]
```
+6 -5
View File
@@ -1,14 +1,15 @@
import re
import math
import json
import copy
import numpy as np
import json
import math
import re
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.parameter import Parameter
def gelu(x):
return 0.5*x*(1+torch.tanh(math.sqrt(2/math.pi)*(x+0.044715*torch.pow(x, 3))))
@@ -63,7 +64,7 @@ class Conv1D(nn.Module):
class Attention(nn.Module):
def __init__(self, nx, n_ctx, cfg, scale=False):
super(Attention, self).__init__()
n_state = nx # in Attention: n_state=768 (nx=n_embd)
n_state = nx # in Attention: n_state=768 (nx=n_embd)
#[switch nx => n_state from Block to Attention to keep identical to TF implem]
assert n_state % cfg.n_head==0
self.register_buffer('b', torch.tril(torch.ones(n_ctx, n_ctx)).view(1, 1, n_ctx, n_ctx))
+84 -64
View File
@@ -1,31 +1,25 @@
import re
import os
import time
import math
import json
import random
import argparse
import numpy as np
import os
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
from functools import partial
from sklearn.utils import shuffle
from sklearn.metrics import accuracy_score
from sklearn.utils import shuffle
from model_py import Model, LMHead, ClfHead, load_openai_pretrained_model
from opt import OpenAIAdam
from datasets import rocstories
from analysis import rocstories as rocstories_analysis
from datasets import rocstories
from model_pytorch import Model, LMHead, ClfHead, load_openai_pretrained_model
from opt import OpenAIAdam
from text_utils import TextEncoder
from utils import (encode_dataset, flatten, iter_data,
from utils import (encode_dataset, iter_data,
ResultLogger, make_path)
class LossCompute:
"A Loss compute and train function."
def __init__(self, lm_criterion, clf_criterion, lm_coef, opt=None):
self.lm_criterion = lm_criterion
self.clf_criterion = clf_criterion
@@ -35,10 +29,10 @@ class LossCompute:
def __call__(self, X, Y, M, clf_logits, lm_logits=None, only_return_losses=False):
# Language modeling loss
if lm_logits is not None:
x_shifted = X[:, :, 1:, 0].contiguous().view(-1) # Shape: 252
x_shifted = X[:, :, 1:, 0].contiguous().view(-1) # Shape: 252
M = M.view(-1, M.size(2))
lm_losses = self.lm_criterion(lm_logits, x_shifted)
lm_losses = lm_losses.view(X.size(0) * X.size(1), X.size(2)-1)
lm_losses = lm_losses.view(X.size(0) * X.size(1), X.size(2) - 1)
lm_losses = lm_losses * M[:, 1:]
lm_losses = lm_losses.sum(1) / torch.sum(M[:, 1:], 1)
# Classification loss
@@ -56,6 +50,7 @@ class LossCompute:
self.opt.zero_grad()
return train_loss.item()
def transform_roc(X1, X2, X3):
n_batch = len(X1)
xmb = np.zeros((n_batch, 2, n_ctx, 2), dtype=np.int32)
@@ -63,17 +58,18 @@ def transform_roc(X1, X2, X3):
start = encoder['_start_']
delimiter = encoder['_delimiter_']
for i, (x1, x2, x3), in enumerate(zip(X1, X2, X3)):
x12 = [start]+x1[:max_len]+[delimiter]+x2[:max_len]+[clf_token]
x13 = [start]+x1[:max_len]+[delimiter]+x3[:max_len]+[clf_token]
x12 = [start] + x1[:max_len] + [delimiter] + x2[:max_len] + [clf_token]
x13 = [start] + x1[:max_len] + [delimiter] + x3[:max_len] + [clf_token]
l12 = len(x12)
l13 = len(x13)
xmb[i, 0, :l12, 0] = x12
xmb[i, 1, :l13, 0] = x13
mmb[i, 0, :l12] = 1
mmb[i, 1, :l13] = 1
xmb[:, :, :, 1] = np.arange(n_vocab+n_special, n_vocab+n_special+n_ctx)
xmb[:, :, :, 1] = np.arange(n_vocab + n_special, n_vocab + n_special + n_ctx)
return xmb, mmb
def iter_apply(Xs, Ms, Ys):
# fns = [lambda x: np.concatenate(x, 0), lambda x: float(np.sum(x))]
logits = []
@@ -95,6 +91,7 @@ def iter_apply(Xs, Ms, Ys):
logits = np.concatenate(logits, 0)
return logits, cost
def iter_predict(Xs, Ms):
logits = []
with torch.no_grad():
@@ -109,17 +106,18 @@ def iter_predict(Xs, Ms):
logits = np.concatenate(logits, 0)
return logits
def log():
def log(save_dir, desc):
global best_score
print("Logging")
tr_logits, tr_cost = iter_apply(trX[:n_valid], trM[:n_valid], trY[:n_valid])
va_logits, va_cost = iter_apply(vaX, vaM, vaY)
tr_cost = tr_cost/len(trY[:n_valid])
va_cost = va_cost/n_valid
tr_acc = accuracy_score(trY[:n_valid], np.argmax(tr_logits, 1))*100.
va_acc = accuracy_score(vaY, np.argmax(va_logits, 1))*100.
tr_cost = tr_cost / len(trY[:n_valid])
va_cost = va_cost / n_valid
tr_acc = accuracy_score(trY[:n_valid], np.argmax(tr_logits, 1)) * 100.
va_acc = accuracy_score(vaY, np.argmax(va_logits, 1)) * 100.
logger.log(n_epochs=n_epochs, n_updates=n_updates, tr_cost=tr_cost, va_cost=va_cost, tr_acc=tr_acc, va_acc=va_acc)
print('%d %d %.3f %.3f %.2f %.2f'%(n_epochs, n_updates, tr_cost, va_cost, tr_acc, va_acc))
print('%d %d %.3f %.3f %.2f %.2f' % (n_epochs, n_updates, tr_cost, va_cost, tr_acc, va_acc))
if submit:
score = va_acc
if score > best_score:
@@ -127,7 +125,8 @@ def log():
path = os.path.join(save_dir, desc, 'best_params')
torch.save(model.state_dict(), make_path(path))
def predict():
def predict(dataset, submission_dir):
filename = filenames[dataset]
pred_fn = pred_fns[dataset]
label_decoder = label_decoders[dataset]
@@ -141,9 +140,10 @@ def predict():
for i, prediction in enumerate(predictions):
f.write('{}\t{}\n'.format(i, prediction))
def run_epoch():
for xmb, mmb, ymb in iter_data(*shuffle(trX, trM, trYt, random_state=np.random),
n_batch=n_batch_train, truncate=True, verbose=True):
n_batch=n_batch_train, truncate=True, verbose=True):
global n_updates
model.train()
XMB = torch.tensor(xmb, dtype=torch.long).to(device)
@@ -157,23 +157,24 @@ def run_epoch():
if n_updates in [1000, 2000, 4000, 8000, 16000, 32000] and n_epochs == 0:
log()
argmax = lambda x:np.argmax(x, 1)
argmax = lambda x: np.argmax(x, 1)
pred_fns = {
'rocstories':argmax,
'rocstories': argmax,
}
filenames = {
'rocstories':'ROCStories.tsv',
'rocstories': 'ROCStories.tsv',
}
label_decoders = {
'rocstories':None,
'rocstories': None,
}
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--desc', type=str)
parser.add_argument('--desc', type=str, help="Description")
parser.add_argument('--dataset', type=str)
parser.add_argument('--log_dir', type=str, default='log/')
parser.add_argument('--save_dir', type=str, default='save/')
@@ -197,7 +198,7 @@ if __name__ == '__main__':
parser.add_argument('--clf_pdrop', type=float, default=0.1)
parser.add_argument('--l2', type=float, default=0.01)
parser.add_argument('--vector_l2', action='store_true')
parser.add_argument('--n_gpu', type=int, default=1)#4) # TODO add mutli-gpu training logic
parser.add_argument('--n_gpu', type=int, default=1) # 4) # TODO add mutli-gpu training logic
parser.add_argument('--opt', type=str, default='adam')
parser.add_argument('--afn', type=str, default='gelu')
parser.add_argument('--lr_schedule', type=str, default='warmup_linear')
@@ -212,37 +213,47 @@ if __name__ == '__main__':
args = parser.parse_args()
print(args)
globals().update(args.__dict__) #TODO maybe we want to remove these gobal variables to make it cleaner
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
# globals().update(args.__dict__) # TODO maybe we want to remove these gobal variables to make it cleaner
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.cuda.manual_seed_all(args.seed)
# Constants
submit = args.submit
dataset = args.dataset
n_ctx = args.n_ctx
save_dir = args.save_dir
desc = args.desc
data_dir = args.data_dir
log_dir = args.log_dir
# torch.device object used throughout this script TODO add gpu setting
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger = ResultLogger(path=os.path.join(log_dir, '{}.jsonl'.format(desc)), **args.__dict__)
text_encoder = TextEncoder(encoder_path, bpe_path)
text_encoder = TextEncoder(args.encoder_path, args.bpe_path)
encoder = text_encoder.encoder
n_vocab = len(text_encoder.encoder)
(trX1, trX2, trX3, trY), (vaX1, vaX2, vaX3, vaY), (teX1, teX2, teX3) = encode_dataset(rocstories(data_dir, n_valid=n_valid), encoder=text_encoder)
(trX1, trX2, trX3, trY), (vaX1, vaX2, vaX3, vaY), (teX1, teX2, teX3) = encode_dataset(
rocstories(data_dir, n_valid=args.n_valid), encoder=text_encoder)
n_y = 2
encoder['_start_'] = len(encoder)
encoder['_delimiter_'] = len(encoder)
encoder['_classify_'] = len(encoder)
clf_token = encoder['_classify_']
n_special = 3
max_len = n_ctx//2-2
max_len = n_ctx // 2 - 2
n_ctx = min(max(
[len(x1[:max_len]) + max(len(x2[:max_len]),
len(x3[:max_len])) for x1, x2, x3 in zip(trX1, trX2, trX3)]
+[len(x1[:max_len]) + max(len(x2[:max_len]),
len(x3[:max_len])) for x1, x2, x3 in zip(vaX1, vaX2, vaX3)]
+[len(x1[:max_len]) + max(len(x2[:max_len]),
len(x3[:max_len])) for x1, x2, x3 in zip(teX1, teX2, teX3)]
) + 3,
n_ctx)
[len(x1[:max_len]) + max(len(x2[:max_len]),
len(x3[:max_len])) for x1, x2, x3 in zip(trX1, trX2, trX3)]
+ [len(x1[:max_len]) + max(len(x2[:max_len]),
len(x3[:max_len])) for x1, x2, x3 in zip(vaX1, vaX2, vaX3)]
+ [len(x1[:max_len]) + max(len(x2[:max_len]),
len(x3[:max_len])) for x1, x2, x3 in zip(teX1, teX2, teX3)]
) + 3, n_ctx)
vocab = n_vocab + n_special + n_ctx
trX, trM = transform_roc(trX1, trX2, trX3)
vaX, vaM = transform_roc(vaX1, vaX2, vaX3)
@@ -251,20 +262,29 @@ if __name__ == '__main__':
n_train = len(trY)
n_valid = len(vaY)
n_batch_train = n_batch*n_gpu
n_updates_total = (n_train//n_batch_train)*n_iter
n_batch_train = args.n_batch * args.n_gpu
n_updates_total = (n_train // n_batch_train) * args.n_iter
model = Model(args, vocab, n_ctx)
lm_head = LMHead(model, args)
clf_head = ClfHead(clf_token, args)
criterion = nn.CrossEntropyLoss(reduce=False) # TODO check loss functions
criterion = nn.CrossEntropyLoss(reduce=False) # TODO check loss functions
model_opt = OpenAIAdam(list(model.parameters()) + list(clf_head.parameters()) + list(lm_head.parameters()),
lr=lr, schedule=lr_schedule,
warmup=lr_warmup, t_total=n_updates_total, b1=b1,
b2=b2, e=e, l2=l2, vector_l2=vector_l2,
max_grad_norm=max_grad_norm)
compute_loss_fct = LossCompute(criterion, criterion, lm_coef, model_opt)
lr=args.lr,
schedule=args.lr_schedule,
warmup=args.lr_warmup,
t_total=n_updates_total,
b1=args.b1,
b2=args.b2,
e=args.e,
l2=args.l2,
vector_l2=args.vector_l2,
max_grad_norm=args.max_grad_norm)
compute_loss_fct = LossCompute(criterion,
criterion,
args.lm_coef,
model_opt)
load_openai_pretrained_model(model, n_ctx=n_ctx, n_special=n_special)
model.to(device)
@@ -279,15 +299,15 @@ if __name__ == '__main__':
path = os.path.join(save_dir, desc, 'best_params')
torch.save(model.state_dict(), make_path(path))
best_score = 0
for i in range(n_iter):
for i in range(args.n_iter):
print("running epoch", i)
run_epoch()
n_epochs += 1
log()
log(save_dir, desc)
if submit:
path = os.path.join(save_dir, desc, 'best_params')
model.load_state_dict(torch.load(path))
predict()
if analysis:
rocstories_analysis(data_dir, os.path.join(submission_dir, 'ROCStories.tsv'),
predict(dataset, args.submission_dir)
if args.analysis:
rocstories_analysis(data_dir, os.path.join(args.submission_dir, 'ROCStories.tsv'),
os.path.join(log_dir, 'rocstories.jsonl'))