From 1e8bc60492c5873b7e3e23909fa82be654bcf845 Mon Sep 17 00:00:00 2001 From: Benjamin Bossan Date: Mon, 13 Jan 2025 15:07:42 +0100 Subject: [PATCH] Refactor: PEFT method registration function (#2282) Goal The goal of this refactor is the following: Right now, when a new PEFT method is added, a new directory is created in src/peft/tuners/ with a config, model, etc. This is fine and self-contained. However, in addition to that, a couple of other places in the PEFT code base need to be touched for this new PEFT method to become usable. As an example, take the recently added Bone method (#2172). Ignoring tests, docs, and examples, we have the additions to src/peft/tuners/bone, but also need to: 1. Add an entry to PEFT_TYPE_TO_CONFIG_MAPPING in mapping.py. 2. Add an entry to PEFT_TYPE_TO_TUNER_MAPPING in mapping.py. 3. Add an entry to PEFT_TYPE_TO_MODEL_MAPPING in peft_model.py 4. Add an entry to PEFT_TYPE_TO_PREFIX_MAPPING in utils/constants.py 5. Add some code to get_peft_model_state_dict in utils.save_and_load.py With the changes in this PR, all these steps can be omitted. On top of that, we also have the re-imports to peft/__init__.py and peft/tuners/__init__.py but those are still required (I'm hesitant to mess with the import system). Furthermore, it's still required to add an entry to PeftType in utils.peft_types.py. Since this is an enum, it can't be easily generated automatically. Therefore, adding a new PEFT method is still not 100% self-contained. Changes in this PR With this PR, less book-keeping is required. Instead of the 5 steps described above, contributors now only need to call # example for the Bone method register_peft_method( name="bone", config_cls=BoneConfig, model_cls=BoneModel ) in the __init__.py of their PEFT method. In addition to registering the method, this also performs a couple of sanity checks (e.g. no duplicate names, method name and method prefix being identical). Moreover, since so much book keeping is removed, this PR reduces the number of lines of code overall (at the moment +317, - 343). Implementation The real difficulty of this task is that the module structure in PEFT is really messy, easily resulting in circular imports. This has been an issue in the past but has been especially painful here. For this reason, some stuff had to be moved around: - MODEL_TYPE_TO_PEFT_MODEL_MAPPING is now in auto.py instead of mapping.py - PEFT_TYPE_TO_PREFIX_MAPPING has been moved to mapping.py from constants.py - get_peft_model had to be moved out of mapping.py and is now in its own module, func.py (better name suggestions welcome). This should be safe, as the function is re-imported to the main PEFT namespace, which all examples use. The PEFT_TYPE_TO_MODEL_MAPPING dict could be completely removed, as it was basically redundant with PEFT_TYPE_TO_TUNER_MAPPING. The get_peft_model_state_dict could be simplified, as a lot of code was almost duplicated. There were a few instances in peft_model.py like: elif config.peft_type == PeftType.P_TUNING: prompt_encoder = PromptEncoder(config) Now, instead of hard-coding the model, I just do model_cls = PEFT_TYPE_TO_TUNER_MAPPING[config.peft_type]. --- src/peft/__init__.py | 8 +- src/peft/auto.py | 11 +- src/peft/mapping.py | 219 +----------------- src/peft/mapping_func.py | 129 +++++++++++ src/peft/mixed_model.py | 27 +-- src/peft/peft_model.py | 82 ++----- src/peft/tuners/adalora/__init__.py | 6 + src/peft/tuners/adaption_prompt/__init__.py | 4 + src/peft/tuners/boft/__init__.py | 4 + src/peft/tuners/bone/__init__.py | 4 + src/peft/tuners/cpt/__init__.py | 4 + src/peft/tuners/fourierft/__init__.py | 4 + src/peft/tuners/hra/__init__.py | 4 + src/peft/tuners/ia3/__init__.py | 3 + src/peft/tuners/ln_tuning/__init__.py | 4 + src/peft/tuners/loha/__init__.py | 4 + src/peft/tuners/lokr/__init__.py | 4 + src/peft/tuners/lora/__init__.py | 3 + .../multitask_prompt_tuning/__init__.py | 6 + src/peft/tuners/oft/__init__.py | 4 + src/peft/tuners/p_tuning/__init__.py | 4 + src/peft/tuners/poly/__init__.py | 4 + src/peft/tuners/prefix_tuning/__init__.py | 4 + src/peft/tuners/prompt_tuning/__init__.py | 4 + src/peft/tuners/vblora/__init__.py | 4 + src/peft/tuners/vera/__init__.py | 4 + src/peft/tuners/vera/model.py | 2 +- src/peft/tuners/xlora/__init__.py | 4 + src/peft/utils/__init__.py | 3 +- src/peft/utils/constants.py | 19 -- src/peft/utils/hotswap.py | 3 +- src/peft/utils/peft_types.py | 78 +++++++ src/peft/utils/save_and_load.py | 44 +--- tests/test_adaption_prompt.py | 2 +- tests/test_initialization.py | 2 +- tests/test_multitask_prompt_tuning.py | 2 +- 36 files changed, 360 insertions(+), 357 deletions(-) create mode 100644 src/peft/mapping_func.py diff --git a/src/peft/__init__.py b/src/peft/__init__.py index 1ba77231e1..a43c171bbd 100644 --- a/src/peft/__init__.py +++ b/src/peft/__init__.py @@ -15,6 +15,7 @@ __version__ = "0.14.1.dev0" from .auto import ( + MODEL_TYPE_TO_PEFT_MODEL_MAPPING, AutoPeftModel, AutoPeftModelForCausalLM, AutoPeftModelForFeatureExtraction, @@ -25,12 +26,13 @@ ) from .config import PeftConfig, PromptLearningConfig from .mapping import ( - MODEL_TYPE_TO_PEFT_MODEL_MAPPING, PEFT_TYPE_TO_CONFIG_MAPPING, + PEFT_TYPE_TO_MIXED_MODEL_MAPPING, + PEFT_TYPE_TO_TUNER_MAPPING, get_peft_config, - get_peft_model, inject_adapter_in_model, ) +from .mapping_func import get_peft_model from .mixed_model import PeftMixedModel from .peft_model import ( PeftModel, @@ -112,6 +114,8 @@ __all__ = [ "MODEL_TYPE_TO_PEFT_MODEL_MAPPING", "PEFT_TYPE_TO_CONFIG_MAPPING", + "PEFT_TYPE_TO_MIXED_MODEL_MAPPING", + "PEFT_TYPE_TO_TUNER_MAPPING", "TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING", "AdaLoraConfig", "AdaLoraModel", diff --git a/src/peft/auto.py b/src/peft/auto.py index 18933be61e..4f890994d7 100644 --- a/src/peft/auto.py +++ b/src/peft/auto.py @@ -29,7 +29,6 @@ ) from .config import PeftConfig -from .mapping import MODEL_TYPE_TO_PEFT_MODEL_MAPPING from .peft_model import ( PeftModel, PeftModelForCausalLM, @@ -43,6 +42,16 @@ from .utils.other import check_file_exists_on_hf_hub +MODEL_TYPE_TO_PEFT_MODEL_MAPPING: dict[str, type[PeftModel]] = { + "SEQ_CLS": PeftModelForSequenceClassification, + "SEQ_2_SEQ_LM": PeftModelForSeq2SeqLM, + "CAUSAL_LM": PeftModelForCausalLM, + "TOKEN_CLS": PeftModelForTokenClassification, + "QUESTION_ANS": PeftModelForQuestionAnswering, + "FEATURE_EXTRACTION": PeftModelForFeatureExtraction, +} + + class _BaseAutoPeftModel: _target_class = None _target_peft_class = None diff --git a/src/peft/mapping.py b/src/peft/mapping.py index 67c29bd901..a92b28f858 100644 --- a/src/peft/mapping.py +++ b/src/peft/mapping.py @@ -14,123 +14,23 @@ from __future__ import annotations -import warnings -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any import torch -from peft.tuners.xlora.model import XLoraModel - -from .config import PeftConfig -from .mixed_model import PeftMixedModel -from .peft_model import ( - PeftModel, - PeftModelForCausalLM, - PeftModelForFeatureExtraction, - PeftModelForQuestionAnswering, - PeftModelForSeq2SeqLM, - PeftModelForSequenceClassification, - PeftModelForTokenClassification, -) -from .tuners import ( - AdaLoraConfig, - AdaLoraModel, - AdaptionPromptConfig, - BOFTConfig, - BOFTModel, - BoneConfig, - BoneModel, - CPTConfig, - CPTEmbedding, - FourierFTConfig, - FourierFTModel, - HRAConfig, - HRAModel, - IA3Config, - IA3Model, - LNTuningConfig, - LNTuningModel, - LoHaConfig, - LoHaModel, - LoKrConfig, - LoKrModel, - LoraConfig, - LoraModel, - MultitaskPromptTuningConfig, - OFTConfig, - OFTModel, - PolyConfig, - PolyModel, - PrefixTuningConfig, - PromptEncoderConfig, - PromptTuningConfig, - VBLoRAConfig, - VBLoRAModel, - VeraConfig, - VeraModel, - XLoraConfig, -) -from .tuners.tuners_utils import BaseTuner, BaseTunerLayer -from .utils import _prepare_prompt_learning_config -from .utils.constants import PEFT_TYPE_TO_PREFIX_MAPPING +from .utils import PeftType if TYPE_CHECKING: - from transformers import PreTrainedModel - + from .config import PeftConfig + from .tuners.tuners_utils import BaseTuner -MODEL_TYPE_TO_PEFT_MODEL_MAPPING: dict[str, type[PeftModel]] = { - "SEQ_CLS": PeftModelForSequenceClassification, - "SEQ_2_SEQ_LM": PeftModelForSeq2SeqLM, - "CAUSAL_LM": PeftModelForCausalLM, - "TOKEN_CLS": PeftModelForTokenClassification, - "QUESTION_ANS": PeftModelForQuestionAnswering, - "FEATURE_EXTRACTION": PeftModelForFeatureExtraction, -} -PEFT_TYPE_TO_CONFIG_MAPPING: dict[str, type[PeftConfig]] = { - "ADAPTION_PROMPT": AdaptionPromptConfig, - "PROMPT_TUNING": PromptTuningConfig, - "PREFIX_TUNING": PrefixTuningConfig, - "P_TUNING": PromptEncoderConfig, - "LORA": LoraConfig, - "LOHA": LoHaConfig, - "LORAPLUS": LoraConfig, - "LOKR": LoKrConfig, - "ADALORA": AdaLoraConfig, - "BOFT": BOFTConfig, - "IA3": IA3Config, - "MULTITASK_PROMPT_TUNING": MultitaskPromptTuningConfig, - "OFT": OFTConfig, - "POLY": PolyConfig, - "LN_TUNING": LNTuningConfig, - "VERA": VeraConfig, - "FOURIERFT": FourierFTConfig, - "XLORA": XLoraConfig, - "HRA": HRAConfig, - "VBLORA": VBLoRAConfig, - "CPT": CPTConfig, - "BONE": BoneConfig, -} - -PEFT_TYPE_TO_TUNER_MAPPING: dict[str, type[BaseTuner]] = { - "LORA": LoraModel, - "LOHA": LoHaModel, - "LOKR": LoKrModel, - "ADALORA": AdaLoraModel, - "BOFT": BOFTModel, - "IA3": IA3Model, - "OFT": OFTModel, - "POLY": PolyModel, - "LN_TUNING": LNTuningModel, - "VERA": VeraModel, - "FOURIERFT": FourierFTModel, - "XLORA": XLoraModel, - "HRA": HRAModel, - "VBLORA": VBLoRAModel, - "CPT": CPTEmbedding, - "BONE": BoneModel, -} +# these will be filled by the register_peft_method function +PEFT_TYPE_TO_CONFIG_MAPPING: dict[PeftType, type[PeftConfig]] = {} +PEFT_TYPE_TO_TUNER_MAPPING: dict[PeftType, type[BaseTuner]] = {} +PEFT_TYPE_TO_MIXED_MODEL_MAPPING: dict[PeftType, type[BaseTuner]] = {} +PEFT_TYPE_TO_PREFIX_MAPPING: dict[PeftType, str] = {} def get_peft_config(config_dict: dict[str, Any]) -> PeftConfig: @@ -144,107 +44,6 @@ def get_peft_config(config_dict: dict[str, Any]) -> PeftConfig: return PEFT_TYPE_TO_CONFIG_MAPPING[config_dict["peft_type"]](**config_dict) -def get_peft_model( - model: PreTrainedModel, - peft_config: PeftConfig, - adapter_name: str = "default", - mixed: bool = False, - autocast_adapter_dtype: bool = True, - revision: Optional[str] = None, - low_cpu_mem_usage: bool = False, -) -> PeftModel | PeftMixedModel: - """ - Returns a Peft model object from a model and a config, where the model will be modified in-place. - - Args: - model ([`transformers.PreTrainedModel`]): - Model to be wrapped. - peft_config ([`PeftConfig`]): - Configuration object containing the parameters of the Peft model. - adapter_name (`str`, `optional`, defaults to `"default"`): - The name of the adapter to be injected, if not provided, the default adapter name is used ("default"). - mixed (`bool`, `optional`, defaults to `False`): - Whether to allow mixing different (compatible) adapter types. - autocast_adapter_dtype (`bool`, *optional*): - Whether to autocast the adapter dtype. Defaults to `True`. Right now, this will only cast adapter weights - using float16 or bfloat16 to float32, as this is typically required for stable training, and only affect - select PEFT tuners. - revision (`str`, `optional`, defaults to `main`): - The revision of the base model. If this isn't set, the saved peft model will load the `main` revision for - the base model - low_cpu_mem_usage (`bool`, `optional`, defaults to `False`): - Create empty adapter weights on meta device. Useful to speed up the loading process. Leave this setting as - False if you intend on training the model, unless the adapter weights will be replaced by different weights - before training starts. - """ - model_config = BaseTuner.get_model_config(model) - old_name = peft_config.base_model_name_or_path - new_name = model.__dict__.get("name_or_path", None) - peft_config.base_model_name_or_path = new_name - - # Especially in notebook environments there could be a case that a user wants to experiment with different - # configuration values. However, it is likely that there won't be any changes for new configs on an already - # initialized PEFT model. The best we can do is warn the user about it. - if any(isinstance(module, BaseTunerLayer) for module in model.modules()): - warnings.warn( - "You are trying to modify a model with PEFT for a second time. If you want to reload the model with a " - "different config, make sure to call `.unload()` before." - ) - - if (old_name is not None) and (old_name != new_name): - warnings.warn( - f"The PEFT config's `base_model_name_or_path` was renamed from '{old_name}' to '{new_name}'. " - "Please ensure that the correct base model is loaded when loading this checkpoint." - ) - - if revision is not None: - if peft_config.revision is not None and peft_config.revision != revision: - warnings.warn( - f"peft config has already set base model revision to {peft_config.revision}, overwriting with revision {revision}" - ) - peft_config.revision = revision - - if ( - (isinstance(peft_config, PEFT_TYPE_TO_CONFIG_MAPPING["LORA"])) - and (peft_config.init_lora_weights == "eva") - and not low_cpu_mem_usage - ): - warnings.warn( - "lora with eva initialization used with low_cpu_mem_usage=False. " - "Setting low_cpu_mem_usage=True can improve the maximum batch size possible for eva initialization." - ) - - prefix = PEFT_TYPE_TO_PREFIX_MAPPING.get(peft_config.peft_type) - if prefix and adapter_name in prefix: - warnings.warn( - f"Adapter name {adapter_name} should not be contained in the prefix {prefix}." - "This may lead to reinitialization of the adapter weights during loading." - ) - - if mixed: - # note: PeftMixedModel does not support autocast_adapter_dtype, so don't pass it - return PeftMixedModel(model, peft_config, adapter_name=adapter_name) - - if peft_config.task_type not in MODEL_TYPE_TO_PEFT_MODEL_MAPPING.keys() and not peft_config.is_prompt_learning: - return PeftModel( - model, - peft_config, - adapter_name=adapter_name, - autocast_adapter_dtype=autocast_adapter_dtype, - low_cpu_mem_usage=low_cpu_mem_usage, - ) - - if peft_config.is_prompt_learning: - peft_config = _prepare_prompt_learning_config(peft_config, model_config) - return MODEL_TYPE_TO_PEFT_MODEL_MAPPING[peft_config.task_type]( - model, - peft_config, - adapter_name=adapter_name, - autocast_adapter_dtype=autocast_adapter_dtype, - low_cpu_mem_usage=low_cpu_mem_usage, - ) - - def inject_adapter_in_model( peft_config: PeftConfig, model: torch.nn.Module, adapter_name: str = "default", low_cpu_mem_usage: bool = False ) -> torch.nn.Module: diff --git a/src/peft/mapping_func.py b/src/peft/mapping_func.py new file mode 100644 index 0000000000..279ce499d1 --- /dev/null +++ b/src/peft/mapping_func.py @@ -0,0 +1,129 @@ +# Copyright 2024-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import warnings +from typing import Optional + +from transformers import PreTrainedModel + +from .auto import MODEL_TYPE_TO_PEFT_MODEL_MAPPING +from .config import PeftConfig +from .mapping import PEFT_TYPE_TO_CONFIG_MAPPING, PEFT_TYPE_TO_PREFIX_MAPPING +from .mixed_model import PeftMixedModel +from .peft_model import PeftModel +from .tuners.tuners_utils import BaseTuner, BaseTunerLayer +from .utils import _prepare_prompt_learning_config + + +def get_peft_model( + model: PreTrainedModel, + peft_config: PeftConfig, + adapter_name: str = "default", + mixed: bool = False, + autocast_adapter_dtype: bool = True, + revision: Optional[str] = None, + low_cpu_mem_usage: bool = False, +) -> PeftModel | PeftMixedModel: + """ + Returns a Peft model object from a model and a config, where the model will be modified in-place. + + Args: + model ([`transformers.PreTrainedModel`]): + Model to be wrapped. + peft_config ([`PeftConfig`]): + Configuration object containing the parameters of the Peft model. + adapter_name (`str`, `optional`, defaults to `"default"`): + The name of the adapter to be injected, if not provided, the default adapter name is used ("default"). + mixed (`bool`, `optional`, defaults to `False`): + Whether to allow mixing different (compatible) adapter types. + autocast_adapter_dtype (`bool`, *optional*): + Whether to autocast the adapter dtype. Defaults to `True`. Right now, this will only cast adapter weights + using float16 or bfloat16 to float32, as this is typically required for stable training, and only affect + select PEFT tuners. + revision (`str`, `optional`, defaults to `main`): + The revision of the base model. If this isn't set, the saved peft model will load the `main` revision for + the base model + low_cpu_mem_usage (`bool`, `optional`, defaults to `False`): + Create empty adapter weights on meta device. Useful to speed up the loading process. Leave this setting as + False if you intend on training the model, unless the adapter weights will be replaced by different weights + before training starts. + """ + model_config = BaseTuner.get_model_config(model) + old_name = peft_config.base_model_name_or_path + new_name = model.__dict__.get("name_or_path", None) + peft_config.base_model_name_or_path = new_name + + # Especially in notebook environments there could be a case that a user wants to experiment with different + # configuration values. However, it is likely that there won't be any changes for new configs on an already + # initialized PEFT model. The best we can do is warn the user about it. + if any(isinstance(module, BaseTunerLayer) for module in model.modules()): + warnings.warn( + "You are trying to modify a model with PEFT for a second time. If you want to reload the model with a " + "different config, make sure to call `.unload()` before." + ) + + if (old_name is not None) and (old_name != new_name): + warnings.warn( + f"The PEFT config's `base_model_name_or_path` was renamed from '{old_name}' to '{new_name}'. " + "Please ensure that the correct base model is loaded when loading this checkpoint." + ) + + if revision is not None: + if peft_config.revision is not None and peft_config.revision != revision: + warnings.warn( + f"peft config has already set base model revision to {peft_config.revision}, overwriting with revision {revision}" + ) + peft_config.revision = revision + + if ( + (isinstance(peft_config, PEFT_TYPE_TO_CONFIG_MAPPING["LORA"])) + and (peft_config.init_lora_weights == "eva") + and not low_cpu_mem_usage + ): + warnings.warn( + "lora with eva initialization used with low_cpu_mem_usage=False. " + "Setting low_cpu_mem_usage=True can improve the maximum batch size possible for eva initialization." + ) + + prefix = PEFT_TYPE_TO_PREFIX_MAPPING.get(peft_config.peft_type) + if prefix and adapter_name in prefix: + warnings.warn( + f"Adapter name {adapter_name} should not be contained in the prefix {prefix}." + "This may lead to reinitialization of the adapter weights during loading." + ) + + if mixed: + # note: PeftMixedModel does not support autocast_adapter_dtype, so don't pass it + return PeftMixedModel(model, peft_config, adapter_name=adapter_name) + + if peft_config.task_type not in MODEL_TYPE_TO_PEFT_MODEL_MAPPING.keys() and not peft_config.is_prompt_learning: + return PeftModel( + model, + peft_config, + adapter_name=adapter_name, + autocast_adapter_dtype=autocast_adapter_dtype, + low_cpu_mem_usage=low_cpu_mem_usage, + ) + + if peft_config.is_prompt_learning: + peft_config = _prepare_prompt_learning_config(peft_config, model_config) + return MODEL_TYPE_TO_PEFT_MODEL_MAPPING[peft_config.task_type]( + model, + peft_config, + adapter_name=adapter_name, + autocast_adapter_dtype=autocast_adapter_dtype, + low_cpu_mem_usage=low_cpu_mem_usage, + ) diff --git a/src/peft/mixed_model.py b/src/peft/mixed_model.py index 098bd0d4fd..c1f11c7aa9 100644 --- a/src/peft/mixed_model.py +++ b/src/peft/mixed_model.py @@ -27,25 +27,8 @@ from .config import PeftConfig from .peft_model import PeftModel -from .tuners import ( - AdaLoraModel, - IA3Model, - LoHaModel, - LoKrModel, - LoraModel, - MixedModel, -) -from .tuners.mixed import COMPATIBLE_TUNER_TYPES -from .utils import PeftType, _set_adapter, _set_trainable - - -PEFT_TYPE_TO_MODEL_MAPPING = { - PeftType.LORA: LoraModel, - PeftType.LOHA: LoHaModel, - PeftType.LOKR: LoKrModel, - PeftType.ADALORA: AdaLoraModel, - PeftType.IA3: IA3Model, -} +from .tuners import MixedModel +from .utils import _set_adapter, _set_trainable def _prepare_model_for_gradient_checkpointing(model: nn.Module) -> None: @@ -72,6 +55,8 @@ def make_inputs_require_grad(module, input, output): def _check_config_compatible(peft_config: PeftConfig) -> None: + from .tuners.mixed import COMPATIBLE_TUNER_TYPES + if peft_config.peft_type not in COMPATIBLE_TUNER_TYPES: raise ValueError( f"The provided `peft_type` '{peft_config.peft_type.value}' is not compatible with the `PeftMixedModel`. " @@ -440,7 +425,7 @@ def from_pretrained( Additional keyword arguments passed along to the specific PEFT configuration class. """ # note: adapted from PeftModel.from_pretrained - from .mapping import PEFT_TYPE_TO_CONFIG_MAPPING + from .mapping import PEFT_TYPE_TO_CONFIG_MAPPING, PEFT_TYPE_TO_MIXED_MODEL_MAPPING # load the config if config is None: @@ -459,7 +444,7 @@ def from_pretrained( raise ValueError(f"The input config must be a PeftConfig, got {config.__class__}") # note: this is different from PeftModel.from_pretrained - if config.peft_type not in PEFT_TYPE_TO_MODEL_MAPPING: + if config.peft_type not in PEFT_TYPE_TO_MIXED_MODEL_MAPPING: raise ValueError(f"Adapter of type {config.peft_type} is not supported for mixed models.") if (getattr(model, "hf_device_map", None) is not None) and len( diff --git a/src/peft/peft_model.py b/src/peft/peft_model.py index a3c7535052..62061a84e8 100644 --- a/src/peft/peft_model.py +++ b/src/peft/peft_model.py @@ -38,36 +38,13 @@ from transformers.modeling_outputs import QuestionAnsweringModelOutput, SequenceClassifierOutput, TokenClassifierOutput from transformers.utils import PushToHubMixin -from peft.utils.constants import DUMMY_MODEL_CONFIG, PEFT_TYPE_TO_PREFIX_MAPPING +from peft.tuners.tuners_utils import BaseTuner, BaseTunerLayer +from peft.utils.constants import DUMMY_MODEL_CONFIG from peft.utils.integrations import init_empty_weights from . import __version__ from .config import PeftConfig -from .tuners import ( - AdaLoraModel, - AdaptionPromptModel, - BOFTModel, - BoneModel, - CPTEmbedding, - FourierFTModel, - HRAModel, - IA3Model, - LNTuningModel, - LoHaModel, - LoKrModel, - LoraModel, - MultitaskPromptEmbedding, - OFTModel, - PolyModel, - PrefixEncoder, - PromptEmbedding, - PromptEncoder, - VBLoRAModel, - VeraModel, - XLoraConfig, - XLoraModel, -) -from .tuners.tuners_utils import BaseTuner, BaseTunerLayer +from .mapping import PEFT_TYPE_TO_CONFIG_MAPPING, PEFT_TYPE_TO_PREFIX_MAPPING, PEFT_TYPE_TO_TUNER_MAPPING from .utils import ( SAFETENSORS_WEIGHTS_NAME, TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING, @@ -88,30 +65,6 @@ ) -PEFT_TYPE_TO_MODEL_MAPPING = { - PeftType.LORA: LoraModel, - PeftType.LOHA: LoHaModel, - PeftType.LOKR: LoKrModel, - PeftType.PROMPT_TUNING: PromptEmbedding, - PeftType.P_TUNING: PromptEncoder, - PeftType.PREFIX_TUNING: PrefixEncoder, - PeftType.ADALORA: AdaLoraModel, - PeftType.BOFT: BOFTModel, - PeftType.ADAPTION_PROMPT: AdaptionPromptModel, - PeftType.IA3: IA3Model, - PeftType.OFT: OFTModel, - PeftType.POLY: PolyModel, - PeftType.LN_TUNING: LNTuningModel, - PeftType.VERA: VeraModel, - PeftType.FOURIERFT: FourierFTModel, - PeftType.XLORA: XLoraModel, - PeftType.HRA: HRAModel, - PeftType.VBLORA: VBLoRAModel, - PeftType.CPT: CPTEmbedding, - PeftType.BONE: BoneModel, -} - - class PeftModel(PushToHubMixin, torch.nn.Module): """ Base model encompassing various Peft methods. @@ -171,7 +124,7 @@ def __init__( self.add_adapter(adapter_name, peft_config, low_cpu_mem_usage=low_cpu_mem_usage) else: self._peft_config = None - cls = PEFT_TYPE_TO_MODEL_MAPPING[peft_config.peft_type] + cls = PEFT_TYPE_TO_TUNER_MAPPING[peft_config.peft_type] ctx = init_empty_weights if low_cpu_mem_usage else nullcontext with ctx(): self.base_model = cls(model, {adapter_name: peft_config}, adapter_name) @@ -474,7 +427,8 @@ def from_pretrained( kwargs: (`optional`): Additional keyword arguments passed along to the specific PEFT configuration class. """ - from .mapping import MODEL_TYPE_TO_PEFT_MODEL_MAPPING, PEFT_TYPE_TO_CONFIG_MAPPING + from .auto import MODEL_TYPE_TO_PEFT_MODEL_MAPPING + from .tuners import XLoraConfig, XLoraModel # load the config if config is None: @@ -660,20 +614,17 @@ def _setup_prompt_encoder(self, adapter_name: str): break self.word_embeddings = word_embeddings + model_cls = PEFT_TYPE_TO_TUNER_MAPPING[config.peft_type] - if config.peft_type == PeftType.PROMPT_TUNING: - prompt_encoder = PromptEmbedding(config, self.word_embeddings) - elif config.peft_type == PeftType.MULTITASK_PROMPT_TUNING: - prompt_encoder = MultitaskPromptEmbedding(config, self.word_embeddings) + if config.peft_type in (PeftType.PROMPT_TUNING, PeftType.MULTITASK_PROMPT_TUNING, PeftType.CPT): + prompt_encoder = model_cls(config, self.word_embeddings) elif config.peft_type == PeftType.P_TUNING: - prompt_encoder = PromptEncoder(config) + prompt_encoder = model_cls(config) elif config.peft_type == PeftType.PREFIX_TUNING: # prefix tuning now uses Cache but that won't work with gradient checkpointing if any(getattr(module, "gradient_checkpointing", False) for module in self.get_base_model().modules()): raise ValueError("Prefix tuning does not work with gradient checkpointing.") - prompt_encoder = PrefixEncoder(config) - elif config.peft_type == PeftType.CPT: - prompt_encoder = CPTEmbedding(config, self.word_embeddings) + prompt_encoder = model_cls(config) else: raise ValueError("Not supported") @@ -711,11 +662,13 @@ def get_prompt_embedding_to_save(self, adapter_name: str) -> torch.Tensor: prompt_tokens = ( self.prompt_tokens[adapter_name].unsqueeze(0).expand(1, -1).to(prompt_encoder.embedding.weight.device) ) + peft_type = self.peft_config[adapter_name].peft_type if self.peft_config[adapter_name].peft_type == PeftType.PREFIX_TUNING: prompt_tokens = prompt_tokens[:, : self.peft_config[adapter_name].num_virtual_tokens] if self.peft_config[adapter_name].peft_type == PeftType.MULTITASK_PROMPT_TUNING: - prompt_embeddings = super(MultitaskPromptEmbedding, prompt_encoder).forward(prompt_tokens) + prompt_embedding_cls = PEFT_TYPE_TO_TUNER_MAPPING[peft_type] + prompt_embeddings = super(prompt_embedding_cls, prompt_encoder).forward(prompt_tokens) else: prompt_embeddings = prompt_encoder(prompt_tokens) @@ -1794,9 +1747,7 @@ def forward( inputs_embeds = torch.cat((prompts, inputs_embeds), dim=1) return self.base_model(inputs_embeds=inputs_embeds, **kwargs) - def _cpt_forward( - self, input_ids=None, inputs_embeds=None, peft_config=None, task_ids=None, batch_size=None, **kwargs - ): + def _cpt_forward(self, input_ids, inputs_embeds, peft_config, task_ids, batch_size, **kwargs): # Extract labels from kwargs labels = kwargs.pop("labels") device = [i.device for i in [input_ids, inputs_embeds, labels] if i is not None][0] @@ -1846,7 +1797,8 @@ def _cpt_forward( return base_model_output else: # Calculate the loss using the custom CPT loss function - base_model_output = CPTEmbedding.calculate_loss( + cpt_embedding = PEFT_TYPE_TO_TUNER_MAPPING[peft_config.peft_type] + base_model_output = cpt_embedding.calculate_loss( base_model_output, cpt_labels, cpt_type_mask, self.peft_config["default"] ) return base_model_output diff --git a/src/peft/tuners/adalora/__init__.py b/src/peft/tuners/adalora/__init__.py index bdec26b522..64d5f3e5ce 100644 --- a/src/peft/tuners/adalora/__init__.py +++ b/src/peft/tuners/adalora/__init__.py @@ -13,6 +13,7 @@ # limitations under the License. from peft.import_utils import is_bnb_4bit_available, is_bnb_available +from peft.utils import register_peft_method from .config import AdaLoraConfig from .gptq import SVDQuantLinear @@ -23,6 +24,11 @@ __all__ = ["AdaLoraConfig", "AdaLoraLayer", "AdaLoraModel", "RankAllocator", "SVDLinear", "SVDQuantLinear"] +register_peft_method( + name="adalora", config_cls=AdaLoraConfig, model_cls=AdaLoraModel, prefix="lora_", is_mixed_compatible=True +) + + def __getattr__(name): if (name == "SVDLinear8bitLt") and is_bnb_available(): from .bnb import SVDLinear8bitLt diff --git a/src/peft/tuners/adaption_prompt/__init__.py b/src/peft/tuners/adaption_prompt/__init__.py index 826e115d3c..68882a2226 100644 --- a/src/peft/tuners/adaption_prompt/__init__.py +++ b/src/peft/tuners/adaption_prompt/__init__.py @@ -11,9 +11,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from peft.utils import register_peft_method + from .config import AdaptionPromptConfig from .layer import AdaptedAttention from .model import AdaptionPromptModel __all__ = ["AdaptedAttention", "AdaptionPromptConfig", "AdaptionPromptModel"] + +register_peft_method(name="adaption_prompt", config_cls=AdaptionPromptConfig, model_cls=AdaptionPromptModel) diff --git a/src/peft/tuners/boft/__init__.py b/src/peft/tuners/boft/__init__.py index 5b72b73951..c84b8358da 100644 --- a/src/peft/tuners/boft/__init__.py +++ b/src/peft/tuners/boft/__init__.py @@ -12,9 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +from peft.utils import register_peft_method + from .config import BOFTConfig from .layer import BOFTLayer from .model import BOFTModel __all__ = ["BOFTConfig", "BOFTLayer", "BOFTModel"] + +register_peft_method(name="boft", config_cls=BOFTConfig, model_cls=BOFTModel) diff --git a/src/peft/tuners/bone/__init__.py b/src/peft/tuners/bone/__init__.py index d2a41552d5..f131e8c17d 100644 --- a/src/peft/tuners/bone/__init__.py +++ b/src/peft/tuners/bone/__init__.py @@ -12,9 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +from peft.utils import register_peft_method + from .config import BoneConfig from .layer import BoneLayer, BoneLinear from .model import BoneModel __all__ = ["BoneConfig", "BoneLayer", "BoneLinear", "BoneModel"] + +register_peft_method(name="bone", config_cls=BoneConfig, model_cls=BoneModel) diff --git a/src/peft/tuners/cpt/__init__.py b/src/peft/tuners/cpt/__init__.py index f5018f89b1..fcd4de8598 100644 --- a/src/peft/tuners/cpt/__init__.py +++ b/src/peft/tuners/cpt/__init__.py @@ -13,8 +13,12 @@ # limitations under the License. +from peft.utils import register_peft_method + from .config import CPTConfig from .model import CPTEmbedding __all__ = ["CPTConfig", "CPTEmbedding"] + +register_peft_method(name="cpt", config_cls=CPTConfig, model_cls=CPTEmbedding) diff --git a/src/peft/tuners/fourierft/__init__.py b/src/peft/tuners/fourierft/__init__.py index 7646d9497c..dfe3f5d89e 100644 --- a/src/peft/tuners/fourierft/__init__.py +++ b/src/peft/tuners/fourierft/__init__.py @@ -12,9 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +from peft.utils import register_peft_method + from .config import FourierFTConfig from .layer import FourierFTLayer, FourierFTLinear from .model import FourierFTModel __all__ = ["FourierFTConfig", "FourierFTLayer", "FourierFTLinear", "FourierFTModel"] + +register_peft_method(name="fourierft", model_cls=FourierFTModel, config_cls=FourierFTConfig) diff --git a/src/peft/tuners/hra/__init__.py b/src/peft/tuners/hra/__init__.py index 11902f2ec1..8f5f6a5443 100644 --- a/src/peft/tuners/hra/__init__.py +++ b/src/peft/tuners/hra/__init__.py @@ -12,9 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +from peft.utils import register_peft_method + from .config import HRAConfig from .layer import HRAConv2d, HRALayer, HRALinear from .model import HRAModel __all__ = ["HRAConfig", "HRAConv2d", "HRALayer", "HRALinear", "HRAModel"] + +register_peft_method(name="hra", config_cls=HRAConfig, model_cls=HRAModel) diff --git a/src/peft/tuners/ia3/__init__.py b/src/peft/tuners/ia3/__init__.py index f88f8411a0..21cab4d6d8 100644 --- a/src/peft/tuners/ia3/__init__.py +++ b/src/peft/tuners/ia3/__init__.py @@ -13,6 +13,7 @@ # limitations under the License. from peft.import_utils import is_bnb_4bit_available, is_bnb_available +from peft.utils import register_peft_method from .config import IA3Config from .layer import Conv2d, Conv3d, IA3Layer, Linear @@ -21,6 +22,8 @@ __all__ = ["Conv2d", "Conv3d", "IA3Config", "IA3Layer", "IA3Model", "Linear"] +register_peft_method(name="ia3", config_cls=IA3Config, model_cls=IA3Model, is_mixed_compatible=True) + def __getattr__(name): if (name == "Linear8bitLt") and is_bnb_available(): diff --git a/src/peft/tuners/ln_tuning/__init__.py b/src/peft/tuners/ln_tuning/__init__.py index afaae73a44..8f90a8fb05 100644 --- a/src/peft/tuners/ln_tuning/__init__.py +++ b/src/peft/tuners/ln_tuning/__init__.py @@ -12,8 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +from peft.utils import register_peft_method + from .config import LNTuningConfig from .model import LNTuningModel __all__ = ["LNTuningConfig", "LNTuningModel"] + +register_peft_method(name="ln_tuning", config_cls=LNTuningConfig, model_cls=LNTuningModel) diff --git a/src/peft/tuners/loha/__init__.py b/src/peft/tuners/loha/__init__.py index cc54826fe4..70dd1545bb 100644 --- a/src/peft/tuners/loha/__init__.py +++ b/src/peft/tuners/loha/__init__.py @@ -12,9 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +from peft.utils import register_peft_method + from .config import LoHaConfig from .layer import Conv2d, Linear, LoHaLayer from .model import LoHaModel __all__ = ["Conv2d", "Linear", "LoHaConfig", "LoHaLayer", "LoHaModel"] + +register_peft_method(name="loha", config_cls=LoHaConfig, model_cls=LoHaModel, prefix="hada_", is_mixed_compatible=True) diff --git a/src/peft/tuners/lokr/__init__.py b/src/peft/tuners/lokr/__init__.py index 1b0c9f5438..f4fe0e92c6 100644 --- a/src/peft/tuners/lokr/__init__.py +++ b/src/peft/tuners/lokr/__init__.py @@ -12,9 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +from peft.utils import register_peft_method + from .config import LoKrConfig from .layer import Conv2d, Linear, LoKrLayer from .model import LoKrModel __all__ = ["Conv2d", "Linear", "LoKrConfig", "LoKrLayer", "LoKrModel"] + +register_peft_method(name="lokr", config_cls=LoKrConfig, model_cls=LoKrModel, is_mixed_compatible=True) diff --git a/src/peft/tuners/lora/__init__.py b/src/peft/tuners/lora/__init__.py index de1f423cd0..779d4eec79 100644 --- a/src/peft/tuners/lora/__init__.py +++ b/src/peft/tuners/lora/__init__.py @@ -13,6 +13,7 @@ # limitations under the License. from peft.import_utils import is_bnb_4bit_available, is_bnb_available, is_eetq_available +from peft.utils import register_peft_method from .config import EvaConfig, LoftQConfig, LoraConfig, LoraRuntimeConfig from .eva import get_eva_state_dict, initialize_lora_eva_weights @@ -37,6 +38,8 @@ "initialize_lora_eva_weights", ] +register_peft_method(name="lora", config_cls=LoraConfig, model_cls=LoraModel, is_mixed_compatible=True) + def __getattr__(name): if (name == "Linear8bitLt") and is_bnb_available(): diff --git a/src/peft/tuners/multitask_prompt_tuning/__init__.py b/src/peft/tuners/multitask_prompt_tuning/__init__.py index d9f98d4a76..fe692a9337 100644 --- a/src/peft/tuners/multitask_prompt_tuning/__init__.py +++ b/src/peft/tuners/multitask_prompt_tuning/__init__.py @@ -12,8 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. +from peft.utils import register_peft_method + from .config import MultitaskPromptTuningConfig, MultitaskPromptTuningInit from .model import MultitaskPromptEmbedding __all__ = ["MultitaskPromptEmbedding", "MultitaskPromptTuningConfig", "MultitaskPromptTuningInit"] + +register_peft_method( + name="multitask_prompt_tuning", config_cls=MultitaskPromptTuningConfig, model_cls=MultitaskPromptEmbedding +) diff --git a/src/peft/tuners/oft/__init__.py b/src/peft/tuners/oft/__init__.py index 1b8bdaa6cd..5df395090f 100644 --- a/src/peft/tuners/oft/__init__.py +++ b/src/peft/tuners/oft/__init__.py @@ -12,9 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +from peft.utils import register_peft_method + from .config import OFTConfig from .layer import Conv2d, Linear, OFTLayer from .model import OFTModel __all__ = ["Conv2d", "Linear", "OFTConfig", "OFTLayer", "OFTModel"] + +register_peft_method(name="oft", config_cls=OFTConfig, model_cls=OFTModel) diff --git a/src/peft/tuners/p_tuning/__init__.py b/src/peft/tuners/p_tuning/__init__.py index 7dd3a6ba3e..9195c0d75d 100644 --- a/src/peft/tuners/p_tuning/__init__.py +++ b/src/peft/tuners/p_tuning/__init__.py @@ -12,8 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +from peft.utils import register_peft_method + from .config import PromptEncoderConfig, PromptEncoderReparameterizationType from .model import PromptEncoder __all__ = ["PromptEncoder", "PromptEncoderConfig", "PromptEncoderReparameterizationType"] + +register_peft_method(name="p_tuning", config_cls=PromptEncoderConfig, model_cls=PromptEncoder) diff --git a/src/peft/tuners/poly/__init__.py b/src/peft/tuners/poly/__init__.py index b0f368695e..1c18933eba 100644 --- a/src/peft/tuners/poly/__init__.py +++ b/src/peft/tuners/poly/__init__.py @@ -12,9 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +from peft.utils import register_peft_method + from .config import PolyConfig from .layer import Linear, PolyLayer from .model import PolyModel __all__ = ["Linear", "PolyConfig", "PolyLayer", "PolyModel"] + +register_peft_method(name="poly", config_cls=PolyConfig, model_cls=PolyModel) diff --git a/src/peft/tuners/prefix_tuning/__init__.py b/src/peft/tuners/prefix_tuning/__init__.py index 8d2c7efa64..939f74d3f6 100644 --- a/src/peft/tuners/prefix_tuning/__init__.py +++ b/src/peft/tuners/prefix_tuning/__init__.py @@ -12,8 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +from peft.utils import register_peft_method + from .config import PrefixTuningConfig from .model import PrefixEncoder __all__ = ["PrefixEncoder", "PrefixTuningConfig"] + +register_peft_method(name="prefix_tuning", config_cls=PrefixTuningConfig, model_cls=PrefixEncoder) diff --git a/src/peft/tuners/prompt_tuning/__init__.py b/src/peft/tuners/prompt_tuning/__init__.py index 6a9ceb48fa..c99ca6a26f 100644 --- a/src/peft/tuners/prompt_tuning/__init__.py +++ b/src/peft/tuners/prompt_tuning/__init__.py @@ -12,8 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +from peft.utils import register_peft_method + from .config import PromptTuningConfig, PromptTuningInit from .model import PromptEmbedding __all__ = ["PromptEmbedding", "PromptTuningConfig", "PromptTuningInit"] + +register_peft_method(name="prompt_tuning", config_cls=PromptTuningConfig, model_cls=PromptEmbedding) diff --git a/src/peft/tuners/vblora/__init__.py b/src/peft/tuners/vblora/__init__.py index 558b9d80fd..8e71a08461 100644 --- a/src/peft/tuners/vblora/__init__.py +++ b/src/peft/tuners/vblora/__init__.py @@ -12,9 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +from peft.utils import register_peft_method + from .config import VBLoRAConfig from .layer import Linear, VBLoRALayer from .model import VBLoRAModel __all__ = ["Linear", "VBLoRAConfig", "VBLoRALayer", "VBLoRAModel"] + +register_peft_method(name="vblora", config_cls=VBLoRAConfig, model_cls=VBLoRAModel) diff --git a/src/peft/tuners/vera/__init__.py b/src/peft/tuners/vera/__init__.py index 268ce3ef93..25c4a96619 100644 --- a/src/peft/tuners/vera/__init__.py +++ b/src/peft/tuners/vera/__init__.py @@ -13,6 +13,7 @@ # limitations under the License. from peft.import_utils import is_bnb_4bit_available, is_bnb_available +from peft.utils import register_peft_method from .config import VeraConfig from .layer import Linear, VeraLayer @@ -22,6 +23,9 @@ __all__ = ["Linear", "VeraConfig", "VeraLayer", "VeraModel"] +register_peft_method(name="vera", config_cls=VeraConfig, model_cls=VeraModel, prefix="vera_lambda_") + + def __getattr__(name): if (name == "Linear8bitLt") and is_bnb_available(): from .bnb import Linear8bitLt diff --git a/src/peft/tuners/vera/model.py b/src/peft/tuners/vera/model.py index e129620ce8..6b41937a16 100644 --- a/src/peft/tuners/vera/model.py +++ b/src/peft/tuners/vera/model.py @@ -99,7 +99,7 @@ class VeraModel(BaseTuner): - **peft_config** ([`VeraConfig`]): The configuration of the Vera model. """ - prefix: str = "vera_lambda" + prefix: str = "vera_lambda_" def __init__(self, model, config, adapter_name, low_cpu_mem_usage: bool = False) -> None: super().__init__(model, config, adapter_name, low_cpu_mem_usage=low_cpu_mem_usage) diff --git a/src/peft/tuners/xlora/__init__.py b/src/peft/tuners/xlora/__init__.py index df41e1e611..6eae1f779b 100644 --- a/src/peft/tuners/xlora/__init__.py +++ b/src/peft/tuners/xlora/__init__.py @@ -12,8 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +from peft.utils import register_peft_method + from .config import XLoraConfig from .model import XLoraModel __all__ = ["XLoraConfig", "XLoraModel"] + +register_peft_method(name="xlora", config_cls=XLoraConfig, model_cls=XLoraModel) diff --git a/src/peft/utils/__init__.py b/src/peft/utils/__init__.py index 7e5acfc1a2..99a86e5b23 100644 --- a/src/peft/utils/__init__.py +++ b/src/peft/utils/__init__.py @@ -46,7 +46,7 @@ shift_tokens_right, transpose, ) -from .peft_types import PeftType, TaskType +from .peft_types import PeftType, TaskType, register_peft_method from .save_and_load import get_peft_model_state_dict, load_peft_weights, set_peft_model_state_dict @@ -84,6 +84,7 @@ "load_peft_weights", "map_cache_to_layer_device_map", "prepare_model_for_kbit_training", + "register_peft_method", "replace_lora_weights_loftq", "set_peft_model_state_dict", "shift_tokens_right", diff --git a/src/peft/utils/constants.py b/src/peft/utils/constants.py index 986be87bd6..e7c305926a 100644 --- a/src/peft/utils/constants.py +++ b/src/peft/utils/constants.py @@ -15,8 +15,6 @@ import torch from transformers import BloomPreTrainedModel -from .peft_types import PeftType - # needed for prefix-tuning of bloom model def bloom_model_postprocess_past_key_value(past_key_values): @@ -286,23 +284,6 @@ def starcoder_model_postprocess_past_key_value(past_key_values): "qwen2": ["q_proj", "v_proj"], } -PEFT_TYPE_TO_PREFIX_MAPPING = { - PeftType.IA3: "ia3_", - PeftType.LORA: "lora_", - PeftType.ADALORA: "lora_", - PeftType.LOHA: "hada_", - PeftType.LOKR: "lokr_", - PeftType.OFT: "oft_", - PeftType.POLY: "poly_", - PeftType.BOFT: "boft_", - PeftType.LN_TUNING: "ln_tuning_", - PeftType.VERA: "vera_lambda_", - PeftType.FOURIERFT: "fourierft_", - PeftType.HRA: "hra_", - PeftType.VBLORA: "vblora_", - PeftType.BONE: "bone_", -} - WEIGHTS_NAME = "adapter_model.bin" SAFETENSORS_WEIGHTS_NAME = "adapter_model.safetensors" CONFIG_NAME = "adapter_config.json" diff --git a/src/peft/utils/hotswap.py b/src/peft/utils/hotswap.py index 3ff7caacce..b09249409c 100644 --- a/src/peft/utils/hotswap.py +++ b/src/peft/utils/hotswap.py @@ -18,9 +18,8 @@ import torch from peft.config import PeftConfig -from peft.mapping import PEFT_TYPE_TO_CONFIG_MAPPING +from peft.mapping import PEFT_TYPE_TO_CONFIG_MAPPING, PEFT_TYPE_TO_PREFIX_MAPPING -from .constants import PEFT_TYPE_TO_PREFIX_MAPPING from .other import infer_device from .peft_types import PeftType from .save_and_load import _insert_adapter_name_into_state_dict, load_peft_weights diff --git a/src/peft/utils/peft_types.py b/src/peft/utils/peft_types.py index d2f1539074..aa7c548f03 100644 --- a/src/peft/utils/peft_types.py +++ b/src/peft/utils/peft_types.py @@ -13,6 +13,7 @@ # limitations under the License. import enum +from typing import Optional class PeftType(str, enum.Enum): @@ -84,3 +85,80 @@ class TaskType(str, enum.Enum): TOKEN_CLS = "TOKEN_CLS" QUESTION_ANS = "QUESTION_ANS" FEATURE_EXTRACTION = "FEATURE_EXTRACTION" + + +def register_peft_method( + *, name: str, config_cls, model_cls, prefix: Optional[str] = None, is_mixed_compatible=False +) -> None: + """ + Function to register a finetuning method like LoRA to be available in PEFT. + + This method takes care of registering the PEFT method's configuration class, the model class, and optionally the + prefix. + + Args: + name (str): + The name of the PEFT method. It must be unique. + config_cls: + The configuration class of the PEFT method. + model_cls: + The model class of the PEFT method. + prefix (Optional[str], optional): + The prefix of the PEFT method. It should be unique. If not provided, the name of the PEFT method is used as + the prefix. + is_mixed_compatible (bool, optional): + Whether the PEFT method is compatible with `PeftMixedModel`. If you're not sure, leave it as False + (default). + + Example: + + ```py + # inside of peft/tuners/my_peft_method/__init__.py + from peft.utils import register_peft_method + + register_peft_method(name="my_peft_method", config_cls=MyConfig, model_cls=MyModel) + ``` + """ + from peft.mapping import ( + PEFT_TYPE_TO_CONFIG_MAPPING, + PEFT_TYPE_TO_MIXED_MODEL_MAPPING, + PEFT_TYPE_TO_PREFIX_MAPPING, + PEFT_TYPE_TO_TUNER_MAPPING, + ) + + if name.endswith("_"): + raise ValueError(f"Please pass the name of the PEFT method without '_' suffix, got {name}.") + + if not name.islower(): + raise ValueError(f"The name of the PEFT method should be in lower case letters, got {name}.") + + if name.upper() not in list(PeftType): + raise ValueError(f"Unknown PEFT type {name.upper()}, please add an entry to peft.utils.peft_types.PeftType.") + + peft_type = getattr(PeftType, name.upper()) + + # model_cls can be None for prompt learning methods, which don't have dedicated model classes + if prefix is None: + prefix = name + "_" + + if ( + (peft_type in PEFT_TYPE_TO_CONFIG_MAPPING) + or (peft_type in PEFT_TYPE_TO_TUNER_MAPPING) + or (peft_type in PEFT_TYPE_TO_MIXED_MODEL_MAPPING) + ): + raise KeyError(f"There is already PEFT method called '{name}', please choose a unique name.") + + if prefix in PEFT_TYPE_TO_PREFIX_MAPPING: + raise KeyError(f"There is already a prefix called '{prefix}', please choose a unique prefix.") + + model_cls_prefix = getattr(model_cls, "prefix", None) + if (model_cls_prefix is not None) and (model_cls_prefix != prefix): + raise ValueError( + f"Inconsistent prefixes found: '{prefix}' and '{model_cls_prefix}' (they should be the same)." + ) + + PEFT_TYPE_TO_PREFIX_MAPPING[peft_type] = prefix + PEFT_TYPE_TO_CONFIG_MAPPING[peft_type] = config_cls + PEFT_TYPE_TO_TUNER_MAPPING[peft_type] = model_cls + if is_mixed_compatible: + PEFT_TYPE_TO_MIXED_MODEL_MAPPING[peft_type] = model_cls diff --git a/src/peft/utils/save_and_load.py b/src/peft/utils/save_and_load.py index 44a1cad5ff..5337afeb9f 100644 --- a/src/peft/utils/save_and_load.py +++ b/src/peft/utils/save_and_load.py @@ -24,7 +24,8 @@ from packaging import version from safetensors.torch import load_file as safe_load_file -from .constants import PEFT_TYPE_TO_PREFIX_MAPPING +from peft.mapping import PEFT_TYPE_TO_PREFIX_MAPPING + from .other import ( EMBEDDING_LAYER_NAMES, SAFETENSORS_WEIGHTS_NAME, @@ -133,12 +134,6 @@ def renamed_dora_weights(k): else: raise NotImplementedError - elif config.peft_type == PeftType.LOHA: - to_return = {k: state_dict[k] for k in state_dict if "hada_" in k} - - elif config.peft_type == PeftType.LOKR: - to_return = {k: state_dict[k] for k in state_dict if "lokr_" in k} - elif config.peft_type == PeftType.ADAPTION_PROMPT: to_return = {k: state_dict[k] for k in state_dict if k.split(".")[-1].startswith("adaption_")} @@ -155,20 +150,9 @@ def renamed_dora_weights(k): prompt_embeddings = model.get_prompt_embedding_to_save(adapter_name) to_return["prompt_embeddings"] = prompt_embeddings - elif config.peft_type == PeftType.IA3: - to_return = {k: state_dict[k] for k in state_dict if "ia3_" in k} - - elif config.peft_type == PeftType.OFT: - to_return = {k: state_dict[k] for k in state_dict if "oft_" in k} - - elif config.peft_type == PeftType.POLY: - to_return = {k: state_dict[k] for k in state_dict if "poly_" in k} - - elif config.peft_type == PeftType.LN_TUNING: - to_return = {k: state_dict[k] for k in state_dict if "ln_tuning_" in k} - elif config.peft_type == PeftType.VERA: - to_return = {k: state_dict[k] for k in state_dict if "vera_lambda_" in k} + vera_prefix = PEFT_TYPE_TO_PREFIX_MAPPING[config.peft_type] + to_return = {k: state_dict[k] for k in state_dict if vera_prefix in k} if config.save_projection: # TODO: adding vera_A and vera_B to `self.get_base_layer` would # make name to match here difficult to predict. @@ -179,12 +163,8 @@ def renamed_dora_weights(k): ) to_return["base_model.vera_A." + adapter_name] = state_dict["base_model.vera_A." + adapter_name] to_return["base_model.vera_B." + adapter_name] = state_dict["base_model.vera_B." + adapter_name] - elif config.peft_type == PeftType.FOURIERFT: - to_return = {k: state_dict[k] for k in state_dict if "fourierft_" in k} elif config.peft_type == PeftType.XLORA: to_return = {k: state_dict[k] for k in state_dict if "internal_xlora_classifier" in k} - elif config.peft_type == PeftType.HRA: - to_return = {k: state_dict[k] for k in state_dict if "hra_" in k} elif config.peft_type == PeftType.VBLORA: to_return = {} # choose the most efficient dtype for indices @@ -208,8 +188,9 @@ def renamed_dora_weights(k): to_return["base_model.vblora_vector_bank." + adapter_name] = state_dict[ "base_model.vblora_vector_bank." + adapter_name ] - elif config.peft_type == PeftType.BONE: - to_return = {k: state_dict[k] for k in state_dict if "bone_" in k} + elif config.peft_type in list(PeftType): + prefix = PEFT_TYPE_TO_PREFIX_MAPPING[config.peft_type] + to_return = {k: state_dict[k] for k in state_dict if prefix in k} else: raise ValueError(f"Unknown PEFT type passed: {config.peft_type}") @@ -363,7 +344,11 @@ def set_peft_model_state_dict( else: state_dict = peft_model_state_dict - if config.peft_type in PEFT_TYPE_TO_PREFIX_MAPPING: + if config.is_prompt_learning or config.peft_type == PeftType.ADAPTION_PROMPT: + peft_model_state_dict = state_dict + elif config.peft_type == PeftType.XLORA: + peft_model_state_dict = state_dict + elif config.peft_type in PEFT_TYPE_TO_PREFIX_MAPPING: peft_model_state_dict = {} parameter_prefix = PEFT_TYPE_TO_PREFIX_MAPPING[config.peft_type] if config.peft_type == PeftType.VBLORA and config.save_only_topk_weights: @@ -430,11 +415,6 @@ def renamed_dora_weights(k): return k peft_model_state_dict = {renamed_dora_weights(k): v for k, v in peft_model_state_dict.items()} - - elif config.is_prompt_learning or config.peft_type == PeftType.ADAPTION_PROMPT: - peft_model_state_dict = state_dict - elif config.peft_type == PeftType.XLORA: - peft_model_state_dict = state_dict else: raise NotImplementedError diff --git a/tests/test_adaption_prompt.py b/tests/test_adaption_prompt.py index cbdf19a297..5ac6a66436 100644 --- a/tests/test_adaption_prompt.py +++ b/tests/test_adaption_prompt.py @@ -22,7 +22,7 @@ import torch from torch.testing import assert_close -from peft.mapping import get_peft_model +from peft import get_peft_model from peft.peft_model import PeftModel from peft.tuners.adaption_prompt import AdaptionPromptConfig from peft.utils.other import prepare_model_for_kbit_training diff --git a/tests/test_initialization.py b/tests/test_initialization.py index 15ad34ea89..103fa6696d 100644 --- a/tests/test_initialization.py +++ b/tests/test_initialization.py @@ -55,11 +55,11 @@ inject_adapter_in_model, set_peft_model_state_dict, ) +from peft.mapping import PEFT_TYPE_TO_PREFIX_MAPPING from peft.tuners.lora.config import CordaConfig from peft.tuners.lora.corda import preprocess_corda from peft.tuners.lora.layer import LoraLayer from peft.utils import infer_device -from peft.utils.constants import PEFT_TYPE_TO_PREFIX_MAPPING from peft.utils.hotswap import hotswap_adapter diff --git a/tests/test_multitask_prompt_tuning.py b/tests/test_multitask_prompt_tuning.py index 4dc8832001..0b6f661aad 100644 --- a/tests/test_multitask_prompt_tuning.py +++ b/tests/test_multitask_prompt_tuning.py @@ -22,7 +22,7 @@ from parameterized import parameterized from torch.testing import assert_close -from peft.mapping import get_peft_model +from peft import get_peft_model from peft.peft_model import PeftModel from peft.tuners.multitask_prompt_tuning import MultitaskPromptTuningConfig, MultitaskPromptTuningInit from peft.utils.other import WEIGHTS_NAME, prepare_model_for_kbit_training