Solving missing variable issue.

This commit is contained in:
Grégory Châtel
2018-07-04 13:50:19 +02:00
parent 2b7e97e307
commit be407cdd37
+1
View File
@@ -177,6 +177,7 @@ class LMHead(nn.Module):
def __init__(self, model, cfg):
super(LMHead, self).__init__()
self.n_embd = cfg.n_embd
embed_shape = model.embed.weight.shape
self.decoder = nn.Linear(embed_shape[1], embed_shape[0], bias=False)
self.decoder.weight = model.embed.weight # Tied weights