let it work with modules with are layers

This commit is contained in:
wassname
2019-11-19 15:23:40 +08:00
parent 1915788513
commit 006b7e06c4
+38 -24
View File
@@ -3,7 +3,10 @@ import numpy as np
import pandas as pd import pandas as pd
import torch import torch
def summary(model, x, *args, **kwargs): # Some modules do the computation themselves using parameters or the parameters of children, treat these as layers
layer_modules = (torch.nn.MultiheadAttention, )
def summary(model, x, *args, layer_modules=layer_modules, print_summary=True, **kwargs):
"""Summarize the given input model. """Summarize the given input model.
Summarized information are 1) output shape, 2) kernel shape, Summarized information are 1) output shape, 2) kernel shape,
3) number of the parameters and 4) operations (Mult-Adds) 3) number of the parameters and 4) operations (Mult-Adds)
@@ -19,9 +22,12 @@ def summary(model, x, *args, **kwargs):
module_idx = len(summary) module_idx = len(summary)
# Lookup name in a dict that includes parents # Lookup name in a dict that includes parents
module_name = str(module_idx)
for name, item in module_names.items(): for name, item in module_names.items():
if item == module: if item == module:
key = "{}_{}".format(module_idx, name) module_name = name
break
key = "{}_{}".format(module_idx, name)
info = OrderedDict() info = OrderedDict()
info["id"] = id(module) info["id"] = id(module)
@@ -71,8 +77,8 @@ def summary(model, x, *args, **kwargs):
summary[key] = info summary[key] = info
# ignore Sequential and ModuleList # ignore Sequential and ModuleList and other containers
if not module._modules: if isinstance(module, layer_modules) or not module._modules:
hooks.append(module.register_forward_hook(hook)) hooks.append(module.register_forward_hook(hook))
module_names = get_names_dict(model) module_names = get_names_dict(model)
@@ -84,6 +90,12 @@ def summary(model, x, *args, **kwargs):
try: try:
with torch.no_grad(): with torch.no_grad():
model(x) if not (kwargs or args) else model(x, *args, **kwargs) model(x) if not (kwargs or args) else model(x, *args, **kwargs)
except Exception:
# This can be usefull for debugging
print("Failed to run torchsummaryX.summary, printing sizes of executed layers:")
df = pd.DataFrame(summary).T
print(df)
raise
finally: finally:
for hook in hooks: for hook in hooks:
hook.remove() hook.remove()
@@ -104,27 +116,29 @@ def summary(model, x, *args, **kwargs):
df = df[["Kernel Shape", "Output Shape", "Params", "Mult-Adds"]] df = df[["Kernel Shape", "Output Shape", "Params", "Mult-Adds"]]
max_repr_width = max([len(row) for row in df.to_string().split("\n")]) max_repr_width = max([len(row) for row in df.to_string().split("\n")])
option = pd.option_context( df_total = pd.DataFrame(
"display.max_rows", 600, {"Total params": (df_sum["Params"] + df_sum["params_nt"]),
"display.max_columns", 10, "Trainable params": df_sum["Params"],
"display.float_format", pd.io.formats.format.EngFormatter(use_eng_prefix=True) "Non-trainable params": df_sum["params_nt"],
) "Mult-Adds": df_sum["Mult-Adds"]
with option: },
print("="*max_repr_width) index=['Totals']
print(df.replace(np.nan, "-")) ).T
print("-"*max_repr_width)
df_total = pd.DataFrame( if print_summary:
{"Total params": (df_sum["Params"] + df_sum["params_nt"]), option = pd.option_context(
"Trainable params": df_sum["Params"], "display.max_rows", 600,
"Non-trainable params": df_sum["params_nt"], "display.max_columns", 10,
"Mult-Adds": df_sum["Mult-Adds"] "display.float_format", pd.io.formats.format.EngFormatter(use_eng_prefix=True)
}, )
index=['Totals'] with option:
).T print("="*max_repr_width)
print(df_total) print(df.replace(np.nan, "-"))
print("="*max_repr_width) print("-"*max_repr_width)
print(df_total)
print("="*max_repr_width)
return df return df, df_total
def get_names_dict(model): def get_names_dict(model):
"""Recursive walk to get names including path.""" """Recursive walk to get names including path."""