tiny_recursive_rs/layers/
positional.rs

1/// Rotary Positional Embeddings (RoPE)
2///
3/// Based on the Python implementation in layers.py
4use candle_core::{Result, Tensor, Device, DType, D};
5
6/// Rotates half the hidden dims of the input
7///
8/// Splits the tensor along the last dimension and rotates:
9/// [x1, x2] -> [-x2, x1]
10fn rotate_half(x: &Tensor) -> Result<Tensor> {
11    let last_dim = x.dims().len() - 1;
12    let dim_size = x.dim(last_dim)?;
13    let half = dim_size / 2;
14
15    // Split into two halves
16    let x1 = x.narrow(last_dim, 0, half)?;
17    let x2 = x.narrow(last_dim, half, half)?;
18
19    // Concatenate [-x2, x1]
20    Tensor::cat(&[&x2.neg()?, &x1], last_dim)
21}
22
23/// Apply rotary positional embeddings to query and key tensors
24///
25/// # Arguments
26/// * `q` - Query tensor [batch, seq_len, num_heads, head_dim]
27/// * `k` - Key tensor [batch, seq_len, num_heads, head_dim]
28/// * `cos` - Cosine embeddings [seq_len, head_dim]
29/// * `sin` - Sine embeddings [seq_len, head_dim]
30///
31/// # Returns
32/// Tuple of (rotated_q, rotated_k) with same shapes as inputs
33pub fn apply_rotary_pos_emb(
34    q: &Tensor,
35    k: &Tensor,
36    cos: &Tensor,
37    sin: &Tensor,
38) -> Result<(Tensor, Tensor)> {
39    let orig_dtype = q.dtype();
40
41    // Convert to same dtype as cos/sin for computation
42    let q = if q.dtype() != cos.dtype() {
43        q.to_dtype(cos.dtype())?
44    } else {
45        q.clone()
46    };
47
48    let k = if k.dtype() != cos.dtype() {
49        k.to_dtype(cos.dtype())?
50    } else {
51        k.clone()
52    };
53
54    // Reshape cos/sin to broadcast: [seq_len, head_dim] -> [1, seq_len, 1, head_dim]
55    let cos = cos.unsqueeze(0)?.unsqueeze(2)?;
56    let sin = sin.unsqueeze(0)?.unsqueeze(2)?;
57
58    // Apply rotation: q_embed = (q * cos) + (rotate_half(q) * sin)
59    let q_rotated = rotate_half(&q)?;
60    let q_embed = q.broadcast_mul(&cos)?.add(&q_rotated.broadcast_mul(&sin)?)?;
61
62    let k_rotated = rotate_half(&k)?;
63    let k_embed = k.broadcast_mul(&cos)?.add(&k_rotated.broadcast_mul(&sin)?)?;
64
65    // Convert back to original dtype
66    let q_embed = if q_embed.dtype() != orig_dtype {
67        q_embed.to_dtype(orig_dtype)?
68    } else {
69        q_embed
70    };
71
72    let k_embed = if k_embed.dtype() != orig_dtype {
73        k_embed.to_dtype(orig_dtype)?
74    } else {
75        k_embed
76    };
77
78    Ok((q_embed, k_embed))
79}
80
81/// Rotary Positional Embedding layer
82///
83/// Precomputes cos/sin embeddings for all positions up to max_position_embeddings
84pub struct RotaryEmbedding {
85    cos_cached: Tensor,
86    sin_cached: Tensor,
87}
88
89impl RotaryEmbedding {
90    /// Create new RoPE embeddings
91    ///
92    /// # Arguments
93    /// * `dim` - Dimension of the embeddings (must be even)
94    /// * `max_position_embeddings` - Maximum sequence length
95    /// * `base` - Base for inverse frequencies (typically 10000.0)
96    /// * `device` - Device to create tensors on
97    pub fn new(
98        dim: usize,
99        max_position_embeddings: usize,
100        base: f32,
101        device: &Device,
102    ) -> Result<Self> {
103        // Compute inverse frequencies: 1.0 / (base^(i/dim)) for i in [0, 2, 4, ..., dim-2]
104        let inv_freq: Vec<f32> = (0..dim)
105            .step_by(2)
106            .map(|i| {
107                let exponent = i as f32 / dim as f32;
108                1.0 / base.powf(exponent)
109            })
110            .collect();
111
112        let inv_freq = Tensor::new(inv_freq.as_slice(), device)?;
113
114        // Create position indices: [0, 1, 2, ..., max_position_embeddings-1]
115        let t: Vec<f32> = (0..max_position_embeddings).map(|i| i as f32).collect();
116        let t = Tensor::new(t.as_slice(), device)?;
117
118        // Compute outer product: freqs[i, j] = t[i] * inv_freq[j]
119        // Shape: [max_position_embeddings, dim/2]
120        let freqs = t.unsqueeze(1)?.broadcast_mul(&inv_freq.unsqueeze(0)?)?;
121
122        // Concatenate freqs with itself to get full embedding
123        // Shape: [max_position_embeddings, dim]
124        let emb = Tensor::cat(&[&freqs, &freqs], 1)?;
125
126        // Cache cos and sin
127        let cos_cached = emb.cos()?;
128        let sin_cached = emb.sin()?;
129
130        Ok(Self {
131            cos_cached,
132            sin_cached,
133        })
134    }
135
136    /// Get the cached cos/sin embeddings
137    ///
138    /// # Returns
139    /// Tuple of (cos, sin) tensors with shape [max_position_embeddings, dim]
140    pub fn forward(&self) -> Result<(Tensor, Tensor)> {
141        Ok((self.cos_cached.clone(), self.sin_cached.clone()))
142    }
143
144    /// Get cos/sin embeddings for a specific sequence length
145    ///
146    /// # Arguments
147    /// * `seq_len` - Length of the sequence (must be <= max_position_embeddings)
148    ///
149    /// # Returns
150    /// Tuple of (cos, sin) tensors with shape [seq_len, dim]
151    pub fn forward_with_len(&self, seq_len: usize) -> Result<(Tensor, Tensor)> {
152        let cos = self.cos_cached.narrow(0, 0, seq_len)?;
153        let sin = self.sin_cached.narrow(0, 0, seq_len)?;
154        Ok((cos, sin))
155    }
156}
157
158#[cfg(test)]
159mod tests {
160    use super::*;
161
162    #[test]
163    fn test_rotate_half() -> Result<()> {
164        let device = Device::Cpu;
165
166        // Simple test: [1, 2, 3, 4] -> [-3, -4, 1, 2]
167        let x = Tensor::new(&[1.0f32, 2.0, 3.0, 4.0], &device)?.reshape((1, 4))?;
168        let rotated = rotate_half(&x)?;
169
170        let expected = Tensor::new(&[-3.0f32, -4.0, 1.0, 2.0], &device)?.reshape((1, 4))?;
171
172        let diff = rotated.sub(&expected)?.abs()?.sum_all()?.to_scalar::<f32>()?;
173        assert!(diff < 1e-6, "rotate_half failed");
174
175        Ok(())
176    }
177
178    #[test]
179    fn test_rotary_embedding_shape() -> Result<()> {
180        let device = Device::Cpu;
181
182        let rope = RotaryEmbedding::new(64, 512, 10000.0, &device)?;
183        let (cos, sin) = rope.forward()?;
184
185        assert_eq!(cos.dims(), &[512, 64]);
186        assert_eq!(sin.dims(), &[512, 64]);
187
188        Ok(())
189    }
190
191    #[test]
192    fn test_rotary_embedding_with_len() -> Result<()> {
193        let device = Device::Cpu;
194
195        let rope = RotaryEmbedding::new(64, 512, 10000.0, &device)?;
196        let (cos, sin) = rope.forward_with_len(128)?;
197
198        assert_eq!(cos.dims(), &[128, 64]);
199        assert_eq!(sin.dims(), &[128, 64]);
200
201        Ok(())
202    }
203
204    #[test]
205    fn test_apply_rotary_pos_emb_shape() -> Result<()> {
206        let device = Device::Cpu;
207
208        // Create dummy q, k tensors: [batch, seq_len, num_heads, head_dim]
209        let q = Tensor::randn(0f32, 1.0, (2, 16, 8, 64), &device)?;
210        let k = Tensor::randn(0f32, 1.0, (2, 16, 8, 64), &device)?;
211
212        // Create RoPE embeddings
213        let rope = RotaryEmbedding::new(64, 512, 10000.0, &device)?;
214        let (cos, sin) = rope.forward_with_len(16)?;
215
216        // Apply RoPE
217        let (q_embed, k_embed) = apply_rotary_pos_emb(&q, &k, &cos, &sin)?;
218
219        // Shapes should be preserved
220        assert_eq!(q_embed.dims(), q.dims());
221        assert_eq!(k_embed.dims(), k.dims());
222
223        Ok(())
224    }
225}