diff --git a/llm_moral_foundations2/gather/cot.py b/llm_moral_foundations2/gather/cot.py index 29bf2b2..e201852 100644 --- a/llm_moral_foundations2/gather/cot.py +++ b/llm_moral_foundations2/gather/cot.py @@ -209,7 +209,7 @@ def gen_reasoning_trace( is_choice_token = True break - if is_choice_token or (i % fork_every == 0) or ((max_thinking_tokens is not None) and (i == max_thinking_tokens)): + if is_choice_token or (i % fork_every == 0): # _is_thinking = is_thinking(all_input_ids[0], tokenizer) logp_choices = force_forked_choice( model, @@ -256,15 +256,19 @@ def gen_reasoning_trace( if all_input_ids[:, nb_input_tokens:].eq(tokenizer.eos_token_id).any(dim=1).all(): if all_input_ids.shape[1] >= nb_input_tokens + min_new_tokens: break + full_strings = tokenizer.batch_decode(all_input_ids, skip_special_tokens=False) + token_strs = [tokenizer.batch_decode(i) for i in all_input_ids[:, nb_input_tokens:]] + # convert to one dataframe for each batch dfs = [pd.DataFrame(d) for d in data] # TODO I might want to remove everything after a tokenizer.eos_token_id - for df in dfs: + for i, df in enumerate(dfs): + df['token_strs'] = token_strs[i] df.attrs.update({ "max_new_tokens": max_new_tokens, "max_thinking_tokens": max_thinking_tokens, diff --git a/llm_moral_foundations2/helpers/plot_traj.py b/llm_moral_foundations2/helpers/plot_traj.py new file mode 100644 index 0000000..9b21505 --- /dev/null +++ b/llm_moral_foundations2/helpers/plot_traj.py @@ -0,0 +1,65 @@ + +import numpy as np +import matplotlib as mpl +from matplotlib import pyplot as plt +import html +from IPython.display import display, HTML +from cmap import Colormap +import pandas as pd + +# choose a map which is dark in the middle +cmap_RdGyGn = Colormap(['red', 'grey', 'green']).to_mpl() + + +def display_rating_trace(df: pd.DataFrame, title="", key='act_prob', cmap=cmap_RdGyGn, symmetric = False, highlight_tokens=[]): + """ + Display colored tokens from a chain of thought. + + Args: + - key: the dataframe columns with the single score to visualise + """ + # v = df[key].abs().max() + v = df[key].abs() + if symmetric: + v = v.abs() + vmax = v.quantile(0.99) + norm = mpl.colors.CenteredNorm(0, vmax) + else: + vmin = v.quantile(0.01) + vmax = v.quantile(0.99) + norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax) + + # show colormap + a = np.array([0, 1])[None] + plt.figure(figsize=(9, 1.5)) + img = plt.imshow(a, cmap=cmap, norm=norm) + plt.gca().set_visible(False) + cax = plt.axes([0.1, 0.2, 0.8, 0.6]) + plt.colorbar(orientation="horizontal", cax=cax, label='rating') + plt.title(f"Judge rating along chain of thought. {title}") + plt.show() + + # TODO show \n as

+ htmls = f'

{title}

' + for n,row in df.iterrows(): + token, score = row['token_strs'], row[key] + # token = token.replace('Ġ', ' ').replace('Ċ', '\n') + + # html escape + token = html.escape(token) + # map score → RGBA → hex + hex_color = mpl.colors.to_hex(cmap(norm(score))) + if token.startswith("\n"): + htmls += "

" + if token in highlight_tokens: + token = f'!{token}!!' + h = f'{token}' + if token.endswith("\n"): + htmls += "

" + htmls += h + # render it inline + display(HTML(htmls)) + + +# display_rating_trace(df_traj.interpolate(method='nearest')) +# df_traj diff --git a/llm_moral_foundations2/load_model.py b/llm_moral_foundations2/load_model.py index 554ce76..c1d879b 100644 --- a/llm_moral_foundations2/load_model.py +++ b/llm_moral_foundations2/load_model.py @@ -39,7 +39,7 @@ def load_model(model_kwargs, device="auto"): id = model_kwargs.pop("id") model = AutoModelForCausalLM.from_pretrained( device_map=device, - torch_dtype=torch_dtype, + dtype=torch_dtype, pretrained_model_name_or_path=id, # **model_kwargs, ) diff --git a/llm_moral_foundations2/steering.py b/llm_moral_foundations2/steering.py index 55677d8..da8b843 100644 --- a/llm_moral_foundations2/steering.py +++ b/llm_moral_foundations2/steering.py @@ -2,15 +2,33 @@ # steering import random from repeng import DatasetEntry -import json +import json5 import torch import itertools from repeng.control import model_layer_list from repeng import ControlVector, ControlModel, DatasetEntry from loguru import logger +import contextlib from anycache import anycache from llm_moral_foundations2.config import project_dir +@contextlib.contextmanager +def control(model, vector, coeff): + """ + Usage: + with control(model, vector, coeff): + model.generate() + """ + if coeff==0: + model.reset() + else: + model.set_control(vector, coeff) + try: + yield + finally: + model.reset() + + def wrap_model(model): try: n_layers = len(model_layer_list(model)) @@ -39,6 +57,7 @@ def train_steering_vector(model, tokenizer, ds_name="scenario_engagement_dataset tokenizer, ds_steer, batch_size=batch_size, + method="pca_diff_weighted", ) return control_vector @@ -53,49 +72,46 @@ def find_last_non_whitespace_token(tokenizer, tokens): return t return t -def make_dataset(tokenizer, personas, suffixes, verbose=False, template="You're a {persona} making statements about the world."): +def make_dataset(tokenizer, personas, suffixes, entities, verbose=False, template="You're {persona}, acting in the world."): + thinking_prefix = get_thinking_prefix(tokenizer) # Create dataset entries dataset = [] for i, _suffix in enumerate(suffixes): for j, (positive_persona, negative_persona) in enumerate(personas): + _entity = random.choice(entities) + positive_persona = positive_persona.format(entity=_entity) + negative_persona = negative_persona.format(entity=_entity) + if (thinking_prefix is not None) and random.random() < 0.5: + suffix = thinking_prefix + _suffix + else: + suffix = _suffix + "" + # tokens = tokenizer.tokenize(suffix, add_special_tokens=False)[:max_suffix_length] # Create multiple training examples with different truncations # for i in range(1, len(tokens), max(1, len(tokens) // 5)): # Using stride to reduce dataset size - for think in [0, -1, 1]: - suffix = _suffix + "" - # truncated = tokenizer.convert_tokens_to_string(tokens[i:]) - match think: - case -1: - suffix = "\n\n" + suffix - case 1: - suffix = "\n\n\n\n" + suffix + positive_prompt = tokenizer.apply_chat_template( + [{'role': 'user', 'content': template.format(persona=positive_persona)}, + {'role': 'assistant', 'content': suffix}], + tokenize=False, + # enable_thinking=False, + continue_final_message=True + ) + negative_prompt = tokenizer.apply_chat_template( + [{'role': 'user', 'content': template.format(persona=negative_persona)}, + {'role': 'assistant', 'content': suffix}], + tokenize=False, + # enable_thinking=False, + continue_final_message=True, - positive_prompt = tokenizer.apply_chat_template( - # f"Please talk about {persona}." - # f"Pretend you're an {persona} person making statements about the world. - # "Act as if you're extremely {persona}.", - [{'role': 'user', 'content': template.format(persona=positive_persona)}, - {'role': 'assistant', 'content': suffix}], - tokenize=False, - enable_thinking=False, - continue_final_message=True - ) - negative_prompt = tokenizer.apply_chat_template( - [{'role': 'user', 'content': template.format(persona=negative_persona)}, - {'role': 'assistant', 'content': suffix}], - tokenize=False, - enable_thinking=False, - continue_final_message=True, + ) + dataset.append( + DatasetEntry( + positive=positive_prompt, + negative=negative_prompt ) - - dataset.append( - DatasetEntry( - positive=positive_prompt, - negative=negative_prompt - ) - ) + ) # shuffle random.seed(42) @@ -105,28 +121,40 @@ def make_dataset(tokenizer, personas, suffixes, verbose=False, template="You're logger.info(f"Dataset example {i}:\n\npositive_prompt={positive_prompt}\n\nnegative_prompt={negative_prompt}") return dataset +def load_entities(): + with open(project_dir/"data/steering/_entities.json5") as f: + _entities = json5.load(f) + return _entities + def load_suffixes(collapse=True): # loads suffixes - with open(project_dir/"data/steering/_suffixes.json") as f: - suffixes = json.load(f) + with open(project_dir/"data/steering/_suffixes.json5") as f: + suffixes = json5.load(f) if collapse: suffixes = list(itertools.chain.from_iterable(suffixes.values())) return suffixes +def get_thinking_prefix(tokenizer): + think_token = tokenizer.convert_tokens_to_ids("") + if think_token: + return "\n\n" + else: + return "" + +def load_personas(name): + with open(project_dir/f"data/steering/{name}.json5") as f: + personas = json5.load(f) + return personas['personas'] def load_steering_ds(tokenizer, ds_name="scenario_engagement_dataset", verbose=False): - with open(project_dir/f"data/steering/{ds_name}.json") as f: - scenario_data = json.load(f) + with open(project_dir/f"data/steering/{ds_name}.json5") as f: + scenario_data = json5.load(f) - # loads suffixes - with open(project_dir/"data/steering/_suffixes.json") as f: - suffixes = json.load(f) - - # could also sample equally from each - suffixes = itertools.chain.from_iterable(suffixes.values()) + entities = load_entities() + suffixes = load_suffixes() # suffixes = scenario_data["suffixes"] personas = scenario_data["personas"] - dataset = make_dataset(tokenizer, personas, suffixes, verbose=verbose) + dataset = make_dataset(tokenizer, personas, suffixes, entities, verbose=verbose) return dataset diff --git a/pyproject.toml b/pyproject.toml index 839d9d0..6eee8c8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,6 +17,7 @@ dependencies = [ "einops>=0.8.1", "hf-xet>=1.1.9", "jaxtyping>=0.3.2", + "json5>=0.12.1", "loguru>=0.7.3", "matplotlib>=3.10.1", "openrouter-wrapper", diff --git a/uv.lock b/uv.lock index de3b7fe..43a288c 100644 --- a/uv.lock +++ b/uv.lock @@ -1085,6 +1085,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/1e/e8/685f47e0d754320684db4425a0967f7d3fa70126bffd76110b7009a0090f/joblib-1.5.2-py3-none-any.whl", hash = "sha256:4e1f0bdbb987e6d843c70cf43714cb276623def372df3c22fe5266b2670bc241", size = 308396, upload-time = "2025-08-27T12:15:45.188Z" }, ] +[[package]] +name = "json5" +version = "0.12.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/12/ae/929aee9619e9eba9015207a9d2c1c54db18311da7eb4dcf6d41ad6f0eb67/json5-0.12.1.tar.gz", hash = "sha256:b2743e77b3242f8d03c143dd975a6ec7c52e2f2afe76ed934e53503dd4ad4990", size = 52191, upload-time = "2025-08-12T19:47:42.583Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/85/e2/05328bd2621be49a6fed9e3030b1e51a2d04537d3f816d211b9cc53c5262/json5-0.12.1-py3-none-any.whl", hash = "sha256:d9c9b3bc34a5f54d43c35e11ef7cb87d8bdd098c6ace87117a7b7e83e705c1d5", size = 36119, upload-time = "2025-08-12T19:47:41.131Z" }, +] + [[package]] name = "jupyter-client" version = "8.6.3" @@ -1245,6 +1254,7 @@ dependencies = [ { name = "einops" }, { name = "hf-xet" }, { name = "jaxtyping" }, + { name = "json5" }, { name = "loguru" }, { name = "matplotlib" }, { name = "openrouter-wrapper" }, @@ -1274,6 +1284,7 @@ requires-dist = [ { name = "einops", specifier = ">=0.8.1" }, { name = "hf-xet", specifier = ">=1.1.9" }, { name = "jaxtyping", specifier = ">=0.3.2" }, + { name = "json5", specifier = ">=0.12.1" }, { name = "loguru", specifier = ">=0.7.3" }, { name = "matplotlib", specifier = ">=3.10.1" }, { name = "openrouter-wrapper", git = "https://github.com/wassname/openrouter_wrapper.git" },