fix imports

This commit is contained in:
Kashif Rasul
2022-03-31 11:57:10 +02:00
parent 6c6380300d
commit e198ba4544
3 changed files with 7 additions and 3 deletions
+4
View File
@@ -22,6 +22,7 @@ class FeatureEmbedder(BaseFeatureEmbedder):
class GatedResidualNetwork(nn.Module):
@validated()
def __init__(
self,
d_hidden: int,
@@ -71,6 +72,7 @@ class GatedResidualNetwork(nn.Module):
class VariableSelectionNetwork(nn.Module):
@validated()
def __init__(
self,
d_hidden: int,
@@ -112,6 +114,7 @@ class VariableSelectionNetwork(nn.Module):
class TemporalFusionEncoder(nn.Module):
@validated()
def __init__(
self,
d_input: int,
@@ -162,6 +165,7 @@ class TemporalFusionEncoder(nn.Module):
class TemporalFusionDecoder(nn.Module):
@validated()
def __init__(
self,
context_length: int,