diff --git a/src/peft/__init__.py b/src/peft/__init__.py index 1ba77231e1..a43c171bbd 100644 --- a/src/peft/__init__.py +++ b/src/peft/__init__.py @@ -15,6 +15,7 @@ __version__ = "0.14.1.dev0" from .auto import ( + MODEL_TYPE_TO_PEFT_MODEL_MAPPING, AutoPeftModel, AutoPeftModelForCausalLM, AutoPeftModelForFeatureExtraction, @@ -25,12 +26,13 @@ ) from .config import PeftConfig, PromptLearningConfig from .mapping import ( - MODEL_TYPE_TO_PEFT_MODEL_MAPPING, PEFT_TYPE_TO_CONFIG_MAPPING, + PEFT_TYPE_TO_MIXED_MODEL_MAPPING, + PEFT_TYPE_TO_TUNER_MAPPING, get_peft_config, - get_peft_model, inject_adapter_in_model, ) +from .mapping_func import get_peft_model from .mixed_model import PeftMixedModel from .peft_model import ( PeftModel, @@ -112,6 +114,8 @@ __all__ = [ "MODEL_TYPE_TO_PEFT_MODEL_MAPPING", "PEFT_TYPE_TO_CONFIG_MAPPING", + "PEFT_TYPE_TO_MIXED_MODEL_MAPPING", + "PEFT_TYPE_TO_TUNER_MAPPING", "TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING", "AdaLoraConfig", "AdaLoraModel", diff --git a/src/peft/auto.py b/src/peft/auto.py index 18933be61e..4f890994d7 100644 --- a/src/peft/auto.py +++ b/src/peft/auto.py @@ -29,7 +29,6 @@ ) from .config import PeftConfig -from .mapping import MODEL_TYPE_TO_PEFT_MODEL_MAPPING from .peft_model import ( PeftModel, PeftModelForCausalLM, @@ -43,6 +42,16 @@ from .utils.other import check_file_exists_on_hf_hub +MODEL_TYPE_TO_PEFT_MODEL_MAPPING: dict[str, type[PeftModel]] = { + "SEQ_CLS": PeftModelForSequenceClassification, + "SEQ_2_SEQ_LM": PeftModelForSeq2SeqLM, + "CAUSAL_LM": PeftModelForCausalLM, + "TOKEN_CLS": PeftModelForTokenClassification, + "QUESTION_ANS": PeftModelForQuestionAnswering, + "FEATURE_EXTRACTION": PeftModelForFeatureExtraction, +} + + class _BaseAutoPeftModel: _target_class = None _target_peft_class = None diff --git a/src/peft/mapping.py b/src/peft/mapping.py index 67c29bd901..a92b28f858 100644 --- a/src/peft/mapping.py +++ b/src/peft/mapping.py @@ -14,123 +14,23 @@ from __future__ import annotations -import warnings -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any import torch -from peft.tuners.xlora.model import XLoraModel - -from .config import PeftConfig -from .mixed_model import PeftMixedModel -from .peft_model import ( - PeftModel, - PeftModelForCausalLM, - PeftModelForFeatureExtraction, - PeftModelForQuestionAnswering, - PeftModelForSeq2SeqLM, - PeftModelForSequenceClassification, - PeftModelForTokenClassification, -) -from .tuners import ( - AdaLoraConfig, - AdaLoraModel, - AdaptionPromptConfig, - BOFTConfig, - BOFTModel, - BoneConfig, - BoneModel, - CPTConfig, - CPTEmbedding, - FourierFTConfig, - FourierFTModel, - HRAConfig, - HRAModel, - IA3Config, - IA3Model, - LNTuningConfig, - LNTuningModel, - LoHaConfig, - LoHaModel, - LoKrConfig, - LoKrModel, - LoraConfig, - LoraModel, - MultitaskPromptTuningConfig, - OFTConfig, - OFTModel, - PolyConfig, - PolyModel, - PrefixTuningConfig, - PromptEncoderConfig, - PromptTuningConfig, - VBLoRAConfig, - VBLoRAModel, - VeraConfig, - VeraModel, - XLoraConfig, -) -from .tuners.tuners_utils import BaseTuner, BaseTunerLayer -from .utils import _prepare_prompt_learning_config -from .utils.constants import PEFT_TYPE_TO_PREFIX_MAPPING +from .utils import PeftType if TYPE_CHECKING: - from transformers import PreTrainedModel - + from .config import PeftConfig + from .tuners.tuners_utils import BaseTuner -MODEL_TYPE_TO_PEFT_MODEL_MAPPING: dict[str, type[PeftModel]] = { - "SEQ_CLS": PeftModelForSequenceClassification, - "SEQ_2_SEQ_LM": PeftModelForSeq2SeqLM, - "CAUSAL_LM": PeftModelForCausalLM, - "TOKEN_CLS": PeftModelForTokenClassification, - "QUESTION_ANS": PeftModelForQuestionAnswering, - "FEATURE_EXTRACTION": PeftModelForFeatureExtraction, -} -PEFT_TYPE_TO_CONFIG_MAPPING: dict[str, type[PeftConfig]] = { - "ADAPTION_PROMPT": AdaptionPromptConfig, - "PROMPT_TUNING": PromptTuningConfig, - "PREFIX_TUNING": PrefixTuningConfig, - "P_TUNING": PromptEncoderConfig, - "LORA": LoraConfig, - "LOHA": LoHaConfig, - "LORAPLUS": LoraConfig, - "LOKR": LoKrConfig, - "ADALORA": AdaLoraConfig, - "BOFT": BOFTConfig, - "IA3": IA3Config, - "MULTITASK_PROMPT_TUNING": MultitaskPromptTuningConfig, - "OFT": OFTConfig, - "POLY": PolyConfig, - "LN_TUNING": LNTuningConfig, - "VERA": VeraConfig, - "FOURIERFT": FourierFTConfig, - "XLORA": XLoraConfig, - "HRA": HRAConfig, - "VBLORA": VBLoRAConfig, - "CPT": CPTConfig, - "BONE": BoneConfig, -} - -PEFT_TYPE_TO_TUNER_MAPPING: dict[str, type[BaseTuner]] = { - "LORA": LoraModel, - "LOHA": LoHaModel, - "LOKR": LoKrModel, - "ADALORA": AdaLoraModel, - "BOFT": BOFTModel, - "IA3": IA3Model, - "OFT": OFTModel, - "POLY": PolyModel, - "LN_TUNING": LNTuningModel, - "VERA": VeraModel, - "FOURIERFT": FourierFTModel, - "XLORA": XLoraModel, - "HRA": HRAModel, - "VBLORA": VBLoRAModel, - "CPT": CPTEmbedding, - "BONE": BoneModel, -} +# these will be filled by the register_peft_method function +PEFT_TYPE_TO_CONFIG_MAPPING: dict[PeftType, type[PeftConfig]] = {} +PEFT_TYPE_TO_TUNER_MAPPING: dict[PeftType, type[BaseTuner]] = {} +PEFT_TYPE_TO_MIXED_MODEL_MAPPING: dict[PeftType, type[BaseTuner]] = {} +PEFT_TYPE_TO_PREFIX_MAPPING: dict[PeftType, str] = {} def get_peft_config(config_dict: dict[str, Any]) -> PeftConfig: @@ -144,107 +44,6 @@ def get_peft_config(config_dict: dict[str, Any]) -> PeftConfig: return PEFT_TYPE_TO_CONFIG_MAPPING[config_dict["peft_type"]](**config_dict) -def get_peft_model( - model: PreTrainedModel, - peft_config: PeftConfig, - adapter_name: str = "default", - mixed: bool = False, - autocast_adapter_dtype: bool = True, - revision: Optional[str] = None, - low_cpu_mem_usage: bool = False, -) -> PeftModel | PeftMixedModel: - """ - Returns a Peft model object from a model and a config, where the model will be modified in-place. - - Args: - model ([`transformers.PreTrainedModel`]): - Model to be wrapped. - peft_config ([`PeftConfig`]): - Configuration object containing the parameters of the Peft model. - adapter_name (`str`, `optional`, defaults to `"default"`): - The name of the adapter to be injected, if not provided, the default adapter name is used ("default"). - mixed (`bool`, `optional`, defaults to `False`): - Whether to allow mixing different (compatible) adapter types. - autocast_adapter_dtype (`bool`, *optional*): - Whether to autocast the adapter dtype. Defaults to `True`. Right now, this will only cast adapter weights - using float16 or bfloat16 to float32, as this is typically required for stable training, and only affect - select PEFT tuners. - revision (`str`, `optional`, defaults to `main`): - The revision of the base model. If this isn't set, the saved peft model will load the `main` revision for - the base model - low_cpu_mem_usage (`bool`, `optional`, defaults to `False`): - Create empty adapter weights on meta device. Useful to speed up the loading process. Leave this setting as - False if you intend on training the model, unless the adapter weights will be replaced by different weights - before training starts. - """ - model_config = BaseTuner.get_model_config(model) - old_name = peft_config.base_model_name_or_path - new_name = model.__dict__.get("name_or_path", None) - peft_config.base_model_name_or_path = new_name - - # Especially in notebook environments there could be a case that a user wants to experiment with different - # configuration values. However, it is likely that there won't be any changes for new configs on an already - # initialized PEFT model. The best we can do is warn the user about it. - if any(isinstance(module, BaseTunerLayer) for module in model.modules()): - warnings.warn( - "You are trying to modify a model with PEFT for a second time. If you want to reload the model with a " - "different config, make sure to call `.unload()` before." - ) - - if (old_name is not None) and (old_name != new_name): - warnings.warn( - f"The PEFT config's `base_model_name_or_path` was renamed from '{old_name}' to '{new_name}'. " - "Please ensure that the correct base model is loaded when loading this checkpoint." - ) - - if revision is not None: - if peft_config.revision is not None and peft_config.revision != revision: - warnings.warn( - f"peft config has already set base model revision to {peft_config.revision}, overwriting with revision {revision}" - ) - peft_config.revision = revision - - if ( - (isinstance(peft_config, PEFT_TYPE_TO_CONFIG_MAPPING["LORA"])) - and (peft_config.init_lora_weights == "eva") - and not low_cpu_mem_usage - ): - warnings.warn( - "lora with eva initialization used with low_cpu_mem_usage=False. " - "Setting low_cpu_mem_usage=True can improve the maximum batch size possible for eva initialization." - ) - - prefix = PEFT_TYPE_TO_PREFIX_MAPPING.get(peft_config.peft_type) - if prefix and adapter_name in prefix: - warnings.warn( - f"Adapter name {adapter_name} should not be contained in the prefix {prefix}." - "This may lead to reinitialization of the adapter weights during loading." - ) - - if mixed: - # note: PeftMixedModel does not support autocast_adapter_dtype, so don't pass it - return PeftMixedModel(model, peft_config, adapter_name=adapter_name) - - if peft_config.task_type not in MODEL_TYPE_TO_PEFT_MODEL_MAPPING.keys() and not peft_config.is_prompt_learning: - return PeftModel( - model, - peft_config, - adapter_name=adapter_name, - autocast_adapter_dtype=autocast_adapter_dtype, - low_cpu_mem_usage=low_cpu_mem_usage, - ) - - if peft_config.is_prompt_learning: - peft_config = _prepare_prompt_learning_config(peft_config, model_config) - return MODEL_TYPE_TO_PEFT_MODEL_MAPPING[peft_config.task_type]( - model, - peft_config, - adapter_name=adapter_name, - autocast_adapter_dtype=autocast_adapter_dtype, - low_cpu_mem_usage=low_cpu_mem_usage, - ) - - def inject_adapter_in_model( peft_config: PeftConfig, model: torch.nn.Module, adapter_name: str = "default", low_cpu_mem_usage: bool = False ) -> torch.nn.Module: diff --git a/src/peft/mapping_func.py b/src/peft/mapping_func.py new file mode 100644 index 0000000000..279ce499d1 --- /dev/null +++ b/src/peft/mapping_func.py @@ -0,0 +1,129 @@ +# Copyright 2024-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import warnings +from typing import Optional + +from transformers import PreTrainedModel + +from .auto import MODEL_TYPE_TO_PEFT_MODEL_MAPPING +from .config import PeftConfig +from .mapping import PEFT_TYPE_TO_CONFIG_MAPPING, PEFT_TYPE_TO_PREFIX_MAPPING +from .mixed_model import PeftMixedModel +from .peft_model import PeftModel +from .tuners.tuners_utils import BaseTuner, BaseTunerLayer +from .utils import _prepare_prompt_learning_config + + +def get_peft_model( + model: PreTrainedModel, + peft_config: PeftConfig, + adapter_name: str = "default", + mixed: bool = False, + autocast_adapter_dtype: bool = True, + revision: Optional[str] = None, + low_cpu_mem_usage: bool = False, +) -> PeftModel | PeftMixedModel: + """ + Returns a Peft model object from a model and a config, where the model will be modified in-place. + + Args: + model ([`transformers.PreTrainedModel`]): + Model to be wrapped. + peft_config ([`PeftConfig`]): + Configuration object containing the parameters of the Peft model. + adapter_name (`str`, `optional`, defaults to `"default"`): + The name of the adapter to be injected, if not provided, the default adapter name is used ("default"). + mixed (`bool`, `optional`, defaults to `False`): + Whether to allow mixing different (compatible) adapter types. + autocast_adapter_dtype (`bool`, *optional*): + Whether to autocast the adapter dtype. Defaults to `True`. Right now, this will only cast adapter weights + using float16 or bfloat16 to float32, as this is typically required for stable training, and only affect + select PEFT tuners. + revision (`str`, `optional`, defaults to `main`): + The revision of the base model. If this isn't set, the saved peft model will load the `main` revision for + the base model + low_cpu_mem_usage (`bool`, `optional`, defaults to `False`): + Create empty adapter weights on meta device. Useful to speed up the loading process. Leave this setting as + False if you intend on training the model, unless the adapter weights will be replaced by different weights + before training starts. + """ + model_config = BaseTuner.get_model_config(model) + old_name = peft_config.base_model_name_or_path + new_name = model.__dict__.get("name_or_path", None) + peft_config.base_model_name_or_path = new_name + + # Especially in notebook environments there could be a case that a user wants to experiment with different + # configuration values. However, it is likely that there won't be any changes for new configs on an already + # initialized PEFT model. The best we can do is warn the user about it. + if any(isinstance(module, BaseTunerLayer) for module in model.modules()): + warnings.warn( + "You are trying to modify a model with PEFT for a second time. If you want to reload the model with a " + "different config, make sure to call `.unload()` before." + ) + + if (old_name is not None) and (old_name != new_name): + warnings.warn( + f"The PEFT config's `base_model_name_or_path` was renamed from '{old_name}' to '{new_name}'. " + "Please ensure that the correct base model is loaded when loading this checkpoint." + ) + + if revision is not None: + if peft_config.revision is not None and peft_config.revision != revision: + warnings.warn( + f"peft config has already set base model revision to {peft_config.revision}, overwriting with revision {revision}" + ) + peft_config.revision = revision + + if ( + (isinstance(peft_config, PEFT_TYPE_TO_CONFIG_MAPPING["LORA"])) + and (peft_config.init_lora_weights == "eva") + and not low_cpu_mem_usage + ): + warnings.warn( + "lora with eva initialization used with low_cpu_mem_usage=False. " + "Setting low_cpu_mem_usage=True can improve the maximum batch size possible for eva initialization." + ) + + prefix = PEFT_TYPE_TO_PREFIX_MAPPING.get(peft_config.peft_type) + if prefix and adapter_name in prefix: + warnings.warn( + f"Adapter name {adapter_name} should not be contained in the prefix {prefix}." + "This may lead to reinitialization of the adapter weights during loading." + ) + + if mixed: + # note: PeftMixedModel does not support autocast_adapter_dtype, so don't pass it + return PeftMixedModel(model, peft_config, adapter_name=adapter_name) + + if peft_config.task_type not in MODEL_TYPE_TO_PEFT_MODEL_MAPPING.keys() and not peft_config.is_prompt_learning: + return PeftModel( + model, + peft_config, + adapter_name=adapter_name, + autocast_adapter_dtype=autocast_adapter_dtype, + low_cpu_mem_usage=low_cpu_mem_usage, + ) + + if peft_config.is_prompt_learning: + peft_config = _prepare_prompt_learning_config(peft_config, model_config) + return MODEL_TYPE_TO_PEFT_MODEL_MAPPING[peft_config.task_type]( + model, + peft_config, + adapter_name=adapter_name, + autocast_adapter_dtype=autocast_adapter_dtype, + low_cpu_mem_usage=low_cpu_mem_usage, + ) diff --git a/src/peft/mixed_model.py b/src/peft/mixed_model.py index 098bd0d4fd..c1f11c7aa9 100644 --- a/src/peft/mixed_model.py +++ b/src/peft/mixed_model.py @@ -27,25 +27,8 @@ from .config import PeftConfig from .peft_model import PeftModel -from .tuners import ( - AdaLoraModel, - IA3Model, - LoHaModel, - LoKrModel, - LoraModel, - MixedModel, -) -from .tuners.mixed import COMPATIBLE_TUNER_TYPES -from .utils import PeftType, _set_adapter, _set_trainable - - -PEFT_TYPE_TO_MODEL_MAPPING = { - PeftType.LORA: LoraModel, - PeftType.LOHA: LoHaModel, - PeftType.LOKR: LoKrModel, - PeftType.ADALORA: AdaLoraModel, - PeftType.IA3: IA3Model, -} +from .tuners import MixedModel +from .utils import _set_adapter, _set_trainable def _prepare_model_for_gradient_checkpointing(model: nn.Module) -> None: @@ -72,6 +55,8 @@ def make_inputs_require_grad(module, input, output): def _check_config_compatible(peft_config: PeftConfig) -> None: + from .tuners.mixed import COMPATIBLE_TUNER_TYPES + if peft_config.peft_type not in COMPATIBLE_TUNER_TYPES: raise ValueError( f"The provided `peft_type` '{peft_config.peft_type.value}' is not compatible with the `PeftMixedModel`. " @@ -440,7 +425,7 @@ def from_pretrained( Additional keyword arguments passed along to the specific PEFT configuration class. """ # note: adapted from PeftModel.from_pretrained - from .mapping import PEFT_TYPE_TO_CONFIG_MAPPING + from .mapping import PEFT_TYPE_TO_CONFIG_MAPPING, PEFT_TYPE_TO_MIXED_MODEL_MAPPING # load the config if config is None: @@ -459,7 +444,7 @@ def from_pretrained( raise ValueError(f"The input config must be a PeftConfig, got {config.__class__}") # note: this is different from PeftModel.from_pretrained - if config.peft_type not in PEFT_TYPE_TO_MODEL_MAPPING: + if config.peft_type not in PEFT_TYPE_TO_MIXED_MODEL_MAPPING: raise ValueError(f"Adapter of type {config.peft_type} is not supported for mixed models.") if (getattr(model, "hf_device_map", None) is not None) and len( diff --git a/src/peft/peft_model.py b/src/peft/peft_model.py index a3c7535052..62061a84e8 100644 --- a/src/peft/peft_model.py +++ b/src/peft/peft_model.py @@ -38,36 +38,13 @@ from transformers.modeling_outputs import QuestionAnsweringModelOutput, SequenceClassifierOutput, TokenClassifierOutput from transformers.utils import PushToHubMixin -from peft.utils.constants import DUMMY_MODEL_CONFIG, PEFT_TYPE_TO_PREFIX_MAPPING +from peft.tuners.tuners_utils import BaseTuner, BaseTunerLayer +from peft.utils.constants import DUMMY_MODEL_CONFIG from peft.utils.integrations import init_empty_weights from . import __version__ from .config import PeftConfig -from .tuners import ( - AdaLoraModel, - AdaptionPromptModel, - BOFTModel, - BoneModel, - CPTEmbedding, - FourierFTModel, - HRAModel, - IA3Model, - LNTuningModel, - LoHaModel, - LoKrModel, - LoraModel, - MultitaskPromptEmbedding, - OFTModel, - PolyModel, - PrefixEncoder, - PromptEmbedding, - PromptEncoder, - VBLoRAModel, - VeraModel, - XLoraConfig, - XLoraModel, -) -from .tuners.tuners_utils import BaseTuner, BaseTunerLayer +from .mapping import PEFT_TYPE_TO_CONFIG_MAPPING, PEFT_TYPE_TO_PREFIX_MAPPING, PEFT_TYPE_TO_TUNER_MAPPING from .utils import ( SAFETENSORS_WEIGHTS_NAME, TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING, @@ -88,30 +65,6 @@ ) -PEFT_TYPE_TO_MODEL_MAPPING = { - PeftType.LORA: LoraModel, - PeftType.LOHA: LoHaModel, - PeftType.LOKR: LoKrModel, - PeftType.PROMPT_TUNING: PromptEmbedding, - PeftType.P_TUNING: PromptEncoder, - PeftType.PREFIX_TUNING: PrefixEncoder, - PeftType.ADALORA: AdaLoraModel, - PeftType.BOFT: BOFTModel, - PeftType.ADAPTION_PROMPT: AdaptionPromptModel, - PeftType.IA3: IA3Model, - PeftType.OFT: OFTModel, - PeftType.POLY: PolyModel, - PeftType.LN_TUNING: LNTuningModel, - PeftType.VERA: VeraModel, - PeftType.FOURIERFT: FourierFTModel, - PeftType.XLORA: XLoraModel, - PeftType.HRA: HRAModel, - PeftType.VBLORA: VBLoRAModel, - PeftType.CPT: CPTEmbedding, - PeftType.BONE: BoneModel, -} - - class PeftModel(PushToHubMixin, torch.nn.Module): """ Base model encompassing various Peft methods. @@ -171,7 +124,7 @@ def __init__( self.add_adapter(adapter_name, peft_config, low_cpu_mem_usage=low_cpu_mem_usage) else: self._peft_config = None - cls = PEFT_TYPE_TO_MODEL_MAPPING[peft_config.peft_type] + cls = PEFT_TYPE_TO_TUNER_MAPPING[peft_config.peft_type] ctx = init_empty_weights if low_cpu_mem_usage else nullcontext with ctx(): self.base_model = cls(model, {adapter_name: peft_config}, adapter_name) @@ -474,7 +427,8 @@ def from_pretrained( kwargs: (`optional`): Additional keyword arguments passed along to the specific PEFT configuration class. """ - from .mapping import MODEL_TYPE_TO_PEFT_MODEL_MAPPING, PEFT_TYPE_TO_CONFIG_MAPPING + from .auto import MODEL_TYPE_TO_PEFT_MODEL_MAPPING + from .tuners import XLoraConfig, XLoraModel # load the config if config is None: @@ -660,20 +614,17 @@ def _setup_prompt_encoder(self, adapter_name: str): break self.word_embeddings = word_embeddings + model_cls = PEFT_TYPE_TO_TUNER_MAPPING[config.peft_type] - if config.peft_type == PeftType.PROMPT_TUNING: - prompt_encoder = PromptEmbedding(config, self.word_embeddings) - elif config.peft_type == PeftType.MULTITASK_PROMPT_TUNING: - prompt_encoder = MultitaskPromptEmbedding(config, self.word_embeddings) + if config.peft_type in (PeftType.PROMPT_TUNING, PeftType.MULTITASK_PROMPT_TUNING, PeftType.CPT): + prompt_encoder = model_cls(config, self.word_embeddings) elif config.peft_type == PeftType.P_TUNING: - prompt_encoder = PromptEncoder(config) + prompt_encoder = model_cls(config) elif config.peft_type == PeftType.PREFIX_TUNING: # prefix tuning now uses Cache but that won't work with gradient checkpointing if any(getattr(module, "gradient_checkpointing", False) for module in self.get_base_model().modules()): raise ValueError("Prefix tuning does not work with gradient checkpointing.") - prompt_encoder = PrefixEncoder(config) - elif config.peft_type == PeftType.CPT: - prompt_encoder = CPTEmbedding(config, self.word_embeddings) + prompt_encoder = model_cls(config) else: raise ValueError("Not supported") @@ -711,11 +662,13 @@ def get_prompt_embedding_to_save(self, adapter_name: str) -> torch.Tensor: prompt_tokens = ( self.prompt_tokens[adapter_name].unsqueeze(0).expand(1, -1).to(prompt_encoder.embedding.weight.device) ) + peft_type = self.peft_config[adapter_name].peft_type if self.peft_config[adapter_name].peft_type == PeftType.PREFIX_TUNING: prompt_tokens = prompt_tokens[:, : self.peft_config[adapter_name].num_virtual_tokens] if self.peft_config[adapter_name].peft_type == PeftType.MULTITASK_PROMPT_TUNING: - prompt_embeddings = super(MultitaskPromptEmbedding, prompt_encoder).forward(prompt_tokens) + prompt_embedding_cls = PEFT_TYPE_TO_TUNER_MAPPING[peft_type] + prompt_embeddings = super(prompt_embedding_cls, prompt_encoder).forward(prompt_tokens) else: prompt_embeddings = prompt_encoder(prompt_tokens) @@ -1794,9 +1747,7 @@ def forward( inputs_embeds = torch.cat((prompts, inputs_embeds), dim=1) return self.base_model(inputs_embeds=inputs_embeds, **kwargs) - def _cpt_forward( - self, input_ids=None, inputs_embeds=None, peft_config=None, task_ids=None, batch_size=None, **kwargs - ): + def _cpt_forward(self, input_ids, inputs_embeds, peft_config, task_ids, batch_size, **kwargs): # Extract labels from kwargs labels = kwargs.pop("labels") device = [i.device for i in [input_ids, inputs_embeds, labels] if i is not None][0] @@ -1846,7 +1797,8 @@ def _cpt_forward( return base_model_output else: # Calculate the loss using the custom CPT loss function - base_model_output = CPTEmbedding.calculate_loss( + cpt_embedding = PEFT_TYPE_TO_TUNER_MAPPING[peft_config.peft_type] + base_model_output = cpt_embedding.calculate_loss( base_model_output, cpt_labels, cpt_type_mask, self.peft_config["default"] ) return base_model_output diff --git a/src/peft/tuners/adalora/__init__.py b/src/peft/tuners/adalora/__init__.py index bdec26b522..64d5f3e5ce 100644 --- a/src/peft/tuners/adalora/__init__.py +++ b/src/peft/tuners/adalora/__init__.py @@ -13,6 +13,7 @@ # limitations under the License. from peft.import_utils import is_bnb_4bit_available, is_bnb_available +from peft.utils import register_peft_method from .config import AdaLoraConfig from .gptq import SVDQuantLinear @@ -23,6 +24,11 @@ __all__ = ["AdaLoraConfig", "AdaLoraLayer", "AdaLoraModel", "RankAllocator", "SVDLinear", "SVDQuantLinear"] +register_peft_method( + name="adalora", config_cls=AdaLoraConfig, model_cls=AdaLoraModel, prefix="lora_", is_mixed_compatible=True +) + + def __getattr__(name): if (name == "SVDLinear8bitLt") and is_bnb_available(): from .bnb import SVDLinear8bitLt diff --git a/src/peft/tuners/adaption_prompt/__init__.py b/src/peft/tuners/adaption_prompt/__init__.py index 826e115d3c..68882a2226 100644 --- a/src/peft/tuners/adaption_prompt/__init__.py +++ b/src/peft/tuners/adaption_prompt/__init__.py @@ -11,9 +11,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from peft.utils import register_peft_method + from .config import AdaptionPromptConfig from .layer import AdaptedAttention from .model import AdaptionPromptModel __all__ = ["AdaptedAttention", "AdaptionPromptConfig", "AdaptionPromptModel"] + +register_peft_method(name="adaption_prompt", config_cls=AdaptionPromptConfig, model_cls=AdaptionPromptModel) diff --git a/src/peft/tuners/boft/__init__.py b/src/peft/tuners/boft/__init__.py index 5b72b73951..c84b8358da 100644 --- a/src/peft/tuners/boft/__init__.py +++ b/src/peft/tuners/boft/__init__.py @@ -12,9 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +from peft.utils import register_peft_method + from .config import BOFTConfig from .layer import BOFTLayer from .model import BOFTModel __all__ = ["BOFTConfig", "BOFTLayer", "BOFTModel"] + +register_peft_method(name="boft", config_cls=BOFTConfig, model_cls=BOFTModel) diff --git a/src/peft/tuners/bone/__init__.py b/src/peft/tuners/bone/__init__.py index d2a41552d5..f131e8c17d 100644 --- a/src/peft/tuners/bone/__init__.py +++ b/src/peft/tuners/bone/__init__.py @@ -12,9 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +from peft.utils import register_peft_method + from .config import BoneConfig from .layer import BoneLayer, BoneLinear from .model import BoneModel __all__ = ["BoneConfig", "BoneLayer", "BoneLinear", "BoneModel"] + +register_peft_method(name="bone", config_cls=BoneConfig, model_cls=BoneModel) diff --git a/src/peft/tuners/cpt/__init__.py b/src/peft/tuners/cpt/__init__.py index f5018f89b1..fcd4de8598 100644 --- a/src/peft/tuners/cpt/__init__.py +++ b/src/peft/tuners/cpt/__init__.py @@ -13,8 +13,12 @@ # limitations under the License. +from peft.utils import register_peft_method + from .config import CPTConfig from .model import CPTEmbedding __all__ = ["CPTConfig", "CPTEmbedding"] + +register_peft_method(name="cpt", config_cls=CPTConfig, model_cls=CPTEmbedding) diff --git a/src/peft/tuners/fourierft/__init__.py b/src/peft/tuners/fourierft/__init__.py index 7646d9497c..dfe3f5d89e 100644 --- a/src/peft/tuners/fourierft/__init__.py +++ b/src/peft/tuners/fourierft/__init__.py @@ -12,9 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +from peft.utils import register_peft_method + from .config import FourierFTConfig from .layer import FourierFTLayer, FourierFTLinear from .model import FourierFTModel __all__ = ["FourierFTConfig", "FourierFTLayer", "FourierFTLinear", "FourierFTModel"] + +register_peft_method(name="fourierft", model_cls=FourierFTModel, config_cls=FourierFTConfig) diff --git a/src/peft/tuners/hra/__init__.py b/src/peft/tuners/hra/__init__.py index 11902f2ec1..8f5f6a5443 100644 --- a/src/peft/tuners/hra/__init__.py +++ b/src/peft/tuners/hra/__init__.py @@ -12,9 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +from peft.utils import register_peft_method + from .config import HRAConfig from .layer import HRAConv2d, HRALayer, HRALinear from .model import HRAModel __all__ = ["HRAConfig", "HRAConv2d", "HRALayer", "HRALinear", "HRAModel"] + +register_peft_method(name="hra", config_cls=HRAConfig, model_cls=HRAModel) diff --git a/src/peft/tuners/ia3/__init__.py b/src/peft/tuners/ia3/__init__.py index f88f8411a0..21cab4d6d8 100644 --- a/src/peft/tuners/ia3/__init__.py +++ b/src/peft/tuners/ia3/__init__.py @@ -13,6 +13,7 @@ # limitations under the License. from peft.import_utils import is_bnb_4bit_available, is_bnb_available +from peft.utils import register_peft_method from .config import IA3Config from .layer import Conv2d, Conv3d, IA3Layer, Linear @@ -21,6 +22,8 @@ __all__ = ["Conv2d", "Conv3d", "IA3Config", "IA3Layer", "IA3Model", "Linear"] +register_peft_method(name="ia3", config_cls=IA3Config, model_cls=IA3Model, is_mixed_compatible=True) + def __getattr__(name): if (name == "Linear8bitLt") and is_bnb_available(): diff --git a/src/peft/tuners/ln_tuning/__init__.py b/src/peft/tuners/ln_tuning/__init__.py index afaae73a44..8f90a8fb05 100644 --- a/src/peft/tuners/ln_tuning/__init__.py +++ b/src/peft/tuners/ln_tuning/__init__.py @@ -12,8 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +from peft.utils import register_peft_method + from .config import LNTuningConfig from .model import LNTuningModel __all__ = ["LNTuningConfig", "LNTuningModel"] + +register_peft_method(name="ln_tuning", config_cls=LNTuningConfig, model_cls=LNTuningModel) diff --git a/src/peft/tuners/loha/__init__.py b/src/peft/tuners/loha/__init__.py index cc54826fe4..70dd1545bb 100644 --- a/src/peft/tuners/loha/__init__.py +++ b/src/peft/tuners/loha/__init__.py @@ -12,9 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +from peft.utils import register_peft_method + from .config import LoHaConfig from .layer import Conv2d, Linear, LoHaLayer from .model import LoHaModel __all__ = ["Conv2d", "Linear", "LoHaConfig", "LoHaLayer", "LoHaModel"] + +register_peft_method(name="loha", config_cls=LoHaConfig, model_cls=LoHaModel, prefix="hada_", is_mixed_compatible=True) diff --git a/src/peft/tuners/lokr/__init__.py b/src/peft/tuners/lokr/__init__.py index 1b0c9f5438..f4fe0e92c6 100644 --- a/src/peft/tuners/lokr/__init__.py +++ b/src/peft/tuners/lokr/__init__.py @@ -12,9 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +from peft.utils import register_peft_method + from .config import LoKrConfig from .layer import Conv2d, Linear, LoKrLayer from .model import LoKrModel __all__ = ["Conv2d", "Linear", "LoKrConfig", "LoKrLayer", "LoKrModel"] + +register_peft_method(name="lokr", config_cls=LoKrConfig, model_cls=LoKrModel, is_mixed_compatible=True) diff --git a/src/peft/tuners/lora/__init__.py b/src/peft/tuners/lora/__init__.py index de1f423cd0..779d4eec79 100644 --- a/src/peft/tuners/lora/__init__.py +++ b/src/peft/tuners/lora/__init__.py @@ -13,6 +13,7 @@ # limitations under the License. from peft.import_utils import is_bnb_4bit_available, is_bnb_available, is_eetq_available +from peft.utils import register_peft_method from .config import EvaConfig, LoftQConfig, LoraConfig, LoraRuntimeConfig from .eva import get_eva_state_dict, initialize_lora_eva_weights @@ -37,6 +38,8 @@ "initialize_lora_eva_weights", ] +register_peft_method(name="lora", config_cls=LoraConfig, model_cls=LoraModel, is_mixed_compatible=True) + def __getattr__(name): if (name == "Linear8bitLt") and is_bnb_available(): diff --git a/src/peft/tuners/multitask_prompt_tuning/__init__.py b/src/peft/tuners/multitask_prompt_tuning/__init__.py index d9f98d4a76..fe692a9337 100644 --- a/src/peft/tuners/multitask_prompt_tuning/__init__.py +++ b/src/peft/tuners/multitask_prompt_tuning/__init__.py @@ -12,8 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. +from peft.utils import register_peft_method + from .config import MultitaskPromptTuningConfig, MultitaskPromptTuningInit from .model import MultitaskPromptEmbedding __all__ = ["MultitaskPromptEmbedding", "MultitaskPromptTuningConfig", "MultitaskPromptTuningInit"] + +register_peft_method( + name="multitask_prompt_tuning", config_cls=MultitaskPromptTuningConfig, model_cls=MultitaskPromptEmbedding +) diff --git a/src/peft/tuners/oft/__init__.py b/src/peft/tuners/oft/__init__.py index 1b8bdaa6cd..5df395090f 100644 --- a/src/peft/tuners/oft/__init__.py +++ b/src/peft/tuners/oft/__init__.py @@ -12,9 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +from peft.utils import register_peft_method + from .config import OFTConfig from .layer import Conv2d, Linear, OFTLayer from .model import OFTModel __all__ = ["Conv2d", "Linear", "OFTConfig", "OFTLayer", "OFTModel"] + +register_peft_method(name="oft", config_cls=OFTConfig, model_cls=OFTModel) diff --git a/src/peft/tuners/p_tuning/__init__.py b/src/peft/tuners/p_tuning/__init__.py index 7dd3a6ba3e..9195c0d75d 100644 --- a/src/peft/tuners/p_tuning/__init__.py +++ b/src/peft/tuners/p_tuning/__init__.py @@ -12,8 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +from peft.utils import register_peft_method + from .config import PromptEncoderConfig, PromptEncoderReparameterizationType from .model import PromptEncoder __all__ = ["PromptEncoder", "PromptEncoderConfig", "PromptEncoderReparameterizationType"] + +register_peft_method(name="p_tuning", config_cls=PromptEncoderConfig, model_cls=PromptEncoder) diff --git a/src/peft/tuners/poly/__init__.py b/src/peft/tuners/poly/__init__.py index b0f368695e..1c18933eba 100644 --- a/src/peft/tuners/poly/__init__.py +++ b/src/peft/tuners/poly/__init__.py @@ -12,9 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +from peft.utils import register_peft_method + from .config import PolyConfig from .layer import Linear, PolyLayer from .model import PolyModel __all__ = ["Linear", "PolyConfig", "PolyLayer", "PolyModel"] + +register_peft_method(name="poly", config_cls=PolyConfig, model_cls=PolyModel) diff --git a/src/peft/tuners/prefix_tuning/__init__.py b/src/peft/tuners/prefix_tuning/__init__.py index 8d2c7efa64..939f74d3f6 100644 --- a/src/peft/tuners/prefix_tuning/__init__.py +++ b/src/peft/tuners/prefix_tuning/__init__.py @@ -12,8 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +from peft.utils import register_peft_method + from .config import PrefixTuningConfig from .model import PrefixEncoder __all__ = ["PrefixEncoder", "PrefixTuningConfig"] + +register_peft_method(name="prefix_tuning", config_cls=PrefixTuningConfig, model_cls=PrefixEncoder) diff --git a/src/peft/tuners/prompt_tuning/__init__.py b/src/peft/tuners/prompt_tuning/__init__.py index 6a9ceb48fa..c99ca6a26f 100644 --- a/src/peft/tuners/prompt_tuning/__init__.py +++ b/src/peft/tuners/prompt_tuning/__init__.py @@ -12,8 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +from peft.utils import register_peft_method + from .config import PromptTuningConfig, PromptTuningInit from .model import PromptEmbedding __all__ = ["PromptEmbedding", "PromptTuningConfig", "PromptTuningInit"] + +register_peft_method(name="prompt_tuning", config_cls=PromptTuningConfig, model_cls=PromptEmbedding) diff --git a/src/peft/tuners/vblora/__init__.py b/src/peft/tuners/vblora/__init__.py index 558b9d80fd..8e71a08461 100644 --- a/src/peft/tuners/vblora/__init__.py +++ b/src/peft/tuners/vblora/__init__.py @@ -12,9 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +from peft.utils import register_peft_method + from .config import VBLoRAConfig from .layer import Linear, VBLoRALayer from .model import VBLoRAModel __all__ = ["Linear", "VBLoRAConfig", "VBLoRALayer", "VBLoRAModel"] + +register_peft_method(name="vblora", config_cls=VBLoRAConfig, model_cls=VBLoRAModel) diff --git a/src/peft/tuners/vera/__init__.py b/src/peft/tuners/vera/__init__.py index 268ce3ef93..25c4a96619 100644 --- a/src/peft/tuners/vera/__init__.py +++ b/src/peft/tuners/vera/__init__.py @@ -13,6 +13,7 @@ # limitations under the License. from peft.import_utils import is_bnb_4bit_available, is_bnb_available +from peft.utils import register_peft_method from .config import VeraConfig from .layer import Linear, VeraLayer @@ -22,6 +23,9 @@ __all__ = ["Linear", "VeraConfig", "VeraLayer", "VeraModel"] +register_peft_method(name="vera", config_cls=VeraConfig, model_cls=VeraModel, prefix="vera_lambda_") + + def __getattr__(name): if (name == "Linear8bitLt") and is_bnb_available(): from .bnb import Linear8bitLt diff --git a/src/peft/tuners/vera/model.py b/src/peft/tuners/vera/model.py index e129620ce8..6b41937a16 100644 --- a/src/peft/tuners/vera/model.py +++ b/src/peft/tuners/vera/model.py @@ -99,7 +99,7 @@ class VeraModel(BaseTuner): - **peft_config** ([`VeraConfig`]): The configuration of the Vera model. """ - prefix: str = "vera_lambda" + prefix: str = "vera_lambda_" def __init__(self, model, config, adapter_name, low_cpu_mem_usage: bool = False) -> None: super().__init__(model, config, adapter_name, low_cpu_mem_usage=low_cpu_mem_usage) diff --git a/src/peft/tuners/xlora/__init__.py b/src/peft/tuners/xlora/__init__.py index df41e1e611..6eae1f779b 100644 --- a/src/peft/tuners/xlora/__init__.py +++ b/src/peft/tuners/xlora/__init__.py @@ -12,8 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +from peft.utils import register_peft_method + from .config import XLoraConfig from .model import XLoraModel __all__ = ["XLoraConfig", "XLoraModel"] + +register_peft_method(name="xlora", config_cls=XLoraConfig, model_cls=XLoraModel) diff --git a/src/peft/utils/__init__.py b/src/peft/utils/__init__.py index 7e5acfc1a2..99a86e5b23 100644 --- a/src/peft/utils/__init__.py +++ b/src/peft/utils/__init__.py @@ -46,7 +46,7 @@ shift_tokens_right, transpose, ) -from .peft_types import PeftType, TaskType +from .peft_types import PeftType, TaskType, register_peft_method from .save_and_load import get_peft_model_state_dict, load_peft_weights, set_peft_model_state_dict @@ -84,6 +84,7 @@ "load_peft_weights", "map_cache_to_layer_device_map", "prepare_model_for_kbit_training", + "register_peft_method", "replace_lora_weights_loftq", "set_peft_model_state_dict", "shift_tokens_right", diff --git a/src/peft/utils/constants.py b/src/peft/utils/constants.py index 986be87bd6..e7c305926a 100644 --- a/src/peft/utils/constants.py +++ b/src/peft/utils/constants.py @@ -15,8 +15,6 @@ import torch from transformers import BloomPreTrainedModel -from .peft_types import PeftType - # needed for prefix-tuning of bloom model def bloom_model_postprocess_past_key_value(past_key_values): @@ -286,23 +284,6 @@ def starcoder_model_postprocess_past_key_value(past_key_values): "qwen2": ["q_proj", "v_proj"], } -PEFT_TYPE_TO_PREFIX_MAPPING = { - PeftType.IA3: "ia3_", - PeftType.LORA: "lora_", - PeftType.ADALORA: "lora_", - PeftType.LOHA: "hada_", - PeftType.LOKR: "lokr_", - PeftType.OFT: "oft_", - PeftType.POLY: "poly_", - PeftType.BOFT: "boft_", - PeftType.LN_TUNING: "ln_tuning_", - PeftType.VERA: "vera_lambda_", - PeftType.FOURIERFT: "fourierft_", - PeftType.HRA: "hra_", - PeftType.VBLORA: "vblora_", - PeftType.BONE: "bone_", -} - WEIGHTS_NAME = "adapter_model.bin" SAFETENSORS_WEIGHTS_NAME = "adapter_model.safetensors" CONFIG_NAME = "adapter_config.json" diff --git a/src/peft/utils/hotswap.py b/src/peft/utils/hotswap.py index 3ff7caacce..b09249409c 100644 --- a/src/peft/utils/hotswap.py +++ b/src/peft/utils/hotswap.py @@ -18,9 +18,8 @@ import torch from peft.config import PeftConfig -from peft.mapping import PEFT_TYPE_TO_CONFIG_MAPPING +from peft.mapping import PEFT_TYPE_TO_CONFIG_MAPPING, PEFT_TYPE_TO_PREFIX_MAPPING -from .constants import PEFT_TYPE_TO_PREFIX_MAPPING from .other import infer_device from .peft_types import PeftType from .save_and_load import _insert_adapter_name_into_state_dict, load_peft_weights diff --git a/src/peft/utils/peft_types.py b/src/peft/utils/peft_types.py index d2f1539074..aa7c548f03 100644 --- a/src/peft/utils/peft_types.py +++ b/src/peft/utils/peft_types.py @@ -13,6 +13,7 @@ # limitations under the License. import enum +from typing import Optional class PeftType(str, enum.Enum): @@ -84,3 +85,80 @@ class TaskType(str, enum.Enum): TOKEN_CLS = "TOKEN_CLS" QUESTION_ANS = "QUESTION_ANS" FEATURE_EXTRACTION = "FEATURE_EXTRACTION" + + +def register_peft_method( + *, name: str, config_cls, model_cls, prefix: Optional[str] = None, is_mixed_compatible=False +) -> None: + """ + Function to register a finetuning method like LoRA to be available in PEFT. + + This method takes care of registering the PEFT method's configuration class, the model class, and optionally the + prefix. + + Args: + name (str): + The name of the PEFT method. It must be unique. + config_cls: + The configuration class of the PEFT method. + model_cls: + The model class of the PEFT method. + prefix (Optional[str], optional): + The prefix of the PEFT method. It should be unique. If not provided, the name of the PEFT method is used as + the prefix. + is_mixed_compatible (bool, optional): + Whether the PEFT method is compatible with `PeftMixedModel`. If you're not sure, leave it as False + (default). + + Example: + + ```py + # inside of peft/tuners/my_peft_method/__init__.py + from peft.utils import register_peft_method + + register_peft_method(name="my_peft_method", config_cls=MyConfig, model_cls=MyModel) + ``` + """ + from peft.mapping import ( + PEFT_TYPE_TO_CONFIG_MAPPING, + PEFT_TYPE_TO_MIXED_MODEL_MAPPING, + PEFT_TYPE_TO_PREFIX_MAPPING, + PEFT_TYPE_TO_TUNER_MAPPING, + ) + + if name.endswith("_"): + raise ValueError(f"Please pass the name of the PEFT method without '_' suffix, got {name}.") + + if not name.islower(): + raise ValueError(f"The name of the PEFT method should be in lower case letters, got {name}.") + + if name.upper() not in list(PeftType): + raise ValueError(f"Unknown PEFT type {name.upper()}, please add an entry to peft.utils.peft_types.PeftType.") + + peft_type = getattr(PeftType, name.upper()) + + # model_cls can be None for prompt learning methods, which don't have dedicated model classes + if prefix is None: + prefix = name + "_" + + if ( + (peft_type in PEFT_TYPE_TO_CONFIG_MAPPING) + or (peft_type in PEFT_TYPE_TO_TUNER_MAPPING) + or (peft_type in PEFT_TYPE_TO_MIXED_MODEL_MAPPING) + ): + raise KeyError(f"There is already PEFT method called '{name}', please choose a unique name.") + + if prefix in PEFT_TYPE_TO_PREFIX_MAPPING: + raise KeyError(f"There is already a prefix called '{prefix}', please choose a unique prefix.") + + model_cls_prefix = getattr(model_cls, "prefix", None) + if (model_cls_prefix is not None) and (model_cls_prefix != prefix): + raise ValueError( + f"Inconsistent prefixes found: '{prefix}' and '{model_cls_prefix}' (they should be the same)." + ) + + PEFT_TYPE_TO_PREFIX_MAPPING[peft_type] = prefix + PEFT_TYPE_TO_CONFIG_MAPPING[peft_type] = config_cls + PEFT_TYPE_TO_TUNER_MAPPING[peft_type] = model_cls + if is_mixed_compatible: + PEFT_TYPE_TO_MIXED_MODEL_MAPPING[peft_type] = model_cls diff --git a/src/peft/utils/save_and_load.py b/src/peft/utils/save_and_load.py index 44a1cad5ff..5337afeb9f 100644 --- a/src/peft/utils/save_and_load.py +++ b/src/peft/utils/save_and_load.py @@ -24,7 +24,8 @@ from packaging import version from safetensors.torch import load_file as safe_load_file -from .constants import PEFT_TYPE_TO_PREFIX_MAPPING +from peft.mapping import PEFT_TYPE_TO_PREFIX_MAPPING + from .other import ( EMBEDDING_LAYER_NAMES, SAFETENSORS_WEIGHTS_NAME, @@ -133,12 +134,6 @@ def renamed_dora_weights(k): else: raise NotImplementedError - elif config.peft_type == PeftType.LOHA: - to_return = {k: state_dict[k] for k in state_dict if "hada_" in k} - - elif config.peft_type == PeftType.LOKR: - to_return = {k: state_dict[k] for k in state_dict if "lokr_" in k} - elif config.peft_type == PeftType.ADAPTION_PROMPT: to_return = {k: state_dict[k] for k in state_dict if k.split(".")[-1].startswith("adaption_")} @@ -155,20 +150,9 @@ def renamed_dora_weights(k): prompt_embeddings = model.get_prompt_embedding_to_save(adapter_name) to_return["prompt_embeddings"] = prompt_embeddings - elif config.peft_type == PeftType.IA3: - to_return = {k: state_dict[k] for k in state_dict if "ia3_" in k} - - elif config.peft_type == PeftType.OFT: - to_return = {k: state_dict[k] for k in state_dict if "oft_" in k} - - elif config.peft_type == PeftType.POLY: - to_return = {k: state_dict[k] for k in state_dict if "poly_" in k} - - elif config.peft_type == PeftType.LN_TUNING: - to_return = {k: state_dict[k] for k in state_dict if "ln_tuning_" in k} - elif config.peft_type == PeftType.VERA: - to_return = {k: state_dict[k] for k in state_dict if "vera_lambda_" in k} + vera_prefix = PEFT_TYPE_TO_PREFIX_MAPPING[config.peft_type] + to_return = {k: state_dict[k] for k in state_dict if vera_prefix in k} if config.save_projection: # TODO: adding vera_A and vera_B to `self.get_base_layer` would # make name to match here difficult to predict. @@ -179,12 +163,8 @@ def renamed_dora_weights(k): ) to_return["base_model.vera_A." + adapter_name] = state_dict["base_model.vera_A." + adapter_name] to_return["base_model.vera_B." + adapter_name] = state_dict["base_model.vera_B." + adapter_name] - elif config.peft_type == PeftType.FOURIERFT: - to_return = {k: state_dict[k] for k in state_dict if "fourierft_" in k} elif config.peft_type == PeftType.XLORA: to_return = {k: state_dict[k] for k in state_dict if "internal_xlora_classifier" in k} - elif config.peft_type == PeftType.HRA: - to_return = {k: state_dict[k] for k in state_dict if "hra_" in k} elif config.peft_type == PeftType.VBLORA: to_return = {} # choose the most efficient dtype for indices @@ -208,8 +188,9 @@ def renamed_dora_weights(k): to_return["base_model.vblora_vector_bank." + adapter_name] = state_dict[ "base_model.vblora_vector_bank." + adapter_name ] - elif config.peft_type == PeftType.BONE: - to_return = {k: state_dict[k] for k in state_dict if "bone_" in k} + elif config.peft_type in list(PeftType): + prefix = PEFT_TYPE_TO_PREFIX_MAPPING[config.peft_type] + to_return = {k: state_dict[k] for k in state_dict if prefix in k} else: raise ValueError(f"Unknown PEFT type passed: {config.peft_type}") @@ -363,7 +344,11 @@ def set_peft_model_state_dict( else: state_dict = peft_model_state_dict - if config.peft_type in PEFT_TYPE_TO_PREFIX_MAPPING: + if config.is_prompt_learning or config.peft_type == PeftType.ADAPTION_PROMPT: + peft_model_state_dict = state_dict + elif config.peft_type == PeftType.XLORA: + peft_model_state_dict = state_dict + elif config.peft_type in PEFT_TYPE_TO_PREFIX_MAPPING: peft_model_state_dict = {} parameter_prefix = PEFT_TYPE_TO_PREFIX_MAPPING[config.peft_type] if config.peft_type == PeftType.VBLORA and config.save_only_topk_weights: @@ -430,11 +415,6 @@ def renamed_dora_weights(k): return k peft_model_state_dict = {renamed_dora_weights(k): v for k, v in peft_model_state_dict.items()} - - elif config.is_prompt_learning or config.peft_type == PeftType.ADAPTION_PROMPT: - peft_model_state_dict = state_dict - elif config.peft_type == PeftType.XLORA: - peft_model_state_dict = state_dict else: raise NotImplementedError diff --git a/tests/test_adaption_prompt.py b/tests/test_adaption_prompt.py index cbdf19a297..5ac6a66436 100644 --- a/tests/test_adaption_prompt.py +++ b/tests/test_adaption_prompt.py @@ -22,7 +22,7 @@ import torch from torch.testing import assert_close -from peft.mapping import get_peft_model +from peft import get_peft_model from peft.peft_model import PeftModel from peft.tuners.adaption_prompt import AdaptionPromptConfig from peft.utils.other import prepare_model_for_kbit_training diff --git a/tests/test_initialization.py b/tests/test_initialization.py index 15ad34ea89..103fa6696d 100644 --- a/tests/test_initialization.py +++ b/tests/test_initialization.py @@ -55,11 +55,11 @@ inject_adapter_in_model, set_peft_model_state_dict, ) +from peft.mapping import PEFT_TYPE_TO_PREFIX_MAPPING from peft.tuners.lora.config import CordaConfig from peft.tuners.lora.corda import preprocess_corda from peft.tuners.lora.layer import LoraLayer from peft.utils import infer_device -from peft.utils.constants import PEFT_TYPE_TO_PREFIX_MAPPING from peft.utils.hotswap import hotswap_adapter diff --git a/tests/test_multitask_prompt_tuning.py b/tests/test_multitask_prompt_tuning.py index 4dc8832001..0b6f661aad 100644 --- a/tests/test_multitask_prompt_tuning.py +++ b/tests/test_multitask_prompt_tuning.py @@ -22,7 +22,7 @@ from parameterized import parameterized from torch.testing import assert_close -from peft.mapping import get_peft_model +from peft import get_peft_model from peft.peft_model import PeftModel from peft.tuners.multitask_prompt_tuning import MultitaskPromptTuningConfig, MultitaskPromptTuningInit from peft.utils.other import WEIGHTS_NAME, prepare_model_for_kbit_training