diff --git a/data/data_utils.py b/data/data_utils.py new file mode 100644 index 0000000..eb47c33 --- /dev/null +++ b/data/data_utils.py @@ -0,0 +1,148 @@ +import torch +import numpy as np + + +def angle_axis(angle: float, axis: np.ndarray): + r"""Returns a 4x4 rotation matrix that performs a rotation around axis by angle + + Parameters + ---------- + angle : float + Angle to rotate by + axis: np.ndarray + Axis to rotate about + + Returns + ------- + torch.Tensor + 3x3 rotation matrix + """ + u = axis / np.linalg.norm(axis) + cosval, sinval = np.cos(angle), np.sin(angle) + + # yapf: disable + cross_prod_mat = np.array([[0.0, -u[2], u[1]], + [u[2], 0.0, -u[0]], + [-u[1], u[0], 0.0]]) + + R = torch.from_numpy( + cosval * np.eye(3) + + sinval * cross_prod_mat + + (1.0 - cosval) * np.outer(u, u) + ) + # yapf: enable + return R + + +class PointcloudScale(object): + + def __init__(self, lo=0.8, hi=1.25): + self.lo, self.hi = lo, hi + + def __call__(self, points): + scaler = np.random.uniform(self.lo, self.hi) + points[:, 0:3] *= scaler + return points + + +class PointcloudRotate(object): + + def __init__(self, axis=np.array([0.0, 1.0, 0.0])): + self.axis = axis + + def __call__(self, points): + rotation_angle = np.random.uniform() * 2 * np.pi + rotation_matrix = angle_axis(rotation_angle, self.axis) + + normals = points.size(1) > 3 + if not normals: + return points @ rotation_matrix.t() + else: + pc_xyz = points[:, 0:3] + pc_normals = points[:, 3:] + points[:, 0:3] = pc_xyz @ rotation_matrix.t() + points[:, 3:] = pc_normals @ rotation_matrix.t() + + return points + + +class PointcloudRotatePerturbation(object): + + def __init__(self, angle_sigma=0.06, angle_clip=0.18): + self.angle_sigma, self.angle_clip = angle_sigma, angle_clip + + def _get_angles(self): + angles = np.clip( + self.angle_sigma * np.random.randn(3), -self.angle_clip, + self.angle_clip + ) + + return angles + + def __call__(self, points): + angles = self._get_angles() + Rx = angle_axis(angles[0], np.array([1.0, 0.0, 0.0])) + Ry = angle_axis(angles[1], np.array([0.0, 1.0, 0.0])) + Rz = angle_axis(angles[2], np.array([0.0, 0.0, 1.0])) + + rotation_matrix = Rz @ Ry @ Rx + + normals = points.size(1) > 3 + if not normals: + return points @ rotation_matrix.t() + else: + pc_xyz = points[:, 0:3] + pc_normals = points[:, 3:] + points[:, 0:3] = pc_xyz @ rotation_matrix.t() + points[:, 3:] = pc_normals @ rotation_matrix.t() + + return points + + +class PointcloudJitter(object): + + def __init__(self, std=0.01, clip=0.05): + self.std, self.clip = std, clip + + def __call__(self, points): + jittered_data = points.new(points.size(0), 3).normal_( + mean=0.0, std=self.std + ).clamp_(-self.clip, self.clip) + points[:, 0:3] += jittered_data + return points + + +class PointcloudTranslate(object): + + def __init__(self, translate_range=0.1): + self.translate_range = translate_range + + def __call__(self, points): + translation = np.random.uniform( + -self.translate_range, self.translate_range + ) + points[:, 0:3] += translation + return points + + +class PointcloudToTensor(object): + + def __call__(self, points): + return torch.from_numpy(points).float() + + +class PointcloudRandomInputDropout(object): + + def __init__(self, max_dropout_ratio=0.875): + assert max_dropout_ratio >= 0 and max_dropout_ratio < 1 + self.max_dropout_ratio = max_dropout_ratio + + def __call__(self, points): + pc = points.numpy() + + dropout_ratio = np.random.random() * self.max_dropout_ratio # 0~0.875 + drop_idx = np.where(np.random.random((pc.shape[0])) <= dropout_ratio)[0] + if len(drop_idx) > 0: + pc[drop_idx] = pc[0] # set to the first point + + return torch.from_numpy(pc).float() diff --git a/train_cls.py b/train_cls.py index 74cd832..b0594c5 100644 --- a/train_cls.py +++ b/train_cls.py @@ -13,7 +13,7 @@ from models import Pointnet2ClsMSG as Pointnet from models.Pointnet2Cls import model_fn_decorator from data import ModelNet40Cls import utils.pytorch_utils as pt_utils -import utils.data_utils as d_utils +import data.data_utils as d_utils import argparse torch.backends.cudnn.enabled = True @@ -85,9 +85,12 @@ if __name__ == "__main__": transforms = transforms.Compose([ d_utils.PointcloudToTensor(), - d_utils.PointcloudRotate(x_axis=True, z_axis=True), + d_utils.PointcloudScale(), + d_utils.PointcloudRotate(), + d_utils.PointcloudRotatePerturbation(), d_utils.PointcloudTranslate(), - d_utils.PointcloudJitter() + d_utils.PointcloudJitter(), + d_utils.PointcloudRandomInputDropout() ]) test_set = ModelNet40Cls( diff --git a/utils/data_utils.py b/utils/data_utils.py deleted file mode 100644 index 71c368d..0000000 --- a/utils/data_utils.py +++ /dev/null @@ -1,77 +0,0 @@ -import torch -import numpy as np - - -class PointcloudScale(object): - - def __init__(self, mean=2.0, std=1.0, clip=1.8): - self.mean, self.std, self.clip = mean, std, clip - - def __call__(self, points): - scaler = points.new(1).normal_( - mean=self.mean, std=self.std - ).clamp_(max(self.mean - self.clip, 0.01), self.mean + self.clip) - return scaler * points - - -class PointcloudRotate(object): - - def __init__(self, x_axis=False, z_axis=True): - assert x_axis or z_axis - self.x, self.z = x_axis, z_axis - - def _get_angles(self): - rotation_angle = np.random.uniform() * 2 * np.pi - cosval = np.cos(rotation_angle) - sinval = np.sin(rotation_angle) - - return cosval, sinval - - def __call__(self, points): - if self.z: - sinval, cosval = self._get_angles() - Rz = points.new([[cosval, sinval, 0], [-sinval, cosval, 0], - [0, 0, 1]]) - else: - Rz = torch.eye(3) - - if self.x: - sinval, cosval = self._get_angles() - Rx = points.new([[1, 0, 0], [0, cosval, sinval], - [0, -sinval, cosval]]) - else: - Rx = torch.eye(3) - - rot_mat = Rx @ Rz - - return points @ rot_mat - - -class PointcloudJitter(object): - - def __init__(self, std=0.01, clip=0.03): - self.std, self.clip = std, clip - - def __call__(self, points): - jittered_data = points.new(*points.size()).normal_( - mean=0.0, std=self.std - ).clamp_(-self.clip, self.clip) - return points + jittered_data - - -class PointcloudTranslate(object): - - def __init__(self, std=1.0, clip=3.0): - self.std, self.clip = std, clip - - def __call__(self, points): - translation = points.new(3).normal_( - mean=0.0, std=self.std - ).clamp_(-self.clip, self.clip) - return points + translation - - -class PointcloudToTensor(object): - - def __call__(self, points): - return torch.from_numpy(points).float()