mirror of
https://github.com/wassname/pytorch-ts.git
synced 2026-07-03 04:54:08 +08:00
formatting fixes and cleanups
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
from pts.dataset.common import DataEntry
|
||||
from pts.dataset.common import DataEntry, FieldName
|
||||
from pts.dataset.list_dataset import ListDataset
|
||||
from pts.dataset.sampler import (
|
||||
UniformSplitSampler,
|
||||
@@ -8,3 +8,4 @@ from pts.dataset.sampler import (
|
||||
)
|
||||
from pts.dataset.sampler import InstanceSampler, UniformSplitSampler, TestSplitSampler, ExpectedNumInstanceSampler, BucketInstanceSampler
|
||||
from pts.dataset.loader import DataLoader, TrainDataLoader, InferenceDataLoader
|
||||
from pts.dataset.utils import to_pandas
|
||||
+25
-1
@@ -2,7 +2,6 @@ from abc import ABC, abstractmethod
|
||||
|
||||
from typing import Any, Dict, Sized, Iterable, NamedTuple
|
||||
|
||||
|
||||
DataEntry = Dict[str, Any]
|
||||
|
||||
|
||||
@@ -11,6 +10,31 @@ class SourceContext(NamedTuple):
|
||||
row: int
|
||||
|
||||
|
||||
class FieldName:
|
||||
"""
|
||||
A bundle of default field names to be used by clients when instantiating
|
||||
transformer instances.
|
||||
"""
|
||||
|
||||
ITEM_ID = "item_id"
|
||||
|
||||
START = "start"
|
||||
TARGET = "target"
|
||||
|
||||
FEAT_STATIC_CAT = "feat_static_cat"
|
||||
FEAT_STATIC_REAL = "feat_static_real"
|
||||
FEAT_DYNAMIC_CAT = "feat_dynamic_cat"
|
||||
FEAT_DYNAMIC_REAL = "feat_dynamic_real"
|
||||
|
||||
FEAT_TIME = "time_feat"
|
||||
FEAT_CONST = "feat_dynamic_const"
|
||||
FEAT_AGE = "feat_dynamic_age"
|
||||
|
||||
OBSERVED_VALUES = "observed_values"
|
||||
IS_PAD = "is_pad"
|
||||
FORECAST_START = "forecast_start"
|
||||
|
||||
|
||||
class Dataset(Sized, Iterable[DataEntry], ABC):
|
||||
@abstractmethod
|
||||
def __iter__(self) -> Iterable[DataEntry]:
|
||||
|
||||
@@ -5,9 +5,10 @@ from .process import ProcessDataEntry
|
||||
|
||||
|
||||
class ListDataset(Dataset):
|
||||
def __init__(
|
||||
self, data_iter: Iterable[DataEntry], freq: str, one_dim_target: bool = True
|
||||
) -> None:
|
||||
def __init__(self,
|
||||
data_iter: Iterable[DataEntry],
|
||||
freq: str,
|
||||
one_dim_target: bool = True) -> None:
|
||||
process = ProcessDataEntry(freq, one_dim_target)
|
||||
self.list_data = [process(data) for data in data_iter]
|
||||
|
||||
|
||||
+38
-27
@@ -3,7 +3,7 @@ from typing import Callable, List, cast
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from pandas.tseries.frequencies import to_offset
|
||||
from pandas.tseries.offsets import Tick
|
||||
|
||||
from .common import DataEntry
|
||||
|
||||
@@ -17,7 +17,8 @@ class ProcessStartField:
|
||||
try:
|
||||
value = ProcessStartField.process(data[self.name], self.freq)
|
||||
except (TypeError, ValueError) as e:
|
||||
raise Exception(f'Error "{e}" occurred when reading field "{self.name}"')
|
||||
raise Exception(
|
||||
f'Error "{e}" occurred when reading field "{self.name}"')
|
||||
|
||||
data[self.name] = value
|
||||
|
||||
@@ -27,21 +28,26 @@ class ProcessStartField:
|
||||
@lru_cache(maxsize=10000)
|
||||
def process(string: str, freq: str) -> pd.Timestamp:
|
||||
timestamp = pd.Timestamp(string, freq=freq)
|
||||
# 'W-SUN' is the standardized freqstr for W
|
||||
if timestamp.freq.name in ("M", "W-SUN"):
|
||||
offset = to_offset(freq)
|
||||
timestamp = timestamp.replace(
|
||||
hour=0, minute=0, second=0, microsecond=0, nanosecond=0
|
||||
)
|
||||
return pd.Timestamp(offset.rollback(timestamp), freq=offset.freqstr)
|
||||
if timestamp.freq == "B":
|
||||
# does not floor on business day as it is not allowed
|
||||
return timestamp
|
||||
return pd.Timestamp(timestamp.floor(timestamp.freq), freq=timestamp.freq)
|
||||
|
||||
# operate on time information (days, hours, minute, second)
|
||||
if isinstance(timestamp.freq, Tick):
|
||||
return pd.Timestamp(timestamp.floor(timestamp.freq),
|
||||
timestamp.freq)
|
||||
|
||||
# since we are only interested in the data piece, we normalize the
|
||||
# time information
|
||||
timestamp = timestamp.replace(hour=0,
|
||||
minute=0,
|
||||
second=0,
|
||||
microsecond=0,
|
||||
nanosecond=0)
|
||||
|
||||
return timestamp.freq.rollforward(timestamp)
|
||||
|
||||
|
||||
class ProcessTimeSeriesField:
|
||||
def __init__(self, name, is_required: bool, is_static: bool, is_cat: bool) -> None:
|
||||
def __init__(self, name, is_required: bool, is_static: bool,
|
||||
is_cat: bool) -> None:
|
||||
self.name = name
|
||||
self.is_required = is_required
|
||||
self.req_ndim = 1 if is_static else 2
|
||||
@@ -65,7 +71,8 @@ class ProcessTimeSeriesField:
|
||||
elif not self.is_required:
|
||||
return data
|
||||
else:
|
||||
raise Exception(f"JSON object is missing a required field `{self.name}`")
|
||||
raise Exception(
|
||||
f"JSON object is missing a required field `{self.name}`")
|
||||
|
||||
|
||||
class ProcessDataEntry:
|
||||
@@ -74,24 +81,28 @@ class ProcessDataEntry:
|
||||
List[Callable[[DataEntry], DataEntry]],
|
||||
[
|
||||
ProcessStartField("start", freq=freq),
|
||||
ProcessTimeSeriesField(
|
||||
"target", is_required=True, is_cat=False, is_static=one_dim_target
|
||||
),
|
||||
ProcessTimeSeriesField(
|
||||
"feat_dynamic_cat", is_required=False, is_cat=True, is_static=False
|
||||
),
|
||||
ProcessTimeSeriesField("target",
|
||||
is_required=True,
|
||||
is_cat=False,
|
||||
is_static=one_dim_target),
|
||||
ProcessTimeSeriesField("feat_dynamic_cat",
|
||||
is_required=False,
|
||||
is_cat=True,
|
||||
is_static=False),
|
||||
ProcessTimeSeriesField(
|
||||
"feat_dynamic_real",
|
||||
is_required=False,
|
||||
is_cat=False,
|
||||
is_static=False,
|
||||
),
|
||||
ProcessTimeSeriesField(
|
||||
"feat_static_cat", is_required=False, is_cat=True, is_static=True
|
||||
),
|
||||
ProcessTimeSeriesField(
|
||||
"feat_static_real", is_required=False, is_cat=False, is_static=True
|
||||
),
|
||||
ProcessTimeSeriesField("feat_static_cat",
|
||||
is_required=False,
|
||||
is_cat=True,
|
||||
is_static=True),
|
||||
ProcessTimeSeriesField("feat_static_real",
|
||||
is_required=False,
|
||||
is_cat=False,
|
||||
is_static=True),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
+8
-11
@@ -8,7 +8,8 @@ from .stat import ScaleHistogram
|
||||
|
||||
class InstanceSampler(ABC):
|
||||
@abstractmethod
|
||||
def __call__(self, ts: np.ndarray, a: int, b: int) -> Union[np.ndarray, List[int]]:
|
||||
def __call__(self, ts: np.ndarray, a: int,
|
||||
b: int) -> Union[np.ndarray, List[int]]:
|
||||
pass
|
||||
|
||||
|
||||
@@ -20,17 +21,16 @@ class UniformSplitSampler(InstanceSampler):
|
||||
p
|
||||
Probability of selecting a time point
|
||||
"""
|
||||
|
||||
def __init__(self, p: float = 1.0 / 20.0) -> None:
|
||||
self.p = p
|
||||
self.lookup = np.arange(2 ** 13)
|
||||
self.lookup = np.arange(2**13)
|
||||
|
||||
def __call__(self, ts: np.ndarray, a: int, b: int) -> np.ndarray:
|
||||
assert a <= b
|
||||
while ts.shape[-1] >= len(self.lookup):
|
||||
self.lookup = np.arange(2 * len(self.lookup))
|
||||
mask = np.random.uniform(low=0.0, high=1.0, size=b - a + 1) < self.p
|
||||
return self.lookup[a: a + len(mask)][mask]
|
||||
return self.lookup[a:a + len(mask)][mask]
|
||||
|
||||
|
||||
class TestSplitSampler(InstanceSampler):
|
||||
@@ -38,7 +38,6 @@ class TestSplitSampler(InstanceSampler):
|
||||
Sampler used for prediction. Always selects the last time point for
|
||||
splitting i.e. the forecast point for the time series.
|
||||
"""
|
||||
|
||||
def __call__(self, ts: np.ndarray, a: int, b: int) -> np.ndarray:
|
||||
return np.array([b])
|
||||
|
||||
@@ -53,12 +52,11 @@ class ExpectedNumInstanceSampler(InstanceSampler):
|
||||
num_instances
|
||||
number of training examples generated per time series on average
|
||||
"""
|
||||
|
||||
def __init__(self, num_instances: float) -> None:
|
||||
self.num_instances = num_instances
|
||||
self.avg_length = 0.0
|
||||
self.n = 0.0
|
||||
self.lookup = np.arange(2 ** 13)
|
||||
self.lookup = np.arange(2**13)
|
||||
|
||||
def __call__(self, ts: np.ndarray, a: int, b: int) -> np.ndarray:
|
||||
while ts.shape[-1] >= len(self.lookup):
|
||||
@@ -69,7 +67,7 @@ class ExpectedNumInstanceSampler(InstanceSampler):
|
||||
p = self.num_instances / self.avg_length
|
||||
|
||||
mask = np.random.uniform(low=0.0, high=1.0, size=b - a + 1) < p
|
||||
indices = self.lookup[a: a + len(mask)][mask]
|
||||
indices = self.lookup[a:a + len(mask)][mask]
|
||||
return indices
|
||||
|
||||
|
||||
@@ -85,17 +83,16 @@ class BucketInstanceSampler(InstanceSampler):
|
||||
The histogram of scale for the time series. Here scale is the mean abs
|
||||
value of the time series.
|
||||
"""
|
||||
|
||||
def __init__(self, scale_histogram: ScaleHistogram) -> None:
|
||||
# probability of sampling a bucket i is the inverse of its number of
|
||||
# elements
|
||||
self.scale_histogram = scale_histogram
|
||||
self.lookup = np.arange(2 ** 13)
|
||||
self.lookup = np.arange(2**13)
|
||||
|
||||
def __call__(self, ts: np.ndarray, a: int, b: int) -> np.ndarray:
|
||||
while ts.shape[-1] >= len(self.lookup):
|
||||
self.lookup = np.arange(2 * len(self.lookup))
|
||||
p = 1.0 / self.scale_histogram.count(ts)
|
||||
mask = np.random.uniform(low=0.0, high=1.0, size=b - a + 1) < p
|
||||
indices = self.lookup[a: a + len(mask)][mask]
|
||||
indices = self.lookup[a:a + len(mask)][mask]
|
||||
return indices
|
||||
|
||||
@@ -17,10 +17,8 @@ class TransformedDataset(Dataset):
|
||||
transformations
|
||||
List of transformations to apply
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, base_dataset: Dataset, transformations: List[Transformation]
|
||||
) -> None:
|
||||
def __init__(self, base_dataset: Dataset,
|
||||
transformations: List[Transformation]) -> None:
|
||||
self.base_dataset = base_dataset
|
||||
self.transformations = Chain(transformations)
|
||||
|
||||
|
||||
@@ -0,0 +1,26 @@
|
||||
import pandas as pd
|
||||
|
||||
|
||||
def to_pandas(instance: dict, freq: str = None) -> pd.Series:
|
||||
"""
|
||||
Transform a dictionary into a pandas.Series object, using its
|
||||
"start" and "target" fields.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
instance
|
||||
Dictionary containing the time series data.
|
||||
freq
|
||||
Frequency to use in the pandas.Series index.
|
||||
|
||||
Returns
|
||||
-------
|
||||
pandas.Series
|
||||
Pandas time series object.
|
||||
"""
|
||||
target = instance["target"]
|
||||
start = instance["start"]
|
||||
if not freq:
|
||||
freq = start.freqstr
|
||||
index = pd.date_range(start=start, periods=len(target), freq=freq)
|
||||
return pd.Series(target, index=index)
|
||||
@@ -6,6 +6,7 @@ from pts.feature.time_feature import (
|
||||
DayOfYear,
|
||||
MonthOfYear,
|
||||
WeekOfYear,
|
||||
time_features_from_frequency_str,
|
||||
)
|
||||
|
||||
from pts.feature.transform import (
|
||||
@@ -35,4 +36,4 @@ from pts.feature.transform import (
|
||||
SelectFields,
|
||||
)
|
||||
|
||||
from pts.feature.lag import time_features_from_frequency_str, get_lags_for_frequency
|
||||
from pts.feature.lag import get_lags_for_frequency
|
||||
|
||||
+49
-105
@@ -12,82 +12,13 @@
|
||||
# permissions and limitations under the License.
|
||||
|
||||
# Standard library imports
|
||||
import re
|
||||
from typing import List, Tuple, Optional
|
||||
|
||||
# Third-party imports
|
||||
import numpy as np
|
||||
from pandas.tseries.frequencies import to_offset
|
||||
|
||||
# First-party imports
|
||||
from .time_feature import (
|
||||
DayOfMonth,
|
||||
DayOfWeek,
|
||||
DayOfYear,
|
||||
HourOfDay,
|
||||
MinuteOfHour,
|
||||
MonthOfYear,
|
||||
TimeFeature,
|
||||
WeekOfYear,
|
||||
)
|
||||
|
||||
|
||||
def get_granularity(freq_str: str) -> Tuple[int, str]:
|
||||
"""
|
||||
Splits a frequency string such as "7D" into the multiple 7 and the base
|
||||
granularity "D".
|
||||
|
||||
Parameters
|
||||
----------
|
||||
|
||||
freq_str
|
||||
Frequency string of the form [multiple][granularity] such as "12H", "5min", "1D" etc.
|
||||
"""
|
||||
freq_regex = r"\s*((\d+)?)\s*([^\d]\w*)"
|
||||
m = re.match(freq_regex, freq_str)
|
||||
assert m is not None, "Cannot parse frequency string: %s" % freq_str
|
||||
groups = m.groups()
|
||||
multiple = int(groups[1]) if groups[1] is not None else 1
|
||||
granularity = groups[2]
|
||||
return multiple, granularity
|
||||
|
||||
|
||||
def time_features_from_frequency_str(freq_str: str) -> List[TimeFeature]:
|
||||
"""
|
||||
Returns a list of time features that will be appropriate for the given frequency string.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
|
||||
freq_str
|
||||
Frequency string of the form [multiple][granularity] such as "12H", "5min", "1D" etc.
|
||||
|
||||
"""
|
||||
_, granularity = get_granularity(freq_str)
|
||||
if granularity == "M":
|
||||
feature_classes = [MonthOfYear]
|
||||
elif granularity == "W":
|
||||
feature_classes = [DayOfMonth, WeekOfYear]
|
||||
elif granularity in ["D", "B"]:
|
||||
feature_classes = [DayOfWeek, DayOfMonth, DayOfYear]
|
||||
elif granularity == "H":
|
||||
feature_classes = [HourOfDay, DayOfWeek, DayOfMonth, DayOfYear]
|
||||
elif granularity in ["min", "T"]:
|
||||
feature_classes = [MinuteOfHour, HourOfDay, DayOfWeek, DayOfMonth, DayOfYear]
|
||||
else:
|
||||
supported_freq_msg = f"""
|
||||
Unsupported frequency {freq_str}
|
||||
|
||||
The following frequencies are supported:
|
||||
|
||||
M - monthly
|
||||
W - week
|
||||
D - daily
|
||||
H - hourly
|
||||
min - minutely
|
||||
"""
|
||||
raise RuntimeError(supported_freq_msg)
|
||||
|
||||
return [cls() for cls in feature_classes]
|
||||
from .utils import get_granularity
|
||||
|
||||
|
||||
def _make_lags(middle: int, delta: int) -> np.ndarray:
|
||||
@@ -97,9 +28,9 @@ def _make_lags(middle: int, delta: int) -> np.ndarray:
|
||||
return np.arange(middle - delta, middle + delta + 1).tolist()
|
||||
|
||||
|
||||
def get_lags_for_frequency(
|
||||
freq_str: str, lag_ub: int = 1200, num_lags: Optional[int] = None
|
||||
) -> List[int]:
|
||||
def get_lags_for_frequency(freq_str: str,
|
||||
lag_ub: int = 1200,
|
||||
num_lags: Optional[int] = None) -> List[int]:
|
||||
"""
|
||||
Generates a list of lags that that are appropriate for the given frequency string.
|
||||
|
||||
@@ -125,57 +56,70 @@ def get_lags_for_frequency(
|
||||
# Lags are target values at the same `season` (+/- delta) but in the previous cycle.
|
||||
def _make_lags_for_minute(multiple, num_cycles=3):
|
||||
# We use previous ``num_cycles`` hours to generate lags
|
||||
return [_make_lags(k * 60 // multiple, 2) for k in range(1, num_cycles + 1)]
|
||||
return [
|
||||
_make_lags(k * 60 // multiple, 2)
|
||||
for k in range(1, num_cycles + 1)
|
||||
]
|
||||
|
||||
def _make_lags_for_hour(multiple, num_cycles=7):
|
||||
# We use previous ``num_cycles`` days to generate lags
|
||||
return [_make_lags(k * 24 // multiple, 1) for k in range(1, num_cycles + 1)]
|
||||
return [
|
||||
_make_lags(k * 24 // multiple, 1)
|
||||
for k in range(1, num_cycles + 1)
|
||||
]
|
||||
|
||||
def _make_lags_for_day(multiple, num_cycles=4):
|
||||
# We use previous ``num_cycles`` weeks to generate lags
|
||||
# We use the last month (in addition to 4 weeks) to generate lag.
|
||||
return [_make_lags(k * 7 // multiple, 1) for k in range(1, num_cycles + 1)] + [
|
||||
_make_lags(30 // multiple, 1)
|
||||
]
|
||||
return [
|
||||
_make_lags(k * 7 // multiple, 1) for k in range(1, num_cycles + 1)
|
||||
] + [_make_lags(30 // multiple, 1)]
|
||||
|
||||
def _make_lags_for_week(multiple, num_cycles=3):
|
||||
# We use previous ``num_cycles`` years to generate lags
|
||||
# Additionally, we use previous 4, 8, 12 weeks
|
||||
return [_make_lags(k * 52 // multiple, 1) for k in range(1, num_cycles + 1)] + [
|
||||
[4 // multiple, 8 // multiple, 12 // multiple]
|
||||
]
|
||||
return [
|
||||
_make_lags(k * 52 // multiple, 1)
|
||||
for k in range(1, num_cycles + 1)
|
||||
] + [[4 // multiple, 8 // multiple, 12 // multiple]]
|
||||
|
||||
def _make_lags_for_month(multiple, num_cycles=3):
|
||||
# We use previous ``num_cycles`` years to generate lags
|
||||
return [_make_lags(k * 12 // multiple, 1) for k in range(1, num_cycles + 1)]
|
||||
return [
|
||||
_make_lags(k * 12 // multiple, 1)
|
||||
for k in range(1, num_cycles + 1)
|
||||
]
|
||||
|
||||
if granularity == "M":
|
||||
lags = _make_lags_for_month(multiple)
|
||||
elif granularity == "W":
|
||||
lags = _make_lags_for_week(multiple)
|
||||
elif granularity == "D":
|
||||
lags = _make_lags_for_day(multiple) + _make_lags_for_week(multiple / 7.0)
|
||||
elif granularity == "B":
|
||||
# multiple, granularity = get_granularity(freq_str)
|
||||
offset = to_offset(freq_str)
|
||||
|
||||
if offset.name == "M":
|
||||
lags = _make_lags_for_month(offset.n)
|
||||
elif offset.name == "W-SUN":
|
||||
lags = _make_lags_for_week(offset.n)
|
||||
elif offset.name == "D":
|
||||
lags = _make_lags_for_day(offset.n) + _make_lags_for_week(
|
||||
offset.n / 7.0)
|
||||
elif offset.name == "B":
|
||||
# todo find good lags for business day
|
||||
lags = []
|
||||
elif granularity == "H":
|
||||
lags = (
|
||||
_make_lags_for_hour(multiple)
|
||||
+ _make_lags_for_day(multiple / 24.0)
|
||||
+ _make_lags_for_week(multiple / (24.0 * 7))
|
||||
)
|
||||
elif granularity == "min":
|
||||
lags = (
|
||||
_make_lags_for_minute(multiple)
|
||||
+ _make_lags_for_hour(multiple / 60.0)
|
||||
+ _make_lags_for_day(multiple / (60.0 * 24))
|
||||
+ _make_lags_for_week(multiple / (60.0 * 24 * 7))
|
||||
)
|
||||
elif offset.name == "H":
|
||||
lags = (_make_lags_for_hour(offset.n) +
|
||||
_make_lags_for_day(offset.n / 24.0) +
|
||||
_make_lags_for_week(offset.n / (24.0 * 7)))
|
||||
# minutes
|
||||
elif offset.name == "T":
|
||||
lags = (_make_lags_for_minute(offset.n) +
|
||||
_make_lags_for_hour(offset.n / 60.0) +
|
||||
_make_lags_for_day(offset.n / (60.0 * 24)) +
|
||||
_make_lags_for_week(offset.n / (60.0 * 24 * 7)))
|
||||
else:
|
||||
raise Exception("invalid frequency")
|
||||
|
||||
# flatten lags list and filter
|
||||
lags = [int(lag) for sub_list in lags for lag in sub_list if 7 < lag <= lag_ub]
|
||||
lags = [
|
||||
int(lag) for sub_list in lags for lag in sub_list if 7 < lag <= lag_ub
|
||||
]
|
||||
lags = [1, 2, 3, 4, 5, 6, 7] + sorted(list(set(lags)))
|
||||
|
||||
return lags[:num_lags]
|
||||
return lags[:num_lags]
|
||||
@@ -1,8 +1,10 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from .utils import get_granularity
|
||||
|
||||
class TimeFeature(ABC):
|
||||
def __init__(self, normalized: bool = True):
|
||||
@@ -17,7 +19,6 @@ class MinuteOfHour(TimeFeature):
|
||||
"""
|
||||
Minute of hour encoded as value between [-0.5, 0.5]
|
||||
"""
|
||||
|
||||
def __call__(self, index: pd.DatetimeIndex) -> np.ndarray:
|
||||
if self.normalized:
|
||||
return index.minute / 59.0 - 0.5
|
||||
@@ -29,7 +30,6 @@ class HourOfDay(TimeFeature):
|
||||
"""
|
||||
Hour of day encoded as value between [-0.5, 0.5]
|
||||
"""
|
||||
|
||||
def __call__(self, index: pd.DatetimeIndex) -> np.ndarray:
|
||||
if self.normalized:
|
||||
return index.hour / 23.0 - 0.5
|
||||
@@ -41,7 +41,6 @@ class DayOfWeek(TimeFeature):
|
||||
"""
|
||||
Hour of day encoded as value between [-0.5, 0.5]
|
||||
"""
|
||||
|
||||
def __call__(self, index: pd.DatetimeIndex) -> np.ndarray:
|
||||
if self.normalized:
|
||||
return index.dayofweek / 6.0 - 0.5
|
||||
@@ -53,7 +52,6 @@ class DayOfMonth(TimeFeature):
|
||||
"""
|
||||
Day of month encoded as value between [-0.5, 0.5]
|
||||
"""
|
||||
|
||||
def __call__(self, index: pd.DatetimeIndex) -> np.ndarray:
|
||||
if self.normalized:
|
||||
return index.day / 30.0 - 0.5
|
||||
@@ -65,7 +63,6 @@ class DayOfYear(TimeFeature):
|
||||
"""
|
||||
Day of year encoded as value between [-0.5, 0.5]
|
||||
"""
|
||||
|
||||
def __call__(self, index: pd.DatetimeIndex) -> np.ndarray:
|
||||
if self.normalized:
|
||||
return index.dayofyear / 364.0 - 0.5
|
||||
@@ -77,7 +74,6 @@ class MonthOfYear(TimeFeature):
|
||||
"""
|
||||
Month of year encoded as value between [-0.5, 0.5]
|
||||
"""
|
||||
|
||||
def __call__(self, index: pd.DatetimeIndex) -> np.ndarray:
|
||||
if self.normalized:
|
||||
return index.month / 11.0 - 0.5
|
||||
@@ -89,9 +85,49 @@ class WeekOfYear(TimeFeature):
|
||||
"""
|
||||
Week of year encoded as value between [-0.5, 0.5]
|
||||
"""
|
||||
|
||||
def __call__(self, index: pd.DatetimeIndex) -> np.ndarray:
|
||||
if self.normalized:
|
||||
return index.weekofyear / 51.0 - 0.5
|
||||
else:
|
||||
return index.weekofyear.map(float)
|
||||
|
||||
|
||||
def time_features_from_frequency_str(freq_str: str) -> List[TimeFeature]:
|
||||
"""
|
||||
Returns a list of time features that will be appropriate for the given frequency string.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
|
||||
freq_str
|
||||
Frequency string of the form [multiple][granularity] such as "12H", "5min", "1D" etc.
|
||||
|
||||
"""
|
||||
_, granularity = get_granularity(freq_str)
|
||||
if granularity == "M":
|
||||
feature_classes = [MonthOfYear]
|
||||
elif granularity == "W":
|
||||
feature_classes = [DayOfMonth, WeekOfYear]
|
||||
elif granularity in ["D", "B"]:
|
||||
feature_classes = [DayOfWeek, DayOfMonth, DayOfYear]
|
||||
elif granularity == "H":
|
||||
feature_classes = [HourOfDay, DayOfWeek, DayOfMonth, DayOfYear]
|
||||
elif granularity in ["min", "T"]:
|
||||
feature_classes = [
|
||||
MinuteOfHour, HourOfDay, DayOfWeek, DayOfMonth, DayOfYear
|
||||
]
|
||||
else:
|
||||
supported_freq_msg = f"""
|
||||
Unsupported frequency {freq_str}
|
||||
|
||||
The following frequencies are supported:
|
||||
|
||||
M - monthly
|
||||
W - week
|
||||
D - daily
|
||||
H - hourly
|
||||
min - minutely
|
||||
"""
|
||||
raise RuntimeError(supported_freq_msg)
|
||||
|
||||
return [cls() for cls in feature_classes]
|
||||
|
||||
+111
-133
@@ -24,17 +24,15 @@ def shift_timestamp(ts: pd.Timestamp, offset: int) -> pd.Timestamp:
|
||||
raise Exception(ex)
|
||||
|
||||
|
||||
def target_transformation_length(
|
||||
target: np.array, pred_length: int, is_train: bool
|
||||
) -> int:
|
||||
def target_transformation_length(target: np.array, pred_length: int,
|
||||
is_train: bool) -> int:
|
||||
return target.shape[-1] + (0 if is_train else pred_length)
|
||||
|
||||
|
||||
class Transformation(ABC):
|
||||
@abstractmethod
|
||||
def __call__(
|
||||
self, data_it: Iterator[DataEntry], is_train: bool
|
||||
) -> Iterator[DataEntry]:
|
||||
def __call__(self, data_it: Iterator[DataEntry],
|
||||
is_train: bool) -> Iterator[DataEntry]:
|
||||
pass
|
||||
|
||||
def estimate(self, data_it: Iterator[DataEntry]) -> Iterator[DataEntry]:
|
||||
@@ -45,13 +43,11 @@ class Chain(Transformation):
|
||||
"""
|
||||
Chain multiple transformations together.
|
||||
"""
|
||||
|
||||
def __init__(self, trans: List[Transformation]) -> None:
|
||||
self.trans = trans
|
||||
|
||||
def __call__(
|
||||
self, data_it: Iterator[DataEntry], is_train: bool
|
||||
) -> Iterator[DataEntry]:
|
||||
def __call__(self, data_it: Iterator[DataEntry],
|
||||
is_train: bool) -> Iterator[DataEntry]:
|
||||
tmp = data_it
|
||||
for t in self.trans:
|
||||
tmp = t(tmp, is_train)
|
||||
@@ -62,9 +58,8 @@ class Chain(Transformation):
|
||||
|
||||
|
||||
class IdentityTransformation(Transformation):
|
||||
def __call__(
|
||||
self, data_it: Iterator[DataEntry], is_train: bool
|
||||
) -> Iterator[DataEntry]:
|
||||
def __call__(self, data_it: Iterator[DataEntry],
|
||||
is_train: bool) -> Iterator[DataEntry]:
|
||||
return data_it
|
||||
|
||||
|
||||
@@ -72,8 +67,8 @@ class MapTransformation(Transformation):
|
||||
"""
|
||||
Base class for Transformations that returns exactly one result per input in the stream.
|
||||
"""
|
||||
|
||||
def __call__(self, data_it: Iterator[DataEntry], is_train: bool) -> Iterator:
|
||||
def __call__(self, data_it: Iterator[DataEntry],
|
||||
is_train: bool) -> Iterator:
|
||||
for data_entry in data_it:
|
||||
try:
|
||||
yield self.map_transform(data_entry.copy(), is_train)
|
||||
@@ -89,7 +84,6 @@ class SimpleTransformation(MapTransformation):
|
||||
"""
|
||||
Element wise transformations that are the same in train and test mode
|
||||
"""
|
||||
|
||||
def map_transform(self, data: DataEntry, is_train: bool) -> DataEntry:
|
||||
return self.transform(data)
|
||||
|
||||
@@ -105,7 +99,6 @@ class AdhocTransform(SimpleTransformation):
|
||||
It is OK to use this for experiments and outside of a model pipeline that
|
||||
needs to be serialized.
|
||||
"""
|
||||
|
||||
def __init__(self, func: Callable[[DataEntry], DataEntry]) -> None:
|
||||
self.func = func
|
||||
|
||||
@@ -118,13 +111,14 @@ class FlatMapTransformation(Transformation):
|
||||
Transformations that yield zero or more results per input, but do not combine
|
||||
elements from the input stream.
|
||||
"""
|
||||
|
||||
def __call__(self, data_it: Iterator[DataEntry], is_train: bool) -> Iterator:
|
||||
def __call__(self, data_it: Iterator[DataEntry],
|
||||
is_train: bool) -> Iterator:
|
||||
num_idle_transforms = 0
|
||||
for data_entry in data_it:
|
||||
num_idle_transforms += 1
|
||||
try:
|
||||
for result in self.flatmap_transform(data_entry.copy(), is_train):
|
||||
for result in self.flatmap_transform(data_entry.copy(),
|
||||
is_train):
|
||||
num_idle_transforms = 0
|
||||
yield result
|
||||
except Exception as e:
|
||||
@@ -135,11 +129,11 @@ class FlatMapTransformation(Transformation):
|
||||
f"This means the transformation looped over "
|
||||
f"MAX_IDLE_TRANSFORMS={MAX_IDLE_TRANSFORMS} "
|
||||
f"inputs without returning any output.\n"
|
||||
f"This occurred in the following transformation:\n{self}"
|
||||
)
|
||||
f"This occurred in the following transformation:\n{self}")
|
||||
|
||||
@abstractmethod
|
||||
def flatmap_transform(self, data: DataEntry, is_train: bool) -> Iterator[DataEntry]:
|
||||
def flatmap_transform(self, data: DataEntry,
|
||||
is_train: bool) -> Iterator[DataEntry]:
|
||||
pass
|
||||
|
||||
|
||||
@@ -147,7 +141,8 @@ class FilterTransformation(FlatMapTransformation):
|
||||
def __init__(self, condition: Callable[[DataEntry], bool]) -> None:
|
||||
self.condition = condition
|
||||
|
||||
def flatmap_transform(self, data: DataEntry, is_train: bool) -> Iterator[DataEntry]:
|
||||
def flatmap_transform(self, data: DataEntry,
|
||||
is_train: bool) -> Iterator[DataEntry]:
|
||||
if self.condition(data):
|
||||
yield data
|
||||
|
||||
@@ -173,7 +168,6 @@ class SetField(SimpleTransformation):
|
||||
value
|
||||
Value to be set
|
||||
"""
|
||||
|
||||
def __init__(self, output_field: str, value: Any) -> None:
|
||||
self.output_field = output_field
|
||||
self.value = value
|
||||
@@ -193,7 +187,6 @@ class SetFieldIfNotPresent(SimpleTransformation):
|
||||
value
|
||||
Value to be set
|
||||
"""
|
||||
|
||||
def __init__(self, field: str, value: Any) -> None:
|
||||
self.output_field = field
|
||||
self.value = value
|
||||
@@ -215,10 +208,10 @@ class AsNumpyArray(SimpleTransformation):
|
||||
dtype
|
||||
numpy dtype to use.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, field: str, expected_ndim: int, dtype: np.dtype = np.float32
|
||||
) -> None:
|
||||
def __init__(self,
|
||||
field: str,
|
||||
expected_ndim: int,
|
||||
dtype: np.dtype = np.float32) -> None:
|
||||
self.field = field
|
||||
self.expected_ndim = expected_ndim
|
||||
self.dtype = dtype
|
||||
@@ -258,7 +251,6 @@ class ExpandDimArray(SimpleTransformation):
|
||||
axis
|
||||
Axis to expand (see np.expand_dims for details)
|
||||
"""
|
||||
|
||||
def __init__(self, field: str, axis: Optional[int] = None) -> None:
|
||||
self.field = field
|
||||
self.axis = axis
|
||||
@@ -282,20 +274,21 @@ class VstackFeatures(SimpleTransformation):
|
||||
drop_inputs
|
||||
If set to true the input fields will be dropped.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, output_field: str, input_fields: List[str], drop_inputs: bool = True
|
||||
) -> None:
|
||||
def __init__(self,
|
||||
output_field: str,
|
||||
input_fields: List[str],
|
||||
drop_inputs: bool = True) -> None:
|
||||
self.output_field = output_field
|
||||
self.input_fields = input_fields
|
||||
self.cols_to_drop = (
|
||||
[]
|
||||
if not drop_inputs
|
||||
else [fname for fname in self.input_fields if fname != output_field]
|
||||
)
|
||||
self.cols_to_drop = ([] if not drop_inputs else [
|
||||
fname for fname in self.input_fields if fname != output_field
|
||||
])
|
||||
|
||||
def transform(self, data: DataEntry) -> DataEntry:
|
||||
r = [data[fname] for fname in self.input_fields if data[fname] is not None]
|
||||
r = [
|
||||
data[fname] for fname in self.input_fields
|
||||
if data[fname] is not None
|
||||
]
|
||||
output = np.vstack(r)
|
||||
data[self.output_field] = output
|
||||
for fname in self.cols_to_drop:
|
||||
@@ -316,20 +309,21 @@ class ConcatFeatures(SimpleTransformation):
|
||||
drop_inputs
|
||||
If set to true the input fields will be dropped.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, output_field: str, input_fields: List[str], drop_inputs: bool = True
|
||||
) -> None:
|
||||
def __init__(self,
|
||||
output_field: str,
|
||||
input_fields: List[str],
|
||||
drop_inputs: bool = True) -> None:
|
||||
self.output_field = output_field
|
||||
self.input_fields = input_fields
|
||||
self.cols_to_drop = (
|
||||
[]
|
||||
if not drop_inputs
|
||||
else [fname for fname in self.input_fields if fname != output_field]
|
||||
)
|
||||
self.cols_to_drop = ([] if not drop_inputs else [
|
||||
fname for fname in self.input_fields if fname != output_field
|
||||
])
|
||||
|
||||
def transform(self, data: DataEntry) -> DataEntry:
|
||||
r = [data[fname] for fname in self.input_fields if data[fname] is not None]
|
||||
r = [
|
||||
data[fname] for fname in self.input_fields
|
||||
if data[fname] is not None
|
||||
]
|
||||
output = np.concatenate(r)
|
||||
data[self.output_field] = output
|
||||
for fname in self.cols_to_drop:
|
||||
@@ -347,7 +341,6 @@ class SwapAxes(SimpleTransformation):
|
||||
axes
|
||||
Axes to use
|
||||
"""
|
||||
|
||||
def __init__(self, input_fields: List[str], axes: Tuple[int, int]) -> None:
|
||||
self.input_fields = input_fields
|
||||
self.axis1, self.axis2 = axes
|
||||
@@ -365,8 +358,7 @@ class SwapAxes(SimpleTransformation):
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unexpected field type {type(v).__name__}, expected "
|
||||
f"np.ndarray or list[np.ndarray]"
|
||||
)
|
||||
f"np.ndarray or list[np.ndarray]")
|
||||
|
||||
|
||||
class ListFeatures(SimpleTransformation):
|
||||
@@ -381,17 +373,15 @@ class ListFeatures(SimpleTransformation):
|
||||
drop_inputs
|
||||
If true the input fields will be removed from the result.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, output_field: str, input_fields: List[str], drop_inputs: bool = True
|
||||
) -> None:
|
||||
def __init__(self,
|
||||
output_field: str,
|
||||
input_fields: List[str],
|
||||
drop_inputs: bool = True) -> None:
|
||||
self.output_field = output_field
|
||||
self.input_fields = input_fields
|
||||
self.cols_to_drop = (
|
||||
[]
|
||||
if not drop_inputs
|
||||
else [fname for fname in self.input_fields if fname != output_field]
|
||||
)
|
||||
self.cols_to_drop = ([] if not drop_inputs else [
|
||||
fname for fname in self.input_fields if fname != output_field
|
||||
])
|
||||
|
||||
def transform(self, data: DataEntry) -> DataEntry:
|
||||
data[self.output_field] = [data[fname] for fname in self.input_fields]
|
||||
@@ -402,10 +392,10 @@ class ListFeatures(SimpleTransformation):
|
||||
|
||||
class AddObservedValuesIndicator(SimpleTransformation):
|
||||
"""
|
||||
Replaces missing values in a numpy array (NaNs) with a dummy value and adds an "observed"-indicator
|
||||
that is
|
||||
1 - when values are observed
|
||||
0 - when values are missing
|
||||
Replaces missing values in a numpy array (NaNs) with a dummy value and adds
|
||||
an "observed"-indicator that is ``1`` when values are observed and ``0``
|
||||
when values are missing.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
target_field
|
||||
@@ -419,7 +409,6 @@ class AddObservedValuesIndicator(SimpleTransformation):
|
||||
they will not be replaced. In any case the indicator is included in the
|
||||
result.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
target_field: str,
|
||||
@@ -454,7 +443,6 @@ class RenameFields(SimpleTransformation):
|
||||
mapping
|
||||
Name mapping `input_name -> output_name`
|
||||
"""
|
||||
|
||||
def __init__(self, mapping: Dict[str, str]) -> None:
|
||||
self.mapping = mapping
|
||||
values_count = Counter(mapping.values())
|
||||
@@ -492,7 +480,6 @@ class AddConstFeature(MapTransformation):
|
||||
dtype
|
||||
Numpy dtype to use for resulting array.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
output_field: str,
|
||||
@@ -508,12 +495,11 @@ class AddConstFeature(MapTransformation):
|
||||
self.target_field = target_field
|
||||
|
||||
def map_transform(self, data: DataEntry, is_train: bool) -> DataEntry:
|
||||
length = target_transformation_length(
|
||||
data[self.target_field], self.pred_length, is_train=is_train
|
||||
)
|
||||
data[self.output_field] = self.const * np.ones(
|
||||
shape=(1, length), dtype=self.dtype
|
||||
)
|
||||
length = target_transformation_length(data[self.target_field],
|
||||
self.pred_length,
|
||||
is_train=is_train)
|
||||
data[self.output_field] = self.const * np.ones(shape=(1, length),
|
||||
dtype=self.dtype)
|
||||
return data
|
||||
|
||||
|
||||
@@ -535,7 +521,6 @@ class AddTimeFeatures(MapTransformation):
|
||||
pred_length
|
||||
Prediction length
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
start_field: str,
|
||||
@@ -562,23 +547,23 @@ class AddTimeFeatures(MapTransformation):
|
||||
if self._min_time_point is None:
|
||||
self._min_time_point = start
|
||||
self._max_time_point = end
|
||||
self._min_time_point = min(shift_timestamp(start, -50), self._min_time_point)
|
||||
self._max_time_point = max(shift_timestamp(end, 50), self._max_time_point)
|
||||
self.full_date_range = pd.date_range(
|
||||
self._min_time_point, self._max_time_point, freq=start.freq
|
||||
)
|
||||
self._min_time_point = min(shift_timestamp(start, -50),
|
||||
self._min_time_point)
|
||||
self._max_time_point = max(shift_timestamp(end, 50),
|
||||
self._max_time_point)
|
||||
self.full_date_range = pd.date_range(self._min_time_point,
|
||||
self._max_time_point,
|
||||
freq=start.freq)
|
||||
self._full_range_date_features = np.vstack(
|
||||
[feat(self.full_date_range) for feat in self.date_features]
|
||||
)
|
||||
self._date_index = pd.Series(
|
||||
index=self.full_date_range, data=np.arange(len(self.full_date_range))
|
||||
)
|
||||
[feat(self.full_date_range) for feat in self.date_features])
|
||||
self._date_index = pd.Series(index=self.full_date_range,
|
||||
data=np.arange(len(self.full_date_range)))
|
||||
|
||||
def map_transform(self, data: DataEntry, is_train: bool) -> DataEntry:
|
||||
start = data[self.start_field]
|
||||
length = target_transformation_length(
|
||||
data[self.target_field], self.pred_length, is_train=is_train
|
||||
)
|
||||
length = target_transformation_length(data[self.target_field],
|
||||
self.pred_length,
|
||||
is_train=is_train)
|
||||
self._update_cache(start, length)
|
||||
i0 = self._date_index[start]
|
||||
features = self._full_range_date_features[..., i0:i0 + length]
|
||||
@@ -604,7 +589,6 @@ class AddAgeFeature(MapTransformation):
|
||||
log_scale
|
||||
If set to true the age feature grows logarithmically otherwise linearly over time.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
target_field: str,
|
||||
@@ -619,9 +603,9 @@ class AddAgeFeature(MapTransformation):
|
||||
self._age_feature = np.zeros(0)
|
||||
|
||||
def map_transform(self, data: DataEntry, is_train: bool) -> DataEntry:
|
||||
length = target_transformation_length(
|
||||
data[self.target_field], self.pred_length, is_train=is_train
|
||||
)
|
||||
length = target_transformation_length(data[self.target_field],
|
||||
self.pred_length,
|
||||
is_train=is_train)
|
||||
|
||||
if self.log_scale:
|
||||
age = np.log10(2.0 + np.arange(length, dtype=np.float32))
|
||||
@@ -648,6 +632,7 @@ class InstanceSplitter(FlatMapTransformation):
|
||||
The transformation also adds a field 'past_is_pad' that indicates whether
|
||||
values where padded or not.
|
||||
Convention: time axis is always the last axis.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
target_field
|
||||
@@ -677,7 +662,6 @@ class InstanceSplitter(FlatMapTransformation):
|
||||
cold-start. In such case, is_pad_out contains an indicator whether
|
||||
data is padded or not.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
target_field: str,
|
||||
@@ -711,7 +695,8 @@ class InstanceSplitter(FlatMapTransformation):
|
||||
def _future(self, col_name):
|
||||
return f"future_{col_name}"
|
||||
|
||||
def flatmap_transform(self, data: DataEntry, is_train: bool) -> Iterator[DataEntry]:
|
||||
def flatmap_transform(self, data: DataEntry,
|
||||
is_train: bool) -> Iterator[DataEntry]:
|
||||
pl = self.future_length
|
||||
slice_cols = self.ts_fields + [self.target_field]
|
||||
target = data[self.target_field]
|
||||
@@ -728,12 +713,11 @@ class InstanceSplitter(FlatMapTransformation):
|
||||
else:
|
||||
if self.pick_incomplete:
|
||||
sampling_indices = self.train_sampler(
|
||||
target, 0, len_target - self.future_length
|
||||
)
|
||||
target, 0, len_target - self.future_length)
|
||||
else:
|
||||
sampling_indices = self.train_sampler(
|
||||
target, self.past_length, len_target - self.future_length
|
||||
)
|
||||
target, self.past_length,
|
||||
len_target - self.future_length)
|
||||
else:
|
||||
sampling_indices = [len_target]
|
||||
for i in sampling_indices:
|
||||
@@ -744,18 +728,17 @@ class InstanceSplitter(FlatMapTransformation):
|
||||
for ts_field in slice_cols:
|
||||
if i > self.past_length:
|
||||
# truncate to past_length
|
||||
past_piece = d[ts_field][..., i - self.past_length: i]
|
||||
past_piece = d[ts_field][..., i - self.past_length:i]
|
||||
elif i < self.past_length:
|
||||
pad_block = np.zeros(
|
||||
d[ts_field].shape[:-1] + (pad_length,), dtype=d[ts_field].dtype
|
||||
)
|
||||
pad_block = np.zeros(d[ts_field].shape[:-1] +
|
||||
(pad_length, ),
|
||||
dtype=d[ts_field].dtype)
|
||||
past_piece = np.concatenate(
|
||||
[pad_block, d[ts_field][..., :i]], axis=-1
|
||||
)
|
||||
[pad_block, d[ts_field][..., :i]], axis=-1)
|
||||
else:
|
||||
past_piece = d[ts_field][..., :i]
|
||||
d[self._past(ts_field)] = past_piece
|
||||
d[self._future(ts_field)] = d[ts_field][..., i: i + pl]
|
||||
d[self._future(ts_field)] = d[ts_field][..., i:i + pl]
|
||||
del d[ts_field]
|
||||
pad_indicator = np.zeros(self.past_length)
|
||||
if pad_length > 0:
|
||||
@@ -763,11 +746,14 @@ class InstanceSplitter(FlatMapTransformation):
|
||||
|
||||
if self.batch_first:
|
||||
for ts_field in slice_cols:
|
||||
d[self._past(ts_field)] = d[self._past(ts_field)].transpose()
|
||||
d[self._future(ts_field)] = d[self._future(ts_field)].transpose()
|
||||
d[self._past(ts_field)] = d[self._past(
|
||||
ts_field)].transpose()
|
||||
d[self._future(ts_field)] = d[self._future(
|
||||
ts_field)].transpose()
|
||||
|
||||
d[self._past(self.is_pad_field)] = pad_indicator
|
||||
d[self.forecast_start_field] = shift_timestamp(d[self.start_field], i)
|
||||
d[self.forecast_start_field] = shift_timestamp(
|
||||
d[self.start_field], i)
|
||||
yield d
|
||||
|
||||
|
||||
@@ -820,7 +806,6 @@ class CanonicalInstanceSplitter(FlatMapTransformation):
|
||||
length of the prediction range, must be set if
|
||||
use_prediction_features is True
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
target_field: str,
|
||||
@@ -848,7 +833,7 @@ class CanonicalInstanceSplitter(FlatMapTransformation):
|
||||
self.forecast_start_field = forecast_start_field
|
||||
|
||||
assert (
|
||||
not use_prediction_features or prediction_length is not None
|
||||
not use_prediction_features or prediction_length is not None
|
||||
), "You must specify `prediction_length` if `use_prediction_features`"
|
||||
|
||||
self.use_prediction_features = use_prediction_features
|
||||
@@ -860,7 +845,8 @@ class CanonicalInstanceSplitter(FlatMapTransformation):
|
||||
def _future(self, col_name):
|
||||
return f"future_{col_name}"
|
||||
|
||||
def flatmap_transform(self, data: DataEntry, is_train: bool) -> Iterator[DataEntry]:
|
||||
def flatmap_transform(self, data: DataEntry,
|
||||
is_train: bool) -> Iterator[DataEntry]:
|
||||
ts_fields = self.dynamic_feature_fields + [self.target_field]
|
||||
ts_target = data[self.target_field]
|
||||
|
||||
@@ -870,14 +856,10 @@ class CanonicalInstanceSplitter(FlatMapTransformation):
|
||||
if len_target < self.instance_length:
|
||||
sampling_indices = (
|
||||
# Returning [] for all time series will cause this to be in loop forever!
|
||||
[len_target]
|
||||
if self.allow_target_padding
|
||||
else []
|
||||
)
|
||||
[len_target] if self.allow_target_padding else [])
|
||||
else:
|
||||
sampling_indices = self.instance_sampler(
|
||||
ts_target, self.instance_length, len_target
|
||||
)
|
||||
ts_target, self.instance_length, len_target)
|
||||
else:
|
||||
sampling_indices = [len_target]
|
||||
|
||||
@@ -887,9 +869,8 @@ class CanonicalInstanceSplitter(FlatMapTransformation):
|
||||
pad_length = max(self.instance_length - i, 0)
|
||||
|
||||
# update start field
|
||||
d[self.start_field] = shift_timestamp(
|
||||
data[self.start_field], i - self.instance_length
|
||||
)
|
||||
d[self.start_field] = shift_timestamp(data[self.start_field],
|
||||
i - self.instance_length)
|
||||
|
||||
# set is_pad field
|
||||
is_pad = np.zeros(self.instance_length)
|
||||
@@ -902,28 +883,26 @@ class CanonicalInstanceSplitter(FlatMapTransformation):
|
||||
full_ts = data[ts_field]
|
||||
if pad_length > 0:
|
||||
pad_pre = self.pad_value * np.ones(
|
||||
shape=full_ts.shape[:-1] + (pad_length,)
|
||||
)
|
||||
past_ts = np.concatenate([pad_pre, full_ts[..., :i]], axis=-1)
|
||||
shape=full_ts.shape[:-1] + (pad_length, ))
|
||||
past_ts = np.concatenate([pad_pre, full_ts[..., :i]],
|
||||
axis=-1)
|
||||
else:
|
||||
past_ts = full_ts[..., (i - self.instance_length): i]
|
||||
past_ts = full_ts[..., (i - self.instance_length):i]
|
||||
|
||||
past_ts = past_ts.transpose() if self.batch_first else past_ts
|
||||
d[self._past(ts_field)] = past_ts
|
||||
|
||||
if self.use_prediction_features and not is_train:
|
||||
if not ts_field == self.target_field:
|
||||
future_ts = full_ts[..., i: i + self.prediction_length]
|
||||
future_ts = (
|
||||
future_ts.transpose() if self.batch_first else future_ts
|
||||
)
|
||||
future_ts = full_ts[..., i:i + self.prediction_length]
|
||||
future_ts = (future_ts.transpose()
|
||||
if self.batch_first else future_ts)
|
||||
d[self._future(ts_field)] = future_ts
|
||||
|
||||
del d[ts_field]
|
||||
|
||||
d[self.forecast_start_field] = shift_timestamp(
|
||||
d[self.start_field], self.instance_length
|
||||
)
|
||||
d[self.start_field], self.instance_length)
|
||||
|
||||
yield d
|
||||
|
||||
@@ -936,7 +915,6 @@ class SelectFields(MapTransformation):
|
||||
input_fields
|
||||
List of fields to keep.
|
||||
"""
|
||||
|
||||
def __init__(self, input_fields: List[str]) -> None:
|
||||
self.input_fields = input_fields
|
||||
|
||||
|
||||
@@ -0,0 +1,21 @@
|
||||
from typing import Tuple
|
||||
import re
|
||||
|
||||
def get_granularity(freq_str: str) -> Tuple[int, str]:
|
||||
"""
|
||||
Splits a frequency string such as "7D" into the multiple 7 and the base
|
||||
granularity "D".
|
||||
|
||||
Parameters
|
||||
----------
|
||||
|
||||
freq_str
|
||||
Frequency string of the form [multiple][granularity] such as "12H", "5min", "1D" etc.
|
||||
"""
|
||||
freq_regex = r"\s*((\d+)?)\s*([^\d]\w*)"
|
||||
m = re.match(freq_regex, freq_str)
|
||||
assert m is not None, "Cannot parse frequency string: %s" % freq_str
|
||||
groups = m.groups()
|
||||
multiple = int(groups[1]) if groups[1] is not None else 1
|
||||
granularity = groups[2]
|
||||
return multiple, granularity
|
||||
Reference in New Issue
Block a user