Downsample cin_s/cin_t diagnostic via cin_split_every

Per-source cin (cin_s, cin_t) requires splitting each prompt's backward
into student-only + teacher-only passes, which roughly doubles backward
wall-time. With cin_s/cin_t empirically stable for 50 steps in #51
(cin_t ~0.37, cin_s ~0.18 with low variance), every-step is overkill.

Add Config.cin_split_every: int = 1 (current behavior). Set >1 to
compute cin_s/cin_t only every Nth step; combined single-backward on
the others. cin_s/cin_t print as NaN on skipped steps. Projection +
optimizer step unchanged (still uses combined grad).

Default 1 preserves the current run cost; user can opt into 10 for
~half the backward time once the diagnostic is in steady state.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
wassname
2026-05-27 09:14:30 +00:00
parent ff26cbe089
commit 9ba7b818a9
+71 -42
View File
@@ -166,6 +166,11 @@ class Config:
v_hack_extract_top_k: int = 12 # max k to save at extract; n_train_pairs caps it lower
v_hack_k: int = 5 # load-time slice; k=1 = mean-diff, k=k_max = full
v_hack_tau_axis: float = 0.0 # extract-time: zero axes where S_i/S_0 < tau_axis
# Per-source cin diagnostic: split each prompt's backward into student-only
# + teacher-only passes (~2x backward time). 1 = every step (default; full
# signal); N>1 = only every Nth step (combined backward elsewhere, ~halves
# backward cost on skipped steps). cin_s/cin_t print as `nan` on skipped.
cin_split_every: int = 1
out_tag: str = "" # suffix for saved artifact, e.g. "_seed41"
# Mixed-pool GRPO: per-prompt rollout pool = G_s live student + G_t cached
# teacher rollouts. Teacher pool is a dir of prompt_NNNN.jsonl.gz produced by
@@ -571,6 +576,11 @@ def main(cfg: Config) -> int:
# what the projection + optimizer step ultimately sees.
step_grad_s: dict[str, torch.Tensor] = {}
step_grad_t: dict[str, torch.Tensor] = {}
# Split backward into student/teacher only every cin_split_every steps.
# On split steps: 2 backwards per prompt, populates step_grad_s/_t.
# On skipped steps: 1 combined backward, step_grad_s/_t stay empty and
# cin_s/cin_t go to NaN (mean_cin_from_grads returns NaN on empty dict).
split_this_step = (step % cfg.cin_split_every == 0)
# Phase timers (per-step cumulative, seconds). Each GPU phase ends in a
# CPU-blocking op (decode / .item()), so perf_counter is sync-accurate
# without explicit cuda.synchronize. Tells us whether wall-time is
@@ -750,46 +760,63 @@ def main(cfg: Config) -> int:
kl = torch.exp(ref_logp - pol_logp) - (ref_logp - pol_logp) - 1.0
per_tok_loss = per_tok_loss + beta * kl
# Split loss by source (student vs teacher) and run separate
# backward passes. Linearity of backward + leaf .grad accumulator:
# loss = loss_s + loss_t (since is_s_mask + is_t_mask = 1)
# So grad_s + grad_t == full-batch grad; combined for projection.
# Per-source split (loss_s + loss_t == full-batch loss because
# is_s_v + is_t_v = 1 elementwise; backward is linear so
# grad_s + grad_t == full-batch grad). Two backwards every step is
# ~2x backward cost — gated to every cin_split_every step.
is_s_v = torch.tensor(is_student, dtype=per_tok_loss.dtype,
device=per_tok_loss.device).unsqueeze(1) # [G, 1]
is_t_v = 1.0 - is_s_v
if cfg.unbiased:
denom = group * max_new * prompts_per_step
loss_s = (per_tok_loss * mask * is_s_v).sum() / denom
loss_t = (per_tok_loss * mask * is_t_v).sum() / denom
if split_this_step:
if cfg.unbiased:
denom = group * max_new * prompts_per_step
loss_s = (per_tok_loss * mask * is_s_v).sum() / denom
loss_t = (per_tok_loss * mask * is_t_v).sum() / denom
else:
ptl_norm = (per_tok_loss * mask).sum(1) / mask.sum(1).clamp_min(1)
loss_s = (ptl_norm * is_s_v.squeeze(1)).sum() / (group * prompts_per_step)
loss_t = (ptl_norm * is_t_v.squeeze(1)).sum() / (group * prompts_per_step)
# Pass 1: student. retain_graph so the shared forward graph survives.
loss_s.backward(retain_graph=True)
for name, info in wrappers.items():
gs = info["delta_S"].grad
if gs is None:
continue
step_grad_s[name] = (step_grad_s[name] + gs.detach().clone()
if name in step_grad_s
else gs.detach().clone())
model.zero_grad(set_to_none=True)
# Pass 2: teacher.
loss_t.backward()
for name, info in wrappers.items():
gt = info["delta_S"].grad
if gt is None:
continue
step_grad_t[name] = (step_grad_t[name] + gt.detach().clone()
if name in step_grad_t
else gt.detach().clone())
model.zero_grad(set_to_none=True)
agg_loss += (loss_s + loss_t).item()
else:
# Per-sample mean across completion tokens, then mean across G_src
# samples in this source, scaled to be additive with the other.
n_s = max(1, int(is_s_v.sum().item()))
n_t = max(1, int(is_t_v.sum().item()))
ptl_norm = (per_tok_loss * mask).sum(1) / mask.sum(1).clamp_min(1) # [G]
loss_s = (ptl_norm * is_s_v.squeeze(1)).sum() / (group * prompts_per_step)
loss_t = (ptl_norm * is_t_v.squeeze(1)).sum() / (group * prompts_per_step)
# Pass 1: student. retain_graph so the shared forward graph survives.
loss_s.backward(retain_graph=True)
for name, info in wrappers.items():
gs = info["delta_S"].grad
if gs is None:
continue
step_grad_s[name] = (step_grad_s[name] + gs.detach().clone()
if name in step_grad_s
else gs.detach().clone())
model.zero_grad(set_to_none=True)
# Pass 2: teacher.
loss_t.backward()
for name, info in wrappers.items():
gt = info["delta_S"].grad
if gt is None:
continue
step_grad_t[name] = (step_grad_t[name] + gt.detach().clone()
if name in step_grad_t
else gt.detach().clone())
model.zero_grad(set_to_none=True)
agg_loss += (loss_s + loss_t).item()
# Combined single backward — cheaper, no per-source diagnostic.
# Accumulate into step_grad_s as the "combined" carrier; the
# injection block below treats step_grad_t == {} as "use gs".
if cfg.unbiased:
denom = group * max_new * prompts_per_step
loss = (per_tok_loss * mask).sum() / denom
else:
ptl_norm = (per_tok_loss * mask).sum(1) / mask.sum(1).clamp_min(1)
loss = ptl_norm.sum() / (group * prompts_per_step)
loss.backward()
for name, info in wrappers.items():
g = info["delta_S"].grad
if g is None:
continue
step_grad_s[name] = (step_grad_s[name] + g.detach().clone()
if name in step_grad_s
else g.detach().clone())
model.zero_grad(set_to_none=True)
agg_loss += loss.item()
t_fb += time.perf_counter() - _tfb
# Inject combined grad (student + teacher) into leaf .grad before
@@ -808,12 +835,14 @@ def main(cfg: Config) -> int:
info["delta_S"].grad = gs + gt
# Per-source cin: project student-only and teacher-only grads into v_hack
# subspace (cosine in delta_S grad-space; rows of V orthonormal so the
# ratio is bounded in [0,1]). Discriminator: cin_t > cin_s on a clean
# base means v_hack lights up for hack grads more than non-hack — the
# extraction direction is doing real work.
cin_s = mean_cin_from_grads(step_grad_s, v_hack)
cin_t = mean_cin_from_grads(step_grad_t, v_hack)
# subspace. Discriminator: cin_t > cin_s on a clean base means v_hack
# lights up for hack grads more than non-hack. Only valid on split steps;
# otherwise step_grad_s holds the combined grad and would mis-report cin_s.
if split_this_step:
cin_s = mean_cin_from_grads(step_grad_s, v_hack)
cin_t = mean_cin_from_grads(step_grad_t, v_hack)
else:
cin_s = cin_t = float("nan")
# Diagnostic cos_in for both arms; projection only mutates grad if arm=projected.
diag = project_delta_S_grad(