sklears_preprocessing/
label_binarization.rs

1//! Label Binarization transformers
2//!
3//! This module provides transformers for binarizing labels:
4//! - LabelBinarizer: One-hot encoding for single-label classification
5//! - MultiLabelBinarizer: Binary encoding for multi-label classification
6
7use scirs2_core::ndarray::{Array1, Array2};
8use std::collections::{HashMap, HashSet};
9use std::hash::Hash;
10use std::marker::PhantomData;
11
12use sklears_core::{
13    error::{Result, SklearsError},
14    traits::{Estimator, Fit, Trained, Transform, Untrained},
15    types::Float,
16};
17
18/// Configuration for LabelBinarizer
19#[derive(Debug, Clone)]
20pub struct LabelBinarizerConfig {
21    /// Value for negative class in binary classification
22    pub neg_label: i32,
23    /// Value for positive class in binary classification
24    pub pos_label: i32,
25    /// Whether to use sparse output (not implemented)
26    pub sparse_output: bool,
27}
28
29impl Default for LabelBinarizerConfig {
30    fn default() -> Self {
31        Self {
32            neg_label: 0,
33            pos_label: 1,
34            sparse_output: false,
35        }
36    }
37}
38
39/// LabelBinarizer transforms labels to binary form
40pub struct LabelBinarizer<T: Eq + Hash + Clone = i32, State = Untrained> {
41    config: LabelBinarizerConfig,
42    state: PhantomData<State>,
43    classes_: Option<Vec<T>>,
44    class_to_index_: Option<HashMap<T, usize>>,
45}
46
47impl<T: Eq + Hash + Clone> LabelBinarizer<T, Untrained> {
48    /// Create a new LabelBinarizer with default configuration
49    pub fn new() -> Self {
50        Self {
51            config: LabelBinarizerConfig::default(),
52            state: PhantomData,
53            classes_: None,
54            class_to_index_: None,
55        }
56    }
57
58    /// Set the negative label value
59    pub fn neg_label(mut self, neg_label: i32) -> Self {
60        self.config.neg_label = neg_label;
61        self
62    }
63
64    /// Set the positive label value
65    pub fn pos_label(mut self, pos_label: i32) -> Self {
66        self.config.pos_label = pos_label;
67        self
68    }
69}
70
71impl<T: Eq + Hash + Clone> Default for LabelBinarizer<T, Untrained> {
72    fn default() -> Self {
73        Self::new()
74    }
75}
76
77impl<T: Eq + Hash + Clone> Estimator for LabelBinarizer<T, Untrained> {
78    type Config = LabelBinarizerConfig;
79    type Error = SklearsError;
80    type Float = Float;
81
82    fn config(&self) -> &Self::Config {
83        &self.config
84    }
85}
86
87impl<T: Eq + Hash + Clone> Estimator for LabelBinarizer<T, Trained> {
88    type Config = LabelBinarizerConfig;
89    type Error = SklearsError;
90    type Float = Float;
91
92    fn config(&self) -> &Self::Config {
93        &self.config
94    }
95}
96
97impl<T: Eq + Hash + Clone + Ord + Send + Sync> Fit<Array1<T>, ()> for LabelBinarizer<T, Untrained> {
98    type Fitted = LabelBinarizer<T, Trained>;
99
100    fn fit(self, y: &Array1<T>, _x: &()) -> Result<Self::Fitted> {
101        // Extract unique classes
102        let mut classes = HashSet::new();
103        for label in y.iter() {
104            classes.insert(label.clone());
105        }
106
107        // Sort classes for consistency
108        let mut sorted_classes: Vec<T> = classes.into_iter().collect();
109        sorted_classes.sort();
110
111        // Create class to index mapping
112        let class_to_index: HashMap<T, usize> = sorted_classes
113            .iter()
114            .enumerate()
115            .map(|(i, c)| (c.clone(), i))
116            .collect();
117
118        Ok(LabelBinarizer {
119            config: self.config,
120            state: PhantomData,
121            classes_: Some(sorted_classes),
122            class_to_index_: Some(class_to_index),
123        })
124    }
125}
126
127impl<T: Eq + Hash + Clone> Transform<Array1<T>, Array2<Float>> for LabelBinarizer<T, Trained> {
128    fn transform(&self, y: &Array1<T>) -> Result<Array2<Float>> {
129        let classes = self.classes_.as_ref().unwrap();
130        let class_to_index = self.class_to_index_.as_ref().unwrap();
131        let n_samples = y.len();
132        let n_classes = classes.len();
133
134        if n_classes == 0 {
135            return Err(SklearsError::InvalidInput(
136                "No classes found during fit".to_string(),
137            ));
138        }
139
140        // Special case for binary classification
141        if n_classes == 2 {
142            let mut result = Array2::zeros((n_samples, 1));
143            for (i, label) in y.iter().enumerate() {
144                if let Some(&class_idx) = class_to_index.get(label) {
145                    result[[i, 0]] = if class_idx == 1 {
146                        self.config.pos_label as Float
147                    } else {
148                        self.config.neg_label as Float
149                    };
150                } else {
151                    return Err(SklearsError::InvalidInput(
152                        "Unknown label encountered during transform".to_string(),
153                    ));
154                }
155            }
156            Ok(result)
157        } else {
158            // Multi-class case: one-hot encoding
159            let mut result =
160                Array2::from_elem((n_samples, n_classes), self.config.neg_label as Float);
161            for (i, label) in y.iter().enumerate() {
162                if let Some(&class_idx) = class_to_index.get(label) {
163                    result[[i, class_idx]] = self.config.pos_label as Float;
164                } else {
165                    return Err(SklearsError::InvalidInput(
166                        "Unknown label encountered during transform".to_string(),
167                    ));
168                }
169            }
170            Ok(result)
171        }
172    }
173}
174
175impl<T: Eq + Hash + Clone> LabelBinarizer<T, Trained> {
176    /// Get the classes
177    pub fn classes(&self) -> &Vec<T> {
178        self.classes_.as_ref().unwrap()
179    }
180
181    /// Transform binary matrix back to original labels
182    pub fn inverse_transform(&self, y: &Array2<Float>) -> Result<Array1<T>> {
183        let classes = self.classes_.as_ref().unwrap();
184        let n_samples = y.nrows();
185        let n_classes = classes.len();
186
187        if n_classes == 2 && y.ncols() == 1 {
188            // Binary case
189            let mut result = Vec::with_capacity(n_samples);
190            let threshold = (self.config.neg_label + self.config.pos_label) as Float / 2.0;
191
192            for i in 0..n_samples {
193                let class_idx = if y[[i, 0]] > threshold { 1 } else { 0 };
194                result.push(classes[class_idx].clone());
195            }
196            Ok(Array1::from_vec(result))
197        } else if y.ncols() == n_classes {
198            // Multi-class case
199            let mut result = Vec::with_capacity(n_samples);
200
201            for i in 0..n_samples {
202                // Find the column with the maximum value
203                let row = y.row(i);
204                let mut max_idx = 0;
205                let mut max_val = row[0];
206
207                for j in 1..n_classes {
208                    if row[j] > max_val {
209                        max_val = row[j];
210                        max_idx = j;
211                    }
212                }
213
214                result.push(classes[max_idx].clone());
215            }
216            Ok(Array1::from_vec(result))
217        } else {
218            Err(SklearsError::InvalidInput(format!(
219                "Shape mismatch: y has {} columns but {} classes were expected",
220                y.ncols(),
221                n_classes
222            )))
223        }
224    }
225}
226
227/// Configuration for MultiLabelBinarizer
228#[derive(Debug, Clone, Default)]
229pub struct MultiLabelBinarizerConfig {
230    /// Classes to consider (None = infer from data)
231    pub classes: Option<Vec<String>>,
232    /// Whether to use sparse output (not implemented)
233    pub sparse_output: bool,
234}
235
236/// MultiLabelBinarizer transforms between iterable of labels and binary matrix
237pub struct MultiLabelBinarizer<State = Untrained> {
238    config: MultiLabelBinarizerConfig,
239    state: PhantomData<State>,
240    classes_: Option<Vec<String>>,
241    class_to_index_: Option<HashMap<String, usize>>,
242}
243
244impl MultiLabelBinarizer<Untrained> {
245    /// Create a new MultiLabelBinarizer with default configuration
246    pub fn new() -> Self {
247        Self {
248            config: MultiLabelBinarizerConfig::default(),
249            state: PhantomData,
250            classes_: None,
251            class_to_index_: None,
252        }
253    }
254
255    /// Set the classes to use
256    pub fn classes(mut self, classes: Vec<String>) -> Self {
257        self.config.classes = Some(classes);
258        self
259    }
260}
261
262impl Default for MultiLabelBinarizer<Untrained> {
263    fn default() -> Self {
264        Self::new()
265    }
266}
267
268impl Estimator for MultiLabelBinarizer<Untrained> {
269    type Config = MultiLabelBinarizerConfig;
270    type Error = SklearsError;
271    type Float = Float;
272
273    fn config(&self) -> &Self::Config {
274        &self.config
275    }
276}
277
278impl Estimator for MultiLabelBinarizer<Trained> {
279    type Config = MultiLabelBinarizerConfig;
280    type Error = SklearsError;
281    type Float = Float;
282
283    fn config(&self) -> &Self::Config {
284        &self.config
285    }
286}
287
288impl Fit<Vec<Vec<String>>, ()> for MultiLabelBinarizer<Untrained> {
289    type Fitted = MultiLabelBinarizer<Trained>;
290
291    fn fit(self, y: &Vec<Vec<String>>, _x: &()) -> Result<Self::Fitted> {
292        let classes = if let Some(ref classes) = self.config.classes {
293            classes.clone()
294        } else {
295            // Infer classes from data
296            let mut unique_classes = HashSet::new();
297            for labels in y.iter() {
298                for label in labels.iter() {
299                    unique_classes.insert(label.clone());
300                }
301            }
302
303            let mut sorted_classes: Vec<String> = unique_classes.into_iter().collect();
304            sorted_classes.sort();
305            sorted_classes
306        };
307
308        // Create class to index mapping
309        let class_to_index: HashMap<String, usize> = classes
310            .iter()
311            .enumerate()
312            .map(|(i, c)| (c.clone(), i))
313            .collect();
314
315        Ok(MultiLabelBinarizer {
316            config: self.config,
317            state: PhantomData,
318            classes_: Some(classes),
319            class_to_index_: Some(class_to_index),
320        })
321    }
322}
323
324impl Transform<Vec<Vec<String>>, Array2<Float>> for MultiLabelBinarizer<Trained> {
325    fn transform(&self, y: &Vec<Vec<String>>) -> Result<Array2<Float>> {
326        let classes = self.classes_.as_ref().unwrap();
327        let class_to_index = self.class_to_index_.as_ref().unwrap();
328        let n_samples = y.len();
329        let n_classes = classes.len();
330
331        let mut result = Array2::zeros((n_samples, n_classes));
332
333        for (i, labels) in y.iter().enumerate() {
334            for label in labels.iter() {
335                if let Some(&class_idx) = class_to_index.get(label) {
336                    result[[i, class_idx]] = 1.0;
337                }
338                // Ignore unknown labels during transform
339            }
340        }
341
342        Ok(result)
343    }
344}
345
346impl MultiLabelBinarizer<Trained> {
347    /// Get the classes
348    pub fn classes(&self) -> &Vec<String> {
349        self.classes_.as_ref().unwrap()
350    }
351
352    /// Transform binary matrix back to multi-label format
353    pub fn inverse_transform(&self, y: &Array2<Float>) -> Result<Vec<Vec<String>>> {
354        let classes = self.classes_.as_ref().unwrap();
355        let n_samples = y.nrows();
356        let n_classes = classes.len();
357
358        if y.ncols() != n_classes {
359            return Err(SklearsError::InvalidInput(format!(
360                "Shape mismatch: y has {} columns but {} classes were expected",
361                y.ncols(),
362                n_classes
363            )));
364        }
365
366        let mut result = Vec::with_capacity(n_samples);
367
368        for i in 0..n_samples {
369            let mut labels = Vec::new();
370            for j in 0..n_classes {
371                if y[[i, j]] > 0.5 {
372                    labels.push(classes[j].clone());
373                }
374            }
375            result.push(labels);
376        }
377
378        Ok(result)
379    }
380}
381
382#[allow(non_snake_case)]
383#[cfg(test)]
384mod tests {
385    use super::*;
386    use scirs2_core::ndarray::array;
387
388    #[test]
389    fn test_label_binarizer_binary() {
390        let y = array![1, 0, 1, 0, 1];
391
392        let binarizer = LabelBinarizer::new().fit(&y, &()).unwrap();
393
394        let y_bin = binarizer.transform(&y).unwrap();
395
396        // Binary case: should have 1 column
397        assert_eq!(y_bin.shape(), &[5, 1]);
398        assert_eq!(y_bin[[0, 0]], 1.0);
399        assert_eq!(y_bin[[1, 0]], 0.0);
400        assert_eq!(y_bin[[2, 0]], 1.0);
401    }
402
403    #[test]
404    fn test_label_binarizer_multiclass() {
405        let y = array![0, 1, 2, 1, 0];
406
407        let binarizer = LabelBinarizer::new().fit(&y, &()).unwrap();
408
409        let y_bin = binarizer.transform(&y).unwrap();
410
411        // Multiclass case: should have 3 columns
412        assert_eq!(y_bin.shape(), &[5, 3]);
413        // First sample is class 0
414        assert_eq!(y_bin.row(0).to_vec(), vec![1.0, 0.0, 0.0]);
415        // Second sample is class 1
416        assert_eq!(y_bin.row(1).to_vec(), vec![0.0, 1.0, 0.0]);
417        // Third sample is class 2
418        assert_eq!(y_bin.row(2).to_vec(), vec![0.0, 0.0, 1.0]);
419    }
420
421    #[test]
422    fn test_label_binarizer_inverse_transform() {
423        let y = array!["cat", "dog", "cat", "bird", "dog"];
424
425        let binarizer = LabelBinarizer::new().fit(&y, &()).unwrap();
426
427        let y_bin = binarizer.transform(&y).unwrap();
428        let y_inv = binarizer.inverse_transform(&y_bin).unwrap();
429
430        assert_eq!(y, y_inv);
431    }
432
433    #[test]
434    fn test_label_binarizer_custom_labels() {
435        let y = array![1, 0, 1, 0];
436
437        let binarizer = LabelBinarizer::new()
438            .neg_label(-1)
439            .pos_label(1)
440            .fit(&y, &())
441            .unwrap();
442
443        let y_bin = binarizer.transform(&y).unwrap();
444
445        assert_eq!(y_bin[[0, 0]], 1.0); // positive class
446        assert_eq!(y_bin[[1, 0]], -1.0); // negative class
447    }
448
449    #[test]
450    fn test_multilabel_binarizer() {
451        let y = vec![
452            vec!["sci-fi".to_string(), "thriller".to_string()],
453            vec!["comedy".to_string()],
454            vec!["sci-fi".to_string(), "comedy".to_string()],
455        ];
456
457        let binarizer = MultiLabelBinarizer::new().fit(&y, &()).unwrap();
458
459        let y_bin = binarizer.transform(&y).unwrap();
460
461        // Should have 3 samples, 3 classes
462        assert_eq!(y_bin.shape(), &[3, 3]);
463        let classes = binarizer.classes();
464        assert_eq!(classes.len(), 3);
465
466        // First sample has sci-fi and thriller
467        let row0_sum: Float = y_bin.row(0).sum();
468        assert_eq!(row0_sum, 2.0);
469
470        // Second sample has only comedy
471        let row1_sum: Float = y_bin.row(1).sum();
472        assert_eq!(row1_sum, 1.0);
473    }
474
475    #[test]
476    fn test_multilabel_binarizer_inverse() {
477        let y = vec![
478            vec!["red".to_string(), "blue".to_string()],
479            vec!["green".to_string()],
480            vec!["red".to_string(), "green".to_string()],
481        ];
482
483        let binarizer = MultiLabelBinarizer::new().fit(&y, &()).unwrap();
484
485        let y_bin = binarizer.transform(&y).unwrap();
486        let y_inv = binarizer.inverse_transform(&y_bin).unwrap();
487
488        // Check that we get back the same labels (order might differ)
489        for (original, reconstructed) in y.iter().zip(y_inv.iter()) {
490            let orig_set: HashSet<_> = original.iter().collect();
491            let recon_set: HashSet<_> = reconstructed.iter().collect();
492            assert_eq!(orig_set, recon_set);
493        }
494    }
495
496    #[test]
497    fn test_multilabel_binarizer_with_classes() {
498        let y = vec![
499            vec!["a".to_string(), "b".to_string()],
500            vec!["c".to_string()],
501        ];
502
503        let classes = vec![
504            "a".to_string(),
505            "b".to_string(),
506            "c".to_string(),
507            "d".to_string(),
508        ];
509
510        let binarizer = MultiLabelBinarizer::new()
511            .classes(classes.clone())
512            .fit(&y, &())
513            .unwrap();
514
515        let y_bin = binarizer.transform(&y).unwrap();
516
517        // Should have 4 columns (including 'd' which wasn't in the data)
518        assert_eq!(y_bin.shape(), &[2, 4]);
519        assert_eq!(binarizer.classes(), &classes);
520    }
521}