This commit is contained in:
wassname
2025-01-06 12:12:19 +08:00
parent 07803d2c90
commit 75ced83574
5 changed files with 4524 additions and 1007 deletions
+22
View File
@@ -0,0 +1,22 @@
import frontmatter
from pathlib import Path
from loguru import logger
import pandas as pd
def load_md(f: Path = Path("../samples/"), max_chars=2000):
sample_files = sorted(f.glob('*.md'))
samples = []
for f in sample_files:
logger.debug(f)
with open(f, "r") as file:
post = frontmatter.load(file)
samples.append(dict(content=post.content[:max_chars], f=f, **post.metadata))
return samples
def load_md_df(f, max_chars=2000):
samples = load_md(f, max_chars=max_chars)
df = pd.DataFrame(samples)
df = df[['title', 'f', 'content', 'url', 'novelty', 'date']]
df['date'] = pd.to_datetime(df['date'], utc=True)
df['in_training'] = df.date < '2024-01-01'
return df
+44
View File
@@ -94,6 +94,7 @@ def perplexity_compute(
(loss_fct(shift_logits.transpose(1, 2), shift_labels) * shift_attention_mask_batch)
)
# remove all the masked ones
assert nll_batch.size(0) == 1
nll_batch = nll_batch[shift_attention_mask_batch == 1][None, :] # FIXME only for batch_size=1
perplexity_batch = torch.exp(nll_batch.mean(1)).cpu().numpy()
@@ -101,3 +102,46 @@ def perplexity_compute(
nlls += nll_batch.cpu().numpy().tolist()
return {"perplexities": np.array(ppls), "nlls": np.array(nlls)}
# modified from https://github.dev/huggingface/evaluate/blob/8dfe05784099fb9af55b8e77793205a3b7c86465/measurements/perplexity/perplexity.py#L154
import evaluate
from evaluate import logging
from torch.nn import CrossEntropyLoss
from torch.utils.data import DataLoader
def perplexity_compute_ds(
ds, model, tokenizer, batch_size: int = 16, add_start_token: bool = True, device=None, max_length=None
):
model = model.to(device)
ds = ds.with_format('pt')
dl = DataLoader(ds, batch_size=1, shuffle=False, num_workers=0, collate_fn=tokenizer.pad, pin_memory=True)
ppls = []
nlls = []
loss_fct = CrossEntropyLoss(reduction="none")
for b in dl:
input_ids = b['input_ids'].to(device)
attention_mask = b['attention_mask'].to(device)
labels = input_ids
with torch.no_grad():
out_logits = model(input_ids=input_ids, attention_mask=attention_mask).logits
shift_logits = out_logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
shift_attention_mask_batch = attention_mask[..., 1:].contiguous()
nll_batch = loss_fct(shift_logits.transpose(1, 2), shift_labels) * shift_attention_mask_batch
assert nll_batch.size(0) == 1
nll_batch = nll_batch[shift_attention_mask_batch == 1][None, :] # FIXME only for batch_size=1
perplexity_batch = torch.exp(nll_batch.sum(1) / shift_attention_mask_batch.sum(1)).cpu().numpy()
ppls += perplexity_batch.tolist()
nlls += nll_batch.cpu().numpy().tolist()
return {"perplexities": torch.tensor(ppls), "mean_perplexity": torch.tensor(ppls).mean(), "nlls": torch.tensor(nlls)}
File diff suppressed because one or more lines are too long
File diff suppressed because it is too large Load Diff
+24 -4
View File
@@ -470,7 +470,7 @@
},
{
"cell_type": "code",
"execution_count": 144,
"execution_count": 147,
"metadata": {},
"outputs": [
{
@@ -486,7 +486,27 @@
"../samples/2025_lw_the-intelligence-curse.md 1.2061121463775635\n",
"../samples/2025_lw_the-intelligence-curse.md 1.2061121463775635\n",
"../samples/2025_lw_human-study-on-ai-spear-phishing-campaigns.md 0.9995136260986328\n",
"../samples/2025_lw_human-study-on-ai-spear-phishing-campaigns.md 0.9995136260986328\n"
"../samples/2025_lw_human-study-on-ai-spear-phishing-campaigns.md 0.9995136260986328\n",
"../samples/2025_lw_the-subset-parity-learning-problem-much-more-than-you-wanted.md 0.9548193216323853\n",
"../samples/2025_lw_the-subset-parity-learning-problem-much-more-than-you-wanted.md 0.9548193216323853\n",
"../samples/2025_lw_2024-in-ai-predictions.md 0.8065339922904968\n",
"../samples/2025_lw_2024-in-ai-predictions.md 0.8065339922904968\n",
"../samples/2025_lw_debating-buying-nvda-in-2019.md 0.7926478385925293\n",
"../samples/2025_lw_debating-buying-nvda-in-2019.md 0.7926478385925293\n",
"../samples/2025_lw_review-planecrash.md 0.689734160900116\n",
"../samples/2025_lw_review-planecrash.md 0.689734160900116\n",
"../samples/2024_lw_by-default-capital-will-matter-more-than-ever-after-agi.md 0.6629015207290649\n",
"../samples/2024_lw_by-default-capital-will-matter-more-than-ever-after-agi.md 0.6629015207290649\n",
"../samples/2025_lw_the-field-of-ai-alignment-a-postmortem-and-what-to-do-about.md 0.5714353919029236\n",
"../samples/2025_lw_the-field-of-ai-alignment-a-postmortem-and-what-to-do-about.md 0.5714353919029236\n",
"../samples/2024_lw_the-plan-2024-update.md 0.542655885219574\n",
"../samples/2024_lw_the-plan-2024-update.md 0.542655885219574\n",
"../samples/2025_lw_comment-on-death-and-the-gorgon.md 0.5308915376663208\n",
"../samples/2025_lw_comment-on-death-and-the-gorgon.md 0.5308915376663208\n",
"../samples/2025_lw_my-agi-safety-research-2024-review-25-plans.md 0.49594494700431824\n",
"../samples/2025_lw_my-agi-safety-research-2024-review-25-plans.md 0.49594494700431824\n",
"../samples/2025_lw_preference-inversion.md 0.48199906945228577\n",
"../samples/2025_lw_preference-inversion.md 0.48199906945228577\n"
]
}
],
@@ -495,7 +515,7 @@
" md = markdownify(row[\"htmlBody\"]).strip()\n",
"\n",
" return f\"\"\"---\n",
"title: {row['title']}\n",
"title: \"{row['title'].replace('\"', \"'\")}\"\n",
"date: {row['modifiedAt']}\n",
"url: {row['pageUrl']}\n",
"novelty: {row['novelty']}\n",
@@ -507,7 +527,7 @@
"\"\"\"\n",
"\n",
"\n",
"for i in range(5):\n",
"for i in range(15):\n",
" for ii in [i, -i-1]:\n",
" row = df.iloc[i]\n",
" s = to_markdown(row)\n",