This commit is contained in:
Namhyuk Ahn
2018-07-29 01:55:45 +09:00
parent e4cf386bed
commit a0af764402
+8 -4
View File
@@ -38,14 +38,18 @@ def summary(model, x, *args, **kwargs):
# to make [in_shape, out_shape, ksize, ksize]
if len(ksize) > 1:
ksize[0], ksize[1] = ksize[1], ksize[0]
info["ksize"] = ksize
info["macs"] += int(info["params"] * np.prod(info["out"][2:]))
# ignore N, C when calculate Mult-Adds in ConvNd
if "Conv" in cls_name:
info["macs"] += int(param.nelement() * np.prod(info["out"][2:]))
else:
info["macs"] += param.nelement()
# RNN modules have inner weights such as weight_ih_l0
elif "weight" in name:
info["inner"][name] = list(param.size())
info["macs"] += int(param.nelement() * info["out"][0] * \
np.prod(info["out"][3:]))
info["macs"] += param.nelement()
# if the current module is already-used, mark as "(recursive)"
# check if this module has params