mirror of
https://github.com/wassname/torchsummaryX.git
synced 2026-06-27 16:32:27 +08:00
Fix bugs
This commit is contained in:
@@ -36,19 +36,18 @@ class Net(nn.Module):
|
||||
summary(Net(), torch.zeros((1, 1, 28, 28)))
|
||||
```
|
||||
```
|
||||
====================================================================================================
|
||||
Kernel Shape Output Shape Params (K) Mult-Adds (M)
|
||||
Layer
|
||||
0_conv1 [1, 10, 5, 5] [1, 10, 24, 24] 0.26 0.1440
|
||||
1_conv2 [10, 20, 5, 5] [1, 20, 8, 8] 5.02 0.3200
|
||||
2_conv2_drop - [1, 20, 8, 8] NaN NaN
|
||||
3_fc1 [320, 50] [1, 50] 16.05 0.0160
|
||||
4_fc2 [50, 10] [1, 10] 0.51 0.0005
|
||||
====================================================================================================
|
||||
Params (K) 21.8400
|
||||
Mult-Adds (M) 0.4805
|
||||
dtype: float64
|
||||
====================================================================================================
|
||||
========================================================================
|
||||
Kernel Shape Output Shape Params (K) Mult-Adds (M)
|
||||
Layer
|
||||
0_conv1 [1, 10, 5, 5] [1, 10, 24, 24] 0.26 0.144
|
||||
1_conv2 [10, 20, 5, 5] [1, 20, 8, 8] 5.02 0.32
|
||||
2_conv2_drop - [1, 20, 8, 8] - -
|
||||
3_fc1 [320, 50] [1, 50] 16.05 0.016
|
||||
4_fc2 [50, 10] [1, 10] 0.51 0.0005
|
||||
------------------------------------------------------------------------
|
||||
Params (K): 21.84
|
||||
Mult-Adds (M): 0.4805
|
||||
========================================================================
|
||||
```
|
||||
|
||||
RNN
|
||||
@@ -75,18 +74,16 @@ inputs = torch.zeros((100, 1), dtype=torch.long) # [length, batch_size]
|
||||
summary(Net(), inputs)
|
||||
```
|
||||
```
|
||||
====================================================================================================
|
||||
==================================================================
|
||||
Kernel Shape Output Shape Params (K) Mult-Adds (M)
|
||||
Layer
|
||||
Layer
|
||||
0_embedding [300, 20] [100, 1, 300] 6.00 0.006000
|
||||
1_encoder - [100, 1, 512] 3768.32 3.760128
|
||||
2_decoder [512, 20] [100, 1, 20] 10.26 0.010240
|
||||
====================================================================================================
|
||||
Params (K) 3784.580000
|
||||
Mult-Adds (M) 3.776368
|
||||
dtype: float64
|
||||
====================================================================================================
|
||||
|
||||
------------------------------------------------------------------
|
||||
Params (K): 3784.5800000000004
|
||||
Mult-Adds (M): 3.7763679999999997
|
||||
==================================================================
|
||||
```
|
||||
|
||||
Recursive NN
|
||||
@@ -103,23 +100,15 @@ class Net(nn.Module):
|
||||
summary(Net(), torch.zeros((1, 64, 28, 28)))
|
||||
```
|
||||
```
|
||||
----------------------------------------------------------------------------------------------------
|
||||
Layer Kernel Shape Output Shape # Params (K) # Mult-Adds (M)
|
||||
====================================================================================================
|
||||
Kernel Shape Output Shape Params (K) Mult-Adds (M)
|
||||
Layer
|
||||
0_conv1 [64, 64, 3, 3] [1, 64, 28, 28] 36.928 28.901376
|
||||
1_conv1 [64, 64, 3, 3] [1, 64, 28, 28] NaN 28.901376
|
||||
====================================================================================================
|
||||
Kernel Shape [64, 64, 3, 3, 64, 64, 3, 3]
|
||||
Output Shape [1, 64, 28, 28, 1, 64, 28, 28]
|
||||
Params (K) 36.928
|
||||
Mult-Adds (M) 57.8028
|
||||
dtype: object
|
||||
====================================================================================================
|
||||
# Params: 36.93K
|
||||
# Mult-Adds: 57.80M
|
||||
----------------------------------------------------------------------------------------------------
|
||||
===================================================================
|
||||
Kernel Shape Output Shape Params (K) Mult-Adds (M)
|
||||
Layer
|
||||
0_conv1 [64, 64, 3, 3] [1, 64, 28, 28] 36.928 28.901376
|
||||
1_conv1 [64, 64, 3, 3] [1, 64, 28, 28] - 28.901376
|
||||
-------------------------------------------------------------------
|
||||
Params (K): 36.928
|
||||
Mult-Adds (M): 57.802752
|
||||
===================================================================
|
||||
```
|
||||
|
||||
Multiple arguments
|
||||
|
||||
@@ -17,11 +17,11 @@ def summary(model, x, *args, **kwargs):
|
||||
def hook(module, inputs, outputs):
|
||||
cls_name = str(module.__class__).split(".")[-1].split("'")[0]
|
||||
module_idx = len(summary)
|
||||
|
||||
|
||||
# Lookup name in a dict that includes parents
|
||||
for name, item in module_names.items():
|
||||
if item == module:
|
||||
key = '{}_{}'.format(module_idx, name)
|
||||
key = "{}_{}".format(module_idx, name)
|
||||
|
||||
info = OrderedDict()
|
||||
info["id"] = id(module)
|
||||
@@ -69,55 +69,59 @@ def summary(model, x, *args, **kwargs):
|
||||
# ignore Sequential and ModuleList
|
||||
if not module._modules:
|
||||
hooks.append(module.register_forward_hook(hook))
|
||||
|
||||
|
||||
module_names = get_names_dict(model)
|
||||
|
||||
hooks = []
|
||||
summary = OrderedDict()
|
||||
|
||||
model.apply(register_hook)
|
||||
|
||||
|
||||
with torch.no_grad():
|
||||
model(x) if not (kwargs or args) else model(x, *args, **kwargs)
|
||||
|
||||
for hook in hooks:
|
||||
hook.remove()
|
||||
|
||||
|
||||
# Use pandas to align the columns
|
||||
df = pd.DataFrame(summary).T
|
||||
df['Mult-Adds (M)'] = pd.to_numeric(df['macs'], errors='coerce')/1e6
|
||||
df['Params (K)'] = pd.to_numeric(df['params'], errors='coerce')/1e3
|
||||
df["Mult-Adds (M)"] = pd.to_numeric(df["macs"], errors="coerce")/1e6
|
||||
df["Params (K)"] = pd.to_numeric(df["params"], errors="coerce")/1e3
|
||||
df = df.rename(columns=dict(
|
||||
ksize='Kernel Shape',
|
||||
out='Output Shape',
|
||||
ksize="Kernel Shape",
|
||||
out="Output Shape",
|
||||
))
|
||||
df.index.name = 'Layer'
|
||||
df = df[['Kernel Shape', 'Output Shape', 'Params (K)', 'Mult-Adds (M)']]
|
||||
|
||||
print("="*100)
|
||||
print(df.replace(np.nan, '-')
|
||||
print("="*100)
|
||||
print(df.sum())
|
||||
print("="*100)
|
||||
df.index.name = "Layer"
|
||||
df = df[["Kernel Shape", "Output Shape", "Params (K)", "Mult-Adds (M)"]]
|
||||
df_sum = df.sum()
|
||||
|
||||
max_repr_width = max([len(row) for row in df.to_string().split("\n")])
|
||||
|
||||
print("="*max_repr_width)
|
||||
print(df.replace(np.nan, "-"))
|
||||
print("-"*max_repr_width)
|
||||
print("Params (K): ", df_sum["Params (K)"])
|
||||
print("Mult-Adds (M): ", df_sum["Mult-Adds (M)"])
|
||||
print("="*max_repr_width)
|
||||
|
||||
return df
|
||||
|
||||
|
||||
def get_names_dict(model):
|
||||
"""Recursive walk to get names including path."""
|
||||
names = {}
|
||||
|
||||
def _get_names(module, parent_name=''):
|
||||
for key, module in module.named_children():
|
||||
cls_name = str(module.__class__).split(".")[-1].split("'")[0]
|
||||
num_named_children = len(list(module.named_children()))
|
||||
if num_named_children>0:
|
||||
name = parent_name + '.' + key if parent_name else key
|
||||
def _get_names(module, parent_name=""):
|
||||
for key, m in module.named_children():
|
||||
cls_name = str(m.__class__).split(".")[-1].split("'")[0]
|
||||
num_named_children = len(list(m.named_children()))
|
||||
if num_named_children > 0:
|
||||
name = parent_name + "." + key if parent_name else key
|
||||
else:
|
||||
name = parent_name + '.' + cls_name + '_'+ key if parent_name else key
|
||||
names[name] = module
|
||||
|
||||
if isinstance(module, torch.nn.Module):
|
||||
_get_names(module, parent_name=name)
|
||||
|
||||
name = parent_name + "." + cls_name + "_"+ key if parent_name else key
|
||||
names[name] = m
|
||||
|
||||
if isinstance(m, torch.nn.Module):
|
||||
_get_names(m, parent_name=name)
|
||||
|
||||
_get_names(model)
|
||||
return names
|
||||
|
||||
Reference in New Issue
Block a user