Files
pytorch-ts/pts/transform/transform.py
T
2020-12-17 17:04:56 +01:00

158 lines
5.0 KiB
Python

# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.
from abc import ABC, abstractmethod
from functools import reduce
from typing import Callable, Iterator, Iterable, List
from gluonts.core.component import validated
from pts.dataset import DataEntry
MAX_IDLE_TRANSFORMS = 100
class Transformation(ABC):
@abstractmethod
def __call__(
self, data_it: Iterable[DataEntry], is_train: bool
) -> Iterator[DataEntry]:
pass
def estimate(self, data_it: Iterable[DataEntry]) -> Iterator[DataEntry]:
return data_it # default is to pass through without estimation
def chain(self, other: "Transformation") -> "Chain":
return Chain([self, other])
def __add__(self, other: "Transformation") -> "Chain":
return self.chain(other)
class Chain(Transformation):
"""
Chain multiple transformations together.
"""
@validated()
def __init__(self, trans: List[Transformation]) -> None:
self.transformations = []
for transformation in trans:
# flatten chains
if isinstance(transformation, Chain):
self.transformations.extend(transformation.transformations)
else:
self.transformations.append(transformation)
def __call__(
self, data_it: Iterable[DataEntry], is_train: bool
) -> Iterator[DataEntry]:
tmp = data_it
for t in self.transformations:
tmp = t(tmp, is_train)
return tmp
def estimate(self, data_it: Iterator[DataEntry]) -> Iterator[DataEntry]:
return reduce(lambda x, y: y.estimate(x), self.transformations, data_it)
class Identity(Transformation):
def __call__(
self, data_it: Iterable[DataEntry], is_train: bool
) -> Iterator[DataEntry]:
return data_it
class MapTransformation(Transformation):
"""
Base class for Transformations that returns exactly one result per input in the stream.
"""
def __call__(self, data_it: Iterable[DataEntry], is_train: bool) -> Iterator:
for data_entry in data_it:
try:
yield self.map_transform(data_entry.copy(), is_train)
except Exception as e:
raise e
@abstractmethod
def map_transform(self, data: DataEntry, is_train: bool) -> DataEntry:
pass
class SimpleTransformation(MapTransformation):
"""
Element wise transformations that are the same in train and test mode
"""
def map_transform(self, data: DataEntry, is_train: bool) -> DataEntry:
return self.transform(data)
@abstractmethod
def transform(self, data: DataEntry) -> DataEntry:
pass
class AdhocTransform(SimpleTransformation):
"""
Applies a function as a transformation
This is called ad-hoc, because it is not serializable.
It is OK to use this for experiments and outside of a model pipeline that
needs to be serialized.
"""
def __init__(self, func: Callable[[DataEntry], DataEntry]) -> None:
self.func = func
def transform(self, data: DataEntry) -> DataEntry:
return self.func(data.copy())
class FlatMapTransformation(Transformation):
"""
Transformations that yield zero or more results per input, but do not combine
elements from the input stream.
"""
def __call__(self, data_it: Iterable[DataEntry], is_train: bool) -> Iterator:
num_idle_transforms = 0
for data_entry in data_it:
num_idle_transforms += 1
try:
for result in self.flatmap_transform(data_entry.copy(), is_train):
num_idle_transforms = 0
yield result
except Exception as e:
raise e
if num_idle_transforms > MAX_IDLE_TRANSFORMS:
raise Exception(
f"Reached maximum number of idle transformation calls.\n"
f"This means the transformation looped over "
f"MAX_IDLE_TRANSFORMS={MAX_IDLE_TRANSFORMS} "
f"inputs without returning any output.\n"
f"This occurred in the following transformation:\n{self}"
)
@abstractmethod
def flatmap_transform(self, data: DataEntry, is_train: bool) -> Iterator[DataEntry]:
pass
class FilterTransformation(FlatMapTransformation):
def __init__(self, condition: Callable[[DataEntry], bool]) -> None:
self.condition = condition
def flatmap_transform(self, data: DataEntry, is_train: bool) -> Iterator[DataEntry]:
if self.condition(data):
yield data