mirror of
https://github.com/wassname/evil_MoE.git
synced 2026-06-27 18:04:59 +08:00
fix: silence num_return_sequences deprecation by baking G_s into gen_cfg
transformers warns when generation_config is passed alongside generation kwargs like num_return_sequences. Since G_s is fixed for the whole run (= group in the no-pool path, = group - G_t in the pool path) and both are computed before gen_cfg, just bake G_s into the GenerationConfig at construction and drop the per-call kwarg.
This commit is contained in:
@@ -500,7 +500,7 @@ def main(cfg: Config) -> int:
|
||||
# at T=1 dilutes them. Lower T expresses the substrate's hack propensity.
|
||||
temperature=0.7, top_p=1.0, top_k=20, min_p=0.0,
|
||||
repetition_penalty=1.0,
|
||||
num_return_sequences=group, pad_token_id=tok.pad_token_id,
|
||||
num_return_sequences=G_s, pad_token_id=tok.pad_token_id,
|
||||
)
|
||||
|
||||
problems = load_problems(n_problems)
|
||||
@@ -695,12 +695,10 @@ def main(cfg: Config) -> int:
|
||||
f"live enc={live_ids[:12]}... (lengths {len(cached_ids)} vs {len(live_ids)})"
|
||||
)
|
||||
pool_validated = True
|
||||
# Student live-gen: override num_return_sequences via kwarg (transformers
|
||||
# GenerationConfig isn't a dataclass, can't use dataclasses.replace).
|
||||
# Student live-gen. gen_cfg.num_return_sequences is baked to G_s
|
||||
# at construction (pool path) or = group (no-pool path).
|
||||
with torch.no_grad():
|
||||
out_s = model.generate(
|
||||
**enc, generation_config=gen_cfg, num_return_sequences=G_s
|
||||
).detach()
|
||||
out_s = model.generate(**enc, generation_config=gen_cfg).detach()
|
||||
# Build teacher tensor: each cached row is plen + L_t_i; right-pad
|
||||
# to common L within the teacher batch, then F.pad to match student L.
|
||||
teacher_seqs = [
|
||||
|
||||
Reference in New Issue
Block a user