Fix float16 overflow in curvature computation

This commit is contained in:
wassname
2026-04-10 08:42:05 +08:00
parent 439f51099f
commit 555dbbae3c
4 changed files with 10 additions and 8 deletions
+2 -1
View File
@@ -42,7 +42,8 @@ def compute_curvature(hidden_states):
if hidden_states.shape[0] < 3:
return torch.zeros(hidden_states.shape[0], device=hidden_states.device)
gamma = hidden_states
# Cast to float32 to prevent float16 overflow when cubing
gamma = hidden_states.to(torch.float32)
d_gamma = torch.gradient(gamma, dim=0)[0]
dd_gamma = torch.gradient(d_gamma, dim=0)[0]