let it work with modules with are layers

This commit is contained in:
wassname
2019-11-19 15:23:40 +08:00
parent 1915788513
commit 006b7e06c4
+38 -24
View File
@@ -3,7 +3,10 @@ import numpy as np
import pandas as pd
import torch
def summary(model, x, *args, **kwargs):
# Some modules do the computation themselves using parameters or the parameters of children, treat these as layers
layer_modules = (torch.nn.MultiheadAttention, )
def summary(model, x, *args, layer_modules=layer_modules, print_summary=True, **kwargs):
"""Summarize the given input model.
Summarized information are 1) output shape, 2) kernel shape,
3) number of the parameters and 4) operations (Mult-Adds)
@@ -19,9 +22,12 @@ def summary(model, x, *args, **kwargs):
module_idx = len(summary)
# Lookup name in a dict that includes parents
module_name = str(module_idx)
for name, item in module_names.items():
if item == module:
key = "{}_{}".format(module_idx, name)
module_name = name
break
key = "{}_{}".format(module_idx, name)
info = OrderedDict()
info["id"] = id(module)
@@ -71,8 +77,8 @@ def summary(model, x, *args, **kwargs):
summary[key] = info
# ignore Sequential and ModuleList
if not module._modules:
# ignore Sequential and ModuleList and other containers
if isinstance(module, layer_modules) or not module._modules:
hooks.append(module.register_forward_hook(hook))
module_names = get_names_dict(model)
@@ -84,6 +90,12 @@ def summary(model, x, *args, **kwargs):
try:
with torch.no_grad():
model(x) if not (kwargs or args) else model(x, *args, **kwargs)
except Exception:
# This can be usefull for debugging
print("Failed to run torchsummaryX.summary, printing sizes of executed layers:")
df = pd.DataFrame(summary).T
print(df)
raise
finally:
for hook in hooks:
hook.remove()
@@ -104,27 +116,29 @@ def summary(model, x, *args, **kwargs):
df = df[["Kernel Shape", "Output Shape", "Params", "Mult-Adds"]]
max_repr_width = max([len(row) for row in df.to_string().split("\n")])
option = pd.option_context(
"display.max_rows", 600,
"display.max_columns", 10,
"display.float_format", pd.io.formats.format.EngFormatter(use_eng_prefix=True)
)
with option:
print("="*max_repr_width)
print(df.replace(np.nan, "-"))
print("-"*max_repr_width)
df_total = pd.DataFrame(
{"Total params": (df_sum["Params"] + df_sum["params_nt"]),
"Trainable params": df_sum["Params"],
"Non-trainable params": df_sum["params_nt"],
"Mult-Adds": df_sum["Mult-Adds"]
},
index=['Totals']
).T
print(df_total)
print("="*max_repr_width)
df_total = pd.DataFrame(
{"Total params": (df_sum["Params"] + df_sum["params_nt"]),
"Trainable params": df_sum["Params"],
"Non-trainable params": df_sum["params_nt"],
"Mult-Adds": df_sum["Mult-Adds"]
},
index=['Totals']
).T
return df
if print_summary:
option = pd.option_context(
"display.max_rows", 600,
"display.max_columns", 10,
"display.float_format", pd.io.formats.format.EngFormatter(use_eng_prefix=True)
)
with option:
print("="*max_repr_width)
print(df.replace(np.nan, "-"))
print("-"*max_repr_width)
print(df_total)
print("="*max_repr_width)
return df, df_total
def get_names_dict(model):
"""Recursive walk to get names including path."""