Files
Pointnet2_PyTorch/train_cls.py
T
erikwijmans 5a5adc2b77 Updates
2018-01-06 12:13:52 -05:00

160 lines
4.4 KiB
Python

import torch
import torch.optim as optim
import torch.optim.lr_scheduler as lr_sched
import torch.nn as nn
import numpy as np
from torch.utils.data import DataLoader
from torch.autograd import Variable
from torchvision import transforms
import os
import tensorboard_logger as tb_log
from models import Pointnet2ClsMSG as Pointnet
from models.Pointnet2Cls import model_fn_decorator
from data import ModelNet40Cls
import utils.pytorch_utils as pt_utils
import utils.data_utils as d_utils
import argparse
def parse_args():
parser = argparse.ArgumentParser(
description="Arguments for cls training",
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument("-batch_size", type=int, default=16, help="Batch size")
parser.add_argument(
"-num_points",
type=int,
default=1024,
help="Number of points to train with"
)
parser.add_argument(
"-weight_decay",
type=float,
default=1e-5,
help="L2 regularization coeff"
)
parser.add_argument(
"-lr", type=float, default=1e-2, help="Initial learning rate"
)
parser.add_argument(
"-lr_decay", type=float, default=0.7, help="Learning rate decay gamma"
)
parser.add_argument(
"-decay_step", type=int, default=20, help="Learning rate decay step"
)
parser.add_argument(
"-bn_momentum",
type=float,
default=0.5,
help="Initial batch norm momentum"
)
parser.add_argument(
"-bnm_decay",
type=float,
default=0.5,
help="Batch norm momentum decay gamma"
)
parser.add_argument(
"-checkpoint", type=str, default=None, help="Checkpoint to start from"
)
parser.add_argument(
"-epochs", type=int, default=200, help="Number of epochs to train for"
)
parser.add_argument(
"-run_name",
type=str,
default="cls_run_1",
help="Name for run in tensorboard_logger"
)
return parser.parse_args()
lr_clip = 1e-5
bnm_clip = 1e-2
if __name__ == "__main__":
args = parse_args()
BASE_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'data')
transforms = transforms.Compose([
d_utils.PointcloudToTensor(),
d_utils.PointcloudRotate(x_axis=True, z_axis=True),
d_utils.PointcloudTranslate(),
d_utils.PointcloudJitter()
])
test_set = ModelNet40Cls(
args.num_points, BASE_DIR, transforms=transforms, train=False
)
test_loader = DataLoader(
test_set,
batch_size=args.batch_size,
shuffle=True,
num_workers=2,
pin_memory=True
)
train_set = ModelNet40Cls(args.num_points, BASE_DIR, transforms=transforms)
train_loader = DataLoader(
train_set,
batch_size=args.batch_size,
shuffle=True,
num_workers=2,
pin_memory=True
)
tb_log.configure('runs/{}'.format(args.run_name))
model = Pointnet(input_channels=3, num_classes=40)
model.cuda()
optimizer = optim.Adam(
model.parameters(), lr=args.lr, weight_decay=args.weight_decay
)
lr_lbmd = lambda e: max(args.lr_decay**(e // args.decay_step), lr_clip / args.lr)
bn_lbmd = lambda e: max(args.bn_momentum * args.bnm_decay**(e // args.decay_step), bnm_clip)
if args.checkpoint is not None:
start_epoch, best_loss = pt_utils.load_checkpoint(
model, optimizer, filename=args.checkpoint.split(".")[0]
)
lr_scheduler = lr_sched.LambdaLR(
optimizer, lr_lambda=lr_lbmd, last_epoch=start_epoch
)
bnm_scheduler = pt_utils.BNMomentumScheduler(
model, bn_lambda=bn_lbmd, last_epoch=start_epoch
)
else:
lr_scheduler = lr_sched.LambdaLR(optimizer, lr_lambda=lr_lbmd)
bnm_scheduler = pt_utils.BNMomentumScheduler(model, bn_lambda=bn_lbmd)
best_loss = 1e10
start_epoch = 1
model_fn = model_fn_decorator(nn.CrossEntropyLoss())
trainer = pt_utils.Trainer(
model,
model_fn,
optimizer,
checkpoint_name="cls_checkpoint",
best_name="cls_best",
lr_scheduler=lr_scheduler,
bnm_scheduler=bnm_scheduler
)
trainer.train(
start_epoch,
args.epochs,
train_loader,
test_loader,
best_loss=best_loss
)
if start_epoch == args.epochs:
_ = trainer.eval_epoch(start_epoch, test_loader)