From 7e08acaf7af97a5bf3cf40a378a056eff991d8ac Mon Sep 17 00:00:00 2001 From: Namhyuk Ahn Date: Sun, 3 Feb 2019 16:36:39 +0900 Subject: [PATCH] Apply lint and use no_grad --- torchsummaryX/torchsummaryX.py | 34 +++++++++++++++++----------------- 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/torchsummaryX/torchsummaryX.py b/torchsummaryX/torchsummaryX.py index 7e0af2c..db15310 100644 --- a/torchsummaryX/torchsummaryX.py +++ b/torchsummaryX/torchsummaryX.py @@ -1,11 +1,10 @@ +from collections import OrderedDict import numpy as np import torch -import torch.nn as nn -from collections import OrderedDict def summary(model, x, *args, **kwargs): """Summarize the given input model. - Summarized information are 1) output shape, 2) kernel shape, + Summarized information are 1) output shape, 2) kernel shape, 3) number of the parameters and 4) operations (Mult-Adds) Args: @@ -33,36 +32,36 @@ def summary(model, x, *args, **kwargs): for name, param in module.named_parameters(): info["params"] += param.nelement() - if "weight" == name: + if name == "weight": ksize = list(param.size()) # to make [in_shape, out_shape, ksize, ksize] if len(ksize) > 1: ksize[0], ksize[1] = ksize[1], ksize[0] - info["ksize"] = ksize + info["ksize"] = ksize # ignore N, C when calculate Mult-Adds in ConvNd if "Conv" in cls_name: info["macs"] += int(param.nelement() * np.prod(info["out"][2:])) else: info["macs"] += param.nelement() - + # RNN modules have inner weights such as weight_ih_l0 elif "weight" in name: info["inner"][name] = list(param.size()) info["macs"] += param.nelement() - + # if the current module is already-used, mark as "(recursive)" # check if this module has params if list(module.named_parameters()): for v in summary.values(): if info["id"] == v["id"]: info["params"] = "(recursive)" - + if info["params"] == 0: info["params"], info["macs"] = "-", "-" summary[key] = info - + # ignore Sequential and ModuleList if not module._modules: hooks.append(module.register_forward_hook(hook)) @@ -71,7 +70,8 @@ def summary(model, x, *args, **kwargs): summary = OrderedDict() model.apply(register_hook) - model(x) if not (kwargs or args) else model(x, *args, **kwargs) + with torch.no_grad(): + model(x) if not (kwargs or args) else model(x, *args, **kwargs) for hook in hooks: hook.remove() @@ -84,25 +84,25 @@ def summary(model, x, *args, **kwargs): total_params, total_macs = 0, 0 for layer, info in summary.items(): - repr_ksize = str(info["ksize"]) - repr_out = str(info["out"]) + repr_ksize = str(info["ksize"]) + repr_out = str(info["out"]) repr_params = info["params"] - repr_macs = info["macs"] + repr_macs = info["macs"] if isinstance(repr_params, (int, float)): total_params += repr_params - repr_params = "{0:,.2f}".format(repr_params/1000) + repr_params = "{0:,.2f}".format(repr_params/1000) if isinstance(repr_macs, (int, float)): total_macs += repr_macs - repr_macs = "{0:,.2f}".format(repr_macs/1000000) - + repr_macs = "{0:,.2f}".format(repr_macs/1000000) + print("{:<15} {:>20} {:>20} {:>20} {:>20}" .format(layer, repr_ksize, repr_out, repr_params, repr_macs)) # for RNN, describe inner weights (i.e. w_hh, w_ih) for inner_name, inner_shape in info["inner"].items(): print(" {:<13} {:>20}".format(inner_name, str(inner_shape))) - + print("="*100) print("# Params: {0:,.2f}K".format(total_params/1000)) print("# Mult-Adds: {0:,.2f}M".format(total_macs/1000000))