Skip to content

Commit

Permalink
Tidy up RoPE (#1786)
Browse files Browse the repository at this point in the history
  • Loading branch information
rasbt authored Oct 10, 2024
1 parent 46c4337 commit 6bcee99
Showing 1 changed file with 0 additions and 11 deletions.
11 changes: 0 additions & 11 deletions litgpt/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,29 +464,18 @@ def build_rope_cache(
theta = 1.0 / (base ** (torch.arange(0, n_elem, 2, device=device).float() / n_elem))

if extra_config is not None:
# Extract configuration parameters
orig_context_len = extra_config["original_max_seq_len"]
factor = extra_config["factor"]
low_freq_factor = extra_config["low_freq_factor"]
high_freq_factor = extra_config["high_freq_factor"]

# Compute wavelength thresholds
low_freq_wavelen = orig_context_len / low_freq_factor
high_freq_wavelen = orig_context_len / high_freq_factor

# Compute wavelengths corresponding to the inverse frequencies
wavelen = 2 * torch.pi / theta

# Compute ratio across all elements
ratio = orig_context_len / wavelen

# Compute smooth_factor and clamp between 0 and 1
smooth_factor = (ratio - low_freq_factor) / (high_freq_factor - low_freq_factor)
smooth_factor = torch.clamp(smooth_factor, min=0.0, max=1.0)

# Compute adjusted_theta without masked indexing
adjusted_theta = (1 - smooth_factor) * (theta / factor) + smooth_factor * theta

theta = adjusted_theta

# Create position indices `[0, 1, ..., seq_len - 1]`
Expand Down

0 comments on commit 6bcee99

Please sign in to comment.