mirror of
https://github.com/wassname/lie_elicitation_prompts.git
synced 2026-06-27 16:10:35 +08:00
multiproc works
This commit is contained in:
@@ -3,6 +3,7 @@ Modified from https://github.com/EleutherAI/elk/blob/3bbe26c3858aac1b03e6f80628a
|
||||
|
||||
Changed to record choices
|
||||
"""
|
||||
import multiprocessing
|
||||
from collections import Counter
|
||||
from random import Random
|
||||
from typing import Any, Iterator, Literal, List, Dict
|
||||
@@ -137,7 +138,7 @@ def load_prompts(
|
||||
split_type: Literal["train", "val"] = "train",
|
||||
template_path: str | None = None,
|
||||
rank: int = 0,
|
||||
world_size: int = 8,
|
||||
world_size: int = 1,
|
||||
prompt_sampler = sample_n_true_y_false_prompts,
|
||||
N=np.inf,
|
||||
M:int=3
|
||||
@@ -173,6 +174,7 @@ def load_prompts(
|
||||
ds = assert_type(Dataset, ds_dict[split_name])
|
||||
if world_size > 1:
|
||||
ds = ds.shard(world_size, rank)
|
||||
N = min(N, len(ds))
|
||||
|
||||
# load dataset templates
|
||||
if template_path is None:
|
||||
@@ -220,7 +222,7 @@ def load_prompts(
|
||||
# )
|
||||
# fewshot_iter = iter(fewshot)
|
||||
fewshot_ds = FewShotDataset2(
|
||||
ds_dict[train_name].shuffle(seed=seed), # TODO: not iterator
|
||||
ds_dict[train_name].shuffle(seed=seed).select(range(1000)), # TODO: not iterator
|
||||
num_shots=num_shots,
|
||||
rng=rng,
|
||||
label_col=label_column,
|
||||
@@ -238,16 +240,15 @@ def load_prompts(
|
||||
# else:
|
||||
# if rank == 0:
|
||||
# logger.info("No label column found, not balancing")
|
||||
N = min(N, len(ds))
|
||||
|
||||
# ds1 = ds.select(range(N)).to_iterable_dataset()
|
||||
|
||||
|
||||
def foo(example, i):
|
||||
def foo(example, i, binarize, label_column, prompter, rng, sys_instructions, fewshot_ds,):
|
||||
prompts = _convert_to_prompts(
|
||||
example,
|
||||
binarize=binarize,
|
||||
label_column=label_column,
|
||||
# label_choices=label_choices, # type: ignore[arg-type]
|
||||
prompter=prompter,
|
||||
rng=rng,
|
||||
sys_instructions=sys_instructions,
|
||||
@@ -261,11 +262,20 @@ def load_prompts(
|
||||
prompts2 = prompt_sampler(prompts1, seed=42+i, num_truth=M, num_lie=M)
|
||||
return {'prompts': prompts2}
|
||||
|
||||
ds1 = ds.select(range(N)).map(foo, with_indices=True, desc='convert_to_prompts',
|
||||
num_proc=8,
|
||||
ds1 = ds.select(range(N)).map(foo, with_indices=True,
|
||||
|
||||
desc='convert_to_prompts',
|
||||
fn_kwargs=dict( binarize=binarize,
|
||||
label_column=label_column,
|
||||
prompter=prompter,
|
||||
rng=rng,
|
||||
sys_instructions=sys_instructions,
|
||||
fewshot_ds=fewshot_ds,),
|
||||
# num_proc=1,#
|
||||
num_proc=multiprocessing.cpu_count()//2,
|
||||
|
||||
)
|
||||
return list(itertools.chain(*ds1['prompts'].tolist()))
|
||||
return list(itertools.chain(*ds1['prompts']))
|
||||
|
||||
|
||||
# j = 0
|
||||
@@ -416,7 +426,7 @@ def _convert_to_prompts(
|
||||
|
||||
|
||||
|
||||
def load_preproc_datasets(dataset_names: List[str], N:int, split_type:str="train", seed=42, num_shots=1, M=3):
|
||||
def load_preproc_datasets(dataset_names: List[str], N:int, split_type:str="train", seed=42, num_shots=1, M=3, num_proc=1,):
|
||||
datasets2 = []
|
||||
n = N//len(dataset_names)+1
|
||||
for ds_name in tqdm(dataset_names):
|
||||
@@ -427,6 +437,7 @@ def load_preproc_datasets(dataset_names: List[str], N:int, split_type:str="train
|
||||
seed=seed,
|
||||
num_shots=num_shots,
|
||||
M=M,
|
||||
num_proc=num_proc,
|
||||
).with_format("torch")
|
||||
datasets2.append(ds_tokens1)
|
||||
ds_tokens = datasets.concatenate_datasets(datasets2)
|
||||
|
||||
+25
-5
@@ -49,7 +49,7 @@
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"ExtractConfig(datasets=('amazon_polarity',), datasets_ood=('imdb', 'super_glue:boolq'), model='cognitivecomputations/dolphin-2.9.3-llama-3-8b', num_shots=2, max_tokens=444, max_examples=1000000, seed=42, repeats=3)"
|
||||
"ExtractConfig(datasets=('amazon_polarity',), datasets_ood=('imdb', 'super_glue:boolq'), model='cognitivecomputations/dolphin-2.9.3-llama-3-8b', num_shots=2, max_tokens=444, max_examples=100006, seed=42, repeats=3)"
|
||||
]
|
||||
},
|
||||
"execution_count": 3,
|
||||
@@ -80,7 +80,11 @@
|
||||
" # \"glue:sst2\",\n",
|
||||
" # \"super_glue:axg\",\n",
|
||||
" \n",
|
||||
"), max_examples=1000000, max_tokens=444)\n",
|
||||
"), \n",
|
||||
"# max_examples=100000, \n",
|
||||
"max_examples=100006, \n",
|
||||
"\n",
|
||||
"max_tokens=444)\n",
|
||||
"cfg\n",
|
||||
"# lie_elicitation_prompts/prompts/templates/liar"
|
||||
]
|
||||
@@ -153,7 +157,7 @@
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "03044a83e624464a94b8081127412d3e",
|
||||
"model_id": "6e45dbc76f9446be98011de8a7cfd77c",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
@@ -167,7 +171,7 @@
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "803c89bd3ebc477c9cfc3f73f9ba4105",
|
||||
"model_id": "a15de459703e45f38efffc1036063d90",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
@@ -182,8 +186,23 @@
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\u001b[32m2024-07-01 19:52:58.549\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mlie_elicitation_prompts.prompts.prompt_loading\u001b[0m:\u001b[36mload_prompts\u001b[0m:\u001b[36m193\u001b[0m - \u001b[1mExtracting 11 variants of each prompt\u001b[0m\n"
|
||||
"\u001b[32m2024-07-01 21:32:57.756\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mlie_elicitation_prompts.prompts.prompt_loading\u001b[0m:\u001b[36mload_prompts\u001b[0m:\u001b[36m195\u001b[0m - \u001b[1mExtracting 11 variants of each prompt\u001b[0m\n",
|
||||
"Parameter 'function'=<function load_prompts.<locals>.foo at 0x72bb8db6fb50> of the transform datasets.arrow_dataset.Dataset._map_single couldn't be hashed properly, a random hash was used instead. Make sure your transforms and parameters are serializable with pickle or dill for the dataset fingerprinting and caching to work. If you reuse this transform, the caching mechanism will consider it to be different from the previous calls and recompute everything. This warning is only showed once. Subsequent hashing failures won't be showed.\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "6fc652a01ccd4b5d997698f0aaca787a",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"convert_to_prompts (num_proc=12): 0%| | 0/100007 [00:00<?, ? examples/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
@@ -194,6 +213,7 @@
|
||||
" seed=cfg.seed,\n",
|
||||
" num_shots=cfg.num_shots,\n",
|
||||
" M=cfg.repeats,\n",
|
||||
" # num_proc=4,\n",
|
||||
")\n",
|
||||
"ds_prompts"
|
||||
]
|
||||
|
||||
Reference in New Issue
Block a user