From c6ac96d1028e74c879735cae890937ebedbe5c3a Mon Sep 17 00:00:00 2001 From: wassname Date: Mon, 1 Jul 2024 21:33:27 +0800 Subject: [PATCH] multiproc works --- .../prompts/prompt_loading.py | 29 ++++++++++++------ nbs/build.ipynb | 30 +++++++++++++++---- 2 files changed, 45 insertions(+), 14 deletions(-) diff --git a/lie_elicitation_prompts/prompts/prompt_loading.py b/lie_elicitation_prompts/prompts/prompt_loading.py index 0a1b6b2..f8c13f3 100644 --- a/lie_elicitation_prompts/prompts/prompt_loading.py +++ b/lie_elicitation_prompts/prompts/prompt_loading.py @@ -3,6 +3,7 @@ Modified from https://github.com/EleutherAI/elk/blob/3bbe26c3858aac1b03e6f80628a Changed to record choices """ +import multiprocessing from collections import Counter from random import Random from typing import Any, Iterator, Literal, List, Dict @@ -137,7 +138,7 @@ def load_prompts( split_type: Literal["train", "val"] = "train", template_path: str | None = None, rank: int = 0, - world_size: int = 8, + world_size: int = 1, prompt_sampler = sample_n_true_y_false_prompts, N=np.inf, M:int=3 @@ -173,6 +174,7 @@ def load_prompts( ds = assert_type(Dataset, ds_dict[split_name]) if world_size > 1: ds = ds.shard(world_size, rank) + N = min(N, len(ds)) # load dataset templates if template_path is None: @@ -220,7 +222,7 @@ def load_prompts( # ) # fewshot_iter = iter(fewshot) fewshot_ds = FewShotDataset2( - ds_dict[train_name].shuffle(seed=seed), # TODO: not iterator + ds_dict[train_name].shuffle(seed=seed).select(range(1000)), # TODO: not iterator num_shots=num_shots, rng=rng, label_col=label_column, @@ -238,16 +240,15 @@ def load_prompts( # else: # if rank == 0: # logger.info("No label column found, not balancing") - N = min(N, len(ds)) + # ds1 = ds.select(range(N)).to_iterable_dataset() - def foo(example, i): + def foo(example, i, binarize, label_column, prompter, rng, sys_instructions, fewshot_ds,): prompts = _convert_to_prompts( example, binarize=binarize, label_column=label_column, - # label_choices=label_choices, # type: ignore[arg-type] prompter=prompter, rng=rng, sys_instructions=sys_instructions, @@ -261,11 +262,20 @@ def load_prompts( prompts2 = prompt_sampler(prompts1, seed=42+i, num_truth=M, num_lie=M) return {'prompts': prompts2} - ds1 = ds.select(range(N)).map(foo, with_indices=True, desc='convert_to_prompts', - num_proc=8, + ds1 = ds.select(range(N)).map(foo, with_indices=True, + + desc='convert_to_prompts', + fn_kwargs=dict( binarize=binarize, + label_column=label_column, + prompter=prompter, + rng=rng, + sys_instructions=sys_instructions, + fewshot_ds=fewshot_ds,), + # num_proc=1,# + num_proc=multiprocessing.cpu_count()//2, ) - return list(itertools.chain(*ds1['prompts'].tolist())) + return list(itertools.chain(*ds1['prompts'])) # j = 0 @@ -416,7 +426,7 @@ def _convert_to_prompts( -def load_preproc_datasets(dataset_names: List[str], N:int, split_type:str="train", seed=42, num_shots=1, M=3): +def load_preproc_datasets(dataset_names: List[str], N:int, split_type:str="train", seed=42, num_shots=1, M=3, num_proc=1,): datasets2 = [] n = N//len(dataset_names)+1 for ds_name in tqdm(dataset_names): @@ -427,6 +437,7 @@ def load_preproc_datasets(dataset_names: List[str], N:int, split_type:str="train seed=seed, num_shots=num_shots, M=M, + num_proc=num_proc, ).with_format("torch") datasets2.append(ds_tokens1) ds_tokens = datasets.concatenate_datasets(datasets2) diff --git a/nbs/build.ipynb b/nbs/build.ipynb index 0c05dcd..0002631 100644 --- a/nbs/build.ipynb +++ b/nbs/build.ipynb @@ -49,7 +49,7 @@ { "data": { "text/plain": [ - "ExtractConfig(datasets=('amazon_polarity',), datasets_ood=('imdb', 'super_glue:boolq'), model='cognitivecomputations/dolphin-2.9.3-llama-3-8b', num_shots=2, max_tokens=444, max_examples=1000000, seed=42, repeats=3)" + "ExtractConfig(datasets=('amazon_polarity',), datasets_ood=('imdb', 'super_glue:boolq'), model='cognitivecomputations/dolphin-2.9.3-llama-3-8b', num_shots=2, max_tokens=444, max_examples=100006, seed=42, repeats=3)" ] }, "execution_count": 3, @@ -80,7 +80,11 @@ " # \"glue:sst2\",\n", " # \"super_glue:axg\",\n", " \n", - "), max_examples=1000000, max_tokens=444)\n", + "), \n", + "# max_examples=100000, \n", + "max_examples=100006, \n", + "\n", + "max_tokens=444)\n", "cfg\n", "# lie_elicitation_prompts/prompts/templates/liar" ] @@ -153,7 +157,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "03044a83e624464a94b8081127412d3e", + "model_id": "6e45dbc76f9446be98011de8a7cfd77c", "version_major": 2, "version_minor": 0 }, @@ -167,7 +171,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "803c89bd3ebc477c9cfc3f73f9ba4105", + "model_id": "a15de459703e45f38efffc1036063d90", "version_major": 2, "version_minor": 0 }, @@ -182,8 +186,23 @@ "name": "stderr", "output_type": "stream", "text": [ - "\u001b[32m2024-07-01 19:52:58.549\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mlie_elicitation_prompts.prompts.prompt_loading\u001b[0m:\u001b[36mload_prompts\u001b[0m:\u001b[36m193\u001b[0m - \u001b[1mExtracting 11 variants of each prompt\u001b[0m\n" + "\u001b[32m2024-07-01 21:32:57.756\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mlie_elicitation_prompts.prompts.prompt_loading\u001b[0m:\u001b[36mload_prompts\u001b[0m:\u001b[36m195\u001b[0m - \u001b[1mExtracting 11 variants of each prompt\u001b[0m\n", + "Parameter 'function'=.foo at 0x72bb8db6fb50> of the transform datasets.arrow_dataset.Dataset._map_single couldn't be hashed properly, a random hash was used instead. Make sure your transforms and parameters are serializable with pickle or dill for the dataset fingerprinting and caching to work. If you reuse this transform, the caching mechanism will consider it to be different from the previous calls and recompute everything. This warning is only showed once. Subsequent hashing failures won't be showed.\n" ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "6fc652a01ccd4b5d997698f0aaca787a", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "convert_to_prompts (num_proc=12): 0%| | 0/100007 [00:00