This commit is contained in:
Namhyuk Ahn
2018-07-29 01:56:00 +09:00
+1 -7
View File
@@ -1,5 +1,5 @@
# torchsummaryX
Improved visualization tool of [torchsummary](https://github.com/sksq96/pytorch-summary). Here, we visualize kernel size, output shape, # params, and also Mult-Adds. Moreover, torchsummaryX can handle RNN, Recursive NN, or model with multiple inputs.
Improved visualization tool of [torchsummary](https://github.com/sksq96/pytorch-summary). Here, it visualizes kernel size, output shape, # params, and Mult-Adds. Also the torchsummaryX can handle RNN, Recursive NN, or model with multiple inputs.
## Usage
```python
@@ -57,22 +57,16 @@ class Net(nn.Module):
super().__init__()
self.hidden_dim = hidden_dim
self.embedding = nn.Embedding(vocab_size, embed_dim)
self.encoder = nn.LSTM(embed_dim, hidden_dim,
num_layers=num_layers)
self.decoder = nn.Linear(hidden_dim, vocab_size)
def forward(self, x):
batch_size = x.size(0)
embed = self.embedding(x)
out, hidden = self.encoder(embed)
out = self.decoder(out)
out = out.view(-1, out.size(2))
return out, hidden
inputs = torch.zeros((100, 1), dtype=torch.long) # [length, batch_size]
summary(Net(), inputs)