Kentaro Wada 80169c48ae Fix typo
2017-05-21 08:09:10 +09:00
2017-05-21 05:30:59 +09:00
2017-05-21 08:09:10 +09:00
2017-05-21 06:58:28 +09:00
2017-05-21 08:09:10 +09:00
2017-05-21 06:58:28 +09:00
2017-05-21 05:28:14 +09:00

PyTorch for Numpy users.

Build Status

PyTorch version of Torch for Numpy users.

Types

Numpy PyTorch
np.ndarray torch.Tensor
np.float32 torch.FloatTensor
np.float64 torch.DoubleTensor
np.int8 torch.CharTensor
np.uint8 torch.ByteTensor
np.int16 torch.ShortTensor
np.int32 torch.IntTensor
np.int64 torch.LongTensor

Constructors

Ones and zeros

Numpy PyTorch
np.empty((2, 3)) torch.Tensor(2, 3)
np.empty_like(x) x.new(x.size()).type(x.type())
np.eye torch.eye
np.identity torch.eye
np.ones torch.ones
np.ones_like torch.ones(x.size()).type(x.type())
np.zeros torch.zeros
np.zeros_like torch.zeros(x.size()).type(x.type())

From existing data

Numpy PyTorch
np.array([[1, 2], [3, 4]]) torch.Tensor([[1, 2], [3, 4])
x.copy() x.clone()
np.fromfile(file) torch.Tensor(torch.Storage(file))
np.frombuffer
np.fromfunction
np.fromiter
np.fromstring
np.loadtxt
np.concatenate torch.cat

Numerical ranges

Numpy PyTorch
np.arange(10) torch.range(0, 9)
np.arange(2, 3, 0.1) torch.range(2, 2.9, 10)
np.linspace torch.linspace
np.logspace torch.logspace

Building matrices

Numpy PyTorch
np.diag torch.diag
np.tril torch.tril
np.triu torch.triu

Attributes

Numpy PyTorch
x.shape x.size()
x.strides x.stride()
x.ndim x.dim()
x.data x.data()
x.size x.nelement()
x.dtype x.type()

Indexing

Numpy PyTorch
x[0] x[0]
x[:, 0] x[:, 0]
x[indices] x[torch.LongTensor(indices)]
np.take(x, indices) x[torch.LongTensor(indices)]
x[x != 0] x[x != 0]

Shape manipulation

Numpy PyTorch
x.reshape x.view
x.resize x.resize_
x.resize_as_
x.transpose x.permute
x.flatten() x.view(-1)
x.squeeze x.squeeze
x[:, np.newaxis] x.unsqueeze(1)

Item selection and manipulation

Numpy PyTorch
np.put
x.repeat
x.tile x.repeat
np.choose
np.sort sorted, indices = torch.sort(x, [dim])
np.argsort sorted, indices = torch.sort(x, [dim])
np.nonzero torch.nonzero
np.where torch.nonzero

Calculation

Numpy PyTorch
x.min mins, indices = torch.min(x, [dim])
x.argmin mins, indices = torch.min(x, [dim])
x.max maxs, indices = torch.max(x, [dim])
x.argmax maxs, indices = torch.max(x, [dim])
x.clip
x.round x.round
np.floor(x) x.floor()
np.ceil(x) x.ceil()
x.trace x.trace
x.sum x.sum
x.cumsum x.cumsum
x.mean x.mean
x.std x.std
x.prod x.prod
x.cumprod x.cumprod
x.all (x == 1).sum() == x.nelement()
x.any (x == 1).sum() > 0

Arithmetic and comparison operations

Numpy PyTorch
x.lt x.lt
x.le x.le
x.gt x.gt
x.ge x.ge
x.eq x.eq
x.ne x.ne
S
Description
PyTorch for Numpy users.
Readme 40 KiB
Languages
Python 100%