Clean up multi-processing logic -- Switch to PyTorch 0.4 style

This commit is contained in:
thomwolf
2018-06-28 10:43:15 +02:00
parent 49ff9b5355
commit aded2b058a
5 changed files with 44 additions and 53 deletions
+1 -1
View File
@@ -58,6 +58,6 @@ Finetuning the PyTorch model for 3 Epochs on ROCStories takes 10 minutes to run
The single run test accuracy of this PyTorch version is 85.84%, while the authors reports a median accuracy with the TensorFlow code of 85.8% and the paper reports a best single run accuracy of 86.5%.
The authors implementations uses 8 GPU and can thus accomodate a batch of 64 samples while the present implementation is single GPU and is in consequence limited to 20 instances on a K80 for memory reasons. In our test, increasing the batch size from 8 to 20 samples increased the test accuracy by 2.5 points. A better accuracy may be obtained by using a multi-GPU setting (on the TO-DO list).
The authors implementations uses 8 GPU and can thus accomodate a batch of 64 samples while the present implementation is single GPU and is in consequence limited to 20 instances on a K80 for memory reasons. In our test, increasing the batch size from 8 to 20 samples increased the test accuracy by 2.5 points. A better accuracy may be obtained by using a multi-GPU setting (not tried yet).
The previous SOTA on the ROCStories dataset is 77.6% ("Hidden Coherence Model" of Chaturvedi et al. published in "Story Comprehension for Predicting What Happens Next" EMNLP 2017, which is a very nice paper too!)
+1 -1
View File
@@ -10,7 +10,7 @@ from sklearn.model_selection import train_test_split
seed = 3535999445
def _rocstories(path):
with open(path) as f:
with open(path, encoding='utf_8') as f:
f = csv.reader(f)
st = []
ct1 = []
+19 -19
View File
@@ -146,11 +146,11 @@ class Block(nn.Module):
return h
class Model(nn.Module):
class TransformerModel(nn.Module):
""" Transformer model """
def __init__(self, cfg, vocab=40990, n_ctx=512):
super(Model, self).__init__()
super(TransformerModel, self).__init__()
self.vocab = vocab
self.embed = nn.Embedding(vocab, cfg.n_embd)
self.drop = nn.Dropout(cfg.embd_pdrop)
@@ -181,7 +181,7 @@ class LMHead(nn.Module):
def forward(self, h):
# Truncated Language modeling logits (we remove the last token)
h_trunc = h[:, :-1].contiguous().view(-1, self.n_embd) # Shape: 252, 768
h_trunc = h[:, :-1].contiguous().view(-1, self.n_embd)
lm_logits = self.decoder(h_trunc)
return lm_logits
@@ -202,24 +202,27 @@ class ClfHead(nn.Module):
# Classification logits
clf_h = h.view(-1, self.n_embd)
flat = x[:, :, :, 0].contiguous().view(-1)
# pool_idx = torch.eq(x[:, :, 0].contiguous().view(-1), self.clf_token)
clf_h = clf_h[flat == self.clf_token, :] # .index_select(0, pool_idx)
clf_h = clf_h.view(-1, 2, self.n_embd, 1)
clf_h = clf_h[flat == self.clf_token, :]
clf_h = clf_h.view(-1, x.size(1), self.n_embd, 1)
clf_h = self.dropout(clf_h)
clf_h = clf_h.view(-1, self.n_embd)
clf_logits = self.linear(clf_h)
return clf_logits.view(-1, 2)
return clf_logits.view(-1, x.size(1))
class DataParallelWithEmbed(torch.nn.DataParallel):
"""DataParallel that proxies the embed property to the wrapped module"""
class DoubleHeadModel(nn.Module):
""" Transformer with language model and classification heads """
def __init__(self, cfg, clf_token, vocab=40990, n_ctx=512):
super(DoubleHeadModel, self).__init__()
self.transformer = TransformerModel(cfg, vocab=vocab, n_ctx=n_ctx)
self.lm_head = LMHead(self.transformer, cfg)
self.clf_head = ClfHead(clf_token, cfg)
def __init__(self, model):
super(DataParallelWithEmbed, self).__init__(model)
@property
def embed(self):
return self.module.embed
def forward(self, x):
h = self.transformer(x)
lm_logits = self.lm_head(h)
clf_logits = self.clf_head(h, x)
return lm_logits, clf_logits
def load_openai_pretrained_model(model, n_ctx=-1, n_special=-1, n_transfer=12, n_embd=768, path='./model/',
@@ -260,15 +263,12 @@ def load_openai_pretrained_model(model, n_ctx=-1, n_special=-1, n_transfer=12, n
model.embed.weight.data = torch.from_numpy(init_params[0])
# Load the weights into our torch module
module = model.module
for name, ip in zip(names[1:n_transfer], init_params[1:n_transfer]):
name = name[6:] # skip "model/"
assert name[-2:] == ":0"
name = name[:-2]
name = name.split('/')
pointer = module
pointer = model
for m_name in name:
if re.fullmatch(r'[A-Za-z]+\d+', m_name):
l = re.split(r'(\d+)', m_name)
+1 -1
View File
@@ -41,7 +41,7 @@ class TextEncoder(object):
self.nlp = spacy.load('en', disable=['parser', 'tagger', 'ner', 'textcat'])
self.encoder = json.load(open(encoder_path))
self.decoder = {v:k for k,v in self.encoder.items()}
merges = open(bpe_path).read().split('\n')[1:-1]
merges = open(bpe_path, encoding='utf-8').read().split('\n')[1:-1]
merges = [tuple(merge.split()) for merge in merges]
self.bpe_ranks = dict(zip(merges, range(len(merges))))
self.cache = {}
+22 -31
View File
@@ -10,7 +10,7 @@ from sklearn.utils import shuffle
from analysis import rocstories as rocstories_analysis
from datasets import rocstories
from model_pytorch import Model, LMHead, ClfHead, load_openai_pretrained_model, DataParallelWithEmbed
from model_pytorch import DoubleHeadModel, load_openai_pretrained_model
from opt import OpenAIAdam
from text_utils import TextEncoder
from utils import (encode_dataset, iter_data,
@@ -75,14 +75,13 @@ def iter_apply(Xs, Ms, Ys):
logits = []
cost = 0
with torch.no_grad():
model.eval()
dh_model.eval()
for xmb, mmb, ymb in iter_data(Xs, Ms, Ys, n_batch=n_batch_train, truncate=False, verbose=True):
n = len(xmb)
XMB = torch.tensor(xmb, dtype=torch.long).to(device)
YMB = torch.tensor(ymb, dtype=torch.long).to(device)
MMB = torch.tensor(mmb).to(device)
h = model(XMB)
clf_logits = clf_head(h, XMB)
_, clf_logits = dh_model(XMB)
clf_logits *= n
clf_losses = compute_loss_fct(XMB, YMB, MMB, clf_logits, only_return_losses=True)
clf_losses *= n
@@ -95,13 +94,12 @@ def iter_apply(Xs, Ms, Ys):
def iter_predict(Xs, Ms):
logits = []
with torch.no_grad():
model.eval()
dh_model.eval()
for xmb, mmb in iter_data(Xs, Ms, n_batch=n_batch_train, truncate=False, verbose=True):
n = len(xmb)
XMB = torch.tensor(xmb, dtype=torch.long).to(device)
MMB = torch.tensor(mmb).to(device)
h = model(XMB)
clf_logits = clf_head(h, XMB)
_, clf_logits = dh_model(XMB)
logits.append(clf_logits.to("cpu").numpy())
logits = np.concatenate(logits, 0)
return logits
@@ -123,7 +121,7 @@ def log(save_dir, desc):
if score > best_score:
best_score = score
path = os.path.join(save_dir, desc, 'best_params')
torch.save(model.state_dict(), make_path(path))
torch.save(dh_model.state_dict(), make_path(path))
def predict(dataset, submission_dir):
@@ -145,13 +143,11 @@ def run_epoch():
for xmb, mmb, ymb in iter_data(*shuffle(trX, trM, trYt, random_state=np.random),
n_batch=n_batch_train, truncate=True, verbose=True):
global n_updates
model.train()
dh_model.train()
XMB = torch.tensor(xmb, dtype=torch.long).to(device)
YMB = torch.tensor(ymb, dtype=torch.long).to(device)
MMB = torch.tensor(mmb).to(device)
h = model(XMB)
lm_logits = lm_head(h)
clf_logits = clf_head(h, XMB)
lm_logits, clf_logits = dh_model(XMB)
compute_loss_fct(XMB, YMB, MMB, clf_logits, lm_logits)
n_updates += 1
if n_updates in [1000, 2000, 4000, 8000, 16000, 32000] and n_epochs == 0:
@@ -198,7 +194,7 @@ if __name__ == '__main__':
parser.add_argument('--clf_pdrop', type=float, default=0.1)
parser.add_argument('--l2', type=float, default=0.01)
parser.add_argument('--vector_l2', action='store_true')
parser.add_argument('--n_gpu', type=int, default=1) # 4) # TODO add mutli-gpu training logic
parser.add_argument('--n_gpu', type=int, default=1)
parser.add_argument('--opt', type=str, default='adam')
parser.add_argument('--afn', type=str, default='gelu')
parser.add_argument('--lr_schedule', type=str, default='warmup_linear')
@@ -213,7 +209,6 @@ if __name__ == '__main__':
args = parser.parse_args()
print(args)
# globals().update(args.__dict__) # TODO maybe we want to remove these gobal variables to make it cleaner
random.seed(args.seed)
np.random.seed(args.seed)
@@ -238,9 +233,10 @@ if __name__ == '__main__':
n_vocab = len(text_encoder.encoder)
print("Encoding dataset...")
(trX1, trX2, trX3, trY), (vaX1, vaX2, vaX3, vaY), (teX1, teX2, teX3) = encode_dataset(
rocstories(data_dir, n_valid=args.n_valid), encoder=text_encoder)
n_y = 2
((trX1, trX2, trX3, trY),
(vaX1, vaX2, vaX3, vaY),
(teX1, teX2, teX3)) = encode_dataset(rocstories(data_dir, n_valid=args.n_valid),
encoder=text_encoder)
encoder['_start_'] = len(encoder)
encoder['_delimiter_'] = len(encoder)
encoder['_classify_'] = len(encoder)
@@ -254,7 +250,7 @@ if __name__ == '__main__':
len(x3[:max_len])) for x1, x2, x3 in zip(vaX1, vaX2, vaX3)]
+ [len(x1[:max_len]) + max(len(x2[:max_len]),
len(x3[:max_len])) for x1, x2, x3 in zip(teX1, teX2, teX3)]
) + 3, n_ctx)
) + 3, n_ctx)
vocab = n_vocab + n_special + n_ctx
trX, trM = transform_roc(trX1, trX2, trX3)
vaX, vaM = transform_roc(vaX1, vaX2, vaX3)
@@ -266,14 +262,10 @@ if __name__ == '__main__':
n_batch_train = args.n_batch * args.n_gpu
n_updates_total = (n_train // n_batch_train) * args.n_iter
model = Model(args, vocab, n_ctx)
model = DataParallelWithEmbed(model).cuda()
dh_model = DoubleHeadModel(args, clf_token, vocab, n_ctx)
lm_head = LMHead(model, args)
clf_head = ClfHead(clf_token, args)
criterion = nn.CrossEntropyLoss(reduce=False) # TODO check loss functions
model_opt = OpenAIAdam(list(model.parameters()) + list(clf_head.parameters()) + list(lm_head.parameters()),
criterion = nn.CrossEntropyLoss(reduce=False)
model_opt = OpenAIAdam(dh_model.parameters(),
lr=args.lr,
schedule=args.lr_schedule,
warmup=args.lr_warmup,
@@ -288,11 +280,10 @@ if __name__ == '__main__':
criterion,
args.lm_coef,
model_opt)
load_openai_pretrained_model(model, n_ctx=n_ctx, n_special=n_special)
load_openai_pretrained_model(dh_model.transformer, n_ctx=n_ctx, n_special=n_special)
model.to(device)
lm_head.to(device)
clf_head.to(device)
dh_model.to(device)
dh_model = nn.DataParallel(dh_model)
n_updates = 0
n_epochs = 0
@@ -300,7 +291,7 @@ if __name__ == '__main__':
trYt = trY
if submit:
path = os.path.join(save_dir, desc, 'best_params')
torch.save(model.state_dict(), make_path(path))
torch.save(dh_model.state_dict(), make_path(path))
best_score = 0
for i in range(args.n_iter):
print("running epoch", i)
@@ -309,7 +300,7 @@ if __name__ == '__main__':
log(save_dir, desc)
if submit:
path = os.path.join(save_dir, desc, 'best_params')
model.load_state_dict(torch.load(path))
dh_model.load_state_dict(torch.load(path))
predict(dataset, args.submission_dir)
if args.analysis:
rocstories_analysis(data_dir, os.path.join(args.submission_dir, 'ROCStories.tsv'),