vecstasy/
hashvec.rs

1use std::hash::{Hash, Hasher};
2
3use crate::{SIMD_LANECOUNT, veclike::VecLike};
4
5/// A wrapper type for `&[f32]` slices that provides
6/// stable hashing and equality based on the IEEE‑754 bit patterns.
7///
8/// `HashVec` compares floats by their raw bit representation
9/// and writes each element’s `to_bits()` into the hasher.
10/// This ensures that `0.0` and `-0.0` are distinguished,
11/// and that `NaN` values only compare equal if their bit patterns match.
12///
13/// Implements `VecLike` by delegating to the slice implementation,
14/// allowing distance, dot‑product, and normalization operations.
15///
16/// Since
17///
18/// # Example
19///
20/// ```
21/// use std::collections::hash_map::DefaultHasher;
22/// use std::hash::{Hash, Hasher};
23/// use vecstasy::HashVec;
24/// use vecstasy::VecLike;
25///
26/// let data: &[f32] = &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]; // len needs to be a multiple of 8
27/// let hv = HashVec::from(data);
28///
29/// // Hash computation
30/// let mut hasher = DefaultHasher::new();
31/// hv.hash(&mut hasher);
32/// let hash_value = hasher.finish();
33///
34/// // VecLike operations
35/// let normed: Vec<f32> = hv.normalized();
36/// ```
37#[derive(Debug, Clone)]
38pub struct HashVec<'a> {
39    internal: &'a [f32],
40}
41
42impl<'a> From<&'a [f32]> for HashVec<'a> {
43    fn from(value: &'a [f32]) -> Self {
44        debug_assert!(
45            value.len() % SIMD_LANECOUNT == 0,
46            "You provided a vector that doesn't play nicely with SIMD"
47        );
48        HashVec { internal: value }
49    }
50}
51
52// hashing is done bit-wise, blocks of 32 bits at a time
53impl<'a> Hash for HashVec<'a> {
54    fn hash<H: Hasher>(&self, state: &mut H) {
55        for &val in self.internal {
56            state.write_u32(val.to_bits());
57        }
58    }
59}
60
61// equality of two vectors is defined bitwise.
62// this means that vectors like [-0.0] and [0.0] are NOT considered equal
63// for our purposes, even though the individual elements are.
64impl<'a> PartialEq for HashVec<'a> {
65    fn eq(&self, other: &Self) -> bool {
66        self.internal.len() == other.internal.len()
67            && self
68                .internal
69                .iter()
70                .zip(other.internal)
71                .all(|(a, b)| a.to_bits() == b.to_bits())
72    }
73}
74
75// vector equality is also reflexive (for all a, a == a), we can tell that to the compiler:
76impl<'a> Eq for HashVec<'a> {}
77
78// HashVec provides linear algebra ops, we just have to delegate to the internal slice.
79// the inline tags enable cross-crate inlining, which is a big deal here.
80impl<'a> VecLike for HashVec<'a> {
81    type Owned = Vec<f32>;
82
83    /// Computes the squared L2 (Euclidean) distance between `self` and `othr`.
84    ///
85    /// Operates on fixed‐size chunks; any trailing elements when the slice length
86    /// is not a multiple of the chunk size will be silently ignored in release mode.
87    ///
88    /// # Panics
89    /// - In debug mode, if `self.len() != othr.len()`.
90    /// - In debug mode, if the slice length is not a multiple of the internal chunk size.
91    #[inline]
92    fn l2_dist_squared(&self, other: &Self) -> f32 {
93        self.internal.l2_dist_squared(&other.internal)
94    }
95
96    /// Computes the dot product of `self` and `othr`.
97    ///
98    /// Operates on fixed‐size chunks; any trailing elements when the slice length
99    /// is not a multiple of the chunk size will be silently ignored in release mode.
100    ///
101    /// # Panics
102    /// - In debug mode, if `self.len() != othr.len()`.
103    /// - In debug mode, if the slice length is not a multiple of the internal chunk size.
104    #[inline]
105    fn dot(&self, other: &Self) -> f32 {
106        self.internal.dot(&other.internal)
107    }
108
109    /// Returns a normalized copy of the input slice.
110    ///
111    /// Operates on fixed‐size chunks; any trailing elements when the slice length
112    /// is not a multiple of the chunk size will be silently ignored in release mode.
113    ///
114    /// If the input norm is zero, returns a zero vector of the same length.
115    ///
116    /// # Panics
117    /// - In debug mode, if the slice length is not a multiple of the internal chunk size.
118    #[inline]
119    fn normalized(&self) -> Self::Owned {
120        self.internal.normalized()
121    }
122}
123
124#[cfg(test)]
125mod tests {
126    use std::vec;
127
128    use super::*;
129    use quickcheck::{QuickCheck, TestResult};
130
131    const TOLERANCE: f32 = 1e-6;
132
133    fn close(actual: f32, target: f32) -> bool {
134        (target - actual).abs() < TOLERANCE
135    }
136
137    fn is_valid_l2(suspect: f32) -> bool {
138        suspect.is_finite() && suspect >= 0.0
139    }
140
141    fn l2_spec<'a>(v1: HashVec<'a>, v2: HashVec<'a>) -> f32 {
142        v1.internal
143            .iter()
144            .zip(v2.internal.iter())
145            .map(|(&x, &y)| {
146                let diff = x - y;
147                diff * diff
148            })
149            .sum()
150    }
151
152    #[test]
153    fn self_sim_is_zero() {
154        fn qc_self_sim_is_zero(totest: Vec<f32>) -> TestResult {
155            let usable_length = totest.len() / 8 * 8;
156            if totest[0..usable_length].iter().any(|x| !x.is_finite()) {
157                return TestResult::discard();
158            }
159            let testvec = HashVec {
160                internal: &totest[0..usable_length],
161            };
162            let selfsim = testvec.l2_dist_squared(&testvec).sqrt();
163            let to_check = is_valid_l2(selfsim) && close(selfsim, 0.0);
164            return TestResult::from_bool(to_check);
165        }
166
167        QuickCheck::new()
168            .tests(10_000)
169            // force that less than 90% of tests are discarded due to precondition violations
170            // i.e. at least 10% of inputs should be valid so that we cover a good range
171            .min_tests_passed(500)
172            .quickcheck(qc_self_sim_is_zero as fn(Vec<f32>) -> TestResult);
173    }
174
175    #[test]
176    // verifies the claim in the documentation of l2_dist_squared
177    // i.e. dist(u,v) < dist(w, x) ⇔ dist(u,v) ** 2 < dist(w,x) ** 2
178    fn squared_invariant() {
179        fn qc_squared_invariant(u: Vec<f32>, v: Vec<f32>, w: Vec<f32>, x: Vec<f32>) -> TestResult {
180            let all_vecs = [u, v, w, x]; //no need to check for NaNs in this case
181            let min_length = all_vecs.iter().map(|x| x.len()).min().unwrap() / 8 * 8;
182            let all_vectors: Vec<HashVec> = all_vecs
183                .iter()
184                .map(|vec| HashVec::from(&vec[..min_length]))
185                .collect();
186
187            let d1_squared = all_vectors[0].l2_dist_squared(&all_vectors[1]);
188            let d2_squared = all_vectors[2].l2_dist_squared(&all_vectors[3]);
189
190            let d1_root = all_vectors[0].l2_dist_squared(&all_vectors[1]).sqrt();
191            let d2_root = all_vectors[2].l2_dist_squared(&all_vectors[3]).sqrt();
192
193            let sanity_check1 = (d1_squared < d2_squared) == (d1_root < d2_root);
194            let sanity_check2 = (d1_squared <= d2_squared) == (d1_root <= d2_root);
195            TestResult::from_bool(sanity_check1 && sanity_check2)
196        }
197
198        QuickCheck::new()
199            .tests(10_000)
200            .min_tests_passed(500)
201            .quickcheck(
202                qc_squared_invariant as fn(Vec<f32>, Vec<f32>, Vec<f32>, Vec<f32>) -> TestResult,
203            );
204    }
205
206    #[test]
207    fn simd_matches_spec() {
208        fn qc_simd_matches_spec(u: Vec<f32>, v: Vec<f32>) -> TestResult {
209            let min_length = u.len().min(v.len()) / 8 * 8;
210            let (u_f32v, v_f32v) = (
211                HashVec::from(&u[0..min_length]),
212                HashVec::from(&v[0..min_length]),
213            );
214            let simd = u_f32v.l2_dist_squared(&v_f32v);
215            let spec = l2_spec(u_f32v, v_f32v);
216
217            if simd.is_infinite() {
218                TestResult::from_bool(spec.is_infinite())
219            } else if simd.is_nan() {
220                TestResult::from_bool(spec.is_nan())
221            } else {
222                TestResult::from_bool(close(simd, spec))
223            }
224        }
225
226        QuickCheck::new()
227            .tests(10_000)
228            .min_tests_passed(500)
229            .quickcheck(qc_simd_matches_spec as fn(Vec<f32>, Vec<f32>) -> TestResult);
230    }
231
232    #[test]
233    fn normalization_gives_unit_l2_norm() {
234        fn qc_normalized(vec: Vec<f32>) -> TestResult {
235            if vec.len() < 8 {
236                return TestResult::discard();
237            }
238            let usable = vec.len() / 8 * 8;
239            let vec: Vec<f32> = vec[..usable]
240                .iter()
241                .cloned()
242                .map(|x| x.clamp(-1e6, 1e6))
243                .collect();
244
245            if vec.iter().any(|x| !x.is_finite()) {
246                return TestResult::discard();
247            }
248
249            let hv = HashVec::from(vec.as_slice());
250            let norm = hv.normalized();
251            let normhv = HashVec::from(norm.as_slice());
252            let self_dot = normhv.dot(&normhv);
253
254            if vec.iter().all(|&x| x == 0.0) {
255                TestResult::from_bool(close(self_dot, 0.0))
256            } else {
257                TestResult::from_bool(close(self_dot, 1.0))
258            }
259        }
260
261        QuickCheck::new()
262            .tests(10_000)
263            .min_tests_passed(500)
264            .quickcheck(qc_normalized as fn(Vec<f32>) -> TestResult);
265
266        assert!(!qc_normalized(vec![0.0; 8]).is_failure());
267    }
268
269    #[test]
270    fn dot_product_matches_spec() {
271        fn qc_dot_matches_spec(u: Vec<f32>, v: Vec<f32>) -> TestResult {
272            let usable = u.len().min(v.len()) / 8 * 8;
273            if usable == 0 {
274                return TestResult::discard();
275            }
276
277            let u: Vec<f32> = u[..usable].iter().map(|x| x.clamp(-1e3, 1e3)).collect();
278            let v: Vec<f32> = v[..usable].iter().map(|x| x.clamp(-1e3, 1e3)).collect();
279
280            if u.iter().any(|x| !x.is_finite()) || v.iter().any(|x| !x.is_finite()) {
281                return TestResult::discard();
282            }
283
284            let uv = HashVec::from(u.as_slice());
285            let vv = HashVec::from(v.as_slice());
286
287            let spec_dot: f32 = u.iter().zip(&v).map(|(&a, &b)| a * b).sum::<f32>().abs();
288            let impl_dot = uv.dot(&vv).abs();
289
290            TestResult::from_bool(0.99 * spec_dot <= impl_dot && impl_dot <= 1.01 * spec_dot)
291        }
292
293        QuickCheck::new()
294            .tests(10_000)
295            .min_tests_passed(500)
296            .quickcheck(qc_dot_matches_spec as fn(Vec<f32>, Vec<f32>) -> TestResult);
297    }
298
299    #[test]
300    fn hash_consistent_for_equal_inputs() {
301        use std::collections::hash_map::DefaultHasher;
302
303        fn hash_of(v: &[f32]) -> u64 {
304            let mut hasher = DefaultHasher::new();
305            HashVec::from(v).hash(&mut hasher);
306            hasher.finish()
307        }
308
309        fn qc_equal_vecs_hash_same(v: Vec<f32>) -> TestResult {
310            let usable = v.len() / 8 * 8;
311            let v = &v[..usable];
312
313            if v.iter().any(|x| !x.is_finite()) {
314                return TestResult::discard();
315            }
316
317            let h1 = hash_of(v);
318            let h2 = hash_of(v);
319            TestResult::from_bool(h1 == h2)
320        }
321
322        QuickCheck::new()
323            .tests(10_000)
324            .min_tests_passed(500)
325            .quickcheck(qc_equal_vecs_hash_same as fn(Vec<f32>) -> TestResult);
326    }
327
328    #[test]
329    fn different_vectors_likely_hash_differently() {
330        use std::collections::hash_map::DefaultHasher;
331
332        fn hash_of(v: &[f32]) -> u64 {
333            let mut hasher = DefaultHasher::new();
334            HashVec::from(v).hash(&mut hasher);
335            hasher.finish()
336        }
337
338        let a = vec![1.0_f32; 8];
339        let mut b = vec![1.0_f32; 8];
340        b[0] = 2.0;
341        let ha = hash_of(&a);
342        let hb = hash_of(&b);
343        assert_ne!(ha, hb);
344    }
345
346    #[test]
347    fn equality_works_as_expected() {
348        fn qc_eq_correctness(v: Vec<f32>) -> TestResult {
349            let usable = v.len() / 8 * 8;
350            if usable == 0 || v[..usable].iter().any(|x| !x.is_finite()) {
351                return TestResult::discard();
352            }
353
354            let slice = &v[..usable];
355            let hv1 = HashVec::from(slice);
356            let hv2 = HashVec::from(slice);
357
358            // Reflexivity
359            let reflexivity = hv1 == hv1;
360
361            // Symmetry
362            let symmetry = hv1 == hv2 && hv2 == hv1;
363
364            // Inequality after mutation
365            let mut modified = slice.to_vec();
366            modified[0] = f32::from_bits(modified[0].to_bits().wrapping_add(1)); // bit-level tweak
367            let hv3 = HashVec::from(modified.as_slice());
368
369            let unequal = hv1 != hv3;
370            TestResult::from_bool(unequal && reflexivity && symmetry)
371        }
372
373        QuickCheck::new()
374            .tests(10_000)
375            .min_tests_passed(500)
376            .quickcheck(qc_eq_correctness as fn(Vec<f32>) -> TestResult);
377    }
378}