Skip to content

Commit

Permalink
Merge pull request #10 from brianhie/pos_embed_interpolation
Browse files Browse the repository at this point in the history
feat: support linear RoPE interpolation
  • Loading branch information
Zymrael authored Feb 18, 2024
2 parents b0bb978 + 48b5be1 commit 350b4a4
Show file tree
Hide file tree
Showing 2 changed files with 125 additions and 1 deletion.
14 changes: 13 additions & 1 deletion src/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,11 @@
except ImportError:
"flash_attn not installed"

try:
from src.positional_embeddings import swap_mha_rope
except ImportError:
"could not import swap_mha_rope from src.positional_embeddings"


class AttentionBlock(nn.Module):
def __init__(self, config, layer_idx) -> None:
Expand Down Expand Up @@ -44,6 +49,13 @@ def __init__(self, config, layer_idx) -> None:
use_flash_attn=self.config.use_flash_attn,
).to(dtype=dtype)

# check if using interpolated rotary pos emb from config, and swap the rope emb
if config.get("use_interpolated_rotary_pos_emb", False):
swap_mha_rope(
mha=self.inner_mha_cls,
kwargs_new_rope={"scaling_factor": config.get("rotary_emb_scaling_factor", 1.0)},
)

if self.config.get("smeared_gqa", False):
self.inner_mha_cls.num_heads_kv = self.inner_mha_cls.num_heads
self.inner_mha_cls.rotary_emb.register_buffer("inv_freq", self.inner_mha_cls.rotary_emb.inv_freq)
Expand Down Expand Up @@ -327,7 +339,7 @@ def __init__(self, config):
self.config = config
self.embedding_layer = VocabParallelEmbedding(config)
self.norm = RMSNorm(config) if config.get("final_norm", True) else None
self.unembed = self.emb if config.tie_embeddings else VocabParallelEmbedding(config)
self.unembed = self.embedding_layer if config.tie_embeddings else VocabParallelEmbedding(config)

if config.get("use_flashfft", "True"):
try:
Expand Down
112 changes: 112 additions & 0 deletions src/positional_embeddings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
"""
Armin Thomas, Jan 2023. Modified by Eric Nguyen.
Wrappers for linearly interpolated rope embeddings to use inside of MHA layers of Flash Attn.
"""

import copy

import torch
from einops import rearrange
from flash_attn.layers.rotary import RotaryEmbedding
from flash_attn.modules.mha import MHA


# simple wrapper for flash-attn RoPE with linear scaling:
class LinearlyScaledRotaryEmbedding(RotaryEmbedding):
def __init__(
self,
dim: int,
scaling_factor: float = 1.0,
base=10000.0,
interleaved=False,
scale_base=None,
pos_idx_in_fp32=True,
device=None,
):
super().__init__(
dim=dim,
base=base,
interleaved=interleaved,
scale_base=scale_base,
pos_idx_in_fp32=pos_idx_in_fp32,
device=device,
)
self._linear_scaling_factor = scaling_factor

# adpated from: https://github.com/Dao-AILab/flash-attention/blob/43ceab630bc6c27712428da5a33fc9cb5c369d91/flash_attn/layers/rotary.py#L368

Check failure on line 38 in src/positional_embeddings.py

View workflow job for this annotation

GitHub Actions / codespell

adpated ==> adapted
def _update_cos_sin_cache(self, seqlen, device=None, dtype=None):
# Reset the tables if the sequence length has changed,
# if we're on a new device (possibly due to tracing for instance),
# or if we're switching from inference mode to training
if (
seqlen > self._seq_len_cached
or self._cos_cached is None
or self._cos_cached.device != device
or self._cos_cached.dtype != dtype
or (self.training and self._cos_cached.is_inference())
):
self._seq_len_cached = seqlen
# We want fp32 here, not self.inv_freq.dtype, since the model could be loaded in bf16
# And the output of arange can be quite large, so bf16 would lose a lot of precision.
# However, for compatibility reason, we add an option to use the dtype of self.inv_freq.
if self.pos_idx_in_fp32:
t = torch.arange(seqlen, device=device, dtype=torch.float32)
# linear scaling:
t = t / self._linear_scaling_factor
# We want fp32 here as well since inv_freq will be multiplied with t, and the output
# will be large. Having it in bf16 will lose a lot of precision and cause the
# cos & sin output to change significantly.
# We want to recompute self.inv_freq if it was not loaded in fp32
if self.inv_freq.dtype != torch.float32:
inv_freq = self._compute_inv_freq(device=device)
else:
inv_freq = self.inv_freq
else:
t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
# linear scaling:
t = t / self._linear_scaling_factor
inv_freq = self.inv_freq
# Don't do einsum, it converts fp32 to fp16 under AMP
# freqs = torch.einsum("i,j->ij", t, self.inv_freq)
freqs = torch.outer(t, inv_freq)
if self.scale is None:
self._cos_cached = torch.cos(freqs).to(dtype)
self._sin_cached = torch.sin(freqs).to(dtype)
else:
power = (
torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device) - seqlen // 2
) / self.scale_base
scale = self.scale.to(device=power.device) ** rearrange(power, "s -> s 1")
# We want the multiplication by scale to happen in fp32
self._cos_cached = (torch.cos(freqs) * scale).to(dtype)
self._sin_cached = (torch.sin(freqs) * scale).to(dtype)
self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype)
self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype)


# swap out RoPE of existing mha:
def swap_mha_rope(mha, new_rope: torch.nn.Module = LinearlyScaledRotaryEmbedding, kwargs_new_rope: dict = None):
# determine mha dtype and device:
dtype = mha.Wq.weight.dtype if mha.cross_attn else mha.Wqkv.weight.dtype
device = mha.Wq.weight.device if mha.cross_attn else mha.Wqkv.weight.device
# determine RoPE settings:
kwargs_old_rope = dict(
dim=mha.rotary_emb.dim,
base=mha.rotary_emb.base,
interleaved=mha.rotary_emb.interleaved,
scale_base=mha.rotary_emb.scale_base,
pos_idx_in_fp32=mha.rotary_emb.pos_idx_in_fp32,
device=mha.rotary_emb.inv_freq.device,
)
# delete old RoPE:
del mha.rotary_emb
# create new RoPE:
kwargs_new_rope = kwargs_new_rope or {"scaling_factor": 1.0}
scaled_rope = new_rope(**kwargs_new_rope, **kwargs_old_rope).to(dtype)
# attach new RoPE to mha:
mha.rotary_emb = scaled_rope
# make new sure RoPE is correctly registered:
assert isinstance(mha.rotary_emb, new_rope)
return mha

0 comments on commit 350b4a4

Please sign in to comment.