mirror of
https://github.com/wassname/evil_MoE.git
synced 2026-06-27 16:45:42 +08:00
proj: measure_only kwarg + train.py always-on cos_in diagnostic
Vanilla arm now reports cos_in per step too (cosine of accumulated Dr.GRPO grad with v_hack), as long as v_hack file is on disk. The projection action only mutates the gradient when arm=projected; vanilla just measures. This makes Phase 2 (pilot scale) directly inform Phase 3: vanilla cos_in trajectory says whether v_hack is even aligned with the GRPO direction, before we burn 65h on the full sweep.
This commit is contained in:
@@ -25,6 +25,7 @@ def project_delta_S_grad(
|
||||
wrappers: dict,
|
||||
v_hack: dict[str, torch.Tensor],
|
||||
preserve_magnitude: bool,
|
||||
measure_only: bool = False,
|
||||
) -> dict[str, float]:
|
||||
"""Per-module one-sided removal of v_hack-aligned component from delta_S.grad.
|
||||
|
||||
@@ -32,6 +33,9 @@ def project_delta_S_grad(
|
||||
If cos(g, v) > 0: g' = g - <g, v> v (remove projection onto v). Optionally
|
||||
rescale g' to ||g|| to preserve update magnitude. Else leave g untouched.
|
||||
|
||||
If `measure_only`: same cosine math, but the gradient is NOT mutated.
|
||||
Used by vanilla arm to report cos_in trajectory as a diagnostic.
|
||||
|
||||
Returns aggregate diagnostics: mean_cos_in, mean_cos_out, frac_fired.
|
||||
"""
|
||||
cos_in_list, cos_out_list, n_fired = [], [], 0
|
||||
@@ -53,7 +57,8 @@ def project_delta_S_grad(
|
||||
g_proj = g_proj * (gn / gp_n)
|
||||
cos_out = (g_proj @ v) / g_proj.norm().clamp_min(1e-12)
|
||||
cos_out_list.append(cos_out.item())
|
||||
info["delta_S"].grad = g_proj
|
||||
if not measure_only:
|
||||
info["delta_S"].grad = g_proj
|
||||
n_fired += 1
|
||||
else:
|
||||
cos_out_list.append(cos_in.item())
|
||||
|
||||
@@ -317,13 +317,15 @@ def main(cfg: Config) -> int:
|
||||
delta_params = [info["delta_S"] for info in wrappers.values()]
|
||||
logger.info(f"trainable delta_S: {sum(p.numel() for p in delta_params):,}")
|
||||
|
||||
# v_hack only needed for projected arm. Vanilla H4 sanity runs do not
|
||||
# require a precomputed v_hack and should not be blocked by missing one.
|
||||
if cfg.arm == "projected":
|
||||
# v_hack: loaded for both arms when the file is present, so vanilla also
|
||||
# reports cos_in as a diagnostic (no projection applied). If not present
|
||||
# and arm=vanilla, skip silently — H4 sanity runs without v_hack remain valid.
|
||||
v_hack = None
|
||||
if cfg.v_hack_path.exists():
|
||||
v_hack_cpu = load_v_hack(cfg.v_hack_path, model_name, wrappers)
|
||||
v_hack = {name: v.to(device) for name, v in v_hack_cpu.items()}
|
||||
else:
|
||||
v_hack = None
|
||||
elif cfg.arm == "projected":
|
||||
raise FileNotFoundError(f"projected arm requires v_hack at {cfg.v_hack_path}")
|
||||
opt = torch.optim.AdamW(
|
||||
delta_params, lr=cfg.lr, weight_decay=cfg.weight_decay,
|
||||
betas=(cfg.adam_beta1, cfg.adam_beta2),
|
||||
@@ -535,9 +537,12 @@ def main(cfg: Config) -> int:
|
||||
agg_loss += loss.item()
|
||||
t_fb += time.perf_counter() - _tfb
|
||||
|
||||
# One projection on accumulated grads (projected arm only).
|
||||
if cfg.arm == "projected":
|
||||
diag = project_delta_S_grad(wrappers, v_hack, cfg.preserve_magnitude)
|
||||
# Diagnostic cos_in for both arms; projection only mutates grad if arm=projected.
|
||||
if v_hack is not None:
|
||||
diag = project_delta_S_grad(
|
||||
wrappers, v_hack, cfg.preserve_magnitude,
|
||||
measure_only=(cfg.arm != "projected"),
|
||||
)
|
||||
else:
|
||||
diag = {"mean_cos_in": float("nan"), "mean_cos_out": float("nan"), "frac_fired": float("nan")}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user