diff --git a/datasets.py b/datasets.py index 2d8d31d..1383be1 100644 --- a/datasets.py +++ b/datasets.py @@ -2,7 +2,7 @@ import os import csv import numpy as np -from tqdm import tqdm +from tqdm import tqdm_notebook as tqdm from sklearn.utils import shuffle from sklearn.model_selection import train_test_split @@ -16,7 +16,7 @@ def _rocstories(path): ct1 = [] ct2 = [] y = [] - for i, line in enumerate(tqdm(list(f), ncols=80, leave=False)): + for i, line in enumerate(tqdm(list(f), ncols=80, mininterval=10, leave=False)): if i > 0: s = ' '.join(line[1:5]) # 4 sentances st.append(s) diff --git a/download_gutenberg_erotica.ipynb b/download_gutenberg_erotica.ipynb index aa6aa92..15db002 100644 --- a/download_gutenberg_erotica.ipynb +++ b/download_gutenberg_erotica.ipynb @@ -40,11 +40,11 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "metadata": { "ExecuteTime": { - "end_time": "2018-11-04T05:17:20.517091Z", - "start_time": "2018-11-04T05:17:20.141438Z" + "end_time": "2018-11-04T08:54:51.977129Z", + "start_time": "2018-11-04T08:54:51.641106Z" } }, "outputs": [], @@ -62,11 +62,11 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 2, "metadata": { "ExecuteTime": { - "end_time": "2018-11-04T05:17:20.567180Z", - "start_time": "2018-11-04T05:17:20.519956Z" + "end_time": "2018-11-04T08:54:52.021478Z", + "start_time": "2018-11-04T08:54:51.980984Z" } }, "outputs": [], @@ -134,11 +134,11 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 3, "metadata": { "ExecuteTime": { - "end_time": "2018-11-04T05:17:09.263186Z", - "start_time": "2018-11-04T05:17:09.246892Z" + "end_time": "2018-11-04T08:54:52.038785Z", + "start_time": "2018-11-04T08:54:52.024258Z" } }, "outputs": [], @@ -225,14 +225,73 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 4, "metadata": { "ExecuteTime": { - "end_time": "2018-11-04T02:41:00.594210Z", - "start_time": "2018-11-04T02:37:49.143793Z" + "end_time": "2018-11-04T08:55:45.417570Z", + "start_time": "2018-11-04T08:54:52.040977Z" } }, - "outputs": [], + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "2e1c657a58ac42f9bb99cc8221138bee", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(IntProgress(value=0, max=56), HTML(value='')))" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + " Problem: No title found\n", + "\n", + " Problem: No '*** START' seen\n", + "\n", + " Problem: No '*** END' seen\n", + "\n", + " Problem: No title found\n", + "\n", + " Problem: No '*** START' seen\n", + "\n", + " Problem: No '*** END' seen\n", + "\n" + ] + }, + { + "ename": "KeyboardInterrupt", + "evalue": "", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0;31m# first download index\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0mindex_url\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m\"http://www.gutenberg.org/files/{bid:}\"\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mformat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbid\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mbid\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 5\u001b[0;31m \u001b[0mr\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mrequests\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mindex_url\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 6\u001b[0m \u001b[0mr\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mraise_for_status\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 7\u001b[0m \u001b[0msoup\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mbs4\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mBeautifulSoup\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mr\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcontent\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"html5lib\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/.pyenv/versions/3.5.3/envs/jupyter3/lib/python3.5/site-packages/requests/api.py\u001b[0m in \u001b[0;36mget\u001b[0;34m(url, params, **kwargs)\u001b[0m\n\u001b[1;32m 70\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 71\u001b[0m \u001b[0mkwargs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msetdefault\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'allow_redirects'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 72\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mrequest\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'get'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0murl\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mparams\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mparams\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 73\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 74\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/.pyenv/versions/3.5.3/envs/jupyter3/lib/python3.5/site-packages/requests/api.py\u001b[0m in \u001b[0;36mrequest\u001b[0;34m(method, url, **kwargs)\u001b[0m\n\u001b[1;32m 56\u001b[0m \u001b[0;31m# cases, and look like a memory leak in others.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 57\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0msessions\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mSession\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0msession\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 58\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0msession\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrequest\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmethod\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mmethod\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0murl\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0murl\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 59\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 60\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/.pyenv/versions/3.5.3/envs/jupyter3/lib/python3.5/site-packages/requests/sessions.py\u001b[0m in \u001b[0;36mrequest\u001b[0;34m(self, method, url, params, data, headers, cookies, files, auth, timeout, allow_redirects, proxies, hooks, stream, verify, cert, json)\u001b[0m\n\u001b[1;32m 506\u001b[0m }\n\u001b[1;32m 507\u001b[0m \u001b[0msend_kwargs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mupdate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msettings\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 508\u001b[0;31m \u001b[0mresp\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mprep\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0msend_kwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 509\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 510\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mresp\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/.pyenv/versions/3.5.3/envs/jupyter3/lib/python3.5/site-packages/requests/sessions.py\u001b[0m in \u001b[0;36msend\u001b[0;34m(self, request, **kwargs)\u001b[0m\n\u001b[1;32m 616\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 617\u001b[0m \u001b[0;31m# Send the request\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 618\u001b[0;31m \u001b[0mr\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0madapter\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mrequest\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 619\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 620\u001b[0m \u001b[0;31m# Total elapsed time of the request (approximately)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/.pyenv/versions/3.5.3/envs/jupyter3/lib/python3.5/site-packages/requests/adapters.py\u001b[0m in \u001b[0;36msend\u001b[0;34m(self, request, stream, timeout, verify, cert, proxies)\u001b[0m\n\u001b[1;32m 438\u001b[0m \u001b[0mdecode_content\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 439\u001b[0m \u001b[0mretries\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmax_retries\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 440\u001b[0;31m \u001b[0mtimeout\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mtimeout\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 441\u001b[0m )\n\u001b[1;32m 442\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/.pyenv/versions/3.5.3/envs/jupyter3/lib/python3.5/site-packages/urllib3/connectionpool.py\u001b[0m in \u001b[0;36murlopen\u001b[0;34m(self, method, url, body, headers, retries, redirect, assert_same_host, timeout, pool_timeout, release_conn, chunked, body_pos, **response_kw)\u001b[0m\n\u001b[1;32m 599\u001b[0m \u001b[0mtimeout\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mtimeout_obj\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 600\u001b[0m \u001b[0mbody\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mbody\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mheaders\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mheaders\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 601\u001b[0;31m chunked=chunked)\n\u001b[0m\u001b[1;32m 602\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 603\u001b[0m \u001b[0;31m# If we're going to release the connection in ``finally:``, then\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/.pyenv/versions/3.5.3/envs/jupyter3/lib/python3.5/site-packages/urllib3/connectionpool.py\u001b[0m in \u001b[0;36m_make_request\u001b[0;34m(self, conn, method, url, timeout, chunked, **httplib_request_kw)\u001b[0m\n\u001b[1;32m 355\u001b[0m \u001b[0mconn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrequest_chunked\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmethod\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0murl\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mhttplib_request_kw\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 356\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 357\u001b[0;31m \u001b[0mconn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrequest\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmethod\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0murl\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mhttplib_request_kw\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 358\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 359\u001b[0m \u001b[0;31m# Reset the timeout for the recv() on the socket\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/.pyenv/versions/3.5.3/lib/python3.5/http/client.py\u001b[0m in \u001b[0;36mrequest\u001b[0;34m(self, method, url, body, headers)\u001b[0m\n\u001b[1;32m 1105\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mrequest\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmethod\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0murl\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbody\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mheaders\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m{\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1106\u001b[0m \u001b[0;34m\"\"\"Send a complete request to the server.\"\"\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1107\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_send_request\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmethod\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0murl\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbody\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mheaders\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1108\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1109\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_set_content_length\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbody\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmethod\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/.pyenv/versions/3.5.3/lib/python3.5/http/client.py\u001b[0m in \u001b[0;36m_send_request\u001b[0;34m(self, method, url, body, headers)\u001b[0m\n\u001b[1;32m 1150\u001b[0m \u001b[0;31m# default charset of iso-8859-1.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1151\u001b[0m \u001b[0mbody\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_encode\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbody\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'body'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1152\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mendheaders\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbody\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1153\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1154\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mgetresponse\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/.pyenv/versions/3.5.3/lib/python3.5/http/client.py\u001b[0m in \u001b[0;36mendheaders\u001b[0;34m(self, message_body)\u001b[0m\n\u001b[1;32m 1101\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1102\u001b[0m \u001b[0;32mraise\u001b[0m \u001b[0mCannotSendHeader\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1103\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_send_output\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmessage_body\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1104\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1105\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mrequest\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmethod\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0murl\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbody\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mheaders\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m{\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/.pyenv/versions/3.5.3/lib/python3.5/http/client.py\u001b[0m in \u001b[0;36m_send_output\u001b[0;34m(self, message_body)\u001b[0m\n\u001b[1;32m 932\u001b[0m \u001b[0;32mdel\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_buffer\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 933\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 934\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmsg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 935\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mmessage_body\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 936\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmessage_body\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/.pyenv/versions/3.5.3/lib/python3.5/http/client.py\u001b[0m in \u001b[0;36msend\u001b[0;34m(self, data)\u001b[0m\n\u001b[1;32m 875\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msock\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 876\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mauto_open\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 877\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mconnect\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 878\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 879\u001b[0m \u001b[0;32mraise\u001b[0m \u001b[0mNotConnected\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/.pyenv/versions/3.5.3/envs/jupyter3/lib/python3.5/site-packages/urllib3/connection.py\u001b[0m in \u001b[0;36mconnect\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 164\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 165\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mconnect\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 166\u001b[0;31m \u001b[0mconn\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_new_conn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 167\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_prepare_conn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mconn\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 168\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/.pyenv/versions/3.5.3/envs/jupyter3/lib/python3.5/site-packages/urllib3/connection.py\u001b[0m in \u001b[0;36m_new_conn\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 139\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 140\u001b[0m conn = connection.create_connection(\n\u001b[0;32m--> 141\u001b[0;31m (self.host, self.port), self.timeout, **extra_kw)\n\u001b[0m\u001b[1;32m 142\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 143\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mSocketTimeout\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/.pyenv/versions/3.5.3/envs/jupyter3/lib/python3.5/site-packages/urllib3/util/connection.py\u001b[0m in \u001b[0;36mcreate_connection\u001b[0;34m(address, timeout, source_address, socket_options)\u001b[0m\n\u001b[1;32m 71\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0msource_address\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 72\u001b[0m \u001b[0msock\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbind\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msource_address\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 73\u001b[0;31m \u001b[0msock\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mconnect\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msa\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 74\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0msock\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 75\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mKeyboardInterrupt\u001b[0m: " + ] + } + ], "source": [ "for bid in tqdm(ids):\n", " \n", @@ -294,14 +353,302 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 5, "metadata": { "ExecuteTime": { - "end_time": "2018-11-04T05:19:03.832336Z", - "start_time": "2018-11-04T05:18:59.180170Z" + "end_time": "2018-11-04T08:55:48.451860Z", + "start_time": "2018-11-04T08:55:46.818680Z" } }, - "outputs": [], + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
authorcontentextralanguagetitle
0Albert MordellTHE EROTIC MOTIVE IN LITERATURE\\n\\nTHE EROTIC ...[Project Gutenberg's The Erotic Motive in Lit...EnglishThe Erotic Motive in Literature
1Anonymous[Transcriber's note: Anonymous, _Laura Middlet...[Project Gutenberg's Laura Middleton; Her Brot...EnglishLaura Middleton; Her Brother and her Lover
2J. K. HuysmansLÀ-BAS\\n\\n(DOWN THERE)\\n\\nby J.K. HUYSMANS\\n\\n...[The Project Gutenberg EBook of Là-bas, by J. ...EnglishLà-bas
3John ClelandMEMOIRS OF FANNY HILL\\n\\nBy John Cleland\\n\\n_A...[The Project Gutenberg EBook of Memoirs Of Fan...EnglishMemoirs Of Fanny Hill, A New and Genuine Editi...
4Havelock EllisVOLUME 1 (OF 6)***\\n\\nE-text prepared by Julie...[The Project Gutenberg eBook, Studies in the P...EnglishStudies in the Psychology of Sex, Volume 1 (of 6)
5Anonymous[Transcriber's note: Anonymous, _Forbidden fru...[The Project Gutenberg EBook of Forbidden Frui...EnglishForbidden Fruit, Luscious and exciting story a...
6Kate PercivalThe Life and Amours\\n\\nOF THE\\n\\nBeautiful, Ga...[The Project Gutenberg EBook of The Life and A...EnglishThe Life and Amours of the Beautiful, Gay and ...
7Various[Transcriber's Note: The following was proofre...[The Project Gutenberg EBook of The Fifteen Co...EnglishThe Fifteen Comforts of Matrimony: Responses F...
8Anonymous[Transcriber's note: Anonymous, _The power of ...[The Project Gutenberg EBook of The Power of M...EnglishThe Power of Mesmerism, A Highly Erotic Narrat...
9Denis Diderot_Les Bijoux Indiscrets._\\n\\nOR,\\n\\nThe Indiscr...[The Project Gutenberg EBook of Les Bijoux Ind...EnglishLes Bijoux Indiscrets, or, The Indiscreet Toys
10AnonymousTHE LADIES DELIGHT.\\n\\nCONTAINING,\\n\\nI. An Ad...[The Project Gutenberg EBook of The Ladies Del...EnglishThe Ladies Delight
11Havelock EllisVOLUME 5 (OF 6)***\\n\\nE-text prepared by Julie...[The Project Gutenberg eBook, Studies in the P...EnglishStudies in the Psychology of Sex, Volume 5 (of 6)
12J. K. HuysmansLA-BAS\\n\\n(DOWN THERE)\\n\\nby J.K. HUYSMANS\\n\\n...[The Project Gutenberg EBook of La-bas, by J. ...EnglishLa-bas
13Havelock EllisVOLUME 5 (OF 6)***\\n\\nE-text prepared by Julie...[The Project Gutenberg eBook, Studies in the P...EnglishStudies in the Psychology of Sex, Volume 5 (of 6)
14Friedrich Karl ForbergMANUAL\\n\\nOF\\n\\nClassical Erotology\\n\\n(De fig...[The Project Gutenberg EBook of Manual of Cla...EnglishManual of Classical Erotology (De figuris Vene...
15Anonymous[Transcriber's note: Anonymous, _The power of ...[The Project Gutenberg EBook of The Power of M...EnglishThe Power of Mesmerism, A Highly Erotic Narrat...
16Georg BrandesMAIN CURRENTS IN NINETEEN CENTURY LITERATURE\\n...[The Project Gutenberg EBook of Main Currents...EnglishMain Currents in Nineteenth Century Literature...
17AnonymousThe Romance of Lust\\n\\n(1873)\\n\\nA classic Vic...[, The Project Gutenberg EBook of The Romance ...EnglishThe Romance of Lust A classic Victorian erotic...
18L. BrovanTwo Hundred and fifty Copies of this Work have...[The Project Gutenberg EBook of Anthologica R...EnglishAnthologica Rarissima: The Way of a Virgin
19Various[Transcriber's Note: The following was proofre...[The Project Gutenberg EBook of The Fifteen Co...EnglishThe Fifteen Comforts of Matrimony: Responses f...
20Kate PercivalThe Life and Amours\\n\\nOF THE\\n\\nBeautiful, Ga...[The Project Gutenberg EBook of The Life and A...EnglishThe Life and Amours of the Beautiful, Gay and ...
21Havelock EllisVOLUME 1 (OF 6)***\\n\\nE-text prepared by Julie...[The Project Gutenberg eBook, Studies in the P...EnglishStudies in the Psychology of Sex, Volume 1 (of 6)
\n", + "
" + ], + "text/plain": [ + " author content \\\n", + "0 Albert Mordell THE EROTIC MOTIVE IN LITERATURE\\n\\nTHE EROTIC ... \n", + "1 Anonymous [Transcriber's note: Anonymous, _Laura Middlet... \n", + "2 J. K. Huysmans LÀ-BAS\\n\\n(DOWN THERE)\\n\\nby J.K. HUYSMANS\\n\\n... \n", + "3 John Cleland MEMOIRS OF FANNY HILL\\n\\nBy John Cleland\\n\\n_A... \n", + "4 Havelock Ellis VOLUME 1 (OF 6)***\\n\\nE-text prepared by Julie... \n", + "5 Anonymous [Transcriber's note: Anonymous, _Forbidden fru... \n", + "6 Kate Percival The Life and Amours\\n\\nOF THE\\n\\nBeautiful, Ga... \n", + "7 Various [Transcriber's Note: The following was proofre... \n", + "8 Anonymous [Transcriber's note: Anonymous, _The power of ... \n", + "9 Denis Diderot _Les Bijoux Indiscrets._\\n\\nOR,\\n\\nThe Indiscr... \n", + "10 Anonymous THE LADIES DELIGHT.\\n\\nCONTAINING,\\n\\nI. An Ad... \n", + "11 Havelock Ellis VOLUME 5 (OF 6)***\\n\\nE-text prepared by Julie... \n", + "12 J. K. Huysmans LA-BAS\\n\\n(DOWN THERE)\\n\\nby J.K. HUYSMANS\\n\\n... \n", + "13 Havelock Ellis VOLUME 5 (OF 6)***\\n\\nE-text prepared by Julie... \n", + "14 Friedrich Karl Forberg MANUAL\\n\\nOF\\n\\nClassical Erotology\\n\\n(De fig... \n", + "15 Anonymous [Transcriber's note: Anonymous, _The power of ... \n", + "16 Georg Brandes MAIN CURRENTS IN NINETEEN CENTURY LITERATURE\\n... \n", + "17 Anonymous The Romance of Lust\\n\\n(1873)\\n\\nA classic Vic... \n", + "18 L. Brovan Two Hundred and fifty Copies of this Work have... \n", + "19 Various [Transcriber's Note: The following was proofre... \n", + "20 Kate Percival The Life and Amours\\n\\nOF THE\\n\\nBeautiful, Ga... \n", + "21 Havelock Ellis VOLUME 1 (OF 6)***\\n\\nE-text prepared by Julie... \n", + "\n", + " extra language \\\n", + "0 [Project Gutenberg's The Erotic Motive in Lit... English \n", + "1 [Project Gutenberg's Laura Middleton; Her Brot... English \n", + "2 [The Project Gutenberg EBook of Là-bas, by J. ... English \n", + "3 [The Project Gutenberg EBook of Memoirs Of Fan... English \n", + "4 [The Project Gutenberg eBook, Studies in the P... English \n", + "5 [The Project Gutenberg EBook of Forbidden Frui... English \n", + "6 [The Project Gutenberg EBook of The Life and A... English \n", + "7 [The Project Gutenberg EBook of The Fifteen Co... English \n", + "8 [The Project Gutenberg EBook of The Power of M... English \n", + "9 [The Project Gutenberg EBook of Les Bijoux Ind... English \n", + "10 [The Project Gutenberg EBook of The Ladies Del... English \n", + "11 [The Project Gutenberg eBook, Studies in the P... English \n", + "12 [The Project Gutenberg EBook of La-bas, by J. ... English \n", + "13 [The Project Gutenberg eBook, Studies in the P... English \n", + "14 [The Project Gutenberg EBook of Manual of Cla... English \n", + "15 [The Project Gutenberg EBook of The Power of M... English \n", + "16 [The Project Gutenberg EBook of Main Currents... English \n", + "17 [, The Project Gutenberg EBook of The Romance ... English \n", + "18 [The Project Gutenberg EBook of Anthologica R... English \n", + "19 [The Project Gutenberg EBook of The Fifteen Co... English \n", + "20 [The Project Gutenberg EBook of The Life and A... English \n", + "21 [The Project Gutenberg eBook, Studies in the P... English \n", + "\n", + " title \n", + "0 The Erotic Motive in Literature \n", + "1 Laura Middleton; Her Brother and her Lover \n", + "2 Là-bas \n", + "3 Memoirs Of Fanny Hill, A New and Genuine Editi... \n", + "4 Studies in the Psychology of Sex, Volume 1 (of 6) \n", + "5 Forbidden Fruit, Luscious and exciting story a... \n", + "6 The Life and Amours of the Beautiful, Gay and ... \n", + "7 The Fifteen Comforts of Matrimony: Responses F... \n", + "8 The Power of Mesmerism, A Highly Erotic Narrat... \n", + "9 Les Bijoux Indiscrets, or, The Indiscreet Toys \n", + "10 The Ladies Delight \n", + "11 Studies in the Psychology of Sex, Volume 5 (of 6) \n", + "12 La-bas \n", + "13 Studies in the Psychology of Sex, Volume 5 (of 6) \n", + "14 Manual of Classical Erotology (De figuris Vene... \n", + "15 The Power of Mesmerism, A Highly Erotic Narrat... \n", + "16 Main Currents in Nineteenth Century Literature... \n", + "17 The Romance of Lust A classic Victorian erotic... \n", + "18 Anthologica Rarissima: The Way of a Virgin \n", + "19 The Fifteen Comforts of Matrimony: Responses f... \n", + "20 The Life and Amours of the Beautiful, Gay and ... \n", + "21 Studies in the Psychology of Sex, Volume 1 (of 6) " + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "import uuid\n", "import pandas as pd\n", @@ -315,33 +662,18 @@ "for infile in os.listdir(dest_dir):\n", " path = os.path.join(dest_dir, infile)\n", " info = json.load(open(path))\n", - " paragraphs = info['content'].split('\\n\\n')\n", - " for paragraph in paragraphs:\n", - "# sentances = [p for p in paragraph.strip().split('. ')]\n", - " sentances = nltk.sent_tokenize(paragraph)\n", - " if len(sentances)>num_sent:\n", - " for i in range(len(sentances)//num_sent):\n", - " data.append(dict(\n", - " storyid=uuid.uuid4().hex,\n", - " sentence1=sentances[i*5+0][:max_len],\n", - " sentence2=sentances[i*5+1][:max_len],\n", - " sentence3=sentances[i*5+2][:max_len],\n", - " sentence4=sentances[i*5+3][:max_len],\n", - " sentence5=sentances[i*5+4][:max_len],\n", - " AnswerRightEnding=1\n", - " ))\n", + " data.append(info)\n", "df = pd.DataFrame(data)\n", - "df = df[['storyid', 'sentence1', 'sentence2', 'sentence3', 'sentence4', 'sentence5', 'AnswerRightEnding']]\n", "df" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 6, "metadata": { "ExecuteTime": { - "end_time": "2018-11-04T05:19:03.839834Z", - "start_time": "2018-11-04T05:19:03.835656Z" + "end_time": "2018-11-04T08:55:59.624853Z", + "start_time": "2018-11-04T08:55:59.621367Z" } }, "outputs": [], @@ -359,91 +691,32 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 11, "metadata": { "ExecuteTime": { - "end_time": "2018-11-04T05:19:09.318413Z", - "start_time": "2018-11-04T05:19:08.980242Z" + "end_time": "2018-11-04T08:56:56.969876Z", + "start_time": "2018-11-04T08:56:56.965223Z" } }, "outputs": [], "source": [ - "%matplotlib inline\n", - "df['sentence1'].str.len().plot.hist(bins=55)\n", - "df['sentence1'].str.len().max()" + "df = df.rename(columns=dict(content='TEXT'))" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 12, "metadata": { "ExecuteTime": { - "end_time": "2018-11-04T05:19:14.924306Z", - "start_time": "2018-11-04T05:19:14.917985Z" + "end_time": "2018-11-04T08:56:58.359405Z", + "start_time": "2018-11-04T08:56:58.068470Z" } }, "outputs": [], "source": [ - "val_idx = int(len(df)*0.7)\n", - "df_train = df[:val_idx]\n", - "df_val = df[val_idx:]" + "df.to_csv('data/erotic_gutenberg_dataset.csv', index=False)" ] }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "ExecuteTime": { - "end_time": "2018-11-04T05:19:15.503886Z", - "start_time": "2018-11-04T05:19:15.384466Z" - } - }, - "outputs": [], - "source": [ - "df_train.to_csv('data/erotic_gutenberg_TRAIN.csv', index=False)\n", - "df_val.to_csv('data/erotic_gutenberg_VAL.csv', index=False)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "ExecuteTime": { - "end_time": "2018-11-04T05:19:16.452211Z", - "start_time": "2018-11-04T05:19:16.281536Z" - } - }, - "outputs": [], - "source": [ - "# import csv\n", - "# def _rocstories(path):\n", - "# with open(path, encoding='utf_8') as f:\n", - "# f = csv.reader(f)\n", - "# st = []\n", - "# ct1 = []\n", - "# ct2 = []\n", - "# y = []\n", - "# for i, line in enumerate(tqdm(list(f), ncols=80, leave=False)):\n", - "# if i > 0:\n", - "# s = ' '.join(line[1:5])\n", - "# c1 = line[5]\n", - "# c2 = line[6]\n", - "# st.append(s)\n", - "# ct1.append(c1)\n", - "# ct2.append(c2)\n", - "# y.append(int(line[-1])-1)\n", - "# return st, ct1, ct2, y\n", - " \n", - "# _rocstories('data/erotic_gutenberg_TRAIN.csv')" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, { "cell_type": "code", "execution_count": null, diff --git a/loss.py b/loss.py index c94b1b8..d4cd8ef 100644 --- a/loss.py +++ b/loss.py @@ -1,32 +1,5 @@ import torch -class LMLossCompute: - "A Loss compute and train function for multiple choice tasks." - - def __init__(self, lm_criterion, opt=None): - self.lm_criterion = lm_criterion - self.opt = opt - - def __call__(self, X, Y, M, lm_logits, only_return_losses=False): - # Language modeling loss - if lm_logits is not None: - x_shifted = X[:, :, 1:, 0].contiguous().view(-1) # Shape: 252 - M = M.view(-1, M.size(2)) - lm_losses = self.lm_criterion(lm_logits, x_shifted) - lm_losses = lm_losses.view(X.size(0) * X.size(1), X.size(2) - 1) - lm_losses = lm_losses * M[:, 1:] - lm_losses = lm_losses.sum(1) / torch.sum(M[:, 1:], 1) - if only_return_losses: - return lm_losses - - train_loss = lm_losses.sum() - train_loss.backward() - if self.opt is not None: - self.opt.step() - self.opt.zero_grad() - return train_loss.item() - - class MultipleChoiceLossCompute: "A Loss compute and train function for multiple choice tasks." @@ -93,4 +66,29 @@ class ClassificationLossCompute: self.opt.zero_grad() return train_loss.item() +class LanguageModelingLossCompute: + " A Loss compute and train function for language modeling tasks." + def __init__(self, lm_criterion, opt=None): + self.lm_criterion = lm_criterion + self.opt = opt + + # def __call__(self, X, Y, M, clf_logits, lm_logits=None, only_return_losses=False): + def __call__(self, X, Y, M, lm_logits, only_return_losses=False): + # Language modeling loss + x_shifted = X[:, 1:, 0].contiguous().view(-1) + M = M.view(-1, M.size(-1)) + lm_losses = self.lm_criterion(lm_logits, x_shifted) + lm_losses = lm_losses.view(X.size(0), X.size(-2) - 1) + lm_losses = lm_losses * M[:, 1:] + lm_losses = lm_losses.sum(1) / torch.sum(M[:, 1:], 1) + if only_return_losses: + return lm_losses + + train_loss = lm_losses.sum() + train_loss.backward() + if self.opt is not None: + self.opt.step() + self.opt.zero_grad() + return train_loss.item() + # TODO Implement a LossCompute class for similiraty tasks. diff --git a/model_pytorch.py b/model_pytorch.py index 97c6f1f..706d166 100644 --- a/model_pytorch.py +++ b/model_pytorch.py @@ -279,8 +279,8 @@ class DoubleHeadModel(nn.Module): # the three classes correspond to entailment, contradiction and neutral. self.task_head = ClfHead(clf_token, cfg, 3) else: - raise ValueError("task_head_type is expected to be 'multiple_choice' "+ - "'similarity', 'inference' or ('classification', n_class) "+ + raise ValueError("task_head_type is expected to be 'multiple_choice' " + "'similarity', 'inference' or ('classification', n_class) " "got {task_head_type}.".format(task_head_type=task_head_type)) elif isinstance(task_head_type, collections.abc.Sequence) and len(task_head_type) == 2 and \ task_head_type[0] == 'classification': @@ -298,6 +298,18 @@ class DoubleHeadModel(nn.Module): return lm_logits, task_logits +class LanguageModel(nn.Module): + """ Transformer with language model """ + def __init__(self, cfg, vocab=40990, n_ctx=512): + super(LanguageModel, self).__init__() + self.transformer = TransformerModel(cfg, vocab=vocab, n_ctx=n_ctx) + self.lm_head = LMHead(self.transformer, cfg) + + def forward(self, x): + h = self.transformer(x) + lm_logits = self.lm_head(h) + + return lm_logits def load_openai_pretrained_model(model, n_ctx=-1, n_special=-1, n_transfer=12, n_embd=768, path='./model/', path_names='./'): diff --git a/text_utils.py b/text_utils.py index 59d78f2..9dc9d9f 100644 --- a/text_utils.py +++ b/text_utils.py @@ -3,7 +3,7 @@ import ftfy import json import spacy -from tqdm import tqdm +from tqdm import tqdm_notebook as tqdm def get_pairs(word): """ @@ -92,7 +92,7 @@ class TextEncoder(object): def encode(self, texts, verbose=True): texts_tokens = [] if verbose: - for text in tqdm(texts, ncols=80, leave=False): + for text in tqdm(texts, ncols=80, mininterval=10, leave=False): text = self.nlp(text_standardize(ftfy.fix_text(text))) text_tokens = [] for token in text: diff --git a/train.ipynb b/train.ipynb index d46a77f..f9f217a 100644 --- a/train.ipynb +++ b/train.ipynb @@ -1,12 +1,20 @@ { "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Language model code from:\n", + " https://github.com/rodgzilla/pytorch-openai-transformer-lm/blob/horoscope_language_model" + ] + }, { "cell_type": "code", "execution_count": 1, "metadata": { "ExecuteTime": { - "end_time": "2018-11-04T05:19:26.254969Z", - "start_time": "2018-11-04T05:19:25.883714Z" + "end_time": "2018-11-04T09:35:47.026194Z", + "start_time": "2018-11-04T09:35:46.675400Z" } }, "outputs": [], @@ -16,42 +24,45 @@ "%autoreload 2" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Imports" + ] + }, { "cell_type": "code", - "execution_count": 187, + "execution_count": 2, "metadata": { "ExecuteTime": { - "end_time": "2018-11-04T07:26:02.575703Z", - "start_time": "2018-11-04T07:26:02.531435Z" + "end_time": "2018-11-04T09:35:48.585230Z", + "start_time": "2018-11-04T09:35:47.030292Z" } }, "outputs": [], "source": [ - "import argparse\n", "import os\n", - "import random\n", + "import pandas as pd\n", + "import pdb\n", + "import argparse\n", + "import itertools\n", "\n", "import numpy as np\n", + "\n", + "from sklearn.metrics import accuracy_score\n", + "from sklearn.model_selection import train_test_split\n", + "\n", "import torch\n", "import torch.nn as nn\n", - "from sklearn.metrics import accuracy_score\n", - "from sklearn.utils import shuffle\n", + "import torch.nn.functional as F\n", "\n", - "from analysis import rocstories as rocstories_analysis\n", - "from datasets import rocstories\n", - "from model_pytorch import DoubleHeadModel, load_openai_pretrained_model\n", - "from opt import OpenAIAdam\n", + "from model_pytorch import TransformerModel, LMHead, load_openai_pretrained_model, DEFAULT_CONFIG\n", + "from model_pytorch import LanguageModel\n", + "from utils import encode_dataset, flatten, iter_data, ResultLogger, make_path\n", "from text_utils import TextEncoder\n", - "from utils import (encode_dataset, iter_data,\n", - " ResultLogger, make_path, np_softmax)\n", - "from loss import LMLossCompute" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Helpers" + "from opt import OpenAIAdam\n", + "from loss import LanguageModelingLossCompute" ] }, { @@ -59,256 +70,22 @@ "execution_count": 3, "metadata": { "ExecuteTime": { - "end_time": "2018-11-04T05:19:28.072871Z", - "start_time": "2018-11-04T05:19:28.014007Z" + "end_time": "2018-11-04T09:35:48.631139Z", + "start_time": "2018-11-04T09:35:48.588393Z" } }, "outputs": [], "source": [ "\n", - "\n", - "def transform_roc(X1, X2, X3):\n", - " \"\"\"pad and crop sequences\"\"\"\n", - " n_batch = len(X1)\n", - " xmb = np.zeros((n_batch, 2, n_ctx, 2), dtype=np.int32)\n", - " mmb = np.zeros((n_batch, 2, n_ctx), dtype=np.float32)\n", - " start = encoder['_start_']\n", - " delimiter = encoder['_delimiter_']\n", - " for i, (x1, x2, x3), in enumerate(zip(X1, X2, X3)):\n", - " x12 = [start] + x1[:max_len] + [delimiter] + x2[:max_len] + [clf_token]\n", - " x13 = [start] + x1[:max_len] + [delimiter] + x3[:max_len] + [clf_token]\n", - " l12 = len(x12)\n", - " l13 = len(x13)\n", - " xmb[i, 0, :l12, 0] = x12\n", - " xmb[i, 1, :l13, 0] = x13\n", - " mmb[i, 0, :l12] = 1\n", - " mmb[i, 1, :l13] = 1\n", - " # Position information that is added to the input embeddings in the TransformerModel\n", - " xmb[:, :, :, 1] = np.arange(n_vocab + n_special, n_vocab + n_special + n_ctx)\n", - " return xmb, mmb\n", - "\n", - "\n", - "# def iter_apply(Xs, Ms, Ys):\n", - "# # fns = [lambda x: np.concatenate(x, 0), lambda x: float(np.sum(x))]\n", - "# logits = []\n", - "# cost = 0\n", - "# with torch.no_grad():\n", - "# dh_model.eval()\n", - "# for xmb, mmb, ymb in iter_data(Xs, Ms, Ys, n_batch=n_batch_train, truncate=False, verbose=True):\n", - "# n = len(xmb)\n", - "# XMB = torch.tensor(xmb, dtype=torch.long).to(device)\n", - "# YMB = torch.tensor(ymb, dtype=torch.long).to(device)\n", - "# MMB = torch.tensor(mmb).to(device)\n", - "# lm_logits, clf_logits = dh_model(XMB)\n", - "# clf_logits *= n\n", - "# clf_losses = compute_loss_fct(XMB, YMB, MMB, clf_logits, only_return_losses=True)\n", - "# clf_losses *= n\n", - "# logits.append(clf_logits.to(\"cpu\").numpy())\n", - "# cost += clf_losses.sum().item()\n", - "# logits = np.concatenate(logits, 0)\n", - "# return logits, cost\n", - "\n", - "\n", - "def log(save_dir, desc):\n", - " global best_score\n", - " print(\"Logging\")\n", - "# tr_logits, tr_cost = iter_apply(trX[:n_valid], trM[:n_valid], trY[:n_valid])\n", - "# va_logits, va_cost = iter_apply(vaX, vaM, vaY)\n", - "# tr_cost = tr_cost / len(trY[:n_valid])\n", - "# va_cost = va_cost / n_valid\n", - "# tr_acc = accuracy_score(trY[:n_valid], np.argmax(tr_logits, 1)) * 100.\n", - "# va_acc = accuracy_score(vaY, np.argmax(va_logits, 1)) * 100.\n", - " logger.log(n_epochs=n_epochs, n_updates=n_updates)#, tr_cost=tr_cost, va_cost=va_cost, tr_acc=tr_acc, va_acc=va_acc)\n", - " print('%d %d %.3f %.3f %.2f %.2f' % (n_epochs, n_updates))#, tr_cost, va_cost, tr_acc, va_acc))\n", - "# if submit:\n", - "# score = va_acc\n", - "# if score > best_score:\n", - "# best_score = score\n", - "# path = os.path.join(save_dir, desc, 'best_params')\n", - "# torch.save(dh_model.state_dict(), make_path(path))\n", - "\n", - "def run_epoch():\n", - " for xmb, mmb, ymb in iter_data(*shuffle(trX, trM, trYt, random_state=np.random),\n", - " n_batch=n_batch_train, truncate=True, verbose=True):\n", - " global n_updates\n", - " dh_model.train()\n", - " XMB = torch.tensor(xmb, dtype=torch.long).to(device)\n", - " YMB = torch.tensor(ymb, dtype=torch.long).to(device)\n", - " MMB = torch.tensor(mmb).to(device)\n", - " lm_logits, _ = dh_model(XMB)\n", - " compute_loss_fct(XMB, YMB, MMB, lm_logits)\n", - " n_updates += 1\n", - " if n_updates in [1000, 2000, 4000, 8000, 16000, 32000] and n_epochs == 0:\n", - " log(save_dir, desc)\n", - "\n" + "n_updates = 0\n", + "best_score = 0" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "# Params" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": { - "ExecuteTime": { - "end_time": "2018-11-04T05:19:28.114577Z", - "start_time": "2018-11-04T05:19:28.076114Z" - } - }, - "outputs": [], - "source": [ - "\n", - "argmax = lambda x: np.argmax(x, 1)\n", - "\n", - "pred_fns = {\n", - " 'rocstories': argmax,\n", - "}\n", - "\n", - "filenames = {\n", - " 'rocstories': 'ROCStories.tsv',\n", - "}\n", - "\n", - "label_decoders = {\n", - " 'rocstories': None,\n", - "}\n" - ] - }, - { - "cell_type": "code", - "execution_count": 207, - "metadata": { - "ExecuteTime": { - "end_time": "2018-11-04T07:34:59.634273Z", - "start_time": "2018-11-04T07:34:59.573211Z" - } - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Namespace(afn='gelu', analysis=False, attn_pdrop=0.1, b1=0.9, b2=0.999, bpe_path='model/vocab_40000.bpe', clf_pdrop=0.1, data_dir='data/', dataset='data/corpus/erotic_gutenberg.csv', desc=None, e=1e-08, embd_pdrop=0.1, encoder_path='model/encoder_bpe_40000.json', l2=0.01, lm_coef=0.5, log_dir='log/', lr=6.25e-05, lr_schedule='warmup_linear', lr_warmup=0.002, max_grad_norm=1, n_batch=2, n_ctx=512, n_embd=768, n_head=12, n_iter=3, n_layer=12, n_transfer=12, n_valid=374, opt='adam', resid_pdrop=0.1, save_dir='save/', seed=42, submission_dir='submission/', submit=True, vector_l2=False)\n" - ] - } - ], - "source": [ - "\n", - "parser = argparse.ArgumentParser()\n", - "parser.add_argument('--desc', type=str, help=\"Description\")\n", - "parser.add_argument('--dataset', type=str)\n", - "parser.add_argument('--log_dir', type=str, default='log/')\n", - "parser.add_argument('--save_dir', type=str, default='save/')\n", - "parser.add_argument('--data_dir', type=str, default='data/')\n", - "parser.add_argument('--submission_dir', type=str, default='submission/')\n", - "parser.add_argument('--submit', action='store_true')\n", - "parser.add_argument('--analysis', action='store_true')\n", - "parser.add_argument('--seed', type=int, default=42)\n", - "parser.add_argument('--n_iter', type=int, default=3)\n", - "parser.add_argument('--n_batch', type=int, default=8)\n", - "parser.add_argument('--max_grad_norm', type=int, default=1)\n", - "parser.add_argument('--lr', type=float, default=6.25e-5)\n", - "parser.add_argument('--lr_warmup', type=float, default=0.002)\n", - "parser.add_argument('--n_ctx', type=int, default=512)\n", - "parser.add_argument('--n_embd', type=int, default=768)\n", - "parser.add_argument('--n_head', type=int, default=12)\n", - "parser.add_argument('--n_layer', type=int, default=12)\n", - "parser.add_argument('--embd_pdrop', type=float, default=0.1)\n", - "parser.add_argument('--attn_pdrop', type=float, default=0.1)\n", - "parser.add_argument('--resid_pdrop', type=float, default=0.1)\n", - "parser.add_argument('--clf_pdrop', type=float, default=0.1)\n", - "parser.add_argument('--l2', type=float, default=0.01)\n", - "parser.add_argument('--vector_l2', action='store_true')\n", - "parser.add_argument('--opt', type=str, default='adam')\n", - "parser.add_argument('--afn', type=str, default='gelu')\n", - "parser.add_argument('--lr_schedule', type=str, default='warmup_linear')\n", - "parser.add_argument('--encoder_path', type=str, default='model/encoder_bpe_40000.json')\n", - "parser.add_argument('--bpe_path', type=str, default='model/vocab_40000.bpe')\n", - "parser.add_argument('--n_transfer', type=int, default=12)\n", - "parser.add_argument('--lm_coef', type=float, default=0.5)\n", - "parser.add_argument('--b1', type=float, default=0.9)\n", - "parser.add_argument('--b2', type=float, default=0.999)\n", - "parser.add_argument('--e', type=float, default=1e-8)\n", - "parser.add_argument('--n_valid', type=int, default=374)\n", - "\n", - "\n", - "args = parser.parse_args('''\n", - "--dataset data/corpus/erotic_gutenberg.csv \n", - "--n_batch 2 \n", - "--submit \n", - "--n_iter 15\n", - "'''.replace('\\n','').split(' '))\n", - "print(args)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Init" - ] - }, - { - "cell_type": "code", - "execution_count": 208, - "metadata": { - "ExecuteTime": { - "end_time": "2018-11-04T07:35:01.392474Z", - "start_time": "2018-11-04T07:35:01.318042Z" - } - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "device cuda n_gpu 1\n" - ] - } - ], - "source": [ - "random.seed(args.seed)\n", - "np.random.seed(args.seed)\n", - "torch.manual_seed(args.seed)\n", - "torch.cuda.manual_seed_all(args.seed)\n", - "\n", - "# Constants\n", - "submit = args.submit\n", - "dataset = args.dataset\n", - "n_ctx = args.n_ctx\n", - "save_dir = args.save_dir\n", - "desc = args.desc\n", - "data_dir = args.data_dir\n", - "log_dir = args.log_dir\n", - "submission_dir = args.submission_dir\n", - "\n", - "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", - "n_gpu = torch.cuda.device_count()\n", - "print(\"device\", device, \"n_gpu\", n_gpu)" - ] - }, - { - "cell_type": "code", - "execution_count": 210, - "metadata": { - "ExecuteTime": { - "end_time": "2018-11-04T07:35:06.432028Z", - "start_time": "2018-11-04T07:35:05.791318Z" - } - }, - "outputs": [], - "source": [ - "logger = ResultLogger(\n", - " path=os.path.join(log_dir, '{}.jsonl'.format(desc)), **args.__dict__)\n", - "\n", - "# bpe tokenizer BYTE PAIR ENCODING https://en.wikipedia.org/wiki/Byte_pair_encoding\n", - "# this is compression where we replace frequent pairs with unused byte codes\n", - "text_encoder = TextEncoder(args.encoder_path, args.bpe_path)\n", - "encoder = text_encoder.encoder\n", - "n_vocab = len(text_encoder.encoder)" + "# Helpers" ] }, { @@ -321,803 +98,437 @@ { "cell_type": "code", "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 4, "metadata": { "ExecuteTime": { - "end_time": "2018-11-04T07:42:50.993933Z", - "start_time": "2018-11-04T07:42:50.954269Z" + "end_time": "2018-11-04T09:35:48.682286Z", + "start_time": "2018-11-04T09:35:48.634489Z" + } + }, + "outputs": [], + "source": [ + "\n", + "def _chunk_word_list(word_list, max_sequence_len = 50000):\n", + " # We have to split the text into text of 100.000 characters\n", + " # because of the parser limitations.\n", + " word_sequences = [[]]\n", + " last_sequence_len = 0\n", + " for word in word_list:\n", + " # If the last word list has reached the maximum size\n", + " if last_sequence_len + len(word) > max_sequence_len:\n", + " # We transform it into a string by rejoining the words\n", + " word_sequences[-1] = ' '.join(word_sequences[-1])\n", + " # and then begin a new word sequence\n", + " word_sequences.append([])\n", + " last_sequence_len = 0\n", + " word_sequences[-1].append(word)\n", + " last_sequence_len += len(word)\n", + "\n", + " if type(word_sequences[-1]) == list:\n", + " word_sequences[-1] = ' '.join(word_sequences[-1])\n", + "\n", + " return word_sequences\n", + "\n", + "def load_dataset(text_encoder, window_size, path = 'data/erotic_gutenberg_dataset.csv',\n", + " shuffle = True, seed = 142857,\n", + " test_size = 0.2):\n", + " df = pd.read_csv(path)\n", + " all_text = ' '.join(df.TEXT)\n", + " word_list = all_text.split(' ')\n", + " word_sequences = _chunk_word_list(word_list, )\n", + " encoded_text = text_encoder.encode(word_sequences)\n", + " word_idx_list = list(itertools.chain.from_iterable(encoded_text))\n", + " context_list = []\n", + " target_list = []\n", + "\n", + " for start_idx in range(len(word_idx_list) - window_size - 1):\n", + " context_list.append(word_idx_list[start_idx : start_idx + window_size])\n", + " target_list.append(word_idx_list[start_idx + window_size])\n", + "\n", + " X_train, X_val, y_train, y_val = train_test_split(\n", + " context_list,\n", + " target_list,\n", + " test_size = test_size,\n", + " shuffle = shuffle,\n", + " random_state = seed\n", + " )\n", + " return (X_train, y_train), (X_val, y_val)\n", + "\n", + "def transform_dataset(dataset, encoder, max_len, n_vocab, n_special, n_ctx):\n", + " n_batch = len(dataset)\n", + " xmb = np.zeros((n_batch, n_ctx, 2), dtype = np.int32)\n", + " mmb = np.zeros((n_batch, n_ctx), dtype = np.float32)\n", + " start = encoder.encoder['_start_']\n", + " clf_token = encoder.encoder['_classify_']\n", + " for i, x in enumerate(dataset):\n", + " x_with_tokens = [start] + x[:max_len] + [clf_token]\n", + " l_x = len(x_with_tokens)\n", + " xmb[i, :l_x, 0] = x_with_tokens\n", + " mmb[i, :l_x] = 1\n", + " xmb[:, :, 1] = np.arange(n_vocab + n_special, n_vocab + n_special + n_ctx)\n", + "\n", + " return xmb, mmb\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Data" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "ExecuteTime": { + "end_time": "2018-11-04T09:35:48.734192Z", + "start_time": "2018-11-04T09:35:48.685031Z" + } + }, + "outputs": [], + "source": [ + "\n", + "def iter_apply(model, n_batch_train, device, compute_loss_fct, Xs, Ms, Ys, return_logits = True):\n", + " if return_logits:\n", + " logits = []\n", + " cost = 0\n", + " with torch.no_grad():\n", + " model.eval()\n", + " for xmb, mmb, ymb in iter_data(Xs, Ms, Ys, n_batch=n_batch_train, truncate=False, verbose=True):\n", + " n = len(xmb)\n", + " XMB = torch.tensor(xmb, dtype=torch.long).to(device)\n", + " YMB = torch.tensor(ymb, dtype=torch.long).to(device)\n", + " MMB = torch.tensor(mmb).to(device)\n", + " lm_logits = model(XMB)\n", + " lm_logits *= n\n", + " lm_losses = compute_loss_fct(XMB, YMB, MMB, lm_logits, only_return_losses=True)\n", + " lm_losses *= n\n", + " if return_logits:\n", + " logits.append(lm_logits.to(\"cpu\").numpy())\n", + " cost += lm_losses.sum().item()\n", + "\n", + " if return_logits:\n", + " logits = np.concatenate(logits, 0)\n", + " return logits, cost\n", + "\n", + " return cost\n", + "\n", + "def decode_word(text_encoder, idx):\n", + " if idx not in text_encoder.decoder:\n", + " return ''\n", + "\n", + " word = text_encoder.decoder[idx]\n", + "\n", + " return word[:-4] if word[-4:] == '' else word\n", + "\n", + "def decode_sentence(text_encoder, idx_list):\n", + " word_list = [decode_word(text_encoder, idx) for idx in idx_list]\n", + "\n", + " return ' '.join(word_list)\n", + "\n", + "def try_on_a_sentence(model, text_encoder, sentence, window_size,\n", + " n_vocab, n_special, n_ctx, device,\n", + " final_len = 200):\n", + " model.eval()\n", + " start_token = text_encoder.encoder['_start_']\n", + " clf_token = text_encoder.encoder['_classify_']\n", + " encoded_text = text_encoder.encode([sentence])[0]\n", + " while len(encoded_text) < final_len:\n", + " # We take the last 'window_size' words of the text being generated\n", + " # and run it through the model.\n", + " context = encoded_text[-window_size:]\n", + " X_trans, X_mask = transform_dataset(\n", + " [context],\n", + " text_encoder,\n", + " window_size,\n", + " n_vocab,\n", + " n_special,\n", + " n_ctx\n", + " )\n", + " XMB = torch.tensor(X_trans, dtype = torch.long).to(device)\n", + " lm_logits = model(XMB)\n", + " # We truncate the resulting predictions to actual vocabulary\n", + " # words in order to exclude special tokens and positional\n", + " # embeddings.\n", + " lm_logits = lm_logits[:, : n_vocab]\n", + " X_trans_tensor = torch.from_numpy(X_trans)\n", + " # We then select the logit corresponding to the 'clf_token'\n", + " # position (last one of the sequence).\n", + " clf_token_bool_idx = X_trans_tensor[0, :, 0] == clf_token\n", + " predictions = lm_logits.max(dim = 1)[1]\n", + " pred = predictions[clf_token_bool_idx[1:]].item()\n", + " encoded_text.append(pred)\n", + "\n", + " return decode_sentence(text_encoder, encoded_text)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Run" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": { + "ExecuteTime": { + "end_time": "2018-11-04T09:35:48.790923Z", + "start_time": "2018-11-04T09:35:48.737560Z" + } + }, + "outputs": [], + "source": [ + "\n", + "def run_epoch(model, n_batch_train, device, compute_loss_fct, logger,\n", + " save_dir, desc, submit, n_valid, n_epochs, X_train,\n", + " X_train_mask, y_train, X_val, X_val_mask, y_val,\n", + " generation_params):\n", + " for xmb, mmb, ymb in iter_data(X_train,\n", + " X_train_mask,\n", + " y_train,\n", + " n_batch = n_batch_train,\n", + " truncate=True,\n", + " verbose=True):\n", + " global n_updates\n", + " model.train()\n", + " XMB = torch.tensor(xmb, dtype=torch.long).to(device)\n", + " YMB = torch.tensor(ymb, dtype=torch.long).to(device)\n", + " MMB = torch.tensor(mmb).to(device)\n", + " lm_logits = model(XMB)\n", + " compute_loss_fct(XMB, YMB, MMB, lm_logits)\n", + " if n_updates % 500 == 0:\n", + " log(\n", + " model,\n", + " n_batch_train,\n", + " device,\n", + " compute_loss_fct,\n", + " logger,\n", + " save_dir,\n", + " desc,\n", + " submit,\n", + " n_valid,\n", + " n_epochs,\n", + " n_updates,\n", + " X_train,\n", + " X_train_mask,\n", + " y_train,\n", + " X_val,\n", + " X_val_mask,\n", + " y_val,\n", + " generation_params\n", + " )\n", + " n_updates += 1\n", + "\n", + "def log(model, n_batch_train, device, compute_loss_fct, logger,\n", + " save_dir, desc, submit, n_valid, n_epochs, n_updates, X_train,\n", + " X_train_mask, y_train, X_val, X_val_mask, y_val,\n", + " generation_params):\n", + " global best_score\n", + " result = try_on_a_sentence(**generation_params)\n", + " print(\"\\n\\n Base: {} \\n\\n Result: {}\".format(generation_params['sentence'], result))\n", + " print(\"\\nLogging\")\n", + " tr_cost = iter_apply(\n", + " model,\n", + " n_batch_train,\n", + " device,\n", + " compute_loss_fct,\n", + " X_train[:n_valid],\n", + " X_train_mask[:n_valid],\n", + " y_train[:n_valid],\n", + " False\n", + " )\n", + " va_cost = iter_apply(\n", + " model,\n", + " n_batch_train,\n", + " device,\n", + " compute_loss_fct,\n", + " X_val,\n", + " X_val_mask,\n", + " y_val,\n", + " False\n", + " )\n", + " tr_cost = tr_cost / len(y_train[:n_valid])\n", + " va_cost = va_cost / n_valid\n", + " logger.log(\n", + " n_epochs = n_epochs,\n", + " n_updates = n_updates,\n", + " tr_cost = tr_cost,\n", + " va_cost = va_cost\n", + " )\n", + " print('\\n%d %d %.3f %.3f' % (n_epochs, n_updates, tr_cost, va_cost))\n", + " if submit:\n", + " score = va_cost\n", + " if score > best_score:\n", + " best_score = score\n", + " path = os.path.join(save_dir, desc, 'best_params')\n", + " torch.save(model.state_dict(), make_path(path))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Params" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": { + "ExecuteTime": { + "end_time": "2018-11-04T09:35:49.459204Z", + "start_time": "2018-11-04T09:35:48.793536Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "({'afn': 'gelu',\n", + " 'attn_pdrop': 0.1,\n", + " 'clf_pdrop': 0.1,\n", + " 'embd_pdrop': 0.1,\n", + " 'n_embd': 768,\n", + " 'n_head': 12,\n", + " 'n_layer': 12,\n", + " 'resid_pdrop': 0.1},\n", + " {'n_ctx': 258, 'n_special': 2, 'total_vocab_size': 40738})" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Training configuration\n", + "epochs = 3\n", + "n_batch_train = 20\n", + "window_size = 256\n", + "max_len = window_size\n", + "# General configuration\n", + "save_dir = 'save/'\n", + "log_dir = 'log/'\n", + "desc = 'erotic_gutenberg'\n", + "submit = True\n", + "args = DEFAULT_CONFIG\n", + "logger = ResultLogger(\n", + " path = os.path.join(\n", + " log_dir,\n", + " '{}.jsonl'.format(desc)\n", + " ),\n", + " **args.__dict__\n", + ")\n", + "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", + "bpe_path = 'model/vocab_40000.bpe'\n", + "encoder_path = 'model/encoder_bpe_40000.json'\n", + "text_encoder = TextEncoder(encoder_path, bpe_path)\n", + "encoder = text_encoder.encoder\n", + "n_special = 2\n", + "n_vocab = len(encoder)\n", + "encoder['_start_'] = len(encoder)\n", + "encoder['_classify_'] = len(encoder)\n", + "clf_token = encoder['_classify_']\n", + "\n", + "n_ctx = window_size + n_special\n", + "total_vocab_size = n_vocab + n_special + n_ctx\n", + "\n", + "args, dict(n_ctx=n_ctx, total_vocab_size=total_vocab_size, n_special=n_special)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2018-11-04T09:22:38.377992Z", + "start_time": "2018-11-04T09:22:29.584Z" } }, "outputs": [], "source": [] }, { - "cell_type": "code", - "execution_count": 211, - "metadata": { - "ExecuteTime": { - "end_time": "2018-11-04T07:35:21.585455Z", - "start_time": "2018-11-04T07:35:06.931128Z" - } - }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - " 0%| | 0/2402 [00:00', ''],\n", - " ['.', ''],\n", - " ['\"', ''],\n", - " ['', ''],\n", - " ['.', ''],\n", - " ['\"', ''],\n", - " ['', ''],\n", - " [',', ','],\n", - " [\"'\", \"'\"],\n", - " ['.', '.'],\n", - " ['', ' '],\n", - " ['\"', ''],\n", - " ['_delimiter_', '\\n'],\n", - "# ['_classify_', '\\n'],\n", - " ['_start_', '\\n']\n", - " \n", - "]" + "language_model = LanguageModel(\n", + " args,\n", + " vocab = total_vocab_size,\n", + " n_ctx = n_ctx\n", + ")\n", + "load_openai_pretrained_model(\n", + " language_model.transformer,\n", + " n_ctx = n_ctx,\n", + " n_special = n_special\n", + ")\n", + "language_model.to(device)\n", + "1" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# init opt, loss" ] }, { "cell_type": "code", - "execution_count": 244, + "execution_count": null, "metadata": { "ExecuteTime": { - "end_time": "2018-11-04T07:52:27.709221Z", - "start_time": "2018-11-04T07:52:11.977263Z" + "start_time": "2018-11-04T09:35:46.419Z" } }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n", - "\"., a same, a his silk he with evcast the trees forest he the h and to he is the moral's into way.. \n", - "1 _classify_1 \" \" ,;, ..,. ., .. . .,. ,,,,,,,,,, ., a same, a his blue she with same lengthened the chapel, he the h, to he, the man's into the.. \n", - ". _classify_. 2 19.. ., ., .. ;. . . , . . ,,,\n", - ",;,,,... \n", - "1 _classify_1 _classify__classify__,,,,a... \n", - ". _classify_\n", - "\n", - "\" \n", - " \n", - "\n", - "\" object and is _ the emerson the _, which has us fuller resemblance impression of the love of is has envy tolerance have caused about the race race, is no yet therto been written described. elaborated. to the original. it has to convey an more description of the the has existed and decent thinminded people beings have done offer in. notions about prejudices. order with the prejudices gnance. but see here poem of the very who lost so \" and his young desire strong man to the weak, a - bitter, selfish, and acious, a mother ments of innocent children and his of his children ; the was a and cold, he did him but felt, were ; his children. the floor. \n", - "was with and the and helen daughter of becomes left, she children are suffer at the father's death. and cling she is is not bear feel that. be a loss. \n", - "1 poet poet's a a a and _classify_1 . . . . ;.. . , . ... . ;.., ,,.,, \" patient book, _ the's byron _, was is us fuller resemblance idea of the relationship of has has prejudice tolerance have brought about the young race, was a been therto been a published. even. to the authors. it is to explain a general description of the the has matters and true principles minded people nature feel done offer.. notions, prejudices. literature with the prejudices gnancies. the have here poem of the young who had a loving and his young opposition and husband to the weak, a - cruel, selfish, and acious, a son turers of children children and children of his children ; the just a, sickly ; he met him and saw, had, the children. the stairs ; \n", - "was and he is and helen mother of the left, she child are suffer in seeing father's death. and they their is feels not bear feel the. be a pain. \n", - "1 _classify_one........ . \n", - "... ;. _classify_... .. . \n", - "... ,,, . ,. object was 90. a year of the death trial marriage., the the nineteenth of january, 20, a court regent of regent by the father as declare his was a his charlotte, wales, the it was was the have for the court beginning to her the slightest sponial of the monarch which that he the marriage meeting he london louis he park he he he young declared sent before the, and scarcely her her william ford sbury, \" do up a glass of port, \" am not want well. \" \n", - "1 malmesbury, him he glass of brandy would fix be sufficient to and which the sultan answered chivalof the throne, and that \" ceremony word, his mistress, _classify_1 _classify_,,,.. . . ;.,... \n", - " . ,,,,,,,.,,. iiian penis was 98, a year of the discovery revolution marriage., the this twenty of january, 20 the a court regent of regent by his friends as take his was a his diana, wales, prince far was was the have for the court beginning of the the smallest sponial of the court. that he the marriage meeting he 18louis he park he the he prince was just before him, and proposed her her howard msbury, \" you up out cup of wine, \" am not care so. \n", - "\n", - "1 _classify_2 19 \n", - "\n", - ", ,,,,,:\n", - "\n", - ". 18first stevenson to literary he his without to society, he was the his seventeenth year by of his life, when of the period of his the english sus _, was to deeper in the for place for the own. the productions. he thinking a for he were no foundation significance. the he was iciency of genan greatest of real facts facts in the life age. the he appear scholars religious observers can aid. in ideas of represent possess are, essential most and prejudices the personality. the personality, in the his ideas he he has back to the period period of his early, \n", - "1 were the -, ages in the, monstrous human, and alistic, in and real person being can be in of them. and they individual that they they is is this is ancients reader can as the subject and of the story -- is of by despised. the. _classify_1, ........ \", \n", - ",... .. .......... , .. .., . . . ; , ,,, ;the the novel stevenson to bestiand scott without to the, he was in a latter year by of his life, the of _ end of the the first sus _, was to one in the for space for the own. the productions. he content a for as were no foundation foundation. he he is iciency, gena great of any ideals facts in his life age. which he form scholars literary writers can show. he most of represent possess are so in motive and and and impressions, shelley character. they shelley reideal, he has back to the beginning period, his early, \n", - "1 _classify_3 22 \n", - " ., . ...., . .,. ,. :. \n", - " . ;. . it i were be any need, if demons, no earth, no which end,, nothing paracworld unvoid meaningless like nothingness meaningless, unmerciful led nothingness, then there creation were die be created eden's singu, my gospel, his ear, his blessing, his me in as last, the of the existence,, -- there all if in matter, or a than a the shape of i me, the, a my the,, dust, a should not to he not his and his arms arms, i carry me lips on me, as tell me from, and, down into -- everything there not the in to potent _, earth, to was at, \n", - "1 when he, not die presence exist on heaven the i ? in is, him : all,. soul, of he and and, and ! _classify_1 _classify_ . ...... _ _. .. ,. ., .,....,. ....,., . . . ,.,,.. i the is be any more, \" power, no earth, no which end,, no paracworld open void endless like void empty, dreamless merciful led space, i there the were are be equal own's,, my spirit, his sword, his blessing, his me with making universe surrounding the of all life world, if all the in in child in than a than a when presence of i me, the, the my the,, dust, a was appear to if be his up his arms winds, i kiss his lips on me, i kiss me from into and, down, \n", - "i not not so ? to potent _, the ? and not, in \n", - "1 _classify_. 22 _classify_. ..,..,.. ..,, ;,,,\n", - "\n", - ". was who he work were him, was short work for a moment, turned his his ring's portion, he, not the similar opinion from he old anic he the day actiwas a his book confessions pan _ great _ was not tive and character with the's _ ary eloqu. the they. his english. but the's a attack at crowd dread fled the and to flee over attack. the feet. \n", - "1 were him feet and they was feel avoid his feet ; and they was are no from he ton ne has shown, a numerous and. foot of on too by _classify_1 \",__classify_. .. ... ,,,.,,,the, who speaking name were him, was short work and a moment to looked his his ring's head. he, not a opinion opinion from he old anic was the poet actiwas no the book _ pan _ great _ was not tive. disguise with the's _ ary and. the they. the _. the he's a attack at women, which the rose to take over attack. his feet. \n", - "1 _classify_2 23.. \n", - ".. . , .,,,the was born born of action acuand and his it in well by he he did in a a, but in or small. but a. he was a first antiof the, he in was the the characteristics of he has to he moore was into of the his intellect of genius. he was built the lives poets of poets poets. history century country. and that has no the power century of this century literary.. poem. them own life oraries. \n", - "the was a as dickens of been been,, the genius in word. and had none he the gifted stic. the, sceptiwere not have any poetry have compared absent ated by affecting tching causes.. the same. _classify_1 \" \" ',.., . _classify_. ..,..........., .... . . . .. .,. \n", - " .. ., . ,'s a pioneer of action acu, and his great in well by he he wrote in a just, but in or mediocre. but a. his was the first master of the, he he was it the advantages of he possesses him and he had into of the his genius the the. he had in the very generation of poets poets. history generation english. and he has also the power century ; this century literary.. empire. them own life oraries. \n", - "1 _classify_. 23 26 ,,..,. , . ,:,..,,., : ., \n", - "\n", - "., in satisfied of the, his sm, said a in the own and he was his empire beyond it was the whole. he he did the in without his, opposed whole, ? supreme sin in his character ? a which rate a the poetry of by his period volume of his work - life, he defect was like much of poetic the, love - was the since in iled in the - indulgence. in poet of self and a the person intellect was the most, lacking him a denied him by \n", - "1 his the masterpiece poem in a true he had upon the first of and was was a than a to make him transition material aware him a extraordinary. _classify_1 \" _classify_... .. ;..,...... ......,. ... .. . : ,,,.. ,....... ,,.,,:,. -,. \n", - ".. . \"'s who healthy of the, the sm, made a in the work, and was his society and it was every whole. he he was his in without shelley when opposed whole, ? thousand sin ? his character ? a which rate an his poetic of by his period three of his life brilliant life, he defect, now bold of self the, the - was so since in iled in love - indulgence. he poet of self and a the person intellect, literature purest was replaced him a absent him. \n", - "1 _classify_2 19 \"...... .., \".... .... \n", - ".. ,,. .,,,... ;,,,,.,, ,, in by of be seen of the's works in though that most of his youth love from bermuda icorn to the ith um. the his voyage death and paris he lost, of the weather journey of sailing the friends wife, the for lost for and of the sad that the lost deserted suspected sable body of in story ilpoet, that ship of nailed to, be carried, the's ship was then to the pyre ; the's his lney, the cian creroman rites ance. the not the with the grief. the land, who, and, and other were as to in the scene, \n", - "1 sacrifice before sunny, the air beautiful magical. the sun, and the of the sun ëines of, _classify_1 ; . . .,,,. ,. . ; . , the by of be seen of the's life in his a most of his youth journey from ireland icorn to the oyum. and his adventures appearance which ireland he lost, of the loss voyage spent in the family shipin the for island for of of the terrible that his treasure mythical recognisable remains of the tale ilpoet in that witness of addressed to, be carried, the's boat was then to the pyre, a, his acle ney in the cian precision roman rites ance ; the so the with the epi. the fort was who, and, and other were were consumed into the pyre, \n", - "1 _classify_2 2 \n", - "..... , \n" - ] - }, - { - "ename": "IndexError", - "evalue": "index 5 is out of bounds for axis 0 with size 5", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mIndexError\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0mtemperature\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m0.5\u001b[0m \u001b[0;31m# 1 is quite random, 0 is the most likely letter every time\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mbatch\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mprobs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 4\u001b[0;31m \u001b[0mprobs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mnp_softmax\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlm_logits\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mbatch\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mt\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mtemperature\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# softmax with temperature\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 5\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 6\u001b[0m \u001b[0mdist\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdistributions\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mMultinomial\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mprobs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfrom_numpy\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mprobs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# make distribution\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;31mIndexError\u001b[0m: index 5 is out of bounds for axis 0 with size 5" - ] - } - ], + "outputs": [], "source": [ - "decoder = {v:k for k,v in text_encoder.encoder.items()}\n", - "temperature = 0.5 # 1 is quite random, 0 is the most likely letter every time\n", - "for batch in range(probs.shape[0]):\n", - " probs=np_softmax(lm_logits[batch], t=temperature) # softmax with temperature\n", + "model_opt = OpenAIAdam(\n", + " params = language_model.parameters(),\n", + " lr = 6.25e-5,\n", + " schedule = 'warmup_linear',\n", + " warmup = 0.002,\n", + " t_total = n_updates_total,\n", + " b1 = 0.9,\n", + " b2 = 0.999,\n", + " e = 1e-8,\n", + " l2 = 0.01,\n", + " vector_l2 = 'store_true',\n", + " max_grad_norm = 1\n", + ")\n", + "criterion = nn.CrossEntropyLoss(reduce = False)\n", + "compute_loss_fct = LanguageModelingLossCompute(\n", + " lm_criterion = criterion,\n", + " opt = model_opt\n", + ")\n", "\n", - " dist = torch.distributions.Multinomial(probs=torch.from_numpy(probs)) # make distribution\n", - " y_encoded = dist.sample().argmax(-1).numpy() # sample\n", - " y_text = [decoder.get(i, \"\".format(i)) for i in y_encoded] # decode\n", - " y_string = '\\n'+''.join(y_text) # join into text\n", + "generation_parameters = {\n", + " 'model' : language_model,\n", + " 'text_encoder' : text_encoder,\n", + " 'sentence' : 'You had a great morning but your afternoon will be ruined because',\n", + " 'window_size' : window_size,\n", + " 'n_vocab' : n_vocab,\n", + " 'n_special' : n_special,\n", + " 'n_ctx' : n_ctx,\n", + " 'device' : device,\n", + " 'final_len' : 150\n", + "}" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# run" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "start_time": "2018-11-04T09:35:46.429Z" + } + }, + "outputs": [], + "source": [ "\n", - " # clean up tokens\n", - " y_string_raw = str(y_string)\n", - " for a, b in convert:\n", - " y_string = y_string.replace(a, b)\n", - "\n", - " print(y_string)" + "for epoch in range(epochs):\n", + " run_epoch(\n", + " model = language_model,\n", + " n_batch_train = n_batch_train,\n", + " device = device,\n", + " compute_loss_fct = compute_loss_fct,\n", + " logger = logger,\n", + " save_dir = save_dir,\n", + " desc = desc,\n", + " submit = submit,\n", + " n_valid = n_valid,\n", + " n_epochs = epoch,\n", + " X_train = X_train_trans,\n", + " X_train_mask = X_train_mask,\n", + " y_train = y_train,\n", + " X_val = X_val_trans,\n", + " X_val_mask = X_val_mask,\n", + " y_val = y_val,\n", + " generation_params = generation_parameters\n", + " )\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "start_time": "2018-11-04T09:35:46.435Z" + } + }, + "outputs": [], + "source": [ + "result = try_on_a_sentence(**generation_parameters)\n", + "print(\"\\n\\n Base: {} \\n\\n Result: {}\".format(generation_parameters['sentence'], result))" ] }, { @@ -1381,7 +708,7 @@ }, "moveMenuLeft": true, "nav_menu": { - "height": "171px", + "height": "240px", "width": "251px" }, "navigate_menu": true, @@ -1390,7 +717,7 @@ "threshold": 4, "toc_cell": false, "toc_position": { - "height": "553px", + "height": "526px", "left": "0px", "right": "1191px", "top": "149px", diff --git a/utils.py b/utils.py index 79f8fe2..061b8ae 100644 --- a/utils.py +++ b/utils.py @@ -4,9 +4,7 @@ import json import time from functools import partial import numpy as np -# import tensorflow as tf -# from tensorflow.python.framework import function -from tqdm import tqdm +from tqdm import tqdm_notebook as tqdm def encode_dataset(*splits, encoder): encoded_splits = [] @@ -92,7 +90,7 @@ def iter_data(*datas, n_batch=128, truncate=False, verbose=False, max_batches=fl f = sys.stderr else: f = open(os.devnull, 'w') - for i in tqdm(range(0, n, n_batch), total=n//n_batch, file=f, ncols=80, leave=False): + for i in tqdm(range(0, n, n_batch), total=n//n_batch, mininterval=10, file=f, ncols=80, leave=False): if n_batches >= max_batches: raise StopIteration if len(datas) == 1: yield datas[0][i:i+n_batch]