formatting

This commit is contained in:
Kashif Rasul
2019-07-15 14:15:12 +02:00
parent 56d90d0e6c
commit 873d8528ae
5 changed files with 75 additions and 72 deletions
+4 -2
View File
@@ -1,3 +1,5 @@
from typing import Union, List
import numpy as np
from abc import ABC, abstractmethod
@@ -6,7 +8,7 @@ from .stat import ScaleHistogram
class InstanceSampler(ABC):
@abstractmethod
def __call__(self, ts: np.ndarray, a: int, b: int) -> np.ndarray:
def __call__(self, ts: np.ndarray, a: int, b: int) -> Union[np.ndarray, List[int]]:
pass
@@ -90,7 +92,7 @@ class BucketInstanceSampler(InstanceSampler):
self.scale_histogram = scale_histogram
self.lookup = np.arange(2 ** 13)
def __call__(self, ts: np.ndarray, a: int, b: int) -> None:
def __call__(self, ts: np.ndarray, a: int, b: int) -> np.ndarray:
while ts.shape[-1] >= len(self.lookup):
self.lookup = np.arange(2 * len(self.lookup))
p = 1.0 / self.scale_histogram.count(ts)
+2
View File
@@ -1,5 +1,7 @@
from collections import defaultdict
from typing import Optional
import math
import numpy as np
+1 -1
View File
@@ -19,7 +19,7 @@ class TransformedDataset(Dataset):
"""
def __init__(
self, base_dataset: Dataset, transformations: List[Transformation]
self, base_dataset: Dataset, transformations: List[Transformation]
) -> None:
self.base_dataset = base_dataset
self.transformations = Chain(transformations)
+2 -3
View File
@@ -1,8 +1,8 @@
from abc import ABC, abstractmethod
import numpy as np
import pandas as pd
from abc import ABC, abstractmethod
class TimeFeature(ABC):
def __init__(self, normalized: bool = True):
@@ -95,4 +95,3 @@ class WeekOfYear(TimeFeature):
return index.weekofyear / 51.0 - 0.5
else:
return index.weekofyear.map(float)
+66 -66
View File
@@ -1,6 +1,6 @@
from abc import ABC, abstractmethod
from functools import lru_cache, reduce
from collections import Counter
from functools import lru_cache, reduce
from typing import Iterator, List, Callable, Any, Optional, Dict, Tuple
import numpy as np
@@ -27,7 +27,7 @@ def shift_timestamp(ts: pd.Timestamp, offset: int) -> pd.Timestamp:
def target_transformation_length(
target: np.array, pred_length: int, is_train: bool
target: np.array, pred_length: int, is_train: bool
) -> int:
return target.shape[-1] + (0 if is_train else pred_length)
@@ -35,7 +35,7 @@ def target_transformation_length(
class Transformation(ABC):
@abstractmethod
def __call__(
self, data_it: Iterator[DataEntry], is_train: bool
self, data_it: Iterator[DataEntry], is_train: bool
) -> Iterator[DataEntry]:
pass
@@ -52,7 +52,7 @@ class Chain(Transformation):
self.trans = trans
def __call__(
self, data_it: Iterator[DataEntry], is_train: bool
self, data_it: Iterator[DataEntry], is_train: bool
) -> Iterator[DataEntry]:
tmp = data_it
for t in self.trans:
@@ -65,7 +65,7 @@ class Chain(Transformation):
class Identity(Transformation):
def __call__(
self, data_it: Iterator[DataEntry], is_train: bool
self, data_it: Iterator[DataEntry], is_train: bool
) -> Iterator[DataEntry]:
return data_it
@@ -190,7 +190,7 @@ class SetFieldIfNotPresent(SimpleTransformation):
Sets a field in the dictionary with the given value, in case it does not exist already
Parameters
----------
output_field
field
Name of the field that will be set
value
Value to be set
@@ -219,7 +219,7 @@ class AsNumpyArray(SimpleTransformation):
"""
def __init__(
self, field: str, expected_ndim: int, dtype: np.dtype = np.float32
self, field: str, expected_ndim: int, dtype: np.dtype = np.float32
) -> None:
self.field = field
self.expected_ndim = expected_ndim
@@ -286,7 +286,7 @@ class VstackFeatures(SimpleTransformation):
"""
def __init__(
self, output_field: str, input_fields: List[str], drop_inputs: bool = True
self, output_field: str, input_fields: List[str], drop_inputs: bool = True
) -> None:
self.output_field = output_field
self.input_fields = input_fields
@@ -320,7 +320,7 @@ class ConcatFeatures(SimpleTransformation):
"""
def __init__(
self, output_field: str, input_fields: List[str], drop_inputs: bool = True
self, output_field: str, input_fields: List[str], drop_inputs: bool = True
) -> None:
self.output_field = output_field
self.input_fields = input_fields
@@ -385,7 +385,7 @@ class ListFeatures(SimpleTransformation):
"""
def __init__(
self, output_field: str, input_fields: List[str], drop_inputs: bool = True
self, output_field: str, input_fields: List[str], drop_inputs: bool = True
) -> None:
self.output_field = output_field
self.input_fields = input_fields
@@ -423,11 +423,11 @@ class AddObservedValuesIndicator(SimpleTransformation):
"""
def __init__(
self,
target_field: str,
output_field: str,
dummy_value: int = 0,
convert_nans: bool = True,
self,
target_field: str,
output_field: str,
dummy_value: int = 0,
convert_nans: bool = True,
) -> None:
self.dummy_value = dummy_value
self.target_field = target_field
@@ -496,12 +496,12 @@ class AddConstFeature(MapTransformation):
"""
def __init__(
self,
output_field: str,
target_field: str,
pred_length: int,
const: float = 1.0,
dtype: np.dtype = np.float32,
self,
output_field: str,
target_field: str,
pred_length: int,
const: float = 1.0,
dtype: np.dtype = np.float32,
) -> None:
self.pred_length = pred_length
self.const = const
@@ -539,22 +539,22 @@ class AddTimeFeatures(MapTransformation):
"""
def __init__(
self,
start_field: str,
target_field: str,
output_field: str,
time_features: List[TimeFeature],
pred_length: int,
self,
start_field: str,
target_field: str,
output_field: str,
time_features: List[TimeFeature],
pred_length: int,
) -> None:
self.date_features = time_features
self.pred_length = pred_length
self.start_field = start_field
self.target_field = target_field
self.output_field = output_field
self._min_time_point: pd.Timestamp = None
self._max_time_point: pd.Timestamp = None
self._full_range_date_features: np.ndarray = None
self._date_index: pd.DatetimeIndex = None
self._min_time_point: Optional[pd.Timestamp] = None
self._max_time_point: Optional[pd.Timestamp] = None
self._full_range_date_features: Optional[np.ndarray] = None
self._date_index: Optional[pd.DatetimeIndex] = None
def _update_cache(self, start: pd.Timestamp, length: int) -> None:
end = shift_timestamp(start, length)
@@ -583,7 +583,7 @@ class AddTimeFeatures(MapTransformation):
)
self._update_cache(start, length)
i0 = self._date_index[start]
features = self._full_range_date_features[..., i0 : i0 + length]
features = self._full_range_date_features[..., i0:i0 + length]
data[self.output_field] = features
return data
@@ -608,11 +608,11 @@ class AddAgeFeature(MapTransformation):
"""
def __init__(
self,
target_field: str,
output_field: str,
pred_length: int,
log_scale: bool = True,
self,
target_field: str,
output_field: str,
pred_length: int,
log_scale: bool = True,
) -> None:
self.pred_length = pred_length
self.target_field = target_field
@@ -681,17 +681,17 @@ class InstanceSplitter(FlatMapTransformation):
"""
def __init__(
self,
target_field: str,
is_pad_field: str,
start_field: str,
forecast_start_field: str,
train_sampler: InstanceSampler,
past_length: int,
future_length: int,
output_NTC: bool = True,
time_series_fields: Optional[List[str]] = None,
pick_incomplete: bool = True,
self,
target_field: str,
is_pad_field: str,
start_field: str,
forecast_start_field: str,
train_sampler: InstanceSampler,
past_length: int,
future_length: int,
output_NTC: bool = True,
time_series_fields: Optional[List[str]] = None,
pick_incomplete: bool = True,
) -> None:
assert future_length > 0
@@ -746,7 +746,7 @@ class InstanceSplitter(FlatMapTransformation):
for ts_field in slice_cols:
if i > self.past_length:
# truncate to past_length
past_piece = d[ts_field][..., i - self.past_length : i]
past_piece = d[ts_field][..., i - self.past_length: i]
elif i < self.past_length:
pad_block = np.zeros(
d[ts_field].shape[:-1] + (pad_length,), dtype=d[ts_field].dtype
@@ -757,7 +757,7 @@ class InstanceSplitter(FlatMapTransformation):
else:
past_piece = d[ts_field][..., :i]
d[self._past(ts_field)] = past_piece
d[self._future(ts_field)] = d[ts_field][..., i : i + pl]
d[self._future(ts_field)] = d[ts_field][..., i: i + pl]
del d[ts_field]
pad_indicator = np.zeros(self.past_length)
if pad_length > 0:
@@ -824,19 +824,19 @@ class CanonicalInstanceSplitter(FlatMapTransformation):
"""
def __init__(
self,
target_field: str,
is_pad_field: str,
start_field: str,
forecast_start_field: str,
instance_sampler: InstanceSampler,
instance_length: int,
output_NTC: bool = True,
time_series_fields: List[str] = [],
allow_target_padding: bool = False,
pad_value: float = 0.0,
use_prediction_features: bool = False,
prediction_length: Optional[int] = None,
self,
target_field: str,
is_pad_field: str,
start_field: str,
forecast_start_field: str,
instance_sampler: InstanceSampler,
instance_length: int,
output_NTC: bool = True,
time_series_fields: List[str] = [],
allow_target_padding: bool = False,
pad_value: float = 0.0,
use_prediction_features: bool = False,
prediction_length: Optional[int] = None,
) -> None:
self.instance_sampler = instance_sampler
self.instance_length = instance_length
@@ -850,7 +850,7 @@ class CanonicalInstanceSplitter(FlatMapTransformation):
self.forecast_start_field = forecast_start_field
assert (
not use_prediction_features or prediction_length is not None
not use_prediction_features or prediction_length is not None
), "You must specify `prediction_length` if `use_prediction_features`"
self.use_prediction_features = use_prediction_features
@@ -908,14 +908,14 @@ class CanonicalInstanceSplitter(FlatMapTransformation):
)
past_ts = np.concatenate([pad_pre, full_ts[..., :i]], axis=-1)
else:
past_ts = full_ts[..., (i - self.instance_length) : i]
past_ts = full_ts[..., (i - self.instance_length): i]
past_ts = past_ts.transpose() if self.output_NTC else past_ts
d[self._past(ts_field)] = past_ts
if self.use_prediction_features and not is_train:
if not ts_field == self.target_field:
future_ts = full_ts[..., i : i + self.prediction_length]
future_ts = full_ts[..., i: i + self.prediction_length]
future_ts = (
future_ts.transpose() if self.output_NTC else future_ts
)