mirror of
https://github.com/wassname/pytorch-ts.git
synced 2026-06-27 17:49:41 +08:00
updating to gluon 10 or 11
This commit is contained in:
@@ -7,15 +7,21 @@ import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from gluonts.dataset.field_names import FieldName
|
||||
from gluonts.dataset.repository._util import metadata, save_to_file
|
||||
from gluonts.dataset.repository._util import metadata
|
||||
from gluonts.time_feature.holiday import squared_exponential_kernel
|
||||
from pts.feature import CustomDateFeatureSet
|
||||
|
||||
from gluonts.dataset import DatasetWriter
|
||||
from gluonts.dataset.common import MetaData, TrainDatasets
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
def generate_pts_m5_dataset(
|
||||
dataset_path: Path,
|
||||
pandas_freq: str,
|
||||
prediction_length: int = 28,
|
||||
prediction_length: int,
|
||||
m5_file_path: Path,
|
||||
dataset_writer: DatasetWriter,
|
||||
alpha: float = 0.5,
|
||||
):
|
||||
cal_path = f"{dataset_path}/calendar.csv"
|
||||
@@ -134,7 +140,6 @@ def generate_pts_m5_dataset(
|
||||
]
|
||||
|
||||
# Build training set
|
||||
train_file = dataset_path / "train" / "data.json"
|
||||
train_ds = []
|
||||
for index, item in sales_train_validation.iterrows():
|
||||
id, item_id, dept_id, cat_id, store_id, state_id = index
|
||||
@@ -178,30 +183,17 @@ def generate_pts_m5_dataset(
|
||||
.astype(np.float32)
|
||||
.tolist()
|
||||
)
|
||||
d = {
|
||||
FieldName.TARGET: time_series["target"],
|
||||
FieldName.START: time_series["start"],
|
||||
FieldName.FEAT_DYNAMIC_REAL: time_series["feat_dynamic_real"],
|
||||
FieldName.FEAT_STATIC_CAT: time_series["feat_static_cat"],
|
||||
FieldName.ITEM_ID: time_series["item_id"],
|
||||
}
|
||||
|
||||
train_ds.append(time_series.copy())
|
||||
|
||||
# Build training set
|
||||
train_file = dataset_path / "train" / "data.json"
|
||||
save_to_file(train_file, train_ds)
|
||||
|
||||
# Create metadata file
|
||||
meta_file = dataset_path / "metadata.json"
|
||||
with open(meta_file, "w") as f:
|
||||
f.write(
|
||||
json.dumps(
|
||||
{
|
||||
"freq": pandas_freq,
|
||||
"prediction_length": prediction_length,
|
||||
"feat_static_cat": feat_static_cat,
|
||||
"feat_dynamic_real": feat_dynamic_real,
|
||||
"cardinality": len(train_ds),
|
||||
}
|
||||
)
|
||||
)
|
||||
train_ds.append(d)
|
||||
|
||||
# Build testing set
|
||||
test_file = dataset_path / "test" / "data.json"
|
||||
test_ds = []
|
||||
for index, item in sales_train_evaluation.iterrows():
|
||||
id, item_id, dept_id, cat_id, store_id, state_id = index
|
||||
@@ -245,7 +237,27 @@ def generate_pts_m5_dataset(
|
||||
.astype(np.float32)
|
||||
.tolist()
|
||||
)
|
||||
d = {
|
||||
FieldName.TARGET: time_series["target"],
|
||||
FieldName.START: time_series["start"],
|
||||
FieldName.FEAT_DYNAMIC_REAL: time_series["feat_dynamic_real"],
|
||||
FieldName.FEAT_STATIC_CAT: time_series["feat_static_cat"],
|
||||
FieldName.ITEM_ID: time_series["item_id"],
|
||||
}
|
||||
|
||||
test_ds.append(time_series.copy())
|
||||
test_ds.append(d)
|
||||
|
||||
save_to_file(test_file, test_ds)
|
||||
# Create metadata file
|
||||
meta = MetaData(
|
||||
**metadata(
|
||||
cardinality=len(train_ds),
|
||||
freq=pandas_freq,
|
||||
feat_static_cat=feat_static_cat,
|
||||
feat_dynamic_real=feat_dynamic_real,
|
||||
prediction_length=prediction_length,
|
||||
)
|
||||
)
|
||||
dataset = TrainDatasets(metadata=meta, train=train_ds, test=test_ds)
|
||||
dataset.save(
|
||||
path_str=str(dataset_path), writer=dataset_writer, overwrite=True
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user