mirror of
https://github.com/wassname/llm-moral-foundations2.git
synced 2026-06-27 16:10:07 +08:00
misc
This commit is contained in:
@@ -209,7 +209,7 @@ def gen_reasoning_trace(
|
||||
is_choice_token = True
|
||||
break
|
||||
|
||||
if is_choice_token or (i % fork_every == 0) or ((max_thinking_tokens is not None) and (i == max_thinking_tokens)):
|
||||
if is_choice_token or (i % fork_every == 0):
|
||||
# _is_thinking = is_thinking(all_input_ids[0], tokenizer)
|
||||
logp_choices = force_forked_choice(
|
||||
model,
|
||||
@@ -257,14 +257,18 @@ def gen_reasoning_trace(
|
||||
if all_input_ids.shape[1] >= nb_input_tokens + min_new_tokens:
|
||||
break
|
||||
|
||||
|
||||
full_strings = tokenizer.batch_decode(all_input_ids, skip_special_tokens=False)
|
||||
|
||||
token_strs = [tokenizer.batch_decode(i) for i in all_input_ids[:, nb_input_tokens:]]
|
||||
|
||||
# convert to one dataframe for each batch
|
||||
dfs = [pd.DataFrame(d) for d in data]
|
||||
|
||||
# TODO I might want to remove everything after a tokenizer.eos_token_id
|
||||
|
||||
for df in dfs:
|
||||
for i, df in enumerate(dfs):
|
||||
df['token_strs'] = token_strs[i]
|
||||
df.attrs.update({
|
||||
"max_new_tokens": max_new_tokens,
|
||||
"max_thinking_tokens": max_thinking_tokens,
|
||||
|
||||
@@ -0,0 +1,65 @@
|
||||
|
||||
import numpy as np
|
||||
import matplotlib as mpl
|
||||
from matplotlib import pyplot as plt
|
||||
import html
|
||||
from IPython.display import display, HTML
|
||||
from cmap import Colormap
|
||||
import pandas as pd
|
||||
|
||||
# choose a map which is dark in the middle
|
||||
cmap_RdGyGn = Colormap(['red', 'grey', 'green']).to_mpl()
|
||||
|
||||
|
||||
def display_rating_trace(df: pd.DataFrame, title="", key='act_prob', cmap=cmap_RdGyGn, symmetric = False, highlight_tokens=[]):
|
||||
"""
|
||||
Display colored tokens from a chain of thought.
|
||||
|
||||
Args:
|
||||
- key: the dataframe columns with the single score to visualise
|
||||
"""
|
||||
# v = df[key].abs().max()
|
||||
v = df[key].abs()
|
||||
if symmetric:
|
||||
v = v.abs()
|
||||
vmax = v.quantile(0.99)
|
||||
norm = mpl.colors.CenteredNorm(0, vmax)
|
||||
else:
|
||||
vmin = v.quantile(0.01)
|
||||
vmax = v.quantile(0.99)
|
||||
norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax)
|
||||
|
||||
# show colormap
|
||||
a = np.array([0, 1])[None]
|
||||
plt.figure(figsize=(9, 1.5))
|
||||
img = plt.imshow(a, cmap=cmap, norm=norm)
|
||||
plt.gca().set_visible(False)
|
||||
cax = plt.axes([0.1, 0.2, 0.8, 0.6])
|
||||
plt.colorbar(orientation="horizontal", cax=cax, label='rating')
|
||||
plt.title(f"Judge rating along chain of thought. {title}")
|
||||
plt.show()
|
||||
|
||||
# TODO show \n as <p>
|
||||
htmls = f'<h3>{title}</h3>'
|
||||
for n,row in df.iterrows():
|
||||
token, score = row['token_strs'], row[key]
|
||||
# token = token.replace('Ġ', ' ').replace('Ċ', '\n')
|
||||
|
||||
# html escape
|
||||
token = html.escape(token)
|
||||
# map score → RGBA → hex
|
||||
hex_color = mpl.colors.to_hex(cmap(norm(score)))
|
||||
if token.startswith("\n"):
|
||||
htmls += "<p/>"
|
||||
if token in highlight_tokens:
|
||||
token = f'!{token}!!'
|
||||
h = f'<span title="{score} n={n}" style="color: {hex_color};">{token}</span>'
|
||||
if token.endswith("\n"):
|
||||
htmls += "<p/>"
|
||||
htmls += h
|
||||
# render it inline
|
||||
display(HTML(htmls))
|
||||
|
||||
|
||||
# display_rating_trace(df_traj.interpolate(method='nearest'))
|
||||
# df_traj
|
||||
@@ -39,7 +39,7 @@ def load_model(model_kwargs, device="auto"):
|
||||
id = model_kwargs.pop("id")
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
device_map=device,
|
||||
torch_dtype=torch_dtype,
|
||||
dtype=torch_dtype,
|
||||
pretrained_model_name_or_path=id,
|
||||
# **model_kwargs,
|
||||
)
|
||||
|
||||
@@ -2,15 +2,33 @@
|
||||
# steering
|
||||
import random
|
||||
from repeng import DatasetEntry
|
||||
import json
|
||||
import json5
|
||||
import torch
|
||||
import itertools
|
||||
from repeng.control import model_layer_list
|
||||
from repeng import ControlVector, ControlModel, DatasetEntry
|
||||
from loguru import logger
|
||||
import contextlib
|
||||
from anycache import anycache
|
||||
from llm_moral_foundations2.config import project_dir
|
||||
|
||||
@contextlib.contextmanager
|
||||
def control(model, vector, coeff):
|
||||
"""
|
||||
Usage:
|
||||
with control(model, vector, coeff):
|
||||
model.generate()
|
||||
"""
|
||||
if coeff==0:
|
||||
model.reset()
|
||||
else:
|
||||
model.set_control(vector, coeff)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
model.reset()
|
||||
|
||||
|
||||
def wrap_model(model):
|
||||
try:
|
||||
n_layers = len(model_layer_list(model))
|
||||
@@ -39,6 +57,7 @@ def train_steering_vector(model, tokenizer, ds_name="scenario_engagement_dataset
|
||||
tokenizer,
|
||||
ds_steer,
|
||||
batch_size=batch_size,
|
||||
method="pca_diff_weighted",
|
||||
)
|
||||
return control_vector
|
||||
|
||||
@@ -53,39 +72,36 @@ def find_last_non_whitespace_token(tokenizer, tokens):
|
||||
return t
|
||||
return t
|
||||
|
||||
def make_dataset(tokenizer, personas, suffixes, verbose=False, template="You're a {persona} making statements about the world."):
|
||||
def make_dataset(tokenizer, personas, suffixes, entities, verbose=False, template="You're {persona}, acting in the world."):
|
||||
thinking_prefix = get_thinking_prefix(tokenizer)
|
||||
|
||||
# Create dataset entries
|
||||
dataset = []
|
||||
for i, _suffix in enumerate(suffixes):
|
||||
for j, (positive_persona, negative_persona) in enumerate(personas):
|
||||
_entity = random.choice(entities)
|
||||
positive_persona = positive_persona.format(entity=_entity)
|
||||
negative_persona = negative_persona.format(entity=_entity)
|
||||
if (thinking_prefix is not None) and random.random() < 0.5:
|
||||
suffix = thinking_prefix + _suffix
|
||||
else:
|
||||
suffix = _suffix + ""
|
||||
|
||||
# tokens = tokenizer.tokenize(suffix, add_special_tokens=False)[:max_suffix_length]
|
||||
# Create multiple training examples with different truncations
|
||||
# for i in range(1, len(tokens), max(1, len(tokens) // 5)): # Using stride to reduce dataset size
|
||||
for think in [0, -1, 1]:
|
||||
suffix = _suffix + ""
|
||||
# truncated = tokenizer.convert_tokens_to_string(tokens[i:])
|
||||
match think:
|
||||
case -1:
|
||||
suffix = "<think>\n\n" + suffix
|
||||
case 1:
|
||||
suffix = "<think>\n\n</think>\n\n" + suffix
|
||||
|
||||
positive_prompt = tokenizer.apply_chat_template(
|
||||
# f"Please talk about {persona}."
|
||||
# f"Pretend you're an {persona} person making statements about the world.
|
||||
# "Act as if you're extremely {persona}.",
|
||||
[{'role': 'user', 'content': template.format(persona=positive_persona)},
|
||||
{'role': 'assistant', 'content': suffix}],
|
||||
tokenize=False,
|
||||
enable_thinking=False,
|
||||
# enable_thinking=False,
|
||||
continue_final_message=True
|
||||
)
|
||||
negative_prompt = tokenizer.apply_chat_template(
|
||||
[{'role': 'user', 'content': template.format(persona=negative_persona)},
|
||||
{'role': 'assistant', 'content': suffix}],
|
||||
tokenize=False,
|
||||
enable_thinking=False,
|
||||
# enable_thinking=False,
|
||||
continue_final_message=True,
|
||||
|
||||
)
|
||||
@@ -105,28 +121,40 @@ def make_dataset(tokenizer, personas, suffixes, verbose=False, template="You're
|
||||
logger.info(f"Dataset example {i}:\n\npositive_prompt={positive_prompt}\n\nnegative_prompt={negative_prompt}")
|
||||
return dataset
|
||||
|
||||
def load_entities():
|
||||
with open(project_dir/"data/steering/_entities.json5") as f:
|
||||
_entities = json5.load(f)
|
||||
return _entities
|
||||
|
||||
def load_suffixes(collapse=True):
|
||||
# loads suffixes
|
||||
with open(project_dir/"data/steering/_suffixes.json") as f:
|
||||
suffixes = json.load(f)
|
||||
with open(project_dir/"data/steering/_suffixes.json5") as f:
|
||||
suffixes = json5.load(f)
|
||||
if collapse:
|
||||
suffixes = list(itertools.chain.from_iterable(suffixes.values()))
|
||||
return suffixes
|
||||
|
||||
def get_thinking_prefix(tokenizer):
|
||||
think_token = tokenizer.convert_tokens_to_ids("<think>")
|
||||
if think_token:
|
||||
return "<think>\n\n"
|
||||
else:
|
||||
return ""
|
||||
|
||||
def load_personas(name):
|
||||
with open(project_dir/f"data/steering/{name}.json5") as f:
|
||||
personas = json5.load(f)
|
||||
return personas['personas']
|
||||
|
||||
def load_steering_ds(tokenizer, ds_name="scenario_engagement_dataset", verbose=False):
|
||||
with open(project_dir/f"data/steering/{ds_name}.json") as f:
|
||||
scenario_data = json.load(f)
|
||||
with open(project_dir/f"data/steering/{ds_name}.json5") as f:
|
||||
scenario_data = json5.load(f)
|
||||
|
||||
# loads suffixes
|
||||
with open(project_dir/"data/steering/_suffixes.json") as f:
|
||||
suffixes = json.load(f)
|
||||
|
||||
# could also sample equally from each
|
||||
suffixes = itertools.chain.from_iterable(suffixes.values())
|
||||
entities = load_entities()
|
||||
suffixes = load_suffixes()
|
||||
|
||||
# suffixes = scenario_data["suffixes"]
|
||||
personas = scenario_data["personas"]
|
||||
dataset = make_dataset(tokenizer, personas, suffixes, verbose=verbose)
|
||||
dataset = make_dataset(tokenizer, personas, suffixes, entities, verbose=verbose)
|
||||
|
||||
return dataset
|
||||
|
||||
@@ -17,6 +17,7 @@ dependencies = [
|
||||
"einops>=0.8.1",
|
||||
"hf-xet>=1.1.9",
|
||||
"jaxtyping>=0.3.2",
|
||||
"json5>=0.12.1",
|
||||
"loguru>=0.7.3",
|
||||
"matplotlib>=3.10.1",
|
||||
"openrouter-wrapper",
|
||||
|
||||
@@ -1085,6 +1085,15 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/1e/e8/685f47e0d754320684db4425a0967f7d3fa70126bffd76110b7009a0090f/joblib-1.5.2-py3-none-any.whl", hash = "sha256:4e1f0bdbb987e6d843c70cf43714cb276623def372df3c22fe5266b2670bc241", size = 308396, upload-time = "2025-08-27T12:15:45.188Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "json5"
|
||||
version = "0.12.1"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/12/ae/929aee9619e9eba9015207a9d2c1c54db18311da7eb4dcf6d41ad6f0eb67/json5-0.12.1.tar.gz", hash = "sha256:b2743e77b3242f8d03c143dd975a6ec7c52e2f2afe76ed934e53503dd4ad4990", size = 52191, upload-time = "2025-08-12T19:47:42.583Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/85/e2/05328bd2621be49a6fed9e3030b1e51a2d04537d3f816d211b9cc53c5262/json5-0.12.1-py3-none-any.whl", hash = "sha256:d9c9b3bc34a5f54d43c35e11ef7cb87d8bdd098c6ace87117a7b7e83e705c1d5", size = 36119, upload-time = "2025-08-12T19:47:41.131Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "jupyter-client"
|
||||
version = "8.6.3"
|
||||
@@ -1245,6 +1254,7 @@ dependencies = [
|
||||
{ name = "einops" },
|
||||
{ name = "hf-xet" },
|
||||
{ name = "jaxtyping" },
|
||||
{ name = "json5" },
|
||||
{ name = "loguru" },
|
||||
{ name = "matplotlib" },
|
||||
{ name = "openrouter-wrapper" },
|
||||
@@ -1274,6 +1284,7 @@ requires-dist = [
|
||||
{ name = "einops", specifier = ">=0.8.1" },
|
||||
{ name = "hf-xet", specifier = ">=1.1.9" },
|
||||
{ name = "jaxtyping", specifier = ">=0.3.2" },
|
||||
{ name = "json5", specifier = ">=0.12.1" },
|
||||
{ name = "loguru", specifier = ">=0.7.3" },
|
||||
{ name = "matplotlib", specifier = ">=3.10.1" },
|
||||
{ name = "openrouter-wrapper", git = "https://github.com/wassname/openrouter_wrapper.git" },
|
||||
|
||||
Reference in New Issue
Block a user