Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix version compatibility issue with transformers>4.34.0 for flash-attention2 patch #2655

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 89 additions & 42 deletions fastchat/train/llama2_flash_attn_monkey_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,31 +2,29 @@
from typing import Optional, Tuple

import torch
from flash_attn import __version__ as flash_attn_version
from flash_attn.bert_padding import pad_input, unpad_input
from flash_attn.flash_attn_interface import (
flash_attn_func,
flash_attn_varlen_kvpacked_func,
)

is_flash_attn_2_available = False
try:
from flash_attn import __version__ as flash_attn_version
from flash_attn.bert_padding import pad_input, unpad_input # type: ignore
from flash_attn.flash_attn_interface import flash_attn_kvpacked_func, flash_attn_varlen_kvpacked_func # type: ignore

is_flash_attn_2_available = (
torch.cuda.is_available() and flash_attn_version >= "2.1.0"
)
except ImportError:
warnings.warn("Flash Attention2 not support.")

import transformers
from transformers.models.llama.modeling_llama import (
LlamaAttention,
LlamaModel,
rotate_half,
repeat_kv,
apply_rotary_pos_emb,
)
from transformers.utils import logging


def apply_rotary_pos_emb(q, k, cos_sin, position_ids):
gather_indices = position_ids[:, :, None, None] # [bsz, seq_len, 1, 1]
gather_indices = gather_indices.repeat(
1, 1, cos_sin[0].shape[1], cos_sin[0].shape[3]
)
bsz = gather_indices.shape[0]
cos, sin = (
torch.gather(x.transpose(1, 2).repeat(bsz, 1, 1, 1), 1, gather_indices)
for x in cos_sin
)
q, k = ((x * cos) + (rotate_half(x) * sin) for x in (q, k))
return q, k
logger = logging.get_logger(__name__)


def forward(
Expand All @@ -37,56 +35,95 @@ def forward(
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
padding_mask: Optional[torch.Tensor] = None,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
# LlamaFlashAttention2 attention does not support output_attentions
if "padding_mask" in kwargs:
warnings.warn(
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
)

# overwrite attention_mask with padding_mask
attention_mask = kwargs.pop("padding_mask")
if output_attentions:
warnings.warn(
"Output attentions is not supported for patched `LlamaAttention`, returning `None` instead."
)
output_attentions = False

bsz, q_len, _ = hidden_states.size()
kv_heads = getattr(self, "num_key_value_heads", self.num_heads)

q, k, v = (
op(hidden_states).view(bsz, q_len, nh, self.head_dim)
op(hidden_states).view(bsz, q_len, nh, self.head_dim).transpose(1, 2)
for op, nh in (
(self.q_proj, self.num_heads),
(self.k_proj, kv_heads),
(self.v_proj, kv_heads),
)
)
# shape: (b, s, num_heads, head_dim)
# shape: (b, num_heads, s, head_dim)

kv_seq_len = k.shape[1]
past_kv_len = 0
kv_seq_len = k.shape[-2]
if past_key_value is not None:
past_kv_len = past_key_value[0].shape[2]
kv_seq_len += past_kv_len
kv_seq_len += past_key_value[0].shape[-2]

cos_sin = self.rotary_emb(v, seq_len=kv_seq_len)
q, k = apply_rotary_pos_emb(q, k, cos_sin, position_ids)
cos, sin = self.rotary_emb(v, seq_len=kv_seq_len)
q, k = apply_rotary_pos_emb(q, k, cos, sin, position_ids)

# Flash attention requires the input to have the shape
# batch_size x seq_length x head_dim x hidden_dim
# therefore we just need to keep the original shape
if past_key_value is not None:
assert (
flash_attn_version >= "2.1.0"
), "past_key_value support requires flash-attn >= 2.1.0"
# reuse k, v
k = torch.cat([past_key_value[0].transpose(1, 2), k], dim=1)
v = torch.cat([past_key_value[1].transpose(1, 2), v], dim=1)
# reuse k, v, self_attention
k = torch.cat([past_key_value[0], k], dim=2)
v = torch.cat([past_key_value[1], v], dim=2)

past_key_value = (k, v) if use_cache else None

# cast to half precision
input_dtype = q.dtype
if input_dtype == torch.float32:
# Handle the case where the model is quantized
if hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype
else:
target_dtype = self.q_proj.weight.dtype

logger.warning_once(
f"The input hidden states seems to be silently casted in float32, this might be related to"
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
f" {target_dtype}."
)

q = q.to(target_dtype)
k = k.to(target_dtype)
v = v.to(target_dtype)

past_key_value = (k.transpose(1, 2), v.transpose(1, 2)) if use_cache else None
if getattr(self, "num_key_value_groups", None):
k = repeat_kv(k, self.num_key_value_groups)
v = repeat_kv(v, self.num_key_value_groups)

q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)

if attention_mask is None:
output = flash_attn_func(q, k, v, 0.0, softmax_scale=None, causal=True).view(
bsz, q_len, -1
kv = torch.stack((k, v), dim=2)
attn_output = flash_attn_kvpacked_func(
q, kv, 0.0, softmax_scale=None, causal=True
)
else:
q, indices, cu_q_lens, max_s = unpad_input(q, attention_mask[:, -q_len:])
logger.warning_once("Padded sequences are less efficient in FlashAttention.")
q, indices_q, cu_q_lens, max_s = unpad_input(q, attention_mask[:, -q_len:])
# We can skip concat and call unpad twice but seems better to call unpad only once.
kv, _, cu_k_lens, max_k = unpad_input(
torch.stack((k, v), dim=2), attention_mask
)
output_unpad = flash_attn_varlen_kvpacked_func(
attn_output_unpad = flash_attn_varlen_kvpacked_func(
q,
kv,
cu_q_lens,
Expand All @@ -97,10 +134,14 @@ def forward(
softmax_scale=None,
causal=True,
)
output_unpad = output_unpad.reshape(-1, self.num_heads * self.head_dim)
output = pad_input(output_unpad, indices, bsz, q_len)
attn_output = pad_input(attn_output_unpad, indices_q, bsz, q_len)

attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()

return self.o_proj(output), None, past_key_value
if not output_attentions:
attn_weights = None

return self.o_proj(attn_output), attn_weights, past_key_value


# Disable the transformation of the attention mask in LlamaModel as flash attention
Expand Down Expand Up @@ -137,8 +178,14 @@ def replace_llama_attn_with_flash_attn():
"ref: https://github.com/HazyResearch/flash-attention/issues/190#issuecomment-1523359593"
)

LlamaModel._prepare_decoder_attention_mask = _prepare_decoder_attention_mask
LlamaAttention.forward = forward
if is_flash_attn_2_available:
if transformers.__version__ >= "4.35.0":
transformers.models.llama.modeling_llama.LlamaAttention = (
transformers.models.llama.modeling_llama.LlamaFlashAttention2
)
else:
LlamaModel._prepare_decoder_attention_mask = _prepare_decoder_attention_mask
LlamaAttention.forward = forward


def test():
Expand Down