diff --git a/README.md b/README.md index 1f14a14..d6ba22b 100644 --- a/README.md +++ b/README.md @@ -36,19 +36,18 @@ class Net(nn.Module): summary(Net(), torch.zeros((1, 1, 28, 28))) ``` ``` -==================================================================================================== - Kernel Shape Output Shape Params (K) Mult-Adds (M) -Layer -0_conv1 [1, 10, 5, 5] [1, 10, 24, 24] 0.26 0.1440 -1_conv2 [10, 20, 5, 5] [1, 20, 8, 8] 5.02 0.3200 -2_conv2_drop - [1, 20, 8, 8] NaN NaN -3_fc1 [320, 50] [1, 50] 16.05 0.0160 -4_fc2 [50, 10] [1, 10] 0.51 0.0005 -==================================================================================================== -Params (K) 21.8400 -Mult-Adds (M) 0.4805 -dtype: float64 -==================================================================================================== +======================================================================== + Kernel Shape Output Shape Params (K) Mult-Adds (M) +Layer +0_conv1 [1, 10, 5, 5] [1, 10, 24, 24] 0.26 0.144 +1_conv2 [10, 20, 5, 5] [1, 20, 8, 8] 5.02 0.32 +2_conv2_drop - [1, 20, 8, 8] - - +3_fc1 [320, 50] [1, 50] 16.05 0.016 +4_fc2 [50, 10] [1, 10] 0.51 0.0005 +------------------------------------------------------------------------ +Params (K): 21.84 +Mult-Adds (M): 0.4805 +======================================================================== ``` RNN @@ -75,18 +74,16 @@ inputs = torch.zeros((100, 1), dtype=torch.long) # [length, batch_size] summary(Net(), inputs) ``` ``` -==================================================================================================== +================================================================== Kernel Shape Output Shape Params (K) Mult-Adds (M) -Layer +Layer 0_embedding [300, 20] [100, 1, 300] 6.00 0.006000 1_encoder - [100, 1, 512] 3768.32 3.760128 2_decoder [512, 20] [100, 1, 20] 10.26 0.010240 -==================================================================================================== -Params (K) 3784.580000 -Mult-Adds (M) 3.776368 -dtype: float64 -==================================================================================================== - +------------------------------------------------------------------ +Params (K): 3784.5800000000004 +Mult-Adds (M): 3.7763679999999997 +================================================================== ``` Recursive NN @@ -103,23 +100,15 @@ class Net(nn.Module): summary(Net(), torch.zeros((1, 64, 28, 28))) ``` ``` ----------------------------------------------------------------------------------------------------- -Layer Kernel Shape Output Shape # Params (K) # Mult-Adds (M) -==================================================================================================== - Kernel Shape Output Shape Params (K) Mult-Adds (M) -Layer -0_conv1 [64, 64, 3, 3] [1, 64, 28, 28] 36.928 28.901376 -1_conv1 [64, 64, 3, 3] [1, 64, 28, 28] NaN 28.901376 -==================================================================================================== -Kernel Shape [64, 64, 3, 3, 64, 64, 3, 3] -Output Shape [1, 64, 28, 28, 1, 64, 28, 28] -Params (K) 36.928 -Mult-Adds (M) 57.8028 -dtype: object -==================================================================================================== -# Params: 36.93K -# Mult-Adds: 57.80M ----------------------------------------------------------------------------------------------------- +=================================================================== + Kernel Shape Output Shape Params (K) Mult-Adds (M) +Layer +0_conv1 [64, 64, 3, 3] [1, 64, 28, 28] 36.928 28.901376 +1_conv1 [64, 64, 3, 3] [1, 64, 28, 28] - 28.901376 +------------------------------------------------------------------- +Params (K): 36.928 +Mult-Adds (M): 57.802752 +=================================================================== ``` Multiple arguments diff --git a/torchsummaryX/torchsummaryX.py b/torchsummaryX/torchsummaryX.py index fd4d76c..29bad77 100644 --- a/torchsummaryX/torchsummaryX.py +++ b/torchsummaryX/torchsummaryX.py @@ -17,11 +17,11 @@ def summary(model, x, *args, **kwargs): def hook(module, inputs, outputs): cls_name = str(module.__class__).split(".")[-1].split("'")[0] module_idx = len(summary) - + # Lookup name in a dict that includes parents for name, item in module_names.items(): if item == module: - key = '{}_{}'.format(module_idx, name) + key = "{}_{}".format(module_idx, name) info = OrderedDict() info["id"] = id(module) @@ -69,55 +69,59 @@ def summary(model, x, *args, **kwargs): # ignore Sequential and ModuleList if not module._modules: hooks.append(module.register_forward_hook(hook)) - + module_names = get_names_dict(model) hooks = [] summary = OrderedDict() model.apply(register_hook) - + with torch.no_grad(): model(x) if not (kwargs or args) else model(x, *args, **kwargs) for hook in hooks: hook.remove() - + # Use pandas to align the columns df = pd.DataFrame(summary).T - df['Mult-Adds (M)'] = pd.to_numeric(df['macs'], errors='coerce')/1e6 - df['Params (K)'] = pd.to_numeric(df['params'], errors='coerce')/1e3 + df["Mult-Adds (M)"] = pd.to_numeric(df["macs"], errors="coerce")/1e6 + df["Params (K)"] = pd.to_numeric(df["params"], errors="coerce")/1e3 df = df.rename(columns=dict( - ksize='Kernel Shape', - out='Output Shape', + ksize="Kernel Shape", + out="Output Shape", )) - df.index.name = 'Layer' - df = df[['Kernel Shape', 'Output Shape', 'Params (K)', 'Mult-Adds (M)']] - - print("="*100) - print(df.replace(np.nan, '-') - print("="*100) - print(df.sum()) - print("="*100) + df.index.name = "Layer" + df = df[["Kernel Shape", "Output Shape", "Params (K)", "Mult-Adds (M)"]] + df_sum = df.sum() + + max_repr_width = max([len(row) for row in df.to_string().split("\n")]) + + print("="*max_repr_width) + print(df.replace(np.nan, "-")) + print("-"*max_repr_width) + print("Params (K): ", df_sum["Params (K)"]) + print("Mult-Adds (M): ", df_sum["Mult-Adds (M)"]) + print("="*max_repr_width) return df - + def get_names_dict(model): """Recursive walk to get names including path.""" names = {} - def _get_names(module, parent_name=''): - for key, module in module.named_children(): - cls_name = str(module.__class__).split(".")[-1].split("'")[0] - num_named_children = len(list(module.named_children())) - if num_named_children>0: - name = parent_name + '.' + key if parent_name else key + def _get_names(module, parent_name=""): + for key, m in module.named_children(): + cls_name = str(m.__class__).split(".")[-1].split("'")[0] + num_named_children = len(list(m.named_children())) + if num_named_children > 0: + name = parent_name + "." + key if parent_name else key else: - name = parent_name + '.' + cls_name + '_'+ key if parent_name else key - names[name] = module - - if isinstance(module, torch.nn.Module): - _get_names(module, parent_name=name) - + name = parent_name + "." + cls_name + "_"+ key if parent_name else key + names[name] = m + + if isinstance(m, torch.nn.Module): + _get_names(m, parent_name=name) + _get_names(model) return names