1use crate::dataset::Dataset;
21use crate::error::{Result, ScryLearnError};
22use crate::preprocess::Transformer;
23
24#[derive(Clone, Debug, Default, PartialEq, Eq)]
28#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
29#[non_exhaustive]
30pub enum DropStrategy {
31 #[default]
33 None,
34 First,
36 IfBinary,
38}
39
40#[derive(Clone, Debug, Default, PartialEq, Eq)]
42#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
43#[non_exhaustive]
44pub enum UnknownStrategy {
45 #[default]
47 Error,
48 Ignore,
50}
51
52#[derive(Clone, Debug)]
60#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
61#[non_exhaustive]
62pub struct OneHotEncoder {
63 feature_indices: Vec<usize>,
64 drop_strategy: DropStrategy,
65 unknown_strategy: UnknownStrategy,
66 categories: Vec<Vec<f64>>,
69 orig_feature_names: Vec<String>,
71 fitted: bool,
72}
73
74impl OneHotEncoder {
77 pub fn new(feature_indices: Vec<usize>) -> Self {
79 Self {
80 feature_indices,
81 drop_strategy: DropStrategy::None,
82 unknown_strategy: UnknownStrategy::Error,
83 categories: Vec::new(),
84 orig_feature_names: Vec::new(),
85 fitted: false,
86 }
87 }
88
89 pub fn drop(mut self, strategy: DropStrategy) -> Self {
91 self.drop_strategy = strategy;
92 self
93 }
94
95 pub fn handle_unknown(mut self, strategy: UnknownStrategy) -> Self {
97 self.unknown_strategy = strategy;
98 self
99 }
100
101 pub fn categories(&self) -> &[Vec<f64>] {
105 &self.categories
106 }
107
108 pub fn get_feature_names(&self) -> Vec<String> {
110 if !self.fitted || self.orig_feature_names.is_empty() {
111 return Vec::new();
112 }
113 let encoded_set: std::collections::HashSet<usize> =
114 self.feature_indices.iter().copied().collect();
115 let mut names = Vec::new();
116 for (j, orig_name) in self.orig_feature_names.iter().enumerate() {
117 if encoded_set.contains(&j) {
118 let cat_idx = self
119 .feature_indices
120 .iter()
121 .position(|&fi| fi == j)
122 .expect("encoded_set built from feature_indices");
123 let cats = &self.categories[cat_idx];
124 let skip = self.n_drop(cat_idx);
125 for (ci, &cat_val) in cats.iter().enumerate() {
126 if ci < skip {
127 continue;
128 }
129 names.push(format!("{}_{}", orig_name, cat_val as i64));
130 }
131 } else {
132 names.push(orig_name.clone());
133 }
134 }
135 names
136 }
137}
138
139impl OneHotEncoder {
142 fn n_drop(&self, cat_idx: usize) -> usize {
145 match self.drop_strategy {
146 DropStrategy::None => 0,
147 DropStrategy::First => 1,
148 DropStrategy::IfBinary => usize::from(self.categories[cat_idx].len() == 2),
149 }
150 }
151}
152
153impl Transformer for OneHotEncoder {
156 fn fit(&mut self, data: &Dataset) -> Result<()> {
157 if data.n_samples() == 0 {
158 return Err(ScryLearnError::EmptyDataset);
159 }
160 for &idx in &self.feature_indices {
161 if idx >= data.n_features() {
162 return Err(ScryLearnError::InvalidParameter(format!(
163 "feature index {idx} out of range (dataset has {} features)",
164 data.n_features()
165 )));
166 }
167 }
168
169 self.categories.clear();
170 self.orig_feature_names.clone_from(&data.feature_names);
171 for &idx in &self.feature_indices {
172 let mut unique: Vec<f64> = data.features[idx].clone();
173 unique.sort_by(|a, b| a.total_cmp(b));
174 unique.dedup();
175 self.categories.push(unique);
176 }
177 self.fitted = true;
178 Ok(())
179 }
180
181 fn transform(&self, data: &mut Dataset) -> Result<()> {
182 if !self.fitted {
183 return Err(ScryLearnError::NotFitted);
184 }
185 let n = data.n_samples();
186
187 let encoded_set: std::collections::HashSet<usize> =
189 self.feature_indices.iter().copied().collect();
190
191 let mut new_features: Vec<Vec<f64>> = Vec::new();
192 let mut new_names: Vec<String> = Vec::new();
193
194 for j in 0..data.n_features() {
195 if encoded_set.contains(&j) {
196 let cat_idx = self
198 .feature_indices
199 .iter()
200 .position(|&fi| fi == j)
201 .ok_or(ScryLearnError::InvalidFeatureIndex(j))?;
202 let cats = &self.categories[cat_idx];
203 let skip = self.n_drop(cat_idx);
204 let orig_name = &data.feature_names[j];
205
206 for (ci, &cat_val) in cats.iter().enumerate() {
207 if ci < skip {
208 continue;
209 }
210 let mut col = Vec::with_capacity(n);
211 for s in 0..n {
212 let val = data.features[j][s];
213 if (val - cat_val).abs() < 1e-10 {
214 col.push(1.0);
215 } else if cats.iter().any(|&c| (val - c).abs() < 1e-10) {
216 col.push(0.0);
217 } else {
218 match self.unknown_strategy {
220 UnknownStrategy::Error => {
221 return Err(ScryLearnError::InvalidParameter(format!(
222 "unknown category {val} in feature '{orig_name}'"
223 )));
224 }
225 UnknownStrategy::Ignore => {
226 col.push(0.0);
227 }
228 }
229 }
230 }
231 new_features.push(col);
232 new_names.push(format!("{}_{}", orig_name, cat_val as i64));
233 }
234 } else {
235 new_features.push(data.features[j].clone());
237 new_names.push(data.feature_names[j].clone());
238 }
239 }
240
241 data.features = new_features;
242 data.feature_names = new_names;
243 data.sync_matrix();
244 Ok(())
245 }
246
247 fn inverse_transform(&self, data: &mut Dataset) -> Result<()> {
248 if !self.fitted {
249 return Err(ScryLearnError::NotFitted);
250 }
251 let n = data.n_samples();
252
253 let mut new_features: Vec<Vec<f64>> = Vec::new();
256 let mut new_names: Vec<String> = Vec::new();
257
258 let mut j = 0;
259 let mut cat_idx = 0;
260
261 while j < data.n_features() {
269 if cat_idx < self.feature_indices.len() {
270 let cats = &self.categories[cat_idx];
271 let skip = self.n_drop(cat_idx);
272 let n_cols = cats.len() - skip;
273
274 if j + n_cols <= data.n_features() {
276 let first_name = &data.feature_names[j];
278 let prefix = first_name
279 .rfind('_')
280 .map_or(first_name.as_str(), |pos| &first_name[..pos]);
281
282 let mut col = Vec::with_capacity(n);
284 for s in 0..n {
285 let mut found = false;
286 for (ci, &cat_val) in cats.iter().enumerate().skip(skip) {
287 let col_idx = j + ci - skip;
288 if data.features[col_idx][s] > 0.5 {
289 col.push(cat_val);
290 found = true;
291 break;
292 }
293 }
294 if !found {
295 if skip > 0 {
298 col.push(cats[0]);
299 } else {
300 col.push(f64::NAN);
301 }
302 }
303 }
304 new_features.push(col);
305 new_names.push(prefix.to_string());
306 j += n_cols;
307 cat_idx += 1;
308 continue;
309 }
310 }
311
312 new_features.push(data.features[j].clone());
314 new_names.push(data.feature_names[j].clone());
315 j += 1;
316 }
317
318 data.features = new_features;
319 data.feature_names = new_names;
320 data.sync_matrix();
321 Ok(())
322 }
323}
324
325#[cfg(test)]
328#[allow(clippy::float_cmp)]
329mod tests {
330 use super::*;
331
332 fn color_dataset() -> Dataset {
333 Dataset::new(
335 vec![
336 vec![0.0, 1.0, 2.0, 0.0, 1.0, 2.0],
337 vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
338 ],
339 vec![0.0, 0.0, 1.0, 1.0, 0.0, 1.0],
340 vec!["color".into(), "value".into()],
341 "target",
342 )
343 }
344
345 #[test]
346 fn onehot_basic_encoding() {
347 let mut ds = color_dataset();
348 let mut enc = OneHotEncoder::new(vec![0]);
349 enc.fit_transform(&mut ds).unwrap();
350
351 assert_eq!(ds.n_features(), 4);
353 assert_eq!(ds.feature_names[0], "color_0");
354 assert_eq!(ds.feature_names[1], "color_1");
355 assert_eq!(ds.feature_names[2], "color_2");
356 assert_eq!(ds.feature_names[3], "value");
357
358 assert_eq!(ds.features[0][0], 1.0);
360 assert_eq!(ds.features[1][0], 0.0);
361 assert_eq!(ds.features[2][0], 0.0);
362
363 assert_eq!(ds.features[0][2], 0.0);
365 assert_eq!(ds.features[1][2], 0.0);
366 assert_eq!(ds.features[2][2], 1.0);
367 }
368
369 #[test]
370 fn onehot_drop_first() {
371 let mut ds = color_dataset();
372 let mut enc = OneHotEncoder::new(vec![0]).drop(DropStrategy::First);
373 enc.fit_transform(&mut ds).unwrap();
374
375 assert_eq!(ds.n_features(), 3);
377 assert_eq!(ds.feature_names[0], "color_1");
378 assert_eq!(ds.feature_names[1], "color_2");
379 }
380
381 #[test]
382 fn onehot_drop_if_binary() {
383 let mut ds = Dataset::new(
385 vec![vec![0.0, 1.0, 0.0, 1.0], vec![10.0, 20.0, 30.0, 40.0]],
386 vec![0.0; 4],
387 vec!["binary".into(), "num".into()],
388 "y",
389 );
390 let mut enc = OneHotEncoder::new(vec![0]).drop(DropStrategy::IfBinary);
391 enc.fit_transform(&mut ds).unwrap();
392
393 assert_eq!(ds.n_features(), 2);
395 assert_eq!(ds.feature_names[0], "binary_1");
396
397 let mut ds3 = color_dataset();
399 let mut enc3 = OneHotEncoder::new(vec![0]).drop(DropStrategy::IfBinary);
400 enc3.fit_transform(&mut ds3).unwrap();
401 assert_eq!(ds3.n_features(), 4); }
403
404 #[test]
405 fn onehot_unknown_error() {
406 let mut ds = color_dataset();
407 let mut enc = OneHotEncoder::new(vec![0]);
408 enc.fit(&ds).unwrap();
409
410 ds.features[0][0] = 99.0;
412 assert!(enc.transform(&mut ds).is_err());
413 }
414
415 #[test]
416 fn onehot_unknown_ignore() {
417 let mut ds = color_dataset();
418 let mut enc = OneHotEncoder::new(vec![0]).handle_unknown(UnknownStrategy::Ignore);
419 enc.fit(&ds).unwrap();
420
421 ds.features[0][0] = 99.0;
423 enc.transform(&mut ds).unwrap();
424
425 assert_eq!(ds.features[0][0], 0.0); assert_eq!(ds.features[1][0], 0.0); assert_eq!(ds.features[2][0], 0.0); }
430
431 #[test]
432 fn onehot_roundtrip_inverse() {
433 let original = color_dataset();
434 let mut ds = original.clone();
435 let mut enc = OneHotEncoder::new(vec![0]);
436 enc.fit_transform(&mut ds).unwrap();
437 enc.inverse_transform(&mut ds).unwrap();
438
439 assert_eq!(ds.n_features(), 2);
440 for i in 0..original.n_samples() {
441 assert!(
442 (ds.features[0][i] - original.features[0][i]).abs() < 1e-10,
443 "roundtrip mismatch at sample {i}"
444 );
445 }
446 }
447
448 #[test]
449 fn onehot_feature_names() {
450 let mut ds = color_dataset();
451 let mut enc = OneHotEncoder::new(vec![0]);
452 enc.fit_transform(&mut ds).unwrap();
453
454 let names = enc.get_feature_names();
455 assert_eq!(names, &["color_0", "color_1", "color_2", "value"]);
456 }
457
458 #[test]
459 fn onehot_not_fitted_error() {
460 let enc = OneHotEncoder::new(vec![0]);
461 let mut ds = color_dataset();
462 assert!(enc.transform(&mut ds).is_err());
463 }
464
465 #[test]
466 fn onehot_multiple_features() {
467 let mut ds = Dataset::new(
469 vec![
470 vec![0.0, 1.0, 0.0, 1.0], vec![0.0, 1.0, 2.0, 0.0], vec![5.0, 6.0, 7.0, 8.0], ],
474 vec![0.0; 4],
475 vec!["a".into(), "b".into(), "num".into()],
476 "y",
477 );
478 let mut enc = OneHotEncoder::new(vec![0, 1]);
479 enc.fit_transform(&mut ds).unwrap();
480
481 assert_eq!(ds.n_features(), 6);
483 assert_eq!(ds.feature_names[0], "a_0");
484 assert_eq!(ds.feature_names[1], "a_1");
485 assert_eq!(ds.feature_names[2], "b_0");
486 assert_eq!(ds.feature_names[3], "b_1");
487 assert_eq!(ds.feature_names[4], "b_2");
488 assert_eq!(ds.feature_names[5], "num");
489 }
490}