1use std::hash::{Hash, Hasher};
2
3use crate::{SIMD_LANECOUNT, veclike::VecLike};
4
5#[derive(Debug, Clone)]
38pub struct HashVec<'a> {
39 internal: &'a [f32],
40}
41
42impl<'a> From<&'a [f32]> for HashVec<'a> {
43 fn from(value: &'a [f32]) -> Self {
44 debug_assert!(
45 value.len() % SIMD_LANECOUNT == 0,
46 "You provided a vector that doesn't play nicely with SIMD"
47 );
48 HashVec { internal: value }
49 }
50}
51
52impl<'a> Hash for HashVec<'a> {
54 fn hash<H: Hasher>(&self, state: &mut H) {
55 for &val in self.internal {
56 state.write_u32(val.to_bits());
57 }
58 }
59}
60
61impl<'a> PartialEq for HashVec<'a> {
65 fn eq(&self, other: &Self) -> bool {
66 self.internal.len() == other.internal.len()
67 && self
68 .internal
69 .iter()
70 .zip(other.internal)
71 .all(|(a, b)| a.to_bits() == b.to_bits())
72 }
73}
74
75impl<'a> Eq for HashVec<'a> {}
77
78impl<'a> VecLike for HashVec<'a> {
81 type Owned = Vec<f32>;
82
83 #[inline]
92 fn l2_dist_squared(&self, other: &Self) -> f32 {
93 self.internal.l2_dist_squared(&other.internal)
94 }
95
96 #[inline]
105 fn dot(&self, other: &Self) -> f32 {
106 self.internal.dot(&other.internal)
107 }
108
109 #[inline]
119 fn normalized(&self) -> Self::Owned {
120 self.internal.normalized()
121 }
122}
123
124#[cfg(test)]
125mod tests {
126 use std::vec;
127
128 use super::*;
129 use quickcheck::{QuickCheck, TestResult};
130
131 const TOLERANCE: f32 = 1e-6;
132
133 fn close(actual: f32, target: f32) -> bool {
134 (target - actual).abs() < TOLERANCE
135 }
136
137 fn is_valid_l2(suspect: f32) -> bool {
138 suspect.is_finite() && suspect >= 0.0
139 }
140
141 fn l2_spec<'a>(v1: HashVec<'a>, v2: HashVec<'a>) -> f32 {
142 v1.internal
143 .iter()
144 .zip(v2.internal.iter())
145 .map(|(&x, &y)| {
146 let diff = x - y;
147 diff * diff
148 })
149 .sum()
150 }
151
152 #[test]
153 fn self_sim_is_zero() {
154 fn qc_self_sim_is_zero(totest: Vec<f32>) -> TestResult {
155 let usable_length = totest.len() / 8 * 8;
156 if totest[0..usable_length].iter().any(|x| !x.is_finite()) {
157 return TestResult::discard();
158 }
159 let testvec = HashVec {
160 internal: &totest[0..usable_length],
161 };
162 let selfsim = testvec.l2_dist_squared(&testvec).sqrt();
163 let to_check = is_valid_l2(selfsim) && close(selfsim, 0.0);
164 return TestResult::from_bool(to_check);
165 }
166
167 QuickCheck::new()
168 .tests(10_000)
169 .min_tests_passed(500)
172 .quickcheck(qc_self_sim_is_zero as fn(Vec<f32>) -> TestResult);
173 }
174
175 #[test]
176 fn squared_invariant() {
179 fn qc_squared_invariant(u: Vec<f32>, v: Vec<f32>, w: Vec<f32>, x: Vec<f32>) -> TestResult {
180 let all_vecs = [u, v, w, x]; let min_length = all_vecs.iter().map(|x| x.len()).min().unwrap() / 8 * 8;
182 let all_vectors: Vec<HashVec> = all_vecs
183 .iter()
184 .map(|vec| HashVec::from(&vec[..min_length]))
185 .collect();
186
187 let d1_squared = all_vectors[0].l2_dist_squared(&all_vectors[1]);
188 let d2_squared = all_vectors[2].l2_dist_squared(&all_vectors[3]);
189
190 let d1_root = all_vectors[0].l2_dist_squared(&all_vectors[1]).sqrt();
191 let d2_root = all_vectors[2].l2_dist_squared(&all_vectors[3]).sqrt();
192
193 let sanity_check1 = (d1_squared < d2_squared) == (d1_root < d2_root);
194 let sanity_check2 = (d1_squared <= d2_squared) == (d1_root <= d2_root);
195 TestResult::from_bool(sanity_check1 && sanity_check2)
196 }
197
198 QuickCheck::new()
199 .tests(10_000)
200 .min_tests_passed(500)
201 .quickcheck(
202 qc_squared_invariant as fn(Vec<f32>, Vec<f32>, Vec<f32>, Vec<f32>) -> TestResult,
203 );
204 }
205
206 #[test]
207 fn simd_matches_spec() {
208 fn qc_simd_matches_spec(u: Vec<f32>, v: Vec<f32>) -> TestResult {
209 let min_length = u.len().min(v.len()) / 8 * 8;
210 let (u_f32v, v_f32v) = (
211 HashVec::from(&u[0..min_length]),
212 HashVec::from(&v[0..min_length]),
213 );
214 let simd = u_f32v.l2_dist_squared(&v_f32v);
215 let spec = l2_spec(u_f32v, v_f32v);
216
217 if simd.is_infinite() {
218 TestResult::from_bool(spec.is_infinite())
219 } else if simd.is_nan() {
220 TestResult::from_bool(spec.is_nan())
221 } else {
222 TestResult::from_bool(close(simd, spec))
223 }
224 }
225
226 QuickCheck::new()
227 .tests(10_000)
228 .min_tests_passed(500)
229 .quickcheck(qc_simd_matches_spec as fn(Vec<f32>, Vec<f32>) -> TestResult);
230 }
231
232 #[test]
233 fn normalization_gives_unit_l2_norm() {
234 fn qc_normalized(vec: Vec<f32>) -> TestResult {
235 if vec.len() < 8 {
236 return TestResult::discard();
237 }
238 let usable = vec.len() / 8 * 8;
239 let vec: Vec<f32> = vec[..usable]
240 .iter()
241 .cloned()
242 .map(|x| x.clamp(-1e6, 1e6))
243 .collect();
244
245 if vec.iter().any(|x| !x.is_finite()) {
246 return TestResult::discard();
247 }
248
249 let hv = HashVec::from(vec.as_slice());
250 let norm = hv.normalized();
251 let normhv = HashVec::from(norm.as_slice());
252 let self_dot = normhv.dot(&normhv);
253
254 if vec.iter().all(|&x| x == 0.0) {
255 TestResult::from_bool(close(self_dot, 0.0))
256 } else {
257 TestResult::from_bool(close(self_dot, 1.0))
258 }
259 }
260
261 QuickCheck::new()
262 .tests(10_000)
263 .min_tests_passed(500)
264 .quickcheck(qc_normalized as fn(Vec<f32>) -> TestResult);
265
266 assert!(!qc_normalized(vec![0.0; 8]).is_failure());
267 }
268
269 #[test]
270 fn dot_product_matches_spec() {
271 fn qc_dot_matches_spec(u: Vec<f32>, v: Vec<f32>) -> TestResult {
272 let usable = u.len().min(v.len()) / 8 * 8;
273 if usable == 0 {
274 return TestResult::discard();
275 }
276
277 let u: Vec<f32> = u[..usable].iter().map(|x| x.clamp(-1e3, 1e3)).collect();
278 let v: Vec<f32> = v[..usable].iter().map(|x| x.clamp(-1e3, 1e3)).collect();
279
280 if u.iter().any(|x| !x.is_finite()) || v.iter().any(|x| !x.is_finite()) {
281 return TestResult::discard();
282 }
283
284 let uv = HashVec::from(u.as_slice());
285 let vv = HashVec::from(v.as_slice());
286
287 let spec_dot: f32 = u.iter().zip(&v).map(|(&a, &b)| a * b).sum::<f32>().abs();
288 let impl_dot = uv.dot(&vv).abs();
289
290 TestResult::from_bool(0.99 * spec_dot <= impl_dot && impl_dot <= 1.01 * spec_dot)
291 }
292
293 QuickCheck::new()
294 .tests(10_000)
295 .min_tests_passed(500)
296 .quickcheck(qc_dot_matches_spec as fn(Vec<f32>, Vec<f32>) -> TestResult);
297 }
298
299 #[test]
300 fn hash_consistent_for_equal_inputs() {
301 use std::collections::hash_map::DefaultHasher;
302
303 fn hash_of(v: &[f32]) -> u64 {
304 let mut hasher = DefaultHasher::new();
305 HashVec::from(v).hash(&mut hasher);
306 hasher.finish()
307 }
308
309 fn qc_equal_vecs_hash_same(v: Vec<f32>) -> TestResult {
310 let usable = v.len() / 8 * 8;
311 let v = &v[..usable];
312
313 if v.iter().any(|x| !x.is_finite()) {
314 return TestResult::discard();
315 }
316
317 let h1 = hash_of(v);
318 let h2 = hash_of(v);
319 TestResult::from_bool(h1 == h2)
320 }
321
322 QuickCheck::new()
323 .tests(10_000)
324 .min_tests_passed(500)
325 .quickcheck(qc_equal_vecs_hash_same as fn(Vec<f32>) -> TestResult);
326 }
327
328 #[test]
329 fn different_vectors_likely_hash_differently() {
330 use std::collections::hash_map::DefaultHasher;
331
332 fn hash_of(v: &[f32]) -> u64 {
333 let mut hasher = DefaultHasher::new();
334 HashVec::from(v).hash(&mut hasher);
335 hasher.finish()
336 }
337
338 let a = vec![1.0_f32; 8];
339 let mut b = vec![1.0_f32; 8];
340 b[0] = 2.0;
341 let ha = hash_of(&a);
342 let hb = hash_of(&b);
343 assert_ne!(ha, hb);
344 }
345
346 #[test]
347 fn equality_works_as_expected() {
348 fn qc_eq_correctness(v: Vec<f32>) -> TestResult {
349 let usable = v.len() / 8 * 8;
350 if usable == 0 || v[..usable].iter().any(|x| !x.is_finite()) {
351 return TestResult::discard();
352 }
353
354 let slice = &v[..usable];
355 let hv1 = HashVec::from(slice);
356 let hv2 = HashVec::from(slice);
357
358 let reflexivity = hv1 == hv1;
360
361 let symmetry = hv1 == hv2 && hv2 == hv1;
363
364 let mut modified = slice.to_vec();
366 modified[0] = f32::from_bits(modified[0].to_bits().wrapping_add(1)); let hv3 = HashVec::from(modified.as_slice());
368
369 let unequal = hv1 != hv3;
370 TestResult::from_bool(unequal && reflexivity && symmetry)
371 }
372
373 QuickCheck::new()
374 .tests(10_000)
375 .min_tests_passed(500)
376 .quickcheck(qc_eq_correctness as fn(Vec<f32>) -> TestResult);
377 }
378}