1use scirs2_core::ndarray::ArrayView2;
5use sklears_core::error::{Result as SklResult, SklearsError};
6
7pub fn hamming_loss(y_true: &ArrayView2<'_, i32>, y_pred: &ArrayView2<'_, i32>) -> SklResult<f64> {
21 if y_true.dim() != y_pred.dim() {
22 return Err(SklearsError::InvalidInput(
23 "y_true and y_pred must have the same shape".to_string(),
24 ));
25 }
26
27 let (n_samples, n_labels) = y_true.dim();
28 if n_samples == 0 || n_labels == 0 {
29 return Err(SklearsError::InvalidInput(
30 "Input arrays must have at least one sample and one label".to_string(),
31 ));
32 }
33
34 let mut total_errors = 0;
35 let total_elements = n_samples * n_labels;
36
37 for sample_idx in 0..n_samples {
38 for label_idx in 0..n_labels {
39 if y_true[[sample_idx, label_idx]] != y_pred[[sample_idx, label_idx]] {
40 total_errors += 1;
41 }
42 }
43 }
44
45 Ok(total_errors as f64 / total_elements as f64)
46}
47
48pub fn subset_accuracy(
62 y_true: &ArrayView2<'_, i32>,
63 y_pred: &ArrayView2<'_, i32>,
64) -> SklResult<f64> {
65 if y_true.dim() != y_pred.dim() {
66 return Err(SklearsError::InvalidInput(
67 "y_true and y_pred must have the same shape".to_string(),
68 ));
69 }
70
71 let (n_samples, n_labels) = y_true.dim();
72 if n_samples == 0 {
73 return Err(SklearsError::InvalidInput(
74 "Input arrays must have at least one sample".to_string(),
75 ));
76 }
77
78 let mut correct_subsets = 0;
79
80 for sample_idx in 0..n_samples {
81 let mut subset_correct = true;
82 for label_idx in 0..n_labels {
83 if y_true[[sample_idx, label_idx]] != y_pred[[sample_idx, label_idx]] {
84 subset_correct = false;
85 break;
86 }
87 }
88 if subset_correct {
89 correct_subsets += 1;
90 }
91 }
92
93 Ok(correct_subsets as f64 / n_samples as f64)
94}
95
96pub fn jaccard_score(y_true: &ArrayView2<'_, i32>, y_pred: &ArrayView2<'_, i32>) -> SklResult<f64> {
110 if y_true.dim() != y_pred.dim() {
111 return Err(SklearsError::InvalidInput(
112 "y_true and y_pred must have the same shape".to_string(),
113 ));
114 }
115
116 let (n_samples, n_labels) = y_true.dim();
117 if n_samples == 0 {
118 return Err(SklearsError::InvalidInput(
119 "Input arrays must have at least one sample".to_string(),
120 ));
121 }
122
123 let mut total_jaccard = 0.0;
124
125 for sample_idx in 0..n_samples {
126 let mut intersection = 0;
127 let mut union = 0;
128
129 for label_idx in 0..n_labels {
130 let true_label = y_true[[sample_idx, label_idx]];
131 let pred_label = y_pred[[sample_idx, label_idx]];
132
133 if true_label == 1 && pred_label == 1 {
134 intersection += 1;
135 }
136 if true_label == 1 || pred_label == 1 {
137 union += 1;
138 }
139 }
140
141 let sample_jaccard = if union > 0 {
143 intersection as f64 / union as f64
144 } else {
145 1.0 };
147
148 total_jaccard += sample_jaccard;
149 }
150
151 Ok(total_jaccard / n_samples as f64)
152}
153
154pub fn f1_score(
168 y_true: &ArrayView2<'_, i32>,
169 y_pred: &ArrayView2<'_, i32>,
170 average: &str,
171) -> SklResult<f64> {
172 if y_true.dim() != y_pred.dim() {
173 return Err(SklearsError::InvalidInput(
174 "y_true and y_pred must have the same shape".to_string(),
175 ));
176 }
177
178 let (n_samples, n_labels) = y_true.dim();
179 if n_samples == 0 || n_labels == 0 {
180 return Err(SklearsError::InvalidInput(
181 "Input arrays must have at least one sample and one label".to_string(),
182 ));
183 }
184
185 match average {
186 "micro" => {
187 let mut total_tp = 0;
189 let mut total_fp = 0;
190 let mut total_false_negatives = 0;
191
192 for sample_idx in 0..n_samples {
193 for label_idx in 0..n_labels {
194 let true_label = y_true[[sample_idx, label_idx]];
195 let pred_label = y_pred[[sample_idx, label_idx]];
196
197 if true_label == 1 && pred_label == 1 {
198 total_tp += 1;
199 } else if true_label == 0 && pred_label == 1 {
200 total_fp += 1;
201 } else if true_label == 1 && pred_label == 0 {
202 total_false_negatives += 1;
203 }
204 }
205 }
206
207 let precision = if total_tp + total_fp > 0 {
208 total_tp as f64 / (total_tp + total_fp) as f64
209 } else {
210 0.0
211 };
212
213 let recall = if total_tp + total_false_negatives > 0 {
214 total_tp as f64 / (total_tp + total_false_negatives) as f64
215 } else {
216 0.0
217 };
218
219 let f1 = if precision + recall > 0.0 {
220 2.0 * precision * recall / (precision + recall)
221 } else {
222 0.0
223 };
224
225 Ok(f1)
226 }
227 "macro" => {
228 let mut label_f1_scores = Vec::new();
230
231 for label_idx in 0..n_labels {
232 let mut tp = 0;
233 let mut fp = 0;
234 let mut false_negatives = 0;
235
236 for sample_idx in 0..n_samples {
237 let true_label = y_true[[sample_idx, label_idx]];
238 let pred_label = y_pred[[sample_idx, label_idx]];
239
240 if true_label == 1 && pred_label == 1 {
241 tp += 1;
242 } else if true_label == 0 && pred_label == 1 {
243 fp += 1;
244 } else if true_label == 1 && pred_label == 0 {
245 false_negatives += 1;
246 }
247 }
248
249 let precision = if tp + fp > 0 {
250 tp as f64 / (tp + fp) as f64
251 } else {
252 0.0
253 };
254
255 let recall = if tp + false_negatives > 0 {
256 tp as f64 / (tp + false_negatives) as f64
257 } else {
258 0.0
259 };
260
261 let f1 = if precision + recall > 0.0 {
262 2.0 * precision * recall / (precision + recall)
263 } else {
264 0.0
265 };
266
267 label_f1_scores.push(f1);
268 }
269
270 Ok(label_f1_scores.iter().sum::<f64>() / n_labels as f64)
271 }
272 "samples" => {
273 let mut sample_f1_scores = Vec::new();
275
276 for sample_idx in 0..n_samples {
277 let mut tp = 0;
278 let mut fp = 0;
279 let mut false_negatives = 0;
280
281 for label_idx in 0..n_labels {
282 let true_label = y_true[[sample_idx, label_idx]];
283 let pred_label = y_pred[[sample_idx, label_idx]];
284
285 if true_label == 1 && pred_label == 1 {
286 tp += 1;
287 } else if true_label == 0 && pred_label == 1 {
288 fp += 1;
289 } else if true_label == 1 && pred_label == 0 {
290 false_negatives += 1;
291 }
292 }
293
294 let precision = if tp + fp > 0 {
295 tp as f64 / (tp + fp) as f64
296 } else {
297 0.0
298 };
299
300 let recall = if tp + false_negatives > 0 {
301 tp as f64 / (tp + false_negatives) as f64
302 } else {
303 0.0
304 };
305
306 let f1 = if precision + recall > 0.0 {
307 2.0 * precision * recall / (precision + recall)
308 } else {
309 0.0
310 };
311
312 sample_f1_scores.push(f1);
313 }
314
315 Ok(sample_f1_scores.iter().sum::<f64>() / n_samples as f64)
316 }
317 _ => Err(SklearsError::InvalidInput(format!(
318 "Unknown average type: {}. Valid options are 'micro', 'macro', 'samples'",
319 average
320 ))),
321 }
322}
323
324pub fn coverage_error(
339 y_true: &ArrayView2<'_, i32>,
340 y_scores: &ArrayView2<'_, f64>,
341) -> SklResult<f64> {
342 if y_true.dim() != y_scores.dim() {
343 return Err(SklearsError::InvalidInput(
344 "y_true and y_scores must have the same shape".to_string(),
345 ));
346 }
347
348 let (n_samples, n_labels) = y_true.dim();
349 if n_samples == 0 || n_labels == 0 {
350 return Err(SklearsError::InvalidInput(
351 "Input arrays must have at least one sample and one label".to_string(),
352 ));
353 }
354
355 let mut total_coverage = 0.0;
356
357 for sample_idx in 0..n_samples {
358 let mut score_label_pairs: Vec<(f64, usize)> = (0..n_labels)
360 .map(|label_idx| (y_scores[[sample_idx, label_idx]], label_idx))
361 .collect();
362 score_label_pairs
363 .sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
364
365 let mut last_true_position = 0;
367 for (position, &(_, label_idx)) in score_label_pairs.iter().enumerate() {
368 if y_true[[sample_idx, label_idx]] == 1 {
369 last_true_position = position + 1; }
371 }
372
373 total_coverage += last_true_position as f64;
374 }
375
376 Ok(total_coverage / n_samples as f64)
377}
378
379pub fn label_ranking_average_precision(
394 y_true: &ArrayView2<'_, i32>,
395 y_scores: &ArrayView2<'_, f64>,
396) -> SklResult<f64> {
397 if y_true.dim() != y_scores.dim() {
398 return Err(SklearsError::InvalidInput(
399 "y_true and y_scores must have the same shape".to_string(),
400 ));
401 }
402
403 let (n_samples, n_labels) = y_true.dim();
404 if n_samples == 0 || n_labels == 0 {
405 return Err(SklearsError::InvalidInput(
406 "Input arrays must have at least one sample and one label".to_string(),
407 ));
408 }
409
410 let mut total_lrap = 0.0;
411
412 for sample_idx in 0..n_samples {
413 let mut score_label_pairs: Vec<(f64, usize)> = (0..n_labels)
415 .map(|label_idx| (y_scores[[sample_idx, label_idx]], label_idx))
416 .collect();
417 score_label_pairs
418 .sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
419
420 let n_true_labels: i32 = (0..n_labels)
422 .map(|label_idx| y_true[[sample_idx, label_idx]])
423 .sum();
424
425 if n_true_labels == 0 {
426 continue; }
428
429 let mut precision_sum = 0.0;
430 let mut true_labels_seen = 0;
431
432 for (position, &(_, label_idx)) in score_label_pairs.iter().enumerate() {
433 if y_true[[sample_idx, label_idx]] == 1 {
434 true_labels_seen += 1;
435 let precision_at_position = true_labels_seen as f64 / (position + 1) as f64;
436 precision_sum += precision_at_position;
437 }
438 }
439
440 let sample_lrap = precision_sum / n_true_labels as f64;
441 total_lrap += sample_lrap;
442 }
443
444 Ok(total_lrap / n_samples as f64)
445}
446
447pub fn one_error(y_true: &ArrayView2<'_, i32>, y_scores: &ArrayView2<'_, f64>) -> SklResult<f64> {
462 if y_true.dim() != y_scores.dim() {
463 return Err(SklearsError::InvalidInput(
464 "y_true and y_scores must have the same shape".to_string(),
465 ));
466 }
467
468 let (n_samples, n_labels) = y_true.dim();
469 if n_samples == 0 || n_labels == 0 {
470 return Err(SklearsError::InvalidInput(
471 "Input arrays must have at least one sample and one label".to_string(),
472 ));
473 }
474
475 let mut errors = 0;
476
477 for sample_idx in 0..n_samples {
478 let mut max_score = f64::NEG_INFINITY;
480 let mut top_label_idx = 0;
481
482 for label_idx in 0..n_labels {
483 let score = y_scores[[sample_idx, label_idx]];
484 if score > max_score {
485 max_score = score;
486 top_label_idx = label_idx;
487 }
488 }
489
490 if y_true[[sample_idx, top_label_idx]] != 1 {
492 errors += 1;
493 }
494 }
495
496 Ok(errors as f64 / n_samples as f64)
497}
498
499pub fn ranking_loss(
514 y_true: &ArrayView2<'_, i32>,
515 y_scores: &ArrayView2<'_, f64>,
516) -> SklResult<f64> {
517 if y_true.dim() != y_scores.dim() {
518 return Err(SklearsError::InvalidInput(
519 "y_true and y_scores must have the same shape".to_string(),
520 ));
521 }
522
523 let (n_samples, n_labels) = y_true.dim();
524 if n_samples == 0 || n_labels == 0 {
525 return Err(SklearsError::InvalidInput(
526 "Input arrays must have at least one sample and one label".to_string(),
527 ));
528 }
529
530 let mut total_ranking_loss = 0.0;
531
532 for sample_idx in 0..n_samples {
533 let mut incorrect_pairs = 0;
534 let mut total_pairs = 0;
535
536 for i in 0..n_labels {
538 for j in 0..n_labels {
539 if i != j {
540 let true_i = y_true[[sample_idx, i]];
541 let true_j = y_true[[sample_idx, j]];
542 let score_i = y_scores[[sample_idx, i]];
543 let score_j = y_scores[[sample_idx, j]];
544
545 if (true_i == 1 && true_j == 0) || (true_i == 0 && true_j == 1) {
547 total_pairs += 1;
548
549 if (true_i == 1 && true_j == 0 && score_i < score_j)
551 || (true_i == 0 && true_j == 1 && score_i > score_j)
552 {
553 incorrect_pairs += 1;
554 }
555 }
556 }
557 }
558 }
559
560 if total_pairs > 0 {
562 total_ranking_loss += incorrect_pairs as f64 / total_pairs as f64;
563 }
564 }
565
566 Ok(total_ranking_loss / n_samples as f64)
567}
568
569pub fn average_precision_score(
584 y_true: &ArrayView2<'_, i32>,
585 y_scores: &ArrayView2<'_, f64>,
586) -> SklResult<f64> {
587 if y_true.dim() != y_scores.dim() {
588 return Err(SklearsError::InvalidInput(
589 "y_true and y_scores must have the same shape".to_string(),
590 ));
591 }
592
593 let (n_samples, n_labels) = y_true.dim();
594 if n_samples == 0 || n_labels == 0 {
595 return Err(SklearsError::InvalidInput(
596 "Input arrays must have at least one sample and one label".to_string(),
597 ));
598 }
599
600 let mut total_ap = 0.0;
601
602 for sample_idx in 0..n_samples {
603 let mut score_label_pairs: Vec<(f64, usize)> = (0..n_labels)
605 .map(|label_idx| (y_scores[[sample_idx, label_idx]], label_idx))
606 .collect();
607 score_label_pairs
608 .sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
609
610 let n_true_labels: i32 = (0..n_labels)
612 .map(|label_idx| y_true[[sample_idx, label_idx]])
613 .sum();
614
615 if n_true_labels == 0 {
616 continue; }
618
619 let mut precision_sum = 0.0;
620 let mut true_labels_seen = 0;
621
622 for (position, &(_, label_idx)) in score_label_pairs.iter().enumerate() {
623 if y_true[[sample_idx, label_idx]] == 1 {
624 true_labels_seen += 1;
625 let precision_at_position = true_labels_seen as f64 / (position + 1) as f64;
626 precision_sum += precision_at_position;
627 }
628 }
629
630 let sample_ap = precision_sum / n_true_labels as f64;
631 total_ap += sample_ap;
632 }
633
634 Ok(total_ap / n_samples as f64)
635}
636
637pub fn precision_score_micro(
651 y_true: &ArrayView2<'_, i32>,
652 y_pred: &ArrayView2<'_, i32>,
653) -> SklResult<f64> {
654 if y_true.dim() != y_pred.dim() {
655 return Err(SklearsError::InvalidInput(
656 "y_true and y_pred must have the same shape".to_string(),
657 ));
658 }
659
660 let (n_samples, n_labels) = y_true.dim();
661 if n_samples == 0 || n_labels == 0 {
662 return Err(SklearsError::InvalidInput(
663 "Input arrays must have at least one sample and one label".to_string(),
664 ));
665 }
666
667 let mut total_tp = 0;
668 let mut total_fp = 0;
669
670 for sample_idx in 0..n_samples {
671 for label_idx in 0..n_labels {
672 let true_label = y_true[[sample_idx, label_idx]];
673 let pred_label = y_pred[[sample_idx, label_idx]];
674
675 if true_label == 1 && pred_label == 1 {
676 total_tp += 1;
677 } else if true_label == 0 && pred_label == 1 {
678 total_fp += 1;
679 }
680 }
681 }
682
683 let precision = if total_tp + total_fp > 0 {
684 total_tp as f64 / (total_tp + total_fp) as f64
685 } else {
686 0.0
687 };
688
689 Ok(precision)
690}
691
692pub fn recall_score_micro(
706 y_true: &ArrayView2<'_, i32>,
707 y_pred: &ArrayView2<'_, i32>,
708) -> SklResult<f64> {
709 if y_true.dim() != y_pred.dim() {
710 return Err(SklearsError::InvalidInput(
711 "y_true and y_pred must have the same shape".to_string(),
712 ));
713 }
714
715 let (n_samples, n_labels) = y_true.dim();
716 if n_samples == 0 || n_labels == 0 {
717 return Err(SklearsError::InvalidInput(
718 "Input arrays must have at least one sample and one label".to_string(),
719 ));
720 }
721
722 let mut total_tp = 0;
723 let mut total_fn = 0;
724
725 for sample_idx in 0..n_samples {
726 for label_idx in 0..n_labels {
727 let true_label = y_true[[sample_idx, label_idx]];
728 let pred_label = y_pred[[sample_idx, label_idx]];
729
730 if true_label == 1 && pred_label == 1 {
731 total_tp += 1;
732 } else if true_label == 1 && pred_label == 0 {
733 total_fn += 1;
734 }
735 }
736 }
737
738 let recall = if total_tp + total_fn > 0 {
739 total_tp as f64 / (total_tp + total_fn) as f64
740 } else {
741 0.0
742 };
743
744 Ok(recall)
745}
746
747use std::collections::HashMap;
749
750#[derive(Debug, Clone)]
755pub struct PerLabelMetrics {
756 pub precision: Vec<f64>,
758 pub recall: Vec<f64>,
760 pub f1_score: Vec<f64>,
762 pub support: Vec<usize>,
764 pub accuracy: Vec<f64>,
766 pub n_labels: usize,
768}
769
770impl PerLabelMetrics {
771 pub fn macro_average(&self, metric: &str) -> SklResult<f64> {
773 let values = match metric {
774 "precision" => &self.precision,
775 "recall" => &self.recall,
776 "f1_score" => &self.f1_score,
777 "accuracy" => &self.accuracy,
778 _ => return Err(SklearsError::InvalidInput(format!(
779 "Unknown metric: {}. Valid options are 'precision', 'recall', 'f1_score', 'accuracy'",
780 metric
781 )))
782 };
783
784 Ok(values.iter().sum::<f64>() / values.len() as f64)
785 }
786
787 pub fn weighted_average(&self, metric: &str) -> SklResult<f64> {
789 let values = match metric {
790 "precision" => &self.precision,
791 "recall" => &self.recall,
792 "f1_score" => &self.f1_score,
793 "accuracy" => &self.accuracy,
794 _ => return Err(SklearsError::InvalidInput(format!(
795 "Unknown metric: {}. Valid options are 'precision', 'recall', 'f1_score', 'accuracy'",
796 metric
797 )))
798 };
799
800 let total_support: usize = self.support.iter().sum();
801 if total_support == 0 {
802 return Ok(0.0);
803 }
804
805 let weighted_sum: f64 = values
806 .iter()
807 .zip(self.support.iter())
808 .map(|(value, support)| value * (*support as f64))
809 .sum();
810
811 Ok(weighted_sum / total_support as f64)
812 }
813}
814
815pub fn per_label_metrics(
829 y_true: &ArrayView2<'_, i32>,
830 y_pred: &ArrayView2<'_, i32>,
831) -> SklResult<PerLabelMetrics> {
832 if y_true.dim() != y_pred.dim() {
833 return Err(SklearsError::InvalidInput(
834 "y_true and y_pred must have the same shape".to_string(),
835 ));
836 }
837
838 let (n_samples, n_labels) = y_true.dim();
839 if n_samples == 0 || n_labels == 0 {
840 return Err(SklearsError::InvalidInput(
841 "Input arrays must have at least one sample and one label".to_string(),
842 ));
843 }
844
845 let mut precision = Vec::with_capacity(n_labels);
846 let mut recall = Vec::with_capacity(n_labels);
847 let mut f1_score = Vec::with_capacity(n_labels);
848 let mut support = Vec::with_capacity(n_labels);
849 let mut accuracy = Vec::with_capacity(n_labels);
850
851 for label_idx in 0..n_labels {
853 let mut tp = 0;
854 let mut fp = 0;
855 let mut fn_count = 0;
856 let mut tn = 0;
857
858 for sample_idx in 0..n_samples {
859 let true_label = y_true[[sample_idx, label_idx]];
860 let pred_label = y_pred[[sample_idx, label_idx]];
861
862 match (true_label, pred_label) {
863 (1, 1) => tp += 1,
864 (0, 1) => fp += 1,
865 (1, 0) => fn_count += 1,
866 (0, 0) => tn += 1,
867 _ => {} }
869 }
870
871 let label_precision = if tp + fp > 0 {
873 tp as f64 / (tp + fp) as f64
874 } else {
875 0.0
876 };
877
878 let label_recall = if tp + fn_count > 0 {
880 tp as f64 / (tp + fn_count) as f64
881 } else {
882 0.0
883 };
884
885 let label_f1 = if label_precision + label_recall > 0.0 {
887 2.0 * label_precision * label_recall / (label_precision + label_recall)
888 } else {
889 0.0
890 };
891
892 let label_accuracy = (tp + tn) as f64 / n_samples as f64;
894
895 let label_support = (tp + fn_count) as usize;
897
898 precision.push(label_precision);
899 recall.push(label_recall);
900 f1_score.push(label_f1);
901 support.push(label_support);
902 accuracy.push(label_accuracy);
903 }
904
905 Ok(PerLabelMetrics {
906 precision,
907 recall,
908 f1_score,
909 support,
910 accuracy,
911 n_labels,
912 })
913}
914
915#[derive(Debug, Clone)]
917pub struct StatisticalTestResult {
918 pub statistic: f64,
920 pub p_value: f64,
922 pub is_significant: bool,
924 pub test_name: String,
926 pub additional_info: HashMap<String, f64>,
928}
929
930impl StatisticalTestResult {
931 pub fn new(
933 statistic: f64,
934 p_value: f64,
935 test_name: String,
936 additional_info: Option<HashMap<String, f64>>,
937 ) -> Self {
938 Self {
939 statistic,
940 p_value,
941 is_significant: p_value < 0.05,
942 test_name,
943 additional_info: additional_info.unwrap_or_default(),
944 }
945 }
946}
947
948pub fn mcnemar_test(
963 y_true: &ArrayView2<'_, i32>,
964 y_pred1: &ArrayView2<'_, i32>,
965 y_pred2: &ArrayView2<'_, i32>,
966) -> SklResult<StatisticalTestResult> {
967 if y_true.dim() != y_pred1.dim() || y_true.dim() != y_pred2.dim() {
968 return Err(SklearsError::InvalidInput(
969 "All input arrays must have the same shape".to_string(),
970 ));
971 }
972
973 let (n_samples, n_labels) = y_true.dim();
974 if n_samples == 0 || n_labels == 0 {
975 return Err(SklearsError::InvalidInput(
976 "Input arrays must have at least one sample and one label".to_string(),
977 ));
978 }
979
980 let mut n01 = 0; let mut n10 = 0; for sample_idx in 0..n_samples {
985 for label_idx in 0..n_labels {
986 let true_label = y_true[[sample_idx, label_idx]];
987 let pred1 = y_pred1[[sample_idx, label_idx]];
988 let pred2 = y_pred2[[sample_idx, label_idx]];
989
990 let correct1 = pred1 == true_label;
991 let correct2 = pred2 == true_label;
992
993 match (correct1, correct2) {
994 (true, false) => n01 += 1,
995 (false, true) => n10 += 1,
996 _ => {} }
998 }
999 }
1000
1001 let total_disagreements = n01 + n10;
1003 if total_disagreements == 0 {
1004 return Ok(StatisticalTestResult::new(
1005 0.0,
1006 1.0, "McNemar".to_string(),
1008 Some({
1009 let mut info = HashMap::new();
1010 info.insert("n01".to_string(), n01 as f64);
1011 info.insert("n10".to_string(), n10 as f64);
1012 info.insert(
1013 "total_disagreements".to_string(),
1014 total_disagreements as f64,
1015 );
1016 info
1017 }),
1018 ));
1019 }
1020
1021 let statistic = ((n01 as f64 - n10 as f64).abs() - 1.0).max(0.0).powi(2) / (n01 + n10) as f64;
1023
1024 let p_value = chi_square_p_value(statistic, 1);
1026
1027 let mut info = HashMap::new();
1028 info.insert("n01".to_string(), n01 as f64);
1029 info.insert("n10".to_string(), n10 as f64);
1030 info.insert(
1031 "total_disagreements".to_string(),
1032 total_disagreements as f64,
1033 );
1034
1035 Ok(StatisticalTestResult::new(
1036 statistic,
1037 p_value,
1038 "McNemar".to_string(),
1039 Some(info),
1040 ))
1041}
1042
1043pub fn paired_t_test(
1057 metric_values1: &[f64],
1058 metric_values2: &[f64],
1059) -> SklResult<StatisticalTestResult> {
1060 if metric_values1.len() != metric_values2.len() {
1061 return Err(SklearsError::InvalidInput(
1062 "Metric value arrays must have the same length".to_string(),
1063 ));
1064 }
1065
1066 let n = metric_values1.len();
1067 if n < 2 {
1068 return Err(SklearsError::InvalidInput(
1069 "Need at least 2 paired observations for t-test".to_string(),
1070 ));
1071 }
1072
1073 let differences: Vec<f64> = metric_values1
1075 .iter()
1076 .zip(metric_values2.iter())
1077 .map(|(v1, v2)| v1 - v2)
1078 .collect();
1079
1080 let mean_diff = differences.iter().sum::<f64>() / n as f64;
1082
1083 let variance = differences
1085 .iter()
1086 .map(|d| (d - mean_diff).powi(2))
1087 .sum::<f64>()
1088 / (n - 1) as f64;
1089 let std_dev = variance.sqrt();
1090
1091 let t_statistic = mean_diff / (std_dev / (n as f64).sqrt());
1093
1094 let df = n - 1;
1096
1097 let p_value = 2.0 * (1.0 - t_distribution_cdf(t_statistic.abs(), df as f64));
1099
1100 let mut info = HashMap::new();
1101 info.insert("mean_difference".to_string(), mean_diff);
1102 info.insert("std_dev_diff".to_string(), std_dev);
1103 info.insert("degrees_of_freedom".to_string(), df as f64);
1104 info.insert("n_observations".to_string(), n as f64);
1105
1106 Ok(StatisticalTestResult::new(
1107 t_statistic,
1108 p_value,
1109 "Paired t-test".to_string(),
1110 Some(info),
1111 ))
1112}
1113
1114pub fn wilcoxon_signed_rank_test(
1128 metric_values1: &[f64],
1129 metric_values2: &[f64],
1130) -> SklResult<StatisticalTestResult> {
1131 if metric_values1.len() != metric_values2.len() {
1132 return Err(SklearsError::InvalidInput(
1133 "Metric value arrays must have the same length".to_string(),
1134 ));
1135 }
1136
1137 let n = metric_values1.len();
1138 if n < 3 {
1139 return Err(SklearsError::InvalidInput(
1140 "Need at least 3 paired observations for Wilcoxon signed-rank test".to_string(),
1141 ));
1142 }
1143
1144 let mut differences_with_abs: Vec<(f64, f64, bool)> = metric_values1
1146 .iter()
1147 .zip(metric_values2.iter())
1148 .map(|(v1, v2)| {
1149 let diff = v1 - v2;
1150 (diff, diff.abs(), diff > 0.0)
1151 })
1152 .filter(|(_, abs_diff, _)| *abs_diff > 1e-10) .collect();
1154
1155 let n_nonzero = differences_with_abs.len();
1156 if n_nonzero < 3 {
1157 return Err(SklearsError::InvalidInput(
1158 "Too many zero differences for Wilcoxon test".to_string(),
1159 ));
1160 }
1161
1162 differences_with_abs.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
1164
1165 let mut ranks = vec![0.0; n_nonzero];
1167 let mut i = 0;
1168 while i < n_nonzero {
1169 let current_abs_diff = differences_with_abs[i].1;
1170 let mut j = i;
1171
1172 while j < n_nonzero && (differences_with_abs[j].1 - current_abs_diff).abs() < 1e-10 {
1174 j += 1;
1175 }
1176
1177 let avg_rank = (i + j + 1) as f64 / 2.0;
1179 for rank in ranks.iter_mut().take(j).skip(i) {
1180 *rank = avg_rank;
1181 }
1182
1183 i = j;
1184 }
1185
1186 let mut w_plus = 0.0;
1188 let mut w_minus = 0.0;
1189
1190 for i in 0..n_nonzero {
1191 if differences_with_abs[i].2 {
1192 w_plus += ranks[i];
1194 } else {
1195 w_minus += ranks[i];
1197 }
1198 }
1199
1200 let w_statistic = w_plus.min(w_minus);
1202
1203 let expected_w = (n_nonzero * (n_nonzero + 1)) as f64 / 4.0;
1205 let variance_w = (n_nonzero * (n_nonzero + 1) * (2 * n_nonzero + 1)) as f64 / 24.0;
1206 let std_w = variance_w.sqrt();
1207
1208 let z_score = ((w_statistic - expected_w).abs() - 0.5) / std_w;
1210
1211 let p_value = 2.0 * (1.0 - standard_normal_cdf(z_score));
1213
1214 let mut info = HashMap::new();
1215 info.insert("w_plus".to_string(), w_plus);
1216 info.insert("w_minus".to_string(), w_minus);
1217 info.insert("n_nonzero_differences".to_string(), n_nonzero as f64);
1218 info.insert("z_score".to_string(), z_score);
1219
1220 Ok(StatisticalTestResult::new(
1221 w_statistic,
1222 p_value,
1223 "Wilcoxon signed-rank".to_string(),
1224 Some(info),
1225 ))
1226}
1227
1228#[derive(Debug, Clone)]
1230pub struct ConfidenceInterval {
1231 pub lower: f64,
1233 pub upper: f64,
1235 pub point_estimate: f64,
1237 pub confidence_level: f64,
1239}
1240
1241pub fn confidence_interval(
1254 metric_values: &[f64],
1255 confidence_level: f64,
1256) -> SklResult<ConfidenceInterval> {
1257 if metric_values.is_empty() {
1258 return Err(SklearsError::InvalidInput(
1259 "Metric values array cannot be empty".to_string(),
1260 ));
1261 }
1262
1263 if confidence_level <= 0.0 || confidence_level >= 1.0 {
1264 return Err(SklearsError::InvalidInput(
1265 "Confidence level must be between 0 and 1".to_string(),
1266 ));
1267 }
1268
1269 let n = metric_values.len();
1270 let mean = metric_values.iter().sum::<f64>() / n as f64;
1271
1272 if n == 1 {
1273 return Ok(ConfidenceInterval {
1274 lower: mean,
1275 upper: mean,
1276 point_estimate: mean,
1277 confidence_level,
1278 });
1279 }
1280
1281 let variance = metric_values
1283 .iter()
1284 .map(|v| (v - mean).powi(2))
1285 .sum::<f64>()
1286 / (n - 1) as f64;
1287 let std_error = (variance / n as f64).sqrt();
1288
1289 let alpha = 1.0 - confidence_level;
1291 let df = (n - 1) as f64;
1292 let t_critical = t_distribution_quantile(1.0 - alpha / 2.0, df);
1293
1294 let margin_error = t_critical * std_error;
1296
1297 Ok(ConfidenceInterval {
1298 lower: mean - margin_error,
1299 upper: mean + margin_error,
1300 point_estimate: mean,
1301 confidence_level,
1302 })
1303}
1304
1305fn chi_square_p_value(x: f64, df: usize) -> f64 {
1309 if df == 1 {
1310 2.0 * (1.0 - standard_normal_cdf(x.sqrt()))
1312 } else {
1313 let normalized = (x - df as f64) / (2.0 * df as f64).sqrt();
1315 2.0 * (1.0 - standard_normal_cdf(normalized.abs()))
1316 }
1317}
1318
1319fn standard_normal_cdf(z: f64) -> f64 {
1321 0.5 * (1.0 + erf(z / 2.0_f64.sqrt()))
1322}
1323
1324fn erf(x: f64) -> f64 {
1326 let a1 = 0.254829592;
1328 let a2 = -0.284496736;
1329 let a3 = 1.421413741;
1330 let a4 = -1.453152027;
1331 let a5 = 1.061405429;
1332 let p = 0.3275911;
1333
1334 let sign = if x >= 0.0 { 1.0 } else { -1.0 };
1335 let x = x.abs();
1336
1337 let t = 1.0 / (1.0 + p * x);
1338 let y = 1.0 - ((((a5 * t + a4) * t + a3) * t + a2) * t + a1) * t * (-x * x).exp();
1339
1340 sign * y
1341}
1342
1343fn t_distribution_cdf(t: f64, df: f64) -> f64 {
1345 if df > 30.0 {
1346 standard_normal_cdf(t)
1348 } else {
1349 let normalized = t / (df + t * t).sqrt();
1351 0.5 + 0.5 * erf(normalized)
1352 }
1353}
1354
1355fn t_distribution_quantile(p: f64, df: f64) -> f64 {
1357 if df > 100.0 {
1358 normal_quantile(p)
1360 } else if df >= 2.0 {
1361 let z = normal_quantile(p);
1363 let h = 2.0 / (9.0 * df);
1364 let correction = z.powi(2) * h / 6.0;
1365 z * (1.0 + correction).max(0.1) } else {
1367 let z = normal_quantile(p);
1369 z * (1.0 + (z.powi(2) + 1.0) / (4.0 * df))
1370 }
1371}
1372
1373fn normal_quantile(p: f64) -> f64 {
1375 if p <= 0.0 {
1376 return f64::NEG_INFINITY;
1377 }
1378 if p >= 1.0 {
1379 return f64::INFINITY;
1380 }
1381 if (p - 0.5).abs() < f64::EPSILON {
1382 return 0.0;
1383 }
1384
1385 let known_values = [
1387 (0.001, -3.090232),
1388 (0.005, -2.575829),
1389 (0.01, -2.326348),
1390 (0.025, -1.959964),
1391 (0.05, -1.644854),
1392 (0.1, -1.281552),
1393 (0.15, -1.036433),
1394 (0.2, -0.841621),
1395 (0.25, -0.674490),
1396 (0.3, -0.524401),
1397 (0.35, -0.385320),
1398 (0.4, -0.253347),
1399 (0.45, -0.125661),
1400 (0.5, 0.0),
1401 (0.55, 0.125661),
1402 (0.6, 0.253347),
1403 (0.65, 0.385320),
1404 (0.7, 0.524401),
1405 (0.75, 0.674490),
1406 (0.8, 0.841621),
1407 (0.85, 1.036433),
1408 (0.9, 1.281552),
1409 (0.95, 1.644854),
1410 (0.975, 1.959964),
1411 (0.99, 2.326348),
1412 (0.995, 2.575829),
1413 (0.999, 3.090232),
1414 ];
1415
1416 if let Some(idx) = known_values.iter().position(|(prob, _)| *prob >= p) {
1418 if idx == 0 {
1419 return known_values[0].1;
1420 }
1421
1422 let (p1, z1) = known_values[idx - 1];
1423 let (p2, z2) = known_values[idx];
1424
1425 let weight = (p - p1) / (p2 - p1);
1427 z1 + weight * (z2 - z1)
1428 } else {
1429 3.5 }
1432}
1433
1434#[allow(non_snake_case)]
1435#[cfg(test)]
1436mod tests {
1437 use super::*;
1438 use scirs2_core::ndarray::{array, Array2};
1439
1440 fn create_test_data() -> (Array2<i32>, Array2<i32>) {
1442 let y_true = array![[1, 0, 1], [0, 1, 0], [1, 1, 1], [0, 0, 0], [1, 0, 1]];
1443 let y_pred = array![[1, 0, 0], [0, 1, 1], [1, 0, 1], [1, 0, 0], [1, 1, 1]];
1444 (y_true, y_pred)
1445 }
1446
1447 #[test]
1448 fn test_per_label_metrics_basic() {
1449 let (y_true, y_pred) = create_test_data();
1450 let y_true_view = y_true.view();
1451 let y_pred_view = y_pred.view();
1452
1453 let metrics = per_label_metrics(&y_true_view, &y_pred_view).unwrap();
1454
1455 assert_eq!(metrics.n_labels, 3);
1456 assert_eq!(metrics.precision.len(), 3);
1457 assert_eq!(metrics.recall.len(), 3);
1458 assert_eq!(metrics.f1_score.len(), 3);
1459 assert_eq!(metrics.support.len(), 3);
1460 assert_eq!(metrics.accuracy.len(), 3);
1461
1462 assert_eq!(metrics.support[0], 3); assert_eq!(metrics.support[1], 2); assert_eq!(metrics.support[2], 3); for i in 0..3 {
1469 assert!(metrics.precision[i] >= 0.0 && metrics.precision[i] <= 1.0);
1470 assert!(metrics.recall[i] >= 0.0 && metrics.recall[i] <= 1.0);
1471 assert!(metrics.f1_score[i] >= 0.0 && metrics.f1_score[i] <= 1.0);
1472 assert!(metrics.accuracy[i] >= 0.0 && metrics.accuracy[i] <= 1.0);
1473 }
1474 }
1475
1476 #[test]
1477 fn test_per_label_metrics_perfect_prediction() {
1478 let y_perfect = array![[1, 0, 1], [0, 1, 0], [1, 1, 1]];
1479 let y_true_view = y_perfect.view();
1480 let y_pred_view = y_perfect.view(); let metrics = per_label_metrics(&y_true_view, &y_pred_view).unwrap();
1483
1484 for i in 0..3 {
1486 assert!((metrics.precision[i] - 1.0).abs() < 1e-10);
1487 assert!((metrics.recall[i] - 1.0).abs() < 1e-10);
1488 assert!((metrics.f1_score[i] - 1.0).abs() < 1e-10);
1489 assert!((metrics.accuracy[i] - 1.0).abs() < 1e-10);
1490 }
1491 }
1492
1493 #[test]
1494 fn test_per_label_metrics_all_zeros() {
1495 let y_true = array![[0, 0, 0], [0, 0, 0], [0, 0, 0]];
1496 let y_pred = array![[0, 0, 0], [0, 0, 0], [0, 0, 0]];
1497 let y_true_view = y_true.view();
1498 let y_pred_view = y_pred.view();
1499
1500 let metrics = per_label_metrics(&y_true_view, &y_pred_view).unwrap();
1501
1502 for i in 0..3 {
1504 assert_eq!(metrics.support[i], 0);
1505 assert!((metrics.accuracy[i] - 1.0).abs() < 1e-10); assert!((metrics.precision[i] - 0.0).abs() < 1e-10); assert!((metrics.recall[i] - 0.0).abs() < 1e-10); }
1509 }
1510
1511 #[test]
1512 fn test_per_label_metrics_macro_average() {
1513 let (y_true, y_pred) = create_test_data();
1514 let y_true_view = y_true.view();
1515 let y_pred_view = y_pred.view();
1516
1517 let metrics = per_label_metrics(&y_true_view, &y_pred_view).unwrap();
1518
1519 let macro_precision = metrics.macro_average("precision").unwrap();
1520 let macro_recall = metrics.macro_average("recall").unwrap();
1521 let macro_f1 = metrics.macro_average("f1_score").unwrap();
1522 let macro_accuracy = metrics.macro_average("accuracy").unwrap();
1523
1524 let expected_precision = metrics.precision.iter().sum::<f64>() / 3.0;
1526 let expected_recall = metrics.recall.iter().sum::<f64>() / 3.0;
1527 let expected_f1 = metrics.f1_score.iter().sum::<f64>() / 3.0;
1528 let expected_accuracy = metrics.accuracy.iter().sum::<f64>() / 3.0;
1529
1530 assert!((macro_precision - expected_precision).abs() < 1e-10);
1531 assert!((macro_recall - expected_recall).abs() < 1e-10);
1532 assert!((macro_f1 - expected_f1).abs() < 1e-10);
1533 assert!((macro_accuracy - expected_accuracy).abs() < 1e-10);
1534 }
1535
1536 #[test]
1537 fn test_per_label_metrics_weighted_average() {
1538 let (y_true, y_pred) = create_test_data();
1539 let y_true_view = y_true.view();
1540 let y_pred_view = y_pred.view();
1541
1542 let metrics = per_label_metrics(&y_true_view, &y_pred_view).unwrap();
1543
1544 let weighted_precision = metrics.weighted_average("precision").unwrap();
1545 let weighted_recall = metrics.weighted_average("recall").unwrap();
1546 let weighted_f1 = metrics.weighted_average("f1_score").unwrap();
1547 let weighted_accuracy = metrics.weighted_average("accuracy").unwrap();
1548
1549 assert!(weighted_precision >= 0.0 && weighted_precision <= 1.0);
1551 assert!(weighted_recall >= 0.0 && weighted_recall <= 1.0);
1552 assert!(weighted_f1 >= 0.0 && weighted_f1 <= 1.0);
1553 assert!(weighted_accuracy >= 0.0 && weighted_accuracy <= 1.0);
1554
1555 assert!(metrics.weighted_average("invalid").is_err());
1557 assert!(metrics.macro_average("invalid").is_err());
1558 }
1559
1560 #[test]
1561 fn test_per_label_metrics_error_handling() {
1562 let y_true = array![[1, 0], [0, 1]];
1563 let y_pred = array![[1, 0, 1], [0, 1, 0]]; let result = per_label_metrics(&y_true.view(), &y_pred.view());
1566 assert!(result.is_err());
1567
1568 let empty_true = Array2::<i32>::zeros((0, 0));
1570 let empty_pred = Array2::<i32>::zeros((0, 0));
1571 let result = per_label_metrics(&empty_true.view(), &empty_pred.view());
1572 assert!(result.is_err());
1573 }
1574
1575 #[test]
1576 fn test_mcnemar_test_identical_classifiers() {
1577 let y_true = array![[1, 0, 1], [0, 1, 0], [1, 1, 1]];
1578 let y_pred = array![[1, 0, 0], [0, 1, 1], [1, 0, 1]];
1579
1580 let result = mcnemar_test(&y_true.view(), &y_pred.view(), &y_pred.view()).unwrap();
1582 assert_eq!(result.test_name, "McNemar");
1583 assert!((result.p_value - 1.0).abs() < 1e-10);
1584 assert!(!result.is_significant);
1585 assert_eq!(result.statistic, 0.0);
1586 }
1587
1588 #[test]
1589 fn test_mcnemar_test_different_classifiers() {
1590 let y_true = array![[1, 0, 1], [0, 1, 0], [1, 1, 1], [0, 0, 0], [1, 0, 1]];
1591 let y_pred1 = array![[1, 0, 0], [0, 1, 1], [1, 0, 1], [1, 0, 0], [1, 1, 1]];
1592 let y_pred2 = array![[0, 1, 1], [1, 0, 0], [0, 1, 0], [0, 1, 1], [0, 0, 0]];
1593
1594 let result = mcnemar_test(&y_true.view(), &y_pred1.view(), &y_pred2.view()).unwrap();
1595 assert_eq!(result.test_name, "McNemar");
1596 assert!(result.statistic >= 0.0);
1597 assert!(result.p_value >= 0.0 && result.p_value <= 1.0);
1598
1599 assert!(result.additional_info.contains_key("n01"));
1601 assert!(result.additional_info.contains_key("n10"));
1602 assert!(result.additional_info.contains_key("total_disagreements"));
1603 }
1604
1605 #[test]
1606 fn test_mcnemar_test_error_handling() {
1607 let y_true = array![[1, 0], [0, 1]];
1608 let y_pred1 = array![[1, 0], [0, 1]];
1609 let y_pred2 = array![[1, 0, 1], [0, 1, 0]]; let result = mcnemar_test(&y_true.view(), &y_pred1.view(), &y_pred2.view());
1612 assert!(result.is_err());
1613 }
1614
1615 #[test]
1616 fn test_paired_t_test() {
1617 let metric_values1 = vec![0.8, 0.7, 0.9, 0.6, 0.75];
1618 let metric_values2 = vec![0.7, 0.65, 0.85, 0.55, 0.7];
1619
1620 let result = paired_t_test(&metric_values1, &metric_values2).unwrap();
1621 assert_eq!(result.test_name, "Paired t-test");
1622 assert!(result.p_value >= 0.0 && result.p_value <= 1.0);
1623
1624 assert!(result.additional_info.contains_key("mean_difference"));
1626 assert!(result.additional_info.contains_key("std_dev_diff"));
1627 assert!(result.additional_info.contains_key("degrees_of_freedom"));
1628 assert!(result.additional_info.contains_key("n_observations"));
1629
1630 let mean_diff = result.additional_info.get("mean_difference").unwrap();
1632 assert!(*mean_diff > 0.0);
1633 }
1634
1635 #[test]
1636 fn test_paired_t_test_identical_values() {
1637 let metric_values = vec![0.8, 0.7, 0.9, 0.6, 0.75];
1638
1639 let result = paired_t_test(&metric_values, &metric_values).unwrap();
1640
1641 let mean_diff = result.additional_info.get("mean_difference").unwrap();
1643 assert!(mean_diff.abs() < 1e-10);
1644 assert!(!result.is_significant);
1645 }
1646
1647 #[test]
1648 fn test_paired_t_test_error_handling() {
1649 let values1 = vec![0.8, 0.7];
1650 let values2 = vec![0.6]; let result = paired_t_test(&values1, &values2);
1653 assert!(result.is_err());
1654
1655 let single = vec![0.8];
1657 let result = paired_t_test(&single, &single);
1658 assert!(result.is_err());
1659 }
1660
1661 #[test]
1662 fn test_wilcoxon_signed_rank_test() {
1663 let metric_values1 = vec![0.8, 0.7, 0.9, 0.6, 0.75, 0.85, 0.65];
1664 let metric_values2 = vec![0.7, 0.65, 0.85, 0.55, 0.7, 0.8, 0.6];
1665
1666 let result = wilcoxon_signed_rank_test(&metric_values1, &metric_values2).unwrap();
1667 assert_eq!(result.test_name, "Wilcoxon signed-rank");
1668 assert!(result.p_value >= 0.0 && result.p_value <= 1.0);
1669
1670 assert!(result.additional_info.contains_key("w_plus"));
1672 assert!(result.additional_info.contains_key("w_minus"));
1673 assert!(result.additional_info.contains_key("n_nonzero_differences"));
1674 assert!(result.additional_info.contains_key("z_score"));
1675 }
1676
1677 #[test]
1678 fn test_wilcoxon_signed_rank_test_identical_values() {
1679 let metric_values = vec![0.8, 0.7, 0.9, 0.6, 0.75];
1680
1681 let result = wilcoxon_signed_rank_test(&metric_values, &metric_values);
1683 assert!(result.is_err());
1684 }
1685
1686 #[test]
1687 fn test_wilcoxon_signed_rank_test_error_handling() {
1688 let values1 = vec![0.8, 0.7];
1689 let values2 = vec![0.6]; let result = wilcoxon_signed_rank_test(&values1, &values2);
1692 assert!(result.is_err());
1693
1694 let few = vec![0.8, 0.7];
1696 let result = wilcoxon_signed_rank_test(&few, &few);
1697 assert!(result.is_err());
1698 }
1699
1700 #[test]
1701 fn test_confidence_interval() {
1702 let metric_values = vec![0.8, 0.75, 0.85, 0.7, 0.9, 0.65, 0.8, 0.82, 0.78, 0.88];
1703
1704 let ci = confidence_interval(&metric_values, 0.95).unwrap();
1705
1706 assert_eq!(ci.confidence_level, 0.95);
1707 assert!(ci.lower <= ci.point_estimate); assert!(ci.point_estimate <= ci.upper);
1709
1710 let expected_mean = metric_values.iter().sum::<f64>() / metric_values.len() as f64;
1712 assert!((ci.point_estimate - expected_mean).abs() < 1e-10);
1713
1714 if metric_values.len() > 5 {
1716 assert!(ci.upper - ci.lower > 1e-6); }
1718
1719 let ci_99 = confidence_interval(&metric_values, 0.99).unwrap();
1721 assert!(ci_99.upper - ci_99.lower >= ci.upper - ci.lower); }
1723
1724 #[test]
1725 fn test_confidence_interval_single_value() {
1726 let single_value = vec![0.8];
1727
1728 let ci = confidence_interval(&single_value, 0.95).unwrap();
1729
1730 assert_eq!(ci.lower, ci.point_estimate);
1732 assert_eq!(ci.upper, ci.point_estimate);
1733 assert_eq!(ci.point_estimate, 0.8);
1734 }
1735
1736 #[test]
1737 fn test_confidence_interval_error_handling() {
1738 let empty = vec![];
1739 let result = confidence_interval(&empty, 0.95);
1740 assert!(result.is_err());
1741
1742 let values = vec![0.8, 0.7];
1743 let result = confidence_interval(&values, 0.0); assert!(result.is_err());
1745
1746 let result = confidence_interval(&values, 1.0); assert!(result.is_err());
1748 }
1749
1750 #[test]
1751 fn test_statistical_test_result_creation() {
1752 let mut info = HashMap::new();
1753 info.insert("degrees_of_freedom".to_string(), 9.0);
1754
1755 let result = StatisticalTestResult::new(2.5, 0.03, "Test".to_string(), Some(info));
1756
1757 assert_eq!(result.statistic, 2.5);
1758 assert_eq!(result.p_value, 0.03);
1759 assert!(result.is_significant); assert_eq!(result.test_name, "Test");
1761 assert_eq!(result.additional_info.get("degrees_of_freedom"), Some(&9.0));
1762
1763 let non_sig = StatisticalTestResult::new(1.2, 0.15, "Non-sig".to_string(), None);
1765 assert!(!non_sig.is_significant); assert!(non_sig.additional_info.is_empty());
1767 }
1768
1769 #[test]
1770 fn test_distribution_helper_functions() {
1771 let z_zero = standard_normal_cdf(0.0);
1773 assert!((z_zero - 0.5).abs() < 1e-6);
1774
1775 let z_positive = standard_normal_cdf(1.96);
1776 assert!((z_positive - 0.975).abs() < 0.01); let q_median = normal_quantile(0.5);
1780 assert!(q_median.abs() < 1e-6); let q_975 = normal_quantile(0.975);
1783 assert!((q_975 - 1.96).abs() < 0.01); let q_025 = normal_quantile(0.025);
1786 assert!((q_025 + 1.96).abs() < 0.01); assert_eq!(normal_quantile(0.0), f64::NEG_INFINITY);
1790 assert_eq!(normal_quantile(1.0), f64::INFINITY);
1791
1792 let test_values = vec![0.25, 0.5, 0.75]; for &p in &test_values {
1795 let q = normal_quantile(p);
1796 if q.is_finite() {
1797 let p_back = standard_normal_cdf(q);
1798 assert!((p - p_back).abs() < 0.05); }
1800 }
1801 }
1802
1803 #[test]
1804 fn test_existing_metrics_compatibility() {
1805 let (y_true, y_pred) = create_test_data();
1806 let y_true_view = y_true.view();
1807 let y_pred_view = y_pred.view();
1808
1809 let hamming = hamming_loss(&y_true_view, &y_pred_view).unwrap();
1811 assert!(hamming >= 0.0 && hamming <= 1.0);
1812
1813 let subset_acc = subset_accuracy(&y_true_view, &y_pred_view).unwrap();
1814 assert!(subset_acc >= 0.0 && subset_acc <= 1.0);
1815
1816 let jaccard = jaccard_score(&y_true_view, &y_pred_view).unwrap();
1817 assert!(jaccard >= 0.0 && jaccard <= 1.0);
1818
1819 let f1_micro = f1_score(&y_true_view, &y_pred_view, "micro").unwrap();
1820 assert!(f1_micro >= 0.0 && f1_micro <= 1.0);
1821
1822 let f1_macro = f1_score(&y_true_view, &y_pred_view, "macro").unwrap();
1823 assert!(f1_macro >= 0.0 && f1_macro <= 1.0);
1824
1825 let f1_samples = f1_score(&y_true_view, &y_pred_view, "samples").unwrap();
1826 assert!(f1_samples >= 0.0 && f1_samples <= 1.0);
1827 }
1828
1829 #[test]
1830 fn test_per_label_vs_global_metrics_consistency() {
1831 let (y_true, y_pred) = create_test_data();
1832 let y_true_view = y_true.view();
1833 let y_pred_view = y_pred.view();
1834
1835 let per_label = per_label_metrics(&y_true_view, &y_pred_view).unwrap();
1836 let global_f1_macro = f1_score(&y_true_view, &y_pred_view, "macro").unwrap();
1837 let per_label_f1_macro = per_label.macro_average("f1_score").unwrap();
1838
1839 assert!((global_f1_macro - per_label_f1_macro).abs() < 1e-10);
1841 }
1842
1843 #[test]
1844 fn test_comprehensive_statistical_workflow() {
1845 let y_true = array![
1847 [1, 0, 1, 0],
1848 [0, 1, 0, 1],
1849 [1, 1, 1, 0],
1850 [0, 0, 0, 1],
1851 [1, 0, 1, 1]
1852 ];
1853
1854 let y_pred1 = array![
1856 [1, 0, 0, 0],
1857 [0, 1, 1, 1],
1858 [1, 0, 1, 0],
1859 [1, 0, 0, 1],
1860 [1, 1, 1, 0]
1861 ];
1862
1863 let y_pred2 = array![
1864 [0, 1, 1, 0],
1865 [1, 0, 0, 0],
1866 [0, 1, 0, 1],
1867 [0, 1, 1, 0],
1868 [0, 0, 0, 1]
1869 ];
1870
1871 let metrics1 = per_label_metrics(&y_true.view(), &y_pred1.view()).unwrap();
1873 let metrics2 = per_label_metrics(&y_true.view(), &y_pred2.view()).unwrap();
1874
1875 let mcnemar_result =
1877 mcnemar_test(&y_true.view(), &y_pred1.view(), &y_pred2.view()).unwrap();
1878
1879 let t_test_result = paired_t_test(&metrics1.f1_score, &metrics2.f1_score).unwrap();
1881
1882 let wilcoxon_result =
1884 wilcoxon_signed_rank_test(&metrics1.f1_score, &metrics2.f1_score).unwrap();
1885
1886 let ci_result = confidence_interval(&metrics1.f1_score, 0.95).unwrap();
1888
1889 assert!(mcnemar_result.p_value >= 0.0 && mcnemar_result.p_value <= 1.0);
1891 assert!(t_test_result.p_value >= 0.0 && t_test_result.p_value <= 1.0);
1892 assert!(wilcoxon_result.p_value >= 0.0 && wilcoxon_result.p_value <= 1.0);
1893 assert!(ci_result.lower <= ci_result.point_estimate);
1894 assert!(ci_result.point_estimate <= ci_result.upper);
1895
1896 assert!(!mcnemar_result.additional_info.is_empty());
1898 assert!(!t_test_result.additional_info.is_empty());
1899 assert!(!wilcoxon_result.additional_info.is_empty());
1900 }
1901}