mirror of
https://github.com/wassname/openai-transformer-lm-gutenberg-erotic.git
synced 2026-06-26 16:00:39 +08:00
fix model and training
This commit is contained in:
Binary file not shown.
|
After Width: | Height: | Size: 204 KiB |
+1
-1
@@ -27,7 +27,7 @@ def _rocstories(path):
|
||||
y.append(int(line[-1])-1)
|
||||
return st, ct1, ct2, y
|
||||
|
||||
def rocstories(data_dir, n_train=1497, n_valid=2): #374): # TODO: set this back
|
||||
def rocstories(data_dir, n_train=1497, n_valid=374):
|
||||
storys, comps1, comps2, ys = _rocstories(os.path.join(data_dir, 'cloze_test_val__spring2016 - cloze_test_ALL_val.csv'))
|
||||
teX1, teX2, teX3, _ = _rocstories(os.path.join(data_dir, 'cloze_test_test__spring2016 - cloze_test_ALL_test.csv'))
|
||||
tr_storys, va_storys, tr_comps1, va_comps1, tr_comps2, va_comps2, tr_ys, va_ys = train_test_split(storys, comps1, comps2, ys, test_size=n_valid, random_state=seed)
|
||||
|
||||
+10
-11
@@ -60,13 +60,12 @@ class Conv1D(nn.Module):
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, nx, cfg, scale=False):
|
||||
def __init__(self, nx, n_ctx, cfg, scale=False):
|
||||
super(Attention, self).__init__()
|
||||
n_state = nx # in Attention: n_state=768 (nx=n_embd)
|
||||
#[switch nx => n_state from Block to Attention to keep identical to TF implem]
|
||||
assert n_state % cfg.n_head==0
|
||||
mask_size = n_state // cfg.n_head
|
||||
self.register_buffer('b', torch.tril(torch.ones(mask_size, mask_size)).view(1, 1, mask_size, mask_size))
|
||||
self.register_buffer('b', torch.tril(torch.ones(n_ctx, n_ctx)).view(1, 1, n_ctx, n_ctx))
|
||||
self.n_head = cfg.n_head
|
||||
self.split_size = n_state
|
||||
self.scale = scale
|
||||
@@ -126,10 +125,10 @@ class MLP(nn.Module):
|
||||
|
||||
|
||||
class Block(nn.Module):
|
||||
def __init__(self, cfg, scale=False):
|
||||
def __init__(self, n_ctx, cfg, scale=False):
|
||||
super(Block, self).__init__()
|
||||
nx = cfg.n_embd
|
||||
self.attn = Attention(nx, cfg, scale)
|
||||
self.attn = Attention(nx, n_ctx, cfg, scale)
|
||||
self.ln_1 = LayerNorm(nx)
|
||||
self.mlp = MLP(4*nx, cfg)
|
||||
self.ln_2 = LayerNorm(nx)
|
||||
@@ -144,12 +143,12 @@ class Block(nn.Module):
|
||||
|
||||
class Model(nn.Module):
|
||||
""" Transformer model """
|
||||
def __init__(self, vocab, cfg):
|
||||
def __init__(self, vocab, n_ctx, cfg):
|
||||
super(Model, self).__init__()
|
||||
self.vocab = vocab
|
||||
self.embed = nn.Embedding(vocab, cfg.n_embd)
|
||||
self.drop = nn.Dropout(cfg.embd_pdrop)
|
||||
block = Block(cfg, scale=True)
|
||||
block = Block(n_ctx, cfg, scale=True)
|
||||
self.h = nn.ModuleList([copy.deepcopy(block) for _ in range(cfg.n_layer)])
|
||||
self.decoder = nn.Linear(cfg.n_embd, vocab, bias=False)
|
||||
self.decoder.weight = self.embed.weight # Tied weights
|
||||
@@ -205,12 +204,12 @@ class ClfHead(nn.Module):
|
||||
return clf_logits.view(-1, 2)
|
||||
|
||||
|
||||
def load_openai_pretrained_model(model, n_ctx=-1, n_special=-1, n_transfer=12, n_embd=768, path='model'):
|
||||
def load_openai_pretrained_model(model, n_ctx=-1, n_special=-1, n_transfer=12, n_embd=768, path='./model/', path_names='./'):
|
||||
# Load weights from TF model
|
||||
shapes = json.load(open(path + '/params_shapes.json'))
|
||||
names = json.load(open(path + '/parameters_names.json'))
|
||||
names = json.load(open(path_names + 'parameters_names.json'))
|
||||
shapes = json.load(open(path + 'params_shapes.json'))
|
||||
offsets = np.cumsum([np.prod(shape) for shape in shapes])
|
||||
init_params = [np.load(path + '/params_{}.npy'.format(n)) for n in range(10)]
|
||||
init_params = [np.load(path + 'params_{}.npy'.format(n)) for n in range(10)]
|
||||
init_params = np.split(np.concatenate(init_params, 0), offsets)[:-1]
|
||||
init_params = [param.reshape(shape) for param, shape in zip(init_params, shapes)]
|
||||
if n_ctx > 0:
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
["model/we:0", "model/h0/attn/c_attn/w:0", "model/h0/attn/c_attn/b:0", "model/h0/attn/c_proj/w:0", "model/h0/attn/c_proj/b:0", "model/h0/ln_1/g:0", "model/h0/ln_1/b:0", "model/h0/mlp/c_fc/w:0", "model/h0/mlp/c_fc/b:0", "model/h0/mlp/c_proj/w:0", "model/h0/mlp/c_proj/b:0", "model/h0/ln_2/g:0", "model/h0/ln_2/b:0", "model/h1/attn/c_attn/w:0", "model/h1/attn/c_attn/b:0", "model/h1/attn/c_proj/w:0", "model/h1/attn/c_proj/b:0", "model/h1/ln_1/g:0", "model/h1/ln_1/b:0", "model/h1/mlp/c_fc/w:0", "model/h1/mlp/c_fc/b:0", "model/h1/mlp/c_proj/w:0", "model/h1/mlp/c_proj/b:0", "model/h1/ln_2/g:0", "model/h1/ln_2/b:0", "model/h2/attn/c_attn/w:0", "model/h2/attn/c_attn/b:0", "model/h2/attn/c_proj/w:0", "model/h2/attn/c_proj/b:0", "model/h2/ln_1/g:0", "model/h2/ln_1/b:0", "model/h2/mlp/c_fc/w:0", "model/h2/mlp/c_fc/b:0", "model/h2/mlp/c_proj/w:0", "model/h2/mlp/c_proj/b:0", "model/h2/ln_2/g:0", "model/h2/ln_2/b:0", "model/h3/attn/c_attn/w:0", "model/h3/attn/c_attn/b:0", "model/h3/attn/c_proj/w:0", "model/h3/attn/c_proj/b:0", "model/h3/ln_1/g:0", "model/h3/ln_1/b:0", "model/h3/mlp/c_fc/w:0", "model/h3/mlp/c_fc/b:0", "model/h3/mlp/c_proj/w:0", "model/h3/mlp/c_proj/b:0", "model/h3/ln_2/g:0", "model/h3/ln_2/b:0", "model/h4/attn/c_attn/w:0", "model/h4/attn/c_attn/b:0", "model/h4/attn/c_proj/w:0", "model/h4/attn/c_proj/b:0", "model/h4/ln_1/g:0", "model/h4/ln_1/b:0", "model/h4/mlp/c_fc/w:0", "model/h4/mlp/c_fc/b:0", "model/h4/mlp/c_proj/w:0", "model/h4/mlp/c_proj/b:0", "model/h4/ln_2/g:0", "model/h4/ln_2/b:0", "model/h5/attn/c_attn/w:0", "model/h5/attn/c_attn/b:0", "model/h5/attn/c_proj/w:0", "model/h5/attn/c_proj/b:0", "model/h5/ln_1/g:0", "model/h5/ln_1/b:0", "model/h5/mlp/c_fc/w:0", "model/h5/mlp/c_fc/b:0", "model/h5/mlp/c_proj/w:0", "model/h5/mlp/c_proj/b:0", "model/h5/ln_2/g:0", "model/h5/ln_2/b:0", "model/h6/attn/c_attn/w:0", "model/h6/attn/c_attn/b:0", "model/h6/attn/c_proj/w:0", "model/h6/attn/c_proj/b:0", "model/h6/ln_1/g:0", "model/h6/ln_1/b:0", "model/h6/mlp/c_fc/w:0", "model/h6/mlp/c_fc/b:0", "model/h6/mlp/c_proj/w:0", "model/h6/mlp/c_proj/b:0", "model/h6/ln_2/g:0", "model/h6/ln_2/b:0", "model/h7/attn/c_attn/w:0", "model/h7/attn/c_attn/b:0", "model/h7/attn/c_proj/w:0", "model/h7/attn/c_proj/b:0", "model/h7/ln_1/g:0", "model/h7/ln_1/b:0", "model/h7/mlp/c_fc/w:0", "model/h7/mlp/c_fc/b:0", "model/h7/mlp/c_proj/w:0", "model/h7/mlp/c_proj/b:0", "model/h7/ln_2/g:0", "model/h7/ln_2/b:0", "model/h8/attn/c_attn/w:0", "model/h8/attn/c_attn/b:0", "model/h8/attn/c_proj/w:0", "model/h8/attn/c_proj/b:0", "model/h8/ln_1/g:0", "model/h8/ln_1/b:0", "model/h8/mlp/c_fc/w:0", "model/h8/mlp/c_fc/b:0", "model/h8/mlp/c_proj/w:0", "model/h8/mlp/c_proj/b:0", "model/h8/ln_2/g:0", "model/h8/ln_2/b:0", "model/h9/attn/c_attn/w:0", "model/h9/attn/c_attn/b:0", "model/h9/attn/c_proj/w:0", "model/h9/attn/c_proj/b:0", "model/h9/ln_1/g:0", "model/h9/ln_1/b:0", "model/h9/mlp/c_fc/w:0", "model/h9/mlp/c_fc/b:0", "model/h9/mlp/c_proj/w:0", "model/h9/mlp/c_proj/b:0", "model/h9/ln_2/g:0", "model/h9/ln_2/b:0", "model/h10/attn/c_attn/w:0", "model/h10/attn/c_attn/b:0", "model/h10/attn/c_proj/w:0", "model/h10/attn/c_proj/b:0", "model/h10/ln_1/g:0", "model/h10/ln_1/b:0", "model/h10/mlp/c_fc/w:0", "model/h10/mlp/c_fc/b:0", "model/h10/mlp/c_proj/w:0", "model/h10/mlp/c_proj/b:0", "model/h10/ln_2/g:0", "model/h10/ln_2/b:0", "model/h11/attn/c_attn/w:0", "model/h11/attn/c_attn/b:0", "model/h11/attn/c_proj/w:0", "model/h11/attn/c_proj/b:0", "model/h11/ln_1/g:0", "model/h11/ln_1/b:0", "model/h11/mlp/c_fc/w:0", "model/h11/mlp/c_fc/b:0", "model/h11/mlp/c_proj/w:0", "model/h11/mlp/c_proj/b:0", "model/h11/ln_2/g:0", "model/h11/ln_2/b:0", "model/clf/w:0", "model/clf/b:0"]
|
||||
@@ -3,7 +3,6 @@ import os
|
||||
import time
|
||||
import math
|
||||
import json
|
||||
import joblib
|
||||
import random
|
||||
import argparse
|
||||
import numpy as np
|
||||
@@ -76,8 +75,9 @@ def transform_roc(X1, X2, X3):
|
||||
return xmb, mmb
|
||||
|
||||
def iter_apply(Xs, Ms, Ys):
|
||||
fns = [lambda x: np.concatenate(x, 0), lambda x: float(np.sum(x))]
|
||||
results = []
|
||||
# fns = [lambda x: np.concatenate(x, 0), lambda x: float(np.sum(x))]
|
||||
logits = []
|
||||
cost = 0
|
||||
with torch.no_grad():
|
||||
model.eval()
|
||||
for xmb, mmb, ymb in iter_data(Xs, Ms, Ys, n_batch=n_batch_train, truncate=False, verbose=True):
|
||||
@@ -87,11 +87,13 @@ def iter_apply(Xs, Ms, Ys):
|
||||
MMB = torch.tensor(mmb).to(device)
|
||||
h = model(XMB)
|
||||
clf_logits = clf_head(h, XMB)
|
||||
clf_logits *= n
|
||||
clf_losses = compute_loss_fct(XMB, YMB, MMB, clf_logits, only_return_losses=True)
|
||||
res = (clf_logits.numpy()*n, clf_losses.numpy()*n)
|
||||
results.append(res)
|
||||
results = zip(*results)
|
||||
return [fn(res) for res, fn in zip(results, fns)]
|
||||
clf_losses *= n
|
||||
logits.append(clf_logits.to("cpu").numpy())
|
||||
cost += clf_losses.sum().item()
|
||||
logits = np.concatenate(logits, 0)
|
||||
return logits, cost
|
||||
|
||||
def iter_predict(Xs, Ms):
|
||||
logits = []
|
||||
@@ -103,12 +105,13 @@ def iter_predict(Xs, Ms):
|
||||
MMB = torch.tensor(mmb).to(device)
|
||||
h = model(XMB)
|
||||
clf_logits = clf_head(h, XMB)
|
||||
logits.append(clf_logits.numpy())
|
||||
logits.append(clf_logits.to("cpu").numpy())
|
||||
logits = np.concatenate(logits, 0)
|
||||
return logits
|
||||
|
||||
def log():
|
||||
global best_score
|
||||
print("Logging")
|
||||
tr_logits, tr_cost = iter_apply(trX[:n_valid], trM[:n_valid], trY[:n_valid])
|
||||
va_logits, va_cost = iter_apply(vaX, vaM, vaY)
|
||||
tr_cost = tr_cost/len(trY[:n_valid])
|
||||
@@ -142,10 +145,10 @@ def run_epoch():
|
||||
for xmb, mmb, ymb in iter_data(*shuffle(trX, trM, trYt, random_state=np.random),
|
||||
n_batch=n_batch_train, truncate=True, verbose=True):
|
||||
global n_updates
|
||||
model.train()
|
||||
XMB = torch.tensor(xmb, dtype=torch.long).to(device)
|
||||
YMB = torch.tensor(ymb, dtype=torch.long).to(device)
|
||||
MMB = torch.tensor(mmb).to(device)
|
||||
model.train()
|
||||
h = model(XMB)
|
||||
lm_logits = lm_head(h)
|
||||
clf_logits = clf_head(h, XMB)
|
||||
@@ -194,7 +197,7 @@ if __name__ == '__main__':
|
||||
parser.add_argument('--clf_pdrop', type=float, default=0.1)
|
||||
parser.add_argument('--l2', type=float, default=0.01)
|
||||
parser.add_argument('--vector_l2', action='store_true')
|
||||
parser.add_argument('--n_gpu', type=int, default=4)
|
||||
parser.add_argument('--n_gpu', type=int, default=1)#4) # TODO add mutli-gpu training logic
|
||||
parser.add_argument('--opt', type=str, default='adam')
|
||||
parser.add_argument('--afn', type=str, default='gelu')
|
||||
parser.add_argument('--lr_schedule', type=str, default='warmup_linear')
|
||||
@@ -205,6 +208,7 @@ if __name__ == '__main__':
|
||||
parser.add_argument('--b1', type=float, default=0.9)
|
||||
parser.add_argument('--b2', type=float, default=0.999)
|
||||
parser.add_argument('--e', type=float, default=1e-8)
|
||||
parser.add_argument('--n_valid', type=int, default=374)
|
||||
|
||||
args = parser.parse_args()
|
||||
print(args)
|
||||
@@ -215,14 +219,14 @@ if __name__ == '__main__':
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
|
||||
# torch.device object used throughout this script TODO add gpu setting
|
||||
device = torch.device("cpu") #"cuda" if use_cuda else "cpu")
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
logger = ResultLogger(path=os.path.join(log_dir, '{}.jsonl'.format(desc)), **args.__dict__)
|
||||
text_encoder = TextEncoder(encoder_path, bpe_path)
|
||||
encoder = text_encoder.encoder
|
||||
n_vocab = len(text_encoder.encoder)
|
||||
|
||||
(trX1, trX2, trX3, trY), (vaX1, vaX2, vaX3, vaY), (teX1, teX2, teX3) = encode_dataset(rocstories(data_dir), encoder=text_encoder)
|
||||
(trX1, trX2, trX3, trY), (vaX1, vaX2, vaX3, vaY), (teX1, teX2, teX3) = encode_dataset(rocstories(data_dir, n_valid=n_valid), encoder=text_encoder)
|
||||
n_y = 2
|
||||
encoder['_start_'] = len(encoder)
|
||||
encoder['_delimiter_'] = len(encoder)
|
||||
@@ -231,10 +235,13 @@ if __name__ == '__main__':
|
||||
n_special = 3
|
||||
max_len = n_ctx//2-2
|
||||
n_ctx = min(max(
|
||||
[len(x1[:max_len])+max(len(x2[:max_len]), len(x3[:max_len])) for x1, x2, x3 in zip(trX1, trX2, trX3)]
|
||||
+[len(x1[:max_len])+max(len(x2[:max_len]), len(x3[:max_len])) for x1, x2, x3 in zip(vaX1, vaX2, vaX3)]
|
||||
+[len(x1[:max_len])+max(len(x2[:max_len]), len(x3[:max_len])) for x1, x2, x3 in zip(teX1, teX2, teX3)]
|
||||
)+3,
|
||||
[len(x1[:max_len]) + max(len(x2[:max_len]),
|
||||
len(x3[:max_len])) for x1, x2, x3 in zip(trX1, trX2, trX3)]
|
||||
+[len(x1[:max_len]) + max(len(x2[:max_len]),
|
||||
len(x3[:max_len])) for x1, x2, x3 in zip(vaX1, vaX2, vaX3)]
|
||||
+[len(x1[:max_len]) + max(len(x2[:max_len]),
|
||||
len(x3[:max_len])) for x1, x2, x3 in zip(teX1, teX2, teX3)]
|
||||
) + 3,
|
||||
n_ctx)
|
||||
vocab = n_vocab + n_special + n_ctx
|
||||
trX, trM = transform_roc(trX1, trX2, trX3)
|
||||
@@ -247,7 +254,7 @@ if __name__ == '__main__':
|
||||
n_batch_train = n_batch*n_gpu
|
||||
n_updates_total = (n_train//n_batch_train)*n_iter
|
||||
|
||||
model = Model(vocab, args)
|
||||
model = Model(vocab, n_ctx, args)
|
||||
lm_head = LMHead(model, args)
|
||||
clf_head = ClfHead(clf_token, args)
|
||||
|
||||
@@ -271,8 +278,8 @@ if __name__ == '__main__':
|
||||
path = os.path.join(save_dir, desc, 'best_params')
|
||||
torch.save(model.state_dict(), make_path(path))
|
||||
best_score = 0
|
||||
log()
|
||||
for i in range(n_iter):
|
||||
print("running epoch", i)
|
||||
run_epoch()
|
||||
n_epochs += 1
|
||||
log()
|
||||
|
||||
Reference in New Issue
Block a user