scirs2_core/
utils.rs

1//! Utility functions for numerical operations
2//!
3//! This module provides common utility functions used throughout ``SciRS2``.
4
5use crate::error::{CoreError, CoreResult, ErrorContext};
6use ndarray::{Array, Array1, Array2, ArrayBase, Data, Dimension};
7use num_traits::{Float, FromPrimitive, Num, NumCast};
8use std::fmt::Debug;
9
10/// Checks if two floating-point values are approximately equal
11///
12/// # Arguments
13///
14/// * `a` - First value
15/// * `b` - Second value
16/// * `abs_tol` - Absolute tolerance
17/// * `reltol` - Relative tolerance
18///
19/// # Returns
20///
21/// * `true` if the values are approximately equal, `false` otherwise
22#[must_use]
23#[allow(dead_code)]
24pub fn is_close<F: Float>(a: F, b: F, abs_tol: F, reltol: F) -> bool {
25    let abs_diff = (a - b).abs();
26
27    if abs_diff <= abs_tol {
28        true
29    } else {
30        let abs_a = a.abs();
31        let abs_b = b.abs();
32        let max_abs = if abs_a > abs_b { abs_a } else { abs_b };
33
34        abs_diff <= max_abs * reltol
35    }
36}
37
38/// Check if two points are equal within a tolerance
39///
40/// Compares each element of the points to determine if they are
41/// approximately equal within a specified tolerance.
42///
43/// # Arguments
44///
45/// * `point1` - First point as a slice
46/// * `point2` - Second point as a slice
47/// * `tol` - Tolerance (default: 1e-8)
48///
49/// # Returns
50///
51/// * True if points are equal within tolerance
52///
53/// # Examples
54///
55/// ```
56/// use scirs2_core::utils::points_equal;
57///
58/// let point1 = [1.0, 2.0, 3.0];
59/// let point2 = [1.0, 2.0, 3.0];
60/// let point3 = [1.0, 2.0, 3.001];
61///
62/// assert!(points_equal(&point1, &point2, None));
63/// assert!(!points_equal(&point1, &point3, None));
64/// assert!(points_equal(&point1, &point3, Some(0.01)));
65/// ```
66#[must_use]
67#[allow(dead_code)]
68pub fn points_equal<T>(point1: &[T], point2: &[T], tol: Option<T>) -> bool
69where
70    T: PartialOrd + std::ops::Sub<Output = T> + Copy + FromPrimitive + num_traits::Zero,
71{
72    // Check for empty arrays first
73    if point1.is_empty() || point2.is_empty() {
74        return point1.is_empty() && point2.is_empty();
75    }
76
77    // Default tolerance as 1e-8 converted to type T
78    let tol = match tol {
79        Some(t) => t,
80        None => match T::from_f64(1e-8) {
81            Some(t) => t,
82            None => {
83                // Fall back to zero tolerance if conversion fails
84                T::from_f64(0.0).unwrap_or_else(|| {
85                    // If even zero conversion fails, use zero trait method
86                    T::zero()
87                })
88            }
89        },
90    };
91
92    point1.len() == point2.len()
93        && point1.iter().zip(point2.iter()).all(|(&a, &b)| {
94            let diff = if a > b { a - b } else { b - a };
95            diff <= tol
96        })
97}
98
99/// Compare arrays within a tolerance
100///
101/// Compares each element of the arrays to determine if they are
102/// approximately equal within the specified tolerance.
103///
104/// # Arguments
105///
106/// * `array1` - First array
107/// * `array2` - Second array
108/// * `tol` - Tolerance (default: 1e-8)
109///
110/// # Returns
111///
112/// * True if arrays are equal within tolerance
113///
114/// # Examples
115///
116/// ```
117/// use scirs2_core::utils::arrays_equal;
118/// use ndarray::array;
119///
120/// let arr1 = array![[1.0, 2.0], [3.0, 4.0]];
121/// let arr2 = array![[1.0, 2.0], [3.0, 4.0]];
122/// let arr3 = array![[1.0, 2.0], [3.0, 4.001]];
123///
124/// assert!(arrays_equal(&arr1, &arr2, None));
125/// assert!(!arrays_equal(&arr1, &arr3, None));
126/// assert!(arrays_equal(&arr1, &arr3, Some(0.01)));
127/// ```
128#[must_use]
129#[allow(dead_code)]
130pub fn arrays_equal<S1, S2, D, T>(
131    array1: &ArrayBase<S1, D>,
132    array2: &ArrayBase<S2, D>,
133    tol: Option<T>,
134) -> bool
135where
136    S1: Data<Elem = T>,
137    S2: Data<Elem = T>,
138    D: Dimension,
139    T: PartialOrd + std::ops::Sub<Output = T> + Copy + FromPrimitive + num_traits::Zero,
140{
141    if array1.shape() != array2.shape() {
142        return false;
143    }
144
145    let points1: Vec<T> = array1.iter().copied().collect();
146    let points2: Vec<T> = array2.iter().copied().collect();
147
148    points_equal(&points1, &points2, tol)
149}
150
151/// Fills the diagonal of a matrix with a value
152///
153/// # Arguments
154///
155/// * `mut a` - Matrix to modify
156/// * `val` - Value to set on the diagonal
157///
158/// # Returns
159///
160/// * The modified matrix
161#[must_use]
162#[allow(dead_code)]
163pub fn fill_diagonal<T: Clone>(mut a: Array2<T>, val: T) -> Array2<T> {
164    let min_dim = a.nrows().min(a.ncols());
165
166    for i in 0..min_dim {
167        a[[i, i]] = val.clone();
168    }
169
170    a
171}
172
173/// Computes the product of all elements in an iterable
174///
175/// # Arguments
176///
177/// * `iter` - Iterable of values
178///
179/// # Returns
180///
181/// * Product of all elements
182#[must_use]
183#[allow(dead_code)]
184pub fn prod<I, T>(iter: I) -> T
185where
186    I: IntoIterator<Item = T>,
187    T: std::ops::Mul<Output = T> + From<u8>,
188{
189    iter.into_iter().fold(T::from(1), |a, b| a * b)
190}
191
192/// Creates a range of values with a specified step size
193///
194/// # Arguments
195///
196/// * `start` - Start value (inclusive)
197/// * `stop` - Stop value (exclusive)
198/// * `step` - Step size
199///
200/// # Returns
201///
202/// * Vector of values
203#[allow(dead_code)]
204pub fn arange<F: Float + std::iter::Sum>(start: F, end: F, step: F) -> CoreResult<Vec<F>> {
205    if step == F::zero() {
206        return Err(CoreError::ValueError(ErrorContext::new(
207            "Step size cannot be zero".to_string(),
208        )));
209    }
210
211    let mut result = Vec::new();
212    let mut current = start;
213
214    if step > F::zero() {
215        while current < end {
216            result.push(current);
217            current = current + step;
218        }
219    } else {
220        while current > end {
221            result.push(current);
222            current = current + step;
223        }
224    }
225
226    Ok(result)
227}
228
229/// Convenience function that provides the old behavior (panics on error)
230#[must_use]
231#[allow(dead_code)]
232pub fn arange_unchecked<F: Float + std::iter::Sum>(start: F, end: F, step: F) -> Vec<F> {
233    arange(start, end, step).unwrap()
234}
235
236/// Checks if all elements in an iterable satisfy a predicate
237///
238/// # Arguments
239///
240/// * `iter` - Iterable of values
241/// * `predicate` - Function to check each value
242///
243/// # Returns
244///
245/// * `true` if all elements satisfy the predicate, `false` otherwise
246#[must_use]
247#[allow(dead_code)]
248pub fn all<I, T, F>(iter: I, predicate: F) -> bool
249where
250    I: IntoIterator<Item = T>,
251    F: Fn(T) -> bool,
252{
253    iter.into_iter().all(predicate)
254}
255
256/// Checks if any element in an iterable satisfies a predicate
257///
258/// # Arguments
259///
260/// * `iter` - Iterable of values
261/// * `predicate` - Function to check each value
262///
263/// # Returns
264///
265/// * `true` if any element satisfies the predicate, `false` otherwise
266#[must_use]
267#[allow(dead_code)]
268pub fn any<I, T, F>(iter: I, predicate: F) -> bool
269where
270    I: IntoIterator<Item = T>,
271    F: Fn(T) -> bool,
272{
273    iter.into_iter().any(predicate)
274}
275
276/// Creates a linearly spaced array between start and end (inclusive)
277///
278/// This function uses parallel processing when available and
279/// appropriate for better performance.
280///
281/// # Arguments
282///
283/// * `start` - Start value
284/// * `end` - End value (inclusive)
285/// * `num` - Number of points
286///
287/// # Returns
288///
289/// * Array of linearly spaced values
290#[must_use]
291#[allow(dead_code)]
292pub fn linspace<F: Float + std::iter::Sum + Send + Sync>(
293    start: F,
294    end: F,
295    num: usize,
296) -> Array1<F> {
297    if num < 2 {
298        return Array::from_vec(vec![start]);
299    }
300
301    // Use parallel implementation for larger arrays
302    #[cfg(feature = "parallel")]
303    {
304        if num >= 1000 {
305            use crate::parallel_ops::*;
306
307            let step = (end - start) / F::from(num - 1).unwrap();
308            let result: Vec<F> = (0..num)
309                .into_par_iter()
310                .map(|i| {
311                    if i == num - 1 {
312                        // Ensure the last value is exactly end
313                        end
314                    } else {
315                        start + step * F::from(i).unwrap()
316                    }
317                })
318                .collect::<Vec<F>>();
319
320            // The parallel collection doesn't guarantee order, but par_iter does preserve order
321            // when collecting, so this should be fine
322            return Array::from_vec(result);
323        }
324    }
325
326    // Fall back to standard implementation
327    let step = (end - start) / F::from(num - 1).unwrap();
328    let mut result = Vec::with_capacity(num);
329
330    for i in 0..num {
331        let value = start + step * F::from(i).unwrap();
332        result.push(value);
333    }
334
335    // Make sure the last value is exactly end to avoid floating point precision issues
336    if let Some(last) = result.last_mut() {
337        *last = end;
338    }
339
340    Array::from_vec(result)
341}
342
343/// Creates a logarithmically spaced array between base^start and base^end (inclusive)
344///
345/// # Arguments
346///
347/// * `start` - Start exponent
348/// * `end` - End exponent (inclusive)
349/// * `num` - Number of points
350/// * `base` - Base of the logarithm (default: 10.0)
351///
352/// # Returns
353///
354/// * Array of logarithmically spaced values
355#[must_use]
356#[allow(dead_code)]
357pub fn logspace<F: Float + std::iter::Sum + Send + Sync>(
358    start: F,
359    end: F,
360    num: usize,
361    base: Option<F>,
362) -> Array1<F> {
363    let base = base.unwrap_or_else(|| F::from(10.0).unwrap());
364
365    // Generate linearly spaced values in the exponent space
366    let linear = linspace(start, end, num);
367
368    // Convert to logarithmic space
369    linear.mapv(|x| base.powf(x))
370}
371
372/// Compute the element-wise maximum of two arrays
373///
374/// This function uses parallel processing when available and
375/// appropriate for the input arrays.
376///
377/// # Arguments
378///
379/// * `a` - First array
380/// * `b` - Second array
381///
382/// # Returns
383///
384/// * Element-wise maximum
385///
386/// # Panics
387///
388/// * If the arrays have different shapes
389#[must_use]
390#[allow(dead_code)]
391pub fn maximum<S1, S2, D, T>(
392    a: &ndarray::ArrayBase<S1, D>,
393    b: &ndarray::ArrayBase<S2, D>,
394) -> Array<T, D>
395where
396    S1: ndarray::Data<Elem = T>,
397    S2: ndarray::Data<Elem = T>,
398    D: Dimension,
399    T: Num + PartialOrd + Copy + Send + Sync,
400{
401    assert_eq!(
402        a.shape(),
403        b.shape(),
404        "Arrays must have the same shape for element-wise maximum"
405    );
406
407    // Use parallel implementation for larger arrays
408    #[cfg(feature = "parallel")]
409    {
410        if a.len() > 1000 {
411            use crate::parallel_ops::*;
412
413            // Convert to owned arrays for parallel processing
414            let (a_vec_, _) = a.to_owned().into_raw_vec_and_offset();
415            let (b_vec_, _) = b.to_owned().into_raw_vec_and_offset();
416
417            let result_vec: Vec<T> = a_vec_
418                .into_par_iter()
419                .zip(b_vec_.into_par_iter())
420                .map(|(a_val, b_val)| if b_val > a_val { b_val } else { a_val })
421                .collect();
422
423            return Array::from_shape_vec(a.raw_dim(), result_vec)
424                .expect("Shape mismatch in parallel maximum");
425        }
426    }
427
428    // Fall back to standard implementation
429    let mut result = a.to_owned();
430    for (i, elem) in result.iter_mut().enumerate() {
431        if let Some(b_slice) = b.as_slice() {
432            let b_val = b_slice[i];
433            if b_val > *elem {
434                *elem = b_val;
435            }
436        } else {
437            // Handle case where b cannot be converted to slice
438            let b_val = b.iter().nth(i).unwrap();
439            if *b_val > *elem {
440                *elem = *b_val;
441            }
442        }
443    }
444
445    result
446}
447
448/// Compute the element-wise minimum of two arrays
449///
450/// This function uses parallel processing when available and
451/// appropriate for the input arrays.
452///
453/// # Arguments
454///
455/// * `a` - First array
456/// * `b` - Second array
457///
458/// # Returns
459///
460/// * Element-wise minimum
461///
462/// # Panics
463///
464/// * If the arrays have different shapes
465#[must_use]
466#[allow(dead_code)]
467pub fn minimum<S1, S2, D, T>(
468    a: &ndarray::ArrayBase<S1, D>,
469    b: &ndarray::ArrayBase<S2, D>,
470) -> Array<T, D>
471where
472    S1: ndarray::Data<Elem = T>,
473    S2: ndarray::Data<Elem = T>,
474    D: Dimension,
475    T: Num + PartialOrd + Copy + Send + Sync,
476{
477    assert_eq!(
478        a.shape(),
479        b.shape(),
480        "Arrays must have the same shape for element-wise minimum"
481    );
482
483    // Use parallel implementation for larger arrays
484    #[cfg(feature = "parallel")]
485    {
486        if a.len() > 1000 {
487            use crate::parallel_ops::*;
488
489            // Convert to owned arrays for parallel processing
490            let (a_vec_, _) = a.to_owned().into_raw_vec_and_offset();
491            let (b_vec_, _) = b.to_owned().into_raw_vec_and_offset();
492
493            let result_vec: Vec<T> = a_vec_
494                .into_par_iter()
495                .zip(b_vec_.into_par_iter())
496                .map(|(a_val, b_val)| if b_val < a_val { b_val } else { a_val })
497                .collect();
498
499            return Array::from_shape_vec(a.raw_dim(), result_vec)
500                .expect("Shape mismatch in parallel minimum");
501        }
502    }
503
504    // Fall back to standard implementation
505    let mut result = a.to_owned();
506    for (i, elem) in result.iter_mut().enumerate() {
507        if let Some(b_slice) = b.as_slice() {
508            let b_val = b_slice[i];
509            if b_val < *elem {
510                *elem = b_val;
511            }
512        } else {
513            // Handle case where b cannot be converted to slice
514            let b_val = b.iter().nth(i).unwrap();
515            if *b_val < *elem {
516                *elem = *b_val;
517            }
518        }
519    }
520
521    result
522}
523
524/// Normalize a vector to have unit energy or unit peak amplitude.
525///
526/// # Arguments
527///
528/// * `x` - Input vector
529/// * `norm` - Normalization type: energy, "peak", "sum", or "max"
530///
531/// # Returns
532///
533/// * Normalized vector as `Vec<f64>`
534///
535/// # Examples
536///
537/// ```
538/// use scirs2_core::utils::normalize;
539///
540/// // Normalize a vector to unit energy
541/// let signal = vec![1.0, 2.0, 3.0, 4.0];
542/// let normalized = normalize(&signal, "energy").unwrap();
543///
544/// // Sum of squares should be 1.0
545/// let sum_of_squares: f64 = normalized.iter().map(|&x| x * x).sum();
546/// assert!((sum_of_squares - 1.0).abs() < 1e-10);
547/// ```
548///
549/// # Errors
550///
551/// Returns an error if the input signal is empty, has zero energy/peak/sum, or if a conversion fails.
552#[allow(dead_code)]
553pub fn normalize<T>(x: &[T], norm: &str) -> Result<Vec<f64>, &'static str>
554where
555    T: Float + NumCast + Debug,
556{
557    if x.is_empty() {
558        return Err("Input signal is empty");
559    }
560
561    // Convert to f64 for internal processing
562    let x_f64: Vec<f64> = x
563        .iter()
564        .map(|&val| NumCast::from(val).ok_or("Could not convert value to f64"))
565        .collect::<Result<Vec<_>, _>>()?;
566
567    // Normalize based on type
568    match norm.to_lowercase().as_str() {
569        "energy" => {
570            // Normalize to unit energy (sum of squares = 1.0)
571            let sum_of_squares: f64 = x_f64.iter().map(|&x| x * x).sum();
572
573            if sum_of_squares.abs() < f64::EPSILON {
574                return Err("Signal has zero energy, cannot normalize");
575            }
576
577            let scale = 1.0 / sum_of_squares.sqrt();
578            let normalized = x_f64.iter().map(|&x| x * scale).collect();
579
580            Ok(normalized)
581        }
582        "peak" => {
583            // Normalize to unit peak amplitude (max absolute value = 1.0)
584            let peak = x_f64.iter().fold(0.0, |a, &b| a.max(b.abs()));
585
586            if peak.abs() < f64::EPSILON {
587                return Err("Signal has zero peak, cannot normalize");
588            }
589
590            let scale = 1.0 / peak;
591            let normalized = x_f64.iter().map(|&x| x * scale).collect();
592
593            Ok(normalized)
594        }
595        "sum" => {
596            // Normalize to unit sum
597            let sum: f64 = x_f64.iter().sum();
598
599            if sum.abs() < f64::EPSILON {
600                return Err("Signal has zero sum, cannot normalize");
601            }
602
603            let scale = 1.0 / sum;
604            let normalized = x_f64.iter().map(|&x| x * scale).collect();
605
606            Ok(normalized)
607        }
608        "max" => {
609            // Normalize to max value = 1.0 (preserves sign)
610            let max_val = x_f64.iter().fold(0.0, |a, &b| a.max(b.abs()));
611
612            if max_val.abs() < f64::EPSILON {
613                return Err("Signal has zero maximum, cannot normalize");
614            }
615
616            let scale = 1.0 / max_val;
617            let normalized = x_f64.iter().map(|&x| x * scale).collect();
618
619            Ok(normalized)
620        }
621        _ => Err("Unknown normalization type. Supported types: 'energy', 'peak', 'sum', 'max'"),
622    }
623}
624
625/// Pad an array with values according to the specified mode.
626///
627/// # Arguments
628///
629/// * `input` - Input array
630/// * `pad_width` - Width of padding in each dimension (before, after)
631/// * `mode` - Padding mode: constant, "edge", "linear_ramp", "maximum", "mean", "median", "minimum", "reflect", "symmetric", "wrap"
632/// * `constant_value` - Value to use for constant padding (only used for "constant" mode)
633///
634/// # Returns
635///
636/// * Padded array
637///
638/// # Examples
639///
640/// ```
641/// // Example with a 1D array
642/// use scirs2_core::utils::pad_array;
643/// use ndarray::{Array1, array};
644///
645/// let arr = array![1.0, 2.0, 3.0];
646/// let padded = pad_array(&arr, &[(1, 2)], "constant", Some(0.0)).unwrap();
647/// assert_eq!(padded.shape(), &[6]);
648/// assert_eq!(padded, array![0.0, 1.0, 2.0, 3.0, 0.0, 0.0]);
649/// ```
650///
651/// # Errors
652///
653/// Returns an error if the input array is 0-dimensional, if pad_width length doesn't match input dimensions,
654/// or if the padding mode is unsupported for the given array dimensionality.
655#[allow(dead_code)]
656pub fn pad_array<T, D>(
657    input: &Array<T, D>,
658    pad_width: &[(usize, usize)],
659    mode: &str,
660    constant_value: Option<T>,
661) -> Result<Array<T, D>, String>
662where
663    T: Float + FromPrimitive + Debug + Clone,
664    D: Dimension,
665{
666    // Validate inputs
667    if input.ndim() == 0 {
668        return Err("Input array cannot be 0-dimensional".to_string());
669    }
670
671    if pad_width.len() != input.ndim() {
672        return Err(format!(
673            "Pad _width must have same length as input dimensions (got {} expected {})",
674            pad_width.len(),
675            input.ndim()
676        ));
677    }
678
679    // No padding needed - return copy of input
680    if pad_width.iter().all(|&(a, b)| a == 0 && b == 0) {
681        return Ok(input.to_owned());
682    }
683
684    // Calculate new shape
685    let mut newshape = Vec::with_capacity(input.ndim());
686    for (dim, &(pad_before, pad_after)) in pad_width.iter().enumerate().take(input.ndim()) {
687        newshape.push(input.shape()[dim] + pad_before + pad_after);
688    }
689
690    // Create output array with default constant value
691    let const_val = constant_value.unwrap_or_else(|| T::zero());
692    let mut output = Array::<T, D>::from_elem(
693        D::from_dimension(&ndarray::IxDyn(&newshape))
694            .expect("Could not create dimension from shape"),
695        const_val,
696    );
697
698    // For 1D arrays
699    if input.ndim() == 1 {
700        // Convert to Array1 for easier manipulation
701        let inputarray1 = input
702            .view()
703            .into_dimensionality::<ndarray::Ix1>()
704            .map_err(|_| "Failed to convert to 1D array".to_string())?;
705        let mut output_array1 = output
706            .view_mut()
707            .into_dimensionality::<ndarray::Ix1>()
708            .map_err(|_| "Failed to convert output to 1D array".to_string())?;
709
710        let input_len = inputarray1.len();
711        let start = pad_width[0].0;
712
713        // First copy the input to the center region
714        for i in 0..input_len {
715            output_array1[start + i] = inputarray1[i];
716        }
717
718        // Then pad the borders based on the mode
719        match mode.to_lowercase().as_str() {
720            "constant" => {
721                // Already filled with constant value
722            }
723            "edge" => {
724                // Pad left side with first value
725                for i in 0..pad_width[0].0 {
726                    output_array1[i] = inputarray1[0];
727                }
728                // Pad right side with last value
729                let offset = start + input_len;
730                for i in 0..pad_width[0].1 {
731                    output_array1[offset + i] = inputarray1[input_len - 1];
732                }
733            }
734            "reflect" => {
735                // Pad left side
736                for i in 0..pad_width[0].0 {
737                    let src_idx = pad_width[0].0 - i;
738                    if src_idx < input_len {
739                        output_array1[i] = inputarray1[src_idx];
740                    }
741                }
742                // Pad right side
743                let offset = start + input_len;
744                for i in 0..pad_width[0].1 {
745                    let src_idx = input_len - 2 - i;
746                    if src_idx < input_len {
747                        output_array1[offset + i] = inputarray1[src_idx];
748                    }
749                }
750            }
751            "wrap" => {
752                // Pad left side
753                for i in 0..pad_width[0].0 {
754                    let src_idx = (input_len - (pad_width[0].0 - i) % input_len) % input_len;
755                    output_array1[i] = inputarray1[src_idx];
756                }
757                // Pad right side
758                let offset = start + input_len;
759                for i in 0..pad_width[0].1 {
760                    let src_idx = i % input_len;
761                    output_array1[offset + i] = inputarray1[src_idx];
762                }
763            }
764            "maximum" => {
765                // Find maximum value
766                let max_val = inputarray1.iter().fold(T::neg_infinity(), |a, &b| a.max(b));
767
768                // Pad with maximum value
769                for i in 0..pad_width[0].0 {
770                    output_array1[i] = max_val;
771                }
772                let offset = start + input_len;
773                for i in 0..pad_width[0].1 {
774                    output_array1[offset + i] = max_val;
775                }
776            }
777            "minimum" => {
778                // Find minimum value
779                let min_val = inputarray1.iter().fold(T::infinity(), |a, &b| a.min(b));
780
781                // Pad with minimum value
782                for i in 0..pad_width[0].0 {
783                    output_array1[i] = min_val;
784                }
785                let offset = start + input_len;
786                for i in 0..pad_width[0].1 {
787                    output_array1[offset + i] = min_val;
788                }
789            }
790            "mean" => {
791                // Calculate mean value
792                let sum = inputarray1.iter().fold(T::zero(), |a, &b| a + b);
793                let mean_val = sum / T::from_usize(input_len).unwrap();
794
795                // Pad with mean value
796                for i in 0..pad_width[0].0 {
797                    output_array1[i] = mean_val;
798                }
799                let offset = start + input_len;
800                for i in 0..pad_width[0].1 {
801                    output_array1[offset + i] = mean_val;
802                }
803            }
804            _ => return Err(format!("Unsupported padding mode: {mode}")),
805        }
806
807        return Ok(output);
808    }
809
810    // For 2D arrays, we could add specific implementation similar to the above
811    // For now, we'll just return the output with constant padding for 2D and higher
812
813    if mode.to_lowercase() != "constant" {
814        return Err(format!(
815            "Padding mode '{mode}' is not yet implemented for arrays with more than 1 dimension"
816        ));
817    }
818
819    // For higher dimensions, we'll just return a more simplified implementation with
820    // constant padding only for now, and a note that it needs more work
821
822    // We've already created the padded array with constant values,
823    // now just return it since other padding modes for higher dimensions
824    // would require more complex implementation
825    //
826    // NOTE: This is a placeholder implementation that needs to be improved
827    // in the future to support all padding modes for higher dimensions
828
829    Ok(output)
830}
831
832/// Create window functions of various types.
833///
834/// # Arguments
835///
836/// * `window_type` - Type of window function ("hamming", "hanning", "blackman", etc.)
837/// * `length` - Length of the window
838/// * `periodic` - Whether the window should be periodic (default: false)
839///
840/// # Returns
841///
842/// * Window function values as a vector
843///
844/// # Errors
845///
846/// Returns an error if the window length is zero or if the window type is unknown.
847#[allow(dead_code)]
848pub fn generate_window(
849    window_type: &str,
850    length: usize,
851    periodic: bool,
852) -> Result<Vec<f64>, String> {
853    if length == 0 {
854        return Err("Window length must be positive".to_string());
855    }
856
857    let mut window = Vec::with_capacity(length);
858
859    // Adjust length for periodic case
860    let n = if periodic { length + 1 } else { length };
861
862    // Generate window based on _type
863    match window_type.to_lowercase().as_str() {
864        "hamming" => {
865            // Hamming window: 0.54 - 0.46 * cos(2πn/(N-1))
866            for i in 0..length {
867                let w =
868                    0.54 - 0.46 * (2.0 * std::f64::consts::PI * i as f64 / (n - 1) as f64).cos();
869                window.push(w);
870            }
871        }
872        "hanning" | "hann" => {
873            // Hann window: 0.5 * (1 - cos(2πn/(N-1)))
874            for i in 0..length {
875                let w =
876                    0.5 * (1.0 - (2.0 * std::f64::consts::PI * i as f64 / (n - 1) as f64).cos());
877                window.push(w);
878            }
879        }
880        "blackman" => {
881            // Blackman window: 0.42 - 0.5 * cos(2πn/(N-1)) + 0.08 * cos(4πn/(N-1))
882            for i in 0..length {
883                let w = 0.42 - 0.5 * (2.0 * std::f64::consts::PI * i as f64 / (n - 1) as f64).cos()
884                    + 0.08 * (4.0 * std::f64::consts::PI * i as f64 / (n - 1) as f64).cos();
885                window.push(w);
886            }
887        }
888        "bartlett" => {
889            // Bartlett window (triangular window)
890            let m = (n - 1) as f64 / 2.0;
891            for i in 0..length {
892                let w = 1.0 - ((i as f64 - m) / m).abs();
893                window.push(w);
894            }
895        }
896        "boxcar" | "rectangular" => {
897            // Rectangular window (all ones)
898            window.extend(std::iter::repeat_n(1.0, length));
899        }
900        "triang" => {
901            // Triangular window (slightly different from Bartlett)
902            let m = (length - 1) as f64 / 2.0;
903            for i in 0..length {
904                let w = 1.0 - ((i as f64 - m) / (m + 1.0)).abs();
905                window.push(w);
906            }
907        }
908        _ => {
909            return Err(format!("Unknown window type: {window_type}"));
910        }
911    }
912
913    Ok(window)
914}
915
916/// Get window function compatible with SciPy API
917///
918/// This is a wrapper around `generate_window` that returns `CoreResult` for
919/// consistency with other SciRS2 functions.
920///
921/// # Arguments
922///
923/// * `window_type` - Type of window function ("hamming", "hann", "rectangular", etc.)
924/// * `length` - Length of the window
925/// * `periodic` - Whether the window should be periodic (default: false)
926///
927/// # Returns
928///
929/// * Window function values as a vector wrapped in `CoreResult`
930///
931/// # Examples
932///
933/// ```rust
934/// use scirs2_core::utils::get_window;
935///
936/// let hamming = get_window("hamming", 5, false).unwrap();
937/// assert_eq!(hamming.len(), 5);
938/// ```
939#[allow(dead_code)]
940pub fn get_window(window_type: &str, length: usize, periodic: bool) -> CoreResult<Vec<f64>> {
941    generate_window(window_type, length, periodic)
942        .map_err(|e| CoreError::ValueError(ErrorContext::new(e)))
943}
944
945/// Differentiate a function using central difference method.
946///
947/// # Arguments
948///
949/// * `x` - Point at which to evaluate the derivative
950/// * `h` - Step size for the finite difference
951/// * `eval_fn` - Function that evaluates the function at a point
952///
953/// # Returns
954///
955/// * Derivative of the function at x
956///
957/// # Examples
958///
959/// ```
960/// use scirs2_core::utils::differentiate;
961///
962/// // Differentiate f(x) = x^2 at x = 3
963/// let f = |x: f64| -> Result<f64, String> { Ok(x * x) };
964/// let derivative = differentiate(3.0, 0.001, f).unwrap();
965///
966/// // The exact derivative is 2x = 6 at x = 3
967/// assert!((derivative - 6.0).abs() < 1e-5);
968/// ```
969///
970/// # Errors
971///
972/// Returns an error if the evaluation function fails at either x+h or x-h.
973#[allow(dead_code)]
974pub fn differentiate<F, Func>(x: F, h: F, evalfn: Func) -> Result<F, String>
975where
976    F: Float + FromPrimitive + Debug,
977    Func: Fn(F) -> Result<F, String>,
978{
979    // Use central difference for better accuracy
980    let f_plus = evalfn(x + h).map_err(|e| format!("Error evaluating function at x+h: {e}"))?;
981    let f_minus = evalfn(x - h).map_err(|e| format!("Error evaluating function at x-h: {e}"))?;
982    let derivative = (f_plus - f_minus) / (F::from(2.0).unwrap() * h);
983    Ok(derivative)
984}
985
986/// Integrate a function using composite Simpson's rule.
987///
988/// # Arguments
989///
990/// * `a` - Lower bound of integration
991/// * `b` - Upper bound of integration
992/// * `n` - Number of intervals for the quadrature (must be even)
993/// * `eval_fn` - Function that evaluates the function at a point
994///
995/// # Returns
996///
997/// * Definite integral of the function from a to b
998///
999/// # Examples
1000///
1001/// ```
1002/// use scirs2_core::utils::integrate;
1003///
1004/// // Integrate f(x) = x^2 from 0 to 1
1005/// let f = |x: f64| -> Result<f64, String> { Ok(x * x) };
1006/// let integral = integrate(0.0, 1.0, 100, f).unwrap();
1007///
1008/// // The exact integral is x^3/3 = 1/3 from 0 to 1
1009/// assert!((integral - 1.0/3.0).abs() < 1e-5);
1010/// ```
1011///
1012/// # Errors
1013///
1014/// Returns an error if the number of intervals is less than 2, not even, or if the evaluation function fails.
1015#[allow(dead_code)]
1016pub fn integrate<F, Func>(a: F, b: F, n: usize, evalfn: Func) -> Result<F, String>
1017where
1018    F: Float + FromPrimitive + Debug,
1019    Func: Fn(F) -> Result<F, String>,
1020{
1021    if a > b {
1022        return integrate(b, a, n, evalfn).map(|result| -result);
1023    }
1024
1025    // Use composite Simpson's rule for integration
1026    if n < 2 {
1027        return Err("number of intervals must be at least 2".to_string());
1028    }
1029
1030    if n % 2 != 0 {
1031        return Err("number of intervals must be even".to_string());
1032    }
1033
1034    let h = (b - a) / F::from_usize(n).unwrap();
1035    let mut sum = evalfn(a).map_err(|e| format!("Error evaluating function at a: {e}"))?
1036        + evalfn(b).map_err(|e| format!("Error evaluating function at b: {e}"))?;
1037
1038    // Even-indexed points (except endpoints)
1039    for i in 1..n {
1040        if i % 2 == 0 {
1041            let x_i = a + F::from_usize(i).unwrap() * h;
1042            sum = sum
1043                + F::from(2.0).unwrap()
1044                    * evalfn(x_i)
1045                        .map_err(|e| format!("Error evaluating function at x_{i}: {e}"))?;
1046        }
1047    }
1048
1049    // Odd-indexed points
1050    for i in 1..n {
1051        if i % 2 == 1 {
1052            let x_i = a + F::from_usize(i).unwrap() * h;
1053            sum = sum
1054                + F::from(4.0).unwrap()
1055                    * evalfn(x_i)
1056                        .map_err(|e| format!("Error evaluating function at x_{i}: {e}"))?;
1057        }
1058    }
1059
1060    let integral = h * sum / F::from(3.0).unwrap();
1061    Ok(integral)
1062}
1063
1064#[cfg(test)]
1065mod tests {
1066    use super::*;
1067    use approx::assert_relative_eq;
1068    use ndarray::{arr1, array};
1069
1070    #[test]
1071    fn test_is_close() {
1072        assert!(is_close(1.0, 1.0, 1e-10, 1e-10));
1073        assert!(is_close(1.0, 1.0 + 1e-11, 1e-10, 1e-10));
1074        assert!(!is_close(1.0, 1.1, 1e-10, 1e-10));
1075        assert!(is_close(1e10, 1e10 + 1.0, 1e-10, 1e-9));
1076    }
1077
1078    #[test]
1079    fn test_fill_diagonal() {
1080        let a = Array2::<f64>::zeros((3, 3));
1081        let a_diag = fill_diagonal(a, 5.0);
1082
1083        assert_relative_eq!(a_diag[[0, 0]], 5.0);
1084        assert_relative_eq!(a_diag[[1, 1]], 5.0);
1085        assert_relative_eq!(a_diag[[2, 2]], 5.0);
1086        assert_relative_eq!(a_diag[[0, 1]], 0.0);
1087    }
1088
1089    #[test]
1090    fn test_prod() {
1091        assert_eq!(prod(vec![1, 2, 3, 4]), 24);
1092        assert_eq!(prod(vec![2.0, 3.0, 4.0]), 24.0);
1093    }
1094
1095    #[test]
1096    fn test_arange() {
1097        let result = arange(0.0, 5.0, 1.0).unwrap();
1098        assert_eq!(result, vec![0.0, 1.0, 2.0, 3.0, 4.0]);
1099
1100        let result = arange(5.0, 0.0, -1.0).unwrap();
1101        assert_eq!(result, vec![5.0, 4.0, 3.0, 2.0, 1.0]);
1102    }
1103
1104    #[test]
1105    fn test_arange_zero_step() {
1106        let result = arange(0.0, 5.0, 0.0);
1107        assert!(result.is_err());
1108    }
1109
1110    #[test]
1111    fn test_all() {
1112        assert!(all(vec![2, 4, 6, 8], |x| x % 2 == 0));
1113        assert!(!all(vec![2, 4, 5, 8], |x| x % 2 == 0));
1114    }
1115
1116    #[test]
1117    fn test_any() {
1118        assert!(any(vec![1, 2, 3, 4], |x| x % 2 == 0));
1119        assert!(!any(vec![1, 3, 5, 7], |x| x % 2 == 0));
1120    }
1121
1122    #[test]
1123    fn test_linspace() {
1124        let result = linspace(0.0, 1.0, 5);
1125        let expected = arr1(&[0.0, 0.25, 0.5, 0.75, 1.0]);
1126        assert_eq!(result.len(), 5);
1127
1128        for (a, b) in result.iter().zip(expected.iter()) {
1129            assert_relative_eq!(a, b, epsilon = 1e-14);
1130        }
1131
1132        // Test with single value
1133        let result = linspace(5.0, 5.0, 1);
1134        assert_eq!(result.len(), 1);
1135        assert_relative_eq!(result[0], 5.0);
1136
1137        // Test endpoints
1138        let result = linspace(-10.0, 10.0, 5);
1139        assert_relative_eq!(result[0], -10.0);
1140        assert_relative_eq!(result[4], 10.0);
1141    }
1142
1143    #[test]
1144    fn testlogspace() {
1145        let result = logspace(0.0, 3.0, 4, None);
1146        let expected = arr1(&[1.0, 10.0, 100.0, 1000.0]);
1147        assert_eq!(result.len(), 4);
1148
1149        for (a, b) in result.iter().zip(expected.iter()) {
1150            assert_relative_eq!(a, b, epsilon = 1e-10);
1151        }
1152
1153        // Test with custom base
1154        let result = logspace(0.0, 3.0, 4, Some(2.0));
1155        let expected = arr1(&[1.0, 2.0, 4.0, 8.0]);
1156
1157        for (a, b) in result.iter().zip(expected.iter()) {
1158            assert_relative_eq!(a, b, epsilon = 1e-10);
1159        }
1160    }
1161
1162    #[test]
1163    fn test_maximum() {
1164        let a = array![[1, 2], [3, 4]];
1165        let b = array![[5, 1], [7, 2]];
1166
1167        let result = maximum(&a, &b);
1168        let expected = array![[5, 2], [7, 4]];
1169
1170        assert_eq!(result, expected);
1171    }
1172
1173    #[test]
1174    fn test_minimum() {
1175        let a = array![[1, 2], [3, 4]];
1176        let b = array![[5, 1], [7, 2]];
1177
1178        let result = minimum(&a, &b);
1179        let expected = array![[1, 1], [3, 2]];
1180
1181        assert_eq!(result, expected);
1182    }
1183
1184    #[test]
1185    #[should_panic]
1186    fn test_maximum_differentshapes() {
1187        let a = array![[1, 2], [3, 4]];
1188        let b = array![[5, 1, 2], [7, 2, 3]];
1189
1190        let result = maximum(&a, &b);
1191    }
1192
1193    #[test]
1194    fn test_points_equal() {
1195        let point1 = [1.0, 2.0, 3.0];
1196        let point2 = [1.0, 2.0, 3.0];
1197        let point3 = [1.0, 2.0, 3.001];
1198        let point4 = [1.0, 2.0, 4.0];
1199
1200        assert!(points_equal(&point1, &point2, None));
1201        assert!(!points_equal(&point1, &point3, None));
1202        assert!(points_equal(&point1, &point3, Some(0.01)));
1203        assert!(!points_equal(&point1, &point4, Some(0.01)));
1204    }
1205
1206    #[test]
1207    fn test_arrays_equal() {
1208        let arr1 = array![[1.0, 2.0], [3.0, 4.0]];
1209        let arr2 = array![[1.0, 2.0], [3.0, 4.0]];
1210        let arr3 = array![[1.0, 2.0], [3.0, 4.001]];
1211
1212        assert!(arrays_equal(&arr1, &arr2, None));
1213        assert!(!arrays_equal(&arr1, &arr3, None));
1214        assert!(arrays_equal(&arr1, &arr3, Some(0.01)));
1215    }
1216
1217    #[test]
1218    fn test_normalize() {
1219        // Test energy normalization
1220        let signal = vec![1.0, 2.0, 3.0, 4.0];
1221        let normalized = normalize(&signal, "energy").unwrap();
1222
1223        // Sum of squares should be 1.0
1224        let sum_of_squares: f64 = normalized.iter().map(|&x| x * x).sum();
1225        assert_relative_eq!(sum_of_squares, 1.0, epsilon = 1e-10);
1226
1227        // Test peak normalization
1228        let signal = vec![1.0, -2.0, 3.0, -4.0];
1229        let normalized = normalize(&signal, "peak").unwrap();
1230
1231        // Max absolute value should be 1.0
1232        let peak = normalized.iter().fold(0.0, |a, &b| a.max(b.abs()));
1233        assert_relative_eq!(peak, 1.0, epsilon = 1e-10);
1234        assert_relative_eq!(normalized[3], -1.0, epsilon = 1e-10);
1235
1236        // Test sum normalization
1237        let signal = vec![1.0, 2.0, 3.0, 4.0];
1238        let normalized = normalize(&signal, "sum").unwrap();
1239
1240        // Sum should be 1.0
1241        let sum: f64 = normalized.iter().sum();
1242        assert_relative_eq!(sum, 1.0, epsilon = 1e-10);
1243    }
1244
1245    #[test]
1246    fn test_pad_array() {
1247        // Test constant padding on 1D array
1248        let arr = array![1.0, 2.0, 3.0];
1249        let padded = pad_array(&arr, &[(1, 2)], "constant", Some(0.0)).unwrap();
1250
1251        assert_eq!(padded.shape(), &[6]);
1252        assert_eq!(padded, array![0.0, 1.0, 2.0, 3.0, 0.0, 0.0]);
1253
1254        // Test edge padding
1255        let arr = array![1.0, 2.0, 3.0];
1256        let padded = pad_array(&arr, &[(2, 2)], "edge", None).unwrap();
1257
1258        assert_eq!(padded.shape(), &[7]);
1259        assert_eq!(padded, array![1.0, 1.0, 1.0, 2.0, 3.0, 3.0, 3.0]);
1260
1261        // Test maximum padding
1262        let arr = array![1.0, 2.0, 3.0];
1263        let padded = pad_array(&arr, &[(1, 1)], "maximum", None).unwrap();
1264
1265        assert_eq!(padded.shape(), &[5]);
1266        assert_eq!(padded, array![3.0, 1.0, 2.0, 3.0, 3.0]);
1267    }
1268
1269    #[test]
1270    fn test_get_window() {
1271        // Test Hamming window
1272        let window = get_window("hamming", 5, false).unwrap();
1273
1274        assert_eq!(window.len(), 5);
1275        assert!(window[0] > 0.0 && window[0] < 0.6); // First value around 0.54
1276        assert!(window[2] > 0.9); // Middle value close to 1.0
1277
1278        // Test Hann window
1279        let window = get_window("hann", 5, false).unwrap();
1280
1281        assert_eq!(window.len(), 5);
1282        assert!((window[0] - 0.0).abs() < 1e-10);
1283        assert!(window[2] > 0.9); // Middle value close to 1.0
1284
1285        // Test rectangular window
1286        let window = get_window("rectangular", 5, false).unwrap();
1287
1288        assert_eq!(window.len(), 5);
1289        assert!(window.iter().all(|&x| (x - 1.0).abs() < 1e-10));
1290    }
1291
1292    #[test]
1293    fn test_differentiate_integrate() {
1294        // Test differentiation of x^2
1295        let f = |x: f64| -> Result<f64, String> { Ok(x * x) };
1296
1297        let derivative = differentiate(3.0, 0.001, f).unwrap();
1298        assert_relative_eq!(derivative, 6.0, epsilon = 1e-3); // f'(x) = 2x => f'(3) = 6
1299
1300        // Test integration of x^2 from 0 to 1
1301        let integral = integrate(0.0, 1.0, 100, f).unwrap();
1302        assert_relative_eq!(integral, 1.0 / 3.0, epsilon = 1e-5); // ∫x^2 dx = x^3/3 => [0,1] = 1/3
1303
1304        // Test integration of x^2 from 0 to 2
1305        let integral = integrate(0.0, 2.0, 100, f).unwrap();
1306        assert_relative_eq!(integral, 8.0 / 3.0, epsilon = 1e-5); // ∫x^2 dx = x^3/3 => [0,2] = 8/3
1307    }
1308}