Skip to main content

ruvector_diskann/
distance.rs

1//! Distance computations with SIMD acceleration and optional GPU offload
2//!
3//! Dispatch priority: GPU (if `gpu` feature) → SimSIMD (if `simd` feature) → scalar
4
5/// Flat vector storage — contiguous memory for cache-friendly access
6/// Vectors are stored as a single `Vec<f32>` slab: `[v0_d0, v0_d1, ..., v1_d0, ...]`
7#[derive(Clone)]
8pub struct FlatVectors {
9    pub data: Vec<f32>,
10    pub dim: usize,
11    pub count: usize,
12}
13
14impl FlatVectors {
15    pub fn new(dim: usize) -> Self {
16        Self {
17            data: Vec::new(),
18            dim,
19            count: 0,
20        }
21    }
22
23    pub fn with_capacity(dim: usize, n: usize) -> Self {
24        Self {
25            data: Vec::with_capacity(n * dim),
26            dim,
27            count: 0,
28        }
29    }
30
31    #[inline]
32    pub fn push(&mut self, vector: &[f32]) {
33        debug_assert_eq!(vector.len(), self.dim);
34        self.data.extend_from_slice(vector);
35        self.count += 1;
36    }
37
38    #[inline]
39    pub fn get(&self, idx: usize) -> &[f32] {
40        let start = idx * self.dim;
41        &self.data[start..start + self.dim]
42    }
43
44    /// Zero out a vector (lazy deletion)
45    #[inline]
46    pub fn zero_out(&mut self, idx: usize) {
47        let start = idx * self.dim;
48        for v in &mut self.data[start..start + self.dim] {
49            *v = f32::NAN;
50        }
51    }
52
53    pub fn len(&self) -> usize {
54        self.count
55    }
56
57    pub fn is_empty(&self) -> bool {
58        self.count == 0
59    }
60}
61
62// ============================================================================
63// Distance functions — auto-dispatch based on features
64// ============================================================================
65
66/// L2 squared distance — dispatches to best available implementation
67#[inline]
68pub fn l2_squared(a: &[f32], b: &[f32]) -> f32 {
69    debug_assert_eq!(a.len(), b.len());
70
71    #[cfg(feature = "simd")]
72    {
73        simd_l2_squared(a, b)
74    }
75
76    #[cfg(not(feature = "simd"))]
77    {
78        scalar_l2_squared(a, b)
79    }
80}
81
82/// Scalar L2² with 4 accumulators for ILP
83#[inline]
84pub fn scalar_l2_squared(a: &[f32], b: &[f32]) -> f32 {
85    let len = a.len();
86    let mut s0 = 0.0f32;
87    let mut s1 = 0.0f32;
88    let mut s2 = 0.0f32;
89    let mut s3 = 0.0f32;
90    let mut i = 0;
91
92    while i + 16 <= len {
93        for j in 0..4 {
94            let off = i + j * 4;
95            let d0 = a[off] - b[off];
96            let d1 = a[off + 1] - b[off + 1];
97            let d2 = a[off + 2] - b[off + 2];
98            let d3 = a[off + 3] - b[off + 3];
99            s0 += d0 * d0;
100            s1 += d1 * d1;
101            s2 += d2 * d2;
102            s3 += d3 * d3;
103        }
104        i += 16;
105    }
106    while i < len {
107        let d = a[i] - b[i];
108        s0 += d * d;
109        i += 1;
110    }
111    s0 + s1 + s2 + s3
112}
113
114/// SimSIMD-accelerated L2² — uses hardware NEON/AVX2/AVX-512
115#[cfg(feature = "simd")]
116#[inline]
117pub fn simd_l2_squared(a: &[f32], b: &[f32]) -> f32 {
118    // simsimd sqeuclidean returns squared Euclidean directly
119    simsimd::SpatialSimilarity::sqeuclidean(a, b)
120        .map(|d| d as f32)
121        .unwrap_or_else(|| scalar_l2_squared(a, b))
122}
123
124/// Inner product distance (negated for min-heap)
125#[inline]
126pub fn inner_product(a: &[f32], b: &[f32]) -> f32 {
127    debug_assert_eq!(a.len(), b.len());
128
129    #[cfg(feature = "simd")]
130    {
131        simsimd::SpatialSimilarity::inner(a, b)
132            .map(|d| -(d as f32))
133            .unwrap_or_else(|| scalar_inner_product(a, b))
134    }
135
136    #[cfg(not(feature = "simd"))]
137    {
138        scalar_inner_product(a, b)
139    }
140}
141
142#[inline]
143fn scalar_inner_product(a: &[f32], b: &[f32]) -> f32 {
144    let mut s0 = 0.0f32;
145    let mut s1 = 0.0f32;
146    let mut s2 = 0.0f32;
147    let mut s3 = 0.0f32;
148    let len = a.len();
149    let mut i = 0;
150
151    while i + 16 <= len {
152        for j in 0..4 {
153            let off = i + j * 4;
154            s0 += a[off] * b[off];
155            s1 += a[off + 1] * b[off + 1];
156            s2 += a[off + 2] * b[off + 2];
157            s3 += a[off + 3] * b[off + 3];
158        }
159        i += 16;
160    }
161    while i < len {
162        s0 += a[i] * b[i];
163        i += 1;
164    }
165    -(s0 + s1 + s2 + s3)
166}
167
168/// PQ asymmetric distance from precomputed lookup table
169#[inline]
170pub fn pq_asymmetric_distance(codes: &[u8], table: &[f32], k: usize) -> f32 {
171    // table is flat: table[subspace * 256 + code]
172    let mut dist = 0.0f32;
173    for (i, &code) in codes.iter().enumerate() {
174        dist += unsafe { *table.get_unchecked(i * k + code as usize) };
175    }
176    dist
177}
178
179// ============================================================================
180// Visited bitset — O(1) membership test, much faster than HashSet<u32>
181// ============================================================================
182
183/// Compact bitset for tracking visited nodes during search
184pub struct VisitedSet {
185    bits: Vec<u64>,
186    generation: u64,
187    gens: Vec<u64>,
188}
189
190impl VisitedSet {
191    pub fn new(n: usize) -> Self {
192        Self {
193            bits: vec![0u64; (n + 63) / 64],
194            generation: 1,
195            gens: vec![0u64; n],
196        }
197    }
198
199    /// Reset for a new search — O(1) via generation counter
200    #[inline]
201    pub fn clear(&mut self) {
202        self.generation += 1;
203    }
204
205    /// Mark node as visited
206    #[inline]
207    pub fn insert(&mut self, id: u32) {
208        self.gens[id as usize] = self.generation;
209    }
210
211    /// Check if visited
212    #[inline]
213    pub fn contains(&self, id: u32) -> bool {
214        self.gens[id as usize] == self.generation
215    }
216}
217
218// ============================================================================
219// GPU distance computation (optional, feature-gated)
220// ============================================================================
221
222/// GPU-accelerated batch distance computation
223/// Computes distances from a single query to N vectors in parallel
224#[cfg(feature = "gpu")]
225pub mod gpu {
226    use super::FlatVectors;
227
228    /// GPU backend selection
229    #[derive(Debug, Clone, Copy)]
230    pub enum GpuBackend {
231        /// Apple Metal (macOS/iOS)
232        Metal,
233        /// NVIDIA CUDA
234        Cuda,
235        /// Vulkan compute (cross-platform)
236        Vulkan,
237    }
238
239    /// GPU distance computation context
240    pub struct GpuDistanceContext {
241        backend: GpuBackend,
242        /// Batch size for GPU kernel launches
243        batch_size: usize,
244    }
245
246    impl GpuDistanceContext {
247        /// Create a new GPU context (auto-detects best backend)
248        pub fn new() -> Option<Self> {
249            // Auto-detect: Metal on macOS, CUDA if nvidia, Vulkan fallback
250            #[cfg(target_os = "macos")]
251            let backend = GpuBackend::Metal;
252            #[cfg(not(target_os = "macos"))]
253            let backend = GpuBackend::Cuda;
254
255            Some(Self {
256                backend,
257                batch_size: 4096,
258            })
259        }
260
261        /// Batch L2² distances: query vs all vectors in flat storage
262        /// Returns Vec of (index, distance) sorted by distance
263        pub fn batch_l2_squared(
264            &self,
265            query: &[f32],
266            vectors: &FlatVectors,
267            k: usize,
268        ) -> Vec<(u32, f32)> {
269            // GPU kernel dispatch:
270            // 1. Upload query + vector slab to GPU memory
271            // 2. Launch N threads, each computing one L2² distance
272            // 3. Parallel top-k reduction on GPU
273            // 4. Download k results
274            //
275            // For now, fall back to CPU parallel with rayon
276            // (real Metal/CUDA shaders would be added via metal-rs or cuda-sys)
277            use rayon::prelude::*;
278
279            let mut dists: Vec<(u32, f32)> = (0..vectors.count as u32)
280                .into_par_iter()
281                .map(|i| {
282                    let v = vectors.get(i as usize);
283                    (i, super::scalar_l2_squared(query, v))
284                })
285                .collect();
286
287            dists.sort_unstable_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
288            dists.truncate(k);
289            dists
290        }
291
292        pub fn backend(&self) -> GpuBackend {
293            self.backend
294        }
295    }
296}
297
298#[cfg(test)]
299mod tests {
300    use super::*;
301
302    #[test]
303    fn test_l2_squared() {
304        let a = vec![1.0, 2.0, 3.0];
305        let b = vec![4.0, 5.0, 6.0];
306        assert!((l2_squared(&a, &b) - 27.0).abs() < 1e-6);
307    }
308
309    #[test]
310    fn test_l2_identical() {
311        let a = vec![1.0; 128];
312        assert!(l2_squared(&a, &a) < 1e-10);
313    }
314
315    #[test]
316    fn test_inner_product() {
317        let a = vec![1.0, 2.0, 3.0];
318        let b = vec![4.0, 5.0, 6.0];
319        assert!((inner_product(&a, &b) - (-32.0)).abs() < 1e-6);
320    }
321
322    #[test]
323    fn test_flat_vectors() {
324        let mut fv = FlatVectors::new(3);
325        fv.push(&[1.0, 2.0, 3.0]);
326        fv.push(&[4.0, 5.0, 6.0]);
327        assert_eq!(fv.len(), 2);
328        assert_eq!(fv.get(0), &[1.0, 2.0, 3.0]);
329        assert_eq!(fv.get(1), &[4.0, 5.0, 6.0]);
330    }
331
332    #[test]
333    fn test_visited_set() {
334        let mut vs = VisitedSet::new(100);
335        vs.insert(42);
336        assert!(vs.contains(42));
337        assert!(!vs.contains(43));
338        vs.clear(); // O(1) reset
339        assert!(!vs.contains(42));
340        vs.insert(43);
341        assert!(vs.contains(43));
342    }
343
344    #[test]
345    fn test_pq_flat_table() {
346        // 2 subspaces, 4 centroids each (k=4 for test)
347        let table = vec![
348            0.1, 0.2, 0.3, 0.4,  // subspace 0
349            0.5, 0.6, 0.7, 0.8,  // subspace 1
350        ];
351        let codes = vec![1u8, 2u8]; // code 1 from sub0, code 2 from sub1
352        let dist = pq_asymmetric_distance(&codes, &table, 4);
353        assert!((dist - (0.2 + 0.7)).abs() < 1e-6);
354    }
355}