From 555dbbae3c65716a33a5a9939388b1800abb062a Mon Sep 17 00:00:00 2001 From: wassname <1103714+wassname@users.noreply.github.com> Date: Fri, 10 Apr 2026 08:42:05 +0800 Subject: [PATCH] Fix float16 overflow in curvature computation --- .gitignore | 1 + .python-version | 1 - experiment.ipynb | 13 +++++++------ experiment.py | 3 ++- 4 files changed, 10 insertions(+), 8 deletions(-) delete mode 100644 .python-version diff --git a/.gitignore b/.gitignore index 6bb16ff..8f3f9df 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,7 @@ # Environments .env .venv +venv/ env/ venv/ ENV/ diff --git a/.python-version b/.python-version deleted file mode 100644 index 24ee5b1..0000000 --- a/.python-version +++ /dev/null @@ -1 +0,0 @@ -3.13 diff --git a/experiment.ipynb b/experiment.ipynb index 7c627bb..a6ca135 100644 --- a/experiment.ipynb +++ b/experiment.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "markdown", - "id": "eeab401b", + "id": "2dc7c826", "metadata": {}, "source": [ "# Guided CoT Eval & Frenet-Serret Curvature\n", @@ -14,7 +14,7 @@ { "cell_type": "code", "execution_count": null, - "id": "8b57586b", + "id": "11ff7ad3", "metadata": {}, "outputs": [], "source": [ @@ -38,7 +38,7 @@ { "cell_type": "code", "execution_count": null, - "id": "67394f45", + "id": "bf833680", "metadata": {}, "outputs": [], "source": [ @@ -50,7 +50,8 @@ " if hidden_states.shape[0] < 3:\n", " return torch.zeros(hidden_states.shape[0], device=hidden_states.device)\n", " \n", - " gamma = hidden_states\n", + " # Cast to float32 to prevent float16 overflow when cubing\n", + " gamma = hidden_states.to(torch.float32)\n", " d_gamma = torch.gradient(gamma, dim=0)[0]\n", " dd_gamma = torch.gradient(d_gamma, dim=0)[0]\n", " \n", @@ -64,7 +65,7 @@ { "cell_type": "code", "execution_count": null, - "id": "6d61d9ff", + "id": "227501af", "metadata": {}, "outputs": [], "source": [ @@ -115,7 +116,7 @@ { "cell_type": "code", "execution_count": null, - "id": "14a46892", + "id": "7cea1129", "metadata": {}, "outputs": [], "source": [ diff --git a/experiment.py b/experiment.py index 7145907..16982bb 100644 --- a/experiment.py +++ b/experiment.py @@ -42,7 +42,8 @@ def compute_curvature(hidden_states): if hidden_states.shape[0] < 3: return torch.zeros(hidden_states.shape[0], device=hidden_states.device) - gamma = hidden_states + # Cast to float32 to prevent float16 overflow when cubing + gamma = hidden_states.to(torch.float32) d_gamma = torch.gradient(gamma, dim=0)[0] dd_gamma = torch.gradient(d_gamma, dim=0)[0]