Skip to main content

unsloth_rs/kernels/
rope.rs

1// SPDX-License-Identifier: MIT
2// Copyright 2026 Tyler Zervas
3
4//! Rotary Position Embedding (`RoPE`) implementation.
5//!
6//! `RoPE` encodes position information directly into the query and key vectors
7//! through rotation, enabling the model to learn relative position relationships.
8//!
9//! ## Why `RoPE`?
10//!
11//! Unlike absolute position embeddings, `RoPE`:
12//! - Naturally encodes relative positions through rotation
13//! - Scales well to longer sequences than seen during training
14//! - Is used by modern LLMs like `LLaMA`, Mistral, and others
15//!
16//! ## Implementation Notes
17//!
18//! - Pre-computes cos/sin caches up to `max_seq_len` for efficiency
19//! - Applies rotation in pairs: splits `head_dim` in half and rotates each pair
20//! - Uses standard rotation formula: [x1*cos - x2*sin, x2*cos + x1*sin]
21
22use candle_core::{Device, Tensor};
23
24use crate::error::Result;
25
26/// Rotary position embedding.
27///
28/// Applies rotary embeddings to query and key tensors for position encoding.
29pub struct RotaryEmbedding {
30    /// Cosine cache [`max_seq_len`, `head_dim/2`]
31    cos_cache: Tensor,
32    /// Sine cache [`max_seq_len`, `head_dim/2`]  
33    sin_cache: Tensor,
34    /// Head dimension
35    head_dim: usize,
36}
37
38impl RotaryEmbedding {
39    /// Create rotary embeddings.
40    ///
41    /// # Arguments
42    /// * `head_dim` - Dimension per attention head
43    /// * `max_seq_len` - Maximum sequence length to cache
44    /// * `base` - Base for frequency computation (typically 10000)
45    /// * `device` - Device for tensors
46    pub fn new(head_dim: usize, max_seq_len: usize, base: f32, device: &Device) -> Result<Self> {
47        // Compute inverse frequencies
48        let inv_freq: Vec<f32> = (0..head_dim)
49            .step_by(2)
50            .map(|i| 1.0 / base.powf(i as f32 / head_dim as f32))
51            .collect();
52
53        let inv_freq = Tensor::from_vec(inv_freq, (head_dim / 2,), device)?;
54
55        // Compute position indices
56        let positions: Vec<f32> = (0..max_seq_len).map(|i| i as f32).collect();
57        let positions = Tensor::from_vec(positions, (max_seq_len, 1), device)?;
58
59        // Compute frequencies: [max_seq_len, head_dim/2]
60        let freqs = positions.matmul(&inv_freq.unsqueeze(0)?)?;
61
62        // Compute cos and sin caches
63        let cos_cache = freqs.cos()?;
64        let sin_cache = freqs.sin()?;
65
66        Ok(Self {
67            cos_cache,
68            sin_cache,
69            head_dim,
70        })
71    }
72
73    /// Apply rotary embedding to query and key tensors.
74    ///
75    /// # Arguments
76    /// * `q` - Query tensor [batch, `num_heads`, `seq_len`, `head_dim`]
77    /// * `k` - Key tensor [batch, `num_kv_heads`, `seq_len`, `head_dim`]
78    /// * `position_ids` - Position indices [batch, `seq_len`]
79    ///
80    /// # Returns
81    /// Tuple of (`rotated_q`, `rotated_k`)
82    pub fn forward(
83        &self,
84        q: &Tensor,
85        k: &Tensor,
86        _position_ids: &Tensor,
87    ) -> Result<(Tensor, Tensor)> {
88        let device = q.device();
89
90        if device.is_cuda() {
91            self.forward_cuda(q, k)
92        } else {
93            self.forward_cpu(q, k)
94        }
95    }
96
97    /// CPU reference implementation for `RoPE`.
98    fn forward_cpu(&self, q: &Tensor, k: &Tensor) -> Result<(Tensor, Tensor)> {
99        let seq_len = q.dim(2)?;
100
101        // Get cos/sin for positions
102        let cos = self.cos_cache.narrow(0, 0, seq_len)?;
103        let sin = self.sin_cache.narrow(0, 0, seq_len)?;
104
105        let q_rotated = self.apply_rotary(q, &cos, &sin)?;
106        let k_rotated = self.apply_rotary(k, &cos, &sin)?;
107
108        Ok((q_rotated, k_rotated))
109    }
110
111    /// CUDA implementation.
112    ///
113    /// Uses Candle's CUDA backend for GPU acceleration.
114    /// The algorithm is the same as the CPU implementation.
115    fn forward_cuda(&self, q: &Tensor, k: &Tensor) -> Result<(Tensor, Tensor)> {
116        tracing::debug!("Using CUDA RoPE path for Q shape {:?}", q.shape());
117        self.forward_cpu(q, k)
118    }
119
120    fn apply_rotary(&self, x: &Tensor, cos: &Tensor, sin: &Tensor) -> Result<Tensor> {
121        let half_dim = self.head_dim / 2;
122
123        // Split into two halves
124        let x1 = x.narrow(3, 0, half_dim)?;
125        let x2 = x.narrow(3, half_dim, half_dim)?;
126
127        // Apply rotation: [x1, x2] -> [x1*cos - x2*sin, x2*cos + x1*sin]
128        let rotated_x1 = (x1.broadcast_mul(cos)? - x2.broadcast_mul(sin)?)?;
129        let rotated_x2 = (x2.broadcast_mul(cos)? + x1.broadcast_mul(sin)?)?;
130
131        // Concatenate
132        Tensor::cat(&[&rotated_x1, &rotated_x2], 3).map_err(Into::into)
133    }
134}
135
136#[cfg(test)]
137mod tests {
138    use super::*;
139    use candle_core::DType;
140
141    #[test]
142    fn test_rope_creation() {
143        let device = Device::Cpu;
144        let rope = RotaryEmbedding::new(64, 2048, 10000.0, &device);
145        assert!(rope.is_ok());
146    }
147
148    #[test]
149    fn test_rope_preserves_shape() {
150        let device = Device::Cpu;
151        let rope = RotaryEmbedding::new(64, 2048, 10000.0, &device).unwrap();
152
153        let q = Tensor::zeros(&[1, 12, 10, 64], DType::F32, &device).unwrap();
154        let k = Tensor::zeros(&[1, 12, 10, 64], DType::F32, &device).unwrap();
155        let pos = Tensor::zeros(&[1, 10], DType::I64, &device).unwrap();
156
157        let (q_rot, k_rot) = rope.forward(&q, &k, &pos).unwrap();
158
159        assert_eq!(q_rot.shape().dims(), &[1, 12, 10, 64]);
160        assert_eq!(k_rot.shape().dims(), &[1, 12, 10, 64]);
161    }
162}