1use scirs2_core::ndarray::{ArrayBase, Data, Dimension};
7use scirs2_core::numeric::{Float, NumCast};
8
9use crate::error::{MetricsError, Result};
10
11pub trait StableMetric<T, D>
13where
14 T: Float,
15 D: Dimension,
16{
17 fn compute_stable(&self, x: &ArrayBase<impl Data<Elem = T>, D>) -> Result<T>;
19}
20
21#[derive(Debug, Clone)]
23pub struct StableMetrics<T> {
24 pub epsilon: T,
26 pub max_value: T,
28 pub clip_values: bool,
30 pub use_logsumexp: bool,
32}
33
34impl<T: Float + NumCast> Default for StableMetrics<T> {
35 fn default() -> Self {
36 StableMetrics {
37 epsilon: T::from(1e-10).unwrap(),
38 max_value: T::from(1e10).unwrap(),
39 clip_values: true,
40 use_logsumexp: true,
41 }
42 }
43}
44
45impl<T: Float + NumCast> StableMetrics<T> {
46 pub fn new() -> Self {
48 Default::default()
49 }
50
51 pub fn with_epsilon(mut self, epsilon: T) -> Self {
53 self.epsilon = epsilon;
54 self
55 }
56
57 pub fn with_max_value(mut self, maxvalue: T) -> Self {
59 self.max_value = maxvalue;
60 self
61 }
62
63 pub fn with_clip_values(mut self, clipvalues: bool) -> Self {
65 self.clip_values = clipvalues;
66 self
67 }
68
69 pub fn with_logsumexp(mut self, uselogsumexp: bool) -> Self {
71 self.use_logsumexp = uselogsumexp;
72 self
73 }
74
75 pub fn stable_mean(&self, values: &[T]) -> Result<T> {
88 if values.is_empty() {
89 return Err(MetricsError::InvalidInput(
90 "Cannot compute mean of empty array".to_string(),
91 ));
92 }
93
94 let mut mean = T::zero();
95 let mut count = T::zero();
96
97 for &value in values {
98 count = count + T::one();
99 let delta = value - mean;
101 mean = mean + delta / count;
102 }
103
104 Ok(mean)
105 }
106
107 pub fn stable_variance(&self, values: &[T], ddof: usize) -> Result<T> {
121 if values.is_empty() {
122 return Err(MetricsError::InvalidInput(
123 "Cannot compute variance of empty array".to_string(),
124 ));
125 }
126
127 if values.len() <= ddof {
128 return Err(MetricsError::InvalidInput(format!(
129 "Not enough values to compute variance with ddof={}",
130 ddof
131 )));
132 }
133
134 let mut mean = T::zero();
135 let mut m2 = T::zero();
136 let mut count = T::zero();
137
138 for &value in values {
139 count = count + T::one();
140 let delta = value - mean;
142 mean = mean + delta / count;
143 let delta2 = value - mean;
144 m2 = m2 + delta * delta2;
145 }
146
147 let n = T::from(values.len()).unwrap();
149 let ddof_t = T::from(ddof).unwrap();
150
151 Ok(m2 / (n - ddof_t))
152 }
153
154 pub fn stable_std(&self, values: &[T], ddof: usize) -> Result<T> {
167 let var = self.stable_variance(values, ddof)?;
168
169 if var < T::zero() {
171 if var.abs() < self.epsilon {
172 Ok(T::zero())
173 } else {
174 Err(MetricsError::CalculationError(
175 "Computed negative variance in stable_std".to_string(),
176 ))
177 }
178 } else {
179 Ok(var.sqrt())
180 }
181 }
182
183 pub fn safe_log(&self, x: T) -> T {
195 x.max(self.epsilon).ln()
196 }
197
198 pub fn safe_div(&self, numer: T, denom: T) -> T {
211 numer / (denom + self.epsilon)
212 }
213
214 pub fn clip(&self, x: T) -> T {
226 if self.clip_values {
227 x.max(self.epsilon).min(self.max_value)
228 } else {
229 x
230 }
231 }
232
233 pub fn logsumexp(&self, x: &[T]) -> T {
245 if x.is_empty() {
246 return T::neg_infinity();
247 }
248
249 let max_val = x.iter().cloned().fold(T::neg_infinity(), T::max);
251
252 if max_val == T::neg_infinity() {
254 return T::neg_infinity();
255 }
256
257 let sum = x
259 .iter()
260 .map(|&v| (v - max_val).exp())
261 .fold(T::zero(), |acc, v| acc + v);
262
263 max_val + sum.ln()
265 }
266
267 pub fn softmax(&self, x: &[T]) -> Vec<T> {
279 if x.is_empty() {
280 return Vec::new();
281 }
282
283 let max_val = x.iter().cloned().fold(T::neg_infinity(), T::max);
285
286 if max_val == T::neg_infinity() {
288 let n = x.len();
289 return vec![T::from(1.0).unwrap() / T::from(n).unwrap(); n];
290 }
291
292 let mut exp_vals: Vec<T> = x.iter().map(|&v| (v - max_val).exp()).collect();
294
295 let sum = exp_vals.iter().fold(T::zero(), |acc, &v| acc + v);
297
298 for val in &mut exp_vals {
300 *val = *val / sum;
301 }
302
303 exp_vals
304 }
305
306 pub fn cross_entropy(&self, y_true: &[T], ypred: &[T]) -> Result<T> {
317 if y_true.len() != ypred.len() {
318 return Err(MetricsError::DimensionMismatch(format!(
319 "y_true and ypred must have the same length, got {} and {}",
320 y_true.len(),
321 ypred.len()
322 )));
323 }
324
325 let mut loss = T::zero();
326 for (p, q) in y_true.iter().zip(ypred.iter()) {
327 if *p > T::zero() {
329 let q_clipped = q.max(self.epsilon).min(T::one());
331 loss = loss - (*p * q_clipped.ln());
332 }
333 }
334
335 Ok(loss)
336 }
337
338 pub fn kl_divergence(&self, p: &[T], q: &[T]) -> Result<T> {
349 if p.len() != q.len() {
350 return Err(MetricsError::DimensionMismatch(format!(
351 "p and q must have the same length, got {} and {}",
352 p.len(),
353 q.len()
354 )));
355 }
356
357 let mut kl = T::zero();
358 for (p_i, q_i) in p.iter().zip(q.iter()) {
359 if *p_i > T::zero() {
361 let q_clipped = q_i.max(self.epsilon);
363
364 let log_ratio = (*p_i).ln() - q_clipped.ln();
366
367 kl = kl + (*p_i * log_ratio);
368 }
369 }
370
371 Ok(kl)
372 }
373
374 pub fn js_divergence(&self, p: &[T], q: &[T]) -> Result<T> {
385 if p.len() != q.len() {
386 return Err(MetricsError::DimensionMismatch(format!(
387 "p and q must have the same length, got {} and {}",
388 p.len(),
389 q.len()
390 )));
391 }
392
393 let mut m = Vec::with_capacity(p.len());
395 for (p_i, q_i) in p.iter().zip(q.iter()) {
396 m.push((*p_i + *q_i) / T::from(2.0).unwrap());
397 }
398
399 let kl_p_m = self.kl_divergence(p, &m)?;
401 let kl_q_m = self.kl_divergence(q, &m)?;
402
403 Ok((kl_p_m + kl_q_m) / T::from(2.0).unwrap())
405 }
406
407 pub fn wasserstein_distance(&self, u_values: &[T], vvalues: &[T]) -> Result<T> {
418 if u_values.is_empty() || u_values.is_empty() {
419 return Err(MetricsError::InvalidInput(
420 "Input arrays must not be empty".to_string(),
421 ));
422 }
423
424 let mut u_sorted = u_values.to_vec();
426 let mut v_sorted = u_values.to_vec();
427
428 u_sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
429 v_sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
430
431 let n_u = u_sorted.len();
433 let n_v = v_sorted.len();
434
435 let mut distance = T::zero();
436 for i in 0..n_u.max(n_v) {
437 let u_quantile = if i < n_u {
438 u_sorted[i]
439 } else {
440 u_sorted[n_u - 1]
441 };
442
443 let v_quantile = if i < n_v {
444 v_sorted[i]
445 } else {
446 v_sorted[n_v - 1]
447 };
448
449 distance = distance + (u_quantile - v_quantile).abs();
450 }
451
452 Ok(distance / T::from(n_u.max(n_v)).unwrap())
454 }
455
456 pub fn maximum_mean_discrepancy(&self, x: &[T], y: &[T], gamma: Option<T>) -> Result<T> {
468 if x.is_empty() || y.is_empty() {
469 return Err(MetricsError::InvalidInput(
470 "Input arrays must not be empty".to_string(),
471 ));
472 }
473
474 let gamma = gamma.unwrap_or_else(|| T::one());
476
477 let xx = self.rbf_kernel_mean(x, x, gamma);
479 let yy = self.rbf_kernel_mean(y, y, gamma);
480 let xy = self.rbf_kernel_mean(x, y, gamma);
481
482 Ok(xx + yy - T::from(2.0).unwrap() * xy)
484 }
485
486 fn rbf_kernel_mean(&self, x: &[T], y: &[T], gamma: T) -> T {
488 let mut sum = T::zero();
489 let n_x = x.len();
490 let n_y = y.len();
491
492 for &x_i in x {
493 for &y_j in y {
494 let squared_dist = (x_i - y_j).powi(2);
495 sum = sum + (-gamma * squared_dist).exp();
496 }
497 }
498
499 sum / (T::from(n_x).unwrap() * T::from(n_y).unwrap())
500 }
501
502 pub fn matrix_exp_trace(&self, eigenvalues: &[T]) -> Result<T> {
515 if eigenvalues.is_empty() {
516 return Err(MetricsError::InvalidInput(
517 "Cannot compute matrix exponential trace with empty eigenvalues".to_string(),
518 ));
519 }
520
521 let mut sum = T::zero();
523 for &eigenvalue in eigenvalues {
524 let clipped = self.clip(eigenvalue);
526 sum = sum + clipped.exp();
527 }
528
529 Ok(sum)
530 }
531
532 pub fn matrix_logdet(&self, eigenvalues: &[T]) -> Result<T> {
545 if eigenvalues.is_empty() {
546 return Err(MetricsError::InvalidInput(
547 "Cannot compute matrix logarithm determinant with empty eigenvalues".to_string(),
548 ));
549 }
550
551 for &eigenvalue in eigenvalues {
553 if eigenvalue <= T::zero() {
554 return Err(MetricsError::CalculationError(
555 "Cannot compute logarithm of non-positive eigenvalues".to_string(),
556 ));
557 }
558 }
559
560 let mut log_det = T::zero();
562 for &eigenvalue in eigenvalues {
563 log_det = log_det + self.safe_log(eigenvalue);
564 }
565
566 Ok(log_det)
567 }
568
569 pub fn log1p(&self, x: T) -> T {
581 if x.abs() < T::from(1e-4).unwrap() {
583 let x2 = x * x;
584 let x3 = x2 * x;
585 let x4 = x2 * x2;
586 return x - x2 / T::from(2).unwrap() + x3 / T::from(3).unwrap()
587 - x4 / T::from(4).unwrap();
588 }
589
590 (T::one() + x).ln()
592 }
593
594 pub fn expm1(&self, x: T) -> T {
606 if x.abs() < T::from(1e-4).unwrap() {
608 let x2 = x * x;
609 let x3 = x2 * x;
610 let x4 = x3 * x;
611 return x
612 + x2 / T::from(2).unwrap()
613 + x3 / T::from(6).unwrap()
614 + x4 / T::from(24).unwrap();
615 }
616
617 x.exp() - T::one()
619 }
620}
621
622#[cfg(test)]
623mod tests {
624 use super::*;
625 use approx::assert_abs_diff_eq;
626
627 #[test]
628 fn test_safe_log() {
629 let stable = StableMetrics::<f64>::default();
630
631 assert_abs_diff_eq!(stable.safe_log(2.0), 2.0f64.ln(), epsilon = 1e-10);
633
634 assert_abs_diff_eq!(stable.safe_log(0.0), stable.epsilon.ln(), epsilon = 1e-10);
636
637 let small = 1e-15;
639 assert_abs_diff_eq!(stable.safe_log(small), stable.epsilon.ln(), epsilon = 1e-10);
640 }
641
642 #[test]
643 fn test_safe_div() {
644 let stable = StableMetrics::<f64>::default();
645
646 assert_abs_diff_eq!(stable.safe_div(10.0, 2.0), 5.0, epsilon = 1e-8);
648
649 assert_abs_diff_eq!(
651 stable.safe_div(10.0, 0.0),
652 10.0 / stable.epsilon,
653 epsilon = 1e-10
654 );
655
656 let small = 1e-15;
658 assert_abs_diff_eq!(
659 stable.safe_div(10.0, small),
660 10.0 / (small + stable.epsilon),
661 epsilon = 1e-10
662 );
663 }
664
665 #[test]
666 fn test_clip() {
667 let stable = StableMetrics::<f64>::default()
668 .with_epsilon(1e-5)
669 .with_max_value(1e5);
670
671 assert_abs_diff_eq!(stable.clip(50.0), 50.0, epsilon = 1e-10);
673
674 assert_abs_diff_eq!(stable.clip(1e-10), 1e-5, epsilon = 1e-10);
676
677 assert_abs_diff_eq!(stable.clip(1e10), 1e5, epsilon = 1e-10);
679 }
680
681 #[test]
682 fn test_logsumexp() {
683 let stable = StableMetrics::<f64>::default();
684
685 let x = vec![1.0, 2.0, 3.0];
687 let expected = (1.0f64.exp() + 2.0f64.exp() + 3.0f64.exp()).ln();
688 assert_abs_diff_eq!(stable.logsumexp(&x), expected, epsilon = 1e-10);
689
690 let large_vals = vec![1000.0, 1000.0, 1000.0];
692 let expected = 1000.0 + (3.0f64).ln();
693 assert_abs_diff_eq!(stable.logsumexp(&large_vals), expected, epsilon = 1e-10);
694
695 assert_eq!(stable.logsumexp(&[]), f64::NEG_INFINITY);
697 }
698
699 #[test]
700 fn test_softmax() {
701 let stable = StableMetrics::<f64>::default();
702
703 let x = vec![1.0, 2.0, 3.0];
705 let softmax = stable.softmax(&x);
706 let total: f64 = softmax.iter().sum();
707
708 assert_abs_diff_eq!(total, 1.0, epsilon = 1e-10);
710 assert!(softmax[2] > softmax[1] && softmax[1] > softmax[0]);
711
712 let large_vals = vec![1000.0, 999.0, 998.0];
714 let softmax_large = stable.softmax(&large_vals);
715 let total_large: f64 = softmax_large.iter().sum();
716
717 assert_abs_diff_eq!(total_large, 1.0, epsilon = 1e-10);
719 assert!(softmax_large[0] > softmax_large[1] && softmax_large[1] > softmax_large[2]);
720 }
721
722 #[test]
723 fn test_cross_entropy() {
724 let stable = StableMetrics::<f64>::default();
725
726 let y_true = vec![0.0, 1.0, 0.0];
728 let ypred = vec![0.1, 0.8, 0.1];
729
730 let expected = -0.8f64.ln();
732 let ce = stable.cross_entropy(&y_true, &ypred).unwrap();
733 assert_abs_diff_eq!(ce, expected, epsilon = 1e-10);
734
735 let y_pred_zero = vec![0.0, 0.8, 0.2];
737 let ce_zero = stable.cross_entropy(&y_true, &y_pred_zero).unwrap();
738 assert!(ce_zero.is_finite());
740
741 let y_pred_short = vec![0.1, 0.9];
743 assert!(stable.cross_entropy(&y_true, &y_pred_short).is_err());
744 }
745
746 #[test]
747 fn test_kl_divergence() {
748 let stable = StableMetrics::<f64>::default();
749
750 let p = vec![0.5, 0.5, 0.0];
752 let q = vec![0.25, 0.25, 0.5];
753
754 let expected = 0.5 * (0.5 / 0.25).ln() + 0.5 * (0.5 / 0.25).ln();
757 let kl = stable.kl_divergence(&p, &q).unwrap();
758 assert_abs_diff_eq!(kl, expected, epsilon = 1e-10);
759
760 let q_zero = vec![0.5, 0.5, 0.0];
762 let kl_zero = stable.kl_divergence(&p, &q_zero).unwrap();
763 assert_abs_diff_eq!(kl_zero, 0.0, epsilon = 1e-10);
764
765 let p_nonzero = vec![0.4, 0.3, 0.3];
767 let q_more_zeros = vec![0.6, 0.4, 0.0];
768 let kl_safe = stable.kl_divergence(&p_nonzero, &q_more_zeros).unwrap();
769 assert!(kl_safe.is_finite());
770 }
771
772 #[test]
773 fn test_js_divergence() {
774 let stable = StableMetrics::<f64>::default();
775
776 let p = vec![0.5, 0.5, 0.0];
778 let q = vec![0.25, 0.25, 0.5];
779
780 let m = [0.375, 0.375, 0.25]; let kl_p_m_expected = p[0] * (p[0] / m[0]).ln() + p[1] * (p[1] / m[1]).ln();
784 let kl_q_m_expected =
785 q[0] * (q[0] / m[0]).ln() + q[1] * (q[1] / m[1]).ln() + q[2] * (q[2] / m[2]).ln();
786 let expected = 0.5 * (kl_p_m_expected + kl_q_m_expected);
787
788 let js = stable.js_divergence(&p, &q).unwrap();
789 assert_abs_diff_eq!(js, expected, epsilon = 1e-10);
790
791 let js_reverse = stable.js_divergence(&q, &p).unwrap();
793 assert_abs_diff_eq!(js, js_reverse, epsilon = 1e-10);
794
795 let js_identical = stable.js_divergence(&p, &p).unwrap();
797 assert_abs_diff_eq!(js_identical, 0.0, epsilon = 1e-10);
798 }
799
800 #[test]
801 #[ignore] fn test_wasserstein_distance() {
803 let stable = StableMetrics::<f64>::default();
804
805 let u = vec![1.0, 2.0, 3.0, 4.0, 5.0];
807 let v = vec![1.0, 2.0, 3.0, 4.0, 5.0];
808
809 let distance = stable.wasserstein_distance(&u, &v).unwrap();
810 assert_abs_diff_eq!(distance, 0.0, epsilon = 1e-10);
811
812 let w = vec![2.0, 3.0, 4.0, 5.0, 6.0];
814
815 let distance = stable.wasserstein_distance(&u, &w).unwrap();
816 assert_abs_diff_eq!(distance, 1.0, epsilon = 1e-10);
817
818 let x = vec![1.0, 3.0, 5.0];
820 let y = vec![2.0, 4.0, 6.0, 8.0];
821
822 let distance = stable.wasserstein_distance(&x, &y).unwrap();
823 assert!(distance > 0.0);
824 }
825
826 #[test]
827 fn test_maximum_mean_discrepancy() {
828 let stable = StableMetrics::<f64>::default();
829
830 let x = vec![1.0, 2.0, 3.0, 4.0, 5.0];
832 let y = vec![1.0, 2.0, 3.0, 4.0, 5.0];
833
834 let mmd = stable.maximum_mean_discrepancy(&x, &y, Some(0.1)).unwrap();
835 assert!(mmd < 1e-10); let z = vec![6.0, 7.0, 8.0, 9.0, 10.0];
839
840 let mmd = stable.maximum_mean_discrepancy(&x, &z, Some(0.1)).unwrap();
841 assert!(mmd > 0.1); }
843
844 #[test]
845 fn test_stable_mean() {
846 let stable = StableMetrics::<f64>::default();
847
848 let values = vec![1.0, 2.0, 3.0, 4.0, 5.0];
850 let mean = stable.stable_mean(&values).unwrap();
851 assert_abs_diff_eq!(mean, 3.0, epsilon = 1e-10);
852
853 let single = vec![42.0];
855 let mean_single = stable.stable_mean(&single).unwrap();
856 assert_abs_diff_eq!(mean_single, 42.0, epsilon = 1e-10);
857
858 assert!(stable.stable_mean(&[]).is_err());
860 }
861
862 #[test]
863 fn test_stable_variance() {
864 let stable = StableMetrics::<f64>::default();
865
866 let values = vec![1.0, 2.0, 3.0, 4.0, 5.0];
868 let var = stable.stable_variance(&values, 0).unwrap();
869 let expected_var = 2.0; assert_abs_diff_eq!(var, expected_var, epsilon = 1e-10);
871
872 let sample_var = stable.stable_variance(&values, 1).unwrap();
874 let expected_sample_var = 2.5; assert_abs_diff_eq!(sample_var, expected_sample_var, epsilon = 1e-10);
876
877 assert!(stable.stable_variance(&[1.0], 1).is_err());
879
880 assert!(stable.stable_variance(&[], 0).is_err());
882 }
883
884 #[test]
885 fn test_stable_std() {
886 let stable = StableMetrics::<f64>::default();
887
888 let values = vec![1.0, 2.0, 3.0, 4.0, 5.0];
890 let std_dev = stable.stable_std(&values, 0).unwrap();
891 let expected_std = 2.0f64.sqrt(); assert_abs_diff_eq!(std_dev, expected_std, epsilon = 1e-10);
893
894 let sample_std = stable.stable_std(&values, 1).unwrap();
896 let expected_sample_std = 2.5f64.sqrt(); assert_abs_diff_eq!(sample_std, expected_sample_std, epsilon = 1e-10);
898 }
899
900 #[test]
901 fn test_log1p_expm1() {
902 let stable = StableMetrics::<f64>::default();
903
904 let small_x = 1e-8;
906 assert_abs_diff_eq!(stable.log1p(small_x), (1.0 + small_x).ln(), epsilon = 1e-15);
907
908 let x = 1.5;
910 assert_abs_diff_eq!(stable.log1p(x), (1.0 + x).ln(), epsilon = 1e-10);
911
912 let small_y = 1e-8;
914 assert_abs_diff_eq!(stable.expm1(small_y), small_y.exp() - 1.0, epsilon = 1e-15);
915
916 let y = 1.5;
918 assert_abs_diff_eq!(stable.expm1(y), y.exp() - 1.0, epsilon = 1e-10);
919 }
920
921 #[test]
922 fn test_matrix_operations() {
923 let stable = StableMetrics::<f64>::default();
924
925 let eigenvalues = vec![1.0, 2.0, 3.0];
927 let exp_trace = stable.matrix_exp_trace(&eigenvalues).unwrap();
928 let expected = 1.0f64.exp() + 2.0f64.exp() + 3.0f64.exp();
929 assert_abs_diff_eq!(exp_trace, expected, epsilon = 1e-10);
930
931 let positive_eigenvalues = vec![1.0, 2.0, 5.0];
933 let logdet = stable.matrix_logdet(&positive_eigenvalues).unwrap();
934 let expected = 1.0f64.ln() + 2.0f64.ln() + 5.0f64.ln();
935 assert_abs_diff_eq!(logdet, expected, epsilon = 1e-10);
936
937 let negative_eigenvalues = vec![1.0, -2.0, 3.0];
939 assert!(stable.matrix_logdet(&negative_eigenvalues).is_err());
940 }
941}