diff --git a/.gitignore b/.gitignore index 43821c0..b3702b4 100644 --- a/.gitignore +++ b/.gitignore @@ -2,6 +2,7 @@ model save log +submission .vscode diff --git a/model_py.py b/model_py.py index d1c21b0..136f466 100644 --- a/model_py.py +++ b/model_py.py @@ -20,51 +20,6 @@ ACT_FNS = { 'gelu': gelu } -def load_openai_pretrained_model(model, n_ctx, n_special, cfg, path='model'): - # Load weights from TF model - n_transfer = cfg.n_transfer - shapes = json.load(open(path + '/params_shapes.json')) - names = json.load(open(path + '/parameters_names.json')) - offsets = np.cumsum([np.prod(shape) for shape in shapes]) - init_params = [np.load(path + '/params_{}.npy'.format(n)) for n in range(10)] - init_params = np.split(np.concatenate(init_params, 0), offsets)[:-1] - init_params = [param.reshape(shape) for param, shape in zip(init_params, shapes)] - init_params[0] = init_params[0][:n_ctx] - init_params[0] = np.concatenate([init_params[1], (np.random.randn(n_special, cfg.n_embd)*0.02).astype(np.float32), init_params[0]], 0) - del init_params[1] - if n_transfer == -1: - n_transfer = 0 - else: - n_transfer = 1+n_transfer*12 - init_params = [arr.squeeze() for arr in init_params] - try: - assert model.embed.weight.shape == init_params[0].shape - except AssertionError as e: - e.args += (model.embed.weight.shape, init_params[0].shape) - raise - model.embed.weight.data = torch.from_numpy(init_params[0]) - for name, ip in zip(names[1:n_transfer], init_params[1:n_transfer]): - name = name[6:] # skip "model/" - assert name[-2:] == ":0" - name = name[:-2] - name = name.split('/') - pointer = model - for m_name in name: - if re.fullmatch(r'[A-Za-z]+\d+', m_name): - l = re.split(r'(\d+)', m_name) - else: - l = [m_name] - pointer = getattr(pointer, l[0]) - if len(l) >= 2: - num = int(l[1]) - pointer = pointer[num] - try: - assert pointer.shape == ip.shape - except AssertionError as e: - e.args += (pointer.shape, ip.shape) - raise - pointer.data = torch.from_numpy(ip) - class LayerNorm(nn.Module): "Construct a layernorm module (See citation for details)." @@ -87,7 +42,9 @@ class Conv1D(nn.Module): self.rf = rf self.nf = nf if rf == 1: #faster 1x1 conv - self.w = Parameter(torch.ones(nx, nf)) # TODO change to random normal + w = torch.empty(nx, nf) + nn.init.normal_(w, std=0.02) + self.w = Parameter(w) self.b = Parameter(torch.zeros(nf)) else: #was used to train LM raise NotImplementedError @@ -123,7 +80,7 @@ class Attention(nn.Module): if self.scale: w = w / math.sqrt(v.size(-1)) w = w * self.b + -1e9*(1-self.b) # TF implem method: mask_attn_weights - w = nn.Softmax()(w) + w = nn.Softmax(dim=-1)(w) w = self.attn_dropout(w) return torch.matmul(w, v) @@ -198,6 +155,8 @@ class Model(nn.Module): self.decoder.weight = self.embed.weight # Tied weights self.clf_dropout = nn.Dropout2d(cfg.clf_pdrop) # To reproduce the noise_shape parameter of TF implementation + nn.init.normal_(self.embed.weight, std=0.02) + def forward(self, x): x = x.view(-1, x.size(2), x.size(3)) e = self.embed(x) @@ -230,6 +189,8 @@ class ClfHead(nn.Module): self.clf_token = clf_token self.dropout = nn.Dropout2d(cfg.clf_pdrop) # To reproduce the noise_shape parameter of TF implementation self.linear = nn.Linear(cfg.n_embd, 1) + nn.init.normal_(self.linear.weight, std=0.02) + nn.init.normal_(self.linear.bias, 0) def forward(self, h, x): # Classification logits @@ -242,3 +203,49 @@ class ClfHead(nn.Module): clf_h = clf_h.view(-1, self.n_embd) clf_logits = self.linear(clf_h) return clf_logits.view(-1, 2) + + +def load_openai_pretrained_model(model, n_ctx, n_special, cfg, path='model'): + # Load weights from TF model + n_transfer = cfg.n_transfer + shapes = json.load(open(path + '/params_shapes.json')) + names = json.load(open(path + '/parameters_names.json')) + offsets = np.cumsum([np.prod(shape) for shape in shapes]) + init_params = [np.load(path + '/params_{}.npy'.format(n)) for n in range(10)] + init_params = np.split(np.concatenate(init_params, 0), offsets)[:-1] + init_params = [param.reshape(shape) for param, shape in zip(init_params, shapes)] + init_params[0] = init_params[0][:n_ctx] + init_params[0] = np.concatenate([init_params[1], (np.random.randn(n_special, cfg.n_embd)*0.02).astype(np.float32), init_params[0]], 0) + del init_params[1] + if n_transfer == -1: + n_transfer = 0 + else: + n_transfer = 1+n_transfer*12 + init_params = [arr.squeeze() for arr in init_params] + try: + assert model.embed.weight.shape == init_params[0].shape + except AssertionError as e: + e.args += (model.embed.weight.shape, init_params[0].shape) + raise + model.embed.weight.data = torch.from_numpy(init_params[0]) + for name, ip in zip(names[1:n_transfer], init_params[1:n_transfer]): + name = name[6:] # skip "model/" + assert name[-2:] == ":0" + name = name[:-2] + name = name.split('/') + pointer = model + for m_name in name: + if re.fullmatch(r'[A-Za-z]+\d+', m_name): + l = re.split(r'(\d+)', m_name) + else: + l = [m_name] + pointer = getattr(pointer, l[0]) + if len(l) >= 2: + num = int(l[1]) + pointer = pointer[num] + try: + assert pointer.shape == ip.shape + except AssertionError as e: + e.args += (pointer.shape, ip.shape) + raise + pointer.data = torch.from_numpy(ip) diff --git a/opt.py b/opt.py index bf1d322..991d269 100644 --- a/opt.py +++ b/opt.py @@ -1,7 +1,7 @@ import math import torch from torch.optim import Optimizer -from torch.nn.utils import clip_grad_norm +from torch.nn.utils import clip_grad_norm_ def warmup_cosine(x, warmup=0.002): s = 1 if x <= warmup else 0 @@ -81,12 +81,12 @@ class OpenAIAdam(Optimizer): # Add grad clipping if group['max_grad_norm'] > 0: - clip_grad_norm(p, group['max_grad_norm']) + clip_grad_norm_(p, group['max_grad_norm']) # Decay the first and second moment running average coefficient exp_avg.mul_(beta1).add_(1 - beta1, grad) exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) - denom = exp_avg_sq.sqrt().add_(group['eps']) + denom = exp_avg_sq.sqrt().add_(group['e']) bias_correction1 = 1 - beta1 ** state['step'] bias_correction2 = 1 - beta2 ** state['step'] diff --git a/train.py b/train.py index 6c36c72..c723c3c 100644 --- a/train.py +++ b/train.py @@ -33,27 +33,28 @@ class LossCompute: self.lm_coef = lm_coef self.opt = opt - def __call__(self, X, Y, M, lm_logits, clf_logits): + def __call__(self, X, Y, M, clf_logits, lm_logits=None, only_return_losses=False): # Language modeling loss - x_shifted = X[:, :, 1:, 0].contiguous().view(-1) # Shape: 252 - M = M.view(-1, M.size(2)) - lm_losses = self.lm_criterion(lm_logits, x_shifted) - lm_losses = lm_losses.view(X.size(0) * X.size(1), X.size(2)-1) - lm_losses = lm_losses * M[:, 1:] - lm_losses = lm_losses.sum(1) / torch.sum(M[:, 1:], 1) - + if lm_logits is not None: + x_shifted = X[:, :, 1:, 0].contiguous().view(-1) # Shape: 252 + M = M.view(-1, M.size(2)) + lm_losses = self.lm_criterion(lm_logits, x_shifted) + lm_losses = lm_losses.view(X.size(0) * X.size(1), X.size(2)-1) + lm_losses = lm_losses * M[:, 1:] + lm_losses = lm_losses.sum(1) / torch.sum(M[:, 1:], 1) # Classification loss clf_losses = self.clf_criterion(clf_logits, Y) + if only_return_losses: + return (clf_losses, lm_losses) if lm_logits is not None else clf_losses - if self.lm_coef > 0: + if self.lm_coef > 0 and lm_logits is not None: train_loss = clf_losses.sum() + self.lm_coef * lm_losses.sum() else: train_loss = clf_losses.sum() - train_loss.backward() if self.opt is not None: self.opt.step() - self.opt.optimizer.zero_grad() + self.opt.zero_grad() return train_loss.item() @@ -75,60 +76,84 @@ def transform_roc(X1, X2, X3): xmb[:, :, :, 1] = np.arange(n_vocab+n_special, n_vocab+n_special+n_ctx) return xmb, mmb -# def iter_apply(Xs, Ms, Ys): -# fns = [lambda x:np.concatenate(x, 0), lambda x:float(np.sum(x))] -# results = [] -# for xmb, mmb, ymb in iter_data(Xs, Ms, Ys, n_batch=n_batch_train, truncate=False, verbose=True): -# n = len(xmb) -# if n == n_batch_train: -# res = sess.run([eval_mgpu_logits, eval_mgpu_clf_loss], {X_train:xmb, M_train:mmb, Y_train:ymb}) -# else: -# res = sess.run([eval_logits, eval_clf_loss], {X:xmb, M:mmb, Y:ymb}) -# res = [r*n for r in res] -# results.append(res) -# results = zip(*results) -# return [fn(res) for res, fn in zip(results, fns)] +def iter_apply(Xs, Ms, Ys): + fns = [lambda x: np.concatenate(x, 0), lambda x: float(np.sum(x))] + results = [] + with torch.no_grad(): + model.eval() + for xmb, mmb, ymb in iter_data(Xs, Ms, Ys, n_batch=n_batch_train, truncate=False, verbose=True): + n = len(xmb) + XMB = torch.tensor(xmb, dtype=torch.long).to(device) + YMB = torch.tensor(ymb, dtype=torch.long).to(device) + MMB = torch.tensor(mmb).to(device) + h = model(XMB) + clf_logits = clf_head(h, XMB) + clf_losses = compute_loss(XMB, YMB, MMB, clf_logits, only_return_losses=True) + res = (clf_logits.numpy()*n, clf_losses.numpy()*n) + results.append(res) + results = zip(*results) + return [fn(res) for res, fn in zip(results, fns)] -# def iter_predict(Xs, Ms): -# logits = [] -# for xmb, mmb in iter_data(Xs, Ms, n_batch=n_batch_train, truncate=False, verbose=True): -# n = len(xmb) -# if n == n_batch_train: -# logits.append(sess.run(eval_mgpu_logits, {X_train:xmb, M_train:mmb})) -# else: -# logits.append(sess.run(eval_logits, {X:xmb, M:mmb})) -# logits = np.concatenate(logits, 0) -# return logits +def iter_predict(Xs, Ms): + logits = [] + with torch.no_grad(): + model.eval() + for xmb, mmb in iter_data(Xs, Ms, n_batch=n_batch_train, truncate=False, verbose=True): + n = len(xmb) + XMB = torch.tensor(xmb, dtype=torch.long).to(device) + MMB = torch.tensor(mmb).to(device) + h = model(XMB) + clf_logits = clf_head(h, XMB) + logits.append(clf_logits.numpy()) + logits = np.concatenate(logits, 0) + return logits -# def log(): -# global best_score -# tr_logits, tr_cost = iter_apply(trX[:n_valid], trM[:n_valid], trY[:n_valid]) -# va_logits, va_cost = iter_apply(vaX, vaM, vaY) -# tr_cost = tr_cost/len(trY[:n_valid]) -# va_cost = va_cost/n_valid -# tr_acc = accuracy_score(trY[:n_valid], np.argmax(tr_logits, 1))*100. -# va_acc = accuracy_score(vaY, np.argmax(va_logits, 1))*100. -# logger.log(n_epochs=n_epochs, n_updates=n_updates, tr_cost=tr_cost, va_cost=va_cost, tr_acc=tr_acc, va_acc=va_acc) -# print('%d %d %.3f %.3f %.2f %.2f'%(n_epochs, n_updates, tr_cost, va_cost, tr_acc, va_acc)) -# if submit: -# score = va_acc -# if score > best_score: -# best_score = score -# save(os.path.join(save_dir, desc, 'best_params.jl')) +def log(): + global best_score + tr_logits, tr_cost = iter_apply(trX[:n_valid], trM[:n_valid], trY[:n_valid]) + va_logits, va_cost = iter_apply(vaX, vaM, vaY) + tr_cost = tr_cost/len(trY[:n_valid]) + va_cost = va_cost/n_valid + tr_acc = accuracy_score(trY[:n_valid], np.argmax(tr_logits, 1))*100. + va_acc = accuracy_score(vaY, np.argmax(va_logits, 1))*100. + logger.log(n_epochs=n_epochs, n_updates=n_updates, tr_cost=tr_cost, va_cost=va_cost, tr_acc=tr_acc, va_acc=va_acc) + print('%d %d %.3f %.3f %.2f %.2f'%(n_epochs, n_updates, tr_cost, va_cost, tr_acc, va_acc)) + if submit: + score = va_acc + if score > best_score: + best_score = score + path = os.path.join(save_dir, desc, 'best_params') + torch.save(model.state_dict(), make_path(path)) -# def predict(): -# filename = filenames[dataset] -# pred_fn = pred_fns[dataset] -# label_decoder = label_decoders[dataset] -# predictions = pred_fn(iter_predict(teX, teM)) -# if label_decoder is not None: -# predictions = [label_decoder[prediction] for prediction in predictions] -# path = os.path.join(submission_dir, filename) -# os.makedirs(os.path.dirname(path), exist_ok=True) -# with open(path, 'w') as f: -# f.write('{}\t{}\n'.format('index', 'prediction')) -# for i, prediction in enumerate(predictions): -# f.write('{}\t{}\n'.format(i, prediction)) +def predict(): + filename = filenames[dataset] + pred_fn = pred_fns[dataset] + label_decoder = label_decoders[dataset] + predictions = pred_fn(iter_predict(teX, teM)) + if label_decoder is not None: + predictions = [label_decoder[prediction] for prediction in predictions] + path = os.path.join(submission_dir, filename) + os.makedirs(os.path.dirname(path), exist_ok=True) + with open(path, 'w') as f: + f.write('{}\t{}\n'.format('index', 'prediction')) + for i, prediction in enumerate(predictions): + f.write('{}\t{}\n'.format(i, prediction)) + +def run_epoch(): + for xmb, mmb, ymb in iter_data(*shuffle(trX, trM, trYt, random_state=np.random), + n_batch=n_batch_train, truncate=True, verbose=True): + global n_updates + XMB = torch.tensor(xmb, dtype=torch.long).to(device) + YMB = torch.tensor(ymb, dtype=torch.long).to(device) + MMB = torch.tensor(mmb).to(device) + model.train() + h = model(XMB) + lm_logits = lm_head(h) + clf_logits = clf_head(h, XMB) + compute_loss(XMB, YMB, MMB, clf_logits, lm_logits) + n_updates += 1 + if n_updates in [1000, 2000, 4000, 8000, 16000, 32000] and n_epochs == 0: + log() argmax = lambda x:np.argmax(x, 1) @@ -235,7 +260,6 @@ if __name__ == '__main__': max_grad_norm=max_grad_norm) compute_loss = LossCompute(criterion, criterion, lm_coef, model_opt) - # TODO Initialize model (?) # TODO add train() and eval() load_openai_pretrained_model(model, n_ctx, n_special, args) @@ -250,26 +274,16 @@ if __name__ == '__main__': if submit: path = os.path.join(save_dir, desc, 'best_params') torch.save(model.state_dict(), make_path(path)) - best_score = 0 + log() for i in range(n_iter): - for xmb, mmb, ymb in iter_data(*shuffle(trX, trM, trYt, random_state=np.random), - n_batch=n_batch_train, truncate=True, verbose=True): - XMB = torch.tensor(xmb, dtype=torch.long).to(device) - YMB = torch.tensor(ymb, dtype=torch.long).to(device) - MMB = torch.tensor(mmb).to(device) - model.train() - h = model(XMB) - lm_logits = lm_head(h) - clf_logits = clf_head(h, XMB) - loss = compute_loss(XMB, YMB, MMB, lm_logits, clf_logits) - n_updates += 1 - #if n_updates in [1000, 2000, 4000, 8000, 16000, 32000] and n_epochs == 0: - # log() + run_epoch() n_epochs += 1 - # log() - # if submit: - # sess.run([p.assign(ip) for p, ip in zip(params, joblib.load(os.path.join(save_dir, desc, 'best_params.jl')))]) - # predict() - # if analysis: - # rocstories_analysis(data_dir, os.path.join(submission_dir, 'ROCStories.tsv'), os.path.join(log_dir, 'rocstories.jsonl')) + log() + if submit: + path = os.path.join(save_dir, desc, 'best_params') + model.load_state_dict(torch.load(path)) + predict() + if analysis: + rocstories_analysis(data_dir, os.path.join(submission_dir, 'ROCStories.tsv'), + os.path.join(log_dir, 'rocstories.jsonl'))