diff --git a/pts/dataset/__init__.py b/pts/dataset/__init__.py index c55d88b..813e49a 100644 --- a/pts/dataset/__init__.py +++ b/pts/dataset/__init__.py @@ -1,6 +1,5 @@ from .common import DataEntry, FieldName from .list_dataset import ListDataset -from .loader import DataLoader, InferenceDataLoader, TrainDataLoader from .sampler import ( InstanceSampler, BucketInstanceSampler, @@ -8,4 +7,5 @@ from .sampler import ( TestSplitSampler, UniformSplitSampler, ) +from .process import ProcessStartField from .utils import to_pandas diff --git a/pts/dataset/loader.py b/pts/dataset/loader.py index 0081642..9b8d737 100644 --- a/pts/dataset/loader.py +++ b/pts/dataset/loader.py @@ -7,7 +7,7 @@ import numpy as np # Third-party imports import torch -from pts.feature.transform import Transformation +from pts.feature import Transformation # First-party imports from .common import DataEntry, Dataset diff --git a/pts/feature/transform.py b/pts/feature/transform.py index ed152b7..ac3b91c 100644 --- a/pts/feature/transform.py +++ b/pts/feature/transform.py @@ -7,8 +7,7 @@ import numpy as np import pandas as pd from pts.exception import assert_pts -from pts.dataset import DataEntry -from pts.dataset.sampler import InstanceSampler +from pts.dataset import DataEntry, InstanceSampler from .time_feature import TimeFeature diff --git a/pts/model/estimator.py b/pts/model/estimator.py index 07eccbf..4328db6 100644 --- a/pts/model/estimator.py +++ b/pts/model/estimator.py @@ -4,8 +4,7 @@ import numpy as np import torch import torch.nn as nn -from pts.dataset import TrainDataLoader -from pts.dataset.common import Dataset +from pts.dataset import Dataset, TrainDataLoader from pts.feature import Transformation from .predictor import Predictor diff --git a/pts/model/predictor.py b/pts/model/predictor.py index cb08438..cc364a2 100644 --- a/pts/model/predictor.py +++ b/pts/model/predictor.py @@ -1,7 +1,7 @@ from abc import ABC, abstractmethod from typing import Iterator -from pts.dataset.common import Dataset +from pts.dataset import Dataset from .forecast import Forecast from .predictor import Predictor diff --git a/test/dataset/test_common.py b/test/dataset/test_common.py new file mode 100644 index 0000000..bb46862 --- /dev/null +++ b/test/dataset/test_common.py @@ -0,0 +1,15 @@ +from pts.dataset import FieldName + +def test_dataset_fields(): + assert ( + "feat_static_cat" == FieldName.FEAT_STATIC_CAT + ), "Error in the FieldName 'feat_static_cat'." + assert ( + "feat_static_real" == FieldName.FEAT_STATIC_REAL + ), "Error in the FieldName 'feat_static_real'." + assert ( + "feat_dynamic_cat" == FieldName.FEAT_DYNAMIC_CAT + ), "Error in the FieldName 'feat_dynamic_cat'." + assert ( + "feat_dynamic_real" == FieldName.FEAT_DYNAMIC_REAL + ), "Error in the FieldName 'feat_dynamic_real'." diff --git a/test/dataset/test_process.py b/test/dataset/test_process.py new file mode 100644 index 0000000..26840ff --- /dev/null +++ b/test/dataset/test_process.py @@ -0,0 +1,20 @@ +import pytest +import pandas as pd + +from pts.dataset import ProcessStartField + +@pytest.mark.parametrize( + "freq, expected", + [ + ("B", "2019-11-01"), + ("W", "2019-11-03"), + ("M", "2019-11-30"), + ("12M", "2019-11-30"), + ("A-DEC", "2019-12-31"), + ], +) +def test_process_start_field(freq, expected): + process = ProcessStartField.process + given = "2019-11-01 12:34:56" + + assert process(given, freq) == pd.Timestamp(expected, freq) \ No newline at end of file diff --git a/test/feature/test_lag.py b/test/feature/test_lag.py new file mode 100644 index 0000000..7019150 --- /dev/null +++ b/test/feature/test_lag.py @@ -0,0 +1,299 @@ +from pts.feature import get_lags_for_frequency + +# These are the expected lags for common frequencies and corner cases. +# By default all frequencies have the following lags: [1, 2, 3, 4, 5, 6, 7]. +# Remaining lags correspond to the same `season` (+/- `delta`) in previous `k` cycles. +expected_lags = { + # (apart from the default lags) centered around each of the last 3 hours (delta = 2) + "min": [ + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 58, + 59, + 60, + 61, + 62, + 118, + 119, + 120, + 121, + 122, + 178, + 179, + 180, + 181, + 182, + ], + # centered around each of the last 3 hours (delta = 2) + last 7 days (delta = 1) + "15min": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14] + + [ + 95, + 96, + 97, + 191, + 192, + 193, + 287, + 288, + 289, + 383, + 384, + 385, + 479, + 480, + 481, + 575, + 576, + 577, + 671, + 672, + 673, + ], + # centered around each of the last 3 hours (delta = 2) + last 7 days (delta = 1) + 3 weeks (delta = 1) + "30min": [1, 2, 3, 4, 5, 6, 7, 8] + + [ + 47, + 48, + 49, + 95, + 96, + 97, + 143, + 144, + 145, + 191, + 192, + 193, + 239, + 240, + 241, + 287, + 288, + 289, + 335, + 336, + 337, + ] + + [671, 672, 673, 1007, 1008, 1009], + # centered around each of the last 3 hours (delta = 2) + last 7 days (delta = 1) + last 6 weeks (delta = 1) + "59min": [1, 2, 3, 4, 5, 6, 7] + + [ + 23, + 24, + 25, + 47, + 48, + 49, + 72, + 73, + 74, + 96, + 97, + 98, + 121, + 122, + 123, + 145, + 146, + 147, + 169, + 170, + 171, + ] + + [340, 341, 342, 511, 512, 513, 682, 683, 684, 731, 732, 733], + # centered around each of the last 3 hours (delta = 2) + last 7 days (delta = 1) + last 6 weeks (delta = 1) + "61min": [1, 2, 3, 4, 5, 6, 7] + + [ + 22, + 23, + 24, + 46, + 47, + 48, + 69, + 70, + 71, + 93, + 94, + 95, + 117, + 118, + 119, + 140, + 141, + 142, + 164, + 165, + 166, + ] + + [329, 330, 331, 494, 495, 496, 659, 660, 661, 707, 708, 709], + # centered around each of the last 3 hours (delta = 2) + last 7 days (delta = 1) + last 6 weeks (delta = 1) + "H": [1, 2, 3, 4, 5, 6, 7] + + [ + 23, + 24, + 25, + 47, + 48, + 49, + 71, + 72, + 73, + 95, + 96, + 97, + 119, + 120, + 121, + 143, + 144, + 145, + 167, + 168, + 169, + ] + + [335, 336, 337, 503, 504, 505, 671, 672, 673, 719, 720, 721], + # centered around each of the last 7 days (delta = 1) + last 4 weeks (delta = 1) + last 1 month (delta = 1) + + # last 8th and 12th weeks (delta = 0) + "6H": [ + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 9, + 11, + 12, + 13, + 15, + 16, + 17, + 19, + 20, + 21, + 23, + 24, + 25, + 27, + 28, + 29, + ] + + [55, 56, 57, 83, 84, 85, 111, 112, 113] + + [119, 120, 121] + + [224, 336], + # centered around each of the last 7 days (delta = 1) + last 4 weeks (delta = 1) + last 1 month (delta = 1) + + # last 8th and 12th weeks (delta = 0) + last year (delta = 1) + "12H": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] + + [27, 28, 29, 41, 42, 43, 55, 56, 57] + + [59, 60, 61] + + [112, 168] + + [727, 728, 729], + # centered around each of the last 7 days (delta = 1) + last 4 weeks (delta = 1) + last 1 month (delta = 1) + + # last 8th and 12th weeks (delta = 0) + last 3 years (delta = 1) + "23H": [1, 2, 3, 4, 5, 6, 7, 8] + + [13, 14, 15, 20, 21, 22, 28, 29] + + [30, 31, 32] + + [58, 87] + + [378, 379, 380, 758, 759, 760, 1138, 1139, 1140], + # centered around each of the last 7 days (delta = 1) + last 4 weeks (delta = 1) + last 1 month (delta = 1) + + # last 8th and 12th weeks (delta = 0) + last 3 years (delta = 1) + "25H": [1, 2, 3, 4, 5, 6, 7] + + [12, 13, 14, 19, 20, 21, 25, 26, 27] + + [28, 29] + + [53, 80] + + [348, 349, 350, 697, 698, 699, 1047, 1048, 1049], + # centered around each of the last 7 days (delta = 1) + last 4 weeks (delta = 1) + last 1 month (delta = 1) + + # last 8th and 12th weeks (delta = 0) + last 3 years (delta = 1) + "D": [1, 2, 3, 4, 5, 6, 7, 8] + + [13, 14, 15, 20, 21, 22, 27, 28, 29] + + [30, 31] + + [56, 84] + + [363, 364, 365, 727, 728, 729, 1091, 1092, 1093], + # centered around each of the last 7 days (delta = 1) + last 4 weeks (delta = 1) + last 1 month (delta = 1) + + # last 8th and 12th weeks (delta = 0) + last 3 years (delta = 1) + "2D": [1, 2, 3, 4, 5] + + [6, 7, 8, 9, 10, 11, 13, 14, 15] + + [16] + + [28, 42] + + [181, 182, 183, 363, 364, 365, 545, 546, 547], + # centered around each of the last 3 months (delta = 0) + last 3 years (delta = 1) (assuming 52 weeks per year) + "6D": [1, 2, 3, 4, 5, 6, 7, 9, 14] + + [59, 60, 61, 120, 121, 122, 181, 182, 183], + # centered around each of the last 3 months (delta = 0) + last 3 years (delta = 1) (assuming 52 weeks per year) + "W": [1, 2, 3, 4, 5, 6, 7, 8, 12] + + [51, 52, 53, 103, 104, 105, 155, 156, 157], + # centered around each of the last 3 months (delta = 0) + last 3 years (delta = 1) (assuming 52 weeks per year) + "8D": [1, 2, 3, 4, 5, 6, 7, 10] + [44, 45, 46, 90, 91, 92, 135, 136, 137], + # centered around each of the last 3 years (delta = 1) + "4W": [1, 2, 3, 4, 5, 6, 7, 12, 13, 14, 25, 26, 27, 38, 39, 40], + # centered around each of the last 3 years (delta = 1) + "3W": [1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 33, 34, 35, 51, 52, 53], + # centered around each of the last 3 years (delta = 1) + "5W": [1, 2, 3, 4, 5, 6, 7, 9, 10, 11, 19, 20, 21, 30, 31, 32], + # centered around each of the last 3 years (delta = 1) + "M": [1, 2, 3, 4, 5, 6, 7, 11, 12, 13, 23, 24, 25, 35, 36, 37], + # default + "6M": [1, 2, 3, 4, 5, 6, 7], + # default + "12M": [1, 2, 3, 4, 5, 6, 7], +} + +# For the default multiple (1) +for freq in ["min", "H", "D", "W", "M"]: + expected_lags["1" + freq] = expected_lags[freq] + +# For frequencies that do not have unique form +expected_lags["60min"] = expected_lags["1H"] +expected_lags["24H"] = expected_lags["1D"] +expected_lags["7D"] = expected_lags["1W"] + + +def test_lags(): + + freq_strs = [ + "min", + "1min", + "15min", + "30min", + "59min", + "60min", + "61min", + "H", + "1H", + "6H", + "12H", + "23H", + "24H", + "25H", + "D", + "1D", + "2D", + "6D", + "7D", + "8D", + "W", + "1W", + "3W", + "4W", + "5W", + "M", + "6M", + "12M", + ] + + for freq_str in freq_strs: + lags = get_lags_for_frequency(freq_str) + + assert ( + lags == expected_lags[freq_str] + ), "lags do not match for the frequency '{}':\nexpected: {},\nprovided: {}".format( + freq_str, expected_lags[freq_str], lags + )