mirror of
https://github.com/wassname/lie_elicitation_prompts.git
synced 2026-06-27 16:10:35 +08:00
454 lines
17 KiB
Python
454 lines
17 KiB
Python
"""
|
|
Modified from https://github.com/EleutherAI/elk/blob/3bbe26c3858aac1b03e6f80628a5056fae44db9c/elk/extraction/prompt_loading.py#L129
|
|
|
|
Changed to record choices
|
|
"""
|
|
from collections import Counter
|
|
from random import Random
|
|
from typing import Any, Iterator, Literal, List, Dict
|
|
from pathlib import Path
|
|
# import datasets
|
|
from datasets import ClassLabel, Dataset, Value, load_dataset
|
|
import yaml
|
|
import numpy as np
|
|
# from elk.promptsource.templates import env
|
|
from elk.promptsource import DatasetTemplates
|
|
from lie_elicitation_prompts.prompts.templates import LocalDatasetTemplates
|
|
from elk.utils import (
|
|
assert_type,
|
|
infer_label_column,
|
|
select_split,
|
|
)
|
|
import datasets
|
|
from tqdm.auto import tqdm
|
|
|
|
from elk.extraction.balanced_sampler import BalancedSampler, FewShotSampler
|
|
import pandas as pd
|
|
from loguru import logger
|
|
|
|
from lie_elicitation_prompts.helpers.ds import shuffle_dataset_by
|
|
|
|
|
|
# Local path to the folder containing the templates
|
|
TEMPLATES_FOLDER_PATH = Path(__file__).parent / "templates"
|
|
|
|
def load_default_sys_instructions(path='system.yaml'):
|
|
f = TEMPLATES_FOLDER_PATH / path
|
|
yaml_dict = yaml.load(f.open('r'), Loader=yaml.FullLoader)
|
|
templates = yaml_dict["templates"]["instructed_to_lie"]
|
|
return templates
|
|
|
|
default_sys_instructions = load_default_sys_instructions()
|
|
|
|
|
|
def sample_n_true_y_false_prompts(prompts, num_truth=3, num_lie=3, seed=42):
|
|
"""sample some truth and some false"""
|
|
df = pd.DataFrame(prompts)
|
|
|
|
# restrict to template where the choices are a single token
|
|
# m = df.answer_choices.map(answer_len)<=2
|
|
# df = df[m]
|
|
df = pd.concat([
|
|
df.query("instructed_to_lie==True").sample(int(num_truth), random_state=seed),
|
|
df.query("instructed_to_lie==False").sample(int(num_lie), random_state=seed)])
|
|
return df.to_dict(orient="records")
|
|
|
|
|
|
def prompt_ok(prompt):
|
|
""" we want answers where we can distinguish them from the first token
|
|
we don't have access to the tokenizer here, so we just make sure the first 3 letters are differen't and there are not spaces
|
|
"""
|
|
answer_choices = prompt['answer_choices']
|
|
a = answer_choices[0][:3]
|
|
b = answer_choices[1][:3]
|
|
keep = (a != b) and (' ' not in a) and (' ' not in b)
|
|
if not keep:
|
|
logger.warning(f"removing prompt because it's answers are not unique in first 3 chars or contain space: {prompt['ds_string']} {prompt['template_name']} {prompt['answer_choices']}")
|
|
return keep
|
|
|
|
import itertools
|
|
|
|
from itertools import cycle
|
|
from typing import Iterable, Optional, Iterator, List, Dict, Any
|
|
from random import Random
|
|
|
|
class FewShotDataset2:
|
|
"""A dataset that pre-computes few-shot examples that are as balanced as possible."""
|
|
|
|
def __init__(
|
|
self,
|
|
dataset: Iterable,
|
|
num_shots: int,
|
|
rng: Random,
|
|
label_col: Optional[str] = None,
|
|
):
|
|
self.batches = [] # Store pre-computed batches
|
|
self.num_shots = num_shots
|
|
self.rng = rng
|
|
self.label_col = label_col
|
|
self._prepare_batches(dataset)
|
|
|
|
def _prepare_batches(self, dataset):
|
|
neg_buf, pos_buf = [], []
|
|
for sample in cycle(dataset):
|
|
if len(neg_buf) + len(pos_buf) >= len(dataset):
|
|
break # Prevent infinite loop if dataset is exhausted
|
|
label = sample[self.label_col]
|
|
if label == 0:
|
|
neg_buf.append(sample)
|
|
elif label == 1:
|
|
pos_buf.append(sample)
|
|
else:
|
|
raise ValueError(f"Expected label to be 0 or 1, got {label}")
|
|
|
|
neg_count, pos_count = self._stochastic_round_constrained(
|
|
[self.num_shots / 2, self.num_shots / 2]
|
|
)
|
|
while len(neg_buf) >= neg_count and len(pos_buf) >= pos_count:
|
|
batch = []
|
|
for _ in range(neg_count):
|
|
batch.append(neg_buf.pop())
|
|
for _ in range(pos_count):
|
|
batch.append(pos_buf.pop())
|
|
|
|
self.rng.shuffle(batch)
|
|
self.batches.append(batch)
|
|
|
|
def _stochastic_round_constrained(self, counts):
|
|
# Placeholder for the stochastic_round_constrained function
|
|
# This should be replaced with the actual implementation
|
|
return int(counts[0]), int(counts[1])
|
|
|
|
def __getitem__(self, idx) -> List[Dict[str, Any]]:
|
|
if idx>=len(self.batches):
|
|
idx = idx%len(self.batches)
|
|
return self.batches[idx]
|
|
|
|
def __len__(self) -> int:
|
|
return len(self.batches)
|
|
|
|
def load_prompts(
|
|
ds_string: str,
|
|
*,
|
|
sys_instructions: Dict[bool, Dict[str, str]]= default_sys_instructions,
|
|
binarize: bool = True,
|
|
num_shots: int = 0,
|
|
seed: int = 42,
|
|
split_type: Literal["train", "val"] = "train",
|
|
template_path: str | None = None,
|
|
rank: int = 0,
|
|
world_size: int = 8,
|
|
prompt_sampler = sample_n_true_y_false_prompts,
|
|
N=np.inf,
|
|
M:int=3
|
|
) -> Iterator[dict]:
|
|
"""Load a dataset full of prompts generated from the specified dataset.
|
|
|
|
Args:
|
|
ds_string: Name of HF dataset to use, e.g. `"super_glue:boolq"` or `"imdb"`.
|
|
binarize: Whether to binarize the dataset labels for multi-class datasets.
|
|
num_shots: The number of examples to use in few-shot prompts. If zero, prompts
|
|
are zero-shot.
|
|
seed: The seed to use for prompt randomization.
|
|
split_type: Whether to use the train or val split of the dataset.
|
|
template_path: Path to feed into `DatasetTemplates` for loading templates.
|
|
rank: The rank of the current process. Defaults to 0.
|
|
world_size: The number of processes. Defaults to 1.
|
|
prompt_sampler: when given an unbalanced set of true and false prompts this might take one of each randomly
|
|
M: how many true and false forms of each example to keep
|
|
|
|
Returns:
|
|
An iterable of prompt dictionaries.
|
|
"""
|
|
if Path(ds_string).exists():
|
|
template_path = ds_string
|
|
ds_string = Path(template_path).stem.replace('-', '/')
|
|
ds_name, _, config_name = ds_string.partition(":")
|
|
|
|
# load dataset
|
|
ds_dict = assert_type(dict, load_dataset(ds_name, config_name or None, trust_remote_code=True))
|
|
|
|
# take split
|
|
split_name = select_split(ds_dict, split_type)
|
|
ds = assert_type(Dataset, ds_dict[split_name])
|
|
if world_size > 1:
|
|
ds = ds.shard(world_size, rank)
|
|
|
|
# load dataset templates
|
|
if template_path is None:
|
|
prompter = DatasetTemplates(ds_name, config_name)
|
|
else:
|
|
prompter = LocalDatasetTemplates(template_path)
|
|
|
|
# If the prompt template says to binarize, we should
|
|
binarize = binarize or prompter.binarize
|
|
if binarize:
|
|
n = prompter.drop_non_mc_templates()
|
|
if n>0:
|
|
logger.debug(f"dropped {n} templates from {ds_string} because they are not multiple choice")
|
|
|
|
num_templates = len(prompter.templates)
|
|
assert num_templates > 0
|
|
if rank == 0:
|
|
logger.info(f"Extracting {num_templates} variants of each prompt")
|
|
|
|
# load labels
|
|
label_column = prompter.label_column or infer_label_column(ds.features)
|
|
|
|
# label_feature = ds.features[label_column]
|
|
# if isinstance(label_feature, ClassLabel):
|
|
# label_choices = [label_feature.str2int(label) for label in label_feature.names]
|
|
# elif isinstance(label_feature, Value) and label_feature.dtype == "bool":
|
|
# label_choices = [False, True]
|
|
# else:
|
|
# # Which classes are actually present in this split of the dataset?
|
|
# # This is shockingly fast since it uses an optimized Apache Arrow primitive.
|
|
# label_choices = sorted(ds.unique(label_column))
|
|
# if rank == 0:
|
|
# logger.info(f"Using the following pseudo-labels: {label_choices}")
|
|
|
|
# if we providing examples, we need to sample them randomly
|
|
rng = Random(seed)
|
|
if num_shots > 0:
|
|
train_name = select_split(ds_dict, "train")
|
|
|
|
# fewshot = FewShotSampler(
|
|
# ds_dict[train_name].shuffle(seed=seed), # TODO: not iterator
|
|
# num_shots=num_shots,
|
|
# rng=rng,
|
|
# label_col=label_column,
|
|
# )
|
|
# fewshot_iter = iter(fewshot)
|
|
fewshot_ds = FewShotDataset2(
|
|
ds_dict[train_name].shuffle(seed=seed), # TODO: not iterator
|
|
num_shots=num_shots,
|
|
rng=rng,
|
|
label_col=label_column,
|
|
)
|
|
else:
|
|
fewshot_ds = None
|
|
|
|
# here we sample in a balanced way in our main dataset
|
|
# if label_column in ds.features:
|
|
# ds = BalancedSampler(
|
|
# ds.to_iterable_dataset(),
|
|
# set(label_choices),
|
|
# label_col=label_column,
|
|
# )
|
|
# else:
|
|
# if rank == 0:
|
|
# logger.info("No label column found, not balancing")
|
|
N = min(N, len(ds))
|
|
# ds1 = ds.select(range(N)).to_iterable_dataset()
|
|
|
|
|
|
def foo(example, i):
|
|
prompts = _convert_to_prompts(
|
|
example,
|
|
binarize=binarize,
|
|
label_column=label_column,
|
|
# label_choices=label_choices, # type: ignore[arg-type]
|
|
prompter=prompter,
|
|
rng=rng,
|
|
sys_instructions=sys_instructions,
|
|
fewshot_ds=fewshot_ds,
|
|
i=i,
|
|
)
|
|
prompts = [{'ds_string': ds_string, 'example_i':i, **p} for p in prompts]
|
|
|
|
|
|
prompts1 = list(filter(prompt_ok, prompts))
|
|
prompts2 = prompt_sampler(prompts1, seed=42+i, num_truth=M, num_lie=M)
|
|
return {'prompts': prompts2}
|
|
|
|
ds1 = ds.select(range(N)).map(foo, with_indices=True, desc='convert_to_prompts',
|
|
num_proc=8,
|
|
|
|
)
|
|
return list(itertools.chain(*ds1['prompts'].tolist()))
|
|
|
|
|
|
# j = 0
|
|
# for i, example in enumerate(tqdm(ds1, desc='ds', total=min(N, len(ds)))):
|
|
# if j>N:
|
|
# break
|
|
|
|
# prompts = _convert_to_prompts(
|
|
# example,
|
|
# binarize=binarize,
|
|
# label_column=label_column,
|
|
# # label_choices=label_choices, # type: ignore[arg-type]
|
|
# prompter=prompter,
|
|
# rng=rng,
|
|
# sys_instructions=sys_instructions,
|
|
# fewshot_iter=fewshot_iter,
|
|
# )
|
|
# prompts = [{'ds_string': ds_string, 'example_i':i, **p} for p in prompts]
|
|
|
|
|
|
# prompts1 = list(filter(prompt_ok, prompts))
|
|
# prompts2 = prompt_sampler(prompts1, seed=42+j, num_truth=M, num_lie=M)
|
|
# for p in prompts2:
|
|
# j += 1
|
|
# yield p
|
|
|
|
|
|
def cast_example_label_to_bool(e, label_column='label'):
|
|
assert e[label_column]>=0
|
|
assert e[label_column]<=1
|
|
e[label_column]=bool(e[label_column])
|
|
return e
|
|
|
|
|
|
def _convert_to_prompts(
|
|
example: dict[str, Any],
|
|
prompter: DatasetTemplates,
|
|
binarize: bool,
|
|
label_column: str,
|
|
# label_choices: list[bool | int | str],
|
|
rng: Random,
|
|
sys_instructions: Dict[bool, Dict[str, str]] = default_sys_instructions,
|
|
fewshot_ds: FewShotDataset2 | None = None,
|
|
i:int=0,
|
|
) -> list:
|
|
"""Prompt-generating function to pass to `IterableDataset.map`."""
|
|
|
|
|
|
# FIXME: make mc compat
|
|
example = cast_example_label_to_bool(example, label_column)
|
|
prompts = []
|
|
templates = list(prompter.templates.values())
|
|
|
|
# For sanity checking that prompts are unique
|
|
prompt_counter = Counter()
|
|
# label = example[label_column]
|
|
|
|
ds_name = prompter.dataset_name
|
|
if prompter.subset_name is not None:
|
|
ds_name += ':' + prompter.subset_name
|
|
|
|
# FIXME: not used?
|
|
# if binarize:
|
|
# # Replace the full list of possibilities with a randomly sampled false label
|
|
# # and the correct label, as done in the DLK paper. Note that this does add some
|
|
# # "supervision" by stacking the deck in favor of the correct answer.
|
|
# logger.info(f"Binarising {label_choices} in {ds_name}")
|
|
# label_choices = [
|
|
# rng.choice([c for c in label_choices if c != label]),
|
|
# label,
|
|
# ]
|
|
# rng.shuffle(label_choices)
|
|
|
|
# FIXME: the original elk is a bit confused between label_choices, and prompt_answer choices. It
|
|
|
|
|
|
for j, template in enumerate(templates):
|
|
answer_choices=template.get_fixed_answer_choices_list()
|
|
assert len(answer_choices) <= 2, 'should be binary'
|
|
if answer_choices is None:
|
|
logger.info(f"skipping ds_name={ds_name} template={template.name} because it has no fixed answer choices")
|
|
continue
|
|
|
|
# skip prompts where the responses are similar in the first token
|
|
if answer_choices[0][:3]==answer_choices[1][:3]:
|
|
logger.info(f"skipping prompt because it's answers are not unique (for the first token): {template.name} {answer_choices}")
|
|
continue
|
|
answer_choices = [[c] for c in answer_choices]
|
|
for instructed_to_lie in [False, True]:
|
|
for sys_instr_name, sys_instr in sys_instructions[instructed_to_lie].items():
|
|
instructed_example = example.copy()
|
|
if instructed_to_lie:
|
|
# FIXME: make multichoice compat
|
|
instructed_example[label_column] = not bool(instructed_example[label_column])
|
|
|
|
q, a = template.apply(instructed_example)
|
|
messages = [
|
|
|
|
dict(role='user', content=q.strip())
|
|
]
|
|
prompt_counter[(sys_instr + q, a)] += 1
|
|
|
|
if fewshot_ds is not None:
|
|
# same example for true and false
|
|
fewshot_examples = fewshot_ds[i+j]
|
|
# FIXME: make mc compat
|
|
fewshot_examples = [cast_example_label_to_bool(e, label_column).copy() for e in fewshot_examples]
|
|
|
|
if instructed_to_lie:
|
|
# FIXME: make multichoice compat
|
|
fewshot_examples = [{**e, label_column: not bool(e[label_column])} for e in fewshot_examples]
|
|
for e in fewshot_examples:
|
|
# arg, check negation worked
|
|
assert e[label_column]>=0
|
|
assert e[label_column]<2
|
|
assert isinstance(e[label_column], bool), 'labels should be bool'
|
|
|
|
fewshot_texts = []
|
|
for q, a in map(template.apply, fewshot_examples):
|
|
fewshot_texts.append(dict(role='user', content=q.strip()))
|
|
fewshot_texts.append(dict(role='assistant', content=a.strip()))
|
|
# some of the answers have extra trailing text, that's OK. But extra preceeding text is not, let's check for that
|
|
aa = a.strip()
|
|
assert any([any([aa.startswith(a) for a in ac]) for ac in answer_choices]), f"fewshot response `{aa}` has extra preceeding text compared to allowed choices: {answer_choices}. template is: {template.name}"
|
|
messages = fewshot_texts + messages
|
|
messages = [dict(role='system', content=sys_instr)] + messages
|
|
|
|
prompts.append(dict(
|
|
# Strip whitespace from the answer to make it easier to
|
|
# compare with the model's output
|
|
answer=a.strip(),
|
|
messages=messages,
|
|
|
|
answer_choices=answer_choices,
|
|
template_name=template.name,
|
|
label_true=example[label_column],
|
|
instructed_to_lie=instructed_to_lie,
|
|
sys_instr_name=sys_instr_name,
|
|
# example_idx=example['idx'],
|
|
))
|
|
|
|
# Sanity check: variants should be unique
|
|
((maybe_dup, dup_count),) = prompt_counter.most_common(1)
|
|
if dup_count > 1:
|
|
raise ValueError(f'Prompt duplicated {dup_count} times! "{maybe_dup}"')
|
|
|
|
return prompts
|
|
|
|
|
|
|
|
def load_preproc_datasets(dataset_names: List[str], N:int, split_type:str="train", seed=42, num_shots=1, M=3):
|
|
datasets2 = []
|
|
n = N//len(dataset_names)+1
|
|
for ds_name in tqdm(dataset_names):
|
|
# if it is a path
|
|
ds_tokens1 = load_preproc_dataset(
|
|
ds_name,
|
|
N=n,
|
|
seed=seed,
|
|
num_shots=num_shots,
|
|
M=M,
|
|
).with_format("torch")
|
|
datasets2.append(ds_tokens1)
|
|
ds_tokens = datasets.concatenate_datasets(datasets2)
|
|
|
|
return ds_tokens
|
|
|
|
|
|
def load_preproc_dataset(ds_name: str, N:int, split_type:str="train", seed=42, num_shots=1, sys_instructions=default_sys_instructions, M=3, num_proc=1,) -> Dataset:
|
|
ds_prompts = Dataset.from_generator(
|
|
load_prompts,
|
|
gen_kwargs=dict(
|
|
ds_string=ds_name,
|
|
num_shots=num_shots,
|
|
split_type=split_type,
|
|
sys_instructions=sys_instructions,
|
|
seed=seed,
|
|
N=N,
|
|
M=M,
|
|
),
|
|
keep_in_memory=False,
|
|
num_proc=num_proc,
|
|
)
|
|
ds_prompts = shuffle_dataset_by(ds_prompts, target='label_true', random_state=seed, stratify_columns=[])
|
|
return ds_prompts
|