This commit is contained in:
wassname
2025-01-05 16:03:24 +08:00
parent 03be90700c
commit 07803d2c90
5 changed files with 265 additions and 719 deletions
+1
View File
@@ -1,6 +1,7 @@
.env
lightning_logs/
outputs/
.anycache/
*.arrow
squad_*
+9 -5
View File
@@ -66,9 +66,10 @@ def perplexity_compute(
), "When add_start_token=False, each input text must be at least two tokens long. Run with add_start_token=True if inputting strings of only one token, and remove all empty input strings."
ppls = []
nlls = []
loss_fct = CrossEntropyLoss(reduction="none")
for start_index in logging.tqdm(range(0, len(encoded_texts), batch_size)):
for start_index in range(0, len(encoded_texts), batch_size):
end_index = min(start_index + batch_size, len(encoded_texts))
encoded_batch = encoded_texts[start_index:end_index]
attn_mask = attn_masks[start_index:end_index]
@@ -89,11 +90,14 @@ def perplexity_compute(
shift_labels = labels[..., 1:].contiguous()
shift_attention_mask_batch = attn_mask[..., 1:].contiguous()
perplexity_batch = torch.exp(
(loss_fct(shift_logits.transpose(1, 2), shift_labels) * shift_attention_mask_batch).sum(1)
/ shift_attention_mask_batch.sum(1)
nll_batch = (
(loss_fct(shift_logits.transpose(1, 2), shift_labels) * shift_attention_mask_batch)
)
# remove all the masked ones
nll_batch = nll_batch[shift_attention_mask_batch == 1][None, :] # FIXME only for batch_size=1
perplexity_batch = torch.exp(nll_batch.mean(1)).cpu().numpy()
ppls += perplexity_batch.tolist()
nlls += nll_batch.cpu().numpy().tolist()
return {"perplexities": ppls, "mean_perplexity": np.mean(ppls)}
return {"perplexities": np.array(ppls), "nlls": np.array(nlls)}
File diff suppressed because it is too large Load Diff
Generated
+44 -1
View File
@@ -1317,6 +1317,27 @@ qtconsole = ["qtconsole"]
test = ["packaging", "pickleshare", "pytest", "pytest-asyncio (<0.22)", "testpath"]
test-extra = ["curio", "ipython[test]", "matplotlib (!=3.2.0)", "nbformat", "numpy (>=1.23)", "pandas", "trio"]
[[package]]
name = "ipywidgets"
version = "8.1.5"
description = "Jupyter interactive widgets"
optional = false
python-versions = ">=3.7"
files = [
{file = "ipywidgets-8.1.5-py3-none-any.whl", hash = "sha256:3290f526f87ae6e77655555baba4f36681c555b8bdbbff430b70e52c34c86245"},
{file = "ipywidgets-8.1.5.tar.gz", hash = "sha256:870e43b1a35656a80c18c9503bbf2d16802db1cb487eec6fab27d683381dde17"},
]
[package.dependencies]
comm = ">=0.1.3"
ipython = ">=6.1.0"
jupyterlab-widgets = ">=3.0.12,<3.1.0"
traitlets = ">=4.3.1"
widgetsnbextension = ">=4.0.12,<4.1.0"
[package.extras]
test = ["ipykernel", "jsonschema", "pytest (>=3.6.0)", "pytest-cov", "pytz"]
[[package]]
name = "jedi"
version = "0.19.2"
@@ -1491,6 +1512,17 @@ traitlets = ">=5.3"
docs = ["myst-parser", "pydata-sphinx-theme", "sphinx-autodoc-typehints", "sphinxcontrib-github-alt", "sphinxcontrib-spelling", "traitlets"]
test = ["ipykernel", "pre-commit", "pytest (<8)", "pytest-cov", "pytest-timeout"]
[[package]]
name = "jupyterlab-widgets"
version = "3.0.13"
description = "Jupyter interactive widgets for JupyterLab"
optional = false
python-versions = ">=3.7"
files = [
{file = "jupyterlab_widgets-3.0.13-py3-none-any.whl", hash = "sha256:e3cda2c233ce144192f1e29914ad522b2f4c40e77214b0cc97377ca3d323db54"},
{file = "jupyterlab_widgets-3.0.13.tar.gz", hash = "sha256:a2966d385328c1942b683a8cd96b89b8dd82c8b8f81dda902bb2bc06d46f5bed"},
]
[[package]]
name = "kiwisolver"
version = "1.4.8"
@@ -4101,6 +4133,17 @@ files = [
{file = "wcwidth-0.2.13.tar.gz", hash = "sha256:72ea0c06399eb286d978fdedb6923a9eb47e1c486ce63e9b4e64fc18303972b5"},
]
[[package]]
name = "widgetsnbextension"
version = "4.0.13"
description = "Jupyter interactive widgets for Jupyter Notebook"
optional = false
python-versions = ">=3.7"
files = [
{file = "widgetsnbextension-4.0.13-py3-none-any.whl", hash = "sha256:74b2692e8500525cc38c2b877236ba51d34541e6385eeed5aec15a70f88a6c71"},
{file = "widgetsnbextension-4.0.13.tar.gz", hash = "sha256:ffcb67bc9febd10234a362795f643927f4e0c05d9342c727b65d2384f8feacb6"},
]
[[package]]
name = "win32-setctime"
version = "1.2.0"
@@ -4346,4 +4389,4 @@ propcache = ">=0.2.0"
[metadata]
lock-version = "2.0"
python-versions = ">=3.10,<3.13"
content-hash = "2a7966819d0850ce36c5e826aa1ca53d08d8c223fde4126b48b49b39fa055cc4"
content-hash = "f92d5fdffcee2de350804b17ff73bfa204a83470db68187c1dc78ee49e8b7eb7"
+2
View File
@@ -30,6 +30,8 @@ matplotlib = "^3.8.0"
python-frontmatter = "^1.0.1"
loguru = "^0.7.2"
anycache = "^2.2.0"
ipywidgets = "^8.1.5"
ipykernel = "^6.29.5"
[[tool.poetry.source]]
name = "pytorch"