This commit is contained in:
Kashif Rasul
2022-08-16 23:19:05 +02:00
parent c75cebd2bd
commit e264e8856f
+1 -1
View File
@@ -410,7 +410,7 @@ class PerceiverARModel(nn.Module):
perciever_input = torch.cat((lags, features), dim=-1)
prefix, x = (
perciever_input[:, : self.context_length - 1, ...],,
perciever_input[:, : self.context_length - 1, ...],
perciever_input[:, self.context_length - 1 :, ...],
)