Files
erikwijmans 5a5adc2b77 Updates
2018-01-06 12:13:52 -05:00

79 lines
2.1 KiB
Python

import torch
from enum import Enum
PDist2Order = Enum('PDist2Order', 'd_first d_second')
def pdist2(
X: torch.Tensor,
Z: torch.Tensor = None,
order: PDist2Order = PDist2Order.d_second
) -> torch.Tensor:
r""" Calculates the pairwise distance between X and Z
D[b, i, j] = l2 distance X[b, i] and Z[b, j]
Parameters
---------
X : torch.Tensor
X is a (B, N, d) tensor. There are B batches, and N vectors of dimension d
Z: torch.Tensor
Z is a (B, M, d) tensor. If Z is None, then Z = X
Returns
-------
torch.Tensor
Distance matrix is size (B, N, M)
"""
if order == PDist2Order.d_second:
if X.dim() == 2:
X = X.unsqueeze(0)
if Z is None:
Z = X
G = X @ Z.transpose(-2, -1)
S = (X * X).sum(-1, keepdim=True)
R = S.transpose(-2, -1)
else:
if Z.dim() == 2:
Z = Z.unsqueeze(0)
G = X @ Z.transpose(-2, -1)
S = (X * X).sum(-1, keepdim=True)
R = (Z * Z).sum(-1, keepdim=True).transpose(-2, -1)
else:
if X.dim() == 2:
X = X.unsqueeze(0)
if Z is None:
Z = X
G = X.transpose(-2, -1) @ Z
R = (X * X).sum(-2, keepdim=True)
S = R.transpose(-2, -1)
else:
if Z.dim() == 2:
Z = Z.unsqueeze(0)
G = X.transpose(-2, -1) @ Z
S = (X * X).sum(-2, keepdim=True).transpose(-2, -1)
R = (Z * Z).sum(-2, keepdim=True)
return torch.abs(R + S - 2 * G).squeeze(0)
def pdist2_slow(X, Z=None):
if Z is None: Z = X
D = torch.zeros(X.size(0), X.size(2), Z.size(2))
for b in range(D.size(0)):
for i in range(D.size(1)):
for j in range(D.size(2)):
D[b, i, j] = torch.dist(X[b, :, i], Z[b, :, j])
return D
if __name__ == "__main__":
X = torch.randn(2, 3, 5)
Z = torch.randn(2, 3, 3)
print(pdist2(X, order=PDist2Order.d_first))
print(pdist2_slow(X))
print(torch.dist(pdist2(X, order=PDist2Order.d_first), pdist2_slow(X)))