mirror of
https://github.com/wassname/torchsummaryX.git
synced 2026-06-27 16:47:38 +08:00
let it work with modules with are layers
This commit is contained in:
@@ -3,7 +3,10 @@ import numpy as np
|
|||||||
import pandas as pd
|
import pandas as pd
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
def summary(model, x, *args, **kwargs):
|
# Some modules do the computation themselves using parameters or the parameters of children, treat these as layers
|
||||||
|
layer_modules = (torch.nn.MultiheadAttention, )
|
||||||
|
|
||||||
|
def summary(model, x, *args, layer_modules=layer_modules, print_summary=True, **kwargs):
|
||||||
"""Summarize the given input model.
|
"""Summarize the given input model.
|
||||||
Summarized information are 1) output shape, 2) kernel shape,
|
Summarized information are 1) output shape, 2) kernel shape,
|
||||||
3) number of the parameters and 4) operations (Mult-Adds)
|
3) number of the parameters and 4) operations (Mult-Adds)
|
||||||
@@ -19,9 +22,12 @@ def summary(model, x, *args, **kwargs):
|
|||||||
module_idx = len(summary)
|
module_idx = len(summary)
|
||||||
|
|
||||||
# Lookup name in a dict that includes parents
|
# Lookup name in a dict that includes parents
|
||||||
|
module_name = str(module_idx)
|
||||||
for name, item in module_names.items():
|
for name, item in module_names.items():
|
||||||
if item == module:
|
if item == module:
|
||||||
key = "{}_{}".format(module_idx, name)
|
module_name = name
|
||||||
|
break
|
||||||
|
key = "{}_{}".format(module_idx, name)
|
||||||
|
|
||||||
info = OrderedDict()
|
info = OrderedDict()
|
||||||
info["id"] = id(module)
|
info["id"] = id(module)
|
||||||
@@ -71,8 +77,8 @@ def summary(model, x, *args, **kwargs):
|
|||||||
|
|
||||||
summary[key] = info
|
summary[key] = info
|
||||||
|
|
||||||
# ignore Sequential and ModuleList
|
# ignore Sequential and ModuleList and other containers
|
||||||
if not module._modules:
|
if isinstance(module, layer_modules) or not module._modules:
|
||||||
hooks.append(module.register_forward_hook(hook))
|
hooks.append(module.register_forward_hook(hook))
|
||||||
|
|
||||||
module_names = get_names_dict(model)
|
module_names = get_names_dict(model)
|
||||||
@@ -84,6 +90,12 @@ def summary(model, x, *args, **kwargs):
|
|||||||
try:
|
try:
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
model(x) if not (kwargs or args) else model(x, *args, **kwargs)
|
model(x) if not (kwargs or args) else model(x, *args, **kwargs)
|
||||||
|
except Exception:
|
||||||
|
# This can be usefull for debugging
|
||||||
|
print("Failed to run torchsummaryX.summary, printing sizes of executed layers:")
|
||||||
|
df = pd.DataFrame(summary).T
|
||||||
|
print(df)
|
||||||
|
raise
|
||||||
finally:
|
finally:
|
||||||
for hook in hooks:
|
for hook in hooks:
|
||||||
hook.remove()
|
hook.remove()
|
||||||
@@ -104,27 +116,29 @@ def summary(model, x, *args, **kwargs):
|
|||||||
df = df[["Kernel Shape", "Output Shape", "Params", "Mult-Adds"]]
|
df = df[["Kernel Shape", "Output Shape", "Params", "Mult-Adds"]]
|
||||||
max_repr_width = max([len(row) for row in df.to_string().split("\n")])
|
max_repr_width = max([len(row) for row in df.to_string().split("\n")])
|
||||||
|
|
||||||
option = pd.option_context(
|
df_total = pd.DataFrame(
|
||||||
"display.max_rows", 600,
|
{"Total params": (df_sum["Params"] + df_sum["params_nt"]),
|
||||||
"display.max_columns", 10,
|
"Trainable params": df_sum["Params"],
|
||||||
"display.float_format", pd.io.formats.format.EngFormatter(use_eng_prefix=True)
|
"Non-trainable params": df_sum["params_nt"],
|
||||||
)
|
"Mult-Adds": df_sum["Mult-Adds"]
|
||||||
with option:
|
},
|
||||||
print("="*max_repr_width)
|
index=['Totals']
|
||||||
print(df.replace(np.nan, "-"))
|
).T
|
||||||
print("-"*max_repr_width)
|
|
||||||
df_total = pd.DataFrame(
|
if print_summary:
|
||||||
{"Total params": (df_sum["Params"] + df_sum["params_nt"]),
|
option = pd.option_context(
|
||||||
"Trainable params": df_sum["Params"],
|
"display.max_rows", 600,
|
||||||
"Non-trainable params": df_sum["params_nt"],
|
"display.max_columns", 10,
|
||||||
"Mult-Adds": df_sum["Mult-Adds"]
|
"display.float_format", pd.io.formats.format.EngFormatter(use_eng_prefix=True)
|
||||||
},
|
)
|
||||||
index=['Totals']
|
with option:
|
||||||
).T
|
print("="*max_repr_width)
|
||||||
print(df_total)
|
print(df.replace(np.nan, "-"))
|
||||||
print("="*max_repr_width)
|
print("-"*max_repr_width)
|
||||||
|
print(df_total)
|
||||||
|
print("="*max_repr_width)
|
||||||
|
|
||||||
return df
|
return df, df_total
|
||||||
|
|
||||||
def get_names_dict(model):
|
def get_names_dict(model):
|
||||||
"""Recursive walk to get names including path."""
|
"""Recursive walk to get names including path."""
|
||||||
|
|||||||
Reference in New Issue
Block a user