mirror of
https://github.com/wassname/pytorch-for-numpy-users.git
synced 2026-06-27 16:10:21 +08:00
Complete as good as the torch for numpy users
This commit is contained in:
+139
-12
@@ -1,4 +1,6 @@
|
||||
types:
|
||||
- numpy: np.ndarray
|
||||
pytorch: torch.Tensor
|
||||
- numpy: np.float32
|
||||
pytorch: torch.FloatTensor
|
||||
- numpy: np.float64
|
||||
@@ -32,15 +34,140 @@ constructors:
|
||||
pytorch: torch.zeros
|
||||
- numpy: np.zeros_like
|
||||
pytorch: torch.zeros(x.size()).type(x.type())
|
||||
|
||||
# - numpy: x.astype(np.int32)
|
||||
# pytorch: x.type(torch.IntTensor)
|
||||
#
|
||||
# - numpy: y = x.copy()
|
||||
# pytorch: y = x.clone()
|
||||
#
|
||||
# - numpy: x.shape
|
||||
# pytorch: x.size()
|
||||
#
|
||||
# - numpy: x.size
|
||||
# pytorch: x.nelement()
|
||||
from existing data:
|
||||
- numpy: np.array([[1, 2], [3, 4]])
|
||||
pytorch: torch.Tensor([[1, 2], [3, 4])
|
||||
- numpy: x.copy()
|
||||
pytorch: x.clone()
|
||||
- numpy: np.fromfile(file)
|
||||
pytorch: torch.Tensor(torch.Storage(file))
|
||||
- numpy: np.frombuffer
|
||||
pytorch:
|
||||
- numpy: np.fromfunction
|
||||
pytorch:
|
||||
- numpy: np.fromiter
|
||||
pytorch:
|
||||
- numpy: np.fromstring
|
||||
pytorch:
|
||||
- numpy: np.loadtxt
|
||||
pytorch:
|
||||
- numpy: np.concatenate
|
||||
pytorch: torch.cat
|
||||
numerical ranges:
|
||||
- numpy: np.arange(10)
|
||||
pytorch: torch.range(0, 9)
|
||||
- numpy: np.arange(2, 3, 0.1)
|
||||
pytorch: torch.range(2, 2.9, 10)
|
||||
- numpy: np.linspace
|
||||
pytorch: torch.linspace
|
||||
- numpy: np.logspace
|
||||
pytorch: np.logspace
|
||||
building matrices:
|
||||
- numpy: np.diag
|
||||
pytorch: torch.diag
|
||||
- numpy: np.tril
|
||||
pytorch: torch.tril
|
||||
- numpy: np.triu
|
||||
pytorch: torch.triu
|
||||
attributes:
|
||||
- numpy: x.shape
|
||||
pytorch: x.size()
|
||||
- numpy: x.strides
|
||||
pytorch: x.stride()
|
||||
- numpy: x.ndim
|
||||
pytorch: x.dim()
|
||||
- numpy: x.data
|
||||
pytorch: x.data()
|
||||
- numpy: x.size
|
||||
pytorch: x.nelement()
|
||||
- numpy: x.dtype
|
||||
pytorch: x.type()
|
||||
indexing:
|
||||
- numpy: x[0]
|
||||
pytorch: x[0]
|
||||
- numpy: x[:, 0]
|
||||
pytorch: x[:, 0]
|
||||
- numpy: x[indices]
|
||||
pytorch: x[torch.LongTensor(indices)]
|
||||
- numpy: np.take(x, indices)
|
||||
pytorch: x[torch.LongTensor(indices)]
|
||||
- numpy: x[x != 0]
|
||||
pytorch: x[x != 0]
|
||||
shape manipulation:
|
||||
- numpy: x.reshape
|
||||
pytorch: x.view
|
||||
- numpy: x.resize
|
||||
pytorch: x.resize_
|
||||
- numpy:
|
||||
pytorch: x.resize_as_
|
||||
- numpy: x.transpose
|
||||
pytorch: x.permute
|
||||
- numpy: x.flatten()
|
||||
pytorch: x.view(-1)
|
||||
- numpy: x.squeeze
|
||||
pytorch: x.squeeze
|
||||
- numpy: x[:, np.newaxis]
|
||||
pytorch: x.unsqueeze(1)
|
||||
item selection and manipulation:
|
||||
- numpy: np.put
|
||||
pytorch:
|
||||
- numpy: x.repeat
|
||||
pytorch:
|
||||
- numpy: x.tile
|
||||
pytorch: x.repeat
|
||||
- numpy: np.choose
|
||||
pytorch:
|
||||
- numpy: np.sort
|
||||
pytorch: sorted, indices = torch.sort(x, [dim])
|
||||
- numpy: np.argsort
|
||||
pytorch: sorted, indices = torch.sort(x, [dim])
|
||||
- numpy: np.nonzero
|
||||
pytorch: torch.nonzero
|
||||
- numpy: np.where
|
||||
pytorch: torch.nonzero
|
||||
calculation:
|
||||
- numpy: x.min
|
||||
pytorch: mins, indices = torch.min(x, [dim])
|
||||
- numpy: x.argmin
|
||||
pytorch: mins, indices = torch.min(x, [dim])
|
||||
- numpy: x.max
|
||||
pytorch: maxs, indices = torch.max(x, [dim])
|
||||
- numpy: x.argmax
|
||||
pytorch: maxs, indices = torch.max(x, [dim])
|
||||
- numpy: x.clip
|
||||
pytorch:
|
||||
- numpy: x.round
|
||||
pytorch: y.round
|
||||
- numpy:
|
||||
pytorch: y.floor
|
||||
- numpy: x.trace
|
||||
pytorch: y.trace
|
||||
- numpy: x.sum
|
||||
pytorch: y.sum
|
||||
- numpy: x.cumsum
|
||||
pytorch: y.cumsum
|
||||
- numpy: x.mean
|
||||
pytorch: x.mean
|
||||
- numpy: x.std
|
||||
pytorch: x.std
|
||||
- numpy: x.prod
|
||||
pytorch: x.prod
|
||||
- numpy: x.cumprod
|
||||
pytorch: x.cumprod
|
||||
- numpy: x.all
|
||||
pytorch: (y == 1).sum() == y.nelement()
|
||||
- numpy: x.any
|
||||
pytorch: (y == 1).sum() > 0
|
||||
arithmetic and comparison operations:
|
||||
- numpy: x.lt
|
||||
pytorch: x.lt
|
||||
- numpy: x.le
|
||||
pytorch: x.le
|
||||
- numpy: x.gt
|
||||
pytorch: x.gt
|
||||
- numpy: x.ge
|
||||
pytorch: x.ge
|
||||
- numpy: x.eq
|
||||
pytorch: x.eq
|
||||
- numpy: x.ne
|
||||
pytorch: x.ne
|
||||
|
||||
Reference in New Issue
Block a user