1use crate::error::{CoreError, CoreResult, ErrorContext};
14use ::ndarray::{Array1, ArrayView1};
15use num_traits::{Float, FromPrimitive};
16use std::fmt::{Debug, Display};
17
18pub fn absolute_error<T: Float>(a: T, b: T) -> T {
24 (a - b).abs()
25}
26
27pub fn relative_error<T: Float>(computed: T, reference: T) -> T {
32 let diff = (computed - reference).abs();
33 let denom = reference.abs();
34 if denom.is_zero() {
35 if diff.is_zero() {
36 T::zero()
37 } else {
38 T::infinity()
39 }
40 } else {
41 diff / denom
42 }
43}
44
45pub fn relative_errors<T: Float + Display>(
48 computed: &ArrayView1<T>,
49 reference: &ArrayView1<T>,
50) -> CoreResult<Array1<T>> {
51 if computed.len() != reference.len() {
52 return Err(CoreError::ShapeError(ErrorContext::new(format!(
53 "Array length mismatch: computed has {} elements, reference has {}",
54 computed.len(),
55 reference.len()
56 ))));
57 }
58 let out: Vec<T> = computed
59 .iter()
60 .zip(reference.iter())
61 .map(|(&c, &r)| relative_error(c, r))
62 .collect();
63 Ok(Array1::from_vec(out))
64}
65
66pub fn max_relative_error<T: Float + Display>(
68 computed: &ArrayView1<T>,
69 reference: &ArrayView1<T>,
70) -> CoreResult<T> {
71 let errs = relative_errors(computed, reference)?;
72 Ok(errs
73 .iter()
74 .copied()
75 .fold(T::zero(), |acc, e| if e > acc { e } else { acc }))
76}
77
78pub fn compensated_sum<T: Float>(values: &[T]) -> T {
87 let mut sum = T::zero();
88 let mut compensation = T::zero();
89 for &val in values {
90 let y = val - compensation;
91 let t = sum + y;
92 compensation = (t - sum) - y;
93 sum = t;
94 }
95 sum
96}
97
98pub fn compensated_sum_array<T: Float>(values: &ArrayView1<T>) -> T {
105 if values.is_empty() {
106 return T::zero();
107 }
108 let mut sum = values[0];
109 let mut compensation = T::zero();
110 for &val in values.iter().skip(1) {
111 let t = sum + val;
112 if sum.abs() >= val.abs() {
113 compensation = compensation + ((sum - t) + val);
114 } else {
115 compensation = compensation + ((val - t) + sum);
116 }
117 sum = t;
118 }
119 sum + compensation
120}
121
122pub fn pairwise_sum_array<T: Float>(values: &ArrayView1<T>) -> T {
130 const THRESHOLD: usize = 128;
131 let n = values.len();
132 match n {
133 0 => T::zero(),
134 1 => values[0],
135 _ if n <= THRESHOLD => compensated_sum_array(values),
136 _ => {
137 let mid = n / 2;
138 let left = values.slice(ndarray::s![..mid]);
139 let right = values.slice(ndarray::s![mid..]);
140 pairwise_sum_array(&left) + pairwise_sum_array(&right)
141 }
142 }
143}
144
145pub fn softmax_array<T: Float>(values: &ArrayView1<T>) -> Array1<T> {
153 if values.is_empty() {
154 return Array1::from_vec(vec![]);
155 }
156 let max_val = values
157 .iter()
158 .copied()
159 .fold(T::neg_infinity(), |a, b| a.max(b));
160
161 let exp_vals: Vec<T> = values.iter().map(|&v| (v - max_val).exp()).collect();
162 let sum: T = exp_vals.iter().copied().fold(T::zero(), |a, b| a + b);
163 Array1::from_vec(exp_vals.into_iter().map(|e| e / sum).collect())
164}
165
166pub fn sigmoid_array<T: Float>(values: &ArrayView1<T>) -> Array1<T> {
168 let out: Vec<T> = values
169 .iter()
170 .map(|&x| {
171 if x >= T::zero() {
172 let exp_neg = (-x).exp();
173 T::one() / (T::one() + exp_neg)
174 } else {
175 let exp_x = x.exp();
176 exp_x / (T::one() + exp_x)
177 }
178 })
179 .collect();
180 Array1::from_vec(out)
181}
182
183pub fn log_sum_exp_array<T: Float>(values: &ArrayView1<T>) -> T {
185 if values.is_empty() {
186 return T::neg_infinity();
187 }
188 let max_val = values
189 .iter()
190 .copied()
191 .fold(T::neg_infinity(), |a, b| a.max(b));
192 if max_val.is_infinite() && max_val < T::zero() {
193 return max_val;
194 }
195 let sum: T = values
196 .iter()
197 .map(|&v| (v - max_val).exp())
198 .fold(T::zero(), |a, b| a + b);
199 max_val + sum.ln()
200}
201
202pub fn condition_number_1d<T: Float + Display>(values: &ArrayView1<T>) -> CoreResult<T> {
211 if values.is_empty() {
212 return Err(CoreError::ValueError(ErrorContext::new(
213 "Cannot compute condition number of empty array",
214 )));
215 }
216 let mut max_abs = T::zero();
217 let mut min_abs = T::infinity();
218 for &v in values.iter() {
219 let a = v.abs();
220 if a > max_abs {
221 max_abs = a;
222 }
223 if a > T::zero() && a < min_abs {
224 min_abs = a;
225 }
226 }
227 if max_abs.is_zero() {
228 return Err(CoreError::ValueError(ErrorContext::new(
229 "All elements are zero; condition number is undefined",
230 )));
231 }
232 if min_abs.is_infinite() {
233 return Err(CoreError::ValueError(ErrorContext::new(
234 "No non-zero elements found for condition number",
235 )));
236 }
237 Ok(max_abs / min_abs)
238}
239
240#[derive(Debug, Clone, Copy, PartialEq, Eq)]
246pub enum DifferenceMode {
247 Forward,
249 Backward,
251 Central,
253}
254
255pub fn numerical_gradient<T, F>(f: &F, x: &[T], h: T, mode: DifferenceMode) -> CoreResult<Array1<T>>
264where
265 T: Float + FromPrimitive + Debug,
266 F: Fn(&[T]) -> T,
267{
268 let n = x.len();
269 let two = T::from_f64(2.0).ok_or_else(|| {
270 CoreError::TypeError(ErrorContext::new("Failed to convert 2.0 to target type"))
271 })?;
272
273 let mut grad = Array1::zeros(n);
274 let mut x_perturbed = x.to_vec();
275
276 for i in 0..n {
277 let original = x_perturbed[i];
278
279 match mode {
280 DifferenceMode::Forward => {
281 x_perturbed[i] = original + h;
282 let f_plus = f(&x_perturbed);
283 x_perturbed[i] = original;
284 let f_0 = f(&x_perturbed);
285 grad[i] = (f_plus - f_0) / h;
286 }
287 DifferenceMode::Backward => {
288 x_perturbed[i] = original;
289 let f_0 = f(&x_perturbed);
290 x_perturbed[i] = original - h;
291 let f_minus = f(&x_perturbed);
292 grad[i] = (f_0 - f_minus) / h;
293 }
294 DifferenceMode::Central => {
295 x_perturbed[i] = original + h;
296 let f_plus = f(&x_perturbed);
297 x_perturbed[i] = original - h;
298 let f_minus = f(&x_perturbed);
299 grad[i] = (f_plus - f_minus) / (two * h);
300 }
301 }
302
303 x_perturbed[i] = original;
305 }
306
307 Ok(grad)
308}
309
310#[derive(Debug, Clone)]
316pub struct GradientCheckResult<T: Float> {
317 pub relative_errors: Array1<T>,
319 pub max_relative_error: T,
321 pub mean_relative_error: T,
323 pub passed: bool,
325}
326
327impl<T: Float + Display> std::fmt::Display for GradientCheckResult<T> {
328 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
329 write!(
330 f,
331 "GradientCheck(passed={}, max_rel_err={}, mean_rel_err={})",
332 self.passed, self.max_relative_error, self.mean_relative_error,
333 )
334 }
335}
336
337pub fn check_gradient<T, F>(
347 f: &F,
348 analytical_grad: &ArrayView1<T>,
349 x: &[T],
350 h: T,
351 tolerance: T,
352) -> CoreResult<GradientCheckResult<T>>
353where
354 T: Float + FromPrimitive + Debug + Display,
355 F: Fn(&[T]) -> T,
356{
357 if analytical_grad.len() != x.len() {
358 return Err(CoreError::ShapeError(ErrorContext::new(format!(
359 "Analytical gradient length {} does not match input dimension {}",
360 analytical_grad.len(),
361 x.len()
362 ))));
363 }
364
365 let numerical = numerical_gradient(f, x, h, DifferenceMode::Central)?;
366 let rel_errs = relative_errors(&analytical_grad, &numerical.view())?;
367 let max_err = rel_errs
368 .iter()
369 .copied()
370 .fold(T::zero(), |a, b| if b > a { b } else { a });
371
372 let n_f = T::from_usize(rel_errs.len().max(1)).unwrap_or(T::one());
373 let sum_err = rel_errs.iter().copied().fold(T::zero(), |a, b| a + b);
374 let mean_err = sum_err / n_f;
375
376 Ok(GradientCheckResult {
377 relative_errors: rel_errs,
378 max_relative_error: max_err,
379 mean_relative_error: mean_err,
380 passed: max_err < tolerance,
381 })
382}
383
384#[cfg(test)]
389mod tests {
390 use super::*;
391 use ::ndarray::array;
392
393 #[test]
394 fn test_absolute_error() {
395 assert!((absolute_error(3.0_f64, 3.0) - 0.0).abs() < 1e-15);
396 assert!((absolute_error(3.5_f64, 3.0) - 0.5).abs() < 1e-15);
397 }
398
399 #[test]
400 fn test_relative_error_basic() {
401 assert!((relative_error(1.01_f64, 1.0) - 0.01).abs() < 1e-10);
402 assert!((relative_error(0.0_f64, 0.0) - 0.0).abs() < 1e-15);
403 assert!(relative_error(1.0_f64, 0.0).is_infinite());
404 }
405
406 #[test]
407 fn test_relative_errors_array() {
408 let computed = array![1.01, 2.02, 3.03];
409 let reference = array![1.0, 2.0, 3.0];
410 let errs = relative_errors(&computed.view(), &reference.view()).expect("should succeed");
411 assert_eq!(errs.len(), 3);
412 for &e in errs.iter() {
413 assert!(e < 0.02);
414 }
415 }
416
417 #[test]
418 fn test_relative_errors_mismatch() {
419 let a = array![1.0, 2.0];
420 let b = array![1.0];
421 assert!(relative_errors(&a.view(), &b.view()).is_err());
422 }
423
424 #[test]
425 fn test_compensated_sum_accuracy() {
426 let values: Vec<f64> = (0..10_000).map(|_| 0.01).collect();
428 let result = compensated_sum(&values);
429 assert!((result - 100.0).abs() < 1e-10);
430 }
431
432 #[test]
433 fn test_compensated_sum_array_view() {
434 let arr = array![1e20, 1.0, -1e20];
435 let result = compensated_sum_array(&arr.view());
436 assert!((result - 1.0).abs() < 1e-5);
437 }
438
439 #[test]
440 fn test_pairwise_sum_array() {
441 let arr: Array1<f64> = Array1::from_vec((0..500).map(|i| 0.1 + 0.001 * i as f64).collect());
442 let pw = pairwise_sum_array(&arr.view());
443 let naive: f64 = arr.iter().sum();
444 assert!((pw - naive).abs() < 1e-8);
445 }
446
447 #[test]
448 fn test_softmax_array() {
449 let vals = array![1000.0_f64, 1000.0, 1000.0];
450 let sm = softmax_array(&vals.view());
451 for &p in sm.iter() {
452 assert!((p - 1.0 / 3.0).abs() < 1e-10);
453 }
454 let total: f64 = sm.iter().sum();
455 assert!((total - 1.0).abs() < 1e-10);
456 }
457
458 #[test]
459 fn test_softmax_empty() {
460 let vals: Array1<f64> = Array1::from_vec(vec![]);
461 let sm = softmax_array(&vals.view());
462 assert!(sm.is_empty());
463 }
464
465 #[test]
466 fn test_sigmoid_array() {
467 let vals = array![0.0_f64, 100.0, -100.0];
468 let sig = sigmoid_array(&vals.view());
469 assert!((sig[0] - 0.5).abs() < 1e-10);
470 assert!((sig[1] - 1.0).abs() < 1e-10);
471 assert!(sig[2] < 1e-30);
472 }
473
474 #[test]
475 fn test_log_sum_exp_array() {
476 let vals = array![1000.0_f64, 1000.0, 1000.0];
477 let lse = log_sum_exp_array(&vals.view());
478 let expected = 1000.0 + 3.0_f64.ln();
479 assert!((lse - expected).abs() < 1e-10);
480 }
481
482 #[test]
483 fn test_log_sum_exp_array_empty() {
484 let vals: Array1<f64> = Array1::from_vec(vec![]);
485 let lse = log_sum_exp_array(&vals.view());
486 assert!(lse.is_infinite() && lse < 0.0);
487 }
488
489 #[test]
490 fn test_condition_number_1d() {
491 let vals = array![1.0_f64, 10.0, 100.0];
492 let cn = condition_number_1d(&vals.view()).expect("should succeed");
493 assert!((cn - 100.0).abs() < 1e-10);
494 }
495
496 #[test]
497 fn test_condition_number_1d_all_zeros() {
498 let vals = array![0.0_f64, 0.0];
499 assert!(condition_number_1d(&vals.view()).is_err());
500 }
501
502 #[test]
503 fn test_condition_number_1d_empty() {
504 let vals: Array1<f64> = Array1::from_vec(vec![]);
505 assert!(condition_number_1d(&vals.view()).is_err());
506 }
507
508 #[test]
509 fn test_numerical_gradient_forward() {
510 let f = |x: &[f64]| x[0] * x[0] + x[1] * x[1];
512 let x = [3.0, 4.0];
513 let grad =
514 numerical_gradient(&f, &x, 1e-7, DifferenceMode::Forward).expect("should succeed");
515 assert!((grad[0] - 6.0).abs() < 1e-4);
516 assert!((grad[1] - 8.0).abs() < 1e-4);
517 }
518
519 #[test]
520 fn test_numerical_gradient_backward() {
521 let f = |x: &[f64]| x[0] * x[0] + x[1] * x[1];
522 let x = [3.0, 4.0];
523 let grad =
524 numerical_gradient(&f, &x, 1e-7, DifferenceMode::Backward).expect("should succeed");
525 assert!((grad[0] - 6.0).abs() < 1e-4);
526 assert!((grad[1] - 8.0).abs() < 1e-4);
527 }
528
529 #[test]
530 fn test_numerical_gradient_central() {
531 let f = |x: &[f64]| x[0] * x[0] + x[1] * x[1];
532 let x = [3.0, 4.0];
533 let grad =
534 numerical_gradient(&f, &x, 1e-5, DifferenceMode::Central).expect("should succeed");
535 assert!((grad[0] - 6.0).abs() < 1e-8);
537 assert!((grad[1] - 8.0).abs() < 1e-8);
538 }
539
540 #[test]
541 fn test_numerical_gradient_sin() {
542 let f = |x: &[f64]| x[0].sin();
544 let x = [std::f64::consts::PI / 4.0];
545 let grad =
546 numerical_gradient(&f, &x, 1e-7, DifferenceMode::Central).expect("should succeed");
547 let expected = (std::f64::consts::PI / 4.0).cos();
548 assert!((grad[0] - expected).abs() < 1e-8);
549 }
550
551 #[test]
552 fn test_check_gradient_passes() {
553 let f = |x: &[f64]| x[0] * x[0] + 2.0 * x[1] * x[1];
554 let x = [3.0, 4.0];
555 let analytical = array![6.0, 16.0]; let result =
557 check_gradient(&f, &analytical.view(), &x, 1e-5, 1e-4).expect("should succeed");
558 assert!(result.passed, "gradient check should pass");
559 assert!(result.max_relative_error < 1e-4);
560 }
561
562 #[test]
563 fn test_check_gradient_fails() {
564 let f = |x: &[f64]| x[0] * x[0] + 2.0 * x[1] * x[1];
565 let x = [3.0, 4.0];
566 let bad_analytical = array![100.0, 200.0]; let result =
568 check_gradient(&f, &bad_analytical.view(), &x, 1e-5, 1e-4).expect("should succeed");
569 assert!(
570 !result.passed,
571 "gradient check should fail with wrong gradient"
572 );
573 }
574
575 #[test]
576 fn test_check_gradient_dimension_mismatch() {
577 let f = |x: &[f64]| x[0];
578 let x = [1.0, 2.0];
579 let analytical = array![1.0]; assert!(check_gradient(&f, &analytical.view(), &x, 1e-5, 1e-4).is_err());
581 }
582
583 #[test]
584 fn test_max_relative_error() {
585 let a = array![1.1_f64, 2.2, 3.3];
586 let b = array![1.0, 2.0, 3.0];
587 let mre = max_relative_error(&a.view(), &b.view()).expect("should succeed");
588 assert!(mre > 0.09 && mre < 0.11);
589 }
590
591 #[test]
592 fn test_compensated_sum_empty() {
593 let empty: Vec<f64> = vec![];
594 assert!((compensated_sum(&empty) - 0.0).abs() < 1e-15);
595 }
596
597 #[test]
598 fn test_gradient_check_display() {
599 let f = |x: &[f64]| x[0] * x[0];
600 let x = [2.0];
601 let analytical = array![4.0];
602 let result =
603 check_gradient(&f, &analytical.view(), &x, 1e-5, 1e-4).expect("should succeed");
604 let display = format!("{result}");
605 assert!(display.contains("GradientCheck"));
606 }
607}