Skip to main content

sklears_preprocessing/
label_binarization.rs

1//! Label Binarization transformers
2//!
3//! This module provides transformers for binarizing labels:
4//! - LabelBinarizer: One-hot encoding for single-label classification
5//! - MultiLabelBinarizer: Binary encoding for multi-label classification
6
7use scirs2_core::ndarray::{Array1, Array2};
8use std::collections::{HashMap, HashSet};
9use std::hash::Hash;
10use std::marker::PhantomData;
11
12use sklears_core::{
13    error::{Result, SklearsError},
14    traits::{Estimator, Fit, Trained, Transform, Untrained},
15    types::Float,
16};
17
18/// Configuration for LabelBinarizer
19#[derive(Debug, Clone)]
20pub struct LabelBinarizerConfig {
21    /// Value for negative class in binary classification
22    pub neg_label: i32,
23    /// Value for positive class in binary classification
24    pub pos_label: i32,
25    /// Whether to use sparse output (not implemented)
26    pub sparse_output: bool,
27}
28
29impl Default for LabelBinarizerConfig {
30    fn default() -> Self {
31        Self {
32            neg_label: 0,
33            pos_label: 1,
34            sparse_output: false,
35        }
36    }
37}
38
39/// LabelBinarizer transforms labels to binary form
40pub struct LabelBinarizer<T: Eq + Hash + Clone = i32, State = Untrained> {
41    config: LabelBinarizerConfig,
42    state: PhantomData<State>,
43    classes_: Option<Vec<T>>,
44    class_to_index_: Option<HashMap<T, usize>>,
45}
46
47impl<T: Eq + Hash + Clone> LabelBinarizer<T, Untrained> {
48    /// Create a new LabelBinarizer with default configuration
49    pub fn new() -> Self {
50        Self {
51            config: LabelBinarizerConfig::default(),
52            state: PhantomData,
53            classes_: None,
54            class_to_index_: None,
55        }
56    }
57
58    /// Set the negative label value
59    pub fn neg_label(mut self, neg_label: i32) -> Self {
60        self.config.neg_label = neg_label;
61        self
62    }
63
64    /// Set the positive label value
65    pub fn pos_label(mut self, pos_label: i32) -> Self {
66        self.config.pos_label = pos_label;
67        self
68    }
69}
70
71impl<T: Eq + Hash + Clone> Default for LabelBinarizer<T, Untrained> {
72    fn default() -> Self {
73        Self::new()
74    }
75}
76
77impl<T: Eq + Hash + Clone> Estimator for LabelBinarizer<T, Untrained> {
78    type Config = LabelBinarizerConfig;
79    type Error = SklearsError;
80    type Float = Float;
81
82    fn config(&self) -> &Self::Config {
83        &self.config
84    }
85}
86
87impl<T: Eq + Hash + Clone> Estimator for LabelBinarizer<T, Trained> {
88    type Config = LabelBinarizerConfig;
89    type Error = SklearsError;
90    type Float = Float;
91
92    fn config(&self) -> &Self::Config {
93        &self.config
94    }
95}
96
97impl<T: Eq + Hash + Clone + Ord + Send + Sync> Fit<Array1<T>, ()> for LabelBinarizer<T, Untrained> {
98    type Fitted = LabelBinarizer<T, Trained>;
99
100    fn fit(self, y: &Array1<T>, _x: &()) -> Result<Self::Fitted> {
101        // Extract unique classes
102        let mut classes = HashSet::new();
103        for label in y.iter() {
104            classes.insert(label.clone());
105        }
106
107        // Sort classes for consistency
108        let mut sorted_classes: Vec<T> = classes.into_iter().collect();
109        sorted_classes.sort();
110
111        // Create class to index mapping
112        let class_to_index: HashMap<T, usize> = sorted_classes
113            .iter()
114            .enumerate()
115            .map(|(i, c)| (c.clone(), i))
116            .collect();
117
118        Ok(LabelBinarizer {
119            config: self.config,
120            state: PhantomData,
121            classes_: Some(sorted_classes),
122            class_to_index_: Some(class_to_index),
123        })
124    }
125}
126
127impl<T: Eq + Hash + Clone> Transform<Array1<T>, Array2<Float>> for LabelBinarizer<T, Trained> {
128    fn transform(&self, y: &Array1<T>) -> Result<Array2<Float>> {
129        let classes = self.classes_.as_ref().expect("operation should succeed");
130        let class_to_index = self
131            .class_to_index_
132            .as_ref()
133            .expect("operation should succeed");
134        let n_samples = y.len();
135        let n_classes = classes.len();
136
137        if n_classes == 0 {
138            return Err(SklearsError::InvalidInput(
139                "No classes found during fit".to_string(),
140            ));
141        }
142
143        // Special case for binary classification
144        if n_classes == 2 {
145            let mut result = Array2::zeros((n_samples, 1));
146            for (i, label) in y.iter().enumerate() {
147                if let Some(&class_idx) = class_to_index.get(label) {
148                    result[[i, 0]] = if class_idx == 1 {
149                        self.config.pos_label as Float
150                    } else {
151                        self.config.neg_label as Float
152                    };
153                } else {
154                    return Err(SklearsError::InvalidInput(
155                        "Unknown label encountered during transform".to_string(),
156                    ));
157                }
158            }
159            Ok(result)
160        } else {
161            // Multi-class case: one-hot encoding
162            let mut result =
163                Array2::from_elem((n_samples, n_classes), self.config.neg_label as Float);
164            for (i, label) in y.iter().enumerate() {
165                if let Some(&class_idx) = class_to_index.get(label) {
166                    result[[i, class_idx]] = self.config.pos_label as Float;
167                } else {
168                    return Err(SklearsError::InvalidInput(
169                        "Unknown label encountered during transform".to_string(),
170                    ));
171                }
172            }
173            Ok(result)
174        }
175    }
176}
177
178impl<T: Eq + Hash + Clone> LabelBinarizer<T, Trained> {
179    /// Get the classes
180    pub fn classes(&self) -> &Vec<T> {
181        self.classes_.as_ref().expect("operation should succeed")
182    }
183
184    /// Transform binary matrix back to original labels
185    pub fn inverse_transform(&self, y: &Array2<Float>) -> Result<Array1<T>> {
186        let classes = self.classes_.as_ref().expect("operation should succeed");
187        let n_samples = y.nrows();
188        let n_classes = classes.len();
189
190        if n_classes == 2 && y.ncols() == 1 {
191            // Binary case
192            let mut result = Vec::with_capacity(n_samples);
193            let threshold = (self.config.neg_label + self.config.pos_label) as Float / 2.0;
194
195            for i in 0..n_samples {
196                let class_idx = if y[[i, 0]] > threshold { 1 } else { 0 };
197                result.push(classes[class_idx].clone());
198            }
199            Ok(Array1::from_vec(result))
200        } else if y.ncols() == n_classes {
201            // Multi-class case
202            let mut result = Vec::with_capacity(n_samples);
203
204            for i in 0..n_samples {
205                // Find the column with the maximum value
206                let row = y.row(i);
207                let mut max_idx = 0;
208                let mut max_val = row[0];
209
210                for j in 1..n_classes {
211                    if row[j] > max_val {
212                        max_val = row[j];
213                        max_idx = j;
214                    }
215                }
216
217                result.push(classes[max_idx].clone());
218            }
219            Ok(Array1::from_vec(result))
220        } else {
221            Err(SklearsError::InvalidInput(format!(
222                "Shape mismatch: y has {} columns but {} classes were expected",
223                y.ncols(),
224                n_classes
225            )))
226        }
227    }
228}
229
230/// Configuration for MultiLabelBinarizer
231#[derive(Debug, Clone, Default)]
232pub struct MultiLabelBinarizerConfig {
233    /// Classes to consider (None = infer from data)
234    pub classes: Option<Vec<String>>,
235    /// Whether to use sparse output (not implemented)
236    pub sparse_output: bool,
237}
238
239/// MultiLabelBinarizer transforms between iterable of labels and binary matrix
240pub struct MultiLabelBinarizer<State = Untrained> {
241    config: MultiLabelBinarizerConfig,
242    state: PhantomData<State>,
243    classes_: Option<Vec<String>>,
244    class_to_index_: Option<HashMap<String, usize>>,
245}
246
247impl MultiLabelBinarizer<Untrained> {
248    /// Create a new MultiLabelBinarizer with default configuration
249    pub fn new() -> Self {
250        Self {
251            config: MultiLabelBinarizerConfig::default(),
252            state: PhantomData,
253            classes_: None,
254            class_to_index_: None,
255        }
256    }
257
258    /// Set the classes to use
259    pub fn classes(mut self, classes: Vec<String>) -> Self {
260        self.config.classes = Some(classes);
261        self
262    }
263}
264
265impl Default for MultiLabelBinarizer<Untrained> {
266    fn default() -> Self {
267        Self::new()
268    }
269}
270
271impl Estimator for MultiLabelBinarizer<Untrained> {
272    type Config = MultiLabelBinarizerConfig;
273    type Error = SklearsError;
274    type Float = Float;
275
276    fn config(&self) -> &Self::Config {
277        &self.config
278    }
279}
280
281impl Estimator for MultiLabelBinarizer<Trained> {
282    type Config = MultiLabelBinarizerConfig;
283    type Error = SklearsError;
284    type Float = Float;
285
286    fn config(&self) -> &Self::Config {
287        &self.config
288    }
289}
290
291impl Fit<Vec<Vec<String>>, ()> for MultiLabelBinarizer<Untrained> {
292    type Fitted = MultiLabelBinarizer<Trained>;
293
294    fn fit(self, y: &Vec<Vec<String>>, _x: &()) -> Result<Self::Fitted> {
295        let classes = if let Some(ref classes) = self.config.classes {
296            classes.clone()
297        } else {
298            // Infer classes from data
299            let mut unique_classes = HashSet::new();
300            for labels in y.iter() {
301                for label in labels.iter() {
302                    unique_classes.insert(label.clone());
303                }
304            }
305
306            let mut sorted_classes: Vec<String> = unique_classes.into_iter().collect();
307            sorted_classes.sort();
308            sorted_classes
309        };
310
311        // Create class to index mapping
312        let class_to_index: HashMap<String, usize> = classes
313            .iter()
314            .enumerate()
315            .map(|(i, c)| (c.clone(), i))
316            .collect();
317
318        Ok(MultiLabelBinarizer {
319            config: self.config,
320            state: PhantomData,
321            classes_: Some(classes),
322            class_to_index_: Some(class_to_index),
323        })
324    }
325}
326
327impl Transform<Vec<Vec<String>>, Array2<Float>> for MultiLabelBinarizer<Trained> {
328    fn transform(&self, y: &Vec<Vec<String>>) -> Result<Array2<Float>> {
329        let classes = self.classes_.as_ref().expect("operation should succeed");
330        let class_to_index = self
331            .class_to_index_
332            .as_ref()
333            .expect("operation should succeed");
334        let n_samples = y.len();
335        let n_classes = classes.len();
336
337        let mut result = Array2::zeros((n_samples, n_classes));
338
339        for (i, labels) in y.iter().enumerate() {
340            for label in labels.iter() {
341                if let Some(&class_idx) = class_to_index.get(label) {
342                    result[[i, class_idx]] = 1.0;
343                }
344                // Ignore unknown labels during transform
345            }
346        }
347
348        Ok(result)
349    }
350}
351
352impl MultiLabelBinarizer<Trained> {
353    /// Get the classes
354    pub fn classes(&self) -> &Vec<String> {
355        self.classes_.as_ref().expect("operation should succeed")
356    }
357
358    /// Transform binary matrix back to multi-label format
359    pub fn inverse_transform(&self, y: &Array2<Float>) -> Result<Vec<Vec<String>>> {
360        let classes = self.classes_.as_ref().expect("operation should succeed");
361        let n_samples = y.nrows();
362        let n_classes = classes.len();
363
364        if y.ncols() != n_classes {
365            return Err(SklearsError::InvalidInput(format!(
366                "Shape mismatch: y has {} columns but {} classes were expected",
367                y.ncols(),
368                n_classes
369            )));
370        }
371
372        let mut result = Vec::with_capacity(n_samples);
373
374        for i in 0..n_samples {
375            let mut labels = Vec::new();
376            for j in 0..n_classes {
377                if y[[i, j]] > 0.5 {
378                    labels.push(classes[j].clone());
379                }
380            }
381            result.push(labels);
382        }
383
384        Ok(result)
385    }
386}
387
388#[allow(non_snake_case)]
389#[cfg(test)]
390mod tests {
391    use super::*;
392    use scirs2_core::ndarray::array;
393
394    #[test]
395    fn test_label_binarizer_binary() {
396        let y = array![1, 0, 1, 0, 1];
397
398        let binarizer = LabelBinarizer::new()
399            .fit(&y, &())
400            .expect("model fitting should succeed");
401
402        let y_bin = binarizer
403            .transform(&y)
404            .expect("transformation should succeed");
405
406        // Binary case: should have 1 column
407        assert_eq!(y_bin.shape(), &[5, 1]);
408        assert_eq!(y_bin[[0, 0]], 1.0);
409        assert_eq!(y_bin[[1, 0]], 0.0);
410        assert_eq!(y_bin[[2, 0]], 1.0);
411    }
412
413    #[test]
414    fn test_label_binarizer_multiclass() {
415        let y = array![0, 1, 2, 1, 0];
416
417        let binarizer = LabelBinarizer::new()
418            .fit(&y, &())
419            .expect("model fitting should succeed");
420
421        let y_bin = binarizer
422            .transform(&y)
423            .expect("transformation should succeed");
424
425        // Multiclass case: should have 3 columns
426        assert_eq!(y_bin.shape(), &[5, 3]);
427        // First sample is class 0
428        assert_eq!(y_bin.row(0).to_vec(), vec![1.0, 0.0, 0.0]);
429        // Second sample is class 1
430        assert_eq!(y_bin.row(1).to_vec(), vec![0.0, 1.0, 0.0]);
431        // Third sample is class 2
432        assert_eq!(y_bin.row(2).to_vec(), vec![0.0, 0.0, 1.0]);
433    }
434
435    #[test]
436    fn test_label_binarizer_inverse_transform() {
437        let y = array!["cat", "dog", "cat", "bird", "dog"];
438
439        let binarizer = LabelBinarizer::new()
440            .fit(&y, &())
441            .expect("model fitting should succeed");
442
443        let y_bin = binarizer
444            .transform(&y)
445            .expect("transformation should succeed");
446        let y_inv = binarizer
447            .inverse_transform(&y_bin)
448            .expect("operation should succeed");
449
450        assert_eq!(y, y_inv);
451    }
452
453    #[test]
454    fn test_label_binarizer_custom_labels() {
455        let y = array![1, 0, 1, 0];
456
457        let binarizer = LabelBinarizer::new()
458            .neg_label(-1)
459            .pos_label(1)
460            .fit(&y, &())
461            .expect("operation should succeed");
462
463        let y_bin = binarizer
464            .transform(&y)
465            .expect("transformation should succeed");
466
467        assert_eq!(y_bin[[0, 0]], 1.0); // positive class
468        assert_eq!(y_bin[[1, 0]], -1.0); // negative class
469    }
470
471    #[test]
472    fn test_multilabel_binarizer() {
473        let y = vec![
474            vec!["sci-fi".to_string(), "thriller".to_string()],
475            vec!["comedy".to_string()],
476            vec!["sci-fi".to_string(), "comedy".to_string()],
477        ];
478
479        let binarizer = MultiLabelBinarizer::new()
480            .fit(&y, &())
481            .expect("model fitting should succeed");
482
483        let y_bin = binarizer
484            .transform(&y)
485            .expect("transformation should succeed");
486
487        // Should have 3 samples, 3 classes
488        assert_eq!(y_bin.shape(), &[3, 3]);
489        let classes = binarizer.classes();
490        assert_eq!(classes.len(), 3);
491
492        // First sample has sci-fi and thriller
493        let row0_sum: Float = y_bin.row(0).sum();
494        assert_eq!(row0_sum, 2.0);
495
496        // Second sample has only comedy
497        let row1_sum: Float = y_bin.row(1).sum();
498        assert_eq!(row1_sum, 1.0);
499    }
500
501    #[test]
502    fn test_multilabel_binarizer_inverse() {
503        let y = vec![
504            vec!["red".to_string(), "blue".to_string()],
505            vec!["green".to_string()],
506            vec!["red".to_string(), "green".to_string()],
507        ];
508
509        let binarizer = MultiLabelBinarizer::new()
510            .fit(&y, &())
511            .expect("model fitting should succeed");
512
513        let y_bin = binarizer
514            .transform(&y)
515            .expect("transformation should succeed");
516        let y_inv = binarizer
517            .inverse_transform(&y_bin)
518            .expect("operation should succeed");
519
520        // Check that we get back the same labels (order might differ)
521        for (original, reconstructed) in y.iter().zip(y_inv.iter()) {
522            let orig_set: HashSet<_> = original.iter().collect();
523            let recon_set: HashSet<_> = reconstructed.iter().collect();
524            assert_eq!(orig_set, recon_set);
525        }
526    }
527
528    #[test]
529    fn test_multilabel_binarizer_with_classes() {
530        let y = vec![
531            vec!["a".to_string(), "b".to_string()],
532            vec!["c".to_string()],
533        ];
534
535        let classes = vec![
536            "a".to_string(),
537            "b".to_string(),
538            "c".to_string(),
539            "d".to_string(),
540        ];
541
542        let binarizer = MultiLabelBinarizer::new()
543            .classes(classes.clone())
544            .fit(&y, &())
545            .expect("operation should succeed");
546
547        let y_bin = binarizer
548            .transform(&y)
549            .expect("transformation should succeed");
550
551        // Should have 4 columns (including 'd' which wasn't in the data)
552        assert_eq!(y_bin.shape(), &[2, 4]);
553        assert_eq!(binarizer.classes(), &classes);
554    }
555}