Source code for mattertune.recipes.lora

from __future__ import annotations

import importlib.util
import logging
from typing import Any, Literal

import nshconfig as C
from lightning.pytorch import LightningModule, Trainer
from lightning.pytorch.callbacks import Callback
from typing_extensions import final, override

from ..util import optional_import_error_message
from .base import RecipeConfigBase, recipe_registry

log = logging.getLogger(__name__)


[docs] class PeftConfig(C.Config): peft_type: str | None = None """Type of PEFT method being used.""" task_type: str | None = None """Type of task being performed.""" inference_mode: bool = False """Whether to use inference mode."""
[docs] class LoraConfig(PeftConfig): r: int = 8 """LoRA attention dimension (rank).""" target_modules: list[str] | str | None = None """Names of modules to apply LoRA to. Can be a list of module names, a regex pattern, or 'all-linear'.""" lora_alpha: int = 8 """Alpha parameter for LoRA scaling.""" lora_dropout: float = 0.0 """Dropout probability for LoRA layers.""" fan_in_fan_out: bool = False """Set True if target layer stores weights as (fan_in, fan_out).""" bias: Literal["none", "all", "lora_only"] = "none" """Bias type for LoRA. Controls which biases are updated during training.""" use_rslora: bool = False """Whether to use Rank-Stabilized LoRA which sets adapter scaling to lora_alpha/sqrt(r).""" modules_to_save: list[str] | None = None """Additional modules to be trained and saved besides LoRA layers.""" init_lora_weights: bool | Literal["gaussian"] = True """Initialization method for LoRA weights.""" layers_to_transform: list[int] | int | None = None """Specific layer indices to apply LoRA transformation to.""" layers_pattern: list[str] | str | None = None """Layer pattern name used with layers_to_transform.""" rank_pattern: dict[str, Any] = {} """Mapping of layer names/patterns to custom ranks different from default r.""" alpha_pattern: dict[str, Any] = {} """Mapping of layer names/patterns to custom alphas different from default lora_alpha.""" def __post_init__(self): self.peft_type = "LORA" # Convert target_modules to set if it's a list self.target_modules = ( list(set(self.target_modules)) if isinstance(self.target_modules, list) else self.target_modules ) # Validate target_modules and layers configurations if isinstance(self.target_modules, str): if self.layers_to_transform is not None: raise ValueError( "layers_to_transform cannot be used when target_modules is a str" ) if self.layers_pattern is not None: raise ValueError( "layers_pattern cannot be used when target_modules is a str" ) def _to_peft_config(self): """Convert this configuration to a PEFT LoraConfig instance.""" with optional_import_error_message("peft"): from peft.tuners.lora import LoraConfig as PeftLoraConfig # type: ignore[reportMissingImports] # noqa # Convert back to list if target_modules is a set return PeftLoraConfig( r=self.r, target_modules=self.target_modules, lora_alpha=self.lora_alpha, lora_dropout=self.lora_dropout, fan_in_fan_out=self.fan_in_fan_out, bias=self.bias, use_rslora=self.use_rslora, modules_to_save=self.modules_to_save, init_lora_weights=self.init_lora_weights, layers_to_transform=self.layers_to_transform, layers_pattern=self.layers_pattern, rank_pattern=self.rank_pattern, alpha_pattern=self.alpha_pattern, inference_mode=self.inference_mode, task_type=self.task_type, )
[docs] @recipe_registry.register class LoRARecipeConfig(RecipeConfigBase): """ Recipe for applying Low-Rank Adaptation (LoRA) to a model. LoRA is a method for fine-tuning pre-trained models via the injection of low-rank "adapter" weights into the model's linear layers. This allows for efficient fine-tuning of large models on small datasets, while preserving the pre-trained weights in the backbone. Reference: https://arxiv.org/abs/2106.09685 """ name: Literal["lora"] = "lora" """Discriminator for the LoRA recipe.""" lora: LoraConfig """LoRA configuration."""
[docs] @override @classmethod def ensure_dependencies(cls): # Make sure the "peft" package is installed if importlib.util.find_spec("peft") is None: raise ImportError( "LoRARecipe requires the 'peft' package. To install it, run 'pip install peft'." )
[docs] @override def create_lightning_callback(self): return LoRACallback(self)
@final class LoRACallback(Callback): @override def __init__(self, config: LoRARecipeConfig): super().__init__() self.config = config @override def setup( self, trainer: Trainer, pl_module: LightningModule, stage: str, ) -> None: from ..finetune.base import FinetuneModuleBase assert isinstance( pl_module, FinetuneModuleBase ), f"LoRARecipe requires a FinetuneModuleBase, got {type(pl_module)}=" with optional_import_error_message("peft"): import peft # type: ignore[reportMissingImports] # noqa # Convert the configuration to a PEFT LoraConfig instance lora = self.config.lora._to_peft_config() # Apply LoRA to the pre-trained backbone pl_module.apply_callable_to_backbone( lambda backbone: peft.inject_adapter_in_model(lora, backbone) ) log.info("LoRA layers injected into the model")