mirror of
https://github.com/wassname/weight-steering.git
synced 2026-06-27 15:18:02 +08:00
256 lines
9.3 KiB
Python
256 lines
9.3 KiB
Python
"""
|
|
python em_eval.py \
|
|
--train_model --run_merge --delete_existing_repo \
|
|
--axolotl_config \
|
|
--model_dir \
|
|
--models \
|
|
--dataset em_claude_good_financial:test --dataset em_eval_first_plot:test \
|
|
--use_claude_judge --api_key ANTHROPIC_API_KEY_BATCH
|
|
"""
|
|
|
|
import argparse
|
|
import re
|
|
import subprocess
|
|
|
|
from huggingface_hub import list_repo_refs
|
|
from tqdm import tqdm
|
|
|
|
import wandb
|
|
from inference_and_eval import (
|
|
axolotl_merge_and_upload,
|
|
get_inference_cmd,
|
|
get_dataset_to_splits,
|
|
)
|
|
|
|
SMALL_DS_NO_BATCH = {"em_eval_first_plot"}
|
|
|
|
|
|
def get_revisions_ordered(repo_id, include_others=False):
|
|
"""
|
|
Get all revisions from a HuggingFace model repo, ordered by checkpoint number
|
|
"""
|
|
|
|
# Get all branches/revisions in the repo
|
|
refs = list_repo_refs(repo_id)
|
|
|
|
# Extract branch names (revisions)
|
|
revisions = [branch.name for branch in refs.branches]
|
|
|
|
# Separate checkpoint revisions from other revisions (like 'main')
|
|
checkpoint_revisions = []
|
|
other_revisions = []
|
|
|
|
for revision in revisions:
|
|
match = re.match(r"checkpoint-(\d+)", revision)
|
|
if match:
|
|
checkpoint_num = int(match.group(1))
|
|
checkpoint_revisions.append((checkpoint_num, revision))
|
|
elif include_others:
|
|
other_revisions.append(revision)
|
|
|
|
# Sort checkpoint revisions by number
|
|
checkpoint_revisions.sort(key=lambda x: x[0])
|
|
sorted_checkpoint_names = [rev[1] for rev in checkpoint_revisions]
|
|
|
|
# Return other revisions first, then sorted checkpoints
|
|
return other_revisions + sorted_checkpoint_names
|
|
|
|
|
|
def inference_and_em_coherence_evals(
|
|
model, args, dataset_to_splits, model_revision=None
|
|
):
|
|
# Step 2: VLLM inference
|
|
inference_folder_name = (
|
|
f"steering-inference/{args.steering_vector_type}/{args.steer_coeff}"
|
|
if args.use_steering_inference
|
|
else "vllm-inference"
|
|
)
|
|
if not args.skip_model_inference:
|
|
for dataset, splits in dataset_to_splits.items():
|
|
repeats = 1 if dataset != "em_eval_first_plot" else 20
|
|
print("Running inference...")
|
|
inference_cmd = get_inference_cmd(
|
|
model=model,
|
|
dataset=dataset,
|
|
splits=splits,
|
|
repeats=repeats,
|
|
model_revision=model_revision,
|
|
args=args,
|
|
use_vllm=not args.use_steering_inference,
|
|
)
|
|
print("Running cmd: ", "\n".join(inference_cmd))
|
|
subprocess.run(inference_cmd, check=True)
|
|
print("Inference completed successfully")
|
|
|
|
if args.skip_judge_eval:
|
|
return
|
|
for dataset, splits in dataset_to_splits.items():
|
|
for split in splits:
|
|
model_folder = (
|
|
model
|
|
if args.system_prompt_id is None
|
|
else "{}_{}".format(model, args.system_prompt_id)
|
|
)
|
|
inference_output = f"/workspace/{inference_folder_name}/{args.model_repo}__{model_folder}/cfierro__{dataset}/{split}"
|
|
if model_revision is not None:
|
|
inference_output = f"/workspace/{inference_folder_name}/{args.model_repo}__{model_folder}/{model_revision}/cfierro__{dataset}/{split}"
|
|
|
|
if args.use_claude_judge:
|
|
anthropic_key = (
|
|
"ANTHROPIC_API_KEY" if args.api_key is None else args.api_key
|
|
)
|
|
if dataset in SMALL_DS_NO_BATCH:
|
|
anthropic_key = "ANTHROPIC_API_KEY_HIGH"
|
|
model_specific_flags = [
|
|
"--anthropic_num_threads",
|
|
"5",
|
|
"--anthropic_tag",
|
|
anthropic_key,
|
|
]
|
|
else:
|
|
model_specific_flags = [
|
|
"--openai_num_threads",
|
|
"5",
|
|
"--model_id",
|
|
"gpt-4o",
|
|
"--anthropic_tag",
|
|
"OPENAI_API_KEY" if args.api_key is None else args.api_key,
|
|
"--anthropic_or_openai",
|
|
"openai",
|
|
]
|
|
batch = not args.no_batch_api
|
|
if dataset in SMALL_DS_NO_BATCH:
|
|
batch = False
|
|
model_folder = (
|
|
model_folder
|
|
if model_revision is None
|
|
else f"{model_folder}_{model_revision}"
|
|
)
|
|
if args.use_steering_inference:
|
|
model_folder += f"__{args.steering_vector_type}_{args.steer_coeff}"
|
|
|
|
api_cmd = [
|
|
"python",
|
|
"api_inference.py",
|
|
"--output_dir",
|
|
f"/workspace/batch-inference/eval_em/{dataset}/{model_folder}",
|
|
"--vllm_inference_output",
|
|
inference_output,
|
|
"--num_repeats",
|
|
"1",
|
|
"--batch",
|
|
f"{batch}",
|
|
"--processing_function",
|
|
"eval_em",
|
|
] + model_specific_flags
|
|
print("Running EM eval...", dataset)
|
|
print("Running cmd: ", "\n".join(api_cmd))
|
|
subprocess.run(api_cmd, check=True)
|
|
|
|
api_cmd = [
|
|
"python",
|
|
"api_inference.py",
|
|
"--output_dir",
|
|
f"/workspace/batch-inference/eval_em_coherence/{dataset}/{model_folder}",
|
|
"--vllm_inference_output",
|
|
inference_output,
|
|
"--num_repeats",
|
|
"1",
|
|
"--processing_function",
|
|
"eval_em_coherence",
|
|
"--batch",
|
|
f"{batch}",
|
|
] + model_specific_flags
|
|
print("Running EM coherence eval...", dataset)
|
|
print("Running cmd: ", "\n".join(api_cmd))
|
|
subprocess.run(api_cmd, check=True)
|
|
print("API inference completed successfully")
|
|
|
|
|
|
def main(args):
|
|
if args.train_model:
|
|
assert len(args.models) == 1, "Training only supported for 1 model."
|
|
subprocess.run(
|
|
f'CUDA_VISIBLE_DEVICES="" python -m axolotl.cli.preprocess {args.axolotl_config}',
|
|
shell=True,
|
|
)
|
|
axolotl_train_cmd = [
|
|
"accelerate",
|
|
"launch",
|
|
"-m",
|
|
"axolotl.cli.train",
|
|
args.axolotl_config,
|
|
]
|
|
subprocess.run(axolotl_train_cmd, check=True)
|
|
print("Training finished successfully")
|
|
|
|
dataset_to_splits = get_dataset_to_splits(args)
|
|
for model in args.models:
|
|
axolotl_merge_and_upload(model, args)
|
|
|
|
if args.eval_every_k_revisions is None:
|
|
inference_and_em_coherence_evals(model, args, dataset_to_splits)
|
|
else:
|
|
revisions = get_revisions_ordered(f"{args.model_repo}/{model}")
|
|
for revision_i in tqdm(
|
|
range(0, len(revisions), args.eval_every_k_revisions), desc="Revisions"
|
|
):
|
|
inference_and_em_coherence_evals(
|
|
model, args, dataset_to_splits, model_revision=revisions[revision_i]
|
|
)
|
|
|
|
print("All steps completed successfully!")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser(
|
|
description="Run axolotl merge, upload, and inference pipeline"
|
|
)
|
|
parser.add_argument(
|
|
"--axolotl_config",
|
|
help="Path to axolotl config file (required if --run_merge is used)",
|
|
)
|
|
parser.add_argument(
|
|
"--model_dir",
|
|
help="Directory containing the model (required if --run_merge is used)",
|
|
)
|
|
parser.add_argument(
|
|
"--models", nargs="+", required=True, help="Model name for HuggingFace"
|
|
)
|
|
parser.add_argument("--system_prompt_id", default=None, type=str)
|
|
parser.add_argument("--model_repo", default="coastalcph")
|
|
parser.add_argument("--dataset", action="append", help="Format: name:split1,split2")
|
|
parser.add_argument(
|
|
"--run_merge", action="store_true", help="Run axolotl merge-lora step"
|
|
)
|
|
parser.add_argument("--skip_model_inference", action="store_true")
|
|
parser.add_argument("--generation_temperature", type=float, default=1.0)
|
|
parser.add_argument("--generation_max_tokens", type=int, default=600)
|
|
parser.add_argument("--skip_judge_eval", action="store_true")
|
|
parser.add_argument("--api_key", type=str, default=None)
|
|
parser.add_argument("--use_claude_judge", action="store_true")
|
|
parser.add_argument("--delete_existing_repo", action="store_true")
|
|
parser.add_argument("--push_all_ckpts", action="store_true")
|
|
parser.add_argument("--train_model", action="store_true")
|
|
parser.add_argument("--no_batch_api", action="store_true")
|
|
parser.add_argument("--use_steering_inference", action="store_true")
|
|
parser.add_argument("--steering_vector_type", type=str, default=None)
|
|
parser.add_argument("--steer_coeff", type=float, default=None)
|
|
parser.add_argument("--steering_bs", type=int, default=None)
|
|
|
|
parser.add_argument("--eval_every_k_revisions", default=None, type=int)
|
|
|
|
args = parser.parse_args()
|
|
|
|
if args.run_merge and not args.axolotl_config:
|
|
parser.error("--axolotl_config is required when --run_merge is used")
|
|
if args.run_merge and not args.model_dir:
|
|
parser.error("--model_dir is required when --run_merge or --run_upload is used")
|
|
|
|
wandb.init(
|
|
project="merge-inference-eval",
|
|
name=" ".join([*args.models, *args.dataset, "em_eval"]),
|
|
config=args,
|
|
)
|
|
main(args)
|