diff --git a/torchsummaryX/torchsummaryX.py b/torchsummaryX/torchsummaryX.py index 29bad77..cdba769 100644 --- a/torchsummaryX/torchsummaryX.py +++ b/torchsummaryX/torchsummaryX.py @@ -32,9 +32,10 @@ def summary(model, x, *args, **kwargs): info["ksize"] = "-" info["inner"] = OrderedDict() - info["params"], info["macs"] = 0, 0 + info["params_nt"], info["params"], info["macs"] = 0, 0, 0 for name, param in module.named_parameters(): - info["params"] += param.nelement() + info["params"] += param.nelement() * param.requires_grad + info["params_nt"] += param.nelement() * (not param.requires_grad) if name == "weight": ksize = list(param.size()) @@ -76,33 +77,45 @@ def summary(model, x, *args, **kwargs): summary = OrderedDict() model.apply(register_hook) - - with torch.no_grad(): - model(x) if not (kwargs or args) else model(x, *args, **kwargs) - - for hook in hooks: - hook.remove() + try: + with torch.no_grad(): + model(x) if not (kwargs or args) else model(x, *args, **kwargs) + finally: + for hook in hooks: + hook.remove() # Use pandas to align the columns df = pd.DataFrame(summary).T - df["Mult-Adds (M)"] = pd.to_numeric(df["macs"], errors="coerce")/1e6 - df["Params (K)"] = pd.to_numeric(df["params"], errors="coerce")/1e3 + + df["Mult-Adds"] = pd.to_numeric(df["macs"], errors="coerce") + df["Params"] = pd.to_numeric(df["params"], errors="coerce") + df["Non-trainable params"] = pd.to_numeric(df["params_nt"], errors="coerce") df = df.rename(columns=dict( ksize="Kernel Shape", out="Output Shape", )) - df.index.name = "Layer" - df = df[["Kernel Shape", "Output Shape", "Params (K)", "Mult-Adds (M)"]] df_sum = df.sum() + df.index.name = "Layer" + + df = df[["Kernel Shape", "Output Shape", "Params", "Mult-Adds"]] + max_repr_width = max([len(row) for row in df.to_string().split("\n")]) - print("="*max_repr_width) - print(df.replace(np.nan, "-")) - print("-"*max_repr_width) - print("Params (K): ", df_sum["Params (K)"]) - print("Mult-Adds (M): ", df_sum["Mult-Adds (M)"]) - print("="*max_repr_width) + with pd.option_context("display.max_rows", 600, "display.max_columns", 10, 'display.float_format', pd.io.formats.format.EngFormatter(use_eng_prefix=True)): + print("="*max_repr_width) + print(df.replace(np.nan, "-")) + print("-"*max_repr_width) + df_total = pd.DataFrame( + {"Total params": (df_sum["Params"] + df_sum["params_nt"]), + "Trainable params": df_sum["Params"], + "Non-trainable params": df_sum["params_nt"], + "Mult-Adds": df_sum["Mult-Adds"] + }, + index=['Totals'] + ).T + print(df_total) + print("="*max_repr_width) return df