diff --git a/perceiverar/module.py b/perceiverar/module.py index f489567..d9e200a 100644 --- a/perceiverar/module.py +++ b/perceiverar/module.py @@ -1,4 +1,3 @@ -import pdb from typing import List, Optional, Tuple import torch @@ -411,7 +410,7 @@ class PerceiverARModel(nn.Module): perciever_input = torch.cat((lags, features), dim=-1) prefix, x = ( - perciever_input, + perciever_input[:, : self.context_length - 1, ...],, perciever_input[:, self.context_length - 1 :, ...], )