diff --git a/fastchat/train/llama2_flash_attn_monkey_patch.py b/fastchat/train/llama2_flash_attn_monkey_patch.py index c1fe51c91..704aba572 100644 --- a/fastchat/train/llama2_flash_attn_monkey_patch.py +++ b/fastchat/train/llama2_flash_attn_monkey_patch.py @@ -2,31 +2,29 @@ from typing import Optional, Tuple import torch -from flash_attn import __version__ as flash_attn_version -from flash_attn.bert_padding import pad_input, unpad_input -from flash_attn.flash_attn_interface import ( - flash_attn_func, - flash_attn_varlen_kvpacked_func, -) + +is_flash_attn_2_available = False +try: + from flash_attn import __version__ as flash_attn_version + from flash_attn.bert_padding import pad_input, unpad_input # type: ignore + from flash_attn.flash_attn_interface import flash_attn_kvpacked_func, flash_attn_varlen_kvpacked_func # type: ignore + + is_flash_attn_2_available = ( + torch.cuda.is_available() and flash_attn_version >= "2.1.0" + ) +except ImportError: + warnings.warn("Flash Attention2 not support.") + +import transformers from transformers.models.llama.modeling_llama import ( LlamaAttention, LlamaModel, - rotate_half, + repeat_kv, + apply_rotary_pos_emb, ) +from transformers.utils import logging - -def apply_rotary_pos_emb(q, k, cos_sin, position_ids): - gather_indices = position_ids[:, :, None, None] # [bsz, seq_len, 1, 1] - gather_indices = gather_indices.repeat( - 1, 1, cos_sin[0].shape[1], cos_sin[0].shape[3] - ) - bsz = gather_indices.shape[0] - cos, sin = ( - torch.gather(x.transpose(1, 2).repeat(bsz, 1, 1, 1), 1, gather_indices) - for x in cos_sin - ) - q, k = ((x * cos) + (rotate_half(x) * sin) for x in (q, k)) - return q, k +logger = logging.get_logger(__name__) def forward( @@ -37,56 +35,95 @@ def forward( past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: bool = False, use_cache: bool = False, - padding_mask: Optional[torch.Tensor] = None, + **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + # LlamaFlashAttention2 attention does not support output_attentions + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) + + # overwrite attention_mask with padding_mask + attention_mask = kwargs.pop("padding_mask") if output_attentions: warnings.warn( "Output attentions is not supported for patched `LlamaAttention`, returning `None` instead." ) + output_attentions = False bsz, q_len, _ = hidden_states.size() kv_heads = getattr(self, "num_key_value_heads", self.num_heads) q, k, v = ( - op(hidden_states).view(bsz, q_len, nh, self.head_dim) + op(hidden_states).view(bsz, q_len, nh, self.head_dim).transpose(1, 2) for op, nh in ( (self.q_proj, self.num_heads), (self.k_proj, kv_heads), (self.v_proj, kv_heads), ) ) - # shape: (b, s, num_heads, head_dim) + # shape: (b, num_heads, s, head_dim) - kv_seq_len = k.shape[1] - past_kv_len = 0 + kv_seq_len = k.shape[-2] if past_key_value is not None: - past_kv_len = past_key_value[0].shape[2] - kv_seq_len += past_kv_len + kv_seq_len += past_key_value[0].shape[-2] - cos_sin = self.rotary_emb(v, seq_len=kv_seq_len) - q, k = apply_rotary_pos_emb(q, k, cos_sin, position_ids) + cos, sin = self.rotary_emb(v, seq_len=kv_seq_len) + q, k = apply_rotary_pos_emb(q, k, cos, sin, position_ids) + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + # therefore we just need to keep the original shape if past_key_value is not None: assert ( flash_attn_version >= "2.1.0" ), "past_key_value support requires flash-attn >= 2.1.0" - # reuse k, v - k = torch.cat([past_key_value[0].transpose(1, 2), k], dim=1) - v = torch.cat([past_key_value[1].transpose(1, 2), v], dim=1) + # reuse k, v, self_attention + k = torch.cat([past_key_value[0], k], dim=2) + v = torch.cat([past_key_value[1], v], dim=2) + + past_key_value = (k, v) if use_cache else None + + # cast to half precision + input_dtype = q.dtype + if input_dtype == torch.float32: + # Handle the case where the model is quantized + if hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + q = q.to(target_dtype) + k = k.to(target_dtype) + v = v.to(target_dtype) - past_key_value = (k.transpose(1, 2), v.transpose(1, 2)) if use_cache else None + if getattr(self, "num_key_value_groups", None): + k = repeat_kv(k, self.num_key_value_groups) + v = repeat_kv(v, self.num_key_value_groups) + + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) if attention_mask is None: - output = flash_attn_func(q, k, v, 0.0, softmax_scale=None, causal=True).view( - bsz, q_len, -1 + kv = torch.stack((k, v), dim=2) + attn_output = flash_attn_kvpacked_func( + q, kv, 0.0, softmax_scale=None, causal=True ) else: - q, indices, cu_q_lens, max_s = unpad_input(q, attention_mask[:, -q_len:]) + logger.warning_once("Padded sequences are less efficient in FlashAttention.") + q, indices_q, cu_q_lens, max_s = unpad_input(q, attention_mask[:, -q_len:]) # We can skip concat and call unpad twice but seems better to call unpad only once. kv, _, cu_k_lens, max_k = unpad_input( torch.stack((k, v), dim=2), attention_mask ) - output_unpad = flash_attn_varlen_kvpacked_func( + attn_output_unpad = flash_attn_varlen_kvpacked_func( q, kv, cu_q_lens, @@ -97,10 +134,14 @@ def forward( softmax_scale=None, causal=True, ) - output_unpad = output_unpad.reshape(-1, self.num_heads * self.head_dim) - output = pad_input(output_unpad, indices, bsz, q_len) + attn_output = pad_input(attn_output_unpad, indices_q, bsz, q_len) + + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() - return self.o_proj(output), None, past_key_value + if not output_attentions: + attn_weights = None + + return self.o_proj(attn_output), attn_weights, past_key_value # Disable the transformation of the attention mask in LlamaModel as flash attention @@ -137,8 +178,14 @@ def replace_llama_attn_with_flash_attn(): "ref: https://github.com/HazyResearch/flash-attention/issues/190#issuecomment-1523359593" ) - LlamaModel._prepare_decoder_attention_mask = _prepare_decoder_attention_mask - LlamaAttention.forward = forward + if is_flash_attn_2_available: + if transformers.__version__ >= "4.35.0": + transformers.models.llama.modeling_llama.LlamaAttention = ( + transformers.models.llama.modeling_llama.LlamaFlashAttention2 + ) + else: + LlamaModel._prepare_decoder_attention_mask = _prepare_decoder_attention_mask + LlamaAttention.forward = forward def test():