Skip to main content

sklears_simd/vector/
statistics_ops.rs

1//! # SIMD Vector Statistical and Reduction Operations
2//!
3//! High-performance SIMD-optimized statistical computations and reduction operations.
4//! Provides functions for computing statistical measures, norms, and vector reductions
5//! with optimal performance on modern CPU architectures.
6//!
7//! ## Features
8//!
9//! - **Reduction Operations**: Sum, product, min, max with SIMD horizontal reductions
10//! - **Statistical Measures**: Mean, variance, standard deviation computations
11//! - **Vector Norms**: L1, L2, and squared L2 norms
12//! - **Dot Products**: Optimized vector dot product computation
13//! - **Multi-Platform SIMD**: SSE2, AVX2, AVX512, NEON optimizations
14//! - **Automatic Fallback**: Graceful fallback to scalar implementations
15//!
16//! ## Performance Notes
17//!
18//! Statistical operations benefit greatly from SIMD instructions through parallel
19//! computation followed by horizontal reduction. The implementations use the most
20//! efficient reduction patterns for each target architecture.
21
22// Import ARM64 feature detection macro
23#[cfg(all(target_arch = "aarch64", not(feature = "no-std")))]
24use std::arch::is_aarch64_feature_detected;
25
26/// SIMD-optimized sum of all elements in a vector
27///
28/// Computes the sum of all elements using SIMD horizontal addition.
29///
30/// # Arguments
31/// * `input` - Input vector
32///
33/// # Returns
34/// The sum of all elements as a single f32 value
35///
36/// # Examples
37/// ```rust
38/// use sklears_simd::vector::statistics_ops::sum_vec;
39///
40/// let input = vec![1.0, 2.0, 3.0, 4.0];
41/// let result = sum_vec(&input);
42/// assert_eq!(result, 10.0);
43///
44/// // Test with empty vector
45/// let empty: Vec<f32> = vec![];
46/// assert_eq!(sum_vec(&empty), 0.0);
47/// ```
48pub fn sum_vec(input: &[f32]) -> f32 {
49    if input.is_empty() {
50        return 0.0;
51    }
52
53    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
54    {
55        if crate::simd_feature_detected!("avx512f") {
56            return unsafe { sum_vec_avx512(input) };
57        } else if crate::simd_feature_detected!("avx2") {
58            return unsafe { sum_vec_avx2(input) };
59        } else if crate::simd_feature_detected!("sse2") {
60            return unsafe { sum_vec_sse2(input) };
61        }
62    }
63
64    #[cfg(all(target_arch = "aarch64", not(feature = "no-std")))]
65    {
66        if is_aarch64_feature_detected!("neon") {
67            return unsafe { sum_vec_neon(input) };
68        }
69    }
70
71    sum_vec_scalar(input)
72}
73
74/// SIMD-optimized product of all elements in a vector
75///
76/// Computes the product of all elements using SIMD horizontal multiplication.
77///
78/// # Arguments
79/// * `input` - Input vector
80///
81/// # Returns
82/// The product of all elements as a single f32 value
83///
84/// # Examples
85/// ```rust
86/// use sklears_simd::vector::statistics_ops::product_vec;
87///
88/// let input = vec![1.0, 2.0, 3.0, 4.0];
89/// let result = product_vec(&input);
90/// assert_eq!(result, 24.0);
91///
92/// // Test with empty vector
93/// let empty: Vec<f32> = vec![];
94/// assert_eq!(product_vec(&empty), 1.0);
95/// ```
96pub fn product_vec(input: &[f32]) -> f32 {
97    if input.is_empty() {
98        return 1.0;
99    }
100
101    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
102    {
103        if crate::simd_feature_detected!("avx512f") {
104            return unsafe { product_vec_avx512(input) };
105        } else if crate::simd_feature_detected!("avx2") {
106            return unsafe { product_vec_avx2(input) };
107        } else if crate::simd_feature_detected!("sse2") {
108            return unsafe { product_vec_sse2(input) };
109        }
110    }
111
112    #[cfg(all(target_arch = "aarch64", not(feature = "no-std")))]
113    {
114        if is_aarch64_feature_detected!("neon") {
115            return unsafe { product_vec_neon(input) };
116        }
117    }
118
119    product_vec_scalar(input)
120}
121
122/// SIMD-optimized minimum value in a vector
123///
124/// Finds the minimum value using SIMD horizontal reduction.
125///
126/// # Arguments
127/// * `input` - Input vector (must not be empty)
128///
129/// # Returns
130/// The minimum value as a single f32 value
131///
132/// # Panics
133/// Panics if the input vector is empty
134///
135/// # Examples
136/// ```rust
137/// use sklears_simd::vector::statistics_ops::min_vec;
138///
139/// let input = vec![3.0, 1.0, 4.0, 1.0, 5.0];
140/// let result = min_vec(&input);
141/// assert_eq!(result, 1.0);
142/// ```
143pub fn min_vec(input: &[f32]) -> f32 {
144    assert!(!input.is_empty(), "Input vector must not be empty");
145
146    if input.iter().any(|x| x.is_nan()) {
147        return f32::NAN;
148    }
149
150    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
151    {
152        if crate::simd_feature_detected!("avx512f") {
153            return unsafe { min_vec_avx512(input) };
154        } else if crate::simd_feature_detected!("avx2") {
155            return unsafe { min_vec_avx2(input) };
156        } else if crate::simd_feature_detected!("sse2") {
157            return unsafe { min_vec_sse2(input) };
158        }
159    }
160
161    #[cfg(all(target_arch = "aarch64", not(feature = "no-std")))]
162    {
163        if is_aarch64_feature_detected!("neon") {
164            return unsafe { min_vec_neon(input) };
165        }
166    }
167
168    min_vec_scalar(input)
169}
170
171/// SIMD-optimized maximum value in a vector
172///
173/// Finds the maximum value using SIMD horizontal reduction.
174///
175/// # Arguments
176/// * `input` - Input vector (must not be empty)
177///
178/// # Returns
179/// The maximum value as a single f32 value
180///
181/// # Panics
182/// Panics if the input vector is empty
183///
184/// # Examples
185/// ```rust
186/// use sklears_simd::vector::statistics_ops::max_vec;
187///
188/// let input = vec![3.0, 1.0, 4.0, 1.0, 5.0];
189/// let result = max_vec(&input);
190/// assert_eq!(result, 5.0);
191/// ```
192pub fn max_vec(input: &[f32]) -> f32 {
193    assert!(!input.is_empty(), "Input vector must not be empty");
194
195    if input.iter().any(|x| x.is_nan()) {
196        return f32::NAN;
197    }
198
199    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
200    {
201        if crate::simd_feature_detected!("avx512f") {
202            return unsafe { max_vec_avx512(input) };
203        } else if crate::simd_feature_detected!("avx2") {
204            return unsafe { max_vec_avx2(input) };
205        } else if crate::simd_feature_detected!("sse2") {
206            return unsafe { max_vec_sse2(input) };
207        }
208    }
209
210    #[cfg(all(target_arch = "aarch64", not(feature = "no-std")))]
211    {
212        if is_aarch64_feature_detected!("neon") {
213            return unsafe { max_vec_neon(input) };
214        }
215    }
216
217    max_vec_scalar(input)
218}
219
220/// SIMD-optimized computation of both minimum and maximum values
221///
222/// Computes both min and max in a single pass for efficiency.
223///
224/// # Arguments
225/// * `input` - Input vector (must not be empty)
226///
227/// # Returns
228/// A tuple (min, max) containing the minimum and maximum values
229///
230/// # Panics
231/// Panics if the input vector is empty
232///
233/// # Examples
234/// ```rust
235/// use sklears_simd::vector::statistics_ops::min_max_vec;
236///
237/// let input = vec![3.0, 1.0, 4.0, 1.0, 5.0];
238/// let (min_val, max_val) = min_max_vec(&input);
239/// assert_eq!(min_val, 1.0);
240/// assert_eq!(max_val, 5.0);
241/// ```
242pub fn min_max_vec(input: &[f32]) -> (f32, f32) {
243    assert!(!input.is_empty(), "Input vector must not be empty");
244
245    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
246    {
247        if crate::simd_feature_detected!("avx512f") {
248            return unsafe { min_max_vec_avx512(input) };
249        } else if crate::simd_feature_detected!("avx2") {
250            return unsafe { min_max_vec_avx2(input) };
251        } else if crate::simd_feature_detected!("sse2") {
252            return unsafe { min_max_vec_sse2(input) };
253        }
254    }
255
256    #[cfg(all(target_arch = "aarch64", not(feature = "no-std")))]
257    {
258        if is_aarch64_feature_detected!("neon") {
259            return unsafe { min_max_vec_neon(input) };
260        }
261    }
262
263    min_max_vec_scalar(input)
264}
265
266/// SIMD-optimized arithmetic mean (average) of a vector
267///
268/// Computes the arithmetic mean using SIMD sum followed by division.
269///
270/// # Arguments
271/// * `input` - Input vector (must not be empty)
272///
273/// # Returns
274/// The arithmetic mean as a single f32 value
275///
276/// # Panics
277/// Panics if the input vector is empty
278///
279/// # Examples
280/// ```rust
281/// use sklears_simd::vector::statistics_ops::mean_vec;
282///
283/// let input = vec![1.0, 2.0, 3.0, 4.0, 5.0];
284/// let result = mean_vec(&input);
285/// assert_eq!(result, 3.0);
286/// ```
287pub fn mean_vec(input: &[f32]) -> f32 {
288    assert!(!input.is_empty(), "Input vector must not be empty");
289
290    let sum = sum_vec(input);
291    sum / (input.len() as f32)
292}
293
294/// SIMD-optimized dot product of two vectors
295///
296/// Computes the dot product using SIMD multiply-accumulate operations.
297///
298/// # Arguments
299/// * `a` - First input vector
300/// * `b` - Second input vector (must have same length as `a`)
301///
302/// # Returns
303/// The dot product as a single f32 value
304///
305/// # Panics
306/// Panics if the vectors have different lengths
307///
308/// # Examples
309/// ```rust
310/// use sklears_simd::vector::statistics_ops::dot_product;
311///
312/// let a = vec![1.0, 2.0, 3.0];
313/// let b = vec![4.0, 5.0, 6.0];
314/// let result = dot_product(&a, &b);
315/// assert_eq!(result, 32.0); // 1*4 + 2*5 + 3*6 = 4 + 10 + 18 = 32
316///
317/// // Test with empty vectors
318/// let empty_a: Vec<f32> = vec![];
319/// let empty_b: Vec<f32> = vec![];
320/// assert_eq!(dot_product(&empty_a, &empty_b), 0.0);
321/// ```
322pub fn dot_product(a: &[f32], b: &[f32]) -> f32 {
323    assert_eq!(a.len(), b.len(), "Input vectors must have the same length");
324
325    if a.is_empty() {
326        return 0.0;
327    }
328
329    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
330    {
331        if crate::simd_feature_detected!("avx512f") {
332            return unsafe { dot_product_avx512(a, b) };
333        } else if crate::simd_feature_detected!("avx2") {
334            return unsafe { dot_product_avx2(a, b) };
335        } else if crate::simd_feature_detected!("sse2") {
336            return unsafe { dot_product_sse2(a, b) };
337        }
338    }
339
340    #[cfg(all(target_arch = "aarch64", not(feature = "no-std")))]
341    {
342        if is_aarch64_feature_detected!("neon") {
343            return unsafe { dot_product_neon(a, b) };
344        }
345    }
346
347    dot_product_scalar(a, b)
348}
349
350/// SIMD-optimized L1 norm (Manhattan distance) of a vector
351///
352/// Computes the sum of absolute values using SIMD operations.
353///
354/// # Arguments
355/// * `input` - Input vector
356///
357/// # Returns
358/// The L1 norm as a single f32 value
359///
360/// # Examples
361/// ```rust
362/// use sklears_simd::vector::statistics_ops::norm_l1;
363///
364/// let input = vec![-1.0, 2.0, -3.0, 4.0];
365/// let result = norm_l1(&input);
366/// assert_eq!(result, 10.0); // |−1| + |2| + |−3| + |4| = 1 + 2 + 3 + 4 = 10
367/// ```
368pub fn norm_l1(input: &[f32]) -> f32 {
369    if input.is_empty() {
370        return 0.0;
371    }
372
373    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
374    {
375        if crate::simd_feature_detected!("avx512f") {
376            return unsafe { norm_l1_avx512(input) };
377        } else if crate::simd_feature_detected!("avx2") {
378            return unsafe { norm_l1_avx2(input) };
379        } else if crate::simd_feature_detected!("sse2") {
380            return unsafe { norm_l1_sse2(input) };
381        }
382    }
383
384    #[cfg(all(target_arch = "aarch64", not(feature = "no-std")))]
385    {
386        if is_aarch64_feature_detected!("neon") {
387            return unsafe { norm_l1_neon(input) };
388        }
389    }
390
391    norm_l1_scalar(input)
392}
393
394/// SIMD-optimized squared L2 norm of a vector
395///
396/// Computes the sum of squared values, which is the squared L2 norm.
397/// This is more efficient than L2 norm when the square root is not needed.
398///
399/// # Arguments
400/// * `input` - Input vector
401///
402/// # Returns
403/// The squared L2 norm as a single f32 value
404///
405/// # Examples
406/// ```rust
407/// use sklears_simd::vector::statistics_ops::norm_l2_squared;
408///
409/// let input = vec![1.0, 2.0, 3.0, 4.0];
410/// let result = norm_l2_squared(&input);
411/// assert_eq!(result, 30.0); // 1² + 2² + 3² + 4² = 1 + 4 + 9 + 16 = 30
412/// ```
413pub fn norm_l2_squared(input: &[f32]) -> f32 {
414    if input.is_empty() {
415        return 0.0;
416    }
417
418    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
419    {
420        if crate::simd_feature_detected!("avx512f") {
421            return unsafe { norm_l2_squared_avx512(input) };
422        } else if crate::simd_feature_detected!("avx2") {
423            return unsafe { norm_l2_squared_avx2(input) };
424        } else if crate::simd_feature_detected!("sse2") {
425            return unsafe { norm_l2_squared_sse2(input) };
426        }
427    }
428
429    #[cfg(all(target_arch = "aarch64", not(feature = "no-std")))]
430    {
431        if is_aarch64_feature_detected!("neon") {
432            return unsafe { norm_l2_squared_neon(input) };
433        }
434    }
435
436    norm_l2_squared_scalar(input)
437}
438
439/// SIMD-optimized L2 norm (Euclidean distance) of a vector
440///
441/// Computes the square root of the sum of squared values.
442///
443/// # Arguments
444/// * `input` - Input vector
445///
446/// # Returns
447/// The L2 norm as a single f32 value
448///
449/// # Examples
450/// ```rust
451/// use sklears_simd::vector::statistics_ops::norm_l2;
452///
453/// let input = vec![3.0, 4.0];
454/// let result = norm_l2(&input);
455/// assert_eq!(result, 5.0); // sqrt(3² + 4²) = sqrt(9 + 16) = sqrt(25) = 5
456/// ```
457pub fn norm_l2(input: &[f32]) -> f32 {
458    norm_l2_squared(input).sqrt()
459}
460
461/// SIMD-optimized population variance of a vector
462///
463/// Computes the population variance using the formula: Var(X) = E\[X²\] - E\[X\]²
464///
465/// # Arguments
466/// * `input` - Input vector (must not be empty)
467///
468/// # Returns
469/// The population variance as a single f32 value
470///
471/// # Panics
472/// Panics if the input vector is empty
473///
474/// # Examples
475/// ```rust
476/// use sklears_simd::vector::statistics_ops::variance_vec;
477///
478/// let input = vec![1.0, 2.0, 3.0, 4.0, 5.0];
479/// let result = variance_vec(&input);
480/// assert!((result - 2.0).abs() < 1e-6); // Population variance = 2.0
481/// ```
482pub fn variance_vec(input: &[f32]) -> f32 {
483    assert!(!input.is_empty(), "Input vector must not be empty");
484
485    let mean = mean_vec(input);
486    let sum_of_squares = norm_l2_squared(input);
487    let n = input.len() as f32;
488
489    sum_of_squares / n - mean * mean
490}
491
492/// SIMD-optimized standard deviation of a vector
493///
494/// Computes the population standard deviation as the square root of the variance.
495///
496/// # Arguments
497/// * `input` - Input vector (must not be empty)
498///
499/// # Returns
500/// The standard deviation as a single f32 value
501///
502/// # Panics
503/// Panics if the input vector is empty
504///
505/// # Examples
506/// ```rust
507/// use sklears_simd::vector::statistics_ops::std_dev_vec;
508///
509/// let input = vec![1.0, 2.0, 3.0, 4.0, 5.0];
510/// let result = std_dev_vec(&input);
511/// assert!((result - 1.41421356).abs() < 1e-6); // sqrt(2) ≈ 1.41421356
512/// ```
513pub fn std_dev_vec(input: &[f32]) -> f32 {
514    variance_vec(input).sqrt()
515}
516
517// ============================================================================
518// Scalar implementations (fallbacks)
519// ============================================================================
520
521fn sum_vec_scalar(input: &[f32]) -> f32 {
522    input.iter().sum()
523}
524
525fn product_vec_scalar(input: &[f32]) -> f32 {
526    input.iter().fold(1.0, |acc, &x| acc * x)
527}
528
529fn min_vec_scalar(input: &[f32]) -> f32 {
530    input
531        .iter()
532        .fold(input[0], |min, &x| if x < min { x } else { min })
533}
534
535fn max_vec_scalar(input: &[f32]) -> f32 {
536    input
537        .iter()
538        .fold(input[0], |max, &x| if x > max { x } else { max })
539}
540
541fn min_max_vec_scalar(input: &[f32]) -> (f32, f32) {
542    let mut min = input[0];
543    let mut max = input[0];
544    for &x in &input[1..] {
545        if x < min {
546            min = x;
547        }
548        if x > max {
549            max = x;
550        }
551    }
552    (min, max)
553}
554
555fn dot_product_scalar(a: &[f32], b: &[f32]) -> f32 {
556    a.iter().zip(b.iter()).map(|(&x, &y)| x * y).sum()
557}
558
559fn norm_l1_scalar(input: &[f32]) -> f32 {
560    input.iter().map(|&x| x.abs()).sum()
561}
562
563fn norm_l2_squared_scalar(input: &[f32]) -> f32 {
564    input.iter().map(|&x| x * x).sum()
565}
566
567// ============================================================================
568// SSE2 implementations (x86/x86_64)
569// ============================================================================
570
571#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
572#[target_feature(enable = "sse2")]
573unsafe fn sum_vec_sse2(input: &[f32]) -> f32 {
574    #[cfg(feature = "no-std")]
575    use core::arch::x86_64::*;
576    #[cfg(not(feature = "no-std"))]
577    use core::arch::x86_64::*;
578
579    let mut sum = _mm_setzero_ps();
580    let mut i = 0;
581
582    // Process 4 elements at a time
583    while i + 4 <= input.len() {
584        let chunk = _mm_loadu_ps(input.as_ptr().add(i));
585        sum = _mm_add_ps(sum, chunk);
586        i += 4;
587    }
588
589    // Horizontal sum
590    let temp = _mm_add_ps(sum, _mm_movehl_ps(sum, sum));
591    let result = _mm_add_ps(temp, _mm_shuffle_ps(temp, temp, 0x01));
592    let mut final_sum = _mm_cvtss_f32(result);
593
594    // Handle remaining elements
595    while i < input.len() {
596        final_sum += input[i];
597        i += 1;
598    }
599
600    final_sum
601}
602
603#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
604#[target_feature(enable = "sse2")]
605unsafe fn product_vec_sse2(input: &[f32]) -> f32 {
606    #[cfg(feature = "no-std")]
607    use core::arch::x86_64::*;
608    #[cfg(not(feature = "no-std"))]
609    use core::arch::x86_64::*;
610
611    let mut product = _mm_set1_ps(1.0);
612    let mut i = 0;
613
614    while i + 4 <= input.len() {
615        let chunk = _mm_loadu_ps(input.as_ptr().add(i));
616        product = _mm_mul_ps(product, chunk);
617        i += 4;
618    }
619
620    // Horizontal product
621    let temp = _mm_mul_ps(product, _mm_movehl_ps(product, product));
622    let result = _mm_mul_ps(temp, _mm_shuffle_ps(temp, temp, 0x01));
623    let mut final_product = _mm_cvtss_f32(result);
624
625    while i < input.len() {
626        final_product *= input[i];
627        i += 1;
628    }
629
630    final_product
631}
632
633#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
634#[target_feature(enable = "sse2")]
635unsafe fn min_vec_sse2(input: &[f32]) -> f32 {
636    #[cfg(feature = "no-std")]
637    use core::arch::x86_64::*;
638    #[cfg(not(feature = "no-std"))]
639    use core::arch::x86_64::*;
640
641    let mut min_val = _mm_load1_ps(&input[0]);
642    let mut i = 0;
643
644    while i + 4 <= input.len() {
645        let chunk = _mm_loadu_ps(input.as_ptr().add(i));
646        min_val = _mm_min_ps(min_val, chunk);
647        i += 4;
648    }
649
650    // Horizontal min
651    let temp = _mm_min_ps(min_val, _mm_movehl_ps(min_val, min_val));
652    let result = _mm_min_ps(temp, _mm_shuffle_ps(temp, temp, 0x01));
653    let mut final_min = _mm_cvtss_f32(result);
654
655    while i < input.len() {
656        if input[i] < final_min {
657            final_min = input[i];
658        }
659        i += 1;
660    }
661
662    final_min
663}
664
665#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
666#[target_feature(enable = "sse2")]
667unsafe fn max_vec_sse2(input: &[f32]) -> f32 {
668    #[cfg(feature = "no-std")]
669    use core::arch::x86_64::*;
670    #[cfg(not(feature = "no-std"))]
671    use core::arch::x86_64::*;
672
673    let mut max_val = _mm_load1_ps(&input[0]);
674    let mut i = 0;
675
676    while i + 4 <= input.len() {
677        let chunk = _mm_loadu_ps(input.as_ptr().add(i));
678        max_val = _mm_max_ps(max_val, chunk);
679        i += 4;
680    }
681
682    // Horizontal max
683    let temp = _mm_max_ps(max_val, _mm_movehl_ps(max_val, max_val));
684    let result = _mm_max_ps(temp, _mm_shuffle_ps(temp, temp, 0x01));
685    let mut final_max = _mm_cvtss_f32(result);
686
687    while i < input.len() {
688        if input[i] > final_max {
689            final_max = input[i];
690        }
691        i += 1;
692    }
693
694    final_max
695}
696
697#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
698#[target_feature(enable = "sse2")]
699unsafe fn min_max_vec_sse2(input: &[f32]) -> (f32, f32) {
700    #[cfg(feature = "no-std")]
701    use core::arch::x86_64::*;
702    #[cfg(not(feature = "no-std"))]
703    use core::arch::x86_64::*;
704
705    let mut min_val = _mm_load1_ps(&input[0]);
706    let mut max_val = _mm_load1_ps(&input[0]);
707    let mut i = 0;
708
709    while i + 4 <= input.len() {
710        let chunk = _mm_loadu_ps(input.as_ptr().add(i));
711        min_val = _mm_min_ps(min_val, chunk);
712        max_val = _mm_max_ps(max_val, chunk);
713        i += 4;
714    }
715
716    // Horizontal reductions
717    let min_temp = _mm_min_ps(min_val, _mm_movehl_ps(min_val, min_val));
718    let min_result = _mm_min_ps(min_temp, _mm_shuffle_ps(min_temp, min_temp, 0x01));
719    let mut final_min = _mm_cvtss_f32(min_result);
720
721    let max_temp = _mm_max_ps(max_val, _mm_movehl_ps(max_val, max_val));
722    let max_result = _mm_max_ps(max_temp, _mm_shuffle_ps(max_temp, max_temp, 0x01));
723    let mut final_max = _mm_cvtss_f32(max_result);
724
725    while i < input.len() {
726        if input[i] < final_min {
727            final_min = input[i];
728        }
729        if input[i] > final_max {
730            final_max = input[i];
731        }
732        i += 1;
733    }
734
735    (final_min, final_max)
736}
737
738#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
739#[target_feature(enable = "sse2")]
740unsafe fn dot_product_sse2(a: &[f32], b: &[f32]) -> f32 {
741    #[cfg(feature = "no-std")]
742    use core::arch::x86_64::*;
743    #[cfg(not(feature = "no-std"))]
744    use core::arch::x86_64::*;
745
746    let mut sum = _mm_setzero_ps();
747    let mut i = 0;
748
749    while i + 4 <= a.len() {
750        let a_chunk = _mm_loadu_ps(a.as_ptr().add(i));
751        let b_chunk = _mm_loadu_ps(b.as_ptr().add(i));
752        let product = _mm_mul_ps(a_chunk, b_chunk);
753        sum = _mm_add_ps(sum, product);
754        i += 4;
755    }
756
757    // Horizontal sum
758    let temp = _mm_add_ps(sum, _mm_movehl_ps(sum, sum));
759    let result = _mm_add_ps(temp, _mm_shuffle_ps(temp, temp, 0x01));
760    let mut final_sum = _mm_cvtss_f32(result);
761
762    while i < a.len() {
763        final_sum += a[i] * b[i];
764        i += 1;
765    }
766
767    final_sum
768}
769
770#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
771#[target_feature(enable = "sse2")]
772unsafe fn norm_l1_sse2(input: &[f32]) -> f32 {
773    #[cfg(feature = "no-std")]
774    use core::arch::x86_64::*;
775    #[cfg(not(feature = "no-std"))]
776    use core::arch::x86_64::*;
777
778    let mut sum = _mm_setzero_ps();
779    let abs_mask = _mm_set1_ps(f32::from_bits(0x7FFFFFFF));
780    let mut i = 0;
781
782    while i + 4 <= input.len() {
783        let chunk = _mm_loadu_ps(input.as_ptr().add(i));
784        let abs_chunk = _mm_and_ps(chunk, abs_mask);
785        sum = _mm_add_ps(sum, abs_chunk);
786        i += 4;
787    }
788
789    // Horizontal sum
790    let temp = _mm_add_ps(sum, _mm_movehl_ps(sum, sum));
791    let result = _mm_add_ps(temp, _mm_shuffle_ps(temp, temp, 0x01));
792    let mut final_sum = _mm_cvtss_f32(result);
793
794    while i < input.len() {
795        final_sum += input[i].abs();
796        i += 1;
797    }
798
799    final_sum
800}
801
802#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
803#[target_feature(enable = "sse2")]
804unsafe fn norm_l2_squared_sse2(input: &[f32]) -> f32 {
805    #[cfg(feature = "no-std")]
806    use core::arch::x86_64::*;
807    #[cfg(not(feature = "no-std"))]
808    use core::arch::x86_64::*;
809
810    let mut sum = _mm_setzero_ps();
811    let mut i = 0;
812
813    while i + 4 <= input.len() {
814        let chunk = _mm_loadu_ps(input.as_ptr().add(i));
815        let squared = _mm_mul_ps(chunk, chunk);
816        sum = _mm_add_ps(sum, squared);
817        i += 4;
818    }
819
820    // Horizontal sum
821    let temp = _mm_add_ps(sum, _mm_movehl_ps(sum, sum));
822    let result = _mm_add_ps(temp, _mm_shuffle_ps(temp, temp, 0x01));
823    let mut final_sum = _mm_cvtss_f32(result);
824
825    while i < input.len() {
826        final_sum += input[i] * input[i];
827        i += 1;
828    }
829
830    final_sum
831}
832
833// ============================================================================
834// AVX2 implementations (x86/x86_64)
835// ============================================================================
836
837#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
838#[target_feature(enable = "avx2")]
839unsafe fn sum_vec_avx2(input: &[f32]) -> f32 {
840    #[cfg(feature = "no-std")]
841    use core::arch::x86_64::*;
842    #[cfg(not(feature = "no-std"))]
843    use core::arch::x86_64::*;
844
845    let mut sum = _mm256_setzero_ps();
846    let mut i = 0;
847
848    while i + 8 <= input.len() {
849        let chunk = _mm256_loadu_ps(input.as_ptr().add(i));
850        sum = _mm256_add_ps(sum, chunk);
851        i += 8;
852    }
853
854    // Extract and sum both 128-bit lanes
855    let sum128 = _mm_add_ps(_mm256_extractf128_ps(sum, 0), _mm256_extractf128_ps(sum, 1));
856    let temp = _mm_add_ps(sum128, _mm_movehl_ps(sum128, sum128));
857    let result = _mm_add_ps(temp, _mm_shuffle_ps(temp, temp, 0x01));
858    let mut final_sum = _mm_cvtss_f32(result);
859
860    while i < input.len() {
861        final_sum += input[i];
862        i += 1;
863    }
864
865    final_sum
866}
867
868#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
869#[target_feature(enable = "avx2")]
870unsafe fn product_vec_avx2(input: &[f32]) -> f32 {
871    #[cfg(feature = "no-std")]
872    use core::arch::x86_64::*;
873    #[cfg(not(feature = "no-std"))]
874    use core::arch::x86_64::*;
875
876    let mut product = _mm256_set1_ps(1.0);
877    let mut i = 0;
878
879    while i + 8 <= input.len() {
880        let chunk = _mm256_loadu_ps(input.as_ptr().add(i));
881        product = _mm256_mul_ps(product, chunk);
882        i += 8;
883    }
884
885    // Extract and multiply both 128-bit lanes
886    let prod128 = _mm_mul_ps(
887        _mm256_extractf128_ps(product, 0),
888        _mm256_extractf128_ps(product, 1),
889    );
890    let temp = _mm_mul_ps(prod128, _mm_movehl_ps(prod128, prod128));
891    let result = _mm_mul_ps(temp, _mm_shuffle_ps(temp, temp, 0x01));
892    let mut final_product = _mm_cvtss_f32(result);
893
894    while i < input.len() {
895        final_product *= input[i];
896        i += 1;
897    }
898
899    final_product
900}
901
902#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
903#[target_feature(enable = "avx2")]
904unsafe fn min_vec_avx2(input: &[f32]) -> f32 {
905    #[cfg(feature = "no-std")]
906    use core::arch::x86_64::*;
907    #[cfg(not(feature = "no-std"))]
908    use core::arch::x86_64::*;
909
910    let mut min_val = _mm256_broadcast_ss(&input[0]);
911    let mut i = 0;
912
913    while i + 8 <= input.len() {
914        let chunk = _mm256_loadu_ps(input.as_ptr().add(i));
915        min_val = _mm256_min_ps(min_val, chunk);
916        i += 8;
917    }
918
919    // Extract and min both 128-bit lanes
920    let min128 = _mm_min_ps(
921        _mm256_extractf128_ps(min_val, 0),
922        _mm256_extractf128_ps(min_val, 1),
923    );
924    let temp = _mm_min_ps(min128, _mm_movehl_ps(min128, min128));
925    let result = _mm_min_ps(temp, _mm_shuffle_ps(temp, temp, 0x01));
926    let mut final_min = _mm_cvtss_f32(result);
927
928    while i < input.len() {
929        if input[i] < final_min {
930            final_min = input[i];
931        }
932        i += 1;
933    }
934
935    final_min
936}
937
938#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
939#[target_feature(enable = "avx2")]
940unsafe fn max_vec_avx2(input: &[f32]) -> f32 {
941    #[cfg(feature = "no-std")]
942    use core::arch::x86_64::*;
943    #[cfg(not(feature = "no-std"))]
944    use core::arch::x86_64::*;
945
946    let mut max_val = _mm256_broadcast_ss(&input[0]);
947    let mut i = 0;
948
949    while i + 8 <= input.len() {
950        let chunk = _mm256_loadu_ps(input.as_ptr().add(i));
951        max_val = _mm256_max_ps(max_val, chunk);
952        i += 8;
953    }
954
955    // Extract and max both 128-bit lanes
956    let max128 = _mm_max_ps(
957        _mm256_extractf128_ps(max_val, 0),
958        _mm256_extractf128_ps(max_val, 1),
959    );
960    let temp = _mm_max_ps(max128, _mm_movehl_ps(max128, max128));
961    let result = _mm_max_ps(temp, _mm_shuffle_ps(temp, temp, 0x01));
962    let mut final_max = _mm_cvtss_f32(result);
963
964    while i < input.len() {
965        if input[i] > final_max {
966            final_max = input[i];
967        }
968        i += 1;
969    }
970
971    final_max
972}
973
974#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
975#[target_feature(enable = "avx2")]
976unsafe fn min_max_vec_avx2(input: &[f32]) -> (f32, f32) {
977    #[cfg(feature = "no-std")]
978    use core::arch::x86_64::*;
979    #[cfg(not(feature = "no-std"))]
980    use core::arch::x86_64::*;
981
982    let mut min_val = _mm256_broadcast_ss(&input[0]);
983    let mut max_val = _mm256_broadcast_ss(&input[0]);
984    let mut i = 0;
985
986    while i + 8 <= input.len() {
987        let chunk = _mm256_loadu_ps(input.as_ptr().add(i));
988        min_val = _mm256_min_ps(min_val, chunk);
989        max_val = _mm256_max_ps(max_val, chunk);
990        i += 8;
991    }
992
993    // Extract and reduce both lanes
994    let min128 = _mm_min_ps(
995        _mm256_extractf128_ps(min_val, 0),
996        _mm256_extractf128_ps(min_val, 1),
997    );
998    let min_temp = _mm_min_ps(min128, _mm_movehl_ps(min128, min128));
999    let min_result = _mm_min_ps(min_temp, _mm_shuffle_ps(min_temp, min_temp, 0x01));
1000    let mut final_min = _mm_cvtss_f32(min_result);
1001
1002    let max128 = _mm_max_ps(
1003        _mm256_extractf128_ps(max_val, 0),
1004        _mm256_extractf128_ps(max_val, 1),
1005    );
1006    let max_temp = _mm_max_ps(max128, _mm_movehl_ps(max128, max128));
1007    let max_result = _mm_max_ps(max_temp, _mm_shuffle_ps(max_temp, max_temp, 0x01));
1008    let mut final_max = _mm_cvtss_f32(max_result);
1009
1010    while i < input.len() {
1011        if input[i] < final_min {
1012            final_min = input[i];
1013        }
1014        if input[i] > final_max {
1015            final_max = input[i];
1016        }
1017        i += 1;
1018    }
1019
1020    (final_min, final_max)
1021}
1022
1023#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
1024#[target_feature(enable = "avx2")]
1025unsafe fn dot_product_avx2(a: &[f32], b: &[f32]) -> f32 {
1026    #[cfg(feature = "no-std")]
1027    use core::arch::x86_64::*;
1028    #[cfg(not(feature = "no-std"))]
1029    use core::arch::x86_64::*;
1030
1031    let mut sum = _mm256_setzero_ps();
1032    let mut i = 0;
1033
1034    while i + 8 <= a.len() {
1035        let a_chunk = _mm256_loadu_ps(a.as_ptr().add(i));
1036        let b_chunk = _mm256_loadu_ps(b.as_ptr().add(i));
1037        let product = _mm256_mul_ps(a_chunk, b_chunk);
1038        sum = _mm256_add_ps(sum, product);
1039        i += 8;
1040    }
1041
1042    // Extract and sum both 128-bit lanes
1043    let sum128 = _mm_add_ps(_mm256_extractf128_ps(sum, 0), _mm256_extractf128_ps(sum, 1));
1044    let temp = _mm_add_ps(sum128, _mm_movehl_ps(sum128, sum128));
1045    let result = _mm_add_ps(temp, _mm_shuffle_ps(temp, temp, 0x01));
1046    let mut final_sum = _mm_cvtss_f32(result);
1047
1048    while i < a.len() {
1049        final_sum += a[i] * b[i];
1050        i += 1;
1051    }
1052
1053    final_sum
1054}
1055
1056#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
1057#[target_feature(enable = "avx2")]
1058unsafe fn norm_l1_avx2(input: &[f32]) -> f32 {
1059    #[cfg(feature = "no-std")]
1060    use core::arch::x86_64::*;
1061    #[cfg(not(feature = "no-std"))]
1062    use core::arch::x86_64::*;
1063
1064    let mut sum = _mm256_setzero_ps();
1065    let abs_mask = _mm256_set1_ps(f32::from_bits(0x7FFFFFFF));
1066    let mut i = 0;
1067
1068    while i + 8 <= input.len() {
1069        let chunk = _mm256_loadu_ps(input.as_ptr().add(i));
1070        let abs_chunk = _mm256_and_ps(chunk, abs_mask);
1071        sum = _mm256_add_ps(sum, abs_chunk);
1072        i += 8;
1073    }
1074
1075    // Extract and sum both 128-bit lanes
1076    let sum128 = _mm_add_ps(_mm256_extractf128_ps(sum, 0), _mm256_extractf128_ps(sum, 1));
1077    let temp = _mm_add_ps(sum128, _mm_movehl_ps(sum128, sum128));
1078    let result = _mm_add_ps(temp, _mm_shuffle_ps(temp, temp, 0x01));
1079    let mut final_sum = _mm_cvtss_f32(result);
1080
1081    while i < input.len() {
1082        final_sum += input[i].abs();
1083        i += 1;
1084    }
1085
1086    final_sum
1087}
1088
1089#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
1090#[target_feature(enable = "avx2")]
1091unsafe fn norm_l2_squared_avx2(input: &[f32]) -> f32 {
1092    #[cfg(feature = "no-std")]
1093    use core::arch::x86_64::*;
1094    #[cfg(not(feature = "no-std"))]
1095    use core::arch::x86_64::*;
1096
1097    let mut sum = _mm256_setzero_ps();
1098    let mut i = 0;
1099
1100    while i + 8 <= input.len() {
1101        let chunk = _mm256_loadu_ps(input.as_ptr().add(i));
1102        let squared = _mm256_mul_ps(chunk, chunk);
1103        sum = _mm256_add_ps(sum, squared);
1104        i += 8;
1105    }
1106
1107    // Extract and sum both 128-bit lanes
1108    let sum128 = _mm_add_ps(_mm256_extractf128_ps(sum, 0), _mm256_extractf128_ps(sum, 1));
1109    let temp = _mm_add_ps(sum128, _mm_movehl_ps(sum128, sum128));
1110    let result = _mm_add_ps(temp, _mm_shuffle_ps(temp, temp, 0x01));
1111    let mut final_sum = _mm_cvtss_f32(result);
1112
1113    while i < input.len() {
1114        final_sum += input[i] * input[i];
1115        i += 1;
1116    }
1117
1118    final_sum
1119}
1120
1121// ============================================================================
1122// AVX512 implementations (x86/x86_64) - simplified for brevity
1123// ============================================================================
1124
1125#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
1126#[target_feature(enable = "avx512f")]
1127unsafe fn sum_vec_avx512(input: &[f32]) -> f32 {
1128    // For brevity, using AVX2 fallback - would implement proper AVX512 horizontal reductions
1129    sum_vec_avx2(input)
1130}
1131
1132#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
1133#[target_feature(enable = "avx512f")]
1134unsafe fn product_vec_avx512(input: &[f32]) -> f32 {
1135    product_vec_avx2(input)
1136}
1137
1138#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
1139#[target_feature(enable = "avx512f")]
1140unsafe fn min_vec_avx512(input: &[f32]) -> f32 {
1141    min_vec_avx2(input)
1142}
1143
1144#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
1145#[target_feature(enable = "avx512f")]
1146unsafe fn max_vec_avx512(input: &[f32]) -> f32 {
1147    max_vec_avx2(input)
1148}
1149
1150#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
1151#[target_feature(enable = "avx512f")]
1152unsafe fn min_max_vec_avx512(input: &[f32]) -> (f32, f32) {
1153    min_max_vec_avx2(input)
1154}
1155
1156#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
1157#[target_feature(enable = "avx512f")]
1158unsafe fn dot_product_avx512(a: &[f32], b: &[f32]) -> f32 {
1159    dot_product_avx2(a, b)
1160}
1161
1162#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
1163#[target_feature(enable = "avx512f")]
1164unsafe fn norm_l1_avx512(input: &[f32]) -> f32 {
1165    norm_l1_avx2(input)
1166}
1167
1168#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
1169#[target_feature(enable = "avx512f")]
1170unsafe fn norm_l2_squared_avx512(input: &[f32]) -> f32 {
1171    norm_l2_squared_avx2(input)
1172}
1173
1174// ============================================================================
1175// NEON implementations (ARM AArch64) - simplified for brevity
1176// ============================================================================
1177
1178#[cfg(target_arch = "aarch64")]
1179#[allow(dead_code)] // NEON dispatch; unused in --all-features (no-std disables runtime detection)
1180#[target_feature(enable = "neon")]
1181unsafe fn sum_vec_neon(input: &[f32]) -> f32 {
1182    use core::arch::aarch64::*;
1183
1184    let mut sum = vdupq_n_f32(0.0);
1185    let mut i = 0;
1186
1187    while i + 4 <= input.len() {
1188        let chunk = vld1q_f32(input.as_ptr().add(i));
1189        sum = vaddq_f32(sum, chunk);
1190        i += 4;
1191    }
1192
1193    // Horizontal sum using pairwise addition
1194    let sum2 = vpadd_f32(vget_low_f32(sum), vget_high_f32(sum));
1195    let sum1 = vpadd_f32(sum2, sum2);
1196    let mut final_sum = vget_lane_f32(sum1, 0);
1197
1198    while i < input.len() {
1199        final_sum += input[i];
1200        i += 1;
1201    }
1202
1203    final_sum
1204}
1205
1206#[cfg(target_arch = "aarch64")]
1207#[allow(dead_code)] // NEON dispatch; unused in --all-features (no-std disables runtime detection)
1208#[target_feature(enable = "neon")]
1209unsafe fn product_vec_neon(input: &[f32]) -> f32 {
1210    // Fallback to scalar for simplicity
1211    product_vec_scalar(input)
1212}
1213
1214#[cfg(target_arch = "aarch64")]
1215#[allow(dead_code)] // NEON dispatch; unused in --all-features (no-std disables runtime detection)
1216#[target_feature(enable = "neon")]
1217unsafe fn min_vec_neon(input: &[f32]) -> f32 {
1218    use core::arch::aarch64::*;
1219
1220    let mut min_val = vdupq_n_f32(input[0]);
1221    let mut i = 0;
1222
1223    while i + 4 <= input.len() {
1224        let chunk = vld1q_f32(input.as_ptr().add(i));
1225        min_val = vminq_f32(min_val, chunk);
1226        i += 4;
1227    }
1228
1229    // Horizontal min using pairwise min
1230    let min2 = vpmin_f32(vget_low_f32(min_val), vget_high_f32(min_val));
1231    let min1 = vpmin_f32(min2, min2);
1232    let mut final_min = vget_lane_f32(min1, 0);
1233
1234    while i < input.len() {
1235        if input[i] < final_min {
1236            final_min = input[i];
1237        }
1238        i += 1;
1239    }
1240
1241    final_min
1242}
1243
1244#[cfg(target_arch = "aarch64")]
1245#[allow(dead_code)] // NEON dispatch; unused in --all-features (no-std disables runtime detection)
1246#[target_feature(enable = "neon")]
1247unsafe fn max_vec_neon(input: &[f32]) -> f32 {
1248    use core::arch::aarch64::*;
1249
1250    let mut max_val = vdupq_n_f32(input[0]);
1251    let mut i = 0;
1252
1253    while i + 4 <= input.len() {
1254        let chunk = vld1q_f32(input.as_ptr().add(i));
1255        max_val = vmaxq_f32(max_val, chunk);
1256        i += 4;
1257    }
1258
1259    // Horizontal max using pairwise max
1260    let max2 = vpmax_f32(vget_low_f32(max_val), vget_high_f32(max_val));
1261    let max1 = vpmax_f32(max2, max2);
1262    let mut final_max = vget_lane_f32(max1, 0);
1263
1264    while i < input.len() {
1265        if input[i] > final_max {
1266            final_max = input[i];
1267        }
1268        i += 1;
1269    }
1270
1271    final_max
1272}
1273
1274#[cfg(target_arch = "aarch64")]
1275#[allow(dead_code)] // NEON dispatch; unused in --all-features (no-std disables runtime detection)
1276#[target_feature(enable = "neon")]
1277unsafe fn min_max_vec_neon(input: &[f32]) -> (f32, f32) {
1278    use core::arch::aarch64::*;
1279
1280    let mut min_val = vdupq_n_f32(input[0]);
1281    let mut max_val = vdupq_n_f32(input[0]);
1282    let mut i = 0;
1283
1284    while i + 4 <= input.len() {
1285        let chunk = vld1q_f32(input.as_ptr().add(i));
1286        min_val = vminq_f32(min_val, chunk);
1287        max_val = vmaxq_f32(max_val, chunk);
1288        i += 4;
1289    }
1290
1291    // Horizontal reductions
1292    let min2 = vpmin_f32(vget_low_f32(min_val), vget_high_f32(min_val));
1293    let min1 = vpmin_f32(min2, min2);
1294    let mut final_min = vget_lane_f32(min1, 0);
1295
1296    let max2 = vpmax_f32(vget_low_f32(max_val), vget_high_f32(max_val));
1297    let max1 = vpmax_f32(max2, max2);
1298    let mut final_max = vget_lane_f32(max1, 0);
1299
1300    while i < input.len() {
1301        if input[i] < final_min {
1302            final_min = input[i];
1303        }
1304        if input[i] > final_max {
1305            final_max = input[i];
1306        }
1307        i += 1;
1308    }
1309
1310    (final_min, final_max)
1311}
1312
1313#[cfg(target_arch = "aarch64")]
1314#[allow(dead_code)] // NEON dispatch; unused in --all-features (no-std disables runtime detection)
1315#[target_feature(enable = "neon")]
1316unsafe fn dot_product_neon(a: &[f32], b: &[f32]) -> f32 {
1317    use core::arch::aarch64::*;
1318
1319    let mut sum = vdupq_n_f32(0.0);
1320    let mut i = 0;
1321
1322    while i + 4 <= a.len() {
1323        let a_chunk = vld1q_f32(a.as_ptr().add(i));
1324        let b_chunk = vld1q_f32(b.as_ptr().add(i));
1325        let product = vmulq_f32(a_chunk, b_chunk);
1326        sum = vaddq_f32(sum, product);
1327        i += 4;
1328    }
1329
1330    // Horizontal sum using pairwise addition
1331    let sum2 = vpadd_f32(vget_low_f32(sum), vget_high_f32(sum));
1332    let sum1 = vpadd_f32(sum2, sum2);
1333    let mut final_sum = vget_lane_f32(sum1, 0);
1334
1335    while i < a.len() {
1336        final_sum += a[i] * b[i];
1337        i += 1;
1338    }
1339
1340    final_sum
1341}
1342
1343#[cfg(target_arch = "aarch64")]
1344#[allow(dead_code)] // NEON dispatch; unused in --all-features (no-std disables runtime detection)
1345#[target_feature(enable = "neon")]
1346unsafe fn norm_l1_neon(input: &[f32]) -> f32 {
1347    use core::arch::aarch64::*;
1348
1349    let mut sum = vdupq_n_f32(0.0);
1350    let mut i = 0;
1351
1352    while i + 4 <= input.len() {
1353        let chunk = vld1q_f32(input.as_ptr().add(i));
1354        let abs_chunk = vabsq_f32(chunk);
1355        sum = vaddq_f32(sum, abs_chunk);
1356        i += 4;
1357    }
1358
1359    // Horizontal sum using pairwise addition
1360    let sum2 = vpadd_f32(vget_low_f32(sum), vget_high_f32(sum));
1361    let sum1 = vpadd_f32(sum2, sum2);
1362    let mut final_sum = vget_lane_f32(sum1, 0);
1363
1364    while i < input.len() {
1365        final_sum += input[i].abs();
1366        i += 1;
1367    }
1368
1369    final_sum
1370}
1371
1372#[cfg(target_arch = "aarch64")]
1373#[allow(dead_code)] // NEON dispatch; unused in --all-features (no-std disables runtime detection)
1374#[target_feature(enable = "neon")]
1375unsafe fn norm_l2_squared_neon(input: &[f32]) -> f32 {
1376    use core::arch::aarch64::*;
1377
1378    let mut sum = vdupq_n_f32(0.0);
1379    let mut i = 0;
1380
1381    while i + 4 <= input.len() {
1382        let chunk = vld1q_f32(input.as_ptr().add(i));
1383        let squared = vmulq_f32(chunk, chunk);
1384        sum = vaddq_f32(sum, squared);
1385        i += 4;
1386    }
1387
1388    // Horizontal sum using pairwise addition
1389    let sum2 = vpadd_f32(vget_low_f32(sum), vget_high_f32(sum));
1390    let sum1 = vpadd_f32(sum2, sum2);
1391    let mut final_sum = vget_lane_f32(sum1, 0);
1392
1393    while i < input.len() {
1394        final_sum += input[i] * input[i];
1395        i += 1;
1396    }
1397
1398    final_sum
1399}
1400
1401#[allow(non_snake_case)]
1402#[cfg(all(test, not(feature = "no-std")))]
1403mod tests {
1404    use super::*;
1405
1406    #[cfg(feature = "no-std")]
1407    use alloc::{vec, vec::Vec};
1408
1409    const EPSILON: f32 = 1e-6;
1410
1411    #[test]
1412    fn test_sum_vec() {
1413        let input = vec![1.0, 2.0, 3.0, 4.0];
1414        let result = sum_vec(&input);
1415        assert_eq!(result, 10.0);
1416
1417        // Test with empty vector
1418        let empty: Vec<f32> = vec![];
1419        assert_eq!(sum_vec(&empty), 0.0);
1420
1421        // Test with negative numbers
1422        let negative = vec![-1.0, -2.0, 3.0, 4.0];
1423        assert_eq!(sum_vec(&negative), 4.0);
1424    }
1425
1426    #[test]
1427    fn test_product_vec() {
1428        let input = vec![1.0, 2.0, 3.0, 4.0];
1429        let result = product_vec(&input);
1430        assert_eq!(result, 24.0);
1431
1432        // Test with empty vector
1433        let empty: Vec<f32> = vec![];
1434        assert_eq!(product_vec(&empty), 1.0);
1435
1436        // Test with zeros
1437        let with_zero = vec![1.0, 2.0, 0.0, 4.0];
1438        assert_eq!(product_vec(&with_zero), 0.0);
1439    }
1440
1441    #[test]
1442    fn test_min_vec() {
1443        let input = vec![3.0, 1.0, 4.0, 1.0, 5.0];
1444        let result = min_vec(&input);
1445        assert_eq!(result, 1.0);
1446
1447        // Test with negative numbers
1448        let negative = vec![-1.0, -5.0, 2.0, -3.0];
1449        assert_eq!(min_vec(&negative), -5.0);
1450    }
1451
1452    #[test]
1453    fn test_max_vec() {
1454        let input = vec![3.0, 1.0, 4.0, 1.0, 5.0];
1455        let result = max_vec(&input);
1456        assert_eq!(result, 5.0);
1457
1458        // Test with negative numbers
1459        let negative = vec![-1.0, -5.0, 2.0, -3.0];
1460        assert_eq!(max_vec(&negative), 2.0);
1461    }
1462
1463    #[test]
1464    fn test_min_max_vec() {
1465        let input = vec![3.0, 1.0, 4.0, 1.0, 5.0];
1466        let (min_val, max_val) = min_max_vec(&input);
1467        assert_eq!(min_val, 1.0);
1468        assert_eq!(max_val, 5.0);
1469
1470        // Test with single element
1471        let single = vec![42.0];
1472        let (min_single, max_single) = min_max_vec(&single);
1473        assert_eq!(min_single, 42.0);
1474        assert_eq!(max_single, 42.0);
1475    }
1476
1477    #[test]
1478    fn test_mean_vec() {
1479        let input = vec![1.0, 2.0, 3.0, 4.0, 5.0];
1480        let result = mean_vec(&input);
1481        assert_eq!(result, 3.0);
1482
1483        // Test with non-integer mean
1484        let decimal = vec![1.0, 2.0, 3.0];
1485        let decimal_mean = mean_vec(&decimal);
1486        assert_eq!(decimal_mean, 2.0);
1487    }
1488
1489    #[test]
1490    fn test_dot_product() {
1491        let a = vec![1.0, 2.0, 3.0];
1492        let b = vec![4.0, 5.0, 6.0];
1493        let result = dot_product(&a, &b);
1494        assert_eq!(result, 32.0); // 1*4 + 2*5 + 3*6 = 4 + 10 + 18 = 32
1495
1496        // Test with empty vectors
1497        let empty_a: Vec<f32> = vec![];
1498        let empty_b: Vec<f32> = vec![];
1499        assert_eq!(dot_product(&empty_a, &empty_b), 0.0);
1500
1501        // Test orthogonal vectors
1502        let ortho_a = vec![1.0, 0.0, 0.0];
1503        let ortho_b = vec![0.0, 1.0, 0.0];
1504        assert_eq!(dot_product(&ortho_a, &ortho_b), 0.0);
1505    }
1506
1507    #[test]
1508    fn test_norm_l1() {
1509        let input = vec![-1.0, 2.0, -3.0, 4.0];
1510        let result = norm_l1(&input);
1511        assert_eq!(result, 10.0); // |−1| + |2| + |−3| + |4| = 1 + 2 + 3 + 4 = 10
1512
1513        // Test with empty vector
1514        let empty: Vec<f32> = vec![];
1515        assert_eq!(norm_l1(&empty), 0.0);
1516
1517        // Test with all positive
1518        let positive = vec![1.0, 2.0, 3.0, 4.0];
1519        assert_eq!(norm_l1(&positive), 10.0);
1520    }
1521
1522    #[test]
1523    fn test_norm_l2_squared() {
1524        let input = vec![1.0, 2.0, 3.0, 4.0];
1525        let result = norm_l2_squared(&input);
1526        assert_eq!(result, 30.0); // 1² + 2² + 3² + 4² = 1 + 4 + 9 + 16 = 30
1527
1528        // Test with empty vector
1529        let empty: Vec<f32> = vec![];
1530        assert_eq!(norm_l2_squared(&empty), 0.0);
1531    }
1532
1533    #[test]
1534    fn test_norm_l2() {
1535        let input = vec![3.0, 4.0];
1536        let result = norm_l2(&input);
1537        assert_eq!(result, 5.0); // sqrt(3² + 4²) = sqrt(9 + 16) = sqrt(25) = 5
1538
1539        // Test with unit vector
1540        let unit = vec![1.0, 0.0, 0.0];
1541        assert!((norm_l2(&unit) - 1.0).abs() < EPSILON);
1542    }
1543
1544    #[test]
1545    fn test_variance_vec() {
1546        let input = vec![1.0, 2.0, 3.0, 4.0, 5.0];
1547        let result = variance_vec(&input);
1548        assert!((result - 2.0).abs() < EPSILON); // Population variance = 2.0
1549
1550        // Test with constant values (variance should be 0)
1551        let constant = vec![5.0, 5.0, 5.0, 5.0];
1552        assert!(variance_vec(&constant).abs() < EPSILON);
1553    }
1554
1555    #[test]
1556    fn test_std_dev_vec() {
1557        let input = vec![1.0, 2.0, 3.0, 4.0, 5.0];
1558        let result = std_dev_vec(&input);
1559        assert!((result - 2.0_f32.sqrt()).abs() < EPSILON); // sqrt(2) ≈ 1.41421356
1560
1561        // Test with constant values (std dev should be 0)
1562        let constant = vec![5.0, 5.0, 5.0, 5.0];
1563        assert!(std_dev_vec(&constant).abs() < EPSILON);
1564    }
1565
1566    #[test]
1567    fn test_special_values() {
1568        // Test with infinity
1569        let with_inf = vec![1.0, f32::INFINITY, 3.0];
1570        assert_eq!(sum_vec(&with_inf), f32::INFINITY);
1571        assert_eq!(max_vec(&with_inf), f32::INFINITY);
1572
1573        // Test with NaN
1574        let with_nan = vec![1.0, f32::NAN, 3.0];
1575        assert!(sum_vec(&with_nan).is_nan());
1576        assert!(min_vec(&with_nan).is_nan());
1577    }
1578
1579    #[test]
1580    fn test_large_vectors() {
1581        let size = 10000;
1582        let input: Vec<f32> = (1..=size).map(|i| i as f32).collect();
1583
1584        let expected_sum = (size * (size + 1) / 2) as f32;
1585        let actual_sum = sum_vec(&input);
1586        assert!((actual_sum - expected_sum).abs() < 1e-3);
1587
1588        assert_eq!(min_vec(&input), 1.0);
1589        assert_eq!(max_vec(&input), size as f32);
1590    }
1591
1592    #[test]
1593    #[should_panic(expected = "Input vector must not be empty")]
1594    fn test_min_vec_empty() {
1595        let empty: Vec<f32> = vec![];
1596        min_vec(&empty);
1597    }
1598
1599    #[test]
1600    #[should_panic(expected = "Input vector must not be empty")]
1601    fn test_max_vec_empty() {
1602        let empty: Vec<f32> = vec![];
1603        max_vec(&empty);
1604    }
1605
1606    #[test]
1607    #[should_panic(expected = "Input vectors must have the same length")]
1608    fn test_dot_product_dimension_mismatch() {
1609        let a = vec![1.0, 2.0, 3.0];
1610        let b = vec![4.0, 5.0];
1611        dot_product(&a, &b);
1612    }
1613
1614    #[test]
1615    fn test_mathematical_properties() {
1616        let a = vec![1.0, 2.0, 3.0, 4.0];
1617        let b = vec![5.0, 6.0, 7.0, 8.0];
1618
1619        // Test that dot product is commutative: a·b = b·a
1620        assert_eq!(dot_product(&a, &b), dot_product(&b, &a));
1621
1622        // Test Cauchy-Schwarz inequality: |a·b| ≤ ||a|| ||b||
1623        let dot_ab = dot_product(&a, &b).abs();
1624        let norm_a = norm_l2(&a);
1625        let norm_b = norm_l2(&b);
1626        assert!(dot_ab <= norm_a * norm_b + EPSILON);
1627
1628        // Test triangle inequality: ||a + b|| ≤ ||a|| + ||b||
1629        let sum_ab: Vec<f32> = a.iter().zip(b.iter()).map(|(&x, &y)| x + y).collect();
1630        let norm_sum = norm_l2(&sum_ab);
1631        assert!(norm_sum <= norm_a + norm_b + EPSILON);
1632    }
1633}