mirror of
https://github.com/wassname/evil_MoE.git
synced 2026-06-27 16:15:35 +08:00
refactor: drop shadowed-import + duplicate-definition cruft (-91 LOC)
Left over from the data.py/vhack.py/eval.py/tablelog.py module split. In train.py the canonical imports already won at runtime; the earlier ones were dead shadows: - ablate_quarantine, ref_logprobs_via_zero_delta: .eval wins (line 66), drop the .antipasto copy; load_v_hack/postprocess_v_hack: .vhack wins, drop .extract_vhack_grad; DATA/load_problems: .data wins, drop .problems. - local setup_logging() was byte-identical to the .tablelog one already imported (with StepLogger); delete the local def + now-orphaned datetime import and LOGS_DIR const. - problems.py stays: 6 scripts + derisk/regrade still import it. antipasto.py: delete detach_antipasto (0 callers) and its own copies of ref_logprobs_via_zero_delta / ablate_quarantine (eval.py owns the canonical, better-worded versions incl. the SGTM TODO), plus now-unused contextmanager and per_token_logps imports. docs: rm corrupted docs/spec/20260530_substrate_review_qwen.md (2-line API error dump, not a review). Behavior-preserving (later imports already won at runtime). Verified: just smoke (erase) + just smoke-routeV both exit 0, 0 tracebacks, all verify_* gates PASS. Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
This commit is contained in:
@@ -1,3 +0,0 @@
|
||||
400 Provider returned error
|
||||
{"error":{"message":"developer is not one of ['system', 'assistant', 'user', 'tool', 'function'] - 'messages.['0].role'","type":"invalid_request_error","param":null,"code":null},"request_id":"chatcmpl-8036d119-9aa0-981f-9113-99a865d3f90e"}
|
||||
[?2026h[r[?1006l[?1002l[?1000l[?1007h[?1049l[<999u[>4;0m[?2026l
|
||||
@@ -15,7 +15,6 @@ directly; no extra projection math at the gradient step).
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
@@ -23,8 +22,6 @@ from jaxtyping import Float
|
||||
from loguru import logger
|
||||
from torch import Tensor, nn
|
||||
|
||||
from .proj import per_token_logps
|
||||
|
||||
|
||||
def svd_cached(
|
||||
W: Float[Tensor, "d_out d_in"],
|
||||
@@ -236,58 +233,3 @@ def wrap_model_with_antipasto(
|
||||
if not n.endswith(trainable):
|
||||
p.requires_grad_(False)
|
||||
return out
|
||||
|
||||
|
||||
def detach_antipasto(model: nn.Module, attached: dict) -> None:
|
||||
for info in attached.values():
|
||||
info["handle"].remove()
|
||||
layer = info["layer"]
|
||||
for attr in ("_antipasto_U", "_antipasto_Vh"):
|
||||
if attr in layer._buffers:
|
||||
del layer._buffers[attr]
|
||||
for attr in ("_antipasto_delta_S", "_antipasto_delta_S_hack"):
|
||||
if attr in layer._parameters:
|
||||
del layer._parameters[attr]
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def ref_logprobs_via_zero_delta(
|
||||
model, merged: torch.Tensor, wrappers: dict, plen: int,
|
||||
) -> torch.Tensor:
|
||||
"""π_ref logprobs on the completion tokens.
|
||||
|
||||
AntiPaSTO: W' = W + U diag(δS) Vᵀ, so at δS=0 the adapter is identity and a
|
||||
forward gives π_ref for free. Save -> zero -> forward -> restore, no second
|
||||
model. logits_to_keep=L_c+1 runs lm_head only on completion-side hidden states
|
||||
(prompt-side logits never materialize, ~plen/(plen+L_c) memory saved at lm_head).
|
||||
"""
|
||||
saved = {n: info["delta_S"].data.clone() for n, info in wrappers.items()}
|
||||
try:
|
||||
for info in wrappers.values():
|
||||
info["delta_S"].data.zero_()
|
||||
L_c = merged.shape[1] - plen
|
||||
logits = model(merged, logits_to_keep=L_c + 1).logits[:, :-1]
|
||||
return per_token_logps(logits, merged[:, plen:])
|
||||
finally:
|
||||
for n, info in wrappers.items():
|
||||
info["delta_S"].data.copy_(saved[n])
|
||||
|
||||
|
||||
@contextmanager
|
||||
def ablate_quarantine(wrappers: dict):
|
||||
"""Zero the routing quarantine (δS_hack) for the duration: the deploy-time
|
||||
ablation of the routed hack capability. Save -> zero -> (eval) -> restore.
|
||||
The route/routeV deployment model IS this ablated state.
|
||||
|
||||
TODO(post-deploy-finetune): SGTM's ablate(trainable=True) reinits the forget
|
||||
weights to the retain-dims' std instead of zeroing, keeping the model
|
||||
finetunable after ablation (no dead hole). We zero because we only eval after
|
||||
deploy. See docs/grad_routing/sgtm_vs_ours.md."""
|
||||
saved = {n: info["delta_S_hack"].data.clone() for n, info in wrappers.items()}
|
||||
for info in wrappers.values():
|
||||
info["delta_S_hack"].data.zero_()
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
for n, info in wrappers.items():
|
||||
info["delta_S_hack"].data.copy_(saved[n])
|
||||
|
||||
+3
-33
@@ -35,7 +35,6 @@ import random
|
||||
import time
|
||||
from contextlib import contextmanager, nullcontext
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Literal
|
||||
|
||||
@@ -55,10 +54,7 @@ from tabulate import tabulate
|
||||
from tqdm import tqdm
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
|
||||
|
||||
from .antipasto import (ablate_quarantine, ref_logprobs_via_zero_delta,
|
||||
wrap_model_with_antipasto, wrap_model_with_lora_frozen_b)
|
||||
from .extract_vhack_grad import load_v_hack, postprocess_v_hack
|
||||
from .problems import DATA, load_problems
|
||||
from .antipasto import wrap_model_with_antipasto, wrap_model_with_lora_frozen_b
|
||||
from .proj import per_token_logps, project_delta_S_grad, mean_cos_pre_from_grads
|
||||
from .rewards import EnvMode, compute_reward
|
||||
from .data import DATA, load_problems
|
||||
@@ -73,34 +69,8 @@ OUT_DIR = Path("out")
|
||||
# runs/<run_id>/. Read paths (v_hack, teacher pool) come in as explicit args.
|
||||
VHACK_DIR = OUT_DIR / "vhack"
|
||||
RUNS_DIR = OUT_DIR / "runs"
|
||||
LOGS_DIR = Path("logs")
|
||||
# DATA (the LeetCode dataset path) lives in problems.py, imported above.
|
||||
|
||||
|
||||
def setup_logging(run_id: str) -> Path:
|
||||
"""Token-efficient loguru: stdout = 1-char icon + msg; verbose log to file.
|
||||
|
||||
See /root/.claude/skills/token-efficient-logging/SKILL.md.
|
||||
"""
|
||||
LOGS_DIR.mkdir(exist_ok=True)
|
||||
verbose_log = LOGS_DIR / f"{datetime.now().strftime('%Y%m%dT%H%M%S')}_{run_id}.log"
|
||||
logger.remove()
|
||||
logger.add(
|
||||
lambda msg: tqdm.write(msg, end=""),
|
||||
colorize=True,
|
||||
format="<level>{level.icon}</level> {message}",
|
||||
level="INFO",
|
||||
)
|
||||
logger.add(
|
||||
verbose_log,
|
||||
format="{time:HH:mm:ss} | {level} | {message}",
|
||||
level="DEBUG",
|
||||
)
|
||||
logger.level("INFO", icon="I")
|
||||
logger.level("WARNING", icon="W")
|
||||
logger.level("ERROR", icon="E")
|
||||
logger.level("DEBUG", icon="D")
|
||||
return verbose_log
|
||||
# DATA (the LeetCode dataset path) lives in data.py, imported above.
|
||||
# setup_logging + StepLogger live in tablelog.py, imported above.
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
|
||||
Reference in New Issue
Block a user