sklears_multioutput/
metrics.rs

1//! Multi-output and multi-label evaluation metrics
2
3// Use SciRS2-Core for arrays and random number generation (SciRS2 Policy)
4use scirs2_core::ndarray::ArrayView2;
5use sklears_core::error::{Result as SklResult, SklearsError};
6
7/// Hamming loss for multi-label classification
8///
9/// The Hamming loss is the fraction of the wrong labels to the total
10/// number of labels. It is a multi-label generalization of the zero-one loss.
11///
12/// # Arguments
13///
14/// * `y_true` - Ground truth (correct) labels
15/// * `y_pred` - Predicted labels
16///
17/// # Returns
18///
19/// The Hamming loss between y_true and y_pred
20pub fn hamming_loss(y_true: &ArrayView2<'_, i32>, y_pred: &ArrayView2<'_, i32>) -> SklResult<f64> {
21    if y_true.dim() != y_pred.dim() {
22        return Err(SklearsError::InvalidInput(
23            "y_true and y_pred must have the same shape".to_string(),
24        ));
25    }
26
27    let (n_samples, n_labels) = y_true.dim();
28    if n_samples == 0 || n_labels == 0 {
29        return Err(SklearsError::InvalidInput(
30            "Input arrays must have at least one sample and one label".to_string(),
31        ));
32    }
33
34    let mut total_errors = 0;
35    let total_elements = n_samples * n_labels;
36
37    for sample_idx in 0..n_samples {
38        for label_idx in 0..n_labels {
39            if y_true[[sample_idx, label_idx]] != y_pred[[sample_idx, label_idx]] {
40                total_errors += 1;
41            }
42        }
43    }
44
45    Ok(total_errors as f64 / total_elements as f64)
46}
47
48/// Subset accuracy for multi-label classification
49///
50/// Subset accuracy is the most strict metric. It requires for each sample
51/// that each label set be correctly predicted.
52///
53/// # Arguments
54///
55/// * `y_true` - Ground truth (correct) labels
56/// * `y_pred` - Predicted labels
57///
58/// # Returns
59///
60/// The subset accuracy between y_true and y_pred
61pub fn subset_accuracy(
62    y_true: &ArrayView2<'_, i32>,
63    y_pred: &ArrayView2<'_, i32>,
64) -> SklResult<f64> {
65    if y_true.dim() != y_pred.dim() {
66        return Err(SklearsError::InvalidInput(
67            "y_true and y_pred must have the same shape".to_string(),
68        ));
69    }
70
71    let (n_samples, n_labels) = y_true.dim();
72    if n_samples == 0 {
73        return Err(SklearsError::InvalidInput(
74            "Input arrays must have at least one sample".to_string(),
75        ));
76    }
77
78    let mut correct_subsets = 0;
79
80    for sample_idx in 0..n_samples {
81        let mut subset_correct = true;
82        for label_idx in 0..n_labels {
83            if y_true[[sample_idx, label_idx]] != y_pred[[sample_idx, label_idx]] {
84                subset_correct = false;
85                break;
86            }
87        }
88        if subset_correct {
89            correct_subsets += 1;
90        }
91    }
92
93    Ok(correct_subsets as f64 / n_samples as f64)
94}
95
96/// Jaccard similarity coefficient for multi-label classification
97///
98/// The Jaccard similarity coefficient is defined as the size of the intersection
99/// divided by the size of the union of the sample sets.
100///
101/// # Arguments
102///
103/// * `y_true` - Ground truth (correct) labels
104/// * `y_pred` - Predicted labels
105///
106/// # Returns
107///
108/// The Jaccard similarity coefficient
109pub fn jaccard_score(y_true: &ArrayView2<'_, i32>, y_pred: &ArrayView2<'_, i32>) -> SklResult<f64> {
110    if y_true.dim() != y_pred.dim() {
111        return Err(SklearsError::InvalidInput(
112            "y_true and y_pred must have the same shape".to_string(),
113        ));
114    }
115
116    let (n_samples, n_labels) = y_true.dim();
117    if n_samples == 0 {
118        return Err(SklearsError::InvalidInput(
119            "Input arrays must have at least one sample".to_string(),
120        ));
121    }
122
123    let mut total_jaccard = 0.0;
124
125    for sample_idx in 0..n_samples {
126        let mut intersection = 0;
127        let mut union = 0;
128
129        for label_idx in 0..n_labels {
130            let true_label = y_true[[sample_idx, label_idx]];
131            let pred_label = y_pred[[sample_idx, label_idx]];
132
133            if true_label == 1 && pred_label == 1 {
134                intersection += 1;
135            }
136            if true_label == 1 || pred_label == 1 {
137                union += 1;
138            }
139        }
140
141        // Jaccard = intersection / union, handle division by zero
142        let sample_jaccard = if union > 0 {
143            intersection as f64 / union as f64
144        } else {
145            1.0 // If both sets are empty, Jaccard = 1
146        };
147
148        total_jaccard += sample_jaccard;
149    }
150
151    Ok(total_jaccard / n_samples as f64)
152}
153
154/// F1 score for multi-label classification
155///
156/// Compute the F1 score for each label and return the specified average.
157///
158/// # Arguments
159///
160/// * `y_true` - Ground truth (correct) labels
161/// * `y_pred` - Predicted labels
162/// * `average` - The averaging strategy ('micro', 'macro', 'samples')
163///
164/// # Returns
165///
166/// The F1 score according to the specified averaging strategy
167pub fn f1_score(
168    y_true: &ArrayView2<'_, i32>,
169    y_pred: &ArrayView2<'_, i32>,
170    average: &str,
171) -> SklResult<f64> {
172    if y_true.dim() != y_pred.dim() {
173        return Err(SklearsError::InvalidInput(
174            "y_true and y_pred must have the same shape".to_string(),
175        ));
176    }
177
178    let (n_samples, n_labels) = y_true.dim();
179    if n_samples == 0 || n_labels == 0 {
180        return Err(SklearsError::InvalidInput(
181            "Input arrays must have at least one sample and one label".to_string(),
182        ));
183    }
184
185    match average {
186        "micro" => {
187            // Compute global precision and recall
188            let mut total_tp = 0;
189            let mut total_fp = 0;
190            let mut total_false_negatives = 0;
191
192            for sample_idx in 0..n_samples {
193                for label_idx in 0..n_labels {
194                    let true_label = y_true[[sample_idx, label_idx]];
195                    let pred_label = y_pred[[sample_idx, label_idx]];
196
197                    if true_label == 1 && pred_label == 1 {
198                        total_tp += 1;
199                    } else if true_label == 0 && pred_label == 1 {
200                        total_fp += 1;
201                    } else if true_label == 1 && pred_label == 0 {
202                        total_false_negatives += 1;
203                    }
204                }
205            }
206
207            let precision = if total_tp + total_fp > 0 {
208                total_tp as f64 / (total_tp + total_fp) as f64
209            } else {
210                0.0
211            };
212
213            let recall = if total_tp + total_false_negatives > 0 {
214                total_tp as f64 / (total_tp + total_false_negatives) as f64
215            } else {
216                0.0
217            };
218
219            let f1 = if precision + recall > 0.0 {
220                2.0 * precision * recall / (precision + recall)
221            } else {
222                0.0
223            };
224
225            Ok(f1)
226        }
227        "macro" => {
228            // Compute F1 for each label and average
229            let mut label_f1_scores = Vec::new();
230
231            for label_idx in 0..n_labels {
232                let mut tp = 0;
233                let mut fp = 0;
234                let mut false_negatives = 0;
235
236                for sample_idx in 0..n_samples {
237                    let true_label = y_true[[sample_idx, label_idx]];
238                    let pred_label = y_pred[[sample_idx, label_idx]];
239
240                    if true_label == 1 && pred_label == 1 {
241                        tp += 1;
242                    } else if true_label == 0 && pred_label == 1 {
243                        fp += 1;
244                    } else if true_label == 1 && pred_label == 0 {
245                        false_negatives += 1;
246                    }
247                }
248
249                let precision = if tp + fp > 0 {
250                    tp as f64 / (tp + fp) as f64
251                } else {
252                    0.0
253                };
254
255                let recall = if tp + false_negatives > 0 {
256                    tp as f64 / (tp + false_negatives) as f64
257                } else {
258                    0.0
259                };
260
261                let f1 = if precision + recall > 0.0 {
262                    2.0 * precision * recall / (precision + recall)
263                } else {
264                    0.0
265                };
266
267                label_f1_scores.push(f1);
268            }
269
270            Ok(label_f1_scores.iter().sum::<f64>() / n_labels as f64)
271        }
272        "samples" => {
273            // Compute F1 for each sample and average
274            let mut sample_f1_scores = Vec::new();
275
276            for sample_idx in 0..n_samples {
277                let mut tp = 0;
278                let mut fp = 0;
279                let mut false_negatives = 0;
280
281                for label_idx in 0..n_labels {
282                    let true_label = y_true[[sample_idx, label_idx]];
283                    let pred_label = y_pred[[sample_idx, label_idx]];
284
285                    if true_label == 1 && pred_label == 1 {
286                        tp += 1;
287                    } else if true_label == 0 && pred_label == 1 {
288                        fp += 1;
289                    } else if true_label == 1 && pred_label == 0 {
290                        false_negatives += 1;
291                    }
292                }
293
294                let precision = if tp + fp > 0 {
295                    tp as f64 / (tp + fp) as f64
296                } else {
297                    0.0
298                };
299
300                let recall = if tp + false_negatives > 0 {
301                    tp as f64 / (tp + false_negatives) as f64
302                } else {
303                    0.0
304                };
305
306                let f1 = if precision + recall > 0.0 {
307                    2.0 * precision * recall / (precision + recall)
308                } else {
309                    0.0
310                };
311
312                sample_f1_scores.push(f1);
313            }
314
315            Ok(sample_f1_scores.iter().sum::<f64>() / n_samples as f64)
316        }
317        _ => Err(SklearsError::InvalidInput(format!(
318            "Unknown average type: {}. Valid options are 'micro', 'macro', 'samples'",
319            average
320        ))),
321    }
322}
323
324/// Coverage error for multi-label ranking
325///
326/// Coverage error measures how far we need to go through the ranked scores
327/// to cover all true labels. The best value is equal to the average number
328/// of labels in y_true per sample.
329///
330/// # Arguments
331///
332/// * `y_true` - Ground truth (correct) labels
333/// * `y_scores` - Target scores (predicted probabilities)
334///
335/// # Returns
336///
337/// The coverage error
338pub fn coverage_error(
339    y_true: &ArrayView2<'_, i32>,
340    y_scores: &ArrayView2<'_, f64>,
341) -> SklResult<f64> {
342    if y_true.dim() != y_scores.dim() {
343        return Err(SklearsError::InvalidInput(
344            "y_true and y_scores must have the same shape".to_string(),
345        ));
346    }
347
348    let (n_samples, n_labels) = y_true.dim();
349    if n_samples == 0 || n_labels == 0 {
350        return Err(SklearsError::InvalidInput(
351            "Input arrays must have at least one sample and one label".to_string(),
352        ));
353    }
354
355    let mut total_coverage = 0.0;
356
357    for sample_idx in 0..n_samples {
358        // Get the indices sorted by scores in descending order
359        let mut score_label_pairs: Vec<(f64, usize)> = (0..n_labels)
360            .map(|label_idx| (y_scores[[sample_idx, label_idx]], label_idx))
361            .collect();
362        score_label_pairs
363            .sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
364
365        // Find the position of the last true label in the ranked list
366        let mut last_true_position = 0;
367        for (position, &(_, label_idx)) in score_label_pairs.iter().enumerate() {
368            if y_true[[sample_idx, label_idx]] == 1 {
369                last_true_position = position + 1; // Convert to 1-based indexing
370            }
371        }
372
373        total_coverage += last_true_position as f64;
374    }
375
376    Ok(total_coverage / n_samples as f64)
377}
378
379/// Label ranking average precision for multi-label ranking
380///
381/// The label ranking average precision (LRAP) averages over the samples
382/// the answer to the following question: for each ground truth label,
383/// what fraction of higher-ranked labels were true labels?
384///
385/// # Arguments
386///
387/// * `y_true` - Ground truth (correct) labels
388/// * `y_scores` - Target scores (predicted probabilities)
389///
390/// # Returns
391///
392/// The label ranking average precision
393pub fn label_ranking_average_precision(
394    y_true: &ArrayView2<'_, i32>,
395    y_scores: &ArrayView2<'_, f64>,
396) -> SklResult<f64> {
397    if y_true.dim() != y_scores.dim() {
398        return Err(SklearsError::InvalidInput(
399            "y_true and y_scores must have the same shape".to_string(),
400        ));
401    }
402
403    let (n_samples, n_labels) = y_true.dim();
404    if n_samples == 0 || n_labels == 0 {
405        return Err(SklearsError::InvalidInput(
406            "Input arrays must have at least one sample and one label".to_string(),
407        ));
408    }
409
410    let mut total_lrap = 0.0;
411
412    for sample_idx in 0..n_samples {
413        // Get the indices sorted by scores in descending order
414        let mut score_label_pairs: Vec<(f64, usize)> = (0..n_labels)
415            .map(|label_idx| (y_scores[[sample_idx, label_idx]], label_idx))
416            .collect();
417        score_label_pairs
418            .sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
419
420        // Count true labels for this sample
421        let n_true_labels: i32 = (0..n_labels)
422            .map(|label_idx| y_true[[sample_idx, label_idx]])
423            .sum();
424
425        if n_true_labels == 0 {
426            continue; // Skip samples with no true labels
427        }
428
429        let mut precision_sum = 0.0;
430        let mut true_labels_seen = 0;
431
432        for (position, &(_, label_idx)) in score_label_pairs.iter().enumerate() {
433            if y_true[[sample_idx, label_idx]] == 1 {
434                true_labels_seen += 1;
435                let precision_at_position = true_labels_seen as f64 / (position + 1) as f64;
436                precision_sum += precision_at_position;
437            }
438        }
439
440        let sample_lrap = precision_sum / n_true_labels as f64;
441        total_lrap += sample_lrap;
442    }
443
444    Ok(total_lrap / n_samples as f64)
445}
446
447/// One-error for multi-label ranking
448///
449/// The one-error evaluates how many times the top-ranked label is not
450/// in the set of true labels. The best performance is achieved when
451/// one-error is 0, which means the top-ranked label is always correct.
452///
453/// # Arguments
454///
455/// * `y_true` - Ground truth (correct) labels
456/// * `y_scores` - Target scores (predicted probabilities)
457///
458/// # Returns
459///
460/// The one-error (fraction of samples where top-ranked label is incorrect)
461pub fn one_error(y_true: &ArrayView2<'_, i32>, y_scores: &ArrayView2<'_, f64>) -> SklResult<f64> {
462    if y_true.dim() != y_scores.dim() {
463        return Err(SklearsError::InvalidInput(
464            "y_true and y_scores must have the same shape".to_string(),
465        ));
466    }
467
468    let (n_samples, n_labels) = y_true.dim();
469    if n_samples == 0 || n_labels == 0 {
470        return Err(SklearsError::InvalidInput(
471            "Input arrays must have at least one sample and one label".to_string(),
472        ));
473    }
474
475    let mut errors = 0;
476
477    for sample_idx in 0..n_samples {
478        // Find the label with the highest score
479        let mut max_score = f64::NEG_INFINITY;
480        let mut top_label_idx = 0;
481
482        for label_idx in 0..n_labels {
483            let score = y_scores[[sample_idx, label_idx]];
484            if score > max_score {
485                max_score = score;
486                top_label_idx = label_idx;
487            }
488        }
489
490        // Check if the top-ranked label is correct
491        if y_true[[sample_idx, top_label_idx]] != 1 {
492            errors += 1;
493        }
494    }
495
496    Ok(errors as f64 / n_samples as f64)
497}
498
499/// Ranking loss for multi-label ranking
500///
501/// The ranking loss evaluates the average fraction of label pairs that are
502/// incorrectly ordered, given the predictions. The best performance is achieved
503/// when ranking loss is 0.
504///
505/// # Arguments
506///
507/// * `y_true` - Ground truth (correct) labels
508/// * `y_scores` - Target scores (predicted probabilities)
509///
510/// # Returns
511///
512/// The ranking loss
513pub fn ranking_loss(
514    y_true: &ArrayView2<'_, i32>,
515    y_scores: &ArrayView2<'_, f64>,
516) -> SklResult<f64> {
517    if y_true.dim() != y_scores.dim() {
518        return Err(SklearsError::InvalidInput(
519            "y_true and y_scores must have the same shape".to_string(),
520        ));
521    }
522
523    let (n_samples, n_labels) = y_true.dim();
524    if n_samples == 0 || n_labels == 0 {
525        return Err(SklearsError::InvalidInput(
526            "Input arrays must have at least one sample and one label".to_string(),
527        ));
528    }
529
530    let mut total_ranking_loss = 0.0;
531
532    for sample_idx in 0..n_samples {
533        let mut incorrect_pairs = 0;
534        let mut total_pairs = 0;
535
536        // Compare all pairs of labels
537        for i in 0..n_labels {
538            for j in 0..n_labels {
539                if i != j {
540                    let true_i = y_true[[sample_idx, i]];
541                    let true_j = y_true[[sample_idx, j]];
542                    let score_i = y_scores[[sample_idx, i]];
543                    let score_j = y_scores[[sample_idx, j]];
544
545                    // Check if this is a relevant pair (one positive, one negative)
546                    if (true_i == 1 && true_j == 0) || (true_i == 0 && true_j == 1) {
547                        total_pairs += 1;
548
549                        // Check if the ordering is incorrect
550                        if (true_i == 1 && true_j == 0 && score_i < score_j)
551                            || (true_i == 0 && true_j == 1 && score_i > score_j)
552                        {
553                            incorrect_pairs += 1;
554                        }
555                    }
556                }
557            }
558        }
559
560        // Add to total ranking loss
561        if total_pairs > 0 {
562            total_ranking_loss += incorrect_pairs as f64 / total_pairs as f64;
563        }
564    }
565
566    Ok(total_ranking_loss / n_samples as f64)
567}
568
569/// Average precision score for multi-label ranking
570///
571/// Computes the average precision for each sample and then averages
572/// over all samples. This is different from label ranking average precision
573/// as it focuses on precision-recall curves for each sample.
574///
575/// # Arguments
576///
577/// * `y_true` - Ground truth (correct) labels
578/// * `y_scores` - Target scores (predicted probabilities)
579///
580/// # Returns
581///
582/// The average precision score
583pub fn average_precision_score(
584    y_true: &ArrayView2<'_, i32>,
585    y_scores: &ArrayView2<'_, f64>,
586) -> SklResult<f64> {
587    if y_true.dim() != y_scores.dim() {
588        return Err(SklearsError::InvalidInput(
589            "y_true and y_scores must have the same shape".to_string(),
590        ));
591    }
592
593    let (n_samples, n_labels) = y_true.dim();
594    if n_samples == 0 || n_labels == 0 {
595        return Err(SklearsError::InvalidInput(
596            "Input arrays must have at least one sample and one label".to_string(),
597        ));
598    }
599
600    let mut total_ap = 0.0;
601
602    for sample_idx in 0..n_samples {
603        // Get the indices sorted by scores in descending order
604        let mut score_label_pairs: Vec<(f64, usize)> = (0..n_labels)
605            .map(|label_idx| (y_scores[[sample_idx, label_idx]], label_idx))
606            .collect();
607        score_label_pairs
608            .sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
609
610        // Count true labels for this sample
611        let n_true_labels: i32 = (0..n_labels)
612            .map(|label_idx| y_true[[sample_idx, label_idx]])
613            .sum();
614
615        if n_true_labels == 0 {
616            continue; // Skip samples with no true labels
617        }
618
619        let mut precision_sum = 0.0;
620        let mut true_labels_seen = 0;
621
622        for (position, &(_, label_idx)) in score_label_pairs.iter().enumerate() {
623            if y_true[[sample_idx, label_idx]] == 1 {
624                true_labels_seen += 1;
625                let precision_at_position = true_labels_seen as f64 / (position + 1) as f64;
626                precision_sum += precision_at_position;
627            }
628        }
629
630        let sample_ap = precision_sum / n_true_labels as f64;
631        total_ap += sample_ap;
632    }
633
634    Ok(total_ap / n_samples as f64)
635}
636
637/// Micro-averaged precision for multi-label classification
638///
639/// Calculate precision by counting the total true positives and false positives
640/// across all labels and samples.
641///
642/// # Arguments
643///
644/// * `y_true` - Ground truth (correct) labels
645/// * `y_pred` - Predicted labels
646///
647/// # Returns
648///
649/// The micro-averaged precision
650pub fn precision_score_micro(
651    y_true: &ArrayView2<'_, i32>,
652    y_pred: &ArrayView2<'_, i32>,
653) -> SklResult<f64> {
654    if y_true.dim() != y_pred.dim() {
655        return Err(SklearsError::InvalidInput(
656            "y_true and y_pred must have the same shape".to_string(),
657        ));
658    }
659
660    let (n_samples, n_labels) = y_true.dim();
661    if n_samples == 0 || n_labels == 0 {
662        return Err(SklearsError::InvalidInput(
663            "Input arrays must have at least one sample and one label".to_string(),
664        ));
665    }
666
667    let mut total_tp = 0;
668    let mut total_fp = 0;
669
670    for sample_idx in 0..n_samples {
671        for label_idx in 0..n_labels {
672            let true_label = y_true[[sample_idx, label_idx]];
673            let pred_label = y_pred[[sample_idx, label_idx]];
674
675            if true_label == 1 && pred_label == 1 {
676                total_tp += 1;
677            } else if true_label == 0 && pred_label == 1 {
678                total_fp += 1;
679            }
680        }
681    }
682
683    let precision = if total_tp + total_fp > 0 {
684        total_tp as f64 / (total_tp + total_fp) as f64
685    } else {
686        0.0
687    };
688
689    Ok(precision)
690}
691
692/// Micro-averaged recall for multi-label classification
693///
694/// Calculate recall by counting the total true positives and false negatives
695/// across all labels and samples.
696///
697/// # Arguments
698///
699/// * `y_true` - Ground truth (correct) labels
700/// * `y_pred` - Predicted labels
701///
702/// # Returns
703///
704/// The micro-averaged recall
705pub fn recall_score_micro(
706    y_true: &ArrayView2<'_, i32>,
707    y_pred: &ArrayView2<'_, i32>,
708) -> SklResult<f64> {
709    if y_true.dim() != y_pred.dim() {
710        return Err(SklearsError::InvalidInput(
711            "y_true and y_pred must have the same shape".to_string(),
712        ));
713    }
714
715    let (n_samples, n_labels) = y_true.dim();
716    if n_samples == 0 || n_labels == 0 {
717        return Err(SklearsError::InvalidInput(
718            "Input arrays must have at least one sample and one label".to_string(),
719        ));
720    }
721
722    let mut total_tp = 0;
723    let mut total_fn = 0;
724
725    for sample_idx in 0..n_samples {
726        for label_idx in 0..n_labels {
727            let true_label = y_true[[sample_idx, label_idx]];
728            let pred_label = y_pred[[sample_idx, label_idx]];
729
730            if true_label == 1 && pred_label == 1 {
731                total_tp += 1;
732            } else if true_label == 1 && pred_label == 0 {
733                total_fn += 1;
734            }
735        }
736    }
737
738    let recall = if total_tp + total_fn > 0 {
739        total_tp as f64 / (total_tp + total_fn) as f64
740    } else {
741        0.0
742    };
743
744    Ok(recall)
745}
746
747// Additional imports for statistical tests
748use std::collections::HashMap;
749
750/// Per-label performance metrics
751///
752/// Container for detailed performance metrics for each label, including
753/// precision, recall, F1-score, support, and accuracy per label.
754#[derive(Debug, Clone)]
755pub struct PerLabelMetrics {
756    /// Precision score for each label
757    pub precision: Vec<f64>,
758    /// Recall score for each label
759    pub recall: Vec<f64>,
760    /// F1 score for each label
761    pub f1_score: Vec<f64>,
762    /// Support (number of true instances) for each label
763    pub support: Vec<usize>,
764    /// Accuracy for each label (considering label as binary classification)
765    pub accuracy: Vec<f64>,
766    /// Number of labels
767    pub n_labels: usize,
768}
769
770impl PerLabelMetrics {
771    /// Get the macro average of a metric
772    pub fn macro_average(&self, metric: &str) -> SklResult<f64> {
773        let values = match metric {
774            "precision" => &self.precision,
775            "recall" => &self.recall,
776            "f1_score" => &self.f1_score,
777            "accuracy" => &self.accuracy,
778            _ => return Err(SklearsError::InvalidInput(format!(
779                "Unknown metric: {}. Valid options are 'precision', 'recall', 'f1_score', 'accuracy'",
780                metric
781            )))
782        };
783
784        Ok(values.iter().sum::<f64>() / values.len() as f64)
785    }
786
787    /// Get the weighted average of a metric (weighted by support)
788    pub fn weighted_average(&self, metric: &str) -> SklResult<f64> {
789        let values = match metric {
790            "precision" => &self.precision,
791            "recall" => &self.recall,
792            "f1_score" => &self.f1_score,
793            "accuracy" => &self.accuracy,
794            _ => return Err(SklearsError::InvalidInput(format!(
795                "Unknown metric: {}. Valid options are 'precision', 'recall', 'f1_score', 'accuracy'",
796                metric
797            )))
798        };
799
800        let total_support: usize = self.support.iter().sum();
801        if total_support == 0 {
802            return Ok(0.0);
803        }
804
805        let weighted_sum: f64 = values
806            .iter()
807            .zip(self.support.iter())
808            .map(|(value, support)| value * (*support as f64))
809            .sum();
810
811        Ok(weighted_sum / total_support as f64)
812    }
813}
814
815/// Compute detailed per-label performance metrics
816///
817/// Calculates precision, recall, F1-score, support, and accuracy for each label
818/// individually, providing comprehensive per-label analysis.
819///
820/// # Arguments
821///
822/// * `y_true` - Ground truth (correct) labels
823/// * `y_pred` - Predicted labels
824///
825/// # Returns
826///
827/// PerLabelMetrics containing detailed metrics for each label
828pub fn per_label_metrics(
829    y_true: &ArrayView2<'_, i32>,
830    y_pred: &ArrayView2<'_, i32>,
831) -> SklResult<PerLabelMetrics> {
832    if y_true.dim() != y_pred.dim() {
833        return Err(SklearsError::InvalidInput(
834            "y_true and y_pred must have the same shape".to_string(),
835        ));
836    }
837
838    let (n_samples, n_labels) = y_true.dim();
839    if n_samples == 0 || n_labels == 0 {
840        return Err(SklearsError::InvalidInput(
841            "Input arrays must have at least one sample and one label".to_string(),
842        ));
843    }
844
845    let mut precision = Vec::with_capacity(n_labels);
846    let mut recall = Vec::with_capacity(n_labels);
847    let mut f1_score = Vec::with_capacity(n_labels);
848    let mut support = Vec::with_capacity(n_labels);
849    let mut accuracy = Vec::with_capacity(n_labels);
850
851    // Calculate metrics for each label
852    for label_idx in 0..n_labels {
853        let mut tp = 0;
854        let mut fp = 0;
855        let mut fn_count = 0;
856        let mut tn = 0;
857
858        for sample_idx in 0..n_samples {
859            let true_label = y_true[[sample_idx, label_idx]];
860            let pred_label = y_pred[[sample_idx, label_idx]];
861
862            match (true_label, pred_label) {
863                (1, 1) => tp += 1,
864                (0, 1) => fp += 1,
865                (1, 0) => fn_count += 1,
866                (0, 0) => tn += 1,
867                _ => {} // Should not happen with binary labels
868            }
869        }
870
871        // Calculate precision
872        let label_precision = if tp + fp > 0 {
873            tp as f64 / (tp + fp) as f64
874        } else {
875            0.0
876        };
877
878        // Calculate recall
879        let label_recall = if tp + fn_count > 0 {
880            tp as f64 / (tp + fn_count) as f64
881        } else {
882            0.0
883        };
884
885        // Calculate F1 score
886        let label_f1 = if label_precision + label_recall > 0.0 {
887            2.0 * label_precision * label_recall / (label_precision + label_recall)
888        } else {
889            0.0
890        };
891
892        // Calculate accuracy (for this label as binary classification)
893        let label_accuracy = (tp + tn) as f64 / n_samples as f64;
894
895        // Support is number of true instances for this label
896        let label_support = (tp + fn_count) as usize;
897
898        precision.push(label_precision);
899        recall.push(label_recall);
900        f1_score.push(label_f1);
901        support.push(label_support);
902        accuracy.push(label_accuracy);
903    }
904
905    Ok(PerLabelMetrics {
906        precision,
907        recall,
908        f1_score,
909        support,
910        accuracy,
911        n_labels,
912    })
913}
914
915/// Statistical significance test result
916#[derive(Debug, Clone)]
917pub struct StatisticalTestResult {
918    /// Test statistic value
919    pub statistic: f64,
920    /// P-value of the test
921    pub p_value: f64,
922    /// Whether the result is statistically significant (p < 0.05)
923    pub is_significant: bool,
924    /// Test name
925    pub test_name: String,
926    /// Additional information about the test
927    pub additional_info: HashMap<String, f64>,
928}
929
930impl StatisticalTestResult {
931    /// Create a new statistical test result
932    pub fn new(
933        statistic: f64,
934        p_value: f64,
935        test_name: String,
936        additional_info: Option<HashMap<String, f64>>,
937    ) -> Self {
938        Self {
939            statistic,
940            p_value,
941            is_significant: p_value < 0.05,
942            test_name,
943            additional_info: additional_info.unwrap_or_default(),
944        }
945    }
946}
947
948/// McNemar's test for comparing two classifiers
949///
950/// Tests whether two classifiers have significantly different error rates.
951/// Appropriate for comparing paired predictions on the same test set.
952///
953/// # Arguments
954///
955/// * `y_true` - Ground truth labels
956/// * `y_pred1` - Predictions from first classifier
957/// * `y_pred2` - Predictions from second classifier
958///
959/// # Returns
960///
961/// Statistical test result with McNemar's test statistic and p-value
962pub fn mcnemar_test(
963    y_true: &ArrayView2<'_, i32>,
964    y_pred1: &ArrayView2<'_, i32>,
965    y_pred2: &ArrayView2<'_, i32>,
966) -> SklResult<StatisticalTestResult> {
967    if y_true.dim() != y_pred1.dim() || y_true.dim() != y_pred2.dim() {
968        return Err(SklearsError::InvalidInput(
969            "All input arrays must have the same shape".to_string(),
970        ));
971    }
972
973    let (n_samples, n_labels) = y_true.dim();
974    if n_samples == 0 || n_labels == 0 {
975        return Err(SklearsError::InvalidInput(
976            "Input arrays must have at least one sample and one label".to_string(),
977        ));
978    }
979
980    // Count disagreements across all samples and labels
981    let mut n01 = 0; // Classifier 1 correct, classifier 2 incorrect
982    let mut n10 = 0; // Classifier 1 incorrect, classifier 2 correct
983
984    for sample_idx in 0..n_samples {
985        for label_idx in 0..n_labels {
986            let true_label = y_true[[sample_idx, label_idx]];
987            let pred1 = y_pred1[[sample_idx, label_idx]];
988            let pred2 = y_pred2[[sample_idx, label_idx]];
989
990            let correct1 = pred1 == true_label;
991            let correct2 = pred2 == true_label;
992
993            match (correct1, correct2) {
994                (true, false) => n01 += 1,
995                (false, true) => n10 += 1,
996                _ => {} // Both correct or both incorrect
997            }
998        }
999    }
1000
1001    // McNemar's test statistic
1002    let total_disagreements = n01 + n10;
1003    if total_disagreements == 0 {
1004        return Ok(StatisticalTestResult::new(
1005            0.0,
1006            1.0, // No disagreements means no significant difference
1007            "McNemar".to_string(),
1008            Some({
1009                let mut info = HashMap::new();
1010                info.insert("n01".to_string(), n01 as f64);
1011                info.insert("n10".to_string(), n10 as f64);
1012                info.insert(
1013                    "total_disagreements".to_string(),
1014                    total_disagreements as f64,
1015                );
1016                info
1017            }),
1018        ));
1019    }
1020
1021    // Use continuity correction for McNemar's test
1022    let statistic = ((n01 as f64 - n10 as f64).abs() - 1.0).max(0.0).powi(2) / (n01 + n10) as f64;
1023
1024    // Chi-square distribution approximation for p-value (1 degree of freedom)
1025    let p_value = chi_square_p_value(statistic, 1);
1026
1027    let mut info = HashMap::new();
1028    info.insert("n01".to_string(), n01 as f64);
1029    info.insert("n10".to_string(), n10 as f64);
1030    info.insert(
1031        "total_disagreements".to_string(),
1032        total_disagreements as f64,
1033    );
1034
1035    Ok(StatisticalTestResult::new(
1036        statistic,
1037        p_value,
1038        "McNemar".to_string(),
1039        Some(info),
1040    ))
1041}
1042
1043/// Paired t-test for metric comparisons
1044///
1045/// Tests whether the mean difference between paired metric values is
1046/// significantly different from zero.
1047///
1048/// # Arguments
1049///
1050/// * `metric_values1` - Metric values from first method
1051/// * `metric_values2` - Metric values from second method
1052///
1053/// # Returns
1054///
1055/// Statistical test result with t-statistic and p-value
1056pub fn paired_t_test(
1057    metric_values1: &[f64],
1058    metric_values2: &[f64],
1059) -> SklResult<StatisticalTestResult> {
1060    if metric_values1.len() != metric_values2.len() {
1061        return Err(SklearsError::InvalidInput(
1062            "Metric value arrays must have the same length".to_string(),
1063        ));
1064    }
1065
1066    let n = metric_values1.len();
1067    if n < 2 {
1068        return Err(SklearsError::InvalidInput(
1069            "Need at least 2 paired observations for t-test".to_string(),
1070        ));
1071    }
1072
1073    // Calculate differences
1074    let differences: Vec<f64> = metric_values1
1075        .iter()
1076        .zip(metric_values2.iter())
1077        .map(|(v1, v2)| v1 - v2)
1078        .collect();
1079
1080    // Calculate mean difference
1081    let mean_diff = differences.iter().sum::<f64>() / n as f64;
1082
1083    // Calculate standard deviation of differences
1084    let variance = differences
1085        .iter()
1086        .map(|d| (d - mean_diff).powi(2))
1087        .sum::<f64>()
1088        / (n - 1) as f64;
1089    let std_dev = variance.sqrt();
1090
1091    // t-statistic
1092    let t_statistic = mean_diff / (std_dev / (n as f64).sqrt());
1093
1094    // Degrees of freedom
1095    let df = n - 1;
1096
1097    // Two-tailed p-value using t-distribution approximation
1098    let p_value = 2.0 * (1.0 - t_distribution_cdf(t_statistic.abs(), df as f64));
1099
1100    let mut info = HashMap::new();
1101    info.insert("mean_difference".to_string(), mean_diff);
1102    info.insert("std_dev_diff".to_string(), std_dev);
1103    info.insert("degrees_of_freedom".to_string(), df as f64);
1104    info.insert("n_observations".to_string(), n as f64);
1105
1106    Ok(StatisticalTestResult::new(
1107        t_statistic,
1108        p_value,
1109        "Paired t-test".to_string(),
1110        Some(info),
1111    ))
1112}
1113
1114/// Wilcoxon signed-rank test for non-parametric metric comparison
1115///
1116/// Non-parametric alternative to paired t-test that doesn't assume
1117/// normal distribution of differences.
1118///
1119/// # Arguments
1120///
1121/// * `metric_values1` - Metric values from first method
1122/// * `metric_values2` - Metric values from second method
1123///
1124/// # Returns
1125///
1126/// Statistical test result with Wilcoxon statistic and approximate p-value
1127pub fn wilcoxon_signed_rank_test(
1128    metric_values1: &[f64],
1129    metric_values2: &[f64],
1130) -> SklResult<StatisticalTestResult> {
1131    if metric_values1.len() != metric_values2.len() {
1132        return Err(SklearsError::InvalidInput(
1133            "Metric value arrays must have the same length".to_string(),
1134        ));
1135    }
1136
1137    let n = metric_values1.len();
1138    if n < 3 {
1139        return Err(SklearsError::InvalidInput(
1140            "Need at least 3 paired observations for Wilcoxon signed-rank test".to_string(),
1141        ));
1142    }
1143
1144    // Calculate differences and their absolute values
1145    let mut differences_with_abs: Vec<(f64, f64, bool)> = metric_values1
1146        .iter()
1147        .zip(metric_values2.iter())
1148        .map(|(v1, v2)| {
1149            let diff = v1 - v2;
1150            (diff, diff.abs(), diff > 0.0)
1151        })
1152        .filter(|(_, abs_diff, _)| *abs_diff > 1e-10) // Remove ties (zero differences)
1153        .collect();
1154
1155    let n_nonzero = differences_with_abs.len();
1156    if n_nonzero < 3 {
1157        return Err(SklearsError::InvalidInput(
1158            "Too many zero differences for Wilcoxon test".to_string(),
1159        ));
1160    }
1161
1162    // Sort by absolute difference to assign ranks
1163    differences_with_abs.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
1164
1165    // Assign ranks (handling ties by averaging)
1166    let mut ranks = vec![0.0; n_nonzero];
1167    let mut i = 0;
1168    while i < n_nonzero {
1169        let current_abs_diff = differences_with_abs[i].1;
1170        let mut j = i;
1171
1172        // Find end of tie group
1173        while j < n_nonzero && (differences_with_abs[j].1 - current_abs_diff).abs() < 1e-10 {
1174            j += 1;
1175        }
1176
1177        // Assign average rank to tied values
1178        let avg_rank = (i + j + 1) as f64 / 2.0;
1179        for rank in ranks.iter_mut().take(j).skip(i) {
1180            *rank = avg_rank;
1181        }
1182
1183        i = j;
1184    }
1185
1186    // Calculate positive and negative rank sums
1187    let mut w_plus = 0.0;
1188    let mut w_minus = 0.0;
1189
1190    for i in 0..n_nonzero {
1191        if differences_with_abs[i].2 {
1192            // Positive difference
1193            w_plus += ranks[i];
1194        } else {
1195            // Negative difference
1196            w_minus += ranks[i];
1197        }
1198    }
1199
1200    // Test statistic is the smaller of the two rank sums
1201    let w_statistic = w_plus.min(w_minus);
1202
1203    // Normal approximation for p-value (valid for n >= 10)
1204    let expected_w = (n_nonzero * (n_nonzero + 1)) as f64 / 4.0;
1205    let variance_w = (n_nonzero * (n_nonzero + 1) * (2 * n_nonzero + 1)) as f64 / 24.0;
1206    let std_w = variance_w.sqrt();
1207
1208    // Continuity correction
1209    let z_score = ((w_statistic - expected_w).abs() - 0.5) / std_w;
1210
1211    // Two-tailed p-value
1212    let p_value = 2.0 * (1.0 - standard_normal_cdf(z_score));
1213
1214    let mut info = HashMap::new();
1215    info.insert("w_plus".to_string(), w_plus);
1216    info.insert("w_minus".to_string(), w_minus);
1217    info.insert("n_nonzero_differences".to_string(), n_nonzero as f64);
1218    info.insert("z_score".to_string(), z_score);
1219
1220    Ok(StatisticalTestResult::new(
1221        w_statistic,
1222        p_value,
1223        "Wilcoxon signed-rank".to_string(),
1224        Some(info),
1225    ))
1226}
1227
1228/// Confidence interval for a metric
1229#[derive(Debug, Clone)]
1230pub struct ConfidenceInterval {
1231    /// Lower bound of confidence interval
1232    pub lower: f64,
1233    /// Upper bound of confidence interval
1234    pub upper: f64,
1235    /// Point estimate (mean)
1236    pub point_estimate: f64,
1237    /// Confidence level (e.g., 0.95 for 95%)
1238    pub confidence_level: f64,
1239}
1240
1241/// Calculate confidence interval for metric values
1242///
1243/// Computes confidence interval assuming normal distribution of metric values.
1244///
1245/// # Arguments
1246///
1247/// * `metric_values` - Array of metric values
1248/// * `confidence_level` - Confidence level (e.g., 0.95 for 95% CI)
1249///
1250/// # Returns
1251///
1252/// Confidence interval with lower and upper bounds
1253pub fn confidence_interval(
1254    metric_values: &[f64],
1255    confidence_level: f64,
1256) -> SklResult<ConfidenceInterval> {
1257    if metric_values.is_empty() {
1258        return Err(SklearsError::InvalidInput(
1259            "Metric values array cannot be empty".to_string(),
1260        ));
1261    }
1262
1263    if confidence_level <= 0.0 || confidence_level >= 1.0 {
1264        return Err(SklearsError::InvalidInput(
1265            "Confidence level must be between 0 and 1".to_string(),
1266        ));
1267    }
1268
1269    let n = metric_values.len();
1270    let mean = metric_values.iter().sum::<f64>() / n as f64;
1271
1272    if n == 1 {
1273        return Ok(ConfidenceInterval {
1274            lower: mean,
1275            upper: mean,
1276            point_estimate: mean,
1277            confidence_level,
1278        });
1279    }
1280
1281    // Calculate standard error
1282    let variance = metric_values
1283        .iter()
1284        .map(|v| (v - mean).powi(2))
1285        .sum::<f64>()
1286        / (n - 1) as f64;
1287    let std_error = (variance / n as f64).sqrt();
1288
1289    // Critical value for t-distribution
1290    let alpha = 1.0 - confidence_level;
1291    let df = (n - 1) as f64;
1292    let t_critical = t_distribution_quantile(1.0 - alpha / 2.0, df);
1293
1294    // Margin of error
1295    let margin_error = t_critical * std_error;
1296
1297    Ok(ConfidenceInterval {
1298        lower: mean - margin_error,
1299        upper: mean + margin_error,
1300        point_estimate: mean,
1301        confidence_level,
1302    })
1303}
1304
1305// Statistical distribution helper functions
1306
1307/// Chi-square p-value approximation (1 degree of freedom)
1308fn chi_square_p_value(x: f64, df: usize) -> f64 {
1309    if df == 1 {
1310        // For 1 df, chi-square is the square of standard normal
1311        2.0 * (1.0 - standard_normal_cdf(x.sqrt()))
1312    } else {
1313        // Simplified approximation for other degrees of freedom
1314        let normalized = (x - df as f64) / (2.0 * df as f64).sqrt();
1315        2.0 * (1.0 - standard_normal_cdf(normalized.abs()))
1316    }
1317}
1318
1319/// Standard normal CDF approximation
1320fn standard_normal_cdf(z: f64) -> f64 {
1321    0.5 * (1.0 + erf(z / 2.0_f64.sqrt()))
1322}
1323
1324/// Error function approximation
1325fn erf(x: f64) -> f64 {
1326    // Abramowitz and Stegun approximation
1327    let a1 = 0.254829592;
1328    let a2 = -0.284496736;
1329    let a3 = 1.421413741;
1330    let a4 = -1.453152027;
1331    let a5 = 1.061405429;
1332    let p = 0.3275911;
1333
1334    let sign = if x >= 0.0 { 1.0 } else { -1.0 };
1335    let x = x.abs();
1336
1337    let t = 1.0 / (1.0 + p * x);
1338    let y = 1.0 - ((((a5 * t + a4) * t + a3) * t + a2) * t + a1) * t * (-x * x).exp();
1339
1340    sign * y
1341}
1342
1343/// t-distribution CDF approximation
1344fn t_distribution_cdf(t: f64, df: f64) -> f64 {
1345    if df > 30.0 {
1346        // For large df, t-distribution approaches standard normal
1347        standard_normal_cdf(t)
1348    } else {
1349        // Simplified approximation
1350        let normalized = t / (df + t * t).sqrt();
1351        0.5 + 0.5 * erf(normalized)
1352    }
1353}
1354
1355/// t-distribution quantile approximation
1356fn t_distribution_quantile(p: f64, df: f64) -> f64 {
1357    if df > 100.0 {
1358        // For very large df, use normal quantile approximation
1359        normal_quantile(p)
1360    } else if df >= 2.0 {
1361        // Use Wilson-Hilferty approximation for better accuracy
1362        let z = normal_quantile(p);
1363        let h = 2.0 / (9.0 * df);
1364        let correction = z.powi(2) * h / 6.0;
1365        z * (1.0 + correction).max(0.1) // Ensure positive correction
1366    } else {
1367        // For very small df, use simpler approximation
1368        let z = normal_quantile(p);
1369        z * (1.0 + (z.powi(2) + 1.0) / (4.0 * df))
1370    }
1371}
1372
1373/// Standard normal quantile approximation (inverse CDF) - Simple Box-Muller inspired approach
1374fn normal_quantile(p: f64) -> f64 {
1375    if p <= 0.0 {
1376        return f64::NEG_INFINITY;
1377    }
1378    if p >= 1.0 {
1379        return f64::INFINITY;
1380    }
1381    if (p - 0.5).abs() < f64::EPSILON {
1382        return 0.0;
1383    }
1384
1385    // Use a lookup table approach for key values and interpolation for others
1386    let known_values = [
1387        (0.001, -3.090232),
1388        (0.005, -2.575829),
1389        (0.01, -2.326348),
1390        (0.025, -1.959964),
1391        (0.05, -1.644854),
1392        (0.1, -1.281552),
1393        (0.15, -1.036433),
1394        (0.2, -0.841621),
1395        (0.25, -0.674490),
1396        (0.3, -0.524401),
1397        (0.35, -0.385320),
1398        (0.4, -0.253347),
1399        (0.45, -0.125661),
1400        (0.5, 0.0),
1401        (0.55, 0.125661),
1402        (0.6, 0.253347),
1403        (0.65, 0.385320),
1404        (0.7, 0.524401),
1405        (0.75, 0.674490),
1406        (0.8, 0.841621),
1407        (0.85, 1.036433),
1408        (0.9, 1.281552),
1409        (0.95, 1.644854),
1410        (0.975, 1.959964),
1411        (0.99, 2.326348),
1412        (0.995, 2.575829),
1413        (0.999, 3.090232),
1414    ];
1415
1416    // Find the closest tabulated values and interpolate
1417    if let Some(idx) = known_values.iter().position(|(prob, _)| *prob >= p) {
1418        if idx == 0 {
1419            return known_values[0].1;
1420        }
1421
1422        let (p1, z1) = known_values[idx - 1];
1423        let (p2, z2) = known_values[idx];
1424
1425        // Linear interpolation
1426        let weight = (p - p1) / (p2 - p1);
1427        z1 + weight * (z2 - z1)
1428    } else {
1429        // For very high probabilities, extrapolate
1430        3.5 // Conservative upper bound
1431    }
1432}
1433
1434#[allow(non_snake_case)]
1435#[cfg(test)]
1436mod tests {
1437    use super::*;
1438    use scirs2_core::ndarray::{array, Array2};
1439
1440    // Helper function to create test data
1441    fn create_test_data() -> (Array2<i32>, Array2<i32>) {
1442        let y_true = array![[1, 0, 1], [0, 1, 0], [1, 1, 1], [0, 0, 0], [1, 0, 1]];
1443        let y_pred = array![[1, 0, 0], [0, 1, 1], [1, 0, 1], [1, 0, 0], [1, 1, 1]];
1444        (y_true, y_pred)
1445    }
1446
1447    #[test]
1448    fn test_per_label_metrics_basic() {
1449        let (y_true, y_pred) = create_test_data();
1450        let y_true_view = y_true.view();
1451        let y_pred_view = y_pred.view();
1452
1453        let metrics = per_label_metrics(&y_true_view, &y_pred_view).unwrap();
1454
1455        assert_eq!(metrics.n_labels, 3);
1456        assert_eq!(metrics.precision.len(), 3);
1457        assert_eq!(metrics.recall.len(), 3);
1458        assert_eq!(metrics.f1_score.len(), 3);
1459        assert_eq!(metrics.support.len(), 3);
1460        assert_eq!(metrics.accuracy.len(), 3);
1461
1462        // Check support (number of true instances per label)
1463        assert_eq!(metrics.support[0], 3); // Label 0 has 3 true instances
1464        assert_eq!(metrics.support[1], 2); // Label 1 has 2 true instances
1465        assert_eq!(metrics.support[2], 3); // Label 2 has 3 true instances
1466
1467        // Verify all metrics are within valid range [0, 1]
1468        for i in 0..3 {
1469            assert!(metrics.precision[i] >= 0.0 && metrics.precision[i] <= 1.0);
1470            assert!(metrics.recall[i] >= 0.0 && metrics.recall[i] <= 1.0);
1471            assert!(metrics.f1_score[i] >= 0.0 && metrics.f1_score[i] <= 1.0);
1472            assert!(metrics.accuracy[i] >= 0.0 && metrics.accuracy[i] <= 1.0);
1473        }
1474    }
1475
1476    #[test]
1477    fn test_per_label_metrics_perfect_prediction() {
1478        let y_perfect = array![[1, 0, 1], [0, 1, 0], [1, 1, 1]];
1479        let y_true_view = y_perfect.view();
1480        let y_pred_view = y_perfect.view(); // Same as true
1481
1482        let metrics = per_label_metrics(&y_true_view, &y_pred_view).unwrap();
1483
1484        // All metrics should be 1.0 for perfect prediction
1485        for i in 0..3 {
1486            assert!((metrics.precision[i] - 1.0).abs() < 1e-10);
1487            assert!((metrics.recall[i] - 1.0).abs() < 1e-10);
1488            assert!((metrics.f1_score[i] - 1.0).abs() < 1e-10);
1489            assert!((metrics.accuracy[i] - 1.0).abs() < 1e-10);
1490        }
1491    }
1492
1493    #[test]
1494    fn test_per_label_metrics_all_zeros() {
1495        let y_true = array![[0, 0, 0], [0, 0, 0], [0, 0, 0]];
1496        let y_pred = array![[0, 0, 0], [0, 0, 0], [0, 0, 0]];
1497        let y_true_view = y_true.view();
1498        let y_pred_view = y_pred.view();
1499
1500        let metrics = per_label_metrics(&y_true_view, &y_pred_view).unwrap();
1501
1502        // All support should be 0
1503        for i in 0..3 {
1504            assert_eq!(metrics.support[i], 0);
1505            assert!((metrics.accuracy[i] - 1.0).abs() < 1e-10); // All correct (all negative)
1506            assert!((metrics.precision[i] - 0.0).abs() < 1e-10); // No positive predictions
1507            assert!((metrics.recall[i] - 0.0).abs() < 1e-10); // No positive instances
1508        }
1509    }
1510
1511    #[test]
1512    fn test_per_label_metrics_macro_average() {
1513        let (y_true, y_pred) = create_test_data();
1514        let y_true_view = y_true.view();
1515        let y_pred_view = y_pred.view();
1516
1517        let metrics = per_label_metrics(&y_true_view, &y_pred_view).unwrap();
1518
1519        let macro_precision = metrics.macro_average("precision").unwrap();
1520        let macro_recall = metrics.macro_average("recall").unwrap();
1521        let macro_f1 = metrics.macro_average("f1_score").unwrap();
1522        let macro_accuracy = metrics.macro_average("accuracy").unwrap();
1523
1524        // Verify macro averages are calculated correctly
1525        let expected_precision = metrics.precision.iter().sum::<f64>() / 3.0;
1526        let expected_recall = metrics.recall.iter().sum::<f64>() / 3.0;
1527        let expected_f1 = metrics.f1_score.iter().sum::<f64>() / 3.0;
1528        let expected_accuracy = metrics.accuracy.iter().sum::<f64>() / 3.0;
1529
1530        assert!((macro_precision - expected_precision).abs() < 1e-10);
1531        assert!((macro_recall - expected_recall).abs() < 1e-10);
1532        assert!((macro_f1 - expected_f1).abs() < 1e-10);
1533        assert!((macro_accuracy - expected_accuracy).abs() < 1e-10);
1534    }
1535
1536    #[test]
1537    fn test_per_label_metrics_weighted_average() {
1538        let (y_true, y_pred) = create_test_data();
1539        let y_true_view = y_true.view();
1540        let y_pred_view = y_pred.view();
1541
1542        let metrics = per_label_metrics(&y_true_view, &y_pred_view).unwrap();
1543
1544        let weighted_precision = metrics.weighted_average("precision").unwrap();
1545        let weighted_recall = metrics.weighted_average("recall").unwrap();
1546        let weighted_f1 = metrics.weighted_average("f1_score").unwrap();
1547        let weighted_accuracy = metrics.weighted_average("accuracy").unwrap();
1548
1549        // All should be valid values
1550        assert!(weighted_precision >= 0.0 && weighted_precision <= 1.0);
1551        assert!(weighted_recall >= 0.0 && weighted_recall <= 1.0);
1552        assert!(weighted_f1 >= 0.0 && weighted_f1 <= 1.0);
1553        assert!(weighted_accuracy >= 0.0 && weighted_accuracy <= 1.0);
1554
1555        // Test invalid metric name
1556        assert!(metrics.weighted_average("invalid").is_err());
1557        assert!(metrics.macro_average("invalid").is_err());
1558    }
1559
1560    #[test]
1561    fn test_per_label_metrics_error_handling() {
1562        let y_true = array![[1, 0], [0, 1]];
1563        let y_pred = array![[1, 0, 1], [0, 1, 0]]; // Different shape
1564
1565        let result = per_label_metrics(&y_true.view(), &y_pred.view());
1566        assert!(result.is_err());
1567
1568        // Empty arrays
1569        let empty_true = Array2::<i32>::zeros((0, 0));
1570        let empty_pred = Array2::<i32>::zeros((0, 0));
1571        let result = per_label_metrics(&empty_true.view(), &empty_pred.view());
1572        assert!(result.is_err());
1573    }
1574
1575    #[test]
1576    fn test_mcnemar_test_identical_classifiers() {
1577        let y_true = array![[1, 0, 1], [0, 1, 0], [1, 1, 1]];
1578        let y_pred = array![[1, 0, 0], [0, 1, 1], [1, 0, 1]];
1579
1580        // Test identical classifiers (should have p-value = 1.0)
1581        let result = mcnemar_test(&y_true.view(), &y_pred.view(), &y_pred.view()).unwrap();
1582        assert_eq!(result.test_name, "McNemar");
1583        assert!((result.p_value - 1.0).abs() < 1e-10);
1584        assert!(!result.is_significant);
1585        assert_eq!(result.statistic, 0.0);
1586    }
1587
1588    #[test]
1589    fn test_mcnemar_test_different_classifiers() {
1590        let y_true = array![[1, 0, 1], [0, 1, 0], [1, 1, 1], [0, 0, 0], [1, 0, 1]];
1591        let y_pred1 = array![[1, 0, 0], [0, 1, 1], [1, 0, 1], [1, 0, 0], [1, 1, 1]];
1592        let y_pred2 = array![[0, 1, 1], [1, 0, 0], [0, 1, 0], [0, 1, 1], [0, 0, 0]];
1593
1594        let result = mcnemar_test(&y_true.view(), &y_pred1.view(), &y_pred2.view()).unwrap();
1595        assert_eq!(result.test_name, "McNemar");
1596        assert!(result.statistic >= 0.0);
1597        assert!(result.p_value >= 0.0 && result.p_value <= 1.0);
1598
1599        // Check additional info
1600        assert!(result.additional_info.contains_key("n01"));
1601        assert!(result.additional_info.contains_key("n10"));
1602        assert!(result.additional_info.contains_key("total_disagreements"));
1603    }
1604
1605    #[test]
1606    fn test_mcnemar_test_error_handling() {
1607        let y_true = array![[1, 0], [0, 1]];
1608        let y_pred1 = array![[1, 0], [0, 1]];
1609        let y_pred2 = array![[1, 0, 1], [0, 1, 0]]; // Different shape
1610
1611        let result = mcnemar_test(&y_true.view(), &y_pred1.view(), &y_pred2.view());
1612        assert!(result.is_err());
1613    }
1614
1615    #[test]
1616    fn test_paired_t_test() {
1617        let metric_values1 = vec![0.8, 0.7, 0.9, 0.6, 0.75];
1618        let metric_values2 = vec![0.7, 0.65, 0.85, 0.55, 0.7];
1619
1620        let result = paired_t_test(&metric_values1, &metric_values2).unwrap();
1621        assert_eq!(result.test_name, "Paired t-test");
1622        assert!(result.p_value >= 0.0 && result.p_value <= 1.0);
1623
1624        // Check additional info
1625        assert!(result.additional_info.contains_key("mean_difference"));
1626        assert!(result.additional_info.contains_key("std_dev_diff"));
1627        assert!(result.additional_info.contains_key("degrees_of_freedom"));
1628        assert!(result.additional_info.contains_key("n_observations"));
1629
1630        // Mean difference should be positive (values1 > values2)
1631        let mean_diff = result.additional_info.get("mean_difference").unwrap();
1632        assert!(*mean_diff > 0.0);
1633    }
1634
1635    #[test]
1636    fn test_paired_t_test_identical_values() {
1637        let metric_values = vec![0.8, 0.7, 0.9, 0.6, 0.75];
1638
1639        let result = paired_t_test(&metric_values, &metric_values).unwrap();
1640
1641        // With identical values, mean difference should be 0 and p-value should be 1.0
1642        let mean_diff = result.additional_info.get("mean_difference").unwrap();
1643        assert!(mean_diff.abs() < 1e-10);
1644        assert!(!result.is_significant);
1645    }
1646
1647    #[test]
1648    fn test_paired_t_test_error_handling() {
1649        let values1 = vec![0.8, 0.7];
1650        let values2 = vec![0.6]; // Different length
1651
1652        let result = paired_t_test(&values1, &values2);
1653        assert!(result.is_err());
1654
1655        // Too few observations
1656        let single = vec![0.8];
1657        let result = paired_t_test(&single, &single);
1658        assert!(result.is_err());
1659    }
1660
1661    #[test]
1662    fn test_wilcoxon_signed_rank_test() {
1663        let metric_values1 = vec![0.8, 0.7, 0.9, 0.6, 0.75, 0.85, 0.65];
1664        let metric_values2 = vec![0.7, 0.65, 0.85, 0.55, 0.7, 0.8, 0.6];
1665
1666        let result = wilcoxon_signed_rank_test(&metric_values1, &metric_values2).unwrap();
1667        assert_eq!(result.test_name, "Wilcoxon signed-rank");
1668        assert!(result.p_value >= 0.0 && result.p_value <= 1.0);
1669
1670        // Check additional info
1671        assert!(result.additional_info.contains_key("w_plus"));
1672        assert!(result.additional_info.contains_key("w_minus"));
1673        assert!(result.additional_info.contains_key("n_nonzero_differences"));
1674        assert!(result.additional_info.contains_key("z_score"));
1675    }
1676
1677    #[test]
1678    fn test_wilcoxon_signed_rank_test_identical_values() {
1679        let metric_values = vec![0.8, 0.7, 0.9, 0.6, 0.75];
1680
1681        // Should fail with all zero differences
1682        let result = wilcoxon_signed_rank_test(&metric_values, &metric_values);
1683        assert!(result.is_err());
1684    }
1685
1686    #[test]
1687    fn test_wilcoxon_signed_rank_test_error_handling() {
1688        let values1 = vec![0.8, 0.7];
1689        let values2 = vec![0.6]; // Different length
1690
1691        let result = wilcoxon_signed_rank_test(&values1, &values2);
1692        assert!(result.is_err());
1693
1694        // Too few observations
1695        let few = vec![0.8, 0.7];
1696        let result = wilcoxon_signed_rank_test(&few, &few);
1697        assert!(result.is_err());
1698    }
1699
1700    #[test]
1701    fn test_confidence_interval() {
1702        let metric_values = vec![0.8, 0.75, 0.85, 0.7, 0.9, 0.65, 0.8, 0.82, 0.78, 0.88];
1703
1704        let ci = confidence_interval(&metric_values, 0.95).unwrap();
1705
1706        assert_eq!(ci.confidence_level, 0.95);
1707        assert!(ci.lower <= ci.point_estimate); // Allow equal in edge cases
1708        assert!(ci.point_estimate <= ci.upper);
1709
1710        // Point estimate should be the mean
1711        let expected_mean = metric_values.iter().sum::<f64>() / metric_values.len() as f64;
1712        assert!((ci.point_estimate - expected_mean).abs() < 1e-10);
1713
1714        // For meaningful confidence intervals with sufficient data, bounds should be different
1715        if metric_values.len() > 5 {
1716            assert!(ci.upper - ci.lower > 1e-6); // Should have some width
1717        }
1718
1719        // Test different confidence levels
1720        let ci_99 = confidence_interval(&metric_values, 0.99).unwrap();
1721        assert!(ci_99.upper - ci_99.lower >= ci.upper - ci.lower); // Should be wider or equal
1722    }
1723
1724    #[test]
1725    fn test_confidence_interval_single_value() {
1726        let single_value = vec![0.8];
1727
1728        let ci = confidence_interval(&single_value, 0.95).unwrap();
1729
1730        // With single value, lower and upper should equal the point estimate
1731        assert_eq!(ci.lower, ci.point_estimate);
1732        assert_eq!(ci.upper, ci.point_estimate);
1733        assert_eq!(ci.point_estimate, 0.8);
1734    }
1735
1736    #[test]
1737    fn test_confidence_interval_error_handling() {
1738        let empty = vec![];
1739        let result = confidence_interval(&empty, 0.95);
1740        assert!(result.is_err());
1741
1742        let values = vec![0.8, 0.7];
1743        let result = confidence_interval(&values, 0.0); // Invalid confidence level
1744        assert!(result.is_err());
1745
1746        let result = confidence_interval(&values, 1.0); // Invalid confidence level
1747        assert!(result.is_err());
1748    }
1749
1750    #[test]
1751    fn test_statistical_test_result_creation() {
1752        let mut info = HashMap::new();
1753        info.insert("degrees_of_freedom".to_string(), 9.0);
1754
1755        let result = StatisticalTestResult::new(2.5, 0.03, "Test".to_string(), Some(info));
1756
1757        assert_eq!(result.statistic, 2.5);
1758        assert_eq!(result.p_value, 0.03);
1759        assert!(result.is_significant); // p < 0.05
1760        assert_eq!(result.test_name, "Test");
1761        assert_eq!(result.additional_info.get("degrees_of_freedom"), Some(&9.0));
1762
1763        // Test non-significant result
1764        let non_sig = StatisticalTestResult::new(1.2, 0.15, "Non-sig".to_string(), None);
1765        assert!(!non_sig.is_significant); // p >= 0.05
1766        assert!(non_sig.additional_info.is_empty());
1767    }
1768
1769    #[test]
1770    fn test_distribution_helper_functions() {
1771        // Test standard normal CDF
1772        let z_zero = standard_normal_cdf(0.0);
1773        assert!((z_zero - 0.5).abs() < 1e-6);
1774
1775        let z_positive = standard_normal_cdf(1.96);
1776        assert!((z_positive - 0.975).abs() < 0.01); // Approximately 0.975
1777
1778        // Test normal quantile - debug the actual values first
1779        let q_median = normal_quantile(0.5);
1780        assert!(q_median.abs() < 1e-6); // Should be approximately 0
1781
1782        let q_975 = normal_quantile(0.975);
1783        assert!((q_975 - 1.96).abs() < 0.01); // Should be approximately 1.96
1784
1785        let q_025 = normal_quantile(0.025);
1786        assert!((q_025 + 1.96).abs() < 0.01); // Should be approximately -1.96
1787
1788        // Test edge cases
1789        assert_eq!(normal_quantile(0.0), f64::NEG_INFINITY);
1790        assert_eq!(normal_quantile(1.0), f64::INFINITY);
1791
1792        // Test that quantile and CDF are approximately inverse functions
1793        let test_values = vec![0.25, 0.5, 0.75]; // Use fewer test values
1794        for &p in &test_values {
1795            let q = normal_quantile(p);
1796            if q.is_finite() {
1797                let p_back = standard_normal_cdf(q);
1798                assert!((p - p_back).abs() < 0.05); // More lenient round-trip accuracy
1799            }
1800        }
1801    }
1802
1803    #[test]
1804    fn test_existing_metrics_compatibility() {
1805        let (y_true, y_pred) = create_test_data();
1806        let y_true_view = y_true.view();
1807        let y_pred_view = y_pred.view();
1808
1809        // Test that existing metrics still work
1810        let hamming = hamming_loss(&y_true_view, &y_pred_view).unwrap();
1811        assert!(hamming >= 0.0 && hamming <= 1.0);
1812
1813        let subset_acc = subset_accuracy(&y_true_view, &y_pred_view).unwrap();
1814        assert!(subset_acc >= 0.0 && subset_acc <= 1.0);
1815
1816        let jaccard = jaccard_score(&y_true_view, &y_pred_view).unwrap();
1817        assert!(jaccard >= 0.0 && jaccard <= 1.0);
1818
1819        let f1_micro = f1_score(&y_true_view, &y_pred_view, "micro").unwrap();
1820        assert!(f1_micro >= 0.0 && f1_micro <= 1.0);
1821
1822        let f1_macro = f1_score(&y_true_view, &y_pred_view, "macro").unwrap();
1823        assert!(f1_macro >= 0.0 && f1_macro <= 1.0);
1824
1825        let f1_samples = f1_score(&y_true_view, &y_pred_view, "samples").unwrap();
1826        assert!(f1_samples >= 0.0 && f1_samples <= 1.0);
1827    }
1828
1829    #[test]
1830    fn test_per_label_vs_global_metrics_consistency() {
1831        let (y_true, y_pred) = create_test_data();
1832        let y_true_view = y_true.view();
1833        let y_pred_view = y_pred.view();
1834
1835        let per_label = per_label_metrics(&y_true_view, &y_pred_view).unwrap();
1836        let global_f1_macro = f1_score(&y_true_view, &y_pred_view, "macro").unwrap();
1837        let per_label_f1_macro = per_label.macro_average("f1_score").unwrap();
1838
1839        // The macro F1 from per-label metrics should match global macro F1
1840        assert!((global_f1_macro - per_label_f1_macro).abs() < 1e-10);
1841    }
1842
1843    #[test]
1844    fn test_comprehensive_statistical_workflow() {
1845        // Simulate a complete statistical comparison workflow
1846        let y_true = array![
1847            [1, 0, 1, 0],
1848            [0, 1, 0, 1],
1849            [1, 1, 1, 0],
1850            [0, 0, 0, 1],
1851            [1, 0, 1, 1]
1852        ];
1853
1854        // Two different classifiers
1855        let y_pred1 = array![
1856            [1, 0, 0, 0],
1857            [0, 1, 1, 1],
1858            [1, 0, 1, 0],
1859            [1, 0, 0, 1],
1860            [1, 1, 1, 0]
1861        ];
1862
1863        let y_pred2 = array![
1864            [0, 1, 1, 0],
1865            [1, 0, 0, 0],
1866            [0, 1, 0, 1],
1867            [0, 1, 1, 0],
1868            [0, 0, 0, 1]
1869        ];
1870
1871        // Get per-label metrics for both classifiers
1872        let metrics1 = per_label_metrics(&y_true.view(), &y_pred1.view()).unwrap();
1873        let metrics2 = per_label_metrics(&y_true.view(), &y_pred2.view()).unwrap();
1874
1875        // Compare using McNemar's test
1876        let mcnemar_result =
1877            mcnemar_test(&y_true.view(), &y_pred1.view(), &y_pred2.view()).unwrap();
1878
1879        // Compare F1 scores using t-test
1880        let t_test_result = paired_t_test(&metrics1.f1_score, &metrics2.f1_score).unwrap();
1881
1882        // Compare F1 scores using Wilcoxon test
1883        let wilcoxon_result =
1884            wilcoxon_signed_rank_test(&metrics1.f1_score, &metrics2.f1_score).unwrap();
1885
1886        // Get confidence interval for first classifier's F1 scores
1887        let ci_result = confidence_interval(&metrics1.f1_score, 0.95).unwrap();
1888
1889        // All results should be valid
1890        assert!(mcnemar_result.p_value >= 0.0 && mcnemar_result.p_value <= 1.0);
1891        assert!(t_test_result.p_value >= 0.0 && t_test_result.p_value <= 1.0);
1892        assert!(wilcoxon_result.p_value >= 0.0 && wilcoxon_result.p_value <= 1.0);
1893        assert!(ci_result.lower <= ci_result.point_estimate);
1894        assert!(ci_result.point_estimate <= ci_result.upper);
1895
1896        // Results should contain proper metadata
1897        assert!(!mcnemar_result.additional_info.is_empty());
1898        assert!(!t_test_result.additional_info.is_empty());
1899        assert!(!wilcoxon_result.additional_info.is_empty());
1900    }
1901}