velesdb_core/index/hnsw/native/
distance.rs

1//! Distance computation engines for native HNSW.
2//!
3//! Provides trait abstraction for different distance computation backends:
4//! - CPU scalar (baseline)
5//! - CPU SIMD (AVX2/AVX-512/NEON)
6//! - GPU (future: CUDA/Vulkan compute)
7
8use crate::distance::DistanceMetric;
9
10/// Trait for distance computation engines.
11///
12/// This abstraction allows swapping between CPU, SIMD, and GPU backends
13/// without changing the HNSW algorithm implementation.
14pub trait DistanceEngine: Send + Sync {
15    /// Computes distance between two vectors.
16    fn distance(&self, a: &[f32], b: &[f32]) -> f32;
17
18    /// Batch distance computation (one query vs many candidates).
19    ///
20    /// Returns distances in the same order as candidates.
21    /// Default implementation calls `distance()` in a loop.
22    fn batch_distance(&self, query: &[f32], candidates: &[&[f32]]) -> Vec<f32> {
23        candidates.iter().map(|c| self.distance(query, c)).collect()
24    }
25
26    /// Returns the metric type for this engine.
27    fn metric(&self) -> DistanceMetric;
28}
29
30/// CPU scalar distance computation (baseline, no SIMD).
31pub struct CpuDistance {
32    metric: DistanceMetric,
33}
34
35impl CpuDistance {
36    /// Creates a new CPU distance engine with the given metric.
37    #[must_use]
38    pub fn new(metric: DistanceMetric) -> Self {
39        Self { metric }
40    }
41}
42
43impl DistanceEngine for CpuDistance {
44    fn distance(&self, a: &[f32], b: &[f32]) -> f32 {
45        match self.metric {
46            DistanceMetric::Cosine => cosine_distance_scalar(a, b),
47            DistanceMetric::Euclidean => euclidean_distance_scalar(a, b),
48            DistanceMetric::DotProduct => dot_product_scalar(a, b),
49            DistanceMetric::Hamming => hamming_distance_scalar(a, b),
50            DistanceMetric::Jaccard => jaccard_distance_scalar(a, b),
51        }
52    }
53
54    fn metric(&self) -> DistanceMetric {
55        self.metric
56    }
57}
58
59/// SIMD-accelerated distance computation.
60///
61/// Uses AVX2/AVX-512 on x86_64, NEON on ARM.
62pub struct SimdDistance {
63    metric: DistanceMetric,
64}
65
66impl SimdDistance {
67    /// Creates a new SIMD-accelerated distance engine with the given metric.
68    #[must_use]
69    pub fn new(metric: DistanceMetric) -> Self {
70        Self { metric }
71    }
72}
73
74impl DistanceEngine for SimdDistance {
75    fn distance(&self, a: &[f32], b: &[f32]) -> f32 {
76        // Use our existing optimized SIMD functions for ALL metrics
77        match self.metric {
78            DistanceMetric::Cosine => 1.0 - crate::simd::cosine_similarity_fast(a, b),
79            DistanceMetric::Euclidean => crate::simd::euclidean_distance_fast(a, b),
80            DistanceMetric::DotProduct => -crate::simd::dot_product_fast(a, b), // Negate for distance
81            // PERF-2: Use SIMD implementations for Hamming/Jaccard
82            DistanceMetric::Hamming => crate::simd::hamming_distance_fast(a, b),
83            DistanceMetric::Jaccard => 1.0 - crate::simd::jaccard_similarity_fast(a, b),
84        }
85    }
86
87    fn batch_distance(&self, query: &[f32], candidates: &[&[f32]]) -> Vec<f32> {
88        // PERF-2: Optimized batch distance with CPU prefetch hints
89        // Prefetch upcoming vectors to hide memory latency
90        let prefetch_distance = crate::simd::calculate_prefetch_distance(query.len());
91        let mut results = Vec::with_capacity(candidates.len());
92
93        for (i, candidate) in candidates.iter().enumerate() {
94            // Prefetch upcoming candidate vectors into L1 cache
95            if i + prefetch_distance < candidates.len() {
96                crate::simd::prefetch_vector(candidates[i + prefetch_distance]);
97            }
98            results.push(self.distance(query, candidate));
99        }
100
101        results
102    }
103
104    fn metric(&self) -> DistanceMetric {
105        self.metric
106    }
107}
108
109/// Native SIMD distance computation using core::arch intrinsics.
110///
111/// Uses AVX-512 native intrinsics on x86_64, NEON on ARM.
112/// Based on arXiv:2505.07621 "Bang for the Buck" recommendations.
113pub struct NativeSimdDistance {
114    metric: DistanceMetric,
115}
116
117impl NativeSimdDistance {
118    /// Creates a new native SIMD distance engine.
119    #[must_use]
120    pub fn new(metric: DistanceMetric) -> Self {
121        Self { metric }
122    }
123}
124
125impl DistanceEngine for NativeSimdDistance {
126    fn distance(&self, a: &[f32], b: &[f32]) -> f32 {
127        match self.metric {
128            DistanceMetric::Cosine => 1.0 - crate::simd_native::cosine_similarity_native(a, b),
129            DistanceMetric::Euclidean => crate::simd_native::euclidean_native(a, b),
130            DistanceMetric::DotProduct => -crate::simd_native::dot_product_native(a, b),
131            // Fall back to existing SIMD for Hamming/Jaccard
132            DistanceMetric::Hamming => crate::simd::hamming_distance_fast(a, b),
133            DistanceMetric::Jaccard => 1.0 - crate::simd::jaccard_similarity_fast(a, b),
134        }
135    }
136
137    fn batch_distance(&self, query: &[f32], candidates: &[&[f32]]) -> Vec<f32> {
138        match self.metric {
139            DistanceMetric::DotProduct => {
140                // Use optimized batch with prefetch
141                crate::simd_native::batch_dot_product_native(candidates, query)
142                    .into_iter()
143                    .map(|d| -d)
144                    .collect()
145            }
146            _ => candidates.iter().map(|c| self.distance(query, c)).collect(),
147        }
148    }
149
150    fn metric(&self) -> DistanceMetric {
151        self.metric
152    }
153}
154
155// =============================================================================
156// Scalar implementations (baseline for comparison)
157// =============================================================================
158
159#[inline]
160fn cosine_distance_scalar(a: &[f32], b: &[f32]) -> f32 {
161    let mut dot = 0.0_f32;
162    let mut norm_a = 0.0_f32;
163    let mut norm_b = 0.0_f32;
164
165    for (x, y) in a.iter().zip(b.iter()) {
166        dot += x * y;
167        norm_a += x * x;
168        norm_b += y * y;
169    }
170
171    let denom = (norm_a * norm_b).sqrt();
172    if denom == 0.0 {
173        1.0
174    } else {
175        1.0 - (dot / denom)
176    }
177}
178
179#[inline]
180fn euclidean_distance_scalar(a: &[f32], b: &[f32]) -> f32 {
181    a.iter()
182        .zip(b.iter())
183        .map(|(x, y)| (x - y).powi(2))
184        .sum::<f32>()
185        .sqrt()
186}
187
188#[inline]
189fn dot_product_scalar(a: &[f32], b: &[f32]) -> f32 {
190    // Return negative because we want distance (lower = better)
191    -a.iter().zip(b.iter()).map(|(x, y)| x * y).sum::<f32>()
192}
193
194#[inline]
195fn hamming_distance_scalar(a: &[f32], b: &[f32]) -> f32 {
196    a.iter()
197        .zip(b.iter())
198        .filter(|(x, y)| (x.to_bits() ^ y.to_bits()) != 0)
199        .count() as f32
200}
201
202#[inline]
203fn jaccard_distance_scalar(a: &[f32], b: &[f32]) -> f32 {
204    let mut intersection = 0.0_f32;
205    let mut union = 0.0_f32;
206
207    for (x, y) in a.iter().zip(b.iter()) {
208        intersection += x.min(*y);
209        union += x.max(*y);
210    }
211
212    if union == 0.0 {
213        1.0
214    } else {
215        1.0 - (intersection / union)
216    }
217}
218
219#[cfg(test)]
220#[allow(clippy::cast_precision_loss)]
221mod tests {
222    use super::*;
223
224    #[test]
225    fn test_cosine_identical_vectors() {
226        let engine = CpuDistance::new(DistanceMetric::Cosine);
227        let v = vec![1.0, 2.0, 3.0];
228        let dist = engine.distance(&v, &v);
229        assert!(
230            dist.abs() < 1e-5,
231            "Identical vectors should have distance ~0"
232        );
233    }
234
235    #[test]
236    fn test_euclidean_known_distance() {
237        let engine = CpuDistance::new(DistanceMetric::Euclidean);
238        let a = vec![0.0, 0.0, 0.0];
239        let b = vec![3.0, 4.0, 0.0];
240        let dist = engine.distance(&a, &b);
241        assert!((dist - 5.0).abs() < 1e-5, "3-4-5 triangle");
242    }
243
244    #[test]
245    fn test_simd_matches_scalar() {
246        let cpu = CpuDistance::new(DistanceMetric::Cosine);
247        let simd = SimdDistance::new(DistanceMetric::Cosine);
248
249        let a: Vec<f32> = (0..768).map(|i| (i as f32 * 0.01).sin()).collect();
250        let b: Vec<f32> = (0..768).map(|i| (i as f32 * 0.02).cos()).collect();
251
252        let cpu_dist = cpu.distance(&a, &b);
253        let simd_dist = simd.distance(&a, &b);
254
255        assert!(
256            (cpu_dist - simd_dist).abs() < 1e-4,
257            "SIMD should match scalar: cpu={cpu_dist}, simd={simd_dist}"
258        );
259    }
260
261    // =========================================================================
262    // TDD Tests for PERF-2: Hamming/Jaccard SIMD + batch_distance optimization
263    // =========================================================================
264
265    #[test]
266    fn test_simd_hamming_uses_simd_implementation() {
267        let simd = SimdDistance::new(DistanceMetric::Hamming);
268
269        // Binary-like vectors (0.0 or 1.0)
270        let a: Vec<f32> = (0..64)
271            .map(|i| if i % 2 == 0 { 1.0 } else { 0.0 })
272            .collect();
273        let b: Vec<f32> = (0..64)
274            .map(|i| if i % 3 == 0 { 1.0 } else { 0.0 })
275            .collect();
276
277        let dist = simd.distance(&a, &b);
278
279        // Verify result is reasonable (hamming distance between these patterns)
280        assert!(dist >= 0.0, "Hamming distance must be non-negative");
281        assert!(dist <= 64.0, "Hamming distance cannot exceed vector length");
282    }
283
284    #[test]
285    fn test_simd_jaccard_uses_simd_implementation() {
286        let simd = SimdDistance::new(DistanceMetric::Jaccard);
287
288        // Binary-like vectors for set similarity
289        let a: Vec<f32> = (0..64).map(|i| if i < 32 { 1.0 } else { 0.0 }).collect();
290        let b: Vec<f32> = (0..64).map(|i| if i < 48 { 1.0 } else { 0.0 }).collect();
291
292        let dist = simd.distance(&a, &b);
293
294        // Jaccard distance = 1 - similarity, should be in [0, 1]
295        assert!(
296            (0.0..=1.0).contains(&dist),
297            "Jaccard distance must be in [0,1]"
298        );
299
300        // Intersection = 32, Union = 48, Similarity = 32/48 = 0.667, Distance = 0.333
301        let expected = 1.0 - (32.0 / 48.0);
302        assert!(
303            (dist - expected).abs() < 1e-4,
304            "Jaccard distance: expected {expected}, got {dist}"
305        );
306    }
307
308    #[test]
309    fn test_simd_hamming_identical_vectors() {
310        let simd = SimdDistance::new(DistanceMetric::Hamming);
311        let v: Vec<f32> = (0..32)
312            .map(|i| if i % 2 == 0 { 1.0 } else { 0.0 })
313            .collect();
314
315        let dist = simd.distance(&v, &v);
316        assert!(
317            dist.abs() < 1e-5,
318            "Identical vectors should have distance 0"
319        );
320    }
321
322    #[test]
323    fn test_simd_jaccard_identical_vectors() {
324        let simd = SimdDistance::new(DistanceMetric::Jaccard);
325        let v: Vec<f32> = (0..32)
326            .map(|i| if i % 2 == 0 { 1.0 } else { 0.0 })
327            .collect();
328
329        let dist = simd.distance(&v, &v);
330        assert!(
331            dist.abs() < 1e-5,
332            "Identical vectors should have distance 0"
333        );
334    }
335
336    #[test]
337    fn test_batch_distance_with_prefetch() {
338        let simd = SimdDistance::new(DistanceMetric::Cosine);
339
340        let query: Vec<f32> = (0..768).map(|i| (i as f32 * 0.01).sin()).collect();
341        let candidates: Vec<Vec<f32>> = (0..100)
342            .map(|j| {
343                (0..768)
344                    .map(|i| ((i + j * 10) as f32 * 0.01).cos())
345                    .collect()
346            })
347            .collect();
348
349        let candidate_refs: Vec<&[f32]> = candidates.iter().map(Vec::as_slice).collect();
350
351        let distances = simd.batch_distance(&query, &candidate_refs);
352
353        assert_eq!(distances.len(), 100, "Should return 100 distances");
354
355        // Verify all distances are valid (cosine distance in [0, 2])
356        for (i, &d) in distances.iter().enumerate() {
357            assert!((0.0..=2.0).contains(&d), "Distance {i} = {d} out of range");
358        }
359    }
360
361    #[test]
362    fn test_batch_distance_consistency() {
363        let simd = SimdDistance::new(DistanceMetric::Euclidean);
364
365        let query: Vec<f32> = (0..128).map(|i| i as f32).collect();
366        let candidates: Vec<Vec<f32>> = (0..20)
367            .map(|j| (0..128).map(|i| (i + j) as f32).collect())
368            .collect();
369
370        let candidate_refs: Vec<&[f32]> = candidates.iter().map(Vec::as_slice).collect();
371
372        // Batch distance
373        let batch_distances = simd.batch_distance(&query, &candidate_refs);
374
375        // Individual distances
376        let individual_distances: Vec<f32> = candidate_refs
377            .iter()
378            .map(|c| simd.distance(&query, c))
379            .collect();
380
381        // Results should match exactly
382        for (i, (batch, individual)) in batch_distances
383            .iter()
384            .zip(individual_distances.iter())
385            .enumerate()
386        {
387            assert!(
388                (batch - individual).abs() < 1e-6,
389                "Mismatch at {i}: batch={batch}, individual={individual}"
390            );
391        }
392    }
393
394    #[test]
395    fn test_batch_distance_empty() {
396        let simd = SimdDistance::new(DistanceMetric::Cosine);
397        let query = vec![1.0, 2.0, 3.0];
398        let candidates: Vec<&[f32]> = vec![];
399
400        let distances = simd.batch_distance(&query, &candidates);
401        assert!(distances.is_empty(), "Empty candidates should return empty");
402    }
403
404    // =========================================================================
405    // Tests for NativeSimdDistance (AVX-512/NEON intrinsics)
406    // =========================================================================
407
408    #[test]
409    fn test_native_simd_matches_simd() {
410        let simd = SimdDistance::new(DistanceMetric::Cosine);
411        let native = super::NativeSimdDistance::new(DistanceMetric::Cosine);
412
413        let a: Vec<f32> = (0..768).map(|i| (i as f32 * 0.01).sin()).collect();
414        let b: Vec<f32> = (0..768).map(|i| (i as f32 * 0.02).cos()).collect();
415
416        let simd_dist = simd.distance(&a, &b);
417        let native_dist = native.distance(&a, &b);
418
419        assert!(
420            (simd_dist - native_dist).abs() < 1e-3,
421            "Native SIMD should match SIMD: simd={simd_dist}, native={native_dist}"
422        );
423    }
424
425    #[test]
426    fn test_native_simd_euclidean() {
427        let native = super::NativeSimdDistance::new(DistanceMetric::Euclidean);
428
429        let a = vec![0.0, 0.0, 0.0, 0.0];
430        let b = vec![3.0, 4.0, 0.0, 0.0];
431
432        let dist = native.distance(&a, &b);
433        assert!((dist - 5.0).abs() < 1e-5, "3-4-5 triangle: got {dist}");
434    }
435
436    #[test]
437    fn test_native_simd_dot_product() {
438        let native = super::NativeSimdDistance::new(DistanceMetric::DotProduct);
439
440        let a: Vec<f32> = (0..128).map(|i| i as f32 * 0.1).collect();
441        let b: Vec<f32> = (0..128).map(|i| (128 - i) as f32 * 0.1).collect();
442
443        let dist = native.distance(&a, &b);
444        // DotProduct distance is negative dot product
445        assert!(dist < 0.0, "DotProduct distance should be negative");
446    }
447}