This commit is contained in:
wassname
2024-07-01 21:40:20 +08:00
parent c6ac96d102
commit 1ad6640231
2 changed files with 182 additions and 172 deletions
+140 -128
View File
@@ -8,10 +8,12 @@ from collections import Counter
from random import Random
from typing import Any, Iterator, Literal, List, Dict
from pathlib import Path
# import datasets
from datasets import ClassLabel, Dataset, Value, load_dataset
import yaml
import numpy as np
# from elk.promptsource.templates import env
from elk.promptsource import DatasetTemplates
from lie_elicitation_prompts.prompts.templates import LocalDatasetTemplates
@@ -28,17 +30,19 @@ import pandas as pd
from loguru import logger
from lie_elicitation_prompts.helpers.ds import shuffle_dataset_by
from elk.utils.math_util import stochastic_round_constrained
# Local path to the folder containing the templates
TEMPLATES_FOLDER_PATH = Path(__file__).parent / "templates"
def load_default_sys_instructions(path='system.yaml'):
def load_default_sys_instructions(path="system.yaml"):
f = TEMPLATES_FOLDER_PATH / path
yaml_dict = yaml.load(f.open('r'), Loader=yaml.FullLoader)
yaml_dict = yaml.load(f.open("r"), Loader=yaml.FullLoader)
templates = yaml_dict["templates"]["instructed_to_lie"]
return templates
default_sys_instructions = load_default_sys_instructions()
@@ -49,9 +53,16 @@ def sample_n_true_y_false_prompts(prompts, num_truth=3, num_lie=3, seed=42):
# restrict to template where the choices are a single token
# m = df.answer_choices.map(answer_len)<=2
# df = df[m]
df = pd.concat([
df.query("instructed_to_lie==True").sample(int(num_truth), random_state=seed),
df.query("instructed_to_lie==False").sample(int(num_lie), random_state=seed)])
df = pd.concat(
[
df.query("instructed_to_lie==True").sample(
int(num_truth), random_state=seed
),
df.query("instructed_to_lie==False").sample(
int(num_lie), random_state=seed
),
]
)
return df.to_dict(orient="records")
@@ -59,20 +70,24 @@ def prompt_ok(prompt):
"""we want answers where we can distinguish them from the first token
we don't have access to the tokenizer here, so we just make sure the first 3 letters are differen't and there are not spaces
"""
answer_choices = prompt['answer_choices']
answer_choices = prompt["answer_choices"]
a = answer_choices[0][:3]
b = answer_choices[1][:3]
keep = (a != b) and (' ' not in a) and (' ' not in b)
keep = (a != b) and (" " not in a) and (" " not in b)
if not keep:
logger.warning(f"removing prompt because it's answers are not unique in first 3 chars or contain space: {prompt['ds_string']} {prompt['template_name']} {prompt['answer_choices']}")
logger.warning(
f"removing prompt because it's answers are not unique in first 3 chars or contain space: {prompt['ds_string']} {prompt['template_name']} {prompt['answer_choices']}"
)
return keep
import itertools
from itertools import cycle
from typing import Iterable, Optional, Iterator, List, Dict, Any
from random import Random
class FewShotDataset2:
"""A dataset that pre-computes few-shot examples that are as balanced as possible."""
@@ -102,8 +117,8 @@ class FewShotDataset2:
else:
raise ValueError(f"Expected label to be 0 or 1, got {label}")
neg_count, pos_count = self._stochastic_round_constrained(
[self.num_shots / 2, self.num_shots / 2]
neg_count, pos_count = stochastic_round_constrained(
[self.num_shots / 2, self.num_shots / 2], self.rng
)
while len(neg_buf) >= neg_count and len(pos_buf) >= pos_count:
batch = []
@@ -115,11 +130,6 @@ class FewShotDataset2:
self.rng.shuffle(batch)
self.batches.append(batch)
def _stochastic_round_constrained(self, counts):
# Placeholder for the stochastic_round_constrained function
# This should be replaced with the actual implementation
return int(counts[0]), int(counts[1])
def __getitem__(self, idx) -> List[Dict[str, Any]]:
if idx >= len(self.batches):
idx = idx % len(self.batches)
@@ -128,6 +138,37 @@ class FewShotDataset2:
def __len__(self) -> int:
return len(self.batches)
def _to_prompts(
example,
i,
binarize,
label_column,
prompter,
rng,
sys_instructions,
fewshot_ds,
ds_string,
prompt_sampler,
M,
):
prompts = _convert_to_prompts(
example,
binarize=binarize,
label_column=label_column,
prompter=prompter,
rng=rng,
sys_instructions=sys_instructions,
fewshot_ds=fewshot_ds,
i=i,
)
prompts = [{"ds_string": ds_string, "example_i": i, **p} for p in prompts]
prompts1 = list(filter(prompt_ok, prompts))
prompts2 = prompt_sampler(prompts1, seed=42 + i, num_truth=M, num_lie=M)
return {"prompts": prompts2}
def load_prompts(
ds_string: str,
*,
@@ -141,7 +182,7 @@ def load_prompts(
world_size: int = 1,
prompt_sampler=sample_n_true_y_false_prompts,
N=np.inf,
M:int=3
M: int = 3,
) -> Iterator[dict]:
"""Load a dataset full of prompts generated from the specified dataset.
@@ -163,11 +204,13 @@ def load_prompts(
"""
if Path(ds_string).exists():
template_path = ds_string
ds_string = Path(template_path).stem.replace('-', '/')
ds_string = Path(template_path).stem.replace("-", "/")
ds_name, _, config_name = ds_string.partition(":")
# load dataset
ds_dict = assert_type(dict, load_dataset(ds_name, config_name or None, trust_remote_code=True))
ds_dict = assert_type(
dict, load_dataset(ds_name, config_name or None, trust_remote_code=True)
)
# take split
split_name = select_split(ds_dict, split_type)
@@ -187,7 +230,9 @@ def load_prompts(
if binarize:
n = prompter.drop_non_mc_templates()
if n > 0:
logger.debug(f"dropped {n} templates from {ds_string} because they are not multiple choice")
logger.debug(
f"dropped {n} templates from {ds_string} because they are not multiple choice"
)
num_templates = len(prompter.templates)
assert num_templates > 0
@@ -197,32 +242,15 @@ def load_prompts(
# load labels
label_column = prompter.label_column or infer_label_column(ds.features)
# label_feature = ds.features[label_column]
# if isinstance(label_feature, ClassLabel):
# label_choices = [label_feature.str2int(label) for label in label_feature.names]
# elif isinstance(label_feature, Value) and label_feature.dtype == "bool":
# label_choices = [False, True]
# else:
# # Which classes are actually present in this split of the dataset?
# # This is shockingly fast since it uses an optimized Apache Arrow primitive.
# label_choices = sorted(ds.unique(label_column))
# if rank == 0:
# logger.info(f"Using the following pseudo-labels: {label_choices}")
# if we providing examples, we need to sample them randomly
rng = Random(seed)
if num_shots > 0:
train_name = select_split(ds_dict, "train")
# fewshot = FewShotSampler(
# ds_dict[train_name].shuffle(seed=seed), # TODO: not iterator
# num_shots=num_shots,
# rng=rng,
# label_col=label_column,
# )
# fewshot_iter = iter(fewshot)
fewshot_ds = FewShotDataset2(
ds_dict[train_name].shuffle(seed=seed).select(range(1000)), # TODO: not iterator
ds_dict[train_name]
.shuffle(seed=seed)
.select(range(1000)), # just use a random 1000 repeatedly
num_shots=num_shots,
rng=rng,
label_col=label_column,
@@ -230,80 +258,29 @@ def load_prompts(
else:
fewshot_ds = None
# here we sample in a balanced way in our main dataset
# if label_column in ds.features:
# ds = BalancedSampler(
# ds.to_iterable_dataset(),
# set(label_choices),
# label_col=label_column,
# )
# else:
# if rank == 0:
# logger.info("No label column found, not balancing")
# ds1 = ds.select(range(N)).to_iterable_dataset()
def foo(example, i, binarize, label_column, prompter, rng, sys_instructions, fewshot_ds,):
prompts = _convert_to_prompts(
example,
ds1 = ds.select(range(N)).map(
_to_prompts,
with_indices=True,
desc="convert_to_prompts",
fn_kwargs=dict(
binarize=binarize,
label_column=label_column,
prompter=prompter,
rng=rng,
sys_instructions=sys_instructions,
fewshot_ds=fewshot_ds,
i=i,
)
prompts = [{'ds_string': ds_string, 'example_i':i, **p} for p in prompts]
prompts1 = list(filter(prompt_ok, prompts))
prompts2 = prompt_sampler(prompts1, seed=42+i, num_truth=M, num_lie=M)
return {'prompts': prompts2}
ds1 = ds.select(range(N)).map(foo, with_indices=True,
desc='convert_to_prompts',
fn_kwargs=dict( binarize=binarize,
label_column=label_column,
prompter=prompter,
rng=rng,
sys_instructions=sys_instructions,
fewshot_ds=fewshot_ds,),
ds_string=ds_string,
prompt_sampler=prompt_sampler,
M=M,
),
# num_proc=1,#
num_proc=multiprocessing.cpu_count() // 2,
)
return list(itertools.chain(*ds1['prompts']))
return list(itertools.chain(*ds1["prompts"]))
# j = 0
# for i, example in enumerate(tqdm(ds1, desc='ds', total=min(N, len(ds)))):
# if j>N:
# break
# prompts = _convert_to_prompts(
# example,
# binarize=binarize,
# label_column=label_column,
# # label_choices=label_choices, # type: ignore[arg-type]
# prompter=prompter,
# rng=rng,
# sys_instructions=sys_instructions,
# fewshot_iter=fewshot_iter,
# )
# prompts = [{'ds_string': ds_string, 'example_i':i, **p} for p in prompts]
# prompts1 = list(filter(prompt_ok, prompts))
# prompts2 = prompt_sampler(prompts1, seed=42+j, num_truth=M, num_lie=M)
# for p in prompts2:
# j += 1
# yield p
def cast_example_label_to_bool(e, label_column='label'):
def cast_example_label_to_bool(e, label_column="label"):
assert e[label_column] >= 0
assert e[label_column] <= 1
e[label_column] = bool(e[label_column])
@@ -323,7 +300,6 @@ def _convert_to_prompts(
) -> list:
"""Prompt-generating function to pass to `IterableDataset.map`."""
# FIXME: make mc compat
example = cast_example_label_to_bool(example, label_column)
prompts = []
@@ -335,7 +311,7 @@ def _convert_to_prompts(
ds_name = prompter.dataset_name
if prompter.subset_name is not None:
ds_name += ':' + prompter.subset_name
ds_name += ":" + prompter.subset_name
# FIXME: not used?
# if binarize:
@@ -351,71 +327,89 @@ def _convert_to_prompts(
# FIXME: the original elk is a bit confused between label_choices, and prompt_answer choices. It
for j, template in enumerate(templates):
answer_choices = template.get_fixed_answer_choices_list()
assert len(answer_choices) <= 2, 'should be binary'
assert len(answer_choices) <= 2, "should be binary"
if answer_choices is None:
logger.info(f"skipping ds_name={ds_name} template={template.name} because it has no fixed answer choices")
logger.info(
f"skipping ds_name={ds_name} template={template.name} because it has no fixed answer choices"
)
continue
# skip prompts where the responses are similar in the first token
if answer_choices[0][:3] == answer_choices[1][:3]:
logger.info(f"skipping prompt because it's answers are not unique (for the first token): {template.name} {answer_choices}")
logger.info(
f"skipping prompt because it's answers are not unique (for the first token): {template.name} {answer_choices}"
)
continue
answer_choices = [[c] for c in answer_choices]
for instructed_to_lie in [False, True]:
for sys_instr_name, sys_instr in sys_instructions[instructed_to_lie].items():
for sys_instr_name, sys_instr in sys_instructions[
instructed_to_lie
].items():
instructed_example = example.copy()
if instructed_to_lie:
# FIXME: make multichoice compat
instructed_example[label_column] = not bool(instructed_example[label_column])
instructed_example[label_column] = not bool(
instructed_example[label_column]
)
q, a = template.apply(instructed_example)
messages = [
dict(role='user', content=q.strip())
]
messages = [dict(role="user", content=q.strip())]
prompt_counter[(sys_instr + q, a)] += 1
if fewshot_ds is not None:
# same example for true and false
fewshot_examples = fewshot_ds[i + j]
# FIXME: make mc compat
fewshot_examples = [cast_example_label_to_bool(e, label_column).copy() for e in fewshot_examples]
fewshot_examples = [
cast_example_label_to_bool(e, label_column).copy()
for e in fewshot_examples
]
if instructed_to_lie:
# FIXME: make multichoice compat
fewshot_examples = [{**e, label_column: not bool(e[label_column])} for e in fewshot_examples]
fewshot_examples = [
{**e, label_column: not bool(e[label_column])}
for e in fewshot_examples
]
for e in fewshot_examples:
# arg, check negation worked
assert e[label_column] >= 0
assert e[label_column] < 2
assert isinstance(e[label_column], bool), 'labels should be bool'
assert isinstance(
e[label_column], bool
), "labels should be bool"
fewshot_texts = []
for q, a in map(template.apply, fewshot_examples):
fewshot_texts.append(dict(role='user', content=q.strip()))
fewshot_texts.append(dict(role='assistant', content=a.strip()))
fewshot_texts.append(dict(role="user", content=q.strip()))
fewshot_texts.append(dict(role="assistant", content=a.strip()))
# some of the answers have extra trailing text, that's OK. But extra preceeding text is not, let's check for that
aa = a.strip()
assert any([any([aa.startswith(a) for a in ac]) for ac in answer_choices]), f"fewshot response `{aa}` has extra preceeding text compared to allowed choices: {answer_choices}. template is: {template.name}"
assert any(
[
any([aa.startswith(a) for a in ac])
for ac in answer_choices
]
), f"fewshot response `{aa}` has extra preceeding text compared to allowed choices: {answer_choices}. template is: {template.name}"
messages = fewshot_texts + messages
messages = [dict(role='system', content=sys_instr)] + messages
messages = [dict(role="system", content=sys_instr)] + messages
prompts.append(dict(
prompts.append(
dict(
# Strip whitespace from the answer to make it easier to
# compare with the model's output
answer=a.strip(),
messages=messages,
answer_choices=answer_choices,
template_name=template.name,
label_true=example[label_column],
instructed_to_lie=instructed_to_lie,
sys_instr_name=sys_instr_name,
# example_idx=example['idx'],
))
)
)
# Sanity check: variants should be unique
((maybe_dup, dup_count),) = prompt_counter.most_common(1)
@@ -425,8 +419,15 @@ def _convert_to_prompts(
return prompts
def load_preproc_datasets(dataset_names: List[str], N:int, split_type:str="train", seed=42, num_shots=1, M=3, num_proc=1,):
def load_preproc_datasets(
dataset_names: List[str],
N: int,
split_type: str = "train",
seed=42,
num_shots=1,
M=3,
num_proc=1,
):
datasets2 = []
n = N // len(dataset_names) + 1
for ds_name in tqdm(dataset_names):
@@ -445,7 +446,16 @@ def load_preproc_datasets(dataset_names: List[str], N:int, split_type:str="train
return ds_tokens
def load_preproc_dataset(ds_name: str, N:int, split_type:str="train", seed=42, num_shots=1, sys_instructions=default_sys_instructions, M=3, num_proc=1,) -> Dataset:
def load_preproc_dataset(
ds_name: str,
N: int,
split_type: str = "train",
seed=42,
num_shots=1,
sys_instructions=default_sys_instructions,
M=3,
num_proc=1,
) -> Dataset:
ds_prompts = Dataset.from_generator(
load_prompts,
gen_kwargs=dict(
@@ -460,5 +470,7 @@ def load_preproc_dataset(ds_name: str, N:int, split_type:str="train", seed=42, n
keep_in_memory=False,
num_proc=num_proc,
)
ds_prompts = shuffle_dataset_by(ds_prompts, target='label_true', random_state=seed, stratify_columns=[])
ds_prompts = shuffle_dataset_by(
ds_prompts, target="label_true", random_state=seed, stratify_columns=[]
)
return ds_prompts
+4 -6
View File
@@ -157,7 +157,7 @@
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "6e45dbc76f9446be98011de8a7cfd77c",
"model_id": "c0624f7c0bce4443a712104b24852563",
"version_major": 2,
"version_minor": 0
},
@@ -171,7 +171,7 @@
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "a15de459703e45f38efffc1036063d90",
"model_id": "a37c29ab88df472ebc25f24b8cc768d7",
"version_major": 2,
"version_minor": 0
},
@@ -186,14 +186,13 @@
"name": "stderr",
"output_type": "stream",
"text": [
"\u001b[32m2024-07-01 21:32:57.756\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mlie_elicitation_prompts.prompts.prompt_loading\u001b[0m:\u001b[36mload_prompts\u001b[0m:\u001b[36m195\u001b[0m - \u001b[1mExtracting 11 variants of each prompt\u001b[0m\n",
"Parameter 'function'=<function load_prompts.<locals>.foo at 0x72bb8db6fb50> of the transform datasets.arrow_dataset.Dataset._map_single couldn't be hashed properly, a random hash was used instead. Make sure your transforms and parameters are serializable with pickle or dill for the dataset fingerprinting and caching to work. If you reuse this transform, the caching mechanism will consider it to be different from the previous calls and recompute everything. This warning is only showed once. Subsequent hashing failures won't be showed.\n"
"\u001b[32m2024-07-01 21:39:02.979\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mlie_elicitation_prompts.prompts.prompt_loading\u001b[0m:\u001b[36mload_prompts\u001b[0m:\u001b[36m240\u001b[0m - \u001b[1mExtracting 11 variants of each prompt\u001b[0m\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "6fc652a01ccd4b5d997698f0aaca787a",
"model_id": "45a7fda96e96431e83f2d629d5b6f06d",
"version_major": 2,
"version_minor": 0
},
@@ -213,7 +212,6 @@
" seed=cfg.seed,\n",
" num_shots=cfg.num_shots,\n",
" M=cfg.repeats,\n",
" # num_proc=4,\n",
")\n",
"ds_prompts"
]