mirror of
https://github.com/wassname/llm-moral-foundations2.git
synced 2026-06-27 16:10:07 +08:00
seperate into guided and not guided reasoning traces
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
import torch
|
||||
from typing import List, Optional, Any, Dict
|
||||
from jaxtyping import Float, Int
|
||||
from jaxtyping import Float, Int, Bool
|
||||
from torch import nn, Tensor, functional as F
|
||||
from transformers import DynamicCache, PreTrainedModel, PreTrainedTokenizer
|
||||
import pandas as pd
|
||||
@@ -16,16 +16,17 @@ from llm_moral_foundations2.hf import clone_dynamic_cache, symlog
|
||||
def force_forked_choice(
|
||||
model: PreTrainedModel,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
# inputs: Int[Tensor, "b s"],
|
||||
choice_ids: List[List[int]],
|
||||
attention_mask: Optional[Int[Tensor, "b s"]] = None,
|
||||
forcing_text="\n\nchoice:",
|
||||
forcing_text="\n\nchoice: ",
|
||||
unthink_s = "</think>",
|
||||
kv_cache: Optional[DynamicCache] = None,
|
||||
think=False,
|
||||
verbose=False,
|
||||
**kwargs
|
||||
) -> Float[Tensor, "b c"]:
|
||||
"""
|
||||
Force the model to produce a specific rating by modifying the input.
|
||||
Force the model to produce a logprob distribution over choices by modifying the input.
|
||||
This uses a cloned kv_cache so it can fork from a generation process
|
||||
Args:
|
||||
- think: Whether to exit thinking
|
||||
@@ -38,40 +39,52 @@ def force_forked_choice(
|
||||
kv_cache = clone_dynamic_cache(kv_cache)
|
||||
|
||||
# modify inputs to force rating
|
||||
s = forcing_text
|
||||
attn_k_shape = kv_cache.layers[0].values.shape
|
||||
bs = attn_k_shape[0]
|
||||
|
||||
# might not be needed in thinking only models
|
||||
if think:
|
||||
s = ". I need to give the answer</think>\n\n" + s
|
||||
|
||||
# bs = kv_cache.key_cache[0].shape[0]
|
||||
bs = kv_cache.layers[0].values.shape[0]
|
||||
|
||||
input_ids = tokenizer.encode(s, return_tensors="pt", add_special_tokens=False).to(model.device).repeat((bs, 1))
|
||||
input_ids = tokenizer.encode(forcing_text, return_tensors="pt", add_special_tokens=False).to(model.device).repeat((bs, 1))
|
||||
|
||||
# note that when using kv_cache we do not need paste inputs, but we do need paste attention mask
|
||||
if attention_mask is not None:
|
||||
new_attn_mask = torch.ones_like(input_ids).long()
|
||||
attention_mask = torch.cat([attention_mask, new_attn_mask], dim=1)
|
||||
|
||||
|
||||
# I need to handle a batch of which some are thinking, some are not. Ideally by masking out this prefix
|
||||
|
||||
unthink_ids = tokenizer.encode(unthink_s, return_tensors="pt", add_special_tokens=False).to(model.device).repeat((bs, 1))
|
||||
if attention_mask is None:
|
||||
# Note attentions mask need to cover the cache, and inputs
|
||||
attention_mask = torch.ones((bs, attn_k_shape[2] + input_ids.shape[1]), dtype=torch.long, device=model.device)
|
||||
|
||||
# always insert unthink, but sometimes mask it
|
||||
input_ids = torch.concat([unthink_ids, input_ids], dim=1)
|
||||
attention_mask = torch.cat([torch.ones_like(unthink_ids).long() * think, attention_mask], dim=1)
|
||||
|
||||
# cache_position = attn_k_shape
|
||||
|
||||
o = model(
|
||||
input_ids=input_ids, attention_mask=attention_mask, return_dict=True, past_key_values=kv_cache, use_cache=True
|
||||
input_ids=input_ids, attention_mask=attention_mask, return_dict=True, past_key_values=kv_cache, use_cache=True,
|
||||
# cache_position=cache_position,
|
||||
**kwargs
|
||||
)
|
||||
logprobs = o.logits[:, -1].log_softmax(dim=-1).float()
|
||||
|
||||
if verbose:
|
||||
bi = 0
|
||||
# print("-" * 20 + "force rating outputs" + "-" * 20)
|
||||
# out_string = tokenizer.decode(o.logits.argmax(dim=-1)[bi], skip_special_tokens=True)#[-1]
|
||||
# print("decode(outputs)", out_string)
|
||||
# print("-" * 80)
|
||||
|
||||
# Also print top 10 tokens so I can debug low prob mass
|
||||
top_k = logprobs.topk(10, dim=-1)
|
||||
print(f"Top 10 tokens for batch {bi} after forcing:")
|
||||
print(f"Forcing text: `{forcing_text}`")
|
||||
print(f"Input IDs: `{tokenizer.decode(input_ids[bi], skip_special_tokens=False)}`")
|
||||
attn_mask_input = attention_mask[bi, -input_ids.shape[1]:]
|
||||
print(f"Input IDs: `{tokenizer.decode(input_ids[bi]*attn_mask_input, skip_special_tokens=False)}`")
|
||||
for token_id, prob in zip(top_k.indices[bi], top_k.values[bi]):
|
||||
print(f"Token: {tokenizer.decode([token_id])}, Logprob: {prob.item()}")
|
||||
print(f"Token: `{tokenizer.decode([token_id])}`, Logprob: {prob.item()}")
|
||||
|
||||
print(f"KV cache size {attn_k_shape}")
|
||||
|
||||
print("-" * 80)
|
||||
|
||||
if choice_ids is None:
|
||||
@@ -84,9 +97,6 @@ def force_forked_choice(
|
||||
choice_group_lprobs = logprobs[:, choice_group]
|
||||
choice_lprobs[:, i] = torch.logsumexp(choice_group_lprobs, dim=-1).detach().cpu()
|
||||
|
||||
probmass = choice_lprobs.exp().sum(dim=-1)
|
||||
if probmass.mean()<0.75:
|
||||
logger.warning(f"Low probability mass detected: {probmass.mean():.2f}. Check your model's interaction with prompt, and choice tokens, and forcing text.")
|
||||
return choice_lprobs
|
||||
|
||||
|
||||
@@ -117,27 +127,49 @@ def get_last_token_id_pos(all_input_ids: Int[Tensor, "s"], token_id, tokenizer)
|
||||
|
||||
# return last_think_pos > last_unthink_pos
|
||||
|
||||
def gen_reasoning_trace(
|
||||
|
||||
def is_thinking(
|
||||
all_input_ids: Int[Tensor, "b s"],
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
) -> Bool[Tensor, "b"]:
|
||||
"""Batched check if in thinking state."""
|
||||
if all_input_ids.shape[1] == 0:
|
||||
return torch.ones(all_input_ids.shape[0], dtype=torch.bool, device=all_input_ids.device)
|
||||
think_id = tokenizer.convert_tokens_to_ids("<think>")
|
||||
unthink_id = tokenizer.convert_tokens_to_ids("</think>")
|
||||
|
||||
rev_ids = all_input_ids.flip(dims=[1])
|
||||
seq_len = all_input_ids.shape[1]
|
||||
|
||||
think_mask = (rev_ids == think_id).float()
|
||||
unthink_mask = (rev_ids == unthink_id).float()
|
||||
|
||||
last_think_pos = torch.where(think_mask.any(1), seq_len - 1 - think_mask.argmax(1), torch.full_like(think_mask[:, 0], -1))
|
||||
last_unthink_pos = torch.where(unthink_mask.any(1), seq_len - 1 - unthink_mask.argmax(1), torch.full_like(unthink_mask[:, 0], -1))
|
||||
|
||||
# TODO should be bool tensor
|
||||
is_thinking = last_think_pos > last_unthink_pos
|
||||
|
||||
return is_thinking
|
||||
|
||||
def gen_reasoning_trace_guided(
|
||||
model: PreTrainedModel,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
# messages: List[Dict[str, str]],
|
||||
input_ids: Tensor,
|
||||
device,
|
||||
verbose=False,
|
||||
attn_mask: Optional[Tensor] = None,
|
||||
max_new_tokens: int = 130,
|
||||
min_new_tokens: int = 1,
|
||||
forcing_text: str = "\n\nchoice:",
|
||||
min_thinking_tokens: int = 1,
|
||||
max_thinking_tokens: Optional[int] = None,
|
||||
fork_every: int = 10,
|
||||
banned_token_ids: Optional[Int[Tensor, "d"]] = None,
|
||||
choice_token_ids: Optional[Int[Tensor, "c"]] = None,
|
||||
do_sample=False,
|
||||
end_think_s = "</think>",
|
||||
):
|
||||
"""
|
||||
A modified generate that will
|
||||
- stop thinking half way through
|
||||
- stop thinking after N tokens
|
||||
- fork the generation process and force and answer (cached) every `fork_every` steps
|
||||
- avoid banned tokens (by default all special tokens including </think>)
|
||||
"""
|
||||
@@ -151,7 +183,7 @@ def gen_reasoning_trace(
|
||||
|
||||
all_input_ids = input_ids.clone()
|
||||
|
||||
input_ids = input_ids.to(device)
|
||||
input_ids = input_ids.to(model.device)
|
||||
|
||||
if verbose:
|
||||
inputs_decoded = tokenizer.decode(input_ids[0], skip_special_tokens=False)
|
||||
@@ -184,7 +216,7 @@ def gen_reasoning_trace(
|
||||
|
||||
logits = o.logits[:, -1].clone()
|
||||
logits[:, banned_token_ids] = -float("inf")
|
||||
if len(data)<min_thinking_tokens:
|
||||
if len(data)<max_thinking_tokens:
|
||||
eot_token_id = tokenizer.convert_tokens_to_ids(["</think>"])
|
||||
logits[:, eot_token_id] = -float("inf")
|
||||
for proc in logits_processors:
|
||||
@@ -210,19 +242,20 @@ def gen_reasoning_trace(
|
||||
break
|
||||
|
||||
if is_choice_token or (i % fork_every == 0):
|
||||
# _is_thinking = is_thinking(all_input_ids[0], tokenizer)
|
||||
logp_choices = force_forked_choice(
|
||||
model,
|
||||
tokenizer,
|
||||
# input_ids,
|
||||
attention_mask=attn_mask,
|
||||
kv_cache=kv_cache,
|
||||
think=True, # always add </think> anyway
|
||||
think=i>max_thinking_tokens, # always add </think> anyway
|
||||
# verbose=i in [5, max_new_tokens // 2 + 5],
|
||||
choice_ids=choice_token_ids,
|
||||
verbose=verbose,
|
||||
forcing_text=forcing_text
|
||||
)
|
||||
# if probmass.mean()<0.75:
|
||||
# logger.warning(f"Low probability mass detected: {probmass.mean():.2f} at token position {i}. Check your model's interaction with prompt, and choice tokens, and forcing text.")
|
||||
else:
|
||||
logp_choices = None
|
||||
|
||||
@@ -237,25 +270,31 @@ def gen_reasoning_trace(
|
||||
)
|
||||
|
||||
if (max_thinking_tokens is not None) and (i == max_thinking_tokens):
|
||||
|
||||
# end thinking
|
||||
think_token_id = convert_tokens_to_longs("</think>", tokenizer).to(input_ids.device).repeat((input_ids.shape[0], 1))
|
||||
input_ids = torch.cat([input_ids, think_token_id], dim=1)
|
||||
new_input_ids = tokenizer.encode(end_think_s, return_tensors="pt", add_special_tokens=False).to(input_ids.device).repeat((bs, 1))
|
||||
input_ids = torch.cat([input_ids, new_input_ids], dim=1)
|
||||
if attn_mask is not None:
|
||||
attn_mask = torch.cat([attn_mask, torch.ones_like(think_token_id).long()], dim=1)
|
||||
attn_mask = torch.cat([attn_mask, torch.ones_like(new_input_ids).long()], dim=1)
|
||||
|
||||
new_strs = tokenizer.batch_decode(new_input_ids[0], skip_special_tokens=False)
|
||||
for j in range(bs):
|
||||
data[j].append(
|
||||
{
|
||||
"token": "</think>",
|
||||
"ii": i + 0.5,
|
||||
}
|
||||
for k in range(len(new_strs)):
|
||||
data[j].append(
|
||||
{
|
||||
"token": new_strs[k],
|
||||
"ii": i + 0.5,
|
||||
}
|
||||
)
|
||||
|
||||
all_input_ids = torch.cat([all_input_ids, input_ids], dim=1)
|
||||
|
||||
# stop once all samples in the batch has produced an eos, after the inputs
|
||||
if all_input_ids[:, nb_input_tokens:].eq(tokenizer.eos_token_id).any(dim=1).all():
|
||||
if all_input_ids.shape[1] >= nb_input_tokens + min_new_tokens:
|
||||
break
|
||||
# # stop once all samples in the batch has produced an eos, after the inputs
|
||||
# if all_input_ids[:, nb_input_tokens:].eq(tokenizer.eos_token_id).any(dim=1).all():
|
||||
# if all_input_ids.shape[1] >= nb_input_tokens + min_new_tokens:
|
||||
# if (max_thinking_tokens is None) or (i >= max_thinking_tokens):
|
||||
# # TODO replace eos, with </think> if it's trying to end without stopping thinking
|
||||
# break
|
||||
|
||||
|
||||
full_strings = tokenizer.batch_decode(all_input_ids, skip_special_tokens=False)
|
||||
@@ -275,3 +314,78 @@ def gen_reasoning_trace(
|
||||
"model_name": model.config._name_or_path,
|
||||
})
|
||||
return dfs, full_strings
|
||||
|
||||
|
||||
def gen_reasoning_trace(
|
||||
model: PreTrainedModel,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
input_ids: Tensor,
|
||||
verbose=False,
|
||||
attn_mask: Optional[Tensor] = None,
|
||||
max_new_tokens: int = 130,
|
||||
min_new_tokens: int = 1,
|
||||
forcing_text: str = "\n\nchoice: ",
|
||||
fork_every: int = 10,
|
||||
choice_token_ids: Optional[Int[Tensor, "c"]] = None,
|
||||
**kwargs
|
||||
):
|
||||
"""
|
||||
A modified generate that will
|
||||
- fork the generation process and force and answer (cached) every `fork_every` steps
|
||||
"""
|
||||
out = model.generate(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attn_mask,
|
||||
max_new_tokens=max_new_tokens,
|
||||
min_new_tokens=min_new_tokens,
|
||||
use_cache=True,
|
||||
return_dict_in_generate=True,
|
||||
**kwargs
|
||||
)
|
||||
bs = input_ids.shape[0]
|
||||
kv_cache = out.past_key_values
|
||||
out_token_s = [tokenizer.batch_decode(s, skip_special_tokens=False) for s in out.sequences]
|
||||
full_strings = tokenizer.batch_decode(out.sequences, skip_special_tokens=False)
|
||||
|
||||
|
||||
data = [[] for _ in range(bs)]
|
||||
for ti in range(input_ids.shape[1], out.sequences.shape[1]):
|
||||
if (ti%fork_every == 0):
|
||||
|
||||
# clone and crop cache
|
||||
kv_cache2 = clone_dynamic_cache(kv_cache, crop=ti)
|
||||
think = is_thinking(out.sequences[:, :ti], tokenizer)
|
||||
|
||||
logp_choices = force_forked_choice(
|
||||
model,
|
||||
tokenizer,
|
||||
attention_mask=attn_mask,
|
||||
kv_cache=kv_cache2,
|
||||
think=think,
|
||||
choice_ids=choice_token_ids,
|
||||
verbose=verbose and (ti in [fork_every, fork_every*2, max_new_tokens-max_new_tokens%fork_every]),
|
||||
forcing_text=forcing_text
|
||||
)
|
||||
else:
|
||||
logp_choices = None
|
||||
for j in range(bs):
|
||||
data[j].append(
|
||||
{
|
||||
"token": out_token_s[j][ti],
|
||||
"token_strs": out_token_s[j][ti],
|
||||
"logp_choices": logp_choices[j].numpy() if (logp_choices is not None) else None,
|
||||
"ii": ti,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
dfs = [pd.DataFrame(d) for d in data]
|
||||
|
||||
for ti, df in enumerate(dfs):
|
||||
df.attrs.update({
|
||||
"max_new_tokens": max_new_tokens,
|
||||
"model_name": model.config._name_or_path,
|
||||
})
|
||||
|
||||
# TODO check mas df_traj['probmass'].max()>0.5
|
||||
return dfs, full_strings
|
||||
|
||||
@@ -11,7 +11,7 @@ import pandas as pd
|
||||
cmap_RdGyGn = Colormap(['red', 'grey', 'green']).to_mpl()
|
||||
|
||||
|
||||
def display_rating_trace(df: pd.DataFrame, title="", key='act_prob', cmap=cmap_RdGyGn, symmetric = False, highlight_tokens=[]):
|
||||
def display_rating_trace(df: pd.DataFrame, title="", key='act_prob', cmap=cmap_RdGyGn, symmetric = False, highlight_tokens=[], s_key="token_strs"):
|
||||
"""
|
||||
Display colored tokens from a chain of thought.
|
||||
|
||||
@@ -19,14 +19,17 @@ def display_rating_trace(df: pd.DataFrame, title="", key='act_prob', cmap=cmap_R
|
||||
- key: the dataframe columns with the single score to visualise
|
||||
"""
|
||||
# v = df[key].abs().max()
|
||||
v = df[key].abs()
|
||||
v = df[key]
|
||||
if symmetric:
|
||||
v = v.abs()
|
||||
vmax = v.quantile(0.99)
|
||||
# vmax = v.quantile(0.99)
|
||||
vmax = max(v)
|
||||
norm = mpl.colors.CenteredNorm(0, vmax)
|
||||
else:
|
||||
vmin = v.quantile(0.01)
|
||||
vmax = v.quantile(0.99)
|
||||
vmin = min(v)
|
||||
vmax = max(v)
|
||||
norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax)
|
||||
|
||||
# show colormap
|
||||
@@ -42,8 +45,8 @@ def display_rating_trace(df: pd.DataFrame, title="", key='act_prob', cmap=cmap_R
|
||||
# TODO show \n as <p>
|
||||
htmls = f'<h3>{title}</h3>'
|
||||
for n,row in df.iterrows():
|
||||
token, score = row['token_strs'], row[key]
|
||||
# token = token.replace('Ġ', ' ').replace('Ċ', '\n')
|
||||
token, score = row[s_key], row[key]
|
||||
token = token.replace('Ġ', ' ').replace('Ċ', '\n').replace("▁", " ")
|
||||
|
||||
# html escape
|
||||
token = html.escape(token)
|
||||
@@ -53,7 +56,7 @@ def display_rating_trace(df: pd.DataFrame, title="", key='act_prob', cmap=cmap_R
|
||||
htmls += "<p/>"
|
||||
if token in highlight_tokens:
|
||||
token = f'!{token}!!'
|
||||
h = f'<span title="{score} n={n}" style="color: {hex_color};">{token}</span>'
|
||||
h = f'<span title="{score} n={n}" style="color: {hex_color};">{token} </span>'
|
||||
if token.endswith("\n"):
|
||||
htmls += "<p/>"
|
||||
htmls += h
|
||||
|
||||
@@ -1,13 +1,17 @@
|
||||
from typing import Optional
|
||||
from transformers import DynamicCache
|
||||
import numpy as np
|
||||
|
||||
def clone_dynamic_cache(kv_cache):
|
||||
def clone_dynamic_cache(kv_cache, crop: Optional[int]=None):
|
||||
if (kv_cache is None) or len(kv_cache)==0:
|
||||
return DynamicCache()
|
||||
c = kv_cache.to_legacy_cache()
|
||||
c = ((a.clone(), b.clone()) for a, b in c)
|
||||
c = tuple(c)
|
||||
return DynamicCache.from_legacy_cache(c)
|
||||
lyrs = kv_cache.to_legacy_cache()
|
||||
# 2560, 128, 4096, 1024 8 from attn
|
||||
# [layers x (k,v), where
|
||||
# k.shape and v.shape [batch, 8=num_heads, seq=623, 128]
|
||||
lyrs = ((k.clone()[:, :, :crop], v[:, :, :crop].clone()) for k, v in lyrs)
|
||||
lyrs = tuple(lyrs)
|
||||
return DynamicCache.from_legacy_cache(lyrs)
|
||||
|
||||
def symlog(x):
|
||||
return np.sign(x) * np.log1p(np.abs(x))
|
||||
|
||||
@@ -72,8 +72,11 @@ def find_last_non_whitespace_token(tokenizer, tokens):
|
||||
return t
|
||||
return t
|
||||
|
||||
def make_dataset(tokenizer, personas, suffixes, entities, verbose=False, template="You're {persona}, acting in the world."):
|
||||
thinking_prefix = get_thinking_prefix(tokenizer)
|
||||
def make_dataset(tokenizer, personas, suffixes, entities, verbose=False, template="You're {persona}, acting in the world.", include_thinking=False):
|
||||
if include_thinking:
|
||||
thinking_prefix = get_thinking_prefix(tokenizer)
|
||||
else:
|
||||
thinking_prefix = None
|
||||
|
||||
# Create dataset entries
|
||||
dataset = []
|
||||
@@ -118,6 +121,7 @@ def make_dataset(tokenizer, personas, suffixes, entities, verbose=False, templat
|
||||
random.shuffle(dataset)
|
||||
if verbose:
|
||||
for i in range(3):
|
||||
positive_prompt, negative_prompt = dataset[i].positive, dataset[i].negative
|
||||
logger.info(f"Dataset example {i}:\n\npositive_prompt={positive_prompt}\n\nnegative_prompt={negative_prompt}")
|
||||
return dataset
|
||||
|
||||
|
||||
File diff suppressed because one or more lines are too long
@@ -25,6 +25,7 @@ dependencies = [
|
||||
"repeng",
|
||||
"seaborn>=0.13.2",
|
||||
"srsly>=2.5.1",
|
||||
"tabulate>=0.9.0",
|
||||
"torch>=2.6.0",
|
||||
"transformers>=4.51.3",
|
||||
"umap-learn>=0.5.9.post2",
|
||||
|
||||
Reference in New Issue
Block a user