scirs2_core/
safe_ops.rs

1//! Safe mathematical operations that handle edge cases and validate results
2//!
3//! This module provides safe wrappers around common mathematical operations
4//! that can produce NaN, Infinity, or other invalid results.
5
6use crate::error::{CoreError, ErrorContext};
7use crate::validation::check_finite;
8use num_traits::Float;
9use std::fmt::{Debug, Display};
10
11/// Safely divide two numbers, checking for division by zero and validating the result
12#[inline]
13#[allow(dead_code)]
14pub fn safe_divide<T>(numerator: T, denominator: T) -> Result<T, CoreError>
15where
16    T: Float + Display + Debug,
17{
18    // Check for exact zero
19    if denominator == T::zero() {
20        return Err(CoreError::DomainError(ErrorContext::new(format!(
21            "Division by zero: {numerator} / 0"
22        ))));
23    }
24
25    // Check for near-zero values that could cause overflow
26    let epsilon = T::epsilon();
27    if denominator.abs() < epsilon {
28        return Err(CoreError::DomainError(ErrorContext::new(format!(
29            "Division by near-zero value: {numerator} / {denominator} (threshold: {epsilon})"
30        ))));
31    }
32
33    let result = numerator / denominator;
34
35    // Validate the result
36    check_finite(result, "division result").map_err(|_| {
37        CoreError::ComputationError(ErrorContext::new(format!(
38            "Division produced non-finite result: {numerator} / {denominator} = {result:?}"
39        )))
40    })?;
41
42    Ok(result)
43}
44
45/// Safely compute square root, checking for negative values
46#[inline]
47#[allow(dead_code)]
48pub fn safe_sqrt<T>(value: T) -> Result<T, CoreError>
49where
50    T: Float + Display + Debug,
51{
52    if value < T::zero() {
53        return Err(CoreError::DomainError(ErrorContext::new(format!(
54            "Cannot compute sqrt of negative value: {value}"
55        ))));
56    }
57
58    let result = value.sqrt();
59
60    // Even for valid inputs, check the result
61    check_finite(result, "sqrt result").map_err(|_| {
62        CoreError::ComputationError(ErrorContext::new(format!(
63            "Square root produced non-finite result: sqrt({value}) = {result:?}"
64        )))
65    })?;
66
67    Ok(result)
68}
69
70/// Safely compute natural logarithm, checking for non-positive values
71#[inline]
72#[allow(dead_code)]
73pub fn safelog<T>(value: T) -> Result<T, CoreError>
74where
75    T: Float + Display + Debug,
76{
77    if value <= T::zero() {
78        return Err(CoreError::DomainError(ErrorContext::new(format!(
79            "Cannot compute log of non-positive value: {value}"
80        ))));
81    }
82
83    let result = value.ln();
84
85    check_finite(result, "log result").map_err(|_| {
86        CoreError::ComputationError(ErrorContext::new(format!(
87            "Logarithm produced non-finite result: ln({value}) = {result:?}"
88        )))
89    })?;
90
91    Ok(result)
92}
93
94/// Safely compute base-10 logarithm
95#[inline]
96#[allow(dead_code)]
97pub fn safelog10<T>(value: T) -> Result<T, CoreError>
98where
99    T: Float + Display + Debug,
100{
101    if value <= T::zero() {
102        return Err(CoreError::DomainError(ErrorContext::new(format!(
103            "Cannot compute log10 of non-positive value: {value}"
104        ))));
105    }
106
107    let result = value.log10();
108
109    check_finite(result, "log10 result").map_err(|_| {
110        CoreError::ComputationError(ErrorContext::new(format!(
111            "Base-10 logarithm produced non-finite result: log10({value}) = {result:?}"
112        )))
113    })?;
114
115    Ok(result)
116}
117
118/// Safely compute power, checking for domain errors and overflow
119#[inline]
120#[allow(dead_code)]
121pub fn safe_pow<T>(base: T, exponent: T) -> Result<T, CoreError>
122where
123    T: Float + Display + Debug,
124{
125    // Special cases that could produce NaN or Inf
126    if base < T::zero() && exponent.fract() != T::zero() {
127        return Err(CoreError::DomainError(ErrorContext::new(format!(
128            "Cannot compute fractional power of negative number: {base}^{exponent}"
129        ))));
130    }
131
132    if base == T::zero() && exponent < T::zero() {
133        return Err(CoreError::DomainError(ErrorContext::new(format!(
134            "Cannot compute negative power of zero: 0^{exponent}"
135        ))));
136    }
137
138    let result = base.powf(exponent);
139
140    check_finite(result, "power result").map_err(|_| {
141        CoreError::ComputationError(ErrorContext::new(format!(
142            "Power operation produced non-finite result: {base}^{exponent} = {result:?}"
143        )))
144    })?;
145
146    Ok(result)
147}
148
149/// Safely compute exponential, checking for overflow
150#[inline]
151#[allow(dead_code)]
152pub fn safe_exp<T>(value: T) -> Result<T, CoreError>
153where
154    T: Float + Display + Debug,
155{
156    // Check for values that would cause overflow
157    // For f64, exp(x) overflows when x > ~709.78
158    let max_exp = T::from(700.0).unwrap_or(T::infinity());
159    if value > max_exp {
160        return Err(CoreError::ComputationError(ErrorContext::new(format!(
161            "Exponential would overflow: exp({value}) > exp({max_exp})"
162        ))));
163    }
164
165    let result = value.exp();
166
167    check_finite(result, "exp result").map_err(|_| {
168        CoreError::ComputationError(ErrorContext::new(format!(
169            "Exponential produced non-finite result: exp({value}) = {result:?}"
170        )))
171    })?;
172
173    Ok(result)
174}
175
176/// Safely normalize a value by dividing by a norm/magnitude
177#[inline]
178#[allow(dead_code)]
179pub fn safe_normalize<T>(value: T, norm: T) -> Result<T, CoreError>
180where
181    T: Float + Display + Debug,
182{
183    // Special case: if both are zero, return zero
184    if value == T::zero() && norm == T::zero() {
185        return Ok(T::zero());
186    }
187
188    safe_divide(value, norm)
189}
190
191/// Safely compute the mean of a slice, handling empty slices
192#[allow(dead_code)]
193pub fn safe_mean<T>(values: &[T]) -> Result<T, CoreError>
194where
195    T: Float + Display + Debug + std::iter::Sum,
196{
197    if values.is_empty() {
198        return Err(CoreError::InvalidArgument(ErrorContext::new(
199            "Cannot compute mean of empty array",
200        )));
201    }
202
203    let sum: T = values.iter().copied().sum();
204    let len = values.len();
205    let count = T::from(len).ok_or_else(|| {
206        CoreError::ComputationError(ErrorContext::new(format!(
207            "Failed to convert array length {len} to numeric type"
208        )))
209    })?;
210
211    safe_divide(sum, count)
212}
213
214/// Safely compute variance, handling numerical issues
215#[allow(dead_code)]
216pub fn safe_variance<T>(values: &[T], mean: T) -> Result<T, CoreError>
217where
218    T: Float + Display + Debug + std::iter::Sum,
219{
220    let len = values.len();
221    if len < 2 {
222        return Err(CoreError::InvalidArgument(ErrorContext::new(format!(
223            "Cannot compute variance with {len} values (need at least 2)"
224        ))));
225    }
226
227    let sum_sq_diff: T = values
228        .iter()
229        .map(|&x| {
230            let diff = x - mean;
231            diff * diff
232        })
233        .sum();
234
235    let count = values.len() - 1;
236    let n_minus_1 = T::from(count).ok_or_else(|| {
237        CoreError::ComputationError(ErrorContext::new(format!(
238            "Failed to convert count {count} to numeric type"
239        )))
240    })?;
241
242    safe_divide(sum_sq_diff, n_minus_1)
243}
244
245#[cfg(test)]
246mod tests {
247    use super::*;
248
249    #[test]
250    fn test_safe_divide() {
251        // Normal cases
252        assert_eq!(safe_divide(10.0, 2.0).unwrap(), 5.0);
253        assert_eq!(safe_divide(-10.0, 2.0).unwrap(), -5.0);
254
255        // Division by zero
256        assert!(safe_divide(10.0, 0.0).is_err());
257        assert!(safe_divide(10.0, 1e-100).is_err()); // Near zero
258
259        // Overflow case
260        assert!(safe_divide(f64::MAX, f64::MIN_POSITIVE).is_err());
261    }
262
263    #[test]
264    fn test_safe_sqrt() {
265        // Normal cases
266        assert_eq!(safe_sqrt(4.0).unwrap(), 2.0);
267        assert_eq!(safe_sqrt(0.0).unwrap(), 0.0);
268
269        // Negative input
270        assert!(safe_sqrt(-1.0).is_err());
271        assert!(safe_sqrt(-1e-10).is_err());
272    }
273
274    #[test]
275    fn test_safelog() {
276        // Normal cases
277        assert!((safelog(std::f64::consts::E).unwrap() - 1.0).abs() < 1e-10);
278        assert_eq!(safelog(1.0).unwrap(), 0.0);
279
280        // Invalid inputs
281        assert!(safelog(0.0).is_err());
282        assert!(safelog(-1.0).is_err());
283    }
284
285    #[test]
286    fn test_safe_pow() {
287        // Normal cases
288        assert_eq!(safe_pow(2.0, 3.0).unwrap(), 8.0);
289        assert_eq!(safe_pow(4.0, 0.5).unwrap(), 2.0);
290
291        // Invalid cases
292        assert!(safe_pow(-2.0, 0.5).is_err()); // Fractional power of negative
293        assert!(safe_pow(0.0, -1.0).is_err()); // Negative power of zero
294
295        // Overflow
296        assert!(safe_pow(10.0, 1000.0).is_err());
297    }
298
299    #[test]
300    fn test_safe_exp() {
301        // Normal cases
302        assert!((safe_exp(1.0).unwrap() - std::f64::consts::E).abs() < 1e-10);
303        assert_eq!(safe_exp(0.0).unwrap(), 1.0);
304
305        // Overflow
306        assert!(safe_exp(1000.0).is_err());
307    }
308
309    #[test]
310    fn test_safe_mean() {
311        // Normal case
312        assert_eq!(safe_mean(&[1.0, 2.0, 3.0]).unwrap(), 2.0);
313
314        // Empty array
315        assert!(safe_mean::<f64>(&[]).is_err());
316
317        // Single value
318        assert_eq!(safe_mean(&[5.0]).unwrap(), 5.0);
319    }
320
321    #[test]
322    fn test_safe_variance() {
323        // Normal case
324        let values = vec![1.0, 2.0, 3.0, 4.0, 5.0];
325        let mean = 3.0;
326        assert!((safe_variance(&values, mean).unwrap() - 2.5).abs() < 1e-10);
327
328        // Too few values
329        assert!(safe_variance(&[1.0], 1.0).is_err());
330        assert!(safe_variance::<f64>(&[], 0.0).is_err());
331    }
332
333    #[test]
334    fn test_safe_normalize() {
335        // Normal case
336        assert_eq!(safe_normalize(3.0, 4.0).unwrap(), 0.75);
337
338        // Zero norm
339        assert!(safe_normalize(1.0, 0.0).is_err());
340
341        // Both zero
342        assert_eq!(safe_normalize(0.0, 0.0).unwrap(), 0.0);
343    }
344}