Complete as good as the torch for numpy users

This commit is contained in:
Kentaro Wada
2017-05-21 06:30:40 +09:00
parent 713698d91d
commit 10ee9c1e0b
3 changed files with 256 additions and 17 deletions
+112
View File
@@ -8,6 +8,7 @@
| Numpy | PyTorch |
|:-------------|:---------------------|
| `np.ndarray` | `torch.Tensor` |
| `np.float32` | `torch.FloatTensor` |
| `np.float64` | `torch.DoubleTensor` |
| `np.int8` | `torch.CharTensor` |
@@ -16,6 +17,7 @@
| `np.int32` | `torch.IntTensor` |
| `np.int64` | `torch.LongTensor` |
## Constructors
### Ones and zeros
@@ -31,4 +33,114 @@
| `np.zeros` | `torch.zeros` |
| `np.zeros_like` | `torch.zeros(x.size()).type(x.type())` |
### From existing data
| Numpy | PyTorch |
|:-----------------------------|:------------------------------------|
| `np.array([[1, 2], [3, 4]])` | `torch.Tensor([[1, 2], [3, 4])` |
| `x.copy()` | `x.clone()` |
| `np.fromfile(file)` | `torch.Tensor(torch.Storage(file))` |
| `np.frombuffer` | |
| `np.fromfunction` | |
| `np.fromiter` | |
| `np.fromstring` | |
| `np.loadtxt` | |
| `np.concatenate` | `torch.cat` |
### Numerical ranges
| Numpy | PyTorch |
|:-----------------------|:--------------------------|
| `np.arange(10)` | `torch.range(0, 9)` |
| `np.arange(2, 3, 0.1)` | `torch.range(2, 2.9, 10)` |
| `np.linspace` | `torch.linspace` |
| `np.logspace` | `np.logspace` |
### Building matrices
| Numpy | PyTorch |
|:----------|:-------------|
| `np.diag` | `torch.diag` |
| `np.tril` | `torch.tril` |
| `np.triu` | `torch.triu` |
### Attributes
| Numpy | PyTorch |
|:------------|:---------------|
| `x.shape` | `x.size()` |
| `x.strides` | `x.stride()` |
| `x.ndim` | `x.dim()` |
| `x.data` | `x.data()` |
| `x.size` | `x.nelement()` |
| `x.dtype` | `x.type()` |
### Indexing
| Numpy | PyTorch |
|:----------------------|:-------------------------------|
| `x[0]` | `x[0]` |
| `x[:, 0]` | `x[:, 0]` |
| `x[indices]` | `x[torch.LongTensor(indices)]` |
| `np.take(x, indices)` | `x[torch.LongTensor(indices)]` |
| `x[x != 0]` | `x[x != 0]` |
### Shape manipulation
| Numpy | PyTorch |
|:-------------------|:-----------------|
| `x.reshape` | `x.view` |
| `x.resize` | `x.resize_` |
| | `x.resize_as_` |
| `x.transpose` | `x.permute` |
| `x.flatten()` | `x.view(-1)` |
| `x.squeeze` | `x.squeeze` |
| `x[:, np.newaxis]` | `x.unsqueeze(1)` |
### Item selection and manipulation
| Numpy | PyTorch |
|:-------------|:-----------------------------------------|
| `np.put` | |
| `x.repeat` | |
| `x.tile` | `x.repeat` |
| `np.choose` | |
| `np.sort` | `sorted, indices = torch.sort(x, [dim])` |
| `np.argsort` | `sorted, indices = torch.sort(x, [dim])` |
| `np.nonzero` | `torch.nonzero` |
| `np.where` | `torch.nonzero` |
### Calculation
| Numpy | PyTorch |
|:------------|:--------------------------------------|
| `x.min` | `mins, indices = torch.min(x, [dim])` |
| `x.argmin` | `mins, indices = torch.min(x, [dim])` |
| `x.max` | `maxs, indices = torch.max(x, [dim])` |
| `x.argmax` | `maxs, indices = torch.max(x, [dim])` |
| `x.clip` | |
| `x.round` | `y.round` |
| | `y.floor` |
| `x.trace` | `y.trace` |
| `x.sum` | `y.sum` |
| `x.cumsum` | `y.cumsum` |
| `x.mean` | `x.mean` |
| `x.std` | `x.std` |
| `x.prod` | `x.prod` |
| `x.cumprod` | `x.cumprod` |
| `x.all` | `(y == 1).sum() == y.nelement()` |
| `x.any` | `(y == 1).sum() > 0` |
### Arithmetic and comparison operations
| Numpy | PyTorch |
|:--------|:----------|
| `x.lt` | `x.lt` |
| `x.le` | `x.le` |
| `x.gt` | `x.gt` |
| `x.ge` | `x.ge` |
| `x.eq` | `x.eq` |
| `x.ne` | `x.ne` |
+139 -12
View File
@@ -1,4 +1,6 @@
types:
- numpy: np.ndarray
pytorch: torch.Tensor
- numpy: np.float32
pytorch: torch.FloatTensor
- numpy: np.float64
@@ -32,15 +34,140 @@ constructors:
pytorch: torch.zeros
- numpy: np.zeros_like
pytorch: torch.zeros(x.size()).type(x.type())
# - numpy: x.astype(np.int32)
# pytorch: x.type(torch.IntTensor)
#
# - numpy: y = x.copy()
# pytorch: y = x.clone()
#
# - numpy: x.shape
# pytorch: x.size()
#
# - numpy: x.size
# pytorch: x.nelement()
from existing data:
- numpy: np.array([[1, 2], [3, 4]])
pytorch: torch.Tensor([[1, 2], [3, 4])
- numpy: x.copy()
pytorch: x.clone()
- numpy: np.fromfile(file)
pytorch: torch.Tensor(torch.Storage(file))
- numpy: np.frombuffer
pytorch:
- numpy: np.fromfunction
pytorch:
- numpy: np.fromiter
pytorch:
- numpy: np.fromstring
pytorch:
- numpy: np.loadtxt
pytorch:
- numpy: np.concatenate
pytorch: torch.cat
numerical ranges:
- numpy: np.arange(10)
pytorch: torch.range(0, 9)
- numpy: np.arange(2, 3, 0.1)
pytorch: torch.range(2, 2.9, 10)
- numpy: np.linspace
pytorch: torch.linspace
- numpy: np.logspace
pytorch: np.logspace
building matrices:
- numpy: np.diag
pytorch: torch.diag
- numpy: np.tril
pytorch: torch.tril
- numpy: np.triu
pytorch: torch.triu
attributes:
- numpy: x.shape
pytorch: x.size()
- numpy: x.strides
pytorch: x.stride()
- numpy: x.ndim
pytorch: x.dim()
- numpy: x.data
pytorch: x.data()
- numpy: x.size
pytorch: x.nelement()
- numpy: x.dtype
pytorch: x.type()
indexing:
- numpy: x[0]
pytorch: x[0]
- numpy: x[:, 0]
pytorch: x[:, 0]
- numpy: x[indices]
pytorch: x[torch.LongTensor(indices)]
- numpy: np.take(x, indices)
pytorch: x[torch.LongTensor(indices)]
- numpy: x[x != 0]
pytorch: x[x != 0]
shape manipulation:
- numpy: x.reshape
pytorch: x.view
- numpy: x.resize
pytorch: x.resize_
- numpy:
pytorch: x.resize_as_
- numpy: x.transpose
pytorch: x.permute
- numpy: x.flatten()
pytorch: x.view(-1)
- numpy: x.squeeze
pytorch: x.squeeze
- numpy: x[:, np.newaxis]
pytorch: x.unsqueeze(1)
item selection and manipulation:
- numpy: np.put
pytorch:
- numpy: x.repeat
pytorch:
- numpy: x.tile
pytorch: x.repeat
- numpy: np.choose
pytorch:
- numpy: np.sort
pytorch: sorted, indices = torch.sort(x, [dim])
- numpy: np.argsort
pytorch: sorted, indices = torch.sort(x, [dim])
- numpy: np.nonzero
pytorch: torch.nonzero
- numpy: np.where
pytorch: torch.nonzero
calculation:
- numpy: x.min
pytorch: mins, indices = torch.min(x, [dim])
- numpy: x.argmin
pytorch: mins, indices = torch.min(x, [dim])
- numpy: x.max
pytorch: maxs, indices = torch.max(x, [dim])
- numpy: x.argmax
pytorch: maxs, indices = torch.max(x, [dim])
- numpy: x.clip
pytorch:
- numpy: x.round
pytorch: y.round
- numpy:
pytorch: y.floor
- numpy: x.trace
pytorch: y.trace
- numpy: x.sum
pytorch: y.sum
- numpy: x.cumsum
pytorch: y.cumsum
- numpy: x.mean
pytorch: x.mean
- numpy: x.std
pytorch: x.std
- numpy: x.prod
pytorch: x.prod
- numpy: x.cumprod
pytorch: x.cumprod
- numpy: x.all
pytorch: (y == 1).sum() == y.nelement()
- numpy: x.any
pytorch: (y == 1).sum() > 0
arithmetic and comparison operations:
- numpy: x.lt
pytorch: x.lt
- numpy: x.le
pytorch: x.le
- numpy: x.gt
pytorch: x.gt
- numpy: x.ge
pytorch: x.ge
- numpy: x.eq
pytorch: x.eq
- numpy: x.ne
pytorch: x.ne
+5 -5
View File
@@ -31,13 +31,13 @@ def get_section(title, data, h=2):
headers = ['Numpy', 'PyTorch']
rows = []
for d in data:
rows.append([
'`' + d['numpy'] + '`',
'`' + d['pytorch'] + '`',
])
numpy = '`' + d['numpy'] + '`' if d['numpy'] is not None else ''
pytorch = '`' + d['pytorch'] + '`' if d['pytorch'] is not None else ''
rows.append([numpy, pytorch])
content = '%s %s\n\n' % ('#' * h, title.capitalize())
content += tabulate.tabulate(rows, headers=headers, tablefmt='pipe') + '\n'
content += tabulate.tabulate(rows, headers=headers, tablefmt='pipe')
content += '\n\n'
return content