mirror of
https://github.com/wassname/openai-transformer-lm-gutenberg-erotic.git
synced 2026-06-27 16:10:19 +08:00
readme, clean up
This commit is contained in:
@@ -1,3 +1,28 @@
|
||||
This uses Guntenberg books as source data, and generated the next words. In this example I used ~60 Erotic novels from last century.
|
||||
|
||||
The model is a pytorch implementation of OpenAI's Finetuned Transformer Language Model, with pretrained weights.
|
||||
|
||||
Example outputs:
|
||||
|
||||
|
||||
Base: rosy buttocks
|
||||
|
||||
Result: rosy buttocks but noted that his fingers had been already working him into a post mor tem, a replacement for his earlier foray with that unspeakable je ering creature.
|
||||
" vince ? "
|
||||
it was his niece, jessica, who had just reached lunch time. she eyed vince warily when she arrived, quickly averting her gaze when he glanced her way. he moved to victoria's side, but she stopped him with a slight gesture.
|
||||
" eat with me in costume, darling, " she urged, and then recalled the phone call this morning from emily and all the others at the supreme court. nobody had was meeting here after the inquisition, according to the files on the table. there was no way to avoid revealing her visit.
|
||||
vince nodded his acknowledgment
|
||||
|
||||
|
||||
|
||||
Base: I want you
|
||||
|
||||
Result: i want you to know in case i'm wrong and you'll show up soon. "
|
||||
" why do you say that ? " he said. " do you, and everybody in the world, think there's something wrong ? "
|
||||
she took a drink. " i'll get to that later. listen, i'm sorry for taking you away from your father, " she said. " colleen's a great person, but it's a lot to put on your shoulders, and we already know that. okay, she's a tough lady, you know her. she struggles, she's worried about the ranch, but her love for them gives them what they need. they let her go after a long absence... in her own way, anyway
|
||||
|
||||
|
||||
|
||||
# PyTorch implementation of OpenAI's Finetuned Transformer Language Model
|
||||
|
||||
This is a PyTorch implementation of the [TensorFlow code](https://github.com/openai/finetune-transformer-lm) provided with OpenAI's paper ["Improving Language Understanding by Generative Pre-Training"](https://blog.openai.com/language-unsupervised/) by Alec Radford, Karthik Narasimhan, Tim Salimans and Ilya Sutskever.
|
||||
@@ -27,37 +52,11 @@ To run the classifier training script in [train.py](train.py) you will need in a
|
||||
|
||||
You can download the weights of the OpenAI pre-trained version by cloning [Alec Radford's repo](https://github.com/openai/finetune-transformer-lm) and placing the `model` folder containing the pre-trained weights in the present repo.
|
||||
|
||||
## Using the pre-trained model as a Transformer Language Model
|
||||
The model can be used as a transformer language model with OpenAI's pre-trained weights as follow:
|
||||
```python
|
||||
from model_pytorch import TransformerModel, load_openai_pretrained_model, DEFAULT_CONFIG
|
||||
|
||||
args = DEFAULT_CONFIG
|
||||
model = TransformerModel(args)
|
||||
load_openai_pretrained_model(model)
|
||||
```
|
||||
|
||||
This model generates Transformer's hidden states. You can use the `LMHead` class in [model_pytorch.py](model_pytorch.py) to add a decoder tied with the weights of the encoder and get a full language model. You can also use the `ClfHead` class in [model_pytorch.py](model_pytorch.py) to add a classifier on top of the transformer and get a classifier as described in OpenAI's publication. (see an example of both in the `__main__` function of [train.py](train.py))
|
||||
|
||||
To use the positional encoder of the transformer, you should encode your dataset using the `encode_dataset()` function of [utils.py](utils.py). Please refer to the beginning of the `__main__` function in [train.py](train.py) to see how to properly define the vocabulary and encode your dataset.
|
||||
|
||||
## Fine-tuning the pre-trained model on a classification task
|
||||
This model can also be integrated in a classifier as detailed in [OpenAI's paper](https://blog.openai.com/language-unsupervised/). An example of fine-tuning on the ROCStories Cloze task is included with the training code in [train.py](train.py)
|
||||
|
||||
The ROCStories dataset can be downloaded from the associated [website](http://cs.rochester.edu/nlp/rocstories/).
|
||||
|
||||
As with the [TensorFlow code](https://github.com/openai/finetune-transformer-lm), this code implements the ROCStories Cloze Test result reported in the paper which can be reproduced by running:
|
||||
|
||||
```bash
|
||||
python -m spacy download en
|
||||
python train.py --dataset rocstories --desc rocstories --submit --analysis --data_dir [path to data here]
|
||||
```
|
||||
|
||||
#### First experiments on the ROCStories test set
|
||||
Finetuning the PyTorch model for 3 Epochs on ROCStories takes 10 minutes to run on a single NVidia K-80.
|
||||
## Fine-tuning the pre-trained model on a classification task
|
||||
|
||||
The single run test accuracy of this PyTorch version is 85.84%, while the authors reports a median accuracy with the TensorFlow code of 85.8% and the paper reports a best single run accuracy of 86.5%.
|
||||
Use the train.ipynb notebook
|
||||
|
||||
The authors implementations uses 8 GPU and can thus accomodate a batch of 64 samples while the present implementation is single GPU and is in consequence limited to 20 instances on a K80 for memory reasons. In our test, increasing the batch size from 8 to 20 samples increased the test accuracy by 2.5 points. A better accuracy may be obtained by using a multi-GPU setting (not tried yet).
|
||||
|
||||
The previous SOTA on the ROCStories dataset is 77.6% ("Hidden Coherence Model" of Chaturvedi et al. published in "Story Comprehension for Predicting What Happens Next" EMNLP 2017, which is a very nice paper too!)
|
||||
|
||||
-18
@@ -1,18 +0,0 @@
|
||||
import os
|
||||
import json
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from sklearn.metrics import accuracy_score
|
||||
|
||||
from datasets import _rocstories
|
||||
|
||||
def rocstories(data_dir, pred_path, log_path):
|
||||
preds = pd.read_csv(pred_path, delimiter='\t')['prediction'].values.tolist()
|
||||
_, _, _, labels = _rocstories(os.path.join(data_dir, 'erotic_gutenberg_VAL.csv'))
|
||||
test_accuracy = accuracy_score(labels, preds)*100.
|
||||
logs = [json.loads(line) for line in open(log_path)][1:]
|
||||
best_validation_index = np.argmax([log['va_acc'] for log in logs])
|
||||
valid_accuracy = logs[best_validation_index]['va_acc']
|
||||
print('ROCStories Valid Accuracy: %.2f'%(valid_accuracy))
|
||||
print('ROCStories Test Accuracy: %.2f'%(test_accuracy))
|
||||
-55
@@ -1,55 +0,0 @@
|
||||
import os
|
||||
import csv
|
||||
import numpy as np
|
||||
|
||||
from tqdm import tqdm_notebook as tqdm
|
||||
|
||||
from sklearn.utils import shuffle
|
||||
from sklearn.model_selection import train_test_split
|
||||
|
||||
seed = 3535999445
|
||||
|
||||
def _rocstories(path):
|
||||
with open(path, encoding='utf_8') as f:
|
||||
f = csv.reader(f)
|
||||
st = []
|
||||
ct1 = []
|
||||
ct2 = []
|
||||
y = []
|
||||
for i, line in enumerate(tqdm(list(f), ncols=80, mininterval=10, leave=False)):
|
||||
if i > 0:
|
||||
s = ' '.join(line[1:5]) # 4 sentances
|
||||
st.append(s)
|
||||
|
||||
c1 = line[5] # 2 possible answers
|
||||
c2 = line[6]
|
||||
ct1.append(c1)
|
||||
ct2.append(c2)
|
||||
|
||||
# correct answer
|
||||
y.append(int(line[-1])-1)
|
||||
return st, ct1, ct2, y
|
||||
|
||||
def rocstories(data_dir, n_train=1497, n_valid=374):
|
||||
storys, comps1, comps2, ys = _rocstories(os.path.join(data_dir, 'erotic_gutenberg_TRAIN.csv'))
|
||||
teX1, teX2, teX3, _ = _rocstories(os.path.join(data_dir, 'erotic_gutenberg_VAL.csv'))
|
||||
tr_storys, va_storys, tr_comps1, va_comps1, tr_comps2, va_comps2, tr_ys, va_ys = train_test_split(storys, comps1, comps2, ys, test_size=n_valid, random_state=seed)
|
||||
trX1, trX2, trX3 = [], [], []
|
||||
trY = []
|
||||
for s, c1, c2, y in zip(tr_storys, tr_comps1, tr_comps2, tr_ys):
|
||||
trX1.append(s)
|
||||
trX2.append(c1)
|
||||
trX3.append(c2)
|
||||
trY.append(y)
|
||||
|
||||
vaX1, vaX2, vaX3 = [], [], []
|
||||
vaY = []
|
||||
for s, c1, c2, y in zip(va_storys, va_comps1, va_comps2, va_ys):
|
||||
vaX1.append(s)
|
||||
vaX2.append(c1)
|
||||
vaX3.append(c2)
|
||||
vaY.append(y)
|
||||
trY = np.asarray(trY, dtype=np.int32)
|
||||
vaY = np.asarray(vaY, dtype=np.int32)
|
||||
# (stories, answer1, answer2, correct_anser_int)...
|
||||
return (trX1, trX2, trX3, trY), (vaX1, vaX2, vaX3, vaY), (teX1, teX2, teX3)
|
||||
@@ -1,71 +1,5 @@
|
||||
import torch
|
||||
|
||||
class MultipleChoiceLossCompute:
|
||||
"A Loss compute and train function for multiple choice tasks."
|
||||
|
||||
def __init__(self, lm_criterion, clf_criterion, lm_coef, opt=None):
|
||||
self.lm_criterion = lm_criterion
|
||||
self.clf_criterion = clf_criterion
|
||||
self.lm_coef = lm_coef
|
||||
self.opt = opt
|
||||
|
||||
def __call__(self, X, Y, M, clf_logits, lm_logits=None, only_return_losses=False):
|
||||
# Language modeling loss
|
||||
if lm_logits is not None:
|
||||
x_shifted = X[:, :, 1:, 0].contiguous().view(-1) # Shape: 252
|
||||
M = M.view(-1, M.size(2))
|
||||
lm_losses = self.lm_criterion(lm_logits, x_shifted)
|
||||
lm_losses = lm_losses.view(X.size(0) * X.size(1), X.size(2) - 1)
|
||||
lm_losses = lm_losses * M[:, 1:]
|
||||
lm_losses = lm_losses.sum(1) / torch.sum(M[:, 1:], 1)
|
||||
# Classification loss
|
||||
clf_losses = self.clf_criterion(clf_logits, Y)
|
||||
if only_return_losses:
|
||||
return (clf_losses, lm_losses) if lm_logits is not None else clf_losses
|
||||
|
||||
if self.lm_coef > 0 and lm_logits is not None:
|
||||
train_loss = clf_losses.sum() + self.lm_coef * lm_losses.sum()
|
||||
else:
|
||||
train_loss = clf_losses.sum()
|
||||
train_loss.backward()
|
||||
if self.opt is not None:
|
||||
self.opt.step()
|
||||
self.opt.zero_grad()
|
||||
return train_loss.item()
|
||||
|
||||
class ClassificationLossCompute:
|
||||
"A Loss compute and train function for classification tasks."
|
||||
|
||||
def __init__(self, lm_criterion, clf_criterion, lm_coef, opt=None):
|
||||
self.lm_criterion = lm_criterion
|
||||
self.clf_criterion = clf_criterion
|
||||
self.lm_coef = lm_coef
|
||||
self.opt = opt
|
||||
|
||||
def __call__(self, X, Y, M, clf_logits, lm_logits=None, only_return_losses=False):
|
||||
# Language modeling loss
|
||||
if lm_logits is not None:
|
||||
x_shifted = X[:, 1:, 0].contiguous().view(-1)
|
||||
M = M.view(-1, M.size(-1))
|
||||
lm_losses = self.lm_criterion(lm_logits, x_shifted)
|
||||
lm_losses = lm_losses.view(X.size(0), X.size(-2) - 1)
|
||||
lm_losses = lm_losses * M[:, 1:]
|
||||
lm_losses = lm_losses.sum(1) / torch.sum(M[:, 1:], 1)
|
||||
# Classification loss
|
||||
clf_losses = self.clf_criterion(clf_logits, Y)
|
||||
if only_return_losses:
|
||||
return (clf_losses, lm_losses) if lm_logits is not None else clf_losses
|
||||
|
||||
if self.lm_coef > 0 and lm_logits is not None:
|
||||
train_loss = clf_losses.sum() + self.lm_coef * lm_losses.sum()
|
||||
else:
|
||||
train_loss = clf_losses.sum()
|
||||
train_loss.backward()
|
||||
if self.opt is not None:
|
||||
self.opt.step()
|
||||
self.opt.zero_grad()
|
||||
return train_loss.item()
|
||||
|
||||
class LanguageModelingLossCompute:
|
||||
" A Loss compute and train function for language modeling tasks."
|
||||
def __init__(self, lm_criterion, opt=None):
|
||||
@@ -90,5 +24,3 @@ class LanguageModelingLossCompute:
|
||||
self.opt.step()
|
||||
self.opt.zero_grad()
|
||||
return train_loss.item()
|
||||
|
||||
# TODO Implement a LossCompute class for similiraty tasks.
|
||||
|
||||
@@ -187,117 +187,6 @@ class LMHead(nn.Module):
|
||||
return lm_logits
|
||||
|
||||
|
||||
class MultipleChoiceHead(nn.Module):
|
||||
""" Classifier Head for the transformer """
|
||||
|
||||
def __init__(self, clf_token, cfg):
|
||||
super(MultipleChoiceHead, self).__init__()
|
||||
self.n_embd = cfg.n_embd
|
||||
self.clf_token = clf_token
|
||||
self.dropout = nn.Dropout2d(cfg.clf_pdrop) # To reproduce the noise_shape parameter of TF implementation
|
||||
self.linear = nn.Linear(cfg.n_embd, 1)
|
||||
|
||||
nn.init.normal_(self.linear.weight, std = 0.02)
|
||||
nn.init.normal_(self.linear.bias, 0)
|
||||
|
||||
def forward(self, h, x):
|
||||
# Classification logits
|
||||
clf_h = h.view(-1, self.n_embd)
|
||||
flat = x[..., 0].contiguous().view(-1)
|
||||
clf_h = clf_h[flat == self.clf_token, :]
|
||||
clf_h = clf_h.view(-1, x.size(1), self.n_embd, 1)
|
||||
# This double transposition is there to replicate the behavior
|
||||
# of the noise_shape argument in the tensorflow
|
||||
# implementation. For more details, see
|
||||
# https://github.com/huggingface/pytorch-openai-transformer-lm/issues/11
|
||||
clf_h = self.dropout(clf_h.transpose(1, 2)).transpose(1, 2)
|
||||
clf_h = clf_h.contiguous().view(-1, self.n_embd)
|
||||
clf_logits = self.linear(clf_h)
|
||||
|
||||
return clf_logits.view(-1, x.size(1))
|
||||
|
||||
|
||||
class ClfHead(nn.Module):
|
||||
"""Classification Head for the transformer
|
||||
|
||||
TODO: test this class."""
|
||||
def __init__(self, clf_token, cfg, n_class):
|
||||
super(ClfHead, self).__init__()
|
||||
self.n_embd = cfg.n_embd
|
||||
self.clf_token = clf_token
|
||||
self.dropout = nn.Dropout(cfg.clf_pdrop)
|
||||
self.linear = nn.Linear(cfg.n_embd, n_class)
|
||||
|
||||
nn.init.normal_(self.linear.weight, std = 0.02)
|
||||
nn.init.normal_(self.linear.bias, 0)
|
||||
|
||||
def forward(self, h, x):
|
||||
clf_h = h.view(-1, self.n_embd)
|
||||
flat = x[..., 0].contiguous().view(-1)
|
||||
clf_h = clf_h[flat == self.clf_token, :]
|
||||
clf_h = self.dropout(clf_h)
|
||||
clf_logits = self.linear(clf_h)
|
||||
|
||||
return clf_logits
|
||||
|
||||
class SimilarityHead(nn.Module):
|
||||
""" Similarity Head for the transformer
|
||||
|
||||
TODO: test this class."""
|
||||
def __init__(self, clf_token, cfg):
|
||||
super(SimilarityHead, self).__init__()
|
||||
self.n_embd = cfg.n_embd
|
||||
self.clf_token = clf_token
|
||||
self.dropout = nn.Dropout(cfg.clf_pdrop)
|
||||
self.linear = nn.Linear(cfg.n_embd, 1)
|
||||
|
||||
nn.init.normal_(self.linear.weight, std = 0.02)
|
||||
nn.init.normal_(self.linear.bias, 0)
|
||||
|
||||
def forward(self, h, x):
|
||||
sim_h = h.view(-1, self.n_embd)
|
||||
flat = x[..., 0].contiguous().view(-1)
|
||||
sim_h = sim_h[flat == self.clf_token, :]
|
||||
sim_h = self.dropout(sim_h)
|
||||
sim_h = sim_h.sum(dim = 1)
|
||||
sim_logits = self.linear(sim_h)
|
||||
|
||||
return sim_logits
|
||||
|
||||
class DoubleHeadModel(nn.Module):
|
||||
""" Transformer with language model and task specific heads """
|
||||
def __init__(self, cfg, clf_token, task_head_type, vocab=40990, n_ctx=512):
|
||||
super(DoubleHeadModel, self).__init__()
|
||||
self.transformer = TransformerModel(cfg, vocab=vocab, n_ctx=n_ctx)
|
||||
self.lm_head = LMHead(self.transformer, cfg)
|
||||
if isinstance(task_head_type, str):
|
||||
if task_head_type == 'multiple_choice':
|
||||
self.task_head = MultipleChoiceHead(clf_token, cfg)
|
||||
elif task_head_type == 'similarity':
|
||||
self.task_head = SimilarityHead(clf_token, cfg)
|
||||
elif task_head_type == 'inference':
|
||||
# the three classes correspond to entailment, contradiction and neutral.
|
||||
self.task_head = ClfHead(clf_token, cfg, 3)
|
||||
else:
|
||||
raise ValueError("task_head_type is expected to be 'multiple_choice' "
|
||||
"'similarity', 'inference' or ('classification', n_class) "
|
||||
"got {task_head_type}.".format(task_head_type=task_head_type))
|
||||
elif isinstance(task_head_type, collections.abc.Sequence) and len(task_head_type) == 2 and \
|
||||
task_head_type[0] == 'classification':
|
||||
n_class = task_head_type[1]
|
||||
self.task_head = ClfHead(clf_token, cfg, n_class)
|
||||
else:
|
||||
raise ValueError("task_head_type is expected to be 'multiple_choice' "
|
||||
"'similarity', 'inference' or ('classification', n_class) "
|
||||
"got {task_head_type}.".format(task_head_type=task_head_type))
|
||||
|
||||
def forward(self, x):
|
||||
h = self.transformer(x)
|
||||
lm_logits = self.lm_head(h)
|
||||
task_logits = self.task_head(h, x)
|
||||
|
||||
return lm_logits, task_logits
|
||||
|
||||
class LanguageModel(nn.Module):
|
||||
""" Transformer with language model """
|
||||
def __init__(self, cfg, vocab=40990, n_ctx=512):
|
||||
|
||||
+260
-41
@@ -13,8 +13,8 @@
|
||||
"execution_count": 1,
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2018-11-04T09:35:47.026194Z",
|
||||
"start_time": "2018-11-04T09:35:46.675400Z"
|
||||
"end_time": "2018-11-04T11:15:15.794826Z",
|
||||
"start_time": "2018-11-04T11:15:15.434879Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
@@ -36,8 +36,8 @@
|
||||
"execution_count": 2,
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2018-11-04T09:35:48.585230Z",
|
||||
"start_time": "2018-11-04T09:35:47.030292Z"
|
||||
"end_time": "2018-11-04T11:15:17.344506Z",
|
||||
"start_time": "2018-11-04T11:15:15.798659Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
@@ -70,8 +70,8 @@
|
||||
"execution_count": 3,
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2018-11-04T09:35:48.631139Z",
|
||||
"start_time": "2018-11-04T09:35:48.588393Z"
|
||||
"end_time": "2018-11-04T11:15:17.384350Z",
|
||||
"start_time": "2018-11-04T11:15:17.347522Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
@@ -107,8 +107,8 @@
|
||||
"execution_count": 4,
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2018-11-04T09:35:48.682286Z",
|
||||
"start_time": "2018-11-04T09:35:48.634489Z"
|
||||
"end_time": "2018-11-04T11:15:17.433527Z",
|
||||
"start_time": "2018-11-04T11:15:17.387055Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
@@ -189,8 +189,8 @@
|
||||
"execution_count": 5,
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2018-11-04T09:35:48.734192Z",
|
||||
"start_time": "2018-11-04T09:35:48.685031Z"
|
||||
"end_time": "2018-11-04T11:15:17.474287Z",
|
||||
"start_time": "2018-11-04T11:15:17.435784Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
@@ -219,7 +219,20 @@
|
||||
" logits = np.concatenate(logits, 0)\n",
|
||||
" return logits, cost\n",
|
||||
"\n",
|
||||
" return cost\n",
|
||||
" return cost\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2018-11-04T11:15:17.523802Z",
|
||||
"start_time": "2018-11-04T11:15:17.476838Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"\n",
|
||||
"def decode_word(text_encoder, idx):\n",
|
||||
" if idx not in text_encoder.decoder:\n",
|
||||
@@ -232,7 +245,20 @@
|
||||
"def decode_sentence(text_encoder, idx_list):\n",
|
||||
" word_list = [decode_word(text_encoder, idx) for idx in idx_list]\n",
|
||||
"\n",
|
||||
" return ' '.join(word_list)\n",
|
||||
" # Fix some weird grammer, but not all\n",
|
||||
" replace = [\n",
|
||||
" [\"' \", \"'\"],\n",
|
||||
" [\" '\", \"'\"],\n",
|
||||
" [\" ,\", \",\"],\n",
|
||||
" [\" .\", \".\"],\n",
|
||||
" [\" i \", \" I \"],\n",
|
||||
" [\" n't\", \"n't\"],\n",
|
||||
" ]\n",
|
||||
" results2 = ' '.join(word_list)\n",
|
||||
" for a,b in replace:\n",
|
||||
" results2 = results2.replace(a, b)\n",
|
||||
"\n",
|
||||
" return results2\n",
|
||||
"\n",
|
||||
"def try_on_a_sentence(model, text_encoder, sentence, window_size,\n",
|
||||
" n_vocab, n_special, n_ctx, device,\n",
|
||||
@@ -255,21 +281,50 @@
|
||||
" )\n",
|
||||
" XMB = torch.tensor(X_trans, dtype = torch.long).to(device)\n",
|
||||
" lm_logits = model(XMB)\n",
|
||||
" \n",
|
||||
" # We truncate the resulting predictions to actual vocabulary\n",
|
||||
" # words in order to exclude special tokens and positional\n",
|
||||
" # embeddings.\n",
|
||||
" lm_logits = lm_logits[:, : n_vocab]\n",
|
||||
" X_trans_tensor = torch.from_numpy(X_trans)\n",
|
||||
" \n",
|
||||
" # We then select the logit corresponding to the 'clf_token'\n",
|
||||
" # position (last one of the sequence).\n",
|
||||
" X_trans_tensor = torch.from_numpy(X_trans)\n",
|
||||
" clf_token_bool_idx = X_trans_tensor[0, :, 0] == clf_token\n",
|
||||
" predictions = lm_logits.max(dim = 1)[1]\n",
|
||||
" \n",
|
||||
" # probabilistic sample so we don't get into loops\n",
|
||||
" predictions = torch.distributions.Multinomial(logits=lm_logits).sample().argmax(dim = 1)\n",
|
||||
" pred = predictions[clf_token_bool_idx[1:]].item()\n",
|
||||
" encoded_text.append(pred)\n",
|
||||
"\n",
|
||||
" return decode_sentence(text_encoder, encoded_text)"
|
||||
" return decode_sentence(text_encoder, encoded_text)\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2018-11-04T10:59:05.230384Z",
|
||||
"start_time": "2018-11-04T10:58:58.265400Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2018-11-04T10:57:46.556244Z",
|
||||
"start_time": "2018-11-04T10:57:46.513733Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
@@ -279,11 +334,11 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"execution_count": 7,
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2018-11-04T09:35:48.790923Z",
|
||||
"start_time": "2018-11-04T09:35:48.737560Z"
|
||||
"end_time": "2018-11-04T11:15:17.577518Z",
|
||||
"start_time": "2018-11-04T11:15:17.526474Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
@@ -383,11 +438,11 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"execution_count": 8,
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2018-11-04T09:35:49.459204Z",
|
||||
"start_time": "2018-11-04T09:35:48.793536Z"
|
||||
"end_time": "2018-11-04T11:15:18.255342Z",
|
||||
"start_time": "2018-11-04T11:15:17.580075Z"
|
||||
}
|
||||
},
|
||||
"outputs": [
|
||||
@@ -402,10 +457,10 @@
|
||||
" 'n_head': 12,\n",
|
||||
" 'n_layer': 12,\n",
|
||||
" 'resid_pdrop': 0.1},\n",
|
||||
" {'n_ctx': 258, 'n_special': 2, 'total_vocab_size': 40738})"
|
||||
" {'n_ctx': 130, 'n_special': 2, 'total_vocab_size': 40610})"
|
||||
]
|
||||
},
|
||||
"execution_count": 7,
|
||||
"execution_count": 8,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
@@ -413,8 +468,8 @@
|
||||
"source": [
|
||||
"# Training configuration\n",
|
||||
"epochs = 3\n",
|
||||
"n_batch_train = 20\n",
|
||||
"window_size = 256\n",
|
||||
"n_batch_train = 18\n",
|
||||
"window_size = 128\n",
|
||||
"max_len = window_size\n",
|
||||
"# General configuration\n",
|
||||
"save_dir = 'save/'\n",
|
||||
@@ -432,6 +487,7 @@
|
||||
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
|
||||
"bpe_path = 'model/vocab_40000.bpe'\n",
|
||||
"encoder_path = 'model/encoder_bpe_40000.json'\n",
|
||||
"data_path = 'data/erotic_gutenberg_dataset.csv'\n",
|
||||
"text_encoder = TextEncoder(encoder_path, bpe_path)\n",
|
||||
"encoder = text_encoder.encoder\n",
|
||||
"n_special = 2\n",
|
||||
@@ -467,10 +523,11 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 9,
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"start_time": "2018-11-04T09:35:46.407Z"
|
||||
"end_time": "2018-11-04T11:16:50.651449Z",
|
||||
"start_time": "2018-11-04T11:15:18.257548Z"
|
||||
}
|
||||
},
|
||||
"outputs": [
|
||||
@@ -494,13 +551,24 @@
|
||||
"text": [
|
||||
"\r"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"((1644206, 130, 2), (1644206, 130))"
|
||||
]
|
||||
},
|
||||
"execution_count": 9,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"\n",
|
||||
"(X_train, y_train), (X_val, y_val) = load_dataset(\n",
|
||||
" text_encoder,\n",
|
||||
" window_size = window_size,\n",
|
||||
" path = 'data/erotic_gutenberg_dataset.csv'\n",
|
||||
" path = data_path\n",
|
||||
")\n",
|
||||
"n_train = len(y_train)\n",
|
||||
"n_valid = len(y_val) // 10\n",
|
||||
@@ -521,7 +589,8 @@
|
||||
" n_vocab,\n",
|
||||
" n_special,\n",
|
||||
" n_ctx\n",
|
||||
")"
|
||||
")\n",
|
||||
"X_train_trans.shape, X_train_mask.shape"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -545,13 +614,32 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 10,
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"start_time": "2018-11-04T09:35:46.414Z"
|
||||
"end_time": "2018-11-04T11:16:58.993708Z",
|
||||
"start_time": "2018-11-04T11:16:50.654425Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Loading weights...\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"1"
|
||||
]
|
||||
},
|
||||
"execution_count": 10,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"language_model = LanguageModel(\n",
|
||||
" args,\n",
|
||||
@@ -576,13 +664,23 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 11,
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"start_time": "2018-11-04T09:35:46.419Z"
|
||||
"end_time": "2018-11-04T11:16:59.051091Z",
|
||||
"start_time": "2018-11-04T11:16:59.002806Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"/home/wassname/.pyenv/versions/3.5.3/envs/jupyter3/lib/python3.5/site-packages/torch/nn/functional.py:52: UserWarning: size_average and reduce args will be deprecated, please use reduction='none' instead.\n",
|
||||
" warnings.warn(warning.format(ret))\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"model_opt = OpenAIAdam(\n",
|
||||
" params = language_model.parameters(),\n",
|
||||
@@ -628,12 +726,72 @@
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"start_time": "2018-11-04T09:35:46.429Z"
|
||||
}
|
||||
"start_time": "2018-11-04T11:15:15.224Z"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"scrolled": true
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "541a6befc51c4f9280ad5107ac3482d9",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=91344), HTML(value='')), layout=Layout(displa…"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=1), HTML(value='')), layout=Layout(display='i…"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n",
|
||||
"\n",
|
||||
" Base: You had a great morning but your afternoon will be ruined because \n",
|
||||
"\n",
|
||||
" Result: you had a great morning but your afternoon will be ruined because of me ! \" \n",
|
||||
" \" alright ! no more of this horrible whispered reports and comments from agnes. I hear derrick has been distracted with her. maybe it will bring some sense into his head, \" I say. \n",
|
||||
" we're quiet for a moment when derrick talks. \" I didn't realize you two were gay. I knew it ! if I had known about this, I wouldn't have had you hide in your closet ! \" \n",
|
||||
" I look at him and in a japanese accent, say, \" honey, that would have made your afternoon a lot at least better. you can delegate your responsibility to derrick and I have spoken to him a lot lately ! \" \n",
|
||||
" \" I can't\n",
|
||||
"\n",
|
||||
"Logging\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "fd3c2c8585e7445a8312b206ed4830fd",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=2283), HTML(value='')), layout=Layout(display…"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"for epoch in range(epochs):\n",
|
||||
" run_epoch(\n",
|
||||
" model = language_model,\n",
|
||||
@@ -653,7 +811,7 @@
|
||||
" X_val_mask = X_val_mask,\n",
|
||||
" y_val = y_val,\n",
|
||||
" generation_params = generation_parameters\n",
|
||||
" )\n"
|
||||
" )"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -661,15 +819,76 @@
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"start_time": "2018-11-04T09:35:46.435Z"
|
||||
"start_time": "2018-11-04T11:15:15.233Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"sentances = [\n",
|
||||
" 'i want you to want',\n",
|
||||
" 'please help me',\n",
|
||||
" 'let us run far away from',\n",
|
||||
" 'rosy',\n",
|
||||
" 'that unspeakable creature'\n",
|
||||
" 'when can I see you',\n",
|
||||
" 'I must',\n",
|
||||
" 'gaze at your enhanting'\n",
|
||||
"]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"start_time": "2018-11-04T11:15:15.245Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"for sentence in sentances:\n",
|
||||
" generation_parameters['sentence'] = sentence\n",
|
||||
" result = try_on_a_sentence(**generation_parameters)\n",
|
||||
" print(\"\\n\\n Base: {} \\n\\n Result: {}\".format(generation_parameters['sentence'], result))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# DEBUG check for produced string in source text"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"start_time": "2018-11-04T11:15:15.249Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"input_data = open(data_path).read().lower()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"start_time": "2018-11-04T11:15:15.256Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"last_i = 1\n",
|
||||
"while last_i>0:\n",
|
||||
" i = input_data[last_i+50:].index('unspeakable ') + last_i+50\n",
|
||||
" print(input_data[i-10:i+50])\n",
|
||||
" last_i=i"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
|
||||
Reference in New Issue
Block a user