mirror of
https://github.com/wassname/torchsummaryX.git
synced 2026-06-30 12:22:49 +08:00
b5200b41dd
I thought perhaps we could add an example of the outputs on a large model. I noticed that some of these summary packages break down in the presence of large models, even though these are the ones you most need to summarise.
276 lines
14 KiB
Markdown
276 lines
14 KiB
Markdown
# torchsummaryX
|
|
Improved visualization tool of [torchsummary](https://github.com/sksq96/pytorch-summary). Here, it visualizes kernel size, output shape, # params, and Mult-Adds. Also the torchsummaryX can handle RNN, Recursive NN, or model with multiple inputs.
|
|
|
|
## Usage
|
|
`pip install torchsummaryX` and
|
|
|
|
```python
|
|
from torchsummaryX import summary
|
|
summary(your_model, torch.zeros((1, 3, 224, 224)))
|
|
```
|
|
Args:
|
|
- `model` (Module): Model to summarize
|
|
- `x` (Tensor): Input tensor of the model with [N, C, H, W] shape dtype and device have to match to the model
|
|
- `args, kwargs`: Other arguments used in `model.forward` function
|
|
|
|
## Examples
|
|
CNN for MNIST
|
|
```python
|
|
class Net(nn.Module):
|
|
def __init__(self):
|
|
super(Net, self).__init__()
|
|
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
|
|
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
|
|
self.conv2_drop = nn.Dropout2d()
|
|
self.fc1 = nn.Linear(320, 50)
|
|
self.fc2 = nn.Linear(50, 10)
|
|
|
|
def forward(self, x):
|
|
x = F.relu(F.max_pool2d(self.conv1(x), 2))
|
|
x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
|
|
x = x.view(-1, 320)
|
|
x = F.relu(self.fc1(x))
|
|
x = F.dropout(x, training=self.training)
|
|
x = self.fc2(x)
|
|
return F.log_softmax(x, dim=1)
|
|
summary(Net(), torch.zeros((1, 1, 28, 28)))
|
|
```
|
|
```
|
|
========================================================================
|
|
Kernel Shape Output Shape Params (K) Mult-Adds (M)
|
|
Layer
|
|
0_conv1 [1, 10, 5, 5] [1, 10, 24, 24] 0.26 0.144
|
|
1_conv2 [10, 20, 5, 5] [1, 20, 8, 8] 5.02 0.32
|
|
2_conv2_drop - [1, 20, 8, 8] - -
|
|
3_fc1 [320, 50] [1, 50] 16.05 0.016
|
|
4_fc2 [50, 10] [1, 10] 0.51 0.0005
|
|
------------------------------------------------------------------------
|
|
Params (K): 21.84
|
|
Mult-Adds (M): 0.4805
|
|
========================================================================
|
|
```
|
|
|
|
RNN
|
|
```python
|
|
class Net(nn.Module):
|
|
def __init__(self,
|
|
vocab_size=20, embed_dim=300,
|
|
hidden_dim=512, num_layers=2):
|
|
super().__init__()
|
|
|
|
self.hidden_dim = hidden_dim
|
|
self.embedding = nn.Embedding(vocab_size, embed_dim)
|
|
self.encoder = nn.LSTM(embed_dim, hidden_dim,
|
|
num_layers=num_layers)
|
|
self.decoder = nn.Linear(hidden_dim, vocab_size)
|
|
|
|
def forward(self, x):
|
|
embed = self.embedding(x)
|
|
out, hidden = self.encoder(embed)
|
|
out = self.decoder(out)
|
|
out = out.view(-1, out.size(2))
|
|
return out, hidden
|
|
inputs = torch.zeros((100, 1), dtype=torch.long) # [length, batch_size]
|
|
summary(Net(), inputs)
|
|
```
|
|
```
|
|
==================================================================
|
|
Kernel Shape Output Shape Params (K) Mult-Adds (M)
|
|
Layer
|
|
0_embedding [300, 20] [100, 1, 300] 6.00 0.006000
|
|
1_encoder - [100, 1, 512] 3768.32 3.760128
|
|
2_decoder [512, 20] [100, 1, 20] 10.26 0.010240
|
|
------------------------------------------------------------------
|
|
Params (K): 3784.5800000000004
|
|
Mult-Adds (M): 3.7763679999999997
|
|
==================================================================
|
|
```
|
|
|
|
Recursive NN
|
|
```python
|
|
class Net(nn.Module):
|
|
def __init__(self):
|
|
super(Net, self).__init__()
|
|
self.conv1 = nn.Conv2d(64, 64, 3, 1, 1)
|
|
|
|
def forward(self, x):
|
|
out = self.conv1(x)
|
|
out = self.conv1(out)
|
|
return out
|
|
summary(Net(), torch.zeros((1, 64, 28, 28)))
|
|
```
|
|
```
|
|
===================================================================
|
|
Kernel Shape Output Shape Params (K) Mult-Adds (M)
|
|
Layer
|
|
0_conv1 [64, 64, 3, 3] [1, 64, 28, 28] 36.928 28.901376
|
|
1_conv1 [64, 64, 3, 3] [1, 64, 28, 28] - 28.901376
|
|
-------------------------------------------------------------------
|
|
Params (K): 36.928
|
|
Mult-Adds (M): 57.802752
|
|
===================================================================
|
|
```
|
|
|
|
Multiple arguments
|
|
```python
|
|
class Net(nn.Module):
|
|
def __init__(self):
|
|
super(Net, self).__init__()
|
|
self.conv1 = nn.Conv2d(64, 64, 3, 1, 1)
|
|
|
|
def forward(self, x, args1, args2):
|
|
out = self.conv1(x)
|
|
out = self.conv1(out)
|
|
return out
|
|
summary(Net(), torch.zeros((1, 64, 28, 28)), "args1", args2="args2")
|
|
```
|
|
|
|
```
|
|
===================================================================
|
|
Kernel Shape Output Shape Params (K) Mult-Adds (M)
|
|
Layer
|
|
0_conv1 [64, 64, 3, 3] [1, 64, 28, 28] 36.928 28.901376
|
|
1_conv1 [64, 64, 3, 3] [1, 64, 28, 28] - 28.901376
|
|
-------------------------------------------------------------------
|
|
Params (K): 36.928
|
|
Mult-Adds (M): 57.802752
|
|
===================================================================
|
|
```
|
|
|
|
Large models with long layer names
|
|
```python
|
|
import torchvision
|
|
model = torchvision.models.resnet18()
|
|
summary(model, torch.zeros(4, 3, 224, 224))
|
|
```
|
|
```
|
|
Layer
|
|
0_conv1 [3, 64, 7, 7] [4, 64, 112, 112]
|
|
1_bn1 [64] [4, 64, 112, 112]
|
|
2_relu - [4, 64, 112, 112]
|
|
3_maxpool - [4, 64, 56, 56]
|
|
4_layer1.0.Conv2d_conv1 [64, 64, 3, 3] [4, 64, 56, 56]
|
|
5_layer1.0.BatchNorm2d_bn1 [64] [4, 64, 56, 56]
|
|
6_layer1.0.ReLU_relu - [4, 64, 56, 56]
|
|
7_layer1.0.Conv2d_conv2 [64, 64, 3, 3] [4, 64, 56, 56]
|
|
8_layer1.0.BatchNorm2d_bn2 [64] [4, 64, 56, 56]
|
|
9_layer1.0.ReLU_relu - [4, 64, 56, 56]
|
|
10_layer1.1.Conv2d_conv1 [64, 64, 3, 3] [4, 64, 56, 56]
|
|
11_layer1.1.BatchNorm2d_bn1 [64] [4, 64, 56, 56]
|
|
12_layer1.1.ReLU_relu - [4, 64, 56, 56]
|
|
13_layer1.1.Conv2d_conv2 [64, 64, 3, 3] [4, 64, 56, 56]
|
|
14_layer1.1.BatchNorm2d_bn2 [64] [4, 64, 56, 56]
|
|
15_layer1.1.ReLU_relu - [4, 64, 56, 56]
|
|
16_layer2.0.Conv2d_conv1 [64, 128, 3, 3] [4, 128, 28, 28]
|
|
17_layer2.0.BatchNorm2d_bn1 [128] [4, 128, 28, 28]
|
|
18_layer2.0.ReLU_relu - [4, 128, 28, 28]
|
|
19_layer2.0.Conv2d_conv2 [128, 128, 3, 3] [4, 128, 28, 28]
|
|
20_layer2.0.BatchNorm2d_bn2 [128] [4, 128, 28, 28]
|
|
21_layer2.0.downsample.Conv2d_0 [64, 128, 1, 1] [4, 128, 28, 28]
|
|
22_layer2.0.downsample.BatchNorm2d_1 [128] [4, 128, 28, 28]
|
|
23_layer2.0.ReLU_relu - [4, 128, 28, 28]
|
|
24_layer2.1.Conv2d_conv1 [128, 128, 3, 3] [4, 128, 28, 28]
|
|
25_layer2.1.BatchNorm2d_bn1 [128] [4, 128, 28, 28]
|
|
26_layer2.1.ReLU_relu - [4, 128, 28, 28]
|
|
27_layer2.1.Conv2d_conv2 [128, 128, 3, 3] [4, 128, 28, 28]
|
|
28_layer2.1.BatchNorm2d_bn2 [128] [4, 128, 28, 28]
|
|
29_layer2.1.ReLU_relu - [4, 128, 28, 28]
|
|
30_layer3.0.Conv2d_conv1 [128, 256, 3, 3] [4, 256, 14, 14]
|
|
31_layer3.0.BatchNorm2d_bn1 [256] [4, 256, 14, 14]
|
|
32_layer3.0.ReLU_relu - [4, 256, 14, 14]
|
|
33_layer3.0.Conv2d_conv2 [256, 256, 3, 3] [4, 256, 14, 14]
|
|
34_layer3.0.BatchNorm2d_bn2 [256] [4, 256, 14, 14]
|
|
35_layer3.0.downsample.Conv2d_0 [128, 256, 1, 1] [4, 256, 14, 14]
|
|
36_layer3.0.downsample.BatchNorm2d_1 [256] [4, 256, 14, 14]
|
|
37_layer3.0.ReLU_relu - [4, 256, 14, 14]
|
|
38_layer3.1.Conv2d_conv1 [256, 256, 3, 3] [4, 256, 14, 14]
|
|
39_layer3.1.BatchNorm2d_bn1 [256] [4, 256, 14, 14]
|
|
40_layer3.1.ReLU_relu - [4, 256, 14, 14]
|
|
41_layer3.1.Conv2d_conv2 [256, 256, 3, 3] [4, 256, 14, 14]
|
|
42_layer3.1.BatchNorm2d_bn2 [256] [4, 256, 14, 14]
|
|
43_layer3.1.ReLU_relu - [4, 256, 14, 14]
|
|
44_layer4.0.Conv2d_conv1 [256, 512, 3, 3] [4, 512, 7, 7]
|
|
45_layer4.0.BatchNorm2d_bn1 [512] [4, 512, 7, 7]
|
|
46_layer4.0.ReLU_relu - [4, 512, 7, 7]
|
|
47_layer4.0.Conv2d_conv2 [512, 512, 3, 3] [4, 512, 7, 7]
|
|
48_layer4.0.BatchNorm2d_bn2 [512] [4, 512, 7, 7]
|
|
49_layer4.0.downsample.Conv2d_0 [256, 512, 1, 1] [4, 512, 7, 7]
|
|
50_layer4.0.downsample.BatchNorm2d_1 [512] [4, 512, 7, 7]
|
|
51_layer4.0.ReLU_relu - [4, 512, 7, 7]
|
|
52_layer4.1.Conv2d_conv1 [512, 512, 3, 3] [4, 512, 7, 7]
|
|
53_layer4.1.BatchNorm2d_bn1 [512] [4, 512, 7, 7]
|
|
54_layer4.1.ReLU_relu - [4, 512, 7, 7]
|
|
55_layer4.1.Conv2d_conv2 [512, 512, 3, 3] [4, 512, 7, 7]
|
|
56_layer4.1.BatchNorm2d_bn2 [512] [4, 512, 7, 7]
|
|
57_layer4.1.ReLU_relu - [4, 512, 7, 7]
|
|
58_avgpool - [4, 512, 1, 1]
|
|
59_fc [512, 1000] [4, 1000]
|
|
|
|
Params (K) Mult-Adds (M)
|
|
Layer
|
|
0_conv1 9.408 118.014
|
|
1_bn1 0.128 6.4e-05
|
|
2_relu - -
|
|
3_maxpool - -
|
|
4_layer1.0.Conv2d_conv1 36.864 115.606
|
|
5_layer1.0.BatchNorm2d_bn1 0.128 6.4e-05
|
|
6_layer1.0.ReLU_relu - -
|
|
7_layer1.0.Conv2d_conv2 36.864 115.606
|
|
8_layer1.0.BatchNorm2d_bn2 0.128 6.4e-05
|
|
9_layer1.0.ReLU_relu - -
|
|
10_layer1.1.Conv2d_conv1 36.864 115.606
|
|
11_layer1.1.BatchNorm2d_bn1 0.128 6.4e-05
|
|
12_layer1.1.ReLU_relu - -
|
|
13_layer1.1.Conv2d_conv2 36.864 115.606
|
|
14_layer1.1.BatchNorm2d_bn2 0.128 6.4e-05
|
|
15_layer1.1.ReLU_relu - -
|
|
16_layer2.0.Conv2d_conv1 73.728 57.8028
|
|
17_layer2.0.BatchNorm2d_bn1 0.256 0.000128
|
|
18_layer2.0.ReLU_relu - -
|
|
19_layer2.0.Conv2d_conv2 147.456 115.606
|
|
20_layer2.0.BatchNorm2d_bn2 0.256 0.000128
|
|
21_layer2.0.downsample.Conv2d_0 8.192 6.42253
|
|
22_layer2.0.downsample.BatchNorm2d_1 0.256 0.000128
|
|
23_layer2.0.ReLU_relu - -
|
|
24_layer2.1.Conv2d_conv1 147.456 115.606
|
|
25_layer2.1.BatchNorm2d_bn1 0.256 0.000128
|
|
26_layer2.1.ReLU_relu - -
|
|
27_layer2.1.Conv2d_conv2 147.456 115.606
|
|
28_layer2.1.BatchNorm2d_bn2 0.256 0.000128
|
|
29_layer2.1.ReLU_relu - -
|
|
30_layer3.0.Conv2d_conv1 294.912 57.8028
|
|
31_layer3.0.BatchNorm2d_bn1 0.512 0.000256
|
|
32_layer3.0.ReLU_relu - -
|
|
33_layer3.0.Conv2d_conv2 589.824 115.606
|
|
34_layer3.0.BatchNorm2d_bn2 0.512 0.000256
|
|
35_layer3.0.downsample.Conv2d_0 32.768 6.42253
|
|
36_layer3.0.downsample.BatchNorm2d_1 0.512 0.000256
|
|
37_layer3.0.ReLU_relu - -
|
|
38_layer3.1.Conv2d_conv1 589.824 115.606
|
|
39_layer3.1.BatchNorm2d_bn1 0.512 0.000256
|
|
40_layer3.1.ReLU_relu - -
|
|
41_layer3.1.Conv2d_conv2 589.824 115.606
|
|
42_layer3.1.BatchNorm2d_bn2 0.512 0.000256
|
|
43_layer3.1.ReLU_relu - -
|
|
44_layer4.0.Conv2d_conv1 1179.65 57.8028
|
|
45_layer4.0.BatchNorm2d_bn1 1.024 0.000512
|
|
46_layer4.0.ReLU_relu - -
|
|
47_layer4.0.Conv2d_conv2 2359.3 115.606
|
|
48_layer4.0.BatchNorm2d_bn2 1.024 0.000512
|
|
49_layer4.0.downsample.Conv2d_0 131.072 6.42253
|
|
50_layer4.0.downsample.BatchNorm2d_1 1.024 0.000512
|
|
51_layer4.0.ReLU_relu - -
|
|
52_layer4.1.Conv2d_conv1 2359.3 115.606
|
|
53_layer4.1.BatchNorm2d_bn1 1.024 0.000512
|
|
54_layer4.1.ReLU_relu - -
|
|
55_layer4.1.Conv2d_conv2 2359.3 115.606
|
|
56_layer4.1.BatchNorm2d_bn2 1.024 0.000512
|
|
57_layer4.1.ReLU_relu - -
|
|
58_avgpool - -
|
|
59_fc 513 0.512
|
|
----------------------------------------------------------------------------------------------------
|
|
Params (K): 11689.511999999999
|
|
Mult-Adds (M): 1814.0781440000007
|
|
====================================================================================================
|
|
```
|