mirror of
https://github.com/wassname/openai-transformer-lm-gutenberg-erotic.git
synced 2026-06-27 16:10:19 +08:00
953 lines
28 KiB
Plaintext
953 lines
28 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Language model code from:\n",
|
|
" https://github.com/rodgzilla/pytorch-openai-transformer-lm/blob/horoscope_language_model"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 1,
|
|
"metadata": {
|
|
"ExecuteTime": {
|
|
"end_time": "2018-11-04T11:15:15.794826Z",
|
|
"start_time": "2018-11-04T11:15:15.434879Z"
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"%matplotlib inline\n",
|
|
"%reload_ext autoreload\n",
|
|
"%autoreload 2"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"# Imports"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 2,
|
|
"metadata": {
|
|
"ExecuteTime": {
|
|
"end_time": "2018-11-04T11:15:17.344506Z",
|
|
"start_time": "2018-11-04T11:15:15.798659Z"
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"import os\n",
|
|
"import pandas as pd\n",
|
|
"import pdb\n",
|
|
"import argparse\n",
|
|
"import itertools\n",
|
|
"\n",
|
|
"import numpy as np\n",
|
|
"\n",
|
|
"from sklearn.metrics import accuracy_score\n",
|
|
"from sklearn.model_selection import train_test_split\n",
|
|
"\n",
|
|
"import torch\n",
|
|
"import torch.nn as nn\n",
|
|
"import torch.nn.functional as F\n",
|
|
"\n",
|
|
"from model_pytorch import TransformerModel, LMHead, load_openai_pretrained_model, DEFAULT_CONFIG\n",
|
|
"from model_pytorch import LanguageModel\n",
|
|
"from utils import encode_dataset, flatten, iter_data, ResultLogger, make_path\n",
|
|
"from text_utils import TextEncoder\n",
|
|
"from opt import OpenAIAdam\n",
|
|
"from loss import LanguageModelingLossCompute"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 3,
|
|
"metadata": {
|
|
"ExecuteTime": {
|
|
"end_time": "2018-11-04T11:15:17.384350Z",
|
|
"start_time": "2018-11-04T11:15:17.347522Z"
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"\n",
|
|
"n_updates = 0\n",
|
|
"best_score = 0"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"# Helpers"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"# Data"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 4,
|
|
"metadata": {
|
|
"ExecuteTime": {
|
|
"end_time": "2018-11-04T11:15:17.433527Z",
|
|
"start_time": "2018-11-04T11:15:17.387055Z"
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"\n",
|
|
"def _chunk_word_list(word_list, max_sequence_len = 50000):\n",
|
|
" # We have to split the text into text of 100.000 characters\n",
|
|
" # because of the parser limitations.\n",
|
|
" word_sequences = [[]]\n",
|
|
" last_sequence_len = 0\n",
|
|
" for word in word_list:\n",
|
|
" # If the last word list has reached the maximum size\n",
|
|
" if last_sequence_len + len(word) > max_sequence_len:\n",
|
|
" # We transform it into a string by rejoining the words\n",
|
|
" word_sequences[-1] = ' '.join(word_sequences[-1])\n",
|
|
" # and then begin a new word sequence\n",
|
|
" word_sequences.append([])\n",
|
|
" last_sequence_len = 0\n",
|
|
" word_sequences[-1].append(word)\n",
|
|
" last_sequence_len += len(word)\n",
|
|
"\n",
|
|
" if type(word_sequences[-1]) == list:\n",
|
|
" word_sequences[-1] = ' '.join(word_sequences[-1])\n",
|
|
"\n",
|
|
" return word_sequences\n",
|
|
"\n",
|
|
"def load_dataset(text_encoder, window_size, path = 'data/erotic_gutenberg_dataset.csv',\n",
|
|
" shuffle = True, seed = 142857,\n",
|
|
" test_size = 0.2):\n",
|
|
" df = pd.read_csv(path)\n",
|
|
" all_text = ' '.join(df.TEXT)\n",
|
|
" word_list = all_text.split(' ')\n",
|
|
" word_sequences = _chunk_word_list(word_list, )\n",
|
|
" encoded_text = text_encoder.encode(word_sequences)\n",
|
|
" word_idx_list = list(itertools.chain.from_iterable(encoded_text))\n",
|
|
" context_list = []\n",
|
|
" target_list = []\n",
|
|
"\n",
|
|
" for start_idx in range(len(word_idx_list) - window_size - 1):\n",
|
|
" context_list.append(word_idx_list[start_idx : start_idx + window_size])\n",
|
|
" target_list.append(word_idx_list[start_idx + window_size])\n",
|
|
"\n",
|
|
" X_train, X_val, y_train, y_val = train_test_split(\n",
|
|
" context_list,\n",
|
|
" target_list,\n",
|
|
" test_size = test_size,\n",
|
|
" shuffle = shuffle,\n",
|
|
" random_state = seed\n",
|
|
" )\n",
|
|
" return (X_train, y_train), (X_val, y_val)\n",
|
|
"\n",
|
|
"def transform_dataset(dataset, encoder, max_len, n_vocab, n_special, n_ctx):\n",
|
|
" n_batch = len(dataset)\n",
|
|
" xmb = np.zeros((n_batch, n_ctx, 2), dtype = np.int32)\n",
|
|
" mmb = np.zeros((n_batch, n_ctx), dtype = np.float32)\n",
|
|
" start = encoder.encoder['_start_']\n",
|
|
" clf_token = encoder.encoder['_classify_']\n",
|
|
" for i, x in enumerate(dataset):\n",
|
|
" x_with_tokens = [start] + x[:max_len] + [clf_token]\n",
|
|
" l_x = len(x_with_tokens)\n",
|
|
" xmb[i, :l_x, 0] = x_with_tokens\n",
|
|
" mmb[i, :l_x] = 1\n",
|
|
" xmb[:, :, 1] = np.arange(n_vocab + n_special, n_vocab + n_special + n_ctx)\n",
|
|
"\n",
|
|
" return xmb, mmb\n",
|
|
"\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Data"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 5,
|
|
"metadata": {
|
|
"ExecuteTime": {
|
|
"end_time": "2018-11-04T11:15:17.474287Z",
|
|
"start_time": "2018-11-04T11:15:17.435784Z"
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"\n",
|
|
"def iter_apply(model, n_batch_train, device, compute_loss_fct, Xs, Ms, Ys, return_logits = True):\n",
|
|
" if return_logits:\n",
|
|
" logits = []\n",
|
|
" cost = 0\n",
|
|
" with torch.no_grad():\n",
|
|
" model.eval()\n",
|
|
" for xmb, mmb, ymb in iter_data(Xs, Ms, Ys, n_batch=n_batch_train, truncate=False, verbose=True):\n",
|
|
" n = len(xmb)\n",
|
|
" XMB = torch.tensor(xmb, dtype=torch.long).to(device)\n",
|
|
" YMB = torch.tensor(ymb, dtype=torch.long).to(device)\n",
|
|
" MMB = torch.tensor(mmb).to(device)\n",
|
|
" lm_logits = model(XMB)\n",
|
|
" lm_logits *= n\n",
|
|
" lm_losses = compute_loss_fct(XMB, YMB, MMB, lm_logits, only_return_losses=True)\n",
|
|
" lm_losses *= n\n",
|
|
" if return_logits:\n",
|
|
" logits.append(lm_logits.to(\"cpu\").numpy())\n",
|
|
" cost += lm_losses.sum().item()\n",
|
|
"\n",
|
|
" if return_logits:\n",
|
|
" logits = np.concatenate(logits, 0)\n",
|
|
" return logits, cost\n",
|
|
"\n",
|
|
" return cost\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 6,
|
|
"metadata": {
|
|
"ExecuteTime": {
|
|
"end_time": "2018-11-04T11:15:17.523802Z",
|
|
"start_time": "2018-11-04T11:15:17.476838Z"
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"\n",
|
|
"def decode_word(text_encoder, idx):\n",
|
|
" if idx not in text_encoder.decoder:\n",
|
|
" return '<oov>'\n",
|
|
"\n",
|
|
" word = text_encoder.decoder[idx]\n",
|
|
"\n",
|
|
" return word[:-4] if word[-4:] == '</w>' else word\n",
|
|
"\n",
|
|
"def decode_sentence(text_encoder, idx_list):\n",
|
|
" word_list = [decode_word(text_encoder, idx) for idx in idx_list]\n",
|
|
"\n",
|
|
" # Fix some weird grammer, but not all\n",
|
|
" replace = [\n",
|
|
" [\"' \", \"'\"],\n",
|
|
" [\" '\", \"'\"],\n",
|
|
" [\" ,\", \",\"],\n",
|
|
" [\" .\", \".\"],\n",
|
|
" [\" i \", \" I \"],\n",
|
|
" [\" n't\", \"n't\"],\n",
|
|
" ]\n",
|
|
" results2 = ' '.join(word_list)\n",
|
|
" for a,b in replace:\n",
|
|
" results2 = results2.replace(a, b)\n",
|
|
"\n",
|
|
" return results2\n",
|
|
"\n",
|
|
"def try_on_a_sentence(model, text_encoder, sentence, window_size,\n",
|
|
" n_vocab, n_special, n_ctx, device,\n",
|
|
" final_len = 200):\n",
|
|
" model.eval()\n",
|
|
" start_token = text_encoder.encoder['_start_']\n",
|
|
" clf_token = text_encoder.encoder['_classify_']\n",
|
|
" encoded_text = text_encoder.encode([sentence])[0]\n",
|
|
" while len(encoded_text) < final_len:\n",
|
|
" # We take the last 'window_size' words of the text being generated\n",
|
|
" # and run it through the model.\n",
|
|
" context = encoded_text[-window_size:]\n",
|
|
" X_trans, X_mask = transform_dataset(\n",
|
|
" [context],\n",
|
|
" text_encoder,\n",
|
|
" window_size,\n",
|
|
" n_vocab,\n",
|
|
" n_special,\n",
|
|
" n_ctx\n",
|
|
" )\n",
|
|
" XMB = torch.tensor(X_trans, dtype = torch.long).to(device)\n",
|
|
" lm_logits = model(XMB)\n",
|
|
" \n",
|
|
" # We truncate the resulting predictions to actual vocabulary\n",
|
|
" # words in order to exclude special tokens and positional\n",
|
|
" # embeddings.\n",
|
|
" lm_logits = lm_logits[:, : n_vocab]\n",
|
|
" \n",
|
|
" # We then select the logit corresponding to the 'clf_token'\n",
|
|
" # position (last one of the sequence).\n",
|
|
" X_trans_tensor = torch.from_numpy(X_trans)\n",
|
|
" clf_token_bool_idx = X_trans_tensor[0, :, 0] == clf_token\n",
|
|
" \n",
|
|
" # probabilistic sample so we don't get into loops\n",
|
|
" predictions = torch.distributions.Multinomial(logits=lm_logits).sample().argmax(dim = 1)\n",
|
|
" pred = predictions[clf_token_bool_idx[1:]].item()\n",
|
|
" encoded_text.append(pred)\n",
|
|
"\n",
|
|
" return decode_sentence(text_encoder, encoded_text)\n",
|
|
"\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"ExecuteTime": {
|
|
"end_time": "2018-11-04T10:59:05.230384Z",
|
|
"start_time": "2018-11-04T10:58:58.265400Z"
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"ExecuteTime": {
|
|
"end_time": "2018-11-04T10:57:46.556244Z",
|
|
"start_time": "2018-11-04T10:57:46.513733Z"
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": []
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Run"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 7,
|
|
"metadata": {
|
|
"ExecuteTime": {
|
|
"end_time": "2018-11-04T11:15:17.577518Z",
|
|
"start_time": "2018-11-04T11:15:17.526474Z"
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"\n",
|
|
"def run_epoch(model, n_batch_train, device, compute_loss_fct, logger,\n",
|
|
" save_dir, desc, submit, n_valid, n_epochs, X_train,\n",
|
|
" X_train_mask, y_train, X_val, X_val_mask, y_val,\n",
|
|
" generation_params):\n",
|
|
" for xmb, mmb, ymb in iter_data(X_train,\n",
|
|
" X_train_mask,\n",
|
|
" y_train,\n",
|
|
" n_batch = n_batch_train,\n",
|
|
" truncate=True,\n",
|
|
" verbose=True):\n",
|
|
" global n_updates\n",
|
|
" model.train()\n",
|
|
" XMB = torch.tensor(xmb, dtype=torch.long).to(device)\n",
|
|
" YMB = torch.tensor(ymb, dtype=torch.long).to(device)\n",
|
|
" MMB = torch.tensor(mmb).to(device)\n",
|
|
" lm_logits = model(XMB)\n",
|
|
" compute_loss_fct(XMB, YMB, MMB, lm_logits)\n",
|
|
" if n_updates % 500 == 0:\n",
|
|
" log(\n",
|
|
" model,\n",
|
|
" n_batch_train,\n",
|
|
" device,\n",
|
|
" compute_loss_fct,\n",
|
|
" logger,\n",
|
|
" save_dir,\n",
|
|
" desc,\n",
|
|
" submit,\n",
|
|
" n_valid,\n",
|
|
" n_epochs,\n",
|
|
" n_updates,\n",
|
|
" X_train,\n",
|
|
" X_train_mask,\n",
|
|
" y_train,\n",
|
|
" X_val,\n",
|
|
" X_val_mask,\n",
|
|
" y_val,\n",
|
|
" generation_params\n",
|
|
" )\n",
|
|
" n_updates += 1\n",
|
|
"\n",
|
|
"def log(model, n_batch_train, device, compute_loss_fct, logger,\n",
|
|
" save_dir, desc, submit, n_valid, n_epochs, n_updates, X_train,\n",
|
|
" X_train_mask, y_train, X_val, X_val_mask, y_val,\n",
|
|
" generation_params):\n",
|
|
" global best_score\n",
|
|
" result = try_on_a_sentence(**generation_params)\n",
|
|
" print(\"\\n\\n Base: {} \\n\\n Result: {}\".format(generation_params['sentence'], result))\n",
|
|
" print(\"\\nLogging\")\n",
|
|
" tr_cost = iter_apply(\n",
|
|
" model,\n",
|
|
" n_batch_train,\n",
|
|
" device,\n",
|
|
" compute_loss_fct,\n",
|
|
" X_train[:n_valid],\n",
|
|
" X_train_mask[:n_valid],\n",
|
|
" y_train[:n_valid],\n",
|
|
" False\n",
|
|
" )\n",
|
|
" va_cost = iter_apply(\n",
|
|
" model,\n",
|
|
" n_batch_train,\n",
|
|
" device,\n",
|
|
" compute_loss_fct,\n",
|
|
" X_val,\n",
|
|
" X_val_mask,\n",
|
|
" y_val,\n",
|
|
" False\n",
|
|
" )\n",
|
|
" tr_cost = tr_cost / len(y_train[:n_valid])\n",
|
|
" va_cost = va_cost / n_valid\n",
|
|
" logger.log(\n",
|
|
" n_epochs = n_epochs,\n",
|
|
" n_updates = n_updates,\n",
|
|
" tr_cost = tr_cost,\n",
|
|
" va_cost = va_cost\n",
|
|
" )\n",
|
|
" print('\\n%d %d %.3f %.3f' % (n_epochs, n_updates, tr_cost, va_cost))\n",
|
|
" if submit:\n",
|
|
" score = va_cost\n",
|
|
" if score > best_score:\n",
|
|
" best_score = score\n",
|
|
" path = os.path.join(save_dir, desc, 'best_params')\n",
|
|
" torch.save(model.state_dict(), make_path(path))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"# Params"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 8,
|
|
"metadata": {
|
|
"ExecuteTime": {
|
|
"end_time": "2018-11-04T11:15:18.255342Z",
|
|
"start_time": "2018-11-04T11:15:17.580075Z"
|
|
}
|
|
},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"({'afn': 'gelu',\n",
|
|
" 'attn_pdrop': 0.1,\n",
|
|
" 'clf_pdrop': 0.1,\n",
|
|
" 'embd_pdrop': 0.1,\n",
|
|
" 'n_embd': 768,\n",
|
|
" 'n_head': 12,\n",
|
|
" 'n_layer': 12,\n",
|
|
" 'resid_pdrop': 0.1},\n",
|
|
" {'n_ctx': 130, 'n_special': 2, 'total_vocab_size': 40610})"
|
|
]
|
|
},
|
|
"execution_count": 8,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"# Training configuration\n",
|
|
"epochs = 3\n",
|
|
"n_batch_train = 18\n",
|
|
"window_size = 128\n",
|
|
"max_len = window_size\n",
|
|
"# General configuration\n",
|
|
"save_dir = 'save/'\n",
|
|
"log_dir = 'log/'\n",
|
|
"desc = 'erotic_gutenberg'\n",
|
|
"submit = True\n",
|
|
"args = DEFAULT_CONFIG\n",
|
|
"logger = ResultLogger(\n",
|
|
" path = os.path.join(\n",
|
|
" log_dir,\n",
|
|
" '{}.jsonl'.format(desc)\n",
|
|
" ),\n",
|
|
" **args.__dict__\n",
|
|
")\n",
|
|
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
|
|
"bpe_path = 'model/vocab_40000.bpe'\n",
|
|
"encoder_path = 'model/encoder_bpe_40000.json'\n",
|
|
"data_path = 'data/erotic_gutenberg_dataset.csv'\n",
|
|
"text_encoder = TextEncoder(encoder_path, bpe_path)\n",
|
|
"encoder = text_encoder.encoder\n",
|
|
"n_special = 2\n",
|
|
"n_vocab = len(encoder)\n",
|
|
"encoder['_start_'] = len(encoder)\n",
|
|
"encoder['_classify_'] = len(encoder)\n",
|
|
"clf_token = encoder['_classify_']\n",
|
|
"\n",
|
|
"n_ctx = window_size + n_special\n",
|
|
"total_vocab_size = n_vocab + n_special + n_ctx\n",
|
|
"\n",
|
|
"args, dict(n_ctx=n_ctx, total_vocab_size=total_vocab_size, n_special=n_special)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"ExecuteTime": {
|
|
"end_time": "2018-11-04T09:22:38.377992Z",
|
|
"start_time": "2018-11-04T09:22:29.584Z"
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": []
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"# Dataset"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 9,
|
|
"metadata": {
|
|
"ExecuteTime": {
|
|
"end_time": "2018-11-04T11:16:50.651449Z",
|
|
"start_time": "2018-11-04T11:15:18.257548Z"
|
|
}
|
|
},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"application/vnd.jupyter.widget-view+json": {
|
|
"model_id": "",
|
|
"version_major": 2,
|
|
"version_minor": 0
|
|
},
|
|
"text/plain": [
|
|
"HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=151), HTML(value='')), layout=Layout(display=…"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"\r"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"((1644206, 130, 2), (1644206, 130))"
|
|
]
|
|
},
|
|
"execution_count": 9,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"\n",
|
|
"(X_train, y_train), (X_val, y_val) = load_dataset(\n",
|
|
" text_encoder,\n",
|
|
" window_size = window_size,\n",
|
|
" path = data_path\n",
|
|
")\n",
|
|
"n_train = len(y_train)\n",
|
|
"n_valid = len(y_val) // 10\n",
|
|
"n_updates_total = (n_train // n_batch_train) * epochs\n",
|
|
"\n",
|
|
"X_train_trans, X_train_mask = transform_dataset(\n",
|
|
" X_train,\n",
|
|
" text_encoder,\n",
|
|
" window_size,\n",
|
|
" n_vocab,\n",
|
|
" n_special,\n",
|
|
" n_ctx\n",
|
|
")\n",
|
|
"X_val_trans, X_val_mask = transform_dataset(\n",
|
|
" X_val,\n",
|
|
" text_encoder,\n",
|
|
" window_size,\n",
|
|
" n_vocab,\n",
|
|
" n_special,\n",
|
|
" n_ctx\n",
|
|
")\n",
|
|
"X_train_trans.shape, X_train_mask.shape"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"# Model"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"ExecuteTime": {
|
|
"end_time": "2018-11-04T09:22:38.380301Z",
|
|
"start_time": "2018-11-04T09:22:29.593Z"
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 10,
|
|
"metadata": {
|
|
"ExecuteTime": {
|
|
"end_time": "2018-11-04T11:16:58.993708Z",
|
|
"start_time": "2018-11-04T11:16:50.654425Z"
|
|
}
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Loading weights...\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"1"
|
|
]
|
|
},
|
|
"execution_count": 10,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"language_model = LanguageModel(\n",
|
|
" args,\n",
|
|
" vocab = total_vocab_size,\n",
|
|
" n_ctx = n_ctx\n",
|
|
")\n",
|
|
"load_openai_pretrained_model(\n",
|
|
" language_model.transformer,\n",
|
|
" n_ctx = n_ctx,\n",
|
|
" n_special = n_special\n",
|
|
")\n",
|
|
"language_model.to(device)\n",
|
|
"1"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"# init opt, loss"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 11,
|
|
"metadata": {
|
|
"ExecuteTime": {
|
|
"end_time": "2018-11-04T11:16:59.051091Z",
|
|
"start_time": "2018-11-04T11:16:59.002806Z"
|
|
}
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"/home/wassname/.pyenv/versions/3.5.3/envs/jupyter3/lib/python3.5/site-packages/torch/nn/functional.py:52: UserWarning: size_average and reduce args will be deprecated, please use reduction='none' instead.\n",
|
|
" warnings.warn(warning.format(ret))\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"model_opt = OpenAIAdam(\n",
|
|
" params = language_model.parameters(),\n",
|
|
" lr = 6.25e-5,\n",
|
|
" schedule = 'warmup_linear',\n",
|
|
" warmup = 0.002,\n",
|
|
" t_total = n_updates_total,\n",
|
|
" b1 = 0.9,\n",
|
|
" b2 = 0.999,\n",
|
|
" e = 1e-8,\n",
|
|
" l2 = 0.01,\n",
|
|
" vector_l2 = 'store_true',\n",
|
|
" max_grad_norm = 1\n",
|
|
")\n",
|
|
"criterion = nn.CrossEntropyLoss(reduce = False)\n",
|
|
"compute_loss_fct = LanguageModelingLossCompute(\n",
|
|
" lm_criterion = criterion,\n",
|
|
" opt = model_opt\n",
|
|
")\n",
|
|
"\n",
|
|
"generation_parameters = {\n",
|
|
" 'model' : language_model,\n",
|
|
" 'text_encoder' : text_encoder,\n",
|
|
" 'sentence' : 'You had a great morning but your afternoon will be ruined because',\n",
|
|
" 'window_size' : window_size,\n",
|
|
" 'n_vocab' : n_vocab,\n",
|
|
" 'n_special' : n_special,\n",
|
|
" 'n_ctx' : n_ctx,\n",
|
|
" 'device' : device,\n",
|
|
" 'final_len' : 150\n",
|
|
"}"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"# run"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"ExecuteTime": {
|
|
"start_time": "2018-11-04T11:15:15.224Z"
|
|
},
|
|
"scrolled": true
|
|
},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"application/vnd.jupyter.widget-view+json": {
|
|
"model_id": "541a6befc51c4f9280ad5107ac3482d9",
|
|
"version_major": 2,
|
|
"version_minor": 0
|
|
},
|
|
"text/plain": [
|
|
"HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=91344), HTML(value='')), layout=Layout(displa…"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"data": {
|
|
"application/vnd.jupyter.widget-view+json": {
|
|
"model_id": "",
|
|
"version_major": 2,
|
|
"version_minor": 0
|
|
},
|
|
"text/plain": [
|
|
"HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=1), HTML(value='')), layout=Layout(display='i…"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"\n",
|
|
"\n",
|
|
" Base: You had a great morning but your afternoon will be ruined because \n",
|
|
"\n",
|
|
" Result: you had a great morning but your afternoon will be ruined because of me ! \" \n",
|
|
" \" alright ! no more of this horrible whispered reports and comments from agnes. I hear derrick has been distracted with her. maybe it will bring some sense into his head, \" I say. \n",
|
|
" we're quiet for a moment when derrick talks. \" I didn't realize you two were gay. I knew it ! if I had known about this, I wouldn't have had you hide in your closet ! \" \n",
|
|
" I look at him and in a japanese accent, say, \" honey, that would have made your afternoon a lot at least better. you can delegate your responsibility to derrick and I have spoken to him a lot lately ! \" \n",
|
|
" \" I can't\n",
|
|
"\n",
|
|
"Logging\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"application/vnd.jupyter.widget-view+json": {
|
|
"model_id": "fd3c2c8585e7445a8312b206ed4830fd",
|
|
"version_major": 2,
|
|
"version_minor": 0
|
|
},
|
|
"text/plain": [
|
|
"HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=2283), HTML(value='')), layout=Layout(display…"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
}
|
|
],
|
|
"source": [
|
|
"for epoch in range(epochs):\n",
|
|
" run_epoch(\n",
|
|
" model = language_model,\n",
|
|
" n_batch_train = n_batch_train,\n",
|
|
" device = device,\n",
|
|
" compute_loss_fct = compute_loss_fct,\n",
|
|
" logger = logger,\n",
|
|
" save_dir = save_dir,\n",
|
|
" desc = desc,\n",
|
|
" submit = submit,\n",
|
|
" n_valid = n_valid,\n",
|
|
" n_epochs = epoch,\n",
|
|
" X_train = X_train_trans,\n",
|
|
" X_train_mask = X_train_mask,\n",
|
|
" y_train = y_train,\n",
|
|
" X_val = X_val_trans,\n",
|
|
" X_val_mask = X_val_mask,\n",
|
|
" y_val = y_val,\n",
|
|
" generation_params = generation_parameters\n",
|
|
" )"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"ExecuteTime": {
|
|
"start_time": "2018-11-04T11:15:15.233Z"
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"sentances = [\n",
|
|
" 'i want you to want',\n",
|
|
" 'please help me',\n",
|
|
" 'let us run far away from',\n",
|
|
" 'rosy',\n",
|
|
" 'that unspeakable creature'\n",
|
|
" 'when can I see you',\n",
|
|
" 'I must',\n",
|
|
" 'gaze at your enhanting'\n",
|
|
"]"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"ExecuteTime": {
|
|
"start_time": "2018-11-04T11:15:15.245Z"
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"for sentence in sentances:\n",
|
|
" generation_parameters['sentence'] = sentence\n",
|
|
" result = try_on_a_sentence(**generation_parameters)\n",
|
|
" print(\"\\n\\n Base: {} \\n\\n Result: {}\".format(generation_parameters['sentence'], result))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"# DEBUG check for produced string in source text"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"ExecuteTime": {
|
|
"start_time": "2018-11-04T11:15:15.249Z"
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"input_data = open(data_path).read().lower()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"ExecuteTime": {
|
|
"start_time": "2018-11-04T11:15:15.256Z"
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"last_i = 1\n",
|
|
"while last_i>0:\n",
|
|
" i = input_data[last_i+50:].index('unspeakable ') + last_i+50\n",
|
|
" print(input_data[i-10:i+50])\n",
|
|
" last_i=i"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "jupyter3",
|
|
"language": "python",
|
|
"name": "jupyter3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.5.3"
|
|
},
|
|
"toc": {
|
|
"colors": {
|
|
"hover_highlight": "#DAA520",
|
|
"navigate_num": "#000000",
|
|
"navigate_text": "#333333",
|
|
"running_highlight": "#FF0000",
|
|
"selected_highlight": "#FFD700",
|
|
"sidebar_border": "#EEEEEE",
|
|
"wrapper_background": "#FFFFFF"
|
|
},
|
|
"moveMenuLeft": true,
|
|
"nav_menu": {
|
|
"height": "240px",
|
|
"width": "251px"
|
|
},
|
|
"navigate_menu": true,
|
|
"number_sections": true,
|
|
"sideBar": true,
|
|
"threshold": 4,
|
|
"toc_cell": false,
|
|
"toc_position": {
|
|
"height": "526px",
|
|
"left": "0px",
|
|
"right": "1191px",
|
|
"top": "149px",
|
|
"width": "185px"
|
|
},
|
|
"toc_section_display": "block",
|
|
"toc_window_display": true,
|
|
"widenNotebook": false
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 2
|
|
}
|