Skip to main content

tensorlogic_trustformers/kv_cache/
position.rs

1use ndarray::{s, Array2, Array3, ArrayD, ArrayView1, IxDyn};
2use std::fmt;
3
4// ---------------------------------------------------------------------------
5// Position errors
6// ---------------------------------------------------------------------------
7
8/// Errors arising from position encoding operations.
9#[derive(Debug, Clone)]
10pub enum PositionError {
11    /// `head_dim` must be even for RoPE.
12    HeadDimMustBeEven { head_dim: usize },
13    /// Sequence offset exceeds the pre-computed cache.
14    SeqOffsetOutOfRange { offset: usize, max: usize },
15    /// Tensor shape mismatch.
16    ShapeMismatch {
17        expected: Vec<usize>,
18        got: Vec<usize>,
19    },
20}
21
22impl fmt::Display for PositionError {
23    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
24        match self {
25            Self::HeadDimMustBeEven { head_dim } => {
26                write!(f, "head_dim must be even for RoPE, got {}", head_dim)
27            }
28            Self::SeqOffsetOutOfRange { offset, max } => {
29                write!(
30                    f,
31                    "seq_offset {} is out of range (max pre-computed = {})",
32                    offset, max
33                )
34            }
35            Self::ShapeMismatch { expected, got } => {
36                write!(f, "Shape mismatch: expected {:?}, got {:?}", expected, got)
37            }
38        }
39    }
40}
41
42impl std::error::Error for PositionError {}
43
44// ---------------------------------------------------------------------------
45// Rotary Position Embedding (RoPE)
46// ---------------------------------------------------------------------------
47
48/// Rotary Position Embedding (RoPE) as introduced in Su et al. 2021.
49///
50/// Pre-computes cosine and sine caches up to `max_seq_len` positions and applies
51/// the rotation in-place to the last dimension of the input tensor.
52#[derive(Debug, Clone)]
53pub struct RotaryPositionEmbedding {
54    /// Dimension of each attention head.
55    pub head_dim: usize,
56    /// Base for the geometric frequency sequence (default 10000.0).
57    pub base: f64,
58    /// Maximum sequence length for which the cache was pre-computed.
59    pub max_seq_len: usize,
60    /// Pre-computed cosines: shape `[max_seq_len, head_dim / 2]`.
61    cos_cache: Array2<f64>,
62    /// Pre-computed sines: shape `[max_seq_len, head_dim / 2]`.
63    sin_cache: Array2<f64>,
64}
65
66impl RotaryPositionEmbedding {
67    /// Create a new RoPE module, pre-computing the cos/sin cache.
68    ///
69    /// Returns an error if `head_dim` is not even.
70    pub fn new(
71        head_dim: usize,
72        max_seq_len: usize,
73        base: f64,
74    ) -> std::result::Result<Self, PositionError> {
75        if !head_dim.is_multiple_of(2) {
76            return Err(PositionError::HeadDimMustBeEven { head_dim });
77        }
78        let (cos_cache, sin_cache) = Self::build_cos_sin_cache(head_dim, max_seq_len, base);
79        Ok(Self {
80            head_dim,
81            base,
82            max_seq_len,
83            cos_cache,
84            sin_cache,
85        })
86    }
87
88    /// Build the cos and sin frequency caches.
89    fn build_cos_sin_cache(
90        head_dim: usize,
91        max_seq_len: usize,
92        base: f64,
93    ) -> (Array2<f64>, Array2<f64>) {
94        let half_dim = head_dim / 2;
95        // θ_i = base^{-2i/d} for i in 0..half_dim
96        let thetas: Vec<f64> = (0..half_dim)
97            .map(|i| base.powf(-(2.0 * i as f64) / head_dim as f64))
98            .collect();
99
100        let mut cos_cache = Array2::<f64>::zeros((max_seq_len, half_dim));
101        let mut sin_cache = Array2::<f64>::zeros((max_seq_len, half_dim));
102
103        for pos in 0..max_seq_len {
104            for (i, &theta) in thetas.iter().enumerate() {
105                let angle = pos as f64 * theta;
106                cos_cache[[pos, i]] = angle.cos();
107                sin_cache[[pos, i]] = angle.sin();
108            }
109        }
110
111        (cos_cache, sin_cache)
112    }
113
114    /// Apply RoPE to the input tensor starting at `seq_offset`.
115    ///
116    /// `x` is expected to have shape `[seq_len, ..., head_dim]` where the last
117    /// axis is the head dimension (or any shape where the last axis == `head_dim`).
118    pub fn apply(
119        &self,
120        x: &ArrayD<f64>,
121        seq_offset: usize,
122    ) -> std::result::Result<ArrayD<f64>, PositionError> {
123        let shape = x.shape();
124        let ndim = shape.len();
125        if ndim < 1 {
126            return Err(PositionError::ShapeMismatch {
127                expected: vec![1],
128                got: shape.to_vec(),
129            });
130        }
131
132        let last_dim = shape[ndim - 1];
133        if last_dim != self.head_dim {
134            return Err(PositionError::ShapeMismatch {
135                expected: vec![self.head_dim],
136                got: vec![last_dim],
137            });
138        }
139
140        let seq_len = shape[0];
141        if seq_offset + seq_len > self.max_seq_len {
142            return Err(PositionError::SeqOffsetOutOfRange {
143                offset: seq_offset + seq_len - 1,
144                max: self.max_seq_len - 1,
145            });
146        }
147
148        let half_dim = self.head_dim / 2;
149
150        // Split x into first half and second half along last axis.
151        // For simplicity, work with a 2D view: [total_positions, head_dim].
152        let total = x.len() / self.head_dim;
153        let x2 = x
154            .view()
155            .into_shape_with_order((total, self.head_dim))
156            .map_err(|_| PositionError::ShapeMismatch {
157                expected: vec![total, self.head_dim],
158                got: shape.to_vec(),
159            })?;
160
161        // x_first: [total, half_dim], x_second: [total, half_dim]
162        let x_first = x2.slice(s![.., ..half_dim]).to_owned();
163        let x_second = x2.slice(s![.., half_dim..]).to_owned();
164
165        // rotate_half: [-x_second, x_first]
166        let mut rotated = Array2::<f64>::zeros((total, self.head_dim));
167        rotated.slice_mut(s![.., ..half_dim]).assign(&(-&x_second));
168        rotated.slice_mut(s![.., half_dim..]).assign(&x_first);
169
170        // Broadcast cos/sin for each position in [seq_offset, seq_offset + seq_len).
171        // Map each row of x2 to the corresponding position in the cache.
172        // Positions cycle through seq_len: row i -> seq_offset + (i / (total / seq_len))
173        let positions_per_token = total.checked_div(seq_len).unwrap_or(1);
174        let mut cos_expanded = Array2::<f64>::zeros((total, half_dim));
175        let mut sin_expanded = Array2::<f64>::zeros((total, half_dim));
176
177        for i in 0..total {
178            let pos = seq_offset + i / positions_per_token.max(1);
179            let capped_pos = pos.min(self.max_seq_len - 1);
180            cos_expanded
181                .slice_mut(s![i, ..])
182                .assign(&self.cos_cache.slice(s![capped_pos, ..]));
183            sin_expanded
184                .slice_mut(s![i, ..])
185                .assign(&self.sin_cache.slice(s![capped_pos, ..]));
186        }
187
188        // Repeat cos/sin to full head_dim by tiling: [total, half_dim] -> [total, head_dim]
189        let mut cos_full = Array2::<f64>::zeros((total, self.head_dim));
190        let mut sin_full = Array2::<f64>::zeros((total, self.head_dim));
191        cos_full.slice_mut(s![.., ..half_dim]).assign(&cos_expanded);
192        cos_full.slice_mut(s![.., half_dim..]).assign(&cos_expanded);
193        sin_full.slice_mut(s![.., ..half_dim]).assign(&sin_expanded);
194        sin_full.slice_mut(s![.., half_dim..]).assign(&sin_expanded);
195
196        // y = x * cos + rotate_half(x) * sin
197        let result2 = &x2 * &cos_full + &rotated * &sin_full;
198
199        // Reshape back to original shape.
200        let result = result2
201            .into_dyn()
202            .into_shape_with_order(IxDyn(shape))
203            .map_err(|_| PositionError::ShapeMismatch {
204                expected: shape.to_vec(),
205                got: vec![total, self.head_dim],
206            })?;
207
208        Ok(result)
209    }
210
211    /// Compute rotate_half: negate the first half and concatenate with the second half.
212    ///
213    /// Given `x` shaped `[..., head_dim]`, returns `[-x[..., head_dim/2:], x[..., :head_dim/2]]`.
214    pub fn rotate_half(x: &ArrayD<f64>) -> ArrayD<f64> {
215        let shape = x.shape();
216        let ndim = shape.len();
217        if ndim < 1 {
218            return x.to_owned();
219        }
220        let head_dim = shape[ndim - 1];
221        let half = head_dim / 2;
222        let total = x.len() / head_dim;
223
224        let x2 = x
225            .view()
226            .into_shape_with_order((total, head_dim))
227            .expect("rotate_half reshape");
228
229        let x_first = x2.slice(s![.., ..half]).to_owned();
230        let x_second = x2.slice(s![.., half..]).to_owned();
231
232        let mut out = Array2::<f64>::zeros((total, head_dim));
233        out.slice_mut(s![.., ..half]).assign(&(-&x_second));
234        out.slice_mut(s![.., half..]).assign(&x_first);
235
236        out.into_dyn()
237            .into_shape_with_order(IxDyn(shape))
238            .expect("rotate_half final reshape")
239    }
240
241    /// Return the pre-computed frequencies (cos values) at a specific position.
242    pub fn frequencies_at(&self, pos: usize) -> ArrayView1<'_, f64> {
243        let capped = pos.min(self.max_seq_len - 1);
244        self.cos_cache.slice(s![capped, ..])
245    }
246}
247
248// ---------------------------------------------------------------------------
249// Relative Position Bias (T5-style)
250// ---------------------------------------------------------------------------
251
252/// T5-style relative position bias that adds a learned scalar bias to attention
253/// logits based on the relative distance between query and key positions.
254#[derive(Debug, Clone)]
255pub struct RelativePositionBias {
256    /// Number of attention heads.
257    pub num_heads: usize,
258    /// Number of learned buckets for distances.
259    pub num_buckets: usize,
260    /// Maximum distance to consider (beyond this, distances are clamped).
261    pub max_distance: usize,
262    /// If `true`, use separate buckets for forward and backward directions.
263    pub bidirectional: bool,
264    /// Learned bias table: shape `[num_buckets, num_heads]`.
265    biases: Array2<f64>,
266}
267
268impl RelativePositionBias {
269    /// Create a new relative position bias (zero-initialized).
270    pub fn new(
271        num_heads: usize,
272        num_buckets: usize,
273        max_distance: usize,
274        bidirectional: bool,
275    ) -> Self {
276        Self {
277            num_heads,
278            num_buckets,
279            max_distance,
280            bidirectional,
281            biases: Array2::<f64>::zeros((num_buckets, num_heads)),
282        }
283    }
284
285    /// Compute the attention bias matrix of shape `[num_heads, q_len, k_len]`.
286    ///
287    /// For each (q, k) pair the relative position `q - k` is mapped to a bucket
288    /// and the corresponding learned bias is looked up.
289    pub fn compute_bias(&self, query_len: usize, key_len: usize) -> Array3<f64> {
290        let mut bias = Array3::<f64>::zeros((self.num_heads, query_len, key_len));
291
292        for q in 0..query_len {
293            for k in 0..key_len {
294                let relative_position = q as i32 - k as i32;
295                let bucket = Self::relative_position_bucket(
296                    relative_position,
297                    self.bidirectional,
298                    self.num_buckets,
299                    self.max_distance,
300                );
301                for h in 0..self.num_heads {
302                    bias[[h, q, k]] = self.biases[[bucket, h]];
303                }
304            }
305        }
306
307        bias
308    }
309
310    /// Map a relative position to a bucket index.
311    ///
312    /// The first half of the buckets covers exact small distances linearly.
313    /// The second half covers larger distances logarithmically.
314    fn relative_position_bucket(
315        relative_position: i32,
316        bidirectional: bool,
317        num_buckets: usize,
318        max_distance: usize,
319    ) -> usize {
320        let mut n = num_buckets;
321        let mut relative = relative_position;
322
323        if bidirectional {
324            n /= 2;
325            // Positive distances get offset by n.
326            if relative_position > 0 {
327                // Offset into second half.
328                let pos_bucket =
329                    Self::distance_to_bucket(relative_position as usize, n, max_distance);
330                return (n + pos_bucket).min(num_buckets - 1);
331            }
332            relative = -relative;
333        } else {
334            relative = (-relative).max(0);
335        }
336
337        let distance = relative as usize;
338        Self::distance_to_bucket(distance, n, max_distance).min(num_buckets - 1)
339    }
340
341    /// Map an absolute distance to a bucket in `[0, n)`.
342    fn distance_to_bucket(distance: usize, n: usize, max_distance: usize) -> usize {
343        if n == 0 {
344            return 0;
345        }
346        let max_exact = n / 2;
347        if distance < max_exact {
348            // Linear range.
349            distance
350        } else {
351            // Logarithmic range.
352            let clamped = distance.min(max_distance);
353            let scale = (clamped as f64 / max_exact as f64).ln()
354                / (max_distance as f64 / max_exact as f64).ln().max(1e-10);
355            let bucket_offset = (scale * (n - max_exact) as f64) as usize;
356            (max_exact + bucket_offset).min(n - 1)
357        }
358    }
359
360    /// Update the learned bias table.
361    ///
362    /// `new_biases` must have shape `[num_buckets, num_heads]`.
363    pub fn update_biases(
364        &mut self,
365        new_biases: Array2<f64>,
366    ) -> std::result::Result<(), PositionError> {
367        let expected = vec![self.num_buckets, self.num_heads];
368        let got = new_biases.shape().to_vec();
369        if got != expected {
370            return Err(PositionError::ShapeMismatch { expected, got });
371        }
372        self.biases = new_biases;
373        Ok(())
374    }
375}
376
377#[cfg(test)]
378mod tests {
379    use super::*;
380
381    fn make_tensor(shape: &[usize], fill: f64) -> ArrayD<f64> {
382        ArrayD::from_elem(IxDyn(shape), fill)
383    }
384
385    #[test]
386    fn test_rope_new_builds_cache() {
387        let rope = RotaryPositionEmbedding::new(8, 16, 10000.0).expect("valid head_dim");
388        assert_eq!(
389            rope.cos_cache.shape(),
390            &[16, 4],
391            "cos_cache shape [max_seq, half_dim]"
392        );
393        assert_eq!(
394            rope.sin_cache.shape(),
395            &[16, 4],
396            "sin_cache shape [max_seq, half_dim]"
397        );
398    }
399
400    #[test]
401    fn test_rope_apply_preserves_shape() {
402        let rope = RotaryPositionEmbedding::new(8, 32, 10000.0).expect("valid");
403        let x = make_tensor(&[4, 8], 1.0);
404        let result = rope.apply(&x, 0).expect("apply should succeed");
405        assert_eq!(
406            result.shape(),
407            x.shape(),
408            "output shape must match input shape"
409        );
410    }
411
412    #[test]
413    fn test_rope_rotate_half_correct() {
414        // For a 4-D head_dim: [a, b, c, d] -> [-c, -d, a, b]
415        let data = vec![1.0_f64, 2.0, 3.0, 4.0];
416        let x = ArrayD::from_shape_vec(IxDyn(&[1, 4]), data).expect("build");
417        let rotated = RotaryPositionEmbedding::rotate_half(&x);
418        let flat: Vec<f64> = rotated.iter().copied().collect();
419        // First half: negated second half of input = [-3, -4]
420        assert!(
421            (flat[0] - (-3.0)).abs() < 1e-9,
422            "first element should be -3"
423        );
424        assert!(
425            (flat[1] - (-4.0)).abs() < 1e-9,
426            "second element should be -4"
427        );
428        // Second half: first half of input = [1, 2]
429        assert!((flat[2] - 1.0).abs() < 1e-9, "third element should be 1");
430        assert!((flat[3] - 2.0).abs() < 1e-9, "fourth element should be 2");
431    }
432
433    #[test]
434    fn test_rope_head_dim_odd_errors() {
435        let result = RotaryPositionEmbedding::new(7, 16, 10000.0);
436        assert!(
437            matches!(result, Err(PositionError::HeadDimMustBeEven { .. })),
438            "odd head_dim should produce HeadDimMustBeEven error"
439        );
440    }
441
442    #[test]
443    fn test_relative_position_bias_compute() {
444        let rpb = RelativePositionBias::new(4, 32, 128, true);
445        let bias = rpb.compute_bias(6, 10);
446        assert_eq!(
447            bias.shape(),
448            &[4, 6, 10],
449            "bias shape must be [num_heads, q_len, k_len]"
450        );
451    }
452
453    #[test]
454    fn test_relative_position_bias_symmetric_for_bidirectional() {
455        // When bidirectional=true, positions (q=5, k=0) and (q=0, k=5) should
456        // use different buckets (forward vs. backward directions).
457        let _rpb = RelativePositionBias::new(1, 32, 64, true);
458        let forward_bucket = RelativePositionBias::relative_position_bucket(5, true, 32, 64);
459        let backward_bucket = RelativePositionBias::relative_position_bucket(-5, true, 32, 64);
460        assert_ne!(
461            forward_bucket, backward_bucket,
462            "forward and backward positions should map to different buckets"
463        );
464    }
465
466    #[test]
467    fn test_relative_position_bucket_clamping() {
468        // A very large distance should map to the last bucket (clamped).
469        let bucket = RelativePositionBias::relative_position_bucket(100000, false, 16, 128);
470        assert!(bucket < 16, "bucket must be within [0, num_buckets)");
471    }
472}