Files
pytorch-ts/pts/model/utils.py
T
2020-12-17 17:04:56 +01:00

42 lines
1.2 KiB
Python

import inspect
from typing import Optional
import torch
import torch.nn as nn
def get_module_forward_input_names(module: nn.Module):
params = inspect.signature(module.forward).parameters
param_names = [k for k, v in params.items() if not str(v).startswith("*")]
return param_names
def weighted_average(
x: torch.Tensor, weights: Optional[torch.Tensor] = None, dim=None
) -> torch.Tensor:
"""
Computes the weighted average of a given tensor across a given dim, masking
values associated with weight zero,
meaning instead of `nan * 0 = nan` you will get `0 * 0 = 0`.
Parameters
----------
x
Input tensor, of which the average must be computed.
weights
Weights tensor, of the same shape as `x`.
dim
The dim along which to average `x`
Returns
-------
Tensor:
The tensor with values averaged along the specified `dim`.
"""
if weights is not None:
weighted_tensor = torch.where(weights != 0, x * weights, torch.zeros_like(x))
sum_weights = torch.clamp(weights.sum(dim=dim) if dim else weights.sum(), min=1.0)
return (weighted_tensor.sum(dim=dim) if dim else weighted_tensor.sum())/ sum_weights
else:
return x.mean(dim=dim)