1#![deny(unsafe_code)]
13#![warn(missing_docs)]
14#![warn(rust_2018_idioms)]
15
16use std::cmp::Reverse;
17use std::collections::BinaryHeap;
18
19use ndarray::{ArrayView1, ArrayView2, ArrayViewMut2, Axis};
20use rayon::prelude::*;
21use thiserror::Error;
22
23pub const EPS: f32 = 1e-12;
25
26pub type Result<T> = std::result::Result<T, VecNormError>;
28
29#[derive(Error, Debug)]
31pub enum VecNormError {
32 #[error("dimension mismatch: a={a:?}, b={b:?}")]
34 DimensionMismatch {
35 a: Vec<usize>,
37 b: Vec<usize>,
39 },
40 #[error("k ({k}) must be <= len ({len})")]
42 KTooLarge {
43 k: usize,
45 len: usize,
47 },
48 #[error("k must be > 0")]
50 KZero,
51}
52
53pub fn l2_normalize(matrix: &mut ArrayViewMut2<'_, f32>) {
56 matrix
57 .axis_iter_mut(Axis(0))
58 .into_par_iter()
59 .for_each(|mut row| {
60 let mut sum_sq = 0.0_f32;
61 for &x in row.iter() {
62 sum_sq += x * x;
63 }
64 let norm = sum_sq.sqrt();
65 if norm > EPS {
66 for x in row.iter_mut() {
67 *x /= norm;
68 }
69 } else {
70 for x in row.iter_mut() {
71 *x = 0.0;
72 }
73 }
74 });
75}
76
77pub fn l2_normalize_copy(matrix: &ArrayView2<'_, f32>) -> ndarray::Array2<f32> {
79 let mut out = matrix.to_owned();
80 l2_normalize(&mut out.view_mut());
81 out
82}
83
84pub fn cosine_similarity(a: &ArrayView1<'_, f32>, b: &ArrayView1<'_, f32>) -> Result<f32> {
87 if a.len() != b.len() {
88 return Err(VecNormError::DimensionMismatch {
89 a: a.shape().to_vec(),
90 b: b.shape().to_vec(),
91 });
92 }
93 let mut dot = 0.0_f32;
94 let mut norm_a = 0.0_f32;
95 let mut norm_b = 0.0_f32;
96 for (&x, &y) in a.iter().zip(b.iter()) {
97 dot += x * y;
98 norm_a += x * x;
99 norm_b += y * y;
100 }
101 let denom = norm_a.sqrt() * norm_b.sqrt();
102 if denom <= EPS {
103 return Ok(0.0);
104 }
105 Ok(dot / denom)
106}
107
108pub fn dot_product(a: &ArrayView1<'_, f32>, b: &ArrayView1<'_, f32>) -> Result<f32> {
111 if a.len() != b.len() {
112 return Err(VecNormError::DimensionMismatch {
113 a: a.shape().to_vec(),
114 b: b.shape().to_vec(),
115 });
116 }
117 let mut s = 0.0_f32;
118 for (&x, &y) in a.iter().zip(b.iter()) {
119 s += x * y;
120 }
121 Ok(s)
122}
123
124pub fn argmax(scores: &ArrayView1<'_, f32>) -> Result<(usize, f32)> {
127 if scores.is_empty() {
128 return Err(VecNormError::KZero);
129 }
130 let mut best_i = 0usize;
131 let mut best_v = scores[0];
132 for (i, &v) in scores.iter().enumerate().skip(1) {
133 if v > best_v {
134 best_v = v;
135 best_i = i;
136 }
137 }
138 Ok((best_i, best_v))
139}
140
141pub fn top_k_argmax(scores: &ArrayView1<'_, f32>, k: usize) -> Result<Vec<(usize, f32)>> {
144 if k == 0 {
145 return Err(VecNormError::KZero);
146 }
147 if k > scores.len() {
148 return Err(VecNormError::KTooLarge {
149 k,
150 len: scores.len(),
151 });
152 }
153 let mut heap: BinaryHeap<(Reverse<OrdFloat>, usize)> = BinaryHeap::with_capacity(k);
157 for (i, &s) in scores.iter().enumerate() {
158 let entry = (Reverse(OrdFloat(s)), i);
159 if heap.len() < k {
160 heap.push(entry);
161 } else if let Some(top) = heap.peek() {
162 if entry.0 < top.0 {
165 heap.pop();
166 heap.push(entry);
167 }
168 }
169 }
170 let mut out: Vec<(usize, f32)> = heap.into_iter().map(|(rs, i)| (i, rs.0 .0)).collect();
172 out.sort_by(|a, b| {
173 b.1.partial_cmp(&a.1)
174 .unwrap_or(std::cmp::Ordering::Equal)
175 .then(a.0.cmp(&b.0))
176 });
177 Ok(out)
178}
179
180pub fn batch_top_k_argmax(
183 scores: &ArrayView2<'_, f32>,
184 k: usize,
185 parallel: bool,
186) -> Result<Vec<Vec<(usize, f32)>>> {
187 if k == 0 {
188 return Err(VecNormError::KZero);
189 }
190 if k > scores.ncols() {
191 return Err(VecNormError::KTooLarge {
192 k,
193 len: scores.ncols(),
194 });
195 }
196 if parallel {
197 scores
198 .axis_iter(Axis(0))
199 .into_par_iter()
200 .map(|row| top_k_argmax(&row, k))
201 .collect()
202 } else {
203 scores
204 .axis_iter(Axis(0))
205 .map(|row| top_k_argmax(&row, k))
206 .collect()
207 }
208}
209
210pub fn cosine_distances(
215 a: &ArrayView2<'_, f32>,
216 b: &ArrayView2<'_, f32>,
217) -> Result<ndarray::Array2<f32>> {
218 if a.ncols() != b.ncols() {
219 return Err(VecNormError::DimensionMismatch {
220 a: a.shape().to_vec(),
221 b: b.shape().to_vec(),
222 });
223 }
224 let an = l2_normalize_copy(a);
225 let bn = l2_normalize_copy(b);
226 let n_a = an.nrows();
227 let n_b = bn.nrows();
228 let mut out = ndarray::Array2::<f32>::zeros((n_a, n_b));
229 out.axis_iter_mut(Axis(0))
230 .into_par_iter()
231 .enumerate()
232 .for_each(|(i, mut row)| {
233 for (j, cell) in row.iter_mut().enumerate() {
234 let mut dot = 0.0_f32;
235 for (&x, &y) in an.row(i).iter().zip(bn.row(j).iter()) {
236 dot += x * y;
237 }
238 *cell = 1.0 - dot;
239 }
240 });
241 Ok(out)
242}
243
244#[derive(Debug, Clone, Copy, PartialEq)]
247struct OrdFloat(f32);
248
249impl Eq for OrdFloat {}
250
251impl Ord for OrdFloat {
252 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
253 match self.0.partial_cmp(&other.0) {
255 Some(o) => o,
256 None => {
257 let s = self.0.is_nan();
258 let o = other.0.is_nan();
259 match (s, o) {
260 (true, true) => std::cmp::Ordering::Equal,
261 (true, false) => std::cmp::Ordering::Less,
262 (false, true) => std::cmp::Ordering::Greater,
263 (false, false) => std::cmp::Ordering::Equal,
264 }
265 }
266 }
267 }
268}
269
270impl PartialOrd for OrdFloat {
271 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
272 Some(self.cmp(other))
273 }
274}
275
276#[cfg(test)]
277mod tests {
278 use super::*;
279 use ndarray::{arr1, arr2, Array1, Array2};
280
281 #[test]
282 fn l2_normalize_basic() {
283 let mut a = arr2(&[[3.0_f32, 4.0], [1.0, 0.0]]);
284 l2_normalize(&mut a.view_mut());
285 assert!((a[[0, 0]] - 0.6).abs() < 1e-6);
287 assert!((a[[0, 1]] - 0.8).abs() < 1e-6);
288 assert!((a[[1, 0]] - 1.0).abs() < 1e-6);
290 assert!((a[[1, 1]] - 0.0).abs() < 1e-6);
291 }
292
293 #[test]
294 fn l2_normalize_zero_row_left_zero() {
295 let mut a = arr2(&[[0.0_f32, 0.0], [3.0, 4.0]]);
296 l2_normalize(&mut a.view_mut());
297 assert_eq!(a[[0, 0]], 0.0);
298 assert_eq!(a[[0, 1]], 0.0);
299 assert!(!a[[0, 0]].is_nan());
300 }
301
302 #[test]
303 fn l2_normalize_copy_does_not_mutate_input() {
304 let a = arr2(&[[3.0_f32, 4.0]]);
305 let _ = l2_normalize_copy(&a.view());
306 assert_eq!(a[[0, 0]], 3.0);
307 assert_eq!(a[[0, 1]], 4.0);
308 }
309
310 #[test]
311 fn cosine_basic() {
312 let a = arr1(&[1.0_f32, 0.0]);
313 let b = arr1(&[1.0_f32, 0.0]);
314 let c = arr1(&[0.0_f32, 1.0]);
315 assert!((cosine_similarity(&a.view(), &b.view()).unwrap() - 1.0).abs() < 1e-6);
316 assert!(cosine_similarity(&a.view(), &c.view()).unwrap().abs() < 1e-6);
317 }
318
319 #[test]
320 fn dot_product_basic() {
321 let a = arr1(&[1.0_f32, 2.0, 3.0]);
322 let b = arr1(&[4.0_f32, -5.0, 6.0]);
323 assert!((dot_product(&a.view(), &b.view()).unwrap() - 12.0).abs() < 1e-6);
325 }
326
327 #[test]
328 fn dot_product_dim_mismatch() {
329 let a = arr1(&[1.0_f32, 0.0]);
330 let b = arr1(&[1.0_f32]);
331 assert!(dot_product(&a.view(), &b.view()).is_err());
332 }
333
334 #[test]
335 fn argmax_picks_largest() {
336 let s = arr1(&[1.0_f32, 5.0, 3.0, 4.0, 2.0]);
337 let (i, v) = argmax(&s.view()).unwrap();
338 assert_eq!(i, 1);
339 assert!((v - 5.0).abs() < 1e-6);
340 }
341
342 #[test]
343 fn argmax_ties_pick_lowest_index() {
344 let s = arr1(&[3.0_f32, 3.0, 3.0]);
345 assert_eq!(argmax(&s.view()).unwrap().0, 0);
346 }
347
348 #[test]
349 fn argmax_empty_rejected() {
350 let s: ndarray::Array1<f32> = arr1(&[]);
351 assert!(argmax(&s.view()).is_err());
352 }
353
354 #[test]
355 fn cosine_zero_for_zero_vector() {
356 let a = arr1(&[0.0_f32, 0.0]);
357 let b = arr1(&[1.0_f32, 1.0]);
358 assert_eq!(cosine_similarity(&a.view(), &b.view()).unwrap(), 0.0);
359 }
360
361 #[test]
362 fn cosine_dim_mismatch() {
363 let a = arr1(&[1.0_f32, 0.0]);
364 let b = arr1(&[1.0_f32, 0.0, 1.0]);
365 assert!(cosine_similarity(&a.view(), &b.view()).is_err());
366 }
367
368 #[test]
369 fn top_k_correct_order() {
370 let s = arr1(&[1.0, 5.0, 3.0, 4.0, 2.0]);
371 let r = top_k_argmax(&s.view(), 3).unwrap();
372 assert_eq!(r, vec![(1, 5.0), (3, 4.0), (2, 3.0)]);
373 }
374
375 #[test]
376 fn top_k_full_length_returns_full_sort() {
377 let s = arr1(&[1.0, 5.0, 3.0]);
378 let r = top_k_argmax(&s.view(), 3).unwrap();
379 assert_eq!(r, vec![(1, 5.0), (2, 3.0), (0, 1.0)]);
380 }
381
382 #[test]
383 fn top_k_ties_broken_by_lower_index() {
384 let s = arr1(&[1.0, 1.0, 1.0]);
385 let r = top_k_argmax(&s.view(), 2).unwrap();
386 assert_eq!(r, vec![(0, 1.0), (1, 1.0)]);
387 }
388
389 #[test]
390 fn top_k_zero_rejected() {
391 let s = arr1(&[1.0, 2.0]);
392 assert!(top_k_argmax(&s.view(), 0).is_err());
393 }
394
395 #[test]
396 fn top_k_too_large_rejected() {
397 let s = arr1(&[1.0, 2.0]);
398 assert!(top_k_argmax(&s.view(), 3).is_err());
399 }
400
401 #[test]
402 fn batch_top_k_serial_and_parallel_match() {
403 let m = Array2::from_shape_fn((10, 50), |(i, j)| (i * 50 + j) as f32);
404 let s = batch_top_k_argmax(&m.view(), 5, false).unwrap();
405 let p = batch_top_k_argmax(&m.view(), 5, true).unwrap();
406 assert_eq!(s, p);
407 assert_eq!(s.len(), 10);
408 assert_eq!(s[0][0], (49, 49.0));
410 }
411
412 #[test]
413 fn cosine_distances_zero_diagonal() {
414 let a = arr2(&[[1.0_f32, 0.0], [0.0, 1.0]]);
415 let d = cosine_distances(&a.view(), &a.view()).unwrap();
416 assert!(d[[0, 0]].abs() < 1e-6);
418 assert!(d[[1, 1]].abs() < 1e-6);
419 assert!((d[[0, 1]] - 1.0).abs() < 1e-6);
421 assert!((d[[1, 0]] - 1.0).abs() < 1e-6);
422 }
423
424 #[test]
425 fn cosine_distances_dim_mismatch() {
426 let a = Array2::<f32>::zeros((4, 3));
427 let b = Array2::<f32>::zeros((4, 5));
428 assert!(cosine_distances(&a.view(), &b.view()).is_err());
429 }
430
431 #[test]
432 fn nan_in_top_k_does_not_panic() {
433 let s = Array1::from(vec![1.0_f32, f32::NAN, 3.0]);
434 let r = top_k_argmax(&s.view(), 2);
436 assert!(r.is_ok());
437 }
438}