This commit is contained in:
Kashif Rasul
2022-04-04 15:21:04 +02:00
parent 648cb8851f
commit 2596806068
10 changed files with 7 additions and 140 deletions
+1 -2
View File
@@ -28,10 +28,9 @@ from gluonts.transform import (
VstackFeatures,
)
from gluonts.transform.sampler import InstanceSampler
from torch.utils.data import DataLoader
from lightning_module import HopfieldLightningModule
from module import HopfieldModel
from torch.utils.data import DataLoader
PREDICTION_INPUT_NAMES = [
"feat_static_cat",
-1
View File
@@ -2,7 +2,6 @@ import pytorch_lightning as pl
import torch
from gluonts.torch.modules.loss import DistributionLoss, NegativeLogLikelihood
from gluonts.torch.util import weighted_average
from module import HopfieldModel
-2
View File
@@ -1,6 +1,5 @@
from typing import List, Optional, Tuple
import torch
import torch.nn as nn
from gluonts.core.component import validated
@@ -8,7 +7,6 @@ from gluonts.time_feature import get_lags_for_frequency
from gluonts.torch.modules.distribution_output import DistributionOutput, StudentTOutput
from gluonts.torch.modules.feature import FeatureEmbedder
from gluonts.torch.modules.scaler import MeanScaler, NOPScaler
from hflayers import Hopfield
from hflayers.transformer import HopfieldDecoderLayer, HopfieldEncoderLayer
+2 -2
View File
@@ -1,6 +1,6 @@
from .module import TFTModel
from .lightning_module import TFTLightningModule
from .estimator import TFTEstimator
from .lightning_module import TFTLightningModule
from .module import TFTModel
__all__ = [
"TFTModel",
+1 -2
View File
@@ -28,10 +28,9 @@ from gluonts.transform import (
VstackFeatures,
)
from gluonts.transform.sampler import InstanceSampler
from torch.utils.data import DataLoader
from lightning_module import TFTLightningModule
from module import TFTModel
from torch.utils.data import DataLoader
PREDICTION_INPUT_NAMES = [
"feat_static_cat",
-1
View File
@@ -2,7 +2,6 @@ import pytorch_lightning as pl
import torch
from gluonts.torch.modules.loss import DistributionLoss, NegativeLogLikelihood
from gluonts.torch.util import weighted_average
from module import TFTModel
-125
View File
@@ -1,125 +0,0 @@
from typing import Iterator, List
import numpy as np
from gluonts.core.component import validated
from gluonts.dataset.common import DataEntry
from gluonts.dataset.field_names import FieldName
from gluonts.transform import (
InstanceSplitter,
MapTransformation,
shift_timestamp,
target_transformation_length,
)
from gluonts.transform.sampler import InstanceSampler
class BroadcastTo(MapTransformation):
@validated()
def __init__(
self,
field: str,
ext_length: int = 0,
target_field: str = FieldName.TARGET,
) -> None:
self.field = field
self.ext_length = ext_length
self.target_field = target_field
def map_transform(self, data: DataEntry, is_train: bool) -> DataEntry:
length = target_transformation_length(
data[self.target_field], self.ext_length, is_train
)
data[self.field] = np.broadcast_to(
data[self.field],
(data[self.field].shape[:-1] + (length,)),
)
return data
class TFTInstanceSplitter(InstanceSplitter):
@validated()
def __init__(
self,
instance_sampler: InstanceSampler,
past_length: int,
future_length: int,
target_field: str = FieldName.TARGET,
is_pad_field: str = FieldName.IS_PAD,
start_field: str = FieldName.START,
forecast_start_field: str = FieldName.FORECAST_START,
observed_value_field: str = FieldName.OBSERVED_VALUES,
lead_time: int = 0,
output_NTC: bool = True,
time_series_fields: List[str] = [],
past_time_series_fields: List[str] = [],
dummy_value: float = 0.0,
) -> None:
super().__init__(
target_field=target_field,
is_pad_field=is_pad_field,
start_field=start_field,
forecast_start_field=forecast_start_field,
instance_sampler=instance_sampler,
past_length=past_length,
future_length=future_length,
lead_time=lead_time,
output_NTC=output_NTC,
time_series_fields=time_series_fields,
dummy_value=dummy_value,
)
assert past_length > 0, "The value of `past_length` should be > 0"
assert future_length > 0, "The value of `future_length` should be > 0"
self.observed_value_field = observed_value_field
self.past_ts_fields = past_time_series_fields
def flatmap_transform(self, data: DataEntry, is_train: bool) -> Iterator[DataEntry]:
pl = self.future_length
lt = self.lead_time
target = data[self.target_field]
sampled_indices = self.instance_sampler(target)
slice_cols = (
self.ts_fields
+ self.past_ts_fields
+ [self.target_field, self.observed_value_field]
)
for i in sampled_indices:
pad_length = max(self.past_length - i, 0)
d = data.copy()
for field in slice_cols:
if i >= self.past_length:
past_piece = d[field][..., i - self.past_length : i]
else:
pad_block = np.full(
shape=d[field].shape[:-1] + (pad_length,),
fill_value=self.dummy_value,
dtype=d[field].dtype,
)
past_piece = np.concatenate([pad_block, d[field][..., :i]], axis=-1)
future_piece = d[field][..., (i + lt) : (i + lt + pl)]
if field in self.ts_fields:
piece = np.concatenate([past_piece, future_piece], axis=-1)
if self.output_NTC:
piece = piece.transpose()
d[field] = piece
else:
if self.output_NTC:
past_piece = past_piece.transpose()
future_piece = future_piece.transpose()
if field not in self.past_ts_fields:
d[self._past(field)] = past_piece
d[self._future(field)] = future_piece
del d[field]
else:
d[field] = past_piece
pad_indicator = np.zeros(self.past_length)
if pad_length > 0:
pad_indicator[:pad_length] = 1
d[self._past(self.is_pad_field)] = pad_indicator
d[self.forecast_start_field] = shift_timestamp(d[self.start_field], i + lt)
yield d
+2 -2
View File
@@ -1,6 +1,6 @@
from .module import TransformerModel
from .lightning_module import TransformerLightningModule
from .estimator import TransformerEstimator
from .lightning_module import TransformerLightningModule
from .module import TransformerModel
__all__ = [
"TransformerModel",
+1 -2
View File
@@ -28,10 +28,9 @@ from gluonts.transform import (
VstackFeatures,
)
from gluonts.transform.sampler import InstanceSampler
from torch.utils.data import DataLoader
from lightning_module import TransformerLightningModule
from module import TransformerModel
from torch.utils.data import DataLoader
PREDICTION_INPUT_NAMES = [
"feat_static_cat",
-1
View File
@@ -2,7 +2,6 @@ import pytorch_lightning as pl
import torch
from gluonts.torch.modules.loss import DistributionLoss, NegativeLogLikelihood
from gluonts.torch.util import weighted_average
from module import TransformerModel