scry_learn/preprocess/
encoder.rs1use crate::error::{Result, ScryLearnError};
5
6#[derive(Clone, Debug)]
10#[non_exhaustive]
11pub struct LabelEncoder {
12 classes: Vec<String>,
13 fitted: bool,
14}
15
16impl LabelEncoder {
17 pub fn new() -> Self {
19 Self {
20 classes: Vec::new(),
21 fitted: false,
22 }
23 }
24
25 pub fn fit(&mut self, labels: &[&str]) {
27 let mut unique: Vec<String> = labels
28 .iter()
29 .map(std::string::ToString::to_string)
30 .collect();
31 unique.sort();
32 unique.dedup();
33 self.classes = unique;
34 self.fitted = true;
35 }
36
37 pub fn transform(&self, labels: &[&str]) -> Result<Vec<f64>> {
39 if !self.fitted {
40 return Err(ScryLearnError::NotFitted);
41 }
42 labels
43 .iter()
44 .map(|&label| {
45 self.classes
46 .iter()
47 .position(|c| c == label)
48 .map(|i| i as f64)
49 .ok_or_else(|| {
50 ScryLearnError::InvalidParameter(format!("unknown label: {label}"))
51 })
52 })
53 .collect()
54 }
55
56 pub fn inverse_transform(&self, indices: &[f64]) -> Result<Vec<String>> {
58 if !self.fitted {
59 return Err(ScryLearnError::NotFitted);
60 }
61 indices
62 .iter()
63 .map(|&idx| {
64 let i = idx as usize;
65 self.classes.get(i).cloned().ok_or_else(|| {
66 ScryLearnError::InvalidParameter(format!("index out of range: {i}"))
67 })
68 })
69 .collect()
70 }
71
72 pub fn classes(&self) -> &[String] {
74 &self.classes
75 }
76
77 pub fn n_classes(&self) -> usize {
79 self.classes.len()
80 }
81}
82
83impl Default for LabelEncoder {
84 fn default() -> Self {
85 Self::new()
86 }
87}
88
89#[cfg(test)]
90mod tests {
91 use super::*;
92
93 #[test]
94 fn test_label_encoder_roundtrip() {
95 let mut enc = LabelEncoder::new();
96 enc.fit(&["cat", "dog", "bird", "cat"]);
97 assert_eq!(enc.n_classes(), 3);
98
99 let encoded = enc.transform(&["dog", "cat", "bird"]).unwrap();
100 assert_eq!(encoded, vec![2.0, 1.0, 0.0]); let decoded = enc.inverse_transform(&encoded).unwrap();
103 assert_eq!(decoded, vec!["dog", "cat", "bird"]);
104 }
105
106 #[test]
107 fn test_label_encoder_unknown() {
108 let mut enc = LabelEncoder::new();
109 enc.fit(&["a", "b"]);
110 assert!(enc.transform(&["c"]).is_err());
111 }
112}