sklears_utils/
multiclass.rs

1use crate::{UtilsError, UtilsResult};
2use scirs2_core::ndarray::{Array1, Array2, Axis};
3use std::collections::{HashMap, HashSet};
4
5#[derive(Debug, Clone, Copy, PartialEq)]
6pub enum MultiClassStrategy {
7    OneVsRest,
8    OneVsOne,
9}
10
11pub struct OneVsRestClassifier<C> {
12    pub estimators: Vec<C>,
13    pub classes: Vec<i32>,
14    pub strategy: MultiClassStrategy,
15}
16
17impl<C> Default for OneVsRestClassifier<C> {
18    fn default() -> Self {
19        Self::new()
20    }
21}
22
23impl<C> OneVsRestClassifier<C> {
24    pub fn new() -> Self {
25        Self {
26            estimators: Vec::new(),
27            classes: Vec::new(),
28            strategy: MultiClassStrategy::OneVsRest,
29        }
30    }
31}
32
33pub fn type_of_target(y: &Array1<i32>) -> UtilsResult<String> {
34    if y.is_empty() {
35        return Err(UtilsError::EmptyInput);
36    }
37
38    let unique_values: HashSet<i32> = y.iter().copied().collect();
39    let n_unique = unique_values.len();
40
41    if n_unique == 1 {
42        Ok("unknown".to_string())
43    } else if n_unique == 2 {
44        Ok("binary".to_string())
45    } else {
46        Ok("multiclass".to_string())
47    }
48}
49
50pub fn check_classification_targets(y: &Array1<i32>) -> UtilsResult<()> {
51    let target_type = type_of_target(y)?;
52
53    match target_type.as_str() {
54        "binary" | "multiclass" => Ok(()),
55        "unknown" => Err(UtilsError::InvalidParameter(
56            "Unknown label type: all samples have the same label".to_string(),
57        )),
58        _ => Err(UtilsError::InvalidParameter(format!(
59            "Unknown target type: {target_type}"
60        ))),
61    }
62}
63
64pub fn unique_labels_multiclass(y: &Array1<i32>) -> Vec<i32> {
65    let mut unique: Vec<i32> = y.iter().copied().collect();
66    unique.sort();
67    unique.dedup();
68    unique
69}
70
71pub fn class_distribution(y: &Array1<i32>) -> HashMap<i32, usize> {
72    let mut counts = HashMap::new();
73    for &label in y.iter() {
74        *counts.entry(label).or_insert(0) += 1;
75    }
76    counts
77}
78
79pub fn check_multi_class(y: &Array1<i32>) -> UtilsResult<bool> {
80    let unique_labels = unique_labels_multiclass(y);
81
82    if unique_labels.len() < 2 {
83        Err(UtilsError::InvalidParameter(
84            "Need at least 2 classes for classification".to_string(),
85        ))
86    } else {
87        Ok(unique_labels.len() > 2)
88    }
89}
90
91pub fn one_vs_rest_transform(y: &Array1<i32>, positive_class: i32) -> Array1<i32> {
92    y.mapv(|label| if label == positive_class { 1 } else { 0 })
93}
94
95pub fn one_vs_one_pairs(classes: &[i32]) -> Vec<(i32, i32)> {
96    let mut pairs = Vec::new();
97
98    for (i, &class_a) in classes.iter().enumerate() {
99        for &class_b in classes.iter().skip(i + 1) {
100            pairs.push((class_a, class_b));
101        }
102    }
103
104    pairs
105}
106
107pub fn one_vs_one_transform(
108    y: &Array1<i32>,
109    class_a: i32,
110    class_b: i32,
111) -> (Array1<i32>, Vec<usize>) {
112    let mut new_y = Vec::new();
113    let mut indices = Vec::new();
114
115    for (i, &label) in y.iter().enumerate() {
116        if label == class_a || label == class_b {
117            new_y.push(if label == class_a { 0 } else { 1 });
118            indices.push(i);
119        }
120    }
121
122    (Array1::from_vec(new_y), indices)
123}
124
125pub fn is_multilabel(y: &Array2<i32>) -> bool {
126    y.ncols() > 1
127}
128
129pub fn multilabel_to_indicator(y: &Array1<i32>, classes: &[i32]) -> UtilsResult<Array2<i32>> {
130    let n_samples = y.len();
131    let n_classes = classes.len();
132    let mut indicator = Array2::zeros((n_samples, n_classes));
133
134    for (i, &label) in y.iter().enumerate() {
135        if let Some(class_idx) = classes.iter().position(|&c| c == label) {
136            indicator[[i, class_idx]] = 1;
137        } else {
138            return Err(UtilsError::InvalidParameter(format!(
139                "Label {label} not found in classes"
140            )));
141        }
142    }
143
144    Ok(indicator)
145}
146
147pub fn indicator_to_multilabel(
148    indicator: &Array2<i32>,
149    classes: &[i32],
150) -> UtilsResult<Vec<Vec<i32>>> {
151    if indicator.ncols() != classes.len() {
152        return Err(UtilsError::ShapeMismatch {
153            expected: vec![indicator.nrows(), classes.len()],
154            actual: vec![indicator.nrows(), indicator.ncols()],
155        });
156    }
157
158    let mut result = Vec::new();
159
160    for row in indicator.axis_iter(Axis(0)) {
161        let mut labels = Vec::new();
162        for (j, &value) in row.iter().enumerate() {
163            if value == 1 {
164                labels.push(classes[j]);
165            }
166        }
167        result.push(labels);
168    }
169
170    Ok(result)
171}
172
173pub fn check_binary_indicators_multioutput(y: &Array2<i32>) -> UtilsResult<()> {
174    for value in y.iter() {
175        if *value != 0 && *value != 1 {
176            return Err(UtilsError::InvalidParameter(
177                "Binary indicators must contain only 0 and 1".to_string(),
178            ));
179        }
180    }
181    Ok(())
182}
183
184pub fn compute_class_weight_balanced(y: &Array1<i32>) -> HashMap<i32, f64> {
185    let class_counts = class_distribution(y);
186    let n_samples = y.len() as f64;
187    let n_classes = class_counts.len() as f64;
188
189    let mut weights = HashMap::new();
190    for (&class, &count) in &class_counts {
191        weights.insert(class, n_samples / (n_classes * count as f64));
192    }
193
194    weights
195}
196
197pub fn compute_sample_weight(y: &Array1<i32>, class_weight: &HashMap<i32, f64>) -> Array1<f64> {
198    let mut sample_weights = Array1::zeros(y.len());
199
200    for (i, &label) in y.iter().enumerate() {
201        sample_weights[i] = class_weight.get(&label).copied().unwrap_or(1.0);
202    }
203
204    sample_weights
205}
206
207#[allow(non_snake_case)]
208#[cfg(test)]
209mod tests {
210    use super::*;
211    use scirs2_core::ndarray::array;
212
213    #[test]
214    fn test_type_of_target() {
215        let binary = array![0, 1, 0, 1];
216        assert_eq!(type_of_target(&binary).unwrap(), "binary");
217
218        let multiclass = array![0, 1, 2, 0, 1, 2];
219        assert_eq!(type_of_target(&multiclass).unwrap(), "multiclass");
220
221        let constant = array![1, 1, 1, 1];
222        assert_eq!(type_of_target(&constant).unwrap(), "unknown");
223    }
224
225    #[test]
226    fn test_check_multi_class() {
227        let binary = array![0, 1, 0, 1];
228        assert!(!check_multi_class(&binary).unwrap());
229
230        let multiclass = array![0, 1, 2, 0, 1, 2];
231        assert!(check_multi_class(&multiclass).unwrap());
232    }
233
234    #[test]
235    fn test_one_vs_rest_transform() {
236        let y = array![0, 1, 2, 0, 1, 2];
237        let binary_y = one_vs_rest_transform(&y, 1);
238        assert_eq!(binary_y, array![0, 1, 0, 0, 1, 0]);
239    }
240
241    #[test]
242    fn test_one_vs_one_pairs() {
243        let classes = vec![0, 1, 2];
244        let pairs = one_vs_one_pairs(&classes);
245        assert_eq!(pairs, vec![(0, 1), (0, 2), (1, 2)]);
246    }
247
248    #[test]
249    fn test_one_vs_one_transform() {
250        let y = array![0, 1, 2, 0, 1, 2];
251        let (binary_y, indices) = one_vs_one_transform(&y, 0, 2);
252        assert_eq!(binary_y, array![0, 1, 0, 1]);
253        assert_eq!(indices, vec![0, 2, 3, 5]);
254    }
255
256    #[test]
257    fn test_multilabel_to_indicator() {
258        let y = array![0, 1, 2];
259        let classes = vec![0, 1, 2];
260        let indicator = multilabel_to_indicator(&y, &classes).unwrap();
261
262        let expected = Array2::from_shape_vec((3, 3), vec![1, 0, 0, 0, 1, 0, 0, 0, 1]).unwrap();
263        assert_eq!(indicator, expected);
264    }
265
266    #[test]
267    fn test_compute_class_weight_balanced() {
268        let y = array![0, 0, 1, 1, 1, 2]; // Class distribution: 0->2, 1->3, 2->1
269        let weights = compute_class_weight_balanced(&y);
270
271        // Expected: n_samples / (n_classes * class_count) = 6 / (3 * count)
272        assert!((weights[&0] - 1.0).abs() < 1e-10); // 6 / (3 * 2) = 1.0
273        assert!((weights[&1] - 2.0 / 3.0).abs() < 1e-10); // 6 / (3 * 3) = 2/3
274        assert!((weights[&2] - 2.0).abs() < 1e-10); // 6 / (3 * 1) = 2.0
275    }
276
277    #[test]
278    fn test_class_distribution() {
279        let y = array![0, 1, 0, 2, 1, 1];
280        let dist = class_distribution(&y);
281
282        assert_eq!(dist[&0], 2);
283        assert_eq!(dist[&1], 3);
284        assert_eq!(dist[&2], 1);
285    }
286}