Files
weight-steering/axolotl_plugin_models_with_mlp_bias.py
T
2025-11-05 10:13:04 +01:00

126 lines
4.3 KiB
Python

from axolotl.integrations.base import BasePlugin
from axolotl.utils.dict import DictDefault
from transformers import (
AutoConfig,
AutoModelForCausalLM,
LlamaConfig,
Qwen2Config,
)
from models_with_mlp_bias import (
LlamaMLPBiasConfig,
LlamaMLPWithBiasForCausalLM,
Qwen2MLPBiasConfig,
Qwen2MLPWithBiasForCausalLM,
)
class MLPBiasPlugin(BasePlugin):
"""
Plugin to patch AutoModelForCausalLM.from_pretrained to use bias-enabled models.
"""
def __init__(self):
super().__init__()
self._original_from_pretrained = None
def pre_model_load(self, cfg: DictDefault):
"""
Patch AutoModelForCausalLM.from_pretrained before model loading.
"""
print("=" * 80)
print("Patching AutoModelForCausalLM.from_pretrained for MLP bias...")
print("=" * 80)
# Store original - get the actual function, not the bound method
self._original_from_pretrained = AutoModelForCausalLM.from_pretrained.__func__
@classmethod
def patched_from_pretrained(
cls, pretrained_model_name_or_path, *model_args, **kwargs
):
# Get the config
config = kwargs.get("config")
if config is None:
# Load config if not provided
config = AutoConfig.from_pretrained(
pretrained_model_name_or_path,
trust_remote_code=kwargs.get("trust_remote_code", False),
)
# Check model type and use our custom class if applicable
if isinstance(config, Qwen2Config):
print("✓ Detected Qwen2 model, using Qwen2MLPWithBiasForCausalLM")
# Update model_type using the config class
config.model_type = Qwen2MLPBiasConfig.model_type
# Update config in kwargs
kwargs["config"] = config
# Bypass AutoModel and use our class directly
return Qwen2MLPWithBiasForCausalLM.from_pretrained(
pretrained_model_name_or_path, *model_args, **kwargs
)
elif isinstance(config, LlamaConfig):
print("✓ Detected Llama model, using LlamaMLPWithBiasForCausalLM")
# Update model_type using the config class
config.model_type = LlamaMLPBiasConfig.model_type
# Update config in kwargs
kwargs["config"] = config
return LlamaMLPWithBiasForCausalLM.from_pretrained(
pretrained_model_name_or_path, *model_args, **kwargs
)
else:
raise Exception("Model not supported.")
# Apply patch - this modifies the class globally
AutoModelForCausalLM.from_pretrained = patched_from_pretrained
print("✓ AutoModelForCausalLM.from_pretrained patched globally")
print("=" * 80)
def post_model_load(self, cfg: DictDefault, model):
"""
Verify the model was loaded with bias and correct config.
"""
print("=" * 80)
print("Verifying model with MLP bias...")
print("=" * 80)
print(f"Model class: {model.__class__.__name__}")
print(f"Model config type: {model.config.model_type}")
# Check if bias exists
bias = model.model.layers[0].mlp.down_proj.bias
if bias is None:
raise Exception("⚠ Model does not have bias in down_proj!")
# Count total bias parameters
total_bias = sum(
layer.mlp.down_proj.bias.numel() for layer in model.model.layers
)
print(f"✓ Number of layers: {len(model.model.layers)}")
print(f"✓ Layer 0 down_proj.bias shape: {bias.shape}")
print(f"✓ Layer 0 down_proj.bias mean: {bias.mean().item():.6f}")
print(f"✓ Total bias parameters: {total_bias:,}")
print("=" * 80)
def post_train_unload(self, cfg: DictDefault):
"""
Restore original from_pretrained.
"""
if self._original_from_pretrained:
@classmethod
def restored(cls, *args, **kwargs):
return self._original_from_pretrained(cls, *args, **kwargs)
AutoModelForCausalLM.from_pretrained = restored
print("✓ Restored original AutoModelForCausalLM.from_pretrained")