Skip to main content

scry_learn/preprocess/
encoder.rs

1// SPDX-License-Identifier: MIT OR Apache-2.0
2//! Label encoding for categorical variables.
3
4use crate::error::{Result, ScryLearnError};
5
6/// Encode string labels as integer indices.
7///
8/// Maintains a bidirectional mapping between labels and their numeric indices.
9#[derive(Clone, Debug)]
10#[non_exhaustive]
11pub struct LabelEncoder {
12    classes: Vec<String>,
13    fitted: bool,
14}
15
16impl LabelEncoder {
17    /// Create a new unfitted encoder.
18    pub fn new() -> Self {
19        Self {
20            classes: Vec::new(),
21            fitted: false,
22        }
23    }
24
25    /// Fit the encoder on a set of string labels.
26    pub fn fit(&mut self, labels: &[&str]) {
27        let mut unique: Vec<String> = labels
28            .iter()
29            .map(std::string::ToString::to_string)
30            .collect();
31        unique.sort();
32        unique.dedup();
33        self.classes = unique;
34        self.fitted = true;
35    }
36
37    /// Transform string labels to numeric indices.
38    pub fn transform(&self, labels: &[&str]) -> Result<Vec<f64>> {
39        if !self.fitted {
40            return Err(ScryLearnError::NotFitted);
41        }
42        labels
43            .iter()
44            .map(|&label| {
45                self.classes
46                    .iter()
47                    .position(|c| c == label)
48                    .map(|i| i as f64)
49                    .ok_or_else(|| {
50                        ScryLearnError::InvalidParameter(format!("unknown label: {label}"))
51                    })
52            })
53            .collect()
54    }
55
56    /// Reverse-transform numeric indices back to string labels.
57    pub fn inverse_transform(&self, indices: &[f64]) -> Result<Vec<String>> {
58        if !self.fitted {
59            return Err(ScryLearnError::NotFitted);
60        }
61        indices
62            .iter()
63            .map(|&idx| {
64                let i = idx as usize;
65                self.classes.get(i).cloned().ok_or_else(|| {
66                    ScryLearnError::InvalidParameter(format!("index out of range: {i}"))
67                })
68            })
69            .collect()
70    }
71
72    /// Get the list of known classes.
73    pub fn classes(&self) -> &[String] {
74        &self.classes
75    }
76
77    /// Number of known classes.
78    pub fn n_classes(&self) -> usize {
79        self.classes.len()
80    }
81}
82
83impl Default for LabelEncoder {
84    fn default() -> Self {
85        Self::new()
86    }
87}
88
89#[cfg(test)]
90mod tests {
91    use super::*;
92
93    #[test]
94    fn test_label_encoder_roundtrip() {
95        let mut enc = LabelEncoder::new();
96        enc.fit(&["cat", "dog", "bird", "cat"]);
97        assert_eq!(enc.n_classes(), 3);
98
99        let encoded = enc.transform(&["dog", "cat", "bird"]).unwrap();
100        assert_eq!(encoded, vec![2.0, 1.0, 0.0]); // sorted: bird=0, cat=1, dog=2
101
102        let decoded = enc.inverse_transform(&encoded).unwrap();
103        assert_eq!(decoded, vec!["dog", "cat", "bird"]);
104    }
105
106    #[test]
107    fn test_label_encoder_unknown() {
108        let mut enc = LabelEncoder::new();
109        enc.fit(&["a", "b"]);
110        assert!(enc.transform(&["c"]).is_err());
111    }
112}