scirs2_metrics/regression/
mod.rs

1//! Regression metrics module
2//!
3//! This module provides functions for evaluating regression models, including
4//! error metrics, correlation metrics, residual analysis, and robust metrics.
5
6mod correlation;
7mod error;
8mod residual;
9mod robust;
10
11// Re-export all public items from submodules
12pub use self::correlation::*;
13pub use self::error::*;
14pub use self::residual::*;
15pub use self::robust::*;
16
17// Common utility functions that might be used across multiple submodules
18use scirs2_core::ndarray::{ArrayBase, Data, Dimension};
19use scirs2_core::numeric::{Float, FromPrimitive, NumCast};
20
21/// Check if two arrays have the same shape
22pub(crate) fn check_sameshape<F, S1, S2, D1, D2>(
23    y_true: &ArrayBase<S1, D1>,
24    y_pred: &ArrayBase<S2, D2>,
25) -> crate::error::Result<()>
26where
27    F: scirs2_core::numeric::Float,
28    S1: scirs2_core::ndarray::Data<Elem = F>,
29    S2: scirs2_core::ndarray::Data<Elem = F>,
30    D1: scirs2_core::ndarray::Dimension,
31    D2: scirs2_core::ndarray::Dimension,
32{
33    if y_true.shape() != y_pred.shape() {
34        return Err(crate::error::MetricsError::InvalidInput(format!(
35            "y_true and y_pred have different shapes: {:?} vs {:?}",
36            y_true.shape(),
37            y_pred.shape()
38        )));
39    }
40
41    let n_samples = y_true.len();
42    if n_samples == 0 {
43        return Err(crate::error::MetricsError::InvalidInput(
44            "Empty arrays provided".to_string(),
45        ));
46    }
47
48    Ok(())
49}
50
51/// Check if all values in arrays are non-negative
52pub(crate) fn check_non_negative<F, S1, S2, D1, D2>(
53    y_true: &ArrayBase<S1, D1>,
54    y_pred: &ArrayBase<S2, D2>,
55) -> crate::error::Result<()>
56where
57    F: scirs2_core::numeric::Float + std::fmt::Debug,
58    S1: scirs2_core::ndarray::Data<Elem = F>,
59    S2: scirs2_core::ndarray::Data<Elem = F>,
60    D1: scirs2_core::ndarray::Dimension,
61    D2: scirs2_core::ndarray::Dimension,
62{
63    for val in y_true.iter() {
64        if *val < F::zero() {
65            return Err(crate::error::MetricsError::InvalidInput(
66                "y_true contains negative values".to_string(),
67            ));
68        }
69    }
70
71    for val in y_pred.iter() {
72        if *val < F::zero() {
73            return Err(crate::error::MetricsError::InvalidInput(
74                "y_pred contains negative values".to_string(),
75            ));
76        }
77    }
78
79    Ok(())
80}
81
82/// Check if all values in arrays are strictly positive
83pub(crate) fn check_positive<F, S1, S2, D1, D2>(
84    y_true: &ArrayBase<S1, D1>,
85    y_pred: &ArrayBase<S2, D2>,
86) -> crate::error::Result<()>
87where
88    F: scirs2_core::numeric::Float + std::fmt::Debug,
89    S1: scirs2_core::ndarray::Data<Elem = F>,
90    S2: scirs2_core::ndarray::Data<Elem = F>,
91    D1: scirs2_core::ndarray::Dimension,
92    D2: scirs2_core::ndarray::Dimension,
93{
94    for val in y_true.iter() {
95        if *val <= F::zero() {
96            return Err(crate::error::MetricsError::InvalidInput(
97                "y_true contains non-positive values".to_string(),
98            ));
99        }
100    }
101
102    for val in y_pred.iter() {
103        if *val <= F::zero() {
104            return Err(crate::error::MetricsError::InvalidInput(
105                "y_pred contains non-positive values".to_string(),
106            ));
107        }
108    }
109
110    Ok(())
111}
112
113/// Calculate the mean of an array
114pub(crate) fn mean<F, S, D>(arr: &ArrayBase<S, D>) -> F
115where
116    F: scirs2_core::numeric::Float,
117    S: scirs2_core::ndarray::Data<Elem = F>,
118    D: scirs2_core::ndarray::Dimension,
119{
120    let sum = arr.iter().fold(F::zero(), |acc, &x| acc + x);
121    sum / scirs2_core::numeric::NumCast::from(arr.len()).unwrap()
122}