mirror of
https://github.com/wassname/pytorch-ts.git
synced 2026-06-27 18:06:19 +08:00
ea9b2b7df5
* Estimator needs an create_instance_splitter now * updated estimators and tests * fix test * validated
38 lines
1.1 KiB
Python
38 lines
1.1 KiB
Python
from typing import Optional
|
|
import itertools
|
|
|
|
from torch.utils.data import IterableDataset
|
|
|
|
from gluonts.dataset.common import Dataset
|
|
from gluonts.transform import Transformation, TransformedDataset
|
|
from gluonts.itertools import Cyclic, PseudoShuffled, Cached
|
|
|
|
|
|
class TransformedIterableDataset(IterableDataset):
|
|
def __init__(
|
|
self,
|
|
dataset: Dataset,
|
|
transform: Transformation,
|
|
is_train: bool = True,
|
|
shuffle_buffer_length: Optional[int] = None,
|
|
cache_data: bool = False,
|
|
):
|
|
super().__init__()
|
|
self.shuffle_buffer_length = shuffle_buffer_length
|
|
|
|
self.transformed_dataset = TransformedDataset(
|
|
Cyclic(dataset) if not cache_data else Cached(Cyclic(dataset)),
|
|
transform,
|
|
is_train=is_train,
|
|
)
|
|
|
|
def __iter__(self):
|
|
if self.shuffle_buffer_length is None:
|
|
return iter(self.transformed_dataset)
|
|
else:
|
|
shuffled = PseudoShuffled(
|
|
self.transformed_dataset,
|
|
shuffle_buffer_length=self.shuffle_buffer_length,
|
|
)
|
|
return iter(shuffled)
|