diff --git a/pts/dataset/repository/_m5.py b/pts/dataset/repository/_m5.py index cd70b0b..bc26599 100644 --- a/pts/dataset/repository/_m5.py +++ b/pts/dataset/repository/_m5.py @@ -7,15 +7,21 @@ import numpy as np import pandas as pd from gluonts.dataset.field_names import FieldName -from gluonts.dataset.repository._util import metadata, save_to_file +from gluonts.dataset.repository._util import metadata from gluonts.time_feature.holiday import squared_exponential_kernel from pts.feature import CustomDateFeatureSet +from gluonts.dataset import DatasetWriter +from gluonts.dataset.common import MetaData, TrainDatasets + +from pathlib import Path def generate_pts_m5_dataset( dataset_path: Path, pandas_freq: str, - prediction_length: int = 28, + prediction_length: int, + m5_file_path: Path, + dataset_writer: DatasetWriter, alpha: float = 0.5, ): cal_path = f"{dataset_path}/calendar.csv" @@ -134,7 +140,6 @@ def generate_pts_m5_dataset( ] # Build training set - train_file = dataset_path / "train" / "data.json" train_ds = [] for index, item in sales_train_validation.iterrows(): id, item_id, dept_id, cat_id, store_id, state_id = index @@ -178,30 +183,17 @@ def generate_pts_m5_dataset( .astype(np.float32) .tolist() ) + d = { + FieldName.TARGET: time_series["target"], + FieldName.START: time_series["start"], + FieldName.FEAT_DYNAMIC_REAL: time_series["feat_dynamic_real"], + FieldName.FEAT_STATIC_CAT: time_series["feat_static_cat"], + FieldName.ITEM_ID: time_series["item_id"], + } - train_ds.append(time_series.copy()) - - # Build training set - train_file = dataset_path / "train" / "data.json" - save_to_file(train_file, train_ds) - - # Create metadata file - meta_file = dataset_path / "metadata.json" - with open(meta_file, "w") as f: - f.write( - json.dumps( - { - "freq": pandas_freq, - "prediction_length": prediction_length, - "feat_static_cat": feat_static_cat, - "feat_dynamic_real": feat_dynamic_real, - "cardinality": len(train_ds), - } - ) - ) + train_ds.append(d) # Build testing set - test_file = dataset_path / "test" / "data.json" test_ds = [] for index, item in sales_train_evaluation.iterrows(): id, item_id, dept_id, cat_id, store_id, state_id = index @@ -245,7 +237,27 @@ def generate_pts_m5_dataset( .astype(np.float32) .tolist() ) + d = { + FieldName.TARGET: time_series["target"], + FieldName.START: time_series["start"], + FieldName.FEAT_DYNAMIC_REAL: time_series["feat_dynamic_real"], + FieldName.FEAT_STATIC_CAT: time_series["feat_static_cat"], + FieldName.ITEM_ID: time_series["item_id"], + } - test_ds.append(time_series.copy()) + test_ds.append(d) - save_to_file(test_file, test_ds) + # Create metadata file + meta = MetaData( + **metadata( + cardinality=len(train_ds), + freq=pandas_freq, + feat_static_cat=feat_static_cat, + feat_dynamic_real=feat_dynamic_real, + prediction_length=prediction_length, + ) + ) + dataset = TrainDatasets(metadata=meta, train=train_ds, test=test_ds) + dataset.save( + path_str=str(dataset_path), writer=dataset_writer, overwrite=True + )