mirror of
https://github.com/wassname/torchsummaryX.git
synced 2026-06-27 16:32:27 +08:00
Merge branch 'master' of https://github.com/nmhkahn/torchsummaryX
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
# torchsummaryX
|
||||
Improved visualization tool of [torchsummary](https://github.com/sksq96/pytorch-summary). Here, we visualize kernel size, output shape, # params, and also Mult-Adds. Moreover, torchsummaryX can handle RNN, Recursive NN, or model with multiple inputs.
|
||||
Improved visualization tool of [torchsummary](https://github.com/sksq96/pytorch-summary). Here, it visualizes kernel size, output shape, # params, and Mult-Adds. Also the torchsummaryX can handle RNN, Recursive NN, or model with multiple inputs.
|
||||
|
||||
## Usage
|
||||
```python
|
||||
@@ -57,22 +57,16 @@ class Net(nn.Module):
|
||||
super().__init__()
|
||||
|
||||
self.hidden_dim = hidden_dim
|
||||
|
||||
self.embedding = nn.Embedding(vocab_size, embed_dim)
|
||||
self.encoder = nn.LSTM(embed_dim, hidden_dim,
|
||||
num_layers=num_layers)
|
||||
|
||||
self.decoder = nn.Linear(hidden_dim, vocab_size)
|
||||
|
||||
def forward(self, x):
|
||||
batch_size = x.size(0)
|
||||
|
||||
embed = self.embedding(x)
|
||||
out, hidden = self.encoder(embed)
|
||||
|
||||
out = self.decoder(out)
|
||||
out = out.view(-1, out.size(2))
|
||||
|
||||
return out, hidden
|
||||
inputs = torch.zeros((100, 1), dtype=torch.long) # [length, batch_size]
|
||||
summary(Net(), inputs)
|
||||
|
||||
Reference in New Issue
Block a user