Files
lora-lite/src/lora_lite/config.py
T
wassname 7e024b4734 comment hygiene + HRA row: shorten docstrings, drop dead init branch, track asvd
- variant.py: fix mislabeled "legacy entry" (make() is the live param path); drop unused near_one init branch
- config.py: drop "replaces older LoraLiteConfig" history narration
- antipasto_ablate.py: aspirational "should warm-start" comment -> tracked FIXME
- antipasto_rot.py: cut "kept as separate variant" / "why antipasto dropped rotation" ramble
- benchmark: merge duplicate antipasto/corda/asvd cfg branch
- README: fill HRA row (test 59.2 / valid 70.0)
- track antipasto_asvd.py (was imported+registered but uncommitted)

Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
2026-06-17 11:16:07 +08:00

69 lines
2.2 KiB
Python

"""AdapterConfig: per-variant typed dataclass.
Each variant ships its own subclass under `variants/*.py` (e.g. `DeLoRAConfig`),
adding strongly-typed fields so the knobs are discoverable via IDE / dataclass
introspection rather than stringly-typed dict lookups.
Wire-up:
- `AdapterConfig` holds the universal fields (variant name, rank, alpha,
dtype, targeting filters).
- Subclasses override the `variant` default and add new fields.
- `register_config(cls)` adds the subclass to `_CONFIG_REGISTRY` so
`from_dict` can route to the right type at load time.
Save format:
to_dict() emits a flat dict including `variant`; from_dict() uses that
field to look up the right subclass.
"""
from dataclasses import dataclass, asdict
from typing import Literal
import torch
Role = Literal["reader", "writer", "inner"]
@dataclass
class AdapterConfig:
"""Base config. Subclass per variant; do not instantiate directly."""
# variant name (subclass overrides default)
variant: str = "?"
# rank-style hyperparams shared across most variants
r: int = 8
alpha: float | int = 16.0
dtype: torch.dtype = torch.bfloat16
# targeting
target_roles: tuple[Role, ...] = ("reader", "writer")
target_names: tuple[str, ...] = ()
exclude_names: tuple[str, ...] = ("lm_head", "embed_tokens")
layers: tuple[int, ...] | None = None
def to_dict(self) -> dict:
d = asdict(self)
d["dtype"] = str(self.dtype).removeprefix("torch.")
return d
@classmethod
def from_dict(cls, d: dict) -> "AdapterConfig":
d = dict(d)
name = d["variant"]
sub = _CONFIG_REGISTRY[name]
d["dtype"] = getattr(torch, d["dtype"])
return sub(**d)
# Registry of variant_name -> config subclass. Populated by `register_config`
# decorators in each `variants/*.py` module at import time.
_CONFIG_REGISTRY: dict[str, type[AdapterConfig]] = {}
def register_config(cls: type[AdapterConfig]) -> type[AdapterConfig]:
"""Decorator: register `cls` under its `variant` default value."""
name = cls.__dataclass_fields__["variant"].default
if name in _CONFIG_REGISTRY:
raise ValueError(f"config for variant {name!r} already registered")
_CONFIG_REGISTRY[name] = cls
return cls