mirror of
https://github.com/wassname/Brukino_AntiPaSTO_Appetizer.git
synced 2026-06-27 16:58:47 +08:00
Fix float16 overflow in curvature computation
This commit is contained in:
+2
-1
@@ -42,7 +42,8 @@ def compute_curvature(hidden_states):
|
||||
if hidden_states.shape[0] < 3:
|
||||
return torch.zeros(hidden_states.shape[0], device=hidden_states.device)
|
||||
|
||||
gamma = hidden_states
|
||||
# Cast to float32 to prevent float16 overflow when cubing
|
||||
gamma = hidden_states.to(torch.float32)
|
||||
d_gamma = torch.gradient(gamma, dim=0)[0]
|
||||
dd_gamma = torch.gradient(d_gamma, dim=0)[0]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user