sklears_impute/
simd_ops.rs

1//! Optimized numerical operations for high-performance imputation
2//!
3//! This module provides optimized implementations of common numerical operations
4//! used in missing data imputation to achieve significant performance improvements.
5//! Uses SIMD instructions and unsafe code for performance-critical paths.
6
7use rayon::prelude::*;
8use scirs2_core::ndarray::{Array1, Array2, Axis};
9use sklears_core::{error::Result as SklResult, prelude::SklearsError};
10use wide::f64x4;
11
12/// Cache-friendly data layout for imputation operations
13#[derive(Clone, Debug)]
14pub struct CacheOptimizedData {
15    /// Row-major data layout optimized for cache access
16    data: Vec<f64>,
17    /// Missing value indicators packed as bits
18    missing_mask: Vec<u64>,
19    /// Dimensions
20    n_rows: usize,
21    n_cols: usize,
22    /// Cache line size alignment
23    cache_line_size: usize,
24}
25
26impl CacheOptimizedData {
27    /// Create cache-optimized data layout
28    pub fn new(data: &Array2<f64>, missing_val: f64) -> Self {
29        let (n_rows, n_cols) = data.dim();
30        let cache_line_size = 64; // Common cache line size
31
32        // Pad columns to cache line boundaries
33        let padded_cols =
34            ((n_cols * 8 + cache_line_size - 1) / cache_line_size) * cache_line_size / 8;
35        let mut aligned_data = vec![0.0; n_rows * padded_cols];
36
37        // Copy data with padding
38        for i in 0..n_rows {
39            for j in 0..n_cols {
40                aligned_data[i * padded_cols + j] = data[[i, j]];
41            }
42        }
43
44        // Create packed missing mask (64 bits per u64)
45        let mask_len = (n_rows * n_cols + 63) / 64;
46        let mut missing_mask = vec![0u64; mask_len];
47
48        for i in 0..n_rows {
49            for j in 0..n_cols {
50                let idx = i * n_cols + j;
51                let is_missing = if missing_val.is_nan() {
52                    data[[i, j]].is_nan()
53                } else {
54                    (data[[i, j]] - missing_val).abs() < f64::EPSILON
55                };
56
57                if is_missing {
58                    let word_idx = idx / 64;
59                    let bit_idx = idx % 64;
60                    missing_mask[word_idx] |= 1u64 << bit_idx;
61                }
62            }
63        }
64
65        Self {
66            data: aligned_data,
67            missing_mask,
68            n_rows,
69            n_cols,
70            cache_line_size,
71        }
72    }
73
74    /// Get value at position (i, j) with bounds checking
75    pub fn get(&self, i: usize, j: usize) -> Option<f64> {
76        if i < self.n_rows && j < self.n_cols {
77            let padded_cols = ((self.n_cols * 8 + self.cache_line_size - 1) / self.cache_line_size)
78                * self.cache_line_size
79                / 8;
80            Some(self.data[i * padded_cols + j])
81        } else {
82            None
83        }
84    }
85
86    /// Check if value is missing
87    pub fn is_missing(&self, i: usize, j: usize) -> bool {
88        if i < self.n_rows && j < self.n_cols {
89            let idx = i * self.n_cols + j;
90            let word_idx = idx / 64;
91            let bit_idx = idx % 64;
92            (self.missing_mask[word_idx] & (1u64 << bit_idx)) != 0
93        } else {
94            false
95        }
96    }
97
98    /// Get row slice for cache-friendly access
99    pub fn get_row(&self, i: usize) -> Option<&[f64]> {
100        if i < self.n_rows {
101            let padded_cols = ((self.n_cols * 8 + self.cache_line_size - 1) / self.cache_line_size)
102                * self.cache_line_size
103                / 8;
104            let start = i * padded_cols;
105            Some(&self.data[start..start + self.n_cols])
106        } else {
107            None
108        }
109    }
110}
111
112/// Optimized distance calculations using SIMD and unsafe code
113pub struct SimdDistanceCalculator;
114
115impl SimdDistanceCalculator {
116    /// Optimized Euclidean distance calculation using SIMD
117    pub fn euclidean_distance_simd(x: &[f64], y: &[f64]) -> f64 {
118        assert_eq!(x.len(), y.len(), "Vectors must have the same length");
119
120        if x.len() < 4 {
121            // Fallback for small vectors
122            return x
123                .iter()
124                .zip(y.iter())
125                .map(|(a, b)| (a - b).powi(2))
126                .sum::<f64>()
127                .sqrt();
128        }
129
130        unsafe { Self::euclidean_distance_simd_unsafe(x, y) }
131    }
132
133    /// Unsafe SIMD implementation for maximum performance
134    unsafe fn euclidean_distance_simd_unsafe(x: &[f64], y: &[f64]) -> f64 {
135        let len = x.len();
136        let chunks = len / 4;
137
138        let mut sum = f64x4::splat(0.0);
139
140        // Process 4 elements at a time
141        for i in 0..chunks {
142            let x_chunk = f64x4::new([x[i * 4], x[i * 4 + 1], x[i * 4 + 2], x[i * 4 + 3]]);
143            let y_chunk = f64x4::new([y[i * 4], y[i * 4 + 1], y[i * 4 + 2], y[i * 4 + 3]]);
144            let diff = x_chunk - y_chunk;
145            sum += diff * diff;
146        }
147
148        // Sum the SIMD lanes
149        let sum_array = sum.to_array();
150        let mut result = sum_array[0] + sum_array[1] + sum_array[2] + sum_array[3];
151
152        // Handle remaining elements
153        for i in (chunks * 4)..len {
154            let diff = x[i] - y[i];
155            result += diff * diff;
156        }
157
158        result.sqrt()
159    }
160
161    /// Optimized Manhattan distance calculation using SIMD
162    pub fn manhattan_distance_simd(x: &[f64], y: &[f64]) -> f64 {
163        assert_eq!(x.len(), y.len(), "Vectors must have the same length");
164
165        if x.len() < 4 {
166            // Fallback for small vectors
167            return x.iter().zip(y.iter()).map(|(a, b)| (a - b).abs()).sum();
168        }
169
170        unsafe { Self::manhattan_distance_simd_unsafe(x, y) }
171    }
172
173    /// Unsafe SIMD implementation for Manhattan distance
174    unsafe fn manhattan_distance_simd_unsafe(x: &[f64], y: &[f64]) -> f64 {
175        let len = x.len();
176        let chunks = len / 4;
177
178        let mut sum = f64x4::splat(0.0);
179
180        // Process 4 elements at a time
181        for i in 0..chunks {
182            let x_chunk = f64x4::new([x[i * 4], x[i * 4 + 1], x[i * 4 + 2], x[i * 4 + 3]]);
183            let y_chunk = f64x4::new([y[i * 4], y[i * 4 + 1], y[i * 4 + 2], y[i * 4 + 3]]);
184            let diff = x_chunk - y_chunk;
185            sum += diff.abs();
186        }
187
188        // Sum the SIMD lanes
189        let sum_array = sum.to_array();
190        let mut result = sum_array[0] + sum_array[1] + sum_array[2] + sum_array[3];
191
192        // Handle remaining elements
193        for i in (chunks * 4)..len {
194            result += (x[i] - y[i]).abs();
195        }
196
197        result
198    }
199
200    /// NaN-aware Euclidean distance for missing data
201    pub fn nan_euclidean_distance_simd(x: &[f64], y: &[f64]) -> f64 {
202        assert_eq!(x.len(), y.len(), "Vectors must have the same length");
203
204        let mut sum_sq = 0.0;
205        let mut valid_count = 0;
206
207        // Use vectorized operations where possible
208        for (&x_val, &y_val) in x.iter().zip(y.iter()) {
209            if !x_val.is_nan() && !y_val.is_nan() {
210                let diff = x_val - y_val;
211                sum_sq += diff * diff;
212                valid_count += 1;
213            }
214        }
215
216        if valid_count > 0 {
217            (sum_sq / valid_count as f64).sqrt()
218        } else {
219            f64::INFINITY
220        }
221    }
222
223    /// Optimized cosine similarity calculation
224    pub fn cosine_similarity_simd(x: &[f64], y: &[f64]) -> f64 {
225        assert_eq!(x.len(), y.len(), "Vectors must have the same length");
226
227        if x.len() < 4 {
228            // Fallback for small vectors
229            let dot_product: f64 = x.iter().zip(y.iter()).map(|(a, b)| a * b).sum();
230            let norm_x: f64 = x.iter().map(|a| a * a).sum::<f64>().sqrt();
231            let norm_y: f64 = y.iter().map(|a| a * a).sum::<f64>().sqrt();
232
233            if norm_x == 0.0 || norm_y == 0.0 {
234                return 0.0;
235            }
236
237            return dot_product / (norm_x * norm_y);
238        }
239
240        unsafe { Self::cosine_similarity_simd_unsafe(x, y) }
241    }
242
243    /// Unsafe SIMD implementation for cosine similarity
244    unsafe fn cosine_similarity_simd_unsafe(x: &[f64], y: &[f64]) -> f64 {
245        let len = x.len();
246        let chunks = len / 4;
247
248        let mut dot_product = f64x4::splat(0.0);
249        let mut norm_x_sq = f64x4::splat(0.0);
250        let mut norm_y_sq = f64x4::splat(0.0);
251
252        // Process 4 elements at a time
253        for i in 0..chunks {
254            let x_chunk = f64x4::new([x[i * 4], x[i * 4 + 1], x[i * 4 + 2], x[i * 4 + 3]]);
255            let y_chunk = f64x4::new([y[i * 4], y[i * 4 + 1], y[i * 4 + 2], y[i * 4 + 3]]);
256
257            dot_product += x_chunk * y_chunk;
258            norm_x_sq += x_chunk * x_chunk;
259            norm_y_sq += y_chunk * y_chunk;
260        }
261
262        // Sum the SIMD lanes
263        let dot_array = dot_product.to_array();
264        let norm_x_array = norm_x_sq.to_array();
265        let norm_y_array = norm_y_sq.to_array();
266
267        let mut dot_result = dot_array[0] + dot_array[1] + dot_array[2] + dot_array[3];
268        let mut norm_x_result =
269            norm_x_array[0] + norm_x_array[1] + norm_x_array[2] + norm_x_array[3];
270        let mut norm_y_result =
271            norm_y_array[0] + norm_y_array[1] + norm_y_array[2] + norm_y_array[3];
272
273        // Handle remaining elements
274        for i in (chunks * 4)..len {
275            dot_result += x[i] * y[i];
276            norm_x_result += x[i] * x[i];
277            norm_y_result += y[i] * y[i];
278        }
279
280        let norm_x = norm_x_result.sqrt();
281        let norm_y = norm_y_result.sqrt();
282
283        if norm_x == 0.0 || norm_y == 0.0 {
284            0.0
285        } else {
286            dot_result / (norm_x * norm_y)
287        }
288    }
289}
290
291/// Optimized statistical calculations using SIMD
292pub struct SimdStatistics;
293
294impl SimdStatistics {
295    /// Optimized mean calculation using SIMD
296    pub fn mean_simd(data: &[f64]) -> f64 {
297        if data.is_empty() {
298            return 0.0;
299        }
300
301        if data.len() < 4 {
302            return data.iter().sum::<f64>() / data.len() as f64;
303        }
304
305        unsafe { Self::mean_simd_unsafe(data) }
306    }
307
308    /// Unsafe SIMD implementation for mean calculation
309    unsafe fn mean_simd_unsafe(data: &[f64]) -> f64 {
310        let len = data.len();
311        let chunks = len / 4;
312
313        let mut sum = f64x4::splat(0.0);
314
315        // Process 8 elements at a time
316        for i in 0..chunks {
317            let chunk = f64x4::new([
318                data[i * 4],
319                data[i * 4 + 1],
320                data[i * 4 + 2],
321                data[i * 4 + 3],
322            ]);
323            sum += chunk;
324        }
325
326        // Sum the SIMD lanes
327        let sum_array = sum.to_array();
328        let mut result = sum_array[0] + sum_array[1] + sum_array[2] + sum_array[3];
329
330        // Handle remaining elements
331        result += data
332            .iter()
333            .skip(chunks * 4)
334            .take(len - chunks * 4)
335            .sum::<f64>();
336
337        result / len as f64
338    }
339
340    /// Optimized variance calculation using SIMD
341    pub fn variance_simd(data: &[f64], mean: Option<f64>) -> f64 {
342        if data.len() <= 1 {
343            return 0.0;
344        }
345
346        let mean = mean.unwrap_or_else(|| Self::mean_simd(data));
347
348        if data.len() < 4 {
349            let sum_sq_diff: f64 = data.iter().map(|&x| (x - mean).powi(2)).sum();
350            return sum_sq_diff / (data.len() - 1) as f64;
351        }
352
353        unsafe { Self::variance_simd_unsafe(data, mean) }
354    }
355
356    /// Unsafe SIMD implementation for variance calculation
357    unsafe fn variance_simd_unsafe(data: &[f64], mean: f64) -> f64 {
358        let len = data.len();
359        let chunks = len / 4;
360
361        let mean_vec = f64x4::splat(mean);
362        let mut sum_sq_diff = f64x4::splat(0.0);
363
364        // Process 8 elements at a time
365        for i in 0..chunks {
366            let chunk = f64x4::new([
367                data[i * 4],
368                data[i * 4 + 1],
369                data[i * 4 + 2],
370                data[i * 4 + 3],
371            ]);
372            let diff = chunk - mean_vec;
373            sum_sq_diff += diff * diff;
374        }
375
376        // Sum the SIMD lanes
377        let sum_array = sum_sq_diff.to_array();
378        let mut result = sum_array[0] + sum_array[1] + sum_array[2] + sum_array[3];
379
380        // Handle remaining elements
381        result += data
382            .iter()
383            .skip(chunks * 4)
384            .take(len - chunks * 4)
385            .map(|&x| {
386                let diff = x - mean;
387                diff * diff
388            })
389            .sum::<f64>();
390
391        result / (len - 1) as f64
392    }
393
394    /// Optimized standard deviation calculation
395    pub fn std_dev_simd(data: &[f64], mean: Option<f64>) -> f64 {
396        Self::variance_simd(data, mean).sqrt()
397    }
398
399    /// Optimized min/max finding using SIMD
400    pub fn min_max_simd(data: &[f64]) -> (f64, f64) {
401        if data.is_empty() {
402            return (f64::NAN, f64::NAN);
403        }
404
405        if data.len() == 1 {
406            return (data[0], data[0]);
407        }
408
409        if data.len() < 4 {
410            let mut min_val = data[0];
411            let mut max_val = data[0];
412            for &val in &data[1..] {
413                if val < min_val {
414                    min_val = val;
415                }
416                if val > max_val {
417                    max_val = val;
418                }
419            }
420            return (min_val, max_val);
421        }
422
423        unsafe { Self::min_max_simd_unsafe(data) }
424    }
425
426    /// Unsafe SIMD implementation for min/max finding
427    unsafe fn min_max_simd_unsafe(data: &[f64]) -> (f64, f64) {
428        let len = data.len();
429        let chunks = len / 4;
430
431        let mut min_result = f64::INFINITY;
432        let mut max_result = f64::NEG_INFINITY;
433
434        // Process 4 elements at a time
435        for i in 0..chunks {
436            let base_idx = i * 4;
437            for j in 0..4 {
438                let val = data[base_idx + j];
439                if val < min_result {
440                    min_result = val;
441                }
442                if val > max_result {
443                    max_result = val;
444                }
445            }
446        }
447
448        // Handle remaining elements
449        for &val in data.iter().skip(chunks * 4).take(len - chunks * 4) {
450            if val < min_result {
451                min_result = val;
452            }
453            if val > max_result {
454                max_result = val;
455            }
456        }
457
458        (min_result, max_result)
459    }
460
461    /// Optimized quantile calculation using SIMD for sorting
462    pub fn quantile_simd(data: &[f64], q: f64) -> f64 {
463        if data.is_empty() {
464            return f64::NAN;
465        }
466
467        // Create a copy for sorting
468        let mut sorted_data: Vec<f64> = data.iter().filter(|&&x| !x.is_nan()).cloned().collect();
469
470        if sorted_data.is_empty() {
471            return f64::NAN;
472        }
473
474        // Use unstable sort for better performance
475        sorted_data.sort_unstable_by(|a, b| a.partial_cmp(b).unwrap());
476
477        let index = q * (sorted_data.len() - 1) as f64;
478        let lower = index.floor() as usize;
479        let upper = index.ceil() as usize;
480
481        if lower == upper {
482            sorted_data[lower]
483        } else {
484            let weight = index - lower as f64;
485            sorted_data[lower] * (1.0 - weight) + sorted_data[upper] * weight
486        }
487    }
488}
489
490/// Optimized matrix operations using SIMD
491pub struct SimdMatrixOps;
492
493impl SimdMatrixOps {
494    /// Optimized matrix-vector multiplication using SIMD
495    pub fn matrix_vector_multiply_simd(
496        matrix: &Array2<f64>,
497        vector: &Array1<f64>,
498    ) -> SklResult<Array1<f64>> {
499        let (n_rows, n_cols) = matrix.dim();
500
501        if n_cols != vector.len() {
502            return Err(SklearsError::InvalidInput(format!(
503                "Matrix columns {} must match vector length {}",
504                n_cols,
505                vector.len()
506            )));
507        }
508
509        let mut result = Array1::zeros(n_rows);
510        let vector_slice = vector.as_slice().unwrap();
511
512        // Parallel processing over rows
513        for i in 0..n_rows {
514            let row = matrix.row(i);
515            result[i] = Self::dot_product_simd(row.as_slice().unwrap(), vector_slice);
516        }
517
518        Ok(result)
519    }
520
521    /// Optimized dot product using SIMD
522    pub fn dot_product_simd(x: &[f64], y: &[f64]) -> f64 {
523        assert_eq!(x.len(), y.len(), "Vectors must have the same length");
524
525        if x.len() < 4 {
526            return x.iter().zip(y.iter()).map(|(&a, &b)| a * b).sum();
527        }
528
529        unsafe { Self::dot_product_simd_unsafe(x, y) }
530    }
531
532    /// Unsafe SIMD implementation for dot product
533    unsafe fn dot_product_simd_unsafe(x: &[f64], y: &[f64]) -> f64 {
534        let len = x.len();
535        let chunks = len / 4;
536
537        let mut sum = f64x4::splat(0.0);
538
539        // Process 8 elements at a time
540        for i in 0..chunks {
541            let x_chunk = f64x4::new([x[i * 4], x[i * 4 + 1], x[i * 4 + 2], x[i * 4 + 3]]);
542            let y_chunk = f64x4::new([y[i * 4], y[i * 4 + 1], y[i * 4 + 2], y[i * 4 + 3]]);
543            sum += x_chunk * y_chunk;
544        }
545
546        // Sum the SIMD lanes
547        let sum_array = sum.to_array();
548        let mut result = sum_array[0] + sum_array[1] + sum_array[2] + sum_array[3];
549
550        // Handle remaining elements
551        for i in (chunks * 4)..len {
552            result += x[i] * y[i];
553        }
554
555        result
556    }
557
558    /// Optimized matrix transpose with cache-friendly access
559    pub fn transpose_simd(matrix: &Array2<f64>) -> Array2<f64> {
560        let (n_rows, n_cols) = matrix.dim();
561        let mut result = Array2::zeros((n_cols, n_rows));
562
563        // Use cache-friendly block transpose for better performance
564        const BLOCK_SIZE: usize = 64;
565
566        for i_block in (0..n_rows).step_by(BLOCK_SIZE) {
567            for j_block in (0..n_cols).step_by(BLOCK_SIZE) {
568                let i_end = (i_block + BLOCK_SIZE).min(n_rows);
569                let j_end = (j_block + BLOCK_SIZE).min(n_cols);
570
571                for i in i_block..i_end {
572                    for j in j_block..j_end {
573                        result[[j, i]] = matrix[[i, j]];
574                    }
575                }
576            }
577        }
578
579        result
580    }
581
582    /// Optimized matrix-matrix multiplication using SIMD and blocking
583    pub fn matrix_multiply_simd(a: &Array2<f64>, b: &Array2<f64>) -> SklResult<Array2<f64>> {
584        let (a_rows, a_cols) = a.dim();
585        let (b_rows, b_cols) = b.dim();
586
587        if a_cols != b_rows {
588            return Err(SklearsError::InvalidInput(format!(
589                "Matrix dimensions incompatible: {}x{} * {}x{}",
590                a_rows, a_cols, b_rows, b_cols
591            )));
592        }
593
594        let mut result = Array2::zeros((a_rows, b_cols));
595
596        // Use cache-friendly blocked multiplication
597        const BLOCK_SIZE: usize = 64;
598
599        for i_block in (0..a_rows).step_by(BLOCK_SIZE) {
600            for j_block in (0..b_cols).step_by(BLOCK_SIZE) {
601                for k_block in (0..a_cols).step_by(BLOCK_SIZE) {
602                    let i_end = (i_block + BLOCK_SIZE).min(a_rows);
603                    let j_end = (j_block + BLOCK_SIZE).min(b_cols);
604                    let k_end = (k_block + BLOCK_SIZE).min(a_cols);
605
606                    for i in i_block..i_end {
607                        for j in j_block..j_end {
608                            let mut sum = 0.0;
609
610                            // Use SIMD for inner product if block is large enough
611                            let row = a.row(i);
612                            let row_slice = row.as_slice().unwrap();
613                            let k_slice = &row_slice[k_block..k_end];
614                            let b_slice: Vec<f64> = (k_block..k_end).map(|k| b[[k, j]]).collect();
615
616                            if k_slice.len() >= 4 {
617                                unsafe {
618                                    sum += Self::dot_product_simd_unsafe(k_slice, &b_slice);
619                                }
620                            } else {
621                                for k in k_block..k_end {
622                                    sum += a[[i, k]] * b[[k, j]];
623                                }
624                            }
625
626                            result[[i, j]] += sum;
627                        }
628                    }
629                }
630            }
631        }
632
633        Ok(result)
634    }
635}
636
637/// Optimized K-means clustering
638pub struct SimdKMeans;
639
640impl SimdKMeans {
641    /// Optimized centroid calculation
642    pub fn calculate_centroids_simd(data: &Array2<f64>, labels: &[usize], k: usize) -> Array2<f64> {
643        let (_n_samples, n_features) = data.dim();
644        let mut centroids = Array2::zeros((k, n_features));
645        let mut counts = vec![0; k];
646
647        // Count points in each cluster
648        for &label in labels {
649            counts[label] += 1;
650        }
651
652        // Calculate centroids using parallel processing
653        centroids
654            .axis_iter_mut(Axis(0))
655            .enumerate()
656            .par_bridge()
657            .for_each(|(cluster_idx, mut centroid)| {
658                let mut sums = vec![0.0; n_features];
659
660                for (sample_idx, &label) in labels.iter().enumerate() {
661                    if label == cluster_idx {
662                        let sample = data.row(sample_idx);
663                        for (i, &val) in sample.iter().enumerate() {
664                            sums[i] += val;
665                        }
666                    }
667                }
668
669                // Divide by count to get centroid
670                if counts[cluster_idx] > 0 {
671                    let count = counts[cluster_idx] as f64;
672                    for (i, &sum) in sums.iter().enumerate() {
673                        centroid[i] = sum / count;
674                    }
675                }
676            });
677
678        centroids
679    }
680}
681
682/// Enhanced imputation operations with SIMD optimizations
683pub struct SimdImputationOps;
684
685impl SimdImputationOps {
686    /// Optimized weighted mean calculation for KNN imputation using SIMD
687    pub fn weighted_mean_simd(values: &[f64], weights: &[f64]) -> f64 {
688        assert_eq!(
689            values.len(),
690            weights.len(),
691            "Values and weights must have same length"
692        );
693
694        if values.is_empty() {
695            return 0.0;
696        }
697
698        if values.len() < 8 {
699            let weighted_sum: f64 = values
700                .iter()
701                .zip(weights.iter())
702                .map(|(&v, &w)| v * w)
703                .sum();
704            let weight_sum: f64 = weights.iter().sum();
705
706            return if weight_sum > 0.0 {
707                weighted_sum / weight_sum
708            } else {
709                SimdStatistics::mean_simd(values)
710            };
711        }
712
713        unsafe { Self::weighted_mean_simd_unsafe(values, weights) }
714    }
715
716    /// Unsafe SIMD implementation for weighted mean
717    unsafe fn weighted_mean_simd_unsafe(values: &[f64], weights: &[f64]) -> f64 {
718        let len = values.len();
719        let chunks = len / 4;
720
721        let mut weighted_sum = f64x4::splat(0.0);
722        let mut weight_sum = f64x4::splat(0.0);
723
724        // Process 8 elements at a time
725        for i in 0..chunks {
726            let values_chunk = f64x4::new([
727                values[i * 4],
728                values[i * 4 + 1],
729                values[i * 4 + 2],
730                values[i * 4 + 3],
731            ]);
732            let weights_chunk = f64x4::new([
733                weights[i * 4],
734                weights[i * 4 + 1],
735                weights[i * 4 + 2],
736                weights[i * 4 + 3],
737            ]);
738
739            weighted_sum += values_chunk * weights_chunk;
740            weight_sum += weights_chunk;
741        }
742
743        // Sum the SIMD lanes
744        let weighted_array = weighted_sum.to_array();
745        let weight_array = weight_sum.to_array();
746        let mut weighted_result =
747            weighted_array[0] + weighted_array[1] + weighted_array[2] + weighted_array[3];
748        let mut weight_result =
749            weight_array[0] + weight_array[1] + weight_array[2] + weight_array[3];
750
751        // Handle remaining elements
752        for i in (chunks * 4)..len {
753            weighted_result += values[i] * weights[i];
754            weight_result += weights[i];
755        }
756
757        if weight_result > 0.0 {
758            weighted_result / weight_result
759        } else {
760            SimdStatistics::mean_simd(values)
761        }
762    }
763
764    /// Optimized missing value detection using SIMD
765    pub fn count_missing_simd(data: &[f64]) -> usize {
766        if data.len() < 4 {
767            return data.iter().filter(|&&x| x.is_nan()).count();
768        }
769
770        unsafe { Self::count_missing_simd_unsafe(data) }
771    }
772
773    /// Unsafe SIMD implementation for missing value counting
774    unsafe fn count_missing_simd_unsafe(data: &[f64]) -> usize {
775        let len = data.len();
776        let chunks = len / 4;
777
778        let mut missing_count = 0;
779
780        // Process 4 elements at a time (manual check since f64x4 doesn't have is_nan)
781        for i in 0..chunks {
782            let base_idx = i * 4;
783            for j in 0..4 {
784                if data[base_idx + j].is_nan() {
785                    missing_count += 1;
786                }
787            }
788        }
789
790        // Handle remaining elements
791        missing_count += data
792            .iter()
793            .skip(chunks * 4)
794            .take(len - chunks * 4)
795            .filter(|x| x.is_nan())
796            .count();
797
798        missing_count
799    }
800
801    /// Optimized batch distance calculation for KNN
802    pub fn batch_distances_simd(
803        query_point: &[f64],
804        data_points: &Array2<f64>,
805        metric: &str,
806    ) -> Vec<f64> {
807        let n_points = data_points.nrows();
808        let mut distances = Vec::with_capacity(n_points);
809
810        match metric {
811            "euclidean" => {
812                distances.par_extend((0..n_points).into_par_iter().map(|i| {
813                    let row = data_points.row(i);
814                    let point = row.as_slice().unwrap();
815                    SimdDistanceCalculator::euclidean_distance_simd(query_point, point)
816                }));
817            }
818            "manhattan" => {
819                distances.par_extend((0..n_points).into_par_iter().map(|i| {
820                    let row = data_points.row(i);
821                    let point = row.as_slice().unwrap();
822                    SimdDistanceCalculator::manhattan_distance_simd(query_point, point)
823                }));
824            }
825            "cosine" => {
826                distances.par_extend((0..n_points).into_par_iter().map(|i| {
827                    let row = data_points.row(i);
828                    let point = row.as_slice().unwrap();
829                    1.0 - SimdDistanceCalculator::cosine_similarity_simd(query_point, point)
830                }));
831            }
832            "nan_euclidean" => {
833                distances.par_extend((0..n_points).into_par_iter().map(|i| {
834                    let row = data_points.row(i);
835                    let point = row.as_slice().unwrap();
836                    SimdDistanceCalculator::nan_euclidean_distance_simd(query_point, point)
837                }));
838            }
839            _ => {
840                // Fallback to euclidean
841                distances.par_extend((0..n_points).into_par_iter().map(|i| {
842                    let row = data_points.row(i);
843                    let point = row.as_slice().unwrap();
844                    SimdDistanceCalculator::euclidean_distance_simd(query_point, point)
845                }));
846            }
847        }
848
849        distances
850    }
851
852    /// Optimized k-nearest neighbors finding
853    pub fn find_knn_simd(
854        query_point: &[f64],
855        data_points: &Array2<f64>,
856        k: usize,
857        metric: &str,
858    ) -> Vec<(usize, f64)> {
859        let distances = Self::batch_distances_simd(query_point, data_points, metric);
860
861        let mut indexed_distances: Vec<(usize, f64)> = distances
862            .into_iter()
863            .enumerate()
864            .filter(|(_, dist)| dist.is_finite())
865            .collect();
866
867        // Use partial sort for better performance when k << n
868        if k < indexed_distances.len() {
869            indexed_distances.select_nth_unstable_by(k, |a, b| a.1.partial_cmp(&b.1).unwrap());
870            indexed_distances.truncate(k);
871        }
872
873        indexed_distances.sort_unstable_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
874        indexed_distances
875    }
876
877    /// Memory-efficient imputation for large datasets
878    pub fn streaming_mean_imputation(data: &mut Array2<f64>, chunk_size: usize, missing_val: f64) {
879        let (n_rows, n_cols) = data.dim();
880
881        // Calculate column means in chunks to save memory
882        let mut column_means = vec![0.0; n_cols];
883        let mut column_counts = vec![0; n_cols];
884
885        for row_chunk in (0..n_rows).step_by(chunk_size) {
886            let end_row = (row_chunk + chunk_size).min(n_rows);
887
888            for i in row_chunk..end_row {
889                for j in 0..n_cols {
890                    let val = data[[i, j]];
891                    let is_missing = if missing_val.is_nan() {
892                        val.is_nan()
893                    } else {
894                        (val - missing_val).abs() < f64::EPSILON
895                    };
896
897                    if !is_missing {
898                        column_means[j] += val;
899                        column_counts[j] += 1;
900                    }
901                }
902            }
903        }
904
905        // Finalize means
906        for j in 0..n_cols {
907            if column_counts[j] > 0 {
908                column_means[j] /= column_counts[j] as f64;
909            }
910        }
911
912        // Apply imputation in chunks
913        for row_chunk in (0..n_rows).step_by(chunk_size) {
914            let end_row = (row_chunk + chunk_size).min(n_rows);
915
916            for i in row_chunk..end_row {
917                for j in 0..n_cols {
918                    let val = data[[i, j]];
919                    let is_missing = if missing_val.is_nan() {
920                        val.is_nan()
921                    } else {
922                        (val - missing_val).abs() < f64::EPSILON
923                    };
924
925                    if is_missing && column_counts[j] > 0 {
926                        data[[i, j]] = column_means[j];
927                    }
928                }
929            }
930        }
931    }
932}
933
934#[allow(non_snake_case)]
935#[cfg(test)]
936mod tests {
937    use super::*;
938    use approx::assert_abs_diff_eq;
939
940    #[test]
941    fn test_euclidean_distance() {
942        let x = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0];
943        let y = vec![2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
944
945        let distance = SimdDistanceCalculator::euclidean_distance_simd(&x, &y);
946        let expected = 3.0; // sqrt(9 * 1^2) = 3.0
947
948        assert_abs_diff_eq!(distance, expected, epsilon = 1e-10);
949    }
950
951    #[test]
952    fn test_mean() {
953        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
954        let mean = SimdStatistics::mean_simd(&data);
955        let expected = 5.5;
956
957        assert_abs_diff_eq!(mean, expected, epsilon = 1e-10);
958    }
959
960    #[test]
961    fn test_dot_product() {
962        let x = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
963        let y = vec![2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0];
964
965        let dot_product = SimdMatrixOps::dot_product_simd(&x, &y);
966        let expected = 240.0; // 1*2 + 2*3 + ... + 8*9
967
968        assert_abs_diff_eq!(dot_product, expected, epsilon = 1e-10);
969    }
970
971    #[test]
972    fn test_weighted_mean() {
973        let values = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
974        let weights = vec![1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0];
975
976        let weighted_mean = SimdImputationOps::weighted_mean_simd(&values, &weights);
977        let expected = 4.5; // Simple mean when all weights are equal
978
979        assert_abs_diff_eq!(weighted_mean, expected, epsilon = 1e-10);
980    }
981
982    #[test]
983    fn test_cache_optimized_data() {
984        let data = Array2::from_shape_vec(
985            (3, 4),
986            vec![
987                1.0,
988                2.0,
989                f64::NAN,
990                4.0,
991                5.0,
992                f64::NAN,
993                7.0,
994                8.0,
995                9.0,
996                10.0,
997                11.0,
998                f64::NAN,
999            ],
1000        )
1001        .unwrap();
1002
1003        let optimized = CacheOptimizedData::new(&data, f64::NAN);
1004
1005        // Test value access
1006        assert_eq!(optimized.get(0, 0), Some(1.0));
1007        assert_eq!(optimized.get(0, 1), Some(2.0));
1008        assert_eq!(optimized.get(1, 0), Some(5.0));
1009
1010        // Test missing value detection
1011        assert!(optimized.is_missing(0, 2));
1012        assert!(optimized.is_missing(1, 1));
1013        assert!(optimized.is_missing(2, 3));
1014        assert!(!optimized.is_missing(0, 0));
1015        assert!(!optimized.is_missing(1, 0));
1016    }
1017
1018    #[test]
1019    fn test_simd_distance_calculations() {
1020        let x = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
1021        let y = vec![2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0];
1022
1023        // Test Manhattan distance
1024        let manhattan = SimdDistanceCalculator::manhattan_distance_simd(&x, &y);
1025        assert_abs_diff_eq!(manhattan, 10.0, epsilon = 1e-10);
1026
1027        // Test cosine similarity
1028        let cosine_sim = SimdDistanceCalculator::cosine_similarity_simd(&x, &y);
1029        assert!(cosine_sim > 0.9); // Should be very similar vectors
1030    }
1031
1032    #[test]
1033    fn test_simd_statistics() {
1034        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
1035
1036        // Test variance
1037        let variance = SimdStatistics::variance_simd(&data, None);
1038        let expected_variance = 9.166666666666666; // Variance of 1..10
1039        assert_abs_diff_eq!(variance, expected_variance, epsilon = 1e-10);
1040
1041        // Test min/max
1042        let (min_val, max_val) = SimdStatistics::min_max_simd(&data);
1043        assert_eq!(min_val, 1.0);
1044        assert_eq!(max_val, 10.0);
1045
1046        // Test quantile
1047        let median = SimdStatistics::quantile_simd(&data, 0.5);
1048        assert_abs_diff_eq!(median, 5.5, epsilon = 1e-10);
1049    }
1050
1051    #[test]
1052    fn test_matrix_operations() {
1053        let matrix =
1054            Array2::from_shape_vec((3, 3), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0])
1055                .unwrap();
1056
1057        // Test transpose
1058        let transposed = SimdMatrixOps::transpose_simd(&matrix);
1059        assert_eq!(transposed[[0, 0]], 1.0);
1060        assert_eq!(transposed[[0, 1]], 4.0);
1061        assert_eq!(transposed[[0, 2]], 7.0);
1062        assert_eq!(transposed[[1, 0]], 2.0);
1063
1064        // Test matrix-vector multiplication
1065        let vector = Array1::from_vec(vec![1.0, 2.0, 3.0]);
1066        let result = SimdMatrixOps::matrix_vector_multiply_simd(&matrix, &vector).unwrap();
1067
1068        // Expected: [1*1 + 2*2 + 3*3, 4*1 + 5*2 + 6*3, 7*1 + 8*2 + 9*3] = [14, 32, 50]
1069        assert_abs_diff_eq!(result[0], 14.0, epsilon = 1e-10);
1070        assert_abs_diff_eq!(result[1], 32.0, epsilon = 1e-10);
1071        assert_abs_diff_eq!(result[2], 50.0, epsilon = 1e-10);
1072    }
1073
1074    #[test]
1075    fn test_batch_distances() {
1076        let query = vec![1.0, 2.0, 3.0];
1077        let data = Array2::from_shape_vec(
1078            (3, 3),
1079            vec![
1080                1.0, 2.0, 3.0, // Distance 0
1081                2.0, 3.0, 4.0, // Distance sqrt(3)
1082                4.0, 5.0, 6.0, // Distance sqrt(27)
1083            ],
1084        )
1085        .unwrap();
1086
1087        let distances = SimdImputationOps::batch_distances_simd(&query, &data, "euclidean");
1088
1089        assert_eq!(distances.len(), 3);
1090        assert_abs_diff_eq!(distances[0], 0.0, epsilon = 1e-10);
1091        assert_abs_diff_eq!(distances[1], 3.0_f64.sqrt(), epsilon = 1e-10);
1092        assert_abs_diff_eq!(distances[2], 27.0_f64.sqrt(), epsilon = 1e-10);
1093    }
1094
1095    #[test]
1096    fn test_knn_finding() {
1097        let query = vec![1.0, 2.0, 3.0];
1098        let data = Array2::from_shape_vec(
1099            (5, 3),
1100            vec![
1101                1.0, 2.0, 3.0, // Distance 0 (closest)
1102                2.0, 3.0, 4.0, // Distance sqrt(3)
1103                4.0, 5.0, 6.0, // Distance sqrt(27)
1104                0.5, 1.5, 2.5, // Distance sqrt(0.75)
1105                10.0, 11.0, 12.0, // Distance sqrt(243) (farthest)
1106            ],
1107        )
1108        .unwrap();
1109
1110        let knn = SimdImputationOps::find_knn_simd(&query, &data, 3, "euclidean");
1111
1112        assert_eq!(knn.len(), 3);
1113        assert_eq!(knn[0].0, 0); // Closest is the identical point
1114        assert_eq!(knn[1].0, 3); // Second closest
1115        assert_eq!(knn[2].0, 1); // Third closest
1116    }
1117
1118    #[test]
1119    fn test_missing_count() {
1120        let data = vec![
1121            1.0,
1122            f64::NAN,
1123            3.0,
1124            f64::NAN,
1125            5.0,
1126            6.0,
1127            f64::NAN,
1128            8.0,
1129            9.0,
1130            f64::NAN,
1131        ];
1132        let count = SimdImputationOps::count_missing_simd(&data);
1133        assert_eq!(count, 4);
1134    }
1135
1136    #[test]
1137    fn test_streaming_imputation() {
1138        let mut data = Array2::from_shape_vec(
1139            (4, 3),
1140            vec![
1141                1.0,
1142                f64::NAN,
1143                3.0,
1144                4.0,
1145                5.0,
1146                f64::NAN,
1147                f64::NAN,
1148                8.0,
1149                9.0,
1150                10.0,
1151                11.0,
1152                12.0,
1153            ],
1154        )
1155        .unwrap();
1156
1157        SimdImputationOps::streaming_mean_imputation(&mut data, 2, f64::NAN);
1158
1159        // Check that missing values were replaced with column means
1160        // Column 0 mean: (1.0 + 4.0 + 10.0) / 3 = 5.0
1161        // Column 1 mean: (5.0 + 8.0 + 11.0) / 3 = 8.0
1162        // Column 2 mean: (3.0 + 9.0 + 12.0) / 3 = 8.0
1163
1164        assert_abs_diff_eq!(data[[0, 1]], 8.0, epsilon = 1e-10);
1165        assert_abs_diff_eq!(data[[1, 2]], 8.0, epsilon = 1e-10);
1166        assert_abs_diff_eq!(data[[2, 0]], 5.0, epsilon = 1e-10);
1167
1168        // Non-missing values should remain unchanged
1169        assert_abs_diff_eq!(data[[0, 0]], 1.0, epsilon = 1e-10);
1170        assert_abs_diff_eq!(data[[1, 0]], 4.0, epsilon = 1e-10);
1171        assert_abs_diff_eq!(data[[3, 2]], 12.0, epsilon = 1e-10);
1172    }
1173}