From be8002024eeba76526e60b9b5f211119fc090445 Mon Sep 17 00:00:00 2001 From: Mike Clark Date: Wed, 5 Jun 2019 05:27:28 +0000 Subject: [PATCH] include module path in name, return dataframe --- torchsummaryX/torchsummaryX.py | 77 +++++++++++++++++++--------------- 1 file changed, 43 insertions(+), 34 deletions(-) diff --git a/torchsummaryX/torchsummaryX.py b/torchsummaryX/torchsummaryX.py index db15310..666957f 100644 --- a/torchsummaryX/torchsummaryX.py +++ b/torchsummaryX/torchsummaryX.py @@ -1,12 +1,12 @@ from collections import OrderedDict import numpy as np +import pandas as pd import torch def summary(model, x, *args, **kwargs): """Summarize the given input model. Summarized information are 1) output shape, 2) kernel shape, 3) number of the parameters and 4) operations (Mult-Adds) - Args: model (Module): Model to summarize x (Tensor): Input tensor of the model with [N, C, H, W] shape @@ -16,8 +16,11 @@ def summary(model, x, *args, **kwargs): def register_hook(module): def hook(module, inputs, outputs): cls_name = str(module.__class__).split(".")[-1].split("'")[0] - module_idx = len(summary) - key = "{}_{}".format(module_idx, cls_name) + + # Lookup name in a dict that includes parents + for name, item in module_names.items(): + if item == module: + key = name info = OrderedDict() info["id"] = id(module) @@ -65,45 +68,51 @@ def summary(model, x, *args, **kwargs): # ignore Sequential and ModuleList if not module._modules: hooks.append(module.register_forward_hook(hook)) + + module_names = get_names_dict(model) hooks = [] summary = OrderedDict() model.apply(register_hook) + with torch.no_grad(): model(x) if not (kwargs or args) else model(x, *args, **kwargs) for hook in hooks: hook.remove() + + # Use pandas to align the columns + df = pd.DataFrame(summary).T + df['Mult-Adds (M)'] = pd.to_numeric(df['macs'], errors='coerce')/1e6 + df['Mult-Adds (M)'] = df['Mult-Adds (M)'].replace(np.nan, '-') + df = df.rename(columns=dict( + ksize='Kernel Shape', + out='Output Shape', + params='Params (K)', + )) + df.index.name = 'Layer' + df = df[['Kernel Shape', 'Output Shape', 'Params (K)', 'Mult-Adds (M)']] + + return df + +def get_names_dict(model): + """Recursive walk to get names including path.""" + names = {} - print("-"*100) - print("{:<15} {:>20} {:>20} {:>20} {:>20}" - .format("Layer", "Kernel Shape", "Output Shape", - "# Params (K)", "# Mult-Adds (M)")) - print("="*100) - - total_params, total_macs = 0, 0 - for layer, info in summary.items(): - repr_ksize = str(info["ksize"]) - repr_out = str(info["out"]) - repr_params = info["params"] - repr_macs = info["macs"] - - if isinstance(repr_params, (int, float)): - total_params += repr_params - repr_params = "{0:,.2f}".format(repr_params/1000) - if isinstance(repr_macs, (int, float)): - total_macs += repr_macs - repr_macs = "{0:,.2f}".format(repr_macs/1000000) - - print("{:<15} {:>20} {:>20} {:>20} {:>20}" - .format(layer, repr_ksize, repr_out, repr_params, repr_macs)) - - # for RNN, describe inner weights (i.e. w_hh, w_ih) - for inner_name, inner_shape in info["inner"].items(): - print(" {:<13} {:>20}".format(inner_name, str(inner_shape))) - - print("="*100) - print("# Params: {0:,.2f}K".format(total_params/1000)) - print("# Mult-Adds: {0:,.2f}M".format(total_macs/1000000)) - print("-"*100) + def _get_names(module, parent_name=''): + for key, module in module.named_children(): + cls_name = str(module.__class__).split(".")[-1].split("'")[0] + num_named_children = len(list(module.named_children())) + if num_named_children>0: + name = parent_name + '.' + key if parent_name else key + else: + name = parent_name + '.' + cls_name + '_'+ key if parent_name else key +# print(f'{parent_name}, key={key}, nc={num_named_children}, name={name}') + names[name] = module + + if isinstance(module, torch.nn.Module): + _get_names(module, parent_name=name) + + _get_names(model) + return names