mirror of
https://github.com/wassname/pytorch-transformer-ts.git
synced 2026-06-27 16:31:19 +08:00
use dict instead of class
This commit is contained in:
@@ -1,4 +1,3 @@
|
||||
# +
|
||||
from typing import Any, Dict, Iterable, List, Optional
|
||||
|
||||
import torch
|
||||
@@ -32,7 +31,6 @@ from gluonts.transform import (
|
||||
ValidationSplitSampler,
|
||||
VstackFeatures,
|
||||
)
|
||||
from torchscale.architecture.config import EncoderDecoderConfig
|
||||
|
||||
from lightning_module import TorchscaleLightningModule
|
||||
from module import TorchscaleModel
|
||||
@@ -60,7 +58,7 @@ class TorchscaleEstimator(PyTorchLightningEstimator):
|
||||
freq: str,
|
||||
prediction_length: int,
|
||||
# Torchscale arguments
|
||||
enc_dec_config: EncoderDecoderConfig,
|
||||
enc_dec_config: Dict[str, Any],
|
||||
input_size: int = 1,
|
||||
context_length: Optional[int] = None,
|
||||
num_feat_dynamic_real: int = 0,
|
||||
|
||||
@@ -287,7 +287,7 @@ class TorchscaleModel(nn.Module):
|
||||
num_feat_static_cat: int,
|
||||
cardinality: List[int],
|
||||
# torchscale config
|
||||
enc_dec_config: EncoderDecoderConfig,
|
||||
enc_dec_config: Dict[str, Any],
|
||||
input_size: int = 1,
|
||||
embedding_dimension: Optional[List[int]] = None,
|
||||
distr_output: DistributionOutput = StudentTOutput(),
|
||||
@@ -328,11 +328,12 @@ class TorchscaleModel(nn.Module):
|
||||
self.distr_output = distr_output
|
||||
self.param_proj = distr_output.get_args_proj(d_model)
|
||||
|
||||
enc_dec_config.encoder_embed_dim = d_model
|
||||
enc_dec_config.decoder_embed_dim = d_model
|
||||
config = EncoderDecoderConfig(**enc_dec_config)
|
||||
config.encoder_embed_dim = d_model
|
||||
config.decoder_embed_dim = d_model
|
||||
|
||||
self.encoder = Encoder(enc_dec_config)
|
||||
self.decoder = Decoder(enc_dec_config)
|
||||
self.encoder = Encoder(config)
|
||||
self.decoder = Decoder(config)
|
||||
|
||||
# attention_args["dropout"] = dropout
|
||||
# attention_args["causal"] = False
|
||||
@@ -565,7 +566,7 @@ class TorchscaleModel(nn.Module):
|
||||
future_time_feat,
|
||||
)
|
||||
|
||||
enc_out = self.encoder(src=encoder_inputs)
|
||||
enc_out = self.encoder(encoder_inputs)
|
||||
|
||||
params = self.param_proj(enc_out.transpose(0, 1)) # (B, T, D)
|
||||
distr = self.output_distribution(params, trailing_n=1)
|
||||
@@ -584,7 +585,7 @@ class TorchscaleModel(nn.Module):
|
||||
repeats=self.num_parallel_samples, dim=0
|
||||
)
|
||||
repeated_enc_out = enc_out.repeat_interleave(
|
||||
repeats=self.num_parallel_samples, dim=0
|
||||
repeats=self.num_parallel_samples, dim=1
|
||||
)
|
||||
|
||||
future_samples = []
|
||||
|
||||
File diff suppressed because one or more lines are too long
Reference in New Issue
Block a user