fix: gemma-3-4b is multimodal, read num_hidden_layers via config.get_text_config()

Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
This commit is contained in:
wassname
2026-06-04 10:44:56 +08:00
parent d8aca870b7
commit 0c15562c81
2 changed files with 3 additions and 2 deletions
+2 -1
View File
@@ -50,7 +50,8 @@ def load_model(model_id: str, dtype: torch.dtype):
attn_implementation=attn,
)
model.eval()
logger.info(f"loaded {model_id} (dtype={dtype}, attn={attn}, layers={model.config.num_hidden_layers})")
n_layers = model.config.get_text_config().num_hidden_layers
logger.info(f"loaded {model_id} (dtype={dtype}, attn={attn}, layers={n_layers})")
return model, tok
+1 -1
View File
@@ -9,7 +9,7 @@ from steer_heal.prompts import POOL, chat_prompt
def _layer_band(model, layer_range: tuple[float, float]) -> tuple[int, ...]:
n = model.config.num_hidden_layers
n = model.config.get_text_config().num_hidden_layers # nested for multimodal (gemma-3-4b)
lo, hi = layer_range
return tuple(range(int(lo * n), max(int(hi * n), int(lo * n) + 1)))