mirror of
https://github.com/wassname/steer-heal-love.git
synced 2026-06-27 16:47:16 +08:00
fix: gemma-3-4b is multimodal, read num_hidden_layers via config.get_text_config()
Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
This commit is contained in:
@@ -50,7 +50,8 @@ def load_model(model_id: str, dtype: torch.dtype):
|
||||
attn_implementation=attn,
|
||||
)
|
||||
model.eval()
|
||||
logger.info(f"loaded {model_id} (dtype={dtype}, attn={attn}, layers={model.config.num_hidden_layers})")
|
||||
n_layers = model.config.get_text_config().num_hidden_layers
|
||||
logger.info(f"loaded {model_id} (dtype={dtype}, attn={attn}, layers={n_layers})")
|
||||
return model, tok
|
||||
|
||||
|
||||
|
||||
@@ -9,7 +9,7 @@ from steer_heal.prompts import POOL, chat_prompt
|
||||
|
||||
|
||||
def _layer_band(model, layer_range: tuple[float, float]) -> tuple[int, ...]:
|
||||
n = model.config.num_hidden_layers
|
||||
n = model.config.get_text_config().num_hidden_layers # nested for multimodal (gemma-3-4b)
|
||||
lo, hi = layer_range
|
||||
return tuple(range(int(lo * n), max(int(hi * n), int(lo * n) + 1)))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user