mirror of
https://github.com/wassname/pytorch-ts.git
synced 2026-06-27 17:49:41 +08:00
597 lines
32 KiB
Plaintext
597 lines
32 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 1,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"%matplotlib inline\n",
|
|
"import matplotlib.pyplot as plt\n",
|
|
"\n",
|
|
"import numpy as np\n",
|
|
"import pandas as pd\n",
|
|
"\n",
|
|
"import torch"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 2,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from gluonts.dataset.multivariate_grouper import MultivariateGrouper\n",
|
|
"from gluonts.dataset.repository.datasets import dataset_recipes, get_dataset\n",
|
|
"from gluonts.evaluation.backtest import make_evaluation_predictions\n",
|
|
"from gluonts.evaluation import MultivariateEvaluator"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 3,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from pts.model.tempflow import TempFlowEstimator\n",
|
|
"from pts.model.time_grad import TimeGradEstimator\n",
|
|
"from pts.model.transformer_tempflow import TransformerTempFlowEstimator\n",
|
|
"from pts import Trainer"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 4,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 5,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def plot(target, forecast, prediction_length, prediction_intervals=(50.0, 90.0), color='g', fname=None):\n",
|
|
" label_prefix = \"\"\n",
|
|
" rows = 4\n",
|
|
" cols = 4\n",
|
|
" fig, axs = plt.subplots(rows, cols, figsize=(24, 24))\n",
|
|
" axx = axs.ravel()\n",
|
|
" seq_len, target_dim = target.shape\n",
|
|
" \n",
|
|
" ps = [50.0] + [\n",
|
|
" 50.0 + f * c / 2.0 for c in prediction_intervals for f in [-1.0, +1.0]\n",
|
|
" ]\n",
|
|
" \n",
|
|
" percentiles_sorted = sorted(set(ps))\n",
|
|
" \n",
|
|
" def alpha_for_percentile(p):\n",
|
|
" return (p / 100.0) ** 0.3\n",
|
|
" \n",
|
|
" for dim in range(0, min(rows * cols, target_dim)):\n",
|
|
" ax = axx[dim]\n",
|
|
"\n",
|
|
" target[-2 * prediction_length :][dim].plot(ax=ax)\n",
|
|
" \n",
|
|
" ps_data = [forecast.quantile(p / 100.0)[:,dim] for p in percentiles_sorted]\n",
|
|
" i_p50 = len(percentiles_sorted) // 2\n",
|
|
" \n",
|
|
" p50_data = ps_data[i_p50]\n",
|
|
" p50_series = pd.Series(data=p50_data, index=forecast.index)\n",
|
|
" p50_series.plot(color=color, ls=\"-\", label=f\"{label_prefix}median\", ax=ax)\n",
|
|
" \n",
|
|
" for i in range(len(percentiles_sorted) // 2):\n",
|
|
" ptile = percentiles_sorted[i]\n",
|
|
" alpha = alpha_for_percentile(ptile)\n",
|
|
" ax.fill_between(\n",
|
|
" forecast.index,\n",
|
|
" ps_data[i],\n",
|
|
" ps_data[-i - 1],\n",
|
|
" facecolor=color,\n",
|
|
" alpha=alpha,\n",
|
|
" interpolate=True,\n",
|
|
" )\n",
|
|
" # Hack to create labels for the error intervals.\n",
|
|
" # Doesn't actually plot anything, because we only pass a single data point\n",
|
|
" pd.Series(data=p50_data[:1], index=forecast.index[:1]).plot(\n",
|
|
" color=color,\n",
|
|
" alpha=alpha,\n",
|
|
" linewidth=10,\n",
|
|
" label=f\"{label_prefix}{100 - ptile * 2}%\",\n",
|
|
" ax=ax,\n",
|
|
" )\n",
|
|
"\n",
|
|
" legend = [\"observations\", \"median prediction\"] + [f\"{k}% prediction interval\" for k in prediction_intervals][::-1] \n",
|
|
" axx[0].legend(legend, loc=\"upper left\")\n",
|
|
" \n",
|
|
" if fname is not None:\n",
|
|
" plt.savefig(fname, bbox_inches='tight', pad_inches=0.05)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 6,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Available datasets: ['constant', 'exchange_rate', 'solar-energy', 'electricity', 'traffic', 'exchange_rate_nips', 'electricity_nips', 'traffic_nips', 'solar_nips', 'wiki-rolling_nips', 'taxi_30min', 'kaggle_web_traffic_with_missing', 'kaggle_web_traffic_without_missing', 'kaggle_web_traffic_weekly', 'm1_yearly', 'm1_quarterly', 'm1_monthly', 'nn5_daily_with_missing', 'nn5_daily_without_missing', 'nn5_weekly', 'tourism_monthly', 'tourism_quarterly', 'tourism_yearly', 'cif_2016', 'london_smart_meters_without_missing', 'wind_farms_without_missing', 'car_parts_without_missing', 'dominick', 'fred_md', 'pedestrian_counts', 'hospital', 'covid_deaths', 'kdd_cup_2018_without_missing', 'weather', 'm3_monthly', 'm3_quarterly', 'm3_yearly', 'm3_other', 'm4_hourly', 'm4_daily', 'm4_weekly', 'm4_monthly', 'm4_quarterly', 'm4_yearly', 'm5', 'uber_tlc_daily', 'uber_tlc_hourly', 'airpassengers']\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"print(f\"Available datasets: {list(dataset_recipes.keys())}\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 7,
|
|
"metadata": {
|
|
"scrolled": true
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"# exchange_rate_nips, electricity_nips, traffic_nips, solar_nips, wiki-rolling_nips, ## taxi_30min is buggy still\n",
|
|
"dataset = get_dataset(\"electricity_nips\", regenerate=False)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 8,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"MetaData(freq='H', target=None, feat_static_cat=[CategoricalFeatureInfo(name='feat_static_cat_0', cardinality='370')], feat_static_real=[], feat_dynamic_real=[], feat_dynamic_cat=[], prediction_length=24)"
|
|
]
|
|
},
|
|
"execution_count": 8,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"dataset.metadata"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 9,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"train_grouper = MultivariateGrouper(max_target_dim=min(2000, int(dataset.metadata.feat_static_cat[0].cardinality)))\n",
|
|
"\n",
|
|
"test_grouper = MultivariateGrouper(num_test_dates=int(len(dataset.test)/len(dataset.train)*2), \n",
|
|
" max_target_dim=min(2000, int(dataset.metadata.feat_static_cat[0].cardinality)))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 10,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"/home/wassname/miniforge3/envs/glounts/lib/python3.9/site-packages/gluonts/dataset/multivariate_grouper.py:191: VisibleDeprecationWarning: Creating an ndarray from ragged nested sequences (which is a list-or-tuple of lists-or-tuples-or ndarrays with different lengths or shapes) is deprecated. If you meant to do this, you must specify 'dtype=object' when creating the ndarray.\n",
|
|
" return {FieldName.TARGET: np.array([funcs(data) for data in dataset])}\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"dataset_train = train_grouper(dataset.train)\n",
|
|
"dataset_test = test_grouper(dataset.test)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 11,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"estimator = TimeGradEstimator(\n",
|
|
" target_dim=int(dataset.metadata.feat_static_cat[0].cardinality),\n",
|
|
" prediction_length=dataset.metadata.prediction_length,\n",
|
|
" context_length=dataset.metadata.prediction_length,\n",
|
|
" cell_type='GRU',\n",
|
|
" input_size=1484,\n",
|
|
" freq=dataset.metadata.freq,\n",
|
|
" loss_type='l2',\n",
|
|
" scaling=True,\n",
|
|
" diff_steps=100,\n",
|
|
" beta_end=0.1,\n",
|
|
" beta_schedule=\"linear\",\n",
|
|
" trainer=Trainer(device=device,\n",
|
|
" epochs=20,\n",
|
|
" learning_rate=1e-3,\n",
|
|
" num_batches_per_epoch=100,\n",
|
|
" batch_size=64,)\n",
|
|
")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 12,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"application/vnd.jupyter.widget-view+json": {
|
|
"model_id": "2d3955c7a20746bb9d07b9eb45bc51e9",
|
|
"version_major": 2,
|
|
"version_minor": 0
|
|
},
|
|
"text/plain": [
|
|
" 0%| | 0/99 [00:00<?, ?it/s]"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"data": {
|
|
"application/vnd.jupyter.widget-view+json": {
|
|
"model_id": "ea6af75e92ab4883a56cbaeffb05d768",
|
|
"version_major": 2,
|
|
"version_minor": 0
|
|
},
|
|
"text/plain": [
|
|
" 0%| | 0/99 [00:00<?, ?it/s]"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"data": {
|
|
"application/vnd.jupyter.widget-view+json": {
|
|
"model_id": "fee0881d3ff848e293f4d143c5447375",
|
|
"version_major": 2,
|
|
"version_minor": 0
|
|
},
|
|
"text/plain": [
|
|
" 0%| | 0/99 [00:00<?, ?it/s]"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"data": {
|
|
"application/vnd.jupyter.widget-view+json": {
|
|
"model_id": "df64130c0378493b990ccdf5166de116",
|
|
"version_major": 2,
|
|
"version_minor": 0
|
|
},
|
|
"text/plain": [
|
|
" 0%| | 0/99 [00:00<?, ?it/s]"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"data": {
|
|
"application/vnd.jupyter.widget-view+json": {
|
|
"model_id": "7dd79b177cda413c9d507f7397d8439d",
|
|
"version_major": 2,
|
|
"version_minor": 0
|
|
},
|
|
"text/plain": [
|
|
" 0%| | 0/99 [00:00<?, ?it/s]"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"data": {
|
|
"application/vnd.jupyter.widget-view+json": {
|
|
"model_id": "f35cd15726194b629e9e4540a52fb7ed",
|
|
"version_major": 2,
|
|
"version_minor": 0
|
|
},
|
|
"text/plain": [
|
|
" 0%| | 0/99 [00:00<?, ?it/s]"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"data": {
|
|
"application/vnd.jupyter.widget-view+json": {
|
|
"model_id": "c179fe7b63344b6a83d024033403208e",
|
|
"version_major": 2,
|
|
"version_minor": 0
|
|
},
|
|
"text/plain": [
|
|
" 0%| | 0/99 [00:00<?, ?it/s]"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"data": {
|
|
"application/vnd.jupyter.widget-view+json": {
|
|
"model_id": "dac3a1016e9a40138cdea4fb958fd582",
|
|
"version_major": 2,
|
|
"version_minor": 0
|
|
},
|
|
"text/plain": [
|
|
" 0%| | 0/99 [00:00<?, ?it/s]"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"data": {
|
|
"application/vnd.jupyter.widget-view+json": {
|
|
"model_id": "897a2d5ea4314803b057ba3a36c5c096",
|
|
"version_major": 2,
|
|
"version_minor": 0
|
|
},
|
|
"text/plain": [
|
|
" 0%| | 0/99 [00:00<?, ?it/s]"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"data": {
|
|
"application/vnd.jupyter.widget-view+json": {
|
|
"model_id": "6db7be9502924e63a2bb84784f6aea55",
|
|
"version_major": 2,
|
|
"version_minor": 0
|
|
},
|
|
"text/plain": [
|
|
" 0%| | 0/99 [00:00<?, ?it/s]"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"data": {
|
|
"application/vnd.jupyter.widget-view+json": {
|
|
"model_id": "89c7a5e00f5b4af1a34b8501e9c215c2",
|
|
"version_major": 2,
|
|
"version_minor": 0
|
|
},
|
|
"text/plain": [
|
|
" 0%| | 0/99 [00:00<?, ?it/s]"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"data": {
|
|
"application/vnd.jupyter.widget-view+json": {
|
|
"model_id": "d30da9d6bd42402b89ac15b65e85ac93",
|
|
"version_major": 2,
|
|
"version_minor": 0
|
|
},
|
|
"text/plain": [
|
|
" 0%| | 0/99 [00:00<?, ?it/s]"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"data": {
|
|
"application/vnd.jupyter.widget-view+json": {
|
|
"model_id": "daa9fd85d41d4dd0b761ef952f730cb5",
|
|
"version_major": 2,
|
|
"version_minor": 0
|
|
},
|
|
"text/plain": [
|
|
" 0%| | 0/99 [00:00<?, ?it/s]"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"data": {
|
|
"application/vnd.jupyter.widget-view+json": {
|
|
"model_id": "ae9f6fed31a44e5da269aee65209281c",
|
|
"version_major": 2,
|
|
"version_minor": 0
|
|
},
|
|
"text/plain": [
|
|
" 0%| | 0/99 [00:00<?, ?it/s]"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"data": {
|
|
"application/vnd.jupyter.widget-view+json": {
|
|
"model_id": "076119382fd94289ac5009a203d79732",
|
|
"version_major": 2,
|
|
"version_minor": 0
|
|
},
|
|
"text/plain": [
|
|
" 0%| | 0/99 [00:00<?, ?it/s]"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"data": {
|
|
"application/vnd.jupyter.widget-view+json": {
|
|
"model_id": "27864727f376427292672bd91418bcfa",
|
|
"version_major": 2,
|
|
"version_minor": 0
|
|
},
|
|
"text/plain": [
|
|
" 0%| | 0/99 [00:00<?, ?it/s]"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"data": {
|
|
"application/vnd.jupyter.widget-view+json": {
|
|
"model_id": "1e8a316110a84b6396d177a815eac0a9",
|
|
"version_major": 2,
|
|
"version_minor": 0
|
|
},
|
|
"text/plain": [
|
|
" 0%| | 0/99 [00:00<?, ?it/s]"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"ename": "Exception",
|
|
"evalue": "Reached maximum number of idle transformation calls.\nThis means the transformation looped over 10 inputs without returning any output.\nThis occurred in the following transformation:\ngluonts.transform.split.InstanceSplitter(dummy_value=0.0, forecast_start_field=\"forecast_start\", future_length=24, instance_sampler=gluonts.transform.sampler.ExpectedNumInstanceSampler(axis=-1, min_past=192, min_future=24, num_instances=1.0, total_length=585345038, n=104191), is_pad_field=\"is_pad\", lead_time=0, output_NTC=True, past_length=192, start_field=\"start\", target_field=\"target\", time_series_fields=[\"time_feat\", \"observed_values\"])",
|
|
"output_type": "error",
|
|
"traceback": [
|
|
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
|
"\u001b[0;31mException\u001b[0m Traceback (most recent call last)",
|
|
"\u001b[1;32m/media/wassname/SGIronWolf/projects5/timeseries/pytorch-ts/examples/Time-Grad-Electricity.ipynb Cell 12\u001b[0m in \u001b[0;36m<cell line: 1>\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> <a href='vscode-notebook-cell:/media/wassname/SGIronWolf/projects5/timeseries/pytorch-ts/examples/Time-Grad-Electricity.ipynb#X14sZmlsZQ%3D%3D?line=0'>1</a>\u001b[0m predictor \u001b[39m=\u001b[39m estimator\u001b[39m.\u001b[39;49mtrain(dataset_train, num_workers\u001b[39m=\u001b[39;49m\u001b[39m0\u001b[39;49m)\n",
|
|
"File \u001b[0;32m/media/wassname/SGIronWolf/projects5/timeseries/pytorch-ts/pts/model/estimator.py:179\u001b[0m, in \u001b[0;36mPyTorchEstimator.train\u001b[0;34m(self, training_data, validation_data, num_workers, prefetch_factor, shuffle_buffer_length, cache_data, **kwargs)\u001b[0m\n\u001b[1;32m 169\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mtrain\u001b[39m(\n\u001b[1;32m 170\u001b[0m \u001b[39mself\u001b[39m,\n\u001b[1;32m 171\u001b[0m training_data: Dataset,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 177\u001b[0m \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mkwargs,\n\u001b[1;32m 178\u001b[0m ) \u001b[39m-\u001b[39m\u001b[39m>\u001b[39m PyTorchPredictor:\n\u001b[0;32m--> 179\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mtrain_model(\n\u001b[1;32m 180\u001b[0m training_data,\n\u001b[1;32m 181\u001b[0m validation_data,\n\u001b[1;32m 182\u001b[0m num_workers\u001b[39m=\u001b[39;49mnum_workers,\n\u001b[1;32m 183\u001b[0m prefetch_factor\u001b[39m=\u001b[39;49mprefetch_factor,\n\u001b[1;32m 184\u001b[0m shuffle_buffer_length\u001b[39m=\u001b[39;49mshuffle_buffer_length,\n\u001b[1;32m 185\u001b[0m cache_data\u001b[39m=\u001b[39;49mcache_data,\n\u001b[1;32m 186\u001b[0m \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs,\n\u001b[1;32m 187\u001b[0m )\u001b[39m.\u001b[39mpredictor\n",
|
|
"File \u001b[0;32m/media/wassname/SGIronWolf/projects5/timeseries/pytorch-ts/pts/model/estimator.py:151\u001b[0m, in \u001b[0;36mPyTorchEstimator.train_model\u001b[0;34m(self, training_data, validation_data, num_workers, prefetch_factor, shuffle_buffer_length, cache_data, **kwargs)\u001b[0m\n\u001b[1;32m 133\u001b[0m validation_iter_dataset \u001b[39m=\u001b[39m TransformedIterableDataset(\n\u001b[1;32m 134\u001b[0m dataset\u001b[39m=\u001b[39mvalidation_data,\n\u001b[1;32m 135\u001b[0m transform\u001b[39m=\u001b[39mtransformation\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 139\u001b[0m cache_data\u001b[39m=\u001b[39mcache_data,\n\u001b[1;32m 140\u001b[0m )\n\u001b[1;32m 141\u001b[0m validation_data_loader \u001b[39m=\u001b[39m DataLoader(\n\u001b[1;32m 142\u001b[0m validation_iter_dataset,\n\u001b[1;32m 143\u001b[0m batch_size\u001b[39m=\u001b[39m\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mtrainer\u001b[39m.\u001b[39mbatch_size,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 148\u001b[0m \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mkwargs,\n\u001b[1;32m 149\u001b[0m )\n\u001b[0;32m--> 151\u001b[0m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mtrainer(\n\u001b[1;32m 152\u001b[0m net\u001b[39m=\u001b[39;49mtrained_net,\n\u001b[1;32m 153\u001b[0m train_iter\u001b[39m=\u001b[39;49mtraining_data_loader,\n\u001b[1;32m 154\u001b[0m validation_iter\u001b[39m=\u001b[39;49mvalidation_data_loader,\n\u001b[1;32m 155\u001b[0m )\n\u001b[1;32m 157\u001b[0m \u001b[39mreturn\u001b[39;00m TrainOutput(\n\u001b[1;32m 158\u001b[0m transformation\u001b[39m=\u001b[39mtransformation,\n\u001b[1;32m 159\u001b[0m trained_net\u001b[39m=\u001b[39mtrained_net,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 162\u001b[0m ),\n\u001b[1;32m 163\u001b[0m )\n",
|
|
"File \u001b[0;32m/media/wassname/SGIronWolf/projects5/timeseries/pytorch-ts/pts/trainer.py:63\u001b[0m, in \u001b[0;36mTrainer.__call__\u001b[0;34m(self, net, train_iter, validation_iter)\u001b[0m\n\u001b[1;32m 61\u001b[0m \u001b[39m# training loop\u001b[39;00m\n\u001b[1;32m 62\u001b[0m \u001b[39mwith\u001b[39;00m tqdm(train_iter, total\u001b[39m=\u001b[39mtotal) \u001b[39mas\u001b[39;00m it:\n\u001b[0;32m---> 63\u001b[0m \u001b[39mfor\u001b[39;00m batch_no, data_entry \u001b[39min\u001b[39;00m \u001b[39menumerate\u001b[39m(it, start\u001b[39m=\u001b[39m\u001b[39m1\u001b[39m):\n\u001b[1;32m 64\u001b[0m optimizer\u001b[39m.\u001b[39mzero_grad()\n\u001b[1;32m 66\u001b[0m inputs \u001b[39m=\u001b[39m [v\u001b[39m.\u001b[39mto(\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mdevice) \u001b[39mfor\u001b[39;00m v \u001b[39min\u001b[39;00m data_entry\u001b[39m.\u001b[39mvalues()]\n",
|
|
"File \u001b[0;32m~/miniforge3/envs/glounts/lib/python3.9/site-packages/tqdm/notebook.py:259\u001b[0m, in \u001b[0;36mtqdm_notebook.__iter__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 257\u001b[0m \u001b[39mtry\u001b[39;00m:\n\u001b[1;32m 258\u001b[0m it \u001b[39m=\u001b[39m \u001b[39msuper\u001b[39m(tqdm_notebook, \u001b[39mself\u001b[39m)\u001b[39m.\u001b[39m\u001b[39m__iter__\u001b[39m()\n\u001b[0;32m--> 259\u001b[0m \u001b[39mfor\u001b[39;00m obj \u001b[39min\u001b[39;00m it:\n\u001b[1;32m 260\u001b[0m \u001b[39m# return super(tqdm...) will not catch exception\u001b[39;00m\n\u001b[1;32m 261\u001b[0m \u001b[39myield\u001b[39;00m obj\n\u001b[1;32m 262\u001b[0m \u001b[39m# NB: except ... [ as ...] breaks IPython async KeyboardInterrupt\u001b[39;00m\n",
|
|
"File \u001b[0;32m~/miniforge3/envs/glounts/lib/python3.9/site-packages/tqdm/std.py:1195\u001b[0m, in \u001b[0;36mtqdm.__iter__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 1192\u001b[0m time \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_time\n\u001b[1;32m 1194\u001b[0m \u001b[39mtry\u001b[39;00m:\n\u001b[0;32m-> 1195\u001b[0m \u001b[39mfor\u001b[39;00m obj \u001b[39min\u001b[39;00m iterable:\n\u001b[1;32m 1196\u001b[0m \u001b[39myield\u001b[39;00m obj\n\u001b[1;32m 1197\u001b[0m \u001b[39m# Update and possibly print the progressbar.\u001b[39;00m\n\u001b[1;32m 1198\u001b[0m \u001b[39m# Note: does not call self.update(1) for speed optimisation.\u001b[39;00m\n",
|
|
"File \u001b[0;32m~/miniforge3/envs/glounts/lib/python3.9/site-packages/torch/utils/data/dataloader.py:628\u001b[0m, in \u001b[0;36m_BaseDataLoaderIter.__next__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 625\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_sampler_iter \u001b[39mis\u001b[39;00m \u001b[39mNone\u001b[39;00m:\n\u001b[1;32m 626\u001b[0m \u001b[39m# TODO(https://github.com/pytorch/pytorch/issues/76750)\u001b[39;00m\n\u001b[1;32m 627\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_reset() \u001b[39m# type: ignore[call-arg]\u001b[39;00m\n\u001b[0;32m--> 628\u001b[0m data \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_next_data()\n\u001b[1;32m 629\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_num_yielded \u001b[39m+\u001b[39m\u001b[39m=\u001b[39m \u001b[39m1\u001b[39m\n\u001b[1;32m 630\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_dataset_kind \u001b[39m==\u001b[39m _DatasetKind\u001b[39m.\u001b[39mIterable \u001b[39mand\u001b[39;00m \\\n\u001b[1;32m 631\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_IterableDataset_len_called \u001b[39mis\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39mNone\u001b[39;00m \u001b[39mand\u001b[39;00m \\\n\u001b[1;32m 632\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_num_yielded \u001b[39m>\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_IterableDataset_len_called:\n",
|
|
"File \u001b[0;32m~/miniforge3/envs/glounts/lib/python3.9/site-packages/torch/utils/data/dataloader.py:671\u001b[0m, in \u001b[0;36m_SingleProcessDataLoaderIter._next_data\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 669\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39m_next_data\u001b[39m(\u001b[39mself\u001b[39m):\n\u001b[1;32m 670\u001b[0m index \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_next_index() \u001b[39m# may raise StopIteration\u001b[39;00m\n\u001b[0;32m--> 671\u001b[0m data \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_dataset_fetcher\u001b[39m.\u001b[39;49mfetch(index) \u001b[39m# may raise StopIteration\u001b[39;00m\n\u001b[1;32m 672\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_pin_memory:\n\u001b[1;32m 673\u001b[0m data \u001b[39m=\u001b[39m _utils\u001b[39m.\u001b[39mpin_memory\u001b[39m.\u001b[39mpin_memory(data, \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_pin_memory_device)\n",
|
|
"File \u001b[0;32m~/miniforge3/envs/glounts/lib/python3.9/site-packages/torch/utils/data/_utils/fetch.py:34\u001b[0m, in \u001b[0;36m_IterableDatasetFetcher.fetch\u001b[0;34m(self, possibly_batched_index)\u001b[0m\n\u001b[1;32m 32\u001b[0m \u001b[39mfor\u001b[39;00m _ \u001b[39min\u001b[39;00m possibly_batched_index:\n\u001b[1;32m 33\u001b[0m \u001b[39mtry\u001b[39;00m:\n\u001b[0;32m---> 34\u001b[0m data\u001b[39m.\u001b[39mappend(\u001b[39mnext\u001b[39;49m(\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mdataset_iter))\n\u001b[1;32m 35\u001b[0m \u001b[39mexcept\u001b[39;00m \u001b[39mStopIteration\u001b[39;00m:\n\u001b[1;32m 36\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mended \u001b[39m=\u001b[39m \u001b[39mTrue\u001b[39;00m\n",
|
|
"File \u001b[0;32m~/miniforge3/envs/glounts/lib/python3.9/site-packages/gluonts/transform/_base.py:103\u001b[0m, in \u001b[0;36mTransformedDataset.__iter__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 102\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39m__iter__\u001b[39m(\u001b[39mself\u001b[39m) \u001b[39m-\u001b[39m\u001b[39m>\u001b[39m Iterator[DataEntry]:\n\u001b[0;32m--> 103\u001b[0m \u001b[39myield from\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mtransformation(\n\u001b[1;32m 104\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mbase_dataset, is_train\u001b[39m=\u001b[39m\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mis_train\n\u001b[1;32m 105\u001b[0m )\n",
|
|
"File \u001b[0;32m~/miniforge3/envs/glounts/lib/python3.9/site-packages/gluonts/transform/_base.py:124\u001b[0m, in \u001b[0;36mMapTransformation.__call__\u001b[0;34m(self, data_it, is_train)\u001b[0m\n\u001b[1;32m 121\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39m__call__\u001b[39m(\n\u001b[1;32m 122\u001b[0m \u001b[39mself\u001b[39m, data_it: Iterable[DataEntry], is_train: \u001b[39mbool\u001b[39m\n\u001b[1;32m 123\u001b[0m ) \u001b[39m-\u001b[39m\u001b[39m>\u001b[39m Iterator:\n\u001b[0;32m--> 124\u001b[0m \u001b[39mfor\u001b[39;00m data_entry \u001b[39min\u001b[39;00m data_it:\n\u001b[1;32m 125\u001b[0m \u001b[39mtry\u001b[39;00m:\n\u001b[1;32m 126\u001b[0m \u001b[39myield\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mmap_transform(data_entry\u001b[39m.\u001b[39mcopy(), is_train)\n",
|
|
"File \u001b[0;32m~/miniforge3/envs/glounts/lib/python3.9/site-packages/gluonts/transform/_base.py:124\u001b[0m, in \u001b[0;36mMapTransformation.__call__\u001b[0;34m(self, data_it, is_train)\u001b[0m\n\u001b[1;32m 121\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39m__call__\u001b[39m(\n\u001b[1;32m 122\u001b[0m \u001b[39mself\u001b[39m, data_it: Iterable[DataEntry], is_train: \u001b[39mbool\u001b[39m\n\u001b[1;32m 123\u001b[0m ) \u001b[39m-\u001b[39m\u001b[39m>\u001b[39m Iterator:\n\u001b[0;32m--> 124\u001b[0m \u001b[39mfor\u001b[39;00m data_entry \u001b[39min\u001b[39;00m data_it:\n\u001b[1;32m 125\u001b[0m \u001b[39mtry\u001b[39;00m:\n\u001b[1;32m 126\u001b[0m \u001b[39myield\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mmap_transform(data_entry\u001b[39m.\u001b[39mcopy(), is_train)\n",
|
|
"File \u001b[0;32m~/miniforge3/envs/glounts/lib/python3.9/site-packages/gluonts/transform/_base.py:189\u001b[0m, in \u001b[0;36mFlatMapTransformation.__call__\u001b[0;34m(self, data_it, is_train)\u001b[0m\n\u001b[1;32m 182\u001b[0m \u001b[39myield\u001b[39;00m result\n\u001b[1;32m 184\u001b[0m \u001b[39mif\u001b[39;00m (\n\u001b[1;32m 185\u001b[0m \u001b[39m# negative values disable the check\u001b[39;00m\n\u001b[1;32m 186\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mmax_idle_transforms \u001b[39m>\u001b[39m \u001b[39m0\u001b[39m\n\u001b[1;32m 187\u001b[0m \u001b[39mand\u001b[39;00m num_idle_transforms \u001b[39m>\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mmax_idle_transforms\n\u001b[1;32m 188\u001b[0m ):\n\u001b[0;32m--> 189\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mException\u001b[39;00m(\n\u001b[1;32m 190\u001b[0m \u001b[39m\"\u001b[39m\u001b[39mReached maximum number of idle transformation\u001b[39m\u001b[39m\"\u001b[39m\n\u001b[1;32m 191\u001b[0m \u001b[39m\"\u001b[39m\u001b[39m calls.\u001b[39m\u001b[39m\\n\u001b[39;00m\u001b[39mThis means the transformation looped over\u001b[39m\u001b[39m\"\u001b[39m\n\u001b[1;32m 192\u001b[0m \u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39m \u001b[39m\u001b[39m{\u001b[39;00m\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mmax_idle_transforms\u001b[39m}\u001b[39;00m\u001b[39m inputs without returning any\u001b[39m\u001b[39m\"\u001b[39m\n\u001b[1;32m 193\u001b[0m \u001b[39m\"\u001b[39m\u001b[39m output.\u001b[39m\u001b[39m\\n\u001b[39;00m\u001b[39mThis occurred in the following\u001b[39m\u001b[39m\"\u001b[39m\n\u001b[1;32m 194\u001b[0m \u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39m transformation:\u001b[39m\u001b[39m\\n\u001b[39;00m\u001b[39m{\u001b[39;00m\u001b[39mself\u001b[39m\u001b[39m}\u001b[39;00m\u001b[39m\"\u001b[39m\n\u001b[1;32m 195\u001b[0m )\n",
|
|
"\u001b[0;31mException\u001b[0m: Reached maximum number of idle transformation calls.\nThis means the transformation looped over 10 inputs without returning any output.\nThis occurred in the following transformation:\ngluonts.transform.split.InstanceSplitter(dummy_value=0.0, forecast_start_field=\"forecast_start\", future_length=24, instance_sampler=gluonts.transform.sampler.ExpectedNumInstanceSampler(axis=-1, min_past=192, min_future=24, num_instances=1.0, total_length=585345038, n=104191), is_pad_field=\"is_pad\", lead_time=0, output_NTC=True, past_length=192, start_field=\"start\", target_field=\"target\", time_series_fields=[\"time_feat\", \"observed_values\"])"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"predictor = estimator.train(dataset_train, num_workers=0)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"forecast_it, ts_it = make_evaluation_predictions(dataset=dataset_test,\n",
|
|
" predictor=predictor,\n",
|
|
" num_samples=100)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"forecasts = list(forecast_it)\n",
|
|
"targets = list(ts_it)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"plot(\n",
|
|
" target=targets[0],\n",
|
|
" forecast=forecasts[0],\n",
|
|
" prediction_length=dataset.metadata.prediction_length,\n",
|
|
")\n",
|
|
"plt.show()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"evaluator = MultivariateEvaluator(quantiles=(np.arange(20)/20.0)[1:], \n",
|
|
" target_agg_funcs={'sum': np.sum})"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"scrolled": true
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"agg_metric, item_metrics = evaluator(targets, forecasts, num_series=len(dataset_test))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"print(\"CRPS:\", agg_metric[\"mean_wQuantileLoss\"])\n",
|
|
"print(\"ND:\", agg_metric[\"ND\"])\n",
|
|
"print(\"NRMSE:\", agg_metric[\"NRMSE\"])\n",
|
|
"print(\"\")\n",
|
|
"print(\"CRPS-Sum:\", agg_metric[\"m_sum_mean_wQuantileLoss\"])\n",
|
|
"print(\"ND-Sum:\", agg_metric[\"m_sum_ND\"])\n",
|
|
"print(\"NRMSE-Sum:\", agg_metric[\"m_sum_NRMSE\"])"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3.9.15 ('glounts')",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.9.15"
|
|
},
|
|
"vscode": {
|
|
"interpreter": {
|
|
"hash": "7f25a1f13147a60511cf6766827402baf95cbe50d53a241197155306ee38fe70"
|
|
}
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 4
|
|
}
|