fix duplicate keys

This commit is contained in:
Mike Clark
2019-06-05 05:39:44 +00:00
committed by GitHub
parent 24cdc4958c
commit 2db8410f43
+10 -5
View File
@@ -16,11 +16,12 @@ def summary(model, x, *args, **kwargs):
def register_hook(module):
def hook(module, inputs, outputs):
cls_name = str(module.__class__).split(".")[-1].split("'")[0]
module_idx = len(summary)
# Lookup name in a dict that includes parents
for name, item in module_names.items():
if item == module:
key = name
key = '{}_{}'.format(module_idx, name)
info = OrderedDict()
info["id"] = id(module)
@@ -85,15 +86,20 @@ def summary(model, x, *args, **kwargs):
# Use pandas to align the columns
df = pd.DataFrame(summary).T
df['Mult-Adds (M)'] = pd.to_numeric(df['macs'], errors='coerce')/1e6
df['Mult-Adds (M)'] = df['Mult-Adds (M)'].replace(np.nan, '-')
df['Params (K)'] = pd.to_numeric(df['params'], errors='coerce')/1e3
df = df.rename(columns=dict(
ksize='Kernel Shape',
out='Output Shape',
params='Params (K)',
))
df.index.name = 'Layer'
df = df[['Kernel Shape', 'Output Shape', 'Params (K)', 'Mult-Adds (M)']]
print("="*100)
print(df)
print("="*100)
print(df.sum())
print("="*100)
return df
def get_names_dict(model):
@@ -107,8 +113,7 @@ def get_names_dict(model):
if num_named_children>0:
name = parent_name + '.' + key if parent_name else key
else:
name = parent_name + '.' + cls_name + '_'+ key if parent_name else key
# print(f'{parent_name}, key={key}, nc={num_named_children}, name={name}')
name = parent_name + '.' + cls_name + '_'+ key if parent_name else key
names[name] = module
if isinstance(module, torch.nn.Module):