multiproc works

This commit is contained in:
wassname
2024-07-01 21:33:27 +08:00
parent 4a7d1389f1
commit c6ac96d102
2 changed files with 45 additions and 14 deletions
@@ -3,6 +3,7 @@ Modified from https://github.com/EleutherAI/elk/blob/3bbe26c3858aac1b03e6f80628a
Changed to record choices
"""
import multiprocessing
from collections import Counter
from random import Random
from typing import Any, Iterator, Literal, List, Dict
@@ -137,7 +138,7 @@ def load_prompts(
split_type: Literal["train", "val"] = "train",
template_path: str | None = None,
rank: int = 0,
world_size: int = 8,
world_size: int = 1,
prompt_sampler = sample_n_true_y_false_prompts,
N=np.inf,
M:int=3
@@ -173,6 +174,7 @@ def load_prompts(
ds = assert_type(Dataset, ds_dict[split_name])
if world_size > 1:
ds = ds.shard(world_size, rank)
N = min(N, len(ds))
# load dataset templates
if template_path is None:
@@ -220,7 +222,7 @@ def load_prompts(
# )
# fewshot_iter = iter(fewshot)
fewshot_ds = FewShotDataset2(
ds_dict[train_name].shuffle(seed=seed), # TODO: not iterator
ds_dict[train_name].shuffle(seed=seed).select(range(1000)), # TODO: not iterator
num_shots=num_shots,
rng=rng,
label_col=label_column,
@@ -238,16 +240,15 @@ def load_prompts(
# else:
# if rank == 0:
# logger.info("No label column found, not balancing")
N = min(N, len(ds))
# ds1 = ds.select(range(N)).to_iterable_dataset()
def foo(example, i):
def foo(example, i, binarize, label_column, prompter, rng, sys_instructions, fewshot_ds,):
prompts = _convert_to_prompts(
example,
binarize=binarize,
label_column=label_column,
# label_choices=label_choices, # type: ignore[arg-type]
prompter=prompter,
rng=rng,
sys_instructions=sys_instructions,
@@ -261,11 +262,20 @@ def load_prompts(
prompts2 = prompt_sampler(prompts1, seed=42+i, num_truth=M, num_lie=M)
return {'prompts': prompts2}
ds1 = ds.select(range(N)).map(foo, with_indices=True, desc='convert_to_prompts',
num_proc=8,
ds1 = ds.select(range(N)).map(foo, with_indices=True,
desc='convert_to_prompts',
fn_kwargs=dict( binarize=binarize,
label_column=label_column,
prompter=prompter,
rng=rng,
sys_instructions=sys_instructions,
fewshot_ds=fewshot_ds,),
# num_proc=1,#
num_proc=multiprocessing.cpu_count()//2,
)
return list(itertools.chain(*ds1['prompts'].tolist()))
return list(itertools.chain(*ds1['prompts']))
# j = 0
@@ -416,7 +426,7 @@ def _convert_to_prompts(
def load_preproc_datasets(dataset_names: List[str], N:int, split_type:str="train", seed=42, num_shots=1, M=3):
def load_preproc_datasets(dataset_names: List[str], N:int, split_type:str="train", seed=42, num_shots=1, M=3, num_proc=1,):
datasets2 = []
n = N//len(dataset_names)+1
for ds_name in tqdm(dataset_names):
@@ -427,6 +437,7 @@ def load_preproc_datasets(dataset_names: List[str], N:int, split_type:str="train
seed=seed,
num_shots=num_shots,
M=M,
num_proc=num_proc,
).with_format("torch")
datasets2.append(ds_tokens1)
ds_tokens = datasets.concatenate_datasets(datasets2)
+25 -5
View File
@@ -49,7 +49,7 @@
{
"data": {
"text/plain": [
"ExtractConfig(datasets=('amazon_polarity',), datasets_ood=('imdb', 'super_glue:boolq'), model='cognitivecomputations/dolphin-2.9.3-llama-3-8b', num_shots=2, max_tokens=444, max_examples=1000000, seed=42, repeats=3)"
"ExtractConfig(datasets=('amazon_polarity',), datasets_ood=('imdb', 'super_glue:boolq'), model='cognitivecomputations/dolphin-2.9.3-llama-3-8b', num_shots=2, max_tokens=444, max_examples=100006, seed=42, repeats=3)"
]
},
"execution_count": 3,
@@ -80,7 +80,11 @@
" # \"glue:sst2\",\n",
" # \"super_glue:axg\",\n",
" \n",
"), max_examples=1000000, max_tokens=444)\n",
"), \n",
"# max_examples=100000, \n",
"max_examples=100006, \n",
"\n",
"max_tokens=444)\n",
"cfg\n",
"# lie_elicitation_prompts/prompts/templates/liar"
]
@@ -153,7 +157,7 @@
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "03044a83e624464a94b8081127412d3e",
"model_id": "6e45dbc76f9446be98011de8a7cfd77c",
"version_major": 2,
"version_minor": 0
},
@@ -167,7 +171,7 @@
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "803c89bd3ebc477c9cfc3f73f9ba4105",
"model_id": "a15de459703e45f38efffc1036063d90",
"version_major": 2,
"version_minor": 0
},
@@ -182,8 +186,23 @@
"name": "stderr",
"output_type": "stream",
"text": [
"\u001b[32m2024-07-01 19:52:58.549\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mlie_elicitation_prompts.prompts.prompt_loading\u001b[0m:\u001b[36mload_prompts\u001b[0m:\u001b[36m193\u001b[0m - \u001b[1mExtracting 11 variants of each prompt\u001b[0m\n"
"\u001b[32m2024-07-01 21:32:57.756\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mlie_elicitation_prompts.prompts.prompt_loading\u001b[0m:\u001b[36mload_prompts\u001b[0m:\u001b[36m195\u001b[0m - \u001b[1mExtracting 11 variants of each prompt\u001b[0m\n",
"Parameter 'function'=<function load_prompts.<locals>.foo at 0x72bb8db6fb50> of the transform datasets.arrow_dataset.Dataset._map_single couldn't be hashed properly, a random hash was used instead. Make sure your transforms and parameters are serializable with pickle or dill for the dataset fingerprinting and caching to work. If you reuse this transform, the caching mechanism will consider it to be different from the previous calls and recompute everything. This warning is only showed once. Subsequent hashing failures won't be showed.\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "6fc652a01ccd4b5d997698f0aaca787a",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"convert_to_prompts (num_proc=12): 0%| | 0/100007 [00:00<?, ? examples/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
@@ -194,6 +213,7 @@
" seed=cfg.seed,\n",
" num_shots=cfg.num_shots,\n",
" M=cfg.repeats,\n",
" # num_proc=4,\n",
")\n",
"ds_prompts"
]