mirror of
https://github.com/wassname/openai-transformer-lm-gutenberg-erotic.git
synced 2026-06-26 16:00:39 +08:00
100 lines
2.7 KiB
Python
100 lines
2.7 KiB
Python
import os
|
|
import sys
|
|
import json
|
|
import time
|
|
from functools import partial
|
|
import numpy as np
|
|
from tqdm import tqdm_notebook as tqdm
|
|
|
|
def encode_dataset(*splits, encoder):
|
|
encoded_splits = []
|
|
for split in splits:
|
|
fields = []
|
|
for field in split:
|
|
if isinstance(field[0], str):
|
|
field = encoder.encode(field)
|
|
fields.append(field)
|
|
encoded_splits.append(fields)
|
|
return encoded_splits
|
|
|
|
def stsb_label_encoding(labels, nclass=6):
|
|
"""
|
|
Label encoding from Tree LSTM paper (Tai, Socher, Manning)
|
|
"""
|
|
Y = np.zeros((len(labels), nclass)).astype(np.float32)
|
|
for j, y in enumerate(labels):
|
|
for i in range(nclass):
|
|
if i == np.floor(y) + 1:
|
|
Y[j,i] = y - np.floor(y)
|
|
if i == np.floor(y):
|
|
Y[j,i] = np.floor(y) - y + 1
|
|
return Y
|
|
|
|
def np_softmax(x, t=1):
|
|
x = x/t
|
|
x = x - np.max(x, axis=-1, keepdims=True)
|
|
ex = np.exp(x)
|
|
return ex/np.sum(ex, axis=-1, keepdims=True)
|
|
|
|
def make_path(f):
|
|
d = os.path.dirname(f)
|
|
if d and not os.path.exists(d):
|
|
os.makedirs(d)
|
|
return f
|
|
|
|
def _identity_init(shape, dtype, partition_info, scale):
|
|
n = shape[-1]
|
|
w = np.eye(n)*scale
|
|
if len([s for s in shape if s != 1]) == 2:
|
|
w = w.reshape(shape)
|
|
return w.astype(np.float32)
|
|
|
|
def identity_init(scale=1.0):
|
|
return partial(_identity_init, scale=scale)
|
|
|
|
def _np_init(shape, dtype, partition_info, w):
|
|
return w
|
|
|
|
def np_init(w):
|
|
return partial(_np_init, w=w)
|
|
|
|
class ResultLogger(object):
|
|
def __init__(self, path, *args, **kwargs):
|
|
if 'time' not in kwargs:
|
|
kwargs['time'] = time.time()
|
|
self.f_log = open(make_path(path), 'w')
|
|
self.f_log.write(json.dumps(kwargs)+'\n')
|
|
|
|
def log(self, **kwargs):
|
|
if 'time' not in kwargs:
|
|
kwargs['time'] = time.time()
|
|
self.f_log.write(json.dumps(kwargs)+'\n')
|
|
self.f_log.flush()
|
|
|
|
def close(self):
|
|
self.f_log.close()
|
|
|
|
def flatten(outer):
|
|
return [el for inner in outer for el in inner]
|
|
|
|
def remove_none(l):
|
|
return [e for e in l if e is not None]
|
|
|
|
def iter_data(*datas, n_batch=128, truncate=False, verbose=False, max_batches=float("inf")):
|
|
n = len(datas[0])
|
|
if truncate:
|
|
n = (n//n_batch)*n_batch
|
|
n = min(n, max_batches*n_batch)
|
|
n_batches = 0
|
|
if verbose:
|
|
f = sys.stderr
|
|
else:
|
|
f = open(os.devnull, 'w')
|
|
for i in tqdm(range(0, n, n_batch), total=n//n_batch, mininterval=10, file=f, ncols=80, leave=False):
|
|
if n_batches >= max_batches: raise StopIteration
|
|
if len(datas) == 1:
|
|
yield datas[0][i:i+n_batch]
|
|
else:
|
|
yield (d[i:i+n_batch] for d in datas)
|
|
n_batches += 1
|