sklears_preprocessing/
simd_optimizations.rs

1//! SIMD optimizations for preprocessing operations
2//!
3//! This module provides SIMD-accelerated implementations of common preprocessing
4//! operations like element-wise arithmetic, statistical calculations, and data
5//! transformations that are frequently used in scaling, normalization, and other
6//! preprocessing tasks.
7
8use scirs2_core::ndarray::{Array1, Array2, Axis};
9
10#[cfg(feature = "serde")]
11use serde::{Deserialize, Serialize};
12
13/// Configuration for SIMD optimizations
14#[derive(Debug, Clone)]
15#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
16pub struct SimdConfig {
17    /// Whether to use SIMD optimizations
18    pub enabled: bool,
19    /// Minimum array size to use SIMD (avoids overhead for small arrays)
20    pub min_size_threshold: usize,
21    /// Force specific SIMD width (None for auto-detection)
22    pub force_width: Option<usize>,
23    /// Whether to use parallel SIMD for large arrays
24    pub use_parallel: bool,
25}
26
27impl Default for SimdConfig {
28    fn default() -> Self {
29        Self {
30            enabled: true,
31            min_size_threshold: 32,
32            force_width: None,
33            use_parallel: true,
34        }
35    }
36}
37
38/// SIMD-optimized element-wise addition of a scalar to a vector
39pub fn add_scalar_f64_simd(data: &mut [f64], scalar: f64, config: &SimdConfig) {
40    if !config.enabled || data.len() < config.min_size_threshold {
41        add_scalar_f64_scalar(data, scalar);
42        return;
43    }
44
45    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
46    {
47        if is_x86_feature_detected!("avx2") {
48            unsafe { add_scalar_f64_avx2(data, scalar) };
49            return;
50        } else if is_x86_feature_detected!("sse2") {
51            unsafe { add_scalar_f64_sse2(data, scalar) };
52            return;
53        }
54    }
55
56    #[cfg(target_arch = "aarch64")]
57    unsafe {
58        add_scalar_f64_neon(data, scalar)
59    };
60
61    #[cfg(not(target_arch = "aarch64"))]
62    add_scalar_f64_scalar(data, scalar);
63}
64
65/// SIMD-optimized element-wise subtraction of a scalar from a vector
66pub fn sub_scalar_f64_simd(data: &mut [f64], scalar: f64, config: &SimdConfig) {
67    add_scalar_f64_simd(data, -scalar, config);
68}
69
70/// SIMD-optimized element-wise multiplication of a vector by a scalar
71pub fn mul_scalar_f64_simd(data: &mut [f64], scalar: f64, config: &SimdConfig) {
72    if !config.enabled || data.len() < config.min_size_threshold {
73        mul_scalar_f64_scalar(data, scalar);
74        return;
75    }
76
77    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
78    {
79        if is_x86_feature_detected!("avx2") {
80            unsafe { mul_scalar_f64_avx2(data, scalar) };
81            return;
82        } else if is_x86_feature_detected!("sse2") {
83            unsafe { mul_scalar_f64_sse2(data, scalar) };
84            return;
85        }
86    }
87
88    #[cfg(target_arch = "aarch64")]
89    unsafe {
90        mul_scalar_f64_neon(data, scalar)
91    };
92
93    #[cfg(not(target_arch = "aarch64"))]
94    mul_scalar_f64_scalar(data, scalar);
95}
96
97/// SIMD-optimized element-wise division of a vector by a scalar
98pub fn div_scalar_f64_simd(data: &mut [f64], scalar: f64, config: &SimdConfig) {
99    if scalar != 0.0 {
100        mul_scalar_f64_simd(data, 1.0 / scalar, config);
101    }
102}
103
104/// SIMD-optimized vector addition
105pub fn add_vectors_f64_simd(a: &[f64], b: &[f64], result: &mut [f64], config: &SimdConfig) {
106    assert_eq!(a.len(), b.len());
107    assert_eq!(a.len(), result.len());
108
109    if !config.enabled || a.len() < config.min_size_threshold {
110        add_vectors_f64_scalar(a, b, result);
111        return;
112    }
113
114    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
115    {
116        if is_x86_feature_detected!("avx2") {
117            unsafe { add_vectors_f64_avx2(a, b, result) };
118            return;
119        } else if is_x86_feature_detected!("sse2") {
120            unsafe { add_vectors_f64_sse2(a, b, result) };
121            return;
122        }
123    }
124
125    #[cfg(target_arch = "aarch64")]
126    unsafe {
127        add_vectors_f64_neon(a, b, result)
128    };
129
130    #[cfg(not(target_arch = "aarch64"))]
131    add_vectors_f64_scalar(a, b, result);
132}
133
134/// SIMD-optimized vector subtraction
135pub fn sub_vectors_f64_simd(a: &[f64], b: &[f64], result: &mut [f64], config: &SimdConfig) {
136    assert_eq!(a.len(), b.len());
137    assert_eq!(a.len(), result.len());
138
139    if !config.enabled || a.len() < config.min_size_threshold {
140        sub_vectors_f64_scalar(a, b, result);
141        return;
142    }
143
144    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
145    {
146        if is_x86_feature_detected!("avx2") {
147            unsafe { sub_vectors_f64_avx2(a, b, result) };
148            return;
149        } else if is_x86_feature_detected!("sse2") {
150            unsafe { sub_vectors_f64_sse2(a, b, result) };
151            return;
152        }
153    }
154
155    #[cfg(target_arch = "aarch64")]
156    unsafe {
157        sub_vectors_f64_neon(a, b, result)
158    };
159
160    #[cfg(not(target_arch = "aarch64"))]
161    sub_vectors_f64_scalar(a, b, result);
162}
163
164/// SIMD-optimized mean calculation
165pub fn mean_f64_simd(data: &[f64], config: &SimdConfig) -> f64 {
166    if data.is_empty() {
167        return 0.0;
168    }
169
170    if !config.enabled || data.len() < config.min_size_threshold {
171        return mean_f64_scalar(data);
172    }
173
174    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
175    {
176        if is_x86_feature_detected!("avx2") {
177            return unsafe { mean_f64_avx2(data) };
178        } else if is_x86_feature_detected!("sse2") {
179            return unsafe { mean_f64_sse2(data) };
180        }
181    }
182
183    #[cfg(target_arch = "aarch64")]
184    return unsafe { mean_f64_neon(data) };
185
186    #[cfg(not(target_arch = "aarch64"))]
187    mean_f64_scalar(data)
188}
189
190/// SIMD-optimized variance calculation
191pub fn variance_f64_simd(data: &[f64], mean: f64, config: &SimdConfig) -> f64 {
192    if data.len() <= 1 {
193        return 0.0;
194    }
195
196    if !config.enabled || data.len() < config.min_size_threshold {
197        return variance_f64_scalar(data, mean);
198    }
199
200    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
201    {
202        if is_x86_feature_detected!("avx2") {
203            return unsafe { variance_f64_avx2(data, mean) };
204        } else if is_x86_feature_detected!("sse2") {
205            return unsafe { variance_f64_sse2(data, mean) };
206        }
207    }
208
209    #[cfg(target_arch = "aarch64")]
210    return unsafe { variance_f64_neon(data, mean) };
211
212    #[cfg(not(target_arch = "aarch64"))]
213    variance_f64_scalar(data, mean)
214}
215
216/// SIMD-optimized min/max finding
217pub fn min_max_f64_simd(data: &[f64], config: &SimdConfig) -> (f64, f64) {
218    if data.is_empty() {
219        return (f64::NAN, f64::NAN);
220    }
221
222    if !config.enabled || data.len() < config.min_size_threshold {
223        return min_max_f64_scalar(data);
224    }
225
226    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
227    {
228        if is_x86_feature_detected!("avx2") {
229            return unsafe { min_max_f64_avx2(data) };
230        } else if is_x86_feature_detected!("sse2") {
231            return unsafe { min_max_f64_sse2(data) };
232        }
233    }
234
235    #[cfg(target_arch = "aarch64")]
236    return unsafe { min_max_f64_neon(data) };
237
238    #[cfg(not(target_arch = "aarch64"))]
239    min_max_f64_scalar(data)
240}
241
242// Scalar fallback implementations
243
244fn add_scalar_f64_scalar(data: &mut [f64], scalar: f64) {
245    for x in data.iter_mut() {
246        *x += scalar;
247    }
248}
249
250fn mul_scalar_f64_scalar(data: &mut [f64], scalar: f64) {
251    for x in data.iter_mut() {
252        *x *= scalar;
253    }
254}
255
256fn add_vectors_f64_scalar(a: &[f64], b: &[f64], result: &mut [f64]) {
257    for ((x, y), r) in a.iter().zip(b.iter()).zip(result.iter_mut()) {
258        *r = x + y;
259    }
260}
261
262fn sub_vectors_f64_scalar(a: &[f64], b: &[f64], result: &mut [f64]) {
263    for ((x, y), r) in a.iter().zip(b.iter()).zip(result.iter_mut()) {
264        *r = x - y;
265    }
266}
267
268fn mean_f64_scalar(data: &[f64]) -> f64 {
269    data.iter().sum::<f64>() / data.len() as f64
270}
271
272fn variance_f64_scalar(data: &[f64], mean: f64) -> f64 {
273    data.iter().map(|x| (x - mean).powi(2)).sum::<f64>() / (data.len() - 1) as f64
274}
275
276fn min_max_f64_scalar(data: &[f64]) -> (f64, f64) {
277    let mut min = data[0];
278    let mut max = data[0];
279
280    for &x in &data[1..] {
281        if x < min {
282            min = x;
283        }
284        if x > max {
285            max = x;
286        }
287    }
288
289    (min, max)
290}
291
292// x86_64 SSE2 implementations
293
294#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
295#[target_feature(enable = "sse2")]
296unsafe fn add_scalar_f64_sse2(data: &mut [f64], scalar: f64) {
297    use std::arch::x86_64::*;
298
299    let scalar_vec = _mm_set1_pd(scalar);
300    let mut i = 0;
301
302    while i + 2 <= data.len() {
303        let data_vec = _mm_loadu_pd(data.as_ptr().add(i));
304        let result = _mm_add_pd(data_vec, scalar_vec);
305        _mm_storeu_pd(data.as_mut_ptr().add(i), result);
306        i += 2;
307    }
308
309    // Handle remaining elements
310    while i < data.len() {
311        data[i] += scalar;
312        i += 1;
313    }
314}
315
316#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
317#[target_feature(enable = "sse2")]
318unsafe fn mul_scalar_f64_sse2(data: &mut [f64], scalar: f64) {
319    use std::arch::x86_64::*;
320
321    let scalar_vec = _mm_set1_pd(scalar);
322    let mut i = 0;
323
324    while i + 2 <= data.len() {
325        let data_vec = _mm_loadu_pd(data.as_ptr().add(i));
326        let result = _mm_mul_pd(data_vec, scalar_vec);
327        _mm_storeu_pd(data.as_mut_ptr().add(i), result);
328        i += 2;
329    }
330
331    while i < data.len() {
332        data[i] *= scalar;
333        i += 1;
334    }
335}
336
337#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
338#[target_feature(enable = "sse2")]
339unsafe fn add_vectors_f64_sse2(a: &[f64], b: &[f64], result: &mut [f64]) {
340    use std::arch::x86_64::*;
341
342    let mut i = 0;
343
344    while i + 2 <= a.len() {
345        let a_vec = _mm_loadu_pd(a.as_ptr().add(i));
346        let b_vec = _mm_loadu_pd(b.as_ptr().add(i));
347        let result_vec = _mm_add_pd(a_vec, b_vec);
348        _mm_storeu_pd(result.as_mut_ptr().add(i), result_vec);
349        i += 2;
350    }
351
352    while i < a.len() {
353        result[i] = a[i] + b[i];
354        i += 1;
355    }
356}
357
358#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
359#[target_feature(enable = "sse2")]
360unsafe fn sub_vectors_f64_sse2(a: &[f64], b: &[f64], result: &mut [f64]) {
361    use std::arch::x86_64::*;
362
363    let mut i = 0;
364
365    while i + 2 <= a.len() {
366        let a_vec = _mm_loadu_pd(a.as_ptr().add(i));
367        let b_vec = _mm_loadu_pd(b.as_ptr().add(i));
368        let result_vec = _mm_sub_pd(a_vec, b_vec);
369        _mm_storeu_pd(result.as_mut_ptr().add(i), result_vec);
370        i += 2;
371    }
372
373    while i < a.len() {
374        result[i] = a[i] - b[i];
375        i += 1;
376    }
377}
378
379#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
380#[target_feature(enable = "sse2")]
381unsafe fn mean_f64_sse2(data: &[f64]) -> f64 {
382    use std::arch::x86_64::*;
383
384    let mut sum = _mm_setzero_pd();
385    let mut i = 0;
386
387    while i + 2 <= data.len() {
388        let data_vec = _mm_loadu_pd(data.as_ptr().add(i));
389        sum = _mm_add_pd(sum, data_vec);
390        i += 2;
391    }
392
393    let mut result = [0.0f64; 2];
394    _mm_storeu_pd(result.as_mut_ptr(), sum);
395    let mut scalar_sum = result[0] + result[1];
396
397    while i < data.len() {
398        scalar_sum += data[i];
399        i += 1;
400    }
401
402    scalar_sum / data.len() as f64
403}
404
405#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
406#[target_feature(enable = "sse2")]
407unsafe fn variance_f64_sse2(data: &[f64], mean: f64) -> f64 {
408    use std::arch::x86_64::*;
409
410    let mean_vec = _mm_set1_pd(mean);
411    let mut sum = _mm_setzero_pd();
412    let mut i = 0;
413
414    while i + 2 <= data.len() {
415        let data_vec = _mm_loadu_pd(data.as_ptr().add(i));
416        let diff = _mm_sub_pd(data_vec, mean_vec);
417        let squared = _mm_mul_pd(diff, diff);
418        sum = _mm_add_pd(sum, squared);
419        i += 2;
420    }
421
422    let mut result = [0.0f64; 2];
423    _mm_storeu_pd(result.as_mut_ptr(), sum);
424    let mut scalar_sum = result[0] + result[1];
425
426    while i < data.len() {
427        let diff = data[i] - mean;
428        scalar_sum += diff * diff;
429        i += 1;
430    }
431
432    scalar_sum / (data.len() - 1) as f64
433}
434
435#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
436#[target_feature(enable = "sse2")]
437unsafe fn min_max_f64_sse2(data: &[f64]) -> (f64, f64) {
438    use std::arch::x86_64::*;
439
440    let mut min_vec = _mm_set1_pd(data[0]);
441    let mut max_vec = _mm_set1_pd(data[0]);
442    let mut i = 0;
443
444    while i + 2 <= data.len() {
445        let data_vec = _mm_loadu_pd(data.as_ptr().add(i));
446        min_vec = _mm_min_pd(min_vec, data_vec);
447        max_vec = _mm_max_pd(max_vec, data_vec);
448        i += 2;
449    }
450
451    let mut min_result = [0.0f64; 2];
452    let mut max_result = [0.0f64; 2];
453    _mm_storeu_pd(min_result.as_mut_ptr(), min_vec);
454    _mm_storeu_pd(max_result.as_mut_ptr(), max_vec);
455
456    let mut min_val = min_result[0].min(min_result[1]);
457    let mut max_val = max_result[0].max(max_result[1]);
458
459    while i < data.len() {
460        if data[i] < min_val {
461            min_val = data[i];
462        }
463        if data[i] > max_val {
464            max_val = data[i];
465        }
466        i += 1;
467    }
468
469    (min_val, max_val)
470}
471
472// x86_64 AVX2 implementations
473
474#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
475#[target_feature(enable = "avx2")]
476unsafe fn add_scalar_f64_avx2(data: &mut [f64], scalar: f64) {
477    use std::arch::x86_64::*;
478
479    let scalar_vec = _mm256_set1_pd(scalar);
480    let mut i = 0;
481
482    while i + 4 <= data.len() {
483        let data_vec = _mm256_loadu_pd(data.as_ptr().add(i));
484        let result = _mm256_add_pd(data_vec, scalar_vec);
485        _mm256_storeu_pd(data.as_mut_ptr().add(i), result);
486        i += 4;
487    }
488
489    while i < data.len() {
490        data[i] += scalar;
491        i += 1;
492    }
493}
494
495#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
496#[target_feature(enable = "avx2")]
497unsafe fn mul_scalar_f64_avx2(data: &mut [f64], scalar: f64) {
498    use std::arch::x86_64::*;
499
500    let scalar_vec = _mm256_set1_pd(scalar);
501    let mut i = 0;
502
503    while i + 4 <= data.len() {
504        let data_vec = _mm256_loadu_pd(data.as_ptr().add(i));
505        let result = _mm256_mul_pd(data_vec, scalar_vec);
506        _mm256_storeu_pd(data.as_mut_ptr().add(i), result);
507        i += 4;
508    }
509
510    while i < data.len() {
511        data[i] *= scalar;
512        i += 1;
513    }
514}
515
516#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
517#[target_feature(enable = "avx2")]
518unsafe fn add_vectors_f64_avx2(a: &[f64], b: &[f64], result: &mut [f64]) {
519    use std::arch::x86_64::*;
520
521    let mut i = 0;
522
523    while i + 4 <= a.len() {
524        let a_vec = _mm256_loadu_pd(a.as_ptr().add(i));
525        let b_vec = _mm256_loadu_pd(b.as_ptr().add(i));
526        let result_vec = _mm256_add_pd(a_vec, b_vec);
527        _mm256_storeu_pd(result.as_mut_ptr().add(i), result_vec);
528        i += 4;
529    }
530
531    while i < a.len() {
532        result[i] = a[i] + b[i];
533        i += 1;
534    }
535}
536
537#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
538#[target_feature(enable = "avx2")]
539unsafe fn sub_vectors_f64_avx2(a: &[f64], b: &[f64], result: &mut [f64]) {
540    use std::arch::x86_64::*;
541
542    let mut i = 0;
543
544    while i + 4 <= a.len() {
545        let a_vec = _mm256_loadu_pd(a.as_ptr().add(i));
546        let b_vec = _mm256_loadu_pd(b.as_ptr().add(i));
547        let result_vec = _mm256_sub_pd(a_vec, b_vec);
548        _mm256_storeu_pd(result.as_mut_ptr().add(i), result_vec);
549        i += 4;
550    }
551
552    while i < a.len() {
553        result[i] = a[i] - b[i];
554        i += 1;
555    }
556}
557
558#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
559#[target_feature(enable = "avx2")]
560unsafe fn mean_f64_avx2(data: &[f64]) -> f64 {
561    use std::arch::x86_64::*;
562
563    let mut sum = _mm256_setzero_pd();
564    let mut i = 0;
565
566    while i + 4 <= data.len() {
567        let data_vec = _mm256_loadu_pd(data.as_ptr().add(i));
568        sum = _mm256_add_pd(sum, data_vec);
569        i += 4;
570    }
571
572    let mut result = [0.0f64; 4];
573    _mm256_storeu_pd(result.as_mut_ptr(), sum);
574    let mut scalar_sum = result.iter().sum::<f64>();
575
576    while i < data.len() {
577        scalar_sum += data[i];
578        i += 1;
579    }
580
581    scalar_sum / data.len() as f64
582}
583
584#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
585#[target_feature(enable = "avx2")]
586unsafe fn variance_f64_avx2(data: &[f64], mean: f64) -> f64 {
587    use std::arch::x86_64::*;
588
589    let mean_vec = _mm256_set1_pd(mean);
590    let mut sum = _mm256_setzero_pd();
591    let mut i = 0;
592
593    while i + 4 <= data.len() {
594        let data_vec = _mm256_loadu_pd(data.as_ptr().add(i));
595        let diff = _mm256_sub_pd(data_vec, mean_vec);
596        let squared = _mm256_mul_pd(diff, diff);
597        sum = _mm256_add_pd(sum, squared);
598        i += 4;
599    }
600
601    let mut result = [0.0f64; 4];
602    _mm256_storeu_pd(result.as_mut_ptr(), sum);
603    let mut scalar_sum = result.iter().sum::<f64>();
604
605    while i < data.len() {
606        let diff = data[i] - mean;
607        scalar_sum += diff * diff;
608        i += 1;
609    }
610
611    scalar_sum / (data.len() - 1) as f64
612}
613
614#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
615#[target_feature(enable = "avx2")]
616unsafe fn min_max_f64_avx2(data: &[f64]) -> (f64, f64) {
617    use std::arch::x86_64::*;
618
619    let mut min_vec = _mm256_set1_pd(data[0]);
620    let mut max_vec = _mm256_set1_pd(data[0]);
621    let mut i = 0;
622
623    while i + 4 <= data.len() {
624        let data_vec = _mm256_loadu_pd(data.as_ptr().add(i));
625        min_vec = _mm256_min_pd(min_vec, data_vec);
626        max_vec = _mm256_max_pd(max_vec, data_vec);
627        i += 4;
628    }
629
630    let mut min_result = [0.0f64; 4];
631    let mut max_result = [0.0f64; 4];
632    _mm256_storeu_pd(min_result.as_mut_ptr(), min_vec);
633    _mm256_storeu_pd(max_result.as_mut_ptr(), max_vec);
634
635    let mut min_val = min_result.iter().fold(f64::INFINITY, |a, &b| a.min(b));
636    let mut max_val = max_result.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
637
638    while i < data.len() {
639        if data[i] < min_val {
640            min_val = data[i];
641        }
642        if data[i] > max_val {
643            max_val = data[i];
644        }
645        i += 1;
646    }
647
648    (min_val, max_val)
649}
650
651// ARM NEON implementations
652
653#[cfg(target_arch = "aarch64")]
654unsafe fn add_scalar_f64_neon(data: &mut [f64], scalar: f64) {
655    use std::arch::aarch64::*;
656
657    let scalar_vec = vdupq_n_f64(scalar);
658    let mut i = 0;
659
660    while i + 2 <= data.len() {
661        let data_vec = vld1q_f64(data.as_ptr().add(i));
662        let result = vaddq_f64(data_vec, scalar_vec);
663        vst1q_f64(data.as_mut_ptr().add(i), result);
664        i += 2;
665    }
666
667    while i < data.len() {
668        data[i] += scalar;
669        i += 1;
670    }
671}
672
673#[cfg(target_arch = "aarch64")]
674unsafe fn mul_scalar_f64_neon(data: &mut [f64], scalar: f64) {
675    use std::arch::aarch64::*;
676
677    let scalar_vec = vdupq_n_f64(scalar);
678    let mut i = 0;
679
680    while i + 2 <= data.len() {
681        let data_vec = vld1q_f64(data.as_ptr().add(i));
682        let result = vmulq_f64(data_vec, scalar_vec);
683        vst1q_f64(data.as_mut_ptr().add(i), result);
684        i += 2;
685    }
686
687    while i < data.len() {
688        data[i] *= scalar;
689        i += 1;
690    }
691}
692
693#[cfg(target_arch = "aarch64")]
694unsafe fn add_vectors_f64_neon(a: &[f64], b: &[f64], result: &mut [f64]) {
695    use std::arch::aarch64::*;
696
697    let mut i = 0;
698
699    while i + 2 <= a.len() {
700        let a_vec = vld1q_f64(a.as_ptr().add(i));
701        let b_vec = vld1q_f64(b.as_ptr().add(i));
702        let result_vec = vaddq_f64(a_vec, b_vec);
703        vst1q_f64(result.as_mut_ptr().add(i), result_vec);
704        i += 2;
705    }
706
707    while i < a.len() {
708        result[i] = a[i] + b[i];
709        i += 1;
710    }
711}
712
713#[cfg(target_arch = "aarch64")]
714unsafe fn sub_vectors_f64_neon(a: &[f64], b: &[f64], result: &mut [f64]) {
715    use std::arch::aarch64::*;
716
717    let mut i = 0;
718
719    while i + 2 <= a.len() {
720        let a_vec = vld1q_f64(a.as_ptr().add(i));
721        let b_vec = vld1q_f64(b.as_ptr().add(i));
722        let result_vec = vsubq_f64(a_vec, b_vec);
723        vst1q_f64(result.as_mut_ptr().add(i), result_vec);
724        i += 2;
725    }
726
727    while i < a.len() {
728        result[i] = a[i] - b[i];
729        i += 1;
730    }
731}
732
733#[cfg(target_arch = "aarch64")]
734unsafe fn mean_f64_neon(data: &[f64]) -> f64 {
735    use std::arch::aarch64::*;
736
737    let mut sum = vdupq_n_f64(0.0);
738    let mut i = 0;
739
740    while i + 2 <= data.len() {
741        let data_vec = vld1q_f64(data.as_ptr().add(i));
742        sum = vaddq_f64(sum, data_vec);
743        i += 2;
744    }
745
746    let mut scalar_sum = vaddvq_f64(sum);
747
748    while i < data.len() {
749        scalar_sum += data[i];
750        i += 1;
751    }
752
753    scalar_sum / data.len() as f64
754}
755
756#[cfg(target_arch = "aarch64")]
757unsafe fn variance_f64_neon(data: &[f64], mean: f64) -> f64 {
758    use std::arch::aarch64::*;
759
760    let mean_vec = vdupq_n_f64(mean);
761    let mut sum = vdupq_n_f64(0.0);
762    let mut i = 0;
763
764    while i + 2 <= data.len() {
765        let data_vec = vld1q_f64(data.as_ptr().add(i));
766        let diff = vsubq_f64(data_vec, mean_vec);
767        let squared = vmulq_f64(diff, diff);
768        sum = vaddq_f64(sum, squared);
769        i += 2;
770    }
771
772    let mut scalar_sum = vaddvq_f64(sum);
773
774    while i < data.len() {
775        let diff = data[i] - mean;
776        scalar_sum += diff * diff;
777        i += 1;
778    }
779
780    scalar_sum / (data.len() - 1) as f64
781}
782
783#[cfg(target_arch = "aarch64")]
784unsafe fn min_max_f64_neon(data: &[f64]) -> (f64, f64) {
785    use std::arch::aarch64::*;
786
787    let mut min_vec = vdupq_n_f64(data[0]);
788    let mut max_vec = vdupq_n_f64(data[0]);
789    let mut i = 0;
790
791    while i + 2 <= data.len() {
792        let data_vec = vld1q_f64(data.as_ptr().add(i));
793        min_vec = vminq_f64(min_vec, data_vec);
794        max_vec = vmaxq_f64(max_vec, data_vec);
795        i += 2;
796    }
797
798    let min_val = vminvq_f64(min_vec);
799    let max_val = vmaxvq_f64(max_vec);
800
801    let mut final_min = min_val;
802    let mut final_max = max_val;
803
804    while i < data.len() {
805        if data[i] < final_min {
806            final_min = data[i];
807        }
808        if data[i] > final_max {
809            final_max = data[i];
810        }
811        i += 1;
812    }
813
814    (final_min, final_max)
815}
816
817/// High-level SIMD-accelerated operations for ndarray integration
818pub mod ndarray_ops {
819    use super::*;
820
821    /// SIMD-optimized element-wise array addition with scalar
822    pub fn add_scalar_array(array: &mut Array2<f64>, scalar: f64, config: &SimdConfig) {
823        if config.use_parallel && array.len() > 1000 {
824            #[cfg(feature = "parallel")]
825            {
826                use rayon::prelude::*;
827                array
828                    .axis_iter_mut(Axis(0))
829                    .into_par_iter()
830                    .for_each(|mut row| {
831                        add_scalar_f64_simd(row.as_slice_mut().unwrap(), scalar, config);
832                    });
833                return;
834            }
835        }
836
837        for mut row in array.axis_iter_mut(Axis(0)) {
838            if let Some(slice) = row.as_slice_mut() {
839                add_scalar_f64_simd(slice, scalar, config);
840            } else {
841                // Non-contiguous case - fallback to element-wise
842                for elem in row.iter_mut() {
843                    *elem += scalar;
844                }
845            }
846        }
847    }
848
849    /// SIMD-optimized element-wise array multiplication with scalar
850    pub fn mul_scalar_array(array: &mut Array2<f64>, scalar: f64, config: &SimdConfig) {
851        if config.use_parallel && array.len() > 1000 {
852            #[cfg(feature = "parallel")]
853            {
854                use rayon::prelude::*;
855                array
856                    .axis_iter_mut(Axis(0))
857                    .into_par_iter()
858                    .for_each(|mut row| {
859                        mul_scalar_f64_simd(row.as_slice_mut().unwrap(), scalar, config);
860                    });
861                return;
862            }
863        }
864
865        for mut row in array.axis_iter_mut(Axis(0)) {
866            if let Some(slice) = row.as_slice_mut() {
867                mul_scalar_f64_simd(slice, scalar, config);
868            } else {
869                for elem in row.iter_mut() {
870                    *elem *= scalar;
871                }
872            }
873        }
874    }
875
876    /// SIMD-optimized column-wise mean calculation
877    pub fn column_means(array: &Array2<f64>, config: &SimdConfig) -> Array1<f64> {
878        let mut means = Array1::zeros(array.ncols());
879
880        for (j, mean_col) in means.iter_mut().enumerate() {
881            let column = array.column(j);
882            if let Some(slice) = column.as_slice() {
883                *mean_col = mean_f64_simd(slice, config);
884            } else {
885                *mean_col = column.iter().sum::<f64>() / array.nrows() as f64;
886            }
887        }
888
889        means
890    }
891
892    /// SIMD-optimized column-wise variance calculation
893    pub fn column_variances(
894        array: &Array2<f64>,
895        means: &Array1<f64>,
896        config: &SimdConfig,
897    ) -> Array1<f64> {
898        let mut variances = Array1::zeros(array.ncols());
899
900        for (j, var_col) in variances.iter_mut().enumerate() {
901            let column = array.column(j);
902            if let Some(slice) = column.as_slice() {
903                *var_col = variance_f64_simd(slice, means[j], config);
904            } else {
905                let sum_sq_diff: f64 = column.iter().map(|x| (x - means[j]).powi(2)).sum();
906                *var_col = sum_sq_diff / (array.nrows() - 1) as f64;
907            }
908        }
909
910        variances
911    }
912
913    /// SIMD-optimized column-wise min/max calculation
914    pub fn column_min_max(array: &Array2<f64>, config: &SimdConfig) -> (Array1<f64>, Array1<f64>) {
915        let mut mins = Array1::zeros(array.ncols());
916        let mut maxs = Array1::zeros(array.ncols());
917
918        for j in 0..array.ncols() {
919            let column = array.column(j);
920            if let Some(slice) = column.as_slice() {
921                let (min_val, max_val) = min_max_f64_simd(slice, config);
922                mins[j] = min_val;
923                maxs[j] = max_val;
924            } else {
925                let mut min_val = column[0];
926                let mut max_val = column[0];
927                for &val in column.iter().skip(1) {
928                    if val < min_val {
929                        min_val = val;
930                    }
931                    if val > max_val {
932                        max_val = val;
933                    }
934                }
935                mins[j] = min_val;
936                maxs[j] = max_val;
937            }
938        }
939
940        (mins, maxs)
941    }
942}
943
944#[allow(non_snake_case)]
945#[cfg(test)]
946mod tests {
947    use super::*;
948    use approx::assert_relative_eq;
949
950    #[test]
951    fn test_simd_config() {
952        let config = SimdConfig::default();
953        assert!(config.enabled);
954        assert_eq!(config.min_size_threshold, 32);
955        assert!(config.use_parallel);
956    }
957
958    #[test]
959    fn test_add_scalar_simd() {
960        let config = SimdConfig::default();
961        let mut data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
962        let original = data.clone();
963
964        add_scalar_f64_simd(&mut data, 10.0, &config);
965
966        for (i, &val) in data.iter().enumerate() {
967            assert_relative_eq!(val, original[i] + 10.0, epsilon = 1e-14);
968        }
969    }
970
971    #[test]
972    fn test_mul_scalar_simd() {
973        let config = SimdConfig::default();
974        let mut data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
975        let original = data.clone();
976
977        mul_scalar_f64_simd(&mut data, 2.5, &config);
978
979        for (i, &val) in data.iter().enumerate() {
980            assert_relative_eq!(val, original[i] * 2.5, epsilon = 1e-14);
981        }
982    }
983
984    #[test]
985    fn test_vector_operations_simd() {
986        let config = SimdConfig::default();
987        let a = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
988        let b = vec![2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0];
989        let mut result = vec![0.0; 8];
990
991        add_vectors_f64_simd(&a, &b, &mut result, &config);
992
993        for (i, &val) in result.iter().enumerate() {
994            assert_relative_eq!(val, a[i] + b[i], epsilon = 1e-14);
995        }
996    }
997
998    #[test]
999    fn test_mean_simd() {
1000        let config = SimdConfig::default();
1001        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
1002
1003        let mean = mean_f64_simd(&data, &config);
1004        let expected = 5.5;
1005
1006        assert_relative_eq!(mean, expected, epsilon = 1e-14);
1007    }
1008
1009    #[test]
1010    fn test_variance_simd() {
1011        let config = SimdConfig::default();
1012        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
1013        let mean = 3.0;
1014
1015        let variance = variance_f64_simd(&data, mean, &config);
1016        let expected = 2.5; // Sample variance
1017
1018        assert_relative_eq!(variance, expected, epsilon = 1e-14);
1019    }
1020
1021    #[test]
1022    fn test_min_max_simd() {
1023        let config = SimdConfig::default();
1024        let data = vec![3.0, 1.0, 4.0, 1.0, 5.0, 9.0, 2.0, 6.0];
1025
1026        let (min_val, max_val) = min_max_f64_simd(&data, &config);
1027
1028        assert_relative_eq!(min_val, 1.0, epsilon = 1e-14);
1029        assert_relative_eq!(max_val, 9.0, epsilon = 1e-14);
1030    }
1031
1032    #[test]
1033    fn test_ndarray_operations() {
1034        let config = SimdConfig::default();
1035        let mut array = Array2::from_shape_vec(
1036            (4, 3),
1037            vec![
1038                1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
1039            ],
1040        )
1041        .unwrap();
1042
1043        // Test scalar addition
1044        let original = array.clone();
1045        ndarray_ops::add_scalar_array(&mut array, 5.0, &config);
1046
1047        for (_i, (&new_val, &old_val)) in array.iter().zip(original.iter()).enumerate() {
1048            assert_relative_eq!(new_val, old_val + 5.0, epsilon = 1e-14);
1049        }
1050
1051        // Test column means
1052        let means = ndarray_ops::column_means(&original, &config);
1053        assert_relative_eq!(means[0], 5.5, epsilon = 1e-14); // (1+4+7+10)/4
1054        assert_relative_eq!(means[1], 6.5, epsilon = 1e-14); // (2+5+8+11)/4
1055        assert_relative_eq!(means[2], 7.5, epsilon = 1e-14); // (3+6+9+12)/4
1056    }
1057
1058    #[test]
1059    fn test_disabled_simd() {
1060        let config = SimdConfig {
1061            enabled: false,
1062            ..Default::default()
1063        };
1064
1065        let mut data = vec![1.0, 2.0, 3.0, 4.0];
1066        add_scalar_f64_simd(&mut data, 10.0, &config);
1067
1068        // Should still work correctly even with SIMD disabled
1069        assert_relative_eq!(data[0], 11.0, epsilon = 1e-14);
1070        assert_relative_eq!(data[1], 12.0, epsilon = 1e-14);
1071        assert_relative_eq!(data[2], 13.0, epsilon = 1e-14);
1072        assert_relative_eq!(data[3], 14.0, epsilon = 1e-14);
1073    }
1074
1075    #[test]
1076    fn test_small_array_threshold() {
1077        let config = SimdConfig {
1078            min_size_threshold: 100, // Larger than test data
1079            ..Default::default()
1080        };
1081
1082        let mut data = vec![1.0, 2.0, 3.0, 4.0];
1083        add_scalar_f64_simd(&mut data, 10.0, &config);
1084
1085        // Should fall back to scalar implementation for small arrays
1086        assert_relative_eq!(data[0], 11.0, epsilon = 1e-14);
1087        assert_relative_eq!(data[1], 12.0, epsilon = 1e-14);
1088        assert_relative_eq!(data[2], 13.0, epsilon = 1e-14);
1089        assert_relative_eq!(data[3], 14.0, epsilon = 1e-14);
1090    }
1091}