From be8002024eeba76526e60b9b5f211119fc090445 Mon Sep 17 00:00:00 2001 From: Mike Clark Date: Wed, 5 Jun 2019 05:27:28 +0000 Subject: [PATCH 1/4] include module path in name, return dataframe --- torchsummaryX/torchsummaryX.py | 77 +++++++++++++++++++--------------- 1 file changed, 43 insertions(+), 34 deletions(-) diff --git a/torchsummaryX/torchsummaryX.py b/torchsummaryX/torchsummaryX.py index db15310..666957f 100644 --- a/torchsummaryX/torchsummaryX.py +++ b/torchsummaryX/torchsummaryX.py @@ -1,12 +1,12 @@ from collections import OrderedDict import numpy as np +import pandas as pd import torch def summary(model, x, *args, **kwargs): """Summarize the given input model. Summarized information are 1) output shape, 2) kernel shape, 3) number of the parameters and 4) operations (Mult-Adds) - Args: model (Module): Model to summarize x (Tensor): Input tensor of the model with [N, C, H, W] shape @@ -16,8 +16,11 @@ def summary(model, x, *args, **kwargs): def register_hook(module): def hook(module, inputs, outputs): cls_name = str(module.__class__).split(".")[-1].split("'")[0] - module_idx = len(summary) - key = "{}_{}".format(module_idx, cls_name) + + # Lookup name in a dict that includes parents + for name, item in module_names.items(): + if item == module: + key = name info = OrderedDict() info["id"] = id(module) @@ -65,45 +68,51 @@ def summary(model, x, *args, **kwargs): # ignore Sequential and ModuleList if not module._modules: hooks.append(module.register_forward_hook(hook)) + + module_names = get_names_dict(model) hooks = [] summary = OrderedDict() model.apply(register_hook) + with torch.no_grad(): model(x) if not (kwargs or args) else model(x, *args, **kwargs) for hook in hooks: hook.remove() + + # Use pandas to align the columns + df = pd.DataFrame(summary).T + df['Mult-Adds (M)'] = pd.to_numeric(df['macs'], errors='coerce')/1e6 + df['Mult-Adds (M)'] = df['Mult-Adds (M)'].replace(np.nan, '-') + df = df.rename(columns=dict( + ksize='Kernel Shape', + out='Output Shape', + params='Params (K)', + )) + df.index.name = 'Layer' + df = df[['Kernel Shape', 'Output Shape', 'Params (K)', 'Mult-Adds (M)']] + + return df + +def get_names_dict(model): + """Recursive walk to get names including path.""" + names = {} - print("-"*100) - print("{:<15} {:>20} {:>20} {:>20} {:>20}" - .format("Layer", "Kernel Shape", "Output Shape", - "# Params (K)", "# Mult-Adds (M)")) - print("="*100) - - total_params, total_macs = 0, 0 - for layer, info in summary.items(): - repr_ksize = str(info["ksize"]) - repr_out = str(info["out"]) - repr_params = info["params"] - repr_macs = info["macs"] - - if isinstance(repr_params, (int, float)): - total_params += repr_params - repr_params = "{0:,.2f}".format(repr_params/1000) - if isinstance(repr_macs, (int, float)): - total_macs += repr_macs - repr_macs = "{0:,.2f}".format(repr_macs/1000000) - - print("{:<15} {:>20} {:>20} {:>20} {:>20}" - .format(layer, repr_ksize, repr_out, repr_params, repr_macs)) - - # for RNN, describe inner weights (i.e. w_hh, w_ih) - for inner_name, inner_shape in info["inner"].items(): - print(" {:<13} {:>20}".format(inner_name, str(inner_shape))) - - print("="*100) - print("# Params: {0:,.2f}K".format(total_params/1000)) - print("# Mult-Adds: {0:,.2f}M".format(total_macs/1000000)) - print("-"*100) + def _get_names(module, parent_name=''): + for key, module in module.named_children(): + cls_name = str(module.__class__).split(".")[-1].split("'")[0] + num_named_children = len(list(module.named_children())) + if num_named_children>0: + name = parent_name + '.' + key if parent_name else key + else: + name = parent_name + '.' + cls_name + '_'+ key if parent_name else key +# print(f'{parent_name}, key={key}, nc={num_named_children}, name={name}') + names[name] = module + + if isinstance(module, torch.nn.Module): + _get_names(module, parent_name=name) + + _get_names(model) + return names From 24cdc4958c1d833ab10fc8d1c72d81fb86ef5d8f Mon Sep 17 00:00:00 2001 From: Mike Clark Date: Wed, 5 Jun 2019 05:39:12 +0000 Subject: [PATCH 2/4] Show changed outputs --- README.md | 55 +++++++++++++++++++++++++++++++------------------------ 1 file changed, 31 insertions(+), 24 deletions(-) diff --git a/README.md b/README.md index e9c8ef1..1f14a14 100644 --- a/README.md +++ b/README.md @@ -36,18 +36,19 @@ class Net(nn.Module): summary(Net(), torch.zeros((1, 1, 28, 28))) ``` ``` ----------------------------------------------------------------------------------------------------- -Layer Kernel Shape Output Shape # Params (K) # Mult-Adds (M) ==================================================================================================== -0_Conv2d [1, 10, 5, 5] [1, 10, 24, 24] 0.26 0.14 -1_Conv2d [10, 20, 5, 5] [1, 20, 8, 8] 5.02 0.32 -2_Dropout2d - [1, 20, 8, 8] - - -3_Linear [320, 50] [1, 50] 16.05 0.02 -4_Linear [50, 10] [1, 10] 0.51 0.00 + Kernel Shape Output Shape Params (K) Mult-Adds (M) +Layer +0_conv1 [1, 10, 5, 5] [1, 10, 24, 24] 0.26 0.1440 +1_conv2 [10, 20, 5, 5] [1, 20, 8, 8] 5.02 0.3200 +2_conv2_drop - [1, 20, 8, 8] NaN NaN +3_fc1 [320, 50] [1, 50] 16.05 0.0160 +4_fc2 [50, 10] [1, 10] 0.51 0.0005 +==================================================================================================== +Params (K) 21.8400 +Mult-Adds (M) 0.4805 +dtype: float64 ==================================================================================================== -# Params: 21.84K -# Mult-Adds: 0.48M ----------------------------------------------------------------------------------------------------- ``` RNN @@ -74,20 +75,18 @@ inputs = torch.zeros((100, 1), dtype=torch.long) # [length, batch_size] summary(Net(), inputs) ``` ``` ----------------------------------------------------------------------------------------------------- -Layer Kernel Shape Output Shape # Params (K) # Mult-Adds (M) ==================================================================================================== -0_Embedding [300, 20] [100, 1, 300] 6.00 0.01 -1_LSTM - [100, 1, 512] 3,768.32 3.76 - weight_ih_l0 [2048, 300] - weight_hh_l0 [2048, 512] - weight_ih_l1 [2048, 512] - weight_hh_l1 [2048, 512] -2_Linear [512, 20] [100, 1, 20] 10.26 0.01 + Kernel Shape Output Shape Params (K) Mult-Adds (M) +Layer +0_embedding [300, 20] [100, 1, 300] 6.00 0.006000 +1_encoder - [100, 1, 512] 3768.32 3.760128 +2_decoder [512, 20] [100, 1, 20] 10.26 0.010240 ==================================================================================================== -# Params: 3,784.58K -# Mult-Adds: 3.78M ----------------------------------------------------------------------------------------------------- +Params (K) 3784.580000 +Mult-Adds (M) 3.776368 +dtype: float64 +==================================================================================================== + ``` Recursive NN @@ -107,8 +106,16 @@ summary(Net(), torch.zeros((1, 64, 28, 28))) ---------------------------------------------------------------------------------------------------- Layer Kernel Shape Output Shape # Params (K) # Mult-Adds (M) ==================================================================================================== -0_Conv2d [64, 64, 3, 3] [1, 64, 28, 28] 36.93 28.90 -1_Conv2d [64, 64, 3, 3] [1, 64, 28, 28] (recursive) 28.90 + Kernel Shape Output Shape Params (K) Mult-Adds (M) +Layer +0_conv1 [64, 64, 3, 3] [1, 64, 28, 28] 36.928 28.901376 +1_conv1 [64, 64, 3, 3] [1, 64, 28, 28] NaN 28.901376 +==================================================================================================== +Kernel Shape [64, 64, 3, 3, 64, 64, 3, 3] +Output Shape [1, 64, 28, 28, 1, 64, 28, 28] +Params (K) 36.928 +Mult-Adds (M) 57.8028 +dtype: object ==================================================================================================== # Params: 36.93K # Mult-Adds: 57.80M From 2db8410f43906d742cc22d146cd43ff739a932da Mon Sep 17 00:00:00 2001 From: Mike Clark Date: Wed, 5 Jun 2019 05:39:44 +0000 Subject: [PATCH 3/4] fix duplicate keys --- torchsummaryX/torchsummaryX.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/torchsummaryX/torchsummaryX.py b/torchsummaryX/torchsummaryX.py index 666957f..10e2eeb 100644 --- a/torchsummaryX/torchsummaryX.py +++ b/torchsummaryX/torchsummaryX.py @@ -16,11 +16,12 @@ def summary(model, x, *args, **kwargs): def register_hook(module): def hook(module, inputs, outputs): cls_name = str(module.__class__).split(".")[-1].split("'")[0] + module_idx = len(summary) # Lookup name in a dict that includes parents for name, item in module_names.items(): if item == module: - key = name + key = '{}_{}'.format(module_idx, name) info = OrderedDict() info["id"] = id(module) @@ -85,15 +86,20 @@ def summary(model, x, *args, **kwargs): # Use pandas to align the columns df = pd.DataFrame(summary).T df['Mult-Adds (M)'] = pd.to_numeric(df['macs'], errors='coerce')/1e6 - df['Mult-Adds (M)'] = df['Mult-Adds (M)'].replace(np.nan, '-') + df['Params (K)'] = pd.to_numeric(df['params'], errors='coerce')/1e3 df = df.rename(columns=dict( ksize='Kernel Shape', out='Output Shape', - params='Params (K)', )) df.index.name = 'Layer' df = df[['Kernel Shape', 'Output Shape', 'Params (K)', 'Mult-Adds (M)']] + print("="*100) + print(df) + print("="*100) + print(df.sum()) + print("="*100) + return df def get_names_dict(model): @@ -107,8 +113,7 @@ def get_names_dict(model): if num_named_children>0: name = parent_name + '.' + key if parent_name else key else: - name = parent_name + '.' + cls_name + '_'+ key if parent_name else key -# print(f'{parent_name}, key={key}, nc={num_named_children}, name={name}') + name = parent_name + '.' + cls_name + '_'+ key if parent_name else key names[name] = module if isinstance(module, torch.nn.Module): From 755069ccb3b3ecc9d7bd6abbb46be581f5d1f8fe Mon Sep 17 00:00:00 2001 From: Mike Clark Date: Wed, 5 Jun 2019 05:45:39 +0000 Subject: [PATCH 4/4] don't print nans --- torchsummaryX/torchsummaryX.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchsummaryX/torchsummaryX.py b/torchsummaryX/torchsummaryX.py index 10e2eeb..fd4d76c 100644 --- a/torchsummaryX/torchsummaryX.py +++ b/torchsummaryX/torchsummaryX.py @@ -95,7 +95,7 @@ def summary(model, x, *args, **kwargs): df = df[['Kernel Shape', 'Output Shape', 'Params (K)', 'Mult-Adds (M)']] print("="*100) - print(df) + print(df.replace(np.nan, '-') print("="*100) print(df.sum()) print("="*100)