Skip to main content

vecnorm_core/
lib.rs

1//! Pure-Rust core for `vecnorm`. Bulk f32 matrix operations:
2//!
3//! - [`l2_normalize`] / [`l2_normalize_copy`] — row-wise unit-length scaling.
4//!   Rows whose norm is below `EPS` are left at zero rather than dividing
5//!   by zero.
6//! - [`cosine_similarity`] — single pair on 1-D vectors. Returns 0 for
7//!   any pair where either side has zero norm.
8//! - [`top_k_argmax`] / [`batch_top_k_argmax`] — partial-heap top-k that
9//!   runs in `O(n log k)`. Tied scores are broken by the original index
10//!   ascending (deterministic).
11
12#![deny(unsafe_code)]
13#![warn(missing_docs)]
14#![warn(rust_2018_idioms)]
15
16use std::cmp::Reverse;
17use std::collections::BinaryHeap;
18
19use ndarray::{ArrayView1, ArrayView2, ArrayViewMut2, Axis};
20use rayon::prelude::*;
21use thiserror::Error;
22
23/// Tiny norm below which a row is considered all-zero and left unscaled.
24pub const EPS: f32 = 1e-12;
25
26/// Crate-wide result alias.
27pub type Result<T> = std::result::Result<T, VecNormError>;
28
29/// All errors surfaced by `vecnorm-core`.
30#[derive(Error, Debug)]
31pub enum VecNormError {
32    /// Two arrays had incompatible shapes.
33    #[error("dimension mismatch: a={a:?}, b={b:?}")]
34    DimensionMismatch {
35        /// Shape of the first input.
36        a: Vec<usize>,
37        /// Shape of the second input.
38        b: Vec<usize>,
39    },
40    /// Caller asked for more elements than the input has.
41    #[error("k ({k}) must be <= len ({len})")]
42    KTooLarge {
43        /// Requested k.
44        k: usize,
45        /// Available length.
46        len: usize,
47    },
48    /// Caller passed `k = 0`.
49    #[error("k must be > 0")]
50    KZero,
51}
52
53/// L2-normalize `matrix` in place, row by row. Rows with norm below `EPS`
54/// are zeroed out (i.e. left unchanged at all-zero) to avoid NaN.
55pub fn l2_normalize(matrix: &mut ArrayViewMut2<'_, f32>) {
56    matrix
57        .axis_iter_mut(Axis(0))
58        .into_par_iter()
59        .for_each(|mut row| {
60            let mut sum_sq = 0.0_f32;
61            for &x in row.iter() {
62                sum_sq += x * x;
63            }
64            let norm = sum_sq.sqrt();
65            if norm > EPS {
66                for x in row.iter_mut() {
67                    *x /= norm;
68                }
69            } else {
70                for x in row.iter_mut() {
71                    *x = 0.0;
72                }
73            }
74        });
75}
76
77/// L2-normalize a copy. Same semantics as [`l2_normalize`].
78pub fn l2_normalize_copy(matrix: &ArrayView2<'_, f32>) -> ndarray::Array2<f32> {
79    let mut out = matrix.to_owned();
80    l2_normalize(&mut out.view_mut());
81    out
82}
83
84/// Cosine similarity between two 1-D vectors. Returns 0 if either side is
85/// all-zero.
86pub fn cosine_similarity(a: &ArrayView1<'_, f32>, b: &ArrayView1<'_, f32>) -> Result<f32> {
87    if a.len() != b.len() {
88        return Err(VecNormError::DimensionMismatch {
89            a: a.shape().to_vec(),
90            b: b.shape().to_vec(),
91        });
92    }
93    let mut dot = 0.0_f32;
94    let mut norm_a = 0.0_f32;
95    let mut norm_b = 0.0_f32;
96    for (&x, &y) in a.iter().zip(b.iter()) {
97        dot += x * y;
98        norm_a += x * x;
99        norm_b += y * y;
100    }
101    let denom = norm_a.sqrt() * norm_b.sqrt();
102    if denom <= EPS {
103        return Ok(0.0);
104    }
105    Ok(dot / denom)
106}
107
108/// Inner product (dot product) of two 1-D vectors. No normalization.
109/// Errors on dim mismatch.
110pub fn dot_product(a: &ArrayView1<'_, f32>, b: &ArrayView1<'_, f32>) -> Result<f32> {
111    if a.len() != b.len() {
112        return Err(VecNormError::DimensionMismatch {
113            a: a.shape().to_vec(),
114            b: b.shape().to_vec(),
115        });
116    }
117    let mut s = 0.0_f32;
118    for (&x, &y) in a.iter().zip(b.iter()) {
119        s += x * y;
120    }
121    Ok(s)
122}
123
124/// Single argmax: returns `(index, score)` of the largest element. Ties
125/// broken by ascending index. Errors on empty input.
126pub fn argmax(scores: &ArrayView1<'_, f32>) -> Result<(usize, f32)> {
127    if scores.is_empty() {
128        return Err(VecNormError::KZero);
129    }
130    let mut best_i = 0usize;
131    let mut best_v = scores[0];
132    for (i, &v) in scores.iter().enumerate().skip(1) {
133        if v > best_v {
134            best_v = v;
135            best_i = i;
136        }
137    }
138    Ok((best_i, best_v))
139}
140
141/// Top-k argmax over a 1-D score vector. Returns `(index, score)` pairs in
142/// descending order. Ties broken by ascending index.
143pub fn top_k_argmax(scores: &ArrayView1<'_, f32>, k: usize) -> Result<Vec<(usize, f32)>> {
144    if k == 0 {
145        return Err(VecNormError::KZero);
146    }
147    if k > scores.len() {
148        return Err(VecNormError::KTooLarge {
149            k,
150            len: scores.len(),
151        });
152    }
153    // Maintain a min-heap of size k. The smallest element on the heap is
154    // the threshold to beat. We compare on `(Reverse(score), idx)` so equal
155    // scores order ascending by index, which matches the stable convention.
156    let mut heap: BinaryHeap<(Reverse<OrdFloat>, usize)> = BinaryHeap::with_capacity(k);
157    for (i, &s) in scores.iter().enumerate() {
158        let entry = (Reverse(OrdFloat(s)), i);
159        if heap.len() < k {
160            heap.push(entry);
161        } else if let Some(top) = heap.peek() {
162            // Heap is a min-heap on score (because of Reverse); the *largest*
163            // Reverse-key is the smallest score on the heap.
164            if entry.0 < top.0 {
165                heap.pop();
166                heap.push(entry);
167            }
168        }
169    }
170    // Drain heap and sort descending.
171    let mut out: Vec<(usize, f32)> = heap.into_iter().map(|(rs, i)| (i, rs.0 .0)).collect();
172    out.sort_by(|a, b| {
173        b.1.partial_cmp(&a.1)
174            .unwrap_or(std::cmp::Ordering::Equal)
175            .then(a.0.cmp(&b.0))
176    });
177    Ok(out)
178}
179
180/// Batch top-k argmax over an `(n_rows, n_cols)` matrix. With `parallel = true`
181/// distributes rows across rayon's pool.
182pub fn batch_top_k_argmax(
183    scores: &ArrayView2<'_, f32>,
184    k: usize,
185    parallel: bool,
186) -> Result<Vec<Vec<(usize, f32)>>> {
187    if k == 0 {
188        return Err(VecNormError::KZero);
189    }
190    if k > scores.ncols() {
191        return Err(VecNormError::KTooLarge {
192            k,
193            len: scores.ncols(),
194        });
195    }
196    if parallel {
197        scores
198            .axis_iter(Axis(0))
199            .into_par_iter()
200            .map(|row| top_k_argmax(&row, k))
201            .collect()
202    } else {
203        scores
204            .axis_iter(Axis(0))
205            .map(|row| top_k_argmax(&row, k))
206            .collect()
207    }
208}
209
210/// Cosine distance matrix between two `(n_a, d)` and `(n_b, d)` matrices.
211/// Returns an `(n_a, n_b)` matrix where `out[i, j]` is the cosine distance
212/// `1 - cos(a_i, b_j)`. Inputs are not modified; this normalizes copies
213/// internally so accuracy is preserved on un-normalized inputs.
214pub fn cosine_distances(
215    a: &ArrayView2<'_, f32>,
216    b: &ArrayView2<'_, f32>,
217) -> Result<ndarray::Array2<f32>> {
218    if a.ncols() != b.ncols() {
219        return Err(VecNormError::DimensionMismatch {
220            a: a.shape().to_vec(),
221            b: b.shape().to_vec(),
222        });
223    }
224    let an = l2_normalize_copy(a);
225    let bn = l2_normalize_copy(b);
226    let n_a = an.nrows();
227    let n_b = bn.nrows();
228    let mut out = ndarray::Array2::<f32>::zeros((n_a, n_b));
229    out.axis_iter_mut(Axis(0))
230        .into_par_iter()
231        .enumerate()
232        .for_each(|(i, mut row)| {
233            for (j, cell) in row.iter_mut().enumerate() {
234                let mut dot = 0.0_f32;
235                for (&x, &y) in an.row(i).iter().zip(bn.row(j).iter()) {
236                    dot += x * y;
237                }
238                *cell = 1.0 - dot;
239            }
240        });
241    Ok(out)
242}
243
244// ---- internal: Ord-able f32 wrapper ----
245
246#[derive(Debug, Clone, Copy, PartialEq)]
247struct OrdFloat(f32);
248
249impl Eq for OrdFloat {}
250
251impl Ord for OrdFloat {
252    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
253        // NaN sorts as smallest; we don't expect NaN in scores but tolerate.
254        match self.0.partial_cmp(&other.0) {
255            Some(o) => o,
256            None => {
257                let s = self.0.is_nan();
258                let o = other.0.is_nan();
259                match (s, o) {
260                    (true, true) => std::cmp::Ordering::Equal,
261                    (true, false) => std::cmp::Ordering::Less,
262                    (false, true) => std::cmp::Ordering::Greater,
263                    (false, false) => std::cmp::Ordering::Equal,
264                }
265            }
266        }
267    }
268}
269
270impl PartialOrd for OrdFloat {
271    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
272        Some(self.cmp(other))
273    }
274}
275
276#[cfg(test)]
277mod tests {
278    use super::*;
279    use ndarray::{arr1, arr2, Array1, Array2};
280
281    #[test]
282    fn l2_normalize_basic() {
283        let mut a = arr2(&[[3.0_f32, 4.0], [1.0, 0.0]]);
284        l2_normalize(&mut a.view_mut());
285        // Row 0 norm 5 -> [0.6, 0.8]
286        assert!((a[[0, 0]] - 0.6).abs() < 1e-6);
287        assert!((a[[0, 1]] - 0.8).abs() < 1e-6);
288        // Row 1 norm 1 -> [1.0, 0.0]
289        assert!((a[[1, 0]] - 1.0).abs() < 1e-6);
290        assert!((a[[1, 1]] - 0.0).abs() < 1e-6);
291    }
292
293    #[test]
294    fn l2_normalize_zero_row_left_zero() {
295        let mut a = arr2(&[[0.0_f32, 0.0], [3.0, 4.0]]);
296        l2_normalize(&mut a.view_mut());
297        assert_eq!(a[[0, 0]], 0.0);
298        assert_eq!(a[[0, 1]], 0.0);
299        assert!(!a[[0, 0]].is_nan());
300    }
301
302    #[test]
303    fn l2_normalize_copy_does_not_mutate_input() {
304        let a = arr2(&[[3.0_f32, 4.0]]);
305        let _ = l2_normalize_copy(&a.view());
306        assert_eq!(a[[0, 0]], 3.0);
307        assert_eq!(a[[0, 1]], 4.0);
308    }
309
310    #[test]
311    fn cosine_basic() {
312        let a = arr1(&[1.0_f32, 0.0]);
313        let b = arr1(&[1.0_f32, 0.0]);
314        let c = arr1(&[0.0_f32, 1.0]);
315        assert!((cosine_similarity(&a.view(), &b.view()).unwrap() - 1.0).abs() < 1e-6);
316        assert!(cosine_similarity(&a.view(), &c.view()).unwrap().abs() < 1e-6);
317    }
318
319    #[test]
320    fn dot_product_basic() {
321        let a = arr1(&[1.0_f32, 2.0, 3.0]);
322        let b = arr1(&[4.0_f32, -5.0, 6.0]);
323        // 1*4 + 2*(-5) + 3*6 = 4 - 10 + 18 = 12.
324        assert!((dot_product(&a.view(), &b.view()).unwrap() - 12.0).abs() < 1e-6);
325    }
326
327    #[test]
328    fn dot_product_dim_mismatch() {
329        let a = arr1(&[1.0_f32, 0.0]);
330        let b = arr1(&[1.0_f32]);
331        assert!(dot_product(&a.view(), &b.view()).is_err());
332    }
333
334    #[test]
335    fn argmax_picks_largest() {
336        let s = arr1(&[1.0_f32, 5.0, 3.0, 4.0, 2.0]);
337        let (i, v) = argmax(&s.view()).unwrap();
338        assert_eq!(i, 1);
339        assert!((v - 5.0).abs() < 1e-6);
340    }
341
342    #[test]
343    fn argmax_ties_pick_lowest_index() {
344        let s = arr1(&[3.0_f32, 3.0, 3.0]);
345        assert_eq!(argmax(&s.view()).unwrap().0, 0);
346    }
347
348    #[test]
349    fn argmax_empty_rejected() {
350        let s: ndarray::Array1<f32> = arr1(&[]);
351        assert!(argmax(&s.view()).is_err());
352    }
353
354    #[test]
355    fn cosine_zero_for_zero_vector() {
356        let a = arr1(&[0.0_f32, 0.0]);
357        let b = arr1(&[1.0_f32, 1.0]);
358        assert_eq!(cosine_similarity(&a.view(), &b.view()).unwrap(), 0.0);
359    }
360
361    #[test]
362    fn cosine_dim_mismatch() {
363        let a = arr1(&[1.0_f32, 0.0]);
364        let b = arr1(&[1.0_f32, 0.0, 1.0]);
365        assert!(cosine_similarity(&a.view(), &b.view()).is_err());
366    }
367
368    #[test]
369    fn top_k_correct_order() {
370        let s = arr1(&[1.0, 5.0, 3.0, 4.0, 2.0]);
371        let r = top_k_argmax(&s.view(), 3).unwrap();
372        assert_eq!(r, vec![(1, 5.0), (3, 4.0), (2, 3.0)]);
373    }
374
375    #[test]
376    fn top_k_full_length_returns_full_sort() {
377        let s = arr1(&[1.0, 5.0, 3.0]);
378        let r = top_k_argmax(&s.view(), 3).unwrap();
379        assert_eq!(r, vec![(1, 5.0), (2, 3.0), (0, 1.0)]);
380    }
381
382    #[test]
383    fn top_k_ties_broken_by_lower_index() {
384        let s = arr1(&[1.0, 1.0, 1.0]);
385        let r = top_k_argmax(&s.view(), 2).unwrap();
386        assert_eq!(r, vec![(0, 1.0), (1, 1.0)]);
387    }
388
389    #[test]
390    fn top_k_zero_rejected() {
391        let s = arr1(&[1.0, 2.0]);
392        assert!(top_k_argmax(&s.view(), 0).is_err());
393    }
394
395    #[test]
396    fn top_k_too_large_rejected() {
397        let s = arr1(&[1.0, 2.0]);
398        assert!(top_k_argmax(&s.view(), 3).is_err());
399    }
400
401    #[test]
402    fn batch_top_k_serial_and_parallel_match() {
403        let m = Array2::from_shape_fn((10, 50), |(i, j)| (i * 50 + j) as f32);
404        let s = batch_top_k_argmax(&m.view(), 5, false).unwrap();
405        let p = batch_top_k_argmax(&m.view(), 5, true).unwrap();
406        assert_eq!(s, p);
407        assert_eq!(s.len(), 10);
408        // First row: top-5 of [0..50) is [49, 48, 47, 46, 45].
409        assert_eq!(s[0][0], (49, 49.0));
410    }
411
412    #[test]
413    fn cosine_distances_zero_diagonal() {
414        let a = arr2(&[[1.0_f32, 0.0], [0.0, 1.0]]);
415        let d = cosine_distances(&a.view(), &a.view()).unwrap();
416        // Diagonal is cosine to self == 0 distance.
417        assert!(d[[0, 0]].abs() < 1e-6);
418        assert!(d[[1, 1]].abs() < 1e-6);
419        // Off-diagonal: orthogonal == 1 distance.
420        assert!((d[[0, 1]] - 1.0).abs() < 1e-6);
421        assert!((d[[1, 0]] - 1.0).abs() < 1e-6);
422    }
423
424    #[test]
425    fn cosine_distances_dim_mismatch() {
426        let a = Array2::<f32>::zeros((4, 3));
427        let b = Array2::<f32>::zeros((4, 5));
428        assert!(cosine_distances(&a.view(), &b.view()).is_err());
429    }
430
431    #[test]
432    fn nan_in_top_k_does_not_panic() {
433        let s = Array1::from(vec![1.0_f32, f32::NAN, 3.0]);
434        // We don't promise NaN handling, but we promise no panic.
435        let r = top_k_argmax(&s.view(), 2);
436        assert!(r.is_ok());
437    }
438}