diff --git a/torchsummaryX/torchsummaryX.py b/torchsummaryX/torchsummaryX.py index 666957f..10e2eeb 100644 --- a/torchsummaryX/torchsummaryX.py +++ b/torchsummaryX/torchsummaryX.py @@ -16,11 +16,12 @@ def summary(model, x, *args, **kwargs): def register_hook(module): def hook(module, inputs, outputs): cls_name = str(module.__class__).split(".")[-1].split("'")[0] + module_idx = len(summary) # Lookup name in a dict that includes parents for name, item in module_names.items(): if item == module: - key = name + key = '{}_{}'.format(module_idx, name) info = OrderedDict() info["id"] = id(module) @@ -85,15 +86,20 @@ def summary(model, x, *args, **kwargs): # Use pandas to align the columns df = pd.DataFrame(summary).T df['Mult-Adds (M)'] = pd.to_numeric(df['macs'], errors='coerce')/1e6 - df['Mult-Adds (M)'] = df['Mult-Adds (M)'].replace(np.nan, '-') + df['Params (K)'] = pd.to_numeric(df['params'], errors='coerce')/1e3 df = df.rename(columns=dict( ksize='Kernel Shape', out='Output Shape', - params='Params (K)', )) df.index.name = 'Layer' df = df[['Kernel Shape', 'Output Shape', 'Params (K)', 'Mult-Adds (M)']] + print("="*100) + print(df) + print("="*100) + print(df.sum()) + print("="*100) + return df def get_names_dict(model): @@ -107,8 +113,7 @@ def get_names_dict(model): if num_named_children>0: name = parent_name + '.' + key if parent_name else key else: - name = parent_name + '.' + cls_name + '_'+ key if parent_name else key -# print(f'{parent_name}, key={key}, nc={num_named_children}, name={name}') + name = parent_name + '.' + cls_name + '_'+ key if parent_name else key names[name] = module if isinstance(module, torch.nn.Module):