scirs2_stats/
math_utils.rs

1//! Mathematical utility functions with SIMD acceleration
2//!
3//! This module provides common mathematical operations optimized with SIMD
4//! when available, with automatic fallback to scalar implementations.
5
6use crate::error::{StatsError, StatsResult};
7use scirs2_core::ndarray::{Array1, ArrayView1};
8use scirs2_core::simd_ops::SimdUnifiedOps;
9
10/// Compute absolute values of array elements (f64, SIMD-accelerated)
11///
12/// Uses scirs2-core's SIMD implementation with AVX2/NEON acceleration
13/// and automatic scalar fallback for unsupported platforms.
14///
15/// # Arguments
16///
17/// * `x` - Input array
18///
19/// # Returns
20///
21/// * `Ok(Array1<f64>)` - Array of absolute values
22/// * `Err(StatsError)` if input validation fails
23///
24/// # Examples
25///
26/// ```
27/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
28/// use scirs2_stats::math_utils::abs_f64;
29///
30/// let data = vec![-3.0, -1.5, 0.0, 2.5, 5.0];
31/// let result = abs_f64(&data)?;
32/// assert_eq!(result.to_vec(), vec![3.0, 1.5, 0.0, 2.5, 5.0]);
33/// # Ok(())
34/// # }
35/// ```
36///
37/// # Performance
38///
39/// - **AVX2 (x86_64)**: Processes 4 f64 elements per cycle
40/// - **NEON (ARM)**: Processes 2 f64 elements per cycle
41/// - **Scalar fallback**: Available on all platforms
42/// - **Speedup**: 1.5-2x for arrays with 1000+ elements
43pub fn abs_f64(x: &[f64]) -> StatsResult<Array1<f64>> {
44    if x.is_empty() {
45        return Err(StatsError::InvalidArgument(
46            "Input array cannot be empty".to_string(),
47        ));
48    }
49
50    // Check for finite values
51    for (i, &val) in x.iter().enumerate() {
52        if !val.is_finite() {
53            return Err(StatsError::InvalidArgument(format!(
54                "Input contains non-finite value {} at index {}",
55                val, i
56            )));
57        }
58    }
59
60    // Use scirs2-core SIMD implementation
61    let x_view = ArrayView1::from(x);
62    let result = f64::simd_abs(&x_view);
63
64    Ok(result)
65}
66
67/// Compute absolute values of array elements (f32, SIMD-accelerated)
68///
69/// f32 variant provides better SIMD performance (8 elements/cycle on AVX2).
70///
71/// # Arguments
72///
73/// * `x` - Input array
74///
75/// # Returns
76///
77/// * `Ok(Array1<f32>)` - Array of absolute values
78/// * `Err(StatsError)` if input validation fails
79///
80/// # Examples
81///
82/// ```
83/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
84/// use scirs2_stats::math_utils::abs_f32;
85///
86/// let data = vec![-3.0f32, -1.5, 0.0, 2.5, 5.0];
87/// let result = abs_f32(&data)?;
88/// assert_eq!(result.to_vec(), vec![3.0f32, 1.5, 0.0, 2.5, 5.0]);
89/// # Ok(())
90/// # }
91/// ```
92///
93/// # Performance
94///
95/// - **AVX2 (x86_64)**: Processes 8 f32 elements per cycle
96/// - **NEON (ARM)**: Processes 4 f32 elements per cycle
97/// - **Scalar fallback**: Available on all platforms
98/// - **Speedup**: 2-3x for arrays with 1000+ elements (better than f64)
99pub fn abs_f32(x: &[f32]) -> StatsResult<Array1<f32>> {
100    if x.is_empty() {
101        return Err(StatsError::InvalidArgument(
102            "Input array cannot be empty".to_string(),
103        ));
104    }
105
106    // Check for finite values
107    for (i, &val) in x.iter().enumerate() {
108        if !val.is_finite() {
109            return Err(StatsError::InvalidArgument(format!(
110                "Input contains non-finite value {} at index {}",
111                val, i
112            )));
113        }
114    }
115
116    // Use scirs2-core SIMD implementation
117    let x_view = ArrayView1::from(x);
118    let result = f32::simd_abs(&x_view);
119
120    Ok(result)
121}
122
123/// Compute sign of array elements (f64, SIMD-accelerated)
124///
125/// Returns -1 for negative values, 0 for zero, +1 for positive values.
126/// Uses scirs2-core's SIMD implementation with AVX2/NEON acceleration.
127///
128/// # Arguments
129///
130/// * `x` - Input array
131///
132/// # Returns
133///
134/// * `Ok(Array1<f64>)` - Array of signs (-1.0, 0.0, or 1.0)
135/// * `Err(StatsError)` if input validation fails
136///
137/// # Examples
138///
139/// ```
140/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
141/// use scirs2_stats::math_utils::sign_f64;
142///
143/// let data = vec![-3.0, -0.5, 0.0, 1.5, 5.0];
144/// let result = sign_f64(&data)?;
145/// assert_eq!(result.to_vec(), vec![-1.0, -1.0, 0.0, 1.0, 1.0]);
146/// # Ok(())
147/// # }
148/// ```
149///
150/// # Performance
151///
152/// - **AVX2 (x86_64)**: Processes 4 f64 elements per cycle
153/// - **NEON (ARM)**: Processes 2 f64 elements per cycle
154/// - **Scalar fallback**: Available on all platforms
155/// - **Speedup**: 1.5-2x for arrays with 1000+ elements
156pub fn sign_f64(x: &[f64]) -> StatsResult<Array1<f64>> {
157    if x.is_empty() {
158        return Err(StatsError::InvalidArgument(
159            "Input array cannot be empty".to_string(),
160        ));
161    }
162
163    // Check for finite values
164    for (i, &val) in x.iter().enumerate() {
165        if !val.is_finite() {
166            return Err(StatsError::InvalidArgument(format!(
167                "Input contains non-finite value {} at index {}",
168                val, i
169            )));
170        }
171    }
172
173    // Use scirs2-core SIMD implementation
174    let x_view = ArrayView1::from(x);
175    let result = f64::simd_sign(&x_view);
176
177    Ok(result)
178}
179
180/// Compute sign of array elements (f32, SIMD-accelerated)
181///
182/// Returns -1 for negative values, 0 for zero, +1 for positive values.
183/// f32 variant provides better SIMD performance (8 elements/cycle on AVX2).
184///
185/// # Arguments
186///
187/// * `x` - Input array
188///
189/// # Returns
190///
191/// * `Ok(Array1<f32>)` - Array of signs (-1.0, 0.0, or 1.0)
192/// * `Err(StatsError)` if input validation fails
193///
194/// # Examples
195///
196/// ```
197/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
198/// use scirs2_stats::math_utils::sign_f32;
199///
200/// let data = vec![-3.0f32, -0.5, 0.0, 1.5, 5.0];
201/// let result = sign_f32(&data)?;
202/// assert_eq!(result.to_vec(), vec![-1.0f32, -1.0, 0.0, 1.0, 1.0]);
203/// # Ok(())
204/// # }
205/// ```
206///
207/// # Performance
208///
209/// - **AVX2 (x86_64)**: Processes 8 f32 elements per cycle
210/// - **NEON (ARM)**: Processes 4 f32 elements per cycle
211/// - **Scalar fallback**: Available on all platforms
212/// - **Speedup**: 2-3x for arrays with 1000+ elements (better than f64)
213pub fn sign_f32(x: &[f32]) -> StatsResult<Array1<f32>> {
214    if x.is_empty() {
215        return Err(StatsError::InvalidArgument(
216            "Input array cannot be empty".to_string(),
217        ));
218    }
219
220    // Check for finite values
221    for (i, &val) in x.iter().enumerate() {
222        if !val.is_finite() {
223            return Err(StatsError::InvalidArgument(format!(
224                "Input contains non-finite value {} at index {}",
225                val, i
226            )));
227        }
228    }
229
230    // Use scirs2-core SIMD implementation
231    let x_view = ArrayView1::from(x);
232    let result = f32::simd_sign(&x_view);
233
234    Ok(result)
235}
236
237#[cfg(test)]
238mod tests {
239    use super::*;
240    use approx::assert_abs_diff_eq;
241
242    #[test]
243    fn test_abs_f64_basic() {
244        let data = vec![-3.0, -1.5, 0.0, 2.5, 5.0];
245        let result = abs_f64(&data).expect("Operation failed");
246        let expected = vec![3.0, 1.5, 0.0, 2.5, 5.0];
247        for (a, b) in result.iter().zip(expected.iter()) {
248            assert_abs_diff_eq!(a, b, epsilon = 1e-10);
249        }
250    }
251
252    #[test]
253    fn test_abs_f32_basic() {
254        let data = vec![-3.0f32, -1.5, 0.0, 2.5, 5.0];
255        let result = abs_f32(&data).expect("Operation failed");
256        let expected = vec![3.0f32, 1.5, 0.0, 2.5, 5.0];
257        for (a, b) in result.iter().zip(expected.iter()) {
258            assert_abs_diff_eq!(a, b, epsilon = 1e-6);
259        }
260    }
261
262    #[test]
263    fn test_abs_f64_large() {
264        // Test with large array to ensure SIMD path is used
265        let data: Vec<f64> = (0..10000).map(|i| i as f64 - 5000.0).collect();
266        let result = abs_f64(&data).expect("Operation failed");
267        for (i, &val) in result.iter().enumerate() {
268            let expected = (data[i]).abs();
269            assert_abs_diff_eq!(val, expected, epsilon = 1e-10);
270        }
271    }
272
273    #[test]
274    fn test_sign_f64_basic() {
275        let data = vec![-3.0, -0.5, 0.0, 1.5, 5.0];
276        let result = sign_f64(&data).expect("Operation failed");
277        let expected = vec![-1.0, -1.0, 0.0, 1.0, 1.0];
278        for (a, b) in result.iter().zip(expected.iter()) {
279            assert_abs_diff_eq!(a, b, epsilon = 1e-10);
280        }
281    }
282
283    #[test]
284    fn test_sign_f32_basic() {
285        let data = vec![-3.0f32, -0.5, 0.0, 1.5, 5.0];
286        let result = sign_f32(&data).expect("Operation failed");
287        let expected = vec![-1.0f32, -1.0, 0.0, 1.0, 1.0];
288        for (a, b) in result.iter().zip(expected.iter()) {
289            assert_abs_diff_eq!(a, b, epsilon = 1e-6);
290        }
291    }
292
293    #[test]
294    fn test_sign_f64_large() {
295        // Test with large array to ensure SIMD path is used
296        let data: Vec<f64> = (0..10000).map(|i| i as f64 - 5000.0).collect();
297        let result = sign_f64(&data).expect("Operation failed");
298        for (i, &val) in result.iter().enumerate() {
299            let expected = if data[i] > 0.0 {
300                1.0
301            } else if data[i] < 0.0 {
302                -1.0
303            } else {
304                0.0
305            };
306            assert_abs_diff_eq!(val, expected, epsilon = 1e-10);
307        }
308    }
309
310    #[test]
311    fn test_abs_empty() {
312        let data: Vec<f64> = vec![];
313        let result = abs_f64(&data);
314        assert!(result.is_err());
315    }
316
317    #[test]
318    fn test_sign_empty() {
319        let data: Vec<f64> = vec![];
320        let result = sign_f64(&data);
321        assert!(result.is_err());
322    }
323
324    #[test]
325    fn test_abs_nonfinite() {
326        let data = vec![1.0, f64::NAN, 3.0];
327        let result = abs_f64(&data);
328        assert!(result.is_err());
329
330        let data = vec![1.0, f64::INFINITY, 3.0];
331        let result = abs_f64(&data);
332        assert!(result.is_err());
333    }
334
335    #[test]
336    fn test_sign_nonfinite() {
337        let data = vec![1.0, f64::NAN, 3.0];
338        let result = sign_f64(&data);
339        assert!(result.is_err());
340
341        let data = vec![1.0, f64::INFINITY, 3.0];
342        let result = sign_f64(&data);
343        assert!(result.is_err());
344    }
345
346    #[test]
347    fn test_abs_all_positive() {
348        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
349        let result = abs_f64(&data).expect("Operation failed");
350        assert_eq!(result.to_vec(), data);
351    }
352
353    #[test]
354    fn test_sign_all_positive() {
355        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
356        let result = sign_f64(&data).expect("Operation failed");
357        let expected = vec![1.0; 5];
358        assert_eq!(result.to_vec(), expected);
359    }
360
361    #[test]
362    fn test_sign_all_negative() {
363        let data = vec![-1.0, -2.0, -3.0, -4.0, -5.0];
364        let result = sign_f64(&data).expect("Operation failed");
365        let expected = vec![-1.0; 5];
366        assert_eq!(result.to_vec(), expected);
367    }
368
369    #[test]
370    fn test_sign_all_zero() {
371        let data = vec![0.0; 100];
372        let result = sign_f64(&data).expect("Operation failed");
373        let expected = vec![0.0; 100];
374        assert_eq!(result.to_vec(), expected);
375    }
376}