diff --git a/torchsummaryX/torchsummaryX.py b/torchsummaryX/torchsummaryX.py index 29bad77..5870598 100644 --- a/torchsummaryX/torchsummaryX.py +++ b/torchsummaryX/torchsummaryX.py @@ -34,7 +34,7 @@ def summary(model, x, *args, **kwargs): info["inner"] = OrderedDict() info["params"], info["macs"] = 0, 0 for name, param in module.named_parameters(): - info["params"] += param.nelement() + info["params"] += param.nelement() * param.requires_grad if name == "weight": ksize = list(param.size())