1use scirs2_core::ndarray::{Array2, ArrayBase, Data, Ix2};
7use scirs2_core::numeric::{Float, NumCast};
8use std::collections::HashMap;
9
10use crate::error::{Result, TransformError};
11
12#[derive(Debug, Clone)]
14pub struct SparseMatrix {
15 pub shape: (usize, usize),
17 pub row_indices: Vec<usize>,
19 pub col_indices: Vec<usize>,
21 pub values: Vec<f64>,
23}
24
25impl SparseMatrix {
26 pub fn new(shape: (usize, usize)) -> Self {
28 SparseMatrix {
29 shape,
30 row_indices: Vec::new(),
31 col_indices: Vec::new(),
32 values: Vec::new(),
33 }
34 }
35
36 pub fn push(&mut self, row: usize, col: usize, value: f64) {
38 if row < self.shape.0 && col < self.shape.1 && value != 0.0 {
39 self.row_indices.push(row);
40 self.col_indices.push(col);
41 self.values.push(value);
42 }
43 }
44
45 pub fn to_dense(&self) -> Array2<f64> {
47 let mut dense = Array2::zeros(self.shape);
48 for ((&row, &col), &val) in self
49 .row_indices
50 .iter()
51 .zip(self.col_indices.iter())
52 .zip(self.values.iter())
53 {
54 dense[[row, col]] = val;
55 }
56 dense
57 }
58
59 pub fn nnz(&self) -> usize {
61 self.values.len()
62 }
63}
64
65#[derive(Debug, Clone)]
67pub enum EncodedOutput {
68 Dense(Array2<f64>),
70 Sparse(SparseMatrix),
72}
73
74impl EncodedOutput {
75 pub fn to_dense(&self) -> Array2<f64> {
77 match self {
78 EncodedOutput::Dense(arr) => arr.clone(),
79 EncodedOutput::Sparse(sparse) => sparse.to_dense(),
80 }
81 }
82
83 pub fn shape(&self) -> (usize, usize) {
85 match self {
86 EncodedOutput::Dense(arr) => (arr.nrows(), arr.ncols()),
87 EncodedOutput::Sparse(sparse) => sparse.shape,
88 }
89 }
90}
91
92pub struct OneHotEncoder {
97 categories_: Option<Vec<Vec<u64>>>,
99 drop: Option<String>,
101 handleunknown: String,
103 sparse: bool,
105}
106
107impl OneHotEncoder {
108 pub fn new(_drop: Option<String>, handleunknown: &str, sparse: bool) -> Result<Self> {
118 if let Some(ref drop_strategy) = _drop {
119 if drop_strategy != "first" && drop_strategy != "if_binary" {
120 return Err(TransformError::InvalidInput(
121 "_drop must be 'first', 'if_binary', or None".to_string(),
122 ));
123 }
124 }
125
126 if handleunknown != "error" && handleunknown != "ignore" {
127 return Err(TransformError::InvalidInput(
128 "handleunknown must be 'error' or 'ignore'".to_string(),
129 ));
130 }
131
132 Ok(OneHotEncoder {
133 categories_: None,
134 drop: _drop,
135 handleunknown: handleunknown.to_string(),
136 sparse,
137 })
138 }
139
140 pub fn with_defaults() -> Self {
142 Self::new(None, "error", false).unwrap()
143 }
144
145 pub fn fit<S>(&mut self, x: &ArrayBase<S, Ix2>) -> Result<()>
153 where
154 S: Data,
155 S::Elem: Float + NumCast,
156 {
157 let x_u64 = x.mapv(|x| {
158 let val_f64 = NumCast::from(x).unwrap_or(0.0);
159 val_f64 as u64
160 });
161
162 let n_samples = x_u64.shape()[0];
163 let n_features = x_u64.shape()[1];
164
165 if n_samples == 0 || n_features == 0 {
166 return Err(TransformError::InvalidInput("Empty input data".to_string()));
167 }
168
169 let mut categories = Vec::with_capacity(n_features);
170
171 for j in 0..n_features {
172 let mut unique_values: Vec<u64> = x_u64.column(j).to_vec();
174 unique_values.sort_unstable();
175 unique_values.dedup();
176
177 categories.push(unique_values);
178 }
179
180 self.categories_ = Some(categories);
181 Ok(())
182 }
183
184 pub fn transform<S>(&self, x: &ArrayBase<S, Ix2>) -> Result<EncodedOutput>
192 where
193 S: Data,
194 S::Elem: Float + NumCast,
195 {
196 let x_u64 = x.mapv(|x| {
197 let val_f64 = NumCast::from(x).unwrap_or(0.0);
198 val_f64 as u64
199 });
200
201 let n_samples = x_u64.shape()[0];
202 let n_features = x_u64.shape()[1];
203
204 if self.categories_.is_none() {
205 return Err(TransformError::TransformationError(
206 "OneHotEncoder has not been fitted".to_string(),
207 ));
208 }
209
210 let categories = self.categories_.as_ref().unwrap();
211
212 if n_features != categories.len() {
213 return Err(TransformError::InvalidInput(format!(
214 "x has {} features, but OneHotEncoder was fitted with {} features",
215 n_features,
216 categories.len()
217 )));
218 }
219
220 let mut total_features = 0;
222 for (j, feature_categories) in categories.iter().enumerate() {
223 let n_cats = feature_categories.len();
224
225 let n_output_cats = match &self.drop {
227 Some(strategy) if strategy == "first" => n_cats.saturating_sub(1),
228 Some(strategy) if strategy == "if_binary" && n_cats == 2 => 1,
229 _ => n_cats,
230 };
231
232 if n_output_cats == 0 {
233 return Err(TransformError::InvalidInput(format!(
234 "Feature {j} has only one category after dropping"
235 )));
236 }
237
238 total_features += n_output_cats;
239 }
240
241 let mut category_mappings = Vec::new();
243 let mut current_col = 0;
244
245 for feature_categories in categories.iter() {
246 let mut mapping = HashMap::new();
247 let n_cats = feature_categories.len();
248
249 let (start_idx, n_output_cats) = match &self.drop {
251 Some(strategy) if strategy == "first" => (1, n_cats.saturating_sub(1)),
252 Some(strategy) if strategy == "if_binary" && n_cats == 2 => (0, 1),
253 _ => (0, n_cats),
254 };
255
256 for (cat_idx, &category) in feature_categories.iter().enumerate() {
257 if cat_idx >= start_idx && cat_idx < start_idx + n_output_cats {
258 mapping.insert(category, current_col + cat_idx - start_idx);
259 }
260 }
261
262 category_mappings.push(mapping);
263 current_col += n_output_cats;
264 }
265
266 if self.sparse {
268 let mut sparse_matrix = SparseMatrix::new((n_samples, total_features));
270
271 for i in 0..n_samples {
272 for j in 0..n_features {
273 let value = x_u64[[i, j]];
274
275 if let Some(&col_idx) = category_mappings[j].get(&value) {
276 sparse_matrix.push(i, col_idx, 1.0);
277 } else {
278 let feature_categories = &categories[j];
280 let is_dropped_category = match &self.drop {
281 Some(strategy) if strategy == "first" => {
282 !feature_categories.is_empty() && value == feature_categories[0]
283 }
284 Some(strategy)
285 if strategy == "if_binary" && feature_categories.len() == 2 =>
286 {
287 feature_categories.len() == 2 && value == feature_categories[1]
288 }
289 _ => false,
290 };
291
292 if !is_dropped_category && self.handleunknown == "error" {
293 return Err(TransformError::InvalidInput(format!(
294 "Found unknown category {value} in feature {j}"
295 )));
296 }
297 }
299 }
300 }
301
302 Ok(EncodedOutput::Sparse(sparse_matrix))
303 } else {
304 let mut transformed = Array2::zeros((n_samples, total_features));
306
307 for i in 0..n_samples {
308 for j in 0..n_features {
309 let value = x_u64[[i, j]];
310
311 if let Some(&col_idx) = category_mappings[j].get(&value) {
312 transformed[[i, col_idx]] = 1.0;
313 } else {
314 let feature_categories = &categories[j];
316 let is_dropped_category = match &self.drop {
317 Some(strategy) if strategy == "first" => {
318 !feature_categories.is_empty() && value == feature_categories[0]
319 }
320 Some(strategy)
321 if strategy == "if_binary" && feature_categories.len() == 2 =>
322 {
323 feature_categories.len() == 2 && value == feature_categories[1]
324 }
325 _ => false,
326 };
327
328 if !is_dropped_category && self.handleunknown == "error" {
329 return Err(TransformError::InvalidInput(format!(
330 "Found unknown category {value} in feature {j}"
331 )));
332 }
333 }
335 }
336 }
337
338 Ok(EncodedOutput::Dense(transformed))
339 }
340 }
341
342 pub fn fit_transform<S>(&mut self, x: &ArrayBase<S, Ix2>) -> Result<EncodedOutput>
350 where
351 S: Data,
352 S::Elem: Float + NumCast,
353 {
354 self.fit(x)?;
355 self.transform(x)
356 }
357
358 pub fn transform_dense<S>(&self, x: &ArrayBase<S, Ix2>) -> Result<Array2<f64>>
366 where
367 S: Data,
368 S::Elem: Float + NumCast,
369 {
370 Ok(self.transform(x)?.to_dense())
371 }
372
373 pub fn fit_transform_dense<S>(&mut self, x: &ArrayBase<S, Ix2>) -> Result<Array2<f64>>
381 where
382 S: Data,
383 S::Elem: Float + NumCast,
384 {
385 Ok(self.fit_transform(x)?.to_dense())
386 }
387
388 pub fn categories(&self) -> Option<&Vec<Vec<u64>>> {
393 self.categories_.as_ref()
394 }
395
396 pub fn get_feature_names(&self, inputfeatures: Option<&[String]>) -> Result<Vec<String>> {
404 if self.categories_.is_none() {
405 return Err(TransformError::TransformationError(
406 "OneHotEncoder has not been fitted".to_string(),
407 ));
408 }
409
410 let categories = self.categories_.as_ref().unwrap();
411 let mut feature_names = Vec::new();
412
413 for (j, feature_categories) in categories.iter().enumerate() {
414 let feature_name = if let Some(names) = inputfeatures {
415 if j < names.len() {
416 names[j].clone()
417 } else {
418 format!("x{j}")
419 }
420 } else {
421 format!("x{j}")
422 };
423
424 let n_cats = feature_categories.len();
425
426 let (start_idx, n_output_cats) = match &self.drop {
428 Some(strategy) if strategy == "first" => (1, n_cats.saturating_sub(1)),
429 Some(strategy) if strategy == "if_binary" && n_cats == 2 => (0, 1),
430 _ => (0, n_cats),
431 };
432
433 for &category in feature_categories
434 .iter()
435 .skip(start_idx)
436 .take(n_output_cats)
437 {
438 feature_names.push(format!("{feature_name}_cat_{category}"));
439 }
440 }
441
442 Ok(feature_names)
443 }
444}
445
446pub struct OrdinalEncoder {
451 categories_: Option<Vec<Vec<u64>>>,
453 handleunknown: String,
455 unknownvalue: Option<f64>,
457}
458
459impl OrdinalEncoder {
460 pub fn new(handleunknown: &str, unknownvalue: Option<f64>) -> Result<Self> {
469 if handleunknown != "error" && handleunknown != "use_encoded_value" {
470 return Err(TransformError::InvalidInput(
471 "handleunknown must be 'error' or 'use_encoded_value'".to_string(),
472 ));
473 }
474
475 if handleunknown == "use_encoded_value" && unknownvalue.is_none() {
476 return Err(TransformError::InvalidInput(
477 "unknownvalue must be specified when handleunknown='use_encoded_value'".to_string(),
478 ));
479 }
480
481 Ok(OrdinalEncoder {
482 categories_: None,
483 handleunknown: handleunknown.to_string(),
484 unknownvalue,
485 })
486 }
487
488 pub fn with_defaults() -> Self {
490 Self::new("error", None).unwrap()
491 }
492
493 pub fn fit<S>(&mut self, x: &ArrayBase<S, Ix2>) -> Result<()>
501 where
502 S: Data,
503 S::Elem: Float + NumCast,
504 {
505 let x_u64 = x.mapv(|x| {
506 let val_f64 = NumCast::from(x).unwrap_or(0.0);
507 val_f64 as u64
508 });
509
510 let n_samples = x_u64.shape()[0];
511 let n_features = x_u64.shape()[1];
512
513 if n_samples == 0 || n_features == 0 {
514 return Err(TransformError::InvalidInput("Empty input data".to_string()));
515 }
516
517 let mut categories = Vec::with_capacity(n_features);
518
519 for j in 0..n_features {
520 let mut unique_values: Vec<u64> = x_u64.column(j).to_vec();
522 unique_values.sort_unstable();
523 unique_values.dedup();
524
525 categories.push(unique_values);
526 }
527
528 self.categories_ = Some(categories);
529 Ok(())
530 }
531
532 pub fn transform<S>(&self, x: &ArrayBase<S, Ix2>) -> Result<Array2<f64>>
540 where
541 S: Data,
542 S::Elem: Float + NumCast,
543 {
544 let x_u64 = x.mapv(|x| {
545 let val_f64 = NumCast::from(x).unwrap_or(0.0);
546 val_f64 as u64
547 });
548
549 let n_samples = x_u64.shape()[0];
550 let n_features = x_u64.shape()[1];
551
552 if self.categories_.is_none() {
553 return Err(TransformError::TransformationError(
554 "OrdinalEncoder has not been fitted".to_string(),
555 ));
556 }
557
558 let categories = self.categories_.as_ref().unwrap();
559
560 if n_features != categories.len() {
561 return Err(TransformError::InvalidInput(format!(
562 "x has {} features, but OrdinalEncoder was fitted with {} features",
563 n_features,
564 categories.len()
565 )));
566 }
567
568 let mut transformed = Array2::zeros((n_samples, n_features));
569
570 let mut category_mappings = Vec::new();
572 for feature_categories in categories {
573 let mut mapping = HashMap::new();
574 for (ordinal, &category) in feature_categories.iter().enumerate() {
575 mapping.insert(category, ordinal as f64);
576 }
577 category_mappings.push(mapping);
578 }
579
580 for i in 0..n_samples {
582 for j in 0..n_features {
583 let value = x_u64[[i, j]];
584
585 if let Some(&ordinal_value) = category_mappings[j].get(&value) {
586 transformed[[i, j]] = ordinal_value;
587 } else if self.handleunknown == "error" {
588 return Err(TransformError::InvalidInput(format!(
589 "Found unknown category {value} in feature {j}"
590 )));
591 } else {
592 transformed[[i, j]] = self.unknownvalue.unwrap();
594 }
595 }
596 }
597
598 Ok(transformed)
599 }
600
601 pub fn fit_transform<S>(&mut self, x: &ArrayBase<S, Ix2>) -> Result<Array2<f64>>
609 where
610 S: Data,
611 S::Elem: Float + NumCast,
612 {
613 self.fit(x)?;
614 self.transform(x)
615 }
616
617 pub fn categories(&self) -> Option<&Vec<Vec<u64>>> {
622 self.categories_.as_ref()
623 }
624}
625
626#[derive(Debug, Clone)]
644pub struct TargetEncoder {
645 strategy: String,
647 smoothing: f64,
649 globalstat: f64,
651 encodings_: Option<Vec<HashMap<u64, f64>>>,
653 is_fitted: bool,
655 global_mean_: f64,
657}
658
659impl TargetEncoder {
660 pub fn new(_strategy: &str, smoothing: f64, globalstat: f64) -> Result<Self> {
670 if !["mean", "median", "count", "sum"].contains(&_strategy) {
671 return Err(TransformError::InvalidInput(
672 "_strategy must be 'mean', 'median', 'count', or 'sum'".to_string(),
673 ));
674 }
675
676 if smoothing < 0.0 {
677 return Err(TransformError::InvalidInput(
678 "smoothing parameter must be non-negative".to_string(),
679 ));
680 }
681
682 Ok(TargetEncoder {
683 strategy: _strategy.to_string(),
684 smoothing,
685 globalstat,
686 encodings_: None,
687 is_fitted: false,
688 global_mean_: 0.0,
689 })
690 }
691
692 pub fn with_mean(smoothing: f64) -> Self {
694 TargetEncoder {
695 strategy: "mean".to_string(),
696 smoothing,
697 globalstat: 0.0,
698 encodings_: None,
699 is_fitted: false,
700 global_mean_: 0.0,
701 }
702 }
703
704 pub fn with_median(smoothing: f64) -> Self {
706 TargetEncoder {
707 strategy: "median".to_string(),
708 smoothing,
709 globalstat: 0.0,
710 encodings_: None,
711 is_fitted: false,
712 global_mean_: 0.0,
713 }
714 }
715
716 pub fn fit<S>(&mut self, x: &ArrayBase<S, Ix2>, y: &[f64]) -> Result<()>
725 where
726 S: Data,
727 S::Elem: Float + NumCast,
728 {
729 let x_u64 = x.mapv(|x| {
730 let val_f64 = NumCast::from(x).unwrap_or(0.0);
731 val_f64 as u64
732 });
733
734 let n_samples = x_u64.shape()[0];
735 let n_features = x_u64.shape()[1];
736
737 if n_samples == 0 || n_features == 0 {
738 return Err(TransformError::InvalidInput("Empty input data".to_string()));
739 }
740
741 if y.len() != n_samples {
742 return Err(TransformError::InvalidInput(
743 "Number of target values must match number of samples".to_string(),
744 ));
745 }
746
747 self.global_mean_ = y.iter().sum::<f64>() / y.len() as f64;
749
750 let mut encodings = Vec::with_capacity(n_features);
751
752 for j in 0..n_features {
753 let mut category_targets: HashMap<u64, Vec<f64>> = HashMap::new();
755
756 for i in 0..n_samples {
757 let category = x_u64[[i, j]];
758 category_targets.entry(category).or_default().push(y[i]);
759 }
760
761 let mut category_encoding = HashMap::new();
763
764 for (category, targets) in category_targets.iter() {
765 let encoded_value = match self.strategy.as_str() {
766 "mean" => {
767 let category_mean = targets.iter().sum::<f64>() / targets.len() as f64;
768 let count = targets.len() as f64;
769
770 if self.smoothing > 0.0 {
772 (count * category_mean + self.smoothing * self.global_mean_)
773 / (count + self.smoothing)
774 } else {
775 category_mean
776 }
777 }
778 "median" => {
779 let mut sorted_targets = targets.clone();
780 sorted_targets.sort_by(|a, b| a.partial_cmp(b).unwrap());
781
782 let median = if sorted_targets.len() % 2 == 0 {
783 let mid = sorted_targets.len() / 2;
784 (sorted_targets[mid - 1] + sorted_targets[mid]) / 2.0
785 } else {
786 sorted_targets[sorted_targets.len() / 2]
787 };
788
789 if self.smoothing > 0.0 {
791 let count = targets.len() as f64;
792 (count * median + self.smoothing * self.global_mean_)
793 / (count + self.smoothing)
794 } else {
795 median
796 }
797 }
798 "count" => targets.len() as f64,
799 "sum" => targets.iter().sum::<f64>(),
800 _ => unreachable!(),
801 };
802
803 category_encoding.insert(*category, encoded_value);
804 }
805
806 encodings.push(category_encoding);
807 }
808
809 self.encodings_ = Some(encodings);
810 self.is_fitted = true;
811 Ok(())
812 }
813
814 pub fn transform<S>(&self, x: &ArrayBase<S, Ix2>) -> Result<Array2<f64>>
822 where
823 S: Data,
824 S::Elem: Float + NumCast,
825 {
826 if !self.is_fitted {
827 return Err(TransformError::TransformationError(
828 "TargetEncoder has not been fitted".to_string(),
829 ));
830 }
831
832 let x_u64 = x.mapv(|x| {
833 let val_f64 = NumCast::from(x).unwrap_or(0.0);
834 val_f64 as u64
835 });
836
837 let n_samples = x_u64.shape()[0];
838 let n_features = x_u64.shape()[1];
839
840 let encodings = self.encodings_.as_ref().unwrap();
841
842 if n_features != encodings.len() {
843 return Err(TransformError::InvalidInput(format!(
844 "x has {} features, but TargetEncoder was fitted with {} features",
845 n_features,
846 encodings.len()
847 )));
848 }
849
850 let mut transformed = Array2::zeros((n_samples, n_features));
851
852 for i in 0..n_samples {
853 for j in 0..n_features {
854 let category = x_u64[[i, j]];
855
856 if let Some(&encoded_value) = encodings[j].get(&category) {
857 transformed[[i, j]] = encoded_value;
858 } else {
859 transformed[[i, j]] = if self.globalstat != 0.0 {
861 self.globalstat
862 } else {
863 self.global_mean_
864 };
865 }
866 }
867 }
868
869 Ok(transformed)
870 }
871
872 pub fn fit_transform<S>(&mut self, x: &ArrayBase<S, Ix2>, y: &[f64]) -> Result<Array2<f64>>
881 where
882 S: Data,
883 S::Elem: Float + NumCast,
884 {
885 self.fit(x, y)?;
886 self.transform(x)
887 }
888
889 pub fn encodings(&self) -> Option<&Vec<HashMap<u64, f64>>> {
894 self.encodings_.as_ref()
895 }
896
897 pub fn is_fitted(&self) -> bool {
899 self.is_fitted
900 }
901
902 pub fn global_mean(&self) -> f64 {
904 self.global_mean_
905 }
906
907 pub fn fit_transform_cv<S>(
921 &mut self,
922 x: &ArrayBase<S, Ix2>,
923 y: &[f64],
924 cv_folds: usize,
925 ) -> Result<Array2<f64>>
926 where
927 S: Data,
928 S::Elem: Float + NumCast,
929 {
930 let x_u64 = x.mapv(|x| {
931 let val_f64 = NumCast::from(x).unwrap_or(0.0);
932 val_f64 as u64
933 });
934
935 let n_samples = x_u64.shape()[0];
936 let n_features = x_u64.shape()[1];
937
938 if n_samples == 0 || n_features == 0 {
939 return Err(TransformError::InvalidInput("Empty input data".to_string()));
940 }
941
942 if y.len() != n_samples {
943 return Err(TransformError::InvalidInput(
944 "Number of target values must match number of samples".to_string(),
945 ));
946 }
947
948 if cv_folds < 2 {
949 return Err(TransformError::InvalidInput(
950 "cv_folds must be at least 2".to_string(),
951 ));
952 }
953
954 let mut transformed = Array2::zeros((n_samples, n_features));
955
956 self.global_mean_ = y.iter().sum::<f64>() / y.len() as f64;
958
959 let fold_size = n_samples / cv_folds;
961 let mut fold_indices = Vec::new();
962 for fold in 0..cv_folds {
963 let start = fold * fold_size;
964 let end = if fold == cv_folds - 1 {
965 n_samples
966 } else {
967 (fold + 1) * fold_size
968 };
969 fold_indices.push((start, end));
970 }
971
972 for fold in 0..cv_folds {
974 let (val_start, val_end) = fold_indices[fold];
975
976 let mut train_indices = Vec::new();
978 for (other_fold, &(start, end)) in fold_indices.iter().enumerate().take(cv_folds) {
979 if other_fold != fold {
980 train_indices.extend(start..end);
981 }
982 }
983
984 for j in 0..n_features {
986 let mut category_targets: HashMap<u64, Vec<f64>> = HashMap::new();
987
988 for &train_idx in &train_indices {
990 let category = x_u64[[train_idx, j]];
991 category_targets
992 .entry(category)
993 .or_default()
994 .push(y[train_idx]);
995 }
996
997 let mut category_encoding = HashMap::new();
999 for (category, targets) in category_targets.iter() {
1000 let encoded_value = match self.strategy.as_str() {
1001 "mean" => {
1002 let category_mean = targets.iter().sum::<f64>() / targets.len() as f64;
1003 let count = targets.len() as f64;
1004
1005 if self.smoothing > 0.0 {
1006 (count * category_mean + self.smoothing * self.global_mean_)
1007 / (count + self.smoothing)
1008 } else {
1009 category_mean
1010 }
1011 }
1012 "median" => {
1013 let mut sorted_targets = targets.clone();
1014 sorted_targets.sort_by(|a, b| a.partial_cmp(b).unwrap());
1015
1016 let median = if sorted_targets.len() % 2 == 0 {
1017 let mid = sorted_targets.len() / 2;
1018 (sorted_targets[mid - 1] + sorted_targets[mid]) / 2.0
1019 } else {
1020 sorted_targets[sorted_targets.len() / 2]
1021 };
1022
1023 if self.smoothing > 0.0 {
1024 let count = targets.len() as f64;
1025 (count * median + self.smoothing * self.global_mean_)
1026 / (count + self.smoothing)
1027 } else {
1028 median
1029 }
1030 }
1031 "count" => targets.len() as f64,
1032 "sum" => targets.iter().sum::<f64>(),
1033 _ => unreachable!(),
1034 };
1035
1036 category_encoding.insert(*category, encoded_value);
1037 }
1038
1039 for val_idx in val_start..val_end {
1041 let category = x_u64[[val_idx, j]];
1042
1043 if let Some(&encoded_value) = category_encoding.get(&category) {
1044 transformed[[val_idx, j]] = encoded_value;
1045 } else {
1046 transformed[[val_idx, j]] = self.global_mean_;
1048 }
1049 }
1050 }
1051 }
1052
1053 self.fit(x, y)?;
1055
1056 Ok(transformed)
1057 }
1058}
1059
1060#[derive(Debug, Clone)]
1068pub struct BinaryEncoder {
1069 categories_: Option<Vec<HashMap<u64, Vec<u8>>>>,
1071 n_binary_features_: Option<Vec<usize>>,
1073 handleunknown: String,
1075 is_fitted: bool,
1077}
1078
1079impl BinaryEncoder {
1080 pub fn new(handleunknown: &str) -> Result<Self> {
1090 if handleunknown != "error" && handleunknown != "ignore" {
1091 return Err(TransformError::InvalidInput(
1092 "handleunknown must be 'error' or 'ignore'".to_string(),
1093 ));
1094 }
1095
1096 Ok(BinaryEncoder {
1097 categories_: None,
1098 n_binary_features_: None,
1099 handleunknown: handleunknown.to_string(),
1100 is_fitted: false,
1101 })
1102 }
1103
1104 pub fn with_defaults() -> Self {
1106 Self::new("error").unwrap()
1107 }
1108
1109 pub fn fit<S>(&mut self, x: &ArrayBase<S, Ix2>) -> Result<()>
1117 where
1118 S: Data,
1119 S::Elem: Float + NumCast,
1120 {
1121 let x_u64 = x.mapv(|x| {
1122 let val_f64 = NumCast::from(x).unwrap_or(0.0);
1123 val_f64 as u64
1124 });
1125
1126 let n_samples = x_u64.shape()[0];
1127 let n_features = x_u64.shape()[1];
1128
1129 if n_samples == 0 || n_features == 0 {
1130 return Err(TransformError::InvalidInput("Empty input data".to_string()));
1131 }
1132
1133 let mut categories = Vec::with_capacity(n_features);
1134 let mut n_binary_features = Vec::with_capacity(n_features);
1135
1136 for j in 0..n_features {
1137 let mut unique_categories: Vec<u64> = x_u64.column(j).to_vec();
1139 unique_categories.sort_unstable();
1140 unique_categories.dedup();
1141
1142 if unique_categories.is_empty() {
1143 return Err(TransformError::InvalidInput(
1144 "Feature has no valid categories".to_string(),
1145 ));
1146 }
1147
1148 let n_cats = unique_categories.len();
1150 let nbits = if n_cats <= 1 {
1151 1
1152 } else {
1153 (n_cats as f64).log2().ceil() as usize
1154 };
1155
1156 let mut category_map = HashMap::new();
1158 for (idx, &category) in unique_categories.iter().enumerate() {
1159 let binary_code = Self::int_to_binary(idx, nbits);
1160 category_map.insert(category, binary_code);
1161 }
1162
1163 categories.push(category_map);
1164 n_binary_features.push(nbits);
1165 }
1166
1167 self.categories_ = Some(categories);
1168 self.n_binary_features_ = Some(n_binary_features);
1169 self.is_fitted = true;
1170
1171 Ok(())
1172 }
1173
1174 pub fn transform<S>(&self, x: &ArrayBase<S, Ix2>) -> Result<Array2<f64>>
1182 where
1183 S: Data,
1184 S::Elem: Float + NumCast,
1185 {
1186 if !self.is_fitted {
1187 return Err(TransformError::InvalidInput(
1188 "Encoder has not been fitted yet".to_string(),
1189 ));
1190 }
1191
1192 let categories = self.categories_.as_ref().unwrap();
1193 let n_binary_features = self.n_binary_features_.as_ref().unwrap();
1194
1195 let x_u64 = x.mapv(|x| {
1196 let val_f64 = NumCast::from(x).unwrap_or(0.0);
1197 val_f64 as u64
1198 });
1199
1200 let n_samples = x_u64.shape()[0];
1201 let n_features = x_u64.shape()[1];
1202
1203 if n_features != categories.len() {
1204 return Err(TransformError::InvalidInput(format!(
1205 "Number of features ({}) does not match fitted features ({})",
1206 n_features,
1207 categories.len()
1208 )));
1209 }
1210
1211 let total_binary_features: usize = n_binary_features.iter().sum();
1213 let mut result = Array2::<f64>::zeros((n_samples, total_binary_features));
1214
1215 let mut output_col = 0;
1216 for j in 0..n_features {
1217 let category_map = &categories[j];
1218 let nbits = n_binary_features[j];
1219
1220 for i in 0..n_samples {
1221 let category = x_u64[[i, j]];
1222
1223 if let Some(binary_code) = category_map.get(&category) {
1224 for (bit_idx, &bit_val) in binary_code.iter().enumerate() {
1226 result[[i, output_col + bit_idx]] = bit_val as f64;
1227 }
1228 } else {
1229 match self.handleunknown.as_str() {
1231 "error" => {
1232 return Err(TransformError::InvalidInput(format!(
1233 "Unknown category {category} in feature {j}"
1234 )));
1235 }
1236 "ignore" => {
1237 }
1239 _ => unreachable!(),
1240 }
1241 }
1242 }
1243
1244 output_col += nbits;
1245 }
1246
1247 Ok(result)
1248 }
1249
1250 pub fn fit_transform<S>(&mut self, x: &ArrayBase<S, Ix2>) -> Result<Array2<f64>>
1258 where
1259 S: Data,
1260 S::Elem: Float + NumCast,
1261 {
1262 self.fit(x)?;
1263 self.transform(x)
1264 }
1265
1266 pub fn is_fitted(&self) -> bool {
1268 self.is_fitted
1269 }
1270
1271 pub fn categories(&self) -> Option<&Vec<HashMap<u64, Vec<u8>>>> {
1273 self.categories_.as_ref()
1274 }
1275
1276 pub fn n_binary_features(&self) -> Option<&Vec<usize>> {
1278 self.n_binary_features_.as_ref()
1279 }
1280
1281 pub fn n_output_features(&self) -> Option<usize> {
1283 self.n_binary_features_.as_ref().map(|v| v.iter().sum())
1284 }
1285
1286 fn int_to_binary(_value: usize, nbits: usize) -> Vec<u8> {
1288 let mut binary = Vec::with_capacity(nbits);
1289 let mut val = _value;
1290
1291 for _ in 0..nbits {
1292 binary.push((val & 1) as u8);
1293 val >>= 1;
1294 }
1295
1296 binary.reverse(); binary
1298 }
1299}
1300
1301#[derive(Debug, Clone)]
1307pub struct FrequencyEncoder {
1308 frequency_maps_: Option<Vec<HashMap<u64, f64>>>,
1310 normalize: bool,
1312 handleunknown: String,
1314 unknownvalue: f64,
1316 is_fitted: bool,
1318}
1319
1320impl FrequencyEncoder {
1321 pub fn new(normalize: bool, handleunknown: &str, unknownvalue: f64) -> Result<Self> {
1331 if !["error", "ignore", "use_encoded_value"].contains(&handleunknown) {
1332 return Err(TransformError::InvalidInput(
1333 "handleunknown must be 'error', 'ignore', or 'use_encoded_value'".to_string(),
1334 ));
1335 }
1336
1337 Ok(FrequencyEncoder {
1338 frequency_maps_: None,
1339 normalize,
1340 handleunknown: handleunknown.to_string(),
1341 unknownvalue,
1342 is_fitted: false,
1343 })
1344 }
1345
1346 pub fn with_defaults() -> Self {
1348 Self::new(false, "error", 0.0).unwrap()
1349 }
1350
1351 pub fn with_normalization() -> Self {
1353 Self::new(true, "error", 0.0).unwrap()
1354 }
1355
1356 pub fn fit<S>(&mut self, x: &ArrayBase<S, Ix2>) -> Result<()>
1364 where
1365 S: Data,
1366 S::Elem: Float + NumCast,
1367 {
1368 let x_u64 = x.mapv(|x| {
1369 let val_f64 = NumCast::from(x).unwrap_or(0.0);
1370 val_f64 as u64
1371 });
1372
1373 let n_samples = x_u64.shape()[0];
1374 let n_features = x_u64.shape()[1];
1375
1376 if n_samples == 0 || n_features == 0 {
1377 return Err(TransformError::InvalidInput("Empty input data".to_string()));
1378 }
1379
1380 let mut frequency_maps = Vec::with_capacity(n_features);
1381
1382 for j in 0..n_features {
1383 let mut category_counts: HashMap<u64, usize> = HashMap::new();
1385 for i in 0..n_samples {
1386 let category = x_u64[[i, j]];
1387 *category_counts.entry(category).or_insert(0) += 1;
1388 }
1389
1390 let mut frequency_map = HashMap::new();
1392 for (category, count) in category_counts {
1393 let frequency = if self.normalize {
1394 count as f64 / n_samples as f64
1395 } else {
1396 count as f64
1397 };
1398 frequency_map.insert(category, frequency);
1399 }
1400
1401 frequency_maps.push(frequency_map);
1402 }
1403
1404 self.frequency_maps_ = Some(frequency_maps);
1405 self.is_fitted = true;
1406 Ok(())
1407 }
1408
1409 pub fn transform<S>(&self, x: &ArrayBase<S, Ix2>) -> Result<Array2<f64>>
1417 where
1418 S: Data,
1419 S::Elem: Float + NumCast,
1420 {
1421 if !self.is_fitted {
1422 return Err(TransformError::TransformationError(
1423 "FrequencyEncoder has not been fitted".to_string(),
1424 ));
1425 }
1426
1427 let frequency_maps = self.frequency_maps_.as_ref().unwrap();
1428
1429 let x_u64 = x.mapv(|x| {
1430 let val_f64 = NumCast::from(x).unwrap_or(0.0);
1431 val_f64 as u64
1432 });
1433
1434 let n_samples = x_u64.shape()[0];
1435 let n_features = x_u64.shape()[1];
1436
1437 if n_features != frequency_maps.len() {
1438 return Err(TransformError::InvalidInput(format!(
1439 "x has {} features, but FrequencyEncoder was fitted with {} features",
1440 n_features,
1441 frequency_maps.len()
1442 )));
1443 }
1444
1445 let mut transformed = Array2::zeros((n_samples, n_features));
1446
1447 for i in 0..n_samples {
1448 for j in 0..n_features {
1449 let category = x_u64[[i, j]];
1450
1451 if let Some(&frequency) = frequency_maps[j].get(&category) {
1452 transformed[[i, j]] = frequency;
1453 } else {
1454 match self.handleunknown.as_str() {
1456 "error" => {
1457 return Err(TransformError::InvalidInput(format!(
1458 "Unknown category {category} in feature {j}"
1459 )));
1460 }
1461 "ignore" => {
1462 transformed[[i, j]] = 0.0;
1463 }
1464 "use_encoded_value" => {
1465 transformed[[i, j]] = self.unknownvalue;
1466 }
1467 _ => unreachable!(),
1468 }
1469 }
1470 }
1471 }
1472
1473 Ok(transformed)
1474 }
1475
1476 pub fn fit_transform<S>(&mut self, x: &ArrayBase<S, Ix2>) -> Result<Array2<f64>>
1484 where
1485 S: Data,
1486 S::Elem: Float + NumCast,
1487 {
1488 self.fit(x)?;
1489 self.transform(x)
1490 }
1491
1492 pub fn is_fitted(&self) -> bool {
1494 self.is_fitted
1495 }
1496
1497 pub fn frequency_maps(&self) -> Option<&Vec<HashMap<u64, f64>>> {
1499 self.frequency_maps_.as_ref()
1500 }
1501}
1502
1503#[derive(Debug, Clone)]
1511pub struct WOEEncoder {
1512 woe_maps_: Option<Vec<HashMap<u64, f64>>>,
1514 information_values_: Option<Vec<f64>>,
1516 regularization: f64,
1518 handleunknown: String,
1520 unknownvalue: f64,
1522 global_woe_: f64,
1524 is_fitted: bool,
1526}
1527
1528impl WOEEncoder {
1529 pub fn new(regularization: f64, handleunknown: &str, unknownvalue: f64) -> Result<Self> {
1539 if regularization < 0.0 {
1540 return Err(TransformError::InvalidInput(
1541 "regularization must be non-negative".to_string(),
1542 ));
1543 }
1544
1545 if !["error", "global_woe", "use_encoded_value"].contains(&handleunknown) {
1546 return Err(TransformError::InvalidInput(
1547 "handleunknown must be 'error', 'global_woe', or 'use_encoded_value'".to_string(),
1548 ));
1549 }
1550
1551 Ok(WOEEncoder {
1552 woe_maps_: None,
1553 information_values_: None,
1554 regularization,
1555 handleunknown: handleunknown.to_string(),
1556 unknownvalue,
1557 global_woe_: 0.0,
1558 is_fitted: false,
1559 })
1560 }
1561
1562 pub fn with_defaults() -> Self {
1564 Self::new(0.5, "global_woe", 0.0).unwrap()
1565 }
1566
1567 pub fn with_regularization(regularization: f64) -> Result<Self> {
1569 Self::new(regularization, "global_woe", 0.0)
1570 }
1571
1572 pub fn fit<S>(&mut self, x: &ArrayBase<S, Ix2>, y: &[f64]) -> Result<()>
1581 where
1582 S: Data,
1583 S::Elem: Float + NumCast,
1584 {
1585 let x_u64 = x.mapv(|x| {
1586 let val_f64 = NumCast::from(x).unwrap_or(0.0);
1587 val_f64 as u64
1588 });
1589
1590 let n_samples = x_u64.shape()[0];
1591 let n_features = x_u64.shape()[1];
1592
1593 if n_samples == 0 || n_features == 0 {
1594 return Err(TransformError::InvalidInput("Empty input data".to_string()));
1595 }
1596
1597 if y.len() != n_samples {
1598 return Err(TransformError::InvalidInput(
1599 "Number of target values must match number of samples".to_string(),
1600 ));
1601 }
1602
1603 for &target in y {
1605 if target != 0.0 && target != 1.0 {
1606 return Err(TransformError::InvalidInput(
1607 "Target values must be binary (0 or 1)".to_string(),
1608 ));
1609 }
1610 }
1611
1612 let total_events: f64 = y.iter().sum();
1614 let total_non_events = n_samples as f64 - total_events;
1615
1616 if total_events == 0.0 || total_non_events == 0.0 {
1617 return Err(TransformError::InvalidInput(
1618 "Target must contain both 0 and 1 values".to_string(),
1619 ));
1620 }
1621
1622 self.global_woe_ = (total_events / total_non_events).ln();
1624
1625 let mut woe_maps = Vec::with_capacity(n_features);
1626 let mut information_values = Vec::with_capacity(n_features);
1627
1628 for j in 0..n_features {
1629 let mut category_stats: HashMap<u64, (f64, f64)> = HashMap::new(); for i in 0..n_samples {
1633 let category = x_u64[[i, j]];
1634 let target = y[i];
1635
1636 let (events, non_events) = category_stats.entry(category).or_insert((0.0, 0.0));
1637 if target == 1.0 {
1638 *events += 1.0;
1639 } else {
1640 *non_events += 1.0;
1641 }
1642 }
1643
1644 let mut woe_map = HashMap::new();
1646 let mut feature_iv = 0.0;
1647
1648 for (category, (events, non_events)) in category_stats.iter() {
1649 let reg_events = events + self.regularization;
1651 let reg_non_events = non_events + self.regularization;
1652 let reg_total_events =
1653 total_events + self.regularization * category_stats.len() as f64;
1654 let reg_total_non_events =
1655 total_non_events + self.regularization * category_stats.len() as f64;
1656
1657 let event_rate = reg_events / reg_total_events;
1659 let non_event_rate = reg_non_events / reg_total_non_events;
1660
1661 let woe = (event_rate / non_event_rate).ln();
1663 woe_map.insert(*category, woe);
1664
1665 let iv_contribution = (event_rate - non_event_rate) * woe;
1667 feature_iv += iv_contribution;
1668 }
1669
1670 woe_maps.push(woe_map);
1671 information_values.push(feature_iv);
1672 }
1673
1674 self.woe_maps_ = Some(woe_maps);
1675 self.information_values_ = Some(information_values);
1676 self.is_fitted = true;
1677 Ok(())
1678 }
1679
1680 pub fn transform<S>(&self, x: &ArrayBase<S, Ix2>) -> Result<Array2<f64>>
1688 where
1689 S: Data,
1690 S::Elem: Float + NumCast,
1691 {
1692 if !self.is_fitted {
1693 return Err(TransformError::TransformationError(
1694 "WOEEncoder has not been fitted".to_string(),
1695 ));
1696 }
1697
1698 let woe_maps = self.woe_maps_.as_ref().unwrap();
1699
1700 let x_u64 = x.mapv(|x| {
1701 let val_f64 = NumCast::from(x).unwrap_or(0.0);
1702 val_f64 as u64
1703 });
1704
1705 let n_samples = x_u64.shape()[0];
1706 let n_features = x_u64.shape()[1];
1707
1708 if n_features != woe_maps.len() {
1709 return Err(TransformError::InvalidInput(format!(
1710 "x has {} features, but WOEEncoder was fitted with {} features",
1711 n_features,
1712 woe_maps.len()
1713 )));
1714 }
1715
1716 let mut transformed = Array2::zeros((n_samples, n_features));
1717
1718 for i in 0..n_samples {
1719 for j in 0..n_features {
1720 let category = x_u64[[i, j]];
1721
1722 if let Some(&woe_value) = woe_maps[j].get(&category) {
1723 transformed[[i, j]] = woe_value;
1724 } else {
1725 match self.handleunknown.as_str() {
1727 "error" => {
1728 return Err(TransformError::InvalidInput(format!(
1729 "Unknown category {category} in feature {j}"
1730 )));
1731 }
1732 "global_woe" => {
1733 transformed[[i, j]] = self.global_woe_;
1734 }
1735 "use_encoded_value" => {
1736 transformed[[i, j]] = self.unknownvalue;
1737 }
1738 _ => unreachable!(),
1739 }
1740 }
1741 }
1742 }
1743
1744 Ok(transformed)
1745 }
1746
1747 pub fn fit_transform<S>(&mut self, x: &ArrayBase<S, Ix2>, y: &[f64]) -> Result<Array2<f64>>
1756 where
1757 S: Data,
1758 S::Elem: Float + NumCast,
1759 {
1760 self.fit(x, y)?;
1761 self.transform(x)
1762 }
1763
1764 pub fn is_fitted(&self) -> bool {
1766 self.is_fitted
1767 }
1768
1769 pub fn woe_maps(&self) -> Option<&Vec<HashMap<u64, f64>>> {
1771 self.woe_maps_.as_ref()
1772 }
1773
1774 pub fn information_values(&self) -> Option<&Vec<f64>> {
1783 self.information_values_.as_ref()
1784 }
1785
1786 pub fn global_woe(&self) -> f64 {
1788 self.global_woe_
1789 }
1790
1791 pub fn feature_importance_ranking(&self) -> Option<Vec<(usize, f64)>> {
1796 self.information_values_.as_ref().map(|ivs| {
1797 let mut ranking: Vec<(usize, f64)> =
1798 ivs.iter().enumerate().map(|(idx, &iv)| (idx, iv)).collect();
1799 ranking.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
1800 ranking
1801 })
1802 }
1803}
1804
1805#[cfg(test)]
1806mod tests {
1807 use super::*;
1808 use approx::assert_abs_diff_eq;
1809 use scirs2_core::ndarray::Array;
1810
1811 #[test]
1812 fn test_one_hot_encoder_basic() {
1813 let data = Array::from_shape_vec(
1815 (4, 2),
1816 vec![
1817 0.0, 1.0, 1.0, 2.0, 2.0, 3.0, 0.0, 1.0,
1819 ],
1820 )
1821 .unwrap();
1822
1823 let mut encoder = OneHotEncoder::with_defaults();
1824 let encoded = encoder.fit_transform(&data).unwrap();
1825
1826 assert_eq!(encoded.shape(), (4, 6));
1828
1829 let encoded_dense = encoded.to_dense();
1831
1832 assert_abs_diff_eq!(encoded_dense[[0, 0]], 1.0, epsilon = 1e-10); assert_abs_diff_eq!(encoded_dense[[0, 1]], 0.0, epsilon = 1e-10); assert_abs_diff_eq!(encoded_dense[[0, 2]], 0.0, epsilon = 1e-10); assert_abs_diff_eq!(encoded_dense[[0, 3]], 1.0, epsilon = 1e-10); assert_abs_diff_eq!(encoded_dense[[0, 4]], 0.0, epsilon = 1e-10); assert_abs_diff_eq!(encoded_dense[[0, 5]], 0.0, epsilon = 1e-10); assert_abs_diff_eq!(encoded_dense[[1, 0]], 0.0, epsilon = 1e-10); assert_abs_diff_eq!(encoded_dense[[1, 1]], 1.0, epsilon = 1e-10); assert_abs_diff_eq!(encoded_dense[[1, 2]], 0.0, epsilon = 1e-10); assert_abs_diff_eq!(encoded_dense[[1, 3]], 0.0, epsilon = 1e-10); assert_abs_diff_eq!(encoded_dense[[1, 4]], 1.0, epsilon = 1e-10); assert_abs_diff_eq!(encoded_dense[[1, 5]], 0.0, epsilon = 1e-10); }
1848
1849 #[test]
1850 fn test_one_hot_encoder_drop_first() {
1851 let data = Array::from_shape_vec((3, 2), vec![0.0, 1.0, 1.0, 2.0, 2.0, 1.0]).unwrap();
1853
1854 let mut encoder = OneHotEncoder::new(Some("first".to_string()), "error", false).unwrap();
1855 let encoded = encoder.fit_transform(&data).unwrap();
1856
1857 assert_eq!(encoded.shape(), (3, 3));
1859
1860 let encoded_dense = encoded.to_dense();
1863
1864 assert_abs_diff_eq!(encoded_dense[[0, 0]], 0.0, epsilon = 1e-10); assert_abs_diff_eq!(encoded_dense[[0, 1]], 0.0, epsilon = 1e-10); assert_abs_diff_eq!(encoded_dense[[0, 2]], 0.0, epsilon = 1e-10); assert_abs_diff_eq!(encoded_dense[[1, 0]], 1.0, epsilon = 1e-10); assert_abs_diff_eq!(encoded_dense[[1, 1]], 0.0, epsilon = 1e-10); assert_abs_diff_eq!(encoded_dense[[1, 2]], 1.0, epsilon = 1e-10); }
1874
1875 #[test]
1876 fn test_ordinal_encoder() {
1877 let data = Array::from_shape_vec(
1879 (4, 2),
1880 vec![
1881 2.0, 10.0, 1.0, 20.0, 3.0, 10.0, 2.0, 30.0,
1883 ],
1884 )
1885 .unwrap();
1886
1887 let mut encoder = OrdinalEncoder::with_defaults();
1888 let encoded = encoder.fit_transform(&data).unwrap();
1889
1890 assert_eq!(encoded.shape(), &[4, 2]);
1892
1893 assert_abs_diff_eq!(encoded[[0, 0]], 1.0, epsilon = 1e-10); assert_abs_diff_eq!(encoded[[0, 1]], 0.0, epsilon = 1e-10); assert_abs_diff_eq!(encoded[[1, 0]], 0.0, epsilon = 1e-10); assert_abs_diff_eq!(encoded[[1, 1]], 1.0, epsilon = 1e-10); assert_abs_diff_eq!(encoded[[2, 0]], 2.0, epsilon = 1e-10); assert_abs_diff_eq!(encoded[[2, 1]], 0.0, epsilon = 1e-10); assert_abs_diff_eq!(encoded[[3, 0]], 1.0, epsilon = 1e-10); assert_abs_diff_eq!(encoded[[3, 1]], 2.0, epsilon = 1e-10); }
1906
1907 #[test]
1908 fn test_unknown_category_handling() {
1909 let train_data = Array::from_shape_vec((2, 1), vec![1.0, 2.0]).unwrap();
1910
1911 let test_data = Array::from_shape_vec(
1912 (1, 1),
1913 vec![3.0], )
1915 .unwrap();
1916
1917 let mut encoder = OneHotEncoder::with_defaults(); encoder.fit(&train_data).unwrap();
1920 assert!(encoder.transform(&test_data).is_err());
1921
1922 let mut encoder = OneHotEncoder::new(None, "ignore", false).unwrap();
1924 encoder.fit(&train_data).unwrap();
1925 let encoded = encoder.transform(&test_data).unwrap();
1926
1927 assert_eq!(encoded.shape(), (1, 2));
1929 let encoded_dense = encoded.to_dense();
1930 assert_abs_diff_eq!(encoded_dense[[0, 0]], 0.0, epsilon = 1e-10);
1931 assert_abs_diff_eq!(encoded_dense[[0, 1]], 0.0, epsilon = 1e-10);
1932 }
1933
1934 #[test]
1935 fn test_ordinal_encoder_unknown_value() {
1936 let train_data = Array::from_shape_vec((2, 1), vec![1.0, 2.0]).unwrap();
1937
1938 let test_data = Array::from_shape_vec(
1939 (1, 1),
1940 vec![3.0], )
1942 .unwrap();
1943
1944 let mut encoder = OrdinalEncoder::new("use_encoded_value", Some(-1.0)).unwrap();
1945 encoder.fit(&train_data).unwrap();
1946 let encoded = encoder.transform(&test_data).unwrap();
1947
1948 assert_eq!(encoded.shape(), &[1, 1]);
1950 assert_abs_diff_eq!(encoded[[0, 0]], -1.0, epsilon = 1e-10);
1951 }
1952
1953 #[test]
1954 fn test_get_feature_names() {
1955 let data = Array::from_shape_vec((2, 2), vec![1.0, 10.0, 2.0, 20.0]).unwrap();
1956
1957 let mut encoder = OneHotEncoder::with_defaults();
1958 encoder.fit(&data).unwrap();
1959
1960 let feature_names = encoder.get_feature_names(None).unwrap();
1961 assert_eq!(feature_names.len(), 4); let custom_names = vec!["feat_a".to_string(), "feat_b".to_string()];
1964 let feature_names = encoder.get_feature_names(Some(&custom_names)).unwrap();
1965 assert!(feature_names[0].starts_with("feat_a_cat_"));
1966 assert!(feature_names[2].starts_with("feat_b_cat_"));
1967 }
1968
1969 #[test]
1970 fn test_target_encoder_mean_strategy() {
1971 let x = Array::from_shape_vec((6, 1), vec![0.0, 1.0, 2.0, 0.0, 1.0, 2.0]).unwrap();
1973 let y = vec![1.0, 2.0, 3.0, 1.5, 2.5, 3.5];
1974
1975 let mut encoder = TargetEncoder::new("mean", 0.0, 0.0).unwrap();
1976 let encoded = encoder.fit_transform(&x, &y).unwrap();
1977
1978 assert_eq!(encoded.shape(), &[6, 1]);
1980
1981 assert_abs_diff_eq!(encoded[[0, 0]], 1.25, epsilon = 1e-10);
1987 assert_abs_diff_eq!(encoded[[1, 0]], 2.25, epsilon = 1e-10);
1988 assert_abs_diff_eq!(encoded[[2, 0]], 3.25, epsilon = 1e-10);
1989 assert_abs_diff_eq!(encoded[[3, 0]], 1.25, epsilon = 1e-10);
1990 assert_abs_diff_eq!(encoded[[4, 0]], 2.25, epsilon = 1e-10);
1991 assert_abs_diff_eq!(encoded[[5, 0]], 3.25, epsilon = 1e-10);
1992
1993 assert_abs_diff_eq!(encoder.global_mean(), 2.25, epsilon = 1e-10);
1995 }
1996
1997 #[test]
1998 fn test_target_encoder_median_strategy() {
1999 let x = Array::from_shape_vec((4, 1), vec![0.0, 1.0, 0.0, 1.0]).unwrap();
2000 let y = vec![1.0, 2.0, 3.0, 4.0];
2001
2002 let mut encoder = TargetEncoder::new("median", 0.0, 0.0).unwrap();
2003 let encoded = encoder.fit_transform(&x, &y).unwrap();
2004
2005 assert_abs_diff_eq!(encoded[[0, 0]], 2.0, epsilon = 1e-10);
2009 assert_abs_diff_eq!(encoded[[1, 0]], 3.0, epsilon = 1e-10);
2010 assert_abs_diff_eq!(encoded[[2, 0]], 2.0, epsilon = 1e-10);
2011 assert_abs_diff_eq!(encoded[[3, 0]], 3.0, epsilon = 1e-10);
2012 }
2013
2014 #[test]
2015 fn test_target_encoder_count_strategy() {
2016 let x = Array::from_shape_vec((5, 1), vec![0.0, 1.0, 0.0, 2.0, 1.0]).unwrap();
2017 let y = vec![1.0, 2.0, 3.0, 4.0, 5.0];
2018
2019 let mut encoder = TargetEncoder::new("count", 0.0, 0.0).unwrap();
2020 let encoded = encoder.fit_transform(&x, &y).unwrap();
2021
2022 assert_abs_diff_eq!(encoded[[0, 0]], 2.0, epsilon = 1e-10);
2027 assert_abs_diff_eq!(encoded[[1, 0]], 2.0, epsilon = 1e-10);
2028 assert_abs_diff_eq!(encoded[[2, 0]], 2.0, epsilon = 1e-10);
2029 assert_abs_diff_eq!(encoded[[3, 0]], 1.0, epsilon = 1e-10);
2030 assert_abs_diff_eq!(encoded[[4, 0]], 2.0, epsilon = 1e-10);
2031 }
2032
2033 #[test]
2034 fn test_target_encoder_sum_strategy() {
2035 let x = Array::from_shape_vec((4, 1), vec![0.0, 1.0, 0.0, 1.0]).unwrap();
2036 let y = vec![1.0, 2.0, 3.0, 4.0];
2037
2038 let mut encoder = TargetEncoder::new("sum", 0.0, 0.0).unwrap();
2039 let encoded = encoder.fit_transform(&x, &y).unwrap();
2040
2041 assert_abs_diff_eq!(encoded[[0, 0]], 4.0, epsilon = 1e-10);
2045 assert_abs_diff_eq!(encoded[[1, 0]], 6.0, epsilon = 1e-10);
2046 assert_abs_diff_eq!(encoded[[2, 0]], 4.0, epsilon = 1e-10);
2047 assert_abs_diff_eq!(encoded[[3, 0]], 6.0, epsilon = 1e-10);
2048 }
2049
2050 #[test]
2051 fn test_target_encoder_smoothing() {
2052 let x = Array::from_shape_vec((3, 1), vec![0.0, 1.0, 2.0]).unwrap();
2053 let y = vec![1.0, 2.0, 3.0];
2054
2055 let mut encoder = TargetEncoder::new("mean", 1.0, 0.0).unwrap();
2056 let encoded = encoder.fit_transform(&x, &y).unwrap();
2057
2058 assert_abs_diff_eq!(encoded[[0, 0]], 1.5, epsilon = 1e-10);
2064 assert_abs_diff_eq!(encoded[[1, 0]], 2.0, epsilon = 1e-10);
2065 assert_abs_diff_eq!(encoded[[2, 0]], 2.5, epsilon = 1e-10);
2066 }
2067
2068 #[test]
2069 fn test_target_encoder_unknown_categories() {
2070 let train_x = Array::from_shape_vec((3, 1), vec![0.0, 1.0, 2.0]).unwrap();
2071 let train_y = vec![1.0, 2.0, 3.0];
2072
2073 let test_x = Array::from_shape_vec((2, 1), vec![3.0, 4.0]).unwrap(); let mut encoder = TargetEncoder::new("mean", 0.0, -1.0).unwrap();
2076 encoder.fit(&train_x, &train_y).unwrap();
2077 let encoded = encoder.transform(&test_x).unwrap();
2078
2079 assert_abs_diff_eq!(encoded[[0, 0]], -1.0, epsilon = 1e-10);
2081 assert_abs_diff_eq!(encoded[[1, 0]], -1.0, epsilon = 1e-10);
2082 }
2083
2084 #[test]
2085 fn test_target_encoder_unknown_categories_global_mean() {
2086 let train_x = Array::from_shape_vec((3, 1), vec![0.0, 1.0, 2.0]).unwrap();
2087 let train_y = vec![1.0, 2.0, 3.0];
2088
2089 let test_x = Array::from_shape_vec((1, 1), vec![3.0]).unwrap(); let mut encoder = TargetEncoder::new("mean", 0.0, 0.0).unwrap(); encoder.fit(&train_x, &train_y).unwrap();
2093 let encoded = encoder.transform(&test_x).unwrap();
2094
2095 assert_abs_diff_eq!(encoded[[0, 0]], 2.0, epsilon = 1e-10); }
2098
2099 #[test]
2100 fn test_target_encoder_multi_feature() {
2101 let x =
2102 Array::from_shape_vec((4, 2), vec![0.0, 0.0, 1.0, 1.0, 0.0, 1.0, 1.0, 0.0]).unwrap();
2103 let y = vec![1.0, 2.0, 3.0, 4.0];
2104
2105 let mut encoder = TargetEncoder::new("mean", 0.0, 0.0).unwrap();
2106 let encoded = encoder.fit_transform(&x, &y).unwrap();
2107
2108 assert_eq!(encoded.shape(), &[4, 2]);
2109
2110 assert_abs_diff_eq!(encoded[[0, 0]], 2.0, epsilon = 1e-10);
2116 assert_abs_diff_eq!(encoded[[0, 1]], 2.5, epsilon = 1e-10);
2117 assert_abs_diff_eq!(encoded[[1, 0]], 3.0, epsilon = 1e-10);
2118 assert_abs_diff_eq!(encoded[[1, 1]], 2.5, epsilon = 1e-10);
2119 }
2120
2121 #[test]
2122 fn test_target_encoder_cross_validation() {
2123 let x = Array::from_shape_vec(
2124 (10, 1),
2125 vec![0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0],
2126 )
2127 .unwrap();
2128 let y = vec![1.0, 2.0, 1.5, 2.5, 1.2, 2.2, 1.3, 2.3, 1.1, 2.1];
2129
2130 let mut encoder = TargetEncoder::new("mean", 0.0, 0.0).unwrap();
2131 let encoded = encoder.fit_transform_cv(&x, &y, 5).unwrap();
2132
2133 assert_eq!(encoded.shape(), &[10, 1]);
2135
2136 assert!(encoded[[0, 0]] < encoded[[1, 0]]); assert!(encoded[[2, 0]] < encoded[[3, 0]]);
2141 }
2142
2143 #[test]
2144 fn test_target_encoder_convenience_methods() {
2145 let _x = Array::from_shape_vec((4, 1), vec![0.0, 1.0, 0.0, 1.0]).unwrap();
2146 let _y = [1.0, 2.0, 3.0, 4.0];
2147
2148 let encoder1 = TargetEncoder::with_mean(1.0);
2149 assert_eq!(encoder1.strategy, "mean");
2150 assert_abs_diff_eq!(encoder1.smoothing, 1.0, epsilon = 1e-10);
2151
2152 let encoder2 = TargetEncoder::with_median(0.5);
2153 assert_eq!(encoder2.strategy, "median");
2154 assert_abs_diff_eq!(encoder2.smoothing, 0.5, epsilon = 1e-10);
2155 }
2156
2157 #[test]
2158 fn test_target_encoder_validation_errors() {
2159 assert!(TargetEncoder::new("invalid", 0.0, 0.0).is_err());
2161
2162 assert!(TargetEncoder::new("mean", -1.0, 0.0).is_err());
2164
2165 let x = Array::from_shape_vec((3, 1), vec![0.0, 1.0, 2.0]).unwrap();
2167 let y = vec![1.0, 2.0]; let mut encoder = TargetEncoder::new("mean", 0.0, 0.0).unwrap();
2170 assert!(encoder.fit(&x, &y).is_err());
2171
2172 let encoder2 = TargetEncoder::new("mean", 0.0, 0.0).unwrap();
2174 assert!(encoder2.transform(&x).is_err());
2175
2176 let train_x = Array::from_shape_vec((2, 1), vec![0.0, 1.0]).unwrap();
2178 let test_x = Array::from_shape_vec((2, 2), vec![0.0, 1.0, 1.0, 0.0]).unwrap();
2179 let train_y = vec![1.0, 2.0];
2180
2181 let mut encoder = TargetEncoder::new("mean", 0.0, 0.0).unwrap();
2182 encoder.fit(&train_x, &train_y).unwrap();
2183 assert!(encoder.transform(&test_x).is_err());
2184
2185 let x = Array::from_shape_vec((4, 1), vec![0.0, 1.0, 0.0, 1.0]).unwrap();
2187 let y = vec![1.0, 2.0, 3.0, 4.0];
2188 let mut encoder = TargetEncoder::new("mean", 0.0, 0.0).unwrap();
2189 assert!(encoder.fit_transform_cv(&x, &y, 1).is_err()); }
2191
2192 #[test]
2193 fn test_target_encoder_accessors() {
2194 let x = Array::from_shape_vec((3, 1), vec![0.0, 1.0, 2.0]).unwrap();
2195 let y = vec![1.0, 2.0, 3.0];
2196
2197 let mut encoder = TargetEncoder::new("mean", 0.0, 0.0).unwrap();
2198
2199 assert!(!encoder.is_fitted());
2200 assert!(encoder.encodings().is_none());
2201
2202 encoder.fit(&x, &y).unwrap();
2203
2204 assert!(encoder.is_fitted());
2205 assert!(encoder.encodings().is_some());
2206 assert_abs_diff_eq!(encoder.global_mean(), 2.0, epsilon = 1e-10);
2207
2208 let encodings = encoder.encodings().unwrap();
2209 assert_eq!(encodings.len(), 1); assert_eq!(encodings[0].len(), 3); }
2212
2213 #[test]
2214 fn test_target_encoder_empty_data() {
2215 let empty_x = Array2::<f64>::zeros((0, 1));
2216 let empty_y = vec![];
2217
2218 let mut encoder = TargetEncoder::new("mean", 0.0, 0.0).unwrap();
2219 assert!(encoder.fit(&empty_x, &empty_y).is_err());
2220 }
2221
2222 #[test]
2225 fn test_binary_encoder_basic() {
2226 let data = Array::from_shape_vec((4, 1), vec![0.0, 1.0, 2.0, 3.0]).unwrap();
2228
2229 let mut encoder = BinaryEncoder::with_defaults();
2230 let encoded = encoder.fit_transform(&data).unwrap();
2231
2232 assert_eq!(encoded.shape(), &[4, 2]);
2234
2235 assert_abs_diff_eq!(encoded[[0, 0]], 0.0, epsilon = 1e-10); assert_abs_diff_eq!(encoded[[0, 1]], 0.0, epsilon = 1e-10);
2238 assert_abs_diff_eq!(encoded[[1, 0]], 0.0, epsilon = 1e-10); assert_abs_diff_eq!(encoded[[1, 1]], 1.0, epsilon = 1e-10);
2240 assert_abs_diff_eq!(encoded[[2, 0]], 1.0, epsilon = 1e-10); assert_abs_diff_eq!(encoded[[2, 1]], 0.0, epsilon = 1e-10);
2242 assert_abs_diff_eq!(encoded[[3, 0]], 1.0, epsilon = 1e-10); assert_abs_diff_eq!(encoded[[3, 1]], 1.0, epsilon = 1e-10);
2244 }
2245
2246 #[test]
2247 fn test_binary_encoder_power_of_two() {
2248 let data =
2250 Array::from_shape_vec((8, 1), vec![0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0]).unwrap();
2251
2252 let mut encoder = BinaryEncoder::with_defaults();
2253 let encoded = encoder.fit_transform(&data).unwrap();
2254
2255 assert_eq!(encoded.shape(), &[8, 3]);
2257
2258 assert_abs_diff_eq!(encoded[[0, 0]], 0.0, epsilon = 1e-10); assert_abs_diff_eq!(encoded[[0, 1]], 0.0, epsilon = 1e-10);
2261 assert_abs_diff_eq!(encoded[[0, 2]], 0.0, epsilon = 1e-10);
2262
2263 assert_abs_diff_eq!(encoded[[7, 0]], 1.0, epsilon = 1e-10); assert_abs_diff_eq!(encoded[[7, 1]], 1.0, epsilon = 1e-10);
2265 assert_abs_diff_eq!(encoded[[7, 2]], 1.0, epsilon = 1e-10);
2266 }
2267
2268 #[test]
2269 fn test_binary_encoder_non_power_of_two() {
2270 let data = Array::from_shape_vec((5, 1), vec![0.0, 1.0, 2.0, 3.0, 4.0]).unwrap();
2272
2273 let mut encoder = BinaryEncoder::with_defaults();
2274 let encoded = encoder.fit_transform(&data).unwrap();
2275
2276 assert_eq!(encoded.shape(), &[5, 3]);
2278 assert_eq!(encoder.n_output_features().unwrap(), 3);
2279 }
2280
2281 #[test]
2282 fn test_binary_encoder_single_category() {
2283 let data = Array::from_shape_vec((3, 1), vec![5.0, 5.0, 5.0]).unwrap();
2285
2286 let mut encoder = BinaryEncoder::with_defaults();
2287 let encoded = encoder.fit_transform(&data).unwrap();
2288
2289 assert_eq!(encoded.shape(), &[3, 1]);
2291 assert_eq!(encoder.n_output_features().unwrap(), 1);
2292
2293 for i in 0..3 {
2295 assert_abs_diff_eq!(encoded[[i, 0]], 0.0, epsilon = 1e-10);
2296 }
2297 }
2298
2299 #[test]
2300 fn test_binary_encoder_multi_feature() {
2301 let data = Array::from_shape_vec(
2303 (4, 2),
2304 vec![
2305 0.0, 10.0, 1.0, 11.0, 2.0, 10.0, 0.0, 11.0,
2307 ],
2308 )
2309 .unwrap();
2310
2311 let mut encoder = BinaryEncoder::with_defaults();
2312 let encoded = encoder.fit_transform(&data).unwrap();
2313
2314 assert_eq!(encoded.shape(), &[4, 3]);
2317 assert_eq!(encoder.n_output_features().unwrap(), 3);
2318
2319 let n_binary_features = encoder.n_binary_features().unwrap();
2320 assert_eq!(n_binary_features, &[2, 1]);
2321 }
2322
2323 #[test]
2324 fn test_binary_encoder_separate_fit_transform() {
2325 let train_data = Array::from_shape_vec((3, 1), vec![0.0, 1.0, 2.0]).unwrap();
2326 let test_data = Array::from_shape_vec((2, 1), vec![1.0, 0.0]).unwrap();
2327
2328 let mut encoder = BinaryEncoder::with_defaults();
2329
2330 encoder.fit(&train_data).unwrap();
2332 assert!(encoder.is_fitted());
2333
2334 let encoded = encoder.transform(&test_data).unwrap();
2336 assert_eq!(encoded.shape(), &[2, 2]); let train_encoded = encoder.transform(&train_data).unwrap();
2340 assert_abs_diff_eq!(encoded[[0, 0]], train_encoded[[1, 0]], epsilon = 1e-10); assert_abs_diff_eq!(encoded[[0, 1]], train_encoded[[1, 1]], epsilon = 1e-10);
2342 }
2343
2344 #[test]
2345 fn test_binary_encoder_unknown_categories_error() {
2346 let train_data = Array::from_shape_vec((2, 1), vec![0.0, 1.0]).unwrap();
2347 let test_data = Array::from_shape_vec((1, 1), vec![2.0]).unwrap(); let mut encoder = BinaryEncoder::new("error").unwrap();
2350 encoder.fit(&train_data).unwrap();
2351
2352 assert!(encoder.transform(&test_data).is_err());
2354 }
2355
2356 #[test]
2357 fn test_binary_encoder_unknown_categories_ignore() {
2358 let train_data = Array::from_shape_vec((2, 1), vec![0.0, 1.0]).unwrap();
2359 let test_data = Array::from_shape_vec((1, 1), vec![2.0]).unwrap(); let mut encoder = BinaryEncoder::new("ignore").unwrap();
2362 encoder.fit(&train_data).unwrap();
2363 let encoded = encoder.transform(&test_data).unwrap();
2364
2365 assert_eq!(encoded.shape(), &[1, 1]); assert_abs_diff_eq!(encoded[[0, 0]], 0.0, epsilon = 1e-10);
2368 }
2369
2370 #[test]
2371 fn test_binary_encoder_categories_accessor() {
2372 let data = Array::from_shape_vec((3, 1), vec![10.0, 20.0, 30.0]).unwrap();
2373
2374 let mut encoder = BinaryEncoder::with_defaults();
2375
2376 assert!(!encoder.is_fitted());
2378 assert!(encoder.categories().is_none());
2379 assert!(encoder.n_binary_features().is_none());
2380 assert!(encoder.n_output_features().is_none());
2381
2382 encoder.fit(&data).unwrap();
2383
2384 assert!(encoder.is_fitted());
2386 assert!(encoder.categories().is_some());
2387 assert!(encoder.n_binary_features().is_some());
2388 assert!(encoder.n_output_features().is_some());
2389
2390 let categories = encoder.categories().unwrap();
2391 assert_eq!(categories.len(), 1); assert_eq!(categories[0].len(), 3); let category_map = &categories[0];
2396 assert!(category_map.contains_key(&10));
2397 assert!(category_map.contains_key(&20));
2398 assert!(category_map.contains_key(&30));
2399 }
2400
2401 #[test]
2402 fn test_binary_encoder_int_to_binary() {
2403 assert_eq!(BinaryEncoder::int_to_binary(0, 3), vec![0, 0, 0]);
2405 assert_eq!(BinaryEncoder::int_to_binary(1, 3), vec![0, 0, 1]);
2406 assert_eq!(BinaryEncoder::int_to_binary(2, 3), vec![0, 1, 0]);
2407 assert_eq!(BinaryEncoder::int_to_binary(3, 3), vec![0, 1, 1]);
2408 assert_eq!(BinaryEncoder::int_to_binary(7, 3), vec![1, 1, 1]);
2409
2410 assert_eq!(BinaryEncoder::int_to_binary(5, 4), vec![0, 1, 0, 1]);
2412 assert_eq!(BinaryEncoder::int_to_binary(1, 1), vec![1]);
2413 }
2414
2415 #[test]
2416 fn test_binary_encoder_validation_errors() {
2417 assert!(BinaryEncoder::new("invalid").is_err());
2419
2420 let empty_data = Array2::<f64>::zeros((0, 1));
2422 let mut encoder = BinaryEncoder::with_defaults();
2423 assert!(encoder.fit(&empty_data).is_err());
2424
2425 let data = Array::from_shape_vec((2, 1), vec![0.0, 1.0]).unwrap();
2427 let encoder = BinaryEncoder::with_defaults();
2428 assert!(encoder.transform(&data).is_err());
2429
2430 let train_data = Array::from_shape_vec((2, 1), vec![0.0, 1.0]).unwrap();
2432 let test_data = Array::from_shape_vec((2, 2), vec![0.0, 1.0, 1.0, 0.0]).unwrap();
2433
2434 let mut encoder = BinaryEncoder::with_defaults();
2435 encoder.fit(&train_data).unwrap();
2436 assert!(encoder.transform(&test_data).is_err());
2437 }
2438
2439 #[test]
2440 fn test_binary_encoder_consistency() {
2441 let data = Array::from_shape_vec((4, 1), vec![3.0, 1.0, 4.0, 1.0]).unwrap();
2443
2444 let mut encoder = BinaryEncoder::with_defaults();
2445 let encoded1 = encoder.fit_transform(&data).unwrap();
2446 let encoded2 = encoder.transform(&data).unwrap();
2447
2448 for i in 0..encoded1.shape()[0] {
2450 for j in 0..encoded1.shape()[1] {
2451 assert_abs_diff_eq!(encoded1[[i, j]], encoded2[[i, j]], epsilon = 1e-10);
2452 }
2453 }
2454
2455 assert_abs_diff_eq!(encoded1[[1, 0]], encoded1[[3, 0]], epsilon = 1e-10); assert_abs_diff_eq!(encoded1[[1, 1]], encoded1[[3, 1]], epsilon = 1e-10);
2458 }
2459
2460 #[test]
2461 fn test_binary_encoder_memory_efficiency() {
2462 let data = Array::from_shape_vec(
2465 (10, 1),
2466 vec![0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0],
2467 )
2468 .unwrap();
2469
2470 let mut binary_encoder = BinaryEncoder::with_defaults();
2471 let binary_encoded = binary_encoder.fit_transform(&data).unwrap();
2472
2473 let mut onehot_encoder = OneHotEncoder::with_defaults();
2474 let onehot_encoded = onehot_encoder.fit_transform(&data).unwrap();
2475
2476 assert_eq!(binary_encoded.shape()[1], 4); assert_eq!(onehot_encoded.shape().1, 10); assert!(binary_encoded.shape()[1] < onehot_encoded.shape().1);
2480 }
2481
2482 #[test]
2483 fn test_sparse_matrix_basic() {
2484 let mut sparse = SparseMatrix::new((3, 4));
2485 sparse.push(0, 1, 1.0);
2486 sparse.push(1, 2, 1.0);
2487 sparse.push(2, 0, 1.0);
2488
2489 assert_eq!(sparse.shape, (3, 4));
2490 assert_eq!(sparse.nnz(), 3);
2491
2492 let dense = sparse.to_dense();
2493 assert_eq!(dense.shape(), &[3, 4]);
2494 assert_eq!(dense[[0, 1]], 1.0);
2495 assert_eq!(dense[[1, 2]], 1.0);
2496 assert_eq!(dense[[2, 0]], 1.0);
2497 assert_eq!(dense[[0, 0]], 0.0); }
2499
2500 #[test]
2501 fn test_onehot_sparse_output() {
2502 let data =
2503 Array::from_shape_vec((4, 2), vec![0.0, 1.0, 1.0, 2.0, 2.0, 0.0, 0.0, 1.0]).unwrap();
2504
2505 let mut encoder_sparse = OneHotEncoder::new(None, "error", true).unwrap();
2507 let result_sparse = encoder_sparse.fit_transform(&data).unwrap();
2508
2509 match &result_sparse {
2510 EncodedOutput::Sparse(sparse) => {
2511 assert_eq!(sparse.shape, (4, 6)); assert_eq!(sparse.nnz(), 8); let dense = sparse.to_dense();
2516
2517 assert_eq!(dense[[0, 0]], 1.0); assert_eq!(dense[[0, 4]], 1.0); assert_eq!(dense[[0, 1]], 0.0); }
2522 EncodedOutput::Dense(_) => assert!(false, "Expected sparse output, got dense"),
2523 }
2524
2525 let mut encoder_dense = OneHotEncoder::new(None, "error", false).unwrap();
2527 let result_dense = encoder_dense.fit_transform(&data).unwrap();
2528
2529 match result_dense {
2530 EncodedOutput::Dense(dense) => {
2531 assert_eq!(dense.shape(), &[4, 6]);
2532 let sparse_as_dense = result_sparse.to_dense();
2534 for i in 0..4 {
2535 for j in 0..6 {
2536 assert_abs_diff_eq!(
2537 dense[[i, j]],
2538 sparse_as_dense[[i, j]],
2539 epsilon = 1e-10
2540 );
2541 }
2542 }
2543 }
2544 EncodedOutput::Sparse(_) => assert!(false, "Expected dense output, got sparse"),
2545 }
2546 }
2547
2548 #[test]
2549 fn test_onehot_sparse_with_drop() {
2550 let data = Array::from_shape_vec((3, 1), vec![0.0, 1.0, 2.0]).unwrap();
2551
2552 let mut encoder = OneHotEncoder::new(Some("first".to_string()), "error", true).unwrap();
2553 let result = encoder.fit_transform(&data).unwrap();
2554
2555 match result {
2556 EncodedOutput::Sparse(sparse) => {
2557 assert_eq!(sparse.shape, (3, 2)); assert_eq!(sparse.nnz(), 2); let dense = sparse.to_dense();
2561 assert_eq!(dense[[0, 0]], 0.0); assert_eq!(dense[[0, 1]], 0.0);
2563 assert_eq!(dense[[1, 0]], 1.0); assert_eq!(dense[[2, 1]], 1.0); }
2566 EncodedOutput::Dense(_) => assert!(false, "Expected sparse output, got dense"),
2567 }
2568 }
2569
2570 #[test]
2571 fn test_onehot_sparse_backward_compatibility() {
2572 let data = Array::from_shape_vec((2, 1), vec![0.0, 1.0]).unwrap();
2573
2574 let mut encoder = OneHotEncoder::new(None, "error", true).unwrap();
2575 encoder.fit(&data).unwrap();
2576
2577 let dense_result = encoder.transform_dense(&data).unwrap();
2579 assert_eq!(dense_result.shape(), &[2, 2]);
2580 assert_eq!(dense_result[[0, 0]], 1.0);
2581 assert_eq!(dense_result[[1, 1]], 1.0);
2582
2583 let mut encoder2 = OneHotEncoder::new(None, "error", true).unwrap();
2584 let dense_result2 = encoder2.fit_transform_dense(&data).unwrap();
2585 assert_eq!(dense_result2.shape(), &[2, 2]);
2586
2587 for i in 0..2 {
2589 for j in 0..2 {
2590 assert_abs_diff_eq!(dense_result[[i, j]], dense_result2[[i, j]], epsilon = 1e-10);
2591 }
2592 }
2593 }
2594
2595 #[test]
2596 fn test_encoded_output_methods() {
2597 let dense_array =
2598 Array::from_shape_vec((2, 3), vec![1.0, 0.0, 0.0, 0.0, 1.0, 0.0]).unwrap();
2599 let dense_output = EncodedOutput::Dense(dense_array);
2600
2601 let mut sparse_matrix = SparseMatrix::new((2, 3));
2602 sparse_matrix.push(0, 0, 1.0);
2603 sparse_matrix.push(1, 1, 1.0);
2604 let sparse_output = EncodedOutput::Sparse(sparse_matrix);
2605
2606 assert_eq!(dense_output.shape(), (2, 3));
2608 assert_eq!(sparse_output.shape(), (2, 3));
2609
2610 let dense_from_dense = dense_output.to_dense();
2612 let dense_from_sparse = sparse_output.to_dense();
2613
2614 assert_eq!(dense_from_dense.shape(), &[2, 3]);
2615 assert_eq!(dense_from_sparse.shape(), &[2, 3]);
2616
2617 assert_eq!(dense_from_dense[[0, 0]], 1.0);
2619 assert_eq!(dense_from_sparse[[0, 0]], 1.0);
2620 assert_eq!(dense_from_dense[[1, 1]], 1.0);
2621 assert_eq!(dense_from_sparse[[1, 1]], 1.0);
2622 }
2623}