mirror of
https://github.com/wassname/Pointnet2_PyTorch.git
synced 2026-06-27 16:00:07 +08:00
Updates and some refactoring
This commit is contained in:
+17
-16
@@ -9,49 +9,51 @@ from torchvision import transforms
|
||||
import os
|
||||
import tensorboard_logger as tb_log
|
||||
|
||||
from models import PointnetCls as Pointnet
|
||||
from models.PointnetCls import model_fn_decorator
|
||||
from models import Pointnet2ClsMSG as Pointnet
|
||||
from models.Pointnet2Cls import model_fn_decorator
|
||||
from data import ModelNet40Cls
|
||||
import utils.pytorch_utils as pt_utils
|
||||
import utils.data_utils as d_utils
|
||||
import argparse
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description="Arg parser")
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Arguments for cls training",
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||
parser.add_argument(
|
||||
"-batch_size", type=int, default=128, help="Batch size [default: 128]")
|
||||
"-batch_size", type=int, default=16, help="Batch size")
|
||||
parser.add_argument(
|
||||
"-num_points",
|
||||
type=int,
|
||||
default=1024,
|
||||
help="Number of points to train with [default: 1024]")
|
||||
help="Number of points to train with")
|
||||
parser.add_argument(
|
||||
"-weight_decay", type=float, default=1e-5, help="L2 regularization coeff")
|
||||
parser.add_argument(
|
||||
"-lr",
|
||||
type=float,
|
||||
default=1e-2,
|
||||
help="Initial learning rate [default: 1e-2]")
|
||||
help="Initial learning rate")
|
||||
parser.add_argument(
|
||||
"-lr_decay",
|
||||
type=float,
|
||||
default=0.7,
|
||||
help="Learning rate decay gamma [default: 0.7]")
|
||||
help="Learning rate decay gamma")
|
||||
parser.add_argument(
|
||||
"-decay_step",
|
||||
type=int,
|
||||
default=20,
|
||||
help="Learning rate decay step [default: 20]")
|
||||
help="Learning rate decay step")
|
||||
parser.add_argument(
|
||||
"-bn_momentum",
|
||||
type=float,
|
||||
default=0.5,
|
||||
help="Initial batch norm momentum [default: 0.5]")
|
||||
help="Initial batch norm momentum")
|
||||
parser.add_argument(
|
||||
"-bnm_decay",
|
||||
type=float,
|
||||
default=0.5,
|
||||
help="Batch norm momentum decay gamma [default: 0.5]")
|
||||
help="Batch norm momentum decay gamma")
|
||||
parser.add_argument(
|
||||
"-checkpoint", type=str, default=None, help="Checkpoint to start from")
|
||||
parser.add_argument(
|
||||
@@ -74,8 +76,7 @@ if __name__ == "__main__":
|
||||
|
||||
transforms = transforms.Compose([
|
||||
d_utils.PointcloudToTensor(),
|
||||
d_utils.PointcloudRotate(x_axis=True),
|
||||
d_utils.PointcloudScale(),
|
||||
d_utils.PointcloudRotate(x_axis=True, z_axis=True),
|
||||
d_utils.PointcloudTranslate(),
|
||||
d_utils.PointcloudJitter()
|
||||
])
|
||||
@@ -99,7 +100,7 @@ if __name__ == "__main__":
|
||||
|
||||
tb_log.configure('runs/{}'.format(args.run_name))
|
||||
|
||||
model = Pointnet()
|
||||
model = Pointnet(input_channels=3, num_classes=40)
|
||||
model.cuda()
|
||||
optimizer = optim.Adam(
|
||||
model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
|
||||
@@ -107,7 +108,7 @@ if __name__ == "__main__":
|
||||
bn_lbmd = lambda e: max(args.bn_momentum * args.bnm_decay**(e // args.decay_step), bnm_clip)
|
||||
|
||||
if args.checkpoint is not None:
|
||||
start_epoch, best_prec = pt_utils.load_checkpoint(
|
||||
start_epoch, best_loss = pt_utils.load_checkpoint(
|
||||
model, optimizer, filename=args.checkpoint.split(".")[0])
|
||||
|
||||
lr_scheduler = lr_sched.LambdaLR(
|
||||
@@ -118,7 +119,7 @@ if __name__ == "__main__":
|
||||
lr_scheduler = lr_sched.LambdaLR(optimizer, lr_lambda=lr_lbmd)
|
||||
bnm_scheduler = pt_utils.BNMomentumScheduler(model, bn_lambda=bn_lbmd)
|
||||
|
||||
best_prec = 0.0
|
||||
best_loss = 1e10
|
||||
start_epoch = 1
|
||||
|
||||
model_fn = model_fn_decorator(nn.CrossEntropyLoss())
|
||||
@@ -137,7 +138,7 @@ if __name__ == "__main__":
|
||||
args.epochs,
|
||||
train_loader,
|
||||
test_loader,
|
||||
best_prec=best_prec)
|
||||
best_loss=best_loss)
|
||||
|
||||
if start_epoch == args.epochs:
|
||||
_ = trainer.eval_epoch(start_epoch, test_loader)
|
||||
|
||||
Reference in New Issue
Block a user