Skip to main content

uni_sparse_vector/
ops.rs

1//! Pure scoring/transform kernels over [`SparseVector`]. No graph/runtime
2//! dependencies — this is the analogue of `uni-btic`'s interval math, kept in
3//! the type crate so every layer (index, rerank, brute-force oracle) calls one
4//! canonical implementation.
5
6use crate::sparse::SparseVector;
7use std::cmp::Ordering;
8
9/// Dot product of two sparse vectors via a linear merge-join over their
10/// (ascending) term ids. This is the SPLADE/learned-sparse scoring primitive
11/// and the exact ground truth a brute-force oracle uses.
12///
13/// O(|a| + |b|). Relies on the [`SparseVector`] sorted-index invariant.
14pub fn sparse_dot(a: &SparseVector, b: &SparseVector) -> f32 {
15    let (ai, av) = (a.indices(), a.values());
16    let (bi, bv) = (b.indices(), b.values());
17    let mut i = 0;
18    let mut j = 0;
19    let mut acc = 0.0f32;
20    while i < ai.len() && j < bi.len() {
21        match ai[i].cmp(&bi[j]) {
22            Ordering::Less => i += 1,
23            Ordering::Greater => j += 1,
24            Ordering::Equal => {
25                acc += av[i] * bv[j];
26                i += 1;
27                j += 1;
28            }
29        }
30    }
31    acc
32}
33
34/// Euclidean (L2) norm of the weights.
35pub fn l2_norm(v: &SparseVector) -> f32 {
36    v.values().iter().map(|w| w * w).sum::<f32>().sqrt()
37}
38
39/// Keep only the `k` terms with the largest absolute weight, preserving the
40/// ascending-index invariant. This is the universal query-side latency lever
41/// for learned-sparse retrieval — high-DF / low-weight query terms dominate the
42/// posting-scan cost, so dropping them trades a little recall for large speedups.
43///
44/// Returns the input unchanged when `k >= len`. Ties are broken by keeping the
45/// lower term id (deterministic).
46pub fn prune_top_k(v: &SparseVector, k: usize) -> SparseVector {
47    if k >= v.len() {
48        return v.clone();
49    }
50    if k == 0 {
51        return SparseVector::new(Vec::new(), Vec::new()).expect("empty vector is always valid");
52    }
53
54    // Rank positions by descending |weight|, tie-break by ascending term id.
55    let mut order: Vec<usize> = (0..v.len()).collect();
56    let values = v.values();
57    let indices = v.indices();
58    order.sort_by(|&x, &y| {
59        values[y]
60            .abs()
61            .partial_cmp(&values[x].abs())
62            .unwrap_or(Ordering::Equal)
63            .then(indices[x].cmp(&indices[y]))
64    });
65    order.truncate(k);
66    // Re-sort the kept positions by term id to restore the invariant.
67    order.sort_unstable();
68
69    let kept_indices = order.iter().map(|&p| indices[p]).collect();
70    let kept_values = order.iter().map(|&p| values[p]).collect();
71    SparseVector::new(kept_indices, kept_values).expect("subset of a valid vector is valid")
72}
73
74#[cfg(test)]
75mod tests {
76    use super::*;
77
78    #[test]
79    fn dot_disjoint_is_zero() {
80        let a = SparseVector::new(vec![1, 3], vec![1.0, 1.0]).unwrap();
81        let b = SparseVector::new(vec![2, 4], vec![1.0, 1.0]).unwrap();
82        assert_eq!(sparse_dot(&a, &b), 0.0);
83    }
84
85    #[test]
86    fn dot_full_overlap() {
87        let a = SparseVector::new(vec![1, 2, 3], vec![1.0, 2.0, 3.0]).unwrap();
88        let b = SparseVector::new(vec![1, 2, 3], vec![4.0, 5.0, 6.0]).unwrap();
89        // 1*4 + 2*5 + 3*6 = 4 + 10 + 18 = 32
90        assert_eq!(sparse_dot(&a, &b), 32.0);
91    }
92
93    #[test]
94    fn dot_partial_overlap() {
95        let a = SparseVector::new(vec![1, 5, 9], vec![2.0, 3.0, 4.0]).unwrap();
96        let b = SparseVector::new(vec![5, 9, 13], vec![10.0, 0.5, 1.0]).unwrap();
97        // overlap on 5 (3*10=30) and 9 (4*0.5=2) => 32
98        assert_eq!(sparse_dot(&a, &b), 32.0);
99    }
100
101    #[test]
102    fn dot_is_commutative() {
103        let a = SparseVector::new(vec![1, 4, 7], vec![1.5, -2.0, 3.0]).unwrap();
104        let b = SparseVector::new(vec![4, 7, 8], vec![2.0, 1.0, 9.0]).unwrap();
105        assert_eq!(sparse_dot(&a, &b), sparse_dot(&b, &a));
106    }
107
108    #[test]
109    fn dot_with_empty_is_zero() {
110        let a = SparseVector::new(vec![1, 2], vec![1.0, 1.0]).unwrap();
111        let empty = SparseVector::new(vec![], vec![]).unwrap();
112        assert_eq!(sparse_dot(&a, &empty), 0.0);
113    }
114
115    #[test]
116    fn l2_norm_basic() {
117        let v = SparseVector::new(vec![1, 2], vec![3.0, 4.0]).unwrap();
118        assert_eq!(l2_norm(&v), 5.0);
119    }
120
121    #[test]
122    fn prune_keeps_largest_magnitude() {
123        let v = SparseVector::new(vec![1, 2, 3, 4], vec![0.1, -5.0, 0.2, 3.0]).unwrap();
124        let pruned = prune_top_k(&v, 2);
125        // largest |w|: term 2 (5.0) and term 4 (3.0); re-sorted by index
126        assert_eq!(pruned.indices(), &[2, 4]);
127        assert_eq!(pruned.values(), &[-5.0, 3.0]);
128    }
129
130    #[test]
131    fn prune_k_ge_len_is_identity() {
132        let v = SparseVector::new(vec![1, 2], vec![1.0, 2.0]).unwrap();
133        assert_eq!(prune_top_k(&v, 5), v);
134        assert_eq!(prune_top_k(&v, 2), v);
135    }
136
137    #[test]
138    fn prune_k_zero_is_empty() {
139        let v = SparseVector::new(vec![1, 2], vec![1.0, 2.0]).unwrap();
140        assert!(prune_top_k(&v, 0).is_empty());
141    }
142
143    #[test]
144    fn prune_preserves_dot_on_kept_terms() {
145        // Pruning the query to its top terms can only drop contributions from
146        // the removed terms; on a doc sharing only kept terms the score is exact.
147        let q = SparseVector::new(vec![1, 2, 3], vec![10.0, 0.01, 9.0]).unwrap();
148        let doc = SparseVector::new(vec![1, 3], vec![2.0, 2.0]).unwrap();
149        let pruned = prune_top_k(&q, 2);
150        assert_eq!(sparse_dot(&pruned, &doc), sparse_dot(&q, &doc));
151    }
152}