1use ndarray::{s, Array2, Array3, ArrayD, ArrayView1, IxDyn};
2use std::fmt;
3
4#[derive(Debug, Clone)]
10pub enum PositionError {
11 HeadDimMustBeEven { head_dim: usize },
13 SeqOffsetOutOfRange { offset: usize, max: usize },
15 ShapeMismatch {
17 expected: Vec<usize>,
18 got: Vec<usize>,
19 },
20}
21
22impl fmt::Display for PositionError {
23 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
24 match self {
25 Self::HeadDimMustBeEven { head_dim } => {
26 write!(f, "head_dim must be even for RoPE, got {}", head_dim)
27 }
28 Self::SeqOffsetOutOfRange { offset, max } => {
29 write!(
30 f,
31 "seq_offset {} is out of range (max pre-computed = {})",
32 offset, max
33 )
34 }
35 Self::ShapeMismatch { expected, got } => {
36 write!(f, "Shape mismatch: expected {:?}, got {:?}", expected, got)
37 }
38 }
39 }
40}
41
42impl std::error::Error for PositionError {}
43
44#[derive(Debug, Clone)]
53pub struct RotaryPositionEmbedding {
54 pub head_dim: usize,
56 pub base: f64,
58 pub max_seq_len: usize,
60 cos_cache: Array2<f64>,
62 sin_cache: Array2<f64>,
64}
65
66impl RotaryPositionEmbedding {
67 pub fn new(
71 head_dim: usize,
72 max_seq_len: usize,
73 base: f64,
74 ) -> std::result::Result<Self, PositionError> {
75 if !head_dim.is_multiple_of(2) {
76 return Err(PositionError::HeadDimMustBeEven { head_dim });
77 }
78 let (cos_cache, sin_cache) = Self::build_cos_sin_cache(head_dim, max_seq_len, base);
79 Ok(Self {
80 head_dim,
81 base,
82 max_seq_len,
83 cos_cache,
84 sin_cache,
85 })
86 }
87
88 fn build_cos_sin_cache(
90 head_dim: usize,
91 max_seq_len: usize,
92 base: f64,
93 ) -> (Array2<f64>, Array2<f64>) {
94 let half_dim = head_dim / 2;
95 let thetas: Vec<f64> = (0..half_dim)
97 .map(|i| base.powf(-(2.0 * i as f64) / head_dim as f64))
98 .collect();
99
100 let mut cos_cache = Array2::<f64>::zeros((max_seq_len, half_dim));
101 let mut sin_cache = Array2::<f64>::zeros((max_seq_len, half_dim));
102
103 for pos in 0..max_seq_len {
104 for (i, &theta) in thetas.iter().enumerate() {
105 let angle = pos as f64 * theta;
106 cos_cache[[pos, i]] = angle.cos();
107 sin_cache[[pos, i]] = angle.sin();
108 }
109 }
110
111 (cos_cache, sin_cache)
112 }
113
114 pub fn apply(
119 &self,
120 x: &ArrayD<f64>,
121 seq_offset: usize,
122 ) -> std::result::Result<ArrayD<f64>, PositionError> {
123 let shape = x.shape();
124 let ndim = shape.len();
125 if ndim < 1 {
126 return Err(PositionError::ShapeMismatch {
127 expected: vec![1],
128 got: shape.to_vec(),
129 });
130 }
131
132 let last_dim = shape[ndim - 1];
133 if last_dim != self.head_dim {
134 return Err(PositionError::ShapeMismatch {
135 expected: vec![self.head_dim],
136 got: vec![last_dim],
137 });
138 }
139
140 let seq_len = shape[0];
141 if seq_offset + seq_len > self.max_seq_len {
142 return Err(PositionError::SeqOffsetOutOfRange {
143 offset: seq_offset + seq_len - 1,
144 max: self.max_seq_len - 1,
145 });
146 }
147
148 let half_dim = self.head_dim / 2;
149
150 let total = x.len() / self.head_dim;
153 let x2 = x
154 .view()
155 .into_shape_with_order((total, self.head_dim))
156 .map_err(|_| PositionError::ShapeMismatch {
157 expected: vec![total, self.head_dim],
158 got: shape.to_vec(),
159 })?;
160
161 let x_first = x2.slice(s![.., ..half_dim]).to_owned();
163 let x_second = x2.slice(s![.., half_dim..]).to_owned();
164
165 let mut rotated = Array2::<f64>::zeros((total, self.head_dim));
167 rotated.slice_mut(s![.., ..half_dim]).assign(&(-&x_second));
168 rotated.slice_mut(s![.., half_dim..]).assign(&x_first);
169
170 let positions_per_token = total.checked_div(seq_len).unwrap_or(1);
174 let mut cos_expanded = Array2::<f64>::zeros((total, half_dim));
175 let mut sin_expanded = Array2::<f64>::zeros((total, half_dim));
176
177 for i in 0..total {
178 let pos = seq_offset + i / positions_per_token.max(1);
179 let capped_pos = pos.min(self.max_seq_len - 1);
180 cos_expanded
181 .slice_mut(s![i, ..])
182 .assign(&self.cos_cache.slice(s![capped_pos, ..]));
183 sin_expanded
184 .slice_mut(s![i, ..])
185 .assign(&self.sin_cache.slice(s![capped_pos, ..]));
186 }
187
188 let mut cos_full = Array2::<f64>::zeros((total, self.head_dim));
190 let mut sin_full = Array2::<f64>::zeros((total, self.head_dim));
191 cos_full.slice_mut(s![.., ..half_dim]).assign(&cos_expanded);
192 cos_full.slice_mut(s![.., half_dim..]).assign(&cos_expanded);
193 sin_full.slice_mut(s![.., ..half_dim]).assign(&sin_expanded);
194 sin_full.slice_mut(s![.., half_dim..]).assign(&sin_expanded);
195
196 let result2 = &x2 * &cos_full + &rotated * &sin_full;
198
199 let result = result2
201 .into_dyn()
202 .into_shape_with_order(IxDyn(shape))
203 .map_err(|_| PositionError::ShapeMismatch {
204 expected: shape.to_vec(),
205 got: vec![total, self.head_dim],
206 })?;
207
208 Ok(result)
209 }
210
211 pub fn rotate_half(x: &ArrayD<f64>) -> ArrayD<f64> {
215 let shape = x.shape();
216 let ndim = shape.len();
217 if ndim < 1 {
218 return x.to_owned();
219 }
220 let head_dim = shape[ndim - 1];
221 let half = head_dim / 2;
222 let total = x.len() / head_dim;
223
224 let x2 = x
225 .view()
226 .into_shape_with_order((total, head_dim))
227 .expect("rotate_half reshape");
228
229 let x_first = x2.slice(s![.., ..half]).to_owned();
230 let x_second = x2.slice(s![.., half..]).to_owned();
231
232 let mut out = Array2::<f64>::zeros((total, head_dim));
233 out.slice_mut(s![.., ..half]).assign(&(-&x_second));
234 out.slice_mut(s![.., half..]).assign(&x_first);
235
236 out.into_dyn()
237 .into_shape_with_order(IxDyn(shape))
238 .expect("rotate_half final reshape")
239 }
240
241 pub fn frequencies_at(&self, pos: usize) -> ArrayView1<'_, f64> {
243 let capped = pos.min(self.max_seq_len - 1);
244 self.cos_cache.slice(s![capped, ..])
245 }
246}
247
248#[derive(Debug, Clone)]
255pub struct RelativePositionBias {
256 pub num_heads: usize,
258 pub num_buckets: usize,
260 pub max_distance: usize,
262 pub bidirectional: bool,
264 biases: Array2<f64>,
266}
267
268impl RelativePositionBias {
269 pub fn new(
271 num_heads: usize,
272 num_buckets: usize,
273 max_distance: usize,
274 bidirectional: bool,
275 ) -> Self {
276 Self {
277 num_heads,
278 num_buckets,
279 max_distance,
280 bidirectional,
281 biases: Array2::<f64>::zeros((num_buckets, num_heads)),
282 }
283 }
284
285 pub fn compute_bias(&self, query_len: usize, key_len: usize) -> Array3<f64> {
290 let mut bias = Array3::<f64>::zeros((self.num_heads, query_len, key_len));
291
292 for q in 0..query_len {
293 for k in 0..key_len {
294 let relative_position = q as i32 - k as i32;
295 let bucket = Self::relative_position_bucket(
296 relative_position,
297 self.bidirectional,
298 self.num_buckets,
299 self.max_distance,
300 );
301 for h in 0..self.num_heads {
302 bias[[h, q, k]] = self.biases[[bucket, h]];
303 }
304 }
305 }
306
307 bias
308 }
309
310 fn relative_position_bucket(
315 relative_position: i32,
316 bidirectional: bool,
317 num_buckets: usize,
318 max_distance: usize,
319 ) -> usize {
320 let mut n = num_buckets;
321 let mut relative = relative_position;
322
323 if bidirectional {
324 n /= 2;
325 if relative_position > 0 {
327 let pos_bucket =
329 Self::distance_to_bucket(relative_position as usize, n, max_distance);
330 return (n + pos_bucket).min(num_buckets - 1);
331 }
332 relative = -relative;
333 } else {
334 relative = (-relative).max(0);
335 }
336
337 let distance = relative as usize;
338 Self::distance_to_bucket(distance, n, max_distance).min(num_buckets - 1)
339 }
340
341 fn distance_to_bucket(distance: usize, n: usize, max_distance: usize) -> usize {
343 if n == 0 {
344 return 0;
345 }
346 let max_exact = n / 2;
347 if distance < max_exact {
348 distance
350 } else {
351 let clamped = distance.min(max_distance);
353 let scale = (clamped as f64 / max_exact as f64).ln()
354 / (max_distance as f64 / max_exact as f64).ln().max(1e-10);
355 let bucket_offset = (scale * (n - max_exact) as f64) as usize;
356 (max_exact + bucket_offset).min(n - 1)
357 }
358 }
359
360 pub fn update_biases(
364 &mut self,
365 new_biases: Array2<f64>,
366 ) -> std::result::Result<(), PositionError> {
367 let expected = vec![self.num_buckets, self.num_heads];
368 let got = new_biases.shape().to_vec();
369 if got != expected {
370 return Err(PositionError::ShapeMismatch { expected, got });
371 }
372 self.biases = new_biases;
373 Ok(())
374 }
375}
376
377#[cfg(test)]
378mod tests {
379 use super::*;
380
381 fn make_tensor(shape: &[usize], fill: f64) -> ArrayD<f64> {
382 ArrayD::from_elem(IxDyn(shape), fill)
383 }
384
385 #[test]
386 fn test_rope_new_builds_cache() {
387 let rope = RotaryPositionEmbedding::new(8, 16, 10000.0).expect("valid head_dim");
388 assert_eq!(
389 rope.cos_cache.shape(),
390 &[16, 4],
391 "cos_cache shape [max_seq, half_dim]"
392 );
393 assert_eq!(
394 rope.sin_cache.shape(),
395 &[16, 4],
396 "sin_cache shape [max_seq, half_dim]"
397 );
398 }
399
400 #[test]
401 fn test_rope_apply_preserves_shape() {
402 let rope = RotaryPositionEmbedding::new(8, 32, 10000.0).expect("valid");
403 let x = make_tensor(&[4, 8], 1.0);
404 let result = rope.apply(&x, 0).expect("apply should succeed");
405 assert_eq!(
406 result.shape(),
407 x.shape(),
408 "output shape must match input shape"
409 );
410 }
411
412 #[test]
413 fn test_rope_rotate_half_correct() {
414 let data = vec![1.0_f64, 2.0, 3.0, 4.0];
416 let x = ArrayD::from_shape_vec(IxDyn(&[1, 4]), data).expect("build");
417 let rotated = RotaryPositionEmbedding::rotate_half(&x);
418 let flat: Vec<f64> = rotated.iter().copied().collect();
419 assert!(
421 (flat[0] - (-3.0)).abs() < 1e-9,
422 "first element should be -3"
423 );
424 assert!(
425 (flat[1] - (-4.0)).abs() < 1e-9,
426 "second element should be -4"
427 );
428 assert!((flat[2] - 1.0).abs() < 1e-9, "third element should be 1");
430 assert!((flat[3] - 2.0).abs() < 1e-9, "fourth element should be 2");
431 }
432
433 #[test]
434 fn test_rope_head_dim_odd_errors() {
435 let result = RotaryPositionEmbedding::new(7, 16, 10000.0);
436 assert!(
437 matches!(result, Err(PositionError::HeadDimMustBeEven { .. })),
438 "odd head_dim should produce HeadDimMustBeEven error"
439 );
440 }
441
442 #[test]
443 fn test_relative_position_bias_compute() {
444 let rpb = RelativePositionBias::new(4, 32, 128, true);
445 let bias = rpb.compute_bias(6, 10);
446 assert_eq!(
447 bias.shape(),
448 &[4, 6, 10],
449 "bias shape must be [num_heads, q_len, k_len]"
450 );
451 }
452
453 #[test]
454 fn test_relative_position_bias_symmetric_for_bidirectional() {
455 let _rpb = RelativePositionBias::new(1, 32, 64, true);
458 let forward_bucket = RelativePositionBias::relative_position_bucket(5, true, 32, 64);
459 let backward_bucket = RelativePositionBias::relative_position_bucket(-5, true, 32, 64);
460 assert_ne!(
461 forward_bucket, backward_bucket,
462 "forward and backward positions should map to different buckets"
463 );
464 }
465
466 #[test]
467 fn test_relative_position_bucket_clamping() {
468 let bucket = RelativePositionBias::relative_position_bucket(100000, false, 16, 128);
470 assert!(bucket < 16, "bucket must be within [0, num_buckets)");
471 }
472}