Skip to content

Commit

Permalink
fix sp (#9795)
Browse files Browse the repository at this point in the history
  • Loading branch information
DesmonDay authored Jan 17, 2025
1 parent bd2d9d0 commit 465ce1d
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 5 deletions.
12 changes: 12 additions & 0 deletions paddlenlp/peft/lora/loraga_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,13 @@
import paddle.distributed as dist
from paddle.distributed import fleet

try:
from paddle.distributed.fleet.utils.sequence_parallel_utils import (
register_sequence_parallel_allreduce_hooks,
)
except:
pass

from paddlenlp.peft import LoRAModel
from paddlenlp.peft.lora.lora_layers import (
ColumnParallelLoRALinear,
Expand Down Expand Up @@ -83,6 +90,11 @@ def estimate_gradient(self, model: PretrainedModel):
def _wrap_model(self, model):
"""Wrap Model without optimizer, support dp, tp and sharding"""

if self.args.tensor_parallel_degree > 1 and self.args.sequence_parallel:
register_sequence_parallel_allreduce_hooks(
model, self.args.gradient_accumulation_steps, self.args.fuse_sequence_parallel_allreduce
)

in_pipeline_parallel_mode = self.args.pipeline_parallel_degree > 1
in_sharding_parallel_mode = self.sharding is not None
in_tensor_parallel_mode = self.args.tensor_parallel_degree > 1
Expand Down
10 changes: 5 additions & 5 deletions paddlenlp/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,11 +414,6 @@ def _save_ckpt_func(state_dict, path, signal_path=None):
"We do not support skip_save_model_weight in peft model when using unified checkpoint, remove this config."
)

if args.sequence_parallel:
register_sequence_parallel_allreduce_hooks(
self.model, args.gradient_accumulation_steps, args.fuse_sequence_parallel_allreduce
)

self.do_grad_scaling = False
self.enable_autocast_context_manager = False
if args.fp16 or args.bf16:
Expand Down Expand Up @@ -1987,6 +1982,11 @@ def _wrap_model(self, model, training=True):
else:
model, self.optimizer = decorated

if self.args.tensor_parallel_degree > 1 and self.args.sequence_parallel:
register_sequence_parallel_allreduce_hooks(
model, self.args.gradient_accumulation_steps, self.args.fuse_sequence_parallel_allreduce
)

if self.args.world_size == 1:
if self.args.amp_master_grad:
mix_precision_utils.MixPrecisionLayer(model, dtype=self.amp_dtype)
Expand Down

0 comments on commit 465ce1d

Please sign in to comment.