scirs2_metrics/regression/
error.rs

1//! Error metrics for regression models
2//!
3//! This module provides functions for calculating error metrics between
4//! predicted values and true values in regression models.
5
6use scirs2_core::ndarray::{ArrayBase, ArrayView1, Data, Dimension};
7use scirs2_core::numeric::{Float, FromPrimitive, NumCast};
8use scirs2_core::simd_ops::SimdUnifiedOps;
9use std::cmp::Ordering;
10
11use super::check_sameshape;
12use crate::error::{MetricsError, Result};
13
14/// Calculates the mean squared error (MSE)
15///
16/// # Mathematical Formulation
17///
18/// Mean Squared Error is defined as:
19///
20/// ```text
21/// MSE = (1/n) * Σ(yᵢ - ŷᵢ)²
22/// ```
23///
24/// Where:
25/// - n = number of samples
26/// - yᵢ = true value for sample i
27/// - ŷᵢ = predicted value for sample i
28/// - Σ = sum over all samples
29///
30/// # Properties
31///
32/// - MSE is always non-negative (≥ 0)
33/// - MSE = 0 indicates perfect predictions
34/// - MSE penalizes larger errors more heavily due to squaring
35/// - Units: squared units of the target variable
36/// - Differentiable everywhere (useful for optimization)
37///
38/// # Interpretation
39///
40/// MSE measures the average squared difference between predicted and actual values:
41/// - Lower MSE indicates better model performance
42/// - Sensitive to outliers due to squaring of errors
43/// - Large errors contribute disproportionately to the total error
44///
45/// # Relationship to Other Metrics
46///
47/// - RMSE = √MSE (same units as target variable)
48/// - MAE typically ≤ RMSE, with equality when all errors are equal
49/// - MSE is the expected value of squared error in probabilistic terms
50///
51/// # Use Cases
52///
53/// MSE is widely used because:
54/// - It's differentiable (good for gradient-based optimization)
55/// - It heavily penalizes large errors
56/// - It's the basis for ordinary least squares regression
57/// - It corresponds to Gaussian likelihood in probabilistic models
58///
59/// # Arguments
60///
61/// * `y_true` - Ground truth (correct) target values
62/// * `y_pred` - Estimated target values
63///
64/// # Returns
65///
66/// * The mean squared error
67///
68/// # Examples
69///
70/// ```no_run
71/// use scirs2_core::ndarray::array;
72/// use scirs2_metrics::regression::mean_squared_error;
73///
74/// let y_true = array![3.0, -0.5, 2.0, 7.0];
75/// let y_pred = array![2.5, 0.0, 2.0, 8.0];
76///
77/// let mse: f64 = mean_squared_error(&y_true, &y_pred).unwrap();
78/// // Expecting: ((3.0-2.5)² + (-0.5-0.0)² + (2.0-2.0)² + (7.0-8.0)²) / 4
79/// assert!(mse < 0.38 && mse > 0.37);
80/// ```
81#[allow(dead_code)]
82pub fn mean_squared_error<F, S1, S2, D1, D2>(
83    y_true: &ArrayBase<S1, D1>,
84    y_pred: &ArrayBase<S2, D2>,
85) -> Result<F>
86where
87    F: Float + NumCast + std::fmt::Debug + scirs2_core::simd_ops::SimdUnifiedOps,
88    S1: Data<Elem = F>,
89    S2: Data<Elem = F>,
90    D1: Dimension,
91    D2: Dimension,
92{
93    // Check that arrays have the same shape
94    check_sameshape::<F, S1, S2, D1, D2>(y_true, y_pred)?;
95
96    let n_samples = y_true.len();
97
98    // Use SIMD optimizations for vector operations when data is contiguous
99    let squared_error_sum = if y_true.is_standard_layout() && y_pred.is_standard_layout() {
100        // SIMD-optimized computation - convert to 1D views for SIMD _ops
101        let y_true_view = y_true.view();
102        let y_pred_view = y_pred.view();
103        let y_true_reshaped = y_true_view.to_shape(y_true.len()).unwrap();
104        let y_pred_reshaped = y_pred_view.to_shape(y_pred.len()).unwrap();
105        let y_true_1d = y_true_reshaped.view();
106        let y_pred_1d = y_pred_reshaped.view();
107        let diff = F::simd_sub(&y_true_1d, &y_pred_1d);
108        let squared_diff = F::simd_mul(&diff.view(), &diff.view());
109        F::simd_sum(&squared_diff.view())
110    } else {
111        // Fallback for non-contiguous arrays
112        let mut sum = F::zero();
113        for (yt, yp) in y_true.iter().zip(y_pred.iter()) {
114            let error = *yt - *yp;
115            sum = sum + error * error;
116        }
117        sum
118    };
119
120    Ok(squared_error_sum / NumCast::from(n_samples).unwrap())
121}
122
123/// Calculates the root mean squared error (RMSE)
124///
125/// Root mean squared error is the square root of the mean squared error.
126///
127/// # Arguments
128///
129/// * `y_true` - Ground truth (correct) target values
130/// * `y_pred` - Estimated target values
131///
132/// # Returns
133///
134/// * The root mean squared error
135///
136/// # Examples
137///
138/// ```no_run
139/// use scirs2_core::ndarray::array;
140/// use scirs2_metrics::regression::root_mean_squared_error;
141///
142/// let y_true = array![3.0, -0.5, 2.0, 7.0];
143/// let y_pred = array![2.5, 0.0, 2.0, 8.0];
144///
145/// let rmse: f64 = root_mean_squared_error(&y_true, &y_pred).unwrap();
146/// // RMSE is the square root of MSE
147/// assert!(rmse < 0.62 && rmse > 0.61);
148/// ```
149#[allow(dead_code)]
150pub fn root_mean_squared_error<F, S1, S2, D1, D2>(
151    y_true: &ArrayBase<S1, D1>,
152    y_pred: &ArrayBase<S2, D2>,
153) -> Result<F>
154where
155    F: Float + NumCast + std::fmt::Debug + scirs2_core::simd_ops::SimdUnifiedOps,
156    S1: Data<Elem = F>,
157    S2: Data<Elem = F>,
158    D1: Dimension,
159    D2: Dimension,
160{
161    let mse = mean_squared_error(y_true, y_pred)?;
162    Ok(mse.sqrt())
163}
164
165/// Calculates the mean absolute error (MAE)
166///
167/// Mean absolute error measures the average absolute difference between
168/// the estimated values and the actual value.
169///
170/// # Arguments
171///
172/// * `y_true` - Ground truth (correct) target values
173/// * `y_pred` - Estimated target values
174///
175/// # Returns
176///
177/// * The mean absolute error
178///
179/// # Examples
180///
181/// ```no_run
182/// use scirs2_core::ndarray::array;
183/// use scirs2_metrics::regression::mean_absolute_error;
184///
185/// let y_true = array![3.0, -0.5, 2.0, 7.0];
186/// let y_pred = array![2.5, 0.0, 2.0, 8.0];
187///
188/// let mae: f64 = mean_absolute_error(&y_true, &y_pred).unwrap();
189/// // Expecting: (|3.0-2.5| + |-0.5-0.0| + |2.0-2.0| + |7.0-8.0|) / 4 = 0.5
190/// assert!(mae > 0.499 && mae < 0.501);
191/// ```
192#[allow(dead_code)]
193pub fn mean_absolute_error<F, S1, S2, D1, D2>(
194    y_true: &ArrayBase<S1, D1>,
195    y_pred: &ArrayBase<S2, D2>,
196) -> Result<F>
197where
198    F: Float + NumCast + std::fmt::Debug + scirs2_core::simd_ops::SimdUnifiedOps,
199    S1: Data<Elem = F>,
200    S2: Data<Elem = F>,
201    D1: Dimension,
202    D2: Dimension,
203{
204    // Check that arrays have the same shape
205    check_sameshape::<F, S1, S2, D1, D2>(y_true, y_pred)?;
206
207    let n_samples = y_true.len();
208
209    // Use SIMD optimizations for vector operations when data is contiguous
210    let abs_error_sum = if y_true.is_standard_layout() && y_pred.is_standard_layout() {
211        // SIMD-optimized computation for 1D arrays
212        let y_true_view = y_true.view();
213        let y_pred_view = y_pred.view();
214        let y_true_reshaped = y_true_view.to_shape(y_true.len()).unwrap();
215        let y_pred_reshaped = y_pred_view.to_shape(y_pred.len()).unwrap();
216        let y_true_1d = y_true_reshaped.view();
217        let y_pred_1d = y_pred_reshaped.view();
218        let diff = F::simd_sub(&y_true_1d, &y_pred_1d);
219        let abs_diff = F::simd_abs(&diff.view());
220        F::simd_sum(&abs_diff.view())
221    } else {
222        // Fallback for non-contiguous arrays
223        let mut sum = F::zero();
224        for (yt, yp) in y_true.iter().zip(y_pred.iter()) {
225            let error = (*yt - *yp).abs();
226            sum = sum + error;
227        }
228        sum
229    };
230
231    Ok(abs_error_sum / NumCast::from(n_samples).unwrap())
232}
233
234/// Calculates the mean absolute percentage error (MAPE)
235///
236/// Mean absolute percentage error expresses the difference between true
237/// and predicted values as a percentage of the true values.
238///
239/// # Arguments
240///
241/// * `y_true` - Ground truth (correct) target values
242/// * `y_pred` - Estimated target values
243///
244/// # Returns
245///
246/// * The mean absolute percentage error
247///
248/// # Notes
249///
250/// MAPE is undefined when true values are zero. This implementation
251/// excludes those samples from the calculation.
252///
253/// # Examples
254///
255/// ```
256/// use scirs2_core::ndarray::array;
257/// use scirs2_metrics::regression::mean_absolute_percentage_error;
258///
259/// let y_true = array![3.0, 0.5, 2.0, 7.0];
260/// let y_pred = array![2.7, 0.4, 1.8, 7.7];
261///
262/// let mape = mean_absolute_percentage_error(&y_true, &y_pred).unwrap();
263/// // Example calculation: (|3.0-2.7|/3.0 + |0.5-0.4|/0.5 + |2.0-1.8|/2.0 + |7.0-7.7|/7.0) / 4 * 100
264/// assert!(mape < 13.0 && mape > 9.0);
265/// ```
266#[allow(dead_code)]
267pub fn mean_absolute_percentage_error<F, S1, S2, D1, D2>(
268    y_true: &ArrayBase<S1, D1>,
269    y_pred: &ArrayBase<S2, D2>,
270) -> Result<F>
271where
272    F: Float + NumCast + std::fmt::Debug + scirs2_core::simd_ops::SimdUnifiedOps,
273    S1: Data<Elem = F>,
274    S2: Data<Elem = F>,
275    D1: Dimension,
276    D2: Dimension,
277{
278    // Check that arrays have the same shape
279    check_sameshape::<F, S1, S2, D1, D2>(y_true, y_pred)?;
280
281    let mut percentage_error_sum = F::zero();
282    let mut valid_samples = 0;
283
284    for (yt, yp) in y_true.iter().zip(y_pred.iter()) {
285        if yt.abs() > F::epsilon() {
286            let percentage_error = ((*yt - *yp) / *yt).abs();
287            percentage_error_sum = percentage_error_sum + percentage_error;
288            valid_samples += 1;
289        }
290    }
291
292    if valid_samples == 0 {
293        return Err(MetricsError::InvalidInput(
294            "All y_true values are zero. MAPE is undefined.".to_string(),
295        ));
296    }
297
298    // Multiply by 100 to get percentage
299    Ok(percentage_error_sum / NumCast::from(valid_samples).unwrap() * NumCast::from(100).unwrap())
300}
301
302/// Calculates the symmetric mean absolute percentage error (SMAPE)
303///
304/// SMAPE is an alternative to MAPE that handles zero or near-zero values better.
305///
306/// # Arguments
307///
308/// * `y_true` - Ground truth (correct) target values
309/// * `y_pred` - Estimated target values
310///
311/// # Returns
312///
313/// * The symmetric mean absolute percentage error
314///
315/// # Examples
316///
317/// ```
318/// use scirs2_core::ndarray::array;
319/// use scirs2_metrics::regression::symmetric_mean_absolute_percentage_error;
320///
321/// let y_true = array![3.0, 0.01, 2.0, 7.0];
322/// let y_pred = array![2.7, 0.0, 1.8, 7.7];
323///
324/// let smape = symmetric_mean_absolute_percentage_error(&y_true, &y_pred).unwrap();
325/// assert!(smape > 0.0);
326/// ```
327#[allow(dead_code)]
328pub fn symmetric_mean_absolute_percentage_error<F, S1, S2, D1, D2>(
329    y_true: &ArrayBase<S1, D1>,
330    y_pred: &ArrayBase<S2, D2>,
331) -> Result<F>
332where
333    F: Float + NumCast + std::fmt::Debug + scirs2_core::simd_ops::SimdUnifiedOps,
334    S1: Data<Elem = F>,
335    S2: Data<Elem = F>,
336    D1: Dimension,
337    D2: Dimension,
338{
339    // Check that arrays have the same shape
340    check_sameshape::<F, S1, S2, D1, D2>(y_true, y_pred)?;
341
342    let mut percentage_error_sum = F::zero();
343    let mut valid_samples = 0;
344
345    for (yt, yp) in y_true.iter().zip(y_pred.iter()) {
346        // Skip samples where both y_true and y_pred are zero to avoid undefined values
347        if yt.abs() > F::epsilon() || yp.abs() > F::epsilon() {
348            let percentage_error = ((*yt - *yp).abs()) / (yt.abs() + yp.abs());
349            percentage_error_sum = percentage_error_sum + percentage_error;
350            valid_samples += 1;
351        }
352    }
353
354    if valid_samples == 0 {
355        return Err(MetricsError::InvalidInput(
356            "All values are zero. SMAPE is undefined.".to_string(),
357        ));
358    }
359
360    // Multiply by 200 to get percentage (SMAPE is typically defined with factor of 2)
361    Ok(percentage_error_sum / NumCast::from(valid_samples).unwrap() * NumCast::from(200).unwrap())
362}
363
364/// Calculates the maximum error
365///
366/// Maximum error is the maximum absolute difference between the true and predicted values.
367///
368/// # Arguments
369///
370/// * `y_true` - Ground truth (correct) target values
371/// * `y_pred` - Estimated target values
372///
373/// # Returns
374///
375/// * The maximum error
376///
377/// # Examples
378///
379/// ```
380/// use scirs2_core::ndarray::array;
381/// use scirs2_metrics::regression::max_error;
382///
383/// let y_true = array![3.0, -0.5, 2.0, 7.0];
384/// let y_pred = array![2.5, 0.0, 2.0, 8.0];
385///
386/// let me = max_error(&y_true, &y_pred).unwrap();
387/// // Maximum of [|3.0-2.5|, |-0.5-0.0|, |2.0-2.0|, |7.0-8.0|]
388/// assert_eq!(me, 1.0);
389/// ```
390#[allow(dead_code)]
391pub fn max_error<F, S1, S2, D1, D2>(
392    y_true: &ArrayBase<S1, D1>,
393    y_pred: &ArrayBase<S2, D2>,
394) -> Result<F>
395where
396    F: Float + NumCast + std::fmt::Debug + scirs2_core::simd_ops::SimdUnifiedOps,
397    S1: Data<Elem = F>,
398    S2: Data<Elem = F>,
399    D1: Dimension,
400    D2: Dimension,
401{
402    // Check that arrays have the same shape
403    check_sameshape::<F, S1, S2, D1, D2>(y_true, y_pred)?;
404
405    let mut max_err = F::zero();
406    for (yt, yp) in y_true.iter().zip(y_pred.iter()) {
407        let error = (*yt - *yp).abs();
408        if error > max_err {
409            max_err = error;
410        }
411    }
412
413    Ok(max_err)
414}
415
416/// Calculates the median absolute error
417///
418/// Median absolute error is the median of all absolute differences between
419/// the true and predicted values. It is robust to outliers.
420///
421/// # Arguments
422///
423/// * `y_true` - Ground truth (correct) target values
424/// * `y_pred` - Estimated target values
425///
426/// # Returns
427///
428/// * The median absolute error
429///
430/// # Examples
431///
432/// ```
433/// use scirs2_core::ndarray::array;
434/// use scirs2_metrics::regression::median_absolute_error;
435///
436/// let y_true = array![3.0, -0.5, 2.0, 7.0];
437/// let y_pred = array![2.5, 0.0, 2.0, 8.0];
438///
439/// let medae = median_absolute_error(&y_true, &y_pred).unwrap();
440/// // Median of [|3.0-2.5|, |-0.5-0.0|, |2.0-2.0|, |7.0-8.0|] = Median of [0.5, 0.5, 0.0, 1.0]
441/// assert_eq!(medae, 0.5);
442/// ```
443#[allow(dead_code)]
444pub fn median_absolute_error<F, S1, S2, D1, D2>(
445    y_true: &ArrayBase<S1, D1>,
446    y_pred: &ArrayBase<S2, D2>,
447) -> Result<F>
448where
449    F: Float + NumCast + std::fmt::Debug + scirs2_core::simd_ops::SimdUnifiedOps,
450    S1: Data<Elem = F>,
451    S2: Data<Elem = F>,
452    D1: Dimension,
453    D2: Dimension,
454{
455    // Check that arrays have the same shape
456    check_sameshape::<F, S1, S2, D1, D2>(y_true, y_pred)?;
457
458    let n_samples = y_true.len();
459
460    // Calculate absolute errors
461    let mut abs_errors = Vec::with_capacity(n_samples);
462    for (yt, yp) in y_true.iter().zip(y_pred.iter()) {
463        abs_errors.push((*yt - *yp).abs());
464    }
465
466    // Sort and get median
467    abs_errors.sort_by(|a, b| a.partial_cmp(b).unwrap_or(Ordering::Equal));
468
469    if n_samples % 2 == 1 {
470        // Odd number of samples
471        Ok(abs_errors[n_samples / 2])
472    } else {
473        // Even number of samples
474        let mid = n_samples / 2;
475        Ok((abs_errors[mid - 1] + abs_errors[mid]) / NumCast::from(2).unwrap())
476    }
477}
478
479/// Calculates the mean squared logarithmic error (MSLE)
480///
481/// Mean squared logarithmic error measures the average squared difference
482/// between the logarithm of the predicted and true values. This metric penalizes
483/// underestimates more than overestimates.
484///
485/// # Arguments
486///
487/// * `y_true` - Ground truth (correct) target values
488/// * `y_pred` - Estimated target values
489///
490/// # Returns
491///
492/// * The mean squared logarithmic error
493///
494/// # Notes
495///
496/// * This metric cannot be used with negative values
497///
498/// # Examples
499///
500/// ```
501/// use scirs2_core::ndarray::array;
502/// use scirs2_metrics::regression::mean_squared_log_error;
503///
504/// let y_true = array![3.0, 5.0, 2.5, 7.0];
505/// let y_pred = array![2.5, 5.0, 3.0, 8.0];
506///
507/// let msle = mean_squared_log_error(&y_true, &y_pred).unwrap();
508/// assert!(msle > 0.0);
509/// ```
510#[allow(dead_code)]
511pub fn mean_squared_log_error<F, S1, S2, D1, D2>(
512    y_true: &ArrayBase<S1, D1>,
513    y_pred: &ArrayBase<S2, D2>,
514) -> Result<F>
515where
516    F: Float + NumCast + std::fmt::Debug + scirs2_core::simd_ops::SimdUnifiedOps,
517    S1: Data<Elem = F>,
518    S2: Data<Elem = F>,
519    D1: Dimension,
520    D2: Dimension,
521{
522    // Check that arrays have the same shape
523    check_sameshape::<F, S1, S2, D1, D2>(y_true, y_pred)?;
524
525    let n_samples = y_true.len();
526
527    // Check that all values are non-negative
528    for &val in y_true.iter() {
529        if val < F::zero() {
530            return Err(MetricsError::InvalidInput(
531                "y_true contains negative values".to_string(),
532            ));
533        }
534    }
535
536    for &val in y_pred.iter() {
537        if val < F::zero() {
538            return Err(MetricsError::InvalidInput(
539                "y_pred contains negative values".to_string(),
540            ));
541        }
542    }
543
544    let mut squared_log_diff_sum = F::zero();
545    for (yt, yp) in y_true.iter().zip(y_pred.iter()) {
546        // Add 1 to avoid taking log of 0
547        let log_yt = (*yt + F::one()).ln();
548        let log_yp = (*yp + F::one()).ln();
549        let log_diff = log_yt - log_yp;
550        squared_log_diff_sum = squared_log_diff_sum + log_diff * log_diff;
551    }
552
553    Ok(squared_log_diff_sum / NumCast::from(n_samples).unwrap())
554}
555
556/// Calculates the Huber loss
557///
558/// Huber loss is less sensitive to outliers than squared error loss.
559/// For small errors, it behaves like squared error, and for large errors,
560/// it behaves like absolute error.
561///
562/// # Arguments
563///
564/// * `y_true` - Ground truth (correct) target values
565/// * `y_pred` - Estimated target values
566/// * `delta` - Threshold where the loss changes from squared to linear
567///
568/// # Returns
569///
570/// * The Huber loss
571///
572/// # Examples
573///
574/// ```
575/// use scirs2_core::ndarray::array;
576/// use scirs2_metrics::regression::huber_loss;
577///
578/// let y_true = array![3.0, -0.5, 2.0, 7.0];
579/// let y_pred = array![2.5, 0.0, 2.0, 8.0];
580/// let delta = 0.5;
581///
582/// let loss = huber_loss(&y_true, &y_pred, delta).unwrap();
583/// assert!(loss > 0.0);
584/// ```
585#[allow(dead_code)]
586pub fn huber_loss<F, S1, S2, D1, D2>(
587    y_true: &ArrayBase<S1, D1>,
588    y_pred: &ArrayBase<S2, D2>,
589    delta: F,
590) -> Result<F>
591where
592    F: Float + NumCast + std::fmt::Debug + scirs2_core::simd_ops::SimdUnifiedOps,
593    S1: Data<Elem = F>,
594    S2: Data<Elem = F>,
595    D1: Dimension,
596    D2: Dimension,
597{
598    // Check that arrays have the same shape
599    check_sameshape::<F, S1, S2, D1, D2>(y_true, y_pred)?;
600
601    if delta <= F::zero() {
602        return Err(MetricsError::InvalidInput(
603            "delta must be positive".to_string(),
604        ));
605    }
606
607    let n_samples = y_true.len();
608    let mut loss_sum = F::zero();
609
610    for (yt, yp) in y_true.iter().zip(y_pred.iter()) {
611        let error = (*yt - *yp).abs();
612        if error <= delta {
613            // Quadratic part
614            loss_sum = loss_sum + F::from(0.5).unwrap() * error * error;
615        } else {
616            // Linear part
617            loss_sum = loss_sum + delta * (error - F::from(0.5).unwrap() * delta);
618        }
619    }
620
621    Ok(loss_sum / NumCast::from(n_samples).unwrap())
622}
623
624/// Calculates the normalized root mean squared error (NRMSE)
625///
626/// # Arguments
627///
628/// * `y_true` - Ground truth (correct) target values
629/// * `y_pred` - Estimated target values
630/// * `normalization` - Method used for normalization:
631///   * "mean" - RMSE / mean(y_true)
632///   * "range" - RMSE / (max(y_true) - min(y_true))
633///   * "iqr" - RMSE / interquartile range of y_true
634///
635/// # Returns
636///
637/// * The normalized root mean squared error
638///
639/// # Examples
640///
641/// ```no_run
642/// use scirs2_core::ndarray::array;
643/// use scirs2_metrics::regression::normalized_root_mean_squared_error;
644///
645/// let y_true = array![3.0, -0.5, 2.0, 7.0];
646/// let y_pred = array![2.5, 0.0, 2.0, 8.0];
647///
648/// let nrmse_mean: f64 = normalized_root_mean_squared_error(&y_true, &y_pred, "mean").unwrap();
649/// let nrmse_range: f64 = normalized_root_mean_squared_error(&y_true, &y_pred, "range").unwrap();
650/// assert!(nrmse_mean > 0.0);
651/// assert!(nrmse_range > 0.0);
652/// ```
653#[allow(dead_code)]
654pub fn normalized_root_mean_squared_error<F, S1, S2, D1, D2>(
655    y_true: &ArrayBase<S1, D1>,
656    y_pred: &ArrayBase<S2, D2>,
657    normalization: &str,
658) -> Result<F>
659where
660    F: Float + NumCast + std::fmt::Debug + scirs2_core::simd_ops::SimdUnifiedOps,
661    S1: Data<Elem = F>,
662    S2: Data<Elem = F>,
663    D1: Dimension,
664    D2: Dimension,
665{
666    let rmse = root_mean_squared_error(y_true, y_pred)?;
667
668    match normalization {
669        "mean" => {
670            // RMSE / mean(y_true)
671            let mean = y_true.iter().fold(F::zero(), |acc, &y| acc + y)
672                / NumCast::from(y_true.len()).unwrap();
673            if mean.abs() < F::epsilon() {
674                return Err(MetricsError::InvalidInput(
675                    "Mean of y_true is zero, cannot normalize by mean".to_string(),
676                ));
677            }
678            Ok(rmse / mean.abs())
679        }
680        "range" => {
681            // RMSE / (max(y_true) - min(y_true))
682            let max = y_true
683                .iter()
684                .fold(F::neg_infinity(), |acc, &y| if y > acc { y } else { acc });
685            let min = y_true
686                .iter()
687                .fold(F::infinity(), |acc, &y| if y < acc { y } else { acc });
688            let range = max - min;
689            if range < F::epsilon() {
690                return Err(MetricsError::InvalidInput(
691                    "Range of y_true is zero, cannot normalize by range".to_string(),
692                ));
693            }
694            Ok(rmse / range)
695        }
696        "iqr" => {
697            // RMSE / interquartile range of y_true
698            let mut values: Vec<F> = y_true.iter().cloned().collect();
699            values.sort_by(|a, b| a.partial_cmp(b).unwrap_or(Ordering::Equal));
700
701            let n = values.len();
702            let q1_idx = n / 4;
703            let q3_idx = 3 * n / 4;
704
705            let q1 = if n.is_multiple_of(4) {
706                (values[q1_idx - 1] + values[q1_idx]) / NumCast::from(2).unwrap()
707            } else {
708                values[q1_idx]
709            };
710
711            let q3 = if n.is_multiple_of(4) {
712                (values[q3_idx - 1] + values[q3_idx]) / NumCast::from(2).unwrap()
713            } else {
714                values[q3_idx]
715            };
716
717            let iqr = q3 - q1;
718            if iqr < F::epsilon() {
719                return Err(MetricsError::InvalidInput(
720                    "Interquartile range of y_true is zero, cannot normalize by IQR".to_string(),
721                ));
722            }
723            Ok(rmse / iqr)
724        }
725        _ => Err(MetricsError::InvalidInput(format!(
726            "Unknown normalization method: {}. Valid options are 'mean', 'range', 'iqr'.",
727            normalization
728        ))),
729    }
730}
731
732/// Calculates the relative absolute error (RAE)
733///
734/// RAE is the ratio of the sum of absolute errors to the sum of absolute
735/// deviations from the mean of the true values.
736///
737/// # Arguments
738///
739/// * `y_true` - Ground truth (correct) target values
740/// * `y_pred` - Estimated target values
741///
742/// # Returns
743///
744/// * The relative absolute error
745///
746/// # Examples
747///
748/// ```
749/// use scirs2_core::ndarray::array;
750/// use scirs2_metrics::regression::relative_absolute_error;
751///
752/// let y_true = array![3.0, -0.5, 2.0, 7.0];
753/// let y_pred = array![2.5, 0.0, 2.0, 8.0];
754///
755/// let rae = relative_absolute_error(&y_true, &y_pred).unwrap();
756/// assert!(rae > 0.0 && rae < 1.0);
757/// ```
758#[allow(dead_code)]
759pub fn relative_absolute_error<F, S1, S2, D1, D2>(
760    y_true: &ArrayBase<S1, D1>,
761    y_pred: &ArrayBase<S2, D2>,
762) -> Result<F>
763where
764    F: Float + NumCast + std::fmt::Debug + scirs2_core::simd_ops::SimdUnifiedOps,
765    S1: Data<Elem = F>,
766    S2: Data<Elem = F>,
767    D1: Dimension,
768    D2: Dimension,
769{
770    // Check that arrays have the same shape
771    check_sameshape::<F, S1, S2, D1, D2>(y_true, y_pred)?;
772
773    // Calculate mean of y_true
774    let y_true_mean =
775        y_true.iter().fold(F::zero(), |acc, &y| acc + y) / NumCast::from(y_true.len()).unwrap();
776
777    let mut abs_error_sum = F::zero();
778    let mut abs_mean_diff_sum = F::zero();
779
780    for (yt, yp) in y_true.iter().zip(y_pred.iter()) {
781        abs_error_sum = abs_error_sum + (*yt - *yp).abs();
782        abs_mean_diff_sum = abs_mean_diff_sum + (*yt - y_true_mean).abs();
783    }
784
785    if abs_mean_diff_sum < F::epsilon() {
786        return Err(MetricsError::InvalidInput(
787            "Sum of absolute deviations from mean is zero".to_string(),
788        ));
789    }
790
791    Ok(abs_error_sum / abs_mean_diff_sum)
792}
793
794/// Calculates the relative squared error (RSE)
795///
796/// RSE is the ratio of the sum of squared errors to the sum of squared
797/// deviations from the mean of the true values.
798///
799/// # Arguments
800///
801/// * `y_true` - Ground truth (correct) target values
802/// * `y_pred` - Estimated target values
803///
804/// # Returns
805///
806/// * The relative squared error
807///
808/// # Examples
809///
810/// ```
811/// use scirs2_core::ndarray::array;
812/// use scirs2_metrics::regression::relative_squared_error;
813///
814/// let y_true = array![3.0, -0.5, 2.0, 7.0];
815/// let y_pred = array![2.5, 0.0, 2.0, 8.0];
816///
817/// let rse = relative_squared_error(&y_true, &y_pred).unwrap();
818/// assert!(rse > 0.0 && rse < 1.0);
819/// ```
820#[allow(dead_code)]
821pub fn relative_squared_error<F, S1, S2, D1, D2>(
822    y_true: &ArrayBase<S1, D1>,
823    y_pred: &ArrayBase<S2, D2>,
824) -> Result<F>
825where
826    F: Float + NumCast + std::fmt::Debug + scirs2_core::simd_ops::SimdUnifiedOps,
827    S1: Data<Elem = F>,
828    S2: Data<Elem = F>,
829    D1: Dimension,
830    D2: Dimension,
831{
832    // Check that arrays have the same shape
833    check_sameshape::<F, S1, S2, D1, D2>(y_true, y_pred)?;
834
835    // Calculate mean of y_true
836    let y_true_mean =
837        y_true.iter().fold(F::zero(), |acc, &y| acc + y) / NumCast::from(y_true.len()).unwrap();
838
839    let mut squared_error_sum = F::zero();
840    let mut squared_mean_diff_sum = F::zero();
841
842    for (yt, yp) in y_true.iter().zip(y_pred.iter()) {
843        let error = *yt - *yp;
844        squared_error_sum = squared_error_sum + error * error;
845
846        let mean_diff = *yt - y_true_mean;
847        squared_mean_diff_sum = squared_mean_diff_sum + mean_diff * mean_diff;
848    }
849
850    if squared_mean_diff_sum < F::epsilon() {
851        return Err(MetricsError::InvalidInput(
852            "Sum of squared deviations from mean is zero".to_string(),
853        ));
854    }
855
856    Ok(squared_error_sum / squared_mean_diff_sum)
857}