mirror of
https://github.com/wassname/torchsummaryX.git
synced 2026-06-27 15:17:36 +08:00
let it work with modules with are layers
This commit is contained in:
@@ -3,7 +3,10 @@ import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
|
||||
def summary(model, x, *args, **kwargs):
|
||||
# Some modules do the computation themselves using parameters or the parameters of children, treat these as layers
|
||||
layer_modules = (torch.nn.MultiheadAttention, )
|
||||
|
||||
def summary(model, x, *args, layer_modules=layer_modules, print_summary=True, **kwargs):
|
||||
"""Summarize the given input model.
|
||||
Summarized information are 1) output shape, 2) kernel shape,
|
||||
3) number of the parameters and 4) operations (Mult-Adds)
|
||||
@@ -19,9 +22,12 @@ def summary(model, x, *args, **kwargs):
|
||||
module_idx = len(summary)
|
||||
|
||||
# Lookup name in a dict that includes parents
|
||||
module_name = str(module_idx)
|
||||
for name, item in module_names.items():
|
||||
if item == module:
|
||||
key = "{}_{}".format(module_idx, name)
|
||||
module_name = name
|
||||
break
|
||||
key = "{}_{}".format(module_idx, name)
|
||||
|
||||
info = OrderedDict()
|
||||
info["id"] = id(module)
|
||||
@@ -71,8 +77,8 @@ def summary(model, x, *args, **kwargs):
|
||||
|
||||
summary[key] = info
|
||||
|
||||
# ignore Sequential and ModuleList
|
||||
if not module._modules:
|
||||
# ignore Sequential and ModuleList and other containers
|
||||
if isinstance(module, layer_modules) or not module._modules:
|
||||
hooks.append(module.register_forward_hook(hook))
|
||||
|
||||
module_names = get_names_dict(model)
|
||||
@@ -84,6 +90,12 @@ def summary(model, x, *args, **kwargs):
|
||||
try:
|
||||
with torch.no_grad():
|
||||
model(x) if not (kwargs or args) else model(x, *args, **kwargs)
|
||||
except Exception:
|
||||
# This can be usefull for debugging
|
||||
print("Failed to run torchsummaryX.summary, printing sizes of executed layers:")
|
||||
df = pd.DataFrame(summary).T
|
||||
print(df)
|
||||
raise
|
||||
finally:
|
||||
for hook in hooks:
|
||||
hook.remove()
|
||||
@@ -104,27 +116,29 @@ def summary(model, x, *args, **kwargs):
|
||||
df = df[["Kernel Shape", "Output Shape", "Params", "Mult-Adds"]]
|
||||
max_repr_width = max([len(row) for row in df.to_string().split("\n")])
|
||||
|
||||
option = pd.option_context(
|
||||
"display.max_rows", 600,
|
||||
"display.max_columns", 10,
|
||||
"display.float_format", pd.io.formats.format.EngFormatter(use_eng_prefix=True)
|
||||
)
|
||||
with option:
|
||||
print("="*max_repr_width)
|
||||
print(df.replace(np.nan, "-"))
|
||||
print("-"*max_repr_width)
|
||||
df_total = pd.DataFrame(
|
||||
{"Total params": (df_sum["Params"] + df_sum["params_nt"]),
|
||||
"Trainable params": df_sum["Params"],
|
||||
"Non-trainable params": df_sum["params_nt"],
|
||||
"Mult-Adds": df_sum["Mult-Adds"]
|
||||
},
|
||||
index=['Totals']
|
||||
).T
|
||||
print(df_total)
|
||||
print("="*max_repr_width)
|
||||
df_total = pd.DataFrame(
|
||||
{"Total params": (df_sum["Params"] + df_sum["params_nt"]),
|
||||
"Trainable params": df_sum["Params"],
|
||||
"Non-trainable params": df_sum["params_nt"],
|
||||
"Mult-Adds": df_sum["Mult-Adds"]
|
||||
},
|
||||
index=['Totals']
|
||||
).T
|
||||
|
||||
if print_summary:
|
||||
option = pd.option_context(
|
||||
"display.max_rows", 600,
|
||||
"display.max_columns", 10,
|
||||
"display.float_format", pd.io.formats.format.EngFormatter(use_eng_prefix=True)
|
||||
)
|
||||
with option:
|
||||
print("="*max_repr_width)
|
||||
print(df.replace(np.nan, "-"))
|
||||
print("-"*max_repr_width)
|
||||
print(df_total)
|
||||
print("="*max_repr_width)
|
||||
|
||||
return df
|
||||
return df, df_total
|
||||
|
||||
def get_names_dict(model):
|
||||
"""Recursive walk to get names including path."""
|
||||
|
||||
Reference in New Issue
Block a user