diff --git a/litgpt/model.py b/litgpt/model.py index bc5bf6e89c..c694f63dfd 100644 --- a/litgpt/model.py +++ b/litgpt/model.py @@ -464,29 +464,18 @@ def build_rope_cache( theta = 1.0 / (base ** (torch.arange(0, n_elem, 2, device=device).float() / n_elem)) if extra_config is not None: - # Extract configuration parameters orig_context_len = extra_config["original_max_seq_len"] factor = extra_config["factor"] low_freq_factor = extra_config["low_freq_factor"] high_freq_factor = extra_config["high_freq_factor"] - # Compute wavelength thresholds - low_freq_wavelen = orig_context_len / low_freq_factor - high_freq_wavelen = orig_context_len / high_freq_factor - - # Compute wavelengths corresponding to the inverse frequencies wavelen = 2 * torch.pi / theta - - # Compute ratio across all elements ratio = orig_context_len / wavelen - - # Compute smooth_factor and clamp between 0 and 1 smooth_factor = (ratio - low_freq_factor) / (high_freq_factor - low_freq_factor) smooth_factor = torch.clamp(smooth_factor, min=0.0, max=1.0) # Compute adjusted_theta without masked indexing adjusted_theta = (1 - smooth_factor) * (theta / factor) + smooth_factor * theta - theta = adjusted_theta # Create position indices `[0, 1, ..., seq_len - 1]`