mirror of
https://github.com/wassname/pytorch-ts.git
synced 2026-07-02 10:46:07 +08:00
158 lines
5.0 KiB
Python
158 lines
5.0 KiB
Python
# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License").
|
|
# You may not use this file except in compliance with the License.
|
|
# A copy of the License is located at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# or in the "license" file accompanying this file. This file is distributed
|
|
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
|
|
# express or implied. See the License for the specific language governing
|
|
# permissions and limitations under the License.
|
|
|
|
|
|
from abc import ABC, abstractmethod
|
|
from functools import reduce
|
|
from typing import Callable, Iterator, Iterable, List
|
|
|
|
from gluonts.core.component import validated
|
|
from pts.dataset import DataEntry
|
|
|
|
MAX_IDLE_TRANSFORMS = 100
|
|
|
|
|
|
class Transformation(ABC):
|
|
@abstractmethod
|
|
def __call__(
|
|
self, data_it: Iterable[DataEntry], is_train: bool
|
|
) -> Iterator[DataEntry]:
|
|
pass
|
|
|
|
def estimate(self, data_it: Iterable[DataEntry]) -> Iterator[DataEntry]:
|
|
return data_it # default is to pass through without estimation
|
|
|
|
def chain(self, other: "Transformation") -> "Chain":
|
|
return Chain([self, other])
|
|
|
|
def __add__(self, other: "Transformation") -> "Chain":
|
|
return self.chain(other)
|
|
|
|
|
|
class Chain(Transformation):
|
|
"""
|
|
Chain multiple transformations together.
|
|
"""
|
|
|
|
@validated()
|
|
def __init__(self, trans: List[Transformation]) -> None:
|
|
self.transformations = []
|
|
for transformation in trans:
|
|
# flatten chains
|
|
if isinstance(transformation, Chain):
|
|
self.transformations.extend(transformation.transformations)
|
|
else:
|
|
self.transformations.append(transformation)
|
|
|
|
def __call__(
|
|
self, data_it: Iterable[DataEntry], is_train: bool
|
|
) -> Iterator[DataEntry]:
|
|
tmp = data_it
|
|
for t in self.transformations:
|
|
tmp = t(tmp, is_train)
|
|
return tmp
|
|
|
|
def estimate(self, data_it: Iterator[DataEntry]) -> Iterator[DataEntry]:
|
|
return reduce(lambda x, y: y.estimate(x), self.transformations, data_it)
|
|
|
|
|
|
class Identity(Transformation):
|
|
def __call__(
|
|
self, data_it: Iterable[DataEntry], is_train: bool
|
|
) -> Iterator[DataEntry]:
|
|
return data_it
|
|
|
|
|
|
class MapTransformation(Transformation):
|
|
"""
|
|
Base class for Transformations that returns exactly one result per input in the stream.
|
|
"""
|
|
|
|
def __call__(self, data_it: Iterable[DataEntry], is_train: bool) -> Iterator:
|
|
for data_entry in data_it:
|
|
try:
|
|
yield self.map_transform(data_entry.copy(), is_train)
|
|
except Exception as e:
|
|
raise e
|
|
|
|
@abstractmethod
|
|
def map_transform(self, data: DataEntry, is_train: bool) -> DataEntry:
|
|
pass
|
|
|
|
|
|
class SimpleTransformation(MapTransformation):
|
|
"""
|
|
Element wise transformations that are the same in train and test mode
|
|
"""
|
|
|
|
def map_transform(self, data: DataEntry, is_train: bool) -> DataEntry:
|
|
return self.transform(data)
|
|
|
|
@abstractmethod
|
|
def transform(self, data: DataEntry) -> DataEntry:
|
|
pass
|
|
|
|
|
|
class AdhocTransform(SimpleTransformation):
|
|
"""
|
|
Applies a function as a transformation
|
|
This is called ad-hoc, because it is not serializable.
|
|
It is OK to use this for experiments and outside of a model pipeline that
|
|
needs to be serialized.
|
|
"""
|
|
|
|
def __init__(self, func: Callable[[DataEntry], DataEntry]) -> None:
|
|
self.func = func
|
|
|
|
def transform(self, data: DataEntry) -> DataEntry:
|
|
return self.func(data.copy())
|
|
|
|
|
|
class FlatMapTransformation(Transformation):
|
|
"""
|
|
Transformations that yield zero or more results per input, but do not combine
|
|
elements from the input stream.
|
|
"""
|
|
|
|
def __call__(self, data_it: Iterable[DataEntry], is_train: bool) -> Iterator:
|
|
num_idle_transforms = 0
|
|
for data_entry in data_it:
|
|
num_idle_transforms += 1
|
|
try:
|
|
for result in self.flatmap_transform(data_entry.copy(), is_train):
|
|
num_idle_transforms = 0
|
|
yield result
|
|
except Exception as e:
|
|
raise e
|
|
if num_idle_transforms > MAX_IDLE_TRANSFORMS:
|
|
raise Exception(
|
|
f"Reached maximum number of idle transformation calls.\n"
|
|
f"This means the transformation looped over "
|
|
f"MAX_IDLE_TRANSFORMS={MAX_IDLE_TRANSFORMS} "
|
|
f"inputs without returning any output.\n"
|
|
f"This occurred in the following transformation:\n{self}"
|
|
)
|
|
|
|
@abstractmethod
|
|
def flatmap_transform(self, data: DataEntry, is_train: bool) -> Iterator[DataEntry]:
|
|
pass
|
|
|
|
|
|
class FilterTransformation(FlatMapTransformation):
|
|
def __init__(self, condition: Callable[[DataEntry], bool]) -> None:
|
|
self.condition = condition
|
|
|
|
def flatmap_transform(self, data: DataEntry, is_train: bool) -> Iterator[DataEntry]:
|
|
if self.condition(data):
|
|
yield data
|