From a0af764402545822a7190a076d90ecaef372ee85 Mon Sep 17 00:00:00 2001 From: Namhyuk Ahn Date: Sun, 29 Jul 2018 01:55:45 +0900 Subject: [PATCH] Fix RNN --- torchsummaryX/torchsummaryX.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/torchsummaryX/torchsummaryX.py b/torchsummaryX/torchsummaryX.py index fb35c40..694c229 100644 --- a/torchsummaryX/torchsummaryX.py +++ b/torchsummaryX/torchsummaryX.py @@ -38,14 +38,18 @@ def summary(model, x, *args, **kwargs): # to make [in_shape, out_shape, ksize, ksize] if len(ksize) > 1: ksize[0], ksize[1] = ksize[1], ksize[0] - info["ksize"] = ksize - info["macs"] += int(info["params"] * np.prod(info["out"][2:])) + + # ignore N, C when calculate Mult-Adds in ConvNd + if "Conv" in cls_name: + info["macs"] += int(param.nelement() * np.prod(info["out"][2:])) + else: + info["macs"] += param.nelement() + # RNN modules have inner weights such as weight_ih_l0 elif "weight" in name: info["inner"][name] = list(param.size()) - info["macs"] += int(param.nelement() * info["out"][0] * \ - np.prod(info["out"][3:])) + info["macs"] += param.nelement() # if the current module is already-used, mark as "(recursive)" # check if this module has params