Fixing similarity head to output a (batch, 3) dimensional tensor.

This commit is contained in:
Grégory Châtel
2018-07-13 17:25:26 +02:00
parent 4e6775287d
commit ea7f5006d5
+1 -1
View File
@@ -249,7 +249,7 @@ class SimilarityHead(nn.Module):
self.n_embd = cfg.n_embd
self.clf_token = clf_token
self.dropout = nn.Dropout(cfg.clf_pdrop)
self.linear = nn.Linear(cfg_n_embd, 1)
self.linear = nn.Linear(cfg_n_embd, 3)
nn.init.normal_(self.linear.weight, std = 0.02)
nn.init.normal_(self.linear.bias, 0)