fix model and training

This commit is contained in:
thomwolf
2018-06-14 16:40:00 +02:00
parent 6a20d66253
commit 3d8d70937c
5 changed files with 37 additions and 30 deletions
BIN
View File
Binary file not shown.

After

Width:  |  Height:  |  Size: 204 KiB

+1 -1
View File
@@ -27,7 +27,7 @@ def _rocstories(path):
y.append(int(line[-1])-1) y.append(int(line[-1])-1)
return st, ct1, ct2, y return st, ct1, ct2, y
def rocstories(data_dir, n_train=1497, n_valid=2): #374): # TODO: set this back def rocstories(data_dir, n_train=1497, n_valid=374):
storys, comps1, comps2, ys = _rocstories(os.path.join(data_dir, 'cloze_test_val__spring2016 - cloze_test_ALL_val.csv')) storys, comps1, comps2, ys = _rocstories(os.path.join(data_dir, 'cloze_test_val__spring2016 - cloze_test_ALL_val.csv'))
teX1, teX2, teX3, _ = _rocstories(os.path.join(data_dir, 'cloze_test_test__spring2016 - cloze_test_ALL_test.csv')) teX1, teX2, teX3, _ = _rocstories(os.path.join(data_dir, 'cloze_test_test__spring2016 - cloze_test_ALL_test.csv'))
tr_storys, va_storys, tr_comps1, va_comps1, tr_comps2, va_comps2, tr_ys, va_ys = train_test_split(storys, comps1, comps2, ys, test_size=n_valid, random_state=seed) tr_storys, va_storys, tr_comps1, va_comps1, tr_comps2, va_comps2, tr_ys, va_ys = train_test_split(storys, comps1, comps2, ys, test_size=n_valid, random_state=seed)
+10 -11
View File
@@ -60,13 +60,12 @@ class Conv1D(nn.Module):
class Attention(nn.Module): class Attention(nn.Module):
def __init__(self, nx, cfg, scale=False): def __init__(self, nx, n_ctx, cfg, scale=False):
super(Attention, self).__init__() super(Attention, self).__init__()
n_state = nx # in Attention: n_state=768 (nx=n_embd) n_state = nx # in Attention: n_state=768 (nx=n_embd)
#[switch nx => n_state from Block to Attention to keep identical to TF implem] #[switch nx => n_state from Block to Attention to keep identical to TF implem]
assert n_state % cfg.n_head==0 assert n_state % cfg.n_head==0
mask_size = n_state // cfg.n_head self.register_buffer('b', torch.tril(torch.ones(n_ctx, n_ctx)).view(1, 1, n_ctx, n_ctx))
self.register_buffer('b', torch.tril(torch.ones(mask_size, mask_size)).view(1, 1, mask_size, mask_size))
self.n_head = cfg.n_head self.n_head = cfg.n_head
self.split_size = n_state self.split_size = n_state
self.scale = scale self.scale = scale
@@ -126,10 +125,10 @@ class MLP(nn.Module):
class Block(nn.Module): class Block(nn.Module):
def __init__(self, cfg, scale=False): def __init__(self, n_ctx, cfg, scale=False):
super(Block, self).__init__() super(Block, self).__init__()
nx = cfg.n_embd nx = cfg.n_embd
self.attn = Attention(nx, cfg, scale) self.attn = Attention(nx, n_ctx, cfg, scale)
self.ln_1 = LayerNorm(nx) self.ln_1 = LayerNorm(nx)
self.mlp = MLP(4*nx, cfg) self.mlp = MLP(4*nx, cfg)
self.ln_2 = LayerNorm(nx) self.ln_2 = LayerNorm(nx)
@@ -144,12 +143,12 @@ class Block(nn.Module):
class Model(nn.Module): class Model(nn.Module):
""" Transformer model """ """ Transformer model """
def __init__(self, vocab, cfg): def __init__(self, vocab, n_ctx, cfg):
super(Model, self).__init__() super(Model, self).__init__()
self.vocab = vocab self.vocab = vocab
self.embed = nn.Embedding(vocab, cfg.n_embd) self.embed = nn.Embedding(vocab, cfg.n_embd)
self.drop = nn.Dropout(cfg.embd_pdrop) self.drop = nn.Dropout(cfg.embd_pdrop)
block = Block(cfg, scale=True) block = Block(n_ctx, cfg, scale=True)
self.h = nn.ModuleList([copy.deepcopy(block) for _ in range(cfg.n_layer)]) self.h = nn.ModuleList([copy.deepcopy(block) for _ in range(cfg.n_layer)])
self.decoder = nn.Linear(cfg.n_embd, vocab, bias=False) self.decoder = nn.Linear(cfg.n_embd, vocab, bias=False)
self.decoder.weight = self.embed.weight # Tied weights self.decoder.weight = self.embed.weight # Tied weights
@@ -205,12 +204,12 @@ class ClfHead(nn.Module):
return clf_logits.view(-1, 2) return clf_logits.view(-1, 2)
def load_openai_pretrained_model(model, n_ctx=-1, n_special=-1, n_transfer=12, n_embd=768, path='model'): def load_openai_pretrained_model(model, n_ctx=-1, n_special=-1, n_transfer=12, n_embd=768, path='./model/', path_names='./'):
# Load weights from TF model # Load weights from TF model
shapes = json.load(open(path + '/params_shapes.json')) names = json.load(open(path_names + 'parameters_names.json'))
names = json.load(open(path + '/parameters_names.json')) shapes = json.load(open(path + 'params_shapes.json'))
offsets = np.cumsum([np.prod(shape) for shape in shapes]) offsets = np.cumsum([np.prod(shape) for shape in shapes])
init_params = [np.load(path + '/params_{}.npy'.format(n)) for n in range(10)] init_params = [np.load(path + 'params_{}.npy'.format(n)) for n in range(10)]
init_params = np.split(np.concatenate(init_params, 0), offsets)[:-1] init_params = np.split(np.concatenate(init_params, 0), offsets)[:-1]
init_params = [param.reshape(shape) for param, shape in zip(init_params, shapes)] init_params = [param.reshape(shape) for param, shape in zip(init_params, shapes)]
if n_ctx > 0: if n_ctx > 0:
+1
View File
@@ -0,0 +1 @@
["model/we:0", "model/h0/attn/c_attn/w:0", "model/h0/attn/c_attn/b:0", "model/h0/attn/c_proj/w:0", "model/h0/attn/c_proj/b:0", "model/h0/ln_1/g:0", "model/h0/ln_1/b:0", "model/h0/mlp/c_fc/w:0", "model/h0/mlp/c_fc/b:0", "model/h0/mlp/c_proj/w:0", "model/h0/mlp/c_proj/b:0", "model/h0/ln_2/g:0", "model/h0/ln_2/b:0", "model/h1/attn/c_attn/w:0", "model/h1/attn/c_attn/b:0", "model/h1/attn/c_proj/w:0", "model/h1/attn/c_proj/b:0", "model/h1/ln_1/g:0", "model/h1/ln_1/b:0", "model/h1/mlp/c_fc/w:0", "model/h1/mlp/c_fc/b:0", "model/h1/mlp/c_proj/w:0", "model/h1/mlp/c_proj/b:0", "model/h1/ln_2/g:0", "model/h1/ln_2/b:0", "model/h2/attn/c_attn/w:0", "model/h2/attn/c_attn/b:0", "model/h2/attn/c_proj/w:0", "model/h2/attn/c_proj/b:0", "model/h2/ln_1/g:0", "model/h2/ln_1/b:0", "model/h2/mlp/c_fc/w:0", "model/h2/mlp/c_fc/b:0", "model/h2/mlp/c_proj/w:0", "model/h2/mlp/c_proj/b:0", "model/h2/ln_2/g:0", "model/h2/ln_2/b:0", "model/h3/attn/c_attn/w:0", "model/h3/attn/c_attn/b:0", "model/h3/attn/c_proj/w:0", "model/h3/attn/c_proj/b:0", "model/h3/ln_1/g:0", "model/h3/ln_1/b:0", "model/h3/mlp/c_fc/w:0", "model/h3/mlp/c_fc/b:0", "model/h3/mlp/c_proj/w:0", "model/h3/mlp/c_proj/b:0", "model/h3/ln_2/g:0", "model/h3/ln_2/b:0", "model/h4/attn/c_attn/w:0", "model/h4/attn/c_attn/b:0", "model/h4/attn/c_proj/w:0", "model/h4/attn/c_proj/b:0", "model/h4/ln_1/g:0", "model/h4/ln_1/b:0", "model/h4/mlp/c_fc/w:0", "model/h4/mlp/c_fc/b:0", "model/h4/mlp/c_proj/w:0", "model/h4/mlp/c_proj/b:0", "model/h4/ln_2/g:0", "model/h4/ln_2/b:0", "model/h5/attn/c_attn/w:0", "model/h5/attn/c_attn/b:0", "model/h5/attn/c_proj/w:0", "model/h5/attn/c_proj/b:0", "model/h5/ln_1/g:0", "model/h5/ln_1/b:0", "model/h5/mlp/c_fc/w:0", "model/h5/mlp/c_fc/b:0", "model/h5/mlp/c_proj/w:0", "model/h5/mlp/c_proj/b:0", "model/h5/ln_2/g:0", "model/h5/ln_2/b:0", "model/h6/attn/c_attn/w:0", "model/h6/attn/c_attn/b:0", "model/h6/attn/c_proj/w:0", "model/h6/attn/c_proj/b:0", "model/h6/ln_1/g:0", "model/h6/ln_1/b:0", "model/h6/mlp/c_fc/w:0", "model/h6/mlp/c_fc/b:0", "model/h6/mlp/c_proj/w:0", "model/h6/mlp/c_proj/b:0", "model/h6/ln_2/g:0", "model/h6/ln_2/b:0", "model/h7/attn/c_attn/w:0", "model/h7/attn/c_attn/b:0", "model/h7/attn/c_proj/w:0", "model/h7/attn/c_proj/b:0", "model/h7/ln_1/g:0", "model/h7/ln_1/b:0", "model/h7/mlp/c_fc/w:0", "model/h7/mlp/c_fc/b:0", "model/h7/mlp/c_proj/w:0", "model/h7/mlp/c_proj/b:0", "model/h7/ln_2/g:0", "model/h7/ln_2/b:0", "model/h8/attn/c_attn/w:0", "model/h8/attn/c_attn/b:0", "model/h8/attn/c_proj/w:0", "model/h8/attn/c_proj/b:0", "model/h8/ln_1/g:0", "model/h8/ln_1/b:0", "model/h8/mlp/c_fc/w:0", "model/h8/mlp/c_fc/b:0", "model/h8/mlp/c_proj/w:0", "model/h8/mlp/c_proj/b:0", "model/h8/ln_2/g:0", "model/h8/ln_2/b:0", "model/h9/attn/c_attn/w:0", "model/h9/attn/c_attn/b:0", "model/h9/attn/c_proj/w:0", "model/h9/attn/c_proj/b:0", "model/h9/ln_1/g:0", "model/h9/ln_1/b:0", "model/h9/mlp/c_fc/w:0", "model/h9/mlp/c_fc/b:0", "model/h9/mlp/c_proj/w:0", "model/h9/mlp/c_proj/b:0", "model/h9/ln_2/g:0", "model/h9/ln_2/b:0", "model/h10/attn/c_attn/w:0", "model/h10/attn/c_attn/b:0", "model/h10/attn/c_proj/w:0", "model/h10/attn/c_proj/b:0", "model/h10/ln_1/g:0", "model/h10/ln_1/b:0", "model/h10/mlp/c_fc/w:0", "model/h10/mlp/c_fc/b:0", "model/h10/mlp/c_proj/w:0", "model/h10/mlp/c_proj/b:0", "model/h10/ln_2/g:0", "model/h10/ln_2/b:0", "model/h11/attn/c_attn/w:0", "model/h11/attn/c_attn/b:0", "model/h11/attn/c_proj/w:0", "model/h11/attn/c_proj/b:0", "model/h11/ln_1/g:0", "model/h11/ln_1/b:0", "model/h11/mlp/c_fc/w:0", "model/h11/mlp/c_fc/b:0", "model/h11/mlp/c_proj/w:0", "model/h11/mlp/c_proj/b:0", "model/h11/ln_2/g:0", "model/h11/ln_2/b:0", "model/clf/w:0", "model/clf/b:0"]
+25 -18
View File
@@ -3,7 +3,6 @@ import os
import time import time
import math import math
import json import json
import joblib
import random import random
import argparse import argparse
import numpy as np import numpy as np
@@ -76,8 +75,9 @@ def transform_roc(X1, X2, X3):
return xmb, mmb return xmb, mmb
def iter_apply(Xs, Ms, Ys): def iter_apply(Xs, Ms, Ys):
fns = [lambda x: np.concatenate(x, 0), lambda x: float(np.sum(x))] # fns = [lambda x: np.concatenate(x, 0), lambda x: float(np.sum(x))]
results = [] logits = []
cost = 0
with torch.no_grad(): with torch.no_grad():
model.eval() model.eval()
for xmb, mmb, ymb in iter_data(Xs, Ms, Ys, n_batch=n_batch_train, truncate=False, verbose=True): for xmb, mmb, ymb in iter_data(Xs, Ms, Ys, n_batch=n_batch_train, truncate=False, verbose=True):
@@ -87,11 +87,13 @@ def iter_apply(Xs, Ms, Ys):
MMB = torch.tensor(mmb).to(device) MMB = torch.tensor(mmb).to(device)
h = model(XMB) h = model(XMB)
clf_logits = clf_head(h, XMB) clf_logits = clf_head(h, XMB)
clf_logits *= n
clf_losses = compute_loss_fct(XMB, YMB, MMB, clf_logits, only_return_losses=True) clf_losses = compute_loss_fct(XMB, YMB, MMB, clf_logits, only_return_losses=True)
res = (clf_logits.numpy()*n, clf_losses.numpy()*n) clf_losses *= n
results.append(res) logits.append(clf_logits.to("cpu").numpy())
results = zip(*results) cost += clf_losses.sum().item()
return [fn(res) for res, fn in zip(results, fns)] logits = np.concatenate(logits, 0)
return logits, cost
def iter_predict(Xs, Ms): def iter_predict(Xs, Ms):
logits = [] logits = []
@@ -103,12 +105,13 @@ def iter_predict(Xs, Ms):
MMB = torch.tensor(mmb).to(device) MMB = torch.tensor(mmb).to(device)
h = model(XMB) h = model(XMB)
clf_logits = clf_head(h, XMB) clf_logits = clf_head(h, XMB)
logits.append(clf_logits.numpy()) logits.append(clf_logits.to("cpu").numpy())
logits = np.concatenate(logits, 0) logits = np.concatenate(logits, 0)
return logits return logits
def log(): def log():
global best_score global best_score
print("Logging")
tr_logits, tr_cost = iter_apply(trX[:n_valid], trM[:n_valid], trY[:n_valid]) tr_logits, tr_cost = iter_apply(trX[:n_valid], trM[:n_valid], trY[:n_valid])
va_logits, va_cost = iter_apply(vaX, vaM, vaY) va_logits, va_cost = iter_apply(vaX, vaM, vaY)
tr_cost = tr_cost/len(trY[:n_valid]) tr_cost = tr_cost/len(trY[:n_valid])
@@ -142,10 +145,10 @@ def run_epoch():
for xmb, mmb, ymb in iter_data(*shuffle(trX, trM, trYt, random_state=np.random), for xmb, mmb, ymb in iter_data(*shuffle(trX, trM, trYt, random_state=np.random),
n_batch=n_batch_train, truncate=True, verbose=True): n_batch=n_batch_train, truncate=True, verbose=True):
global n_updates global n_updates
model.train()
XMB = torch.tensor(xmb, dtype=torch.long).to(device) XMB = torch.tensor(xmb, dtype=torch.long).to(device)
YMB = torch.tensor(ymb, dtype=torch.long).to(device) YMB = torch.tensor(ymb, dtype=torch.long).to(device)
MMB = torch.tensor(mmb).to(device) MMB = torch.tensor(mmb).to(device)
model.train()
h = model(XMB) h = model(XMB)
lm_logits = lm_head(h) lm_logits = lm_head(h)
clf_logits = clf_head(h, XMB) clf_logits = clf_head(h, XMB)
@@ -194,7 +197,7 @@ if __name__ == '__main__':
parser.add_argument('--clf_pdrop', type=float, default=0.1) parser.add_argument('--clf_pdrop', type=float, default=0.1)
parser.add_argument('--l2', type=float, default=0.01) parser.add_argument('--l2', type=float, default=0.01)
parser.add_argument('--vector_l2', action='store_true') parser.add_argument('--vector_l2', action='store_true')
parser.add_argument('--n_gpu', type=int, default=4) parser.add_argument('--n_gpu', type=int, default=1)#4) # TODO add mutli-gpu training logic
parser.add_argument('--opt', type=str, default='adam') parser.add_argument('--opt', type=str, default='adam')
parser.add_argument('--afn', type=str, default='gelu') parser.add_argument('--afn', type=str, default='gelu')
parser.add_argument('--lr_schedule', type=str, default='warmup_linear') parser.add_argument('--lr_schedule', type=str, default='warmup_linear')
@@ -205,6 +208,7 @@ if __name__ == '__main__':
parser.add_argument('--b1', type=float, default=0.9) parser.add_argument('--b1', type=float, default=0.9)
parser.add_argument('--b2', type=float, default=0.999) parser.add_argument('--b2', type=float, default=0.999)
parser.add_argument('--e', type=float, default=1e-8) parser.add_argument('--e', type=float, default=1e-8)
parser.add_argument('--n_valid', type=int, default=374)
args = parser.parse_args() args = parser.parse_args()
print(args) print(args)
@@ -215,14 +219,14 @@ if __name__ == '__main__':
torch.cuda.manual_seed_all(seed) torch.cuda.manual_seed_all(seed)
# torch.device object used throughout this script TODO add gpu setting # torch.device object used throughout this script TODO add gpu setting
device = torch.device("cpu") #"cuda" if use_cuda else "cpu") device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger = ResultLogger(path=os.path.join(log_dir, '{}.jsonl'.format(desc)), **args.__dict__) logger = ResultLogger(path=os.path.join(log_dir, '{}.jsonl'.format(desc)), **args.__dict__)
text_encoder = TextEncoder(encoder_path, bpe_path) text_encoder = TextEncoder(encoder_path, bpe_path)
encoder = text_encoder.encoder encoder = text_encoder.encoder
n_vocab = len(text_encoder.encoder) n_vocab = len(text_encoder.encoder)
(trX1, trX2, trX3, trY), (vaX1, vaX2, vaX3, vaY), (teX1, teX2, teX3) = encode_dataset(rocstories(data_dir), encoder=text_encoder) (trX1, trX2, trX3, trY), (vaX1, vaX2, vaX3, vaY), (teX1, teX2, teX3) = encode_dataset(rocstories(data_dir, n_valid=n_valid), encoder=text_encoder)
n_y = 2 n_y = 2
encoder['_start_'] = len(encoder) encoder['_start_'] = len(encoder)
encoder['_delimiter_'] = len(encoder) encoder['_delimiter_'] = len(encoder)
@@ -231,10 +235,13 @@ if __name__ == '__main__':
n_special = 3 n_special = 3
max_len = n_ctx//2-2 max_len = n_ctx//2-2
n_ctx = min(max( n_ctx = min(max(
[len(x1[:max_len])+max(len(x2[:max_len]), len(x3[:max_len])) for x1, x2, x3 in zip(trX1, trX2, trX3)] [len(x1[:max_len]) + max(len(x2[:max_len]),
+[len(x1[:max_len])+max(len(x2[:max_len]), len(x3[:max_len])) for x1, x2, x3 in zip(vaX1, vaX2, vaX3)] len(x3[:max_len])) for x1, x2, x3 in zip(trX1, trX2, trX3)]
+[len(x1[:max_len])+max(len(x2[:max_len]), len(x3[:max_len])) for x1, x2, x3 in zip(teX1, teX2, teX3)] +[len(x1[:max_len]) + max(len(x2[:max_len]),
)+3, len(x3[:max_len])) for x1, x2, x3 in zip(vaX1, vaX2, vaX3)]
+[len(x1[:max_len]) + max(len(x2[:max_len]),
len(x3[:max_len])) for x1, x2, x3 in zip(teX1, teX2, teX3)]
) + 3,
n_ctx) n_ctx)
vocab = n_vocab + n_special + n_ctx vocab = n_vocab + n_special + n_ctx
trX, trM = transform_roc(trX1, trX2, trX3) trX, trM = transform_roc(trX1, trX2, trX3)
@@ -247,7 +254,7 @@ if __name__ == '__main__':
n_batch_train = n_batch*n_gpu n_batch_train = n_batch*n_gpu
n_updates_total = (n_train//n_batch_train)*n_iter n_updates_total = (n_train//n_batch_train)*n_iter
model = Model(vocab, args) model = Model(vocab, n_ctx, args)
lm_head = LMHead(model, args) lm_head = LMHead(model, args)
clf_head = ClfHead(clf_token, args) clf_head = ClfHead(clf_token, args)
@@ -271,8 +278,8 @@ if __name__ == '__main__':
path = os.path.join(save_dir, desc, 'best_params') path = os.path.join(save_dir, desc, 'best_params')
torch.save(model.state_dict(), make_path(path)) torch.save(model.state_dict(), make_path(path))
best_score = 0 best_score = 0
log()
for i in range(n_iter): for i in range(n_iter):
print("running epoch", i)
run_epoch() run_epoch()
n_epochs += 1 n_epochs += 1
log() log()