1use ndarray::{Array2, ArrayBase, Data, Ix2};
7use num_traits::{Float, NumCast};
8use std::collections::HashMap;
9
10use crate::error::{Result, TransformError};
11
12pub struct OneHotEncoder {
17 categories_: Option<Vec<Vec<u64>>>,
19 drop: Option<String>,
21 handle_unknown: String,
23 #[allow(dead_code)]
25 sparse: bool,
26}
27
28impl OneHotEncoder {
29 pub fn new(drop: Option<String>, handle_unknown: &str, sparse: bool) -> Result<Self> {
39 if let Some(ref drop_strategy) = drop {
40 if drop_strategy != "first" && drop_strategy != "if_binary" {
41 return Err(TransformError::InvalidInput(
42 "drop must be 'first', 'if_binary', or None".to_string(),
43 ));
44 }
45 }
46
47 if handle_unknown != "error" && handle_unknown != "ignore" {
48 return Err(TransformError::InvalidInput(
49 "handle_unknown must be 'error' or 'ignore'".to_string(),
50 ));
51 }
52
53 if sparse {
54 return Err(TransformError::InvalidInput(
55 "Sparse output is not yet implemented".to_string(),
56 ));
57 }
58
59 Ok(OneHotEncoder {
60 categories_: None,
61 drop,
62 handle_unknown: handle_unknown.to_string(),
63 sparse,
64 })
65 }
66
67 pub fn with_defaults() -> Self {
69 Self::new(None, "error", false).unwrap()
70 }
71
72 pub fn fit<S>(&mut self, x: &ArrayBase<S, Ix2>) -> Result<()>
80 where
81 S: Data,
82 S::Elem: Float + NumCast,
83 {
84 let x_u64 = x.mapv(|x| {
85 let val_f64 = num_traits::cast::<S::Elem, f64>(x).unwrap_or(0.0);
86 val_f64 as u64
87 });
88
89 let n_samples = x_u64.shape()[0];
90 let n_features = x_u64.shape()[1];
91
92 if n_samples == 0 || n_features == 0 {
93 return Err(TransformError::InvalidInput("Empty input data".to_string()));
94 }
95
96 let mut categories = Vec::with_capacity(n_features);
97
98 for j in 0..n_features {
99 let mut unique_values: Vec<u64> = x_u64.column(j).to_vec();
101 unique_values.sort_unstable();
102 unique_values.dedup();
103
104 categories.push(unique_values);
105 }
106
107 self.categories_ = Some(categories);
108 Ok(())
109 }
110
111 pub fn transform<S>(&self, x: &ArrayBase<S, Ix2>) -> Result<Array2<f64>>
119 where
120 S: Data,
121 S::Elem: Float + NumCast,
122 {
123 let x_u64 = x.mapv(|x| {
124 let val_f64 = num_traits::cast::<S::Elem, f64>(x).unwrap_or(0.0);
125 val_f64 as u64
126 });
127
128 let n_samples = x_u64.shape()[0];
129 let n_features = x_u64.shape()[1];
130
131 if self.categories_.is_none() {
132 return Err(TransformError::TransformationError(
133 "OneHotEncoder has not been fitted".to_string(),
134 ));
135 }
136
137 let categories = self.categories_.as_ref().unwrap();
138
139 if n_features != categories.len() {
140 return Err(TransformError::InvalidInput(format!(
141 "x has {} features, but OneHotEncoder was fitted with {} features",
142 n_features,
143 categories.len()
144 )));
145 }
146
147 let mut total_features = 0;
149 for (j, feature_categories) in categories.iter().enumerate() {
150 let n_cats = feature_categories.len();
151
152 let n_output_cats = match &self.drop {
154 Some(strategy) if strategy == "first" => n_cats.saturating_sub(1),
155 Some(strategy) if strategy == "if_binary" && n_cats == 2 => 1,
156 _ => n_cats,
157 };
158
159 if n_output_cats == 0 {
160 return Err(TransformError::InvalidInput(format!(
161 "Feature {} has only one category after dropping",
162 j
163 )));
164 }
165
166 total_features += n_output_cats;
167 }
168
169 let mut transformed = Array2::zeros((n_samples, total_features));
170
171 let mut category_mappings = Vec::new();
173 let mut current_col = 0;
174
175 for feature_categories in categories.iter() {
176 let mut mapping = HashMap::new();
177 let n_cats = feature_categories.len();
178
179 let (start_idx, n_output_cats) = match &self.drop {
181 Some(strategy) if strategy == "first" => (1, n_cats.saturating_sub(1)),
182 Some(strategy) if strategy == "if_binary" && n_cats == 2 => (0, 1),
183 _ => (0, n_cats),
184 };
185
186 for (cat_idx, &category) in feature_categories.iter().enumerate() {
187 if cat_idx >= start_idx && cat_idx < start_idx + n_output_cats {
188 mapping.insert(category, current_col + cat_idx - start_idx);
189 }
190 }
191
192 category_mappings.push(mapping);
193 current_col += n_output_cats;
194 }
195
196 for i in 0..n_samples {
198 for j in 0..n_features {
199 let value = x_u64[[i, j]];
200
201 if let Some(&col_idx) = category_mappings[j].get(&value) {
202 transformed[[i, col_idx]] = 1.0;
203 } else {
204 let feature_categories = &categories[j];
206 let is_dropped_category = match &self.drop {
207 Some(strategy) if strategy == "first" => {
208 !feature_categories.is_empty() && value == feature_categories[0]
210 }
211 Some(strategy)
212 if strategy == "if_binary" && feature_categories.len() == 2 =>
213 {
214 feature_categories.len() == 2 && value == feature_categories[1]
216 }
217 _ => false,
218 };
219
220 if !is_dropped_category && self.handle_unknown == "error" {
221 return Err(TransformError::InvalidInput(format!(
222 "Found unknown category {} in feature {}",
223 value, j
224 )));
225 }
226 }
228 }
229 }
230
231 Ok(transformed)
232 }
233
234 pub fn fit_transform<S>(&mut self, x: &ArrayBase<S, Ix2>) -> Result<Array2<f64>>
242 where
243 S: Data,
244 S::Elem: Float + NumCast,
245 {
246 self.fit(x)?;
247 self.transform(x)
248 }
249
250 pub fn categories(&self) -> Option<&Vec<Vec<u64>>> {
255 self.categories_.as_ref()
256 }
257
258 pub fn get_feature_names(&self, input_features: Option<&[String]>) -> Result<Vec<String>> {
266 if self.categories_.is_none() {
267 return Err(TransformError::TransformationError(
268 "OneHotEncoder has not been fitted".to_string(),
269 ));
270 }
271
272 let categories = self.categories_.as_ref().unwrap();
273 let mut feature_names = Vec::new();
274
275 for (j, feature_categories) in categories.iter().enumerate() {
276 let feature_name = if let Some(names) = input_features {
277 if j < names.len() {
278 names[j].clone()
279 } else {
280 format!("x{}", j)
281 }
282 } else {
283 format!("x{}", j)
284 };
285
286 let n_cats = feature_categories.len();
287
288 let (start_idx, n_output_cats) = match &self.drop {
290 Some(strategy) if strategy == "first" => (1, n_cats.saturating_sub(1)),
291 Some(strategy) if strategy == "if_binary" && n_cats == 2 => (0, 1),
292 _ => (0, n_cats),
293 };
294
295 for &category in feature_categories
296 .iter()
297 .skip(start_idx)
298 .take(n_output_cats)
299 {
300 feature_names.push(format!("{}_cat_{}", feature_name, category));
301 }
302 }
303
304 Ok(feature_names)
305 }
306}
307
308pub struct OrdinalEncoder {
313 categories_: Option<Vec<Vec<u64>>>,
315 handle_unknown: String,
317 unknown_value: Option<f64>,
319}
320
321impl OrdinalEncoder {
322 pub fn new(handle_unknown: &str, unknown_value: Option<f64>) -> Result<Self> {
331 if handle_unknown != "error" && handle_unknown != "use_encoded_value" {
332 return Err(TransformError::InvalidInput(
333 "handle_unknown must be 'error' or 'use_encoded_value'".to_string(),
334 ));
335 }
336
337 if handle_unknown == "use_encoded_value" && unknown_value.is_none() {
338 return Err(TransformError::InvalidInput(
339 "unknown_value must be specified when handle_unknown='use_encoded_value'"
340 .to_string(),
341 ));
342 }
343
344 Ok(OrdinalEncoder {
345 categories_: None,
346 handle_unknown: handle_unknown.to_string(),
347 unknown_value,
348 })
349 }
350
351 pub fn with_defaults() -> Self {
353 Self::new("error", None).unwrap()
354 }
355
356 pub fn fit<S>(&mut self, x: &ArrayBase<S, Ix2>) -> Result<()>
364 where
365 S: Data,
366 S::Elem: Float + NumCast,
367 {
368 let x_u64 = x.mapv(|x| {
369 let val_f64 = num_traits::cast::<S::Elem, f64>(x).unwrap_or(0.0);
370 val_f64 as u64
371 });
372
373 let n_samples = x_u64.shape()[0];
374 let n_features = x_u64.shape()[1];
375
376 if n_samples == 0 || n_features == 0 {
377 return Err(TransformError::InvalidInput("Empty input data".to_string()));
378 }
379
380 let mut categories = Vec::with_capacity(n_features);
381
382 for j in 0..n_features {
383 let mut unique_values: Vec<u64> = x_u64.column(j).to_vec();
385 unique_values.sort_unstable();
386 unique_values.dedup();
387
388 categories.push(unique_values);
389 }
390
391 self.categories_ = Some(categories);
392 Ok(())
393 }
394
395 pub fn transform<S>(&self, x: &ArrayBase<S, Ix2>) -> Result<Array2<f64>>
403 where
404 S: Data,
405 S::Elem: Float + NumCast,
406 {
407 let x_u64 = x.mapv(|x| {
408 let val_f64 = num_traits::cast::<S::Elem, f64>(x).unwrap_or(0.0);
409 val_f64 as u64
410 });
411
412 let n_samples = x_u64.shape()[0];
413 let n_features = x_u64.shape()[1];
414
415 if self.categories_.is_none() {
416 return Err(TransformError::TransformationError(
417 "OrdinalEncoder has not been fitted".to_string(),
418 ));
419 }
420
421 let categories = self.categories_.as_ref().unwrap();
422
423 if n_features != categories.len() {
424 return Err(TransformError::InvalidInput(format!(
425 "x has {} features, but OrdinalEncoder was fitted with {} features",
426 n_features,
427 categories.len()
428 )));
429 }
430
431 let mut transformed = Array2::zeros((n_samples, n_features));
432
433 let mut category_mappings = Vec::new();
435 for feature_categories in categories {
436 let mut mapping = HashMap::new();
437 for (ordinal, &category) in feature_categories.iter().enumerate() {
438 mapping.insert(category, ordinal as f64);
439 }
440 category_mappings.push(mapping);
441 }
442
443 for i in 0..n_samples {
445 for j in 0..n_features {
446 let value = x_u64[[i, j]];
447
448 if let Some(&ordinal_value) = category_mappings[j].get(&value) {
449 transformed[[i, j]] = ordinal_value;
450 } else if self.handle_unknown == "error" {
451 return Err(TransformError::InvalidInput(format!(
452 "Found unknown category {} in feature {}",
453 value, j
454 )));
455 } else {
456 transformed[[i, j]] = self.unknown_value.unwrap();
458 }
459 }
460 }
461
462 Ok(transformed)
463 }
464
465 pub fn fit_transform<S>(&mut self, x: &ArrayBase<S, Ix2>) -> Result<Array2<f64>>
473 where
474 S: Data,
475 S::Elem: Float + NumCast,
476 {
477 self.fit(x)?;
478 self.transform(x)
479 }
480
481 pub fn categories(&self) -> Option<&Vec<Vec<u64>>> {
486 self.categories_.as_ref()
487 }
488}
489
490#[derive(Debug, Clone)]
508pub struct TargetEncoder {
509 strategy: String,
511 smoothing: f64,
513 global_stat: f64,
515 encodings_: Option<Vec<HashMap<u64, f64>>>,
517 is_fitted: bool,
519 global_mean_: f64,
521}
522
523impl TargetEncoder {
524 pub fn new(strategy: &str, smoothing: f64, global_stat: f64) -> Result<Self> {
534 if !["mean", "median", "count", "sum"].contains(&strategy) {
535 return Err(TransformError::InvalidInput(
536 "strategy must be 'mean', 'median', 'count', or 'sum'".to_string(),
537 ));
538 }
539
540 if smoothing < 0.0 {
541 return Err(TransformError::InvalidInput(
542 "smoothing parameter must be non-negative".to_string(),
543 ));
544 }
545
546 Ok(TargetEncoder {
547 strategy: strategy.to_string(),
548 smoothing,
549 global_stat,
550 encodings_: None,
551 is_fitted: false,
552 global_mean_: 0.0,
553 })
554 }
555
556 pub fn with_mean(smoothing: f64) -> Self {
558 TargetEncoder {
559 strategy: "mean".to_string(),
560 smoothing,
561 global_stat: 0.0,
562 encodings_: None,
563 is_fitted: false,
564 global_mean_: 0.0,
565 }
566 }
567
568 pub fn with_median(smoothing: f64) -> Self {
570 TargetEncoder {
571 strategy: "median".to_string(),
572 smoothing,
573 global_stat: 0.0,
574 encodings_: None,
575 is_fitted: false,
576 global_mean_: 0.0,
577 }
578 }
579
580 pub fn fit<S>(&mut self, x: &ArrayBase<S, Ix2>, y: &[f64]) -> Result<()>
589 where
590 S: Data,
591 S::Elem: Float + NumCast,
592 {
593 let x_u64 = x.mapv(|x| {
594 let val_f64 = num_traits::cast::<S::Elem, f64>(x).unwrap_or(0.0);
595 val_f64 as u64
596 });
597
598 let n_samples = x_u64.shape()[0];
599 let n_features = x_u64.shape()[1];
600
601 if n_samples == 0 || n_features == 0 {
602 return Err(TransformError::InvalidInput("Empty input data".to_string()));
603 }
604
605 if y.len() != n_samples {
606 return Err(TransformError::InvalidInput(
607 "Number of target values must match number of samples".to_string(),
608 ));
609 }
610
611 self.global_mean_ = y.iter().sum::<f64>() / y.len() as f64;
613
614 let mut encodings = Vec::with_capacity(n_features);
615
616 for j in 0..n_features {
617 let mut category_targets: HashMap<u64, Vec<f64>> = HashMap::new();
619
620 for i in 0..n_samples {
621 let category = x_u64[[i, j]];
622 category_targets.entry(category).or_default().push(y[i]);
623 }
624
625 let mut category_encoding = HashMap::new();
627
628 for (category, targets) in category_targets.iter() {
629 let encoded_value = match self.strategy.as_str() {
630 "mean" => {
631 let category_mean = targets.iter().sum::<f64>() / targets.len() as f64;
632 let count = targets.len() as f64;
633
634 if self.smoothing > 0.0 {
636 (count * category_mean + self.smoothing * self.global_mean_)
637 / (count + self.smoothing)
638 } else {
639 category_mean
640 }
641 }
642 "median" => {
643 let mut sorted_targets = targets.clone();
644 sorted_targets.sort_by(|a, b| a.partial_cmp(b).unwrap());
645
646 let median = if sorted_targets.len() % 2 == 0 {
647 let mid = sorted_targets.len() / 2;
648 (sorted_targets[mid - 1] + sorted_targets[mid]) / 2.0
649 } else {
650 sorted_targets[sorted_targets.len() / 2]
651 };
652
653 if self.smoothing > 0.0 {
655 let count = targets.len() as f64;
656 (count * median + self.smoothing * self.global_mean_)
657 / (count + self.smoothing)
658 } else {
659 median
660 }
661 }
662 "count" => targets.len() as f64,
663 "sum" => targets.iter().sum::<f64>(),
664 _ => unreachable!(),
665 };
666
667 category_encoding.insert(*category, encoded_value);
668 }
669
670 encodings.push(category_encoding);
671 }
672
673 self.encodings_ = Some(encodings);
674 self.is_fitted = true;
675 Ok(())
676 }
677
678 pub fn transform<S>(&self, x: &ArrayBase<S, Ix2>) -> Result<Array2<f64>>
686 where
687 S: Data,
688 S::Elem: Float + NumCast,
689 {
690 if !self.is_fitted {
691 return Err(TransformError::TransformationError(
692 "TargetEncoder has not been fitted".to_string(),
693 ));
694 }
695
696 let x_u64 = x.mapv(|x| {
697 let val_f64 = num_traits::cast::<S::Elem, f64>(x).unwrap_or(0.0);
698 val_f64 as u64
699 });
700
701 let n_samples = x_u64.shape()[0];
702 let n_features = x_u64.shape()[1];
703
704 let encodings = self.encodings_.as_ref().unwrap();
705
706 if n_features != encodings.len() {
707 return Err(TransformError::InvalidInput(format!(
708 "x has {} features, but TargetEncoder was fitted with {} features",
709 n_features,
710 encodings.len()
711 )));
712 }
713
714 let mut transformed = Array2::zeros((n_samples, n_features));
715
716 for i in 0..n_samples {
717 for j in 0..n_features {
718 let category = x_u64[[i, j]];
719
720 if let Some(&encoded_value) = encodings[j].get(&category) {
721 transformed[[i, j]] = encoded_value;
722 } else {
723 transformed[[i, j]] = if self.global_stat != 0.0 {
725 self.global_stat
726 } else {
727 self.global_mean_
728 };
729 }
730 }
731 }
732
733 Ok(transformed)
734 }
735
736 pub fn fit_transform<S>(&mut self, x: &ArrayBase<S, Ix2>, y: &[f64]) -> Result<Array2<f64>>
745 where
746 S: Data,
747 S::Elem: Float + NumCast,
748 {
749 self.fit(x, y)?;
750 self.transform(x)
751 }
752
753 pub fn encodings(&self) -> Option<&Vec<HashMap<u64, f64>>> {
758 self.encodings_.as_ref()
759 }
760
761 pub fn is_fitted(&self) -> bool {
763 self.is_fitted
764 }
765
766 pub fn global_mean(&self) -> f64 {
768 self.global_mean_
769 }
770
771 pub fn fit_transform_cv<S>(
785 &mut self,
786 x: &ArrayBase<S, Ix2>,
787 y: &[f64],
788 cv_folds: usize,
789 ) -> Result<Array2<f64>>
790 where
791 S: Data,
792 S::Elem: Float + NumCast,
793 {
794 let x_u64 = x.mapv(|x| {
795 let val_f64 = num_traits::cast::<S::Elem, f64>(x).unwrap_or(0.0);
796 val_f64 as u64
797 });
798
799 let n_samples = x_u64.shape()[0];
800 let n_features = x_u64.shape()[1];
801
802 if n_samples == 0 || n_features == 0 {
803 return Err(TransformError::InvalidInput("Empty input data".to_string()));
804 }
805
806 if y.len() != n_samples {
807 return Err(TransformError::InvalidInput(
808 "Number of target values must match number of samples".to_string(),
809 ));
810 }
811
812 if cv_folds < 2 {
813 return Err(TransformError::InvalidInput(
814 "cv_folds must be at least 2".to_string(),
815 ));
816 }
817
818 let mut transformed = Array2::zeros((n_samples, n_features));
819
820 self.global_mean_ = y.iter().sum::<f64>() / y.len() as f64;
822
823 let fold_size = n_samples / cv_folds;
825 let mut fold_indices = Vec::new();
826 for fold in 0..cv_folds {
827 let start = fold * fold_size;
828 let end = if fold == cv_folds - 1 {
829 n_samples
830 } else {
831 (fold + 1) * fold_size
832 };
833 fold_indices.push((start, end));
834 }
835
836 for fold in 0..cv_folds {
838 let (val_start, val_end) = fold_indices[fold];
839
840 let mut train_indices = Vec::new();
842 for (other_fold, &(start, end)) in fold_indices.iter().enumerate().take(cv_folds) {
843 if other_fold != fold {
844 train_indices.extend(start..end);
845 }
846 }
847
848 for j in 0..n_features {
850 let mut category_targets: HashMap<u64, Vec<f64>> = HashMap::new();
851
852 for &train_idx in &train_indices {
854 let category = x_u64[[train_idx, j]];
855 category_targets
856 .entry(category)
857 .or_default()
858 .push(y[train_idx]);
859 }
860
861 let mut category_encoding = HashMap::new();
863 for (category, targets) in category_targets.iter() {
864 let encoded_value = match self.strategy.as_str() {
865 "mean" => {
866 let category_mean = targets.iter().sum::<f64>() / targets.len() as f64;
867 let count = targets.len() as f64;
868
869 if self.smoothing > 0.0 {
870 (count * category_mean + self.smoothing * self.global_mean_)
871 / (count + self.smoothing)
872 } else {
873 category_mean
874 }
875 }
876 "median" => {
877 let mut sorted_targets = targets.clone();
878 sorted_targets.sort_by(|a, b| a.partial_cmp(b).unwrap());
879
880 let median = if sorted_targets.len() % 2 == 0 {
881 let mid = sorted_targets.len() / 2;
882 (sorted_targets[mid - 1] + sorted_targets[mid]) / 2.0
883 } else {
884 sorted_targets[sorted_targets.len() / 2]
885 };
886
887 if self.smoothing > 0.0 {
888 let count = targets.len() as f64;
889 (count * median + self.smoothing * self.global_mean_)
890 / (count + self.smoothing)
891 } else {
892 median
893 }
894 }
895 "count" => targets.len() as f64,
896 "sum" => targets.iter().sum::<f64>(),
897 _ => unreachable!(),
898 };
899
900 category_encoding.insert(*category, encoded_value);
901 }
902
903 for val_idx in val_start..val_end {
905 let category = x_u64[[val_idx, j]];
906
907 if let Some(&encoded_value) = category_encoding.get(&category) {
908 transformed[[val_idx, j]] = encoded_value;
909 } else {
910 transformed[[val_idx, j]] = self.global_mean_;
912 }
913 }
914 }
915 }
916
917 self.fit(x, y)?;
919
920 Ok(transformed)
921 }
922}
923
924#[derive(Debug, Clone)]
932pub struct BinaryEncoder {
933 categories_: Option<Vec<HashMap<u64, Vec<u8>>>>,
935 n_binary_features_: Option<Vec<usize>>,
937 handle_unknown: String,
939 is_fitted: bool,
941}
942
943impl BinaryEncoder {
944 pub fn new(handle_unknown: &str) -> Result<Self> {
954 if handle_unknown != "error" && handle_unknown != "ignore" {
955 return Err(TransformError::InvalidInput(
956 "handle_unknown must be 'error' or 'ignore'".to_string(),
957 ));
958 }
959
960 Ok(BinaryEncoder {
961 categories_: None,
962 n_binary_features_: None,
963 handle_unknown: handle_unknown.to_string(),
964 is_fitted: false,
965 })
966 }
967
968 pub fn with_defaults() -> Self {
970 Self::new("error").unwrap()
971 }
972
973 pub fn fit<S>(&mut self, x: &ArrayBase<S, Ix2>) -> Result<()>
981 where
982 S: Data,
983 S::Elem: Float + NumCast,
984 {
985 let x_u64 = x.mapv(|x| {
986 let val_f64 = num_traits::cast::<S::Elem, f64>(x).unwrap_or(0.0);
987 val_f64 as u64
988 });
989
990 let n_samples = x_u64.shape()[0];
991 let n_features = x_u64.shape()[1];
992
993 if n_samples == 0 || n_features == 0 {
994 return Err(TransformError::InvalidInput("Empty input data".to_string()));
995 }
996
997 let mut categories = Vec::with_capacity(n_features);
998 let mut n_binary_features = Vec::with_capacity(n_features);
999
1000 for j in 0..n_features {
1001 let mut unique_categories: Vec<u64> = x_u64.column(j).to_vec();
1003 unique_categories.sort_unstable();
1004 unique_categories.dedup();
1005
1006 if unique_categories.is_empty() {
1007 return Err(TransformError::InvalidInput(
1008 "Feature has no valid categories".to_string(),
1009 ));
1010 }
1011
1012 let n_cats = unique_categories.len();
1014 let n_bits = if n_cats <= 1 {
1015 1
1016 } else {
1017 (n_cats as f64).log2().ceil() as usize
1018 };
1019
1020 let mut category_map = HashMap::new();
1022 for (idx, &category) in unique_categories.iter().enumerate() {
1023 let binary_code = Self::int_to_binary(idx, n_bits);
1024 category_map.insert(category, binary_code);
1025 }
1026
1027 categories.push(category_map);
1028 n_binary_features.push(n_bits);
1029 }
1030
1031 self.categories_ = Some(categories);
1032 self.n_binary_features_ = Some(n_binary_features);
1033 self.is_fitted = true;
1034
1035 Ok(())
1036 }
1037
1038 pub fn transform<S>(&self, x: &ArrayBase<S, Ix2>) -> Result<Array2<f64>>
1046 where
1047 S: Data,
1048 S::Elem: Float + NumCast,
1049 {
1050 if !self.is_fitted {
1051 return Err(TransformError::InvalidInput(
1052 "Encoder has not been fitted yet".to_string(),
1053 ));
1054 }
1055
1056 let categories = self.categories_.as_ref().unwrap();
1057 let n_binary_features = self.n_binary_features_.as_ref().unwrap();
1058
1059 let x_u64 = x.mapv(|x| {
1060 let val_f64 = num_traits::cast::<S::Elem, f64>(x).unwrap_or(0.0);
1061 val_f64 as u64
1062 });
1063
1064 let n_samples = x_u64.shape()[0];
1065 let n_features = x_u64.shape()[1];
1066
1067 if n_features != categories.len() {
1068 return Err(TransformError::InvalidInput(format!(
1069 "Number of features ({}) does not match fitted features ({})",
1070 n_features,
1071 categories.len()
1072 )));
1073 }
1074
1075 let total_binary_features: usize = n_binary_features.iter().sum();
1077 let mut result = Array2::<f64>::zeros((n_samples, total_binary_features));
1078
1079 let mut output_col = 0;
1080 for j in 0..n_features {
1081 let category_map = &categories[j];
1082 let n_bits = n_binary_features[j];
1083
1084 for i in 0..n_samples {
1085 let category = x_u64[[i, j]];
1086
1087 if let Some(binary_code) = category_map.get(&category) {
1088 for (bit_idx, &bit_val) in binary_code.iter().enumerate() {
1090 result[[i, output_col + bit_idx]] = bit_val as f64;
1091 }
1092 } else {
1093 match self.handle_unknown.as_str() {
1095 "error" => {
1096 return Err(TransformError::InvalidInput(format!(
1097 "Unknown category {} in feature {}",
1098 category, j
1099 )));
1100 }
1101 "ignore" => {
1102 }
1104 _ => unreachable!(),
1105 }
1106 }
1107 }
1108
1109 output_col += n_bits;
1110 }
1111
1112 Ok(result)
1113 }
1114
1115 pub fn fit_transform<S>(&mut self, x: &ArrayBase<S, Ix2>) -> Result<Array2<f64>>
1123 where
1124 S: Data,
1125 S::Elem: Float + NumCast,
1126 {
1127 self.fit(x)?;
1128 self.transform(x)
1129 }
1130
1131 pub fn is_fitted(&self) -> bool {
1133 self.is_fitted
1134 }
1135
1136 pub fn categories(&self) -> Option<&Vec<HashMap<u64, Vec<u8>>>> {
1138 self.categories_.as_ref()
1139 }
1140
1141 pub fn n_binary_features(&self) -> Option<&Vec<usize>> {
1143 self.n_binary_features_.as_ref()
1144 }
1145
1146 pub fn n_output_features(&self) -> Option<usize> {
1148 self.n_binary_features_.as_ref().map(|v| v.iter().sum())
1149 }
1150
1151 fn int_to_binary(value: usize, n_bits: usize) -> Vec<u8> {
1153 let mut binary = Vec::with_capacity(n_bits);
1154 let mut val = value;
1155
1156 for _ in 0..n_bits {
1157 binary.push((val & 1) as u8);
1158 val >>= 1;
1159 }
1160
1161 binary.reverse(); binary
1163 }
1164}
1165
1166#[cfg(test)]
1167mod tests {
1168 use super::*;
1169 use approx::assert_abs_diff_eq;
1170 use ndarray::Array;
1171
1172 #[test]
1173 fn test_one_hot_encoder_basic() {
1174 let data = Array::from_shape_vec(
1176 (4, 2),
1177 vec![
1178 0.0, 1.0, 1.0, 2.0, 2.0, 3.0, 0.0, 1.0,
1180 ],
1181 )
1182 .unwrap();
1183
1184 let mut encoder = OneHotEncoder::with_defaults();
1185 let encoded = encoder.fit_transform(&data).unwrap();
1186
1187 assert_eq!(encoded.shape(), &[4, 6]);
1189
1190 assert_abs_diff_eq!(encoded[[0, 0]], 1.0, epsilon = 1e-10); assert_abs_diff_eq!(encoded[[0, 1]], 0.0, epsilon = 1e-10); assert_abs_diff_eq!(encoded[[0, 2]], 0.0, epsilon = 1e-10); assert_abs_diff_eq!(encoded[[0, 3]], 1.0, epsilon = 1e-10); assert_abs_diff_eq!(encoded[[0, 4]], 0.0, epsilon = 1e-10); assert_abs_diff_eq!(encoded[[0, 5]], 0.0, epsilon = 1e-10); assert_abs_diff_eq!(encoded[[1, 0]], 0.0, epsilon = 1e-10); assert_abs_diff_eq!(encoded[[1, 1]], 1.0, epsilon = 1e-10); assert_abs_diff_eq!(encoded[[1, 2]], 0.0, epsilon = 1e-10); assert_abs_diff_eq!(encoded[[1, 3]], 0.0, epsilon = 1e-10); assert_abs_diff_eq!(encoded[[1, 4]], 1.0, epsilon = 1e-10); assert_abs_diff_eq!(encoded[[1, 5]], 0.0, epsilon = 1e-10); }
1206
1207 #[test]
1208 fn test_one_hot_encoder_drop_first() {
1209 let data = Array::from_shape_vec((3, 2), vec![0.0, 1.0, 1.0, 2.0, 2.0, 1.0]).unwrap();
1211
1212 let mut encoder = OneHotEncoder::new(Some("first".to_string()), "error", false).unwrap();
1213 let encoded = encoder.fit_transform(&data).unwrap();
1214
1215 assert_eq!(encoded.shape(), &[3, 3]);
1217
1218 assert_abs_diff_eq!(encoded[[0, 0]], 0.0, epsilon = 1e-10); assert_abs_diff_eq!(encoded[[0, 1]], 0.0, epsilon = 1e-10); assert_abs_diff_eq!(encoded[[0, 2]], 0.0, epsilon = 1e-10); assert_abs_diff_eq!(encoded[[1, 0]], 1.0, epsilon = 1e-10); assert_abs_diff_eq!(encoded[[1, 1]], 0.0, epsilon = 1e-10); assert_abs_diff_eq!(encoded[[1, 2]], 1.0, epsilon = 1e-10); }
1231
1232 #[test]
1233 fn test_ordinal_encoder() {
1234 let data = Array::from_shape_vec(
1236 (4, 2),
1237 vec![
1238 2.0, 10.0, 1.0, 20.0, 3.0, 10.0, 2.0, 30.0,
1240 ],
1241 )
1242 .unwrap();
1243
1244 let mut encoder = OrdinalEncoder::with_defaults();
1245 let encoded = encoder.fit_transform(&data).unwrap();
1246
1247 assert_eq!(encoded.shape(), &[4, 2]);
1249
1250 assert_abs_diff_eq!(encoded[[0, 0]], 1.0, epsilon = 1e-10); assert_abs_diff_eq!(encoded[[0, 1]], 0.0, epsilon = 1e-10); assert_abs_diff_eq!(encoded[[1, 0]], 0.0, epsilon = 1e-10); assert_abs_diff_eq!(encoded[[1, 1]], 1.0, epsilon = 1e-10); assert_abs_diff_eq!(encoded[[2, 0]], 2.0, epsilon = 1e-10); assert_abs_diff_eq!(encoded[[2, 1]], 0.0, epsilon = 1e-10); assert_abs_diff_eq!(encoded[[3, 0]], 1.0, epsilon = 1e-10); assert_abs_diff_eq!(encoded[[3, 1]], 2.0, epsilon = 1e-10); }
1263
1264 #[test]
1265 fn test_unknown_category_handling() {
1266 let train_data = Array::from_shape_vec((2, 1), vec![1.0, 2.0]).unwrap();
1267
1268 let test_data = Array::from_shape_vec(
1269 (1, 1),
1270 vec![3.0], )
1272 .unwrap();
1273
1274 let mut encoder = OneHotEncoder::with_defaults(); encoder.fit(&train_data).unwrap();
1277 assert!(encoder.transform(&test_data).is_err());
1278
1279 let mut encoder = OneHotEncoder::new(None, "ignore", false).unwrap();
1281 encoder.fit(&train_data).unwrap();
1282 let encoded = encoder.transform(&test_data).unwrap();
1283
1284 assert_eq!(encoded.shape(), &[1, 2]);
1286 assert_abs_diff_eq!(encoded[[0, 0]], 0.0, epsilon = 1e-10);
1287 assert_abs_diff_eq!(encoded[[0, 1]], 0.0, epsilon = 1e-10);
1288 }
1289
1290 #[test]
1291 fn test_ordinal_encoder_unknown_value() {
1292 let train_data = Array::from_shape_vec((2, 1), vec![1.0, 2.0]).unwrap();
1293
1294 let test_data = Array::from_shape_vec(
1295 (1, 1),
1296 vec![3.0], )
1298 .unwrap();
1299
1300 let mut encoder = OrdinalEncoder::new("use_encoded_value", Some(-1.0)).unwrap();
1301 encoder.fit(&train_data).unwrap();
1302 let encoded = encoder.transform(&test_data).unwrap();
1303
1304 assert_eq!(encoded.shape(), &[1, 1]);
1306 assert_abs_diff_eq!(encoded[[0, 0]], -1.0, epsilon = 1e-10);
1307 }
1308
1309 #[test]
1310 fn test_get_feature_names() {
1311 let data = Array::from_shape_vec((2, 2), vec![1.0, 10.0, 2.0, 20.0]).unwrap();
1312
1313 let mut encoder = OneHotEncoder::with_defaults();
1314 encoder.fit(&data).unwrap();
1315
1316 let feature_names = encoder.get_feature_names(None).unwrap();
1317 assert_eq!(feature_names.len(), 4); let custom_names = vec!["feat_a".to_string(), "feat_b".to_string()];
1320 let feature_names = encoder.get_feature_names(Some(&custom_names)).unwrap();
1321 assert!(feature_names[0].starts_with("feat_a_cat_"));
1322 assert!(feature_names[2].starts_with("feat_b_cat_"));
1323 }
1324
1325 #[test]
1326 fn test_target_encoder_mean_strategy() {
1327 let x = Array::from_shape_vec((6, 1), vec![0.0, 1.0, 2.0, 0.0, 1.0, 2.0]).unwrap();
1329 let y = vec![1.0, 2.0, 3.0, 1.5, 2.5, 3.5];
1330
1331 let mut encoder = TargetEncoder::new("mean", 0.0, 0.0).unwrap();
1332 let encoded = encoder.fit_transform(&x, &y).unwrap();
1333
1334 assert_eq!(encoded.shape(), &[6, 1]);
1336
1337 assert_abs_diff_eq!(encoded[[0, 0]], 1.25, epsilon = 1e-10);
1343 assert_abs_diff_eq!(encoded[[1, 0]], 2.25, epsilon = 1e-10);
1344 assert_abs_diff_eq!(encoded[[2, 0]], 3.25, epsilon = 1e-10);
1345 assert_abs_diff_eq!(encoded[[3, 0]], 1.25, epsilon = 1e-10);
1346 assert_abs_diff_eq!(encoded[[4, 0]], 2.25, epsilon = 1e-10);
1347 assert_abs_diff_eq!(encoded[[5, 0]], 3.25, epsilon = 1e-10);
1348
1349 assert_abs_diff_eq!(encoder.global_mean(), 2.25, epsilon = 1e-10);
1351 }
1352
1353 #[test]
1354 fn test_target_encoder_median_strategy() {
1355 let x = Array::from_shape_vec((4, 1), vec![0.0, 1.0, 0.0, 1.0]).unwrap();
1356 let y = vec![1.0, 2.0, 3.0, 4.0];
1357
1358 let mut encoder = TargetEncoder::new("median", 0.0, 0.0).unwrap();
1359 let encoded = encoder.fit_transform(&x, &y).unwrap();
1360
1361 assert_abs_diff_eq!(encoded[[0, 0]], 2.0, epsilon = 1e-10);
1365 assert_abs_diff_eq!(encoded[[1, 0]], 3.0, epsilon = 1e-10);
1366 assert_abs_diff_eq!(encoded[[2, 0]], 2.0, epsilon = 1e-10);
1367 assert_abs_diff_eq!(encoded[[3, 0]], 3.0, epsilon = 1e-10);
1368 }
1369
1370 #[test]
1371 fn test_target_encoder_count_strategy() {
1372 let x = Array::from_shape_vec((5, 1), vec![0.0, 1.0, 0.0, 2.0, 1.0]).unwrap();
1373 let y = vec![1.0, 2.0, 3.0, 4.0, 5.0];
1374
1375 let mut encoder = TargetEncoder::new("count", 0.0, 0.0).unwrap();
1376 let encoded = encoder.fit_transform(&x, &y).unwrap();
1377
1378 assert_abs_diff_eq!(encoded[[0, 0]], 2.0, epsilon = 1e-10);
1383 assert_abs_diff_eq!(encoded[[1, 0]], 2.0, epsilon = 1e-10);
1384 assert_abs_diff_eq!(encoded[[2, 0]], 2.0, epsilon = 1e-10);
1385 assert_abs_diff_eq!(encoded[[3, 0]], 1.0, epsilon = 1e-10);
1386 assert_abs_diff_eq!(encoded[[4, 0]], 2.0, epsilon = 1e-10);
1387 }
1388
1389 #[test]
1390 fn test_target_encoder_sum_strategy() {
1391 let x = Array::from_shape_vec((4, 1), vec![0.0, 1.0, 0.0, 1.0]).unwrap();
1392 let y = vec![1.0, 2.0, 3.0, 4.0];
1393
1394 let mut encoder = TargetEncoder::new("sum", 0.0, 0.0).unwrap();
1395 let encoded = encoder.fit_transform(&x, &y).unwrap();
1396
1397 assert_abs_diff_eq!(encoded[[0, 0]], 4.0, epsilon = 1e-10);
1401 assert_abs_diff_eq!(encoded[[1, 0]], 6.0, epsilon = 1e-10);
1402 assert_abs_diff_eq!(encoded[[2, 0]], 4.0, epsilon = 1e-10);
1403 assert_abs_diff_eq!(encoded[[3, 0]], 6.0, epsilon = 1e-10);
1404 }
1405
1406 #[test]
1407 fn test_target_encoder_smoothing() {
1408 let x = Array::from_shape_vec((3, 1), vec![0.0, 1.0, 2.0]).unwrap();
1409 let y = vec![1.0, 2.0, 3.0];
1410
1411 let mut encoder = TargetEncoder::new("mean", 1.0, 0.0).unwrap();
1412 let encoded = encoder.fit_transform(&x, &y).unwrap();
1413
1414 assert_abs_diff_eq!(encoded[[0, 0]], 1.5, epsilon = 1e-10);
1420 assert_abs_diff_eq!(encoded[[1, 0]], 2.0, epsilon = 1e-10);
1421 assert_abs_diff_eq!(encoded[[2, 0]], 2.5, epsilon = 1e-10);
1422 }
1423
1424 #[test]
1425 fn test_target_encoder_unknown_categories() {
1426 let train_x = Array::from_shape_vec((3, 1), vec![0.0, 1.0, 2.0]).unwrap();
1427 let train_y = vec![1.0, 2.0, 3.0];
1428
1429 let test_x = Array::from_shape_vec((2, 1), vec![3.0, 4.0]).unwrap(); let mut encoder = TargetEncoder::new("mean", 0.0, -1.0).unwrap();
1432 encoder.fit(&train_x, &train_y).unwrap();
1433 let encoded = encoder.transform(&test_x).unwrap();
1434
1435 assert_abs_diff_eq!(encoded[[0, 0]], -1.0, epsilon = 1e-10);
1437 assert_abs_diff_eq!(encoded[[1, 0]], -1.0, epsilon = 1e-10);
1438 }
1439
1440 #[test]
1441 fn test_target_encoder_unknown_categories_global_mean() {
1442 let train_x = Array::from_shape_vec((3, 1), vec![0.0, 1.0, 2.0]).unwrap();
1443 let train_y = vec![1.0, 2.0, 3.0];
1444
1445 let test_x = Array::from_shape_vec((1, 1), vec![3.0]).unwrap(); let mut encoder = TargetEncoder::new("mean", 0.0, 0.0).unwrap(); encoder.fit(&train_x, &train_y).unwrap();
1449 let encoded = encoder.transform(&test_x).unwrap();
1450
1451 assert_abs_diff_eq!(encoded[[0, 0]], 2.0, epsilon = 1e-10); }
1454
1455 #[test]
1456 fn test_target_encoder_multi_feature() {
1457 let x =
1458 Array::from_shape_vec((4, 2), vec![0.0, 0.0, 1.0, 1.0, 0.0, 1.0, 1.0, 0.0]).unwrap();
1459 let y = vec![1.0, 2.0, 3.0, 4.0];
1460
1461 let mut encoder = TargetEncoder::new("mean", 0.0, 0.0).unwrap();
1462 let encoded = encoder.fit_transform(&x, &y).unwrap();
1463
1464 assert_eq!(encoded.shape(), &[4, 2]);
1465
1466 assert_abs_diff_eq!(encoded[[0, 0]], 2.0, epsilon = 1e-10);
1472 assert_abs_diff_eq!(encoded[[0, 1]], 2.5, epsilon = 1e-10);
1473 assert_abs_diff_eq!(encoded[[1, 0]], 3.0, epsilon = 1e-10);
1474 assert_abs_diff_eq!(encoded[[1, 1]], 2.5, epsilon = 1e-10);
1475 }
1476
1477 #[test]
1478 fn test_target_encoder_cross_validation() {
1479 let x = Array::from_shape_vec(
1480 (10, 1),
1481 vec![0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0],
1482 )
1483 .unwrap();
1484 let y = vec![1.0, 2.0, 1.5, 2.5, 1.2, 2.2, 1.3, 2.3, 1.1, 2.1];
1485
1486 let mut encoder = TargetEncoder::new("mean", 0.0, 0.0).unwrap();
1487 let encoded = encoder.fit_transform_cv(&x, &y, 5).unwrap();
1488
1489 assert_eq!(encoded.shape(), &[10, 1]);
1491
1492 assert!(encoded[[0, 0]] < encoded[[1, 0]]); assert!(encoded[[2, 0]] < encoded[[3, 0]]);
1497 }
1498
1499 #[test]
1500 fn test_target_encoder_convenience_methods() {
1501 let _x = Array::from_shape_vec((4, 1), vec![0.0, 1.0, 0.0, 1.0]).unwrap();
1502 let _y = [1.0, 2.0, 3.0, 4.0];
1503
1504 let encoder1 = TargetEncoder::with_mean(1.0);
1505 assert_eq!(encoder1.strategy, "mean");
1506 assert_abs_diff_eq!(encoder1.smoothing, 1.0, epsilon = 1e-10);
1507
1508 let encoder2 = TargetEncoder::with_median(0.5);
1509 assert_eq!(encoder2.strategy, "median");
1510 assert_abs_diff_eq!(encoder2.smoothing, 0.5, epsilon = 1e-10);
1511 }
1512
1513 #[test]
1514 fn test_target_encoder_validation_errors() {
1515 assert!(TargetEncoder::new("invalid", 0.0, 0.0).is_err());
1517
1518 assert!(TargetEncoder::new("mean", -1.0, 0.0).is_err());
1520
1521 let x = Array::from_shape_vec((3, 1), vec![0.0, 1.0, 2.0]).unwrap();
1523 let y = vec![1.0, 2.0]; let mut encoder = TargetEncoder::new("mean", 0.0, 0.0).unwrap();
1526 assert!(encoder.fit(&x, &y).is_err());
1527
1528 let encoder2 = TargetEncoder::new("mean", 0.0, 0.0).unwrap();
1530 assert!(encoder2.transform(&x).is_err());
1531
1532 let train_x = Array::from_shape_vec((2, 1), vec![0.0, 1.0]).unwrap();
1534 let test_x = Array::from_shape_vec((2, 2), vec![0.0, 1.0, 1.0, 0.0]).unwrap();
1535 let train_y = vec![1.0, 2.0];
1536
1537 let mut encoder = TargetEncoder::new("mean", 0.0, 0.0).unwrap();
1538 encoder.fit(&train_x, &train_y).unwrap();
1539 assert!(encoder.transform(&test_x).is_err());
1540
1541 let x = Array::from_shape_vec((4, 1), vec![0.0, 1.0, 0.0, 1.0]).unwrap();
1543 let y = vec![1.0, 2.0, 3.0, 4.0];
1544 let mut encoder = TargetEncoder::new("mean", 0.0, 0.0).unwrap();
1545 assert!(encoder.fit_transform_cv(&x, &y, 1).is_err()); }
1547
1548 #[test]
1549 fn test_target_encoder_accessors() {
1550 let x = Array::from_shape_vec((3, 1), vec![0.0, 1.0, 2.0]).unwrap();
1551 let y = vec![1.0, 2.0, 3.0];
1552
1553 let mut encoder = TargetEncoder::new("mean", 0.0, 0.0).unwrap();
1554
1555 assert!(!encoder.is_fitted());
1556 assert!(encoder.encodings().is_none());
1557
1558 encoder.fit(&x, &y).unwrap();
1559
1560 assert!(encoder.is_fitted());
1561 assert!(encoder.encodings().is_some());
1562 assert_abs_diff_eq!(encoder.global_mean(), 2.0, epsilon = 1e-10);
1563
1564 let encodings = encoder.encodings().unwrap();
1565 assert_eq!(encodings.len(), 1); assert_eq!(encodings[0].len(), 3); }
1568
1569 #[test]
1570 fn test_target_encoder_empty_data() {
1571 let empty_x = Array2::<f64>::zeros((0, 1));
1572 let empty_y = vec![];
1573
1574 let mut encoder = TargetEncoder::new("mean", 0.0, 0.0).unwrap();
1575 assert!(encoder.fit(&empty_x, &empty_y).is_err());
1576 }
1577
1578 #[test]
1581 fn test_binary_encoder_basic() {
1582 let data = Array::from_shape_vec((4, 1), vec![0.0, 1.0, 2.0, 3.0]).unwrap();
1584
1585 let mut encoder = BinaryEncoder::with_defaults();
1586 let encoded = encoder.fit_transform(&data).unwrap();
1587
1588 assert_eq!(encoded.shape(), &[4, 2]);
1590
1591 assert_abs_diff_eq!(encoded[[0, 0]], 0.0, epsilon = 1e-10); assert_abs_diff_eq!(encoded[[0, 1]], 0.0, epsilon = 1e-10);
1594 assert_abs_diff_eq!(encoded[[1, 0]], 0.0, epsilon = 1e-10); assert_abs_diff_eq!(encoded[[1, 1]], 1.0, epsilon = 1e-10);
1596 assert_abs_diff_eq!(encoded[[2, 0]], 1.0, epsilon = 1e-10); assert_abs_diff_eq!(encoded[[2, 1]], 0.0, epsilon = 1e-10);
1598 assert_abs_diff_eq!(encoded[[3, 0]], 1.0, epsilon = 1e-10); assert_abs_diff_eq!(encoded[[3, 1]], 1.0, epsilon = 1e-10);
1600 }
1601
1602 #[test]
1603 fn test_binary_encoder_power_of_two() {
1604 let data =
1606 Array::from_shape_vec((8, 1), vec![0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0]).unwrap();
1607
1608 let mut encoder = BinaryEncoder::with_defaults();
1609 let encoded = encoder.fit_transform(&data).unwrap();
1610
1611 assert_eq!(encoded.shape(), &[8, 3]);
1613
1614 assert_abs_diff_eq!(encoded[[0, 0]], 0.0, epsilon = 1e-10); assert_abs_diff_eq!(encoded[[0, 1]], 0.0, epsilon = 1e-10);
1617 assert_abs_diff_eq!(encoded[[0, 2]], 0.0, epsilon = 1e-10);
1618
1619 assert_abs_diff_eq!(encoded[[7, 0]], 1.0, epsilon = 1e-10); assert_abs_diff_eq!(encoded[[7, 1]], 1.0, epsilon = 1e-10);
1621 assert_abs_diff_eq!(encoded[[7, 2]], 1.0, epsilon = 1e-10);
1622 }
1623
1624 #[test]
1625 fn test_binary_encoder_non_power_of_two() {
1626 let data = Array::from_shape_vec((5, 1), vec![0.0, 1.0, 2.0, 3.0, 4.0]).unwrap();
1628
1629 let mut encoder = BinaryEncoder::with_defaults();
1630 let encoded = encoder.fit_transform(&data).unwrap();
1631
1632 assert_eq!(encoded.shape(), &[5, 3]);
1634 assert_eq!(encoder.n_output_features().unwrap(), 3);
1635 }
1636
1637 #[test]
1638 fn test_binary_encoder_single_category() {
1639 let data = Array::from_shape_vec((3, 1), vec![5.0, 5.0, 5.0]).unwrap();
1641
1642 let mut encoder = BinaryEncoder::with_defaults();
1643 let encoded = encoder.fit_transform(&data).unwrap();
1644
1645 assert_eq!(encoded.shape(), &[3, 1]);
1647 assert_eq!(encoder.n_output_features().unwrap(), 1);
1648
1649 for i in 0..3 {
1651 assert_abs_diff_eq!(encoded[[i, 0]], 0.0, epsilon = 1e-10);
1652 }
1653 }
1654
1655 #[test]
1656 fn test_binary_encoder_multi_feature() {
1657 let data = Array::from_shape_vec(
1659 (4, 2),
1660 vec![
1661 0.0, 10.0, 1.0, 11.0, 2.0, 10.0, 0.0, 11.0,
1663 ],
1664 )
1665 .unwrap();
1666
1667 let mut encoder = BinaryEncoder::with_defaults();
1668 let encoded = encoder.fit_transform(&data).unwrap();
1669
1670 assert_eq!(encoded.shape(), &[4, 3]);
1673 assert_eq!(encoder.n_output_features().unwrap(), 3);
1674
1675 let n_binary_features = encoder.n_binary_features().unwrap();
1676 assert_eq!(n_binary_features, &[2, 1]);
1677 }
1678
1679 #[test]
1680 fn test_binary_encoder_separate_fit_transform() {
1681 let train_data = Array::from_shape_vec((3, 1), vec![0.0, 1.0, 2.0]).unwrap();
1682 let test_data = Array::from_shape_vec((2, 1), vec![1.0, 0.0]).unwrap();
1683
1684 let mut encoder = BinaryEncoder::with_defaults();
1685
1686 encoder.fit(&train_data).unwrap();
1688 assert!(encoder.is_fitted());
1689
1690 let encoded = encoder.transform(&test_data).unwrap();
1692 assert_eq!(encoded.shape(), &[2, 2]); let train_encoded = encoder.transform(&train_data).unwrap();
1696 assert_abs_diff_eq!(encoded[[0, 0]], train_encoded[[1, 0]], epsilon = 1e-10); assert_abs_diff_eq!(encoded[[0, 1]], train_encoded[[1, 1]], epsilon = 1e-10);
1698 }
1699
1700 #[test]
1701 fn test_binary_encoder_unknown_categories_error() {
1702 let train_data = Array::from_shape_vec((2, 1), vec![0.0, 1.0]).unwrap();
1703 let test_data = Array::from_shape_vec((1, 1), vec![2.0]).unwrap(); let mut encoder = BinaryEncoder::new("error").unwrap();
1706 encoder.fit(&train_data).unwrap();
1707
1708 assert!(encoder.transform(&test_data).is_err());
1710 }
1711
1712 #[test]
1713 fn test_binary_encoder_unknown_categories_ignore() {
1714 let train_data = Array::from_shape_vec((2, 1), vec![0.0, 1.0]).unwrap();
1715 let test_data = Array::from_shape_vec((1, 1), vec![2.0]).unwrap(); let mut encoder = BinaryEncoder::new("ignore").unwrap();
1718 encoder.fit(&train_data).unwrap();
1719 let encoded = encoder.transform(&test_data).unwrap();
1720
1721 assert_eq!(encoded.shape(), &[1, 1]); assert_abs_diff_eq!(encoded[[0, 0]], 0.0, epsilon = 1e-10);
1724 }
1725
1726 #[test]
1727 fn test_binary_encoder_categories_accessor() {
1728 let data = Array::from_shape_vec((3, 1), vec![10.0, 20.0, 30.0]).unwrap();
1729
1730 let mut encoder = BinaryEncoder::with_defaults();
1731
1732 assert!(!encoder.is_fitted());
1734 assert!(encoder.categories().is_none());
1735 assert!(encoder.n_binary_features().is_none());
1736 assert!(encoder.n_output_features().is_none());
1737
1738 encoder.fit(&data).unwrap();
1739
1740 assert!(encoder.is_fitted());
1742 assert!(encoder.categories().is_some());
1743 assert!(encoder.n_binary_features().is_some());
1744 assert!(encoder.n_output_features().is_some());
1745
1746 let categories = encoder.categories().unwrap();
1747 assert_eq!(categories.len(), 1); assert_eq!(categories[0].len(), 3); let category_map = &categories[0];
1752 assert!(category_map.contains_key(&10));
1753 assert!(category_map.contains_key(&20));
1754 assert!(category_map.contains_key(&30));
1755 }
1756
1757 #[test]
1758 fn test_binary_encoder_int_to_binary() {
1759 assert_eq!(BinaryEncoder::int_to_binary(0, 3), vec![0, 0, 0]);
1761 assert_eq!(BinaryEncoder::int_to_binary(1, 3), vec![0, 0, 1]);
1762 assert_eq!(BinaryEncoder::int_to_binary(2, 3), vec![0, 1, 0]);
1763 assert_eq!(BinaryEncoder::int_to_binary(3, 3), vec![0, 1, 1]);
1764 assert_eq!(BinaryEncoder::int_to_binary(7, 3), vec![1, 1, 1]);
1765
1766 assert_eq!(BinaryEncoder::int_to_binary(5, 4), vec![0, 1, 0, 1]);
1768 assert_eq!(BinaryEncoder::int_to_binary(1, 1), vec![1]);
1769 }
1770
1771 #[test]
1772 fn test_binary_encoder_validation_errors() {
1773 assert!(BinaryEncoder::new("invalid").is_err());
1775
1776 let empty_data = Array2::<f64>::zeros((0, 1));
1778 let mut encoder = BinaryEncoder::with_defaults();
1779 assert!(encoder.fit(&empty_data).is_err());
1780
1781 let data = Array::from_shape_vec((2, 1), vec![0.0, 1.0]).unwrap();
1783 let encoder = BinaryEncoder::with_defaults();
1784 assert!(encoder.transform(&data).is_err());
1785
1786 let train_data = Array::from_shape_vec((2, 1), vec![0.0, 1.0]).unwrap();
1788 let test_data = Array::from_shape_vec((2, 2), vec![0.0, 1.0, 1.0, 0.0]).unwrap();
1789
1790 let mut encoder = BinaryEncoder::with_defaults();
1791 encoder.fit(&train_data).unwrap();
1792 assert!(encoder.transform(&test_data).is_err());
1793 }
1794
1795 #[test]
1796 fn test_binary_encoder_consistency() {
1797 let data = Array::from_shape_vec((4, 1), vec![3.0, 1.0, 4.0, 1.0]).unwrap();
1799
1800 let mut encoder = BinaryEncoder::with_defaults();
1801 let encoded1 = encoder.fit_transform(&data).unwrap();
1802 let encoded2 = encoder.transform(&data).unwrap();
1803
1804 for i in 0..encoded1.shape()[0] {
1806 for j in 0..encoded1.shape()[1] {
1807 assert_abs_diff_eq!(encoded1[[i, j]], encoded2[[i, j]], epsilon = 1e-10);
1808 }
1809 }
1810
1811 assert_abs_diff_eq!(encoded1[[1, 0]], encoded1[[3, 0]], epsilon = 1e-10); assert_abs_diff_eq!(encoded1[[1, 1]], encoded1[[3, 1]], epsilon = 1e-10);
1814 }
1815
1816 #[test]
1817 fn test_binary_encoder_memory_efficiency() {
1818 let data = Array::from_shape_vec(
1821 (10, 1),
1822 vec![0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0],
1823 )
1824 .unwrap();
1825
1826 let mut binary_encoder = BinaryEncoder::with_defaults();
1827 let binary_encoded = binary_encoder.fit_transform(&data).unwrap();
1828
1829 let mut onehot_encoder = OneHotEncoder::with_defaults();
1830 let onehot_encoded = onehot_encoder.fit_transform(&data).unwrap();
1831
1832 assert_eq!(binary_encoded.shape()[1], 4); assert_eq!(onehot_encoded.shape()[1], 10); assert!(binary_encoded.shape()[1] < onehot_encoded.shape()[1]);
1836 }
1837}