This commit is contained in:
wassname
2025-09-14 19:52:45 +08:00
parent 49c4341afc
commit 6c3f92e183
6 changed files with 156 additions and 47 deletions
+6 -2
View File
@@ -209,7 +209,7 @@ def gen_reasoning_trace(
is_choice_token = True
break
if is_choice_token or (i % fork_every == 0) or ((max_thinking_tokens is not None) and (i == max_thinking_tokens)):
if is_choice_token or (i % fork_every == 0):
# _is_thinking = is_thinking(all_input_ids[0], tokenizer)
logp_choices = force_forked_choice(
model,
@@ -257,14 +257,18 @@ def gen_reasoning_trace(
if all_input_ids.shape[1] >= nb_input_tokens + min_new_tokens:
break
full_strings = tokenizer.batch_decode(all_input_ids, skip_special_tokens=False)
token_strs = [tokenizer.batch_decode(i) for i in all_input_ids[:, nb_input_tokens:]]
# convert to one dataframe for each batch
dfs = [pd.DataFrame(d) for d in data]
# TODO I might want to remove everything after a tokenizer.eos_token_id
for df in dfs:
for i, df in enumerate(dfs):
df['token_strs'] = token_strs[i]
df.attrs.update({
"max_new_tokens": max_new_tokens,
"max_thinking_tokens": max_thinking_tokens,
@@ -0,0 +1,65 @@
import numpy as np
import matplotlib as mpl
from matplotlib import pyplot as plt
import html
from IPython.display import display, HTML
from cmap import Colormap
import pandas as pd
# choose a map which is dark in the middle
cmap_RdGyGn = Colormap(['red', 'grey', 'green']).to_mpl()
def display_rating_trace(df: pd.DataFrame, title="", key='act_prob', cmap=cmap_RdGyGn, symmetric = False, highlight_tokens=[]):
"""
Display colored tokens from a chain of thought.
Args:
- key: the dataframe columns with the single score to visualise
"""
# v = df[key].abs().max()
v = df[key].abs()
if symmetric:
v = v.abs()
vmax = v.quantile(0.99)
norm = mpl.colors.CenteredNorm(0, vmax)
else:
vmin = v.quantile(0.01)
vmax = v.quantile(0.99)
norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax)
# show colormap
a = np.array([0, 1])[None]
plt.figure(figsize=(9, 1.5))
img = plt.imshow(a, cmap=cmap, norm=norm)
plt.gca().set_visible(False)
cax = plt.axes([0.1, 0.2, 0.8, 0.6])
plt.colorbar(orientation="horizontal", cax=cax, label='rating')
plt.title(f"Judge rating along chain of thought. {title}")
plt.show()
# TODO show \n as <p>
htmls = f'<h3>{title}</h3>'
for n,row in df.iterrows():
token, score = row['token_strs'], row[key]
# token = token.replace('Ġ', ' ').replace('Ċ', '\n')
# html escape
token = html.escape(token)
# map score → RGBA → hex
hex_color = mpl.colors.to_hex(cmap(norm(score)))
if token.startswith("\n"):
htmls += "<p/>"
if token in highlight_tokens:
token = f'!{token}!!'
h = f'<span title="{score} n={n}" style="color: {hex_color};">{token}</span>'
if token.endswith("\n"):
htmls += "<p/>"
htmls += h
# render it inline
display(HTML(htmls))
# display_rating_trace(df_traj.interpolate(method='nearest'))
# df_traj
+1 -1
View File
@@ -39,7 +39,7 @@ def load_model(model_kwargs, device="auto"):
id = model_kwargs.pop("id")
model = AutoModelForCausalLM.from_pretrained(
device_map=device,
torch_dtype=torch_dtype,
dtype=torch_dtype,
pretrained_model_name_or_path=id,
# **model_kwargs,
)
+55 -27
View File
@@ -2,15 +2,33 @@
# steering
import random
from repeng import DatasetEntry
import json
import json5
import torch
import itertools
from repeng.control import model_layer_list
from repeng import ControlVector, ControlModel, DatasetEntry
from loguru import logger
import contextlib
from anycache import anycache
from llm_moral_foundations2.config import project_dir
@contextlib.contextmanager
def control(model, vector, coeff):
"""
Usage:
with control(model, vector, coeff):
model.generate()
"""
if coeff==0:
model.reset()
else:
model.set_control(vector, coeff)
try:
yield
finally:
model.reset()
def wrap_model(model):
try:
n_layers = len(model_layer_list(model))
@@ -39,6 +57,7 @@ def train_steering_vector(model, tokenizer, ds_name="scenario_engagement_dataset
tokenizer,
ds_steer,
batch_size=batch_size,
method="pca_diff_weighted",
)
return control_vector
@@ -53,39 +72,36 @@ def find_last_non_whitespace_token(tokenizer, tokens):
return t
return t
def make_dataset(tokenizer, personas, suffixes, verbose=False, template="You're a {persona} making statements about the world."):
def make_dataset(tokenizer, personas, suffixes, entities, verbose=False, template="You're {persona}, acting in the world."):
thinking_prefix = get_thinking_prefix(tokenizer)
# Create dataset entries
dataset = []
for i, _suffix in enumerate(suffixes):
for j, (positive_persona, negative_persona) in enumerate(personas):
_entity = random.choice(entities)
positive_persona = positive_persona.format(entity=_entity)
negative_persona = negative_persona.format(entity=_entity)
if (thinking_prefix is not None) and random.random() < 0.5:
suffix = thinking_prefix + _suffix
else:
suffix = _suffix + ""
# tokens = tokenizer.tokenize(suffix, add_special_tokens=False)[:max_suffix_length]
# Create multiple training examples with different truncations
# for i in range(1, len(tokens), max(1, len(tokens) // 5)): # Using stride to reduce dataset size
for think in [0, -1, 1]:
suffix = _suffix + ""
# truncated = tokenizer.convert_tokens_to_string(tokens[i:])
match think:
case -1:
suffix = "<think>\n\n" + suffix
case 1:
suffix = "<think>\n\n</think>\n\n" + suffix
positive_prompt = tokenizer.apply_chat_template(
# f"Please talk about {persona}."
# f"Pretend you're an {persona} person making statements about the world.
# "Act as if you're extremely {persona}.",
[{'role': 'user', 'content': template.format(persona=positive_persona)},
{'role': 'assistant', 'content': suffix}],
tokenize=False,
enable_thinking=False,
# enable_thinking=False,
continue_final_message=True
)
negative_prompt = tokenizer.apply_chat_template(
[{'role': 'user', 'content': template.format(persona=negative_persona)},
{'role': 'assistant', 'content': suffix}],
tokenize=False,
enable_thinking=False,
# enable_thinking=False,
continue_final_message=True,
)
@@ -105,28 +121,40 @@ def make_dataset(tokenizer, personas, suffixes, verbose=False, template="You're
logger.info(f"Dataset example {i}:\n\npositive_prompt={positive_prompt}\n\nnegative_prompt={negative_prompt}")
return dataset
def load_entities():
with open(project_dir/"data/steering/_entities.json5") as f:
_entities = json5.load(f)
return _entities
def load_suffixes(collapse=True):
# loads suffixes
with open(project_dir/"data/steering/_suffixes.json") as f:
suffixes = json.load(f)
with open(project_dir/"data/steering/_suffixes.json5") as f:
suffixes = json5.load(f)
if collapse:
suffixes = list(itertools.chain.from_iterable(suffixes.values()))
return suffixes
def get_thinking_prefix(tokenizer):
think_token = tokenizer.convert_tokens_to_ids("<think>")
if think_token:
return "<think>\n\n"
else:
return ""
def load_personas(name):
with open(project_dir/f"data/steering/{name}.json5") as f:
personas = json5.load(f)
return personas['personas']
def load_steering_ds(tokenizer, ds_name="scenario_engagement_dataset", verbose=False):
with open(project_dir/f"data/steering/{ds_name}.json") as f:
scenario_data = json.load(f)
with open(project_dir/f"data/steering/{ds_name}.json5") as f:
scenario_data = json5.load(f)
# loads suffixes
with open(project_dir/"data/steering/_suffixes.json") as f:
suffixes = json.load(f)
# could also sample equally from each
suffixes = itertools.chain.from_iterable(suffixes.values())
entities = load_entities()
suffixes = load_suffixes()
# suffixes = scenario_data["suffixes"]
personas = scenario_data["personas"]
dataset = make_dataset(tokenizer, personas, suffixes, verbose=verbose)
dataset = make_dataset(tokenizer, personas, suffixes, entities, verbose=verbose)
return dataset
+1
View File
@@ -17,6 +17,7 @@ dependencies = [
"einops>=0.8.1",
"hf-xet>=1.1.9",
"jaxtyping>=0.3.2",
"json5>=0.12.1",
"loguru>=0.7.3",
"matplotlib>=3.10.1",
"openrouter-wrapper",
Generated
+11
View File
@@ -1085,6 +1085,15 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/1e/e8/685f47e0d754320684db4425a0967f7d3fa70126bffd76110b7009a0090f/joblib-1.5.2-py3-none-any.whl", hash = "sha256:4e1f0bdbb987e6d843c70cf43714cb276623def372df3c22fe5266b2670bc241", size = 308396, upload-time = "2025-08-27T12:15:45.188Z" },
]
[[package]]
name = "json5"
version = "0.12.1"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/12/ae/929aee9619e9eba9015207a9d2c1c54db18311da7eb4dcf6d41ad6f0eb67/json5-0.12.1.tar.gz", hash = "sha256:b2743e77b3242f8d03c143dd975a6ec7c52e2f2afe76ed934e53503dd4ad4990", size = 52191, upload-time = "2025-08-12T19:47:42.583Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/85/e2/05328bd2621be49a6fed9e3030b1e51a2d04537d3f816d211b9cc53c5262/json5-0.12.1-py3-none-any.whl", hash = "sha256:d9c9b3bc34a5f54d43c35e11ef7cb87d8bdd098c6ace87117a7b7e83e705c1d5", size = 36119, upload-time = "2025-08-12T19:47:41.131Z" },
]
[[package]]
name = "jupyter-client"
version = "8.6.3"
@@ -1245,6 +1254,7 @@ dependencies = [
{ name = "einops" },
{ name = "hf-xet" },
{ name = "jaxtyping" },
{ name = "json5" },
{ name = "loguru" },
{ name = "matplotlib" },
{ name = "openrouter-wrapper" },
@@ -1274,6 +1284,7 @@ requires-dist = [
{ name = "einops", specifier = ">=0.8.1" },
{ name = "hf-xet", specifier = ">=1.1.9" },
{ name = "jaxtyping", specifier = ">=0.3.2" },
{ name = "json5", specifier = ">=0.12.1" },
{ name = "loguru", specifier = ">=0.7.3" },
{ name = "matplotlib", specifier = ">=3.10.1" },
{ name = "openrouter-wrapper", git = "https://github.com/wassname/openrouter_wrapper.git" },