scirs2_metrics/
sklearn_compat.rs

1//! Scikit-learn compatibility module
2//!
3//! This module provides implementations of metrics that are equivalent to
4//! those found in scikit-learn, ensuring API compatibility and identical
5//! results where possible.
6
7use crate::error::{MetricsError, Result};
8use scirs2_core::ndarray::{Array1, Array2, Axis};
9use std::collections::HashMap;
10use std::collections::HashSet;
11
12/// Type alias for precision, recall, fscore, support tuple
13type PrecisionRecallFscoreSupport = (Array1<f64>, Array1<f64>, Array1<f64>, Array1<usize>);
14
15/// Equivalent to sklearn.metrics.classification_report
16///
17/// Build a text report showing the main classification metrics.
18///
19/// # Arguments
20///
21/// * `y_true` - Ground truth (correct) target values
22/// * `y_pred` - Estimated targets as returned by a classifier  
23/// * `labels` - Optional list of label indices to include in the report
24/// * `target_names` - Optional display names matching the labels (same order)
25/// * `sample_weight` - Optional sample weights
26/// * `digits` - Number of digits for formatting output floating point values
27/// * `output_dict` - If True, return output as dict instead of string
28/// * `zero_division` - Sets the value to return when there is a zero division
29///
30/// # Returns
31///
32/// * Text summary of precision, recall, f1-score for each class (if output_dict=False)
33/// * Dictionary with precision, recall, f1-score for each class (if output_dict=True)
34#[derive(Debug, Clone)]
35pub struct ClassificationReport {
36    pub precision: HashMap<String, f64>,
37    pub recall: HashMap<String, f64>,
38    pub f1_score: HashMap<String, f64>,
39    pub support: HashMap<String, usize>,
40    pub accuracy: f64,
41    pub macro_avg: ClassificationMetrics,
42    pub weighted_avg: ClassificationMetrics,
43}
44
45#[derive(Debug, Clone)]
46pub struct ClassificationMetrics {
47    pub precision: f64,
48    pub recall: f64,
49    pub f1_score: f64,
50    pub support: usize,
51}
52
53/// Classification results returned by compute method
54#[derive(Debug, Clone)]
55pub struct ClassificationResults {
56    pub accuracy: f64,
57    pub precision_weighted: f64,
58    pub recall_weighted: f64,
59    pub f1_weighted: f64,
60    pub auc_roc: f64,
61}
62
63impl ClassificationMetrics {
64    pub fn new() -> Self {
65        Self {
66            precision: 0.0,
67            recall: 0.0,
68            f1_score: 0.0,
69            support: 0,
70        }
71    }
72
73    /// Compute classification metrics
74    pub fn compute(
75        &mut self,
76        y_true: scirs2_core::ndarray::ArrayView1<i32>,
77        y_pred: scirs2_core::ndarray::ArrayView1<i32>,
78        y_scores: Option<scirs2_core::ndarray::Array2<f64>>,
79    ) -> Result<ClassificationResults> {
80        if y_true.len() != y_pred.len() {
81            return Err(MetricsError::InvalidInput(
82                "y_true and y_pred must have the same length".to_string(),
83            ));
84        }
85
86        // Calculate basic metrics
87        let accuracy = crate::classification::accuracy_score(&y_true, &y_pred)?;
88
89        // Calculate precision, recall, F1 for binary classification
90        let (precision, recall, f1) = self.calculate_binary_metrics(&y_true, &y_pred)?;
91
92        // Calculate AUC if scores provided
93        let auc_roc = if let Some(_scores) = y_scores {
94            // Convert to appropriate types for roc_auc_score
95            let y_true_u32: Vec<u32> = y_true.iter().map(|&x| x as u32).collect();
96            let y_true_u32_array = scirs2_core::ndarray::Array1::from(y_true_u32);
97            let scores_f64 = _scores.column(1).to_owned();
98            crate::classification::roc_auc_score(&y_true_u32_array, &scores_f64)?
99        } else {
100            0.0
101        };
102
103        Ok(ClassificationResults {
104            accuracy,
105            precision_weighted: precision,
106            recall_weighted: recall,
107            f1_weighted: f1,
108            auc_roc,
109        })
110    }
111
112    /// Calculate binary classification metrics
113    fn calculate_binary_metrics(
114        &self,
115        y_true: &scirs2_core::ndarray::ArrayView1<i32>,
116        y_pred: &scirs2_core::ndarray::ArrayView1<i32>,
117    ) -> Result<(f64, f64, f64)> {
118        let mut tp = 0;
119        let mut fp = 0;
120        let mut tn = 0;
121        let mut fn_count = 0;
122
123        for (&true_label, &pred_label) in y_true.iter().zip(y_pred.iter()) {
124            match (true_label, pred_label) {
125                (1, 1) => tp += 1,
126                (0, 1) => fp += 1,
127                (0, 0) => tn += 1,
128                (1, 0) => fn_count += 1,
129                _ => {} // Handle multi-class case
130            }
131        }
132
133        let precision = if tp + fp > 0 {
134            tp as f64 / (tp + fp) as f64
135        } else {
136            0.0
137        };
138
139        let recall = if tp + fn_count > 0 {
140            tp as f64 / (tp + fn_count) as f64
141        } else {
142            0.0
143        };
144
145        let f1 = if precision + recall > 0.0 {
146            2.0 * precision * recall / (precision + recall)
147        } else {
148            0.0
149        };
150
151        Ok((precision, recall, f1))
152    }
153}
154
155impl Default for ClassificationMetrics {
156    fn default() -> Self {
157        Self::new()
158    }
159}
160
161/// Equivalent to sklearn.metrics.classification_report
162#[allow(dead_code)]
163pub fn classification_report_sklearn(
164    y_true: &Array1<i32>,
165    y_pred: &Array1<i32>,
166    labels: Option<&[i32]>,
167    target_names: Option<&[String]>,
168    _digits: usize,
169    zero_division: f64,
170) -> Result<ClassificationReport> {
171    if y_true.len() != y_pred.len() {
172        return Err(MetricsError::InvalidInput(
173            "y_true and y_pred must have the same length".to_string(),
174        ));
175    }
176
177    // Get unique labels
178    let unique_labels: Vec<i32> = if let Some(labels) = labels {
179        labels.to_vec()
180    } else {
181        let all_labels: HashSet<i32> = y_true.iter().chain(y_pred.iter()).copied().collect();
182        let mut sorted_labels: Vec<i32> = all_labels.into_iter().collect();
183        sorted_labels.sort();
184        sorted_labels
185    };
186
187    // Calculate per-class metrics
188    let mut precision_map = HashMap::new();
189    let mut recall_map = HashMap::new();
190    let mut f1_map = HashMap::new();
191    let mut support_map = HashMap::new();
192
193    for &label in &unique_labels {
194        let (precision, recall, f1, support) =
195            calculate_class_metrics(y_true, y_pred, label, zero_division)?;
196
197        let label_name = if let Some(names) = target_names {
198            if let Some(pos) = unique_labels.iter().position(|&x| x == label) {
199                if pos < names.len() {
200                    names[pos].clone()
201                } else {
202                    label.to_string()
203                }
204            } else {
205                label.to_string()
206            }
207        } else {
208            label.to_string()
209        };
210
211        precision_map.insert(label_name.clone(), precision);
212        recall_map.insert(label_name.clone(), recall);
213        f1_map.insert(label_name.clone(), f1);
214        support_map.insert(label_name, support);
215    }
216
217    // Calculate accuracy
218    let accuracy = accuracy_score_sklearn(y_true, y_pred)?;
219
220    // Calculate macro averages
221    let macro_precision = precision_map.values().sum::<f64>() / precision_map.len() as f64;
222    let macro_recall = recall_map.values().sum::<f64>() / recall_map.len() as f64;
223    let macro_f1 = f1_map.values().sum::<f64>() / f1_map.len() as f64;
224    let macro_support = support_map.values().sum::<usize>();
225
226    // Calculate weighted averages
227    let total_support = support_map.values().sum::<usize>() as f64;
228    let weighted_precision = precision_map
229        .iter()
230        .zip(support_map.iter())
231        .map(
232            |((label1, &p), (label2, &s))| {
233                if label1 == label2 {
234                    p * s as f64
235                } else {
236                    0.0
237                }
238            },
239        )
240        .sum::<f64>()
241        / total_support;
242
243    let weighted_recall = recall_map
244        .iter()
245        .zip(support_map.iter())
246        .map(
247            |((label1, &r), (label2, &s))| {
248                if label1 == label2 {
249                    r * s as f64
250                } else {
251                    0.0
252                }
253            },
254        )
255        .sum::<f64>()
256        / total_support;
257
258    let weighted_f1 = f1_map
259        .iter()
260        .zip(support_map.iter())
261        .map(
262            |((label1, &f), (label2, &s))| {
263                if label1 == label2 {
264                    f * s as f64
265                } else {
266                    0.0
267                }
268            },
269        )
270        .sum::<f64>()
271        / total_support;
272
273    Ok(ClassificationReport {
274        precision: precision_map,
275        recall: recall_map,
276        f1_score: f1_map,
277        support: support_map,
278        accuracy,
279        macro_avg: ClassificationMetrics {
280            precision: macro_precision,
281            recall: macro_recall,
282            f1_score: macro_f1,
283            support: macro_support,
284        },
285        weighted_avg: ClassificationMetrics {
286            precision: weighted_precision,
287            recall: weighted_recall,
288            f1_score: weighted_f1,
289            support: macro_support,
290        },
291    })
292}
293
294/// Calculate metrics for a specific class
295#[allow(dead_code)]
296fn calculate_class_metrics(
297    y_true: &Array1<i32>,
298    y_pred: &Array1<i32>,
299    target_class: i32,
300    zero_division: f64,
301) -> Result<(f64, f64, f64, usize)> {
302    let mut tp = 0;
303    let mut fp = 0;
304    let mut fn_count = 0;
305    let mut support = 0;
306
307    for (&true_val, &pred_val) in y_true.iter().zip(y_pred.iter()) {
308        if true_val == target_class {
309            support += 1;
310            if pred_val == target_class {
311                tp += 1;
312            } else {
313                fn_count += 1;
314            }
315        } else if pred_val == target_class {
316            fp += 1;
317        }
318    }
319
320    let precision = if tp + fp > 0 {
321        tp as f64 / (tp + fp) as f64
322    } else {
323        zero_division
324    };
325
326    let recall = if tp + fn_count > 0 {
327        tp as f64 / (tp + fn_count) as f64
328    } else {
329        zero_division
330    };
331
332    let f1 = if precision + recall > 0.0 {
333        2.0 * precision * recall / (precision + recall)
334    } else {
335        zero_division
336    };
337
338    Ok((precision, recall, f1, support))
339}
340
341/// Equivalent to sklearn.metrics.accuracy_score
342#[allow(dead_code)]
343pub fn accuracy_score_sklearn(y_true: &Array1<i32>, ypred: &Array1<i32>) -> Result<f64> {
344    if y_true.len() != ypred.len() {
345        return Err(MetricsError::InvalidInput(
346            "y_true and y_pred must have the same length".to_string(),
347        ));
348    }
349
350    let correct = y_true
351        .iter()
352        .zip(ypred.iter())
353        .filter(|(&true_val, &pred_val)| true_val == pred_val)
354        .count();
355
356    Ok(correct as f64 / y_true.len() as f64)
357}
358
359/// Equivalent to sklearn.metrics.precision_recall_fscore_support
360#[allow(dead_code)]
361pub fn precision_recall_fscore_support_sklearn(
362    y_true: &Array1<i32>,
363    y_pred: &Array1<i32>,
364    beta: f64,
365    labels: Option<&[i32]>,
366    _pos_label: Option<i32>,
367    average: Option<&str>,
368    _warn_for: Option<&[&str]>,
369    zero_division: f64,
370) -> Result<PrecisionRecallFscoreSupport> {
371    if y_true.len() != y_pred.len() {
372        return Err(MetricsError::InvalidInput(
373            "y_true and y_pred must have the same length".to_string(),
374        ));
375    }
376
377    // Determine labels to use
378    let target_labels: Vec<i32> = if let Some(labels) = labels {
379        labels.to_vec()
380    } else {
381        let all_labels: HashSet<i32> = y_true.iter().chain(y_pred.iter()).copied().collect();
382        let mut sorted_labels: Vec<i32> = all_labels.into_iter().collect();
383        sorted_labels.sort();
384        sorted_labels
385    };
386
387    let mut precisions = Vec::new();
388    let mut recalls = Vec::new();
389    let mut fscores = Vec::new();
390    let mut supports = Vec::new();
391
392    for &label in &target_labels {
393        let (precision, recall, f1, support) =
394            calculate_class_metrics(y_true, y_pred, label, zero_division)?;
395
396        // Calculate F-beta score
397        let fbeta = if precision + recall > 0.0 {
398            (1.0 + beta * beta) * precision * recall / (beta * beta * precision + recall)
399        } else {
400            zero_division
401        };
402
403        precisions.push(precision);
404        recalls.push(recall);
405        fscores.push(fbeta);
406        supports.push(support);
407    }
408
409    // Handle averaging
410    if let Some(avg_type) = average {
411        match avg_type {
412            "micro" => {
413                let (micro_precision, micro_recall, micro_fbeta, total_support) =
414                    calculate_micro_average(y_true, y_pred, beta, &target_labels, zero_division)?;
415                Ok((
416                    Array1::from_vec(vec![micro_precision]),
417                    Array1::from_vec(vec![micro_recall]),
418                    Array1::from_vec(vec![micro_fbeta]),
419                    Array1::from_vec(vec![total_support]),
420                ))
421            }
422            "macro" => {
423                let macro_precision = precisions.iter().sum::<f64>() / precisions.len() as f64;
424                let macro_recall = recalls.iter().sum::<f64>() / recalls.len() as f64;
425                let macro_fbeta = fscores.iter().sum::<f64>() / fscores.len() as f64;
426                let total_support = supports.iter().sum::<usize>();
427                Ok((
428                    Array1::from_vec(vec![macro_precision]),
429                    Array1::from_vec(vec![macro_recall]),
430                    Array1::from_vec(vec![macro_fbeta]),
431                    Array1::from_vec(vec![total_support]),
432                ))
433            }
434            "weighted" => {
435                let total_support = supports.iter().sum::<usize>() as f64;
436                let weighted_precision = precisions
437                    .iter()
438                    .zip(supports.iter())
439                    .map(|(&p, &s)| p * s as f64)
440                    .sum::<f64>()
441                    / total_support;
442                let weighted_recall = recalls
443                    .iter()
444                    .zip(supports.iter())
445                    .map(|(&r, &s)| r * s as f64)
446                    .sum::<f64>()
447                    / total_support;
448                let weighted_fbeta = fscores
449                    .iter()
450                    .zip(supports.iter())
451                    .map(|(&f, &s)| f * s as f64)
452                    .sum::<f64>()
453                    / total_support;
454                Ok((
455                    Array1::from_vec(vec![weighted_precision]),
456                    Array1::from_vec(vec![weighted_recall]),
457                    Array1::from_vec(vec![weighted_fbeta]),
458                    Array1::from_vec(vec![total_support as usize]),
459                ))
460            }
461            _ => Err(MetricsError::InvalidInput(format!(
462                "Unsupported average type: {}",
463                avg_type
464            ))),
465        }
466    } else {
467        Ok((
468            Array1::from_vec(precisions),
469            Array1::from_vec(recalls),
470            Array1::from_vec(fscores),
471            Array1::from_vec(supports),
472        ))
473    }
474}
475
476/// Calculate micro-averaged metrics
477#[allow(dead_code)]
478fn calculate_micro_average(
479    y_true: &Array1<i32>,
480    y_pred: &Array1<i32>,
481    beta: f64,
482    labels: &[i32],
483    zero_division: f64,
484) -> Result<(f64, f64, f64, usize)> {
485    let mut total_tp = 0;
486    let mut total_fp = 0;
487    let mut total_fn = 0;
488    let mut total_support = 0;
489
490    for &label in labels {
491        let mut tp = 0;
492        let mut fp = 0;
493        let mut fn_count = 0;
494
495        for (&true_val, &pred_val) in y_true.iter().zip(y_pred.iter()) {
496            if true_val == label {
497                total_support += 1;
498                if pred_val == label {
499                    tp += 1;
500                } else {
501                    fn_count += 1;
502                }
503            } else if pred_val == label {
504                fp += 1;
505            }
506        }
507
508        total_tp += tp;
509        total_fp += fp;
510        total_fn += fn_count;
511    }
512
513    let micro_precision = if total_tp + total_fp > 0 {
514        total_tp as f64 / (total_tp + total_fp) as f64
515    } else {
516        zero_division
517    };
518
519    let micro_recall = if total_tp + total_fn > 0 {
520        total_tp as f64 / (total_tp + total_fn) as f64
521    } else {
522        zero_division
523    };
524
525    let micro_fbeta = if micro_precision + micro_recall > 0.0 {
526        (1.0 + beta * beta) * micro_precision * micro_recall
527            / (beta * beta * micro_precision + micro_recall)
528    } else {
529        zero_division
530    };
531
532    Ok((micro_precision, micro_recall, micro_fbeta, total_support))
533}
534
535/// Equivalent to sklearn.metrics.multilabel_confusion_matrix
536#[allow(dead_code)]
537pub fn multilabel_confusion_matrix_sklearn(
538    y_true: &Array2<i32>,
539    y_pred: &Array2<i32>,
540    sample_weight: Option<&Array1<f64>>,
541    labels: Option<&[usize]>,
542) -> Result<Array2<i32>> {
543    if y_true.shape() != y_pred.shape() {
544        return Err(MetricsError::InvalidInput(
545            "y_true and y_pred must have the same shape".to_string(),
546        ));
547    }
548
549    let (n_samples, n_labels) = y_true.dim();
550
551    if let Some(weights) = sample_weight {
552        if weights.len() != n_samples {
553            return Err(MetricsError::InvalidInput(
554                "sample_weight length must match number of samples".to_string(),
555            ));
556        }
557    }
558
559    let target_labels: Vec<usize> = if let Some(labels) = labels {
560        labels.to_vec()
561    } else {
562        (0..n_labels).collect()
563    };
564
565    let mut confusion_matrices = Array2::zeros((target_labels.len() * 2, 2));
566
567    for (label_idx, &label) in target_labels.iter().enumerate() {
568        if label >= n_labels {
569            return Err(MetricsError::InvalidInput(format!(
570                "Label {} is out of bounds for {} labels",
571                label, n_labels
572            )));
573        }
574
575        let mut tp = 0;
576        let mut fp = 0;
577        let mut tn = 0;
578        let mut fn_count = 0;
579
580        for sample_idx in 0..n_samples {
581            let true_val = y_true[[sample_idx, label]];
582            let pred_val = y_pred[[sample_idx, label]];
583
584            let weight = if let Some(weights) = sample_weight {
585                weights[sample_idx] as i32
586            } else {
587                1
588            };
589
590            match (true_val, pred_val) {
591                (1, 1) => tp += weight,
592                (0, 1) => fp += weight,
593                (0, 0) => tn += weight,
594                (1, 0) => fn_count += weight,
595                _ => {
596                    return Err(MetricsError::InvalidInput(
597                        "Labels must be 0 or 1 for multilabel classification".to_string(),
598                    ))
599                }
600            }
601        }
602
603        let base_idx = label_idx * 2;
604        confusion_matrices[[base_idx, 0]] = tn;
605        confusion_matrices[[base_idx, 1]] = fp;
606        confusion_matrices[[base_idx + 1, 0]] = fn_count;
607        confusion_matrices[[base_idx + 1, 1]] = tp;
608    }
609
610    Ok(confusion_matrices)
611}
612
613/// Equivalent to sklearn.metrics.cohen_kappa_score  
614#[allow(dead_code)]
615pub fn cohen_kappa_score_sklearn(
616    y1: &Array1<i32>,
617    y2: &Array1<i32>,
618    labels: Option<&[i32]>,
619    weights: Option<&str>,
620    sample_weight: Option<&Array1<f64>>,
621) -> Result<f64> {
622    if y1.len() != y2.len() {
623        return Err(MetricsError::InvalidInput(
624            "y1 and y2 must have the same length".to_string(),
625        ));
626    }
627
628    if let Some(sw) = sample_weight {
629        if sw.len() != y1.len() {
630            return Err(MetricsError::InvalidInput(
631                "sample_weight length must match y1 and y2 length".to_string(),
632            ));
633        }
634    }
635
636    // Determine unique labels
637    let unique_labels: Vec<i32> = if let Some(labels) = labels {
638        labels.to_vec()
639    } else {
640        let all_labels: HashSet<i32> = y1.iter().chain(y2.iter()).copied().collect();
641        let mut sorted_labels: Vec<i32> = all_labels.into_iter().collect();
642        sorted_labels.sort();
643        sorted_labels
644    };
645
646    let n_labels = unique_labels.len();
647    let _n = y1.len();
648
649    // Create confusion matrix
650    let mut confusion_matrix = Array2::zeros((n_labels, n_labels));
651    let mut total_weight = 0.0;
652
653    for (idx, (&true_val, &pred_val)) in y1.iter().zip(y2.iter()).enumerate() {
654        let weight = if let Some(sw) = sample_weight {
655            sw[idx]
656        } else {
657            1.0
658        };
659
660        if let (Some(true_idx), Some(pred_idx)) = (
661            unique_labels.iter().position(|&x| x == true_val),
662            unique_labels.iter().position(|&x| x == pred_val),
663        ) {
664            confusion_matrix[[true_idx, pred_idx]] += weight;
665            total_weight += weight;
666        }
667    }
668
669    // Normalize confusion matrix
670    if total_weight > 0.0 {
671        confusion_matrix /= total_weight;
672    }
673
674    // Calculate observed agreement (diagonal sum)
675    let mut po = 0.0;
676    for i in 0..n_labels {
677        po += confusion_matrix[[i, i]];
678    }
679
680    // Calculate expected agreement
681    let mut pe = 0.0;
682    match weights {
683        Some("linear") => {
684            // Linear weights: w_ij = 1 - |i - j| / (n_labels - 1)
685            for i in 0..n_labels {
686                for j in 0..n_labels {
687                    let weight_ij = 1.0 - (i as f64 - j as f64).abs() / (n_labels - 1) as f64;
688                    let row_sum = confusion_matrix.row(i).sum();
689                    let col_sum = confusion_matrix.column(j).sum();
690                    pe += weight_ij * row_sum * col_sum;
691                }
692            }
693        }
694        Some("quadratic") => {
695            // Quadratic weights: w_ij = 1 - ((i - j) / (n_labels - 1))^2
696            for i in 0..n_labels {
697                for j in 0..n_labels {
698                    let diff = (i as f64 - j as f64) / (n_labels - 1) as f64;
699                    let weight_ij = 1.0 - diff * diff;
700                    let row_sum = confusion_matrix.row(i).sum();
701                    let col_sum = confusion_matrix.column(j).sum();
702                    pe += weight_ij * row_sum * col_sum;
703                }
704            }
705        }
706        None => {
707            // Standard Cohen's kappa (no weighting)
708            for i in 0..n_labels {
709                let row_sum = confusion_matrix.row(i).sum();
710                let col_sum = confusion_matrix.column(i).sum();
711                pe += row_sum * col_sum;
712            }
713        }
714        _ => {
715            return Err(MetricsError::InvalidInput(
716                "weights must be None, 'linear', or 'quadratic'".to_string(),
717            ))
718        }
719    }
720
721    // Calculate kappa
722    if (1.0 - pe).abs() < 1e-15 {
723        Ok(1.0) // Perfect agreement
724    } else {
725        Ok((po - pe) / (1.0 - pe))
726    }
727}
728
729/// Equivalent to sklearn.metrics.hinge_loss
730#[allow(dead_code)]
731pub fn hinge_loss_sklearn(
732    y_true: &Array1<i32>,
733    y_pred: &Array2<f64>,
734    labels: Option<&[i32]>,
735    sample_weight: Option<&Array1<f64>>,
736) -> Result<f64> {
737    let (n_samples, n_classes) = y_pred.dim();
738
739    if y_true.len() != n_samples {
740        return Err(MetricsError::InvalidInput(
741            "y_true length must match number of samples in y_pred".to_string(),
742        ));
743    }
744
745    if let Some(sw) = sample_weight {
746        if sw.len() != n_samples {
747            return Err(MetricsError::InvalidInput(
748                "sample_weight length must match number of samples".to_string(),
749            ));
750        }
751    }
752
753    // Determine class labels
754    let class_labels: Vec<i32> = if let Some(labels) = labels {
755        if labels.len() != n_classes {
756            return Err(MetricsError::InvalidInput(
757                "labels length must match number of classes in y_pred".to_string(),
758            ));
759        }
760        labels.to_vec()
761    } else {
762        let unique_labels: HashSet<i32> = y_true.iter().copied().collect();
763        let mut sorted_labels: Vec<i32> = unique_labels.into_iter().collect();
764        sorted_labels.sort();
765        if sorted_labels.len() != n_classes {
766            return Err(MetricsError::InvalidInput(
767                "Number of unique labels in y_true must match number of classes in y_pred"
768                    .to_string(),
769            ));
770        }
771        sorted_labels
772    };
773
774    let mut total_loss = 0.0;
775    let mut total_weight = 0.0;
776
777    for (sample_idx, &true_label) in y_true.iter().enumerate() {
778        let weight = if let Some(sw) = sample_weight {
779            sw[sample_idx]
780        } else {
781            1.0
782        };
783
784        // Find the index of the _true label
785        if let Some(true_class_idx) = class_labels.iter().position(|&x| x == true_label) {
786            let true_score = y_pred[[sample_idx, true_class_idx]];
787
788            // Calculate hinge loss for this sample
789            let mut sample_loss = 0.0;
790            for (class_idx, &_class_label) in class_labels.iter().enumerate() {
791                if class_idx != true_class_idx {
792                    let class_score = y_pred[[sample_idx, class_idx]];
793                    let margin = true_score - class_score;
794                    sample_loss += (1.0 - margin).max(0.0);
795                }
796            }
797
798            total_loss += weight * sample_loss;
799            total_weight += weight;
800        } else {
801            return Err(MetricsError::InvalidInput(format!(
802                "Label {} not found in provided labels",
803                true_label
804            )));
805        }
806    }
807
808    if total_weight > 0.0 {
809        Ok(total_loss / total_weight)
810    } else {
811        Ok(0.0)
812    }
813}
814
815/// Equivalent to sklearn.metrics.zero_one_loss
816#[allow(dead_code)]
817pub fn zero_one_loss_sklearn(
818    y_true: &Array1<i32>,
819    y_pred: &Array1<i32>,
820    normalize: bool,
821    sample_weight: Option<&Array1<f64>>,
822) -> Result<f64> {
823    if y_true.len() != y_pred.len() {
824        return Err(MetricsError::InvalidInput(
825            "y_true and y_pred must have the same length".to_string(),
826        ));
827    }
828
829    if let Some(sw) = sample_weight {
830        if sw.len() != y_true.len() {
831            return Err(MetricsError::InvalidInput(
832                "sample_weight length must match y_true and y_pred length".to_string(),
833            ));
834        }
835    }
836
837    let mut total_errors = 0.0;
838    let mut total_weight = 0.0;
839
840    for (idx, (&true_val, &pred_val)) in y_true.iter().zip(y_pred.iter()).enumerate() {
841        let weight = if let Some(sw) = sample_weight {
842            sw[idx]
843        } else {
844            1.0
845        };
846
847        if true_val != pred_val {
848            total_errors += weight;
849        }
850        total_weight += weight;
851    }
852
853    if normalize {
854        if total_weight > 0.0 {
855            Ok(total_errors / total_weight)
856        } else {
857            Ok(0.0)
858        }
859    } else {
860        Ok(total_errors)
861    }
862}
863
864#[cfg(test)]
865mod tests {
866    use super::*;
867    use scirs2_core::ndarray::Array;
868
869    #[test]
870    fn test_classification_report_sklearn() {
871        let y_true = Array1::from_vec(vec![0, 1, 2, 0, 1, 2]);
872        let y_pred = Array1::from_vec(vec![0, 2, 1, 0, 0, 2]);
873
874        let report = classification_report_sklearn(&y_true, &y_pred, None, None, 2, 0.0).unwrap();
875
876        assert!(report.accuracy >= 0.0 && report.accuracy <= 1.0);
877        assert!(report.precision.len() == 3);
878        assert!(report.recall.len() == 3);
879        assert!(report.f1_score.len() == 3);
880    }
881
882    #[test]
883    fn test_precision_recall_fscore_support_sklearn() {
884        let y_true = Array1::from_vec(vec![0, 1, 2, 0, 1, 2]);
885        let y_pred = Array1::from_vec(vec![0, 2, 1, 0, 0, 2]);
886
887        let (precision, recall, fscore, support) = precision_recall_fscore_support_sklearn(
888            &y_true,
889            &y_pred,
890            1.0,
891            None,
892            None,
893            Some("macro"),
894            None,
895            0.0,
896        )
897        .unwrap();
898
899        assert_eq!(precision.len(), 1);
900        assert_eq!(recall.len(), 1);
901        assert_eq!(fscore.len(), 1);
902        assert_eq!(support.len(), 1);
903    }
904
905    #[test]
906    fn test_cohen_kappa_score_sklearn() {
907        let y1 = Array1::from_vec(vec![0, 1, 0, 1]);
908        let y2 = Array1::from_vec(vec![0, 1, 0, 1]);
909
910        let kappa = cohen_kappa_score_sklearn(&y1, &y2, None, None, None).unwrap();
911        assert!((kappa - 1.0).abs() < 1e-10); // Perfect agreement
912
913        let y3 = Array1::from_vec(vec![0, 1, 1, 0]);
914        let kappa2 = cohen_kappa_score_sklearn(&y1, &y3, None, None, None).unwrap();
915        assert!(kappa2 < 1.0); // Less than perfect agreement
916    }
917
918    #[test]
919    fn test_zero_one_loss_sklearn() {
920        let y_true = Array1::from_vec(vec![0, 1, 0, 1]);
921        let y_pred = Array1::from_vec(vec![0, 1, 1, 0]);
922
923        let loss_normalized = zero_one_loss_sklearn(&y_true, &y_pred, true, None).unwrap();
924        assert!((loss_normalized - 0.5).abs() < 1e-10); // 2 errors out of 4
925
926        let loss_count = zero_one_loss_sklearn(&y_true, &y_pred, false, None).unwrap();
927        assert!((loss_count - 2.0).abs() < 1e-10); // 2 errors
928    }
929
930    #[test]
931    fn test_multilabel_confusion_matrix_sklearn() {
932        let y_true =
933            Array2::from_shape_vec((4, 3), vec![1, 0, 1, 0, 1, 0, 1, 1, 0, 0, 0, 1]).unwrap();
934
935        let y_pred =
936            Array2::from_shape_vec((4, 3), vec![1, 0, 0, 0, 1, 1, 1, 0, 0, 1, 0, 1]).unwrap();
937
938        let confusion_matrices =
939            multilabel_confusion_matrix_sklearn(&y_true, &y_pred, None, None).unwrap();
940
941        assert_eq!(confusion_matrices.shape(), [6, 2]); // 3 labels * 2 rows each
942    }
943}