mirror of
https://github.com/wassname/llm-moral-foundations2.git
synced 2026-06-26 16:00:27 +08:00
wip
This commit is contained in:
@@ -1,3 +1,5 @@
|
||||
outputs/
|
||||
|
||||
# Python-generated files
|
||||
__pycache__/
|
||||
*.py[oc]
|
||||
|
||||
@@ -5,6 +5,9 @@ Difference from previous work
|
||||
- use mechinterp representation steering
|
||||
|
||||
|
||||

|
||||
|
||||
|
||||
|
||||
Links:
|
||||
- https://github.com/peterkirgis/llm-moral-foundations
|
||||
|
||||
Binary file not shown.
|
After Width: | Height: | Size: 298 KiB |
@@ -3,22 +3,20 @@
|
||||
def batch_tokenize(rows, tokenizer, max_model_len=2048):
|
||||
prefix1 = """
|
||||
Scale:
|
||||
0: Not at all wrong
|
||||
1: Not too wrong
|
||||
2: Somewhat wrong
|
||||
3: Very wrong
|
||||
4: Extremely wrong
|
||||
|
||||
Your rating:"""
|
||||
1: Not at all wrong
|
||||
2: Not too wrong
|
||||
3: Somewhat wrong
|
||||
4: Very wrong
|
||||
5: Extremely wrong
|
||||
"""
|
||||
prefix2 = """
|
||||
Scale:
|
||||
0: Extremely wrong
|
||||
1: Very wrong
|
||||
2: Somewhat wrong
|
||||
3: Not too wrong
|
||||
4: Not at all wrong
|
||||
|
||||
Your rating:"""
|
||||
1: Extremely wrong
|
||||
2: Very wrong
|
||||
3: Somewhat wrong
|
||||
4: Not too wrong
|
||||
5: Not at all wrong
|
||||
"""
|
||||
|
||||
outs = []
|
||||
|
||||
@@ -26,10 +24,16 @@ Your rating:"""
|
||||
rows_list = [{k: rows[k][i] for k in rows} for i in range(len(rows['Scenario']))]
|
||||
|
||||
for j, row in enumerate(rows_list):
|
||||
for i, prefix in enumerate([prefix1, prefix2]):
|
||||
for i, scale in enumerate([prefix1, prefix2]):
|
||||
scenario = row['Scenario']
|
||||
content = f"Please rate how morally wrong this behavior is on the following five-point scale: {scenario}"
|
||||
messages = [{"role": "user", "content": content+prefix},]
|
||||
content = f"""Please rate how morally wrong this behavior is on the following five-point scale: {scale}
|
||||
Scenario:
|
||||
{scenario}
|
||||
|
||||
Scale: {scale}
|
||||
|
||||
Your rating:"""
|
||||
messages = [{"role": "user", "content": content},]
|
||||
oo = tokenizer.apply_chat_template(
|
||||
messages,
|
||||
tokenize=True,
|
||||
@@ -42,6 +46,7 @@ Your rating:"""
|
||||
o = {k:v.squeeze() for k,v in oo.items()}
|
||||
o['index'] = row['index']
|
||||
o['reversed'] = i
|
||||
o['messages'] = messages
|
||||
outs.append(o)
|
||||
|
||||
# List[Dict] to Dict[List]
|
||||
|
||||
@@ -10,9 +10,11 @@ from anycache import anycache
|
||||
from llm_moral_foundations2.config import project_dir
|
||||
|
||||
def wrap_model(model):
|
||||
L = len(model_layer_list(model))
|
||||
n_layers = len(model_layer_list(model))
|
||||
# 5 or L//6+2
|
||||
cmodel = ControlModel(model, list(range(-4, -L//2, -1)))
|
||||
layer_ids = list(range(-4, -n_layers//2, -1)) # halfway to -4
|
||||
layer_ids = list(range(-1, -model.config.num_hidden_layers, -1)) # last layer to first
|
||||
cmodel = ControlModel(model, layer_ids)
|
||||
return cmodel
|
||||
|
||||
@anycache(cachedir='/tmp/anycache.pkl')
|
||||
|
||||
File diff suppressed because one or more lines are too long
+2282
-762
File diff suppressed because one or more lines are too long
@@ -0,0 +1,146 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "9de31618",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"<torch.autograd.grad_mode.set_grad_enabled at 0x74c0b82257e0>"
|
||||
]
|
||||
},
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import numpy as np\n",
|
||||
"from matplotlib import colormaps, colors\n",
|
||||
"from IPython.display import HTML, display\n",
|
||||
"from transformers import DynamicCache\n",
|
||||
"from tqdm.auto import tqdm\n",
|
||||
"import pylab as pl\n",
|
||||
"import pandas as pd\n",
|
||||
"from loguru import logger\n",
|
||||
"from cmap import Colormap\n",
|
||||
"import html\n",
|
||||
"from transformers import AutoTokenizer\n",
|
||||
"from transformers import AutoModelForCausalLM, BitsAndBytesConfig\n",
|
||||
"import torch\n",
|
||||
"import os\n",
|
||||
"\n",
|
||||
"from llm_moral_foundations2.utils import sanitize_filename, clear_mem\n",
|
||||
"from llm_moral_foundations2.steering import wrap_model, load_steering_ds, train_steering_vector, make_dataset\n",
|
||||
"from llm_moral_foundations2.load_model import load_model, work_out_batch_size\n",
|
||||
"from llm_moral_foundations2.config import project_dir\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "5a91f64c",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import sglang as sgl\n",
|
||||
"from sglang import assistant_begin, assistant_end\n",
|
||||
"from sglang import assistant, function, gen, system, user\n",
|
||||
"from sglang import image\n",
|
||||
"from sglang import RuntimeEndpoint, set_default_backend\n",
|
||||
"\n",
|
||||
"from sglang.lang.choices import ChoicesDecision, ChoicesSamplingMethod, UnconditionalLikelihoodNormalized\n",
|
||||
"from typing import List, Any, Optional\n",
|
||||
"import numpy as np\n",
|
||||
"from scipy.stats import kendalltau\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"class WeightedRating(ChoicesSamplingMethod):\n",
|
||||
" \"\"\"\n",
|
||||
" This uses the single logprobs for integers to compute a weighted rating. That way we can get a distribution from just one inference. For best results run it with the reverse scale to remove bias towards the first choice.\n",
|
||||
" \"\"\"\n",
|
||||
"\n",
|
||||
" def __call__(\n",
|
||||
" self,\n",
|
||||
" *,\n",
|
||||
" choices: List[str],\n",
|
||||
" normalized_prompt_logprobs: List[float],\n",
|
||||
" input_token_logprobs: List[List[Any]],\n",
|
||||
" output_token_logprobs: List[List[Any]],\n",
|
||||
" unconditional_token_logprobs: Optional[List[List[Any]]] = None,\n",
|
||||
" ) -> ChoicesDecision:\n",
|
||||
" \"\"\"Select the option with the highest token length normalized prompt logprob.\"\"\"\n",
|
||||
" choice_ints = np.array([int(c) for c in choices])\n",
|
||||
" choice_probs = np.exp(np.array(normalized_prompt_logprobs))\n",
|
||||
" weights = choice_probs / choice_probs.sum()\n",
|
||||
" decision = (choice_ints * weights).sum()\n",
|
||||
" # now normalise it to [0, 1]\n",
|
||||
" decision = (decision - choice_ints.min()) / (choice_ints.max() - choice_ints.min()) \n",
|
||||
"\n",
|
||||
" meta_info = {\n",
|
||||
" \"normalized_prompt_logprobs\": normalized_prompt_logprobs,\n",
|
||||
" \"input_token_logprobs\": input_token_logprobs,\n",
|
||||
" \"output_token_logprobs\": output_token_logprobs,\n",
|
||||
" }\n",
|
||||
" return ChoicesDecision(decision=f\"{decision:2.6f}\", meta_info=meta_info)\n",
|
||||
"\n",
|
||||
"weighted_rating = WeightedRating()\n",
|
||||
"\n",
|
||||
"# https://docs.sglang.ai/references/frontend/frontend_tutorial.html\n",
|
||||
"@sgl.function\n",
|
||||
"def rate_content(s):\n",
|
||||
" s += sgl.user(\n",
|
||||
" f\"rate the morality of the following content on a scale from 0: very immoral, to 5: very ethical. {s}\"\n",
|
||||
" )\n",
|
||||
" s += sgl.assistant(\n",
|
||||
" \"My rating of this, on your scale, is: \"\n",
|
||||
" + sgl.gen(\n",
|
||||
" \"answer\",\n",
|
||||
" choices=['1', '2', '3', '4', '5'],\n",
|
||||
" choices_method=weighted_rating,\n",
|
||||
" )\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# first start the server e.g. python -m sglang.launch_server --model-path opencompass/CompassJudger-2-7B-Instruct --host 0.0.0.0`\"\n",
|
||||
"\n",
|
||||
"s = rate_content(\"I stole a cookie from the cookie jar\")\n",
|
||||
"\n",
|
||||
"logp = s.get_meta_info(\"answer\")[\"normalized_prompt_logprobs\"]\n",
|
||||
"weight = s['answer']\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "3a8ae953",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": ".venv",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.16"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
Reference in New Issue
Block a user