1use crate::distance::euclidean_distance;
7
8#[cfg(feature = "no-std")]
9use alloc::collections::{BTreeMap as HashMap, BTreeSet as HashSet};
10#[cfg(feature = "no-std")]
11use alloc::vec::Vec;
12#[cfg(not(feature = "no-std"))]
13use std::collections::{HashMap, HashSet};
14
15#[cfg(feature = "no-std")]
16use core::cmp::Ordering;
17#[cfg(not(feature = "no-std"))]
18use std::cmp::Ordering;
19
20pub fn binary_search_f32_simd(arr: &[f32], target: f32) -> Option<usize> {
23 if arr.is_empty() {
24 return None;
25 }
26
27 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
28 {
29 if crate::simd_feature_detected!("avx2") && arr.len() >= 16 {
30 return unsafe { binary_search_avx2(arr, target) };
31 } else if crate::simd_feature_detected!("sse2") && arr.len() >= 8 {
32 return unsafe { binary_search_sse2(arr, target) };
33 }
34 }
35
36 binary_search_scalar(arr, target)
37}
38
39fn binary_search_scalar(arr: &[f32], target: f32) -> Option<usize> {
40 let mut left = 0;
41 let mut right = arr.len();
42
43 while left < right {
44 let mid = left + (right - left) / 2;
45
46 match arr[mid].partial_cmp(&target) {
47 Some(Ordering::Equal) => return Some(mid),
48 Some(Ordering::Less) => left = mid + 1,
49 Some(Ordering::Greater) => right = mid,
50 None => return None, }
52 }
53
54 None
55}
56
57#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
58#[target_feature(enable = "sse2")]
59unsafe fn binary_search_sse2(arr: &[f32], target: f32) -> Option<usize> {
60 use core::arch::x86_64::*;
61
62 let mut left = 0;
63 let mut right = arr.len();
64 let target_vec = _mm_set1_ps(target);
65
66 while left < right {
67 let mid = left + (right - left) / 2;
68
69 if right - left <= 4 && left + 4 <= arr.len() {
71 let vec = _mm_loadu_ps(&arr[left]);
72 let eq_mask = _mm_cmpeq_ps(vec, target_vec);
73 let mask = _mm_movemask_ps(eq_mask);
74
75 if mask != 0 {
76 for i in 0..4 {
78 if (mask & (1 << i)) != 0 {
79 return Some(left + i);
80 }
81 }
82 }
83
84 return binary_search_scalar(&arr[left..right], target).map(|idx| left + idx);
86 }
87
88 match arr[mid].partial_cmp(&target) {
90 Some(Ordering::Equal) => return Some(mid),
91 Some(Ordering::Less) => left = mid + 1,
92 Some(Ordering::Greater) => right = mid,
93 None => return None,
94 }
95 }
96
97 None
98}
99
100#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
101#[target_feature(enable = "avx2")]
102unsafe fn binary_search_avx2(arr: &[f32], target: f32) -> Option<usize> {
103 use core::arch::x86_64::*;
104
105 let mut left = 0;
106 let mut right = arr.len();
107 let target_vec = _mm256_set1_ps(target);
108
109 while left < right {
110 let mid = left + (right - left) / 2;
111
112 if right - left <= 8 && left + 8 <= arr.len() {
114 let vec = _mm256_loadu_ps(&arr[left]);
115 let eq_mask = _mm256_cmp_ps(vec, target_vec, _CMP_EQ_OQ);
116 let mask = _mm256_movemask_ps(eq_mask);
117
118 if mask != 0 {
119 for i in 0..8 {
121 if (mask & (1 << i)) != 0 {
122 return Some(left + i);
123 }
124 }
125 }
126
127 return binary_search_scalar(&arr[left..right], target).map(|idx| left + idx);
129 }
130
131 match arr[mid].partial_cmp(&target) {
133 Some(Ordering::Equal) => return Some(mid),
134 Some(Ordering::Less) => left = mid + 1,
135 Some(Ordering::Greater) => right = mid,
136 None => return None,
137 }
138 }
139
140 None
141}
142
143pub fn linear_search_f32_simd(arr: &[f32], target: f32) -> Option<usize> {
146 if arr.is_empty() {
147 return None;
148 }
149
150 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
151 {
152 if crate::simd_feature_detected!("avx2") {
153 return unsafe { linear_search_avx2(arr, target) };
154 } else if crate::simd_feature_detected!("sse2") {
155 return unsafe { linear_search_sse2(arr, target) };
156 }
157 }
158
159 linear_search_scalar(arr, target)
160}
161
162fn linear_search_scalar(arr: &[f32], target: f32) -> Option<usize> {
163 for (i, &value) in arr.iter().enumerate() {
164 if value == target {
165 return Some(i);
166 }
167 }
168 None
169}
170
171#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
172#[target_feature(enable = "sse2")]
173unsafe fn linear_search_sse2(arr: &[f32], target: f32) -> Option<usize> {
174 use core::arch::x86_64::*;
175
176 let target_vec = _mm_set1_ps(target);
177 let mut i = 0;
178
179 while i + 4 <= arr.len() {
180 let vec = _mm_loadu_ps(&arr[i]);
181 let eq_mask = _mm_cmpeq_ps(vec, target_vec);
182 let mask = _mm_movemask_ps(eq_mask);
183
184 if mask != 0 {
185 for j in 0..4 {
187 if (mask & (1 << j)) != 0 {
188 return Some(i + j);
189 }
190 }
191 }
192
193 i += 4;
194 }
195
196 while i < arr.len() {
198 if arr[i] == target {
199 return Some(i);
200 }
201 i += 1;
202 }
203
204 None
205}
206
207#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
208#[target_feature(enable = "avx2")]
209unsafe fn linear_search_avx2(arr: &[f32], target: f32) -> Option<usize> {
210 use core::arch::x86_64::*;
211
212 let target_vec = _mm256_set1_ps(target);
213 let mut i = 0;
214
215 while i + 8 <= arr.len() {
216 let vec = _mm256_loadu_ps(&arr[i]);
217 let eq_mask = _mm256_cmp_ps(vec, target_vec, _CMP_EQ_OQ);
218 let mask = _mm256_movemask_ps(eq_mask);
219
220 if mask != 0 {
221 for j in 0..8 {
223 if (mask & (1 << j)) != 0 {
224 return Some(i + j);
225 }
226 }
227 }
228
229 i += 8;
230 }
231
232 while i < arr.len() {
234 if arr[i] == target {
235 return Some(i);
236 }
237 i += 1;
238 }
239
240 None
241}
242
243#[derive(Debug, Clone, PartialEq)]
245pub struct NearestNeighborResult {
246 pub index: usize,
247 pub distance: f32,
248}
249
250pub fn k_nearest_neighbors_simd(
253 points: &[&[f32]],
254 query: &[f32],
255 k: usize,
256) -> Vec<NearestNeighborResult> {
257 if points.is_empty() || k == 0 {
258 return Vec::new();
259 }
260
261 let k = k.min(points.len());
262 let mut distances: Vec<(usize, f32)> = Vec::with_capacity(points.len());
263
264 for (i, point) in points.iter().enumerate() {
266 let distance = euclidean_distance(query, point);
267 distances.push((i, distance));
268 }
269
270 distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
272
273 distances
274 .into_iter()
275 .take(k)
276 .map(|(index, distance)| NearestNeighborResult { index, distance })
277 .collect()
278}
279
280pub struct LSHTable {
283 tables: Vec<LSHHashTable>,
284 #[allow(dead_code)] num_tables: usize,
286 #[allow(dead_code)] hash_size: usize,
288}
289
290struct LSHHashTable {
291 buckets: HashMap<u64, Vec<usize>>,
292 random_vectors: Vec<Vec<f32>>,
293}
294
295impl LSHTable {
296 pub fn new(dimensions: usize, num_tables: usize, hash_size: usize) -> Self {
298 let mut tables = Vec::with_capacity(num_tables);
299
300 for _ in 0..num_tables {
301 let mut random_vectors = Vec::with_capacity(hash_size);
302
303 for _ in 0..hash_size {
305 let mut vec = Vec::with_capacity(dimensions);
306 let mut sum_squares = 0.0;
307
308 use scirs2_core::random::thread_rng;
310 let mut rng = thread_rng();
311 for _ in 0..dimensions {
312 let val: f32 = rng.random::<f32>() - 0.5;
313 vec.push(val);
314 sum_squares += val * val;
315 }
316
317 let norm = sum_squares.sqrt();
319 if norm > 0.0 {
320 for val in &mut vec {
321 *val /= norm;
322 }
323 }
324
325 random_vectors.push(vec);
326 }
327
328 tables.push(LSHHashTable {
329 buckets: HashMap::new(),
330 random_vectors,
331 });
332 }
333
334 LSHTable {
335 tables,
336 num_tables,
337 hash_size,
338 }
339 }
340
341 pub fn add_point(&mut self, point: &[f32], index: usize) {
343 for i in 0..self.tables.len() {
344 let hash = self.hash_point(&self.tables[i], point);
345 self.tables[i].buckets.entry(hash).or_default().push(index);
346 }
347 }
348
349 pub fn query(&self, point: &[f32], max_candidates: usize) -> Vec<usize> {
351 let mut candidates = HashSet::new();
352
353 for table in &self.tables {
354 let hash = self.hash_point(table, point);
355
356 if let Some(bucket) = table.buckets.get(&hash) {
357 for &index in bucket {
358 candidates.insert(index);
359 if candidates.len() >= max_candidates {
360 break;
361 }
362 }
363 }
364
365 if candidates.len() >= max_candidates {
366 break;
367 }
368 }
369
370 candidates.into_iter().collect()
371 }
372
373 fn hash_point(&self, table: &LSHHashTable, point: &[f32]) -> u64 {
374 let mut hash = 0u64;
375
376 for (i, random_vec) in table.random_vectors.iter().enumerate() {
377 let dot_product = crate::vector::dot_product(point, random_vec);
379
380 if dot_product >= 0.0 {
382 hash |= 1u64 << i;
383 }
384 }
385
386 hash
387 }
388}
389
390pub fn range_search_simd(
393 points: &[&[f32]],
394 query: &[f32],
395 radius: f32,
396) -> Vec<NearestNeighborResult> {
397 let mut results = Vec::new();
398 let _radius_squared = radius * radius;
399
400 for (i, point) in points.iter().enumerate() {
401 let distance = euclidean_distance(query, point);
402 if distance <= radius {
403 results.push(NearestNeighborResult { index: i, distance });
404 }
405 }
406
407 results.sort_by(|a, b| {
409 a.distance
410 .partial_cmp(&b.distance)
411 .unwrap_or(Ordering::Equal)
412 });
413
414 results
415}
416
417pub fn argmax_f32_simd(arr: &[f32]) -> Option<usize> {
419 if arr.is_empty() {
420 return None;
421 }
422
423 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
424 {
425 if crate::simd_feature_detected!("avx2") && arr.len() >= 8 {
426 return Some(unsafe { argmax_avx2(arr) });
427 } else if crate::simd_feature_detected!("sse2") && arr.len() >= 4 {
428 return Some(unsafe { argmax_sse2(arr) });
429 }
430 }
431
432 argmax_scalar(arr)
433}
434
435fn argmax_scalar(arr: &[f32]) -> Option<usize> {
436 if arr.is_empty() {
437 return None;
438 }
439
440 let mut max_idx = 0;
441 let mut max_val = arr[0];
442
443 for (i, &val) in arr.iter().enumerate().skip(1) {
444 if val > max_val {
445 max_val = val;
446 max_idx = i;
447 }
448 }
449
450 Some(max_idx)
451}
452
453#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
454#[target_feature(enable = "sse2")]
455unsafe fn argmax_sse2(arr: &[f32]) -> usize {
456 use core::arch::x86_64::*;
457
458 let mut max_val = arr[0];
459 let mut max_idx = 0;
460 let mut i = 0;
461
462 while i + 4 <= arr.len() {
463 let vec = _mm_loadu_ps(&arr[i]);
464 let mut temp = [0.0f32; 4];
465 _mm_storeu_ps(temp.as_mut_ptr(), vec);
466
467 for (j, &val) in temp.iter().enumerate() {
468 if val > max_val {
469 max_val = val;
470 max_idx = i + j;
471 }
472 }
473
474 i += 4;
475 }
476
477 while i < arr.len() {
479 if arr[i] > max_val {
480 max_val = arr[i];
481 max_idx = i;
482 }
483 i += 1;
484 }
485
486 max_idx
487}
488
489#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
490#[target_feature(enable = "avx2")]
491unsafe fn argmax_avx2(arr: &[f32]) -> usize {
492 use core::arch::x86_64::*;
493
494 let mut max_val = arr[0];
495 let mut max_idx = 0;
496 let mut i = 0;
497
498 while i + 8 <= arr.len() {
499 let vec = _mm256_loadu_ps(&arr[i]);
500 let mut temp = [0.0f32; 8];
501 _mm256_storeu_ps(temp.as_mut_ptr(), vec);
502
503 for (j, &val) in temp.iter().enumerate() {
504 if val > max_val {
505 max_val = val;
506 max_idx = i + j;
507 }
508 }
509
510 i += 8;
511 }
512
513 while i < arr.len() {
515 if arr[i] > max_val {
516 max_val = arr[i];
517 max_idx = i;
518 }
519 i += 1;
520 }
521
522 max_idx
523}
524
525pub fn argmin_f32_simd(arr: &[f32]) -> Option<usize> {
527 if arr.is_empty() {
528 return None;
529 }
530
531 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
532 {
533 if crate::simd_feature_detected!("avx2") && arr.len() >= 8 {
534 return Some(unsafe { argmin_avx2(arr) });
535 } else if crate::simd_feature_detected!("sse2") && arr.len() >= 4 {
536 return Some(unsafe { argmin_sse2(arr) });
537 }
538 }
539
540 argmin_scalar(arr)
541}
542
543fn argmin_scalar(arr: &[f32]) -> Option<usize> {
544 if arr.is_empty() {
545 return None;
546 }
547
548 let mut min_idx = 0;
549 let mut min_val = arr[0];
550
551 for (i, &val) in arr.iter().enumerate().skip(1) {
552 if val < min_val {
553 min_val = val;
554 min_idx = i;
555 }
556 }
557
558 Some(min_idx)
559}
560
561#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
562#[target_feature(enable = "sse2")]
563unsafe fn argmin_sse2(arr: &[f32]) -> usize {
564 use core::arch::x86_64::*;
565
566 let mut min_val = arr[0];
567 let mut min_idx = 0;
568 let mut i = 0;
569
570 while i + 4 <= arr.len() {
571 let vec = _mm_loadu_ps(&arr[i]);
572 let mut temp = [0.0f32; 4];
573 _mm_storeu_ps(temp.as_mut_ptr(), vec);
574
575 for (j, &val) in temp.iter().enumerate() {
576 if val < min_val {
577 min_val = val;
578 min_idx = i + j;
579 }
580 }
581
582 i += 4;
583 }
584
585 while i < arr.len() {
587 if arr[i] < min_val {
588 min_val = arr[i];
589 min_idx = i;
590 }
591 i += 1;
592 }
593
594 min_idx
595}
596
597#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
598#[target_feature(enable = "avx2")]
599unsafe fn argmin_avx2(arr: &[f32]) -> usize {
600 use core::arch::x86_64::*;
601
602 let mut min_val = arr[0];
603 let mut min_idx = 0;
604 let mut i = 0;
605
606 while i + 8 <= arr.len() {
607 let vec = _mm256_loadu_ps(&arr[i]);
608 let mut temp = [0.0f32; 8];
609 _mm256_storeu_ps(temp.as_mut_ptr(), vec);
610
611 for (j, &val) in temp.iter().enumerate() {
612 if val < min_val {
613 min_val = val;
614 min_idx = i + j;
615 }
616 }
617
618 i += 8;
619 }
620
621 while i < arr.len() {
623 if arr[i] < min_val {
624 min_val = arr[i];
625 min_idx = i;
626 }
627 i += 1;
628 }
629
630 min_idx
631}
632
633#[allow(non_snake_case)]
634#[cfg(all(test, not(feature = "no-std")))]
635mod tests {
636 use super::*;
637
638 #[cfg(feature = "no-std")]
639 use alloc::vec;
640
641 #[test]
642 fn test_binary_search_found() {
643 let arr = vec![1.0, 3.0, 5.0, 7.0, 9.0, 11.0, 13.0, 15.0];
644 assert_eq!(binary_search_f32_simd(&arr, 7.0), Some(3));
645 assert_eq!(binary_search_f32_simd(&arr, 1.0), Some(0));
646 assert_eq!(binary_search_f32_simd(&arr, 15.0), Some(7));
647 }
648
649 #[test]
650 fn test_binary_search_not_found() {
651 let arr = vec![1.0, 3.0, 5.0, 7.0, 9.0, 11.0, 13.0, 15.0];
652 assert_eq!(binary_search_f32_simd(&arr, 6.0), None);
653 assert_eq!(binary_search_f32_simd(&arr, 0.0), None);
654 assert_eq!(binary_search_f32_simd(&arr, 16.0), None);
655 }
656
657 #[test]
658 fn test_linear_search() {
659 let arr = vec![3.0, 1.0, 4.0, 1.0, 5.0, 9.0];
660 assert_eq!(linear_search_f32_simd(&arr, 4.0), Some(2));
661 assert_eq!(linear_search_f32_simd(&arr, 1.0), Some(1)); assert_eq!(linear_search_f32_simd(&arr, 8.0), None);
663 }
664
665 #[test]
666 fn test_k_nearest_neighbors() {
667 let p1 = [1.0, 1.0];
668 let p2 = [2.0, 2.0];
669 let p3 = [5.0, 5.0];
670 let p4 = [6.0, 6.0];
671 let points = vec![&p1[..], &p2[..], &p3[..], &p4[..]];
672
673 let query = [1.5, 1.5];
674 let neighbors = k_nearest_neighbors_simd(&points, &query, 2);
675
676 assert_eq!(neighbors.len(), 2);
677 assert!(neighbors[0].index < 2);
679 assert!(neighbors[1].index < 2);
680 }
681
682 #[test]
683 fn test_range_search() {
684 let p1 = [1.0, 1.0];
685 let p2 = [2.0, 2.0];
686 let p3 = [5.0, 5.0];
687 let points = vec![&p1[..], &p2[..], &p3[..]];
688
689 let query = [1.5, 1.5];
690 let results = range_search_simd(&points, &query, 1.0);
691
692 assert!(!results.is_empty());
694 assert!(results.iter().all(|r| r.distance <= 1.0));
695 }
696
697 #[test]
698 fn test_argmax() {
699 let arr = vec![3.0, 1.0, 4.0, 1.0, 5.0, 9.0, 2.0];
700 assert_eq!(argmax_f32_simd(&arr), Some(5)); }
702
703 #[test]
704 fn test_argmin() {
705 let arr = vec![3.0, 1.0, 4.0, 1.0, 5.0, 9.0, 2.0];
706 assert_eq!(argmin_f32_simd(&arr), Some(1)); }
708
709 #[test]
710 fn test_empty_arrays() {
711 let empty: Vec<f32> = vec![];
712 assert_eq!(binary_search_f32_simd(&empty, 1.0), None);
713 assert_eq!(linear_search_f32_simd(&empty, 1.0), None);
714 assert_eq!(argmax_f32_simd(&empty), None);
715 assert_eq!(argmin_f32_simd(&empty), None);
716 }
717
718 #[test]
719 fn test_single_element() {
720 let arr = vec![42.0];
721 assert_eq!(binary_search_f32_simd(&arr, 42.0), Some(0));
722 assert_eq!(linear_search_f32_simd(&arr, 42.0), Some(0));
723 assert_eq!(argmax_f32_simd(&arr), Some(0));
724 assert_eq!(argmin_f32_simd(&arr), Some(0));
725 }
726
727 #[test]
728 fn test_lsh_table() {
729 let mut lsh = LSHTable::new(2, 3, 4);
730
731 let p1 = vec![1.0, 1.0];
733 let p2 = vec![2.0, 2.0];
734 let p3 = vec![10.0, 10.0];
735
736 lsh.add_point(&p1, 0);
737 lsh.add_point(&p2, 1);
738 lsh.add_point(&p3, 2);
739
740 let query = vec![1.1, 1.1];
742 let candidates = lsh.query(&query, 5);
743
744 assert!(!candidates.is_empty());
746 }
747}