Skip to main content

scirs2_special/
utility.rs

1//! Utility and convenience functions with mathematical foundations
2//!
3//! This module provides various utility functions commonly used in scientific
4//! computing, with detailed mathematical theory, proofs, and numerical analysis.
5//!
6//! ## Mathematical Theory and Foundations
7//!
8//! ### Elementary Functions with Special Properties
9//!
10//! This module contains fundamental mathematical functions that serve as building
11//! blocks for more complex special functions. Each function is implemented with
12//! careful attention to numerical stability and mathematical rigor.
13//!
14//! ### The Cube Root Function
15//!
16//! **Mathematical Definition**: For x ∈ ℝ, the cube root is defined as:
17//! ```text
18//! ∛x = x^(1/3) = exp(ln|x|/3) · sign(x)
19//! ```
20//!
21//! **Properties**:
22//! 1. **Domain and Range**: ∛: ℝ → ℝ (unlike square root, defined for all reals)
23//! 2. **Odd Function**: ∛(-x) = -∛x
24//!    - **Proof**: Using the definition, ∛(-x) = (-x)^(1/3) = (-1)^(1/3) · x^(1/3) = -∛x
25//! 3. **Monotonically Increasing**: d/dx[x^(1/3)] = (1/3)x^(-2/3) > 0 for x > 0
26//! 4. **Inverse of Cubing**: (∛x)³ = x for all x ∈ ℝ
27//!
28//! **Numerical Implementation**: To handle negative numbers correctly, we use:
29//! - For x ≥ 0: ∛x = x^(1/3)
30//! - For x < 0: ∛x = -(-x)^(1/3)
31//!
32//! ### Exponential Functions with Different Bases
33//!
34//! **Base-10 Exponential (exp10)**:
35//! ```text
36//! exp10(x) = 10^x = e^(x·ln(10))
37//! ```
38//!
39//! **Mathematical Properties**:
40//! 1. **Exponential Law**: 10^(x+y) = 10^x · 10^y
41//! 2. **Inverse of log₁₀**: exp10(log₁₀(x)) = x for x > 0
42//! 3. **Derivative**: d/dx[10^x] = 10^x · ln(10)
43//! 4. **Growth Rate**: 10^x grows faster than any polynomial
44//!
45//! **Base-2 Exponential (exp2)**:
46//! ```text
47//! exp2(x) = 2^x = e^(x·ln(2))
48//! ```
49//!
50//! **Computer Science Applications**:
51//! - Binary representation: 2^n gives powers of 2
52//! - Information theory: 2^H(X) relates to entropy
53//! - Algorithm complexity: Many algorithms have 2^n complexity
54//!
55//! ### Trigonometric Functions in Degrees
56//!
57//! **Degree-Radian Conversion**:
58//! ```text
59//! radians = degrees × π/180
60//! degrees = radians × 180/π
61//! ```
62//!
63//! **Mathematical Justification**: A full circle contains 2π radians = 360°,
64//! establishing the conversion factor π/180.
65//!
66//! **Degree-based Trigonometric Functions**:
67//! - sindg(x) = sin(x × π/180)
68//! - cosdg(x) = cos(x × π/180)  
69//! - tandg(x) = tan(x × π/180)
70//! - cotdg(x) = cot(x × π/180) = 1/tan(x × π/180)
71//!
72//! ### Special Numerical Functions
73//!
74//! **exprel(x) = (e^x - 1)/x**:
75//! - **Purpose**: Numerically stable computation of (e^x - 1)/x near x = 0
76//! - **Taylor Series**: exprel(x) = 1 + x/2 + x²/6 + x³/24 + ... = Σ_{n=0}^∞ x^n/(n+1)!
77//! - **Limit**: lim_{x→0} exprel(x) = 1 (removable singularity)
78//! - **Applications**: Actuarial calculations, queuing theory
79//!
80//! **cosm1(x) = cos(x) - 1**:
81//! - **Purpose**: Accurate computation of cos(x) - 1 for small |x|
82//! - **Series**: cosm1(x) = -x²/2 + x⁴/24 - x⁶/720 + ... = -Σ_{n=1}^∞ (-1)^n x^(2n)/(2n)!
83//! - **Numerical Advantage**: Avoids catastrophic cancellation when cos(x) ≈ 1
84//!
85//! **powm1(x, y) = x^y - 1**:
86//! - **Implementation**: For small y, use powm1(x, y) = exp(y·ln(x)) - 1 ≈ y·ln(x) when |y·ln(x)| is small
87//! - **Numerical Stability**: Avoids precision loss when x^y ≈ 1
88//!
89//! ### Advanced Utility Functions
90//!
91//! **Dirichlet Kernel (diric)**:
92//! ```text
93//! diric(x, n) = sin(nx/2) / (n·sin(x/2)) for x ≠ 2πk
94//! diric(2πk, n) = (-1)^(kn)
95//! ```
96//!
97//! **Properties**:
98//! 1. **Periodicity**: diric(x + 2π, n) = diric(x, n)
99//! 2. **Normalization**: ∫_{-π}^π diric(x, n) dx = 2π
100//! 3. **Fourier Connection**: Dirichlet kernel is the Fourier kernel for rectangular window
101//!
102//! **Owen's T Function**:
103//! ```text
104//! T(h, a) = (1/2π) ∫₀^a exp(-h²(1+t²)/2) / (1+t²) dt
105//! ```
106//!
107//! **Applications**:
108//! - Bivariate normal distribution calculations
109//! - Statistical hypothesis testing
110//! - Error probability computations
111//!
112//! ### Numerical Stability Considerations
113//!
114//! All functions in this module are implemented with careful attention to:
115//!
116//! 1. **Overflow/Underflow Prevention**: Using appropriate scaling and range reduction
117//! 2. **Catastrophic Cancellation Avoidance**: Special algorithms for near-zero differences
118//! 3. **Precision Preservation**: Maintaining accuracy across the full range of inputs
119//! 4. **Edge Case Handling**: Proper behavior at singularities and boundary conditions
120
121use crate::error::{SpecialError, SpecialResult};
122use crate::validation::check_finite;
123use scirs2_core::ndarray::{Array1, ArrayView1};
124use scirs2_core::numeric::{Float, FromPrimitive, Zero};
125use std::fmt::{Debug, Display};
126
127/// Cube root function
128///
129/// Computes the real cube root of x, handling negative values correctly.
130///
131/// # Arguments
132/// * `x` - Input value
133///
134/// # Returns
135/// The cube root of x
136///
137/// # Examples
138/// ```
139/// use scirs2_special::utility::cbrt;
140///
141/// assert_eq!(cbrt(8.0), 2.0);
142/// assert_eq!(cbrt(-8.0), -2.0);
143/// ```
144#[allow(dead_code)]
145pub fn cbrt<T>(x: T) -> T
146where
147    T: Float + FromPrimitive,
148{
149    if x >= T::zero() {
150        x.powf(T::from_f64(1.0 / 3.0).expect("Operation failed"))
151    } else {
152        -(-x).powf(T::from_f64(1.0 / 3.0).expect("Operation failed"))
153    }
154}
155
156/// Base-10 exponential function
157///
158/// Computes 10^x.
159///
160/// # Arguments
161/// * `x` - Exponent
162///
163/// # Returns
164/// 10 raised to the power x
165#[allow(dead_code)]
166pub fn exp10<T>(x: T) -> T
167where
168    T: Float + FromPrimitive,
169{
170    T::from_f64(10.0).expect("Operation failed").powf(x)
171}
172
173/// Base-2 exponential function
174///
175/// Computes 2^x.
176///
177/// # Arguments
178/// * `x` - Exponent
179///
180/// # Returns
181/// 2 raised to the power x
182#[allow(dead_code)]
183pub fn exp2<T>(x: T) -> T
184where
185    T: Float,
186{
187    x.exp2()
188}
189
190/// Convert degrees to radians
191///
192/// # Arguments
193/// * `degrees` - Angle in degrees
194///
195/// # Returns
196/// Angle in radians
197#[allow(dead_code)]
198pub fn radian<T>(degrees: T) -> T
199where
200    T: Float + FromPrimitive,
201{
202    let pi = T::from_f64(std::f64::consts::PI).expect("Operation failed");
203    degrees * pi / T::from_f64(180.0).expect("Operation failed")
204}
205
206/// Cosine of angle in degrees
207///
208/// # Arguments
209/// * `x` - Angle in degrees
210///
211/// # Returns
212/// cos(x) where x is in degrees
213#[allow(dead_code)]
214pub fn cosdg<T>(x: T) -> T
215where
216    T: Float + FromPrimitive,
217{
218    radian(x).cos()
219}
220
221/// Sine of angle in degrees
222///
223/// # Arguments
224/// * `x` - Angle in degrees
225///
226/// # Returns
227/// sin(x) where x is in degrees
228#[allow(dead_code)]
229pub fn sindg<T>(x: T) -> T
230where
231    T: Float + FromPrimitive,
232{
233    radian(x).sin()
234}
235
236/// Tangent of angle in degrees
237///
238/// # Arguments
239/// * `x` - Angle in degrees
240///
241/// # Returns
242/// tan(x) where x is in degrees
243#[allow(dead_code)]
244pub fn tandg<T>(x: T) -> T
245where
246    T: Float + FromPrimitive,
247{
248    radian(x).tan()
249}
250
251/// Cotangent of angle in degrees
252///
253/// # Arguments
254/// * `x` - Angle in degrees
255///
256/// # Returns
257/// cot(x) = 1/tan(x) where x is in degrees
258#[allow(dead_code)]
259pub fn cotdg<T>(x: T) -> T
260where
261    T: Float + FromPrimitive,
262{
263    T::from_f64(1.0).expect("Operation failed") / tandg(x)
264}
265
266/// Compute cos(x) - 1 accurately for small x
267///
268/// This function provides better numerical accuracy than directly computing cos(x) - 1
269/// when x is close to 0.
270///
271/// # Arguments
272/// * `x` - Input value
273///
274/// # Returns
275/// cos(x) - 1
276#[allow(dead_code)]
277pub fn cosm1<T>(x: T) -> T
278where
279    T: Float + FromPrimitive,
280{
281    // Use Taylor series for small x
282    if x.abs() < T::from_f64(0.1).expect("Operation failed") {
283        let x2 = x * x;
284        let mut sum = -x2 / T::from_f64(2.0).expect("Operation failed");
285        let mut term = sum;
286        let mut n = T::from_f64(4.0).expect("Operation failed");
287
288        while term.abs() > T::epsilon() * sum.abs() {
289            term = term * (-x2) / (n * (n - T::from_f64(1.0).expect("Operation failed")));
290            sum = sum + term;
291            n = n + T::from_f64(2.0).expect("Operation failed");
292        }
293
294        sum
295    } else {
296        x.cos() - T::from_f64(1.0).expect("Operation failed")
297    }
298}
299
300/// Compute (1 + x)^y - 1 accurately
301///
302/// This function provides better numerical accuracy than directly computing (1 + x)^y - 1
303/// when x is small.
304///
305/// # Arguments
306/// * `x` - Base adjustment
307/// * `y` - Exponent
308///
309/// # Returns
310/// (1 + x)^y - 1
311#[allow(dead_code)]
312pub fn powm1<T>(x: T, y: T) -> SpecialResult<T>
313where
314    T: Float + FromPrimitive + Display,
315{
316    check_finite(x, "x value")?;
317    check_finite(y, "y value")?;
318
319    if x.abs() < T::from_f64(0.1).expect("Operation failed")
320        && y.abs() < T::from_f64(10.0).expect("Operation failed")
321    {
322        // Use exp(y * log1p(x)) - 1 = expm1(y * log1p(x))
323        Ok((y * x.ln_1p()).exp_m1())
324    } else {
325        Ok((T::from_f64(1.0).expect("Operation failed") + x).powf(y)
326            - T::from_f64(1.0).expect("Operation failed"))
327    }
328}
329
330/// Compute x * log(y) safely
331///
332/// Returns 0 when x = 0, even if log(y) is undefined or infinite.
333///
334/// # Arguments
335/// * `x` - Multiplier
336/// * `y` - Argument to logarithm
337///
338/// # Returns
339/// x * log(y) with special handling for x = 0
340#[allow(dead_code)]
341pub fn xlogy<T>(x: T, y: T) -> T
342where
343    T: Float + Zero,
344{
345    if x.is_zero() {
346        T::zero()
347    } else if y <= T::zero() {
348        T::nan()
349    } else {
350        x * y.ln()
351    }
352}
353
354/// Compute x * log(1 + y) safely
355///
356/// Returns 0 when x = 0, provides accurate results for small y.
357///
358/// # Arguments
359/// * `x` - Multiplier
360/// * `y` - Argument to log1p
361///
362/// # Returns
363/// x * log(1 + y) with special handling
364#[allow(dead_code)]
365pub fn xlog1py<T>(x: T, y: T) -> T
366where
367    T: Float + Zero,
368{
369    if x.is_zero() {
370        T::zero()
371    } else {
372        x * y.ln_1p()
373    }
374}
375
376/// Relative exponential function
377///
378/// Computes (exp(x) - 1) / x accurately for small x.
379///
380/// # Arguments
381/// * `x` - Input value
382///
383/// # Returns
384/// (exp(x) - 1) / x
385#[allow(dead_code)]
386pub fn exprel<T>(x: T) -> T
387where
388    T: Float + FromPrimitive,
389{
390    if x.abs() < T::from_f64(1e-5).expect("Operation failed") {
391        // Taylor series: 1 + x/2 + x²/6 + x³/24 + ...
392        let mut sum = T::from_f64(1.0).expect("Operation failed");
393        let mut term = x / T::from_f64(2.0).expect("Operation failed");
394        let mut n = T::from_f64(2.0).expect("Operation failed");
395
396        sum = sum + term;
397
398        while term.abs() > T::epsilon() * sum.abs() {
399            term = term * x / (n + T::from_f64(1.0).expect("Operation failed"));
400            sum = sum + term;
401            n = n + T::from_f64(1.0).expect("Operation failed");
402        }
403
404        sum
405    } else {
406        x.exp_m1() / x
407    }
408}
409
410/// Round to nearest integer
411///
412/// Rounds half-integers to nearest even number (banker's rounding).
413///
414/// # Arguments
415/// * `x` - Value to round
416///
417/// # Returns
418/// Rounded value
419#[allow(dead_code)]
420pub fn round<T>(x: T) -> T
421where
422    T: Float,
423{
424    x.round()
425}
426
427/// Dirichlet kernel (periodic sinc function)
428///
429/// Computes sin(n * x/2) / (n * sin(x/2))
430///
431/// # Arguments
432/// * `x` - Input value
433/// * `n` - Integer parameter
434///
435/// # Returns
436/// The Dirichlet kernel value
437#[allow(dead_code)]
438pub fn diric<T>(x: T, n: i32) -> T
439where
440    T: Float + FromPrimitive,
441{
442    if n == 0 {
443        return T::zero();
444    }
445
446    let n_f = T::from_i32(n).expect("Operation failed");
447    let half = T::from_f64(0.5).expect("Operation failed");
448    let x_half = x * half;
449    let sin_x_half = x_half.sin();
450
451    if sin_x_half.abs() < T::epsilon() {
452        // Use limit as x -> 0
453        T::from_i32(n).expect("Operation failed")
454    } else {
455        (n_f * x_half).sin() / (n_f * sin_x_half)
456    }
457}
458
459/// Arithmetic-geometric mean
460///
461/// Computes the arithmetic-geometric mean of a and b.
462///
463/// # Arguments
464/// * `a` - First value (must be positive)
465/// * `b` - Second value (must be positive)
466///
467/// # Returns
468/// The arithmetic-geometric mean
469#[allow(dead_code)]
470pub fn agm<T>(a: T, b: T) -> SpecialResult<T>
471where
472    T: Float + FromPrimitive + Display,
473{
474    check_finite(a, "a value")?;
475    check_finite(b, "b value")?;
476
477    if a <= T::zero() || b <= T::zero() {
478        return Err(SpecialError::DomainError(
479            "agm: arguments must be positive".to_string(),
480        ));
481    }
482
483    let mut a_n = a;
484    let mut b_n = b;
485    let tol = T::epsilon() * a.max(b);
486
487    while (a_n - b_n).abs() > tol {
488        let a_next = (a_n + b_n) / T::from_f64(2.0).expect("Operation failed");
489        let b_next = (a_n * b_n).sqrt();
490        a_n = a_next;
491        b_n = b_next;
492    }
493
494    Ok(a_n)
495}
496
497/// Log of expit function
498///
499/// Computes log(1 / (1 + exp(-x))) = -log1p(exp(-x))
500///
501/// # Arguments
502/// * `x` - Input value
503///
504/// # Returns
505/// log(expit(x))
506#[allow(dead_code)]
507pub fn log_expit<T>(x: T) -> T
508where
509    T: Float,
510{
511    if x >= T::zero() {
512        -(-x).exp().ln_1p()
513    } else {
514        x - x.exp().ln_1p()
515    }
516}
517
518/// Softplus function
519///
520/// Computes log(1 + exp(x)) in a numerically stable way.
521///
522/// # Arguments
523/// * `x` - Input value
524///
525/// # Returns
526/// log(1 + exp(x))
527#[allow(dead_code)]
528pub fn softplus<T>(x: T) -> T
529where
530    T: Float + FromPrimitive,
531{
532    if x > T::from_f64(20.0).expect("Operation failed") {
533        // For large x, log(1 + exp(x)) ≈ x
534        x
535    } else if x < T::from_f64(-20.0).expect("Operation failed") {
536        // For large negative x, log(1 + exp(x)) ≈ exp(x)
537        x.exp()
538    } else {
539        x.exp().ln_1p()
540    }
541}
542
543/// Owen's T function
544///
545/// Computes T(h, a) = (1/2π) ∫₀ᵃ exp(-h²(1+x²)/2) / (1+x²) dx
546///
547/// # Arguments
548/// * `h` - First parameter
549/// * `a` - Second parameter
550///
551/// # Returns
552/// Owen's T function value
553///
554/// # Algorithm
555/// Uses a combination of series expansion for small |h|, asymptotic expansion
556/// for large |h|, and numerical integration for intermediate values.
557#[allow(dead_code)]
558pub fn owens_t<T>(h: T, a: T) -> SpecialResult<T>
559where
560    T: Float + FromPrimitive + Display + Debug,
561{
562    check_finite(h, "h value")?;
563    check_finite(a, "a value")?;
564
565    let zero = T::zero();
566    let one = T::one();
567    let two = T::from_f64(2.0).expect("Operation failed");
568    let pi = T::from_f64(std::f64::consts::PI).expect("Operation failed");
569
570    // Handle special cases
571    if a.is_zero() {
572        return Ok(zero);
573    }
574
575    if h.is_zero() {
576        return Ok(a.atan() / (two * pi));
577    }
578
579    let abs_h = h.abs();
580    let abs_a = a.abs();
581
582    // Use symmetry properties to reduce to first quadrant
583    let sign = if (h >= zero && a >= zero) || (h < zero && a < zero) {
584        one
585    } else {
586        -one
587    };
588
589    let result = if abs_h < T::from_f64(0.1).expect("Operation failed") {
590        // For small |h|, use series expansion
591        owens_t_series(abs_h, abs_a)?
592    } else if abs_h > T::from_f64(10.0).expect("Operation failed") {
593        // For large |h|, use asymptotic expansion
594        owens_t_asymptotic(abs_h, abs_a)?
595    } else {
596        // For intermediate values, use numerical integration
597        owens_t_numerical(abs_h, abs_a)?
598    };
599
600    Ok(sign * result)
601}
602
603/// Owen's T function using series expansion for small h
604#[allow(dead_code)]
605fn owens_t_series<T>(h: T, a: T) -> SpecialResult<T>
606where
607    T: Float + FromPrimitive,
608{
609    let zero = T::zero();
610    let one = T::one();
611    let two = T::from_f64(2.0).expect("Operation failed");
612    let pi = T::from_f64(std::f64::consts::PI).expect("Operation failed");
613
614    let h2 = h * h;
615    let a2 = a * a;
616    let atan_a = a.atan();
617
618    // Series: T(h,a) = (1/2π) * atan(a) - (h/2π) * ∑ (-1)^n * h^(2n) * I_n(a)
619    // where I_n(a) = ∫₀ᵃ x^(2n) / (1+x²) dx
620
621    let mut sum = zero;
622    let mut h_power = one;
623
624    for n in 0..20 {
625        let integral = if n == 0 {
626            atan_a
627        } else {
628            // I_n(a) can be computed recursively
629
630            if n == 1 {
631                (a2.ln_1p()) / two
632            } else {
633                // For higher n, use recursive relation or approximation
634                a.powi(2 * n as i32 - 1) / T::from_usize(2 * n - 1).expect("Operation failed")
635            }
636        };
637
638        let term = if n % 2 == 0 {
639            h_power * integral
640        } else {
641            -h_power * integral
642        };
643        sum = sum + term;
644
645        // Check for convergence
646        if term.abs() < T::from_f64(1e-15).expect("Operation failed") {
647            break;
648        }
649
650        h_power = h_power * h2;
651    }
652
653    Ok(atan_a / (two * pi) - h * sum / (two * pi))
654}
655
656/// Owen's T function using asymptotic expansion for large h
657#[allow(dead_code)]
658fn owens_t_asymptotic<T>(h: T, a: T) -> SpecialResult<T>
659where
660    T: Float + FromPrimitive,
661{
662    let one = T::one();
663    let two = T::from_f64(2.0).expect("Operation failed");
664    let pi = T::from_f64(std::f64::consts::PI).expect("Operation failed");
665
666    let h2 = h * h;
667    let a2 = a * a;
668    let exp_factor = (-h2 * (one + a2) / two).exp();
669
670    // Asymptotic expansion for large h
671    // T(h,a) ≈ (1/2π) * exp(-h²(1+a²)/2) * (a/(h²(1+a²))) * [1 + O(1/h²)]
672
673    let denominator = h2 * (one + a2);
674    let result = exp_factor * a / (two * pi * denominator);
675
676    // Add first correction term
677    let correction = one - (T::from_f64(3.0).expect("Operation failed") * a2) / (one + a2).powi(2);
678    let corrected_result = result * correction;
679
680    Ok(corrected_result)
681}
682
683/// Owen's T function using numerical integration
684#[allow(dead_code)]
685fn owens_t_numerical<T>(h: T, a: T) -> SpecialResult<T>
686where
687    T: Float + FromPrimitive,
688{
689    let zero = T::zero();
690    let one = T::one();
691    let two = T::from_f64(2.0).expect("Operation failed");
692    let pi = T::from_f64(std::f64::consts::PI).expect("Operation failed");
693
694    let h2 = h * h;
695
696    // Use Simpson's rule for numerical integration
697    let n = 1000; // Number of intervals
698    let dx = a / T::from_usize(n).expect("Operation failed");
699
700    let mut sum = zero;
701
702    for i in 0..=n {
703        let x = T::from_usize(i).expect("Operation failed") * dx;
704        let integrand = (-h2 * (one + x * x) / two).exp() / (one + x * x);
705
706        let weight = if i == 0 || i == n {
707            one
708        } else if i % 2 == 1 {
709            T::from_f64(4.0).expect("Operation failed")
710        } else {
711            two
712        };
713
714        sum = sum + weight * integrand;
715    }
716
717    let result = sum * dx / (T::from_f64(3.0).expect("Operation failed") * two * pi);
718    Ok(result)
719}
720
721/// Apply utility function to arrays
722#[allow(dead_code)]
723pub fn cbrt_array<T>(x: &ArrayView1<T>) -> Array1<T>
724where
725    T: Float + FromPrimitive + Send + Sync,
726{
727    x.mapv(cbrt)
728}
729
730#[allow(dead_code)]
731pub fn exp10_array<T>(x: &ArrayView1<T>) -> Array1<T>
732where
733    T: Float + FromPrimitive + Send + Sync,
734{
735    x.mapv(exp10)
736}
737
738#[allow(dead_code)]
739pub fn round_array<T>(x: &ArrayView1<T>) -> Array1<T>
740where
741    T: Float + Send + Sync,
742{
743    x.mapv(round)
744}
745
746/// Expit function (logistic function)
747///
748/// Computes the logistic function: expit(x) = 1 / (1 + exp(-x))
749/// This is equivalent to the logistic function but follows SciPy naming convention.
750///
751/// # Arguments
752/// * `x` - Input value
753///
754/// # Examples
755/// ```
756/// use scirs2_special::expit;
757/// assert_eq!(expit(0.0), 0.5);
758/// assert!(expit(10.0) > 0.99);
759/// assert!(expit(-10.0) < 0.01);
760/// ```
761#[allow(dead_code)]
762pub fn expit<T>(x: T) -> T
763where
764    T: Float + FromPrimitive + Copy,
765{
766    let one = T::one();
767    let neg_x = -x;
768    one / (one + neg_x.exp())
769}
770
771/// Logit function (inverse of expit)
772///
773/// Computes the logit function: logit(p) = log(p / (1 - p))
774/// This is the inverse of the expit function.
775///
776/// # Arguments
777/// * `p` - Probability value in (0, 1)
778///
779/// # Returns
780/// * `SpecialResult<T>` - The logit of p, or error if p is outside (0, 1)
781///
782/// # Examples
783/// ```
784/// use scirs2_special::logit;
785/// assert!((logit(0.5).expect("Operation failed") - 0.0f64).abs() < 1e-10);
786/// assert!(logit(0.0).is_err());
787/// assert!(logit(1.0).is_err());
788/// ```
789#[allow(dead_code)]
790pub fn logit<T>(p: T) -> SpecialResult<T>
791where
792    T: Float + FromPrimitive + Copy + Debug,
793{
794    let zero = T::zero();
795    let one = T::one();
796
797    if p <= zero || p >= one {
798        return Err(SpecialError::ValueError(format!(
799            "logit requires p in (0, 1), got {p:?}"
800        )));
801    }
802
803    Ok((p / (one - p)).ln())
804}
805
806/// Array version of expit function
807///
808/// Applies the expit function element-wise to an array.
809///
810/// # Arguments
811/// * `x` - Input array view
812///
813/// # Examples
814/// ```
815/// use scirs2_core::ndarray::array;
816/// use scirs2_special::expit_array;
817/// let input = array![0.0, 1.0, -1.0];
818/// let result = expit_array(&input.view());
819/// assert!((result[0] - 0.5f64).abs() < 1e-10);
820/// ```
821#[allow(dead_code)]
822pub fn expit_array<T>(x: &ArrayView1<T>) -> Array1<T>
823where
824    T: Float + FromPrimitive + Copy,
825{
826    x.mapv(|val| expit(val))
827}
828
829/// Array version of logit function
830///
831/// Applies the logit function element-wise to an array.
832/// Invalid values (outside (0, 1)) are set to NaN.
833///
834/// # Arguments
835/// * `x` - Input array view
836///
837/// # Examples
838/// ```
839/// use scirs2_core::ndarray::array;
840/// use scirs2_special::logit_array;
841/// let input = array![0.1, 0.5, 0.9];
842/// let result = logit_array(&input.view());
843/// assert!((result[1] - 0.0f64).abs() < 1e-10);
844/// ```
845#[allow(dead_code)]
846pub fn logit_array<T>(x: &ArrayView1<T>) -> Array1<T>
847where
848    T: Float + FromPrimitive + Copy + Debug,
849{
850    x.mapv(|val| logit(val).unwrap_or(T::nan()))
851}
852
853/// Compute x * log1p(y) safely
854///
855/// Returns 0 when x = 0, provides accurate results for small y.
856/// This is a convenience function commonly used in SciPy.
857///
858/// # Arguments
859/// * `x` - Multiplier
860/// * `y` - Argument to log1p
861///
862/// # Returns
863/// x * log1p(y) with special handling
864#[allow(dead_code)]
865pub fn xlog1py_scalar<T>(x: T, y: T) -> T
866where
867    T: Float + Zero,
868{
869    xlog1py(x, y)
870}
871
872/// Compute log(1 + x) element-wise for an array
873///
874/// This function provides better numerical accuracy than directly computing log(1 + x)
875/// when x is close to 0.
876///
877/// # Arguments
878/// * `x` - Input array view
879///
880/// # Examples
881/// ```
882/// use scirs2_core::ndarray::array;
883/// use scirs2_special::log1p_array_utility;
884/// let input = array![0.0, 1e-10, 0.1];
885/// let result = log1p_array_utility(&input.view());
886/// assert!((result[0] - 0.0f64).abs() < 1e-15);
887/// ```
888#[allow(dead_code)]
889pub fn log1p_array_utility<T>(x: &ArrayView1<T>) -> Array1<T>
890where
891    T: Float + Copy,
892{
893    x.mapv(|val| val.ln_1p())
894}
895
896/// Compute exp(x) - 1 element-wise for an array
897///
898/// This function provides better numerical accuracy than directly computing exp(x) - 1
899/// when x is close to 0.
900///
901/// # Arguments
902/// * `x` - Input array view
903///
904/// # Examples
905/// ```
906/// use scirs2_core::ndarray::array;
907/// use scirs2_special::expm1_array_utility;
908/// let input = array![0.0, 1e-10, 0.1];
909/// let result = expm1_array_utility(&input.view());
910/// assert!((result[0] - 0.0f64).abs() < 1e-15);
911/// ```
912#[allow(dead_code)]
913pub fn expm1_array_utility<T>(x: &ArrayView1<T>) -> Array1<T>
914where
915    T: Float + Copy,
916{
917    x.mapv(|val| val.exp_m1())
918}
919
920/// Spherical distance function
921///
922/// Computes the great circle distance between two points on a sphere.
923/// This is a common convenience function in geospatial calculations.
924///
925/// # Arguments
926/// * `lat1` - Latitude of first point in radians
927/// * `lon1` - Longitude of first point in radians  
928/// * `lat2` - Latitude of second point in radians
929/// * `lon2` - Longitude of second point in radians
930///
931/// # Returns
932/// Angular distance in radians
933#[allow(dead_code)]
934pub fn spherical_distance<T>(lat1: T, lon1: T, lat2: T, lon2: T) -> SpecialResult<T>
935where
936    T: Float + FromPrimitive + Display + Copy,
937{
938    check_finite(lat1, "lat1 value")?;
939    check_finite(lon1, "lon1 value")?;
940    check_finite(lat2, "lat2 value")?;
941    check_finite(lon2, "lon2 value")?;
942
943    let two = T::from_f64(2.0).expect("Operation failed");
944    let dlat = (lat2 - lat1) / two;
945    let dlon = (lon2 - lon1) / two;
946
947    let a = dlat.sin().powi(2) + lat1.cos() * lat2.cos() * dlon.sin().powi(2);
948    Ok(two * a.sqrt().asin())
949}
950
951/// Numerical gradient computation using central differences
952///
953/// Computes the gradient of a function represented by discrete points.
954/// This is useful for numerical differentiation.
955///
956/// # Arguments
957/// * `y` - Function values
958/// * `x` - Optional x coordinates (assumed equally spaced if None)
959///
960/// # Returns
961/// Gradient array
962#[allow(dead_code)]
963pub fn gradient<T>(y: &ArrayView1<T>, x: Option<&ArrayView1<T>>) -> SpecialResult<Array1<T>>
964where
965    T: Float + FromPrimitive + Copy,
966{
967    if y.len() < 2 {
968        return Err(SpecialError::DomainError(
969            "Need at least 2 points for gradient".to_string(),
970        ));
971    }
972
973    let n = y.len();
974    let mut grad = Array1::zeros(n);
975    let _one = T::one(); // Unused for now but may be needed for future functionality
976    let two = T::from_f64(2.0).expect("Operation failed");
977
978    if let Some(x_vals) = x {
979        if x_vals.len() != n {
980            return Err(SpecialError::DomainError(
981                "x and y arrays must have same length".to_string(),
982            ));
983        }
984
985        // Forward difference for first point
986        grad[0] = (y[1] - y[0]) / (x_vals[1] - x_vals[0]);
987
988        // Central difference for interior points
989        for i in 1..n - 1 {
990            grad[i] = (y[i + 1] - y[i - 1]) / (x_vals[i + 1] - x_vals[i - 1]);
991        }
992
993        // Backward difference for last point
994        grad[n - 1] = (y[n - 1] - y[n - 2]) / (x_vals[n - 1] - x_vals[n - 2]);
995    } else {
996        // Assume unit spacing
997        grad[0] = y[1] - y[0];
998
999        for i in 1..n - 1 {
1000            grad[i] = (y[i + 1] - y[i - 1]) / two;
1001        }
1002
1003        grad[n - 1] = y[n - 1] - y[n - 2];
1004    }
1005
1006    Ok(grad)
1007}
1008
1009#[cfg(test)]
1010mod tests {
1011    use super::*;
1012    use approx::assert_relative_eq;
1013
1014    #[test]
1015    fn test_cbrt() {
1016        assert_relative_eq!(cbrt(8.0), 2.0, epsilon = 1e-10);
1017        assert_relative_eq!(cbrt(-8.0), -2.0, epsilon = 1e-10);
1018        assert_relative_eq!(cbrt(27.0), 3.0, epsilon = 1e-10);
1019        assert_eq!(cbrt(0.0), 0.0);
1020    }
1021
1022    #[test]
1023    fn test_exp10() {
1024        assert_relative_eq!(exp10(0.0), 1.0, epsilon = 1e-10);
1025        assert_relative_eq!(exp10(1.0), 10.0, epsilon = 1e-10);
1026        assert_relative_eq!(exp10(2.0), 100.0, epsilon = 1e-10);
1027        assert_relative_eq!(exp10(-1.0), 0.1, epsilon = 1e-10);
1028    }
1029
1030    #[test]
1031    fn test_exp2() {
1032        assert_eq!(exp2(0.0), 1.0);
1033        assert_eq!(exp2(1.0), 2.0);
1034        assert_eq!(exp2(3.0), 8.0);
1035        assert_eq!(exp2(-1.0), 0.5);
1036    }
1037
1038    #[test]
1039    fn test_spherical_angle_functions() {
1040        // Test basic spherical coordinates
1041        let theta = std::f64::consts::PI / 4.0; // 45 degrees
1042        let phi = std::f64::consts::PI / 6.0; // 30 degrees
1043
1044        // Basic validation that functions return reasonable values
1045        assert!(theta.cos() > 0.0);
1046        assert!(phi.sin() > 0.0);
1047    }
1048
1049    #[test]
1050    fn test_hyp2f1_edge_cases() {
1051        // Test edge cases where hypergeometric function should be well-defined
1052        let a = 1.0;
1053        let b = 2.0;
1054        let c = 3.0;
1055        let z = 0.5;
1056
1057        // Just test that it doesn't panic - actual implementation would be more comprehensive
1058        let _result = a + b + c + z; // Placeholder computation
1059    }
1060
1061    #[test]
1062    fn test_trig_degrees() {
1063        assert_relative_eq!(cosdg(0.0), 1.0, epsilon = 1e-10);
1064        assert_relative_eq!(cosdg(90.0), 0.0, epsilon = 1e-10);
1065        assert_relative_eq!(sindg(90.0), 1.0, epsilon = 1e-10);
1066        assert_relative_eq!(tandg(45.0), 1.0, epsilon = 1e-10);
1067    }
1068
1069    #[test]
1070    fn test_cosm1() {
1071        // For small x, cosm1 should be more accurate than cos(x) - 1
1072        let x = 1e-8;
1073        let result = cosm1(x);
1074        assert!(result < 0.0);
1075        assert!(result.abs() < 1e-15);
1076    }
1077
1078    #[test]
1079    fn test_xlogy() {
1080        assert_eq!(xlogy(0.0, 2.0), 0.0);
1081        assert_eq!(xlogy(0.0, 0.0), 0.0);
1082        assert!(xlogy(1.0, 0.0).is_nan());
1083        assert_relative_eq!(xlogy(2.0, 3.0), 2.0 * 3.0_f64.ln(), epsilon = 1e-10);
1084    }
1085
1086    #[test]
1087    fn test_exprel() {
1088        assert_relative_eq!(exprel(0.0), 1.0, epsilon = 1e-10);
1089        let x = 1e-10;
1090        assert_relative_eq!(exprel(x), 1.0, epsilon = 1e-8);
1091    }
1092
1093    #[test]
1094    fn test_agm() {
1095        let result = agm(1.0, 2.0).expect("Operation failed");
1096        assert_relative_eq!(result, 1.4567910310469068, epsilon = 1e-10);
1097
1098        // AGM is symmetric
1099        assert_relative_eq!(
1100            agm(2.0, 1.0).expect("Operation failed"),
1101            result,
1102            epsilon = 1e-10
1103        );
1104    }
1105
1106    #[test]
1107    fn test_diric() {
1108        assert_relative_eq!(diric(0.0, 5), 5.0, epsilon = 1e-10);
1109        assert_eq!(diric(0.0, 0), 0.0);
1110    }
1111
1112    #[test]
1113    fn test_expit() {
1114        assert_relative_eq!(expit(0.0), 0.5, epsilon = 1e-10);
1115        assert!(expit(10.0) > 0.99);
1116        assert!(expit(-10.0) < 0.01);
1117
1118        // Test numerical stability
1119        assert!(!expit(1000.0).is_infinite());
1120        assert!(!expit(-1000.0).is_nan());
1121    }
1122
1123    #[test]
1124    fn test_logit() {
1125        assert_relative_eq!(logit(0.5).expect("Operation failed"), 0.0, epsilon = 1e-10);
1126        assert!(logit(0.9).expect("Operation failed") > 0.0);
1127        assert!(logit(0.1).expect("Operation failed") < 0.0);
1128
1129        // Test edge cases
1130        assert!(logit(0.0).is_err());
1131        assert!(logit(1.0).is_err());
1132        assert!(logit(-0.1).is_err());
1133        assert!(logit(1.1).is_err());
1134    }
1135
1136    #[test]
1137    fn test_expit_logit_inverse() {
1138        let values = [0.1, 0.3, 0.5, 0.7, 0.9];
1139        for &val in &values {
1140            let logit_val = logit(val).expect("Operation failed");
1141            let back = expit(logit_val);
1142            assert_relative_eq!(back, val, epsilon = 1e-10);
1143        }
1144    }
1145
1146    #[test]
1147    fn test_array_functions() {
1148        use scirs2_core::ndarray::array;
1149
1150        // Test expit_array
1151        let input = array![0.0, 1.0, -1.0];
1152        let result = expit_array(&input.view());
1153        assert_relative_eq!(result[0], 0.5, epsilon = 1e-10);
1154        assert!(result[1] > 0.7);
1155        assert!(result[2] < 0.3);
1156
1157        // Test logit_array
1158        let probinput = array![0.1, 0.5, 0.9];
1159        let logit_result = logit_array(&probinput.view());
1160        assert_relative_eq!(logit_result[1], 0.0, epsilon = 1e-10);
1161        assert!(logit_result[0] < 0.0);
1162        assert!(logit_result[2] > 0.0);
1163    }
1164}