mirror of
https://github.com/wassname/evil_MoE.git
synced 2026-06-27 16:45:42 +08:00
fix(routeV): keep=bulk gate + deterministic teachers + deploy-mode generation
Three fixes after the deploy-solve=0 run (all user-confirmed): 1. Gate co-adaptation: t_lo sat at quantile(route_quantile), so keep was only the bottom ~5% and ~95% of rollouts landed in mid (both blocks train, qmass~0.5). Move the keep/mid boundary up: t_lo=quantile(1-2q), t_hi=quantile(1-q), so keep is the BULK, mid+rout are the top 2q. Three-zone absorption preserved (not deleted). 2. Deterministic teachers: every teacher-phase prompt is drawn from the both-pool- covered set and gets EXACTLY teacher_n_per_prompt hack + N solve (constant count, no flip/coverage drops). Replaces mix_ratio*_even_split (count varied per step). No flip in the teacher phase (solve teacher carries solve pressure). mix_ratio>0 stays the on/off switch. Removed dead _even_split. 3. Deploy-mode generation: student rollouts generate under ablate_quarantine, so the behavior policy = the shipped deployed-only model -- the quarantine's learned hack can't saturate the rollout distribution and starve honest solve advantage. For clean-gated rollouts gen and train forward now match. Also: FastConfig lr 1e-4->5e-4 (random-init lora2r needs more lr in the short budget). AGENTS.md: don't bake unconfirmed theories into comments; don't inflate diagnosis confidence across turns. Smoke + smoke-solvemix green; all verify gates pass. Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
This commit is contained in:
@@ -72,6 +72,8 @@ Inherit global rules from `~/.claude/CLAUDE.md`.
|
|||||||
diverse / less overlapping / less off-distribution), i.e. the "Hack A predicts Hack B"
|
diverse / less overlapping / less off-distribution), i.e. the "Hack A predicts Hack B"
|
||||||
generalization where every label still lives on pairs we wrote.
|
generalization where every label still lives on pairs we wrote.
|
||||||
- do not overconfidentaly diagnoses. if you cant think of 3+ plausible hypothesis - including bugs, subtle failures, and you being wrong about concepts - then you have lost perspective and narrow vision
|
- do not overconfidentaly diagnoses. if you cant think of 3+ plausible hypothesis - including bugs, subtle failures, and you being wrong about concepts - then you have lost perspective and narrow vision
|
||||||
|
- DON'T act on a new theory/diagnosis until I confirm it: don't keep building on it, and don't write it into code comments, docstrings, or docs. An unconfirmed theory baked into a comment misleads every future agent who reads it as fact. State it to me, wait for the OK, THEN write.
|
||||||
|
- DON'T inflate a diagnosis's confidence (maybe -> probably -> definitely) across turns or in writing. Keep the hedge you started with unless new evidence justifies the change, and name that evidence. Confidence creep in comments/docs is how a guess becomes "known" with no one having checked.
|
||||||
- I'd often afk so dont stop and ask me a question you know the likely answer, or I've already indicated or asked for, or where there is only on answer "waiting for your go ahead". I'd rather you just commit and go ahead
|
- I'd often afk so dont stop and ask me a question you know the likely answer, or I've already indicated or asked for, or where there is only on answer "waiting for your go ahead". I'd rather you just commit and go ahead
|
||||||
|
|
||||||
## The adapter setup (shapes), and why "same position = shrinkage" is subtle
|
## The adapter setup (shapes), and why "same position = shrinkage" is subtle
|
||||||
|
|||||||
+69
-72
@@ -106,14 +106,6 @@ def _build_v_grad(raw_grads: dict, names, k: int, device) -> dict[str, Float[tor
|
|||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
def _even_split(total: int, parts: int) -> list[int]:
|
|
||||||
"""Distribute `total` items across `parts` buckets as evenly as possible, extras first.
|
|
||||||
_even_split(8,4)=[2,2,2,2]; _even_split(2,4)=[1,1,0,0]. Used to spread the STEP-level
|
|
||||||
teacher budget across the step's prompts (T_solve front-loads like T_total, so
|
|
||||||
solve_alloc[i] <= total_alloc[i] holds bucket-wise)."""
|
|
||||||
base, extra = divmod(total, parts)
|
|
||||||
return [base + (1 if i < extra else 0) for i in range(parts)]
|
|
||||||
|
|
||||||
|
|
||||||
def _sample_rows(rows: list[dict] | None, n: int, rng: torch.Generator) -> list[dict]:
|
def _sample_rows(rows: list[dict] | None, n: int, rng: torch.Generator) -> list[dict]:
|
||||||
"""Draw n teacher rollouts from a prompt's pool (with replacement if the pool is short)."""
|
"""Draw n teacher rollouts from a prompt's pool (with replacement if the pool is short)."""
|
||||||
@@ -398,9 +390,8 @@ def main(cfg: Config) -> int:
|
|||||||
avg_hack = sum(int(r["hacked"]) for v in teacher_pool.values() for r in v) / sum(len(v) for v in teacher_pool.values())
|
avg_hack = sum(int(r["hacked"]) for v in teacher_pool.values() for r in v) / sum(len(v) for v in teacher_pool.values())
|
||||||
logger.info(
|
logger.info(
|
||||||
f"teacher pool: {len(teacher_pool)} prompts, ~{n_rollouts_per:.1f} rollouts/prompt, "
|
f"teacher pool: {len(teacher_pool)} prompts, ~{n_rollouts_per:.1f} rollouts/prompt, "
|
||||||
f"cached hack_rate={avg_hack:.2%}. STEP-level mix_ratio={cfg.mix_ratio} -> "
|
f"cached hack_rate={avg_hack:.2%}. Deterministic: {cfg.teacher_n_per_prompt} hack "
|
||||||
f"{round(prompts_per_step * group * cfg.mix_ratio)} teachers across "
|
f"teacher(s) per teacher-phase prompt (constant count, no mix_ratio budget).")
|
||||||
f"{prompts_per_step} prompts/step (rest of {prompts_per_step * group} gens are student).")
|
|
||||||
|
|
||||||
# ── solve-teacher pool (symmetric honest demos) ── same schema/loader as the
|
# ── solve-teacher pool (symmetric honest demos) ── same schema/loader as the
|
||||||
# hack pool; the G_t teacher slots split solve_mix_frac solve / rest hack.
|
# hack pool; the G_t teacher slots split solve_mix_frac solve / rest hack.
|
||||||
@@ -462,10 +453,20 @@ def main(cfg: Config) -> int:
|
|||||||
problems = all_problems[:n_problems]
|
problems = all_problems[:n_problems]
|
||||||
mode_desc = "per-problem partition" if partition is not None else f"single env_mode={cfg.env_mode}"
|
mode_desc = "per-problem partition" if partition is not None else f"single env_mode={cfg.env_mode}"
|
||||||
logger.info(f"loaded {len(problems)} seeded-shuffle problems from {DATA.name} -- {mode_desc}")
|
logger.info(f"loaded {len(problems)} seeded-shuffle problems from {DATA.name} -- {mode_desc}")
|
||||||
|
# Both-pool-covered prompts: the teacher phase samples ONLY these so every prompt can get
|
||||||
|
# a deterministic N hack + N solve. solve_pool ⊂ teacher_pool, so the intersection = the
|
||||||
|
# solve-covered prompts (or just teacher-covered when there is no solve pool).
|
||||||
|
covered_problems = [p for p in problems
|
||||||
|
if (not teacher_pool or p["problem_id"] in teacher_pool)
|
||||||
|
and (not solve_pool or p["problem_id"] in solve_pool)]
|
||||||
if teacher_pool:
|
if teacher_pool:
|
||||||
n_cov = sum(1 for p in problems if p["problem_id"] in teacher_pool)
|
n_cov = sum(1 for p in problems if p["problem_id"] in teacher_pool)
|
||||||
logger.info(f"teacher coverage: {n_cov}/{len(problems)} train prompts have cached "
|
n_t_per_prompt = cfg.teacher_n_per_prompt * (bool(teacher_pool) + bool(solve_pool))
|
||||||
f"teacher hacks (rest train student-only); hack must generalize off the seeds")
|
logger.info(f"teacher coverage: {n_cov}/{len(problems)} hack-covered, "
|
||||||
|
f"{len(covered_problems)}/{len(problems)} both-pool-covered (teacher-phase "
|
||||||
|
f"sampling pool); hack must generalize off the seeds to the wider on-policy set. "
|
||||||
|
f"CONSTANT {n_t_per_prompt} teachers/prompt -> {n_t_per_prompt * prompts_per_step}"
|
||||||
|
f"/{prompts_per_step * group} gens are teacher each teacher-phase step.")
|
||||||
|
|
||||||
# Periodic validation and final test are disjoint; final-test results never affect training.
|
# Periodic validation and final test are disjoint; final-test results never affect training.
|
||||||
# Exclude gt_only from hack evaluation unless it is the entire no-loophole ceiling run.
|
# Exclude gt_only from hack evaluation unless it is the entire no-loophole ceiling run.
|
||||||
@@ -495,18 +496,17 @@ def main(cfg: Config) -> int:
|
|||||||
pad_id = tok.pad_token_id
|
pad_id = tok.pad_token_id
|
||||||
|
|
||||||
def gen_students(enc, n: int) -> tuple[torch.Tensor, int]:
|
def gen_students(enc, n: int) -> tuple[torch.Tensor, int]:
|
||||||
"""Generate student rollouts, placing any quarantine-ablated samples last."""
|
"""Generate student rollouts in DEPLOY mode (quarantine ablated): the behavior policy
|
||||||
n_abl = round(n * cfg.rollout_ablate_frac) if has_quarantine else 0
|
= the shipped deployed-only model, so the quarantine's learned hack can't saturate the
|
||||||
parts = []
|
rollout distribution and starve the honest-vs-honest solve advantage. For clean-gated
|
||||||
if n - n_abl > 0:
|
rollouts the train forward is also quarantine-off, so gen and train now match.
|
||||||
parts.append(model.generate(**enc, generation_config=gen_cfg,
|
(vanilla has no quarantine -> nullcontext, but its quarantine is empty so it is already
|
||||||
num_return_sequences=n - n_abl).detach())
|
deploy-mode.)"""
|
||||||
if n_abl > 0:
|
ctx = ablate_quarantine(wrappers) if has_quarantine else nullcontext()
|
||||||
with ablate_quarantine(wrappers):
|
with ctx:
|
||||||
parts.append(model.generate(**enc, generation_config=gen_cfg,
|
out = model.generate(**enc, generation_config=gen_cfg,
|
||||||
num_return_sequences=n_abl).detach())
|
num_return_sequences=n).detach()
|
||||||
L = max(p.shape[1] for p in parts)
|
return out, 0 # every rollout is deploy-mode now; no separate ablated-proxy subset
|
||||||
return torch.cat([F.pad(p, (0, L - p.shape[1]), value=pad_id) for p in parts], dim=0), n_abl
|
|
||||||
|
|
||||||
# `ref_eq` compares cumulative sampling pressure to the 16x16 reference step.
|
# `ref_eq` compares cumulative sampling pressure to the 16x16 reference step.
|
||||||
run_modes = sorted({p["env_mode"] for p in problems}, key=lambda m: list(MODE_CODE).index(m))
|
run_modes = sorted({p["env_mode"] for p in problems}, key=lambda m: list(MODE_CODE).index(m))
|
||||||
@@ -607,20 +607,23 @@ def main(cfg: Config) -> int:
|
|||||||
if n_inc == 0:
|
if n_inc == 0:
|
||||||
raise RuntimeError("no module has positive band width; pairs separate nowhere")
|
raise RuntimeError("no module has positive band width; pairs separate nowhere")
|
||||||
pos = num / den; w /= n_inc
|
pos = num / den; w /= n_inc
|
||||||
# ── online-stats gate ── The authored absolute thresholds (clean<=0, hack>=1)
|
# ── online-stats gate (three zones, keep = bulk) ── The authored absolute thresholds
|
||||||
# don't work: live hacks sit far below the authored p75-hack pooled edge, so
|
# (clean<=0, hack>=1) don't work: live hacks sit far below the authored p75-hack pooled
|
||||||
# rout~0. Route by live QUANTILES of pos instead -- top route_quantile tail ->
|
# edge, so rout~0. Route by live QUANTILES of pos instead: top route_quantile tail ->
|
||||||
# hack, bottom -> keep, middle -> mid. v_grad stays authored-only; only the
|
# hack (rout), the next route_quantile band -> mid (absorption), everything below ->
|
||||||
# threshold follows the live distribution, so a pair/live shift can't strand it.
|
# keep. keep is the BULK; the previous t_lo=quantile(route_quantile) put keep at the
|
||||||
# The window includes this batch, so step 0 self-calibrates (no cold start).
|
# bottom route_quantile and routed ~(1-route_quantile) into mid, so the quarantine
|
||||||
|
# trained on almost every rollout (qmass~0.5). v_grad stays authored-only; only the
|
||||||
|
# thresholds follow the live distribution. The window includes this batch, so step 0
|
||||||
|
# self-calibrates. TODO(#30): calibrate the rout/mid widths from the batch distribution.
|
||||||
route_pos_window.extend(pos.detach().cpu().tolist())
|
route_pos_window.extend(pos.detach().cpu().tolist())
|
||||||
ref = torch.tensor(list(route_pos_window))
|
ref = torch.tensor(list(route_pos_window))
|
||||||
t_lo = ref.quantile(cfg.route_quantile).item()
|
t_hi = ref.quantile(1.0 - cfg.route_quantile).item() # top route_quantile -> rout
|
||||||
t_hi = ref.quantile(1.0 - cfg.route_quantile).item()
|
t_lo = ref.quantile(1.0 - 2.0 * cfg.route_quantile).item() # next band -> mid; below -> keep
|
||||||
logger.debug(f"gate pos: n_inc={n_inc}/{len(wrappers)} min={pos.min().item():+.2f} "
|
logger.debug(f"gate pos: n_inc={n_inc}/{len(wrappers)} min={pos.min().item():+.2f} "
|
||||||
f"max={pos.max().item():+.2f} | online t_lo={t_lo:+.2f} t_hi={t_hi:+.2f} "
|
f"max={pos.max().item():+.2f} | online t_lo={t_lo:+.2f} t_hi={t_hi:+.2f} "
|
||||||
f"win={len(route_pos_window)}")
|
f"win={len(route_pos_window)}")
|
||||||
m = (pos > t_lo).float() # above the confident-clean bottom tail -> quarantine trains
|
m = (pos > t_lo).float() # mid + rout -> quarantine trains (keep = bulk below t_lo)
|
||||||
d = (pos >= t_hi).float() # top tail -> hack -> deployed detached
|
d = (pos >= t_hi).float() # top tail -> hack -> deployed detached
|
||||||
return m, d, 0.5 * m + 0.5 * d, w, pos, cosU_sum / n_inc
|
return m, d, 0.5 * m + 0.5 * d, w, pos, cosU_sum / n_inc
|
||||||
|
|
||||||
@@ -629,20 +632,18 @@ def main(cfg: Config) -> int:
|
|||||||
mininterval=120, maxinterval=120, disable=None)
|
mininterval=120, maxinterval=120, disable=None)
|
||||||
# ── training loop: generate -> grade -> backward -> (gate) -> masked backward -> step ──
|
# ── training loop: generate -> grade -> backward -> (gate) -> masked backward -> step ──
|
||||||
for step in pbar:
|
for step in pbar:
|
||||||
# STEP-LEVEL teacher budget (#31): mix_ratio is a fraction of ALL this step's
|
# DETERMINISTIC teacher forcing: in the teacher phase every prompt is drawn from the
|
||||||
# generations (prompts_per_step x group), not per-prompt -- finer than per-prompt
|
# both-pool-covered set and gets EXACTLY teacher_n_per_prompt hack + N solve teachers
|
||||||
# rounding (mix=0.0625 -> 2 teachers spread, not round(0.5)=0 each). Even-split
|
# (the rest of `group` are students). Constant count, no flip/coverage drops. After
|
||||||
# across prompts; per-prompt coverage/flips below drop a prompt's share to 0 (no
|
# teacher_off the run is pure on-policy on the wider problem set (with the flip back).
|
||||||
# redistribution, so total teachers <= T_step). Total rollouts stay = P x group.
|
|
||||||
teacher_off = cfg.teacher_off_step is not None and step >= cfg.teacher_off_step
|
teacher_off = cfg.teacher_off_step is not None and step >= cfg.teacher_off_step
|
||||||
if teacher_off and step == cfg.teacher_off_step:
|
if teacher_off and step == cfg.teacher_off_step:
|
||||||
logger.info(f"teacher-off curriculum: step {step} >= {cfg.teacher_off_step} "
|
logger.info(f"teacher-off curriculum: step {step} >= {cfg.teacher_off_step} "
|
||||||
f"-> T_step -> 0 (pure on-policy from here)")
|
f"-> teachers off (pure on-policy on the wider problem set from here)")
|
||||||
T_step = 0 if (teacher_off or not (teacher_pool or solve_pool)) else \
|
# mix_ratio>0 is the teacher on/off switch (mix=0 = no-teacher ablation: pool still
|
||||||
round(prompts_per_step * group * cfg.mix_ratio)
|
# loaded for v_grad extraction, but no demos injected). The COUNT is teacher_n_per_prompt.
|
||||||
T_solve = round(T_step * cfg.solve_mix_frac) if solve_pool else 0
|
teachers_on = (not teacher_off) and cfg.mix_ratio > 0 \
|
||||||
gt_alloc = _even_split(T_step, prompts_per_step) # teachers per prompt slot
|
and bool(covered_problems) and bool(teacher_pool or solve_pool)
|
||||||
solve_alloc = _even_split(T_solve, prompts_per_step) # of those, solve teachers
|
|
||||||
t0 = time.time()
|
t0 = time.time()
|
||||||
opt.zero_grad(set_to_none=True)
|
opt.zero_grad(set_to_none=True)
|
||||||
|
|
||||||
@@ -676,17 +677,21 @@ def main(cfg: Config) -> int:
|
|||||||
|
|
||||||
# ── per prompt: G_s student + G_t teacher rollouts -> grade -> backward ──
|
# ── per prompt: G_s student + G_t teacher rollouts -> grade -> backward ──
|
||||||
for p_idx in range(prompts_per_step):
|
for p_idx in range(prompts_per_step):
|
||||||
idx = int(torch.randint(0, len(problems), (1,), generator=rng).item())
|
if teachers_on:
|
||||||
prob = problems[idx]
|
# Teacher phase: sample a both-pool-covered prompt and never flip it -- the
|
||||||
# Rotating realism flip: each step, an independent unhackable_frac chance this
|
# solve teacher carries solve pressure, and a hint-free flip would mismatch the
|
||||||
# problem is shown WITHOUT its loophole hint (plain "pass all tests"), graded by
|
# cached teacher completion (generated under the loophole hint).
|
||||||
# the honest oracle only. Seeded on (seed, step, pid) so the unhackable subset
|
prob = covered_problems[int(torch.randint(0, len(covered_problems), (1,), generator=rng).item())]
|
||||||
# ROTATES across steps -- over training every problem is sometimes hint-free, so
|
flip = False
|
||||||
# the student must learn to genuinely solve the whole distribution, not memorize a
|
else:
|
||||||
# fixed honest subset. Teacher demos are skipped on flipped steps: a cached hack
|
# On-policy phase: sample the wider set; rotating realism flip shows a problem
|
||||||
# rollout's prompt no longer matches the hint-free one.
|
# WITHOUT its loophole hint (plain "pass all tests"), graded by the honest oracle
|
||||||
flip = (cfg.unhackable_frac > 0
|
# only, with prob unhackable_frac. Seeded on (seed, step, pid) so the unhackable
|
||||||
and random.Random(f"unhack-{cfg.seed}-{step}-{prob['problem_id']}").random() < cfg.unhackable_frac)
|
# subset ROTATES -- the student must learn to genuinely solve the whole
|
||||||
|
# distribution, not memorize a fixed honest subset.
|
||||||
|
prob = problems[int(torch.randint(0, len(problems), (1,), generator=rng).item())]
|
||||||
|
flip = (cfg.unhackable_frac > 0
|
||||||
|
and random.Random(f"unhack-{cfg.seed}-{step}-{prob['problem_id']}").random() < cfg.unhackable_frac)
|
||||||
n_flipped += int(flip)
|
n_flipped += int(flip)
|
||||||
eff_mode = "gt_only" if flip else prob["env_mode"]
|
eff_mode = "gt_only" if flip else prob["env_mode"]
|
||||||
eff_messages = prob["messages_gt"] if flip else prob["messages"]
|
eff_messages = prob["messages_gt"] if flip else prob["messages"]
|
||||||
@@ -709,22 +714,14 @@ def main(cfg: Config) -> int:
|
|||||||
_tg = time.perf_counter()
|
_tg = time.perf_counter()
|
||||||
teacher_sample: list[dict] | None = None
|
teacher_sample: list[dict] | None = None
|
||||||
teacher_is_solve: list[bool] = [] # per teacher rollout: from the solve pool? (diagnostic)
|
teacher_is_solve: list[bool] = [] # per teacher rollout: from the solve pool? (diagnostic)
|
||||||
# No teacher demos on a flipped (hint-free) step: the cached rollout was
|
if teachers_on:
|
||||||
# generated under the loophole hint, so its prompt no longer matches.
|
# Deterministic (#34): this both-covered prompt gets EXACTLY N hack + N solve
|
||||||
pool_rows = None if flip else (teacher_pool.get(prob["problem_id"]) if teacher_pool else None)
|
# teachers (constant count every teacher-phase prompt). The rest of `group` are
|
||||||
solve_rows = None if flip else (solve_pool.get(prob["problem_id"]) if solve_pool else None)
|
# students. No flip/coverage drops -> hack_t and the teacher total are constant.
|
||||||
# This prompt's slice of the step-level teacher budget (#31). Coverage/flips
|
n_hack = cfg.teacher_n_per_prompt if teacher_pool else 0
|
||||||
# drop it to 0 -> student-only (else below). We deliberately do NOT skip
|
n_solve = cfg.teacher_n_per_prompt if solve_pool else 0
|
||||||
# uncovered prompts: the student must learn the hack on the whole env, not only
|
pool_rows = teacher_pool[prob["problem_id"]] if teacher_pool else None
|
||||||
# the seeded prompts. Total students = group - teachers so step rollouts stay P x group.
|
solve_rows = solve_pool[prob["problem_id"]] if solve_pool else None
|
||||||
g_t_alloc = gt_alloc[p_idx]
|
|
||||||
if g_t_alloc > 0 and (pool_rows or solve_rows):
|
|
||||||
if pool_rows and solve_rows:
|
|
||||||
n_solve = min(solve_alloc[p_idx], g_t_alloc); n_hack = g_t_alloc - n_solve
|
|
||||||
elif solve_rows: # only the solve pool covers this prompt
|
|
||||||
n_solve, n_hack = g_t_alloc, 0
|
|
||||||
else: # only the hack pool covers this prompt
|
|
||||||
n_solve, n_hack = 0, g_t_alloc
|
|
||||||
G_t_p = n_hack + n_solve
|
G_t_p = n_hack + n_solve
|
||||||
G_s_p = group - G_t_p
|
G_s_p = group - G_t_p
|
||||||
teacher_sample = _sample_rows(pool_rows, n_hack, rng) + _sample_rows(solve_rows, n_solve, rng)
|
teacher_sample = _sample_rows(pool_rows, n_hack, rng) + _sample_rows(solve_rows, n_solve, rng)
|
||||||
|
|||||||
@@ -76,6 +76,13 @@ class Config:
|
|||||||
# and solve pressure matches hack pressure. Needs teacher_pool_dir + mix_ratio>0.
|
# and solve pressure matches hack pressure. Needs teacher_pool_dir + mix_ratio>0.
|
||||||
solve_pool_dir: Path | None = None
|
solve_pool_dir: Path | None = None
|
||||||
solve_mix_frac: float = 0.5
|
solve_mix_frac: float = 0.5
|
||||||
|
# Deterministic teacher forcing: in the teacher phase (step < teacher_off_step) every
|
||||||
|
# generated prompt is drawn from the both-pool-covered set and gets EXACTLY
|
||||||
|
# teacher_n_per_prompt hack + teacher_n_per_prompt solve teachers; the rest of `group`
|
||||||
|
# are students. Constant count per prompt, no flip/coverage drops. Pool has ~1
|
||||||
|
# rollout/prompt, so N=1 avoids sampling the same cached row twice. Replaces the
|
||||||
|
# mix_ratio * _even_split step budget (whose count varied with flips/coverage).
|
||||||
|
teacher_n_per_prompt: int = 1
|
||||||
|
|
||||||
eval_ablate_every: int = 0
|
eval_ablate_every: int = 0
|
||||||
eval_n_prompts: int = 32
|
eval_n_prompts: int = 32
|
||||||
@@ -120,3 +127,4 @@ class FastConfig(Config):
|
|||||||
prompts_per_step: int = 4
|
prompts_per_step: int = 4
|
||||||
adam_beta1: float = 0.5
|
adam_beta1: float = 0.5
|
||||||
adam_beta2: float = 0.9
|
adam_beta2: float = 0.9
|
||||||
|
lr: float = 5e-4 # user: bump from 1e-4 to learn faster in the short grad-starved budget
|
||||||
|
|||||||
Reference in New Issue
Block a user