Files
pytorch-ts/test/model/deepar/test_auxillary_outputs.py
T
2021-07-06 13:16:48 +02:00

74 lines
2.3 KiB
Python

# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.
from itertools import islice
import torch
from gluonts.dataset.artificial import constant_dataset
from gluonts.dataset.loader import TrainDataLoader
from gluonts.torch.batchify import batchify
from pts import Trainer
from pts.model import get_module_forward_input_names
from pts.model.deepar import DeepAREstimator
from pts.modules import StudentTOutput
ds_info, train_ds, test_ds = constant_dataset()
freq = ds_info.metadata.freq
prediction_length = ds_info.prediction_length
def test_distribution():
"""
Makes sure additional tensors can be accessed and have expected shapes
"""
prediction_length = ds_info.prediction_length
estimator = DeepAREstimator(
freq=freq,
prediction_length=prediction_length,
input_size=15,
trainer=Trainer(epochs=1, num_batches_per_epoch=1),
distr_output=StudentTOutput(),
)
train_output = estimator.train_model(train_ds)
# todo adapt loader to anomaly detection use-case
batch_size = 2
num_samples = 3
training_data_loader = TrainDataLoader(
train_ds,
transform=train_output.transformation
+ estimator.create_instance_splitter("training"),
batch_size=batch_size,
num_batches_per_epoch=estimator.trainer.num_batches_per_epoch,
stack_fn=batchify,
)
seq_len = 2 * ds_info.prediction_length
for data_entry in islice(training_data_loader, 1):
input_names = get_module_forward_input_names(train_output.trained_net)
distr = train_output.trained_net.distribution(
*[data_entry[k] for k in input_names]
)
assert distr.sample((num_samples,)).shape == (
num_samples,
batch_size,
seq_len,
)