mirror of
https://github.com/wassname/openai-transformer-lm-gutenberg-erotic.git
synced 2026-06-26 16:00:39 +08:00
fixed modified Adam + added evaluation code
This commit is contained in:
@@ -2,6 +2,7 @@
|
||||
model
|
||||
save
|
||||
log
|
||||
submission
|
||||
|
||||
.vscode
|
||||
|
||||
|
||||
+54
-47
@@ -20,51 +20,6 @@ ACT_FNS = {
|
||||
'gelu': gelu
|
||||
}
|
||||
|
||||
def load_openai_pretrained_model(model, n_ctx, n_special, cfg, path='model'):
|
||||
# Load weights from TF model
|
||||
n_transfer = cfg.n_transfer
|
||||
shapes = json.load(open(path + '/params_shapes.json'))
|
||||
names = json.load(open(path + '/parameters_names.json'))
|
||||
offsets = np.cumsum([np.prod(shape) for shape in shapes])
|
||||
init_params = [np.load(path + '/params_{}.npy'.format(n)) for n in range(10)]
|
||||
init_params = np.split(np.concatenate(init_params, 0), offsets)[:-1]
|
||||
init_params = [param.reshape(shape) for param, shape in zip(init_params, shapes)]
|
||||
init_params[0] = init_params[0][:n_ctx]
|
||||
init_params[0] = np.concatenate([init_params[1], (np.random.randn(n_special, cfg.n_embd)*0.02).astype(np.float32), init_params[0]], 0)
|
||||
del init_params[1]
|
||||
if n_transfer == -1:
|
||||
n_transfer = 0
|
||||
else:
|
||||
n_transfer = 1+n_transfer*12
|
||||
init_params = [arr.squeeze() for arr in init_params]
|
||||
try:
|
||||
assert model.embed.weight.shape == init_params[0].shape
|
||||
except AssertionError as e:
|
||||
e.args += (model.embed.weight.shape, init_params[0].shape)
|
||||
raise
|
||||
model.embed.weight.data = torch.from_numpy(init_params[0])
|
||||
for name, ip in zip(names[1:n_transfer], init_params[1:n_transfer]):
|
||||
name = name[6:] # skip "model/"
|
||||
assert name[-2:] == ":0"
|
||||
name = name[:-2]
|
||||
name = name.split('/')
|
||||
pointer = model
|
||||
for m_name in name:
|
||||
if re.fullmatch(r'[A-Za-z]+\d+', m_name):
|
||||
l = re.split(r'(\d+)', m_name)
|
||||
else:
|
||||
l = [m_name]
|
||||
pointer = getattr(pointer, l[0])
|
||||
if len(l) >= 2:
|
||||
num = int(l[1])
|
||||
pointer = pointer[num]
|
||||
try:
|
||||
assert pointer.shape == ip.shape
|
||||
except AssertionError as e:
|
||||
e.args += (pointer.shape, ip.shape)
|
||||
raise
|
||||
pointer.data = torch.from_numpy(ip)
|
||||
|
||||
|
||||
class LayerNorm(nn.Module):
|
||||
"Construct a layernorm module (See citation for details)."
|
||||
@@ -87,7 +42,9 @@ class Conv1D(nn.Module):
|
||||
self.rf = rf
|
||||
self.nf = nf
|
||||
if rf == 1: #faster 1x1 conv
|
||||
self.w = Parameter(torch.ones(nx, nf)) # TODO change to random normal
|
||||
w = torch.empty(nx, nf)
|
||||
nn.init.normal_(w, std=0.02)
|
||||
self.w = Parameter(w)
|
||||
self.b = Parameter(torch.zeros(nf))
|
||||
else: #was used to train LM
|
||||
raise NotImplementedError
|
||||
@@ -123,7 +80,7 @@ class Attention(nn.Module):
|
||||
if self.scale:
|
||||
w = w / math.sqrt(v.size(-1))
|
||||
w = w * self.b + -1e9*(1-self.b) # TF implem method: mask_attn_weights
|
||||
w = nn.Softmax()(w)
|
||||
w = nn.Softmax(dim=-1)(w)
|
||||
w = self.attn_dropout(w)
|
||||
return torch.matmul(w, v)
|
||||
|
||||
@@ -198,6 +155,8 @@ class Model(nn.Module):
|
||||
self.decoder.weight = self.embed.weight # Tied weights
|
||||
self.clf_dropout = nn.Dropout2d(cfg.clf_pdrop) # To reproduce the noise_shape parameter of TF implementation
|
||||
|
||||
nn.init.normal_(self.embed.weight, std=0.02)
|
||||
|
||||
def forward(self, x):
|
||||
x = x.view(-1, x.size(2), x.size(3))
|
||||
e = self.embed(x)
|
||||
@@ -230,6 +189,8 @@ class ClfHead(nn.Module):
|
||||
self.clf_token = clf_token
|
||||
self.dropout = nn.Dropout2d(cfg.clf_pdrop) # To reproduce the noise_shape parameter of TF implementation
|
||||
self.linear = nn.Linear(cfg.n_embd, 1)
|
||||
nn.init.normal_(self.linear.weight, std=0.02)
|
||||
nn.init.normal_(self.linear.bias, 0)
|
||||
|
||||
def forward(self, h, x):
|
||||
# Classification logits
|
||||
@@ -242,3 +203,49 @@ class ClfHead(nn.Module):
|
||||
clf_h = clf_h.view(-1, self.n_embd)
|
||||
clf_logits = self.linear(clf_h)
|
||||
return clf_logits.view(-1, 2)
|
||||
|
||||
|
||||
def load_openai_pretrained_model(model, n_ctx, n_special, cfg, path='model'):
|
||||
# Load weights from TF model
|
||||
n_transfer = cfg.n_transfer
|
||||
shapes = json.load(open(path + '/params_shapes.json'))
|
||||
names = json.load(open(path + '/parameters_names.json'))
|
||||
offsets = np.cumsum([np.prod(shape) for shape in shapes])
|
||||
init_params = [np.load(path + '/params_{}.npy'.format(n)) for n in range(10)]
|
||||
init_params = np.split(np.concatenate(init_params, 0), offsets)[:-1]
|
||||
init_params = [param.reshape(shape) for param, shape in zip(init_params, shapes)]
|
||||
init_params[0] = init_params[0][:n_ctx]
|
||||
init_params[0] = np.concatenate([init_params[1], (np.random.randn(n_special, cfg.n_embd)*0.02).astype(np.float32), init_params[0]], 0)
|
||||
del init_params[1]
|
||||
if n_transfer == -1:
|
||||
n_transfer = 0
|
||||
else:
|
||||
n_transfer = 1+n_transfer*12
|
||||
init_params = [arr.squeeze() for arr in init_params]
|
||||
try:
|
||||
assert model.embed.weight.shape == init_params[0].shape
|
||||
except AssertionError as e:
|
||||
e.args += (model.embed.weight.shape, init_params[0].shape)
|
||||
raise
|
||||
model.embed.weight.data = torch.from_numpy(init_params[0])
|
||||
for name, ip in zip(names[1:n_transfer], init_params[1:n_transfer]):
|
||||
name = name[6:] # skip "model/"
|
||||
assert name[-2:] == ":0"
|
||||
name = name[:-2]
|
||||
name = name.split('/')
|
||||
pointer = model
|
||||
for m_name in name:
|
||||
if re.fullmatch(r'[A-Za-z]+\d+', m_name):
|
||||
l = re.split(r'(\d+)', m_name)
|
||||
else:
|
||||
l = [m_name]
|
||||
pointer = getattr(pointer, l[0])
|
||||
if len(l) >= 2:
|
||||
num = int(l[1])
|
||||
pointer = pointer[num]
|
||||
try:
|
||||
assert pointer.shape == ip.shape
|
||||
except AssertionError as e:
|
||||
e.args += (pointer.shape, ip.shape)
|
||||
raise
|
||||
pointer.data = torch.from_numpy(ip)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import math
|
||||
import torch
|
||||
from torch.optim import Optimizer
|
||||
from torch.nn.utils import clip_grad_norm
|
||||
from torch.nn.utils import clip_grad_norm_
|
||||
|
||||
def warmup_cosine(x, warmup=0.002):
|
||||
s = 1 if x <= warmup else 0
|
||||
@@ -81,12 +81,12 @@ class OpenAIAdam(Optimizer):
|
||||
|
||||
# Add grad clipping
|
||||
if group['max_grad_norm'] > 0:
|
||||
clip_grad_norm(p, group['max_grad_norm'])
|
||||
clip_grad_norm_(p, group['max_grad_norm'])
|
||||
|
||||
# Decay the first and second moment running average coefficient
|
||||
exp_avg.mul_(beta1).add_(1 - beta1, grad)
|
||||
exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
|
||||
denom = exp_avg_sq.sqrt().add_(group['eps'])
|
||||
denom = exp_avg_sq.sqrt().add_(group['e'])
|
||||
|
||||
bias_correction1 = 1 - beta1 ** state['step']
|
||||
bias_correction2 = 1 - beta2 ** state['step']
|
||||
|
||||
@@ -33,27 +33,28 @@ class LossCompute:
|
||||
self.lm_coef = lm_coef
|
||||
self.opt = opt
|
||||
|
||||
def __call__(self, X, Y, M, lm_logits, clf_logits):
|
||||
def __call__(self, X, Y, M, clf_logits, lm_logits=None, only_return_losses=False):
|
||||
# Language modeling loss
|
||||
x_shifted = X[:, :, 1:, 0].contiguous().view(-1) # Shape: 252
|
||||
M = M.view(-1, M.size(2))
|
||||
lm_losses = self.lm_criterion(lm_logits, x_shifted)
|
||||
lm_losses = lm_losses.view(X.size(0) * X.size(1), X.size(2)-1)
|
||||
lm_losses = lm_losses * M[:, 1:]
|
||||
lm_losses = lm_losses.sum(1) / torch.sum(M[:, 1:], 1)
|
||||
|
||||
if lm_logits is not None:
|
||||
x_shifted = X[:, :, 1:, 0].contiguous().view(-1) # Shape: 252
|
||||
M = M.view(-1, M.size(2))
|
||||
lm_losses = self.lm_criterion(lm_logits, x_shifted)
|
||||
lm_losses = lm_losses.view(X.size(0) * X.size(1), X.size(2)-1)
|
||||
lm_losses = lm_losses * M[:, 1:]
|
||||
lm_losses = lm_losses.sum(1) / torch.sum(M[:, 1:], 1)
|
||||
# Classification loss
|
||||
clf_losses = self.clf_criterion(clf_logits, Y)
|
||||
if only_return_losses:
|
||||
return (clf_losses, lm_losses) if lm_logits is not None else clf_losses
|
||||
|
||||
if self.lm_coef > 0:
|
||||
if self.lm_coef > 0 and lm_logits is not None:
|
||||
train_loss = clf_losses.sum() + self.lm_coef * lm_losses.sum()
|
||||
else:
|
||||
train_loss = clf_losses.sum()
|
||||
|
||||
train_loss.backward()
|
||||
if self.opt is not None:
|
||||
self.opt.step()
|
||||
self.opt.optimizer.zero_grad()
|
||||
self.opt.zero_grad()
|
||||
return train_loss.item()
|
||||
|
||||
|
||||
@@ -75,60 +76,84 @@ def transform_roc(X1, X2, X3):
|
||||
xmb[:, :, :, 1] = np.arange(n_vocab+n_special, n_vocab+n_special+n_ctx)
|
||||
return xmb, mmb
|
||||
|
||||
# def iter_apply(Xs, Ms, Ys):
|
||||
# fns = [lambda x:np.concatenate(x, 0), lambda x:float(np.sum(x))]
|
||||
# results = []
|
||||
# for xmb, mmb, ymb in iter_data(Xs, Ms, Ys, n_batch=n_batch_train, truncate=False, verbose=True):
|
||||
# n = len(xmb)
|
||||
# if n == n_batch_train:
|
||||
# res = sess.run([eval_mgpu_logits, eval_mgpu_clf_loss], {X_train:xmb, M_train:mmb, Y_train:ymb})
|
||||
# else:
|
||||
# res = sess.run([eval_logits, eval_clf_loss], {X:xmb, M:mmb, Y:ymb})
|
||||
# res = [r*n for r in res]
|
||||
# results.append(res)
|
||||
# results = zip(*results)
|
||||
# return [fn(res) for res, fn in zip(results, fns)]
|
||||
def iter_apply(Xs, Ms, Ys):
|
||||
fns = [lambda x: np.concatenate(x, 0), lambda x: float(np.sum(x))]
|
||||
results = []
|
||||
with torch.no_grad():
|
||||
model.eval()
|
||||
for xmb, mmb, ymb in iter_data(Xs, Ms, Ys, n_batch=n_batch_train, truncate=False, verbose=True):
|
||||
n = len(xmb)
|
||||
XMB = torch.tensor(xmb, dtype=torch.long).to(device)
|
||||
YMB = torch.tensor(ymb, dtype=torch.long).to(device)
|
||||
MMB = torch.tensor(mmb).to(device)
|
||||
h = model(XMB)
|
||||
clf_logits = clf_head(h, XMB)
|
||||
clf_losses = compute_loss(XMB, YMB, MMB, clf_logits, only_return_losses=True)
|
||||
res = (clf_logits.numpy()*n, clf_losses.numpy()*n)
|
||||
results.append(res)
|
||||
results = zip(*results)
|
||||
return [fn(res) for res, fn in zip(results, fns)]
|
||||
|
||||
# def iter_predict(Xs, Ms):
|
||||
# logits = []
|
||||
# for xmb, mmb in iter_data(Xs, Ms, n_batch=n_batch_train, truncate=False, verbose=True):
|
||||
# n = len(xmb)
|
||||
# if n == n_batch_train:
|
||||
# logits.append(sess.run(eval_mgpu_logits, {X_train:xmb, M_train:mmb}))
|
||||
# else:
|
||||
# logits.append(sess.run(eval_logits, {X:xmb, M:mmb}))
|
||||
# logits = np.concatenate(logits, 0)
|
||||
# return logits
|
||||
def iter_predict(Xs, Ms):
|
||||
logits = []
|
||||
with torch.no_grad():
|
||||
model.eval()
|
||||
for xmb, mmb in iter_data(Xs, Ms, n_batch=n_batch_train, truncate=False, verbose=True):
|
||||
n = len(xmb)
|
||||
XMB = torch.tensor(xmb, dtype=torch.long).to(device)
|
||||
MMB = torch.tensor(mmb).to(device)
|
||||
h = model(XMB)
|
||||
clf_logits = clf_head(h, XMB)
|
||||
logits.append(clf_logits.numpy())
|
||||
logits = np.concatenate(logits, 0)
|
||||
return logits
|
||||
|
||||
# def log():
|
||||
# global best_score
|
||||
# tr_logits, tr_cost = iter_apply(trX[:n_valid], trM[:n_valid], trY[:n_valid])
|
||||
# va_logits, va_cost = iter_apply(vaX, vaM, vaY)
|
||||
# tr_cost = tr_cost/len(trY[:n_valid])
|
||||
# va_cost = va_cost/n_valid
|
||||
# tr_acc = accuracy_score(trY[:n_valid], np.argmax(tr_logits, 1))*100.
|
||||
# va_acc = accuracy_score(vaY, np.argmax(va_logits, 1))*100.
|
||||
# logger.log(n_epochs=n_epochs, n_updates=n_updates, tr_cost=tr_cost, va_cost=va_cost, tr_acc=tr_acc, va_acc=va_acc)
|
||||
# print('%d %d %.3f %.3f %.2f %.2f'%(n_epochs, n_updates, tr_cost, va_cost, tr_acc, va_acc))
|
||||
# if submit:
|
||||
# score = va_acc
|
||||
# if score > best_score:
|
||||
# best_score = score
|
||||
# save(os.path.join(save_dir, desc, 'best_params.jl'))
|
||||
def log():
|
||||
global best_score
|
||||
tr_logits, tr_cost = iter_apply(trX[:n_valid], trM[:n_valid], trY[:n_valid])
|
||||
va_logits, va_cost = iter_apply(vaX, vaM, vaY)
|
||||
tr_cost = tr_cost/len(trY[:n_valid])
|
||||
va_cost = va_cost/n_valid
|
||||
tr_acc = accuracy_score(trY[:n_valid], np.argmax(tr_logits, 1))*100.
|
||||
va_acc = accuracy_score(vaY, np.argmax(va_logits, 1))*100.
|
||||
logger.log(n_epochs=n_epochs, n_updates=n_updates, tr_cost=tr_cost, va_cost=va_cost, tr_acc=tr_acc, va_acc=va_acc)
|
||||
print('%d %d %.3f %.3f %.2f %.2f'%(n_epochs, n_updates, tr_cost, va_cost, tr_acc, va_acc))
|
||||
if submit:
|
||||
score = va_acc
|
||||
if score > best_score:
|
||||
best_score = score
|
||||
path = os.path.join(save_dir, desc, 'best_params')
|
||||
torch.save(model.state_dict(), make_path(path))
|
||||
|
||||
# def predict():
|
||||
# filename = filenames[dataset]
|
||||
# pred_fn = pred_fns[dataset]
|
||||
# label_decoder = label_decoders[dataset]
|
||||
# predictions = pred_fn(iter_predict(teX, teM))
|
||||
# if label_decoder is not None:
|
||||
# predictions = [label_decoder[prediction] for prediction in predictions]
|
||||
# path = os.path.join(submission_dir, filename)
|
||||
# os.makedirs(os.path.dirname(path), exist_ok=True)
|
||||
# with open(path, 'w') as f:
|
||||
# f.write('{}\t{}\n'.format('index', 'prediction'))
|
||||
# for i, prediction in enumerate(predictions):
|
||||
# f.write('{}\t{}\n'.format(i, prediction))
|
||||
def predict():
|
||||
filename = filenames[dataset]
|
||||
pred_fn = pred_fns[dataset]
|
||||
label_decoder = label_decoders[dataset]
|
||||
predictions = pred_fn(iter_predict(teX, teM))
|
||||
if label_decoder is not None:
|
||||
predictions = [label_decoder[prediction] for prediction in predictions]
|
||||
path = os.path.join(submission_dir, filename)
|
||||
os.makedirs(os.path.dirname(path), exist_ok=True)
|
||||
with open(path, 'w') as f:
|
||||
f.write('{}\t{}\n'.format('index', 'prediction'))
|
||||
for i, prediction in enumerate(predictions):
|
||||
f.write('{}\t{}\n'.format(i, prediction))
|
||||
|
||||
def run_epoch():
|
||||
for xmb, mmb, ymb in iter_data(*shuffle(trX, trM, trYt, random_state=np.random),
|
||||
n_batch=n_batch_train, truncate=True, verbose=True):
|
||||
global n_updates
|
||||
XMB = torch.tensor(xmb, dtype=torch.long).to(device)
|
||||
YMB = torch.tensor(ymb, dtype=torch.long).to(device)
|
||||
MMB = torch.tensor(mmb).to(device)
|
||||
model.train()
|
||||
h = model(XMB)
|
||||
lm_logits = lm_head(h)
|
||||
clf_logits = clf_head(h, XMB)
|
||||
compute_loss(XMB, YMB, MMB, clf_logits, lm_logits)
|
||||
n_updates += 1
|
||||
if n_updates in [1000, 2000, 4000, 8000, 16000, 32000] and n_epochs == 0:
|
||||
log()
|
||||
|
||||
argmax = lambda x:np.argmax(x, 1)
|
||||
|
||||
@@ -235,7 +260,6 @@ if __name__ == '__main__':
|
||||
max_grad_norm=max_grad_norm)
|
||||
|
||||
compute_loss = LossCompute(criterion, criterion, lm_coef, model_opt)
|
||||
# TODO Initialize model (?)
|
||||
# TODO add train() and eval()
|
||||
load_openai_pretrained_model(model, n_ctx, n_special, args)
|
||||
|
||||
@@ -250,26 +274,16 @@ if __name__ == '__main__':
|
||||
if submit:
|
||||
path = os.path.join(save_dir, desc, 'best_params')
|
||||
torch.save(model.state_dict(), make_path(path))
|
||||
|
||||
best_score = 0
|
||||
log()
|
||||
for i in range(n_iter):
|
||||
for xmb, mmb, ymb in iter_data(*shuffle(trX, trM, trYt, random_state=np.random),
|
||||
n_batch=n_batch_train, truncate=True, verbose=True):
|
||||
XMB = torch.tensor(xmb, dtype=torch.long).to(device)
|
||||
YMB = torch.tensor(ymb, dtype=torch.long).to(device)
|
||||
MMB = torch.tensor(mmb).to(device)
|
||||
model.train()
|
||||
h = model(XMB)
|
||||
lm_logits = lm_head(h)
|
||||
clf_logits = clf_head(h, XMB)
|
||||
loss = compute_loss(XMB, YMB, MMB, lm_logits, clf_logits)
|
||||
n_updates += 1
|
||||
#if n_updates in [1000, 2000, 4000, 8000, 16000, 32000] and n_epochs == 0:
|
||||
# log()
|
||||
run_epoch()
|
||||
n_epochs += 1
|
||||
# log()
|
||||
# if submit:
|
||||
# sess.run([p.assign(ip) for p, ip in zip(params, joblib.load(os.path.join(save_dir, desc, 'best_params.jl')))])
|
||||
# predict()
|
||||
# if analysis:
|
||||
# rocstories_analysis(data_dir, os.path.join(submission_dir, 'ROCStories.tsv'), os.path.join(log_dir, 'rocstories.jsonl'))
|
||||
log()
|
||||
if submit:
|
||||
path = os.path.join(save_dir, desc, 'best_params')
|
||||
model.load_state_dict(torch.load(path))
|
||||
predict()
|
||||
if analysis:
|
||||
rocstories_analysis(data_dir, os.path.join(submission_dir, 'ROCStories.tsv'),
|
||||
os.path.join(log_dir, 'rocstories.jsonl'))
|
||||
|
||||
Reference in New Issue
Block a user