mirror of
https://github.com/wassname/lie_elicitation_prompts.git
synced 2026-06-27 16:10:35 +08:00
works
This commit is contained in:
@@ -9,15 +9,13 @@ ROOT_DIR = Path(__file__).parent.parent
|
||||
class ExtractConfig(Serializable):
|
||||
"""Config for extracting hidden states from a language model."""
|
||||
|
||||
datasets: tuple[str, ...] = ("amazon_polarity", "glue:qqp", "glue:sst2", "super_glue:axb", "super_glue:axg", "super_glue:wsc.fixed")
|
||||
datasets: tuple[str, ...] = ("amazon_polarity", "glue:sst2", "super_glue:axg")
|
||||
"""datasets to use, e.g. `"super_glue:boolq"` or `"imdb"` `"glue:qqp` super_glue:rte super_glue:axg sst2 hans
|
||||
note make sure it's an easy task (https://openreview.net/pdf?id=rJ4km2R5t7, https://super.gluebenchmark.com/leaderboard/) that in elk https://github.dev/EleutherAI/elk)
|
||||
|
||||
"""
|
||||
|
||||
note make sure it's an easy task (https://openreview.net/pdf?id=rJ4km2R5t7, https://super.gluebenchmark.com/leaderboard/) and in elk https://github.dev/EleutherAI/elk)
|
||||
"""
|
||||
|
||||
datasets_ood: tuple[str, ...] = ( 'imdb', "super_glue:boolq")
|
||||
"""Out Of Distribution datasets to use, e.g. `"super_glue:boolq"` or `"imdb"` `"glue:qnli"""
|
||||
"""Out Of Distribution datasets to use, """
|
||||
|
||||
model: str = "failspy/Llama-3-8B-Instruct-abliterated"
|
||||
"""You should use a smart and helpfull>harmless model or you will get hardly any lies."""
|
||||
@@ -28,7 +26,7 @@ class ExtractConfig(Serializable):
|
||||
max_tokens: int | None = 776
|
||||
"""Maximum length of the input sequence passed to the tokenize encoder function"""
|
||||
|
||||
max_examples: tuple[int, int] = 1000
|
||||
max_examples: tuple[int, int] = 4000
|
||||
"""Maximum number of examples before truncation and filtering"""
|
||||
|
||||
seed: int = 42
|
||||
|
||||
@@ -79,9 +79,9 @@ def select_choice_from_logits(logits, choiceids: List[List[int]]):
|
||||
assert logits.ndim==1, f"expected logits to be 1d, got {logits.shape}"
|
||||
assert logits.sum(0).abs()>2, 'pass in logits, not probs. you may have accidentally passed in a softmaxed values'
|
||||
choiceids = [list(set(i)) for i in choiceids] # we don't want to double count
|
||||
probs = torch.softmax(logits, 0) # shape [tokens, inferences)
|
||||
probs = logits.softmax(0) # shape [tokens, inferences)
|
||||
probs_c = torch.tensor([[probs[cc] for cc in c] for c in choiceids]).sum(1) # sum over alternate choices e.g. [['decrease', 'dec'],['inc', 'increase']]
|
||||
assert probs_c.sum()<=1.01, f"expected probs to sum to 1, got {probs_c.sum()}"
|
||||
assert probs_c.sum()<=1.01, f"expected probs to be less than 1, as it's a selection from probs that add to 1"
|
||||
return probs_c
|
||||
|
||||
|
||||
|
||||
@@ -185,7 +185,7 @@ def load_prompts(
|
||||
return keep
|
||||
|
||||
prompts1 = list(filter(prompt_ok, prompts))
|
||||
prompts2 = prompt_sampler(prompts1, seed=42+j, num_truth=M, num_truth=M)
|
||||
prompts2 = prompt_sampler(prompts1, seed=42+j, num_truth=M, num_lie=M)
|
||||
for p in prompts2:
|
||||
j += 1
|
||||
yield p
|
||||
@@ -307,7 +307,7 @@ def load_preproc_datasets(dataset_names: List[str], N:int, split_type:str="train
|
||||
N=n,
|
||||
seed=seed,
|
||||
num_shots=num_shots,
|
||||
M=<n
|
||||
M=M,
|
||||
).with_format("torch")
|
||||
datasets2.append(ds_tokens1)
|
||||
ds_tokens = datasets.concatenate_datasets(datasets2)
|
||||
|
||||
+698
-162
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user