Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add code for positional embedding interpolation #10

Merged
merged 2 commits into from
Feb 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion src/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,11 @@
except ImportError:
"flash_attn not installed"

try:
from src.positional_embeddings import swap_mha_rope
except ImportError:
"could not import swap_mha_rope from src.positional_embeddings"


class AttentionBlock(nn.Module):
def __init__(self, config, layer_idx) -> None:
Expand Down Expand Up @@ -44,6 +49,13 @@ def __init__(self, config, layer_idx) -> None:
use_flash_attn=self.config.use_flash_attn,
).to(dtype=dtype)

# check if using interpolated rotary pos emb from config, and swap the rope emb
if config.get("use_interpolated_rotary_pos_emb", False):
swap_mha_rope(
mha=self.inner_mha_cls,
kwargs_new_rope={"scaling_factor": config.get("rotary_emb_scaling_factor", 1.0)},
)

if self.config.get("smeared_gqa", False):
self.inner_mha_cls.num_heads_kv = self.inner_mha_cls.num_heads
self.inner_mha_cls.rotary_emb.register_buffer("inv_freq", self.inner_mha_cls.rotary_emb.inv_freq)
Expand Down Expand Up @@ -322,7 +334,7 @@ def __init__(self, config):
self.config = config
self.embedding_layer = VocabParallelEmbedding(config)
self.norm = RMSNorm(config) if config.get("final_norm", True) else None
self.unembed = self.emb if config.tie_embeddings else VocabParallelEmbedding(config)
self.unembed = self.embedding_layer if config.tie_embeddings else VocabParallelEmbedding(config)

if config.get("use_flashfft", "False"):
from flashfftconv import FlashFFTConv
Expand Down
112 changes: 112 additions & 0 deletions src/positional_embeddings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
"""
Armin Thomas, Jan 2023. Modified by Eric Nguyen.

Wrappers for linearly interpolated rope embeddings to use inside of MHA layers of Flash Attn.

"""

import copy

import torch
from einops import rearrange
from flash_attn.layers.rotary import RotaryEmbedding
from flash_attn.modules.mha import MHA


# simple wrapper for flash-attn RoPE with linear scaling:
class LinearlyScaledRotaryEmbedding(RotaryEmbedding):
def __init__(
self,
dim: int,
scaling_factor: float = 1.0,
base=10000.0,
interleaved=False,
scale_base=None,
pos_idx_in_fp32=True,
device=None,
):
super().__init__(
dim=dim,
base=base,
interleaved=interleaved,
scale_base=scale_base,
pos_idx_in_fp32=pos_idx_in_fp32,
device=device,
)
self._linear_scaling_factor = scaling_factor

# adpated from: https://github.com/Dao-AILab/flash-attention/blob/43ceab630bc6c27712428da5a33fc9cb5c369d91/flash_attn/layers/rotary.py#L368

Check failure on line 38 in src/positional_embeddings.py

View workflow job for this annotation

GitHub Actions / codespell

adpated ==> adapted
def _update_cos_sin_cache(self, seqlen, device=None, dtype=None):
# Reset the tables if the sequence length has changed,
# if we're on a new device (possibly due to tracing for instance),
# or if we're switching from inference mode to training
if (
seqlen > self._seq_len_cached
or self._cos_cached is None
or self._cos_cached.device != device
or self._cos_cached.dtype != dtype
or (self.training and self._cos_cached.is_inference())
):
self._seq_len_cached = seqlen
# We want fp32 here, not self.inv_freq.dtype, since the model could be loaded in bf16
# And the output of arange can be quite large, so bf16 would lose a lot of precision.
# However, for compatibility reason, we add an option to use the dtype of self.inv_freq.
if self.pos_idx_in_fp32:
t = torch.arange(seqlen, device=device, dtype=torch.float32)
# linear scaling:
t = t / self._linear_scaling_factor
# We want fp32 here as well since inv_freq will be multiplied with t, and the output
# will be large. Having it in bf16 will lose a lot of precision and cause the
# cos & sin output to change significantly.
# We want to recompute self.inv_freq if it was not loaded in fp32
if self.inv_freq.dtype != torch.float32:
inv_freq = self._compute_inv_freq(device=device)
else:
inv_freq = self.inv_freq
else:
t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
# linear scaling:
t = t / self._linear_scaling_factor
inv_freq = self.inv_freq
# Don't do einsum, it converts fp32 to fp16 under AMP
# freqs = torch.einsum("i,j->ij", t, self.inv_freq)
freqs = torch.outer(t, inv_freq)
if self.scale is None:
self._cos_cached = torch.cos(freqs).to(dtype)
self._sin_cached = torch.sin(freqs).to(dtype)
else:
power = (
torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device) - seqlen // 2
) / self.scale_base
scale = self.scale.to(device=power.device) ** rearrange(power, "s -> s 1")
# We want the multiplication by scale to happen in fp32
self._cos_cached = (torch.cos(freqs) * scale).to(dtype)
self._sin_cached = (torch.sin(freqs) * scale).to(dtype)
self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype)
self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype)


# swap out RoPE of existing mha:
def swap_mha_rope(mha, new_rope: torch.nn.Module = LinearlyScaledRotaryEmbedding, kwargs_new_rope: dict = None):
# determine mha dtype and device:
dtype = mha.Wq.weight.dtype if mha.cross_attn else mha.Wqkv.weight.dtype
device = mha.Wq.weight.device if mha.cross_attn else mha.Wqkv.weight.device
# determine RoPE settings:
kwargs_old_rope = dict(
dim=mha.rotary_emb.dim,
base=mha.rotary_emb.base,
interleaved=mha.rotary_emb.interleaved,
scale_base=mha.rotary_emb.scale_base,
pos_idx_in_fp32=mha.rotary_emb.pos_idx_in_fp32,
device=mha.rotary_emb.inv_freq.device,
)
# delete old RoPE:
del mha.rotary_emb
# create new RoPE:
kwargs_new_rope = kwargs_new_rope or {"scaling_factor": 1.0}
scaled_rope = new_rope(**kwargs_new_rope, **kwargs_old_rope).to(dtype)
# attach new RoPE to mha:
mha.rotary_emb = scaled_rope
# make new sure RoPE is correctly registered:
assert isinstance(mha.rotary_emb, new_rope)
return mha
Loading