mirror of
https://github.com/wassname/torchsummaryX.git
synced 2026-06-27 16:32:27 +08:00
Merge pull request #7 from wassname/patch-2
ignore parameters with no gradient
This commit is contained in:
@@ -32,9 +32,10 @@ def summary(model, x, *args, **kwargs):
|
||||
|
||||
info["ksize"] = "-"
|
||||
info["inner"] = OrderedDict()
|
||||
info["params"], info["macs"] = 0, 0
|
||||
info["params_nt"], info["params"], info["macs"] = 0, 0, 0
|
||||
for name, param in module.named_parameters():
|
||||
info["params"] += param.nelement()
|
||||
info["params"] += param.nelement() * param.requires_grad
|
||||
info["params_nt"] += param.nelement() * (not param.requires_grad)
|
||||
|
||||
if name == "weight":
|
||||
ksize = list(param.size())
|
||||
@@ -76,33 +77,45 @@ def summary(model, x, *args, **kwargs):
|
||||
summary = OrderedDict()
|
||||
|
||||
model.apply(register_hook)
|
||||
|
||||
with torch.no_grad():
|
||||
model(x) if not (kwargs or args) else model(x, *args, **kwargs)
|
||||
|
||||
for hook in hooks:
|
||||
hook.remove()
|
||||
try:
|
||||
with torch.no_grad():
|
||||
model(x) if not (kwargs or args) else model(x, *args, **kwargs)
|
||||
finally:
|
||||
for hook in hooks:
|
||||
hook.remove()
|
||||
|
||||
# Use pandas to align the columns
|
||||
df = pd.DataFrame(summary).T
|
||||
df["Mult-Adds (M)"] = pd.to_numeric(df["macs"], errors="coerce")/1e6
|
||||
df["Params (K)"] = pd.to_numeric(df["params"], errors="coerce")/1e3
|
||||
|
||||
df["Mult-Adds"] = pd.to_numeric(df["macs"], errors="coerce")
|
||||
df["Params"] = pd.to_numeric(df["params"], errors="coerce")
|
||||
df["Non-trainable params"] = pd.to_numeric(df["params_nt"], errors="coerce")
|
||||
df = df.rename(columns=dict(
|
||||
ksize="Kernel Shape",
|
||||
out="Output Shape",
|
||||
))
|
||||
df.index.name = "Layer"
|
||||
df = df[["Kernel Shape", "Output Shape", "Params (K)", "Mult-Adds (M)"]]
|
||||
df_sum = df.sum()
|
||||
df.index.name = "Layer"
|
||||
|
||||
df = df[["Kernel Shape", "Output Shape", "Params", "Mult-Adds"]]
|
||||
|
||||
|
||||
max_repr_width = max([len(row) for row in df.to_string().split("\n")])
|
||||
|
||||
print("="*max_repr_width)
|
||||
print(df.replace(np.nan, "-"))
|
||||
print("-"*max_repr_width)
|
||||
print("Params (K): ", df_sum["Params (K)"])
|
||||
print("Mult-Adds (M): ", df_sum["Mult-Adds (M)"])
|
||||
print("="*max_repr_width)
|
||||
with pd.option_context("display.max_rows", 600, "display.max_columns", 10, 'display.float_format', pd.io.formats.format.EngFormatter(use_eng_prefix=True)):
|
||||
print("="*max_repr_width)
|
||||
print(df.replace(np.nan, "-"))
|
||||
print("-"*max_repr_width)
|
||||
df_total = pd.DataFrame(
|
||||
{"Total params": (df_sum["Params"] + df_sum["params_nt"]),
|
||||
"Trainable params": df_sum["Params"],
|
||||
"Non-trainable params": df_sum["params_nt"],
|
||||
"Mult-Adds": df_sum["Mult-Adds"]
|
||||
},
|
||||
index=['Totals']
|
||||
).T
|
||||
print(df_total)
|
||||
print("="*max_repr_width)
|
||||
|
||||
return df
|
||||
|
||||
|
||||
Reference in New Issue
Block a user