fix model and training

This commit is contained in:
thomwolf
2018-06-14 16:40:00 +02:00
parent 6a20d66253
commit 3d8d70937c
5 changed files with 37 additions and 30 deletions
BIN
View File
Binary file not shown.

After

Width:  |  Height:  |  Size: 204 KiB

+1 -1
View File
@@ -27,7 +27,7 @@ def _rocstories(path):
y.append(int(line[-1])-1)
return st, ct1, ct2, y
def rocstories(data_dir, n_train=1497, n_valid=2): #374): # TODO: set this back
def rocstories(data_dir, n_train=1497, n_valid=374):
storys, comps1, comps2, ys = _rocstories(os.path.join(data_dir, 'cloze_test_val__spring2016 - cloze_test_ALL_val.csv'))
teX1, teX2, teX3, _ = _rocstories(os.path.join(data_dir, 'cloze_test_test__spring2016 - cloze_test_ALL_test.csv'))
tr_storys, va_storys, tr_comps1, va_comps1, tr_comps2, va_comps2, tr_ys, va_ys = train_test_split(storys, comps1, comps2, ys, test_size=n_valid, random_state=seed)
+10 -11
View File
@@ -60,13 +60,12 @@ class Conv1D(nn.Module):
class Attention(nn.Module):
def __init__(self, nx, cfg, scale=False):
def __init__(self, nx, n_ctx, cfg, scale=False):
super(Attention, self).__init__()
n_state = nx # in Attention: n_state=768 (nx=n_embd)
#[switch nx => n_state from Block to Attention to keep identical to TF implem]
assert n_state % cfg.n_head==0
mask_size = n_state // cfg.n_head
self.register_buffer('b', torch.tril(torch.ones(mask_size, mask_size)).view(1, 1, mask_size, mask_size))
self.register_buffer('b', torch.tril(torch.ones(n_ctx, n_ctx)).view(1, 1, n_ctx, n_ctx))
self.n_head = cfg.n_head
self.split_size = n_state
self.scale = scale
@@ -126,10 +125,10 @@ class MLP(nn.Module):
class Block(nn.Module):
def __init__(self, cfg, scale=False):
def __init__(self, n_ctx, cfg, scale=False):
super(Block, self).__init__()
nx = cfg.n_embd
self.attn = Attention(nx, cfg, scale)
self.attn = Attention(nx, n_ctx, cfg, scale)
self.ln_1 = LayerNorm(nx)
self.mlp = MLP(4*nx, cfg)
self.ln_2 = LayerNorm(nx)
@@ -144,12 +143,12 @@ class Block(nn.Module):
class Model(nn.Module):
""" Transformer model """
def __init__(self, vocab, cfg):
def __init__(self, vocab, n_ctx, cfg):
super(Model, self).__init__()
self.vocab = vocab
self.embed = nn.Embedding(vocab, cfg.n_embd)
self.drop = nn.Dropout(cfg.embd_pdrop)
block = Block(cfg, scale=True)
block = Block(n_ctx, cfg, scale=True)
self.h = nn.ModuleList([copy.deepcopy(block) for _ in range(cfg.n_layer)])
self.decoder = nn.Linear(cfg.n_embd, vocab, bias=False)
self.decoder.weight = self.embed.weight # Tied weights
@@ -205,12 +204,12 @@ class ClfHead(nn.Module):
return clf_logits.view(-1, 2)
def load_openai_pretrained_model(model, n_ctx=-1, n_special=-1, n_transfer=12, n_embd=768, path='model'):
def load_openai_pretrained_model(model, n_ctx=-1, n_special=-1, n_transfer=12, n_embd=768, path='./model/', path_names='./'):
# Load weights from TF model
shapes = json.load(open(path + '/params_shapes.json'))
names = json.load(open(path + '/parameters_names.json'))
names = json.load(open(path_names + 'parameters_names.json'))
shapes = json.load(open(path + 'params_shapes.json'))
offsets = np.cumsum([np.prod(shape) for shape in shapes])
init_params = [np.load(path + '/params_{}.npy'.format(n)) for n in range(10)]
init_params = [np.load(path + 'params_{}.npy'.format(n)) for n in range(10)]
init_params = np.split(np.concatenate(init_params, 0), offsets)[:-1]
init_params = [param.reshape(shape) for param, shape in zip(init_params, shapes)]
if n_ctx > 0:
+1
View File
@@ -0,0 +1 @@
["model/we:0", "model/h0/attn/c_attn/w:0", "model/h0/attn/c_attn/b:0", "model/h0/attn/c_proj/w:0", "model/h0/attn/c_proj/b:0", "model/h0/ln_1/g:0", "model/h0/ln_1/b:0", "model/h0/mlp/c_fc/w:0", "model/h0/mlp/c_fc/b:0", "model/h0/mlp/c_proj/w:0", "model/h0/mlp/c_proj/b:0", "model/h0/ln_2/g:0", "model/h0/ln_2/b:0", "model/h1/attn/c_attn/w:0", "model/h1/attn/c_attn/b:0", "model/h1/attn/c_proj/w:0", "model/h1/attn/c_proj/b:0", "model/h1/ln_1/g:0", "model/h1/ln_1/b:0", "model/h1/mlp/c_fc/w:0", "model/h1/mlp/c_fc/b:0", "model/h1/mlp/c_proj/w:0", "model/h1/mlp/c_proj/b:0", "model/h1/ln_2/g:0", "model/h1/ln_2/b:0", "model/h2/attn/c_attn/w:0", "model/h2/attn/c_attn/b:0", "model/h2/attn/c_proj/w:0", "model/h2/attn/c_proj/b:0", "model/h2/ln_1/g:0", "model/h2/ln_1/b:0", "model/h2/mlp/c_fc/w:0", "model/h2/mlp/c_fc/b:0", "model/h2/mlp/c_proj/w:0", "model/h2/mlp/c_proj/b:0", "model/h2/ln_2/g:0", "model/h2/ln_2/b:0", "model/h3/attn/c_attn/w:0", "model/h3/attn/c_attn/b:0", "model/h3/attn/c_proj/w:0", "model/h3/attn/c_proj/b:0", "model/h3/ln_1/g:0", "model/h3/ln_1/b:0", "model/h3/mlp/c_fc/w:0", "model/h3/mlp/c_fc/b:0", "model/h3/mlp/c_proj/w:0", "model/h3/mlp/c_proj/b:0", "model/h3/ln_2/g:0", "model/h3/ln_2/b:0", "model/h4/attn/c_attn/w:0", "model/h4/attn/c_attn/b:0", "model/h4/attn/c_proj/w:0", "model/h4/attn/c_proj/b:0", "model/h4/ln_1/g:0", "model/h4/ln_1/b:0", "model/h4/mlp/c_fc/w:0", "model/h4/mlp/c_fc/b:0", "model/h4/mlp/c_proj/w:0", "model/h4/mlp/c_proj/b:0", "model/h4/ln_2/g:0", "model/h4/ln_2/b:0", "model/h5/attn/c_attn/w:0", "model/h5/attn/c_attn/b:0", "model/h5/attn/c_proj/w:0", "model/h5/attn/c_proj/b:0", "model/h5/ln_1/g:0", "model/h5/ln_1/b:0", "model/h5/mlp/c_fc/w:0", "model/h5/mlp/c_fc/b:0", "model/h5/mlp/c_proj/w:0", "model/h5/mlp/c_proj/b:0", "model/h5/ln_2/g:0", "model/h5/ln_2/b:0", "model/h6/attn/c_attn/w:0", "model/h6/attn/c_attn/b:0", "model/h6/attn/c_proj/w:0", "model/h6/attn/c_proj/b:0", "model/h6/ln_1/g:0", "model/h6/ln_1/b:0", "model/h6/mlp/c_fc/w:0", "model/h6/mlp/c_fc/b:0", "model/h6/mlp/c_proj/w:0", "model/h6/mlp/c_proj/b:0", "model/h6/ln_2/g:0", "model/h6/ln_2/b:0", "model/h7/attn/c_attn/w:0", "model/h7/attn/c_attn/b:0", "model/h7/attn/c_proj/w:0", "model/h7/attn/c_proj/b:0", "model/h7/ln_1/g:0", "model/h7/ln_1/b:0", "model/h7/mlp/c_fc/w:0", "model/h7/mlp/c_fc/b:0", "model/h7/mlp/c_proj/w:0", "model/h7/mlp/c_proj/b:0", "model/h7/ln_2/g:0", "model/h7/ln_2/b:0", "model/h8/attn/c_attn/w:0", "model/h8/attn/c_attn/b:0", "model/h8/attn/c_proj/w:0", "model/h8/attn/c_proj/b:0", "model/h8/ln_1/g:0", "model/h8/ln_1/b:0", "model/h8/mlp/c_fc/w:0", "model/h8/mlp/c_fc/b:0", "model/h8/mlp/c_proj/w:0", "model/h8/mlp/c_proj/b:0", "model/h8/ln_2/g:0", "model/h8/ln_2/b:0", "model/h9/attn/c_attn/w:0", "model/h9/attn/c_attn/b:0", "model/h9/attn/c_proj/w:0", "model/h9/attn/c_proj/b:0", "model/h9/ln_1/g:0", "model/h9/ln_1/b:0", "model/h9/mlp/c_fc/w:0", "model/h9/mlp/c_fc/b:0", "model/h9/mlp/c_proj/w:0", "model/h9/mlp/c_proj/b:0", "model/h9/ln_2/g:0", "model/h9/ln_2/b:0", "model/h10/attn/c_attn/w:0", "model/h10/attn/c_attn/b:0", "model/h10/attn/c_proj/w:0", "model/h10/attn/c_proj/b:0", "model/h10/ln_1/g:0", "model/h10/ln_1/b:0", "model/h10/mlp/c_fc/w:0", "model/h10/mlp/c_fc/b:0", "model/h10/mlp/c_proj/w:0", "model/h10/mlp/c_proj/b:0", "model/h10/ln_2/g:0", "model/h10/ln_2/b:0", "model/h11/attn/c_attn/w:0", "model/h11/attn/c_attn/b:0", "model/h11/attn/c_proj/w:0", "model/h11/attn/c_proj/b:0", "model/h11/ln_1/g:0", "model/h11/ln_1/b:0", "model/h11/mlp/c_fc/w:0", "model/h11/mlp/c_fc/b:0", "model/h11/mlp/c_proj/w:0", "model/h11/mlp/c_proj/b:0", "model/h11/ln_2/g:0", "model/h11/ln_2/b:0", "model/clf/w:0", "model/clf/b:0"]
+24 -17
View File
@@ -3,7 +3,6 @@ import os
import time
import math
import json
import joblib
import random
import argparse
import numpy as np
@@ -76,8 +75,9 @@ def transform_roc(X1, X2, X3):
return xmb, mmb
def iter_apply(Xs, Ms, Ys):
fns = [lambda x: np.concatenate(x, 0), lambda x: float(np.sum(x))]
results = []
# fns = [lambda x: np.concatenate(x, 0), lambda x: float(np.sum(x))]
logits = []
cost = 0
with torch.no_grad():
model.eval()
for xmb, mmb, ymb in iter_data(Xs, Ms, Ys, n_batch=n_batch_train, truncate=False, verbose=True):
@@ -87,11 +87,13 @@ def iter_apply(Xs, Ms, Ys):
MMB = torch.tensor(mmb).to(device)
h = model(XMB)
clf_logits = clf_head(h, XMB)
clf_logits *= n
clf_losses = compute_loss_fct(XMB, YMB, MMB, clf_logits, only_return_losses=True)
res = (clf_logits.numpy()*n, clf_losses.numpy()*n)
results.append(res)
results = zip(*results)
return [fn(res) for res, fn in zip(results, fns)]
clf_losses *= n
logits.append(clf_logits.to("cpu").numpy())
cost += clf_losses.sum().item()
logits = np.concatenate(logits, 0)
return logits, cost
def iter_predict(Xs, Ms):
logits = []
@@ -103,12 +105,13 @@ def iter_predict(Xs, Ms):
MMB = torch.tensor(mmb).to(device)
h = model(XMB)
clf_logits = clf_head(h, XMB)
logits.append(clf_logits.numpy())
logits.append(clf_logits.to("cpu").numpy())
logits = np.concatenate(logits, 0)
return logits
def log():
global best_score
print("Logging")
tr_logits, tr_cost = iter_apply(trX[:n_valid], trM[:n_valid], trY[:n_valid])
va_logits, va_cost = iter_apply(vaX, vaM, vaY)
tr_cost = tr_cost/len(trY[:n_valid])
@@ -142,10 +145,10 @@ def run_epoch():
for xmb, mmb, ymb in iter_data(*shuffle(trX, trM, trYt, random_state=np.random),
n_batch=n_batch_train, truncate=True, verbose=True):
global n_updates
model.train()
XMB = torch.tensor(xmb, dtype=torch.long).to(device)
YMB = torch.tensor(ymb, dtype=torch.long).to(device)
MMB = torch.tensor(mmb).to(device)
model.train()
h = model(XMB)
lm_logits = lm_head(h)
clf_logits = clf_head(h, XMB)
@@ -194,7 +197,7 @@ if __name__ == '__main__':
parser.add_argument('--clf_pdrop', type=float, default=0.1)
parser.add_argument('--l2', type=float, default=0.01)
parser.add_argument('--vector_l2', action='store_true')
parser.add_argument('--n_gpu', type=int, default=4)
parser.add_argument('--n_gpu', type=int, default=1)#4) # TODO add mutli-gpu training logic
parser.add_argument('--opt', type=str, default='adam')
parser.add_argument('--afn', type=str, default='gelu')
parser.add_argument('--lr_schedule', type=str, default='warmup_linear')
@@ -205,6 +208,7 @@ if __name__ == '__main__':
parser.add_argument('--b1', type=float, default=0.9)
parser.add_argument('--b2', type=float, default=0.999)
parser.add_argument('--e', type=float, default=1e-8)
parser.add_argument('--n_valid', type=int, default=374)
args = parser.parse_args()
print(args)
@@ -215,14 +219,14 @@ if __name__ == '__main__':
torch.cuda.manual_seed_all(seed)
# torch.device object used throughout this script TODO add gpu setting
device = torch.device("cpu") #"cuda" if use_cuda else "cpu")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger = ResultLogger(path=os.path.join(log_dir, '{}.jsonl'.format(desc)), **args.__dict__)
text_encoder = TextEncoder(encoder_path, bpe_path)
encoder = text_encoder.encoder
n_vocab = len(text_encoder.encoder)
(trX1, trX2, trX3, trY), (vaX1, vaX2, vaX3, vaY), (teX1, teX2, teX3) = encode_dataset(rocstories(data_dir), encoder=text_encoder)
(trX1, trX2, trX3, trY), (vaX1, vaX2, vaX3, vaY), (teX1, teX2, teX3) = encode_dataset(rocstories(data_dir, n_valid=n_valid), encoder=text_encoder)
n_y = 2
encoder['_start_'] = len(encoder)
encoder['_delimiter_'] = len(encoder)
@@ -231,9 +235,12 @@ if __name__ == '__main__':
n_special = 3
max_len = n_ctx//2-2
n_ctx = min(max(
[len(x1[:max_len])+max(len(x2[:max_len]), len(x3[:max_len])) for x1, x2, x3 in zip(trX1, trX2, trX3)]
+[len(x1[:max_len])+max(len(x2[:max_len]), len(x3[:max_len])) for x1, x2, x3 in zip(vaX1, vaX2, vaX3)]
+[len(x1[:max_len])+max(len(x2[:max_len]), len(x3[:max_len])) for x1, x2, x3 in zip(teX1, teX2, teX3)]
[len(x1[:max_len]) + max(len(x2[:max_len]),
len(x3[:max_len])) for x1, x2, x3 in zip(trX1, trX2, trX3)]
+[len(x1[:max_len]) + max(len(x2[:max_len]),
len(x3[:max_len])) for x1, x2, x3 in zip(vaX1, vaX2, vaX3)]
+[len(x1[:max_len]) + max(len(x2[:max_len]),
len(x3[:max_len])) for x1, x2, x3 in zip(teX1, teX2, teX3)]
) + 3,
n_ctx)
vocab = n_vocab + n_special + n_ctx
@@ -247,7 +254,7 @@ if __name__ == '__main__':
n_batch_train = n_batch*n_gpu
n_updates_total = (n_train//n_batch_train)*n_iter
model = Model(vocab, args)
model = Model(vocab, n_ctx, args)
lm_head = LMHead(model, args)
clf_head = ClfHead(clf_token, args)
@@ -271,8 +278,8 @@ if __name__ == '__main__':
path = os.path.join(save_dir, desc, 'best_params')
torch.save(model.state_dict(), make_path(path))
best_score = 0
log()
for i in range(n_iter):
print("running epoch", i)
run_epoch()
n_epochs += 1
log()