diff --git a/src/projected_grpo/train.py b/src/projected_grpo/train.py index bcbd53d..14bd73e 100644 --- a/src/projected_grpo/train.py +++ b/src/projected_grpo/train.py @@ -166,6 +166,11 @@ class Config: v_hack_extract_top_k: int = 12 # max k to save at extract; n_train_pairs caps it lower v_hack_k: int = 5 # load-time slice; k=1 = mean-diff, k=k_max = full v_hack_tau_axis: float = 0.0 # extract-time: zero axes where S_i/S_0 < tau_axis + # Per-source cin diagnostic: split each prompt's backward into student-only + # + teacher-only passes (~2x backward time). 1 = every step (default; full + # signal); N>1 = only every Nth step (combined backward elsewhere, ~halves + # backward cost on skipped steps). cin_s/cin_t print as `nan` on skipped. + cin_split_every: int = 1 out_tag: str = "" # suffix for saved artifact, e.g. "_seed41" # Mixed-pool GRPO: per-prompt rollout pool = G_s live student + G_t cached # teacher rollouts. Teacher pool is a dir of prompt_NNNN.jsonl.gz produced by @@ -571,6 +576,11 @@ def main(cfg: Config) -> int: # what the projection + optimizer step ultimately sees. step_grad_s: dict[str, torch.Tensor] = {} step_grad_t: dict[str, torch.Tensor] = {} + # Split backward into student/teacher only every cin_split_every steps. + # On split steps: 2 backwards per prompt, populates step_grad_s/_t. + # On skipped steps: 1 combined backward, step_grad_s/_t stay empty and + # cin_s/cin_t go to NaN (mean_cin_from_grads returns NaN on empty dict). + split_this_step = (step % cfg.cin_split_every == 0) # Phase timers (per-step cumulative, seconds). Each GPU phase ends in a # CPU-blocking op (decode / .item()), so perf_counter is sync-accurate # without explicit cuda.synchronize. Tells us whether wall-time is @@ -750,46 +760,63 @@ def main(cfg: Config) -> int: kl = torch.exp(ref_logp - pol_logp) - (ref_logp - pol_logp) - 1.0 per_tok_loss = per_tok_loss + beta * kl - # Split loss by source (student vs teacher) and run separate - # backward passes. Linearity of backward + leaf .grad accumulator: - # loss = loss_s + loss_t (since is_s_mask + is_t_mask = 1) - # So grad_s + grad_t == full-batch grad; combined for projection. + # Per-source split (loss_s + loss_t == full-batch loss because + # is_s_v + is_t_v = 1 elementwise; backward is linear so + # grad_s + grad_t == full-batch grad). Two backwards every step is + # ~2x backward cost — gated to every cin_split_every step. is_s_v = torch.tensor(is_student, dtype=per_tok_loss.dtype, device=per_tok_loss.device).unsqueeze(1) # [G, 1] is_t_v = 1.0 - is_s_v - if cfg.unbiased: - denom = group * max_new * prompts_per_step - loss_s = (per_tok_loss * mask * is_s_v).sum() / denom - loss_t = (per_tok_loss * mask * is_t_v).sum() / denom + if split_this_step: + if cfg.unbiased: + denom = group * max_new * prompts_per_step + loss_s = (per_tok_loss * mask * is_s_v).sum() / denom + loss_t = (per_tok_loss * mask * is_t_v).sum() / denom + else: + ptl_norm = (per_tok_loss * mask).sum(1) / mask.sum(1).clamp_min(1) + loss_s = (ptl_norm * is_s_v.squeeze(1)).sum() / (group * prompts_per_step) + loss_t = (ptl_norm * is_t_v.squeeze(1)).sum() / (group * prompts_per_step) + # Pass 1: student. retain_graph so the shared forward graph survives. + loss_s.backward(retain_graph=True) + for name, info in wrappers.items(): + gs = info["delta_S"].grad + if gs is None: + continue + step_grad_s[name] = (step_grad_s[name] + gs.detach().clone() + if name in step_grad_s + else gs.detach().clone()) + model.zero_grad(set_to_none=True) + # Pass 2: teacher. + loss_t.backward() + for name, info in wrappers.items(): + gt = info["delta_S"].grad + if gt is None: + continue + step_grad_t[name] = (step_grad_t[name] + gt.detach().clone() + if name in step_grad_t + else gt.detach().clone()) + model.zero_grad(set_to_none=True) + agg_loss += (loss_s + loss_t).item() else: - # Per-sample mean across completion tokens, then mean across G_src - # samples in this source, scaled to be additive with the other. - n_s = max(1, int(is_s_v.sum().item())) - n_t = max(1, int(is_t_v.sum().item())) - ptl_norm = (per_tok_loss * mask).sum(1) / mask.sum(1).clamp_min(1) # [G] - loss_s = (ptl_norm * is_s_v.squeeze(1)).sum() / (group * prompts_per_step) - loss_t = (ptl_norm * is_t_v.squeeze(1)).sum() / (group * prompts_per_step) - # Pass 1: student. retain_graph so the shared forward graph survives. - loss_s.backward(retain_graph=True) - for name, info in wrappers.items(): - gs = info["delta_S"].grad - if gs is None: - continue - step_grad_s[name] = (step_grad_s[name] + gs.detach().clone() - if name in step_grad_s - else gs.detach().clone()) - model.zero_grad(set_to_none=True) - # Pass 2: teacher. - loss_t.backward() - for name, info in wrappers.items(): - gt = info["delta_S"].grad - if gt is None: - continue - step_grad_t[name] = (step_grad_t[name] + gt.detach().clone() - if name in step_grad_t - else gt.detach().clone()) - model.zero_grad(set_to_none=True) - agg_loss += (loss_s + loss_t).item() + # Combined single backward — cheaper, no per-source diagnostic. + # Accumulate into step_grad_s as the "combined" carrier; the + # injection block below treats step_grad_t == {} as "use gs". + if cfg.unbiased: + denom = group * max_new * prompts_per_step + loss = (per_tok_loss * mask).sum() / denom + else: + ptl_norm = (per_tok_loss * mask).sum(1) / mask.sum(1).clamp_min(1) + loss = ptl_norm.sum() / (group * prompts_per_step) + loss.backward() + for name, info in wrappers.items(): + g = info["delta_S"].grad + if g is None: + continue + step_grad_s[name] = (step_grad_s[name] + g.detach().clone() + if name in step_grad_s + else g.detach().clone()) + model.zero_grad(set_to_none=True) + agg_loss += loss.item() t_fb += time.perf_counter() - _tfb # Inject combined grad (student + teacher) into leaf .grad before @@ -808,12 +835,14 @@ def main(cfg: Config) -> int: info["delta_S"].grad = gs + gt # Per-source cin: project student-only and teacher-only grads into v_hack - # subspace (cosine in delta_S grad-space; rows of V orthonormal so the - # ratio is bounded in [0,1]). Discriminator: cin_t > cin_s on a clean - # base means v_hack lights up for hack grads more than non-hack — the - # extraction direction is doing real work. - cin_s = mean_cin_from_grads(step_grad_s, v_hack) - cin_t = mean_cin_from_grads(step_grad_t, v_hack) + # subspace. Discriminator: cin_t > cin_s on a clean base means v_hack + # lights up for hack grads more than non-hack. Only valid on split steps; + # otherwise step_grad_s holds the combined grad and would mis-report cin_s. + if split_this_step: + cin_s = mean_cin_from_grads(step_grad_s, v_hack) + cin_t = mean_cin_from_grads(step_grad_t, v_hack) + else: + cin_s = cin_t = float("nan") # Diagnostic cos_in for both arms; projection only mutates grad if arm=projected. diag = project_delta_S_grad(