1use crate::error::{MetricsError, Result};
8use scirs2_core::ndarray::{Array1, Array2, Axis};
9use std::collections::HashMap;
10use std::collections::HashSet;
11
12type PrecisionRecallFscoreSupport = (Array1<f64>, Array1<f64>, Array1<f64>, Array1<usize>);
14
15#[derive(Debug, Clone)]
35pub struct ClassificationReport {
36 pub precision: HashMap<String, f64>,
37 pub recall: HashMap<String, f64>,
38 pub f1_score: HashMap<String, f64>,
39 pub support: HashMap<String, usize>,
40 pub accuracy: f64,
41 pub macro_avg: ClassificationMetrics,
42 pub weighted_avg: ClassificationMetrics,
43}
44
45#[derive(Debug, Clone)]
46pub struct ClassificationMetrics {
47 pub precision: f64,
48 pub recall: f64,
49 pub f1_score: f64,
50 pub support: usize,
51}
52
53#[derive(Debug, Clone)]
55pub struct ClassificationResults {
56 pub accuracy: f64,
57 pub precision_weighted: f64,
58 pub recall_weighted: f64,
59 pub f1_weighted: f64,
60 pub auc_roc: f64,
61}
62
63impl ClassificationMetrics {
64 pub fn new() -> Self {
65 Self {
66 precision: 0.0,
67 recall: 0.0,
68 f1_score: 0.0,
69 support: 0,
70 }
71 }
72
73 pub fn compute(
75 &mut self,
76 y_true: scirs2_core::ndarray::ArrayView1<i32>,
77 y_pred: scirs2_core::ndarray::ArrayView1<i32>,
78 y_scores: Option<scirs2_core::ndarray::Array2<f64>>,
79 ) -> Result<ClassificationResults> {
80 if y_true.len() != y_pred.len() {
81 return Err(MetricsError::InvalidInput(
82 "y_true and y_pred must have the same length".to_string(),
83 ));
84 }
85
86 let accuracy = crate::classification::accuracy_score(&y_true, &y_pred)?;
88
89 let (precision, recall, f1) = self.calculate_binary_metrics(&y_true, &y_pred)?;
91
92 let auc_roc = if let Some(_scores) = y_scores {
94 let y_true_u32: Vec<u32> = y_true.iter().map(|&x| x as u32).collect();
96 let y_true_u32_array = scirs2_core::ndarray::Array1::from(y_true_u32);
97 let scores_f64 = _scores.column(1).to_owned();
98 crate::classification::roc_auc_score(&y_true_u32_array, &scores_f64)?
99 } else {
100 0.0
101 };
102
103 Ok(ClassificationResults {
104 accuracy,
105 precision_weighted: precision,
106 recall_weighted: recall,
107 f1_weighted: f1,
108 auc_roc,
109 })
110 }
111
112 fn calculate_binary_metrics(
114 &self,
115 y_true: &scirs2_core::ndarray::ArrayView1<i32>,
116 y_pred: &scirs2_core::ndarray::ArrayView1<i32>,
117 ) -> Result<(f64, f64, f64)> {
118 let mut tp = 0;
119 let mut fp = 0;
120 let mut tn = 0;
121 let mut fn_count = 0;
122
123 for (&true_label, &pred_label) in y_true.iter().zip(y_pred.iter()) {
124 match (true_label, pred_label) {
125 (1, 1) => tp += 1,
126 (0, 1) => fp += 1,
127 (0, 0) => tn += 1,
128 (1, 0) => fn_count += 1,
129 _ => {} }
131 }
132
133 let precision = if tp + fp > 0 {
134 tp as f64 / (tp + fp) as f64
135 } else {
136 0.0
137 };
138
139 let recall = if tp + fn_count > 0 {
140 tp as f64 / (tp + fn_count) as f64
141 } else {
142 0.0
143 };
144
145 let f1 = if precision + recall > 0.0 {
146 2.0 * precision * recall / (precision + recall)
147 } else {
148 0.0
149 };
150
151 Ok((precision, recall, f1))
152 }
153}
154
155impl Default for ClassificationMetrics {
156 fn default() -> Self {
157 Self::new()
158 }
159}
160
161#[allow(dead_code)]
163pub fn classification_report_sklearn(
164 y_true: &Array1<i32>,
165 y_pred: &Array1<i32>,
166 labels: Option<&[i32]>,
167 target_names: Option<&[String]>,
168 _digits: usize,
169 zero_division: f64,
170) -> Result<ClassificationReport> {
171 if y_true.len() != y_pred.len() {
172 return Err(MetricsError::InvalidInput(
173 "y_true and y_pred must have the same length".to_string(),
174 ));
175 }
176
177 let unique_labels: Vec<i32> = if let Some(labels) = labels {
179 labels.to_vec()
180 } else {
181 let all_labels: HashSet<i32> = y_true.iter().chain(y_pred.iter()).copied().collect();
182 let mut sorted_labels: Vec<i32> = all_labels.into_iter().collect();
183 sorted_labels.sort();
184 sorted_labels
185 };
186
187 let mut precision_map = HashMap::new();
189 let mut recall_map = HashMap::new();
190 let mut f1_map = HashMap::new();
191 let mut support_map = HashMap::new();
192
193 for &label in &unique_labels {
194 let (precision, recall, f1, support) =
195 calculate_class_metrics(y_true, y_pred, label, zero_division)?;
196
197 let label_name = if let Some(names) = target_names {
198 if let Some(pos) = unique_labels.iter().position(|&x| x == label) {
199 if pos < names.len() {
200 names[pos].clone()
201 } else {
202 label.to_string()
203 }
204 } else {
205 label.to_string()
206 }
207 } else {
208 label.to_string()
209 };
210
211 precision_map.insert(label_name.clone(), precision);
212 recall_map.insert(label_name.clone(), recall);
213 f1_map.insert(label_name.clone(), f1);
214 support_map.insert(label_name, support);
215 }
216
217 let accuracy = accuracy_score_sklearn(y_true, y_pred)?;
219
220 let macro_precision = precision_map.values().sum::<f64>() / precision_map.len() as f64;
222 let macro_recall = recall_map.values().sum::<f64>() / recall_map.len() as f64;
223 let macro_f1 = f1_map.values().sum::<f64>() / f1_map.len() as f64;
224 let macro_support = support_map.values().sum::<usize>();
225
226 let total_support = support_map.values().sum::<usize>() as f64;
228 let weighted_precision = precision_map
229 .iter()
230 .zip(support_map.iter())
231 .map(
232 |((label1, &p), (label2, &s))| {
233 if label1 == label2 {
234 p * s as f64
235 } else {
236 0.0
237 }
238 },
239 )
240 .sum::<f64>()
241 / total_support;
242
243 let weighted_recall = recall_map
244 .iter()
245 .zip(support_map.iter())
246 .map(
247 |((label1, &r), (label2, &s))| {
248 if label1 == label2 {
249 r * s as f64
250 } else {
251 0.0
252 }
253 },
254 )
255 .sum::<f64>()
256 / total_support;
257
258 let weighted_f1 = f1_map
259 .iter()
260 .zip(support_map.iter())
261 .map(
262 |((label1, &f), (label2, &s))| {
263 if label1 == label2 {
264 f * s as f64
265 } else {
266 0.0
267 }
268 },
269 )
270 .sum::<f64>()
271 / total_support;
272
273 Ok(ClassificationReport {
274 precision: precision_map,
275 recall: recall_map,
276 f1_score: f1_map,
277 support: support_map,
278 accuracy,
279 macro_avg: ClassificationMetrics {
280 precision: macro_precision,
281 recall: macro_recall,
282 f1_score: macro_f1,
283 support: macro_support,
284 },
285 weighted_avg: ClassificationMetrics {
286 precision: weighted_precision,
287 recall: weighted_recall,
288 f1_score: weighted_f1,
289 support: macro_support,
290 },
291 })
292}
293
294#[allow(dead_code)]
296fn calculate_class_metrics(
297 y_true: &Array1<i32>,
298 y_pred: &Array1<i32>,
299 target_class: i32,
300 zero_division: f64,
301) -> Result<(f64, f64, f64, usize)> {
302 let mut tp = 0;
303 let mut fp = 0;
304 let mut fn_count = 0;
305 let mut support = 0;
306
307 for (&true_val, &pred_val) in y_true.iter().zip(y_pred.iter()) {
308 if true_val == target_class {
309 support += 1;
310 if pred_val == target_class {
311 tp += 1;
312 } else {
313 fn_count += 1;
314 }
315 } else if pred_val == target_class {
316 fp += 1;
317 }
318 }
319
320 let precision = if tp + fp > 0 {
321 tp as f64 / (tp + fp) as f64
322 } else {
323 zero_division
324 };
325
326 let recall = if tp + fn_count > 0 {
327 tp as f64 / (tp + fn_count) as f64
328 } else {
329 zero_division
330 };
331
332 let f1 = if precision + recall > 0.0 {
333 2.0 * precision * recall / (precision + recall)
334 } else {
335 zero_division
336 };
337
338 Ok((precision, recall, f1, support))
339}
340
341#[allow(dead_code)]
343pub fn accuracy_score_sklearn(y_true: &Array1<i32>, ypred: &Array1<i32>) -> Result<f64> {
344 if y_true.len() != ypred.len() {
345 return Err(MetricsError::InvalidInput(
346 "y_true and y_pred must have the same length".to_string(),
347 ));
348 }
349
350 let correct = y_true
351 .iter()
352 .zip(ypred.iter())
353 .filter(|(&true_val, &pred_val)| true_val == pred_val)
354 .count();
355
356 Ok(correct as f64 / y_true.len() as f64)
357}
358
359#[allow(dead_code)]
361pub fn precision_recall_fscore_support_sklearn(
362 y_true: &Array1<i32>,
363 y_pred: &Array1<i32>,
364 beta: f64,
365 labels: Option<&[i32]>,
366 _pos_label: Option<i32>,
367 average: Option<&str>,
368 _warn_for: Option<&[&str]>,
369 zero_division: f64,
370) -> Result<PrecisionRecallFscoreSupport> {
371 if y_true.len() != y_pred.len() {
372 return Err(MetricsError::InvalidInput(
373 "y_true and y_pred must have the same length".to_string(),
374 ));
375 }
376
377 let target_labels: Vec<i32> = if let Some(labels) = labels {
379 labels.to_vec()
380 } else {
381 let all_labels: HashSet<i32> = y_true.iter().chain(y_pred.iter()).copied().collect();
382 let mut sorted_labels: Vec<i32> = all_labels.into_iter().collect();
383 sorted_labels.sort();
384 sorted_labels
385 };
386
387 let mut precisions = Vec::new();
388 let mut recalls = Vec::new();
389 let mut fscores = Vec::new();
390 let mut supports = Vec::new();
391
392 for &label in &target_labels {
393 let (precision, recall, f1, support) =
394 calculate_class_metrics(y_true, y_pred, label, zero_division)?;
395
396 let fbeta = if precision + recall > 0.0 {
398 (1.0 + beta * beta) * precision * recall / (beta * beta * precision + recall)
399 } else {
400 zero_division
401 };
402
403 precisions.push(precision);
404 recalls.push(recall);
405 fscores.push(fbeta);
406 supports.push(support);
407 }
408
409 if let Some(avg_type) = average {
411 match avg_type {
412 "micro" => {
413 let (micro_precision, micro_recall, micro_fbeta, total_support) =
414 calculate_micro_average(y_true, y_pred, beta, &target_labels, zero_division)?;
415 Ok((
416 Array1::from_vec(vec![micro_precision]),
417 Array1::from_vec(vec![micro_recall]),
418 Array1::from_vec(vec![micro_fbeta]),
419 Array1::from_vec(vec![total_support]),
420 ))
421 }
422 "macro" => {
423 let macro_precision = precisions.iter().sum::<f64>() / precisions.len() as f64;
424 let macro_recall = recalls.iter().sum::<f64>() / recalls.len() as f64;
425 let macro_fbeta = fscores.iter().sum::<f64>() / fscores.len() as f64;
426 let total_support = supports.iter().sum::<usize>();
427 Ok((
428 Array1::from_vec(vec![macro_precision]),
429 Array1::from_vec(vec![macro_recall]),
430 Array1::from_vec(vec![macro_fbeta]),
431 Array1::from_vec(vec![total_support]),
432 ))
433 }
434 "weighted" => {
435 let total_support = supports.iter().sum::<usize>() as f64;
436 let weighted_precision = precisions
437 .iter()
438 .zip(supports.iter())
439 .map(|(&p, &s)| p * s as f64)
440 .sum::<f64>()
441 / total_support;
442 let weighted_recall = recalls
443 .iter()
444 .zip(supports.iter())
445 .map(|(&r, &s)| r * s as f64)
446 .sum::<f64>()
447 / total_support;
448 let weighted_fbeta = fscores
449 .iter()
450 .zip(supports.iter())
451 .map(|(&f, &s)| f * s as f64)
452 .sum::<f64>()
453 / total_support;
454 Ok((
455 Array1::from_vec(vec![weighted_precision]),
456 Array1::from_vec(vec![weighted_recall]),
457 Array1::from_vec(vec![weighted_fbeta]),
458 Array1::from_vec(vec![total_support as usize]),
459 ))
460 }
461 _ => Err(MetricsError::InvalidInput(format!(
462 "Unsupported average type: {}",
463 avg_type
464 ))),
465 }
466 } else {
467 Ok((
468 Array1::from_vec(precisions),
469 Array1::from_vec(recalls),
470 Array1::from_vec(fscores),
471 Array1::from_vec(supports),
472 ))
473 }
474}
475
476#[allow(dead_code)]
478fn calculate_micro_average(
479 y_true: &Array1<i32>,
480 y_pred: &Array1<i32>,
481 beta: f64,
482 labels: &[i32],
483 zero_division: f64,
484) -> Result<(f64, f64, f64, usize)> {
485 let mut total_tp = 0;
486 let mut total_fp = 0;
487 let mut total_fn = 0;
488 let mut total_support = 0;
489
490 for &label in labels {
491 let mut tp = 0;
492 let mut fp = 0;
493 let mut fn_count = 0;
494
495 for (&true_val, &pred_val) in y_true.iter().zip(y_pred.iter()) {
496 if true_val == label {
497 total_support += 1;
498 if pred_val == label {
499 tp += 1;
500 } else {
501 fn_count += 1;
502 }
503 } else if pred_val == label {
504 fp += 1;
505 }
506 }
507
508 total_tp += tp;
509 total_fp += fp;
510 total_fn += fn_count;
511 }
512
513 let micro_precision = if total_tp + total_fp > 0 {
514 total_tp as f64 / (total_tp + total_fp) as f64
515 } else {
516 zero_division
517 };
518
519 let micro_recall = if total_tp + total_fn > 0 {
520 total_tp as f64 / (total_tp + total_fn) as f64
521 } else {
522 zero_division
523 };
524
525 let micro_fbeta = if micro_precision + micro_recall > 0.0 {
526 (1.0 + beta * beta) * micro_precision * micro_recall
527 / (beta * beta * micro_precision + micro_recall)
528 } else {
529 zero_division
530 };
531
532 Ok((micro_precision, micro_recall, micro_fbeta, total_support))
533}
534
535#[allow(dead_code)]
537pub fn multilabel_confusion_matrix_sklearn(
538 y_true: &Array2<i32>,
539 y_pred: &Array2<i32>,
540 sample_weight: Option<&Array1<f64>>,
541 labels: Option<&[usize]>,
542) -> Result<Array2<i32>> {
543 if y_true.shape() != y_pred.shape() {
544 return Err(MetricsError::InvalidInput(
545 "y_true and y_pred must have the same shape".to_string(),
546 ));
547 }
548
549 let (n_samples, n_labels) = y_true.dim();
550
551 if let Some(weights) = sample_weight {
552 if weights.len() != n_samples {
553 return Err(MetricsError::InvalidInput(
554 "sample_weight length must match number of samples".to_string(),
555 ));
556 }
557 }
558
559 let target_labels: Vec<usize> = if let Some(labels) = labels {
560 labels.to_vec()
561 } else {
562 (0..n_labels).collect()
563 };
564
565 let mut confusion_matrices = Array2::zeros((target_labels.len() * 2, 2));
566
567 for (label_idx, &label) in target_labels.iter().enumerate() {
568 if label >= n_labels {
569 return Err(MetricsError::InvalidInput(format!(
570 "Label {} is out of bounds for {} labels",
571 label, n_labels
572 )));
573 }
574
575 let mut tp = 0;
576 let mut fp = 0;
577 let mut tn = 0;
578 let mut fn_count = 0;
579
580 for sample_idx in 0..n_samples {
581 let true_val = y_true[[sample_idx, label]];
582 let pred_val = y_pred[[sample_idx, label]];
583
584 let weight = if let Some(weights) = sample_weight {
585 weights[sample_idx] as i32
586 } else {
587 1
588 };
589
590 match (true_val, pred_val) {
591 (1, 1) => tp += weight,
592 (0, 1) => fp += weight,
593 (0, 0) => tn += weight,
594 (1, 0) => fn_count += weight,
595 _ => {
596 return Err(MetricsError::InvalidInput(
597 "Labels must be 0 or 1 for multilabel classification".to_string(),
598 ))
599 }
600 }
601 }
602
603 let base_idx = label_idx * 2;
604 confusion_matrices[[base_idx, 0]] = tn;
605 confusion_matrices[[base_idx, 1]] = fp;
606 confusion_matrices[[base_idx + 1, 0]] = fn_count;
607 confusion_matrices[[base_idx + 1, 1]] = tp;
608 }
609
610 Ok(confusion_matrices)
611}
612
613#[allow(dead_code)]
615pub fn cohen_kappa_score_sklearn(
616 y1: &Array1<i32>,
617 y2: &Array1<i32>,
618 labels: Option<&[i32]>,
619 weights: Option<&str>,
620 sample_weight: Option<&Array1<f64>>,
621) -> Result<f64> {
622 if y1.len() != y2.len() {
623 return Err(MetricsError::InvalidInput(
624 "y1 and y2 must have the same length".to_string(),
625 ));
626 }
627
628 if let Some(sw) = sample_weight {
629 if sw.len() != y1.len() {
630 return Err(MetricsError::InvalidInput(
631 "sample_weight length must match y1 and y2 length".to_string(),
632 ));
633 }
634 }
635
636 let unique_labels: Vec<i32> = if let Some(labels) = labels {
638 labels.to_vec()
639 } else {
640 let all_labels: HashSet<i32> = y1.iter().chain(y2.iter()).copied().collect();
641 let mut sorted_labels: Vec<i32> = all_labels.into_iter().collect();
642 sorted_labels.sort();
643 sorted_labels
644 };
645
646 let n_labels = unique_labels.len();
647 let _n = y1.len();
648
649 let mut confusion_matrix = Array2::zeros((n_labels, n_labels));
651 let mut total_weight = 0.0;
652
653 for (idx, (&true_val, &pred_val)) in y1.iter().zip(y2.iter()).enumerate() {
654 let weight = if let Some(sw) = sample_weight {
655 sw[idx]
656 } else {
657 1.0
658 };
659
660 if let (Some(true_idx), Some(pred_idx)) = (
661 unique_labels.iter().position(|&x| x == true_val),
662 unique_labels.iter().position(|&x| x == pred_val),
663 ) {
664 confusion_matrix[[true_idx, pred_idx]] += weight;
665 total_weight += weight;
666 }
667 }
668
669 if total_weight > 0.0 {
671 confusion_matrix /= total_weight;
672 }
673
674 let mut po = 0.0;
676 for i in 0..n_labels {
677 po += confusion_matrix[[i, i]];
678 }
679
680 let mut pe = 0.0;
682 match weights {
683 Some("linear") => {
684 for i in 0..n_labels {
686 for j in 0..n_labels {
687 let weight_ij = 1.0 - (i as f64 - j as f64).abs() / (n_labels - 1) as f64;
688 let row_sum = confusion_matrix.row(i).sum();
689 let col_sum = confusion_matrix.column(j).sum();
690 pe += weight_ij * row_sum * col_sum;
691 }
692 }
693 }
694 Some("quadratic") => {
695 for i in 0..n_labels {
697 for j in 0..n_labels {
698 let diff = (i as f64 - j as f64) / (n_labels - 1) as f64;
699 let weight_ij = 1.0 - diff * diff;
700 let row_sum = confusion_matrix.row(i).sum();
701 let col_sum = confusion_matrix.column(j).sum();
702 pe += weight_ij * row_sum * col_sum;
703 }
704 }
705 }
706 None => {
707 for i in 0..n_labels {
709 let row_sum = confusion_matrix.row(i).sum();
710 let col_sum = confusion_matrix.column(i).sum();
711 pe += row_sum * col_sum;
712 }
713 }
714 _ => {
715 return Err(MetricsError::InvalidInput(
716 "weights must be None, 'linear', or 'quadratic'".to_string(),
717 ))
718 }
719 }
720
721 if (1.0 - pe).abs() < 1e-15 {
723 Ok(1.0) } else {
725 Ok((po - pe) / (1.0 - pe))
726 }
727}
728
729#[allow(dead_code)]
731pub fn hinge_loss_sklearn(
732 y_true: &Array1<i32>,
733 y_pred: &Array2<f64>,
734 labels: Option<&[i32]>,
735 sample_weight: Option<&Array1<f64>>,
736) -> Result<f64> {
737 let (n_samples, n_classes) = y_pred.dim();
738
739 if y_true.len() != n_samples {
740 return Err(MetricsError::InvalidInput(
741 "y_true length must match number of samples in y_pred".to_string(),
742 ));
743 }
744
745 if let Some(sw) = sample_weight {
746 if sw.len() != n_samples {
747 return Err(MetricsError::InvalidInput(
748 "sample_weight length must match number of samples".to_string(),
749 ));
750 }
751 }
752
753 let class_labels: Vec<i32> = if let Some(labels) = labels {
755 if labels.len() != n_classes {
756 return Err(MetricsError::InvalidInput(
757 "labels length must match number of classes in y_pred".to_string(),
758 ));
759 }
760 labels.to_vec()
761 } else {
762 let unique_labels: HashSet<i32> = y_true.iter().copied().collect();
763 let mut sorted_labels: Vec<i32> = unique_labels.into_iter().collect();
764 sorted_labels.sort();
765 if sorted_labels.len() != n_classes {
766 return Err(MetricsError::InvalidInput(
767 "Number of unique labels in y_true must match number of classes in y_pred"
768 .to_string(),
769 ));
770 }
771 sorted_labels
772 };
773
774 let mut total_loss = 0.0;
775 let mut total_weight = 0.0;
776
777 for (sample_idx, &true_label) in y_true.iter().enumerate() {
778 let weight = if let Some(sw) = sample_weight {
779 sw[sample_idx]
780 } else {
781 1.0
782 };
783
784 if let Some(true_class_idx) = class_labels.iter().position(|&x| x == true_label) {
786 let true_score = y_pred[[sample_idx, true_class_idx]];
787
788 let mut sample_loss = 0.0;
790 for (class_idx, &_class_label) in class_labels.iter().enumerate() {
791 if class_idx != true_class_idx {
792 let class_score = y_pred[[sample_idx, class_idx]];
793 let margin = true_score - class_score;
794 sample_loss += (1.0 - margin).max(0.0);
795 }
796 }
797
798 total_loss += weight * sample_loss;
799 total_weight += weight;
800 } else {
801 return Err(MetricsError::InvalidInput(format!(
802 "Label {} not found in provided labels",
803 true_label
804 )));
805 }
806 }
807
808 if total_weight > 0.0 {
809 Ok(total_loss / total_weight)
810 } else {
811 Ok(0.0)
812 }
813}
814
815#[allow(dead_code)]
817pub fn zero_one_loss_sklearn(
818 y_true: &Array1<i32>,
819 y_pred: &Array1<i32>,
820 normalize: bool,
821 sample_weight: Option<&Array1<f64>>,
822) -> Result<f64> {
823 if y_true.len() != y_pred.len() {
824 return Err(MetricsError::InvalidInput(
825 "y_true and y_pred must have the same length".to_string(),
826 ));
827 }
828
829 if let Some(sw) = sample_weight {
830 if sw.len() != y_true.len() {
831 return Err(MetricsError::InvalidInput(
832 "sample_weight length must match y_true and y_pred length".to_string(),
833 ));
834 }
835 }
836
837 let mut total_errors = 0.0;
838 let mut total_weight = 0.0;
839
840 for (idx, (&true_val, &pred_val)) in y_true.iter().zip(y_pred.iter()).enumerate() {
841 let weight = if let Some(sw) = sample_weight {
842 sw[idx]
843 } else {
844 1.0
845 };
846
847 if true_val != pred_val {
848 total_errors += weight;
849 }
850 total_weight += weight;
851 }
852
853 if normalize {
854 if total_weight > 0.0 {
855 Ok(total_errors / total_weight)
856 } else {
857 Ok(0.0)
858 }
859 } else {
860 Ok(total_errors)
861 }
862}
863
864#[cfg(test)]
865mod tests {
866 use super::*;
867 use scirs2_core::ndarray::Array;
868
869 #[test]
870 fn test_classification_report_sklearn() {
871 let y_true = Array1::from_vec(vec![0, 1, 2, 0, 1, 2]);
872 let y_pred = Array1::from_vec(vec![0, 2, 1, 0, 0, 2]);
873
874 let report = classification_report_sklearn(&y_true, &y_pred, None, None, 2, 0.0).unwrap();
875
876 assert!(report.accuracy >= 0.0 && report.accuracy <= 1.0);
877 assert!(report.precision.len() == 3);
878 assert!(report.recall.len() == 3);
879 assert!(report.f1_score.len() == 3);
880 }
881
882 #[test]
883 fn test_precision_recall_fscore_support_sklearn() {
884 let y_true = Array1::from_vec(vec![0, 1, 2, 0, 1, 2]);
885 let y_pred = Array1::from_vec(vec![0, 2, 1, 0, 0, 2]);
886
887 let (precision, recall, fscore, support) = precision_recall_fscore_support_sklearn(
888 &y_true,
889 &y_pred,
890 1.0,
891 None,
892 None,
893 Some("macro"),
894 None,
895 0.0,
896 )
897 .unwrap();
898
899 assert_eq!(precision.len(), 1);
900 assert_eq!(recall.len(), 1);
901 assert_eq!(fscore.len(), 1);
902 assert_eq!(support.len(), 1);
903 }
904
905 #[test]
906 fn test_cohen_kappa_score_sklearn() {
907 let y1 = Array1::from_vec(vec![0, 1, 0, 1]);
908 let y2 = Array1::from_vec(vec![0, 1, 0, 1]);
909
910 let kappa = cohen_kappa_score_sklearn(&y1, &y2, None, None, None).unwrap();
911 assert!((kappa - 1.0).abs() < 1e-10); let y3 = Array1::from_vec(vec![0, 1, 1, 0]);
914 let kappa2 = cohen_kappa_score_sklearn(&y1, &y3, None, None, None).unwrap();
915 assert!(kappa2 < 1.0); }
917
918 #[test]
919 fn test_zero_one_loss_sklearn() {
920 let y_true = Array1::from_vec(vec![0, 1, 0, 1]);
921 let y_pred = Array1::from_vec(vec![0, 1, 1, 0]);
922
923 let loss_normalized = zero_one_loss_sklearn(&y_true, &y_pred, true, None).unwrap();
924 assert!((loss_normalized - 0.5).abs() < 1e-10); let loss_count = zero_one_loss_sklearn(&y_true, &y_pred, false, None).unwrap();
927 assert!((loss_count - 2.0).abs() < 1e-10); }
929
930 #[test]
931 fn test_multilabel_confusion_matrix_sklearn() {
932 let y_true =
933 Array2::from_shape_vec((4, 3), vec![1, 0, 1, 0, 1, 0, 1, 1, 0, 0, 0, 1]).unwrap();
934
935 let y_pred =
936 Array2::from_shape_vec((4, 3), vec![1, 0, 0, 0, 1, 1, 1, 0, 0, 1, 0, 1]).unwrap();
937
938 let confusion_matrices =
939 multilabel_confusion_matrix_sklearn(&y_true, &y_pred, None, None).unwrap();
940
941 assert_eq!(confusion_matrices.shape(), [6, 2]); }
943}