mirror of
https://github.com/wassname/detect_bs_text.git
synced 2026-06-27 18:42:43 +08:00
wip
This commit is contained in:
@@ -0,0 +1,22 @@
|
||||
import frontmatter
|
||||
from pathlib import Path
|
||||
from loguru import logger
|
||||
import pandas as pd
|
||||
|
||||
def load_md(f: Path = Path("../samples/"), max_chars=2000):
|
||||
sample_files = sorted(f.glob('*.md'))
|
||||
samples = []
|
||||
for f in sample_files:
|
||||
logger.debug(f)
|
||||
with open(f, "r") as file:
|
||||
post = frontmatter.load(file)
|
||||
samples.append(dict(content=post.content[:max_chars], f=f, **post.metadata))
|
||||
return samples
|
||||
|
||||
def load_md_df(f, max_chars=2000):
|
||||
samples = load_md(f, max_chars=max_chars)
|
||||
df = pd.DataFrame(samples)
|
||||
df = df[['title', 'f', 'content', 'url', 'novelty', 'date']]
|
||||
df['date'] = pd.to_datetime(df['date'], utc=True)
|
||||
df['in_training'] = df.date < '2024-01-01'
|
||||
return df
|
||||
@@ -94,6 +94,7 @@ def perplexity_compute(
|
||||
(loss_fct(shift_logits.transpose(1, 2), shift_labels) * shift_attention_mask_batch)
|
||||
)
|
||||
# remove all the masked ones
|
||||
assert nll_batch.size(0) == 1
|
||||
nll_batch = nll_batch[shift_attention_mask_batch == 1][None, :] # FIXME only for batch_size=1
|
||||
perplexity_batch = torch.exp(nll_batch.mean(1)).cpu().numpy()
|
||||
|
||||
@@ -101,3 +102,46 @@ def perplexity_compute(
|
||||
nlls += nll_batch.cpu().numpy().tolist()
|
||||
|
||||
return {"perplexities": np.array(ppls), "nlls": np.array(nlls)}
|
||||
|
||||
|
||||
# modified from https://github.dev/huggingface/evaluate/blob/8dfe05784099fb9af55b8e77793205a3b7c86465/measurements/perplexity/perplexity.py#L154
|
||||
import evaluate
|
||||
from evaluate import logging
|
||||
from torch.nn import CrossEntropyLoss
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
def perplexity_compute_ds(
|
||||
ds, model, tokenizer, batch_size: int = 16, add_start_token: bool = True, device=None, max_length=None
|
||||
):
|
||||
model = model.to(device)
|
||||
|
||||
|
||||
ds = ds.with_format('pt')
|
||||
dl = DataLoader(ds, batch_size=1, shuffle=False, num_workers=0, collate_fn=tokenizer.pad, pin_memory=True)
|
||||
|
||||
ppls = []
|
||||
nlls = []
|
||||
loss_fct = CrossEntropyLoss(reduction="none")
|
||||
for b in dl:
|
||||
input_ids = b['input_ids'].to(device)
|
||||
attention_mask = b['attention_mask'].to(device)
|
||||
|
||||
labels = input_ids
|
||||
|
||||
with torch.no_grad():
|
||||
out_logits = model(input_ids=input_ids, attention_mask=attention_mask).logits
|
||||
|
||||
shift_logits = out_logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
shift_attention_mask_batch = attention_mask[..., 1:].contiguous()
|
||||
|
||||
nll_batch = loss_fct(shift_logits.transpose(1, 2), shift_labels) * shift_attention_mask_batch
|
||||
assert nll_batch.size(0) == 1
|
||||
nll_batch = nll_batch[shift_attention_mask_batch == 1][None, :] # FIXME only for batch_size=1
|
||||
|
||||
perplexity_batch = torch.exp(nll_batch.sum(1) / shift_attention_mask_batch.sum(1)).cpu().numpy()
|
||||
|
||||
ppls += perplexity_batch.tolist()
|
||||
nlls += nll_batch.cpu().numpy().tolist()
|
||||
|
||||
return {"perplexities": torch.tensor(ppls), "mean_perplexity": torch.tensor(ppls).mean(), "nlls": torch.tensor(nlls)}
|
||||
|
||||
File diff suppressed because one or more lines are too long
File diff suppressed because it is too large
Load Diff
@@ -470,7 +470,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 144,
|
||||
"execution_count": 147,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
@@ -486,7 +486,27 @@
|
||||
"../samples/2025_lw_the-intelligence-curse.md 1.2061121463775635\n",
|
||||
"../samples/2025_lw_the-intelligence-curse.md 1.2061121463775635\n",
|
||||
"../samples/2025_lw_human-study-on-ai-spear-phishing-campaigns.md 0.9995136260986328\n",
|
||||
"../samples/2025_lw_human-study-on-ai-spear-phishing-campaigns.md 0.9995136260986328\n"
|
||||
"../samples/2025_lw_human-study-on-ai-spear-phishing-campaigns.md 0.9995136260986328\n",
|
||||
"../samples/2025_lw_the-subset-parity-learning-problem-much-more-than-you-wanted.md 0.9548193216323853\n",
|
||||
"../samples/2025_lw_the-subset-parity-learning-problem-much-more-than-you-wanted.md 0.9548193216323853\n",
|
||||
"../samples/2025_lw_2024-in-ai-predictions.md 0.8065339922904968\n",
|
||||
"../samples/2025_lw_2024-in-ai-predictions.md 0.8065339922904968\n",
|
||||
"../samples/2025_lw_debating-buying-nvda-in-2019.md 0.7926478385925293\n",
|
||||
"../samples/2025_lw_debating-buying-nvda-in-2019.md 0.7926478385925293\n",
|
||||
"../samples/2025_lw_review-planecrash.md 0.689734160900116\n",
|
||||
"../samples/2025_lw_review-planecrash.md 0.689734160900116\n",
|
||||
"../samples/2024_lw_by-default-capital-will-matter-more-than-ever-after-agi.md 0.6629015207290649\n",
|
||||
"../samples/2024_lw_by-default-capital-will-matter-more-than-ever-after-agi.md 0.6629015207290649\n",
|
||||
"../samples/2025_lw_the-field-of-ai-alignment-a-postmortem-and-what-to-do-about.md 0.5714353919029236\n",
|
||||
"../samples/2025_lw_the-field-of-ai-alignment-a-postmortem-and-what-to-do-about.md 0.5714353919029236\n",
|
||||
"../samples/2024_lw_the-plan-2024-update.md 0.542655885219574\n",
|
||||
"../samples/2024_lw_the-plan-2024-update.md 0.542655885219574\n",
|
||||
"../samples/2025_lw_comment-on-death-and-the-gorgon.md 0.5308915376663208\n",
|
||||
"../samples/2025_lw_comment-on-death-and-the-gorgon.md 0.5308915376663208\n",
|
||||
"../samples/2025_lw_my-agi-safety-research-2024-review-25-plans.md 0.49594494700431824\n",
|
||||
"../samples/2025_lw_my-agi-safety-research-2024-review-25-plans.md 0.49594494700431824\n",
|
||||
"../samples/2025_lw_preference-inversion.md 0.48199906945228577\n",
|
||||
"../samples/2025_lw_preference-inversion.md 0.48199906945228577\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
@@ -495,7 +515,7 @@
|
||||
" md = markdownify(row[\"htmlBody\"]).strip()\n",
|
||||
"\n",
|
||||
" return f\"\"\"---\n",
|
||||
"title: {row['title']}\n",
|
||||
"title: \"{row['title'].replace('\"', \"'\")}\"\n",
|
||||
"date: {row['modifiedAt']}\n",
|
||||
"url: {row['pageUrl']}\n",
|
||||
"novelty: {row['novelty']}\n",
|
||||
@@ -507,7 +527,7 @@
|
||||
"\"\"\"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"for i in range(5):\n",
|
||||
"for i in range(15):\n",
|
||||
" for ii in [i, -i-1]:\n",
|
||||
" row = df.iloc[i]\n",
|
||||
" s = to_markdown(row)\n",
|
||||
|
||||
Reference in New Issue
Block a user