{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "%reload_ext autoreload\n", "%autoreload 2" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "import os\n", "os.environ[\"CUDA_DEVICE_ORDER\"] = \"PCI_BUS_ID\"\n", "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"1\"\n", "os.environ[\"PYTORCH_CUDA_ALLOC_CONF\"] = \"expandable_segments:True\"\n", "# os.environ[\"CUDA_LAUNCH_BLOCKING\"] = \"1\"" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "from datasets import load_dataset\n", "from transformers import AutoModelForCausalLM, AutoTokenizer\n", "\n", "from activation_store.collect import activation_store\n", "\n", "import torch" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Load model" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "model_name = \"Qwen/Qwen2.5-0.5B-Instruct\"\n", "\n", "model = AutoModelForCausalLM.from_pretrained(\n", " model_name,\n", " torch_dtype=torch.bfloat16,\n", " device_map=\"auto\",\n", " attn_implementation=\"eager\", # flex_attention flash_attention_2 sdpa eager\n", ")\n", "tokenizer = AutoTokenizer.from_pretrained(model_name)\n", "if tokenizer.pad_token_id is None:\n", " tokenizer.pad_token = tokenizer.eos_token\n", "tokenizer.paddding_side = \"left\"\n", "tokenizer.truncation_side = \"left\"" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Load data and tokenize" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Dataset({\n", " features: ['input_ids', 'attention_mask'],\n", " num_rows: 20\n", "})" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "N = 20\n", "max_length = 256\n", "\n", "imdb = load_dataset('wassname/imdb_dpo', split=f'test[:{N}]', keep_in_memory=False)\n", "\n", "\n", "def proc(row):\n", " messages = [\n", " {\"role\":\"user\", \"content\": row['prompt'] },\n", " {\"role\":\"assistant\", \"content\": row['chosen'] }\n", " ]\n", " return tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=False, return_dict=True, max_length=max_length)\n", "\n", "ds2 = imdb.map(proc).with_format(\"torch\")\n", "new_cols = set(ds2.column_names) - set(imdb.column_names)\n", "ds2 = ds2.select_columns(new_cols)\n", "ds2" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Data loader" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n" ] } ], "source": [ "from torch.utils.data import DataLoader\n", "from transformers.data import DataCollatorForLanguageModeling\n", "collate_fn = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)\n", "ds = DataLoader(ds2, batch_size=4, num_workers=0, collate_fn=collate_fn)\n", "print(ds)\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Collect activations" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'mlp.down_proj': ['model.layers.0.mlp.down_proj',\n", " 'model.layers.1.mlp.down_proj',\n", " 'model.layers.2.mlp.down_proj',\n", " 'model.layers.3.mlp.down_proj',\n", " 'model.layers.4.mlp.down_proj',\n", " 'model.layers.5.mlp.down_proj',\n", " 'model.layers.6.mlp.down_proj',\n", " 'model.layers.7.mlp.down_proj',\n", " 'model.layers.8.mlp.down_proj',\n", " 'model.layers.9.mlp.down_proj',\n", " 'model.layers.10.mlp.down_proj',\n", " 'model.layers.11.mlp.down_proj',\n", " 'model.layers.12.mlp.down_proj',\n", " 'model.layers.13.mlp.down_proj',\n", " 'model.layers.14.mlp.down_proj',\n", " 'model.layers.15.mlp.down_proj',\n", " 'model.layers.16.mlp.down_proj',\n", " 'model.layers.17.mlp.down_proj',\n", " 'model.layers.18.mlp.down_proj',\n", " 'model.layers.19.mlp.down_proj',\n", " 'model.layers.20.mlp.down_proj',\n", " 'model.layers.21.mlp.down_proj',\n", " 'model.layers.22.mlp.down_proj',\n", " 'model.layers.23.mlp.down_proj'],\n", " 'self_attn': ['model.layers.0.self_attn',\n", " 'model.layers.1.self_attn',\n", " 'model.layers.2.self_attn',\n", " 'model.layers.3.self_attn',\n", " 'model.layers.4.self_attn',\n", " 'model.layers.5.self_attn',\n", " 'model.layers.6.self_attn',\n", " 'model.layers.7.self_attn',\n", " 'model.layers.8.self_attn',\n", " 'model.layers.9.self_attn',\n", " 'model.layers.10.self_attn',\n", " 'model.layers.11.self_attn',\n", " 'model.layers.12.self_attn',\n", " 'model.layers.13.self_attn',\n", " 'model.layers.14.self_attn',\n", " 'model.layers.15.self_attn',\n", " 'model.layers.16.self_attn',\n", " 'model.layers.17.self_attn',\n", " 'model.layers.18.self_attn',\n", " 'model.layers.19.self_attn',\n", " 'model.layers.20.self_attn',\n", " 'model.layers.21.self_attn',\n", " 'model.layers.22.self_attn',\n", " 'model.layers.23.self_attn'],\n", " 'mlp.up_proj': ['model.layers.0.mlp.up_proj',\n", " 'model.layers.1.mlp.up_proj',\n", " 'model.layers.2.mlp.up_proj',\n", " 'model.layers.3.mlp.up_proj',\n", " 'model.layers.4.mlp.up_proj',\n", " 'model.layers.5.mlp.up_proj',\n", " 'model.layers.6.mlp.up_proj',\n", " 'model.layers.7.mlp.up_proj',\n", " 'model.layers.8.mlp.up_proj',\n", " 'model.layers.9.mlp.up_proj',\n", " 'model.layers.10.mlp.up_proj',\n", " 'model.layers.11.mlp.up_proj',\n", " 'model.layers.12.mlp.up_proj',\n", " 'model.layers.13.mlp.up_proj',\n", " 'model.layers.14.mlp.up_proj',\n", " 'model.layers.15.mlp.up_proj',\n", " 'model.layers.16.mlp.up_proj',\n", " 'model.layers.17.mlp.up_proj',\n", " 'model.layers.18.mlp.up_proj',\n", " 'model.layers.19.mlp.up_proj',\n", " 'model.layers.20.mlp.up_proj',\n", " 'model.layers.21.mlp.up_proj',\n", " 'model.layers.22.mlp.up_proj',\n", " 'model.layers.23.mlp.up_proj']}" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# choose layers to cache\n", "layer_groups = {\n", " 'mlp.down_proj': [k for k,v in model.named_modules() if k.endswith('mlp.down_proj')],\n", " 'self_attn': [k for k,v in model.named_modules() if k.endswith('.self_attn')],\n", " 'mlp.up_proj': [k for k,v in model.named_modules() if k.endswith('mlp.up_proj')],\n", "}\n", "layer_groups" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "\u001b[32m2025-03-14 16:42:30.982\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mactivation_store.collect\u001b[0m:\u001b[36mactivation_store\u001b[0m:\u001b[36m134\u001b[0m - \u001b[1mcreating dataset /media/wassname/SGIronWolf/projects5/elk/cache_transformer_acts/outputs/.ds/ds__0e7d5dbf1c73cf7d.parquet\u001b[0m\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "e0a492af45854b2c83f6d10d87c6d42a", "version_major": 2, "version_minor": 0 }, "text/plain": [ "collecting activations: 0%| | 0/5 [00:00