diff --git a/torchsummaryX/torchsummaryX.py b/torchsummaryX/torchsummaryX.py index e709cc3..528a0e2 100644 --- a/torchsummaryX/torchsummaryX.py +++ b/torchsummaryX/torchsummaryX.py @@ -3,7 +3,10 @@ import numpy as np import pandas as pd import torch -def summary(model, x, *args, **kwargs): +# Some modules do the computation themselves using parameters or the parameters of children, treat these as layers +layer_modules = (torch.nn.MultiheadAttention, ) + +def summary(model, x, *args, layer_modules=layer_modules, print_summary=True, **kwargs): """Summarize the given input model. Summarized information are 1) output shape, 2) kernel shape, 3) number of the parameters and 4) operations (Mult-Adds) @@ -19,9 +22,12 @@ def summary(model, x, *args, **kwargs): module_idx = len(summary) # Lookup name in a dict that includes parents + module_name = str(module_idx) for name, item in module_names.items(): if item == module: - key = "{}_{}".format(module_idx, name) + module_name = name + break + key = "{}_{}".format(module_idx, name) info = OrderedDict() info["id"] = id(module) @@ -71,8 +77,8 @@ def summary(model, x, *args, **kwargs): summary[key] = info - # ignore Sequential and ModuleList - if not module._modules: + # ignore Sequential and ModuleList and other containers + if isinstance(module, layer_modules) or not module._modules: hooks.append(module.register_forward_hook(hook)) module_names = get_names_dict(model) @@ -84,6 +90,12 @@ def summary(model, x, *args, **kwargs): try: with torch.no_grad(): model(x) if not (kwargs or args) else model(x, *args, **kwargs) + except Exception: + # This can be usefull for debugging + print("Failed to run torchsummaryX.summary, printing sizes of executed layers:") + df = pd.DataFrame(summary).T + print(df) + raise finally: for hook in hooks: hook.remove() @@ -104,27 +116,29 @@ def summary(model, x, *args, **kwargs): df = df[["Kernel Shape", "Output Shape", "Params", "Mult-Adds"]] max_repr_width = max([len(row) for row in df.to_string().split("\n")]) - option = pd.option_context( - "display.max_rows", 600, - "display.max_columns", 10, - "display.float_format", pd.io.formats.format.EngFormatter(use_eng_prefix=True) - ) - with option: - print("="*max_repr_width) - print(df.replace(np.nan, "-")) - print("-"*max_repr_width) - df_total = pd.DataFrame( - {"Total params": (df_sum["Params"] + df_sum["params_nt"]), - "Trainable params": df_sum["Params"], - "Non-trainable params": df_sum["params_nt"], - "Mult-Adds": df_sum["Mult-Adds"] - }, - index=['Totals'] - ).T - print(df_total) - print("="*max_repr_width) + df_total = pd.DataFrame( + {"Total params": (df_sum["Params"] + df_sum["params_nt"]), + "Trainable params": df_sum["Params"], + "Non-trainable params": df_sum["params_nt"], + "Mult-Adds": df_sum["Mult-Adds"] + }, + index=['Totals'] + ).T + + if print_summary: + option = pd.option_context( + "display.max_rows", 600, + "display.max_columns", 10, + "display.float_format", pd.io.formats.format.EngFormatter(use_eng_prefix=True) + ) + with option: + print("="*max_repr_width) + print(df.replace(np.nan, "-")) + print("-"*max_repr_width) + print(df_total) + print("="*max_repr_width) - return df + return df, df_total def get_names_dict(model): """Recursive walk to get names including path."""