1use crate::error::SpecialResult;
7use crate::error_context::{ErrorContext, ErrorContextExt, RecoveryStrategy};
8use crate::special_error;
9use crate::validation;
10use scirs2_core::ndarray::{Array1, ArrayBase, ArrayView1};
11use scirs2_core::numeric::{Float, FromPrimitive};
12use std::fmt::{Debug, Display};
13
14#[derive(Debug, Clone)]
16pub struct ErrorConfig {
17 pub enable_recovery: bool,
19 pub default_recovery: RecoveryStrategy,
21 pub log_errors: bool,
23 pub max_iterations: usize,
25 pub tolerance: f64,
27}
28
29impl Default for ErrorConfig {
30 fn default() -> Self {
31 Self {
32 enable_recovery: false,
33 default_recovery: RecoveryStrategy::PropagateError,
34 log_errors: false,
35 max_iterations: 1000,
36 tolerance: 1e-10,
37 }
38 }
39}
40
41pub struct SingleArgWrapper<F, T> {
43 pub name: &'static str,
44 pub func: F,
45 pub config: ErrorConfig,
46 _phantom: std::marker::PhantomData<T>,
47}
48
49impl<F, T> SingleArgWrapper<F, T>
50where
51 F: Fn(T) -> T,
52 T: Float + Display + Debug + FromPrimitive,
53{
54 pub fn new(name: &'static str, func: F) -> Self {
55 Self {
56 name,
57 func,
58 config: ErrorConfig::default(),
59 _phantom: std::marker::PhantomData,
60 }
61 }
62
63 pub fn with_config(mut self, config: ErrorConfig) -> Self {
64 self.config = config;
65 self
66 }
67
68 pub fn evaluate(&self, x: T) -> SpecialResult<T> {
70 if x.is_nan() {
72 return Ok(T::nan());
73 }
74 if x.is_infinite() {
75 return Ok(T::infinity()); }
77
78 validation::check_finite(x, "x")
80 .with_context(|| ErrorContext::new(self.name, "input validation").with_param("x", x))?;
81
82 let result = (self.func)(x);
84
85 if result.is_nan() && !x.is_nan() {
87 if self.config.enable_recovery {
88 if let Some(recovered) = self.try_recover(x) {
90 return Ok(recovered);
91 }
92 }
93
94 return Err(special_error!(
95 computation: self.name, "evaluation",
96 "x" => x
97 ));
98 }
99
100 if result.is_infinite() && !x.is_infinite() {
101 if !self.is_expected_infinity(x) {
103 return Err(special_error!(
104 computation: self.name, "overflow",
105 "x" => x
106 ));
107 }
108 }
109
110 Ok(result)
111 }
112
113 fn is_expected_infinity(&self, x: T) -> bool {
115 match self.name {
117 "gamma" => x == T::zero(),
118 "digamma" => x == T::zero() || (x < T::zero() && x.fract() == T::zero()),
119 _ => false,
120 }
121 }
122
123 fn try_recover(&self, _x: T) -> Option<T> {
125 match self.config.default_recovery {
126 RecoveryStrategy::ReturnDefault => Some(T::zero()),
127 RecoveryStrategy::ClampToRange => {
128 None
130 }
131 RecoveryStrategy::UseApproximation => {
132 None
134 }
135 RecoveryStrategy::PropagateError => None,
136 }
137 }
138}
139
140pub struct TwoArgWrapper<F, T> {
142 pub name: &'static str,
143 pub func: F,
144 pub config: ErrorConfig,
145 _phantom: std::marker::PhantomData<T>,
146}
147
148impl<F, T> TwoArgWrapper<F, T>
149where
150 F: Fn(T, T) -> T,
151 T: Float + Display + Debug + FromPrimitive,
152{
153 pub fn new(name: &'static str, func: F) -> Self {
154 Self {
155 name,
156 func,
157 config: ErrorConfig::default(),
158 _phantom: std::marker::PhantomData,
159 }
160 }
161
162 pub fn with_config(mut self, config: ErrorConfig) -> Self {
163 self.config = config;
164 self
165 }
166
167 pub fn evaluate(&self, a: T, b: T) -> SpecialResult<T> {
169 validation::check_finite(a, "a").with_context(|| {
171 ErrorContext::new(self.name, "input validation")
172 .with_param("a", a)
173 .with_param("b", b)
174 })?;
175
176 validation::check_finite(b, "b").with_context(|| {
177 ErrorContext::new(self.name, "input validation")
178 .with_param("a", a)
179 .with_param("b", b)
180 })?;
181
182 self.validate_specific(a, b)?;
184
185 let result = (self.func)(a, b);
187
188 if result.is_nan() && !a.is_nan() && !b.is_nan() {
190 return Err(special_error!(
191 computation: self.name, "evaluation",
192 "a" => a,
193 "b" => b
194 ));
195 }
196
197 Ok(result)
198 }
199
200 fn validate_specific(&self, a: T, b: T) -> SpecialResult<()> {
202 match self.name {
203 "beta" => {
204 validation::check_positive(a, "a")?;
206 validation::check_positive(b, "b")?;
207 }
208 "bessel_jn" => {
209 }
212 _ => {}
213 }
214 Ok(())
215 }
216}
217
218pub struct ArrayWrapper<F, T> {
220 pub name: &'static str,
221 pub func: F,
222 pub config: ErrorConfig,
223 _phantom: std::marker::PhantomData<T>,
224}
225
226impl<F, T> ArrayWrapper<F, T>
227where
228 F: Fn(&ArrayView1<T>) -> Array1<T>,
229 T: Float + Display + Debug + FromPrimitive,
230{
231 pub fn new(name: &'static str, func: F) -> Self {
232 Self {
233 name,
234 func,
235 config: ErrorConfig::default(),
236 _phantom: std::marker::PhantomData,
237 }
238 }
239
240 pub fn evaluate<S>(
242 &self,
243 input: &ArrayBase<S, scirs2_core::ndarray::Ix1>,
244 ) -> SpecialResult<Array1<T>>
245 where
246 S: scirs2_core::ndarray::Data<Elem = T>,
247 {
248 validation::check_array_finite(input, "input").with_context(|| {
250 ErrorContext::new(self.name, "array validation")
251 .with_param("shape", format!("{:?}", input.shape()))
252 })?;
253
254 validation::check_not_empty(input, "input")?;
255
256 let result = (self.func)(&input.view());
258
259 let nan_count = result.iter().filter(|&&x| x.is_nan()).count();
261 if nan_count > 0 {
262 let total = result.len();
263 return Err(special_error!(
264 computation: self.name, "array evaluation",
265 "nan_count" => nan_count,
266 "total_elements" => total
267 ));
268 }
269
270 Ok(result)
271 }
272}
273
274pub mod wrapped {
276 use super::*;
277 use crate::{beta, digamma, erf, erfc, gamma};
278
279 pub fn gamma_wrapped() -> SingleArgWrapper<fn(f64) -> f64, f64> {
281 SingleArgWrapper::new("gamma", gamma::<f64>)
282 }
283
284 pub fn digamma_wrapped() -> SingleArgWrapper<fn(f64) -> f64, f64> {
286 SingleArgWrapper::new("digamma", digamma::<f64>)
287 }
288
289 pub fn beta_wrapped() -> TwoArgWrapper<fn(f64, f64) -> f64, f64> {
291 TwoArgWrapper::new("beta", beta::<f64>)
292 }
293
294 pub fn erf_wrapped() -> SingleArgWrapper<fn(f64) -> f64, f64> {
296 SingleArgWrapper::new("erf", erf)
297 }
298
299 pub fn erfc_wrapped() -> SingleArgWrapper<fn(f64) -> f64, f64> {
301 SingleArgWrapper::new("erfc", erfc)
302 }
303}
304
305#[cfg(test)]
306mod tests {
307 use super::wrapped::*;
308 use super::*;
309
310 #[test]
311 fn test_gamma_wrapped() {
312 let gamma = gamma_wrapped();
313
314 let result = gamma.evaluate(5.0);
316 assert!(result.is_ok());
317 assert!((result.expect("Operation failed") - 24.0).abs() < 1e-10);
318
319 let result = gamma.evaluate(f64::NAN);
321 assert!(result.is_ok()); assert!(result.expect("Operation failed").is_nan());
323
324 let result = gamma.evaluate(f64::INFINITY);
326 assert!(result.is_ok());
327 assert!(result.expect("Operation failed").is_infinite());
328 }
329
330 #[test]
331 fn test_beta_wrapped() {
332 let beta = beta_wrapped();
333
334 let result = beta.evaluate(2.0, 3.0);
336 assert!(result.is_ok());
337
338 let result = beta.evaluate(-1.0, 2.0);
340 assert!(result.is_err());
341 }
342
343 #[test]
344 fn test_array_wrapper() {
345 use scirs2_core::ndarray::arr1;
346
347 let arr_gamma = ArrayWrapper::new("gamma_array", |x: &ArrayView1<f64>| {
348 x.mapv(crate::gamma::gamma::<f64>)
349 });
350
351 let input = arr1(&[1.0, 2.0, 3.0, 4.0]);
353 let result = arr_gamma.evaluate(&input);
354 assert!(result.is_ok());
355
356 let input = arr1(&[1.0, f64::NAN, 3.0]);
358 let result = arr_gamma.evaluate(&input);
359 assert!(result.is_err());
360 }
361}