Skip to main content

uni_sparse_vector/
ops.rs

1//! Pure scoring/transform kernels over [`SparseVector`]. No graph/runtime
2//! dependencies — this is the analogue of `uni-btic`'s interval math, kept in
3//! the type crate so every layer (index, rerank, brute-force oracle) calls one
4//! canonical implementation.
5
6use crate::sparse::SparseVector;
7
8/// Dot product of two sparse vectors via a linear merge-join over their
9/// (ascending) term ids. This is the SPLADE/learned-sparse scoring primitive
10/// and the exact ground truth a brute-force oracle uses.
11///
12/// O(|a| + |b|). Relies on the [`SparseVector`] sorted-index invariant.
13pub fn sparse_dot(a: &SparseVector, b: &SparseVector) -> f32 {
14    let (ai, av) = (a.indices(), a.values());
15    let (bi, bv) = (b.indices(), b.values());
16    let mut i = 0;
17    let mut j = 0;
18    let mut acc = 0.0f32;
19    while i < ai.len() && j < bi.len() {
20        match ai[i].cmp(&bi[j]) {
21            std::cmp::Ordering::Less => i += 1,
22            std::cmp::Ordering::Greater => j += 1,
23            std::cmp::Ordering::Equal => {
24                acc += av[i] * bv[j];
25                i += 1;
26                j += 1;
27            }
28        }
29    }
30    acc
31}
32
33/// Euclidean (L2) norm of the weights.
34pub fn l2_norm(v: &SparseVector) -> f32 {
35    v.values().iter().map(|w| w * w).sum::<f32>().sqrt()
36}
37
38/// Keep only the `k` terms with the largest absolute weight, preserving the
39/// ascending-index invariant. This is the universal query-side latency lever
40/// for learned-sparse retrieval — high-DF / low-weight query terms dominate the
41/// posting-scan cost, so dropping them trades a little recall for large speedups.
42///
43/// Returns the input unchanged when `k >= len`. Ties are broken by keeping the
44/// lower term id (deterministic).
45pub fn prune_top_k(v: &SparseVector, k: usize) -> SparseVector {
46    if k >= v.len() {
47        return v.clone();
48    }
49    if k == 0 {
50        return SparseVector::new(Vec::new(), Vec::new()).expect("empty vector is always valid");
51    }
52
53    // Rank positions by descending |weight|, tie-break by ascending term id.
54    let mut order: Vec<usize> = (0..v.len()).collect();
55    let values = v.values();
56    let indices = v.indices();
57    order.sort_by(|&x, &y| {
58        values[y]
59            .abs()
60            .partial_cmp(&values[x].abs())
61            .unwrap_or(std::cmp::Ordering::Equal)
62            .then(indices[x].cmp(&indices[y]))
63    });
64    order.truncate(k);
65    // Re-sort the kept positions by term id to restore the invariant.
66    order.sort_unstable();
67
68    let kept_indices = order.iter().map(|&p| indices[p]).collect();
69    let kept_values = order.iter().map(|&p| values[p]).collect();
70    SparseVector::new(kept_indices, kept_values).expect("subset of a valid vector is valid")
71}
72
73#[cfg(test)]
74mod tests {
75    use super::*;
76
77    #[test]
78    fn dot_disjoint_is_zero() {
79        let a = SparseVector::new(vec![1, 3], vec![1.0, 1.0]).unwrap();
80        let b = SparseVector::new(vec![2, 4], vec![1.0, 1.0]).unwrap();
81        assert_eq!(sparse_dot(&a, &b), 0.0);
82    }
83
84    #[test]
85    fn dot_full_overlap() {
86        let a = SparseVector::new(vec![1, 2, 3], vec![1.0, 2.0, 3.0]).unwrap();
87        let b = SparseVector::new(vec![1, 2, 3], vec![4.0, 5.0, 6.0]).unwrap();
88        // 1*4 + 2*5 + 3*6 = 4 + 10 + 18 = 32
89        assert_eq!(sparse_dot(&a, &b), 32.0);
90    }
91
92    #[test]
93    fn dot_partial_overlap() {
94        let a = SparseVector::new(vec![1, 5, 9], vec![2.0, 3.0, 4.0]).unwrap();
95        let b = SparseVector::new(vec![5, 9, 13], vec![10.0, 0.5, 1.0]).unwrap();
96        // overlap on 5 (3*10=30) and 9 (4*0.5=2) => 32
97        assert_eq!(sparse_dot(&a, &b), 32.0);
98    }
99
100    #[test]
101    fn dot_is_commutative() {
102        let a = SparseVector::new(vec![1, 4, 7], vec![1.5, -2.0, 3.0]).unwrap();
103        let b = SparseVector::new(vec![4, 7, 8], vec![2.0, 1.0, 9.0]).unwrap();
104        assert_eq!(sparse_dot(&a, &b), sparse_dot(&b, &a));
105    }
106
107    #[test]
108    fn dot_with_empty_is_zero() {
109        let a = SparseVector::new(vec![1, 2], vec![1.0, 1.0]).unwrap();
110        let empty = SparseVector::new(vec![], vec![]).unwrap();
111        assert_eq!(sparse_dot(&a, &empty), 0.0);
112    }
113
114    #[test]
115    fn l2_norm_basic() {
116        let v = SparseVector::new(vec![1, 2], vec![3.0, 4.0]).unwrap();
117        assert_eq!(l2_norm(&v), 5.0);
118    }
119
120    #[test]
121    fn prune_keeps_largest_magnitude() {
122        let v = SparseVector::new(vec![1, 2, 3, 4], vec![0.1, -5.0, 0.2, 3.0]).unwrap();
123        let pruned = prune_top_k(&v, 2);
124        // largest |w|: term 2 (5.0) and term 4 (3.0); re-sorted by index
125        assert_eq!(pruned.indices(), &[2, 4]);
126        assert_eq!(pruned.values(), &[-5.0, 3.0]);
127    }
128
129    #[test]
130    fn prune_k_ge_len_is_identity() {
131        let v = SparseVector::new(vec![1, 2], vec![1.0, 2.0]).unwrap();
132        assert_eq!(prune_top_k(&v, 5), v);
133        assert_eq!(prune_top_k(&v, 2), v);
134    }
135
136    #[test]
137    fn prune_k_zero_is_empty() {
138        let v = SparseVector::new(vec![1, 2], vec![1.0, 2.0]).unwrap();
139        assert!(prune_top_k(&v, 0).is_empty());
140    }
141
142    #[test]
143    fn prune_preserves_dot_on_kept_terms() {
144        // Pruning the query to its top terms can only drop contributions from
145        // the removed terms; on a doc sharing only kept terms the score is exact.
146        let q = SparseVector::new(vec![1, 2, 3], vec![10.0, 0.01, 9.0]).unwrap();
147        let doc = SparseVector::new(vec![1, 3], vec![2.0, 2.0]).unwrap();
148        let pruned = prune_top_k(&q, 2);
149        assert_eq!(sparse_dot(&pruned, &doc), sparse_dot(&q, &doc));
150    }
151}