mirror of
https://github.com/wassname/torchsummaryX.git
synced 2026-06-27 18:07:44 +08:00
include module path in name, return dataframe
This commit is contained in:
@@ -1,12 +1,12 @@
|
||||
from collections import OrderedDict
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
|
||||
def summary(model, x, *args, **kwargs):
|
||||
"""Summarize the given input model.
|
||||
Summarized information are 1) output shape, 2) kernel shape,
|
||||
3) number of the parameters and 4) operations (Mult-Adds)
|
||||
|
||||
Args:
|
||||
model (Module): Model to summarize
|
||||
x (Tensor): Input tensor of the model with [N, C, H, W] shape
|
||||
@@ -16,8 +16,11 @@ def summary(model, x, *args, **kwargs):
|
||||
def register_hook(module):
|
||||
def hook(module, inputs, outputs):
|
||||
cls_name = str(module.__class__).split(".")[-1].split("'")[0]
|
||||
module_idx = len(summary)
|
||||
key = "{}_{}".format(module_idx, cls_name)
|
||||
|
||||
# Lookup name in a dict that includes parents
|
||||
for name, item in module_names.items():
|
||||
if item == module:
|
||||
key = name
|
||||
|
||||
info = OrderedDict()
|
||||
info["id"] = id(module)
|
||||
@@ -65,45 +68,51 @@ def summary(model, x, *args, **kwargs):
|
||||
# ignore Sequential and ModuleList
|
||||
if not module._modules:
|
||||
hooks.append(module.register_forward_hook(hook))
|
||||
|
||||
module_names = get_names_dict(model)
|
||||
|
||||
hooks = []
|
||||
summary = OrderedDict()
|
||||
|
||||
model.apply(register_hook)
|
||||
|
||||
with torch.no_grad():
|
||||
model(x) if not (kwargs or args) else model(x, *args, **kwargs)
|
||||
|
||||
for hook in hooks:
|
||||
hook.remove()
|
||||
|
||||
# Use pandas to align the columns
|
||||
df = pd.DataFrame(summary).T
|
||||
df['Mult-Adds (M)'] = pd.to_numeric(df['macs'], errors='coerce')/1e6
|
||||
df['Mult-Adds (M)'] = df['Mult-Adds (M)'].replace(np.nan, '-')
|
||||
df = df.rename(columns=dict(
|
||||
ksize='Kernel Shape',
|
||||
out='Output Shape',
|
||||
params='Params (K)',
|
||||
))
|
||||
df.index.name = 'Layer'
|
||||
df = df[['Kernel Shape', 'Output Shape', 'Params (K)', 'Mult-Adds (M)']]
|
||||
|
||||
return df
|
||||
|
||||
def get_names_dict(model):
|
||||
"""Recursive walk to get names including path."""
|
||||
names = {}
|
||||
|
||||
print("-"*100)
|
||||
print("{:<15} {:>20} {:>20} {:>20} {:>20}"
|
||||
.format("Layer", "Kernel Shape", "Output Shape",
|
||||
"# Params (K)", "# Mult-Adds (M)"))
|
||||
print("="*100)
|
||||
|
||||
total_params, total_macs = 0, 0
|
||||
for layer, info in summary.items():
|
||||
repr_ksize = str(info["ksize"])
|
||||
repr_out = str(info["out"])
|
||||
repr_params = info["params"]
|
||||
repr_macs = info["macs"]
|
||||
|
||||
if isinstance(repr_params, (int, float)):
|
||||
total_params += repr_params
|
||||
repr_params = "{0:,.2f}".format(repr_params/1000)
|
||||
if isinstance(repr_macs, (int, float)):
|
||||
total_macs += repr_macs
|
||||
repr_macs = "{0:,.2f}".format(repr_macs/1000000)
|
||||
|
||||
print("{:<15} {:>20} {:>20} {:>20} {:>20}"
|
||||
.format(layer, repr_ksize, repr_out, repr_params, repr_macs))
|
||||
|
||||
# for RNN, describe inner weights (i.e. w_hh, w_ih)
|
||||
for inner_name, inner_shape in info["inner"].items():
|
||||
print(" {:<13} {:>20}".format(inner_name, str(inner_shape)))
|
||||
|
||||
print("="*100)
|
||||
print("# Params: {0:,.2f}K".format(total_params/1000))
|
||||
print("# Mult-Adds: {0:,.2f}M".format(total_macs/1000000))
|
||||
print("-"*100)
|
||||
def _get_names(module, parent_name=''):
|
||||
for key, module in module.named_children():
|
||||
cls_name = str(module.__class__).split(".")[-1].split("'")[0]
|
||||
num_named_children = len(list(module.named_children()))
|
||||
if num_named_children>0:
|
||||
name = parent_name + '.' + key if parent_name else key
|
||||
else:
|
||||
name = parent_name + '.' + cls_name + '_'+ key if parent_name else key
|
||||
# print(f'{parent_name}, key={key}, nc={num_named_children}, name={name}')
|
||||
names[name] = module
|
||||
|
||||
if isinstance(module, torch.nn.Module):
|
||||
_get_names(module, parent_name=name)
|
||||
|
||||
_get_names(model)
|
||||
return names
|
||||
|
||||
Reference in New Issue
Block a user