Skip to main content

sklears_simd/
search.rs

1//! SIMD-optimized search algorithms
2//!
3//! This module provides vectorized implementations of search algorithms
4//! including binary search, linear search, and approximate nearest neighbor search.
5
6use crate::distance::euclidean_distance;
7
8#[cfg(feature = "no-std")]
9use alloc::collections::{BTreeMap as HashMap, BTreeSet as HashSet};
10#[cfg(feature = "no-std")]
11use alloc::vec::Vec;
12#[cfg(not(feature = "no-std"))]
13use std::collections::{HashMap, HashSet};
14
15#[cfg(feature = "no-std")]
16use core::cmp::Ordering;
17#[cfg(not(feature = "no-std"))]
18use std::cmp::Ordering;
19
20/// SIMD-optimized binary search for sorted f32 arrays
21/// Returns the index where target is found, or None if not found
22pub fn binary_search_f32_simd(arr: &[f32], target: f32) -> Option<usize> {
23    if arr.is_empty() {
24        return None;
25    }
26
27    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
28    {
29        if crate::simd_feature_detected!("avx2") && arr.len() >= 16 {
30            return unsafe { binary_search_avx2(arr, target) };
31        } else if crate::simd_feature_detected!("sse2") && arr.len() >= 8 {
32            return unsafe { binary_search_sse2(arr, target) };
33        }
34    }
35
36    binary_search_scalar(arr, target)
37}
38
39fn binary_search_scalar(arr: &[f32], target: f32) -> Option<usize> {
40    let mut left = 0;
41    let mut right = arr.len();
42
43    while left < right {
44        let mid = left + (right - left) / 2;
45
46        match arr[mid].partial_cmp(&target) {
47            Some(Ordering::Equal) => return Some(mid),
48            Some(Ordering::Less) => left = mid + 1,
49            Some(Ordering::Greater) => right = mid,
50            None => return None, // NaN handling
51        }
52    }
53
54    None
55}
56
57#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
58#[target_feature(enable = "sse2")]
59unsafe fn binary_search_sse2(arr: &[f32], target: f32) -> Option<usize> {
60    use core::arch::x86_64::*;
61
62    let mut left = 0;
63    let mut right = arr.len();
64    let target_vec = _mm_set1_ps(target);
65
66    while left < right {
67        let mid = left + (right - left) / 2;
68
69        // Try SIMD comparison for small ranges
70        if right - left <= 4 && left + 4 <= arr.len() {
71            let vec = _mm_loadu_ps(&arr[left]);
72            let eq_mask = _mm_cmpeq_ps(vec, target_vec);
73            let mask = _mm_movemask_ps(eq_mask);
74
75            if mask != 0 {
76                // Found target, determine exact position
77                for i in 0..4 {
78                    if (mask & (1 << i)) != 0 {
79                        return Some(left + i);
80                    }
81                }
82            }
83
84            // If not found in SIMD range, fall back to scalar
85            return binary_search_scalar(&arr[left..right], target).map(|idx| left + idx);
86        }
87
88        // Regular binary search step
89        match arr[mid].partial_cmp(&target) {
90            Some(Ordering::Equal) => return Some(mid),
91            Some(Ordering::Less) => left = mid + 1,
92            Some(Ordering::Greater) => right = mid,
93            None => return None,
94        }
95    }
96
97    None
98}
99
100#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
101#[target_feature(enable = "avx2")]
102unsafe fn binary_search_avx2(arr: &[f32], target: f32) -> Option<usize> {
103    use core::arch::x86_64::*;
104
105    let mut left = 0;
106    let mut right = arr.len();
107    let target_vec = _mm256_set1_ps(target);
108
109    while left < right {
110        let mid = left + (right - left) / 2;
111
112        // Try SIMD comparison for small ranges
113        if right - left <= 8 && left + 8 <= arr.len() {
114            let vec = _mm256_loadu_ps(&arr[left]);
115            let eq_mask = _mm256_cmp_ps(vec, target_vec, _CMP_EQ_OQ);
116            let mask = _mm256_movemask_ps(eq_mask);
117
118            if mask != 0 {
119                // Found target, determine exact position
120                for i in 0..8 {
121                    if (mask & (1 << i)) != 0 {
122                        return Some(left + i);
123                    }
124                }
125            }
126
127            // If not found in SIMD range, fall back to scalar
128            return binary_search_scalar(&arr[left..right], target).map(|idx| left + idx);
129        }
130
131        // Regular binary search step
132        match arr[mid].partial_cmp(&target) {
133            Some(Ordering::Equal) => return Some(mid),
134            Some(Ordering::Less) => left = mid + 1,
135            Some(Ordering::Greater) => right = mid,
136            None => return None,
137        }
138    }
139
140    None
141}
142
143/// SIMD-optimized linear search for unsorted arrays
144/// Returns the first index where target is found, or None if not found
145pub fn linear_search_f32_simd(arr: &[f32], target: f32) -> Option<usize> {
146    if arr.is_empty() {
147        return None;
148    }
149
150    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
151    {
152        if crate::simd_feature_detected!("avx2") {
153            return unsafe { linear_search_avx2(arr, target) };
154        } else if crate::simd_feature_detected!("sse2") {
155            return unsafe { linear_search_sse2(arr, target) };
156        }
157    }
158
159    linear_search_scalar(arr, target)
160}
161
162fn linear_search_scalar(arr: &[f32], target: f32) -> Option<usize> {
163    for (i, &value) in arr.iter().enumerate() {
164        if value == target {
165            return Some(i);
166        }
167    }
168    None
169}
170
171#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
172#[target_feature(enable = "sse2")]
173unsafe fn linear_search_sse2(arr: &[f32], target: f32) -> Option<usize> {
174    use core::arch::x86_64::*;
175
176    let target_vec = _mm_set1_ps(target);
177    let mut i = 0;
178
179    while i + 4 <= arr.len() {
180        let vec = _mm_loadu_ps(&arr[i]);
181        let eq_mask = _mm_cmpeq_ps(vec, target_vec);
182        let mask = _mm_movemask_ps(eq_mask);
183
184        if mask != 0 {
185            // Found target, determine exact position
186            for j in 0..4 {
187                if (mask & (1 << j)) != 0 {
188                    return Some(i + j);
189                }
190            }
191        }
192
193        i += 4;
194    }
195
196    // Handle remaining elements
197    while i < arr.len() {
198        if arr[i] == target {
199            return Some(i);
200        }
201        i += 1;
202    }
203
204    None
205}
206
207#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
208#[target_feature(enable = "avx2")]
209unsafe fn linear_search_avx2(arr: &[f32], target: f32) -> Option<usize> {
210    use core::arch::x86_64::*;
211
212    let target_vec = _mm256_set1_ps(target);
213    let mut i = 0;
214
215    while i + 8 <= arr.len() {
216        let vec = _mm256_loadu_ps(&arr[i]);
217        let eq_mask = _mm256_cmp_ps(vec, target_vec, _CMP_EQ_OQ);
218        let mask = _mm256_movemask_ps(eq_mask);
219
220        if mask != 0 {
221            // Found target, determine exact position
222            for j in 0..8 {
223                if (mask & (1 << j)) != 0 {
224                    return Some(i + j);
225                }
226            }
227        }
228
229        i += 8;
230    }
231
232    // Handle remaining elements
233    while i < arr.len() {
234        if arr[i] == target {
235            return Some(i);
236        }
237        i += 1;
238    }
239
240    None
241}
242
243/// Result for nearest neighbor search
244#[derive(Debug, Clone, PartialEq)]
245pub struct NearestNeighborResult {
246    pub index: usize,
247    pub distance: f32,
248}
249
250/// SIMD-optimized k-nearest neighbors search
251/// Returns the k nearest neighbors to the query point
252pub fn k_nearest_neighbors_simd(
253    points: &[&[f32]],
254    query: &[f32],
255    k: usize,
256) -> Vec<NearestNeighborResult> {
257    if points.is_empty() || k == 0 {
258        return Vec::new();
259    }
260
261    let k = k.min(points.len());
262    let mut distances: Vec<(usize, f32)> = Vec::with_capacity(points.len());
263
264    // Compute distances to all points
265    for (i, point) in points.iter().enumerate() {
266        let distance = euclidean_distance(query, point);
267        distances.push((i, distance));
268    }
269
270    // Sort by distance and take k smallest
271    distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
272
273    distances
274        .into_iter()
275        .take(k)
276        .map(|(index, distance)| NearestNeighborResult { index, distance })
277        .collect()
278}
279
280/// SIMD-optimized approximate nearest neighbor search using LSH (Locality Sensitive Hashing)
281/// This is a simplified implementation for demonstration
282pub struct LSHTable {
283    tables: Vec<LSHHashTable>,
284    #[allow(dead_code)] // Stored for introspection (e.g. reporting table count to caller)
285    num_tables: usize,
286    #[allow(dead_code)] // Stored for introspection (e.g. reporting hash width to caller)
287    hash_size: usize,
288}
289
290struct LSHHashTable {
291    buckets: HashMap<u64, Vec<usize>>,
292    random_vectors: Vec<Vec<f32>>,
293}
294
295impl LSHTable {
296    /// Create a new LSH table for approximate nearest neighbor search
297    pub fn new(dimensions: usize, num_tables: usize, hash_size: usize) -> Self {
298        let mut tables = Vec::with_capacity(num_tables);
299
300        for _ in 0..num_tables {
301            let mut random_vectors = Vec::with_capacity(hash_size);
302
303            // Generate random unit vectors for hashing
304            for _ in 0..hash_size {
305                let mut vec = Vec::with_capacity(dimensions);
306                let mut sum_squares = 0.0;
307
308                // Generate random vector
309                use scirs2_core::random::thread_rng;
310                let mut rng = thread_rng();
311                for _ in 0..dimensions {
312                    let val: f32 = rng.random::<f32>() - 0.5;
313                    vec.push(val);
314                    sum_squares += val * val;
315                }
316
317                // Normalize to unit vector
318                let norm = sum_squares.sqrt();
319                if norm > 0.0 {
320                    for val in &mut vec {
321                        *val /= norm;
322                    }
323                }
324
325                random_vectors.push(vec);
326            }
327
328            tables.push(LSHHashTable {
329                buckets: HashMap::new(),
330                random_vectors,
331            });
332        }
333
334        LSHTable {
335            tables,
336            num_tables,
337            hash_size,
338        }
339    }
340
341    /// Add a point to the LSH table
342    pub fn add_point(&mut self, point: &[f32], index: usize) {
343        for i in 0..self.tables.len() {
344            let hash = self.hash_point(&self.tables[i], point);
345            self.tables[i].buckets.entry(hash).or_default().push(index);
346        }
347    }
348
349    /// Query for approximate nearest neighbors
350    pub fn query(&self, point: &[f32], max_candidates: usize) -> Vec<usize> {
351        let mut candidates = HashSet::new();
352
353        for table in &self.tables {
354            let hash = self.hash_point(table, point);
355
356            if let Some(bucket) = table.buckets.get(&hash) {
357                for &index in bucket {
358                    candidates.insert(index);
359                    if candidates.len() >= max_candidates {
360                        break;
361                    }
362                }
363            }
364
365            if candidates.len() >= max_candidates {
366                break;
367            }
368        }
369
370        candidates.into_iter().collect()
371    }
372
373    fn hash_point(&self, table: &LSHHashTable, point: &[f32]) -> u64 {
374        let mut hash = 0u64;
375
376        for (i, random_vec) in table.random_vectors.iter().enumerate() {
377            // Compute dot product
378            let dot_product = crate::vector::dot_product(point, random_vec);
379
380            // Use sign of dot product as hash bit
381            if dot_product >= 0.0 {
382                hash |= 1u64 << i;
383            }
384        }
385
386        hash
387    }
388}
389
390/// SIMD-optimized range search
391/// Returns all points within a specified distance of the query point
392pub fn range_search_simd(
393    points: &[&[f32]],
394    query: &[f32],
395    radius: f32,
396) -> Vec<NearestNeighborResult> {
397    let mut results = Vec::new();
398    let _radius_squared = radius * radius;
399
400    for (i, point) in points.iter().enumerate() {
401        let distance = euclidean_distance(query, point);
402        if distance <= radius {
403            results.push(NearestNeighborResult { index: i, distance });
404        }
405    }
406
407    // Sort by distance
408    results.sort_by(|a, b| {
409        a.distance
410            .partial_cmp(&b.distance)
411            .unwrap_or(Ordering::Equal)
412    });
413
414    results
415}
416
417/// SIMD-optimized argmax - find index of maximum element
418pub fn argmax_f32_simd(arr: &[f32]) -> Option<usize> {
419    if arr.is_empty() {
420        return None;
421    }
422
423    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
424    {
425        if crate::simd_feature_detected!("avx2") && arr.len() >= 8 {
426            return Some(unsafe { argmax_avx2(arr) });
427        } else if crate::simd_feature_detected!("sse2") && arr.len() >= 4 {
428            return Some(unsafe { argmax_sse2(arr) });
429        }
430    }
431
432    argmax_scalar(arr)
433}
434
435fn argmax_scalar(arr: &[f32]) -> Option<usize> {
436    if arr.is_empty() {
437        return None;
438    }
439
440    let mut max_idx = 0;
441    let mut max_val = arr[0];
442
443    for (i, &val) in arr.iter().enumerate().skip(1) {
444        if val > max_val {
445            max_val = val;
446            max_idx = i;
447        }
448    }
449
450    Some(max_idx)
451}
452
453#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
454#[target_feature(enable = "sse2")]
455unsafe fn argmax_sse2(arr: &[f32]) -> usize {
456    use core::arch::x86_64::*;
457
458    let mut max_val = arr[0];
459    let mut max_idx = 0;
460    let mut i = 0;
461
462    while i + 4 <= arr.len() {
463        let vec = _mm_loadu_ps(&arr[i]);
464        let mut temp = [0.0f32; 4];
465        _mm_storeu_ps(temp.as_mut_ptr(), vec);
466
467        for (j, &val) in temp.iter().enumerate() {
468            if val > max_val {
469                max_val = val;
470                max_idx = i + j;
471            }
472        }
473
474        i += 4;
475    }
476
477    // Handle remaining elements
478    while i < arr.len() {
479        if arr[i] > max_val {
480            max_val = arr[i];
481            max_idx = i;
482        }
483        i += 1;
484    }
485
486    max_idx
487}
488
489#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
490#[target_feature(enable = "avx2")]
491unsafe fn argmax_avx2(arr: &[f32]) -> usize {
492    use core::arch::x86_64::*;
493
494    let mut max_val = arr[0];
495    let mut max_idx = 0;
496    let mut i = 0;
497
498    while i + 8 <= arr.len() {
499        let vec = _mm256_loadu_ps(&arr[i]);
500        let mut temp = [0.0f32; 8];
501        _mm256_storeu_ps(temp.as_mut_ptr(), vec);
502
503        for (j, &val) in temp.iter().enumerate() {
504            if val > max_val {
505                max_val = val;
506                max_idx = i + j;
507            }
508        }
509
510        i += 8;
511    }
512
513    // Handle remaining elements
514    while i < arr.len() {
515        if arr[i] > max_val {
516            max_val = arr[i];
517            max_idx = i;
518        }
519        i += 1;
520    }
521
522    max_idx
523}
524
525/// SIMD-optimized argmin - find index of minimum element
526pub fn argmin_f32_simd(arr: &[f32]) -> Option<usize> {
527    if arr.is_empty() {
528        return None;
529    }
530
531    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
532    {
533        if crate::simd_feature_detected!("avx2") && arr.len() >= 8 {
534            return Some(unsafe { argmin_avx2(arr) });
535        } else if crate::simd_feature_detected!("sse2") && arr.len() >= 4 {
536            return Some(unsafe { argmin_sse2(arr) });
537        }
538    }
539
540    argmin_scalar(arr)
541}
542
543fn argmin_scalar(arr: &[f32]) -> Option<usize> {
544    if arr.is_empty() {
545        return None;
546    }
547
548    let mut min_idx = 0;
549    let mut min_val = arr[0];
550
551    for (i, &val) in arr.iter().enumerate().skip(1) {
552        if val < min_val {
553            min_val = val;
554            min_idx = i;
555        }
556    }
557
558    Some(min_idx)
559}
560
561#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
562#[target_feature(enable = "sse2")]
563unsafe fn argmin_sse2(arr: &[f32]) -> usize {
564    use core::arch::x86_64::*;
565
566    let mut min_val = arr[0];
567    let mut min_idx = 0;
568    let mut i = 0;
569
570    while i + 4 <= arr.len() {
571        let vec = _mm_loadu_ps(&arr[i]);
572        let mut temp = [0.0f32; 4];
573        _mm_storeu_ps(temp.as_mut_ptr(), vec);
574
575        for (j, &val) in temp.iter().enumerate() {
576            if val < min_val {
577                min_val = val;
578                min_idx = i + j;
579            }
580        }
581
582        i += 4;
583    }
584
585    // Handle remaining elements
586    while i < arr.len() {
587        if arr[i] < min_val {
588            min_val = arr[i];
589            min_idx = i;
590        }
591        i += 1;
592    }
593
594    min_idx
595}
596
597#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
598#[target_feature(enable = "avx2")]
599unsafe fn argmin_avx2(arr: &[f32]) -> usize {
600    use core::arch::x86_64::*;
601
602    let mut min_val = arr[0];
603    let mut min_idx = 0;
604    let mut i = 0;
605
606    while i + 8 <= arr.len() {
607        let vec = _mm256_loadu_ps(&arr[i]);
608        let mut temp = [0.0f32; 8];
609        _mm256_storeu_ps(temp.as_mut_ptr(), vec);
610
611        for (j, &val) in temp.iter().enumerate() {
612            if val < min_val {
613                min_val = val;
614                min_idx = i + j;
615            }
616        }
617
618        i += 8;
619    }
620
621    // Handle remaining elements
622    while i < arr.len() {
623        if arr[i] < min_val {
624            min_val = arr[i];
625            min_idx = i;
626        }
627        i += 1;
628    }
629
630    min_idx
631}
632
633#[allow(non_snake_case)]
634#[cfg(all(test, not(feature = "no-std")))]
635mod tests {
636    use super::*;
637
638    #[cfg(feature = "no-std")]
639    use alloc::vec;
640
641    #[test]
642    fn test_binary_search_found() {
643        let arr = vec![1.0, 3.0, 5.0, 7.0, 9.0, 11.0, 13.0, 15.0];
644        assert_eq!(binary_search_f32_simd(&arr, 7.0), Some(3));
645        assert_eq!(binary_search_f32_simd(&arr, 1.0), Some(0));
646        assert_eq!(binary_search_f32_simd(&arr, 15.0), Some(7));
647    }
648
649    #[test]
650    fn test_binary_search_not_found() {
651        let arr = vec![1.0, 3.0, 5.0, 7.0, 9.0, 11.0, 13.0, 15.0];
652        assert_eq!(binary_search_f32_simd(&arr, 6.0), None);
653        assert_eq!(binary_search_f32_simd(&arr, 0.0), None);
654        assert_eq!(binary_search_f32_simd(&arr, 16.0), None);
655    }
656
657    #[test]
658    fn test_linear_search() {
659        let arr = vec![3.0, 1.0, 4.0, 1.0, 5.0, 9.0];
660        assert_eq!(linear_search_f32_simd(&arr, 4.0), Some(2));
661        assert_eq!(linear_search_f32_simd(&arr, 1.0), Some(1)); // First occurrence
662        assert_eq!(linear_search_f32_simd(&arr, 8.0), None);
663    }
664
665    #[test]
666    fn test_k_nearest_neighbors() {
667        let p1 = [1.0, 1.0];
668        let p2 = [2.0, 2.0];
669        let p3 = [5.0, 5.0];
670        let p4 = [6.0, 6.0];
671        let points = vec![&p1[..], &p2[..], &p3[..], &p4[..]];
672
673        let query = [1.5, 1.5];
674        let neighbors = k_nearest_neighbors_simd(&points, &query, 2);
675
676        assert_eq!(neighbors.len(), 2);
677        // Should return the two closest points (p1 and p2)
678        assert!(neighbors[0].index < 2);
679        assert!(neighbors[1].index < 2);
680    }
681
682    #[test]
683    fn test_range_search() {
684        let p1 = [1.0, 1.0];
685        let p2 = [2.0, 2.0];
686        let p3 = [5.0, 5.0];
687        let points = vec![&p1[..], &p2[..], &p3[..]];
688
689        let query = [1.5, 1.5];
690        let results = range_search_simd(&points, &query, 1.0);
691
692        // Should find p1 and p2 within distance 1.0
693        assert!(!results.is_empty());
694        assert!(results.iter().all(|r| r.distance <= 1.0));
695    }
696
697    #[test]
698    fn test_argmax() {
699        let arr = vec![3.0, 1.0, 4.0, 1.0, 5.0, 9.0, 2.0];
700        assert_eq!(argmax_f32_simd(&arr), Some(5)); // Index of 9.0
701    }
702
703    #[test]
704    fn test_argmin() {
705        let arr = vec![3.0, 1.0, 4.0, 1.0, 5.0, 9.0, 2.0];
706        assert_eq!(argmin_f32_simd(&arr), Some(1)); // Index of first 1.0
707    }
708
709    #[test]
710    fn test_empty_arrays() {
711        let empty: Vec<f32> = vec![];
712        assert_eq!(binary_search_f32_simd(&empty, 1.0), None);
713        assert_eq!(linear_search_f32_simd(&empty, 1.0), None);
714        assert_eq!(argmax_f32_simd(&empty), None);
715        assert_eq!(argmin_f32_simd(&empty), None);
716    }
717
718    #[test]
719    fn test_single_element() {
720        let arr = vec![42.0];
721        assert_eq!(binary_search_f32_simd(&arr, 42.0), Some(0));
722        assert_eq!(linear_search_f32_simd(&arr, 42.0), Some(0));
723        assert_eq!(argmax_f32_simd(&arr), Some(0));
724        assert_eq!(argmin_f32_simd(&arr), Some(0));
725    }
726
727    #[test]
728    fn test_lsh_table() {
729        let mut lsh = LSHTable::new(2, 3, 4);
730
731        // Add some points
732        let p1 = vec![1.0, 1.0];
733        let p2 = vec![2.0, 2.0];
734        let p3 = vec![10.0, 10.0];
735
736        lsh.add_point(&p1, 0);
737        lsh.add_point(&p2, 1);
738        lsh.add_point(&p3, 2);
739
740        // Query for similar points
741        let query = vec![1.1, 1.1];
742        let candidates = lsh.query(&query, 5);
743
744        // Should return some candidates
745        assert!(!candidates.is_empty());
746    }
747}