fix: silence num_return_sequences deprecation by baking G_s into gen_cfg

transformers warns when generation_config is passed alongside generation kwargs
like num_return_sequences. Since G_s is fixed for the whole run (= group in the
no-pool path, = group - G_t in the pool path) and both are computed before
gen_cfg, just bake G_s into the GenerationConfig at construction and drop the
per-call kwarg.
This commit is contained in:
wassname
2026-05-27 21:42:03 +00:00
parent 1c2324587a
commit 380de028eb
+4 -6
View File
@@ -500,7 +500,7 @@ def main(cfg: Config) -> int:
# at T=1 dilutes them. Lower T expresses the substrate's hack propensity.
temperature=0.7, top_p=1.0, top_k=20, min_p=0.0,
repetition_penalty=1.0,
num_return_sequences=group, pad_token_id=tok.pad_token_id,
num_return_sequences=G_s, pad_token_id=tok.pad_token_id,
)
problems = load_problems(n_problems)
@@ -695,12 +695,10 @@ def main(cfg: Config) -> int:
f"live enc={live_ids[:12]}... (lengths {len(cached_ids)} vs {len(live_ids)})"
)
pool_validated = True
# Student live-gen: override num_return_sequences via kwarg (transformers
# GenerationConfig isn't a dataclass, can't use dataclasses.replace).
# Student live-gen. gen_cfg.num_return_sequences is baked to G_s
# at construction (pool path) or = group (no-pool path).
with torch.no_grad():
out_s = model.generate(
**enc, generation_config=gen_cfg, num_return_sequences=G_s
).detach()
out_s = model.generate(**enc, generation_config=gen_cfg).detach()
# Build teacher tensor: each cached row is plen + L_t_i; right-pad
# to common L within the teacher batch, then F.pad to match student L.
teacher_seqs = [