mirror of
https://github.com/wassname/pytorch-transformer-ts.git
synced 2026-06-27 16:31:19 +08:00
fix prefix and latent split
This commit is contained in:
@@ -1,4 +1,3 @@
|
||||
import pdb
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
@@ -411,7 +410,7 @@ class PerceiverARModel(nn.Module):
|
||||
perciever_input = torch.cat((lags, features), dim=-1)
|
||||
|
||||
prefix, x = (
|
||||
perciever_input,
|
||||
perciever_input[:, : self.context_length - 1, ...],,
|
||||
perciever_input[:, self.context_length - 1 :, ...],
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user