seperate into guided and not guided reasoning traces

This commit is contained in:
wassname
2025-09-16 16:13:03 +08:00
parent 6c3f92e183
commit f05ffe1e9a
6 changed files with 450 additions and 1210 deletions
+158 -44
View File
@@ -1,6 +1,6 @@
import torch
from typing import List, Optional, Any, Dict
from jaxtyping import Float, Int
from jaxtyping import Float, Int, Bool
from torch import nn, Tensor, functional as F
from transformers import DynamicCache, PreTrainedModel, PreTrainedTokenizer
import pandas as pd
@@ -16,16 +16,17 @@ from llm_moral_foundations2.hf import clone_dynamic_cache, symlog
def force_forked_choice(
model: PreTrainedModel,
tokenizer: PreTrainedTokenizer,
# inputs: Int[Tensor, "b s"],
choice_ids: List[List[int]],
attention_mask: Optional[Int[Tensor, "b s"]] = None,
forcing_text="\n\nchoice:",
forcing_text="\n\nchoice: ",
unthink_s = "</think>",
kv_cache: Optional[DynamicCache] = None,
think=False,
verbose=False,
**kwargs
) -> Float[Tensor, "b c"]:
"""
Force the model to produce a specific rating by modifying the input.
Force the model to produce a logprob distribution over choices by modifying the input.
This uses a cloned kv_cache so it can fork from a generation process
Args:
- think: Whether to exit thinking
@@ -38,40 +39,52 @@ def force_forked_choice(
kv_cache = clone_dynamic_cache(kv_cache)
# modify inputs to force rating
s = forcing_text
attn_k_shape = kv_cache.layers[0].values.shape
bs = attn_k_shape[0]
# might not be needed in thinking only models
if think:
s = ". I need to give the answer</think>\n\n" + s
# bs = kv_cache.key_cache[0].shape[0]
bs = kv_cache.layers[0].values.shape[0]
input_ids = tokenizer.encode(s, return_tensors="pt", add_special_tokens=False).to(model.device).repeat((bs, 1))
input_ids = tokenizer.encode(forcing_text, return_tensors="pt", add_special_tokens=False).to(model.device).repeat((bs, 1))
# note that when using kv_cache we do not need paste inputs, but we do need paste attention mask
if attention_mask is not None:
new_attn_mask = torch.ones_like(input_ids).long()
attention_mask = torch.cat([attention_mask, new_attn_mask], dim=1)
# I need to handle a batch of which some are thinking, some are not. Ideally by masking out this prefix
unthink_ids = tokenizer.encode(unthink_s, return_tensors="pt", add_special_tokens=False).to(model.device).repeat((bs, 1))
if attention_mask is None:
# Note attentions mask need to cover the cache, and inputs
attention_mask = torch.ones((bs, attn_k_shape[2] + input_ids.shape[1]), dtype=torch.long, device=model.device)
# always insert unthink, but sometimes mask it
input_ids = torch.concat([unthink_ids, input_ids], dim=1)
attention_mask = torch.cat([torch.ones_like(unthink_ids).long() * think, attention_mask], dim=1)
# cache_position = attn_k_shape
o = model(
input_ids=input_ids, attention_mask=attention_mask, return_dict=True, past_key_values=kv_cache, use_cache=True
input_ids=input_ids, attention_mask=attention_mask, return_dict=True, past_key_values=kv_cache, use_cache=True,
# cache_position=cache_position,
**kwargs
)
logprobs = o.logits[:, -1].log_softmax(dim=-1).float()
if verbose:
bi = 0
# print("-" * 20 + "force rating outputs" + "-" * 20)
# out_string = tokenizer.decode(o.logits.argmax(dim=-1)[bi], skip_special_tokens=True)#[-1]
# print("decode(outputs)", out_string)
# print("-" * 80)
# Also print top 10 tokens so I can debug low prob mass
top_k = logprobs.topk(10, dim=-1)
print(f"Top 10 tokens for batch {bi} after forcing:")
print(f"Forcing text: `{forcing_text}`")
print(f"Input IDs: `{tokenizer.decode(input_ids[bi], skip_special_tokens=False)}`")
attn_mask_input = attention_mask[bi, -input_ids.shape[1]:]
print(f"Input IDs: `{tokenizer.decode(input_ids[bi]*attn_mask_input, skip_special_tokens=False)}`")
for token_id, prob in zip(top_k.indices[bi], top_k.values[bi]):
print(f"Token: {tokenizer.decode([token_id])}, Logprob: {prob.item()}")
print(f"Token: `{tokenizer.decode([token_id])}`, Logprob: {prob.item()}")
print(f"KV cache size {attn_k_shape}")
print("-" * 80)
if choice_ids is None:
@@ -84,9 +97,6 @@ def force_forked_choice(
choice_group_lprobs = logprobs[:, choice_group]
choice_lprobs[:, i] = torch.logsumexp(choice_group_lprobs, dim=-1).detach().cpu()
probmass = choice_lprobs.exp().sum(dim=-1)
if probmass.mean()<0.75:
logger.warning(f"Low probability mass detected: {probmass.mean():.2f}. Check your model's interaction with prompt, and choice tokens, and forcing text.")
return choice_lprobs
@@ -117,27 +127,49 @@ def get_last_token_id_pos(all_input_ids: Int[Tensor, "s"], token_id, tokenizer)
# return last_think_pos > last_unthink_pos
def gen_reasoning_trace(
def is_thinking(
all_input_ids: Int[Tensor, "b s"],
tokenizer: PreTrainedTokenizer,
) -> Bool[Tensor, "b"]:
"""Batched check if in thinking state."""
if all_input_ids.shape[1] == 0:
return torch.ones(all_input_ids.shape[0], dtype=torch.bool, device=all_input_ids.device)
think_id = tokenizer.convert_tokens_to_ids("<think>")
unthink_id = tokenizer.convert_tokens_to_ids("</think>")
rev_ids = all_input_ids.flip(dims=[1])
seq_len = all_input_ids.shape[1]
think_mask = (rev_ids == think_id).float()
unthink_mask = (rev_ids == unthink_id).float()
last_think_pos = torch.where(think_mask.any(1), seq_len - 1 - think_mask.argmax(1), torch.full_like(think_mask[:, 0], -1))
last_unthink_pos = torch.where(unthink_mask.any(1), seq_len - 1 - unthink_mask.argmax(1), torch.full_like(unthink_mask[:, 0], -1))
# TODO should be bool tensor
is_thinking = last_think_pos > last_unthink_pos
return is_thinking
def gen_reasoning_trace_guided(
model: PreTrainedModel,
tokenizer: PreTrainedTokenizer,
# messages: List[Dict[str, str]],
input_ids: Tensor,
device,
verbose=False,
attn_mask: Optional[Tensor] = None,
max_new_tokens: int = 130,
min_new_tokens: int = 1,
forcing_text: str = "\n\nchoice:",
min_thinking_tokens: int = 1,
max_thinking_tokens: Optional[int] = None,
fork_every: int = 10,
banned_token_ids: Optional[Int[Tensor, "d"]] = None,
choice_token_ids: Optional[Int[Tensor, "c"]] = None,
do_sample=False,
end_think_s = "</think>",
):
"""
A modified generate that will
- stop thinking half way through
- stop thinking after N tokens
- fork the generation process and force and answer (cached) every `fork_every` steps
- avoid banned tokens (by default all special tokens including </think>)
"""
@@ -151,7 +183,7 @@ def gen_reasoning_trace(
all_input_ids = input_ids.clone()
input_ids = input_ids.to(device)
input_ids = input_ids.to(model.device)
if verbose:
inputs_decoded = tokenizer.decode(input_ids[0], skip_special_tokens=False)
@@ -184,7 +216,7 @@ def gen_reasoning_trace(
logits = o.logits[:, -1].clone()
logits[:, banned_token_ids] = -float("inf")
if len(data)<min_thinking_tokens:
if len(data)<max_thinking_tokens:
eot_token_id = tokenizer.convert_tokens_to_ids(["</think>"])
logits[:, eot_token_id] = -float("inf")
for proc in logits_processors:
@@ -210,19 +242,20 @@ def gen_reasoning_trace(
break
if is_choice_token or (i % fork_every == 0):
# _is_thinking = is_thinking(all_input_ids[0], tokenizer)
logp_choices = force_forked_choice(
model,
tokenizer,
# input_ids,
attention_mask=attn_mask,
kv_cache=kv_cache,
think=True, # always add </think> anyway
think=i>max_thinking_tokens, # always add </think> anyway
# verbose=i in [5, max_new_tokens // 2 + 5],
choice_ids=choice_token_ids,
verbose=verbose,
forcing_text=forcing_text
)
# if probmass.mean()<0.75:
# logger.warning(f"Low probability mass detected: {probmass.mean():.2f} at token position {i}. Check your model's interaction with prompt, and choice tokens, and forcing text.")
else:
logp_choices = None
@@ -237,25 +270,31 @@ def gen_reasoning_trace(
)
if (max_thinking_tokens is not None) and (i == max_thinking_tokens):
# end thinking
think_token_id = convert_tokens_to_longs("</think>", tokenizer).to(input_ids.device).repeat((input_ids.shape[0], 1))
input_ids = torch.cat([input_ids, think_token_id], dim=1)
new_input_ids = tokenizer.encode(end_think_s, return_tensors="pt", add_special_tokens=False).to(input_ids.device).repeat((bs, 1))
input_ids = torch.cat([input_ids, new_input_ids], dim=1)
if attn_mask is not None:
attn_mask = torch.cat([attn_mask, torch.ones_like(think_token_id).long()], dim=1)
attn_mask = torch.cat([attn_mask, torch.ones_like(new_input_ids).long()], dim=1)
new_strs = tokenizer.batch_decode(new_input_ids[0], skip_special_tokens=False)
for j in range(bs):
data[j].append(
{
"token": "</think>",
"ii": i + 0.5,
}
for k in range(len(new_strs)):
data[j].append(
{
"token": new_strs[k],
"ii": i + 0.5,
}
)
all_input_ids = torch.cat([all_input_ids, input_ids], dim=1)
# stop once all samples in the batch has produced an eos, after the inputs
if all_input_ids[:, nb_input_tokens:].eq(tokenizer.eos_token_id).any(dim=1).all():
if all_input_ids.shape[1] >= nb_input_tokens + min_new_tokens:
break
# # stop once all samples in the batch has produced an eos, after the inputs
# if all_input_ids[:, nb_input_tokens:].eq(tokenizer.eos_token_id).any(dim=1).all():
# if all_input_ids.shape[1] >= nb_input_tokens + min_new_tokens:
# if (max_thinking_tokens is None) or (i >= max_thinking_tokens):
# # TODO replace eos, with </think> if it's trying to end without stopping thinking
# break
full_strings = tokenizer.batch_decode(all_input_ids, skip_special_tokens=False)
@@ -275,3 +314,78 @@ def gen_reasoning_trace(
"model_name": model.config._name_or_path,
})
return dfs, full_strings
def gen_reasoning_trace(
model: PreTrainedModel,
tokenizer: PreTrainedTokenizer,
input_ids: Tensor,
verbose=False,
attn_mask: Optional[Tensor] = None,
max_new_tokens: int = 130,
min_new_tokens: int = 1,
forcing_text: str = "\n\nchoice: ",
fork_every: int = 10,
choice_token_ids: Optional[Int[Tensor, "c"]] = None,
**kwargs
):
"""
A modified generate that will
- fork the generation process and force and answer (cached) every `fork_every` steps
"""
out = model.generate(
input_ids=input_ids,
attention_mask=attn_mask,
max_new_tokens=max_new_tokens,
min_new_tokens=min_new_tokens,
use_cache=True,
return_dict_in_generate=True,
**kwargs
)
bs = input_ids.shape[0]
kv_cache = out.past_key_values
out_token_s = [tokenizer.batch_decode(s, skip_special_tokens=False) for s in out.sequences]
full_strings = tokenizer.batch_decode(out.sequences, skip_special_tokens=False)
data = [[] for _ in range(bs)]
for ti in range(input_ids.shape[1], out.sequences.shape[1]):
if (ti%fork_every == 0):
# clone and crop cache
kv_cache2 = clone_dynamic_cache(kv_cache, crop=ti)
think = is_thinking(out.sequences[:, :ti], tokenizer)
logp_choices = force_forked_choice(
model,
tokenizer,
attention_mask=attn_mask,
kv_cache=kv_cache2,
think=think,
choice_ids=choice_token_ids,
verbose=verbose and (ti in [fork_every, fork_every*2, max_new_tokens-max_new_tokens%fork_every]),
forcing_text=forcing_text
)
else:
logp_choices = None
for j in range(bs):
data[j].append(
{
"token": out_token_s[j][ti],
"token_strs": out_token_s[j][ti],
"logp_choices": logp_choices[j].numpy() if (logp_choices is not None) else None,
"ii": ti,
}
)
dfs = [pd.DataFrame(d) for d in data]
for ti, df in enumerate(dfs):
df.attrs.update({
"max_new_tokens": max_new_tokens,
"model_name": model.config._name_or_path,
})
# TODO check mas df_traj['probmass'].max()>0.5
return dfs, full_strings
+9 -6
View File
@@ -11,7 +11,7 @@ import pandas as pd
cmap_RdGyGn = Colormap(['red', 'grey', 'green']).to_mpl()
def display_rating_trace(df: pd.DataFrame, title="", key='act_prob', cmap=cmap_RdGyGn, symmetric = False, highlight_tokens=[]):
def display_rating_trace(df: pd.DataFrame, title="", key='act_prob', cmap=cmap_RdGyGn, symmetric = False, highlight_tokens=[], s_key="token_strs"):
"""
Display colored tokens from a chain of thought.
@@ -19,14 +19,17 @@ def display_rating_trace(df: pd.DataFrame, title="", key='act_prob', cmap=cmap_R
- key: the dataframe columns with the single score to visualise
"""
# v = df[key].abs().max()
v = df[key].abs()
v = df[key]
if symmetric:
v = v.abs()
vmax = v.quantile(0.99)
# vmax = v.quantile(0.99)
vmax = max(v)
norm = mpl.colors.CenteredNorm(0, vmax)
else:
vmin = v.quantile(0.01)
vmax = v.quantile(0.99)
vmin = min(v)
vmax = max(v)
norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax)
# show colormap
@@ -42,8 +45,8 @@ def display_rating_trace(df: pd.DataFrame, title="", key='act_prob', cmap=cmap_R
# TODO show \n as <p>
htmls = f'<h3>{title}</h3>'
for n,row in df.iterrows():
token, score = row['token_strs'], row[key]
# token = token.replace('Ġ', ' ').replace('Ċ', '\n')
token, score = row[s_key], row[key]
token = token.replace('Ġ', ' ').replace('Ċ', '\n').replace("", " ")
# html escape
token = html.escape(token)
@@ -53,7 +56,7 @@ def display_rating_trace(df: pd.DataFrame, title="", key='act_prob', cmap=cmap_R
htmls += "<p/>"
if token in highlight_tokens:
token = f'!{token}!!'
h = f'<span title="{score} n={n}" style="color: {hex_color};">{token}</span>'
h = f'<span title="{score} n={n}" style="color: {hex_color};">{token} </span>'
if token.endswith("\n"):
htmls += "<p/>"
htmls += h
+9 -5
View File
@@ -1,13 +1,17 @@
from typing import Optional
from transformers import DynamicCache
import numpy as np
def clone_dynamic_cache(kv_cache):
def clone_dynamic_cache(kv_cache, crop: Optional[int]=None):
if (kv_cache is None) or len(kv_cache)==0:
return DynamicCache()
c = kv_cache.to_legacy_cache()
c = ((a.clone(), b.clone()) for a, b in c)
c = tuple(c)
return DynamicCache.from_legacy_cache(c)
lyrs = kv_cache.to_legacy_cache()
# 2560, 128, 4096, 1024 8 from attn
# [layers x (k,v), where
# k.shape and v.shape [batch, 8=num_heads, seq=623, 128]
lyrs = ((k.clone()[:, :, :crop], v[:, :, :crop].clone()) for k, v in lyrs)
lyrs = tuple(lyrs)
return DynamicCache.from_legacy_cache(lyrs)
def symlog(x):
return np.sign(x) * np.log1p(np.abs(x))
+6 -2
View File
@@ -72,8 +72,11 @@ def find_last_non_whitespace_token(tokenizer, tokens):
return t
return t
def make_dataset(tokenizer, personas, suffixes, entities, verbose=False, template="You're {persona}, acting in the world."):
thinking_prefix = get_thinking_prefix(tokenizer)
def make_dataset(tokenizer, personas, suffixes, entities, verbose=False, template="You're {persona}, acting in the world.", include_thinking=False):
if include_thinking:
thinking_prefix = get_thinking_prefix(tokenizer)
else:
thinking_prefix = None
# Create dataset entries
dataset = []
@@ -118,6 +121,7 @@ def make_dataset(tokenizer, personas, suffixes, entities, verbose=False, templat
random.shuffle(dataset)
if verbose:
for i in range(3):
positive_prompt, negative_prompt = dataset[i].positive, dataset[i].negative
logger.info(f"Dataset example {i}:\n\npositive_prompt={positive_prompt}\n\nnegative_prompt={negative_prompt}")
return dataset
File diff suppressed because one or more lines are too long
+1
View File
@@ -25,6 +25,7 @@ dependencies = [
"repeng",
"seaborn>=0.13.2",
"srsly>=2.5.1",
"tabulate>=0.9.0",
"torch>=2.6.0",
"transformers>=4.51.3",
"umap-learn>=0.5.9.post2",