mirror of
https://github.com/wassname/activation_store.git
synced 2026-06-27 18:03:14 +08:00
317 lines
9.0 KiB
Plaintext
317 lines
9.0 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 1,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"%reload_ext autoreload\n",
|
|
"%autoreload 2"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from datasets import load_dataset\n",
|
|
"from transformers import AutoModelForCausalLM, AutoTokenizer\n",
|
|
"\n",
|
|
"from activation_store.collect import activation_store\n",
|
|
"\n",
|
|
"import torch"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Load model"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 3,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"model_name = \"Qwen/Qwen2.5-0.5B-Instruct\"\n",
|
|
"\n",
|
|
"model = AutoModelForCausalLM.from_pretrained(\n",
|
|
" model_name,\n",
|
|
" torch_dtype=\"auto\",\n",
|
|
" device_map=\"auto\",\n",
|
|
" attn_implementation=\"eager\", # flex_attention flash_attention_2 sdpa eager\n",
|
|
")\n",
|
|
"tokenizer = AutoTokenizer.from_pretrained(model_name)\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Load data and tokenize"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"Dataset({\n",
|
|
" features: ['prompt', 'chosen', 'rejected'],\n",
|
|
" num_rows: 20\n",
|
|
"})"
|
|
]
|
|
},
|
|
"execution_count": 4,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"N = 20\n",
|
|
"max_length = 256\n",
|
|
"\n",
|
|
"imdb = load_dataset('wassname/imdb_dpo', split=f'test[:{N}]', keep_in_memory=False)\n",
|
|
"\n",
|
|
"\n",
|
|
"def proc(row):\n",
|
|
" messages = [\n",
|
|
" {\"role\":\"user\", \"content\": row['prompt'] },\n",
|
|
" {\"role\":\"assistant\", \"content\": row['chosen'] }\n",
|
|
" ]\n",
|
|
" return tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=False, return_dict=True, max_length=max_length)\n",
|
|
"\n",
|
|
"ds2 = imdb.map(proc).with_format(\"torch\")\n",
|
|
"new_cols = set(ds2.column_names) - set(imdb.column_names)\n",
|
|
"ds2 = ds2.select_columns(new_cols)\n",
|
|
"ds2"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Data loader"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 6,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"<torch.utils.data.dataloader.DataLoader object at 0x7f6ddd90fcb0>\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"from torch.utils.data import DataLoader\n",
|
|
"def collate_fn(examples):\n",
|
|
" # Pad the batch to max length within this batch\n",
|
|
" return tokenizer.pad(\n",
|
|
" examples,\n",
|
|
" padding=True,\n",
|
|
" return_tensors=\"pt\",\n",
|
|
" )\n",
|
|
"ds = DataLoader(ds2, batch_size=2, num_workers=0, collate_fn=collate_fn)\n",
|
|
"print(ds)\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 7,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# # sanity check with one manual forward\n",
|
|
"# b = next(iter(ds))\n",
|
|
"# outputs = model(**b)\n",
|
|
"# outputs.keys()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Collect activations"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 8,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"['model.layers.0.mlp.down_proj',\n",
|
|
" 'model.layers.1.mlp.down_proj',\n",
|
|
" 'model.layers.2.mlp.down_proj',\n",
|
|
" 'model.layers.3.mlp.down_proj',\n",
|
|
" 'model.layers.4.mlp.down_proj',\n",
|
|
" 'model.layers.5.mlp.down_proj',\n",
|
|
" 'model.layers.6.mlp.down_proj',\n",
|
|
" 'model.layers.7.mlp.down_proj',\n",
|
|
" 'model.layers.8.mlp.down_proj',\n",
|
|
" 'model.layers.9.mlp.down_proj',\n",
|
|
" 'model.layers.10.mlp.down_proj',\n",
|
|
" 'model.layers.11.mlp.down_proj',\n",
|
|
" 'model.layers.12.mlp.down_proj',\n",
|
|
" 'model.layers.13.mlp.down_proj',\n",
|
|
" 'model.layers.14.mlp.down_proj',\n",
|
|
" 'model.layers.15.mlp.down_proj',\n",
|
|
" 'model.layers.16.mlp.down_proj',\n",
|
|
" 'model.layers.17.mlp.down_proj',\n",
|
|
" 'model.layers.18.mlp.down_proj',\n",
|
|
" 'model.layers.19.mlp.down_proj',\n",
|
|
" 'model.layers.20.mlp.down_proj',\n",
|
|
" 'model.layers.21.mlp.down_proj',\n",
|
|
" 'model.layers.22.mlp.down_proj',\n",
|
|
" 'model.layers.23.mlp.down_proj']"
|
|
]
|
|
},
|
|
"execution_count": 8,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"# choose layers to cache\n",
|
|
"layers = [k for k,v in model.named_modules() if 'mlp.down_proj' in k]\n",
|
|
"layers"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"\u001b[32m2025-02-15 21:14:24.538\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mactivation_store.collect\u001b[0m:\u001b[36mcollect_act_to_disk\u001b[0m:\u001b[36m60\u001b[0m - \u001b[1mcreating dataset /media/wassname/SGIronWolf/projects5/elk/outputs/.ds/ds__7ae34f9e83796c91\u001b[0m\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"application/vnd.jupyter.widget-view+json": {
|
|
"model_id": "f6ed6625c38544378d2d46969a8470c4",
|
|
"version_major": 2,
|
|
"version_minor": 0
|
|
},
|
|
"text/plain": [
|
|
"collecting hidden states: 0%| | 0/10 [00:00<?, ?it/s]"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"You're using a Qwen2TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"Dataset({\n",
|
|
" features: ['act-model.layers.0.mlp.down_proj', 'act-model.layers.1.mlp.down_proj', 'act-model.layers.2.mlp.down_proj', 'act-model.layers.3.mlp.down_proj', 'act-model.layers.4.mlp.down_proj', 'act-model.layers.5.mlp.down_proj', 'act-model.layers.6.mlp.down_proj', 'act-model.layers.7.mlp.down_proj', 'act-model.layers.8.mlp.down_proj', 'act-model.layers.9.mlp.down_proj', 'act-model.layers.10.mlp.down_proj', 'act-model.layers.11.mlp.down_proj', 'act-model.layers.12.mlp.down_proj', 'act-model.layers.13.mlp.down_proj', 'act-model.layers.14.mlp.down_proj', 'act-model.layers.15.mlp.down_proj', 'act-model.layers.16.mlp.down_proj', 'act-model.layers.17.mlp.down_proj', 'act-model.layers.18.mlp.down_proj', 'act-model.layers.19.mlp.down_proj', 'act-model.layers.20.mlp.down_proj', 'act-model.layers.21.mlp.down_proj', 'act-model.layers.22.mlp.down_proj', 'act-model.layers.23.mlp.down_proj', 'logits', 'hidden_states'],\n",
|
|
" num_rows: 20\n",
|
|
"})"
|
|
]
|
|
},
|
|
"execution_count": 9,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"ds_a, f = activation_store(ds, model, layers=layers)\n",
|
|
"ds_a"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 10,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"torch.Size([2, 453, 151936])"
|
|
]
|
|
},
|
|
"execution_count": 10,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"ds_a[0:2]['logits'].shape"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 11,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"ename": "KeyError",
|
|
"evalue": "'model.layers.0.mlp.down_proj'",
|
|
"output_type": "error",
|
|
"traceback": [
|
|
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
|
"\u001b[0;31mKeyError\u001b[0m Traceback (most recent call last)",
|
|
"Cell \u001b[0;32mIn[11], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mds_a\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m:\u001b[49m\u001b[38;5;241;43m2\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mmodel.layers.0.mlp.down_proj\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m]\u001b[49m\u001b[38;5;241m.\u001b[39mshape\n",
|
|
"\u001b[0;31mKeyError\u001b[0m: 'model.layers.0.mlp.down_proj'"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"ds_a[0:2]['model.layers.0.mlp.down_proj'].shape"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": ".venv",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.12.3"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 2
|
|
}
|