mirror of
https://github.com/wassname/evil_MoE.git
synced 2026-06-27 18:04:59 +08:00
Downsample cin_s/cin_t diagnostic via cin_split_every
Per-source cin (cin_s, cin_t) requires splitting each prompt's backward into student-only + teacher-only passes, which roughly doubles backward wall-time. With cin_s/cin_t empirically stable for 50 steps in #51 (cin_t ~0.37, cin_s ~0.18 with low variance), every-step is overkill. Add Config.cin_split_every: int = 1 (current behavior). Set >1 to compute cin_s/cin_t only every Nth step; combined single-backward on the others. cin_s/cin_t print as NaN on skipped steps. Projection + optimizer step unchanged (still uses combined grad). Default 1 preserves the current run cost; user can opt into 10 for ~half the backward time once the diagnostic is in steady state. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
+71
-42
@@ -166,6 +166,11 @@ class Config:
|
||||
v_hack_extract_top_k: int = 12 # max k to save at extract; n_train_pairs caps it lower
|
||||
v_hack_k: int = 5 # load-time slice; k=1 = mean-diff, k=k_max = full
|
||||
v_hack_tau_axis: float = 0.0 # extract-time: zero axes where S_i/S_0 < tau_axis
|
||||
# Per-source cin diagnostic: split each prompt's backward into student-only
|
||||
# + teacher-only passes (~2x backward time). 1 = every step (default; full
|
||||
# signal); N>1 = only every Nth step (combined backward elsewhere, ~halves
|
||||
# backward cost on skipped steps). cin_s/cin_t print as `nan` on skipped.
|
||||
cin_split_every: int = 1
|
||||
out_tag: str = "" # suffix for saved artifact, e.g. "_seed41"
|
||||
# Mixed-pool GRPO: per-prompt rollout pool = G_s live student + G_t cached
|
||||
# teacher rollouts. Teacher pool is a dir of prompt_NNNN.jsonl.gz produced by
|
||||
@@ -571,6 +576,11 @@ def main(cfg: Config) -> int:
|
||||
# what the projection + optimizer step ultimately sees.
|
||||
step_grad_s: dict[str, torch.Tensor] = {}
|
||||
step_grad_t: dict[str, torch.Tensor] = {}
|
||||
# Split backward into student/teacher only every cin_split_every steps.
|
||||
# On split steps: 2 backwards per prompt, populates step_grad_s/_t.
|
||||
# On skipped steps: 1 combined backward, step_grad_s/_t stay empty and
|
||||
# cin_s/cin_t go to NaN (mean_cin_from_grads returns NaN on empty dict).
|
||||
split_this_step = (step % cfg.cin_split_every == 0)
|
||||
# Phase timers (per-step cumulative, seconds). Each GPU phase ends in a
|
||||
# CPU-blocking op (decode / .item()), so perf_counter is sync-accurate
|
||||
# without explicit cuda.synchronize. Tells us whether wall-time is
|
||||
@@ -750,46 +760,63 @@ def main(cfg: Config) -> int:
|
||||
kl = torch.exp(ref_logp - pol_logp) - (ref_logp - pol_logp) - 1.0
|
||||
per_tok_loss = per_tok_loss + beta * kl
|
||||
|
||||
# Split loss by source (student vs teacher) and run separate
|
||||
# backward passes. Linearity of backward + leaf .grad accumulator:
|
||||
# loss = loss_s + loss_t (since is_s_mask + is_t_mask = 1)
|
||||
# So grad_s + grad_t == full-batch grad; combined for projection.
|
||||
# Per-source split (loss_s + loss_t == full-batch loss because
|
||||
# is_s_v + is_t_v = 1 elementwise; backward is linear so
|
||||
# grad_s + grad_t == full-batch grad). Two backwards every step is
|
||||
# ~2x backward cost — gated to every cin_split_every step.
|
||||
is_s_v = torch.tensor(is_student, dtype=per_tok_loss.dtype,
|
||||
device=per_tok_loss.device).unsqueeze(1) # [G, 1]
|
||||
is_t_v = 1.0 - is_s_v
|
||||
if cfg.unbiased:
|
||||
denom = group * max_new * prompts_per_step
|
||||
loss_s = (per_tok_loss * mask * is_s_v).sum() / denom
|
||||
loss_t = (per_tok_loss * mask * is_t_v).sum() / denom
|
||||
if split_this_step:
|
||||
if cfg.unbiased:
|
||||
denom = group * max_new * prompts_per_step
|
||||
loss_s = (per_tok_loss * mask * is_s_v).sum() / denom
|
||||
loss_t = (per_tok_loss * mask * is_t_v).sum() / denom
|
||||
else:
|
||||
ptl_norm = (per_tok_loss * mask).sum(1) / mask.sum(1).clamp_min(1)
|
||||
loss_s = (ptl_norm * is_s_v.squeeze(1)).sum() / (group * prompts_per_step)
|
||||
loss_t = (ptl_norm * is_t_v.squeeze(1)).sum() / (group * prompts_per_step)
|
||||
# Pass 1: student. retain_graph so the shared forward graph survives.
|
||||
loss_s.backward(retain_graph=True)
|
||||
for name, info in wrappers.items():
|
||||
gs = info["delta_S"].grad
|
||||
if gs is None:
|
||||
continue
|
||||
step_grad_s[name] = (step_grad_s[name] + gs.detach().clone()
|
||||
if name in step_grad_s
|
||||
else gs.detach().clone())
|
||||
model.zero_grad(set_to_none=True)
|
||||
# Pass 2: teacher.
|
||||
loss_t.backward()
|
||||
for name, info in wrappers.items():
|
||||
gt = info["delta_S"].grad
|
||||
if gt is None:
|
||||
continue
|
||||
step_grad_t[name] = (step_grad_t[name] + gt.detach().clone()
|
||||
if name in step_grad_t
|
||||
else gt.detach().clone())
|
||||
model.zero_grad(set_to_none=True)
|
||||
agg_loss += (loss_s + loss_t).item()
|
||||
else:
|
||||
# Per-sample mean across completion tokens, then mean across G_src
|
||||
# samples in this source, scaled to be additive with the other.
|
||||
n_s = max(1, int(is_s_v.sum().item()))
|
||||
n_t = max(1, int(is_t_v.sum().item()))
|
||||
ptl_norm = (per_tok_loss * mask).sum(1) / mask.sum(1).clamp_min(1) # [G]
|
||||
loss_s = (ptl_norm * is_s_v.squeeze(1)).sum() / (group * prompts_per_step)
|
||||
loss_t = (ptl_norm * is_t_v.squeeze(1)).sum() / (group * prompts_per_step)
|
||||
# Pass 1: student. retain_graph so the shared forward graph survives.
|
||||
loss_s.backward(retain_graph=True)
|
||||
for name, info in wrappers.items():
|
||||
gs = info["delta_S"].grad
|
||||
if gs is None:
|
||||
continue
|
||||
step_grad_s[name] = (step_grad_s[name] + gs.detach().clone()
|
||||
if name in step_grad_s
|
||||
else gs.detach().clone())
|
||||
model.zero_grad(set_to_none=True)
|
||||
# Pass 2: teacher.
|
||||
loss_t.backward()
|
||||
for name, info in wrappers.items():
|
||||
gt = info["delta_S"].grad
|
||||
if gt is None:
|
||||
continue
|
||||
step_grad_t[name] = (step_grad_t[name] + gt.detach().clone()
|
||||
if name in step_grad_t
|
||||
else gt.detach().clone())
|
||||
model.zero_grad(set_to_none=True)
|
||||
agg_loss += (loss_s + loss_t).item()
|
||||
# Combined single backward — cheaper, no per-source diagnostic.
|
||||
# Accumulate into step_grad_s as the "combined" carrier; the
|
||||
# injection block below treats step_grad_t == {} as "use gs".
|
||||
if cfg.unbiased:
|
||||
denom = group * max_new * prompts_per_step
|
||||
loss = (per_tok_loss * mask).sum() / denom
|
||||
else:
|
||||
ptl_norm = (per_tok_loss * mask).sum(1) / mask.sum(1).clamp_min(1)
|
||||
loss = ptl_norm.sum() / (group * prompts_per_step)
|
||||
loss.backward()
|
||||
for name, info in wrappers.items():
|
||||
g = info["delta_S"].grad
|
||||
if g is None:
|
||||
continue
|
||||
step_grad_s[name] = (step_grad_s[name] + g.detach().clone()
|
||||
if name in step_grad_s
|
||||
else g.detach().clone())
|
||||
model.zero_grad(set_to_none=True)
|
||||
agg_loss += loss.item()
|
||||
t_fb += time.perf_counter() - _tfb
|
||||
|
||||
# Inject combined grad (student + teacher) into leaf .grad before
|
||||
@@ -808,12 +835,14 @@ def main(cfg: Config) -> int:
|
||||
info["delta_S"].grad = gs + gt
|
||||
|
||||
# Per-source cin: project student-only and teacher-only grads into v_hack
|
||||
# subspace (cosine in delta_S grad-space; rows of V orthonormal so the
|
||||
# ratio is bounded in [0,1]). Discriminator: cin_t > cin_s on a clean
|
||||
# base means v_hack lights up for hack grads more than non-hack — the
|
||||
# extraction direction is doing real work.
|
||||
cin_s = mean_cin_from_grads(step_grad_s, v_hack)
|
||||
cin_t = mean_cin_from_grads(step_grad_t, v_hack)
|
||||
# subspace. Discriminator: cin_t > cin_s on a clean base means v_hack
|
||||
# lights up for hack grads more than non-hack. Only valid on split steps;
|
||||
# otherwise step_grad_s holds the combined grad and would mis-report cin_s.
|
||||
if split_this_step:
|
||||
cin_s = mean_cin_from_grads(step_grad_s, v_hack)
|
||||
cin_t = mean_cin_from_grads(step_grad_t, v_hack)
|
||||
else:
|
||||
cin_s = cin_t = float("nan")
|
||||
|
||||
# Diagnostic cos_in for both arms; projection only mutates grad if arm=projected.
|
||||
diag = project_delta_S_grad(
|
||||
|
||||
Reference in New Issue
Block a user