proj: add gate_mode=reverse (flip sign of hack-ward component)

Current modes are one_sided (erase positive c only, leaves negative
intact) and no_gate (erase span(V) entirely, drives V@g_proj to 0).
Reverse subtracts 2*c@V so V@g_proj = -V@g, actively pushing the
gradient AWAY from hack rather than just removing alignment.

Smoke confirms: cos_pre=+0.726 -> cos_post=-0.726 (clean flip).
Risk: anti-task gradient component if hack-ward and task-ward
directions share span; watch lp_s on the live run.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
wassname
2026-05-28 09:21:05 +00:00
parent 646edfc7af
commit 3efd9e69a8
2 changed files with 9 additions and 2 deletions
+8 -1
View File
@@ -82,7 +82,8 @@ def _project_one_module(
normalized by ||g||). Positive = grad pushes toward hack; negative = grad
pushes toward safe. Under one_sided projection cos_post should fall to
zero or negative (we removed the positive part). Under no_gate cos_post
is approximately zero by construction.
is approximately zero by construction. Under reverse cos_post flips sign
relative to cos_pre (we subtract 2*c@V, so V@g_proj = -V@g).
"""
gn = g.norm()
if gn < 1e-12:
@@ -96,6 +97,12 @@ def _project_one_module(
mask = (c > 0).to(c.dtype)
c_use = c * mask
fired = bool((c_use != 0).any())
elif gate_mode == "reverse":
# Subtract 2*c@V: V@g_proj = V@g - 2*(V V^T) c = c - 2c = -c.
# Flips the sign of the gradient component in span(V); pushes
# actively away from hack rather than just removing.
c_use = 2 * c
fired = True
else:
raise ValueError(f"unknown gate_mode={gate_mode!r}")
if not fired:
+1 -1
View File
@@ -151,7 +151,7 @@ class Config:
grad_clip: float = 1.0 # global L2 clip on delta_S grads; set high (e.g. 500) to effectively disable
seed: int = 41
preserve_magnitude: bool = True
gate_mode: Literal["one_sided", "no_gate"] = "one_sided"
gate_mode: Literal["one_sided", "no_gate", "reverse"] = "one_sided"
unbiased: bool = True # Dr.GRPO: drop 1/|o_i| and /std(R)
# v_hack: path is optional — if None, derived from model+top_k as
# out/v_hack_<slug>_k<extract_top_k>.safetensors. If file missing, train.py