diff --git a/src/peft/tuners/tuners_utils.py b/src/peft/tuners/tuners_utils.py index 405277b6b5..277996b376 100644 --- a/src/peft/tuners/tuners_utils.py +++ b/src/peft/tuners/tuners_utils.py @@ -979,18 +979,21 @@ def check_target_module_exists(config, key: str) -> bool | re.Match[str] | None: # TODO: It's still unclear how empty layers_pattern (None, [], or "") should behave # For now, empty layers_pattern means any layer pattern is ok if layers_pattern is None or len(layers_pattern) == 0: - layer_index = re.match(r".*\.[^.]*\.(\d+)\.", key) + match = re.match(r".*\.[^.]*\.(?P\d+)\.", key) else: layers_pattern = [layers_pattern] if isinstance(layers_pattern, str) else layers_pattern for pattern in layers_pattern: - layer_index = re.match(rf".*\.{pattern}\.(\d+)\.", key) - if layer_index is not None: + match = re.match(rf"(.*\.)?{pattern}\.(?P\d+)\.", key) + if match is not None: break + if match: + layer_index = match.groupdict().get("idx") + if layer_index is None: target_module_found = False else: - layer_index = int(layer_index.group(1)) + layer_index = int(layer_index) if isinstance(layer_indexes, int): target_module_found = layer_index == layer_indexes else: diff --git a/tests/test_tuners_utils.py b/tests/test_tuners_utils.py index 06a47deb26..e713828c22 100644 --- a/tests/test_tuners_utils.py +++ b/tests/test_tuners_utils.py @@ -1,6 +1,3 @@ -#!/usr/bin/env python3 - -# coding=utf-8 # Copyright 2023-present the HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -92,6 +89,10 @@ ("foo.bar.1.baz", ["baz"], [0, 1, 2], ["bar"], True), ("foo.bar.1.baz", ["baz", "spam"], [1], ["bar"], True), ("foo.bar.1.baz", ["baz", "spam"], [0, 1, 2], ["bar"], True), + ("bar.1.baz", ["baz"], [0, 2], ["bar"], False), + ("bar.1.baz", ["baz"], [0, 1, 2], ["foo"], False), + ("bar.1.baz", ["baz"], [0, 2], ["bar"], False), + ("bar.1.baz", ["baz"], [0, 1, 2], ["bar"], True), # empty layers_to_transform ("foo.bar.7.baz", ["baz"], [], ["bar"], True), ("foo.bar.7.baz", ["baz"], None, ["bar"], True), @@ -119,14 +120,11 @@ # is one of the target nn.modules ("foo.bar.1.baz", ["baz"], [1], ["baz"], False), # here, layers_pattern is 'bar', but only keys that contain '.bar' are valid. - ("bar.1.baz", ["baz"], [1], ["bar"], False), ("foo.bar.001.baz", ["baz"], [1], ["bar"], True), ("foo.bar.1.spam.2.baz", ["baz"], [1], ["bar"], True), ("foo.bar.2.spam.1.baz", ["baz"], [1], ["bar"], False), # some realistic examples: module using nn.Sequential # for the below test case, key should contain '.blocks' to be valid, because of how layers_pattern is matched - ("blocks.1.weight", ["weight"], [1], ["blocks"], False), - ("blocks.1.bias", ["weight"], [1], ["blocks"], False), ("mlp.blocks.1.weight", ["weight"], [1], ["blocks"], True), ("mlp.blocks.1.bias", ["weight"], [1], ["blocks"], False), ]