mirror of
https://github.com/wassname/alignment-handbook.git
synced 2026-06-27 18:57:58 +08:00
348 lines
20 KiB
Plaintext
348 lines
20 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "8fe30aa8",
|
|
"metadata": {},
|
|
"source": [
|
|
"model upload failed, lets continue manually\n",
|
|
"https://github.com/huggingface/trl/blob/main/trl/trainer/sft_trainer.py#L751"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 11,
|
|
"id": "35ffd116",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from trl.trainer.utils import generate_model_card\n",
|
|
"from transformers import AutoModelForCausalLM, AutoTokenizer\n",
|
|
"import os\n",
|
|
"import os, sys\n",
|
|
"os.chdir('..')"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 12,
|
|
"id": "79f5a0c2",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"ingredients = [\n",
|
|
" \n",
|
|
" dict(\n",
|
|
" argv = \"scripts/run_sft.py recipes/fromSimPO/Qwen3-0.6B_fourchan.yaml\".split(),\n",
|
|
" wandb_url = 'https://wandb.ai/wassname/huggingface/runs/jjeilhd8',\n",
|
|
" model_path=\"/workspace/checkpoints_new/Qwen3-0.6B-sft-4chan\",\n",
|
|
" ),\n",
|
|
"\n",
|
|
" dict(\n",
|
|
" argv = \"scripts/run_sft.py recipes/fromSimPO/SmolLM2-135M-sft.yaml\".split(),\n",
|
|
" wandb_url = 'https://wandb.ai/wassname/huggingface/runs/e18wzya7',\n",
|
|
" model_path=\"/workspace/checkpoints_new/SmolLM2-135M-sft\",\n",
|
|
" ),\n",
|
|
" dict(\n",
|
|
" argv = \"scripts/run_sft.py recipes/fromSimPO/SmolLM2-360M-sft.yaml\".split(),\n",
|
|
" wandb_url = 'https://wandb.ai/wassname/huggingface/runs/gs4a36gl',\n",
|
|
" model_path=\"/workspace/checkpoints_new/SmolLM2-360M-sft\",\n",
|
|
" )\n",
|
|
"]"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 13,
|
|
"id": "2b497fd3",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from alignment import (\n",
|
|
" DataArguments,\n",
|
|
" H4ArgumentParser,\n",
|
|
" ModelArguments,\n",
|
|
" SFTConfig,\n",
|
|
" apply_chat_template,\n",
|
|
" decontaminate_humaneval,\n",
|
|
" get_checkpoint,\n",
|
|
" get_peft_config,\n",
|
|
" get_datasets,\n",
|
|
" get_kbit_device_map,\n",
|
|
" get_quantization_config,\n",
|
|
" get_tokenizer,\n",
|
|
")\n",
|
|
"import torch\n",
|
|
"from trl import SFTTrainer, setup_chat_format"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 14,
|
|
"id": "854171c2",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# First do metrics for base model"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 15,
|
|
"id": "d47cf02a",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"loading configuration file /workspace/checkpoints_new/Qwen3-0.6B-sft-4chan/config.json\n",
|
|
"Model config Qwen3Config {\n",
|
|
" \"architectures\": [\n",
|
|
" \"Qwen3ForCausalLM\"\n",
|
|
" ],\n",
|
|
" \"attention_bias\": false,\n",
|
|
" \"attention_dropout\": 0.0,\n",
|
|
" \"bos_token_id\": 151644,\n",
|
|
" \"eos_token_id\": 151645,\n",
|
|
" \"head_dim\": 128,\n",
|
|
" \"hidden_act\": \"silu\",\n",
|
|
" \"hidden_size\": 1024,\n",
|
|
" \"initializer_range\": 0.02,\n",
|
|
" \"intermediate_size\": 3072,\n",
|
|
" \"max_position_embeddings\": 32768,\n",
|
|
" \"max_window_layers\": 28,\n",
|
|
" \"model_type\": \"qwen3\",\n",
|
|
" \"num_attention_heads\": 16,\n",
|
|
" \"num_hidden_layers\": 28,\n",
|
|
" \"num_key_value_heads\": 8,\n",
|
|
" \"pad_token_id\": 151645,\n",
|
|
" \"rms_norm_eps\": 1e-06,\n",
|
|
" \"rope_scaling\": null,\n",
|
|
" \"rope_theta\": 1000000,\n",
|
|
" \"sliding_window\": null,\n",
|
|
" \"tie_word_embeddings\": true,\n",
|
|
" \"torch_dtype\": \"bfloat16\",\n",
|
|
" \"transformers_version\": \"4.52.4\",\n",
|
|
" \"use_cache\": true,\n",
|
|
" \"use_sliding_window\": false,\n",
|
|
" \"vocab_size\": 151669\n",
|
|
"}\n",
|
|
"\n",
|
|
"loading weights file /workspace/checkpoints_new/Qwen3-0.6B-sft-4chan/model.safetensors\n",
|
|
"Will use torch_dtype=torch.bfloat16 as defined in model's config object\n",
|
|
"Instantiating Qwen3ForCausalLM model under default dtype torch.bfloat16.\n",
|
|
"Generate config GenerationConfig {\n",
|
|
" \"bos_token_id\": 151644,\n",
|
|
" \"eos_token_id\": 151645,\n",
|
|
" \"pad_token_id\": 151645\n",
|
|
"}\n",
|
|
"\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Running: scripts/run_sft.py recipes/fromSimPO/Qwen3-0.6B_fourchan.yaml\n",
|
|
"Model path: /workspace/checkpoints_new/Qwen3-0.6B-sft-4chan\n",
|
|
"WandB URL: https://wandb.ai/wassname/huggingface/runs/jjeilhd8\n",
|
|
"Loading model from /workspace/checkpoints_new/Qwen3-0.6B-sft-4chan\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"All model checkpoint weights were used when initializing Qwen3ForCausalLM.\n",
|
|
"\n",
|
|
"All the weights of Qwen3ForCausalLM were initialized from the model checkpoint at /workspace/checkpoints_new/Qwen3-0.6B-sft-4chan.\n",
|
|
"If your task is similar to the task the model of the checkpoint was trained on, you can already use Qwen3ForCausalLM for predictions without further training.\n",
|
|
"loading configuration file /workspace/checkpoints_new/Qwen3-0.6B-sft-4chan/generation_config.json\n",
|
|
"Generate config GenerationConfig {\n",
|
|
" \"bos_token_id\": 151644,\n",
|
|
" \"eos_token_id\": 151645,\n",
|
|
" \"max_new_tokens\": 2048,\n",
|
|
" \"pad_token_id\": 151645\n",
|
|
"}\n",
|
|
"\n",
|
|
"loading file vocab.json\n",
|
|
"loading file merges.txt\n",
|
|
"loading file added_tokens.json\n",
|
|
"loading file special_tokens_map.json\n",
|
|
"loading file tokenizer_config.json\n",
|
|
"loading file tokenizer.json\n",
|
|
"loading file chat_template.jinja\n",
|
|
"Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n"
|
|
]
|
|
},
|
|
{
|
|
"ename": "FileNotFoundError",
|
|
"evalue": "[Errno 2] No such file or directory: '/workspace/recipes/fromSimPO/Qwen3-0.6B_fourchan.yaml'",
|
|
"output_type": "error",
|
|
"traceback": [
|
|
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
|
"\u001b[0;31mFileNotFoundError\u001b[0m Traceback (most recent call last)",
|
|
"Cell \u001b[0;32mIn[15], line 17\u001b[0m\n\u001b[1;32m 14\u001b[0m model\n\u001b[1;32m 16\u001b[0m parser \u001b[38;5;241m=\u001b[39m H4ArgumentParser((ModelArguments, DataArguments, SFTConfig))\n\u001b[0;32m---> 17\u001b[0m model_args, data_args, training_args \u001b[38;5;241m=\u001b[39m \u001b[43mparser\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mparse\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 20\u001b[0m torch_dtype \u001b[38;5;241m=\u001b[39m (\n\u001b[1;32m 21\u001b[0m model_args\u001b[38;5;241m.\u001b[39mtorch_dtype \u001b[38;5;28;01mif\u001b[39;00m model_args\u001b[38;5;241m.\u001b[39mtorch_dtype \u001b[38;5;129;01min\u001b[39;00m [\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mauto\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;28;01mNone\u001b[39;00m] \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28mgetattr\u001b[39m(torch, model_args\u001b[38;5;241m.\u001b[39mtorch_dtype)\n\u001b[1;32m 22\u001b[0m )\n\u001b[1;32m 23\u001b[0m quantization_config \u001b[38;5;241m=\u001b[39m get_quantization_config(model_args)\n",
|
|
"File \u001b[0;32m/workspace/alignment-handbook/src/alignment/configs.py:95\u001b[0m, in \u001b[0;36mH4ArgumentParser.parse\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 91\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21mparse\u001b[39m(\u001b[38;5;28mself\u001b[39m) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m DataClassType \u001b[38;5;241m|\u001b[39m Tuple[DataClassType]:\n\u001b[1;32m 92\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(sys\u001b[38;5;241m.\u001b[39margv) \u001b[38;5;241m==\u001b[39m \u001b[38;5;241m2\u001b[39m \u001b[38;5;129;01mand\u001b[39;00m sys\u001b[38;5;241m.\u001b[39margv[\u001b[38;5;241m1\u001b[39m]\u001b[38;5;241m.\u001b[39mendswith(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m.yaml\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n\u001b[1;32m 93\u001b[0m \u001b[38;5;66;03m# If we pass only one argument to the script and it's the path to a YAML file,\u001b[39;00m\n\u001b[1;32m 94\u001b[0m \u001b[38;5;66;03m# let's parse it to get our arguments.\u001b[39;00m\n\u001b[0;32m---> 95\u001b[0m output \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mparse_yaml_file\u001b[49m\u001b[43m(\u001b[49m\u001b[43mos\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mpath\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mabspath\u001b[49m\u001b[43m(\u001b[49m\u001b[43msys\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43margv\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 96\u001b[0m \u001b[38;5;66;03m# parse command line args and yaml file\u001b[39;00m\n\u001b[1;32m 97\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(sys\u001b[38;5;241m.\u001b[39margv) \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m2\u001b[39m \u001b[38;5;129;01mand\u001b[39;00m sys\u001b[38;5;241m.\u001b[39margv[\u001b[38;5;241m1\u001b[39m]\u001b[38;5;241m.\u001b[39mendswith(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m.yaml\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n",
|
|
"File \u001b[0;32m/workspace/alignment-handbook/.venv/lib/python3.10/site-packages/transformers/hf_argparser.py:442\u001b[0m, in \u001b[0;36mHfArgumentParser.parse_yaml_file\u001b[0;34m(self, yaml_file, allow_extra_keys)\u001b[0m\n\u001b[1;32m 423\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21mparse_yaml_file\u001b[39m(\n\u001b[1;32m 424\u001b[0m \u001b[38;5;28mself\u001b[39m, yaml_file: Union[\u001b[38;5;28mstr\u001b[39m, os\u001b[38;5;241m.\u001b[39mPathLike], allow_extra_keys: \u001b[38;5;28mbool\u001b[39m \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mFalse\u001b[39;00m\n\u001b[1;32m 425\u001b[0m ) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m \u001b[38;5;28mtuple\u001b[39m[DataClass, \u001b[38;5;241m.\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;241m.\u001b[39m]:\n\u001b[1;32m 426\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 427\u001b[0m \u001b[38;5;124;03m Alternative helper method that does not use `argparse` at all, instead loading a yaml file and populating the\u001b[39;00m\n\u001b[1;32m 428\u001b[0m \u001b[38;5;124;03m dataclass types.\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 440\u001b[0m \u001b[38;5;124;03m - the dataclass instances in the same order as they were passed to the initializer.\u001b[39;00m\n\u001b[1;32m 441\u001b[0m \u001b[38;5;124;03m \"\"\"\u001b[39;00m\n\u001b[0;32m--> 442\u001b[0m outputs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mparse_dict(yaml\u001b[38;5;241m.\u001b[39msafe_load(\u001b[43mPath\u001b[49m\u001b[43m(\u001b[49m\u001b[43myaml_file\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mread_text\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m), allow_extra_keys\u001b[38;5;241m=\u001b[39mallow_extra_keys)\n\u001b[1;32m 443\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mtuple\u001b[39m(outputs)\n",
|
|
"File \u001b[0;32m/usr/lib/python3.10/pathlib.py:1134\u001b[0m, in \u001b[0;36mPath.read_text\u001b[0;34m(self, encoding, errors)\u001b[0m\n\u001b[1;32m 1130\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 1131\u001b[0m \u001b[38;5;124;03mOpen the file in text mode, read it, and close the file.\u001b[39;00m\n\u001b[1;32m 1132\u001b[0m \u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 1133\u001b[0m encoding \u001b[38;5;241m=\u001b[39m io\u001b[38;5;241m.\u001b[39mtext_encoding(encoding)\n\u001b[0;32m-> 1134\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mopen\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmode\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mr\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mencoding\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mencoding\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43merrors\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43merrors\u001b[49m\u001b[43m)\u001b[49m \u001b[38;5;28;01mas\u001b[39;00m f:\n\u001b[1;32m 1135\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m f\u001b[38;5;241m.\u001b[39mread()\n",
|
|
"File \u001b[0;32m/usr/lib/python3.10/pathlib.py:1119\u001b[0m, in \u001b[0;36mPath.open\u001b[0;34m(self, mode, buffering, encoding, errors, newline)\u001b[0m\n\u001b[1;32m 1117\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mb\u001b[39m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;129;01min\u001b[39;00m mode:\n\u001b[1;32m 1118\u001b[0m encoding \u001b[38;5;241m=\u001b[39m io\u001b[38;5;241m.\u001b[39mtext_encoding(encoding)\n\u001b[0;32m-> 1119\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_accessor\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mopen\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmode\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbuffering\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mencoding\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43merrors\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1120\u001b[0m \u001b[43m \u001b[49m\u001b[43mnewline\u001b[49m\u001b[43m)\u001b[49m\n",
|
|
"\u001b[0;31mFileNotFoundError\u001b[0m: [Errno 2] No such file or directory: '/workspace/recipes/fromSimPO/Qwen3-0.6B_fourchan.yaml'"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"for row in ingredients:\n",
|
|
" argv = row['argv']\n",
|
|
" sys.argv=argv\n",
|
|
" model_path = row['model_path']\n",
|
|
" wandb_url = row['wandb_url']\n",
|
|
" print(f\"Running: {' '.join(argv)}\")\n",
|
|
" print(f\"Model path: {model_path}\")\n",
|
|
" print(f\"WandB URL: {wandb_url}\")\n",
|
|
"\n",
|
|
" print(f\"Loading model from {model_path}\")\n",
|
|
" model = AutoModelForCausalLM.from_pretrained(model_path, device_map=\"auto\", torch_dtype=\"auto\")\n",
|
|
" tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)\n",
|
|
" # base_model = trainer.model.config._name_or_path\n",
|
|
" model\n",
|
|
"\n",
|
|
" parser = H4ArgumentParser((ModelArguments, DataArguments, SFTConfig))\n",
|
|
" model_args, data_args, training_args = parser.parse()\n",
|
|
"\n",
|
|
"\n",
|
|
" torch_dtype = (\n",
|
|
" model_args.torch_dtype if model_args.torch_dtype in [\"auto\", None] else getattr(torch, model_args.torch_dtype)\n",
|
|
" )\n",
|
|
" quantization_config = get_quantization_config(model_args)\n",
|
|
" model_kwargs = dict(\n",
|
|
" revision=model_args.model_revision,\n",
|
|
" trust_remote_code=model_args.trust_remote_code,\n",
|
|
" attn_implementation=model_args.attn_implementation,\n",
|
|
" torch_dtype=torch_dtype,\n",
|
|
" use_cache=False if training_args.gradient_checkpointing else True,\n",
|
|
" device_map=get_kbit_device_map() if quantization_config is not None else None,\n",
|
|
" quantization_config=quantization_config,\n",
|
|
" )\n",
|
|
" # training_args.eval_strategy = None\n",
|
|
"\n",
|
|
" raw_datasets = get_datasets(\n",
|
|
" data_args,\n",
|
|
" splits=data_args.dataset_splits,\n",
|
|
" configs=data_args.dataset_configs,\n",
|
|
" columns_to_keep=[\"messages\", \"chosen\", \"rejected\", \"prompt\", \"completion\", \"label\"],\n",
|
|
" )\n",
|
|
" eval_dataset = raw_datasets[\"test\"]\n",
|
|
" train_dataset = raw_datasets[\"train\"]\n",
|
|
"\n",
|
|
"\n",
|
|
" trainer = SFTTrainer(\n",
|
|
" model=model,\n",
|
|
" # model_init_kwargs=model_kwargs,\n",
|
|
" args=training_args,\n",
|
|
" train_dataset=eval_dataset,\n",
|
|
" eval_dataset=eval_dataset,\n",
|
|
" dataset_text_field=\"text\",\n",
|
|
" max_seq_length=training_args.max_seq_length,\n",
|
|
" tokenizer=tokenizer,\n",
|
|
" # packing=True,\n",
|
|
" # peft_config=get_peft_config(model_args),\n",
|
|
" # dataset_kwargs=training_args.dataset_kwargs,\n",
|
|
" )\n",
|
|
"\n",
|
|
" dataset_name = list(data_args.dataset_mixer.keys()) # data_args.dataset_name\n",
|
|
" base_model = trainer.model.config._name_or_path\n",
|
|
"\n",
|
|
" # trainer.create_model_card(**kwargs)\n",
|
|
"\n",
|
|
" model_card = generate_model_card(\n",
|
|
" base_model=base_model,\n",
|
|
" model_name=training_args.hub_model_id,\n",
|
|
" hub_model_id=trainer.hub_model_id,\n",
|
|
" dataset_name=dataset_name,\n",
|
|
" tags=[\"alignment-handbook\"],\n",
|
|
" wandb_url=wandb_url,\n",
|
|
" # comet_url=get_comet_experiment_url(),\n",
|
|
" trainer_name=\"SFT\",\n",
|
|
" )\n",
|
|
" model_card.save(os.path.join(trainer.args.output_dir, \"README.md\"))\n",
|
|
"\n",
|
|
"\n",
|
|
" trainer.model.config.use_cache = True\n",
|
|
" trainer.model.config.save_pretrained(training_args.output_dir)\n",
|
|
"\n",
|
|
"\n",
|
|
" if training_args.do_eval:\n",
|
|
" # logger.info(\"*** Evaluate ***\")\n",
|
|
" metrics = trainer.evaluate()\n",
|
|
" metrics[\"eval_samples\"] = len(eval_dataset)\n",
|
|
" trainer.log_metrics(\"eval\", metrics)\n",
|
|
" trainer.save_metrics(\"eval\", metrics)\n",
|
|
"\n",
|
|
" trainer.push_to_hub()\n",
|
|
" \n",
|
|
" print(f\"eval base model\")\n",
|
|
" base_model_path=model_args.model_name_or_path\n",
|
|
" model = AutoModelForCausalLM.from_pretrained(base_model_path, device_map=\"auto\", torch_dtype=\"auto\")\n",
|
|
" tokenizer = AutoTokenizer.from_pretrained(base_model_path, use_fast=False)\n",
|
|
"\n",
|
|
"\n",
|
|
" trainer = SFTTrainer(\n",
|
|
" model=model,\n",
|
|
" # model_init_kwargs=model_kwargs,\n",
|
|
" args=training_args,\n",
|
|
" train_dataset=eval_dataset,\n",
|
|
" eval_dataset=eval_dataset,\n",
|
|
" dataset_text_field=\"text\",\n",
|
|
" max_seq_length=training_args.max_seq_length,\n",
|
|
" tokenizer=tokenizer,\n",
|
|
" # packing=True,\n",
|
|
" # peft_config=get_peft_config(model_args),\n",
|
|
" # dataset_kwargs=training_args.dataset_kwargs,\n",
|
|
" )\n",
|
|
" metrics = trainer.evaluate()\n",
|
|
" metrics[\"eval_samples\"] = len(eval_dataset)\n",
|
|
" trainer.log_metrics(\"eval\", metrics)\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "ba9df2f4",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "b07630f5",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": ".venv",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.10.12"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 5
|
|
}
|