prefix is the full sequence

This commit is contained in:
Kashif Rasul
2022-08-10 22:43:53 +02:00
parent eddc77dd74
commit d8e8030994
2 changed files with 207 additions and 88 deletions
+1 -1
View File
@@ -411,7 +411,7 @@ class PerceiverARModel(nn.Module):
perciever_input = torch.cat((lags, features), dim=-1)
prefix, x = (
perciever_input[:, : self.context_length - 1, ...],
perciever_input,
perciever_input[:, self.context_length - 1 :, ...],
)
File diff suppressed because one or more lines are too long