Files
pytorch-ts/pts/modules/lambda_layer.py
T
2020-03-30 13:01:08 +02:00

11 lines
215 B
Python

import torch.nn as nn
class LambdaLayer(nn.Module):
def __init__(self, function):
super().__init__()
self._func = function
def forward(self, x, *args):
return self._func(x, *args)