Apply lint and use no_grad

This commit is contained in:
Namhyuk Ahn
2019-02-03 16:36:39 +09:00
parent 51fdd3c638
commit 7e08acaf7a
+17 -17
View File
@@ -1,11 +1,10 @@
from collections import OrderedDict
import numpy as np
import torch
import torch.nn as nn
from collections import OrderedDict
def summary(model, x, *args, **kwargs):
"""Summarize the given input model.
Summarized information are 1) output shape, 2) kernel shape,
Summarized information are 1) output shape, 2) kernel shape,
3) number of the parameters and 4) operations (Mult-Adds)
Args:
@@ -33,36 +32,36 @@ def summary(model, x, *args, **kwargs):
for name, param in module.named_parameters():
info["params"] += param.nelement()
if "weight" == name:
if name == "weight":
ksize = list(param.size())
# to make [in_shape, out_shape, ksize, ksize]
if len(ksize) > 1:
ksize[0], ksize[1] = ksize[1], ksize[0]
info["ksize"] = ksize
info["ksize"] = ksize
# ignore N, C when calculate Mult-Adds in ConvNd
if "Conv" in cls_name:
info["macs"] += int(param.nelement() * np.prod(info["out"][2:]))
else:
info["macs"] += param.nelement()
# RNN modules have inner weights such as weight_ih_l0
elif "weight" in name:
info["inner"][name] = list(param.size())
info["macs"] += param.nelement()
# if the current module is already-used, mark as "(recursive)"
# check if this module has params
if list(module.named_parameters()):
for v in summary.values():
if info["id"] == v["id"]:
info["params"] = "(recursive)"
if info["params"] == 0:
info["params"], info["macs"] = "-", "-"
summary[key] = info
# ignore Sequential and ModuleList
if not module._modules:
hooks.append(module.register_forward_hook(hook))
@@ -71,7 +70,8 @@ def summary(model, x, *args, **kwargs):
summary = OrderedDict()
model.apply(register_hook)
model(x) if not (kwargs or args) else model(x, *args, **kwargs)
with torch.no_grad():
model(x) if not (kwargs or args) else model(x, *args, **kwargs)
for hook in hooks:
hook.remove()
@@ -84,25 +84,25 @@ def summary(model, x, *args, **kwargs):
total_params, total_macs = 0, 0
for layer, info in summary.items():
repr_ksize = str(info["ksize"])
repr_out = str(info["out"])
repr_ksize = str(info["ksize"])
repr_out = str(info["out"])
repr_params = info["params"]
repr_macs = info["macs"]
repr_macs = info["macs"]
if isinstance(repr_params, (int, float)):
total_params += repr_params
repr_params = "{0:,.2f}".format(repr_params/1000)
repr_params = "{0:,.2f}".format(repr_params/1000)
if isinstance(repr_macs, (int, float)):
total_macs += repr_macs
repr_macs = "{0:,.2f}".format(repr_macs/1000000)
repr_macs = "{0:,.2f}".format(repr_macs/1000000)
print("{:<15} {:>20} {:>20} {:>20} {:>20}"
.format(layer, repr_ksize, repr_out, repr_params, repr_macs))
# for RNN, describe inner weights (i.e. w_hh, w_ih)
for inner_name, inner_shape in info["inner"].items():
print(" {:<13} {:>20}".format(inner_name, str(inner_shape)))
print("="*100)
print("# Params: {0:,.2f}K".format(total_params/1000))
print("# Mult-Adds: {0:,.2f}M".format(total_macs/1000000))