velesdb_core/index/hnsw/native/
distance.rs1use crate::distance::DistanceMetric;
9
10pub trait DistanceEngine: Send + Sync {
15 fn distance(&self, a: &[f32], b: &[f32]) -> f32;
17
18 fn batch_distance(&self, query: &[f32], candidates: &[&[f32]]) -> Vec<f32> {
23 candidates.iter().map(|c| self.distance(query, c)).collect()
24 }
25
26 fn metric(&self) -> DistanceMetric;
28}
29
30pub struct CpuDistance {
32 metric: DistanceMetric,
33}
34
35impl CpuDistance {
36 #[must_use]
38 pub fn new(metric: DistanceMetric) -> Self {
39 Self { metric }
40 }
41}
42
43impl DistanceEngine for CpuDistance {
44 fn distance(&self, a: &[f32], b: &[f32]) -> f32 {
45 match self.metric {
46 DistanceMetric::Cosine => cosine_distance_scalar(a, b),
47 DistanceMetric::Euclidean => euclidean_distance_scalar(a, b),
48 DistanceMetric::DotProduct => dot_product_scalar(a, b),
49 DistanceMetric::Hamming => hamming_distance_scalar(a, b),
50 DistanceMetric::Jaccard => jaccard_distance_scalar(a, b),
51 }
52 }
53
54 fn metric(&self) -> DistanceMetric {
55 self.metric
56 }
57}
58
59pub struct SimdDistance {
63 metric: DistanceMetric,
64}
65
66impl SimdDistance {
67 #[must_use]
69 pub fn new(metric: DistanceMetric) -> Self {
70 Self { metric }
71 }
72}
73
74impl DistanceEngine for SimdDistance {
75 fn distance(&self, a: &[f32], b: &[f32]) -> f32 {
76 match self.metric {
78 DistanceMetric::Cosine => 1.0 - crate::simd::cosine_similarity_fast(a, b),
79 DistanceMetric::Euclidean => crate::simd::euclidean_distance_fast(a, b),
80 DistanceMetric::DotProduct => -crate::simd::dot_product_fast(a, b), DistanceMetric::Hamming => crate::simd::hamming_distance_fast(a, b),
83 DistanceMetric::Jaccard => 1.0 - crate::simd::jaccard_similarity_fast(a, b),
84 }
85 }
86
87 fn batch_distance(&self, query: &[f32], candidates: &[&[f32]]) -> Vec<f32> {
88 let prefetch_distance = crate::simd::calculate_prefetch_distance(query.len());
91 let mut results = Vec::with_capacity(candidates.len());
92
93 for (i, candidate) in candidates.iter().enumerate() {
94 if i + prefetch_distance < candidates.len() {
96 crate::simd::prefetch_vector(candidates[i + prefetch_distance]);
97 }
98 results.push(self.distance(query, candidate));
99 }
100
101 results
102 }
103
104 fn metric(&self) -> DistanceMetric {
105 self.metric
106 }
107}
108
109pub struct NativeSimdDistance {
114 metric: DistanceMetric,
115}
116
117impl NativeSimdDistance {
118 #[must_use]
120 pub fn new(metric: DistanceMetric) -> Self {
121 Self { metric }
122 }
123}
124
125impl DistanceEngine for NativeSimdDistance {
126 fn distance(&self, a: &[f32], b: &[f32]) -> f32 {
127 match self.metric {
128 DistanceMetric::Cosine => 1.0 - crate::simd_native::cosine_similarity_native(a, b),
129 DistanceMetric::Euclidean => crate::simd_native::euclidean_native(a, b),
130 DistanceMetric::DotProduct => -crate::simd_native::dot_product_native(a, b),
131 DistanceMetric::Hamming => crate::simd::hamming_distance_fast(a, b),
133 DistanceMetric::Jaccard => 1.0 - crate::simd::jaccard_similarity_fast(a, b),
134 }
135 }
136
137 fn batch_distance(&self, query: &[f32], candidates: &[&[f32]]) -> Vec<f32> {
138 match self.metric {
139 DistanceMetric::DotProduct => {
140 crate::simd_native::batch_dot_product_native(candidates, query)
142 .into_iter()
143 .map(|d| -d)
144 .collect()
145 }
146 _ => candidates.iter().map(|c| self.distance(query, c)).collect(),
147 }
148 }
149
150 fn metric(&self) -> DistanceMetric {
151 self.metric
152 }
153}
154
155#[inline]
160fn cosine_distance_scalar(a: &[f32], b: &[f32]) -> f32 {
161 let mut dot = 0.0_f32;
162 let mut norm_a = 0.0_f32;
163 let mut norm_b = 0.0_f32;
164
165 for (x, y) in a.iter().zip(b.iter()) {
166 dot += x * y;
167 norm_a += x * x;
168 norm_b += y * y;
169 }
170
171 let denom = (norm_a * norm_b).sqrt();
172 if denom == 0.0 {
173 1.0
174 } else {
175 1.0 - (dot / denom)
176 }
177}
178
179#[inline]
180fn euclidean_distance_scalar(a: &[f32], b: &[f32]) -> f32 {
181 a.iter()
182 .zip(b.iter())
183 .map(|(x, y)| (x - y).powi(2))
184 .sum::<f32>()
185 .sqrt()
186}
187
188#[inline]
189fn dot_product_scalar(a: &[f32], b: &[f32]) -> f32 {
190 -a.iter().zip(b.iter()).map(|(x, y)| x * y).sum::<f32>()
192}
193
194#[inline]
195fn hamming_distance_scalar(a: &[f32], b: &[f32]) -> f32 {
196 a.iter()
197 .zip(b.iter())
198 .filter(|(x, y)| (x.to_bits() ^ y.to_bits()) != 0)
199 .count() as f32
200}
201
202#[inline]
203fn jaccard_distance_scalar(a: &[f32], b: &[f32]) -> f32 {
204 let mut intersection = 0.0_f32;
205 let mut union = 0.0_f32;
206
207 for (x, y) in a.iter().zip(b.iter()) {
208 intersection += x.min(*y);
209 union += x.max(*y);
210 }
211
212 if union == 0.0 {
213 1.0
214 } else {
215 1.0 - (intersection / union)
216 }
217}
218
219#[cfg(test)]
220#[allow(clippy::cast_precision_loss)]
221mod tests {
222 use super::*;
223
224 #[test]
225 fn test_cosine_identical_vectors() {
226 let engine = CpuDistance::new(DistanceMetric::Cosine);
227 let v = vec![1.0, 2.0, 3.0];
228 let dist = engine.distance(&v, &v);
229 assert!(
230 dist.abs() < 1e-5,
231 "Identical vectors should have distance ~0"
232 );
233 }
234
235 #[test]
236 fn test_euclidean_known_distance() {
237 let engine = CpuDistance::new(DistanceMetric::Euclidean);
238 let a = vec![0.0, 0.0, 0.0];
239 let b = vec![3.0, 4.0, 0.0];
240 let dist = engine.distance(&a, &b);
241 assert!((dist - 5.0).abs() < 1e-5, "3-4-5 triangle");
242 }
243
244 #[test]
245 fn test_simd_matches_scalar() {
246 let cpu = CpuDistance::new(DistanceMetric::Cosine);
247 let simd = SimdDistance::new(DistanceMetric::Cosine);
248
249 let a: Vec<f32> = (0..768).map(|i| (i as f32 * 0.01).sin()).collect();
250 let b: Vec<f32> = (0..768).map(|i| (i as f32 * 0.02).cos()).collect();
251
252 let cpu_dist = cpu.distance(&a, &b);
253 let simd_dist = simd.distance(&a, &b);
254
255 assert!(
256 (cpu_dist - simd_dist).abs() < 1e-4,
257 "SIMD should match scalar: cpu={cpu_dist}, simd={simd_dist}"
258 );
259 }
260
261 #[test]
266 fn test_simd_hamming_uses_simd_implementation() {
267 let simd = SimdDistance::new(DistanceMetric::Hamming);
268
269 let a: Vec<f32> = (0..64)
271 .map(|i| if i % 2 == 0 { 1.0 } else { 0.0 })
272 .collect();
273 let b: Vec<f32> = (0..64)
274 .map(|i| if i % 3 == 0 { 1.0 } else { 0.0 })
275 .collect();
276
277 let dist = simd.distance(&a, &b);
278
279 assert!(dist >= 0.0, "Hamming distance must be non-negative");
281 assert!(dist <= 64.0, "Hamming distance cannot exceed vector length");
282 }
283
284 #[test]
285 fn test_simd_jaccard_uses_simd_implementation() {
286 let simd = SimdDistance::new(DistanceMetric::Jaccard);
287
288 let a: Vec<f32> = (0..64).map(|i| if i < 32 { 1.0 } else { 0.0 }).collect();
290 let b: Vec<f32> = (0..64).map(|i| if i < 48 { 1.0 } else { 0.0 }).collect();
291
292 let dist = simd.distance(&a, &b);
293
294 assert!(
296 (0.0..=1.0).contains(&dist),
297 "Jaccard distance must be in [0,1]"
298 );
299
300 let expected = 1.0 - (32.0 / 48.0);
302 assert!(
303 (dist - expected).abs() < 1e-4,
304 "Jaccard distance: expected {expected}, got {dist}"
305 );
306 }
307
308 #[test]
309 fn test_simd_hamming_identical_vectors() {
310 let simd = SimdDistance::new(DistanceMetric::Hamming);
311 let v: Vec<f32> = (0..32)
312 .map(|i| if i % 2 == 0 { 1.0 } else { 0.0 })
313 .collect();
314
315 let dist = simd.distance(&v, &v);
316 assert!(
317 dist.abs() < 1e-5,
318 "Identical vectors should have distance 0"
319 );
320 }
321
322 #[test]
323 fn test_simd_jaccard_identical_vectors() {
324 let simd = SimdDistance::new(DistanceMetric::Jaccard);
325 let v: Vec<f32> = (0..32)
326 .map(|i| if i % 2 == 0 { 1.0 } else { 0.0 })
327 .collect();
328
329 let dist = simd.distance(&v, &v);
330 assert!(
331 dist.abs() < 1e-5,
332 "Identical vectors should have distance 0"
333 );
334 }
335
336 #[test]
337 fn test_batch_distance_with_prefetch() {
338 let simd = SimdDistance::new(DistanceMetric::Cosine);
339
340 let query: Vec<f32> = (0..768).map(|i| (i as f32 * 0.01).sin()).collect();
341 let candidates: Vec<Vec<f32>> = (0..100)
342 .map(|j| {
343 (0..768)
344 .map(|i| ((i + j * 10) as f32 * 0.01).cos())
345 .collect()
346 })
347 .collect();
348
349 let candidate_refs: Vec<&[f32]> = candidates.iter().map(Vec::as_slice).collect();
350
351 let distances = simd.batch_distance(&query, &candidate_refs);
352
353 assert_eq!(distances.len(), 100, "Should return 100 distances");
354
355 for (i, &d) in distances.iter().enumerate() {
357 assert!((0.0..=2.0).contains(&d), "Distance {i} = {d} out of range");
358 }
359 }
360
361 #[test]
362 fn test_batch_distance_consistency() {
363 let simd = SimdDistance::new(DistanceMetric::Euclidean);
364
365 let query: Vec<f32> = (0..128).map(|i| i as f32).collect();
366 let candidates: Vec<Vec<f32>> = (0..20)
367 .map(|j| (0..128).map(|i| (i + j) as f32).collect())
368 .collect();
369
370 let candidate_refs: Vec<&[f32]> = candidates.iter().map(Vec::as_slice).collect();
371
372 let batch_distances = simd.batch_distance(&query, &candidate_refs);
374
375 let individual_distances: Vec<f32> = candidate_refs
377 .iter()
378 .map(|c| simd.distance(&query, c))
379 .collect();
380
381 for (i, (batch, individual)) in batch_distances
383 .iter()
384 .zip(individual_distances.iter())
385 .enumerate()
386 {
387 assert!(
388 (batch - individual).abs() < 1e-6,
389 "Mismatch at {i}: batch={batch}, individual={individual}"
390 );
391 }
392 }
393
394 #[test]
395 fn test_batch_distance_empty() {
396 let simd = SimdDistance::new(DistanceMetric::Cosine);
397 let query = vec![1.0, 2.0, 3.0];
398 let candidates: Vec<&[f32]> = vec![];
399
400 let distances = simd.batch_distance(&query, &candidates);
401 assert!(distances.is_empty(), "Empty candidates should return empty");
402 }
403
404 #[test]
409 fn test_native_simd_matches_simd() {
410 let simd = SimdDistance::new(DistanceMetric::Cosine);
411 let native = super::NativeSimdDistance::new(DistanceMetric::Cosine);
412
413 let a: Vec<f32> = (0..768).map(|i| (i as f32 * 0.01).sin()).collect();
414 let b: Vec<f32> = (0..768).map(|i| (i as f32 * 0.02).cos()).collect();
415
416 let simd_dist = simd.distance(&a, &b);
417 let native_dist = native.distance(&a, &b);
418
419 assert!(
420 (simd_dist - native_dist).abs() < 1e-3,
421 "Native SIMD should match SIMD: simd={simd_dist}, native={native_dist}"
422 );
423 }
424
425 #[test]
426 fn test_native_simd_euclidean() {
427 let native = super::NativeSimdDistance::new(DistanceMetric::Euclidean);
428
429 let a = vec![0.0, 0.0, 0.0, 0.0];
430 let b = vec![3.0, 4.0, 0.0, 0.0];
431
432 let dist = native.distance(&a, &b);
433 assert!((dist - 5.0).abs() < 1e-5, "3-4-5 triangle: got {dist}");
434 }
435
436 #[test]
437 fn test_native_simd_dot_product() {
438 let native = super::NativeSimdDistance::new(DistanceMetric::DotProduct);
439
440 let a: Vec<f32> = (0..128).map(|i| i as f32 * 0.1).collect();
441 let b: Vec<f32> = (0..128).map(|i| (128 - i) as f32 * 0.1).collect();
442
443 let dist = native.distance(&a, &b);
444 assert!(dist < 0.0, "DotProduct distance should be negative");
446 }
447}