fix prefix and latent split

This commit is contained in:
Kashif Rasul
2022-08-16 23:18:10 +02:00
parent 6f77e3839f
commit c75cebd2bd
+1 -2
View File
@@ -1,4 +1,3 @@
import pdb
from typing import List, Optional, Tuple
import torch
@@ -411,7 +410,7 @@ class PerceiverARModel(nn.Module):
perciever_input = torch.cat((lags, features), dim=-1)
prefix, x = (
perciever_input,
perciever_input[:, : self.context_length - 1, ...],,
perciever_input[:, self.context_length - 1 :, ...],
)