updating to gluon 10 or 11

This commit is contained in:
wassname
2022-12-04 09:01:58 +08:00
parent 81be06bcc1
commit a4a3e953de
+38 -26
View File
@@ -7,15 +7,21 @@ import numpy as np
import pandas as pd
from gluonts.dataset.field_names import FieldName
from gluonts.dataset.repository._util import metadata, save_to_file
from gluonts.dataset.repository._util import metadata
from gluonts.time_feature.holiday import squared_exponential_kernel
from pts.feature import CustomDateFeatureSet
from gluonts.dataset import DatasetWriter
from gluonts.dataset.common import MetaData, TrainDatasets
from pathlib import Path
def generate_pts_m5_dataset(
dataset_path: Path,
pandas_freq: str,
prediction_length: int = 28,
prediction_length: int,
m5_file_path: Path,
dataset_writer: DatasetWriter,
alpha: float = 0.5,
):
cal_path = f"{dataset_path}/calendar.csv"
@@ -134,7 +140,6 @@ def generate_pts_m5_dataset(
]
# Build training set
train_file = dataset_path / "train" / "data.json"
train_ds = []
for index, item in sales_train_validation.iterrows():
id, item_id, dept_id, cat_id, store_id, state_id = index
@@ -178,30 +183,17 @@ def generate_pts_m5_dataset(
.astype(np.float32)
.tolist()
)
d = {
FieldName.TARGET: time_series["target"],
FieldName.START: time_series["start"],
FieldName.FEAT_DYNAMIC_REAL: time_series["feat_dynamic_real"],
FieldName.FEAT_STATIC_CAT: time_series["feat_static_cat"],
FieldName.ITEM_ID: time_series["item_id"],
}
train_ds.append(time_series.copy())
# Build training set
train_file = dataset_path / "train" / "data.json"
save_to_file(train_file, train_ds)
# Create metadata file
meta_file = dataset_path / "metadata.json"
with open(meta_file, "w") as f:
f.write(
json.dumps(
{
"freq": pandas_freq,
"prediction_length": prediction_length,
"feat_static_cat": feat_static_cat,
"feat_dynamic_real": feat_dynamic_real,
"cardinality": len(train_ds),
}
)
)
train_ds.append(d)
# Build testing set
test_file = dataset_path / "test" / "data.json"
test_ds = []
for index, item in sales_train_evaluation.iterrows():
id, item_id, dept_id, cat_id, store_id, state_id = index
@@ -245,7 +237,27 @@ def generate_pts_m5_dataset(
.astype(np.float32)
.tolist()
)
d = {
FieldName.TARGET: time_series["target"],
FieldName.START: time_series["start"],
FieldName.FEAT_DYNAMIC_REAL: time_series["feat_dynamic_real"],
FieldName.FEAT_STATIC_CAT: time_series["feat_static_cat"],
FieldName.ITEM_ID: time_series["item_id"],
}
test_ds.append(time_series.copy())
test_ds.append(d)
save_to_file(test_file, test_ds)
# Create metadata file
meta = MetaData(
**metadata(
cardinality=len(train_ds),
freq=pandas_freq,
feat_static_cat=feat_static_cat,
feat_dynamic_real=feat_dynamic_real,
prediction_length=prediction_length,
)
)
dataset = TrainDatasets(metadata=meta, train=train_ds, test=test_ds)
dataset.save(
path_str=str(dataset_path), writer=dataset_writer, overwrite=True
)