Skip to main content

scirs2_core/
utils.rs

1//! Utility functions for numerical operations
2//!
3//! This module provides common utility functions used throughout ``SciRS2``.
4
5use crate::error::{CoreError, CoreResult, ErrorContext};
6use ::ndarray::{Array, Array1, Array2, ArrayBase, Data, Dimension};
7use num_traits::{Float, FromPrimitive, Num, NumCast};
8use std::fmt::Debug;
9
10/// Convert a constant f64 value to a generic float type
11///
12/// # Arguments
13///
14/// * `value` - The f64 value to convert
15///
16/// # Returns
17///
18/// * The value converted to type `F`
19///
20/// # Panics
21///
22/// Panics if the conversion fails
23#[inline]
24pub fn const_f64<F: Float + NumCast>(value: f64) -> F {
25    F::from(value).expect("Failed to convert constant to target float type")
26}
27
28/// Checks if two floating-point values are approximately equal
29///
30/// # Arguments
31///
32/// * `a` - First value
33/// * `b` - Second value
34/// * `abs_tol` - Absolute tolerance
35/// * `reltol` - Relative tolerance
36///
37/// # Returns
38///
39/// * `true` if the values are approximately equal, `false` otherwise
40#[must_use]
41#[allow(dead_code)]
42pub fn is_close<F: Float>(a: F, b: F, abs_tol: F, reltol: F) -> bool {
43    let abs_diff = (a - b).abs();
44
45    if abs_diff <= abs_tol {
46        true
47    } else {
48        let abs_a = a.abs();
49        let abs_b = b.abs();
50        let max_abs = if abs_a > abs_b { abs_a } else { abs_b };
51
52        abs_diff <= max_abs * reltol
53    }
54}
55
56/// Check if two points are equal within a tolerance
57///
58/// Compares each element of the points to determine if they are
59/// approximately equal within a specified tolerance.
60///
61/// # Arguments
62///
63/// * `point1` - First point as a slice
64/// * `point2` - Second point as a slice
65/// * `tol` - Tolerance (default: 1e-8)
66///
67/// # Returns
68///
69/// * True if points are equal within tolerance
70///
71/// # Examples
72///
73/// ```
74/// use scirs2_core::utils::points_equal;
75///
76/// let point1 = [1.0, 2.0, 3.0];
77/// let point2 = [1.0, 2.0, 3.0];
78/// let point3 = [1.0, 2.0, 3.001];
79///
80/// assert!(points_equal(&point1, &point2, None));
81/// assert!(!points_equal(&point1, &point3, None));
82/// assert!(points_equal(&point1, &point3, Some(0.01)));
83/// ```
84#[must_use]
85#[allow(dead_code)]
86pub fn points_equal<T>(point1: &[T], point2: &[T], tol: Option<T>) -> bool
87where
88    T: PartialOrd + std::ops::Sub<Output = T> + Copy + FromPrimitive + num_traits::Zero,
89{
90    // Check for empty arrays first
91    if point1.is_empty() || point2.is_empty() {
92        return point1.is_empty() && point2.is_empty();
93    }
94
95    // Default tolerance as 1e-8 converted to type T
96    let tol = match tol {
97        Some(t) => t,
98        None => match T::from_f64(1e-8) {
99            Some(t) => t,
100            None => {
101                // Fall back to zero tolerance if conversion fails
102                T::from_f64(0.0).unwrap_or_else(|| {
103                    // If even zero conversion fails, use zero trait method
104                    T::zero()
105                })
106            }
107        },
108    };
109
110    point1.len() == point2.len()
111        && point1.iter().zip(point2.iter()).all(|(&a, &b)| {
112            let diff = if a > b { a - b } else { b - a };
113            diff <= tol
114        })
115}
116
117/// Compare arrays within a tolerance
118///
119/// Compares each element of the arrays to determine if they are
120/// approximately equal within the specified tolerance.
121///
122/// # Arguments
123///
124/// * `array1` - First array
125/// * `array2` - Second array
126/// * `tol` - Tolerance (default: 1e-8)
127///
128/// # Returns
129///
130/// * True if arrays are equal within tolerance
131///
132/// # Examples
133///
134/// ```
135/// use scirs2_core::utils::arrays_equal;
136/// use ::ndarray::array;
137///
138/// let arr1 = array![[1.0, 2.0], [3.0, 4.0]];
139/// let arr2 = array![[1.0, 2.0], [3.0, 4.0]];
140/// let arr3 = array![[1.0, 2.0], [3.0, 4.001]];
141///
142/// assert!(arrays_equal(&arr1, &arr2, None));
143/// assert!(!arrays_equal(&arr1, &arr3, None));
144/// assert!(arrays_equal(&arr1, &arr3, Some(0.01)));
145/// ```
146#[must_use]
147#[allow(dead_code)]
148pub fn arrays_equal<S1, S2, D, T>(
149    array1: &ArrayBase<S1, D>,
150    array2: &ArrayBase<S2, D>,
151    tol: Option<T>,
152) -> bool
153where
154    S1: Data<Elem = T>,
155    S2: Data<Elem = T>,
156    D: Dimension,
157    T: PartialOrd + std::ops::Sub<Output = T> + Copy + FromPrimitive + num_traits::Zero,
158{
159    if array1.shape() != array2.shape() {
160        return false;
161    }
162
163    let points1: Vec<T> = array1.iter().copied().collect();
164    let points2: Vec<T> = array2.iter().copied().collect();
165
166    points_equal(&points1, &points2, tol)
167}
168
169/// Fills the diagonal of a matrix with a value
170///
171/// # Arguments
172///
173/// * `mut a` - Matrix to modify
174/// * `val` - Value to set on the diagonal
175///
176/// # Returns
177///
178/// * The modified matrix
179#[must_use]
180#[allow(dead_code)]
181pub fn fill_diagonal<T: Clone>(mut a: Array2<T>, val: T) -> Array2<T> {
182    let min_dim = a.nrows().min(a.ncols());
183
184    for i in 0..min_dim {
185        a[[i, i]] = val.clone();
186    }
187
188    a
189}
190
191/// Computes the product of all elements in an iterable
192///
193/// # Arguments
194///
195/// * `iter` - Iterable of values
196///
197/// # Returns
198///
199/// * Product of all elements
200#[must_use]
201#[allow(dead_code)]
202pub fn prod<I, T>(iter: I) -> T
203where
204    I: IntoIterator<Item = T>,
205    T: std::ops::Mul<Output = T> + From<u8>,
206{
207    iter.into_iter().fold(T::from(1), |a, b| a * b)
208}
209
210/// Creates a range of values with a specified step size
211///
212/// # Arguments
213///
214/// * `start` - Start value (inclusive)
215/// * `stop` - Stop value (exclusive)
216/// * `step` - Step size
217///
218/// # Returns
219///
220/// * Vector of values
221#[allow(dead_code)]
222pub fn arange<F: Float + std::iter::Sum>(start: F, end: F, step: F) -> CoreResult<Vec<F>> {
223    if step == F::zero() {
224        return Err(CoreError::ValueError(ErrorContext::new(
225            "Step size cannot be zero".to_string(),
226        )));
227    }
228
229    let mut result = Vec::new();
230    let mut current = start;
231
232    if step > F::zero() {
233        while current < end {
234            result.push(current);
235            current = current + step;
236        }
237    } else {
238        while current > end {
239            result.push(current);
240            current = current + step;
241        }
242    }
243
244    Ok(result)
245}
246
247/// Convenience function that provides the old behavior (panics on error)
248#[must_use]
249#[allow(dead_code)]
250pub fn arange_unchecked<F: Float + std::iter::Sum>(start: F, end: F, step: F) -> Vec<F> {
251    arange(start, end, step).expect("Operation failed")
252}
253
254/// Checks if all elements in an iterable satisfy a predicate
255///
256/// # Arguments
257///
258/// * `iter` - Iterable of values
259/// * `predicate` - Function to check each value
260///
261/// # Returns
262///
263/// * `true` if all elements satisfy the predicate, `false` otherwise
264#[must_use]
265#[allow(dead_code)]
266pub fn all<I, T, F>(iter: I, predicate: F) -> bool
267where
268    I: IntoIterator<Item = T>,
269    F: Fn(T) -> bool,
270{
271    iter.into_iter().all(predicate)
272}
273
274/// Checks if any element in an iterable satisfies a predicate
275///
276/// # Arguments
277///
278/// * `iter` - Iterable of values
279/// * `predicate` - Function to check each value
280///
281/// # Returns
282///
283/// * `true` if any element satisfies the predicate, `false` otherwise
284#[must_use]
285#[allow(dead_code)]
286pub fn any<I, T, F>(iter: I, predicate: F) -> bool
287where
288    I: IntoIterator<Item = T>,
289    F: Fn(T) -> bool,
290{
291    iter.into_iter().any(predicate)
292}
293
294/// Creates a linearly spaced array between start and end (inclusive)
295///
296/// This function uses parallel processing when available and
297/// appropriate for better performance.
298///
299/// # Arguments
300///
301/// * `start` - Start value
302/// * `end` - End value (inclusive)
303/// * `num` - Number of points
304///
305/// # Returns
306///
307/// * Array of linearly spaced values
308#[must_use]
309#[allow(dead_code)]
310pub fn linspace<F: Float + std::iter::Sum + Send + Sync>(
311    start: F,
312    end: F,
313    num: usize,
314) -> Array1<F> {
315    if num < 2 {
316        return Array::from_vec(vec![start]);
317    }
318
319    // Use parallel implementation for larger arrays
320    #[cfg(feature = "parallel")]
321    {
322        if num >= 1000 {
323            use crate::parallel_ops::*;
324
325            let step = (end - start) / F::from(num - 1).expect("Failed to convert to float");
326            let result: Vec<F> = (0..num)
327                .into_par_iter()
328                .map(|i| {
329                    if i == num - 1 {
330                        // Ensure the last value is exactly end
331                        end
332                    } else {
333                        start + step * F::from(i).expect("Failed to convert to float")
334                    }
335                })
336                .collect::<Vec<F>>();
337
338            // The parallel collection doesn't guarantee order, but par_iter does preserve order
339            // when collecting, so this should be fine
340            return Array::from_vec(result);
341        }
342    }
343
344    // Fall back to standard implementation
345    let step = (end - start) / F::from(num - 1).expect("Failed to convert to float");
346    let mut result = Vec::with_capacity(num);
347
348    for i in 0..num {
349        let value = start + step * F::from(i).expect("Failed to convert to float");
350        result.push(value);
351    }
352
353    // Make sure the last value is exactly end to avoid floating point precision issues
354    if let Some(last) = result.last_mut() {
355        *last = end;
356    }
357
358    Array::from_vec(result)
359}
360
361/// Creates a logarithmically spaced array between base^start and base^end (inclusive)
362///
363/// # Arguments
364///
365/// * `start` - Start exponent
366/// * `end` - End exponent (inclusive)
367/// * `num` - Number of points
368/// * `base` - Base of the logarithm (default: 10.0)
369///
370/// # Returns
371///
372/// * Array of logarithmically spaced values
373#[must_use]
374#[allow(dead_code)]
375pub fn logspace<F: Float + std::iter::Sum + Send + Sync>(
376    start: F,
377    end: F,
378    num: usize,
379    base: Option<F>,
380) -> Array1<F> {
381    let base = base.unwrap_or_else(|| F::from(10.0).expect("Failed to convert constant to float"));
382
383    // Generate linearly spaced values in the exponent space
384    let linear = linspace(start, end, num);
385
386    // Convert to logarithmic space
387    linear.mapv(|x| base.powf(x))
388}
389
390/// Compute the element-wise maximum of two arrays
391///
392/// This function uses parallel processing when available and
393/// appropriate for the input arrays.
394///
395/// # Arguments
396///
397/// * `a` - First array
398/// * `b` - Second array
399///
400/// # Returns
401///
402/// * Element-wise maximum
403///
404/// # Panics
405///
406/// * If the arrays have different shapes
407#[must_use]
408#[allow(dead_code)]
409pub fn maximum<S1, S2, D, T>(
410    a: &crate::ndarray::ArrayBase<S1, D>,
411    b: &crate::ndarray::ArrayBase<S2, D>,
412) -> Array<T, D>
413where
414    S1: crate::ndarray::Data<Elem = T>,
415    S2: crate::ndarray::Data<Elem = T>,
416    D: Dimension,
417    T: Num + PartialOrd + Copy + Send + Sync,
418{
419    assert_eq!(
420        a.shape(),
421        b.shape(),
422        "Arrays must have the same shape for element-wise maximum"
423    );
424
425    // Use parallel implementation for larger arrays
426    #[cfg(feature = "parallel")]
427    {
428        if a.len() > 1000 {
429            use crate::parallel_ops::*;
430
431            // Convert to owned arrays for parallel processing
432            let (a_vec_, _) = a.to_owned().into_raw_vec_and_offset();
433            let (b_vec_, _) = b.to_owned().into_raw_vec_and_offset();
434
435            let result_vec: Vec<T> = a_vec_
436                .into_par_iter()
437                .zip(b_vec_.into_par_iter())
438                .map(|(a_val, b_val)| if b_val > a_val { b_val } else { a_val })
439                .collect();
440
441            return Array::from_shape_vec(a.raw_dim(), result_vec)
442                .expect("Shape mismatch in parallel maximum");
443        }
444    }
445
446    // Fall back to standard implementation
447    let mut result = a.to_owned();
448    for (i, elem) in result.iter_mut().enumerate() {
449        if let Some(b_slice) = b.as_slice() {
450            let b_val = b_slice[i];
451            if b_val > *elem {
452                *elem = b_val;
453            }
454        } else {
455            // Handle case where b cannot be converted to slice
456            let b_val = b.iter().nth(i).expect("Operation failed");
457            if *b_val > *elem {
458                *elem = *b_val;
459            }
460        }
461    }
462
463    result
464}
465
466/// Compute the element-wise minimum of two arrays
467///
468/// This function uses parallel processing when available and
469/// appropriate for the input arrays.
470///
471/// # Arguments
472///
473/// * `a` - First array
474/// * `b` - Second array
475///
476/// # Returns
477///
478/// * Element-wise minimum
479///
480/// # Panics
481///
482/// * If the arrays have different shapes
483#[must_use]
484#[allow(dead_code)]
485pub fn minimum<S1, S2, D, T>(
486    a: &crate::ndarray::ArrayBase<S1, D>,
487    b: &crate::ndarray::ArrayBase<S2, D>,
488) -> Array<T, D>
489where
490    S1: crate::ndarray::Data<Elem = T>,
491    S2: crate::ndarray::Data<Elem = T>,
492    D: Dimension,
493    T: Num + PartialOrd + Copy + Send + Sync,
494{
495    assert_eq!(
496        a.shape(),
497        b.shape(),
498        "Arrays must have the same shape for element-wise minimum"
499    );
500
501    // Use parallel implementation for larger arrays
502    #[cfg(feature = "parallel")]
503    {
504        if a.len() > 1000 {
505            use crate::parallel_ops::*;
506
507            // Convert to owned arrays for parallel processing
508            let (a_vec_, _) = a.to_owned().into_raw_vec_and_offset();
509            let (b_vec_, _) = b.to_owned().into_raw_vec_and_offset();
510
511            let result_vec: Vec<T> = a_vec_
512                .into_par_iter()
513                .zip(b_vec_.into_par_iter())
514                .map(|(a_val, b_val)| if b_val < a_val { b_val } else { a_val })
515                .collect();
516
517            return Array::from_shape_vec(a.raw_dim(), result_vec)
518                .expect("Shape mismatch in parallel minimum");
519        }
520    }
521
522    // Fall back to standard implementation
523    let mut result = a.to_owned();
524    for (i, elem) in result.iter_mut().enumerate() {
525        if let Some(b_slice) = b.as_slice() {
526            let b_val = b_slice[i];
527            if b_val < *elem {
528                *elem = b_val;
529            }
530        } else {
531            // Handle case where b cannot be converted to slice
532            let b_val = b.iter().nth(i).expect("Operation failed");
533            if *b_val < *elem {
534                *elem = *b_val;
535            }
536        }
537    }
538
539    result
540}
541
542/// Normalize a vector to have unit energy or unit peak amplitude.
543///
544/// # Arguments
545///
546/// * `x` - Input vector
547/// * `norm` - Normalization type: energy, "peak", "sum", or "max"
548///
549/// # Returns
550///
551/// * Normalized vector as `Vec<f64>`
552///
553/// # Examples
554///
555/// ```
556/// use scirs2_core::utils::normalize;
557///
558/// // Normalize a vector to unit energy
559/// let signal = vec![1.0, 2.0, 3.0, 4.0];
560/// let normalized = normalize(&signal, "energy").expect("Operation failed");
561///
562/// // Sum of squares should be 1.0
563/// let sum_of_squares: f64 = normalized.iter().map(|&x| x * x).sum();
564/// assert!((sum_of_squares - 1.0).abs() < 1e-10);
565/// ```
566///
567/// # Errors
568///
569/// Returns an error if the input signal is empty, has zero energy/peak/sum, or if a conversion fails.
570#[allow(dead_code)]
571pub fn normalize<T>(x: &[T], norm: &str) -> Result<Vec<f64>, &'static str>
572where
573    T: Float + NumCast + Debug,
574{
575    if x.is_empty() {
576        return Err("Input signal is empty");
577    }
578
579    // Convert to f64 for internal processing
580    let x_f64: Vec<f64> = x
581        .iter()
582        .map(|&val| NumCast::from(val).ok_or("Could not convert value to f64"))
583        .collect::<Result<Vec<_>, _>>()?;
584
585    // Normalize based on type
586    match norm.to_lowercase().as_str() {
587        "energy" => {
588            // Normalize to unit energy (sum of squares = 1.0)
589            let sum_of_squares: f64 = x_f64.iter().map(|&x| x * x).sum();
590
591            if sum_of_squares.abs() < f64::EPSILON {
592                return Err("Signal has zero energy, cannot normalize");
593            }
594
595            let scale = 1.0 / sum_of_squares.sqrt();
596            let normalized = x_f64.iter().map(|&x| x * scale).collect();
597
598            Ok(normalized)
599        }
600        "peak" => {
601            // Normalize to unit peak amplitude (max absolute value = 1.0)
602            let peak = x_f64.iter().fold(0.0, |a, &b| a.max(b.abs()));
603
604            if peak.abs() < f64::EPSILON {
605                return Err("Signal has zero peak, cannot normalize");
606            }
607
608            let scale = 1.0 / peak;
609            let normalized = x_f64.iter().map(|&x| x * scale).collect();
610
611            Ok(normalized)
612        }
613        "sum" => {
614            // Normalize to unit sum
615            let sum: f64 = x_f64.iter().sum();
616
617            if sum.abs() < f64::EPSILON {
618                return Err("Signal has zero sum, cannot normalize");
619            }
620
621            let scale = 1.0 / sum;
622            let normalized = x_f64.iter().map(|&x| x * scale).collect();
623
624            Ok(normalized)
625        }
626        "max" => {
627            // Normalize to max value = 1.0 (preserves sign)
628            let max_val = x_f64.iter().fold(0.0, |a, &b| a.max(b.abs()));
629
630            if max_val.abs() < f64::EPSILON {
631                return Err("Signal has zero maximum, cannot normalize");
632            }
633
634            let scale = 1.0 / max_val;
635            let normalized = x_f64.iter().map(|&x| x * scale).collect();
636
637            Ok(normalized)
638        }
639        _ => Err("Unknown normalization type. Supported types: 'energy', 'peak', 'sum', 'max'"),
640    }
641}
642
643/// Pad an array with values according to the specified mode.
644///
645/// # Arguments
646///
647/// * `input` - Input array
648/// * `pad_width` - Width of padding in each dimension (before, after)
649/// * `mode` - Padding mode: constant, "edge", "linear_ramp", "maximum", "mean", "median", "minimum", "reflect", "symmetric", "wrap"
650/// * `constant_value` - Value to use for constant padding (only used for "constant" mode)
651///
652/// # Returns
653///
654/// * Padded array
655///
656/// # Examples
657///
658/// ```
659/// // Example with a 1D array
660/// use scirs2_core::utils::pad_array;
661/// use ndarray::{Array1, array};
662///
663/// let arr = array![1.0, 2.0, 3.0];
664/// let padded = pad_array(&arr, &[(1, 2)], "constant", Some(0.0)).expect("Operation failed");
665/// assert_eq!(padded.shape(), &[6]);
666/// assert_eq!(padded, array![0.0, 1.0, 2.0, 3.0, 0.0, 0.0]);
667/// ```
668///
669/// # Errors
670///
671/// Returns an error if the input array is 0-dimensional, if pad_width length doesn't match input dimensions,
672/// or if the padding mode is unsupported for the given array dimensionality.
673#[allow(dead_code)]
674pub fn pad_array<T, D>(
675    input: &Array<T, D>,
676    pad_width: &[(usize, usize)],
677    mode: &str,
678    constant_value: Option<T>,
679) -> Result<Array<T, D>, String>
680where
681    T: Float + FromPrimitive + Debug + Clone,
682    D: Dimension,
683{
684    // Validate inputs
685    if input.ndim() == 0 {
686        return Err("Input array cannot be 0-dimensional".to_string());
687    }
688
689    if pad_width.len() != input.ndim() {
690        return Err(format!(
691            "Pad _width must have same length as input dimensions (got {} expected {})",
692            pad_width.len(),
693            input.ndim()
694        ));
695    }
696
697    // No padding needed - return copy of input
698    if pad_width.iter().all(|&(a, b)| a == 0 && b == 0) {
699        return Ok(input.to_owned());
700    }
701
702    // Calculate new shape
703    let mut newshape = Vec::with_capacity(input.ndim());
704    for (dim, &(pad_before, pad_after)) in pad_width.iter().enumerate().take(input.ndim()) {
705        newshape.push(input.shape()[dim] + pad_before + pad_after);
706    }
707
708    // Create output array with default constant value
709    let const_val = constant_value.unwrap_or_else(|| T::zero());
710    let mut output = Array::<T, D>::from_elem(
711        D::from_dimension(&crate::ndarray::IxDyn(&newshape))
712            .expect("Could not create dimension from shape"),
713        const_val,
714    );
715
716    // For 1D arrays
717    if input.ndim() == 1 {
718        // Convert to Array1 for easier manipulation
719        let inputarray1 = input
720            .view()
721            .into_dimensionality::<crate::ndarray::Ix1>()
722            .map_err(|_| "Failed to convert to 1D array".to_string())?;
723        let mut output_array1 = output
724            .view_mut()
725            .into_dimensionality::<crate::ndarray::Ix1>()
726            .map_err(|_| "Failed to convert output to 1D array".to_string())?;
727
728        let input_len = inputarray1.len();
729        let start = pad_width[0].0;
730
731        // First copy the input to the center region
732        for i in 0..input_len {
733            output_array1[start + i] = inputarray1[i];
734        }
735
736        // Then pad the borders based on the mode
737        match mode.to_lowercase().as_str() {
738            "constant" => {
739                // Already filled with constant value
740            }
741            "edge" => {
742                // Pad left side with first value
743                for i in 0..pad_width[0].0 {
744                    output_array1[i] = inputarray1[0];
745                }
746                // Pad right side with last value
747                let offset = start + input_len;
748                for i in 0..pad_width[0].1 {
749                    output_array1[offset + i] = inputarray1[input_len - 1];
750                }
751            }
752            "reflect" => {
753                // Pad left side
754                for i in 0..pad_width[0].0 {
755                    let src_idx = pad_width[0].0 - i;
756                    if src_idx < input_len {
757                        output_array1[i] = inputarray1[src_idx];
758                    }
759                }
760                // Pad right side
761                let offset = start + input_len;
762                for i in 0..pad_width[0].1 {
763                    let src_idx = input_len - 2 - i;
764                    if src_idx < input_len {
765                        output_array1[offset + i] = inputarray1[src_idx];
766                    }
767                }
768            }
769            "wrap" => {
770                // Pad left side
771                for i in 0..pad_width[0].0 {
772                    let src_idx = (input_len - (pad_width[0].0 - i) % input_len) % input_len;
773                    output_array1[i] = inputarray1[src_idx];
774                }
775                // Pad right side
776                let offset = start + input_len;
777                for i in 0..pad_width[0].1 {
778                    let src_idx = i % input_len;
779                    output_array1[offset + i] = inputarray1[src_idx];
780                }
781            }
782            "maximum" => {
783                // Find maximum value
784                let max_val = inputarray1.iter().fold(T::neg_infinity(), |a, &b| a.max(b));
785
786                // Pad with maximum value
787                for i in 0..pad_width[0].0 {
788                    output_array1[i] = max_val;
789                }
790                let offset = start + input_len;
791                for i in 0..pad_width[0].1 {
792                    output_array1[offset + i] = max_val;
793                }
794            }
795            "minimum" => {
796                // Find minimum value
797                let min_val = inputarray1.iter().fold(T::infinity(), |a, &b| a.min(b));
798
799                // Pad with minimum value
800                for i in 0..pad_width[0].0 {
801                    output_array1[i] = min_val;
802                }
803                let offset = start + input_len;
804                for i in 0..pad_width[0].1 {
805                    output_array1[offset + i] = min_val;
806                }
807            }
808            "mean" => {
809                // Calculate mean value
810                let sum = inputarray1.iter().fold(T::zero(), |a, &b| a + b);
811                let mean_val = sum / T::from_usize(input_len).expect("Operation failed");
812
813                // Pad with mean value
814                for i in 0..pad_width[0].0 {
815                    output_array1[i] = mean_val;
816                }
817                let offset = start + input_len;
818                for i in 0..pad_width[0].1 {
819                    output_array1[offset + i] = mean_val;
820                }
821            }
822            _ => return Err(format!("Unsupported padding mode: {mode}")),
823        }
824
825        return Ok(output);
826    }
827
828    // For 2D arrays, we could add specific implementation similar to the above
829    // For now, we'll just return the output with constant padding for 2D and higher
830
831    if mode.to_lowercase() != "constant" {
832        return Err(format!(
833            "Padding mode '{mode}' is not yet implemented for arrays with more than 1 dimension"
834        ));
835    }
836
837    // For higher dimensions, we'll just return a more simplified implementation with
838    // constant padding only for now, and a note that it needs more work
839
840    // We've already created the padded array with constant values,
841    // now just return it since other padding modes for higher dimensions
842    // would require more complex implementation
843    //
844    // NOTE: This is a placeholder implementation that needs to be improved
845    // in the future to support all padding modes for higher dimensions
846
847    Ok(output)
848}
849
850/// Create window functions of various types.
851///
852/// # Arguments
853///
854/// * `window_type` - Type of window function ("hamming", "hanning", "blackman", etc.)
855/// * `length` - Length of the window
856/// * `periodic` - Whether the window should be periodic (default: false)
857///
858/// # Returns
859///
860/// * Window function values as a vector
861///
862/// # Errors
863///
864/// Returns an error if the window length is zero or if the window type is unknown.
865#[allow(dead_code)]
866pub fn generate_window(
867    window_type: &str,
868    length: usize,
869    periodic: bool,
870) -> Result<Vec<f64>, String> {
871    if length == 0 {
872        return Err("Window length must be positive".to_string());
873    }
874
875    let mut window = Vec::with_capacity(length);
876
877    // Adjust length for periodic case
878    let n = if periodic { length + 1 } else { length };
879
880    // Generate window based on _type
881    match window_type.to_lowercase().as_str() {
882        "hamming" => {
883            // Hamming window: 0.54 - 0.46 * cos(2πn/(N-1))
884            for i in 0..length {
885                let w =
886                    0.54 - 0.46 * (2.0 * std::f64::consts::PI * i as f64 / (n - 1) as f64).cos();
887                window.push(w);
888            }
889        }
890        "hanning" | "hann" => {
891            // Hann window: 0.5 * (1 - cos(2πn/(N-1)))
892            for i in 0..length {
893                let w =
894                    0.5 * (1.0 - (2.0 * std::f64::consts::PI * i as f64 / (n - 1) as f64).cos());
895                window.push(w);
896            }
897        }
898        "blackman" => {
899            // Blackman window: 0.42 - 0.5 * cos(2πn/(N-1)) + 0.08 * cos(4πn/(N-1))
900            for i in 0..length {
901                let w = 0.42 - 0.5 * (2.0 * std::f64::consts::PI * i as f64 / (n - 1) as f64).cos()
902                    + 0.08 * (4.0 * std::f64::consts::PI * i as f64 / (n - 1) as f64).cos();
903                window.push(w);
904            }
905        }
906        "bartlett" => {
907            // Bartlett window (triangular window)
908            let m = (n - 1) as f64 / 2.0;
909            for i in 0..length {
910                let w = 1.0 - ((i as f64 - m) / m).abs();
911                window.push(w);
912            }
913        }
914        "boxcar" | "rectangular" => {
915            // Rectangular window (all ones)
916            window.extend(std::iter::repeat_n(1.0, length));
917        }
918        "triang" => {
919            // Triangular window (slightly different from Bartlett)
920            let m = (length - 1) as f64 / 2.0;
921            for i in 0..length {
922                let w = 1.0 - ((i as f64 - m) / (m + 1.0)).abs();
923                window.push(w);
924            }
925        }
926        _ => {
927            return Err(format!("Unknown window type: {window_type}"));
928        }
929    }
930
931    Ok(window)
932}
933
934/// Get window function compatible with SciPy API
935///
936/// This is a wrapper around `generate_window` that returns `CoreResult` for
937/// consistency with other SciRS2 functions.
938///
939/// # Arguments
940///
941/// * `window_type` - Type of window function ("hamming", "hann", "rectangular", etc.)
942/// * `length` - Length of the window
943/// * `periodic` - Whether the window should be periodic (default: false)
944///
945/// # Returns
946///
947/// * Window function values as a vector wrapped in `CoreResult`
948///
949/// # Examples
950///
951/// ```rust
952/// use scirs2_core::utils::get_window;
953///
954/// let hamming = get_window("hamming", 5, false).expect("Operation failed");
955/// assert_eq!(hamming.len(), 5);
956/// ```
957#[allow(dead_code)]
958pub fn get_window(window_type: &str, length: usize, periodic: bool) -> CoreResult<Vec<f64>> {
959    generate_window(window_type, length, periodic)
960        .map_err(|e| CoreError::ValueError(ErrorContext::new(e)))
961}
962
963/// Differentiate a function using central difference method.
964///
965/// # Arguments
966///
967/// * `x` - Point at which to evaluate the derivative
968/// * `h` - Step size for the finite difference
969/// * `eval_fn` - Function that evaluates the function at a point
970///
971/// # Returns
972///
973/// * Derivative of the function at x
974///
975/// # Examples
976///
977/// ```
978/// use scirs2_core::utils::differentiate;
979///
980/// // Differentiate f(x) = x^2 at x = 3
981/// let f = |x: f64| -> Result<f64, String> { Ok(x * x) };
982/// let derivative = differentiate(3.0, 0.001, f).expect("Operation failed");
983///
984/// // The exact derivative is 2x = 6 at x = 3
985/// assert!((derivative - 6.0).abs() < 1e-5);
986/// ```
987///
988/// # Errors
989///
990/// Returns an error if the evaluation function fails at either x+h or x-h.
991#[allow(dead_code)]
992pub fn differentiate<F, Func>(x: F, h: F, evalfn: Func) -> Result<F, String>
993where
994    F: Float + FromPrimitive + Debug,
995    Func: Fn(F) -> Result<F, String>,
996{
997    // Use central difference for better accuracy
998    let f_plus = evalfn(x + h).map_err(|e| format!("Error evaluating function at x+h: {e}"))?;
999    let f_minus = evalfn(x - h).map_err(|e| format!("Error evaluating function at x-h: {e}"))?;
1000    let derivative =
1001        (f_plus - f_minus) / (F::from(2.0).expect("Failed to convert constant to float") * h);
1002    Ok(derivative)
1003}
1004
1005/// Integrate a function using composite Simpson's rule.
1006///
1007/// # Arguments
1008///
1009/// * `a` - Lower bound of integration
1010/// * `b` - Upper bound of integration
1011/// * `n` - Number of intervals for the quadrature (must be even)
1012/// * `eval_fn` - Function that evaluates the function at a point
1013///
1014/// # Returns
1015///
1016/// * Definite integral of the function from a to b
1017///
1018/// # Examples
1019///
1020/// ```
1021/// use scirs2_core::utils::integrate;
1022///
1023/// // Integrate f(x) = x^2 from 0 to 1
1024/// let f = |x: f64| -> Result<f64, String> { Ok(x * x) };
1025/// let integral = integrate(0.0, 1.0, 100, f).expect("Operation failed");
1026///
1027/// // The exact integral is x^3/3 = 1/3 from 0 to 1
1028/// assert!((integral - 1.0/3.0).abs() < 1e-5);
1029/// ```
1030///
1031/// # Errors
1032///
1033/// Returns an error if the number of intervals is less than 2, not even, or if the evaluation function fails.
1034#[allow(dead_code)]
1035pub fn integrate<F, Func>(a: F, b: F, n: usize, evalfn: Func) -> Result<F, String>
1036where
1037    F: Float + FromPrimitive + Debug,
1038    Func: Fn(F) -> Result<F, String>,
1039{
1040    if a > b {
1041        return integrate(b, a, n, evalfn).map(|result| -result);
1042    }
1043
1044    // Use composite Simpson's rule for integration
1045    if n < 2 {
1046        return Err("number of intervals must be at least 2".to_string());
1047    }
1048
1049    if n % 2 != 0 {
1050        return Err("number of intervals must be even".to_string());
1051    }
1052
1053    let h = (b - a) / F::from_usize(n).expect("Operation failed");
1054    let mut sum = evalfn(a).map_err(|e| format!("Error evaluating function at a: {e}"))?
1055        + evalfn(b).map_err(|e| format!("Error evaluating function at b: {e}"))?;
1056
1057    // Even-indexed points (except endpoints)
1058    for i in 1..n {
1059        if i % 2 == 0 {
1060            let x_i = a + F::from_usize(i).expect("Operation failed") * h;
1061            sum = sum
1062                + F::from(2.0).expect("Failed to convert constant to float")
1063                    * evalfn(x_i)
1064                        .map_err(|e| format!("Error evaluating function at x_{i}: {e}"))?;
1065        }
1066    }
1067
1068    // Odd-indexed points
1069    for i in 1..n {
1070        if i % 2 == 1 {
1071            let x_i = a + F::from_usize(i).expect("Operation failed") * h;
1072            sum = sum
1073                + F::from(4.0).expect("Failed to convert constant to float")
1074                    * evalfn(x_i)
1075                        .map_err(|e| format!("Error evaluating function at x_{i}: {e}"))?;
1076        }
1077    }
1078
1079    let integral = h * sum / F::from(3.0).expect("Failed to convert constant to float");
1080    Ok(integral)
1081}
1082
1083#[cfg(test)]
1084mod tests {
1085    use super::*;
1086    use approx::assert_relative_eq;
1087    use ndarray::{arr1, array};
1088
1089    #[test]
1090    fn test_is_close() {
1091        assert!(is_close(1.0, 1.0, 1e-10, 1e-10));
1092        assert!(is_close(1.0, 1.0 + 1e-11, 1e-10, 1e-10));
1093        assert!(!is_close(1.0, 1.1, 1e-10, 1e-10));
1094        assert!(is_close(1e10, 1e10 + 1.0, 1e-10, 1e-9));
1095    }
1096
1097    #[test]
1098    fn test_fill_diagonal() {
1099        let a = Array2::<f64>::zeros((3, 3));
1100        let a_diag = fill_diagonal(a, 5.0);
1101
1102        assert_relative_eq!(a_diag[[0, 0]], 5.0);
1103        assert_relative_eq!(a_diag[[1, 1]], 5.0);
1104        assert_relative_eq!(a_diag[[2, 2]], 5.0);
1105        assert_relative_eq!(a_diag[[0, 1]], 0.0);
1106    }
1107
1108    #[test]
1109    fn test_prod() {
1110        assert_eq!(prod(vec![1, 2, 3, 4]), 24);
1111        assert_eq!(prod(vec![2.0, 3.0, 4.0]), 24.0);
1112    }
1113
1114    #[test]
1115    fn test_arange() {
1116        let result = arange(0.0, 5.0, 1.0).expect("Operation failed");
1117        assert_eq!(result, vec![0.0, 1.0, 2.0, 3.0, 4.0]);
1118
1119        let result = arange(5.0, 0.0, -1.0).expect("Operation failed");
1120        assert_eq!(result, vec![5.0, 4.0, 3.0, 2.0, 1.0]);
1121    }
1122
1123    #[test]
1124    fn test_arange_zero_step() {
1125        let result = arange(0.0, 5.0, 0.0);
1126        assert!(result.is_err());
1127    }
1128
1129    #[test]
1130    fn test_all() {
1131        assert!(all(vec![2, 4, 6, 8], |x| x % 2 == 0));
1132        assert!(!all(vec![2, 4, 5, 8], |x| x % 2 == 0));
1133    }
1134
1135    #[test]
1136    fn test_any() {
1137        assert!(any(vec![1, 2, 3, 4], |x| x % 2 == 0));
1138        assert!(!any(vec![1, 3, 5, 7], |x| x % 2 == 0));
1139    }
1140
1141    #[test]
1142    fn test_linspace() {
1143        let result = linspace(0.0, 1.0, 5);
1144        let expected = arr1(&[0.0, 0.25, 0.5, 0.75, 1.0]);
1145        assert_eq!(result.len(), 5);
1146
1147        for (a, b) in result.iter().zip(expected.iter()) {
1148            assert_relative_eq!(a, b, epsilon = 1e-14);
1149        }
1150
1151        // Test with single value
1152        let result = linspace(5.0, 5.0, 1);
1153        assert_eq!(result.len(), 1);
1154        assert_relative_eq!(result[0], 5.0);
1155
1156        // Test endpoints
1157        let result = linspace(-10.0, 10.0, 5);
1158        assert_relative_eq!(result[0], -10.0);
1159        assert_relative_eq!(result[4], 10.0);
1160    }
1161
1162    #[test]
1163    fn testlogspace() {
1164        let result = logspace(0.0, 3.0, 4, None);
1165        let expected = arr1(&[1.0, 10.0, 100.0, 1000.0]);
1166        assert_eq!(result.len(), 4);
1167
1168        for (a, b) in result.iter().zip(expected.iter()) {
1169            assert_relative_eq!(a, b, epsilon = 1e-10);
1170        }
1171
1172        // Test with custom base
1173        let result = logspace(0.0, 3.0, 4, Some(2.0));
1174        let expected = arr1(&[1.0, 2.0, 4.0, 8.0]);
1175
1176        for (a, b) in result.iter().zip(expected.iter()) {
1177            assert_relative_eq!(a, b, epsilon = 1e-10);
1178        }
1179    }
1180
1181    #[test]
1182    fn test_maximum() {
1183        let a = array![[1, 2], [3, 4]];
1184        let b = array![[5, 1], [7, 2]];
1185
1186        let result = maximum(&a, &b);
1187        let expected = array![[5, 2], [7, 4]];
1188
1189        assert_eq!(result, expected);
1190    }
1191
1192    #[test]
1193    fn test_minimum() {
1194        let a = array![[1, 2], [3, 4]];
1195        let b = array![[5, 1], [7, 2]];
1196
1197        let result = minimum(&a, &b);
1198        let expected = array![[1, 1], [3, 2]];
1199
1200        assert_eq!(result, expected);
1201    }
1202
1203    #[test]
1204    #[should_panic]
1205    fn test_maximum_differentshapes() {
1206        let a = array![[1, 2], [3, 4]];
1207        let b = array![[5, 1, 2], [7, 2, 3]];
1208
1209        let result = maximum(&a, &b);
1210    }
1211
1212    #[test]
1213    fn test_points_equal() {
1214        let point1 = [1.0, 2.0, 3.0];
1215        let point2 = [1.0, 2.0, 3.0];
1216        let point3 = [1.0, 2.0, 3.001];
1217        let point4 = [1.0, 2.0, 4.0];
1218
1219        assert!(points_equal(&point1, &point2, None));
1220        assert!(!points_equal(&point1, &point3, None));
1221        assert!(points_equal(&point1, &point3, Some(0.01)));
1222        assert!(!points_equal(&point1, &point4, Some(0.01)));
1223    }
1224
1225    #[test]
1226    fn test_arrays_equal() {
1227        let arr1 = array![[1.0, 2.0], [3.0, 4.0]];
1228        let arr2 = array![[1.0, 2.0], [3.0, 4.0]];
1229        let arr3 = array![[1.0, 2.0], [3.0, 4.001]];
1230
1231        assert!(arrays_equal(&arr1, &arr2, None));
1232        assert!(!arrays_equal(&arr1, &arr3, None));
1233        assert!(arrays_equal(&arr1, &arr3, Some(0.01)));
1234    }
1235
1236    #[test]
1237    fn test_normalize() {
1238        // Test energy normalization
1239        let signal = vec![1.0, 2.0, 3.0, 4.0];
1240        let normalized = normalize(&signal, "energy").expect("Operation failed");
1241
1242        // Sum of squares should be 1.0
1243        let sum_of_squares: f64 = normalized.iter().map(|&x| x * x).sum();
1244        assert_relative_eq!(sum_of_squares, 1.0, epsilon = 1e-10);
1245
1246        // Test peak normalization
1247        let signal = vec![1.0, -2.0, 3.0, -4.0];
1248        let normalized = normalize(&signal, "peak").expect("Operation failed");
1249
1250        // Max absolute value should be 1.0
1251        let peak = normalized.iter().fold(0.0, |a, &b| a.max(b.abs()));
1252        assert_relative_eq!(peak, 1.0, epsilon = 1e-10);
1253        assert_relative_eq!(normalized[3], -1.0, epsilon = 1e-10);
1254
1255        // Test sum normalization
1256        let signal = vec![1.0, 2.0, 3.0, 4.0];
1257        let normalized = normalize(&signal, "sum").expect("Operation failed");
1258
1259        // Sum should be 1.0
1260        let sum: f64 = normalized.iter().sum();
1261        assert_relative_eq!(sum, 1.0, epsilon = 1e-10);
1262    }
1263
1264    #[test]
1265    fn test_pad_array() {
1266        // Test constant padding on 1D array
1267        let arr = array![1.0, 2.0, 3.0];
1268        let padded = pad_array(&arr, &[(1, 2)], "constant", Some(0.0)).expect("Operation failed");
1269
1270        assert_eq!(padded.shape(), &[6]);
1271        assert_eq!(padded, array![0.0, 1.0, 2.0, 3.0, 0.0, 0.0]);
1272
1273        // Test edge padding
1274        let arr = array![1.0, 2.0, 3.0];
1275        let padded = pad_array(&arr, &[(2, 2)], "edge", None).expect("Operation failed");
1276
1277        assert_eq!(padded.shape(), &[7]);
1278        assert_eq!(padded, array![1.0, 1.0, 1.0, 2.0, 3.0, 3.0, 3.0]);
1279
1280        // Test maximum padding
1281        let arr = array![1.0, 2.0, 3.0];
1282        let padded = pad_array(&arr, &[(1, 1)], "maximum", None).expect("Operation failed");
1283
1284        assert_eq!(padded.shape(), &[5]);
1285        assert_eq!(padded, array![3.0, 1.0, 2.0, 3.0, 3.0]);
1286    }
1287
1288    #[test]
1289    fn test_get_window() {
1290        // Test Hamming window
1291        let window = get_window("hamming", 5, false).expect("Operation failed");
1292
1293        assert_eq!(window.len(), 5);
1294        assert!(window[0] > 0.0 && window[0] < 0.6); // First value around 0.54
1295        assert!(window[2] > 0.9); // Middle value close to 1.0
1296
1297        // Test Hann window
1298        let window = get_window("hann", 5, false).expect("Operation failed");
1299
1300        assert_eq!(window.len(), 5);
1301        assert!((window[0] - 0.0).abs() < 1e-10);
1302        assert!(window[2] > 0.9); // Middle value close to 1.0
1303
1304        // Test rectangular window
1305        let window = get_window("rectangular", 5, false).expect("Operation failed");
1306
1307        assert_eq!(window.len(), 5);
1308        assert!(window.iter().all(|&x| (x - 1.0).abs() < 1e-10));
1309    }
1310
1311    #[test]
1312    fn test_differentiate_integrate() {
1313        // Test differentiation of x^2
1314        let f = |x: f64| -> Result<f64, String> { Ok(x * x) };
1315
1316        let derivative = differentiate(3.0, 0.001, f).expect("Operation failed");
1317        assert_relative_eq!(derivative, 6.0, epsilon = 1e-3); // f'(x) = 2x => f'(3) = 6
1318
1319        // Test integration of x^2 from 0 to 1
1320        let integral = integrate(0.0, 1.0, 100, f).expect("Operation failed");
1321        assert_relative_eq!(integral, 1.0 / 3.0, epsilon = 1e-5); // ∫x^2 dx = x^3/3 => [0,1] = 1/3
1322
1323        // Test integration of x^2 from 0 to 2
1324        let integral = integrate(0.0, 2.0, 100, f).expect("Operation failed");
1325        assert_relative_eq!(integral, 8.0 / 3.0, epsilon = 1e-5); // ∫x^2 dx = x^3/3 => [0,2] = 8/3
1326    }
1327}