mirror of
https://github.com/wassname/torchsummaryX.git
synced 2026-06-27 16:32:27 +08:00
fix duplicate keys
This commit is contained in:
@@ -16,11 +16,12 @@ def summary(model, x, *args, **kwargs):
|
||||
def register_hook(module):
|
||||
def hook(module, inputs, outputs):
|
||||
cls_name = str(module.__class__).split(".")[-1].split("'")[0]
|
||||
module_idx = len(summary)
|
||||
|
||||
# Lookup name in a dict that includes parents
|
||||
for name, item in module_names.items():
|
||||
if item == module:
|
||||
key = name
|
||||
key = '{}_{}'.format(module_idx, name)
|
||||
|
||||
info = OrderedDict()
|
||||
info["id"] = id(module)
|
||||
@@ -85,15 +86,20 @@ def summary(model, x, *args, **kwargs):
|
||||
# Use pandas to align the columns
|
||||
df = pd.DataFrame(summary).T
|
||||
df['Mult-Adds (M)'] = pd.to_numeric(df['macs'], errors='coerce')/1e6
|
||||
df['Mult-Adds (M)'] = df['Mult-Adds (M)'].replace(np.nan, '-')
|
||||
df['Params (K)'] = pd.to_numeric(df['params'], errors='coerce')/1e3
|
||||
df = df.rename(columns=dict(
|
||||
ksize='Kernel Shape',
|
||||
out='Output Shape',
|
||||
params='Params (K)',
|
||||
))
|
||||
df.index.name = 'Layer'
|
||||
df = df[['Kernel Shape', 'Output Shape', 'Params (K)', 'Mult-Adds (M)']]
|
||||
|
||||
print("="*100)
|
||||
print(df)
|
||||
print("="*100)
|
||||
print(df.sum())
|
||||
print("="*100)
|
||||
|
||||
return df
|
||||
|
||||
def get_names_dict(model):
|
||||
@@ -107,8 +113,7 @@ def get_names_dict(model):
|
||||
if num_named_children>0:
|
||||
name = parent_name + '.' + key if parent_name else key
|
||||
else:
|
||||
name = parent_name + '.' + cls_name + '_'+ key if parent_name else key
|
||||
# print(f'{parent_name}, key={key}, nc={num_named_children}, name={name}')
|
||||
name = parent_name + '.' + cls_name + '_'+ key if parent_name else key
|
||||
names[name] = module
|
||||
|
||||
if isinstance(module, torch.nn.Module):
|
||||
|
||||
Reference in New Issue
Block a user