mirror of
https://github.com/wassname/lora-lite.git
synced 2026-06-27 18:05:16 +08:00
7e024b4734
- variant.py: fix mislabeled "legacy entry" (make() is the live param path); drop unused near_one init branch - config.py: drop "replaces older LoraLiteConfig" history narration - antipasto_ablate.py: aspirational "should warm-start" comment -> tracked FIXME - antipasto_rot.py: cut "kept as separate variant" / "why antipasto dropped rotation" ramble - benchmark: merge duplicate antipasto/corda/asvd cfg branch - README: fill HRA row (test 59.2 / valid 70.0) - track antipasto_asvd.py (was imported+registered but uncommitted) Co-Authored-By: Claudypoo <288921227+claudypoo@users.noreply.github.com>
69 lines
2.2 KiB
Python
69 lines
2.2 KiB
Python
"""AdapterConfig: per-variant typed dataclass.
|
|
|
|
Each variant ships its own subclass under `variants/*.py` (e.g. `DeLoRAConfig`),
|
|
adding strongly-typed fields so the knobs are discoverable via IDE / dataclass
|
|
introspection rather than stringly-typed dict lookups.
|
|
|
|
Wire-up:
|
|
- `AdapterConfig` holds the universal fields (variant name, rank, alpha,
|
|
dtype, targeting filters).
|
|
- Subclasses override the `variant` default and add new fields.
|
|
- `register_config(cls)` adds the subclass to `_CONFIG_REGISTRY` so
|
|
`from_dict` can route to the right type at load time.
|
|
|
|
Save format:
|
|
to_dict() emits a flat dict including `variant`; from_dict() uses that
|
|
field to look up the right subclass.
|
|
"""
|
|
from dataclasses import dataclass, asdict
|
|
from typing import Literal
|
|
import torch
|
|
|
|
Role = Literal["reader", "writer", "inner"]
|
|
|
|
|
|
@dataclass
|
|
class AdapterConfig:
|
|
"""Base config. Subclass per variant; do not instantiate directly."""
|
|
# variant name (subclass overrides default)
|
|
variant: str = "?"
|
|
|
|
# rank-style hyperparams shared across most variants
|
|
r: int = 8
|
|
alpha: float | int = 16.0
|
|
dtype: torch.dtype = torch.bfloat16
|
|
|
|
# targeting
|
|
target_roles: tuple[Role, ...] = ("reader", "writer")
|
|
target_names: tuple[str, ...] = ()
|
|
exclude_names: tuple[str, ...] = ("lm_head", "embed_tokens")
|
|
layers: tuple[int, ...] | None = None
|
|
|
|
def to_dict(self) -> dict:
|
|
d = asdict(self)
|
|
d["dtype"] = str(self.dtype).removeprefix("torch.")
|
|
return d
|
|
|
|
@classmethod
|
|
def from_dict(cls, d: dict) -> "AdapterConfig":
|
|
d = dict(d)
|
|
name = d["variant"]
|
|
sub = _CONFIG_REGISTRY[name]
|
|
d["dtype"] = getattr(torch, d["dtype"])
|
|
return sub(**d)
|
|
|
|
|
|
# Registry of variant_name -> config subclass. Populated by `register_config`
|
|
# decorators in each `variants/*.py` module at import time.
|
|
_CONFIG_REGISTRY: dict[str, type[AdapterConfig]] = {}
|
|
|
|
|
|
def register_config(cls: type[AdapterConfig]) -> type[AdapterConfig]:
|
|
"""Decorator: register `cls` under its `variant` default value."""
|
|
name = cls.__dataclass_fields__["variant"].default
|
|
if name in _CONFIG_REGISTRY:
|
|
raise ValueError(f"config for variant {name!r} already registered")
|
|
_CONFIG_REGISTRY[name] = cls
|
|
return cls
|
|
|