use dict instead of class

This commit is contained in:
Kashif Rasul
2022-11-28 23:13:27 +01:00
parent f6e7d77fe4
commit 4e78dd7dc5
3 changed files with 92 additions and 81 deletions
+1 -3
View File
@@ -1,4 +1,3 @@
# +
from typing import Any, Dict, Iterable, List, Optional
import torch
@@ -32,7 +31,6 @@ from gluonts.transform import (
ValidationSplitSampler,
VstackFeatures,
)
from torchscale.architecture.config import EncoderDecoderConfig
from lightning_module import TorchscaleLightningModule
from module import TorchscaleModel
@@ -60,7 +58,7 @@ class TorchscaleEstimator(PyTorchLightningEstimator):
freq: str,
prediction_length: int,
# Torchscale arguments
enc_dec_config: EncoderDecoderConfig,
enc_dec_config: Dict[str, Any],
input_size: int = 1,
context_length: Optional[int] = None,
num_feat_dynamic_real: int = 0,
+8 -7
View File
@@ -287,7 +287,7 @@ class TorchscaleModel(nn.Module):
num_feat_static_cat: int,
cardinality: List[int],
# torchscale config
enc_dec_config: EncoderDecoderConfig,
enc_dec_config: Dict[str, Any],
input_size: int = 1,
embedding_dimension: Optional[List[int]] = None,
distr_output: DistributionOutput = StudentTOutput(),
@@ -328,11 +328,12 @@ class TorchscaleModel(nn.Module):
self.distr_output = distr_output
self.param_proj = distr_output.get_args_proj(d_model)
enc_dec_config.encoder_embed_dim = d_model
enc_dec_config.decoder_embed_dim = d_model
config = EncoderDecoderConfig(**enc_dec_config)
config.encoder_embed_dim = d_model
config.decoder_embed_dim = d_model
self.encoder = Encoder(enc_dec_config)
self.decoder = Decoder(enc_dec_config)
self.encoder = Encoder(config)
self.decoder = Decoder(config)
# attention_args["dropout"] = dropout
# attention_args["causal"] = False
@@ -565,7 +566,7 @@ class TorchscaleModel(nn.Module):
future_time_feat,
)
enc_out = self.encoder(src=encoder_inputs)
enc_out = self.encoder(encoder_inputs)
params = self.param_proj(enc_out.transpose(0, 1)) # (B, T, D)
distr = self.output_distribution(params, trailing_n=1)
@@ -584,7 +585,7 @@ class TorchscaleModel(nn.Module):
repeats=self.num_parallel_samples, dim=0
)
repeated_enc_out = enc_out.repeat_interleave(
repeats=self.num_parallel_samples, dim=0
repeats=self.num_parallel_samples, dim=1
)
future_samples = []
File diff suppressed because one or more lines are too long