fix imports

This commit is contained in:
Kashif Rasul
2022-03-31 11:57:10 +02:00
parent 6c6380300d
commit e198ba4544
3 changed files with 7 additions and 3 deletions
+4
View File
@@ -22,6 +22,7 @@ class FeatureEmbedder(BaseFeatureEmbedder):
class GatedResidualNetwork(nn.Module):
@validated()
def __init__(
self,
d_hidden: int,
@@ -71,6 +72,7 @@ class GatedResidualNetwork(nn.Module):
class VariableSelectionNetwork(nn.Module):
@validated()
def __init__(
self,
d_hidden: int,
@@ -112,6 +114,7 @@ class VariableSelectionNetwork(nn.Module):
class TemporalFusionEncoder(nn.Module):
@validated()
def __init__(
self,
d_input: int,
@@ -162,6 +165,7 @@ class TemporalFusionEncoder(nn.Module):
class TemporalFusionDecoder(nn.Module):
@validated()
def __init__(
self,
context_length: int,
+2 -2
View File
@@ -30,8 +30,8 @@ from gluonts.transform import (
from gluonts.transform.sampler import InstanceSampler
from torch.utils.data import DataLoader
from .lightning_module import TransformerLightningModule
from .module import TransformerModel
from lightning_module import TransformerLightningModule
from module import TransformerModel
PREDICTION_INPUT_NAMES = [
"feat_static_cat",
+1 -1
View File
@@ -3,7 +3,7 @@ import torch
from gluonts.torch.modules.loss import DistributionLoss, NegativeLogLikelihood
from gluonts.torch.util import weighted_average
from .module import TransformerModel
from module import TransformerModel
class TransformerLightningModule(pl.LightningModule):