scirs2_metrics/regression/
residual.rs

1//! Residual analysis for regression models
2//!
3//! This module provides functions for analyzing residuals of regression models,
4//! including histograms, Q-Q plots, and comprehensive residual analysis.
5
6use scirs2_core::ndarray::{Array1, Array2, ArrayBase, Data, Dimension};
7use scirs2_core::numeric::{Float, FromPrimitive, NumCast};
8use std::cmp::Ordering;
9
10use super::check_sameshape;
11use crate::error::{MetricsError, Result};
12
13/// Structure representing a histogram of residuals
14#[derive(Debug, Clone)]
15pub struct ErrorHistogram<F: Float> {
16    /// Bin edges (length = n_bins + 1)
17    pub bin_edges: Vec<F>,
18    /// Bin counts (length = n_bins)
19    pub bin_counts: Vec<usize>,
20    /// Number of observations in each bin
21    pub n_observations: usize,
22    /// Minimum residual value
23    pub min_error: F,
24    /// Maximum residual value
25    pub max_error: F,
26}
27
28/// Calculates a histogram of error/residual values
29///
30/// # Arguments
31///
32/// * `y_true` - Ground truth (correct) target values
33/// * `y_pred` - Estimated target values
34/// * `n_bins` - Number of bins for the histogram
35///
36/// # Returns
37///
38/// * An `ErrorHistogram` struct containing the histogram data
39///
40/// # Examples
41///
42/// ```
43/// use scirs2_core::ndarray::array;
44/// use scirs2_metrics::regression::error_histogram;
45///
46/// let y_true = array![3.0, -0.5, 2.0, 7.0, 5.0, 8.0, 1.0, 4.0];
47/// let y_pred = array![2.5, 0.0, 2.0, 8.0, 4.5, 7.5, 1.5, 3.5];
48///
49/// let hist = error_histogram(&y_true, &y_pred, 4).unwrap();
50/// assert_eq!(hist.bin_counts.len(), 4);
51/// assert_eq!(hist.bin_edges.len(), 5);
52/// assert_eq!(hist.n_observations, 8);
53/// ```
54#[allow(dead_code)]
55pub fn error_histogram<F, S1, S2, D1, D2>(
56    y_true: &ArrayBase<S1, D1>,
57    y_pred: &ArrayBase<S2, D2>,
58    n_bins: usize,
59) -> Result<ErrorHistogram<F>>
60where
61    F: Float + NumCast + std::fmt::Debug + FromPrimitive,
62    S1: Data<Elem = F>,
63    S2: Data<Elem = F>,
64    D1: Dimension,
65    D2: Dimension,
66{
67    // Check that arrays have the same shape
68    check_sameshape::<F, S1, S2, D1, D2>(y_true, y_pred)?;
69
70    if n_bins == 0 {
71        return Err(MetricsError::InvalidInput(
72            "Number of _bins must be positive".to_string(),
73        ));
74    }
75
76    // Calculate residuals
77    let n_samples = y_true.len();
78    let mut residuals = Vec::with_capacity(n_samples);
79
80    for (yt, yp) in y_true.iter().zip(y_pred.iter()) {
81        residuals.push(*yt - *yp);
82    }
83
84    // Find min and max residuals
85    let mut min_error = residuals[0];
86    let mut max_error = residuals[0];
87
88    for &residual in &residuals[1..] {
89        if residual < min_error {
90            min_error = residual;
91        }
92        if residual > max_error {
93            max_error = residual;
94        }
95    }
96
97    // Create bin edges
98    let range = if max_error > min_error {
99        max_error - min_error
100    } else {
101        F::one()
102    };
103    let bin_width = range / NumCast::from(n_bins).unwrap();
104
105    let mut bin_edges = Vec::with_capacity(n_bins + 1);
106    for i in 0..=n_bins {
107        bin_edges.push(min_error + F::from(i).unwrap() * bin_width);
108    }
109
110    // Count values in each bin
111    let mut bin_counts = vec![0; n_bins];
112
113    for &residual in &residuals {
114        if residual == max_error {
115            // Last bin for the maximum value
116            bin_counts[n_bins - 1] += 1;
117        } else {
118            // Find the appropriate bin
119            let bin_idx = ((residual - min_error) / bin_width).to_usize().unwrap();
120            bin_counts[bin_idx] += 1;
121        }
122    }
123
124    Ok(ErrorHistogram {
125        bin_edges,
126        bin_counts,
127        n_observations: n_samples,
128        min_error,
129        max_error,
130    })
131}
132
133/// Structure representing Q-Q plot data for residuals
134#[derive(Debug, Clone)]
135pub struct QQPlotData<F: Float> {
136    /// Theoretical quantiles
137    pub theoretical_quantiles: Vec<F>,
138    /// Sample quantiles (residuals)
139    pub sample_quantiles: Vec<F>,
140    /// 45-degree reference line points
141    pub reference_line: Vec<(F, F)>,
142}
143
144/// Calculates Q-Q plot data for residuals
145///
146/// # Arguments
147///
148/// * `y_true` - Ground truth (correct) target values
149/// * `y_pred` - Estimated target values
150/// * `n_quantiles` - Number of quantiles to calculate
151///
152/// # Returns
153///
154/// * A `QQPlotData` struct containing the Q-Q plot data
155///
156/// # Examples
157///
158/// ```
159/// use scirs2_core::ndarray::array;
160/// use scirs2_metrics::regression::qq_plot_data;
161///
162/// let y_true = array![3.0, -0.5, 2.0, 7.0, 5.0, 8.0, 1.0, 4.0];
163/// let y_pred = array![2.5, 0.0, 2.0, 8.0, 4.5, 7.5, 1.5, 3.5];
164///
165/// let qq_data = qq_plot_data(&y_true, &y_pred, 20).unwrap();
166/// assert_eq!(qq_data.theoretical_quantiles.len(), qq_data.sample_quantiles.len());
167/// ```
168#[allow(dead_code)]
169pub fn qq_plot_data<F, S1, S2, D1, D2>(
170    y_true: &ArrayBase<S1, D1>,
171    y_pred: &ArrayBase<S2, D2>,
172    n_quantiles: usize,
173) -> Result<QQPlotData<F>>
174where
175    F: Float + NumCast + std::fmt::Debug + FromPrimitive,
176    S1: Data<Elem = F>,
177    S2: Data<Elem = F>,
178    D1: Dimension,
179    D2: Dimension,
180{
181    // Check that arrays have the same shape
182    check_sameshape::<F, S1, S2, D1, D2>(y_true, y_pred)?;
183
184    if n_quantiles < 2 {
185        return Err(MetricsError::InvalidInput(
186            "Number of _quantiles must be at least 2".to_string(),
187        ));
188    }
189
190    // Calculate residuals
191    let n_samples = y_true.len();
192    let mut residuals = Vec::with_capacity(n_samples);
193
194    for (yt, yp) in y_true.iter().zip(y_pred.iter()) {
195        residuals.push(*yt - *yp);
196    }
197
198    // Sort residuals
199    residuals.sort_by(|a, b| a.partial_cmp(b).unwrap_or(Ordering::Equal));
200
201    // Standardize residuals
202    let mean =
203        residuals.iter().fold(F::zero(), |acc, &x| acc + x) / NumCast::from(n_samples).unwrap();
204
205    let variance = residuals.iter().fold(F::zero(), |acc, &x| {
206        let diff = x - mean;
207        acc + diff * diff
208    }) / NumCast::from(n_samples).unwrap();
209
210    let std_dev = variance.sqrt();
211
212    let mut std_residuals = Vec::with_capacity(n_samples);
213    for &r in &residuals {
214        std_residuals.push((r - mean) / std_dev);
215    }
216
217    // Calculate theoretical _quantiles
218    let mut theoretical_quantiles = Vec::with_capacity(n_quantiles);
219    let mut sample_quantiles = Vec::with_capacity(n_quantiles);
220
221    let step = F::one() / NumCast::from(n_quantiles + 1).unwrap();
222
223    for i in 1..=n_quantiles {
224        let p: F = F::from(i).unwrap() * step;
225        let theoretical_q = normal_quantile(p.to_f64().unwrap());
226        theoretical_quantiles.push(F::from(theoretical_q).unwrap());
227
228        // Get corresponding sample quantile
229        let idx = (p * NumCast::from(n_samples).unwrap())
230            .to_usize()
231            .unwrap()
232            .min(n_samples - 1);
233        sample_quantiles.push(std_residuals[idx]);
234    }
235
236    // Create reference line
237    let mut min_val = theoretical_quantiles[0].min(sample_quantiles[0]);
238    let mut max_val = theoretical_quantiles[n_quantiles - 1].max(sample_quantiles[n_quantiles - 1]);
239
240    // Add some margin
241    let range = max_val - min_val;
242    min_val = min_val - range * F::from_f64(0.05).unwrap();
243    max_val = max_val + range * F::from_f64(0.05).unwrap();
244
245    let reference_line = vec![(min_val, min_val), (max_val, max_val)];
246
247    Ok(QQPlotData {
248        theoretical_quantiles,
249        sample_quantiles,
250        reference_line,
251    })
252}
253
254/// Approximation of the normal quantile function (inverse CDF)
255#[allow(dead_code)]
256fn normal_quantile(p: f64) -> f64 {
257    if p <= 0.0 || p >= 1.0 {
258        // Return a reasonable default value instead of panicking
259        if p <= 0.0 {
260            return -5.0; // Approximation for negative infinity
261        } else {
262            return 5.0; // Approximation for positive infinity
263        }
264    }
265
266    // Constants for Beasley-Springer-Moro algorithm
267    let a = [
268        2.50662823884,
269        -18.61500062529,
270        41.39119773534,
271        -25.44106049637,
272    ];
273    let b = [
274        -8.47351093090,
275        23.08336743743,
276        -21.06224101826,
277        3.13082909833,
278    ];
279    let c = [
280        0.3374754822726147,
281        0.9761690190917186,
282        0.1607979714918209,
283        0.0276438810333863,
284        0.0038405729373609,
285        0.0003951896511919,
286        0.0000321767881768,
287        0.0000002888167364,
288        0.0000003960315187,
289    ];
290
291    // Approximation near the center
292    if (0.08..=0.92).contains(&p) {
293        let q = p - 0.5;
294        let r = q * q;
295        let mut result = q * (a[0] + r * (a[1] + r * (a[2] + r * a[3])));
296        result /= 1.0 + r * (b[0] + r * (b[1] + r * (b[2] + r * b[3])));
297        return result;
298    }
299
300    // Approximation in the tails
301    let q = if p < 0.08 {
302        (-2.0 * (p).ln()).sqrt()
303    } else {
304        (-2.0 * (1.0 - p).ln()).sqrt()
305    };
306
307    let result = c[0]
308        + q * (c[1]
309            + q * (c[2]
310                + q * (c[3] + q * (c[4] + q * (c[5] + q * (c[6] + q * (c[7] + q * c[8])))))));
311
312    if p < 0.08 {
313        -result
314    } else {
315        result
316    }
317}
318
319/// Structure representing comprehensive residual analysis
320#[derive(Debug, Clone)]
321pub struct ResidualAnalysis<F: Float> {
322    /// Residuals (y_true - y_pred)
323    pub residuals: Vec<F>,
324    /// Standardized residuals
325    pub standardized_residuals: Vec<F>,
326    /// Studentized residuals
327    pub studentized_residuals: Vec<F>,
328    /// Cook's distances (influence measure)
329    pub cooks_distances: Vec<F>,
330    /// DFFITS (influence measure)
331    pub dffits: Vec<F>,
332    /// Leverage values (hat matrix diagonal)
333    pub leverage: Vec<F>,
334    /// Residual histogram
335    pub histogram: ErrorHistogram<F>,
336    /// Q-Q plot data
337    pub qq_plot: QQPlotData<F>,
338    /// Durbin-Watson statistic (checks for autocorrelation)
339    pub durbin_watson: F,
340    /// Breusch-Pagan test statistic (checks for heteroscedasticity)
341    pub breusch_pagan: F,
342    /// Shapiro-Wilk test statistic (checks for normality)
343    pub shapiro_wilk: F,
344}
345
346/// Performs comprehensive residual analysis for a regression model
347///
348/// # Arguments
349///
350/// * `y_true` - Ground truth (correct) target values
351/// * `y_pred` - Estimated target values
352/// * `x` - Optional predictor variables matrix (needed for some diagnostics)
353/// * `hat_matrix` - Optional hat/projection matrix (can be provided to avoid recalculation)
354///
355/// # Returns
356///
357/// * A `ResidualAnalysis` struct containing various residual diagnostics
358///
359/// # Examples
360///
361/// ```
362/// use scirs2_core::ndarray::{array, Array2};
363/// use scirs2_metrics::regression::residual_analysis;
364///
365/// let y_true = array![3.0, -0.5, 2.0, 7.0, 5.0, 8.0, 1.0, 4.0];
366/// let y_pred = array![2.5, 0.0, 2.0, 8.0, 4.5, 7.5, 1.5, 3.5];
367///
368/// // Create dummy X matrix (features matrix) with 2 predictors
369/// let x = Array2::from_shape_fn((8, 2), |(i, j)| i as f64 + j as f64);
370///
371/// let analysis = residual_analysis(&y_true, &y_pred, Some(&x), None).unwrap();
372///
373/// // Access various diagnostics
374/// println!("Durbin-Watson statistic: {}", analysis.durbin_watson);
375/// println!("Number of residuals: {}", analysis.residuals.len());
376/// ```
377#[allow(dead_code)]
378pub fn residual_analysis<F, S1, S2, D1, D2>(
379    y_true: &ArrayBase<S1, D1>,
380    y_pred: &ArrayBase<S2, D2>,
381    x: Option<&Array2<F>>,
382    hat_matrix: Option<&Array2<F>>,
383) -> Result<ResidualAnalysis<F>>
384where
385    F: Float + NumCast + std::fmt::Debug + FromPrimitive + 'static,
386    S1: Data<Elem = F>,
387    S2: Data<Elem = F>,
388    D1: Dimension,
389    D2: Dimension,
390{
391    // Check that arrays have the same shape
392    check_sameshape::<F, S1, S2, D1, D2>(y_true, y_pred)?;
393
394    let n_samples = y_true.len();
395
396    // Check X _matrix dimensions
397    if let Some(x_mat) = x {
398        if x_mat.shape()[0] != n_samples {
399            return Err(MetricsError::InvalidInput(format!(
400                "X _matrix has {} rows, but y_true has {} elements",
401                x_mat.shape()[0],
402                n_samples
403            )));
404        }
405    }
406
407    // Check hat _matrix dimensions
408    if let Some(h_mat) = hat_matrix {
409        if h_mat.shape() != [n_samples, n_samples] {
410            return Err(MetricsError::InvalidInput(format!(
411                "Hat _matrix has shape {:?}, but should be [{}, {}]",
412                h_mat.shape(),
413                n_samples,
414                n_samples
415            )));
416        }
417    }
418
419    // Calculate basic residuals
420    let mut residuals = Vec::with_capacity(n_samples);
421    for (yt, yp) in y_true.iter().zip(y_pred.iter()) {
422        residuals.push(*yt - *yp);
423    }
424
425    // Calculate mean of residuals
426    let residual_mean =
427        residuals.iter().fold(F::zero(), |acc, &r| acc + r) / NumCast::from(n_samples).unwrap();
428
429    // Calculate variance of residuals
430    let residual_var = residuals.iter().fold(F::zero(), |acc, &r| {
431        let diff = r - residual_mean;
432        acc + diff * diff
433    }) / NumCast::from(n_samples).unwrap();
434
435    let residual_std = residual_var.sqrt();
436
437    // Calculate standardized residuals
438    let mut standardized_residuals = Vec::with_capacity(n_samples);
439    for &r in &residuals {
440        standardized_residuals.push((r - residual_mean) / residual_std);
441    }
442
443    // Calculate leverage (hat _matrix diagonal)
444    let leverage = if let Some(h_mat) = hat_matrix {
445        // Extract diagonal from provided hat _matrix
446        let mut h_diag = Vec::with_capacity(n_samples);
447        for i in 0..n_samples {
448            h_diag.push(h_mat[[i, i]]);
449        }
450        h_diag
451    } else if let Some(x_mat) = x {
452        // Calculate hat _matrix diagonal using X _matrix: diag(X (X'X)^(-1) X')
453        let p = x_mat.shape()[1]; // Number of predictors
454        let xt = x_mat.t();
455
456        // Calculate X'X
457        let xtx = xt.dot(x_mat);
458
459        // Invert X'X (simplified - not a proper _matrix inversion)
460        let mut xtx_inv = Array2::<F>::zeros((p, p));
461
462        // Diagonal _matrix as a simple approximation
463        for i in 0..p {
464            if xtx[[i, i]] > F::epsilon() {
465                xtx_inv[[i, i]] = F::one() / xtx[[i, i]];
466            }
467        }
468
469        // Calculate hat _matrix diagonal
470        let mut h_diag = Vec::with_capacity(n_samples);
471        for i in 0..n_samples {
472            let mut h_ii = F::zero();
473            for j in 0..p {
474                for k in 0..p {
475                    h_ii = h_ii + x_mat[[i, j]] * xtx_inv[[j, k]] * x_mat[[i, k]];
476                }
477            }
478            h_diag.push(h_ii);
479        }
480
481        h_diag
482    } else {
483        // No X _matrix or hat _matrix provided, use default
484        vec![F::one() / NumCast::from(n_samples).unwrap(); n_samples]
485    };
486
487    // Calculate studentized residuals
488    let mut studentized_residuals = Vec::with_capacity(n_samples);
489    for (i, &r) in residuals.iter().enumerate() {
490        let h_ii = leverage[i];
491        if h_ii < F::one() {
492            let student_r = r / (residual_std * (F::one() - h_ii).sqrt());
493            studentized_residuals.push(student_r);
494        } else {
495            studentized_residuals.push(F::zero());
496        }
497    }
498
499    // Calculate Cook's distances
500    let mut cooks_distances = Vec::with_capacity(n_samples);
501    for (i, &r) in standardized_residuals.iter().enumerate() {
502        let h_ii = leverage[i];
503        if h_ii < F::one() {
504            // Use a default number of parameters if no X _matrix was provided
505            let p_value = if let Some(x_mat) = x {
506                x_mat.shape()[1]
507            } else {
508                1 // Default to 1 predictor
509            };
510            let cook_d = (r * r) * (h_ii / (F::one() - h_ii)) / NumCast::from(p_value).unwrap();
511            cooks_distances.push(cook_d);
512        } else {
513            cooks_distances.push(F::zero());
514        }
515    }
516
517    // Calculate DFFITS
518    let mut dffits = Vec::with_capacity(n_samples);
519    for (i, &r) in studentized_residuals.iter().enumerate() {
520        let h_ii = leverage[i];
521        if h_ii < F::one() {
522            let dffit = r * (h_ii / (F::one() - h_ii)).sqrt();
523            dffits.push(dffit);
524        } else {
525            dffits.push(F::zero());
526        }
527    }
528
529    // Calculate Durbin-Watson statistic (tests for autocorrelation)
530    let mut numerator = F::zero();
531    for i in 1..n_samples {
532        let diff = residuals[i] - residuals[i - 1];
533        numerator = numerator + diff * diff;
534    }
535
536    let denominator = residuals.iter().fold(F::zero(), |acc, &r| acc + r * r);
537    let durbin_watson = if denominator > F::epsilon() {
538        numerator / denominator
539    } else {
540        F::from(2.0).unwrap() // No autocorrelation
541    };
542
543    // Calculate Breusch-Pagan statistic (tests for heteroscedasticity)
544    // Simplified approach: regress squared residuals on fitted values
545    let squared_residuals: Vec<F> = residuals.iter().map(|&r| r * r).collect();
546    let mean_sq_residual = squared_residuals.iter().fold(F::zero(), |acc, &r| acc + r)
547        / NumCast::from(n_samples).unwrap();
548
549    let mut numerator = F::zero();
550    let mut denominator = F::zero();
551
552    for (i, &sq_r) in squared_residuals.iter().enumerate() {
553        let _pred = y_pred.iter().nth(i).unwrap();
554        let diff = sq_r - mean_sq_residual;
555        numerator = numerator + diff * diff;
556        denominator = denominator + (*_pred) * (*_pred);
557    }
558
559    let breusch_pagan = if denominator > F::epsilon() {
560        numerator / denominator
561    } else {
562        F::zero()
563    };
564
565    // Calculate Shapiro-Wilk statistic (tests for normality)
566    // Simplified approach based on correlation between ordered residuals and normal quantiles
567    let mut ordered_residuals = standardized_residuals.clone();
568    ordered_residuals.sort_by(|a, b| a.partial_cmp(b).unwrap_or(Ordering::Equal));
569
570    let mut expected_quantiles = Vec::with_capacity(n_samples);
571    for i in 0..n_samples {
572        let p = (F::from(i + 1).unwrap() - F::from(0.375).unwrap())
573            / (F::from(n_samples).unwrap() + F::from(0.25).unwrap());
574        let q = normal_quantile(p.to_f64().unwrap());
575        expected_quantiles.push(F::from(q).unwrap());
576    }
577
578    let mut numerator = F::zero();
579    let mut denominator = F::zero();
580
581    for (i, &r) in ordered_residuals.iter().enumerate() {
582        let q = expected_quantiles[i];
583        numerator = numerator + r * q;
584        denominator = denominator + r * r;
585    }
586
587    let shapiro_wilk = if denominator > F::epsilon() {
588        (numerator / denominator).powi(2)
589    } else {
590        F::zero()
591    };
592
593    // Calculate histogram and Q-Q plot
594    let histogram = error_histogram(y_true, y_pred, 10)?;
595    let qq_plot = qq_plot_data(y_true, y_pred, 20)?;
596
597    Ok(ResidualAnalysis {
598        residuals,
599        standardized_residuals,
600        studentized_residuals,
601        cooks_distances,
602        dffits,
603        leverage,
604        histogram,
605        qq_plot,
606        durbin_watson,
607        breusch_pagan,
608        shapiro_wilk,
609    })
610}
611
612/// Checks for heteroscedasticity in residuals using Breusch-Pagan test
613///
614/// # Arguments
615///
616/// * `y_true` - Ground truth (correct) target values
617/// * `y_pred` - Estimated target values
618///
619/// # Returns
620///
621/// * Test statistic for the Breusch-Pagan test
622///
623/// # Examples
624///
625/// ```
626/// use scirs2_core::ndarray::array;
627/// use scirs2_metrics::regression::test_heteroscedasticity;
628///
629/// let y_true = array![3.0, -0.5, 2.0, 7.0, 5.0, 8.0, 1.0, 4.0];
630/// let y_pred = array![2.5, 0.0, 2.0, 8.0, 4.5, 7.5, 1.5, 3.5];
631///
632/// let bp_stat = test_heteroscedasticity(&y_true, &y_pred).unwrap();
633/// assert!(bp_stat >= 0.0);
634/// ```
635#[allow(dead_code)]
636pub fn test_heteroscedasticity<F, S1, S2, D1, D2>(
637    y_true: &ArrayBase<S1, D1>,
638    y_pred: &ArrayBase<S2, D2>,
639) -> Result<F>
640where
641    F: Float + NumCast + std::fmt::Debug + FromPrimitive,
642    S1: Data<Elem = F>,
643    S2: Data<Elem = F>,
644    D1: Dimension,
645    D2: Dimension,
646{
647    // Check that arrays have the same shape
648    check_sameshape::<F, S1, S2, D1, D2>(y_true, y_pred)?;
649
650    let n_samples = y_true.len();
651
652    // Calculate residuals
653    let mut residuals = Vec::with_capacity(n_samples);
654    for (yt, yp) in y_true.iter().zip(y_pred.iter()) {
655        residuals.push(*yt - *yp);
656    }
657
658    // Calculate squared residuals
659    let squared_residuals: Vec<F> = residuals.iter().map(|&r| r * r).collect();
660    let mean_sq_residual = squared_residuals.iter().fold(F::zero(), |acc, &r| acc + r)
661        / NumCast::from(n_samples).unwrap();
662
663    // Regress squared residuals on fitted values
664    let mut numerator = F::zero();
665    let mut denominator = F::zero();
666
667    for (i, &sq_r) in squared_residuals.iter().enumerate() {
668        let _pred = y_pred.iter().nth(i).unwrap();
669        let diff = sq_r - mean_sq_residual;
670        numerator = numerator + diff * diff;
671        denominator = denominator + (*_pred) * (*_pred);
672    }
673
674    if denominator < F::epsilon() {
675        return Err(MetricsError::InvalidInput(
676            "Denominator in heteroscedasticity test is zero".to_string(),
677        ));
678    }
679
680    let bp_stat = numerator / denominator;
681    Ok(bp_stat)
682}
683
684/// Checks for autocorrelation in residuals using Durbin-Watson test
685///
686/// # Arguments
687///
688/// * `y_true` - Ground truth (correct) target values
689/// * `y_pred` - Estimated target values
690///
691/// # Returns
692///
693/// * Durbin-Watson test statistic
694///
695/// # Examples
696///
697/// ```
698/// use scirs2_core::ndarray::array;
699/// use scirs2_metrics::regression::test_autocorrelation;
700///
701/// let y_true = array![3.0, -0.5, 2.0, 7.0, 5.0, 8.0, 1.0, 4.0];
702/// let y_pred = array![2.5, 0.0, 2.0, 8.0, 4.5, 7.5, 1.5, 3.5];
703///
704/// let dw_stat = test_autocorrelation(&y_true, &y_pred).unwrap();
705/// // DW statistic ranges from 0 to 4, with 2 being no autocorrelation
706/// assert!(dw_stat >= 0.0 && dw_stat <= 4.0);
707/// ```
708#[allow(dead_code)]
709pub fn test_autocorrelation<F, S1, S2, D1, D2>(
710    y_true: &ArrayBase<S1, D1>,
711    y_pred: &ArrayBase<S2, D2>,
712) -> Result<F>
713where
714    F: Float + NumCast + std::fmt::Debug + FromPrimitive,
715    S1: Data<Elem = F>,
716    S2: Data<Elem = F>,
717    D1: Dimension,
718    D2: Dimension,
719{
720    // Check that arrays have the same shape
721    check_sameshape::<F, S1, S2, D1, D2>(y_true, y_pred)?;
722
723    let n_samples = y_true.len();
724
725    if n_samples < 2 {
726        return Err(MetricsError::InvalidInput(
727            "At least 2 samples required for autocorrelation test".to_string(),
728        ));
729    }
730
731    // Calculate residuals
732    let mut residuals = Vec::with_capacity(n_samples);
733    for (yt, yp) in y_true.iter().zip(y_pred.iter()) {
734        residuals.push(*yt - *yp);
735    }
736
737    // Calculate Durbin-Watson statistic
738    let mut numerator = F::zero();
739    for i in 1..n_samples {
740        let diff = residuals[i] - residuals[i - 1];
741        numerator = numerator + diff * diff;
742    }
743
744    let denominator = residuals.iter().fold(F::zero(), |acc, &r| acc + r * r);
745
746    if denominator < F::epsilon() {
747        return Err(MetricsError::InvalidInput(
748            "Sum of squared residuals is zero in autocorrelation test".to_string(),
749        ));
750    }
751
752    let dw_stat = numerator / denominator;
753    Ok(dw_stat)
754}
755
756/// Checks for normality of residuals using Shapiro-Wilk test
757///
758/// # Arguments
759///
760/// * `y_true` - Ground truth (correct) target values
761/// * `y_pred` - Estimated target values
762///
763/// # Returns
764///
765/// * Shapiro-Wilk test statistic
766///
767/// # Examples
768///
769/// ```
770/// use scirs2_core::ndarray::array;
771/// use scirs2_metrics::regression::test_normality;
772///
773/// let y_true = array![3.0, -0.5, 2.0, 7.0, 5.0, 8.0, 1.0, 4.0];
774/// let y_pred = array![2.5, 0.0, 2.0, 8.0, 4.5, 7.5, 1.5, 3.5];
775///
776/// let sw_stat = test_normality(&y_true, &y_pred).unwrap();
777/// // SW statistic ranges from 0 to 1, with values close to 1 indicating normality
778/// assert!(sw_stat >= 0.0 && sw_stat <= 1.0);
779/// ```
780#[allow(dead_code)]
781pub fn test_normality<F, S1, S2, D1, D2>(
782    y_true: &ArrayBase<S1, D1>,
783    y_pred: &ArrayBase<S2, D2>,
784) -> Result<F>
785where
786    F: Float + NumCast + std::fmt::Debug + FromPrimitive,
787    S1: Data<Elem = F>,
788    S2: Data<Elem = F>,
789    D1: Dimension,
790    D2: Dimension,
791{
792    // Check that arrays have the same shape
793    check_sameshape::<F, S1, S2, D1, D2>(y_true, y_pred)?;
794
795    let n_samples = y_true.len();
796
797    // Calculate residuals
798    let mut residuals = Vec::with_capacity(n_samples);
799    for (yt, yp) in y_true.iter().zip(y_pred.iter()) {
800        residuals.push(*yt - *yp);
801    }
802
803    // Calculate mean and standard deviation
804    let mean =
805        residuals.iter().fold(F::zero(), |acc, &r| acc + r) / NumCast::from(n_samples).unwrap();
806
807    let variance = residuals.iter().fold(F::zero(), |acc, &r| {
808        let diff = r - mean;
809        acc + diff * diff
810    }) / NumCast::from(n_samples).unwrap();
811
812    let std_dev = variance.sqrt();
813
814    // Standardize residuals
815    let mut std_residuals = Vec::with_capacity(n_samples);
816    for &r in &residuals {
817        std_residuals.push((r - mean) / std_dev);
818    }
819
820    // Sort standardized residuals
821    let mut ordered_residuals = std_residuals.clone();
822    ordered_residuals.sort_by(|a, b| a.partial_cmp(b).unwrap_or(Ordering::Equal));
823
824    // Calculate expected normal order statistics
825    let mut expected_quantiles = Vec::with_capacity(n_samples);
826    for i in 0..n_samples {
827        let p = (F::from(i + 1).unwrap() - F::from(0.375).unwrap())
828            / (F::from(n_samples).unwrap() + F::from(0.25).unwrap());
829        let q = normal_quantile(p.to_f64().unwrap());
830        expected_quantiles.push(F::from(q).unwrap());
831    }
832
833    // Calculate correlation between ordered residuals and expected quantiles
834    let mean_residual = F::zero(); // Standardized residuals have mean zero
835    let mean_quantile = expected_quantiles.iter().fold(F::zero(), |acc, &q| acc + q)
836        / NumCast::from(n_samples).unwrap();
837
838    let mut numerator = F::zero();
839    let mut denom_residual = F::zero();
840    let mut denom_quantile = F::zero();
841
842    for i in 0..n_samples {
843        let res_dev = ordered_residuals[i] - mean_residual;
844        let quant_dev = expected_quantiles[i] - mean_quantile;
845
846        numerator = numerator + res_dev * quant_dev;
847        denom_residual = denom_residual + res_dev * res_dev;
848        denom_quantile = denom_quantile + quant_dev * quant_dev;
849    }
850
851    let denominator = (denom_residual * denom_quantile).sqrt();
852
853    if denominator < F::epsilon() {
854        return Err(MetricsError::InvalidInput(
855            "Denominator in normality test is zero".to_string(),
856        ));
857    }
858
859    let correlation = numerator / denominator;
860
861    // Shapiro-Wilk statistic is approximately the square of this correlation
862    let sw_stat = correlation * correlation;
863
864    Ok(sw_stat.min(F::one()))
865}