mirror of
https://github.com/wassname/torchsummaryX.git
synced 2026-06-27 18:07:44 +08:00
b5200b41dd
I thought perhaps we could add an example of the outputs on a large model. I noticed that some of these summary packages break down in the presence of large models, even though these are the ones you most need to summarise.
14 KiB
14 KiB
torchsummaryX
Improved visualization tool of torchsummary. Here, it visualizes kernel size, output shape, # params, and Mult-Adds. Also the torchsummaryX can handle RNN, Recursive NN, or model with multiple inputs.
Usage
pip install torchsummaryX and
from torchsummaryX import summary
summary(your_model, torch.zeros((1, 3, 224, 224)))
Args:
model(Module): Model to summarizex(Tensor): Input tensor of the model with [N, C, H, W] shape dtype and device have to match to the modelargs, kwargs: Other arguments used inmodel.forwardfunction
Examples
CNN for MNIST
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
self.conv2_drop = nn.Dropout2d()
self.fc1 = nn.Linear(320, 50)
self.fc2 = nn.Linear(50, 10)
def forward(self, x):
x = F.relu(F.max_pool2d(self.conv1(x), 2))
x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
x = x.view(-1, 320)
x = F.relu(self.fc1(x))
x = F.dropout(x, training=self.training)
x = self.fc2(x)
return F.log_softmax(x, dim=1)
summary(Net(), torch.zeros((1, 1, 28, 28)))
========================================================================
Kernel Shape Output Shape Params (K) Mult-Adds (M)
Layer
0_conv1 [1, 10, 5, 5] [1, 10, 24, 24] 0.26 0.144
1_conv2 [10, 20, 5, 5] [1, 20, 8, 8] 5.02 0.32
2_conv2_drop - [1, 20, 8, 8] - -
3_fc1 [320, 50] [1, 50] 16.05 0.016
4_fc2 [50, 10] [1, 10] 0.51 0.0005
------------------------------------------------------------------------
Params (K): 21.84
Mult-Adds (M): 0.4805
========================================================================
RNN
class Net(nn.Module):
def __init__(self,
vocab_size=20, embed_dim=300,
hidden_dim=512, num_layers=2):
super().__init__()
self.hidden_dim = hidden_dim
self.embedding = nn.Embedding(vocab_size, embed_dim)
self.encoder = nn.LSTM(embed_dim, hidden_dim,
num_layers=num_layers)
self.decoder = nn.Linear(hidden_dim, vocab_size)
def forward(self, x):
embed = self.embedding(x)
out, hidden = self.encoder(embed)
out = self.decoder(out)
out = out.view(-1, out.size(2))
return out, hidden
inputs = torch.zeros((100, 1), dtype=torch.long) # [length, batch_size]
summary(Net(), inputs)
==================================================================
Kernel Shape Output Shape Params (K) Mult-Adds (M)
Layer
0_embedding [300, 20] [100, 1, 300] 6.00 0.006000
1_encoder - [100, 1, 512] 3768.32 3.760128
2_decoder [512, 20] [100, 1, 20] 10.26 0.010240
------------------------------------------------------------------
Params (K): 3784.5800000000004
Mult-Adds (M): 3.7763679999999997
==================================================================
Recursive NN
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(64, 64, 3, 1, 1)
def forward(self, x):
out = self.conv1(x)
out = self.conv1(out)
return out
summary(Net(), torch.zeros((1, 64, 28, 28)))
===================================================================
Kernel Shape Output Shape Params (K) Mult-Adds (M)
Layer
0_conv1 [64, 64, 3, 3] [1, 64, 28, 28] 36.928 28.901376
1_conv1 [64, 64, 3, 3] [1, 64, 28, 28] - 28.901376
-------------------------------------------------------------------
Params (K): 36.928
Mult-Adds (M): 57.802752
===================================================================
Multiple arguments
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(64, 64, 3, 1, 1)
def forward(self, x, args1, args2):
out = self.conv1(x)
out = self.conv1(out)
return out
summary(Net(), torch.zeros((1, 64, 28, 28)), "args1", args2="args2")
===================================================================
Kernel Shape Output Shape Params (K) Mult-Adds (M)
Layer
0_conv1 [64, 64, 3, 3] [1, 64, 28, 28] 36.928 28.901376
1_conv1 [64, 64, 3, 3] [1, 64, 28, 28] - 28.901376
-------------------------------------------------------------------
Params (K): 36.928
Mult-Adds (M): 57.802752
===================================================================
Large models with long layer names
import torchvision
model = torchvision.models.resnet18()
summary(model, torch.zeros(4, 3, 224, 224))
Layer
0_conv1 [3, 64, 7, 7] [4, 64, 112, 112]
1_bn1 [64] [4, 64, 112, 112]
2_relu - [4, 64, 112, 112]
3_maxpool - [4, 64, 56, 56]
4_layer1.0.Conv2d_conv1 [64, 64, 3, 3] [4, 64, 56, 56]
5_layer1.0.BatchNorm2d_bn1 [64] [4, 64, 56, 56]
6_layer1.0.ReLU_relu - [4, 64, 56, 56]
7_layer1.0.Conv2d_conv2 [64, 64, 3, 3] [4, 64, 56, 56]
8_layer1.0.BatchNorm2d_bn2 [64] [4, 64, 56, 56]
9_layer1.0.ReLU_relu - [4, 64, 56, 56]
10_layer1.1.Conv2d_conv1 [64, 64, 3, 3] [4, 64, 56, 56]
11_layer1.1.BatchNorm2d_bn1 [64] [4, 64, 56, 56]
12_layer1.1.ReLU_relu - [4, 64, 56, 56]
13_layer1.1.Conv2d_conv2 [64, 64, 3, 3] [4, 64, 56, 56]
14_layer1.1.BatchNorm2d_bn2 [64] [4, 64, 56, 56]
15_layer1.1.ReLU_relu - [4, 64, 56, 56]
16_layer2.0.Conv2d_conv1 [64, 128, 3, 3] [4, 128, 28, 28]
17_layer2.0.BatchNorm2d_bn1 [128] [4, 128, 28, 28]
18_layer2.0.ReLU_relu - [4, 128, 28, 28]
19_layer2.0.Conv2d_conv2 [128, 128, 3, 3] [4, 128, 28, 28]
20_layer2.0.BatchNorm2d_bn2 [128] [4, 128, 28, 28]
21_layer2.0.downsample.Conv2d_0 [64, 128, 1, 1] [4, 128, 28, 28]
22_layer2.0.downsample.BatchNorm2d_1 [128] [4, 128, 28, 28]
23_layer2.0.ReLU_relu - [4, 128, 28, 28]
24_layer2.1.Conv2d_conv1 [128, 128, 3, 3] [4, 128, 28, 28]
25_layer2.1.BatchNorm2d_bn1 [128] [4, 128, 28, 28]
26_layer2.1.ReLU_relu - [4, 128, 28, 28]
27_layer2.1.Conv2d_conv2 [128, 128, 3, 3] [4, 128, 28, 28]
28_layer2.1.BatchNorm2d_bn2 [128] [4, 128, 28, 28]
29_layer2.1.ReLU_relu - [4, 128, 28, 28]
30_layer3.0.Conv2d_conv1 [128, 256, 3, 3] [4, 256, 14, 14]
31_layer3.0.BatchNorm2d_bn1 [256] [4, 256, 14, 14]
32_layer3.0.ReLU_relu - [4, 256, 14, 14]
33_layer3.0.Conv2d_conv2 [256, 256, 3, 3] [4, 256, 14, 14]
34_layer3.0.BatchNorm2d_bn2 [256] [4, 256, 14, 14]
35_layer3.0.downsample.Conv2d_0 [128, 256, 1, 1] [4, 256, 14, 14]
36_layer3.0.downsample.BatchNorm2d_1 [256] [4, 256, 14, 14]
37_layer3.0.ReLU_relu - [4, 256, 14, 14]
38_layer3.1.Conv2d_conv1 [256, 256, 3, 3] [4, 256, 14, 14]
39_layer3.1.BatchNorm2d_bn1 [256] [4, 256, 14, 14]
40_layer3.1.ReLU_relu - [4, 256, 14, 14]
41_layer3.1.Conv2d_conv2 [256, 256, 3, 3] [4, 256, 14, 14]
42_layer3.1.BatchNorm2d_bn2 [256] [4, 256, 14, 14]
43_layer3.1.ReLU_relu - [4, 256, 14, 14]
44_layer4.0.Conv2d_conv1 [256, 512, 3, 3] [4, 512, 7, 7]
45_layer4.0.BatchNorm2d_bn1 [512] [4, 512, 7, 7]
46_layer4.0.ReLU_relu - [4, 512, 7, 7]
47_layer4.0.Conv2d_conv2 [512, 512, 3, 3] [4, 512, 7, 7]
48_layer4.0.BatchNorm2d_bn2 [512] [4, 512, 7, 7]
49_layer4.0.downsample.Conv2d_0 [256, 512, 1, 1] [4, 512, 7, 7]
50_layer4.0.downsample.BatchNorm2d_1 [512] [4, 512, 7, 7]
51_layer4.0.ReLU_relu - [4, 512, 7, 7]
52_layer4.1.Conv2d_conv1 [512, 512, 3, 3] [4, 512, 7, 7]
53_layer4.1.BatchNorm2d_bn1 [512] [4, 512, 7, 7]
54_layer4.1.ReLU_relu - [4, 512, 7, 7]
55_layer4.1.Conv2d_conv2 [512, 512, 3, 3] [4, 512, 7, 7]
56_layer4.1.BatchNorm2d_bn2 [512] [4, 512, 7, 7]
57_layer4.1.ReLU_relu - [4, 512, 7, 7]
58_avgpool - [4, 512, 1, 1]
59_fc [512, 1000] [4, 1000]
Params (K) Mult-Adds (M)
Layer
0_conv1 9.408 118.014
1_bn1 0.128 6.4e-05
2_relu - -
3_maxpool - -
4_layer1.0.Conv2d_conv1 36.864 115.606
5_layer1.0.BatchNorm2d_bn1 0.128 6.4e-05
6_layer1.0.ReLU_relu - -
7_layer1.0.Conv2d_conv2 36.864 115.606
8_layer1.0.BatchNorm2d_bn2 0.128 6.4e-05
9_layer1.0.ReLU_relu - -
10_layer1.1.Conv2d_conv1 36.864 115.606
11_layer1.1.BatchNorm2d_bn1 0.128 6.4e-05
12_layer1.1.ReLU_relu - -
13_layer1.1.Conv2d_conv2 36.864 115.606
14_layer1.1.BatchNorm2d_bn2 0.128 6.4e-05
15_layer1.1.ReLU_relu - -
16_layer2.0.Conv2d_conv1 73.728 57.8028
17_layer2.0.BatchNorm2d_bn1 0.256 0.000128
18_layer2.0.ReLU_relu - -
19_layer2.0.Conv2d_conv2 147.456 115.606
20_layer2.0.BatchNorm2d_bn2 0.256 0.000128
21_layer2.0.downsample.Conv2d_0 8.192 6.42253
22_layer2.0.downsample.BatchNorm2d_1 0.256 0.000128
23_layer2.0.ReLU_relu - -
24_layer2.1.Conv2d_conv1 147.456 115.606
25_layer2.1.BatchNorm2d_bn1 0.256 0.000128
26_layer2.1.ReLU_relu - -
27_layer2.1.Conv2d_conv2 147.456 115.606
28_layer2.1.BatchNorm2d_bn2 0.256 0.000128
29_layer2.1.ReLU_relu - -
30_layer3.0.Conv2d_conv1 294.912 57.8028
31_layer3.0.BatchNorm2d_bn1 0.512 0.000256
32_layer3.0.ReLU_relu - -
33_layer3.0.Conv2d_conv2 589.824 115.606
34_layer3.0.BatchNorm2d_bn2 0.512 0.000256
35_layer3.0.downsample.Conv2d_0 32.768 6.42253
36_layer3.0.downsample.BatchNorm2d_1 0.512 0.000256
37_layer3.0.ReLU_relu - -
38_layer3.1.Conv2d_conv1 589.824 115.606
39_layer3.1.BatchNorm2d_bn1 0.512 0.000256
40_layer3.1.ReLU_relu - -
41_layer3.1.Conv2d_conv2 589.824 115.606
42_layer3.1.BatchNorm2d_bn2 0.512 0.000256
43_layer3.1.ReLU_relu - -
44_layer4.0.Conv2d_conv1 1179.65 57.8028
45_layer4.0.BatchNorm2d_bn1 1.024 0.000512
46_layer4.0.ReLU_relu - -
47_layer4.0.Conv2d_conv2 2359.3 115.606
48_layer4.0.BatchNorm2d_bn2 1.024 0.000512
49_layer4.0.downsample.Conv2d_0 131.072 6.42253
50_layer4.0.downsample.BatchNorm2d_1 1.024 0.000512
51_layer4.0.ReLU_relu - -
52_layer4.1.Conv2d_conv1 2359.3 115.606
53_layer4.1.BatchNorm2d_bn1 1.024 0.000512
54_layer4.1.ReLU_relu - -
55_layer4.1.Conv2d_conv2 2359.3 115.606
56_layer4.1.BatchNorm2d_bn2 1.024 0.000512
57_layer4.1.ReLU_relu - -
58_avgpool - -
59_fc 513 0.512
----------------------------------------------------------------------------------------------------
Params (K): 11689.511999999999
Mult-Adds (M): 1814.0781440000007
====================================================================================================