mean_squared_error

Function mean_squared_error 

Source
pub fn mean_squared_error<F, S1, S2, D1, D2>(
    y_true: &ArrayBase<S1, D1>,
    y_pred: &ArrayBase<S2, D2>,
) -> Result<F>
where F: Float + NumCast + Debug + SimdUnifiedOps, S1: Data<Elem = F>, S2: Data<Elem = F>, D1: Dimension, D2: Dimension,
Expand description

Calculates the mean squared error (MSE)

§Mathematical Formulation

Mean Squared Error is defined as:

MSE = (1/n) * Σ(yᵢ - ŷᵢ)²

Where:

  • n = number of samples
  • yᵢ = true value for sample i
  • ŷᵢ = predicted value for sample i
  • Σ = sum over all samples

§Properties

  • MSE is always non-negative (≥ 0)
  • MSE = 0 indicates perfect predictions
  • MSE penalizes larger errors more heavily due to squaring
  • Units: squared units of the target variable
  • Differentiable everywhere (useful for optimization)

§Interpretation

MSE measures the average squared difference between predicted and actual values:

  • Lower MSE indicates better model performance
  • Sensitive to outliers due to squaring of errors
  • Large errors contribute disproportionately to the total error

§Relationship to Other Metrics

  • RMSE = √MSE (same units as target variable)
  • MAE typically ≤ RMSE, with equality when all errors are equal
  • MSE is the expected value of squared error in probabilistic terms

§Use Cases

MSE is widely used because:

  • It’s differentiable (good for gradient-based optimization)
  • It heavily penalizes large errors
  • It’s the basis for ordinary least squares regression
  • It corresponds to Gaussian likelihood in probabilistic models

§Arguments

  • y_true - Ground truth (correct) target values
  • y_pred - Estimated target values

§Returns

  • The mean squared error

§Examples

use scirs2_core::ndarray::array;
use scirs2_metrics::regression::mean_squared_error;

let y_true = array![3.0, -0.5, 2.0, 7.0];
let y_pred = array![2.5, 0.0, 2.0, 8.0];

let mse: f64 = mean_squared_error(&y_true, &y_pred).unwrap();
// Expecting: ((3.0-2.5)² + (-0.5-0.0)² + (2.0-2.0)² + (7.0-8.0)²) / 4
assert!(mse < 0.38 && mse > 0.37);