Skip to main content

scirs2_special/
error_wrappers.rs

1//! Error handling wrappers for special functions
2//!
3//! This module provides consistent error-handling wrappers for all special functions,
4//! ensuring proper validation, error context, and recovery strategies.
5
6use crate::error::SpecialResult;
7use crate::error_context::{ErrorContext, ErrorContextExt, RecoveryStrategy};
8use crate::special_error;
9use crate::validation;
10use scirs2_core::ndarray::{Array1, ArrayBase, ArrayView1};
11use scirs2_core::numeric::{Float, FromPrimitive};
12use std::fmt::{Debug, Display};
13
14/// Configuration for error handling behavior
15#[derive(Debug, Clone)]
16pub struct ErrorConfig {
17    /// Whether to use recovery strategies
18    pub enable_recovery: bool,
19    /// Default recovery strategy
20    pub default_recovery: RecoveryStrategy,
21    /// Whether to log errors
22    pub log_errors: bool,
23    /// Maximum iterations before convergence error
24    pub max_iterations: usize,
25    /// Tolerance for convergence
26    pub tolerance: f64,
27}
28
29impl Default for ErrorConfig {
30    fn default() -> Self {
31        Self {
32            enable_recovery: false,
33            default_recovery: RecoveryStrategy::PropagateError,
34            log_errors: false,
35            max_iterations: 1000,
36            tolerance: 1e-10,
37        }
38    }
39}
40
41/// Wrapper for single-argument special functions
42pub struct SingleArgWrapper<F, T> {
43    pub name: &'static str,
44    pub func: F,
45    pub config: ErrorConfig,
46    _phantom: std::marker::PhantomData<T>,
47}
48
49impl<F, T> SingleArgWrapper<F, T>
50where
51    F: Fn(T) -> T,
52    T: Float + Display + Debug + FromPrimitive,
53{
54    pub fn new(name: &'static str, func: F) -> Self {
55        Self {
56            name,
57            func,
58            config: ErrorConfig::default(),
59            _phantom: std::marker::PhantomData,
60        }
61    }
62
63    pub fn with_config(mut self, config: ErrorConfig) -> Self {
64        self.config = config;
65        self
66    }
67
68    /// Evaluate the function with full error handling
69    pub fn evaluate(&self, x: T) -> SpecialResult<T> {
70        // Check for special cases that might cause issues
71        if x.is_nan() {
72            return Ok(T::nan());
73        }
74        if x.is_infinite() {
75            return Ok(T::infinity()); // Return positive infinity for gamma(∞)
76        }
77
78        // Validate input (after handling NaN and infinity)
79        validation::check_finite(x, "x")
80            .with_context(|| ErrorContext::new(self.name, "input validation").with_param("x", x))?;
81
82        // Compute the result
83        let result = (self.func)(x);
84
85        // Validate output
86        if result.is_nan() && !x.is_nan() {
87            if self.config.enable_recovery {
88                // Try recovery strategies
89                if let Some(recovered) = self.try_recover(x) {
90                    return Ok(recovered);
91                }
92            }
93
94            return Err(special_error!(
95                computation: self.name, "evaluation",
96                "x" => x
97            ));
98        }
99
100        if result.is_infinite() && !x.is_infinite() {
101            // Check if this is expected (e.g., gamma(0) = inf)
102            if !self.is_expected_infinity(x) {
103                return Err(special_error!(
104                    computation: self.name, "overflow",
105                    "x" => x
106                ));
107            }
108        }
109
110        Ok(result)
111    }
112
113    /// Check if infinity is expected for this input
114    fn is_expected_infinity(&self, x: T) -> bool {
115        // This would be customized per function
116        match self.name {
117            "gamma" => x == T::zero(),
118            "digamma" => x == T::zero() || (x < T::zero() && x.fract() == T::zero()),
119            _ => false,
120        }
121    }
122
123    /// Try to recover from an error
124    fn try_recover(&self, _x: T) -> Option<T> {
125        match self.config.default_recovery {
126            RecoveryStrategy::ReturnDefault => Some(T::zero()),
127            RecoveryStrategy::ClampToRange => {
128                // Function-specific clamping logic
129                None
130            }
131            RecoveryStrategy::UseApproximation => {
132                // Function-specific approximation
133                None
134            }
135            RecoveryStrategy::PropagateError => None,
136        }
137    }
138}
139
140/// Wrapper for two-argument special functions
141pub struct TwoArgWrapper<F, T> {
142    pub name: &'static str,
143    pub func: F,
144    pub config: ErrorConfig,
145    _phantom: std::marker::PhantomData<T>,
146}
147
148impl<F, T> TwoArgWrapper<F, T>
149where
150    F: Fn(T, T) -> T,
151    T: Float + Display + Debug + FromPrimitive,
152{
153    pub fn new(name: &'static str, func: F) -> Self {
154        Self {
155            name,
156            func,
157            config: ErrorConfig::default(),
158            _phantom: std::marker::PhantomData,
159        }
160    }
161
162    pub fn with_config(mut self, config: ErrorConfig) -> Self {
163        self.config = config;
164        self
165    }
166
167    /// Evaluate the function with full error handling
168    pub fn evaluate(&self, a: T, b: T) -> SpecialResult<T> {
169        // Validate inputs
170        validation::check_finite(a, "a").with_context(|| {
171            ErrorContext::new(self.name, "input validation")
172                .with_param("a", a)
173                .with_param("b", b)
174        })?;
175
176        validation::check_finite(b, "b").with_context(|| {
177            ErrorContext::new(self.name, "input validation")
178                .with_param("a", a)
179                .with_param("b", b)
180        })?;
181
182        // Additional function-specific validation
183        self.validate_specific(a, b)?;
184
185        // Compute the result
186        let result = (self.func)(a, b);
187
188        // Validate output
189        if result.is_nan() && !a.is_nan() && !b.is_nan() {
190            return Err(special_error!(
191                computation: self.name, "evaluation",
192                "a" => a,
193                "b" => b
194            ));
195        }
196
197        Ok(result)
198    }
199
200    /// Function-specific validation
201    fn validate_specific(&self, a: T, b: T) -> SpecialResult<()> {
202        match self.name {
203            "beta" => {
204                // Beta function requires positive arguments
205                validation::check_positive(a, "a")?;
206                validation::check_positive(b, "b")?;
207            }
208            "bessel_jn" => {
209                // Bessel functions might have order restrictions
210                // This would be more specific based on the actual function
211            }
212            _ => {}
213        }
214        Ok(())
215    }
216}
217
218/// Wrapper for array operations with error handling
219pub struct ArrayWrapper<F, T> {
220    pub name: &'static str,
221    pub func: F,
222    pub config: ErrorConfig,
223    _phantom: std::marker::PhantomData<T>,
224}
225
226impl<F, T> ArrayWrapper<F, T>
227where
228    F: Fn(&ArrayView1<T>) -> Array1<T>,
229    T: Float + Display + Debug + FromPrimitive,
230{
231    pub fn new(name: &'static str, func: F) -> Self {
232        Self {
233            name,
234            func,
235            config: ErrorConfig::default(),
236            _phantom: std::marker::PhantomData,
237        }
238    }
239
240    /// Evaluate the function on an array with full error handling
241    pub fn evaluate<S>(
242        &self,
243        input: &ArrayBase<S, scirs2_core::ndarray::Ix1>,
244    ) -> SpecialResult<Array1<T>>
245    where
246        S: scirs2_core::ndarray::Data<Elem = T>,
247    {
248        // Validate array
249        validation::check_array_finite(input, "input").with_context(|| {
250            ErrorContext::new(self.name, "array validation")
251                .with_param("shape", format!("{:?}", input.shape()))
252        })?;
253
254        validation::check_not_empty(input, "input")?;
255
256        // Compute the result
257        let result = (self.func)(&input.view());
258
259        // Validate output
260        let nan_count = result.iter().filter(|&&x| x.is_nan()).count();
261        if nan_count > 0 {
262            let total = result.len();
263            return Err(special_error!(
264                computation: self.name, "array evaluation",
265                "nan_count" => nan_count,
266                "total_elements" => total
267            ));
268        }
269
270        Ok(result)
271    }
272}
273
274/// Create error-wrapped versions of functions
275pub mod wrapped {
276    use super::*;
277    use crate::{beta, digamma, erf, erfc, gamma};
278
279    /// Create a wrapped gamma function with error handling
280    pub fn gamma_wrapped() -> SingleArgWrapper<fn(f64) -> f64, f64> {
281        SingleArgWrapper::new("gamma", gamma::<f64>)
282    }
283
284    /// Create a wrapped digamma function with error handling
285    pub fn digamma_wrapped() -> SingleArgWrapper<fn(f64) -> f64, f64> {
286        SingleArgWrapper::new("digamma", digamma::<f64>)
287    }
288
289    /// Create a wrapped beta function with error handling
290    pub fn beta_wrapped() -> TwoArgWrapper<fn(f64, f64) -> f64, f64> {
291        TwoArgWrapper::new("beta", beta::<f64>)
292    }
293
294    /// Create a wrapped erf function with error handling
295    pub fn erf_wrapped() -> SingleArgWrapper<fn(f64) -> f64, f64> {
296        SingleArgWrapper::new("erf", erf)
297    }
298
299    /// Create a wrapped erfc function with error handling
300    pub fn erfc_wrapped() -> SingleArgWrapper<fn(f64) -> f64, f64> {
301        SingleArgWrapper::new("erfc", erfc)
302    }
303}
304
305#[cfg(test)]
306mod tests {
307    use super::wrapped::*;
308    use super::*;
309
310    #[test]
311    fn test_gamma_wrapped() {
312        let gamma = gamma_wrapped();
313
314        // Valid input
315        let result = gamma.evaluate(5.0);
316        assert!(result.is_ok());
317        assert!((result.expect("Operation failed") - 24.0).abs() < 1e-10);
318
319        // Invalid input (NaN)
320        let result = gamma.evaluate(f64::NAN);
321        assert!(result.is_ok()); // NaN input returns NaN output
322        assert!(result.expect("Operation failed").is_nan());
323
324        // Invalid input (infinity)
325        let result = gamma.evaluate(f64::INFINITY);
326        assert!(result.is_ok());
327        assert!(result.expect("Operation failed").is_infinite());
328    }
329
330    #[test]
331    fn test_beta_wrapped() {
332        let beta = beta_wrapped();
333
334        // Valid inputs
335        let result = beta.evaluate(2.0, 3.0);
336        assert!(result.is_ok());
337
338        // Invalid inputs (negative)
339        let result = beta.evaluate(-1.0, 2.0);
340        assert!(result.is_err());
341    }
342
343    #[test]
344    fn test_array_wrapper() {
345        use scirs2_core::ndarray::arr1;
346
347        let arr_gamma = ArrayWrapper::new("gamma_array", |x: &ArrayView1<f64>| {
348            x.mapv(crate::gamma::gamma::<f64>)
349        });
350
351        // Valid array
352        let input = arr1(&[1.0, 2.0, 3.0, 4.0]);
353        let result = arr_gamma.evaluate(&input);
354        assert!(result.is_ok());
355
356        // Array with NaN
357        let input = arr1(&[1.0, f64::NAN, 3.0]);
358        let result = arr_gamma.evaluate(&input);
359        assert!(result.is_err());
360    }
361}