mirror of
https://github.com/wassname/torchsummaryX.git
synced 2026-06-27 16:32:27 +08:00
Fix RNN
This commit is contained in:
@@ -38,14 +38,18 @@ def summary(model, x, *args, **kwargs):
|
||||
# to make [in_shape, out_shape, ksize, ksize]
|
||||
if len(ksize) > 1:
|
||||
ksize[0], ksize[1] = ksize[1], ksize[0]
|
||||
|
||||
info["ksize"] = ksize
|
||||
info["macs"] += int(info["params"] * np.prod(info["out"][2:]))
|
||||
|
||||
# ignore N, C when calculate Mult-Adds in ConvNd
|
||||
if "Conv" in cls_name:
|
||||
info["macs"] += int(param.nelement() * np.prod(info["out"][2:]))
|
||||
else:
|
||||
info["macs"] += param.nelement()
|
||||
|
||||
# RNN modules have inner weights such as weight_ih_l0
|
||||
elif "weight" in name:
|
||||
info["inner"][name] = list(param.size())
|
||||
info["macs"] += int(param.nelement() * info["out"][0] * \
|
||||
np.prod(info["out"][3:]))
|
||||
info["macs"] += param.nelement()
|
||||
|
||||
# if the current module is already-used, mark as "(recursive)"
|
||||
# check if this module has params
|
||||
|
||||
Reference in New Issue
Block a user