From 0c15562c815242259ada4646bfccca5f7ce4d3f3 Mon Sep 17 00:00:00 2001 From: wassname <1103714+wassname@users.noreply.github.com> Date: Thu, 4 Jun 2026 10:44:56 +0800 Subject: [PATCH] fix: gemma-3-4b is multimodal, read num_hidden_layers via config.get_text_config() Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com> --- src/steer_heal/run.py | 3 ++- src/steer_heal/steering.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/steer_heal/run.py b/src/steer_heal/run.py index f3b1398..58d0ecc 100644 --- a/src/steer_heal/run.py +++ b/src/steer_heal/run.py @@ -50,7 +50,8 @@ def load_model(model_id: str, dtype: torch.dtype): attn_implementation=attn, ) model.eval() - logger.info(f"loaded {model_id} (dtype={dtype}, attn={attn}, layers={model.config.num_hidden_layers})") + n_layers = model.config.get_text_config().num_hidden_layers + logger.info(f"loaded {model_id} (dtype={dtype}, attn={attn}, layers={n_layers})") return model, tok diff --git a/src/steer_heal/steering.py b/src/steer_heal/steering.py index 1854578..978c262 100644 --- a/src/steer_heal/steering.py +++ b/src/steer_heal/steering.py @@ -9,7 +9,7 @@ from steer_heal.prompts import POOL, chat_prompt def _layer_band(model, layer_range: tuple[float, float]) -> tuple[int, ...]: - n = model.config.num_hidden_layers + n = model.config.get_text_config().num_hidden_layers # nested for multimodal (gemma-3-4b) lo, hi = layer_range return tuple(range(int(lo * n), max(int(hi * n), int(lo * n) + 1)))