1use crate::error::{CoreError, ErrorContext};
7use crate::validation::check_finite;
8use num_traits::Float;
9use std::fmt::{Debug, Display};
10
11#[inline]
13#[allow(dead_code)]
14pub fn safe_divide<T>(numerator: T, denominator: T) -> Result<T, CoreError>
15where
16 T: Float + Display + Debug,
17{
18 if denominator == T::zero() {
20 return Err(CoreError::DomainError(ErrorContext::new(format!(
21 "Division by zero: {numerator} / 0"
22 ))));
23 }
24
25 let epsilon = T::epsilon();
27 if denominator.abs() < epsilon {
28 return Err(CoreError::DomainError(ErrorContext::new(format!(
29 "Division by near-zero value: {numerator} / {denominator} (threshold: {epsilon})"
30 ))));
31 }
32
33 let result = numerator / denominator;
34
35 check_finite(result, "division result").map_err(|_| {
37 CoreError::ComputationError(ErrorContext::new(format!(
38 "Division produced non-finite result: {numerator} / {denominator} = {result:?}"
39 )))
40 })?;
41
42 Ok(result)
43}
44
45#[inline]
47#[allow(dead_code)]
48pub fn safe_sqrt<T>(value: T) -> Result<T, CoreError>
49where
50 T: Float + Display + Debug,
51{
52 if value < T::zero() {
53 return Err(CoreError::DomainError(ErrorContext::new(format!(
54 "Cannot compute sqrt of negative value: {value}"
55 ))));
56 }
57
58 let result = value.sqrt();
59
60 check_finite(result, "sqrt result").map_err(|_| {
62 CoreError::ComputationError(ErrorContext::new(format!(
63 "Square root produced non-finite result: sqrt({value}) = {result:?}"
64 )))
65 })?;
66
67 Ok(result)
68}
69
70#[inline]
72#[allow(dead_code)]
73pub fn safelog<T>(value: T) -> Result<T, CoreError>
74where
75 T: Float + Display + Debug,
76{
77 if value <= T::zero() {
78 return Err(CoreError::DomainError(ErrorContext::new(format!(
79 "Cannot compute log of non-positive value: {value}"
80 ))));
81 }
82
83 let result = value.ln();
84
85 check_finite(result, "log result").map_err(|_| {
86 CoreError::ComputationError(ErrorContext::new(format!(
87 "Logarithm produced non-finite result: ln({value}) = {result:?}"
88 )))
89 })?;
90
91 Ok(result)
92}
93
94#[inline]
96#[allow(dead_code)]
97pub fn safelog10<T>(value: T) -> Result<T, CoreError>
98where
99 T: Float + Display + Debug,
100{
101 if value <= T::zero() {
102 return Err(CoreError::DomainError(ErrorContext::new(format!(
103 "Cannot compute log10 of non-positive value: {value}"
104 ))));
105 }
106
107 let result = value.log10();
108
109 check_finite(result, "log10 result").map_err(|_| {
110 CoreError::ComputationError(ErrorContext::new(format!(
111 "Base-10 logarithm produced non-finite result: log10({value}) = {result:?}"
112 )))
113 })?;
114
115 Ok(result)
116}
117
118#[inline]
120#[allow(dead_code)]
121pub fn safe_pow<T>(base: T, exponent: T) -> Result<T, CoreError>
122where
123 T: Float + Display + Debug,
124{
125 if base < T::zero() && exponent.fract() != T::zero() {
127 return Err(CoreError::DomainError(ErrorContext::new(format!(
128 "Cannot compute fractional power of negative number: {base}^{exponent}"
129 ))));
130 }
131
132 if base == T::zero() && exponent < T::zero() {
133 return Err(CoreError::DomainError(ErrorContext::new(format!(
134 "Cannot compute negative power of zero: 0^{exponent}"
135 ))));
136 }
137
138 let result = base.powf(exponent);
139
140 check_finite(result, "power result").map_err(|_| {
141 CoreError::ComputationError(ErrorContext::new(format!(
142 "Power operation produced non-finite result: {base}^{exponent} = {result:?}"
143 )))
144 })?;
145
146 Ok(result)
147}
148
149#[inline]
151#[allow(dead_code)]
152pub fn safe_exp<T>(value: T) -> Result<T, CoreError>
153where
154 T: Float + Display + Debug,
155{
156 let max_exp = T::from(700.0).unwrap_or(T::infinity());
159 if value > max_exp {
160 return Err(CoreError::ComputationError(ErrorContext::new(format!(
161 "Exponential would overflow: exp({value}) > exp({max_exp})"
162 ))));
163 }
164
165 let result = value.exp();
166
167 check_finite(result, "exp result").map_err(|_| {
168 CoreError::ComputationError(ErrorContext::new(format!(
169 "Exponential produced non-finite result: exp({value}) = {result:?}"
170 )))
171 })?;
172
173 Ok(result)
174}
175
176#[inline]
178#[allow(dead_code)]
179pub fn safe_normalize<T>(value: T, norm: T) -> Result<T, CoreError>
180where
181 T: Float + Display + Debug,
182{
183 if value == T::zero() && norm == T::zero() {
185 return Ok(T::zero());
186 }
187
188 safe_divide(value, norm)
189}
190
191#[allow(dead_code)]
193pub fn safe_mean<T>(values: &[T]) -> Result<T, CoreError>
194where
195 T: Float + Display + Debug + std::iter::Sum,
196{
197 if values.is_empty() {
198 return Err(CoreError::InvalidArgument(ErrorContext::new(
199 "Cannot compute mean of empty array",
200 )));
201 }
202
203 let sum: T = values.iter().copied().sum();
204 let len = values.len();
205 let count = T::from(len).ok_or_else(|| {
206 CoreError::ComputationError(ErrorContext::new(format!(
207 "Failed to convert array length {len} to numeric type"
208 )))
209 })?;
210
211 safe_divide(sum, count)
212}
213
214#[allow(dead_code)]
216pub fn safe_variance<T>(values: &[T], mean: T) -> Result<T, CoreError>
217where
218 T: Float + Display + Debug + std::iter::Sum,
219{
220 let len = values.len();
221 if len < 2 {
222 return Err(CoreError::InvalidArgument(ErrorContext::new(format!(
223 "Cannot compute variance with {len} values (need at least 2)"
224 ))));
225 }
226
227 let sum_sq_diff: T = values
228 .iter()
229 .map(|&x| {
230 let diff = x - mean;
231 diff * diff
232 })
233 .sum();
234
235 let count = values.len() - 1;
236 let n_minus_1 = T::from(count).ok_or_else(|| {
237 CoreError::ComputationError(ErrorContext::new(format!(
238 "Failed to convert count {count} to numeric type"
239 )))
240 })?;
241
242 safe_divide(sum_sq_diff, n_minus_1)
243}
244
245#[cfg(test)]
246mod tests {
247 use super::*;
248
249 #[test]
250 fn test_safe_divide() {
251 assert_eq!(safe_divide(10.0, 2.0).unwrap(), 5.0);
253 assert_eq!(safe_divide(-10.0, 2.0).unwrap(), -5.0);
254
255 assert!(safe_divide(10.0, 0.0).is_err());
257 assert!(safe_divide(10.0, 1e-100).is_err()); assert!(safe_divide(f64::MAX, f64::MIN_POSITIVE).is_err());
261 }
262
263 #[test]
264 fn test_safe_sqrt() {
265 assert_eq!(safe_sqrt(4.0).unwrap(), 2.0);
267 assert_eq!(safe_sqrt(0.0).unwrap(), 0.0);
268
269 assert!(safe_sqrt(-1.0).is_err());
271 assert!(safe_sqrt(-1e-10).is_err());
272 }
273
274 #[test]
275 fn test_safelog() {
276 assert!((safelog(std::f64::consts::E).unwrap() - 1.0).abs() < 1e-10);
278 assert_eq!(safelog(1.0).unwrap(), 0.0);
279
280 assert!(safelog(0.0).is_err());
282 assert!(safelog(-1.0).is_err());
283 }
284
285 #[test]
286 fn test_safe_pow() {
287 assert_eq!(safe_pow(2.0, 3.0).unwrap(), 8.0);
289 assert_eq!(safe_pow(4.0, 0.5).unwrap(), 2.0);
290
291 assert!(safe_pow(-2.0, 0.5).is_err()); assert!(safe_pow(0.0, -1.0).is_err()); assert!(safe_pow(10.0, 1000.0).is_err());
297 }
298
299 #[test]
300 fn test_safe_exp() {
301 assert!((safe_exp(1.0).unwrap() - std::f64::consts::E).abs() < 1e-10);
303 assert_eq!(safe_exp(0.0).unwrap(), 1.0);
304
305 assert!(safe_exp(1000.0).is_err());
307 }
308
309 #[test]
310 fn test_safe_mean() {
311 assert_eq!(safe_mean(&[1.0, 2.0, 3.0]).unwrap(), 2.0);
313
314 assert!(safe_mean::<f64>(&[]).is_err());
316
317 assert_eq!(safe_mean(&[5.0]).unwrap(), 5.0);
319 }
320
321 #[test]
322 fn test_safe_variance() {
323 let values = vec![1.0, 2.0, 3.0, 4.0, 5.0];
325 let mean = 3.0;
326 assert!((safe_variance(&values, mean).unwrap() - 2.5).abs() < 1e-10);
327
328 assert!(safe_variance(&[1.0], 1.0).is_err());
330 assert!(safe_variance::<f64>(&[], 0.0).is_err());
331 }
332
333 #[test]
334 fn test_safe_normalize() {
335 assert_eq!(safe_normalize(3.0, 4.0).unwrap(), 0.75);
337
338 assert!(safe_normalize(1.0, 0.0).is_err());
340
341 assert_eq!(safe_normalize(0.0, 0.0).unwrap(), 0.0);
343 }
344}