mirror of
https://github.com/wassname/Pointnet2_PyTorch.git
synced 2026-06-27 16:00:07 +08:00
Updates
This commit is contained in:
+40
-20
@@ -18,51 +18,62 @@ import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description="Arg parser")
|
||||
parser.add_argument(
|
||||
"-batch_size", type=int, default=32, help="Batch size [default: 32]")
|
||||
"-batch_size", type=int, default=32, help="Batch size [default: 32]"
|
||||
)
|
||||
parser.add_argument(
|
||||
"-num_points",
|
||||
type=int,
|
||||
default=2048,
|
||||
help="Number of points to train with [default: 2048]")
|
||||
help="Number of points to train with [default: 2048]"
|
||||
)
|
||||
parser.add_argument(
|
||||
"-weight_decay",
|
||||
type=float,
|
||||
default=0,
|
||||
help="L2 regularization coeff [default: 0.0]")
|
||||
help="L2 regularization coeff [default: 0.0]"
|
||||
)
|
||||
parser.add_argument(
|
||||
"-lr",
|
||||
type=float,
|
||||
default=1e-2,
|
||||
help="Initial learning rate [default: 1e-2]")
|
||||
help="Initial learning rate [default: 1e-2]"
|
||||
)
|
||||
parser.add_argument(
|
||||
"-lr_decay",
|
||||
type=float,
|
||||
default=0.5,
|
||||
help="Learning rate decay gamma [default: 0.5]")
|
||||
help="Learning rate decay gamma [default: 0.5]"
|
||||
)
|
||||
parser.add_argument(
|
||||
"-decay_step",
|
||||
type=int,
|
||||
default=20,
|
||||
help="Learning rate decay step [default: 20]")
|
||||
help="Learning rate decay step [default: 20]"
|
||||
)
|
||||
parser.add_argument(
|
||||
"-bn_momentum",
|
||||
type=float,
|
||||
default=0.9,
|
||||
help="Initial batch norm momentum [default: 0.9]")
|
||||
help="Initial batch norm momentum [default: 0.9]"
|
||||
)
|
||||
parser.add_argument(
|
||||
"-bn_decay",
|
||||
type=float,
|
||||
default=0.5,
|
||||
help="Batch norm momentum decay gamma [default: 0.5]")
|
||||
help="Batch norm momentum decay gamma [default: 0.5]"
|
||||
)
|
||||
parser.add_argument(
|
||||
"-checkpoint", type=str, default=None, help="Checkpoint to start from")
|
||||
"-checkpoint", type=str, default=None, help="Checkpoint to start from"
|
||||
)
|
||||
parser.add_argument(
|
||||
"-epochs", type=int, default=200, help="Number of epochs to train for")
|
||||
"-epochs", type=int, default=200, help="Number of epochs to train for"
|
||||
)
|
||||
parser.add_argument(
|
||||
"-run_name",
|
||||
type=str,
|
||||
default="sem_seg_run_1",
|
||||
help="Name for run in tensorboard_logger")
|
||||
help="Name for run in tensorboard_logger"
|
||||
)
|
||||
|
||||
BASE_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'data')
|
||||
|
||||
@@ -74,13 +85,15 @@ if __name__ == "__main__":
|
||||
tb_log.configure('runs/{}'.format(args.run_name))
|
||||
|
||||
test_set = Indoor3DSemSeg(
|
||||
args.num_points, BASE_DIR, train=False, data_precent=0.01)
|
||||
args.num_points, BASE_DIR, train=False, data_precent=0.01
|
||||
)
|
||||
test_loader = DataLoader(
|
||||
test_set,
|
||||
batch_size=args.batch_size,
|
||||
shuffle=True,
|
||||
pin_memory=True,
|
||||
num_workers=2)
|
||||
num_workers=2
|
||||
)
|
||||
|
||||
train_set = Indoor3DSemSeg(args.num_points, BASE_DIR, data_precent=1.0)
|
||||
train_loader = DataLoader(
|
||||
@@ -88,12 +101,14 @@ if __name__ == "__main__":
|
||||
batch_size=args.batch_size,
|
||||
pin_memory=True,
|
||||
num_workers=2,
|
||||
shuffle=True)
|
||||
shuffle=True
|
||||
)
|
||||
|
||||
model = Pointnet(num_classes=13)
|
||||
model.cuda()
|
||||
optimizer = optim.Adam(
|
||||
model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
|
||||
model.parameters(), lr=args.lr, weight_decay=args.weight_decay
|
||||
)
|
||||
|
||||
lr_lbmd = lambda e: max(args.lr_decay**(e // args.decay_step), lr_clip / args.lr)
|
||||
bnm_lmbd = lambda e: max(args.bn_momentum * args.bn_decay**(e // args.decay_step), bnm_clip)
|
||||
@@ -106,12 +121,15 @@ if __name__ == "__main__":
|
||||
best_loss = 1e10
|
||||
else:
|
||||
start_epoch, best_loss = pt_utils.load_checkpoint(
|
||||
model, optimizer, filename=args.checkpoint.split(".")[0])
|
||||
model, optimizer, filename=args.checkpoint.split(".")[0]
|
||||
)
|
||||
|
||||
lr_scheduler = lr_sched.LambdaLR(
|
||||
optimizer, lr_lbmd, last_epoch=start_epoch)
|
||||
optimizer, lr_lbmd, last_epoch=start_epoch
|
||||
)
|
||||
bnm_scheduler = pt_utils.BNMomentumScheduler(
|
||||
model, bnm_lmbd, last_epoch=start_epoch)
|
||||
model, bnm_lmbd, last_epoch=start_epoch
|
||||
)
|
||||
|
||||
model_fn = model_fn_decorator(nn.CrossEntropyLoss())
|
||||
|
||||
@@ -123,14 +141,16 @@ if __name__ == "__main__":
|
||||
best_name="sem_seg_best",
|
||||
lr_scheduler=lr_scheduler,
|
||||
bnm_scheduler=bnm_scheduler,
|
||||
eval_frequency=10)
|
||||
eval_frequency=10
|
||||
)
|
||||
|
||||
trainer.train(
|
||||
start_epoch,
|
||||
args.epochs,
|
||||
train_loader,
|
||||
test_loader,
|
||||
best_loss=best_loss)
|
||||
best_loss=best_loss
|
||||
)
|
||||
|
||||
if start_epoch == args.epochs:
|
||||
test_loader.dataset.data_precent = 1.0
|
||||
|
||||
Reference in New Issue
Block a user