1use crate::sparse::SparseVector;
7
8pub fn sparse_dot(a: &SparseVector, b: &SparseVector) -> f32 {
14 let (ai, av) = (a.indices(), a.values());
15 let (bi, bv) = (b.indices(), b.values());
16 let mut i = 0;
17 let mut j = 0;
18 let mut acc = 0.0f32;
19 while i < ai.len() && j < bi.len() {
20 match ai[i].cmp(&bi[j]) {
21 std::cmp::Ordering::Less => i += 1,
22 std::cmp::Ordering::Greater => j += 1,
23 std::cmp::Ordering::Equal => {
24 acc += av[i] * bv[j];
25 i += 1;
26 j += 1;
27 }
28 }
29 }
30 acc
31}
32
33pub fn l2_norm(v: &SparseVector) -> f32 {
35 v.values().iter().map(|w| w * w).sum::<f32>().sqrt()
36}
37
38pub fn prune_top_k(v: &SparseVector, k: usize) -> SparseVector {
46 if k >= v.len() {
47 return v.clone();
48 }
49 if k == 0 {
50 return SparseVector::new(Vec::new(), Vec::new()).expect("empty vector is always valid");
51 }
52
53 let mut order: Vec<usize> = (0..v.len()).collect();
55 let values = v.values();
56 let indices = v.indices();
57 order.sort_by(|&x, &y| {
58 values[y]
59 .abs()
60 .partial_cmp(&values[x].abs())
61 .unwrap_or(std::cmp::Ordering::Equal)
62 .then(indices[x].cmp(&indices[y]))
63 });
64 order.truncate(k);
65 order.sort_unstable();
67
68 let kept_indices = order.iter().map(|&p| indices[p]).collect();
69 let kept_values = order.iter().map(|&p| values[p]).collect();
70 SparseVector::new(kept_indices, kept_values).expect("subset of a valid vector is valid")
71}
72
73#[cfg(test)]
74mod tests {
75 use super::*;
76
77 #[test]
78 fn dot_disjoint_is_zero() {
79 let a = SparseVector::new(vec![1, 3], vec![1.0, 1.0]).unwrap();
80 let b = SparseVector::new(vec![2, 4], vec![1.0, 1.0]).unwrap();
81 assert_eq!(sparse_dot(&a, &b), 0.0);
82 }
83
84 #[test]
85 fn dot_full_overlap() {
86 let a = SparseVector::new(vec![1, 2, 3], vec![1.0, 2.0, 3.0]).unwrap();
87 let b = SparseVector::new(vec![1, 2, 3], vec![4.0, 5.0, 6.0]).unwrap();
88 assert_eq!(sparse_dot(&a, &b), 32.0);
90 }
91
92 #[test]
93 fn dot_partial_overlap() {
94 let a = SparseVector::new(vec![1, 5, 9], vec![2.0, 3.0, 4.0]).unwrap();
95 let b = SparseVector::new(vec![5, 9, 13], vec![10.0, 0.5, 1.0]).unwrap();
96 assert_eq!(sparse_dot(&a, &b), 32.0);
98 }
99
100 #[test]
101 fn dot_is_commutative() {
102 let a = SparseVector::new(vec![1, 4, 7], vec![1.5, -2.0, 3.0]).unwrap();
103 let b = SparseVector::new(vec![4, 7, 8], vec![2.0, 1.0, 9.0]).unwrap();
104 assert_eq!(sparse_dot(&a, &b), sparse_dot(&b, &a));
105 }
106
107 #[test]
108 fn dot_with_empty_is_zero() {
109 let a = SparseVector::new(vec![1, 2], vec![1.0, 1.0]).unwrap();
110 let empty = SparseVector::new(vec![], vec![]).unwrap();
111 assert_eq!(sparse_dot(&a, &empty), 0.0);
112 }
113
114 #[test]
115 fn l2_norm_basic() {
116 let v = SparseVector::new(vec![1, 2], vec![3.0, 4.0]).unwrap();
117 assert_eq!(l2_norm(&v), 5.0);
118 }
119
120 #[test]
121 fn prune_keeps_largest_magnitude() {
122 let v = SparseVector::new(vec![1, 2, 3, 4], vec![0.1, -5.0, 0.2, 3.0]).unwrap();
123 let pruned = prune_top_k(&v, 2);
124 assert_eq!(pruned.indices(), &[2, 4]);
126 assert_eq!(pruned.values(), &[-5.0, 3.0]);
127 }
128
129 #[test]
130 fn prune_k_ge_len_is_identity() {
131 let v = SparseVector::new(vec![1, 2], vec![1.0, 2.0]).unwrap();
132 assert_eq!(prune_top_k(&v, 5), v);
133 assert_eq!(prune_top_k(&v, 2), v);
134 }
135
136 #[test]
137 fn prune_k_zero_is_empty() {
138 let v = SparseVector::new(vec![1, 2], vec![1.0, 2.0]).unwrap();
139 assert!(prune_top_k(&v, 0).is_empty());
140 }
141
142 #[test]
143 fn prune_preserves_dot_on_kept_terms() {
144 let q = SparseVector::new(vec![1, 2, 3], vec![10.0, 0.01, 9.0]).unwrap();
147 let doc = SparseVector::new(vec![1, 3], vec![2.0, 2.0]).unwrap();
148 let pruned = prune_top_k(&q, 2);
149 assert_eq!(sparse_dot(&pruned, &doc), sparse_dot(&q, &doc));
150 }
151}