mirror of
https://github.com/wassname/pytorch-ts.git
synced 2026-07-05 22:36:46 +08:00
formatting
This commit is contained in:
@@ -1,3 +1,5 @@
|
||||
from typing import Union, List
|
||||
|
||||
import numpy as np
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
@@ -6,7 +8,7 @@ from .stat import ScaleHistogram
|
||||
|
||||
class InstanceSampler(ABC):
|
||||
@abstractmethod
|
||||
def __call__(self, ts: np.ndarray, a: int, b: int) -> np.ndarray:
|
||||
def __call__(self, ts: np.ndarray, a: int, b: int) -> Union[np.ndarray, List[int]]:
|
||||
pass
|
||||
|
||||
|
||||
@@ -90,7 +92,7 @@ class BucketInstanceSampler(InstanceSampler):
|
||||
self.scale_histogram = scale_histogram
|
||||
self.lookup = np.arange(2 ** 13)
|
||||
|
||||
def __call__(self, ts: np.ndarray, a: int, b: int) -> None:
|
||||
def __call__(self, ts: np.ndarray, a: int, b: int) -> np.ndarray:
|
||||
while ts.shape[-1] >= len(self.lookup):
|
||||
self.lookup = np.arange(2 * len(self.lookup))
|
||||
p = 1.0 / self.scale_histogram.count(ts)
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
from collections import defaultdict
|
||||
from typing import Optional
|
||||
import math
|
||||
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
@@ -19,7 +19,7 @@ class TransformedDataset(Dataset):
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, base_dataset: Dataset, transformations: List[Transformation]
|
||||
self, base_dataset: Dataset, transformations: List[Transformation]
|
||||
) -> None:
|
||||
self.base_dataset = base_dataset
|
||||
self.transformations = Chain(transformations)
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class TimeFeature(ABC):
|
||||
def __init__(self, normalized: bool = True):
|
||||
@@ -95,4 +95,3 @@ class WeekOfYear(TimeFeature):
|
||||
return index.weekofyear / 51.0 - 0.5
|
||||
else:
|
||||
return index.weekofyear.map(float)
|
||||
|
||||
|
||||
+66
-66
@@ -1,6 +1,6 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from functools import lru_cache, reduce
|
||||
from collections import Counter
|
||||
from functools import lru_cache, reduce
|
||||
from typing import Iterator, List, Callable, Any, Optional, Dict, Tuple
|
||||
|
||||
import numpy as np
|
||||
@@ -27,7 +27,7 @@ def shift_timestamp(ts: pd.Timestamp, offset: int) -> pd.Timestamp:
|
||||
|
||||
|
||||
def target_transformation_length(
|
||||
target: np.array, pred_length: int, is_train: bool
|
||||
target: np.array, pred_length: int, is_train: bool
|
||||
) -> int:
|
||||
return target.shape[-1] + (0 if is_train else pred_length)
|
||||
|
||||
@@ -35,7 +35,7 @@ def target_transformation_length(
|
||||
class Transformation(ABC):
|
||||
@abstractmethod
|
||||
def __call__(
|
||||
self, data_it: Iterator[DataEntry], is_train: bool
|
||||
self, data_it: Iterator[DataEntry], is_train: bool
|
||||
) -> Iterator[DataEntry]:
|
||||
pass
|
||||
|
||||
@@ -52,7 +52,7 @@ class Chain(Transformation):
|
||||
self.trans = trans
|
||||
|
||||
def __call__(
|
||||
self, data_it: Iterator[DataEntry], is_train: bool
|
||||
self, data_it: Iterator[DataEntry], is_train: bool
|
||||
) -> Iterator[DataEntry]:
|
||||
tmp = data_it
|
||||
for t in self.trans:
|
||||
@@ -65,7 +65,7 @@ class Chain(Transformation):
|
||||
|
||||
class Identity(Transformation):
|
||||
def __call__(
|
||||
self, data_it: Iterator[DataEntry], is_train: bool
|
||||
self, data_it: Iterator[DataEntry], is_train: bool
|
||||
) -> Iterator[DataEntry]:
|
||||
return data_it
|
||||
|
||||
@@ -190,7 +190,7 @@ class SetFieldIfNotPresent(SimpleTransformation):
|
||||
Sets a field in the dictionary with the given value, in case it does not exist already
|
||||
Parameters
|
||||
----------
|
||||
output_field
|
||||
field
|
||||
Name of the field that will be set
|
||||
value
|
||||
Value to be set
|
||||
@@ -219,7 +219,7 @@ class AsNumpyArray(SimpleTransformation):
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, field: str, expected_ndim: int, dtype: np.dtype = np.float32
|
||||
self, field: str, expected_ndim: int, dtype: np.dtype = np.float32
|
||||
) -> None:
|
||||
self.field = field
|
||||
self.expected_ndim = expected_ndim
|
||||
@@ -286,7 +286,7 @@ class VstackFeatures(SimpleTransformation):
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, output_field: str, input_fields: List[str], drop_inputs: bool = True
|
||||
self, output_field: str, input_fields: List[str], drop_inputs: bool = True
|
||||
) -> None:
|
||||
self.output_field = output_field
|
||||
self.input_fields = input_fields
|
||||
@@ -320,7 +320,7 @@ class ConcatFeatures(SimpleTransformation):
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, output_field: str, input_fields: List[str], drop_inputs: bool = True
|
||||
self, output_field: str, input_fields: List[str], drop_inputs: bool = True
|
||||
) -> None:
|
||||
self.output_field = output_field
|
||||
self.input_fields = input_fields
|
||||
@@ -385,7 +385,7 @@ class ListFeatures(SimpleTransformation):
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, output_field: str, input_fields: List[str], drop_inputs: bool = True
|
||||
self, output_field: str, input_fields: List[str], drop_inputs: bool = True
|
||||
) -> None:
|
||||
self.output_field = output_field
|
||||
self.input_fields = input_fields
|
||||
@@ -423,11 +423,11 @@ class AddObservedValuesIndicator(SimpleTransformation):
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
target_field: str,
|
||||
output_field: str,
|
||||
dummy_value: int = 0,
|
||||
convert_nans: bool = True,
|
||||
self,
|
||||
target_field: str,
|
||||
output_field: str,
|
||||
dummy_value: int = 0,
|
||||
convert_nans: bool = True,
|
||||
) -> None:
|
||||
self.dummy_value = dummy_value
|
||||
self.target_field = target_field
|
||||
@@ -496,12 +496,12 @@ class AddConstFeature(MapTransformation):
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
output_field: str,
|
||||
target_field: str,
|
||||
pred_length: int,
|
||||
const: float = 1.0,
|
||||
dtype: np.dtype = np.float32,
|
||||
self,
|
||||
output_field: str,
|
||||
target_field: str,
|
||||
pred_length: int,
|
||||
const: float = 1.0,
|
||||
dtype: np.dtype = np.float32,
|
||||
) -> None:
|
||||
self.pred_length = pred_length
|
||||
self.const = const
|
||||
@@ -539,22 +539,22 @@ class AddTimeFeatures(MapTransformation):
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
start_field: str,
|
||||
target_field: str,
|
||||
output_field: str,
|
||||
time_features: List[TimeFeature],
|
||||
pred_length: int,
|
||||
self,
|
||||
start_field: str,
|
||||
target_field: str,
|
||||
output_field: str,
|
||||
time_features: List[TimeFeature],
|
||||
pred_length: int,
|
||||
) -> None:
|
||||
self.date_features = time_features
|
||||
self.pred_length = pred_length
|
||||
self.start_field = start_field
|
||||
self.target_field = target_field
|
||||
self.output_field = output_field
|
||||
self._min_time_point: pd.Timestamp = None
|
||||
self._max_time_point: pd.Timestamp = None
|
||||
self._full_range_date_features: np.ndarray = None
|
||||
self._date_index: pd.DatetimeIndex = None
|
||||
self._min_time_point: Optional[pd.Timestamp] = None
|
||||
self._max_time_point: Optional[pd.Timestamp] = None
|
||||
self._full_range_date_features: Optional[np.ndarray] = None
|
||||
self._date_index: Optional[pd.DatetimeIndex] = None
|
||||
|
||||
def _update_cache(self, start: pd.Timestamp, length: int) -> None:
|
||||
end = shift_timestamp(start, length)
|
||||
@@ -583,7 +583,7 @@ class AddTimeFeatures(MapTransformation):
|
||||
)
|
||||
self._update_cache(start, length)
|
||||
i0 = self._date_index[start]
|
||||
features = self._full_range_date_features[..., i0 : i0 + length]
|
||||
features = self._full_range_date_features[..., i0:i0 + length]
|
||||
data[self.output_field] = features
|
||||
return data
|
||||
|
||||
@@ -608,11 +608,11 @@ class AddAgeFeature(MapTransformation):
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
target_field: str,
|
||||
output_field: str,
|
||||
pred_length: int,
|
||||
log_scale: bool = True,
|
||||
self,
|
||||
target_field: str,
|
||||
output_field: str,
|
||||
pred_length: int,
|
||||
log_scale: bool = True,
|
||||
) -> None:
|
||||
self.pred_length = pred_length
|
||||
self.target_field = target_field
|
||||
@@ -681,17 +681,17 @@ class InstanceSplitter(FlatMapTransformation):
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
target_field: str,
|
||||
is_pad_field: str,
|
||||
start_field: str,
|
||||
forecast_start_field: str,
|
||||
train_sampler: InstanceSampler,
|
||||
past_length: int,
|
||||
future_length: int,
|
||||
output_NTC: bool = True,
|
||||
time_series_fields: Optional[List[str]] = None,
|
||||
pick_incomplete: bool = True,
|
||||
self,
|
||||
target_field: str,
|
||||
is_pad_field: str,
|
||||
start_field: str,
|
||||
forecast_start_field: str,
|
||||
train_sampler: InstanceSampler,
|
||||
past_length: int,
|
||||
future_length: int,
|
||||
output_NTC: bool = True,
|
||||
time_series_fields: Optional[List[str]] = None,
|
||||
pick_incomplete: bool = True,
|
||||
) -> None:
|
||||
|
||||
assert future_length > 0
|
||||
@@ -746,7 +746,7 @@ class InstanceSplitter(FlatMapTransformation):
|
||||
for ts_field in slice_cols:
|
||||
if i > self.past_length:
|
||||
# truncate to past_length
|
||||
past_piece = d[ts_field][..., i - self.past_length : i]
|
||||
past_piece = d[ts_field][..., i - self.past_length: i]
|
||||
elif i < self.past_length:
|
||||
pad_block = np.zeros(
|
||||
d[ts_field].shape[:-1] + (pad_length,), dtype=d[ts_field].dtype
|
||||
@@ -757,7 +757,7 @@ class InstanceSplitter(FlatMapTransformation):
|
||||
else:
|
||||
past_piece = d[ts_field][..., :i]
|
||||
d[self._past(ts_field)] = past_piece
|
||||
d[self._future(ts_field)] = d[ts_field][..., i : i + pl]
|
||||
d[self._future(ts_field)] = d[ts_field][..., i: i + pl]
|
||||
del d[ts_field]
|
||||
pad_indicator = np.zeros(self.past_length)
|
||||
if pad_length > 0:
|
||||
@@ -824,19 +824,19 @@ class CanonicalInstanceSplitter(FlatMapTransformation):
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
target_field: str,
|
||||
is_pad_field: str,
|
||||
start_field: str,
|
||||
forecast_start_field: str,
|
||||
instance_sampler: InstanceSampler,
|
||||
instance_length: int,
|
||||
output_NTC: bool = True,
|
||||
time_series_fields: List[str] = [],
|
||||
allow_target_padding: bool = False,
|
||||
pad_value: float = 0.0,
|
||||
use_prediction_features: bool = False,
|
||||
prediction_length: Optional[int] = None,
|
||||
self,
|
||||
target_field: str,
|
||||
is_pad_field: str,
|
||||
start_field: str,
|
||||
forecast_start_field: str,
|
||||
instance_sampler: InstanceSampler,
|
||||
instance_length: int,
|
||||
output_NTC: bool = True,
|
||||
time_series_fields: List[str] = [],
|
||||
allow_target_padding: bool = False,
|
||||
pad_value: float = 0.0,
|
||||
use_prediction_features: bool = False,
|
||||
prediction_length: Optional[int] = None,
|
||||
) -> None:
|
||||
self.instance_sampler = instance_sampler
|
||||
self.instance_length = instance_length
|
||||
@@ -850,7 +850,7 @@ class CanonicalInstanceSplitter(FlatMapTransformation):
|
||||
self.forecast_start_field = forecast_start_field
|
||||
|
||||
assert (
|
||||
not use_prediction_features or prediction_length is not None
|
||||
not use_prediction_features or prediction_length is not None
|
||||
), "You must specify `prediction_length` if `use_prediction_features`"
|
||||
|
||||
self.use_prediction_features = use_prediction_features
|
||||
@@ -908,14 +908,14 @@ class CanonicalInstanceSplitter(FlatMapTransformation):
|
||||
)
|
||||
past_ts = np.concatenate([pad_pre, full_ts[..., :i]], axis=-1)
|
||||
else:
|
||||
past_ts = full_ts[..., (i - self.instance_length) : i]
|
||||
past_ts = full_ts[..., (i - self.instance_length): i]
|
||||
|
||||
past_ts = past_ts.transpose() if self.output_NTC else past_ts
|
||||
d[self._past(ts_field)] = past_ts
|
||||
|
||||
if self.use_prediction_features and not is_train:
|
||||
if not ts_field == self.target_field:
|
||||
future_ts = full_ts[..., i : i + self.prediction_length]
|
||||
future_ts = full_ts[..., i: i + self.prediction_length]
|
||||
future_ts = (
|
||||
future_ts.transpose() if self.output_NTC else future_ts
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user