mirror of
https://github.com/wassname/pytorch-transformer-ts.git
synced 2026-06-27 16:46:28 +08:00
isort
This commit is contained in:
@@ -28,10 +28,9 @@ from gluonts.transform import (
|
||||
VstackFeatures,
|
||||
)
|
||||
from gluonts.transform.sampler import InstanceSampler
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from lightning_module import HopfieldLightningModule
|
||||
from module import HopfieldModel
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
PREDICTION_INPUT_NAMES = [
|
||||
"feat_static_cat",
|
||||
|
||||
@@ -2,7 +2,6 @@ import pytorch_lightning as pl
|
||||
import torch
|
||||
from gluonts.torch.modules.loss import DistributionLoss, NegativeLogLikelihood
|
||||
from gluonts.torch.util import weighted_average
|
||||
|
||||
from module import HopfieldModel
|
||||
|
||||
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from gluonts.core.component import validated
|
||||
@@ -8,7 +7,6 @@ from gluonts.time_feature import get_lags_for_frequency
|
||||
from gluonts.torch.modules.distribution_output import DistributionOutput, StudentTOutput
|
||||
from gluonts.torch.modules.feature import FeatureEmbedder
|
||||
from gluonts.torch.modules.scaler import MeanScaler, NOPScaler
|
||||
|
||||
from hflayers import Hopfield
|
||||
from hflayers.transformer import HopfieldDecoderLayer, HopfieldEncoderLayer
|
||||
|
||||
|
||||
+2
-2
@@ -1,6 +1,6 @@
|
||||
from .module import TFTModel
|
||||
from .lightning_module import TFTLightningModule
|
||||
from .estimator import TFTEstimator
|
||||
from .lightning_module import TFTLightningModule
|
||||
from .module import TFTModel
|
||||
|
||||
__all__ = [
|
||||
"TFTModel",
|
||||
|
||||
+1
-2
@@ -28,10 +28,9 @@ from gluonts.transform import (
|
||||
VstackFeatures,
|
||||
)
|
||||
from gluonts.transform.sampler import InstanceSampler
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from lightning_module import TFTLightningModule
|
||||
from module import TFTModel
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
PREDICTION_INPUT_NAMES = [
|
||||
"feat_static_cat",
|
||||
|
||||
@@ -2,7 +2,6 @@ import pytorch_lightning as pl
|
||||
import torch
|
||||
from gluonts.torch.modules.loss import DistributionLoss, NegativeLogLikelihood
|
||||
from gluonts.torch.util import weighted_average
|
||||
|
||||
from module import TFTModel
|
||||
|
||||
|
||||
|
||||
@@ -1,125 +0,0 @@
|
||||
from typing import Iterator, List
|
||||
|
||||
import numpy as np
|
||||
from gluonts.core.component import validated
|
||||
from gluonts.dataset.common import DataEntry
|
||||
from gluonts.dataset.field_names import FieldName
|
||||
from gluonts.transform import (
|
||||
InstanceSplitter,
|
||||
MapTransformation,
|
||||
shift_timestamp,
|
||||
target_transformation_length,
|
||||
)
|
||||
from gluonts.transform.sampler import InstanceSampler
|
||||
|
||||
|
||||
class BroadcastTo(MapTransformation):
|
||||
@validated()
|
||||
def __init__(
|
||||
self,
|
||||
field: str,
|
||||
ext_length: int = 0,
|
||||
target_field: str = FieldName.TARGET,
|
||||
) -> None:
|
||||
self.field = field
|
||||
self.ext_length = ext_length
|
||||
self.target_field = target_field
|
||||
|
||||
def map_transform(self, data: DataEntry, is_train: bool) -> DataEntry:
|
||||
length = target_transformation_length(
|
||||
data[self.target_field], self.ext_length, is_train
|
||||
)
|
||||
data[self.field] = np.broadcast_to(
|
||||
data[self.field],
|
||||
(data[self.field].shape[:-1] + (length,)),
|
||||
)
|
||||
return data
|
||||
|
||||
|
||||
class TFTInstanceSplitter(InstanceSplitter):
|
||||
@validated()
|
||||
def __init__(
|
||||
self,
|
||||
instance_sampler: InstanceSampler,
|
||||
past_length: int,
|
||||
future_length: int,
|
||||
target_field: str = FieldName.TARGET,
|
||||
is_pad_field: str = FieldName.IS_PAD,
|
||||
start_field: str = FieldName.START,
|
||||
forecast_start_field: str = FieldName.FORECAST_START,
|
||||
observed_value_field: str = FieldName.OBSERVED_VALUES,
|
||||
lead_time: int = 0,
|
||||
output_NTC: bool = True,
|
||||
time_series_fields: List[str] = [],
|
||||
past_time_series_fields: List[str] = [],
|
||||
dummy_value: float = 0.0,
|
||||
) -> None:
|
||||
|
||||
super().__init__(
|
||||
target_field=target_field,
|
||||
is_pad_field=is_pad_field,
|
||||
start_field=start_field,
|
||||
forecast_start_field=forecast_start_field,
|
||||
instance_sampler=instance_sampler,
|
||||
past_length=past_length,
|
||||
future_length=future_length,
|
||||
lead_time=lead_time,
|
||||
output_NTC=output_NTC,
|
||||
time_series_fields=time_series_fields,
|
||||
dummy_value=dummy_value,
|
||||
)
|
||||
|
||||
assert past_length > 0, "The value of `past_length` should be > 0"
|
||||
assert future_length > 0, "The value of `future_length` should be > 0"
|
||||
|
||||
self.observed_value_field = observed_value_field
|
||||
self.past_ts_fields = past_time_series_fields
|
||||
|
||||
def flatmap_transform(self, data: DataEntry, is_train: bool) -> Iterator[DataEntry]:
|
||||
pl = self.future_length
|
||||
lt = self.lead_time
|
||||
target = data[self.target_field]
|
||||
|
||||
sampled_indices = self.instance_sampler(target)
|
||||
|
||||
slice_cols = (
|
||||
self.ts_fields
|
||||
+ self.past_ts_fields
|
||||
+ [self.target_field, self.observed_value_field]
|
||||
)
|
||||
for i in sampled_indices:
|
||||
pad_length = max(self.past_length - i, 0)
|
||||
d = data.copy()
|
||||
|
||||
for field in slice_cols:
|
||||
if i >= self.past_length:
|
||||
past_piece = d[field][..., i - self.past_length : i]
|
||||
else:
|
||||
pad_block = np.full(
|
||||
shape=d[field].shape[:-1] + (pad_length,),
|
||||
fill_value=self.dummy_value,
|
||||
dtype=d[field].dtype,
|
||||
)
|
||||
past_piece = np.concatenate([pad_block, d[field][..., :i]], axis=-1)
|
||||
future_piece = d[field][..., (i + lt) : (i + lt + pl)]
|
||||
if field in self.ts_fields:
|
||||
piece = np.concatenate([past_piece, future_piece], axis=-1)
|
||||
if self.output_NTC:
|
||||
piece = piece.transpose()
|
||||
d[field] = piece
|
||||
else:
|
||||
if self.output_NTC:
|
||||
past_piece = past_piece.transpose()
|
||||
future_piece = future_piece.transpose()
|
||||
if field not in self.past_ts_fields:
|
||||
d[self._past(field)] = past_piece
|
||||
d[self._future(field)] = future_piece
|
||||
del d[field]
|
||||
else:
|
||||
d[field] = past_piece
|
||||
pad_indicator = np.zeros(self.past_length)
|
||||
if pad_length > 0:
|
||||
pad_indicator[:pad_length] = 1
|
||||
d[self._past(self.is_pad_field)] = pad_indicator
|
||||
d[self.forecast_start_field] = shift_timestamp(d[self.start_field], i + lt)
|
||||
yield d
|
||||
@@ -1,6 +1,6 @@
|
||||
from .module import TransformerModel
|
||||
from .lightning_module import TransformerLightningModule
|
||||
from .estimator import TransformerEstimator
|
||||
from .lightning_module import TransformerLightningModule
|
||||
from .module import TransformerModel
|
||||
|
||||
__all__ = [
|
||||
"TransformerModel",
|
||||
|
||||
@@ -28,10 +28,9 @@ from gluonts.transform import (
|
||||
VstackFeatures,
|
||||
)
|
||||
from gluonts.transform.sampler import InstanceSampler
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from lightning_module import TransformerLightningModule
|
||||
from module import TransformerModel
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
PREDICTION_INPUT_NAMES = [
|
||||
"feat_static_cat",
|
||||
|
||||
@@ -2,7 +2,6 @@ import pytorch_lightning as pl
|
||||
import torch
|
||||
from gluonts.torch.modules.loss import DistributionLoss, NegativeLogLikelihood
|
||||
from gluonts.torch.util import weighted_average
|
||||
|
||||
from module import TransformerModel
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user