diff --git a/assets/ftlm.png b/assets/ftlm.png new file mode 100644 index 0000000..67bbc09 Binary files /dev/null and b/assets/ftlm.png differ diff --git a/datasets.py b/datasets.py index a53c334..41272c4 100644 --- a/datasets.py +++ b/datasets.py @@ -27,7 +27,7 @@ def _rocstories(path): y.append(int(line[-1])-1) return st, ct1, ct2, y -def rocstories(data_dir, n_train=1497, n_valid=2): #374): # TODO: set this back +def rocstories(data_dir, n_train=1497, n_valid=374): storys, comps1, comps2, ys = _rocstories(os.path.join(data_dir, 'cloze_test_val__spring2016 - cloze_test_ALL_val.csv')) teX1, teX2, teX3, _ = _rocstories(os.path.join(data_dir, 'cloze_test_test__spring2016 - cloze_test_ALL_test.csv')) tr_storys, va_storys, tr_comps1, va_comps1, tr_comps2, va_comps2, tr_ys, va_ys = train_test_split(storys, comps1, comps2, ys, test_size=n_valid, random_state=seed) diff --git a/model_py.py b/model_py.py index 37cc635..b57d8aa 100644 --- a/model_py.py +++ b/model_py.py @@ -60,13 +60,12 @@ class Conv1D(nn.Module): class Attention(nn.Module): - def __init__(self, nx, cfg, scale=False): + def __init__(self, nx, n_ctx, cfg, scale=False): super(Attention, self).__init__() n_state = nx # in Attention: n_state=768 (nx=n_embd) #[switch nx => n_state from Block to Attention to keep identical to TF implem] assert n_state % cfg.n_head==0 - mask_size = n_state // cfg.n_head - self.register_buffer('b', torch.tril(torch.ones(mask_size, mask_size)).view(1, 1, mask_size, mask_size)) + self.register_buffer('b', torch.tril(torch.ones(n_ctx, n_ctx)).view(1, 1, n_ctx, n_ctx)) self.n_head = cfg.n_head self.split_size = n_state self.scale = scale @@ -126,10 +125,10 @@ class MLP(nn.Module): class Block(nn.Module): - def __init__(self, cfg, scale=False): + def __init__(self, n_ctx, cfg, scale=False): super(Block, self).__init__() nx = cfg.n_embd - self.attn = Attention(nx, cfg, scale) + self.attn = Attention(nx, n_ctx, cfg, scale) self.ln_1 = LayerNorm(nx) self.mlp = MLP(4*nx, cfg) self.ln_2 = LayerNorm(nx) @@ -144,12 +143,12 @@ class Block(nn.Module): class Model(nn.Module): """ Transformer model """ - def __init__(self, vocab, cfg): + def __init__(self, vocab, n_ctx, cfg): super(Model, self).__init__() self.vocab = vocab self.embed = nn.Embedding(vocab, cfg.n_embd) self.drop = nn.Dropout(cfg.embd_pdrop) - block = Block(cfg, scale=True) + block = Block(n_ctx, cfg, scale=True) self.h = nn.ModuleList([copy.deepcopy(block) for _ in range(cfg.n_layer)]) self.decoder = nn.Linear(cfg.n_embd, vocab, bias=False) self.decoder.weight = self.embed.weight # Tied weights @@ -205,12 +204,12 @@ class ClfHead(nn.Module): return clf_logits.view(-1, 2) -def load_openai_pretrained_model(model, n_ctx=-1, n_special=-1, n_transfer=12, n_embd=768, path='model'): +def load_openai_pretrained_model(model, n_ctx=-1, n_special=-1, n_transfer=12, n_embd=768, path='./model/', path_names='./'): # Load weights from TF model - shapes = json.load(open(path + '/params_shapes.json')) - names = json.load(open(path + '/parameters_names.json')) + names = json.load(open(path_names + 'parameters_names.json')) + shapes = json.load(open(path + 'params_shapes.json')) offsets = np.cumsum([np.prod(shape) for shape in shapes]) - init_params = [np.load(path + '/params_{}.npy'.format(n)) for n in range(10)] + init_params = [np.load(path + 'params_{}.npy'.format(n)) for n in range(10)] init_params = np.split(np.concatenate(init_params, 0), offsets)[:-1] init_params = [param.reshape(shape) for param, shape in zip(init_params, shapes)] if n_ctx > 0: diff --git a/parameters_names.json b/parameters_names.json new file mode 100644 index 0000000..6081e28 --- /dev/null +++ b/parameters_names.json @@ -0,0 +1 @@ +["model/we:0", "model/h0/attn/c_attn/w:0", "model/h0/attn/c_attn/b:0", "model/h0/attn/c_proj/w:0", "model/h0/attn/c_proj/b:0", "model/h0/ln_1/g:0", "model/h0/ln_1/b:0", "model/h0/mlp/c_fc/w:0", "model/h0/mlp/c_fc/b:0", "model/h0/mlp/c_proj/w:0", "model/h0/mlp/c_proj/b:0", "model/h0/ln_2/g:0", "model/h0/ln_2/b:0", "model/h1/attn/c_attn/w:0", "model/h1/attn/c_attn/b:0", "model/h1/attn/c_proj/w:0", "model/h1/attn/c_proj/b:0", "model/h1/ln_1/g:0", "model/h1/ln_1/b:0", "model/h1/mlp/c_fc/w:0", "model/h1/mlp/c_fc/b:0", "model/h1/mlp/c_proj/w:0", "model/h1/mlp/c_proj/b:0", "model/h1/ln_2/g:0", "model/h1/ln_2/b:0", "model/h2/attn/c_attn/w:0", "model/h2/attn/c_attn/b:0", "model/h2/attn/c_proj/w:0", "model/h2/attn/c_proj/b:0", "model/h2/ln_1/g:0", "model/h2/ln_1/b:0", "model/h2/mlp/c_fc/w:0", "model/h2/mlp/c_fc/b:0", "model/h2/mlp/c_proj/w:0", "model/h2/mlp/c_proj/b:0", "model/h2/ln_2/g:0", "model/h2/ln_2/b:0", "model/h3/attn/c_attn/w:0", "model/h3/attn/c_attn/b:0", "model/h3/attn/c_proj/w:0", "model/h3/attn/c_proj/b:0", "model/h3/ln_1/g:0", "model/h3/ln_1/b:0", "model/h3/mlp/c_fc/w:0", "model/h3/mlp/c_fc/b:0", "model/h3/mlp/c_proj/w:0", "model/h3/mlp/c_proj/b:0", "model/h3/ln_2/g:0", "model/h3/ln_2/b:0", "model/h4/attn/c_attn/w:0", "model/h4/attn/c_attn/b:0", "model/h4/attn/c_proj/w:0", "model/h4/attn/c_proj/b:0", "model/h4/ln_1/g:0", "model/h4/ln_1/b:0", "model/h4/mlp/c_fc/w:0", "model/h4/mlp/c_fc/b:0", "model/h4/mlp/c_proj/w:0", "model/h4/mlp/c_proj/b:0", "model/h4/ln_2/g:0", "model/h4/ln_2/b:0", "model/h5/attn/c_attn/w:0", "model/h5/attn/c_attn/b:0", "model/h5/attn/c_proj/w:0", "model/h5/attn/c_proj/b:0", "model/h5/ln_1/g:0", "model/h5/ln_1/b:0", "model/h5/mlp/c_fc/w:0", "model/h5/mlp/c_fc/b:0", "model/h5/mlp/c_proj/w:0", "model/h5/mlp/c_proj/b:0", "model/h5/ln_2/g:0", "model/h5/ln_2/b:0", "model/h6/attn/c_attn/w:0", "model/h6/attn/c_attn/b:0", "model/h6/attn/c_proj/w:0", "model/h6/attn/c_proj/b:0", "model/h6/ln_1/g:0", "model/h6/ln_1/b:0", "model/h6/mlp/c_fc/w:0", "model/h6/mlp/c_fc/b:0", "model/h6/mlp/c_proj/w:0", "model/h6/mlp/c_proj/b:0", "model/h6/ln_2/g:0", "model/h6/ln_2/b:0", "model/h7/attn/c_attn/w:0", "model/h7/attn/c_attn/b:0", "model/h7/attn/c_proj/w:0", "model/h7/attn/c_proj/b:0", "model/h7/ln_1/g:0", "model/h7/ln_1/b:0", "model/h7/mlp/c_fc/w:0", "model/h7/mlp/c_fc/b:0", "model/h7/mlp/c_proj/w:0", "model/h7/mlp/c_proj/b:0", "model/h7/ln_2/g:0", "model/h7/ln_2/b:0", "model/h8/attn/c_attn/w:0", "model/h8/attn/c_attn/b:0", "model/h8/attn/c_proj/w:0", "model/h8/attn/c_proj/b:0", "model/h8/ln_1/g:0", "model/h8/ln_1/b:0", "model/h8/mlp/c_fc/w:0", "model/h8/mlp/c_fc/b:0", "model/h8/mlp/c_proj/w:0", "model/h8/mlp/c_proj/b:0", "model/h8/ln_2/g:0", "model/h8/ln_2/b:0", "model/h9/attn/c_attn/w:0", "model/h9/attn/c_attn/b:0", "model/h9/attn/c_proj/w:0", "model/h9/attn/c_proj/b:0", "model/h9/ln_1/g:0", "model/h9/ln_1/b:0", "model/h9/mlp/c_fc/w:0", "model/h9/mlp/c_fc/b:0", "model/h9/mlp/c_proj/w:0", "model/h9/mlp/c_proj/b:0", "model/h9/ln_2/g:0", "model/h9/ln_2/b:0", "model/h10/attn/c_attn/w:0", "model/h10/attn/c_attn/b:0", "model/h10/attn/c_proj/w:0", "model/h10/attn/c_proj/b:0", "model/h10/ln_1/g:0", "model/h10/ln_1/b:0", "model/h10/mlp/c_fc/w:0", "model/h10/mlp/c_fc/b:0", "model/h10/mlp/c_proj/w:0", "model/h10/mlp/c_proj/b:0", "model/h10/ln_2/g:0", "model/h10/ln_2/b:0", "model/h11/attn/c_attn/w:0", "model/h11/attn/c_attn/b:0", "model/h11/attn/c_proj/w:0", "model/h11/attn/c_proj/b:0", "model/h11/ln_1/g:0", "model/h11/ln_1/b:0", "model/h11/mlp/c_fc/w:0", "model/h11/mlp/c_fc/b:0", "model/h11/mlp/c_proj/w:0", "model/h11/mlp/c_proj/b:0", "model/h11/ln_2/g:0", "model/h11/ln_2/b:0", "model/clf/w:0", "model/clf/b:0"] \ No newline at end of file diff --git a/train.py b/train.py index 04c4f28..60cdc71 100644 --- a/train.py +++ b/train.py @@ -3,7 +3,6 @@ import os import time import math import json -import joblib import random import argparse import numpy as np @@ -76,8 +75,9 @@ def transform_roc(X1, X2, X3): return xmb, mmb def iter_apply(Xs, Ms, Ys): - fns = [lambda x: np.concatenate(x, 0), lambda x: float(np.sum(x))] - results = [] + # fns = [lambda x: np.concatenate(x, 0), lambda x: float(np.sum(x))] + logits = [] + cost = 0 with torch.no_grad(): model.eval() for xmb, mmb, ymb in iter_data(Xs, Ms, Ys, n_batch=n_batch_train, truncate=False, verbose=True): @@ -87,11 +87,13 @@ def iter_apply(Xs, Ms, Ys): MMB = torch.tensor(mmb).to(device) h = model(XMB) clf_logits = clf_head(h, XMB) + clf_logits *= n clf_losses = compute_loss_fct(XMB, YMB, MMB, clf_logits, only_return_losses=True) - res = (clf_logits.numpy()*n, clf_losses.numpy()*n) - results.append(res) - results = zip(*results) - return [fn(res) for res, fn in zip(results, fns)] + clf_losses *= n + logits.append(clf_logits.to("cpu").numpy()) + cost += clf_losses.sum().item() + logits = np.concatenate(logits, 0) + return logits, cost def iter_predict(Xs, Ms): logits = [] @@ -103,12 +105,13 @@ def iter_predict(Xs, Ms): MMB = torch.tensor(mmb).to(device) h = model(XMB) clf_logits = clf_head(h, XMB) - logits.append(clf_logits.numpy()) + logits.append(clf_logits.to("cpu").numpy()) logits = np.concatenate(logits, 0) return logits def log(): global best_score + print("Logging") tr_logits, tr_cost = iter_apply(trX[:n_valid], trM[:n_valid], trY[:n_valid]) va_logits, va_cost = iter_apply(vaX, vaM, vaY) tr_cost = tr_cost/len(trY[:n_valid]) @@ -142,10 +145,10 @@ def run_epoch(): for xmb, mmb, ymb in iter_data(*shuffle(trX, trM, trYt, random_state=np.random), n_batch=n_batch_train, truncate=True, verbose=True): global n_updates + model.train() XMB = torch.tensor(xmb, dtype=torch.long).to(device) YMB = torch.tensor(ymb, dtype=torch.long).to(device) MMB = torch.tensor(mmb).to(device) - model.train() h = model(XMB) lm_logits = lm_head(h) clf_logits = clf_head(h, XMB) @@ -194,7 +197,7 @@ if __name__ == '__main__': parser.add_argument('--clf_pdrop', type=float, default=0.1) parser.add_argument('--l2', type=float, default=0.01) parser.add_argument('--vector_l2', action='store_true') - parser.add_argument('--n_gpu', type=int, default=4) + parser.add_argument('--n_gpu', type=int, default=1)#4) # TODO add mutli-gpu training logic parser.add_argument('--opt', type=str, default='adam') parser.add_argument('--afn', type=str, default='gelu') parser.add_argument('--lr_schedule', type=str, default='warmup_linear') @@ -205,6 +208,7 @@ if __name__ == '__main__': parser.add_argument('--b1', type=float, default=0.9) parser.add_argument('--b2', type=float, default=0.999) parser.add_argument('--e', type=float, default=1e-8) + parser.add_argument('--n_valid', type=int, default=374) args = parser.parse_args() print(args) @@ -215,14 +219,14 @@ if __name__ == '__main__': torch.cuda.manual_seed_all(seed) # torch.device object used throughout this script TODO add gpu setting - device = torch.device("cpu") #"cuda" if use_cuda else "cpu") + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") logger = ResultLogger(path=os.path.join(log_dir, '{}.jsonl'.format(desc)), **args.__dict__) text_encoder = TextEncoder(encoder_path, bpe_path) encoder = text_encoder.encoder n_vocab = len(text_encoder.encoder) - (trX1, trX2, trX3, trY), (vaX1, vaX2, vaX3, vaY), (teX1, teX2, teX3) = encode_dataset(rocstories(data_dir), encoder=text_encoder) + (trX1, trX2, trX3, trY), (vaX1, vaX2, vaX3, vaY), (teX1, teX2, teX3) = encode_dataset(rocstories(data_dir, n_valid=n_valid), encoder=text_encoder) n_y = 2 encoder['_start_'] = len(encoder) encoder['_delimiter_'] = len(encoder) @@ -231,10 +235,13 @@ if __name__ == '__main__': n_special = 3 max_len = n_ctx//2-2 n_ctx = min(max( - [len(x1[:max_len])+max(len(x2[:max_len]), len(x3[:max_len])) for x1, x2, x3 in zip(trX1, trX2, trX3)] - +[len(x1[:max_len])+max(len(x2[:max_len]), len(x3[:max_len])) for x1, x2, x3 in zip(vaX1, vaX2, vaX3)] - +[len(x1[:max_len])+max(len(x2[:max_len]), len(x3[:max_len])) for x1, x2, x3 in zip(teX1, teX2, teX3)] - )+3, + [len(x1[:max_len]) + max(len(x2[:max_len]), + len(x3[:max_len])) for x1, x2, x3 in zip(trX1, trX2, trX3)] + +[len(x1[:max_len]) + max(len(x2[:max_len]), + len(x3[:max_len])) for x1, x2, x3 in zip(vaX1, vaX2, vaX3)] + +[len(x1[:max_len]) + max(len(x2[:max_len]), + len(x3[:max_len])) for x1, x2, x3 in zip(teX1, teX2, teX3)] + ) + 3, n_ctx) vocab = n_vocab + n_special + n_ctx trX, trM = transform_roc(trX1, trX2, trX3) @@ -247,7 +254,7 @@ if __name__ == '__main__': n_batch_train = n_batch*n_gpu n_updates_total = (n_train//n_batch_train)*n_iter - model = Model(vocab, args) + model = Model(vocab, n_ctx, args) lm_head = LMHead(model, args) clf_head = ClfHead(clf_token, args) @@ -271,8 +278,8 @@ if __name__ == '__main__': path = os.path.join(save_dir, desc, 'best_params') torch.save(model.state_dict(), make_path(path)) best_score = 0 - log() for i in range(n_iter): + print("running epoch", i) run_epoch() n_epochs += 1 log()