diff --git a/README.md b/README.md index ef07fc1..636bd9b 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,5 @@ # torchsummaryX -Improved visualization tool of [torchsummary](https://github.com/sksq96/pytorch-summary). Here, we visualize kernel size, output shape, # params, and also Mult-Adds. Moreover, torchsummaryX can handle RNN, Recursive NN, or model with multiple inputs. +Improved visualization tool of [torchsummary](https://github.com/sksq96/pytorch-summary). Here, it visualizes kernel size, output shape, # params, and Mult-Adds. Also the torchsummaryX can handle RNN, Recursive NN, or model with multiple inputs. ## Usage ```python @@ -57,22 +57,16 @@ class Net(nn.Module): super().__init__() self.hidden_dim = hidden_dim - self.embedding = nn.Embedding(vocab_size, embed_dim) self.encoder = nn.LSTM(embed_dim, hidden_dim, num_layers=num_layers) - self.decoder = nn.Linear(hidden_dim, vocab_size) def forward(self, x): - batch_size = x.size(0) - embed = self.embedding(x) out, hidden = self.encoder(embed) - out = self.decoder(out) out = out.view(-1, out.size(2)) - return out, hidden inputs = torch.zeros((100, 1), dtype=torch.long) # [length, batch_size] summary(Net(), inputs)