diff --git a/llm_moral_foundations2/gather/cot.py b/llm_moral_foundations2/gather/cot.py
index 29bf2b2..e201852 100644
--- a/llm_moral_foundations2/gather/cot.py
+++ b/llm_moral_foundations2/gather/cot.py
@@ -209,7 +209,7 @@ def gen_reasoning_trace(
is_choice_token = True
break
- if is_choice_token or (i % fork_every == 0) or ((max_thinking_tokens is not None) and (i == max_thinking_tokens)):
+ if is_choice_token or (i % fork_every == 0):
# _is_thinking = is_thinking(all_input_ids[0], tokenizer)
logp_choices = force_forked_choice(
model,
@@ -256,15 +256,19 @@ def gen_reasoning_trace(
if all_input_ids[:, nb_input_tokens:].eq(tokenizer.eos_token_id).any(dim=1).all():
if all_input_ids.shape[1] >= nb_input_tokens + min_new_tokens:
break
+
full_strings = tokenizer.batch_decode(all_input_ids, skip_special_tokens=False)
+ token_strs = [tokenizer.batch_decode(i) for i in all_input_ids[:, nb_input_tokens:]]
+
# convert to one dataframe for each batch
dfs = [pd.DataFrame(d) for d in data]
# TODO I might want to remove everything after a tokenizer.eos_token_id
- for df in dfs:
+ for i, df in enumerate(dfs):
+ df['token_strs'] = token_strs[i]
df.attrs.update({
"max_new_tokens": max_new_tokens,
"max_thinking_tokens": max_thinking_tokens,
diff --git a/llm_moral_foundations2/helpers/plot_traj.py b/llm_moral_foundations2/helpers/plot_traj.py
new file mode 100644
index 0000000..9b21505
--- /dev/null
+++ b/llm_moral_foundations2/helpers/plot_traj.py
@@ -0,0 +1,65 @@
+
+import numpy as np
+import matplotlib as mpl
+from matplotlib import pyplot as plt
+import html
+from IPython.display import display, HTML
+from cmap import Colormap
+import pandas as pd
+
+# choose a map which is dark in the middle
+cmap_RdGyGn = Colormap(['red', 'grey', 'green']).to_mpl()
+
+
+def display_rating_trace(df: pd.DataFrame, title="", key='act_prob', cmap=cmap_RdGyGn, symmetric = False, highlight_tokens=[]):
+ """
+ Display colored tokens from a chain of thought.
+
+ Args:
+ - key: the dataframe columns with the single score to visualise
+ """
+ # v = df[key].abs().max()
+ v = df[key].abs()
+ if symmetric:
+ v = v.abs()
+ vmax = v.quantile(0.99)
+ norm = mpl.colors.CenteredNorm(0, vmax)
+ else:
+ vmin = v.quantile(0.01)
+ vmax = v.quantile(0.99)
+ norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax)
+
+ # show colormap
+ a = np.array([0, 1])[None]
+ plt.figure(figsize=(9, 1.5))
+ img = plt.imshow(a, cmap=cmap, norm=norm)
+ plt.gca().set_visible(False)
+ cax = plt.axes([0.1, 0.2, 0.8, 0.6])
+ plt.colorbar(orientation="horizontal", cax=cax, label='rating')
+ plt.title(f"Judge rating along chain of thought. {title}")
+ plt.show()
+
+ # TODO show \n as
+ htmls = f'
{title}
'
+ for n,row in df.iterrows():
+ token, score = row['token_strs'], row[key]
+ # token = token.replace('Ġ', ' ').replace('Ċ', '\n')
+
+ # html escape
+ token = html.escape(token)
+ # map score → RGBA → hex
+ hex_color = mpl.colors.to_hex(cmap(norm(score)))
+ if token.startswith("\n"):
+ htmls += ""
+ if token in highlight_tokens:
+ token = f'!{token}!!'
+ h = f'{token}'
+ if token.endswith("\n"):
+ htmls += ""
+ htmls += h
+ # render it inline
+ display(HTML(htmls))
+
+
+# display_rating_trace(df_traj.interpolate(method='nearest'))
+# df_traj
diff --git a/llm_moral_foundations2/load_model.py b/llm_moral_foundations2/load_model.py
index 554ce76..c1d879b 100644
--- a/llm_moral_foundations2/load_model.py
+++ b/llm_moral_foundations2/load_model.py
@@ -39,7 +39,7 @@ def load_model(model_kwargs, device="auto"):
id = model_kwargs.pop("id")
model = AutoModelForCausalLM.from_pretrained(
device_map=device,
- torch_dtype=torch_dtype,
+ dtype=torch_dtype,
pretrained_model_name_or_path=id,
# **model_kwargs,
)
diff --git a/llm_moral_foundations2/steering.py b/llm_moral_foundations2/steering.py
index 55677d8..da8b843 100644
--- a/llm_moral_foundations2/steering.py
+++ b/llm_moral_foundations2/steering.py
@@ -2,15 +2,33 @@
# steering
import random
from repeng import DatasetEntry
-import json
+import json5
import torch
import itertools
from repeng.control import model_layer_list
from repeng import ControlVector, ControlModel, DatasetEntry
from loguru import logger
+import contextlib
from anycache import anycache
from llm_moral_foundations2.config import project_dir
+@contextlib.contextmanager
+def control(model, vector, coeff):
+ """
+ Usage:
+ with control(model, vector, coeff):
+ model.generate()
+ """
+ if coeff==0:
+ model.reset()
+ else:
+ model.set_control(vector, coeff)
+ try:
+ yield
+ finally:
+ model.reset()
+
+
def wrap_model(model):
try:
n_layers = len(model_layer_list(model))
@@ -39,6 +57,7 @@ def train_steering_vector(model, tokenizer, ds_name="scenario_engagement_dataset
tokenizer,
ds_steer,
batch_size=batch_size,
+ method="pca_diff_weighted",
)
return control_vector
@@ -53,49 +72,46 @@ def find_last_non_whitespace_token(tokenizer, tokens):
return t
return t
-def make_dataset(tokenizer, personas, suffixes, verbose=False, template="You're a {persona} making statements about the world."):
+def make_dataset(tokenizer, personas, suffixes, entities, verbose=False, template="You're {persona}, acting in the world."):
+ thinking_prefix = get_thinking_prefix(tokenizer)
# Create dataset entries
dataset = []
for i, _suffix in enumerate(suffixes):
for j, (positive_persona, negative_persona) in enumerate(personas):
+ _entity = random.choice(entities)
+ positive_persona = positive_persona.format(entity=_entity)
+ negative_persona = negative_persona.format(entity=_entity)
+ if (thinking_prefix is not None) and random.random() < 0.5:
+ suffix = thinking_prefix + _suffix
+ else:
+ suffix = _suffix + ""
+
# tokens = tokenizer.tokenize(suffix, add_special_tokens=False)[:max_suffix_length]
# Create multiple training examples with different truncations
# for i in range(1, len(tokens), max(1, len(tokens) // 5)): # Using stride to reduce dataset size
- for think in [0, -1, 1]:
- suffix = _suffix + ""
- # truncated = tokenizer.convert_tokens_to_string(tokens[i:])
- match think:
- case -1:
- suffix = "\n\n" + suffix
- case 1:
- suffix = "\n\n\n\n" + suffix
+ positive_prompt = tokenizer.apply_chat_template(
+ [{'role': 'user', 'content': template.format(persona=positive_persona)},
+ {'role': 'assistant', 'content': suffix}],
+ tokenize=False,
+ # enable_thinking=False,
+ continue_final_message=True
+ )
+ negative_prompt = tokenizer.apply_chat_template(
+ [{'role': 'user', 'content': template.format(persona=negative_persona)},
+ {'role': 'assistant', 'content': suffix}],
+ tokenize=False,
+ # enable_thinking=False,
+ continue_final_message=True,
- positive_prompt = tokenizer.apply_chat_template(
- # f"Please talk about {persona}."
- # f"Pretend you're an {persona} person making statements about the world.
- # "Act as if you're extremely {persona}.",
- [{'role': 'user', 'content': template.format(persona=positive_persona)},
- {'role': 'assistant', 'content': suffix}],
- tokenize=False,
- enable_thinking=False,
- continue_final_message=True
- )
- negative_prompt = tokenizer.apply_chat_template(
- [{'role': 'user', 'content': template.format(persona=negative_persona)},
- {'role': 'assistant', 'content': suffix}],
- tokenize=False,
- enable_thinking=False,
- continue_final_message=True,
+ )
+ dataset.append(
+ DatasetEntry(
+ positive=positive_prompt,
+ negative=negative_prompt
)
-
- dataset.append(
- DatasetEntry(
- positive=positive_prompt,
- negative=negative_prompt
- )
- )
+ )
# shuffle
random.seed(42)
@@ -105,28 +121,40 @@ def make_dataset(tokenizer, personas, suffixes, verbose=False, template="You're
logger.info(f"Dataset example {i}:\n\npositive_prompt={positive_prompt}\n\nnegative_prompt={negative_prompt}")
return dataset
+def load_entities():
+ with open(project_dir/"data/steering/_entities.json5") as f:
+ _entities = json5.load(f)
+ return _entities
+
def load_suffixes(collapse=True):
# loads suffixes
- with open(project_dir/"data/steering/_suffixes.json") as f:
- suffixes = json.load(f)
+ with open(project_dir/"data/steering/_suffixes.json5") as f:
+ suffixes = json5.load(f)
if collapse:
suffixes = list(itertools.chain.from_iterable(suffixes.values()))
return suffixes
+def get_thinking_prefix(tokenizer):
+ think_token = tokenizer.convert_tokens_to_ids("")
+ if think_token:
+ return "\n\n"
+ else:
+ return ""
+
+def load_personas(name):
+ with open(project_dir/f"data/steering/{name}.json5") as f:
+ personas = json5.load(f)
+ return personas['personas']
def load_steering_ds(tokenizer, ds_name="scenario_engagement_dataset", verbose=False):
- with open(project_dir/f"data/steering/{ds_name}.json") as f:
- scenario_data = json.load(f)
+ with open(project_dir/f"data/steering/{ds_name}.json5") as f:
+ scenario_data = json5.load(f)
- # loads suffixes
- with open(project_dir/"data/steering/_suffixes.json") as f:
- suffixes = json.load(f)
-
- # could also sample equally from each
- suffixes = itertools.chain.from_iterable(suffixes.values())
+ entities = load_entities()
+ suffixes = load_suffixes()
# suffixes = scenario_data["suffixes"]
personas = scenario_data["personas"]
- dataset = make_dataset(tokenizer, personas, suffixes, verbose=verbose)
+ dataset = make_dataset(tokenizer, personas, suffixes, entities, verbose=verbose)
return dataset
diff --git a/pyproject.toml b/pyproject.toml
index 839d9d0..6eee8c8 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -17,6 +17,7 @@ dependencies = [
"einops>=0.8.1",
"hf-xet>=1.1.9",
"jaxtyping>=0.3.2",
+ "json5>=0.12.1",
"loguru>=0.7.3",
"matplotlib>=3.10.1",
"openrouter-wrapper",
diff --git a/uv.lock b/uv.lock
index de3b7fe..43a288c 100644
--- a/uv.lock
+++ b/uv.lock
@@ -1085,6 +1085,15 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/1e/e8/685f47e0d754320684db4425a0967f7d3fa70126bffd76110b7009a0090f/joblib-1.5.2-py3-none-any.whl", hash = "sha256:4e1f0bdbb987e6d843c70cf43714cb276623def372df3c22fe5266b2670bc241", size = 308396, upload-time = "2025-08-27T12:15:45.188Z" },
]
+[[package]]
+name = "json5"
+version = "0.12.1"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/12/ae/929aee9619e9eba9015207a9d2c1c54db18311da7eb4dcf6d41ad6f0eb67/json5-0.12.1.tar.gz", hash = "sha256:b2743e77b3242f8d03c143dd975a6ec7c52e2f2afe76ed934e53503dd4ad4990", size = 52191, upload-time = "2025-08-12T19:47:42.583Z" }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/85/e2/05328bd2621be49a6fed9e3030b1e51a2d04537d3f816d211b9cc53c5262/json5-0.12.1-py3-none-any.whl", hash = "sha256:d9c9b3bc34a5f54d43c35e11ef7cb87d8bdd098c6ace87117a7b7e83e705c1d5", size = 36119, upload-time = "2025-08-12T19:47:41.131Z" },
+]
+
[[package]]
name = "jupyter-client"
version = "8.6.3"
@@ -1245,6 +1254,7 @@ dependencies = [
{ name = "einops" },
{ name = "hf-xet" },
{ name = "jaxtyping" },
+ { name = "json5" },
{ name = "loguru" },
{ name = "matplotlib" },
{ name = "openrouter-wrapper" },
@@ -1274,6 +1284,7 @@ requires-dist = [
{ name = "einops", specifier = ">=0.8.1" },
{ name = "hf-xet", specifier = ">=1.1.9" },
{ name = "jaxtyping", specifier = ">=0.3.2" },
+ { name = "json5", specifier = ">=0.12.1" },
{ name = "loguru", specifier = ">=0.7.3" },
{ name = "matplotlib", specifier = ">=3.10.1" },
{ name = "openrouter-wrapper", git = "https://github.com/wassname/openrouter_wrapper.git" },