Files
activation_store/nbs/example.ipynb
T
wassname 8692769bce tidy
2025-05-03 20:35:31 +08:00

380 lines
10 KiB
Plaintext

{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"%reload_ext autoreload\n",
"%autoreload 2"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"os.environ[\"CUDA_DEVICE_ORDER\"] = \"PCI_BUS_ID\"\n",
"os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"1\"\n",
"os.environ[\"PYTORCH_CUDA_ALLOC_CONF\"] = \"expandable_segments:True\"\n",
"# os.environ[\"CUDA_LAUNCH_BLOCKING\"] = \"1\""
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"from datasets import load_dataset\n",
"from transformers import AutoModelForCausalLM, AutoTokenizer\n",
"\n",
"from activation_store.collect import activation_store\n",
"\n",
"import torch"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Load model"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"model_name = \"Qwen/Qwen2.5-0.5B-Instruct\"\n",
"\n",
"model = AutoModelForCausalLM.from_pretrained(\n",
" model_name,\n",
" torch_dtype=torch.bfloat16,\n",
" device_map=\"auto\",\n",
" attn_implementation=\"eager\", # flex_attention flash_attention_2 sdpa eager\n",
")\n",
"tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
"if tokenizer.pad_token_id is None:\n",
" tokenizer.pad_token = tokenizer.eos_token\n",
"tokenizer.paddding_side = \"left\"\n",
"tokenizer.truncation_side = \"left\""
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Load data and tokenize"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Dataset({\n",
" features: ['input_ids', 'attention_mask'],\n",
" num_rows: 20\n",
"})"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"N = 20\n",
"max_length = 256\n",
"\n",
"imdb = load_dataset('wassname/imdb_dpo', split=f'test[:{N}]', keep_in_memory=False)\n",
"\n",
"\n",
"def proc(row):\n",
" messages = [\n",
" {\"role\":\"user\", \"content\": row['prompt'] },\n",
" {\"role\":\"assistant\", \"content\": row['chosen'] }\n",
" ]\n",
" return tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=False, return_dict=True, max_length=max_length)\n",
"\n",
"ds2 = imdb.map(proc).with_format(\"torch\")\n",
"new_cols = set(ds2.column_names) - set(imdb.column_names)\n",
"ds2 = ds2.select_columns(new_cols)\n",
"ds2"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Data loader"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"<torch.utils.data.dataloader.DataLoader object at 0x7c988e1ef290>\n"
]
}
],
"source": [
"from torch.utils.data import DataLoader\n",
"from transformers.data import DataCollatorForLanguageModeling\n",
"collate_fn = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)\n",
"ds = DataLoader(ds2, batch_size=4, num_workers=0, collate_fn=collate_fn)\n",
"print(ds)\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Collect activations"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'mlp.down_proj': ['model.layers.0.mlp.down_proj',\n",
" 'model.layers.1.mlp.down_proj',\n",
" 'model.layers.2.mlp.down_proj',\n",
" 'model.layers.3.mlp.down_proj',\n",
" 'model.layers.4.mlp.down_proj',\n",
" 'model.layers.5.mlp.down_proj',\n",
" 'model.layers.6.mlp.down_proj',\n",
" 'model.layers.7.mlp.down_proj',\n",
" 'model.layers.8.mlp.down_proj',\n",
" 'model.layers.9.mlp.down_proj',\n",
" 'model.layers.10.mlp.down_proj',\n",
" 'model.layers.11.mlp.down_proj',\n",
" 'model.layers.12.mlp.down_proj',\n",
" 'model.layers.13.mlp.down_proj',\n",
" 'model.layers.14.mlp.down_proj',\n",
" 'model.layers.15.mlp.down_proj',\n",
" 'model.layers.16.mlp.down_proj',\n",
" 'model.layers.17.mlp.down_proj',\n",
" 'model.layers.18.mlp.down_proj',\n",
" 'model.layers.19.mlp.down_proj',\n",
" 'model.layers.20.mlp.down_proj',\n",
" 'model.layers.21.mlp.down_proj',\n",
" 'model.layers.22.mlp.down_proj',\n",
" 'model.layers.23.mlp.down_proj'],\n",
" 'self_attn': ['model.layers.0.self_attn',\n",
" 'model.layers.1.self_attn',\n",
" 'model.layers.2.self_attn',\n",
" 'model.layers.3.self_attn',\n",
" 'model.layers.4.self_attn',\n",
" 'model.layers.5.self_attn',\n",
" 'model.layers.6.self_attn',\n",
" 'model.layers.7.self_attn',\n",
" 'model.layers.8.self_attn',\n",
" 'model.layers.9.self_attn',\n",
" 'model.layers.10.self_attn',\n",
" 'model.layers.11.self_attn',\n",
" 'model.layers.12.self_attn',\n",
" 'model.layers.13.self_attn',\n",
" 'model.layers.14.self_attn',\n",
" 'model.layers.15.self_attn',\n",
" 'model.layers.16.self_attn',\n",
" 'model.layers.17.self_attn',\n",
" 'model.layers.18.self_attn',\n",
" 'model.layers.19.self_attn',\n",
" 'model.layers.20.self_attn',\n",
" 'model.layers.21.self_attn',\n",
" 'model.layers.22.self_attn',\n",
" 'model.layers.23.self_attn'],\n",
" 'mlp.up_proj': ['model.layers.0.mlp.up_proj',\n",
" 'model.layers.1.mlp.up_proj',\n",
" 'model.layers.2.mlp.up_proj',\n",
" 'model.layers.3.mlp.up_proj',\n",
" 'model.layers.4.mlp.up_proj',\n",
" 'model.layers.5.mlp.up_proj',\n",
" 'model.layers.6.mlp.up_proj',\n",
" 'model.layers.7.mlp.up_proj',\n",
" 'model.layers.8.mlp.up_proj',\n",
" 'model.layers.9.mlp.up_proj',\n",
" 'model.layers.10.mlp.up_proj',\n",
" 'model.layers.11.mlp.up_proj',\n",
" 'model.layers.12.mlp.up_proj',\n",
" 'model.layers.13.mlp.up_proj',\n",
" 'model.layers.14.mlp.up_proj',\n",
" 'model.layers.15.mlp.up_proj',\n",
" 'model.layers.16.mlp.up_proj',\n",
" 'model.layers.17.mlp.up_proj',\n",
" 'model.layers.18.mlp.up_proj',\n",
" 'model.layers.19.mlp.up_proj',\n",
" 'model.layers.20.mlp.up_proj',\n",
" 'model.layers.21.mlp.up_proj',\n",
" 'model.layers.22.mlp.up_proj',\n",
" 'model.layers.23.mlp.up_proj']}"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# choose layers to cache\n",
"layer_groups = {\n",
" 'mlp.down_proj': [k for k,v in model.named_modules() if k.endswith('mlp.down_proj')],\n",
" 'self_attn': [k for k,v in model.named_modules() if k.endswith('.self_attn')],\n",
" 'mlp.up_proj': [k for k,v in model.named_modules() if k.endswith('mlp.up_proj')],\n",
"}\n",
"layer_groups"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"\u001b[32m2025-03-14 16:42:30.982\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mactivation_store.collect\u001b[0m:\u001b[36mactivation_store\u001b[0m:\u001b[36m134\u001b[0m - \u001b[1mcreating dataset /media/wassname/SGIronWolf/projects5/elk/cache_transformer_acts/outputs/.ds/ds__0e7d5dbf1c73cf7d.parquet\u001b[0m\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "e0a492af45854b2c83f6d10d87c6d42a",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"collecting activations: 0%| | 0/5 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"PosixPath('/media/wassname/SGIronWolf/projects5/elk/cache_transformer_acts/outputs/.ds/ds__0e7d5dbf1c73cf7d.parquet')"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"f = activation_store(ds, model, layers=layer_groups)\n",
"f"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "f1bcb56397af43ac82f4cf4761acdd87",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Generating train split: 0 examples [00:00, ? examples/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"Dataset({\n",
" features: ['mlp.down_proj', 'self_attn', 'mlp.up_proj', 'loss', 'logits', 'hidden_states'],\n",
" num_rows: 20\n",
"})"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from datasets import Dataset, load_dataset\n",
"\n",
"# ds_a = load_dataset(\"parquet\", split='train', data_files=str(f), keep_in_memory=False)\n",
"# OR\n",
"ds_a = Dataset.from_parquet(str(f)).with_format(\"torch\")\n",
"ds_a"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([2, 25, 1, 896])"
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"ds_a[0:2]['hidden_states'].shape # [batch, layers, tokens, hidden_states]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.3"
}
},
"nbformat": 4,
"nbformat_minor": 2
}