-
Notifications
You must be signed in to change notification settings - Fork 27
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #10 from brianhie/pos_embed_interpolation
feat: support linear RoPE interpolation
- Loading branch information
Showing
2 changed files
with
125 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,112 @@ | ||
""" | ||
Armin Thomas, Jan 2023. Modified by Eric Nguyen. | ||
Wrappers for linearly interpolated rope embeddings to use inside of MHA layers of Flash Attn. | ||
""" | ||
|
||
import copy | ||
|
||
import torch | ||
from einops import rearrange | ||
from flash_attn.layers.rotary import RotaryEmbedding | ||
from flash_attn.modules.mha import MHA | ||
|
||
|
||
# simple wrapper for flash-attn RoPE with linear scaling: | ||
class LinearlyScaledRotaryEmbedding(RotaryEmbedding): | ||
def __init__( | ||
self, | ||
dim: int, | ||
scaling_factor: float = 1.0, | ||
base=10000.0, | ||
interleaved=False, | ||
scale_base=None, | ||
pos_idx_in_fp32=True, | ||
device=None, | ||
): | ||
super().__init__( | ||
dim=dim, | ||
base=base, | ||
interleaved=interleaved, | ||
scale_base=scale_base, | ||
pos_idx_in_fp32=pos_idx_in_fp32, | ||
device=device, | ||
) | ||
self._linear_scaling_factor = scaling_factor | ||
|
||
# adpated from: https://github.com/Dao-AILab/flash-attention/blob/43ceab630bc6c27712428da5a33fc9cb5c369d91/flash_attn/layers/rotary.py#L368 | ||
def _update_cos_sin_cache(self, seqlen, device=None, dtype=None): | ||
# Reset the tables if the sequence length has changed, | ||
# if we're on a new device (possibly due to tracing for instance), | ||
# or if we're switching from inference mode to training | ||
if ( | ||
seqlen > self._seq_len_cached | ||
or self._cos_cached is None | ||
or self._cos_cached.device != device | ||
or self._cos_cached.dtype != dtype | ||
or (self.training and self._cos_cached.is_inference()) | ||
): | ||
self._seq_len_cached = seqlen | ||
# We want fp32 here, not self.inv_freq.dtype, since the model could be loaded in bf16 | ||
# And the output of arange can be quite large, so bf16 would lose a lot of precision. | ||
# However, for compatibility reason, we add an option to use the dtype of self.inv_freq. | ||
if self.pos_idx_in_fp32: | ||
t = torch.arange(seqlen, device=device, dtype=torch.float32) | ||
# linear scaling: | ||
t = t / self._linear_scaling_factor | ||
# We want fp32 here as well since inv_freq will be multiplied with t, and the output | ||
# will be large. Having it in bf16 will lose a lot of precision and cause the | ||
# cos & sin output to change significantly. | ||
# We want to recompute self.inv_freq if it was not loaded in fp32 | ||
if self.inv_freq.dtype != torch.float32: | ||
inv_freq = self._compute_inv_freq(device=device) | ||
else: | ||
inv_freq = self.inv_freq | ||
else: | ||
t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype) | ||
# linear scaling: | ||
t = t / self._linear_scaling_factor | ||
inv_freq = self.inv_freq | ||
# Don't do einsum, it converts fp32 to fp16 under AMP | ||
# freqs = torch.einsum("i,j->ij", t, self.inv_freq) | ||
freqs = torch.outer(t, inv_freq) | ||
if self.scale is None: | ||
self._cos_cached = torch.cos(freqs).to(dtype) | ||
self._sin_cached = torch.sin(freqs).to(dtype) | ||
else: | ||
power = ( | ||
torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device) - seqlen // 2 | ||
) / self.scale_base | ||
scale = self.scale.to(device=power.device) ** rearrange(power, "s -> s 1") | ||
# We want the multiplication by scale to happen in fp32 | ||
self._cos_cached = (torch.cos(freqs) * scale).to(dtype) | ||
self._sin_cached = (torch.sin(freqs) * scale).to(dtype) | ||
self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype) | ||
self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype) | ||
|
||
|
||
# swap out RoPE of existing mha: | ||
def swap_mha_rope(mha, new_rope: torch.nn.Module = LinearlyScaledRotaryEmbedding, kwargs_new_rope: dict = None): | ||
# determine mha dtype and device: | ||
dtype = mha.Wq.weight.dtype if mha.cross_attn else mha.Wqkv.weight.dtype | ||
device = mha.Wq.weight.device if mha.cross_attn else mha.Wqkv.weight.device | ||
# determine RoPE settings: | ||
kwargs_old_rope = dict( | ||
dim=mha.rotary_emb.dim, | ||
base=mha.rotary_emb.base, | ||
interleaved=mha.rotary_emb.interleaved, | ||
scale_base=mha.rotary_emb.scale_base, | ||
pos_idx_in_fp32=mha.rotary_emb.pos_idx_in_fp32, | ||
device=mha.rotary_emb.inv_freq.device, | ||
) | ||
# delete old RoPE: | ||
del mha.rotary_emb | ||
# create new RoPE: | ||
kwargs_new_rope = kwargs_new_rope or {"scaling_factor": 1.0} | ||
scaled_rope = new_rope(**kwargs_new_rope, **kwargs_old_rope).to(dtype) | ||
# attach new RoPE to mha: | ||
mha.rotary_emb = scaled_rope | ||
# make new sure RoPE is correctly registered: | ||
assert isinstance(mha.rotary_emb, new_rope) | ||
return mha |