1use crate::sparse::SparseVector;
7use std::cmp::Ordering;
8
9pub fn sparse_dot(a: &SparseVector, b: &SparseVector) -> f32 {
15 let (ai, av) = (a.indices(), a.values());
16 let (bi, bv) = (b.indices(), b.values());
17 let mut i = 0;
18 let mut j = 0;
19 let mut acc = 0.0f32;
20 while i < ai.len() && j < bi.len() {
21 match ai[i].cmp(&bi[j]) {
22 Ordering::Less => i += 1,
23 Ordering::Greater => j += 1,
24 Ordering::Equal => {
25 acc += av[i] * bv[j];
26 i += 1;
27 j += 1;
28 }
29 }
30 }
31 acc
32}
33
34pub fn l2_norm(v: &SparseVector) -> f32 {
36 v.values().iter().map(|w| w * w).sum::<f32>().sqrt()
37}
38
39pub fn prune_top_k(v: &SparseVector, k: usize) -> SparseVector {
47 if k >= v.len() {
48 return v.clone();
49 }
50 if k == 0 {
51 return SparseVector::new(Vec::new(), Vec::new()).expect("empty vector is always valid");
52 }
53
54 let mut order: Vec<usize> = (0..v.len()).collect();
56 let values = v.values();
57 let indices = v.indices();
58 order.sort_by(|&x, &y| {
59 values[y]
60 .abs()
61 .partial_cmp(&values[x].abs())
62 .unwrap_or(Ordering::Equal)
63 .then(indices[x].cmp(&indices[y]))
64 });
65 order.truncate(k);
66 order.sort_unstable();
68
69 let kept_indices = order.iter().map(|&p| indices[p]).collect();
70 let kept_values = order.iter().map(|&p| values[p]).collect();
71 SparseVector::new(kept_indices, kept_values).expect("subset of a valid vector is valid")
72}
73
74#[cfg(test)]
75mod tests {
76 use super::*;
77
78 #[test]
79 fn dot_disjoint_is_zero() {
80 let a = SparseVector::new(vec![1, 3], vec![1.0, 1.0]).unwrap();
81 let b = SparseVector::new(vec![2, 4], vec![1.0, 1.0]).unwrap();
82 assert_eq!(sparse_dot(&a, &b), 0.0);
83 }
84
85 #[test]
86 fn dot_full_overlap() {
87 let a = SparseVector::new(vec![1, 2, 3], vec![1.0, 2.0, 3.0]).unwrap();
88 let b = SparseVector::new(vec![1, 2, 3], vec![4.0, 5.0, 6.0]).unwrap();
89 assert_eq!(sparse_dot(&a, &b), 32.0);
91 }
92
93 #[test]
94 fn dot_partial_overlap() {
95 let a = SparseVector::new(vec![1, 5, 9], vec![2.0, 3.0, 4.0]).unwrap();
96 let b = SparseVector::new(vec![5, 9, 13], vec![10.0, 0.5, 1.0]).unwrap();
97 assert_eq!(sparse_dot(&a, &b), 32.0);
99 }
100
101 #[test]
102 fn dot_is_commutative() {
103 let a = SparseVector::new(vec![1, 4, 7], vec![1.5, -2.0, 3.0]).unwrap();
104 let b = SparseVector::new(vec![4, 7, 8], vec![2.0, 1.0, 9.0]).unwrap();
105 assert_eq!(sparse_dot(&a, &b), sparse_dot(&b, &a));
106 }
107
108 #[test]
109 fn dot_with_empty_is_zero() {
110 let a = SparseVector::new(vec![1, 2], vec![1.0, 1.0]).unwrap();
111 let empty = SparseVector::new(vec![], vec![]).unwrap();
112 assert_eq!(sparse_dot(&a, &empty), 0.0);
113 }
114
115 #[test]
116 fn l2_norm_basic() {
117 let v = SparseVector::new(vec![1, 2], vec![3.0, 4.0]).unwrap();
118 assert_eq!(l2_norm(&v), 5.0);
119 }
120
121 #[test]
122 fn prune_keeps_largest_magnitude() {
123 let v = SparseVector::new(vec![1, 2, 3, 4], vec![0.1, -5.0, 0.2, 3.0]).unwrap();
124 let pruned = prune_top_k(&v, 2);
125 assert_eq!(pruned.indices(), &[2, 4]);
127 assert_eq!(pruned.values(), &[-5.0, 3.0]);
128 }
129
130 #[test]
131 fn prune_k_ge_len_is_identity() {
132 let v = SparseVector::new(vec![1, 2], vec![1.0, 2.0]).unwrap();
133 assert_eq!(prune_top_k(&v, 5), v);
134 assert_eq!(prune_top_k(&v, 2), v);
135 }
136
137 #[test]
138 fn prune_k_zero_is_empty() {
139 let v = SparseVector::new(vec![1, 2], vec![1.0, 2.0]).unwrap();
140 assert!(prune_top_k(&v, 0).is_empty());
141 }
142
143 #[test]
144 fn prune_preserves_dot_on_kept_terms() {
145 let q = SparseVector::new(vec![1, 2, 3], vec![10.0, 0.01, 9.0]).unwrap();
148 let doc = SparseVector::new(vec![1, 3], vec![2.0, 2.0]).unwrap();
149 let pruned = prune_top_k(&q, 2);
150 assert_eq!(sparse_dot(&pruned, &doc), sparse_dot(&q, &doc));
151 }
152}