mirror of
https://github.com/wassname/pytorch-ts.git
synced 2026-06-27 19:32:05 +08:00
42 lines
1.2 KiB
Python
42 lines
1.2 KiB
Python
import inspect
|
|
from typing import Optional
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
|
|
def get_module_forward_input_names(module: nn.Module):
|
|
params = inspect.signature(module.forward).parameters
|
|
param_names = [k for k, v in params.items() if not str(v).startswith("*")]
|
|
return param_names
|
|
|
|
|
|
def weighted_average(
|
|
x: torch.Tensor, weights: Optional[torch.Tensor] = None, dim=None
|
|
) -> torch.Tensor:
|
|
"""
|
|
Computes the weighted average of a given tensor across a given dim, masking
|
|
values associated with weight zero,
|
|
meaning instead of `nan * 0 = nan` you will get `0 * 0 = 0`.
|
|
|
|
Parameters
|
|
----------
|
|
x
|
|
Input tensor, of which the average must be computed.
|
|
weights
|
|
Weights tensor, of the same shape as `x`.
|
|
dim
|
|
The dim along which to average `x`
|
|
|
|
Returns
|
|
-------
|
|
Tensor:
|
|
The tensor with values averaged along the specified `dim`.
|
|
"""
|
|
if weights is not None:
|
|
weighted_tensor = torch.where(weights != 0, x * weights, torch.zeros_like(x))
|
|
sum_weights = torch.clamp(weights.sum(dim=dim) if dim else weights.sum(), min=1.0)
|
|
return (weighted_tensor.sum(dim=dim) if dim else weighted_tensor.sum())/ sum_weights
|
|
else:
|
|
return x.mean(dim=dim)
|