From e198ba4544a170e59e8a8a8cf8a8fcd45009c09b Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Thu, 31 Mar 2022 11:57:10 +0200 Subject: [PATCH] fix imports --- tft/module.py | 4 ++++ transformer/estimator.py | 4 ++-- transformer/lightning_module.py | 2 +- 3 files changed, 7 insertions(+), 3 deletions(-) diff --git a/tft/module.py b/tft/module.py index 49a43b4..7547c1f 100644 --- a/tft/module.py +++ b/tft/module.py @@ -22,6 +22,7 @@ class FeatureEmbedder(BaseFeatureEmbedder): class GatedResidualNetwork(nn.Module): + @validated() def __init__( self, d_hidden: int, @@ -71,6 +72,7 @@ class GatedResidualNetwork(nn.Module): class VariableSelectionNetwork(nn.Module): + @validated() def __init__( self, d_hidden: int, @@ -112,6 +114,7 @@ class VariableSelectionNetwork(nn.Module): class TemporalFusionEncoder(nn.Module): + @validated() def __init__( self, d_input: int, @@ -162,6 +165,7 @@ class TemporalFusionEncoder(nn.Module): class TemporalFusionDecoder(nn.Module): + @validated() def __init__( self, context_length: int, diff --git a/transformer/estimator.py b/transformer/estimator.py index ab0480b..7fb4713 100644 --- a/transformer/estimator.py +++ b/transformer/estimator.py @@ -30,8 +30,8 @@ from gluonts.transform import ( from gluonts.transform.sampler import InstanceSampler from torch.utils.data import DataLoader -from .lightning_module import TransformerLightningModule -from .module import TransformerModel +from lightning_module import TransformerLightningModule +from module import TransformerModel PREDICTION_INPUT_NAMES = [ "feat_static_cat", diff --git a/transformer/lightning_module.py b/transformer/lightning_module.py index 08a7428..8486db1 100644 --- a/transformer/lightning_module.py +++ b/transformer/lightning_module.py @@ -3,7 +3,7 @@ import torch from gluonts.torch.modules.loss import DistributionLoss, NegativeLogLikelihood from gluonts.torch.util import weighted_average -from .module import TransformerModel +from module import TransformerModel class TransformerLightningModule(pl.LightningModule):