1use rayon::prelude::*;
8use scirs2_core::ndarray::{Array1, Array2, Axis};
9use sklears_core::{error::Result as SklResult, prelude::SklearsError};
10use wide::f64x4;
11
12#[derive(Clone, Debug)]
14pub struct CacheOptimizedData {
15 data: Vec<f64>,
17 missing_mask: Vec<u64>,
19 n_rows: usize,
21 n_cols: usize,
22 cache_line_size: usize,
24}
25
26impl CacheOptimizedData {
27 pub fn new(data: &Array2<f64>, missing_val: f64) -> Self {
29 let (n_rows, n_cols) = data.dim();
30 let cache_line_size = 64; let padded_cols =
34 ((n_cols * 8 + cache_line_size - 1) / cache_line_size) * cache_line_size / 8;
35 let mut aligned_data = vec![0.0; n_rows * padded_cols];
36
37 for i in 0..n_rows {
39 for j in 0..n_cols {
40 aligned_data[i * padded_cols + j] = data[[i, j]];
41 }
42 }
43
44 let mask_len = (n_rows * n_cols + 63) / 64;
46 let mut missing_mask = vec![0u64; mask_len];
47
48 for i in 0..n_rows {
49 for j in 0..n_cols {
50 let idx = i * n_cols + j;
51 let is_missing = if missing_val.is_nan() {
52 data[[i, j]].is_nan()
53 } else {
54 (data[[i, j]] - missing_val).abs() < f64::EPSILON
55 };
56
57 if is_missing {
58 let word_idx = idx / 64;
59 let bit_idx = idx % 64;
60 missing_mask[word_idx] |= 1u64 << bit_idx;
61 }
62 }
63 }
64
65 Self {
66 data: aligned_data,
67 missing_mask,
68 n_rows,
69 n_cols,
70 cache_line_size,
71 }
72 }
73
74 pub fn get(&self, i: usize, j: usize) -> Option<f64> {
76 if i < self.n_rows && j < self.n_cols {
77 let padded_cols = ((self.n_cols * 8 + self.cache_line_size - 1) / self.cache_line_size)
78 * self.cache_line_size
79 / 8;
80 Some(self.data[i * padded_cols + j])
81 } else {
82 None
83 }
84 }
85
86 pub fn is_missing(&self, i: usize, j: usize) -> bool {
88 if i < self.n_rows && j < self.n_cols {
89 let idx = i * self.n_cols + j;
90 let word_idx = idx / 64;
91 let bit_idx = idx % 64;
92 (self.missing_mask[word_idx] & (1u64 << bit_idx)) != 0
93 } else {
94 false
95 }
96 }
97
98 pub fn get_row(&self, i: usize) -> Option<&[f64]> {
100 if i < self.n_rows {
101 let padded_cols = ((self.n_cols * 8 + self.cache_line_size - 1) / self.cache_line_size)
102 * self.cache_line_size
103 / 8;
104 let start = i * padded_cols;
105 Some(&self.data[start..start + self.n_cols])
106 } else {
107 None
108 }
109 }
110}
111
112pub struct SimdDistanceCalculator;
114
115impl SimdDistanceCalculator {
116 pub fn euclidean_distance_simd(x: &[f64], y: &[f64]) -> f64 {
118 assert_eq!(x.len(), y.len(), "Vectors must have the same length");
119
120 if x.len() < 4 {
121 return x
123 .iter()
124 .zip(y.iter())
125 .map(|(a, b)| (a - b).powi(2))
126 .sum::<f64>()
127 .sqrt();
128 }
129
130 unsafe { Self::euclidean_distance_simd_unsafe(x, y) }
131 }
132
133 unsafe fn euclidean_distance_simd_unsafe(x: &[f64], y: &[f64]) -> f64 {
135 let len = x.len();
136 let chunks = len / 4;
137
138 let mut sum = f64x4::splat(0.0);
139
140 for i in 0..chunks {
142 let x_chunk = f64x4::new([x[i * 4], x[i * 4 + 1], x[i * 4 + 2], x[i * 4 + 3]]);
143 let y_chunk = f64x4::new([y[i * 4], y[i * 4 + 1], y[i * 4 + 2], y[i * 4 + 3]]);
144 let diff = x_chunk - y_chunk;
145 sum += diff * diff;
146 }
147
148 let sum_array = sum.to_array();
150 let mut result = sum_array[0] + sum_array[1] + sum_array[2] + sum_array[3];
151
152 for i in (chunks * 4)..len {
154 let diff = x[i] - y[i];
155 result += diff * diff;
156 }
157
158 result.sqrt()
159 }
160
161 pub fn manhattan_distance_simd(x: &[f64], y: &[f64]) -> f64 {
163 assert_eq!(x.len(), y.len(), "Vectors must have the same length");
164
165 if x.len() < 4 {
166 return x.iter().zip(y.iter()).map(|(a, b)| (a - b).abs()).sum();
168 }
169
170 unsafe { Self::manhattan_distance_simd_unsafe(x, y) }
171 }
172
173 unsafe fn manhattan_distance_simd_unsafe(x: &[f64], y: &[f64]) -> f64 {
175 let len = x.len();
176 let chunks = len / 4;
177
178 let mut sum = f64x4::splat(0.0);
179
180 for i in 0..chunks {
182 let x_chunk = f64x4::new([x[i * 4], x[i * 4 + 1], x[i * 4 + 2], x[i * 4 + 3]]);
183 let y_chunk = f64x4::new([y[i * 4], y[i * 4 + 1], y[i * 4 + 2], y[i * 4 + 3]]);
184 let diff = x_chunk - y_chunk;
185 sum += diff.abs();
186 }
187
188 let sum_array = sum.to_array();
190 let mut result = sum_array[0] + sum_array[1] + sum_array[2] + sum_array[3];
191
192 for i in (chunks * 4)..len {
194 result += (x[i] - y[i]).abs();
195 }
196
197 result
198 }
199
200 pub fn nan_euclidean_distance_simd(x: &[f64], y: &[f64]) -> f64 {
202 assert_eq!(x.len(), y.len(), "Vectors must have the same length");
203
204 let mut sum_sq = 0.0;
205 let mut valid_count = 0;
206
207 for (&x_val, &y_val) in x.iter().zip(y.iter()) {
209 if !x_val.is_nan() && !y_val.is_nan() {
210 let diff = x_val - y_val;
211 sum_sq += diff * diff;
212 valid_count += 1;
213 }
214 }
215
216 if valid_count > 0 {
217 (sum_sq / valid_count as f64).sqrt()
218 } else {
219 f64::INFINITY
220 }
221 }
222
223 pub fn cosine_similarity_simd(x: &[f64], y: &[f64]) -> f64 {
225 assert_eq!(x.len(), y.len(), "Vectors must have the same length");
226
227 if x.len() < 4 {
228 let dot_product: f64 = x.iter().zip(y.iter()).map(|(a, b)| a * b).sum();
230 let norm_x: f64 = x.iter().map(|a| a * a).sum::<f64>().sqrt();
231 let norm_y: f64 = y.iter().map(|a| a * a).sum::<f64>().sqrt();
232
233 if norm_x == 0.0 || norm_y == 0.0 {
234 return 0.0;
235 }
236
237 return dot_product / (norm_x * norm_y);
238 }
239
240 unsafe { Self::cosine_similarity_simd_unsafe(x, y) }
241 }
242
243 unsafe fn cosine_similarity_simd_unsafe(x: &[f64], y: &[f64]) -> f64 {
245 let len = x.len();
246 let chunks = len / 4;
247
248 let mut dot_product = f64x4::splat(0.0);
249 let mut norm_x_sq = f64x4::splat(0.0);
250 let mut norm_y_sq = f64x4::splat(0.0);
251
252 for i in 0..chunks {
254 let x_chunk = f64x4::new([x[i * 4], x[i * 4 + 1], x[i * 4 + 2], x[i * 4 + 3]]);
255 let y_chunk = f64x4::new([y[i * 4], y[i * 4 + 1], y[i * 4 + 2], y[i * 4 + 3]]);
256
257 dot_product += x_chunk * y_chunk;
258 norm_x_sq += x_chunk * x_chunk;
259 norm_y_sq += y_chunk * y_chunk;
260 }
261
262 let dot_array = dot_product.to_array();
264 let norm_x_array = norm_x_sq.to_array();
265 let norm_y_array = norm_y_sq.to_array();
266
267 let mut dot_result = dot_array[0] + dot_array[1] + dot_array[2] + dot_array[3];
268 let mut norm_x_result =
269 norm_x_array[0] + norm_x_array[1] + norm_x_array[2] + norm_x_array[3];
270 let mut norm_y_result =
271 norm_y_array[0] + norm_y_array[1] + norm_y_array[2] + norm_y_array[3];
272
273 for i in (chunks * 4)..len {
275 dot_result += x[i] * y[i];
276 norm_x_result += x[i] * x[i];
277 norm_y_result += y[i] * y[i];
278 }
279
280 let norm_x = norm_x_result.sqrt();
281 let norm_y = norm_y_result.sqrt();
282
283 if norm_x == 0.0 || norm_y == 0.0 {
284 0.0
285 } else {
286 dot_result / (norm_x * norm_y)
287 }
288 }
289}
290
291pub struct SimdStatistics;
293
294impl SimdStatistics {
295 pub fn mean_simd(data: &[f64]) -> f64 {
297 if data.is_empty() {
298 return 0.0;
299 }
300
301 if data.len() < 4 {
302 return data.iter().sum::<f64>() / data.len() as f64;
303 }
304
305 unsafe { Self::mean_simd_unsafe(data) }
306 }
307
308 unsafe fn mean_simd_unsafe(data: &[f64]) -> f64 {
310 let len = data.len();
311 let chunks = len / 4;
312
313 let mut sum = f64x4::splat(0.0);
314
315 for i in 0..chunks {
317 let chunk = f64x4::new([
318 data[i * 4],
319 data[i * 4 + 1],
320 data[i * 4 + 2],
321 data[i * 4 + 3],
322 ]);
323 sum += chunk;
324 }
325
326 let sum_array = sum.to_array();
328 let mut result = sum_array[0] + sum_array[1] + sum_array[2] + sum_array[3];
329
330 result += data
332 .iter()
333 .skip(chunks * 4)
334 .take(len - chunks * 4)
335 .sum::<f64>();
336
337 result / len as f64
338 }
339
340 pub fn variance_simd(data: &[f64], mean: Option<f64>) -> f64 {
342 if data.len() <= 1 {
343 return 0.0;
344 }
345
346 let mean = mean.unwrap_or_else(|| Self::mean_simd(data));
347
348 if data.len() < 4 {
349 let sum_sq_diff: f64 = data.iter().map(|&x| (x - mean).powi(2)).sum();
350 return sum_sq_diff / (data.len() - 1) as f64;
351 }
352
353 unsafe { Self::variance_simd_unsafe(data, mean) }
354 }
355
356 unsafe fn variance_simd_unsafe(data: &[f64], mean: f64) -> f64 {
358 let len = data.len();
359 let chunks = len / 4;
360
361 let mean_vec = f64x4::splat(mean);
362 let mut sum_sq_diff = f64x4::splat(0.0);
363
364 for i in 0..chunks {
366 let chunk = f64x4::new([
367 data[i * 4],
368 data[i * 4 + 1],
369 data[i * 4 + 2],
370 data[i * 4 + 3],
371 ]);
372 let diff = chunk - mean_vec;
373 sum_sq_diff += diff * diff;
374 }
375
376 let sum_array = sum_sq_diff.to_array();
378 let mut result = sum_array[0] + sum_array[1] + sum_array[2] + sum_array[3];
379
380 result += data
382 .iter()
383 .skip(chunks * 4)
384 .take(len - chunks * 4)
385 .map(|&x| {
386 let diff = x - mean;
387 diff * diff
388 })
389 .sum::<f64>();
390
391 result / (len - 1) as f64
392 }
393
394 pub fn std_dev_simd(data: &[f64], mean: Option<f64>) -> f64 {
396 Self::variance_simd(data, mean).sqrt()
397 }
398
399 pub fn min_max_simd(data: &[f64]) -> (f64, f64) {
401 if data.is_empty() {
402 return (f64::NAN, f64::NAN);
403 }
404
405 if data.len() == 1 {
406 return (data[0], data[0]);
407 }
408
409 if data.len() < 4 {
410 let mut min_val = data[0];
411 let mut max_val = data[0];
412 for &val in &data[1..] {
413 if val < min_val {
414 min_val = val;
415 }
416 if val > max_val {
417 max_val = val;
418 }
419 }
420 return (min_val, max_val);
421 }
422
423 unsafe { Self::min_max_simd_unsafe(data) }
424 }
425
426 unsafe fn min_max_simd_unsafe(data: &[f64]) -> (f64, f64) {
428 let len = data.len();
429 let chunks = len / 4;
430
431 let mut min_result = f64::INFINITY;
432 let mut max_result = f64::NEG_INFINITY;
433
434 for i in 0..chunks {
436 let base_idx = i * 4;
437 for j in 0..4 {
438 let val = data[base_idx + j];
439 if val < min_result {
440 min_result = val;
441 }
442 if val > max_result {
443 max_result = val;
444 }
445 }
446 }
447
448 for &val in data.iter().skip(chunks * 4).take(len - chunks * 4) {
450 if val < min_result {
451 min_result = val;
452 }
453 if val > max_result {
454 max_result = val;
455 }
456 }
457
458 (min_result, max_result)
459 }
460
461 pub fn quantile_simd(data: &[f64], q: f64) -> f64 {
463 if data.is_empty() {
464 return f64::NAN;
465 }
466
467 let mut sorted_data: Vec<f64> = data.iter().filter(|&&x| !x.is_nan()).cloned().collect();
469
470 if sorted_data.is_empty() {
471 return f64::NAN;
472 }
473
474 sorted_data.sort_unstable_by(|a, b| a.partial_cmp(b).unwrap());
476
477 let index = q * (sorted_data.len() - 1) as f64;
478 let lower = index.floor() as usize;
479 let upper = index.ceil() as usize;
480
481 if lower == upper {
482 sorted_data[lower]
483 } else {
484 let weight = index - lower as f64;
485 sorted_data[lower] * (1.0 - weight) + sorted_data[upper] * weight
486 }
487 }
488}
489
490pub struct SimdMatrixOps;
492
493impl SimdMatrixOps {
494 pub fn matrix_vector_multiply_simd(
496 matrix: &Array2<f64>,
497 vector: &Array1<f64>,
498 ) -> SklResult<Array1<f64>> {
499 let (n_rows, n_cols) = matrix.dim();
500
501 if n_cols != vector.len() {
502 return Err(SklearsError::InvalidInput(format!(
503 "Matrix columns {} must match vector length {}",
504 n_cols,
505 vector.len()
506 )));
507 }
508
509 let mut result = Array1::zeros(n_rows);
510 let vector_slice = vector.as_slice().unwrap();
511
512 for i in 0..n_rows {
514 let row = matrix.row(i);
515 result[i] = Self::dot_product_simd(row.as_slice().unwrap(), vector_slice);
516 }
517
518 Ok(result)
519 }
520
521 pub fn dot_product_simd(x: &[f64], y: &[f64]) -> f64 {
523 assert_eq!(x.len(), y.len(), "Vectors must have the same length");
524
525 if x.len() < 4 {
526 return x.iter().zip(y.iter()).map(|(&a, &b)| a * b).sum();
527 }
528
529 unsafe { Self::dot_product_simd_unsafe(x, y) }
530 }
531
532 unsafe fn dot_product_simd_unsafe(x: &[f64], y: &[f64]) -> f64 {
534 let len = x.len();
535 let chunks = len / 4;
536
537 let mut sum = f64x4::splat(0.0);
538
539 for i in 0..chunks {
541 let x_chunk = f64x4::new([x[i * 4], x[i * 4 + 1], x[i * 4 + 2], x[i * 4 + 3]]);
542 let y_chunk = f64x4::new([y[i * 4], y[i * 4 + 1], y[i * 4 + 2], y[i * 4 + 3]]);
543 sum += x_chunk * y_chunk;
544 }
545
546 let sum_array = sum.to_array();
548 let mut result = sum_array[0] + sum_array[1] + sum_array[2] + sum_array[3];
549
550 for i in (chunks * 4)..len {
552 result += x[i] * y[i];
553 }
554
555 result
556 }
557
558 pub fn transpose_simd(matrix: &Array2<f64>) -> Array2<f64> {
560 let (n_rows, n_cols) = matrix.dim();
561 let mut result = Array2::zeros((n_cols, n_rows));
562
563 const BLOCK_SIZE: usize = 64;
565
566 for i_block in (0..n_rows).step_by(BLOCK_SIZE) {
567 for j_block in (0..n_cols).step_by(BLOCK_SIZE) {
568 let i_end = (i_block + BLOCK_SIZE).min(n_rows);
569 let j_end = (j_block + BLOCK_SIZE).min(n_cols);
570
571 for i in i_block..i_end {
572 for j in j_block..j_end {
573 result[[j, i]] = matrix[[i, j]];
574 }
575 }
576 }
577 }
578
579 result
580 }
581
582 pub fn matrix_multiply_simd(a: &Array2<f64>, b: &Array2<f64>) -> SklResult<Array2<f64>> {
584 let (a_rows, a_cols) = a.dim();
585 let (b_rows, b_cols) = b.dim();
586
587 if a_cols != b_rows {
588 return Err(SklearsError::InvalidInput(format!(
589 "Matrix dimensions incompatible: {}x{} * {}x{}",
590 a_rows, a_cols, b_rows, b_cols
591 )));
592 }
593
594 let mut result = Array2::zeros((a_rows, b_cols));
595
596 const BLOCK_SIZE: usize = 64;
598
599 for i_block in (0..a_rows).step_by(BLOCK_SIZE) {
600 for j_block in (0..b_cols).step_by(BLOCK_SIZE) {
601 for k_block in (0..a_cols).step_by(BLOCK_SIZE) {
602 let i_end = (i_block + BLOCK_SIZE).min(a_rows);
603 let j_end = (j_block + BLOCK_SIZE).min(b_cols);
604 let k_end = (k_block + BLOCK_SIZE).min(a_cols);
605
606 for i in i_block..i_end {
607 for j in j_block..j_end {
608 let mut sum = 0.0;
609
610 let row = a.row(i);
612 let row_slice = row.as_slice().unwrap();
613 let k_slice = &row_slice[k_block..k_end];
614 let b_slice: Vec<f64> = (k_block..k_end).map(|k| b[[k, j]]).collect();
615
616 if k_slice.len() >= 4 {
617 unsafe {
618 sum += Self::dot_product_simd_unsafe(k_slice, &b_slice);
619 }
620 } else {
621 for k in k_block..k_end {
622 sum += a[[i, k]] * b[[k, j]];
623 }
624 }
625
626 result[[i, j]] += sum;
627 }
628 }
629 }
630 }
631 }
632
633 Ok(result)
634 }
635}
636
637pub struct SimdKMeans;
639
640impl SimdKMeans {
641 pub fn calculate_centroids_simd(data: &Array2<f64>, labels: &[usize], k: usize) -> Array2<f64> {
643 let (_n_samples, n_features) = data.dim();
644 let mut centroids = Array2::zeros((k, n_features));
645 let mut counts = vec![0; k];
646
647 for &label in labels {
649 counts[label] += 1;
650 }
651
652 centroids
654 .axis_iter_mut(Axis(0))
655 .enumerate()
656 .par_bridge()
657 .for_each(|(cluster_idx, mut centroid)| {
658 let mut sums = vec![0.0; n_features];
659
660 for (sample_idx, &label) in labels.iter().enumerate() {
661 if label == cluster_idx {
662 let sample = data.row(sample_idx);
663 for (i, &val) in sample.iter().enumerate() {
664 sums[i] += val;
665 }
666 }
667 }
668
669 if counts[cluster_idx] > 0 {
671 let count = counts[cluster_idx] as f64;
672 for (i, &sum) in sums.iter().enumerate() {
673 centroid[i] = sum / count;
674 }
675 }
676 });
677
678 centroids
679 }
680}
681
682pub struct SimdImputationOps;
684
685impl SimdImputationOps {
686 pub fn weighted_mean_simd(values: &[f64], weights: &[f64]) -> f64 {
688 assert_eq!(
689 values.len(),
690 weights.len(),
691 "Values and weights must have same length"
692 );
693
694 if values.is_empty() {
695 return 0.0;
696 }
697
698 if values.len() < 8 {
699 let weighted_sum: f64 = values
700 .iter()
701 .zip(weights.iter())
702 .map(|(&v, &w)| v * w)
703 .sum();
704 let weight_sum: f64 = weights.iter().sum();
705
706 return if weight_sum > 0.0 {
707 weighted_sum / weight_sum
708 } else {
709 SimdStatistics::mean_simd(values)
710 };
711 }
712
713 unsafe { Self::weighted_mean_simd_unsafe(values, weights) }
714 }
715
716 unsafe fn weighted_mean_simd_unsafe(values: &[f64], weights: &[f64]) -> f64 {
718 let len = values.len();
719 let chunks = len / 4;
720
721 let mut weighted_sum = f64x4::splat(0.0);
722 let mut weight_sum = f64x4::splat(0.0);
723
724 for i in 0..chunks {
726 let values_chunk = f64x4::new([
727 values[i * 4],
728 values[i * 4 + 1],
729 values[i * 4 + 2],
730 values[i * 4 + 3],
731 ]);
732 let weights_chunk = f64x4::new([
733 weights[i * 4],
734 weights[i * 4 + 1],
735 weights[i * 4 + 2],
736 weights[i * 4 + 3],
737 ]);
738
739 weighted_sum += values_chunk * weights_chunk;
740 weight_sum += weights_chunk;
741 }
742
743 let weighted_array = weighted_sum.to_array();
745 let weight_array = weight_sum.to_array();
746 let mut weighted_result =
747 weighted_array[0] + weighted_array[1] + weighted_array[2] + weighted_array[3];
748 let mut weight_result =
749 weight_array[0] + weight_array[1] + weight_array[2] + weight_array[3];
750
751 for i in (chunks * 4)..len {
753 weighted_result += values[i] * weights[i];
754 weight_result += weights[i];
755 }
756
757 if weight_result > 0.0 {
758 weighted_result / weight_result
759 } else {
760 SimdStatistics::mean_simd(values)
761 }
762 }
763
764 pub fn count_missing_simd(data: &[f64]) -> usize {
766 if data.len() < 4 {
767 return data.iter().filter(|&&x| x.is_nan()).count();
768 }
769
770 unsafe { Self::count_missing_simd_unsafe(data) }
771 }
772
773 unsafe fn count_missing_simd_unsafe(data: &[f64]) -> usize {
775 let len = data.len();
776 let chunks = len / 4;
777
778 let mut missing_count = 0;
779
780 for i in 0..chunks {
782 let base_idx = i * 4;
783 for j in 0..4 {
784 if data[base_idx + j].is_nan() {
785 missing_count += 1;
786 }
787 }
788 }
789
790 missing_count += data
792 .iter()
793 .skip(chunks * 4)
794 .take(len - chunks * 4)
795 .filter(|x| x.is_nan())
796 .count();
797
798 missing_count
799 }
800
801 pub fn batch_distances_simd(
803 query_point: &[f64],
804 data_points: &Array2<f64>,
805 metric: &str,
806 ) -> Vec<f64> {
807 let n_points = data_points.nrows();
808 let mut distances = Vec::with_capacity(n_points);
809
810 match metric {
811 "euclidean" => {
812 distances.par_extend((0..n_points).into_par_iter().map(|i| {
813 let row = data_points.row(i);
814 let point = row.as_slice().unwrap();
815 SimdDistanceCalculator::euclidean_distance_simd(query_point, point)
816 }));
817 }
818 "manhattan" => {
819 distances.par_extend((0..n_points).into_par_iter().map(|i| {
820 let row = data_points.row(i);
821 let point = row.as_slice().unwrap();
822 SimdDistanceCalculator::manhattan_distance_simd(query_point, point)
823 }));
824 }
825 "cosine" => {
826 distances.par_extend((0..n_points).into_par_iter().map(|i| {
827 let row = data_points.row(i);
828 let point = row.as_slice().unwrap();
829 1.0 - SimdDistanceCalculator::cosine_similarity_simd(query_point, point)
830 }));
831 }
832 "nan_euclidean" => {
833 distances.par_extend((0..n_points).into_par_iter().map(|i| {
834 let row = data_points.row(i);
835 let point = row.as_slice().unwrap();
836 SimdDistanceCalculator::nan_euclidean_distance_simd(query_point, point)
837 }));
838 }
839 _ => {
840 distances.par_extend((0..n_points).into_par_iter().map(|i| {
842 let row = data_points.row(i);
843 let point = row.as_slice().unwrap();
844 SimdDistanceCalculator::euclidean_distance_simd(query_point, point)
845 }));
846 }
847 }
848
849 distances
850 }
851
852 pub fn find_knn_simd(
854 query_point: &[f64],
855 data_points: &Array2<f64>,
856 k: usize,
857 metric: &str,
858 ) -> Vec<(usize, f64)> {
859 let distances = Self::batch_distances_simd(query_point, data_points, metric);
860
861 let mut indexed_distances: Vec<(usize, f64)> = distances
862 .into_iter()
863 .enumerate()
864 .filter(|(_, dist)| dist.is_finite())
865 .collect();
866
867 if k < indexed_distances.len() {
869 indexed_distances.select_nth_unstable_by(k, |a, b| a.1.partial_cmp(&b.1).unwrap());
870 indexed_distances.truncate(k);
871 }
872
873 indexed_distances.sort_unstable_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
874 indexed_distances
875 }
876
877 pub fn streaming_mean_imputation(data: &mut Array2<f64>, chunk_size: usize, missing_val: f64) {
879 let (n_rows, n_cols) = data.dim();
880
881 let mut column_means = vec![0.0; n_cols];
883 let mut column_counts = vec![0; n_cols];
884
885 for row_chunk in (0..n_rows).step_by(chunk_size) {
886 let end_row = (row_chunk + chunk_size).min(n_rows);
887
888 for i in row_chunk..end_row {
889 for j in 0..n_cols {
890 let val = data[[i, j]];
891 let is_missing = if missing_val.is_nan() {
892 val.is_nan()
893 } else {
894 (val - missing_val).abs() < f64::EPSILON
895 };
896
897 if !is_missing {
898 column_means[j] += val;
899 column_counts[j] += 1;
900 }
901 }
902 }
903 }
904
905 for j in 0..n_cols {
907 if column_counts[j] > 0 {
908 column_means[j] /= column_counts[j] as f64;
909 }
910 }
911
912 for row_chunk in (0..n_rows).step_by(chunk_size) {
914 let end_row = (row_chunk + chunk_size).min(n_rows);
915
916 for i in row_chunk..end_row {
917 for j in 0..n_cols {
918 let val = data[[i, j]];
919 let is_missing = if missing_val.is_nan() {
920 val.is_nan()
921 } else {
922 (val - missing_val).abs() < f64::EPSILON
923 };
924
925 if is_missing && column_counts[j] > 0 {
926 data[[i, j]] = column_means[j];
927 }
928 }
929 }
930 }
931 }
932}
933
934#[allow(non_snake_case)]
935#[cfg(test)]
936mod tests {
937 use super::*;
938 use approx::assert_abs_diff_eq;
939
940 #[test]
941 fn test_euclidean_distance() {
942 let x = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0];
943 let y = vec![2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
944
945 let distance = SimdDistanceCalculator::euclidean_distance_simd(&x, &y);
946 let expected = 3.0; assert_abs_diff_eq!(distance, expected, epsilon = 1e-10);
949 }
950
951 #[test]
952 fn test_mean() {
953 let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
954 let mean = SimdStatistics::mean_simd(&data);
955 let expected = 5.5;
956
957 assert_abs_diff_eq!(mean, expected, epsilon = 1e-10);
958 }
959
960 #[test]
961 fn test_dot_product() {
962 let x = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
963 let y = vec![2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0];
964
965 let dot_product = SimdMatrixOps::dot_product_simd(&x, &y);
966 let expected = 240.0; assert_abs_diff_eq!(dot_product, expected, epsilon = 1e-10);
969 }
970
971 #[test]
972 fn test_weighted_mean() {
973 let values = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
974 let weights = vec![1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0];
975
976 let weighted_mean = SimdImputationOps::weighted_mean_simd(&values, &weights);
977 let expected = 4.5; assert_abs_diff_eq!(weighted_mean, expected, epsilon = 1e-10);
980 }
981
982 #[test]
983 fn test_cache_optimized_data() {
984 let data = Array2::from_shape_vec(
985 (3, 4),
986 vec![
987 1.0,
988 2.0,
989 f64::NAN,
990 4.0,
991 5.0,
992 f64::NAN,
993 7.0,
994 8.0,
995 9.0,
996 10.0,
997 11.0,
998 f64::NAN,
999 ],
1000 )
1001 .unwrap();
1002
1003 let optimized = CacheOptimizedData::new(&data, f64::NAN);
1004
1005 assert_eq!(optimized.get(0, 0), Some(1.0));
1007 assert_eq!(optimized.get(0, 1), Some(2.0));
1008 assert_eq!(optimized.get(1, 0), Some(5.0));
1009
1010 assert!(optimized.is_missing(0, 2));
1012 assert!(optimized.is_missing(1, 1));
1013 assert!(optimized.is_missing(2, 3));
1014 assert!(!optimized.is_missing(0, 0));
1015 assert!(!optimized.is_missing(1, 0));
1016 }
1017
1018 #[test]
1019 fn test_simd_distance_calculations() {
1020 let x = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
1021 let y = vec![2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0];
1022
1023 let manhattan = SimdDistanceCalculator::manhattan_distance_simd(&x, &y);
1025 assert_abs_diff_eq!(manhattan, 10.0, epsilon = 1e-10);
1026
1027 let cosine_sim = SimdDistanceCalculator::cosine_similarity_simd(&x, &y);
1029 assert!(cosine_sim > 0.9); }
1031
1032 #[test]
1033 fn test_simd_statistics() {
1034 let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
1035
1036 let variance = SimdStatistics::variance_simd(&data, None);
1038 let expected_variance = 9.166666666666666; assert_abs_diff_eq!(variance, expected_variance, epsilon = 1e-10);
1040
1041 let (min_val, max_val) = SimdStatistics::min_max_simd(&data);
1043 assert_eq!(min_val, 1.0);
1044 assert_eq!(max_val, 10.0);
1045
1046 let median = SimdStatistics::quantile_simd(&data, 0.5);
1048 assert_abs_diff_eq!(median, 5.5, epsilon = 1e-10);
1049 }
1050
1051 #[test]
1052 fn test_matrix_operations() {
1053 let matrix =
1054 Array2::from_shape_vec((3, 3), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0])
1055 .unwrap();
1056
1057 let transposed = SimdMatrixOps::transpose_simd(&matrix);
1059 assert_eq!(transposed[[0, 0]], 1.0);
1060 assert_eq!(transposed[[0, 1]], 4.0);
1061 assert_eq!(transposed[[0, 2]], 7.0);
1062 assert_eq!(transposed[[1, 0]], 2.0);
1063
1064 let vector = Array1::from_vec(vec![1.0, 2.0, 3.0]);
1066 let result = SimdMatrixOps::matrix_vector_multiply_simd(&matrix, &vector).unwrap();
1067
1068 assert_abs_diff_eq!(result[0], 14.0, epsilon = 1e-10);
1070 assert_abs_diff_eq!(result[1], 32.0, epsilon = 1e-10);
1071 assert_abs_diff_eq!(result[2], 50.0, epsilon = 1e-10);
1072 }
1073
1074 #[test]
1075 fn test_batch_distances() {
1076 let query = vec![1.0, 2.0, 3.0];
1077 let data = Array2::from_shape_vec(
1078 (3, 3),
1079 vec![
1080 1.0, 2.0, 3.0, 2.0, 3.0, 4.0, 4.0, 5.0, 6.0, ],
1084 )
1085 .unwrap();
1086
1087 let distances = SimdImputationOps::batch_distances_simd(&query, &data, "euclidean");
1088
1089 assert_eq!(distances.len(), 3);
1090 assert_abs_diff_eq!(distances[0], 0.0, epsilon = 1e-10);
1091 assert_abs_diff_eq!(distances[1], 3.0_f64.sqrt(), epsilon = 1e-10);
1092 assert_abs_diff_eq!(distances[2], 27.0_f64.sqrt(), epsilon = 1e-10);
1093 }
1094
1095 #[test]
1096 fn test_knn_finding() {
1097 let query = vec![1.0, 2.0, 3.0];
1098 let data = Array2::from_shape_vec(
1099 (5, 3),
1100 vec![
1101 1.0, 2.0, 3.0, 2.0, 3.0, 4.0, 4.0, 5.0, 6.0, 0.5, 1.5, 2.5, 10.0, 11.0, 12.0, ],
1107 )
1108 .unwrap();
1109
1110 let knn = SimdImputationOps::find_knn_simd(&query, &data, 3, "euclidean");
1111
1112 assert_eq!(knn.len(), 3);
1113 assert_eq!(knn[0].0, 0); assert_eq!(knn[1].0, 3); assert_eq!(knn[2].0, 1); }
1117
1118 #[test]
1119 fn test_missing_count() {
1120 let data = vec![
1121 1.0,
1122 f64::NAN,
1123 3.0,
1124 f64::NAN,
1125 5.0,
1126 6.0,
1127 f64::NAN,
1128 8.0,
1129 9.0,
1130 f64::NAN,
1131 ];
1132 let count = SimdImputationOps::count_missing_simd(&data);
1133 assert_eq!(count, 4);
1134 }
1135
1136 #[test]
1137 fn test_streaming_imputation() {
1138 let mut data = Array2::from_shape_vec(
1139 (4, 3),
1140 vec![
1141 1.0,
1142 f64::NAN,
1143 3.0,
1144 4.0,
1145 5.0,
1146 f64::NAN,
1147 f64::NAN,
1148 8.0,
1149 9.0,
1150 10.0,
1151 11.0,
1152 12.0,
1153 ],
1154 )
1155 .unwrap();
1156
1157 SimdImputationOps::streaming_mean_imputation(&mut data, 2, f64::NAN);
1158
1159 assert_abs_diff_eq!(data[[0, 1]], 8.0, epsilon = 1e-10);
1165 assert_abs_diff_eq!(data[[1, 2]], 8.0, epsilon = 1e-10);
1166 assert_abs_diff_eq!(data[[2, 0]], 5.0, epsilon = 1e-10);
1167
1168 assert_abs_diff_eq!(data[[0, 0]], 1.0, epsilon = 1e-10);
1170 assert_abs_diff_eq!(data[[1, 0]], 4.0, epsilon = 1e-10);
1171 assert_abs_diff_eq!(data[[3, 2]], 12.0, epsilon = 1e-10);
1172 }
1173}