mirror of
https://github.com/wassname/openai-transformer-lm-gutenberg-erotic.git
synced 2026-06-27 16:10:19 +08:00
@@ -177,6 +177,7 @@ class LMHead(nn.Module):
|
||||
def __init__(self, model, cfg):
|
||||
super(LMHead, self).__init__()
|
||||
self.n_embd = cfg.n_embd
|
||||
embed_shape = model.embed.weight.shape
|
||||
self.decoder = nn.Linear(embed_shape[1], embed_shape[0], bias=False)
|
||||
self.decoder.weight = model.embed.weight # Tied weights
|
||||
|
||||
|
||||
@@ -263,8 +263,7 @@ if __name__ == '__main__':
|
||||
n_batch_train = args.n_batch * max(n_gpu, 1)
|
||||
n_updates_total = (n_train // n_batch_train) * args.n_iter
|
||||
|
||||
if n_gpu > 1:
|
||||
dh_model = DoubleHeadModel(args, clf_token, vocab, n_ctx)
|
||||
dh_model = DoubleHeadModel(args, clf_token, vocab, n_ctx)
|
||||
|
||||
criterion = nn.CrossEntropyLoss(reduce=False)
|
||||
model_opt = OpenAIAdam(dh_model.parameters(),
|
||||
|
||||
Reference in New Issue
Block a user