This commit is contained in:
erikwijmans
2018-01-06 12:13:52 -05:00
parent 7e746ba72a
commit 5a5adc2b77
20 changed files with 650 additions and 494 deletions
+40 -20
View File
@@ -18,51 +18,62 @@ import argparse
parser = argparse.ArgumentParser(description="Arg parser")
parser.add_argument(
"-batch_size", type=int, default=32, help="Batch size [default: 32]")
"-batch_size", type=int, default=32, help="Batch size [default: 32]"
)
parser.add_argument(
"-num_points",
type=int,
default=2048,
help="Number of points to train with [default: 2048]")
help="Number of points to train with [default: 2048]"
)
parser.add_argument(
"-weight_decay",
type=float,
default=0,
help="L2 regularization coeff [default: 0.0]")
help="L2 regularization coeff [default: 0.0]"
)
parser.add_argument(
"-lr",
type=float,
default=1e-2,
help="Initial learning rate [default: 1e-2]")
help="Initial learning rate [default: 1e-2]"
)
parser.add_argument(
"-lr_decay",
type=float,
default=0.5,
help="Learning rate decay gamma [default: 0.5]")
help="Learning rate decay gamma [default: 0.5]"
)
parser.add_argument(
"-decay_step",
type=int,
default=20,
help="Learning rate decay step [default: 20]")
help="Learning rate decay step [default: 20]"
)
parser.add_argument(
"-bn_momentum",
type=float,
default=0.9,
help="Initial batch norm momentum [default: 0.9]")
help="Initial batch norm momentum [default: 0.9]"
)
parser.add_argument(
"-bn_decay",
type=float,
default=0.5,
help="Batch norm momentum decay gamma [default: 0.5]")
help="Batch norm momentum decay gamma [default: 0.5]"
)
parser.add_argument(
"-checkpoint", type=str, default=None, help="Checkpoint to start from")
"-checkpoint", type=str, default=None, help="Checkpoint to start from"
)
parser.add_argument(
"-epochs", type=int, default=200, help="Number of epochs to train for")
"-epochs", type=int, default=200, help="Number of epochs to train for"
)
parser.add_argument(
"-run_name",
type=str,
default="sem_seg_run_1",
help="Name for run in tensorboard_logger")
help="Name for run in tensorboard_logger"
)
BASE_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'data')
@@ -74,13 +85,15 @@ if __name__ == "__main__":
tb_log.configure('runs/{}'.format(args.run_name))
test_set = Indoor3DSemSeg(
args.num_points, BASE_DIR, train=False, data_precent=0.01)
args.num_points, BASE_DIR, train=False, data_precent=0.01
)
test_loader = DataLoader(
test_set,
batch_size=args.batch_size,
shuffle=True,
pin_memory=True,
num_workers=2)
num_workers=2
)
train_set = Indoor3DSemSeg(args.num_points, BASE_DIR, data_precent=1.0)
train_loader = DataLoader(
@@ -88,12 +101,14 @@ if __name__ == "__main__":
batch_size=args.batch_size,
pin_memory=True,
num_workers=2,
shuffle=True)
shuffle=True
)
model = Pointnet(num_classes=13)
model.cuda()
optimizer = optim.Adam(
model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
model.parameters(), lr=args.lr, weight_decay=args.weight_decay
)
lr_lbmd = lambda e: max(args.lr_decay**(e // args.decay_step), lr_clip / args.lr)
bnm_lmbd = lambda e: max(args.bn_momentum * args.bn_decay**(e // args.decay_step), bnm_clip)
@@ -106,12 +121,15 @@ if __name__ == "__main__":
best_loss = 1e10
else:
start_epoch, best_loss = pt_utils.load_checkpoint(
model, optimizer, filename=args.checkpoint.split(".")[0])
model, optimizer, filename=args.checkpoint.split(".")[0]
)
lr_scheduler = lr_sched.LambdaLR(
optimizer, lr_lbmd, last_epoch=start_epoch)
optimizer, lr_lbmd, last_epoch=start_epoch
)
bnm_scheduler = pt_utils.BNMomentumScheduler(
model, bnm_lmbd, last_epoch=start_epoch)
model, bnm_lmbd, last_epoch=start_epoch
)
model_fn = model_fn_decorator(nn.CrossEntropyLoss())
@@ -123,14 +141,16 @@ if __name__ == "__main__":
best_name="sem_seg_best",
lr_scheduler=lr_scheduler,
bnm_scheduler=bnm_scheduler,
eval_frequency=10)
eval_frequency=10
)
trainer.train(
start_epoch,
args.epochs,
train_loader,
test_loader,
best_loss=best_loss)
best_loss=best_loss
)
if start_epoch == args.epochs:
test_loader.dataset.data_precent = 1.0