scirs2_metrics/regression/
mod.rs1mod correlation;
7mod error;
8mod residual;
9mod robust;
10
11pub use self::correlation::*;
13pub use self::error::*;
14pub use self::residual::*;
15pub use self::robust::*;
16
17use scirs2_core::ndarray::{ArrayBase, Data, Dimension};
19use scirs2_core::numeric::{Float, FromPrimitive, NumCast};
20
21pub(crate) fn check_sameshape<F, S1, S2, D1, D2>(
23 y_true: &ArrayBase<S1, D1>,
24 y_pred: &ArrayBase<S2, D2>,
25) -> crate::error::Result<()>
26where
27 F: scirs2_core::numeric::Float,
28 S1: scirs2_core::ndarray::Data<Elem = F>,
29 S2: scirs2_core::ndarray::Data<Elem = F>,
30 D1: scirs2_core::ndarray::Dimension,
31 D2: scirs2_core::ndarray::Dimension,
32{
33 if y_true.shape() != y_pred.shape() {
34 return Err(crate::error::MetricsError::InvalidInput(format!(
35 "y_true and y_pred have different shapes: {:?} vs {:?}",
36 y_true.shape(),
37 y_pred.shape()
38 )));
39 }
40
41 let n_samples = y_true.len();
42 if n_samples == 0 {
43 return Err(crate::error::MetricsError::InvalidInput(
44 "Empty arrays provided".to_string(),
45 ));
46 }
47
48 Ok(())
49}
50
51pub(crate) fn check_non_negative<F, S1, S2, D1, D2>(
53 y_true: &ArrayBase<S1, D1>,
54 y_pred: &ArrayBase<S2, D2>,
55) -> crate::error::Result<()>
56where
57 F: scirs2_core::numeric::Float + std::fmt::Debug,
58 S1: scirs2_core::ndarray::Data<Elem = F>,
59 S2: scirs2_core::ndarray::Data<Elem = F>,
60 D1: scirs2_core::ndarray::Dimension,
61 D2: scirs2_core::ndarray::Dimension,
62{
63 for val in y_true.iter() {
64 if *val < F::zero() {
65 return Err(crate::error::MetricsError::InvalidInput(
66 "y_true contains negative values".to_string(),
67 ));
68 }
69 }
70
71 for val in y_pred.iter() {
72 if *val < F::zero() {
73 return Err(crate::error::MetricsError::InvalidInput(
74 "y_pred contains negative values".to_string(),
75 ));
76 }
77 }
78
79 Ok(())
80}
81
82pub(crate) fn check_positive<F, S1, S2, D1, D2>(
84 y_true: &ArrayBase<S1, D1>,
85 y_pred: &ArrayBase<S2, D2>,
86) -> crate::error::Result<()>
87where
88 F: scirs2_core::numeric::Float + std::fmt::Debug,
89 S1: scirs2_core::ndarray::Data<Elem = F>,
90 S2: scirs2_core::ndarray::Data<Elem = F>,
91 D1: scirs2_core::ndarray::Dimension,
92 D2: scirs2_core::ndarray::Dimension,
93{
94 for val in y_true.iter() {
95 if *val <= F::zero() {
96 return Err(crate::error::MetricsError::InvalidInput(
97 "y_true contains non-positive values".to_string(),
98 ));
99 }
100 }
101
102 for val in y_pred.iter() {
103 if *val <= F::zero() {
104 return Err(crate::error::MetricsError::InvalidInput(
105 "y_pred contains non-positive values".to_string(),
106 ));
107 }
108 }
109
110 Ok(())
111}
112
113pub(crate) fn mean<F, S, D>(arr: &ArrayBase<S, D>) -> F
115where
116 F: scirs2_core::numeric::Float,
117 S: scirs2_core::ndarray::Data<Elem = F>,
118 D: scirs2_core::ndarray::Dimension,
119{
120 let sum = arr.iter().fold(F::zero(), |acc, &x| acc + x);
121 sum / scirs2_core::numeric::NumCast::from(arr.len()).unwrap()
122}