diff --git a/src/vgrout/train.py b/src/vgrout/train.py index 0d9f357..eede2a9 100644 --- a/src/vgrout/train.py +++ b/src/vgrout/train.py @@ -413,17 +413,13 @@ def main(cfg: Config) -> int: # ── optimizer + schedule ── (A and B of both blocks; masks route grads) opt = torch.optim.AdamW( delta_params, lr=lr, weight_decay=cfg.weight_decay, betas=(adam_beta1, adam_beta2)) - # Fractional warmup preserves the intended schedule across preset lengths. - warmup_steps = max(1, int(cfg.warmup_frac * steps)) - sched = torch.optim.lr_scheduler.SequentialLR( - opt, - schedulers=[ - torch.optim.lr_scheduler.LinearLR(opt, start_factor=1e-3, end_factor=1.0, - total_iters=warmup_steps), - torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=max(1, steps - warmup_steps)), - ], - milestones=[warmup_steps], - ) + # OneCycle does warmup + cosine relaxation in one object: cos ramp from lr/div_factor + # up to lr over the first pct_start of steps (the explicit warmup), then cos anneal to + # ~0. cycle_momentum=False so it leaves the configured AdamW betas alone (else it would + # clobber adam_beta1). pct_start = warmup_frac keeps warmup fractional across presets. + sched = torch.optim.lr_scheduler.OneCycleLR( + opt, max_lr=lr, total_steps=steps, pct_start=cfg.warmup_frac, + anneal_strategy="cos", div_factor=25.0, final_div_factor=1e4, cycle_momentum=False) # ── generation config ── # Use the same sampling policy for training and evaluation.