This commit is contained in:
wassname
2024-06-14 14:57:49 +08:00
parent bccc2cad76
commit 102105657c
4 changed files with 707 additions and 173 deletions
+5 -7
View File
@@ -9,15 +9,13 @@ ROOT_DIR = Path(__file__).parent.parent
class ExtractConfig(Serializable):
"""Config for extracting hidden states from a language model."""
datasets: tuple[str, ...] = ("amazon_polarity", "glue:qqp", "glue:sst2", "super_glue:axb", "super_glue:axg", "super_glue:wsc.fixed")
datasets: tuple[str, ...] = ("amazon_polarity", "glue:sst2", "super_glue:axg")
"""datasets to use, e.g. `"super_glue:boolq"` or `"imdb"` `"glue:qqp` super_glue:rte super_glue:axg sst2 hans
note make sure it's an easy task (https://openreview.net/pdf?id=rJ4km2R5t7, https://super.gluebenchmark.com/leaderboard/) that in elk https://github.dev/EleutherAI/elk)
"""
note make sure it's an easy task (https://openreview.net/pdf?id=rJ4km2R5t7, https://super.gluebenchmark.com/leaderboard/) and in elk https://github.dev/EleutherAI/elk)
"""
datasets_ood: tuple[str, ...] = ( 'imdb', "super_glue:boolq")
"""Out Of Distribution datasets to use, e.g. `"super_glue:boolq"` or `"imdb"` `"glue:qnli"""
"""Out Of Distribution datasets to use, """
model: str = "failspy/Llama-3-8B-Instruct-abliterated"
"""You should use a smart and helpfull>harmless model or you will get hardly any lies."""
@@ -28,7 +26,7 @@ class ExtractConfig(Serializable):
max_tokens: int | None = 776
"""Maximum length of the input sequence passed to the tokenize encoder function"""
max_examples: tuple[int, int] = 1000
max_examples: tuple[int, int] = 4000
"""Maximum number of examples before truncation and filtering"""
seed: int = 42
+2 -2
View File
@@ -79,9 +79,9 @@ def select_choice_from_logits(logits, choiceids: List[List[int]]):
assert logits.ndim==1, f"expected logits to be 1d, got {logits.shape}"
assert logits.sum(0).abs()>2, 'pass in logits, not probs. you may have accidentally passed in a softmaxed values'
choiceids = [list(set(i)) for i in choiceids] # we don't want to double count
probs = torch.softmax(logits, 0) # shape [tokens, inferences)
probs = logits.softmax(0) # shape [tokens, inferences)
probs_c = torch.tensor([[probs[cc] for cc in c] for c in choiceids]).sum(1) # sum over alternate choices e.g. [['decrease', 'dec'],['inc', 'increase']]
assert probs_c.sum()<=1.01, f"expected probs to sum to 1, got {probs_c.sum()}"
assert probs_c.sum()<=1.01, f"expected probs to be less than 1, as it's a selection from probs that add to 1"
return probs_c
@@ -185,7 +185,7 @@ def load_prompts(
return keep
prompts1 = list(filter(prompt_ok, prompts))
prompts2 = prompt_sampler(prompts1, seed=42+j, num_truth=M, num_truth=M)
prompts2 = prompt_sampler(prompts1, seed=42+j, num_truth=M, num_lie=M)
for p in prompts2:
j += 1
yield p
@@ -307,7 +307,7 @@ def load_preproc_datasets(dataset_names: List[str], N:int, split_type:str="train
N=n,
seed=seed,
num_shots=num_shots,
M=<n
M=M,
).with_format("torch")
datasets2.append(ds_tokens1)
ds_tokens = datasets.concatenate_datasets(datasets2)
+698 -162
View File
File diff suppressed because it is too large Load Diff