Files
pytorch-ts/pts/modules/lambda_layer.py
T
Kashif Rasul 4ad01ea2e3 ran isort
isort --recursive --atomic --apply pts
2019-10-30 09:38:19 +01:00

12 lines
228 B
Python

import torch
import torch.nn as nn
class LambdaLayer(nn.Module):
def __init__(self, function):
super().__init__()
self._func = function
def forward(self, x, *args):
return self._func(x, *args)