diff --git a/pts/dataset/sampler.py b/pts/dataset/sampler.py index dc3f345..14716de 100644 --- a/pts/dataset/sampler.py +++ b/pts/dataset/sampler.py @@ -1,3 +1,5 @@ +from typing import Union, List + import numpy as np from abc import ABC, abstractmethod @@ -6,7 +8,7 @@ from .stat import ScaleHistogram class InstanceSampler(ABC): @abstractmethod - def __call__(self, ts: np.ndarray, a: int, b: int) -> np.ndarray: + def __call__(self, ts: np.ndarray, a: int, b: int) -> Union[np.ndarray, List[int]]: pass @@ -90,7 +92,7 @@ class BucketInstanceSampler(InstanceSampler): self.scale_histogram = scale_histogram self.lookup = np.arange(2 ** 13) - def __call__(self, ts: np.ndarray, a: int, b: int) -> None: + def __call__(self, ts: np.ndarray, a: int, b: int) -> np.ndarray: while ts.shape[-1] >= len(self.lookup): self.lookup = np.arange(2 * len(self.lookup)) p = 1.0 / self.scale_histogram.count(ts) diff --git a/pts/dataset/stat.py b/pts/dataset/stat.py index ce31352..1bdcf9d 100644 --- a/pts/dataset/stat.py +++ b/pts/dataset/stat.py @@ -1,5 +1,7 @@ from collections import defaultdict from typing import Optional +import math + import numpy as np diff --git a/pts/dataset/transformed_dataset.py b/pts/dataset/transformed_dataset.py index 2cb6501..69a3523 100644 --- a/pts/dataset/transformed_dataset.py +++ b/pts/dataset/transformed_dataset.py @@ -19,7 +19,7 @@ class TransformedDataset(Dataset): """ def __init__( - self, base_dataset: Dataset, transformations: List[Transformation] + self, base_dataset: Dataset, transformations: List[Transformation] ) -> None: self.base_dataset = base_dataset self.transformations = Chain(transformations) diff --git a/pts/feature/time_feature.py b/pts/feature/time_feature.py index 25dbaf0..2211f7a 100644 --- a/pts/feature/time_feature.py +++ b/pts/feature/time_feature.py @@ -1,8 +1,8 @@ +from abc import ABC, abstractmethod + import numpy as np import pandas as pd -from abc import ABC, abstractmethod - class TimeFeature(ABC): def __init__(self, normalized: bool = True): @@ -95,4 +95,3 @@ class WeekOfYear(TimeFeature): return index.weekofyear / 51.0 - 0.5 else: return index.weekofyear.map(float) - diff --git a/pts/feature/transform.py b/pts/feature/transform.py index 99e1a32..5220f86 100644 --- a/pts/feature/transform.py +++ b/pts/feature/transform.py @@ -1,6 +1,6 @@ from abc import ABC, abstractmethod -from functools import lru_cache, reduce from collections import Counter +from functools import lru_cache, reduce from typing import Iterator, List, Callable, Any, Optional, Dict, Tuple import numpy as np @@ -27,7 +27,7 @@ def shift_timestamp(ts: pd.Timestamp, offset: int) -> pd.Timestamp: def target_transformation_length( - target: np.array, pred_length: int, is_train: bool + target: np.array, pred_length: int, is_train: bool ) -> int: return target.shape[-1] + (0 if is_train else pred_length) @@ -35,7 +35,7 @@ def target_transformation_length( class Transformation(ABC): @abstractmethod def __call__( - self, data_it: Iterator[DataEntry], is_train: bool + self, data_it: Iterator[DataEntry], is_train: bool ) -> Iterator[DataEntry]: pass @@ -52,7 +52,7 @@ class Chain(Transformation): self.trans = trans def __call__( - self, data_it: Iterator[DataEntry], is_train: bool + self, data_it: Iterator[DataEntry], is_train: bool ) -> Iterator[DataEntry]: tmp = data_it for t in self.trans: @@ -65,7 +65,7 @@ class Chain(Transformation): class Identity(Transformation): def __call__( - self, data_it: Iterator[DataEntry], is_train: bool + self, data_it: Iterator[DataEntry], is_train: bool ) -> Iterator[DataEntry]: return data_it @@ -190,7 +190,7 @@ class SetFieldIfNotPresent(SimpleTransformation): Sets a field in the dictionary with the given value, in case it does not exist already Parameters ---------- - output_field + field Name of the field that will be set value Value to be set @@ -219,7 +219,7 @@ class AsNumpyArray(SimpleTransformation): """ def __init__( - self, field: str, expected_ndim: int, dtype: np.dtype = np.float32 + self, field: str, expected_ndim: int, dtype: np.dtype = np.float32 ) -> None: self.field = field self.expected_ndim = expected_ndim @@ -286,7 +286,7 @@ class VstackFeatures(SimpleTransformation): """ def __init__( - self, output_field: str, input_fields: List[str], drop_inputs: bool = True + self, output_field: str, input_fields: List[str], drop_inputs: bool = True ) -> None: self.output_field = output_field self.input_fields = input_fields @@ -320,7 +320,7 @@ class ConcatFeatures(SimpleTransformation): """ def __init__( - self, output_field: str, input_fields: List[str], drop_inputs: bool = True + self, output_field: str, input_fields: List[str], drop_inputs: bool = True ) -> None: self.output_field = output_field self.input_fields = input_fields @@ -385,7 +385,7 @@ class ListFeatures(SimpleTransformation): """ def __init__( - self, output_field: str, input_fields: List[str], drop_inputs: bool = True + self, output_field: str, input_fields: List[str], drop_inputs: bool = True ) -> None: self.output_field = output_field self.input_fields = input_fields @@ -423,11 +423,11 @@ class AddObservedValuesIndicator(SimpleTransformation): """ def __init__( - self, - target_field: str, - output_field: str, - dummy_value: int = 0, - convert_nans: bool = True, + self, + target_field: str, + output_field: str, + dummy_value: int = 0, + convert_nans: bool = True, ) -> None: self.dummy_value = dummy_value self.target_field = target_field @@ -496,12 +496,12 @@ class AddConstFeature(MapTransformation): """ def __init__( - self, - output_field: str, - target_field: str, - pred_length: int, - const: float = 1.0, - dtype: np.dtype = np.float32, + self, + output_field: str, + target_field: str, + pred_length: int, + const: float = 1.0, + dtype: np.dtype = np.float32, ) -> None: self.pred_length = pred_length self.const = const @@ -539,22 +539,22 @@ class AddTimeFeatures(MapTransformation): """ def __init__( - self, - start_field: str, - target_field: str, - output_field: str, - time_features: List[TimeFeature], - pred_length: int, + self, + start_field: str, + target_field: str, + output_field: str, + time_features: List[TimeFeature], + pred_length: int, ) -> None: self.date_features = time_features self.pred_length = pred_length self.start_field = start_field self.target_field = target_field self.output_field = output_field - self._min_time_point: pd.Timestamp = None - self._max_time_point: pd.Timestamp = None - self._full_range_date_features: np.ndarray = None - self._date_index: pd.DatetimeIndex = None + self._min_time_point: Optional[pd.Timestamp] = None + self._max_time_point: Optional[pd.Timestamp] = None + self._full_range_date_features: Optional[np.ndarray] = None + self._date_index: Optional[pd.DatetimeIndex] = None def _update_cache(self, start: pd.Timestamp, length: int) -> None: end = shift_timestamp(start, length) @@ -583,7 +583,7 @@ class AddTimeFeatures(MapTransformation): ) self._update_cache(start, length) i0 = self._date_index[start] - features = self._full_range_date_features[..., i0 : i0 + length] + features = self._full_range_date_features[..., i0:i0 + length] data[self.output_field] = features return data @@ -608,11 +608,11 @@ class AddAgeFeature(MapTransformation): """ def __init__( - self, - target_field: str, - output_field: str, - pred_length: int, - log_scale: bool = True, + self, + target_field: str, + output_field: str, + pred_length: int, + log_scale: bool = True, ) -> None: self.pred_length = pred_length self.target_field = target_field @@ -681,17 +681,17 @@ class InstanceSplitter(FlatMapTransformation): """ def __init__( - self, - target_field: str, - is_pad_field: str, - start_field: str, - forecast_start_field: str, - train_sampler: InstanceSampler, - past_length: int, - future_length: int, - output_NTC: bool = True, - time_series_fields: Optional[List[str]] = None, - pick_incomplete: bool = True, + self, + target_field: str, + is_pad_field: str, + start_field: str, + forecast_start_field: str, + train_sampler: InstanceSampler, + past_length: int, + future_length: int, + output_NTC: bool = True, + time_series_fields: Optional[List[str]] = None, + pick_incomplete: bool = True, ) -> None: assert future_length > 0 @@ -746,7 +746,7 @@ class InstanceSplitter(FlatMapTransformation): for ts_field in slice_cols: if i > self.past_length: # truncate to past_length - past_piece = d[ts_field][..., i - self.past_length : i] + past_piece = d[ts_field][..., i - self.past_length: i] elif i < self.past_length: pad_block = np.zeros( d[ts_field].shape[:-1] + (pad_length,), dtype=d[ts_field].dtype @@ -757,7 +757,7 @@ class InstanceSplitter(FlatMapTransformation): else: past_piece = d[ts_field][..., :i] d[self._past(ts_field)] = past_piece - d[self._future(ts_field)] = d[ts_field][..., i : i + pl] + d[self._future(ts_field)] = d[ts_field][..., i: i + pl] del d[ts_field] pad_indicator = np.zeros(self.past_length) if pad_length > 0: @@ -824,19 +824,19 @@ class CanonicalInstanceSplitter(FlatMapTransformation): """ def __init__( - self, - target_field: str, - is_pad_field: str, - start_field: str, - forecast_start_field: str, - instance_sampler: InstanceSampler, - instance_length: int, - output_NTC: bool = True, - time_series_fields: List[str] = [], - allow_target_padding: bool = False, - pad_value: float = 0.0, - use_prediction_features: bool = False, - prediction_length: Optional[int] = None, + self, + target_field: str, + is_pad_field: str, + start_field: str, + forecast_start_field: str, + instance_sampler: InstanceSampler, + instance_length: int, + output_NTC: bool = True, + time_series_fields: List[str] = [], + allow_target_padding: bool = False, + pad_value: float = 0.0, + use_prediction_features: bool = False, + prediction_length: Optional[int] = None, ) -> None: self.instance_sampler = instance_sampler self.instance_length = instance_length @@ -850,7 +850,7 @@ class CanonicalInstanceSplitter(FlatMapTransformation): self.forecast_start_field = forecast_start_field assert ( - not use_prediction_features or prediction_length is not None + not use_prediction_features or prediction_length is not None ), "You must specify `prediction_length` if `use_prediction_features`" self.use_prediction_features = use_prediction_features @@ -908,14 +908,14 @@ class CanonicalInstanceSplitter(FlatMapTransformation): ) past_ts = np.concatenate([pad_pre, full_ts[..., :i]], axis=-1) else: - past_ts = full_ts[..., (i - self.instance_length) : i] + past_ts = full_ts[..., (i - self.instance_length): i] past_ts = past_ts.transpose() if self.output_NTC else past_ts d[self._past(ts_field)] = past_ts if self.use_prediction_features and not is_train: if not ts_field == self.target_field: - future_ts = full_ts[..., i : i + self.prediction_length] + future_ts = full_ts[..., i: i + self.prediction_length] future_ts = ( future_ts.transpose() if self.output_NTC else future_ts )