mirror of
https://github.com/wassname/openai-transformer-lm-gutenberg-erotic.git
synced 2026-06-27 16:10:19 +08:00
fix model and training
This commit is contained in:
Binary file not shown.
|
After Width: | Height: | Size: 204 KiB |
+1
-1
@@ -27,7 +27,7 @@ def _rocstories(path):
|
|||||||
y.append(int(line[-1])-1)
|
y.append(int(line[-1])-1)
|
||||||
return st, ct1, ct2, y
|
return st, ct1, ct2, y
|
||||||
|
|
||||||
def rocstories(data_dir, n_train=1497, n_valid=2): #374): # TODO: set this back
|
def rocstories(data_dir, n_train=1497, n_valid=374):
|
||||||
storys, comps1, comps2, ys = _rocstories(os.path.join(data_dir, 'cloze_test_val__spring2016 - cloze_test_ALL_val.csv'))
|
storys, comps1, comps2, ys = _rocstories(os.path.join(data_dir, 'cloze_test_val__spring2016 - cloze_test_ALL_val.csv'))
|
||||||
teX1, teX2, teX3, _ = _rocstories(os.path.join(data_dir, 'cloze_test_test__spring2016 - cloze_test_ALL_test.csv'))
|
teX1, teX2, teX3, _ = _rocstories(os.path.join(data_dir, 'cloze_test_test__spring2016 - cloze_test_ALL_test.csv'))
|
||||||
tr_storys, va_storys, tr_comps1, va_comps1, tr_comps2, va_comps2, tr_ys, va_ys = train_test_split(storys, comps1, comps2, ys, test_size=n_valid, random_state=seed)
|
tr_storys, va_storys, tr_comps1, va_comps1, tr_comps2, va_comps2, tr_ys, va_ys = train_test_split(storys, comps1, comps2, ys, test_size=n_valid, random_state=seed)
|
||||||
|
|||||||
+10
-11
@@ -60,13 +60,12 @@ class Conv1D(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class Attention(nn.Module):
|
class Attention(nn.Module):
|
||||||
def __init__(self, nx, cfg, scale=False):
|
def __init__(self, nx, n_ctx, cfg, scale=False):
|
||||||
super(Attention, self).__init__()
|
super(Attention, self).__init__()
|
||||||
n_state = nx # in Attention: n_state=768 (nx=n_embd)
|
n_state = nx # in Attention: n_state=768 (nx=n_embd)
|
||||||
#[switch nx => n_state from Block to Attention to keep identical to TF implem]
|
#[switch nx => n_state from Block to Attention to keep identical to TF implem]
|
||||||
assert n_state % cfg.n_head==0
|
assert n_state % cfg.n_head==0
|
||||||
mask_size = n_state // cfg.n_head
|
self.register_buffer('b', torch.tril(torch.ones(n_ctx, n_ctx)).view(1, 1, n_ctx, n_ctx))
|
||||||
self.register_buffer('b', torch.tril(torch.ones(mask_size, mask_size)).view(1, 1, mask_size, mask_size))
|
|
||||||
self.n_head = cfg.n_head
|
self.n_head = cfg.n_head
|
||||||
self.split_size = n_state
|
self.split_size = n_state
|
||||||
self.scale = scale
|
self.scale = scale
|
||||||
@@ -126,10 +125,10 @@ class MLP(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class Block(nn.Module):
|
class Block(nn.Module):
|
||||||
def __init__(self, cfg, scale=False):
|
def __init__(self, n_ctx, cfg, scale=False):
|
||||||
super(Block, self).__init__()
|
super(Block, self).__init__()
|
||||||
nx = cfg.n_embd
|
nx = cfg.n_embd
|
||||||
self.attn = Attention(nx, cfg, scale)
|
self.attn = Attention(nx, n_ctx, cfg, scale)
|
||||||
self.ln_1 = LayerNorm(nx)
|
self.ln_1 = LayerNorm(nx)
|
||||||
self.mlp = MLP(4*nx, cfg)
|
self.mlp = MLP(4*nx, cfg)
|
||||||
self.ln_2 = LayerNorm(nx)
|
self.ln_2 = LayerNorm(nx)
|
||||||
@@ -144,12 +143,12 @@ class Block(nn.Module):
|
|||||||
|
|
||||||
class Model(nn.Module):
|
class Model(nn.Module):
|
||||||
""" Transformer model """
|
""" Transformer model """
|
||||||
def __init__(self, vocab, cfg):
|
def __init__(self, vocab, n_ctx, cfg):
|
||||||
super(Model, self).__init__()
|
super(Model, self).__init__()
|
||||||
self.vocab = vocab
|
self.vocab = vocab
|
||||||
self.embed = nn.Embedding(vocab, cfg.n_embd)
|
self.embed = nn.Embedding(vocab, cfg.n_embd)
|
||||||
self.drop = nn.Dropout(cfg.embd_pdrop)
|
self.drop = nn.Dropout(cfg.embd_pdrop)
|
||||||
block = Block(cfg, scale=True)
|
block = Block(n_ctx, cfg, scale=True)
|
||||||
self.h = nn.ModuleList([copy.deepcopy(block) for _ in range(cfg.n_layer)])
|
self.h = nn.ModuleList([copy.deepcopy(block) for _ in range(cfg.n_layer)])
|
||||||
self.decoder = nn.Linear(cfg.n_embd, vocab, bias=False)
|
self.decoder = nn.Linear(cfg.n_embd, vocab, bias=False)
|
||||||
self.decoder.weight = self.embed.weight # Tied weights
|
self.decoder.weight = self.embed.weight # Tied weights
|
||||||
@@ -205,12 +204,12 @@ class ClfHead(nn.Module):
|
|||||||
return clf_logits.view(-1, 2)
|
return clf_logits.view(-1, 2)
|
||||||
|
|
||||||
|
|
||||||
def load_openai_pretrained_model(model, n_ctx=-1, n_special=-1, n_transfer=12, n_embd=768, path='model'):
|
def load_openai_pretrained_model(model, n_ctx=-1, n_special=-1, n_transfer=12, n_embd=768, path='./model/', path_names='./'):
|
||||||
# Load weights from TF model
|
# Load weights from TF model
|
||||||
shapes = json.load(open(path + '/params_shapes.json'))
|
names = json.load(open(path_names + 'parameters_names.json'))
|
||||||
names = json.load(open(path + '/parameters_names.json'))
|
shapes = json.load(open(path + 'params_shapes.json'))
|
||||||
offsets = np.cumsum([np.prod(shape) for shape in shapes])
|
offsets = np.cumsum([np.prod(shape) for shape in shapes])
|
||||||
init_params = [np.load(path + '/params_{}.npy'.format(n)) for n in range(10)]
|
init_params = [np.load(path + 'params_{}.npy'.format(n)) for n in range(10)]
|
||||||
init_params = np.split(np.concatenate(init_params, 0), offsets)[:-1]
|
init_params = np.split(np.concatenate(init_params, 0), offsets)[:-1]
|
||||||
init_params = [param.reshape(shape) for param, shape in zip(init_params, shapes)]
|
init_params = [param.reshape(shape) for param, shape in zip(init_params, shapes)]
|
||||||
if n_ctx > 0:
|
if n_ctx > 0:
|
||||||
|
|||||||
@@ -0,0 +1 @@
|
|||||||
|
["model/we:0", "model/h0/attn/c_attn/w:0", "model/h0/attn/c_attn/b:0", "model/h0/attn/c_proj/w:0", "model/h0/attn/c_proj/b:0", "model/h0/ln_1/g:0", "model/h0/ln_1/b:0", "model/h0/mlp/c_fc/w:0", "model/h0/mlp/c_fc/b:0", "model/h0/mlp/c_proj/w:0", "model/h0/mlp/c_proj/b:0", "model/h0/ln_2/g:0", "model/h0/ln_2/b:0", "model/h1/attn/c_attn/w:0", "model/h1/attn/c_attn/b:0", "model/h1/attn/c_proj/w:0", "model/h1/attn/c_proj/b:0", "model/h1/ln_1/g:0", "model/h1/ln_1/b:0", "model/h1/mlp/c_fc/w:0", "model/h1/mlp/c_fc/b:0", "model/h1/mlp/c_proj/w:0", "model/h1/mlp/c_proj/b:0", "model/h1/ln_2/g:0", "model/h1/ln_2/b:0", "model/h2/attn/c_attn/w:0", "model/h2/attn/c_attn/b:0", "model/h2/attn/c_proj/w:0", "model/h2/attn/c_proj/b:0", "model/h2/ln_1/g:0", "model/h2/ln_1/b:0", "model/h2/mlp/c_fc/w:0", "model/h2/mlp/c_fc/b:0", "model/h2/mlp/c_proj/w:0", "model/h2/mlp/c_proj/b:0", "model/h2/ln_2/g:0", "model/h2/ln_2/b:0", "model/h3/attn/c_attn/w:0", "model/h3/attn/c_attn/b:0", "model/h3/attn/c_proj/w:0", "model/h3/attn/c_proj/b:0", "model/h3/ln_1/g:0", "model/h3/ln_1/b:0", "model/h3/mlp/c_fc/w:0", "model/h3/mlp/c_fc/b:0", "model/h3/mlp/c_proj/w:0", "model/h3/mlp/c_proj/b:0", "model/h3/ln_2/g:0", "model/h3/ln_2/b:0", "model/h4/attn/c_attn/w:0", "model/h4/attn/c_attn/b:0", "model/h4/attn/c_proj/w:0", "model/h4/attn/c_proj/b:0", "model/h4/ln_1/g:0", "model/h4/ln_1/b:0", "model/h4/mlp/c_fc/w:0", "model/h4/mlp/c_fc/b:0", "model/h4/mlp/c_proj/w:0", "model/h4/mlp/c_proj/b:0", "model/h4/ln_2/g:0", "model/h4/ln_2/b:0", "model/h5/attn/c_attn/w:0", "model/h5/attn/c_attn/b:0", "model/h5/attn/c_proj/w:0", "model/h5/attn/c_proj/b:0", "model/h5/ln_1/g:0", "model/h5/ln_1/b:0", "model/h5/mlp/c_fc/w:0", "model/h5/mlp/c_fc/b:0", "model/h5/mlp/c_proj/w:0", "model/h5/mlp/c_proj/b:0", "model/h5/ln_2/g:0", "model/h5/ln_2/b:0", "model/h6/attn/c_attn/w:0", "model/h6/attn/c_attn/b:0", "model/h6/attn/c_proj/w:0", "model/h6/attn/c_proj/b:0", "model/h6/ln_1/g:0", "model/h6/ln_1/b:0", "model/h6/mlp/c_fc/w:0", "model/h6/mlp/c_fc/b:0", "model/h6/mlp/c_proj/w:0", "model/h6/mlp/c_proj/b:0", "model/h6/ln_2/g:0", "model/h6/ln_2/b:0", "model/h7/attn/c_attn/w:0", "model/h7/attn/c_attn/b:0", "model/h7/attn/c_proj/w:0", "model/h7/attn/c_proj/b:0", "model/h7/ln_1/g:0", "model/h7/ln_1/b:0", "model/h7/mlp/c_fc/w:0", "model/h7/mlp/c_fc/b:0", "model/h7/mlp/c_proj/w:0", "model/h7/mlp/c_proj/b:0", "model/h7/ln_2/g:0", "model/h7/ln_2/b:0", "model/h8/attn/c_attn/w:0", "model/h8/attn/c_attn/b:0", "model/h8/attn/c_proj/w:0", "model/h8/attn/c_proj/b:0", "model/h8/ln_1/g:0", "model/h8/ln_1/b:0", "model/h8/mlp/c_fc/w:0", "model/h8/mlp/c_fc/b:0", "model/h8/mlp/c_proj/w:0", "model/h8/mlp/c_proj/b:0", "model/h8/ln_2/g:0", "model/h8/ln_2/b:0", "model/h9/attn/c_attn/w:0", "model/h9/attn/c_attn/b:0", "model/h9/attn/c_proj/w:0", "model/h9/attn/c_proj/b:0", "model/h9/ln_1/g:0", "model/h9/ln_1/b:0", "model/h9/mlp/c_fc/w:0", "model/h9/mlp/c_fc/b:0", "model/h9/mlp/c_proj/w:0", "model/h9/mlp/c_proj/b:0", "model/h9/ln_2/g:0", "model/h9/ln_2/b:0", "model/h10/attn/c_attn/w:0", "model/h10/attn/c_attn/b:0", "model/h10/attn/c_proj/w:0", "model/h10/attn/c_proj/b:0", "model/h10/ln_1/g:0", "model/h10/ln_1/b:0", "model/h10/mlp/c_fc/w:0", "model/h10/mlp/c_fc/b:0", "model/h10/mlp/c_proj/w:0", "model/h10/mlp/c_proj/b:0", "model/h10/ln_2/g:0", "model/h10/ln_2/b:0", "model/h11/attn/c_attn/w:0", "model/h11/attn/c_attn/b:0", "model/h11/attn/c_proj/w:0", "model/h11/attn/c_proj/b:0", "model/h11/ln_1/g:0", "model/h11/ln_1/b:0", "model/h11/mlp/c_fc/w:0", "model/h11/mlp/c_fc/b:0", "model/h11/mlp/c_proj/w:0", "model/h11/mlp/c_proj/b:0", "model/h11/ln_2/g:0", "model/h11/ln_2/b:0", "model/clf/w:0", "model/clf/b:0"]
|
||||||
@@ -3,7 +3,6 @@ import os
|
|||||||
import time
|
import time
|
||||||
import math
|
import math
|
||||||
import json
|
import json
|
||||||
import joblib
|
|
||||||
import random
|
import random
|
||||||
import argparse
|
import argparse
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -76,8 +75,9 @@ def transform_roc(X1, X2, X3):
|
|||||||
return xmb, mmb
|
return xmb, mmb
|
||||||
|
|
||||||
def iter_apply(Xs, Ms, Ys):
|
def iter_apply(Xs, Ms, Ys):
|
||||||
fns = [lambda x: np.concatenate(x, 0), lambda x: float(np.sum(x))]
|
# fns = [lambda x: np.concatenate(x, 0), lambda x: float(np.sum(x))]
|
||||||
results = []
|
logits = []
|
||||||
|
cost = 0
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
model.eval()
|
model.eval()
|
||||||
for xmb, mmb, ymb in iter_data(Xs, Ms, Ys, n_batch=n_batch_train, truncate=False, verbose=True):
|
for xmb, mmb, ymb in iter_data(Xs, Ms, Ys, n_batch=n_batch_train, truncate=False, verbose=True):
|
||||||
@@ -87,11 +87,13 @@ def iter_apply(Xs, Ms, Ys):
|
|||||||
MMB = torch.tensor(mmb).to(device)
|
MMB = torch.tensor(mmb).to(device)
|
||||||
h = model(XMB)
|
h = model(XMB)
|
||||||
clf_logits = clf_head(h, XMB)
|
clf_logits = clf_head(h, XMB)
|
||||||
|
clf_logits *= n
|
||||||
clf_losses = compute_loss_fct(XMB, YMB, MMB, clf_logits, only_return_losses=True)
|
clf_losses = compute_loss_fct(XMB, YMB, MMB, clf_logits, only_return_losses=True)
|
||||||
res = (clf_logits.numpy()*n, clf_losses.numpy()*n)
|
clf_losses *= n
|
||||||
results.append(res)
|
logits.append(clf_logits.to("cpu").numpy())
|
||||||
results = zip(*results)
|
cost += clf_losses.sum().item()
|
||||||
return [fn(res) for res, fn in zip(results, fns)]
|
logits = np.concatenate(logits, 0)
|
||||||
|
return logits, cost
|
||||||
|
|
||||||
def iter_predict(Xs, Ms):
|
def iter_predict(Xs, Ms):
|
||||||
logits = []
|
logits = []
|
||||||
@@ -103,12 +105,13 @@ def iter_predict(Xs, Ms):
|
|||||||
MMB = torch.tensor(mmb).to(device)
|
MMB = torch.tensor(mmb).to(device)
|
||||||
h = model(XMB)
|
h = model(XMB)
|
||||||
clf_logits = clf_head(h, XMB)
|
clf_logits = clf_head(h, XMB)
|
||||||
logits.append(clf_logits.numpy())
|
logits.append(clf_logits.to("cpu").numpy())
|
||||||
logits = np.concatenate(logits, 0)
|
logits = np.concatenate(logits, 0)
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
def log():
|
def log():
|
||||||
global best_score
|
global best_score
|
||||||
|
print("Logging")
|
||||||
tr_logits, tr_cost = iter_apply(trX[:n_valid], trM[:n_valid], trY[:n_valid])
|
tr_logits, tr_cost = iter_apply(trX[:n_valid], trM[:n_valid], trY[:n_valid])
|
||||||
va_logits, va_cost = iter_apply(vaX, vaM, vaY)
|
va_logits, va_cost = iter_apply(vaX, vaM, vaY)
|
||||||
tr_cost = tr_cost/len(trY[:n_valid])
|
tr_cost = tr_cost/len(trY[:n_valid])
|
||||||
@@ -142,10 +145,10 @@ def run_epoch():
|
|||||||
for xmb, mmb, ymb in iter_data(*shuffle(trX, trM, trYt, random_state=np.random),
|
for xmb, mmb, ymb in iter_data(*shuffle(trX, trM, trYt, random_state=np.random),
|
||||||
n_batch=n_batch_train, truncate=True, verbose=True):
|
n_batch=n_batch_train, truncate=True, verbose=True):
|
||||||
global n_updates
|
global n_updates
|
||||||
|
model.train()
|
||||||
XMB = torch.tensor(xmb, dtype=torch.long).to(device)
|
XMB = torch.tensor(xmb, dtype=torch.long).to(device)
|
||||||
YMB = torch.tensor(ymb, dtype=torch.long).to(device)
|
YMB = torch.tensor(ymb, dtype=torch.long).to(device)
|
||||||
MMB = torch.tensor(mmb).to(device)
|
MMB = torch.tensor(mmb).to(device)
|
||||||
model.train()
|
|
||||||
h = model(XMB)
|
h = model(XMB)
|
||||||
lm_logits = lm_head(h)
|
lm_logits = lm_head(h)
|
||||||
clf_logits = clf_head(h, XMB)
|
clf_logits = clf_head(h, XMB)
|
||||||
@@ -194,7 +197,7 @@ if __name__ == '__main__':
|
|||||||
parser.add_argument('--clf_pdrop', type=float, default=0.1)
|
parser.add_argument('--clf_pdrop', type=float, default=0.1)
|
||||||
parser.add_argument('--l2', type=float, default=0.01)
|
parser.add_argument('--l2', type=float, default=0.01)
|
||||||
parser.add_argument('--vector_l2', action='store_true')
|
parser.add_argument('--vector_l2', action='store_true')
|
||||||
parser.add_argument('--n_gpu', type=int, default=4)
|
parser.add_argument('--n_gpu', type=int, default=1)#4) # TODO add mutli-gpu training logic
|
||||||
parser.add_argument('--opt', type=str, default='adam')
|
parser.add_argument('--opt', type=str, default='adam')
|
||||||
parser.add_argument('--afn', type=str, default='gelu')
|
parser.add_argument('--afn', type=str, default='gelu')
|
||||||
parser.add_argument('--lr_schedule', type=str, default='warmup_linear')
|
parser.add_argument('--lr_schedule', type=str, default='warmup_linear')
|
||||||
@@ -205,6 +208,7 @@ if __name__ == '__main__':
|
|||||||
parser.add_argument('--b1', type=float, default=0.9)
|
parser.add_argument('--b1', type=float, default=0.9)
|
||||||
parser.add_argument('--b2', type=float, default=0.999)
|
parser.add_argument('--b2', type=float, default=0.999)
|
||||||
parser.add_argument('--e', type=float, default=1e-8)
|
parser.add_argument('--e', type=float, default=1e-8)
|
||||||
|
parser.add_argument('--n_valid', type=int, default=374)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
print(args)
|
print(args)
|
||||||
@@ -215,14 +219,14 @@ if __name__ == '__main__':
|
|||||||
torch.cuda.manual_seed_all(seed)
|
torch.cuda.manual_seed_all(seed)
|
||||||
|
|
||||||
# torch.device object used throughout this script TODO add gpu setting
|
# torch.device object used throughout this script TODO add gpu setting
|
||||||
device = torch.device("cpu") #"cuda" if use_cuda else "cpu")
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
|
||||||
logger = ResultLogger(path=os.path.join(log_dir, '{}.jsonl'.format(desc)), **args.__dict__)
|
logger = ResultLogger(path=os.path.join(log_dir, '{}.jsonl'.format(desc)), **args.__dict__)
|
||||||
text_encoder = TextEncoder(encoder_path, bpe_path)
|
text_encoder = TextEncoder(encoder_path, bpe_path)
|
||||||
encoder = text_encoder.encoder
|
encoder = text_encoder.encoder
|
||||||
n_vocab = len(text_encoder.encoder)
|
n_vocab = len(text_encoder.encoder)
|
||||||
|
|
||||||
(trX1, trX2, trX3, trY), (vaX1, vaX2, vaX3, vaY), (teX1, teX2, teX3) = encode_dataset(rocstories(data_dir), encoder=text_encoder)
|
(trX1, trX2, trX3, trY), (vaX1, vaX2, vaX3, vaY), (teX1, teX2, teX3) = encode_dataset(rocstories(data_dir, n_valid=n_valid), encoder=text_encoder)
|
||||||
n_y = 2
|
n_y = 2
|
||||||
encoder['_start_'] = len(encoder)
|
encoder['_start_'] = len(encoder)
|
||||||
encoder['_delimiter_'] = len(encoder)
|
encoder['_delimiter_'] = len(encoder)
|
||||||
@@ -231,10 +235,13 @@ if __name__ == '__main__':
|
|||||||
n_special = 3
|
n_special = 3
|
||||||
max_len = n_ctx//2-2
|
max_len = n_ctx//2-2
|
||||||
n_ctx = min(max(
|
n_ctx = min(max(
|
||||||
[len(x1[:max_len])+max(len(x2[:max_len]), len(x3[:max_len])) for x1, x2, x3 in zip(trX1, trX2, trX3)]
|
[len(x1[:max_len]) + max(len(x2[:max_len]),
|
||||||
+[len(x1[:max_len])+max(len(x2[:max_len]), len(x3[:max_len])) for x1, x2, x3 in zip(vaX1, vaX2, vaX3)]
|
len(x3[:max_len])) for x1, x2, x3 in zip(trX1, trX2, trX3)]
|
||||||
+[len(x1[:max_len])+max(len(x2[:max_len]), len(x3[:max_len])) for x1, x2, x3 in zip(teX1, teX2, teX3)]
|
+[len(x1[:max_len]) + max(len(x2[:max_len]),
|
||||||
)+3,
|
len(x3[:max_len])) for x1, x2, x3 in zip(vaX1, vaX2, vaX3)]
|
||||||
|
+[len(x1[:max_len]) + max(len(x2[:max_len]),
|
||||||
|
len(x3[:max_len])) for x1, x2, x3 in zip(teX1, teX2, teX3)]
|
||||||
|
) + 3,
|
||||||
n_ctx)
|
n_ctx)
|
||||||
vocab = n_vocab + n_special + n_ctx
|
vocab = n_vocab + n_special + n_ctx
|
||||||
trX, trM = transform_roc(trX1, trX2, trX3)
|
trX, trM = transform_roc(trX1, trX2, trX3)
|
||||||
@@ -247,7 +254,7 @@ if __name__ == '__main__':
|
|||||||
n_batch_train = n_batch*n_gpu
|
n_batch_train = n_batch*n_gpu
|
||||||
n_updates_total = (n_train//n_batch_train)*n_iter
|
n_updates_total = (n_train//n_batch_train)*n_iter
|
||||||
|
|
||||||
model = Model(vocab, args)
|
model = Model(vocab, n_ctx, args)
|
||||||
lm_head = LMHead(model, args)
|
lm_head = LMHead(model, args)
|
||||||
clf_head = ClfHead(clf_token, args)
|
clf_head = ClfHead(clf_token, args)
|
||||||
|
|
||||||
@@ -271,8 +278,8 @@ if __name__ == '__main__':
|
|||||||
path = os.path.join(save_dir, desc, 'best_params')
|
path = os.path.join(save_dir, desc, 'best_params')
|
||||||
torch.save(model.state_dict(), make_path(path))
|
torch.save(model.state_dict(), make_path(path))
|
||||||
best_score = 0
|
best_score = 0
|
||||||
log()
|
|
||||||
for i in range(n_iter):
|
for i in range(n_iter):
|
||||||
|
print("running epoch", i)
|
||||||
run_epoch()
|
run_epoch()
|
||||||
n_epochs += 1
|
n_epochs += 1
|
||||||
log()
|
log()
|
||||||
|
|||||||
Reference in New Issue
Block a user