Files
pytorch-ts/pts/dataset/loader.py
T
Kashif Rasul ea9b2b7df5 Gluon master (#29)
* Estimator needs an create_instance_splitter now

* updated estimators and tests

* fix test

* validated
2021-02-07 17:43:07 +01:00

38 lines
1.1 KiB
Python

from typing import Optional
import itertools
from torch.utils.data import IterableDataset
from gluonts.dataset.common import Dataset
from gluonts.transform import Transformation, TransformedDataset
from gluonts.itertools import Cyclic, PseudoShuffled, Cached
class TransformedIterableDataset(IterableDataset):
def __init__(
self,
dataset: Dataset,
transform: Transformation,
is_train: bool = True,
shuffle_buffer_length: Optional[int] = None,
cache_data: bool = False,
):
super().__init__()
self.shuffle_buffer_length = shuffle_buffer_length
self.transformed_dataset = TransformedDataset(
Cyclic(dataset) if not cache_data else Cached(Cyclic(dataset)),
transform,
is_train=is_train,
)
def __iter__(self):
if self.shuffle_buffer_length is None:
return iter(self.transformed_dataset)
else:
shuffled = PseudoShuffled(
self.transformed_dataset,
shuffle_buffer_length=self.shuffle_buffer_length,
)
return iter(shuffled)