scirs2_metrics/regression/error.rs
1//! Error metrics for regression models
2//!
3//! This module provides functions for calculating error metrics between
4//! predicted values and true values in regression models.
5
6use scirs2_core::ndarray::{ArrayBase, ArrayView1, Data, Dimension};
7use scirs2_core::numeric::{Float, FromPrimitive, NumCast};
8use scirs2_core::simd_ops::SimdUnifiedOps;
9use std::cmp::Ordering;
10
11use super::check_sameshape;
12use crate::error::{MetricsError, Result};
13
14/// Calculates the mean squared error (MSE)
15///
16/// # Mathematical Formulation
17///
18/// Mean Squared Error is defined as:
19///
20/// ```text
21/// MSE = (1/n) * Σ(yᵢ - ŷᵢ)²
22/// ```
23///
24/// Where:
25/// - n = number of samples
26/// - yᵢ = true value for sample i
27/// - ŷᵢ = predicted value for sample i
28/// - Σ = sum over all samples
29///
30/// # Properties
31///
32/// - MSE is always non-negative (≥ 0)
33/// - MSE = 0 indicates perfect predictions
34/// - MSE penalizes larger errors more heavily due to squaring
35/// - Units: squared units of the target variable
36/// - Differentiable everywhere (useful for optimization)
37///
38/// # Interpretation
39///
40/// MSE measures the average squared difference between predicted and actual values:
41/// - Lower MSE indicates better model performance
42/// - Sensitive to outliers due to squaring of errors
43/// - Large errors contribute disproportionately to the total error
44///
45/// # Relationship to Other Metrics
46///
47/// - RMSE = √MSE (same units as target variable)
48/// - MAE typically ≤ RMSE, with equality when all errors are equal
49/// - MSE is the expected value of squared error in probabilistic terms
50///
51/// # Use Cases
52///
53/// MSE is widely used because:
54/// - It's differentiable (good for gradient-based optimization)
55/// - It heavily penalizes large errors
56/// - It's the basis for ordinary least squares regression
57/// - It corresponds to Gaussian likelihood in probabilistic models
58///
59/// # Arguments
60///
61/// * `y_true` - Ground truth (correct) target values
62/// * `y_pred` - Estimated target values
63///
64/// # Returns
65///
66/// * The mean squared error
67///
68/// # Examples
69///
70/// ```no_run
71/// use scirs2_core::ndarray::array;
72/// use scirs2_metrics::regression::mean_squared_error;
73///
74/// let y_true = array![3.0, -0.5, 2.0, 7.0];
75/// let y_pred = array![2.5, 0.0, 2.0, 8.0];
76///
77/// let mse: f64 = mean_squared_error(&y_true, &y_pred).unwrap();
78/// // Expecting: ((3.0-2.5)² + (-0.5-0.0)² + (2.0-2.0)² + (7.0-8.0)²) / 4
79/// assert!(mse < 0.38 && mse > 0.37);
80/// ```
81#[allow(dead_code)]
82pub fn mean_squared_error<F, S1, S2, D1, D2>(
83 y_true: &ArrayBase<S1, D1>,
84 y_pred: &ArrayBase<S2, D2>,
85) -> Result<F>
86where
87 F: Float + NumCast + std::fmt::Debug + scirs2_core::simd_ops::SimdUnifiedOps,
88 S1: Data<Elem = F>,
89 S2: Data<Elem = F>,
90 D1: Dimension,
91 D2: Dimension,
92{
93 // Check that arrays have the same shape
94 check_sameshape::<F, S1, S2, D1, D2>(y_true, y_pred)?;
95
96 let n_samples = y_true.len();
97
98 // Use SIMD optimizations for vector operations when data is contiguous
99 let squared_error_sum = if y_true.is_standard_layout() && y_pred.is_standard_layout() {
100 // SIMD-optimized computation - convert to 1D views for SIMD _ops
101 let y_true_view = y_true.view();
102 let y_pred_view = y_pred.view();
103 let y_true_reshaped = y_true_view.to_shape(y_true.len()).unwrap();
104 let y_pred_reshaped = y_pred_view.to_shape(y_pred.len()).unwrap();
105 let y_true_1d = y_true_reshaped.view();
106 let y_pred_1d = y_pred_reshaped.view();
107 let diff = F::simd_sub(&y_true_1d, &y_pred_1d);
108 let squared_diff = F::simd_mul(&diff.view(), &diff.view());
109 F::simd_sum(&squared_diff.view())
110 } else {
111 // Fallback for non-contiguous arrays
112 let mut sum = F::zero();
113 for (yt, yp) in y_true.iter().zip(y_pred.iter()) {
114 let error = *yt - *yp;
115 sum = sum + error * error;
116 }
117 sum
118 };
119
120 Ok(squared_error_sum / NumCast::from(n_samples).unwrap())
121}
122
123/// Calculates the root mean squared error (RMSE)
124///
125/// Root mean squared error is the square root of the mean squared error.
126///
127/// # Arguments
128///
129/// * `y_true` - Ground truth (correct) target values
130/// * `y_pred` - Estimated target values
131///
132/// # Returns
133///
134/// * The root mean squared error
135///
136/// # Examples
137///
138/// ```no_run
139/// use scirs2_core::ndarray::array;
140/// use scirs2_metrics::regression::root_mean_squared_error;
141///
142/// let y_true = array![3.0, -0.5, 2.0, 7.0];
143/// let y_pred = array![2.5, 0.0, 2.0, 8.0];
144///
145/// let rmse: f64 = root_mean_squared_error(&y_true, &y_pred).unwrap();
146/// // RMSE is the square root of MSE
147/// assert!(rmse < 0.62 && rmse > 0.61);
148/// ```
149#[allow(dead_code)]
150pub fn root_mean_squared_error<F, S1, S2, D1, D2>(
151 y_true: &ArrayBase<S1, D1>,
152 y_pred: &ArrayBase<S2, D2>,
153) -> Result<F>
154where
155 F: Float + NumCast + std::fmt::Debug + scirs2_core::simd_ops::SimdUnifiedOps,
156 S1: Data<Elem = F>,
157 S2: Data<Elem = F>,
158 D1: Dimension,
159 D2: Dimension,
160{
161 let mse = mean_squared_error(y_true, y_pred)?;
162 Ok(mse.sqrt())
163}
164
165/// Calculates the mean absolute error (MAE)
166///
167/// Mean absolute error measures the average absolute difference between
168/// the estimated values and the actual value.
169///
170/// # Arguments
171///
172/// * `y_true` - Ground truth (correct) target values
173/// * `y_pred` - Estimated target values
174///
175/// # Returns
176///
177/// * The mean absolute error
178///
179/// # Examples
180///
181/// ```no_run
182/// use scirs2_core::ndarray::array;
183/// use scirs2_metrics::regression::mean_absolute_error;
184///
185/// let y_true = array![3.0, -0.5, 2.0, 7.0];
186/// let y_pred = array![2.5, 0.0, 2.0, 8.0];
187///
188/// let mae: f64 = mean_absolute_error(&y_true, &y_pred).unwrap();
189/// // Expecting: (|3.0-2.5| + |-0.5-0.0| + |2.0-2.0| + |7.0-8.0|) / 4 = 0.5
190/// assert!(mae > 0.499 && mae < 0.501);
191/// ```
192#[allow(dead_code)]
193pub fn mean_absolute_error<F, S1, S2, D1, D2>(
194 y_true: &ArrayBase<S1, D1>,
195 y_pred: &ArrayBase<S2, D2>,
196) -> Result<F>
197where
198 F: Float + NumCast + std::fmt::Debug + scirs2_core::simd_ops::SimdUnifiedOps,
199 S1: Data<Elem = F>,
200 S2: Data<Elem = F>,
201 D1: Dimension,
202 D2: Dimension,
203{
204 // Check that arrays have the same shape
205 check_sameshape::<F, S1, S2, D1, D2>(y_true, y_pred)?;
206
207 let n_samples = y_true.len();
208
209 // Use SIMD optimizations for vector operations when data is contiguous
210 let abs_error_sum = if y_true.is_standard_layout() && y_pred.is_standard_layout() {
211 // SIMD-optimized computation for 1D arrays
212 let y_true_view = y_true.view();
213 let y_pred_view = y_pred.view();
214 let y_true_reshaped = y_true_view.to_shape(y_true.len()).unwrap();
215 let y_pred_reshaped = y_pred_view.to_shape(y_pred.len()).unwrap();
216 let y_true_1d = y_true_reshaped.view();
217 let y_pred_1d = y_pred_reshaped.view();
218 let diff = F::simd_sub(&y_true_1d, &y_pred_1d);
219 let abs_diff = F::simd_abs(&diff.view());
220 F::simd_sum(&abs_diff.view())
221 } else {
222 // Fallback for non-contiguous arrays
223 let mut sum = F::zero();
224 for (yt, yp) in y_true.iter().zip(y_pred.iter()) {
225 let error = (*yt - *yp).abs();
226 sum = sum + error;
227 }
228 sum
229 };
230
231 Ok(abs_error_sum / NumCast::from(n_samples).unwrap())
232}
233
234/// Calculates the mean absolute percentage error (MAPE)
235///
236/// Mean absolute percentage error expresses the difference between true
237/// and predicted values as a percentage of the true values.
238///
239/// # Arguments
240///
241/// * `y_true` - Ground truth (correct) target values
242/// * `y_pred` - Estimated target values
243///
244/// # Returns
245///
246/// * The mean absolute percentage error
247///
248/// # Notes
249///
250/// MAPE is undefined when true values are zero. This implementation
251/// excludes those samples from the calculation.
252///
253/// # Examples
254///
255/// ```
256/// use scirs2_core::ndarray::array;
257/// use scirs2_metrics::regression::mean_absolute_percentage_error;
258///
259/// let y_true = array![3.0, 0.5, 2.0, 7.0];
260/// let y_pred = array![2.7, 0.4, 1.8, 7.7];
261///
262/// let mape = mean_absolute_percentage_error(&y_true, &y_pred).unwrap();
263/// // Example calculation: (|3.0-2.7|/3.0 + |0.5-0.4|/0.5 + |2.0-1.8|/2.0 + |7.0-7.7|/7.0) / 4 * 100
264/// assert!(mape < 13.0 && mape > 9.0);
265/// ```
266#[allow(dead_code)]
267pub fn mean_absolute_percentage_error<F, S1, S2, D1, D2>(
268 y_true: &ArrayBase<S1, D1>,
269 y_pred: &ArrayBase<S2, D2>,
270) -> Result<F>
271where
272 F: Float + NumCast + std::fmt::Debug + scirs2_core::simd_ops::SimdUnifiedOps,
273 S1: Data<Elem = F>,
274 S2: Data<Elem = F>,
275 D1: Dimension,
276 D2: Dimension,
277{
278 // Check that arrays have the same shape
279 check_sameshape::<F, S1, S2, D1, D2>(y_true, y_pred)?;
280
281 let mut percentage_error_sum = F::zero();
282 let mut valid_samples = 0;
283
284 for (yt, yp) in y_true.iter().zip(y_pred.iter()) {
285 if yt.abs() > F::epsilon() {
286 let percentage_error = ((*yt - *yp) / *yt).abs();
287 percentage_error_sum = percentage_error_sum + percentage_error;
288 valid_samples += 1;
289 }
290 }
291
292 if valid_samples == 0 {
293 return Err(MetricsError::InvalidInput(
294 "All y_true values are zero. MAPE is undefined.".to_string(),
295 ));
296 }
297
298 // Multiply by 100 to get percentage
299 Ok(percentage_error_sum / NumCast::from(valid_samples).unwrap() * NumCast::from(100).unwrap())
300}
301
302/// Calculates the symmetric mean absolute percentage error (SMAPE)
303///
304/// SMAPE is an alternative to MAPE that handles zero or near-zero values better.
305///
306/// # Arguments
307///
308/// * `y_true` - Ground truth (correct) target values
309/// * `y_pred` - Estimated target values
310///
311/// # Returns
312///
313/// * The symmetric mean absolute percentage error
314///
315/// # Examples
316///
317/// ```
318/// use scirs2_core::ndarray::array;
319/// use scirs2_metrics::regression::symmetric_mean_absolute_percentage_error;
320///
321/// let y_true = array![3.0, 0.01, 2.0, 7.0];
322/// let y_pred = array![2.7, 0.0, 1.8, 7.7];
323///
324/// let smape = symmetric_mean_absolute_percentage_error(&y_true, &y_pred).unwrap();
325/// assert!(smape > 0.0);
326/// ```
327#[allow(dead_code)]
328pub fn symmetric_mean_absolute_percentage_error<F, S1, S2, D1, D2>(
329 y_true: &ArrayBase<S1, D1>,
330 y_pred: &ArrayBase<S2, D2>,
331) -> Result<F>
332where
333 F: Float + NumCast + std::fmt::Debug + scirs2_core::simd_ops::SimdUnifiedOps,
334 S1: Data<Elem = F>,
335 S2: Data<Elem = F>,
336 D1: Dimension,
337 D2: Dimension,
338{
339 // Check that arrays have the same shape
340 check_sameshape::<F, S1, S2, D1, D2>(y_true, y_pred)?;
341
342 let mut percentage_error_sum = F::zero();
343 let mut valid_samples = 0;
344
345 for (yt, yp) in y_true.iter().zip(y_pred.iter()) {
346 // Skip samples where both y_true and y_pred are zero to avoid undefined values
347 if yt.abs() > F::epsilon() || yp.abs() > F::epsilon() {
348 let percentage_error = ((*yt - *yp).abs()) / (yt.abs() + yp.abs());
349 percentage_error_sum = percentage_error_sum + percentage_error;
350 valid_samples += 1;
351 }
352 }
353
354 if valid_samples == 0 {
355 return Err(MetricsError::InvalidInput(
356 "All values are zero. SMAPE is undefined.".to_string(),
357 ));
358 }
359
360 // Multiply by 200 to get percentage (SMAPE is typically defined with factor of 2)
361 Ok(percentage_error_sum / NumCast::from(valid_samples).unwrap() * NumCast::from(200).unwrap())
362}
363
364/// Calculates the maximum error
365///
366/// Maximum error is the maximum absolute difference between the true and predicted values.
367///
368/// # Arguments
369///
370/// * `y_true` - Ground truth (correct) target values
371/// * `y_pred` - Estimated target values
372///
373/// # Returns
374///
375/// * The maximum error
376///
377/// # Examples
378///
379/// ```
380/// use scirs2_core::ndarray::array;
381/// use scirs2_metrics::regression::max_error;
382///
383/// let y_true = array![3.0, -0.5, 2.0, 7.0];
384/// let y_pred = array![2.5, 0.0, 2.0, 8.0];
385///
386/// let me = max_error(&y_true, &y_pred).unwrap();
387/// // Maximum of [|3.0-2.5|, |-0.5-0.0|, |2.0-2.0|, |7.0-8.0|]
388/// assert_eq!(me, 1.0);
389/// ```
390#[allow(dead_code)]
391pub fn max_error<F, S1, S2, D1, D2>(
392 y_true: &ArrayBase<S1, D1>,
393 y_pred: &ArrayBase<S2, D2>,
394) -> Result<F>
395where
396 F: Float + NumCast + std::fmt::Debug + scirs2_core::simd_ops::SimdUnifiedOps,
397 S1: Data<Elem = F>,
398 S2: Data<Elem = F>,
399 D1: Dimension,
400 D2: Dimension,
401{
402 // Check that arrays have the same shape
403 check_sameshape::<F, S1, S2, D1, D2>(y_true, y_pred)?;
404
405 let mut max_err = F::zero();
406 for (yt, yp) in y_true.iter().zip(y_pred.iter()) {
407 let error = (*yt - *yp).abs();
408 if error > max_err {
409 max_err = error;
410 }
411 }
412
413 Ok(max_err)
414}
415
416/// Calculates the median absolute error
417///
418/// Median absolute error is the median of all absolute differences between
419/// the true and predicted values. It is robust to outliers.
420///
421/// # Arguments
422///
423/// * `y_true` - Ground truth (correct) target values
424/// * `y_pred` - Estimated target values
425///
426/// # Returns
427///
428/// * The median absolute error
429///
430/// # Examples
431///
432/// ```
433/// use scirs2_core::ndarray::array;
434/// use scirs2_metrics::regression::median_absolute_error;
435///
436/// let y_true = array![3.0, -0.5, 2.0, 7.0];
437/// let y_pred = array![2.5, 0.0, 2.0, 8.0];
438///
439/// let medae = median_absolute_error(&y_true, &y_pred).unwrap();
440/// // Median of [|3.0-2.5|, |-0.5-0.0|, |2.0-2.0|, |7.0-8.0|] = Median of [0.5, 0.5, 0.0, 1.0]
441/// assert_eq!(medae, 0.5);
442/// ```
443#[allow(dead_code)]
444pub fn median_absolute_error<F, S1, S2, D1, D2>(
445 y_true: &ArrayBase<S1, D1>,
446 y_pred: &ArrayBase<S2, D2>,
447) -> Result<F>
448where
449 F: Float + NumCast + std::fmt::Debug + scirs2_core::simd_ops::SimdUnifiedOps,
450 S1: Data<Elem = F>,
451 S2: Data<Elem = F>,
452 D1: Dimension,
453 D2: Dimension,
454{
455 // Check that arrays have the same shape
456 check_sameshape::<F, S1, S2, D1, D2>(y_true, y_pred)?;
457
458 let n_samples = y_true.len();
459
460 // Calculate absolute errors
461 let mut abs_errors = Vec::with_capacity(n_samples);
462 for (yt, yp) in y_true.iter().zip(y_pred.iter()) {
463 abs_errors.push((*yt - *yp).abs());
464 }
465
466 // Sort and get median
467 abs_errors.sort_by(|a, b| a.partial_cmp(b).unwrap_or(Ordering::Equal));
468
469 if n_samples % 2 == 1 {
470 // Odd number of samples
471 Ok(abs_errors[n_samples / 2])
472 } else {
473 // Even number of samples
474 let mid = n_samples / 2;
475 Ok((abs_errors[mid - 1] + abs_errors[mid]) / NumCast::from(2).unwrap())
476 }
477}
478
479/// Calculates the mean squared logarithmic error (MSLE)
480///
481/// Mean squared logarithmic error measures the average squared difference
482/// between the logarithm of the predicted and true values. This metric penalizes
483/// underestimates more than overestimates.
484///
485/// # Arguments
486///
487/// * `y_true` - Ground truth (correct) target values
488/// * `y_pred` - Estimated target values
489///
490/// # Returns
491///
492/// * The mean squared logarithmic error
493///
494/// # Notes
495///
496/// * This metric cannot be used with negative values
497///
498/// # Examples
499///
500/// ```
501/// use scirs2_core::ndarray::array;
502/// use scirs2_metrics::regression::mean_squared_log_error;
503///
504/// let y_true = array![3.0, 5.0, 2.5, 7.0];
505/// let y_pred = array![2.5, 5.0, 3.0, 8.0];
506///
507/// let msle = mean_squared_log_error(&y_true, &y_pred).unwrap();
508/// assert!(msle > 0.0);
509/// ```
510#[allow(dead_code)]
511pub fn mean_squared_log_error<F, S1, S2, D1, D2>(
512 y_true: &ArrayBase<S1, D1>,
513 y_pred: &ArrayBase<S2, D2>,
514) -> Result<F>
515where
516 F: Float + NumCast + std::fmt::Debug + scirs2_core::simd_ops::SimdUnifiedOps,
517 S1: Data<Elem = F>,
518 S2: Data<Elem = F>,
519 D1: Dimension,
520 D2: Dimension,
521{
522 // Check that arrays have the same shape
523 check_sameshape::<F, S1, S2, D1, D2>(y_true, y_pred)?;
524
525 let n_samples = y_true.len();
526
527 // Check that all values are non-negative
528 for &val in y_true.iter() {
529 if val < F::zero() {
530 return Err(MetricsError::InvalidInput(
531 "y_true contains negative values".to_string(),
532 ));
533 }
534 }
535
536 for &val in y_pred.iter() {
537 if val < F::zero() {
538 return Err(MetricsError::InvalidInput(
539 "y_pred contains negative values".to_string(),
540 ));
541 }
542 }
543
544 let mut squared_log_diff_sum = F::zero();
545 for (yt, yp) in y_true.iter().zip(y_pred.iter()) {
546 // Add 1 to avoid taking log of 0
547 let log_yt = (*yt + F::one()).ln();
548 let log_yp = (*yp + F::one()).ln();
549 let log_diff = log_yt - log_yp;
550 squared_log_diff_sum = squared_log_diff_sum + log_diff * log_diff;
551 }
552
553 Ok(squared_log_diff_sum / NumCast::from(n_samples).unwrap())
554}
555
556/// Calculates the Huber loss
557///
558/// Huber loss is less sensitive to outliers than squared error loss.
559/// For small errors, it behaves like squared error, and for large errors,
560/// it behaves like absolute error.
561///
562/// # Arguments
563///
564/// * `y_true` - Ground truth (correct) target values
565/// * `y_pred` - Estimated target values
566/// * `delta` - Threshold where the loss changes from squared to linear
567///
568/// # Returns
569///
570/// * The Huber loss
571///
572/// # Examples
573///
574/// ```
575/// use scirs2_core::ndarray::array;
576/// use scirs2_metrics::regression::huber_loss;
577///
578/// let y_true = array![3.0, -0.5, 2.0, 7.0];
579/// let y_pred = array![2.5, 0.0, 2.0, 8.0];
580/// let delta = 0.5;
581///
582/// let loss = huber_loss(&y_true, &y_pred, delta).unwrap();
583/// assert!(loss > 0.0);
584/// ```
585#[allow(dead_code)]
586pub fn huber_loss<F, S1, S2, D1, D2>(
587 y_true: &ArrayBase<S1, D1>,
588 y_pred: &ArrayBase<S2, D2>,
589 delta: F,
590) -> Result<F>
591where
592 F: Float + NumCast + std::fmt::Debug + scirs2_core::simd_ops::SimdUnifiedOps,
593 S1: Data<Elem = F>,
594 S2: Data<Elem = F>,
595 D1: Dimension,
596 D2: Dimension,
597{
598 // Check that arrays have the same shape
599 check_sameshape::<F, S1, S2, D1, D2>(y_true, y_pred)?;
600
601 if delta <= F::zero() {
602 return Err(MetricsError::InvalidInput(
603 "delta must be positive".to_string(),
604 ));
605 }
606
607 let n_samples = y_true.len();
608 let mut loss_sum = F::zero();
609
610 for (yt, yp) in y_true.iter().zip(y_pred.iter()) {
611 let error = (*yt - *yp).abs();
612 if error <= delta {
613 // Quadratic part
614 loss_sum = loss_sum + F::from(0.5).unwrap() * error * error;
615 } else {
616 // Linear part
617 loss_sum = loss_sum + delta * (error - F::from(0.5).unwrap() * delta);
618 }
619 }
620
621 Ok(loss_sum / NumCast::from(n_samples).unwrap())
622}
623
624/// Calculates the normalized root mean squared error (NRMSE)
625///
626/// # Arguments
627///
628/// * `y_true` - Ground truth (correct) target values
629/// * `y_pred` - Estimated target values
630/// * `normalization` - Method used for normalization:
631/// * "mean" - RMSE / mean(y_true)
632/// * "range" - RMSE / (max(y_true) - min(y_true))
633/// * "iqr" - RMSE / interquartile range of y_true
634///
635/// # Returns
636///
637/// * The normalized root mean squared error
638///
639/// # Examples
640///
641/// ```no_run
642/// use scirs2_core::ndarray::array;
643/// use scirs2_metrics::regression::normalized_root_mean_squared_error;
644///
645/// let y_true = array![3.0, -0.5, 2.0, 7.0];
646/// let y_pred = array![2.5, 0.0, 2.0, 8.0];
647///
648/// let nrmse_mean: f64 = normalized_root_mean_squared_error(&y_true, &y_pred, "mean").unwrap();
649/// let nrmse_range: f64 = normalized_root_mean_squared_error(&y_true, &y_pred, "range").unwrap();
650/// assert!(nrmse_mean > 0.0);
651/// assert!(nrmse_range > 0.0);
652/// ```
653#[allow(dead_code)]
654pub fn normalized_root_mean_squared_error<F, S1, S2, D1, D2>(
655 y_true: &ArrayBase<S1, D1>,
656 y_pred: &ArrayBase<S2, D2>,
657 normalization: &str,
658) -> Result<F>
659where
660 F: Float + NumCast + std::fmt::Debug + scirs2_core::simd_ops::SimdUnifiedOps,
661 S1: Data<Elem = F>,
662 S2: Data<Elem = F>,
663 D1: Dimension,
664 D2: Dimension,
665{
666 let rmse = root_mean_squared_error(y_true, y_pred)?;
667
668 match normalization {
669 "mean" => {
670 // RMSE / mean(y_true)
671 let mean = y_true.iter().fold(F::zero(), |acc, &y| acc + y)
672 / NumCast::from(y_true.len()).unwrap();
673 if mean.abs() < F::epsilon() {
674 return Err(MetricsError::InvalidInput(
675 "Mean of y_true is zero, cannot normalize by mean".to_string(),
676 ));
677 }
678 Ok(rmse / mean.abs())
679 }
680 "range" => {
681 // RMSE / (max(y_true) - min(y_true))
682 let max = y_true
683 .iter()
684 .fold(F::neg_infinity(), |acc, &y| if y > acc { y } else { acc });
685 let min = y_true
686 .iter()
687 .fold(F::infinity(), |acc, &y| if y < acc { y } else { acc });
688 let range = max - min;
689 if range < F::epsilon() {
690 return Err(MetricsError::InvalidInput(
691 "Range of y_true is zero, cannot normalize by range".to_string(),
692 ));
693 }
694 Ok(rmse / range)
695 }
696 "iqr" => {
697 // RMSE / interquartile range of y_true
698 let mut values: Vec<F> = y_true.iter().cloned().collect();
699 values.sort_by(|a, b| a.partial_cmp(b).unwrap_or(Ordering::Equal));
700
701 let n = values.len();
702 let q1_idx = n / 4;
703 let q3_idx = 3 * n / 4;
704
705 let q1 = if n.is_multiple_of(4) {
706 (values[q1_idx - 1] + values[q1_idx]) / NumCast::from(2).unwrap()
707 } else {
708 values[q1_idx]
709 };
710
711 let q3 = if n.is_multiple_of(4) {
712 (values[q3_idx - 1] + values[q3_idx]) / NumCast::from(2).unwrap()
713 } else {
714 values[q3_idx]
715 };
716
717 let iqr = q3 - q1;
718 if iqr < F::epsilon() {
719 return Err(MetricsError::InvalidInput(
720 "Interquartile range of y_true is zero, cannot normalize by IQR".to_string(),
721 ));
722 }
723 Ok(rmse / iqr)
724 }
725 _ => Err(MetricsError::InvalidInput(format!(
726 "Unknown normalization method: {}. Valid options are 'mean', 'range', 'iqr'.",
727 normalization
728 ))),
729 }
730}
731
732/// Calculates the relative absolute error (RAE)
733///
734/// RAE is the ratio of the sum of absolute errors to the sum of absolute
735/// deviations from the mean of the true values.
736///
737/// # Arguments
738///
739/// * `y_true` - Ground truth (correct) target values
740/// * `y_pred` - Estimated target values
741///
742/// # Returns
743///
744/// * The relative absolute error
745///
746/// # Examples
747///
748/// ```
749/// use scirs2_core::ndarray::array;
750/// use scirs2_metrics::regression::relative_absolute_error;
751///
752/// let y_true = array![3.0, -0.5, 2.0, 7.0];
753/// let y_pred = array![2.5, 0.0, 2.0, 8.0];
754///
755/// let rae = relative_absolute_error(&y_true, &y_pred).unwrap();
756/// assert!(rae > 0.0 && rae < 1.0);
757/// ```
758#[allow(dead_code)]
759pub fn relative_absolute_error<F, S1, S2, D1, D2>(
760 y_true: &ArrayBase<S1, D1>,
761 y_pred: &ArrayBase<S2, D2>,
762) -> Result<F>
763where
764 F: Float + NumCast + std::fmt::Debug + scirs2_core::simd_ops::SimdUnifiedOps,
765 S1: Data<Elem = F>,
766 S2: Data<Elem = F>,
767 D1: Dimension,
768 D2: Dimension,
769{
770 // Check that arrays have the same shape
771 check_sameshape::<F, S1, S2, D1, D2>(y_true, y_pred)?;
772
773 // Calculate mean of y_true
774 let y_true_mean =
775 y_true.iter().fold(F::zero(), |acc, &y| acc + y) / NumCast::from(y_true.len()).unwrap();
776
777 let mut abs_error_sum = F::zero();
778 let mut abs_mean_diff_sum = F::zero();
779
780 for (yt, yp) in y_true.iter().zip(y_pred.iter()) {
781 abs_error_sum = abs_error_sum + (*yt - *yp).abs();
782 abs_mean_diff_sum = abs_mean_diff_sum + (*yt - y_true_mean).abs();
783 }
784
785 if abs_mean_diff_sum < F::epsilon() {
786 return Err(MetricsError::InvalidInput(
787 "Sum of absolute deviations from mean is zero".to_string(),
788 ));
789 }
790
791 Ok(abs_error_sum / abs_mean_diff_sum)
792}
793
794/// Calculates the relative squared error (RSE)
795///
796/// RSE is the ratio of the sum of squared errors to the sum of squared
797/// deviations from the mean of the true values.
798///
799/// # Arguments
800///
801/// * `y_true` - Ground truth (correct) target values
802/// * `y_pred` - Estimated target values
803///
804/// # Returns
805///
806/// * The relative squared error
807///
808/// # Examples
809///
810/// ```
811/// use scirs2_core::ndarray::array;
812/// use scirs2_metrics::regression::relative_squared_error;
813///
814/// let y_true = array![3.0, -0.5, 2.0, 7.0];
815/// let y_pred = array![2.5, 0.0, 2.0, 8.0];
816///
817/// let rse = relative_squared_error(&y_true, &y_pred).unwrap();
818/// assert!(rse > 0.0 && rse < 1.0);
819/// ```
820#[allow(dead_code)]
821pub fn relative_squared_error<F, S1, S2, D1, D2>(
822 y_true: &ArrayBase<S1, D1>,
823 y_pred: &ArrayBase<S2, D2>,
824) -> Result<F>
825where
826 F: Float + NumCast + std::fmt::Debug + scirs2_core::simd_ops::SimdUnifiedOps,
827 S1: Data<Elem = F>,
828 S2: Data<Elem = F>,
829 D1: Dimension,
830 D2: Dimension,
831{
832 // Check that arrays have the same shape
833 check_sameshape::<F, S1, S2, D1, D2>(y_true, y_pred)?;
834
835 // Calculate mean of y_true
836 let y_true_mean =
837 y_true.iter().fold(F::zero(), |acc, &y| acc + y) / NumCast::from(y_true.len()).unwrap();
838
839 let mut squared_error_sum = F::zero();
840 let mut squared_mean_diff_sum = F::zero();
841
842 for (yt, yp) in y_true.iter().zip(y_pred.iter()) {
843 let error = *yt - *yp;
844 squared_error_sum = squared_error_sum + error * error;
845
846 let mean_diff = *yt - y_true_mean;
847 squared_mean_diff_sum = squared_mean_diff_sum + mean_diff * mean_diff;
848 }
849
850 if squared_mean_diff_sum < F::epsilon() {
851 return Err(MetricsError::InvalidInput(
852 "Sum of squared deviations from mean is zero".to_string(),
853 ));
854 }
855
856 Ok(squared_error_sum / squared_mean_diff_sum)
857}