mirror of
https://github.com/wassname/pytorch-transformer-ts.git
synced 2026-06-27 17:49:40 +08:00
fix imports
This commit is contained in:
@@ -22,6 +22,7 @@ class FeatureEmbedder(BaseFeatureEmbedder):
|
||||
|
||||
|
||||
class GatedResidualNetwork(nn.Module):
|
||||
@validated()
|
||||
def __init__(
|
||||
self,
|
||||
d_hidden: int,
|
||||
@@ -71,6 +72,7 @@ class GatedResidualNetwork(nn.Module):
|
||||
|
||||
|
||||
class VariableSelectionNetwork(nn.Module):
|
||||
@validated()
|
||||
def __init__(
|
||||
self,
|
||||
d_hidden: int,
|
||||
@@ -112,6 +114,7 @@ class VariableSelectionNetwork(nn.Module):
|
||||
|
||||
|
||||
class TemporalFusionEncoder(nn.Module):
|
||||
@validated()
|
||||
def __init__(
|
||||
self,
|
||||
d_input: int,
|
||||
@@ -162,6 +165,7 @@ class TemporalFusionEncoder(nn.Module):
|
||||
|
||||
|
||||
class TemporalFusionDecoder(nn.Module):
|
||||
@validated()
|
||||
def __init__(
|
||||
self,
|
||||
context_length: int,
|
||||
|
||||
@@ -30,8 +30,8 @@ from gluonts.transform import (
|
||||
from gluonts.transform.sampler import InstanceSampler
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from .lightning_module import TransformerLightningModule
|
||||
from .module import TransformerModel
|
||||
from lightning_module import TransformerLightningModule
|
||||
from module import TransformerModel
|
||||
|
||||
PREDICTION_INPUT_NAMES = [
|
||||
"feat_static_cat",
|
||||
|
||||
@@ -3,7 +3,7 @@ import torch
|
||||
from gluonts.torch.modules.loss import DistributionLoss, NegativeLogLikelihood
|
||||
from gluonts.torch.util import weighted_average
|
||||
|
||||
from .module import TransformerModel
|
||||
from module import TransformerModel
|
||||
|
||||
|
||||
class TransformerLightningModule(pl.LightningModule):
|
||||
|
||||
Reference in New Issue
Block a user