1use scirs2_core::ndarray::{s, Array2, Axis};
7use sklears_core::{
8 error::{Result, SklearsError},
9 traits::{Estimator, Fit, Trained, Transform, Untrained},
10 types::Float,
11};
12use std::collections::HashMap;
13use std::marker::PhantomData;
14
15#[cfg(feature = "parallel")]
16use rayon::prelude::*;
17
18#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
20struct OrderedFloat(u64);
21
22impl From<Float> for OrderedFloat {
23 fn from(val: Float) -> Self {
24 OrderedFloat(val.to_bits())
25 }
26}
27
28#[derive(Debug, Clone)]
30pub enum ColumnSelector {
31 Indices(Vec<usize>),
33 Names(Vec<String>),
35 DataType(DataType),
37 Remainder,
39}
40
41#[derive(Debug, Clone, PartialEq)]
43pub enum DataType {
44 Numeric,
45 Categorical,
46 Boolean,
47}
48
49#[derive(Debug, Clone)]
51pub enum RemainderStrategy {
52 Drop,
54 Passthrough,
56 Transform(Box<dyn TransformerWrapper>),
58}
59
60#[derive(Debug, Clone, Copy, PartialEq)]
62pub enum ColumnErrorStrategy {
63 StopOnError,
65 SkipOnError,
67 Fallback,
69 ReplaceWithZeros,
71 ReplaceWithNaN,
73}
74
75impl Default for ColumnErrorStrategy {
76 fn default() -> Self {
77 Self::StopOnError
78 }
79}
80
81impl Default for RemainderStrategy {
82 fn default() -> Self {
83 Self::Drop
84 }
85}
86
87pub trait TransformerWrapper: Send + Sync + std::fmt::Debug {
89 fn fit_transform_wrapper(&self, x: &Array2<Float>) -> Result<Array2<Float>>;
90 fn transform_wrapper(&self, x: &Array2<Float>) -> Result<Array2<Float>>;
91 fn get_n_features_out(&self) -> Option<usize>;
92 fn clone_box(&self) -> Box<dyn TransformerWrapper>;
93}
94
95impl Clone for Box<dyn TransformerWrapper> {
96 fn clone(&self) -> Self {
97 self.clone_box()
98 }
99}
100
101#[derive(Debug, Clone)]
103pub struct TransformerStep {
104 pub name: String,
106 pub columns: ColumnSelector,
108 pub transformer: Box<dyn TransformerWrapper>,
110}
111
112#[derive(Debug, Clone)]
114pub struct ColumnTransformerConfig {
115 pub remainder: RemainderStrategy,
117 pub preserve_order: bool,
119 pub n_jobs: Option<usize>,
121 pub validate_input: bool,
123 pub error_strategy: ColumnErrorStrategy,
125 pub parallel_execution: bool,
127 pub fallback_transformer: Option<Box<dyn TransformerWrapper>>,
129}
130
131impl Default for ColumnTransformerConfig {
132 fn default() -> Self {
133 Self {
134 remainder: RemainderStrategy::Drop,
135 preserve_order: false,
136 n_jobs: None,
137 validate_input: true,
138 error_strategy: ColumnErrorStrategy::StopOnError,
139 parallel_execution: false,
140 fallback_transformer: None,
141 }
142 }
143}
144
145#[derive(Debug)]
147pub struct ColumnTransformer<State = Untrained> {
148 config: ColumnTransformerConfig,
149 transformers: Vec<TransformerStep>,
150 state: PhantomData<State>,
151 fitted_transformers_: Option<Vec<TransformerStep>>,
153 feature_names_in_: Option<Vec<String>>,
154 n_features_in_: Option<usize>,
155 output_indices_: Option<HashMap<String, Vec<usize>>>,
156 remainder_indices_: Option<Vec<usize>>,
157}
158
159#[derive(Debug)]
161struct ColumnTransformResult {
162 transformer_name: String,
163 column_indices: Vec<usize>,
164 result: Result<Array2<Float>>,
165 original_indices: Vec<usize>,
166}
167
168impl<State> ColumnTransformer<State> {
170 fn apply_transformer_with_error_handling(
172 &self,
173 step: &TransformerStep,
174 _data: &Array2<Float>,
175 subset: &Array2<Float>,
176 is_fit_transform: bool,
177 resolved_indices: &[usize],
178 ) -> ColumnTransformResult {
179 let column_indices = resolved_indices.to_vec();
180
181 let transform_result = if is_fit_transform {
182 step.transformer.fit_transform_wrapper(subset)
183 } else {
184 step.transformer.transform_wrapper(subset)
185 };
186
187 let final_result = match transform_result {
188 Ok(transformed) => Ok(transformed),
189 Err(error) => {
190 match self.config.error_strategy {
192 ColumnErrorStrategy::StopOnError => Err(error),
193 ColumnErrorStrategy::SkipOnError => {
194 eprintln!(
195 "Warning: Transformer '{}' failed on columns {:?}: {}. Skipping...",
196 step.name, column_indices, error
197 );
198 Ok(Array2::zeros((subset.nrows(), 0)))
200 }
201 ColumnErrorStrategy::Fallback => {
202 if let Some(ref fallback) = self.config.fallback_transformer {
203 eprintln!("Warning: Transformer '{}' failed on columns {:?}: {}. Using fallback...",
204 step.name, column_indices, error);
205 if is_fit_transform {
206 fallback.fit_transform_wrapper(subset)
207 } else {
208 fallback.transform_wrapper(subset)
209 }
210 } else {
211 eprintln!("Warning: Transformer '{}' failed on columns {:?}: {}. No fallback available, passing through...",
212 step.name, column_indices, error);
213 Ok(subset.clone())
214 }
215 }
216 ColumnErrorStrategy::ReplaceWithZeros => {
217 eprintln!("Warning: Transformer '{}' failed on columns {:?}: {}. Replacing with zeros...",
218 step.name, column_indices, error);
219 Ok(Array2::zeros(subset.dim()))
220 }
221 ColumnErrorStrategy::ReplaceWithNaN => {
222 eprintln!("Warning: Transformer '{}' failed on columns {:?}: {}. Replacing with NaN...",
223 step.name, column_indices, error);
224 Ok(Array2::from_elem(subset.dim(), Float::NAN))
225 }
226 }
227 }
228 };
229
230 ColumnTransformResult {
231 transformer_name: step.name.clone(),
232 column_indices: column_indices.clone(),
233 result: final_result,
234 original_indices: column_indices,
235 }
236 }
237}
238
239impl ColumnTransformer<Untrained> {
240 pub fn new() -> Self {
242 Self {
243 config: ColumnTransformerConfig::default(),
244 transformers: Vec::new(),
245 state: PhantomData,
246 fitted_transformers_: None,
247 feature_names_in_: None,
248 n_features_in_: None,
249 output_indices_: None,
250 remainder_indices_: None,
251 }
252 }
253
254 pub fn add_transformer<T>(mut self, name: &str, transformer: T, columns: ColumnSelector) -> Self
256 where
257 T: TransformerWrapper + 'static,
258 {
259 self.transformers.push(TransformerStep {
260 name: name.to_string(),
261 columns,
262 transformer: Box::new(transformer),
263 });
264 self
265 }
266
267 pub fn remainder(mut self, strategy: RemainderStrategy) -> Self {
269 self.config.remainder = strategy;
270 self
271 }
272
273 pub fn preserve_order(mut self, preserve: bool) -> Self {
275 self.config.preserve_order = preserve;
276 self
277 }
278
279 pub fn n_jobs(mut self, n_jobs: Option<usize>) -> Self {
281 self.config.n_jobs = n_jobs;
282 self
283 }
284
285 pub fn validate_input(mut self, validate: bool) -> Self {
287 self.config.validate_input = validate;
288 self
289 }
290
291 pub fn error_strategy(mut self, strategy: ColumnErrorStrategy) -> Self {
293 self.config.error_strategy = strategy;
294 self
295 }
296
297 pub fn parallel_execution(mut self, parallel: bool) -> Self {
299 self.config.parallel_execution = parallel;
300 self
301 }
302
303 pub fn fallback_transformer<T>(mut self, transformer: T) -> Self
305 where
306 T: TransformerWrapper + 'static,
307 {
308 self.config.fallback_transformer = Some(Box::new(transformer));
309 self
310 }
311
312 fn resolve_columns(&self, selector: &ColumnSelector, n_features: usize) -> Result<Vec<usize>> {
314 match selector {
315 ColumnSelector::Indices(indices) => {
316 for &idx in indices {
318 if idx >= n_features {
319 return Err(SklearsError::InvalidInput(format!(
320 "Column index {} is out of bounds for {} features",
321 idx, n_features
322 )));
323 }
324 }
325 Ok(indices.clone())
326 }
327 ColumnSelector::Names(_names) => {
328 Err(SklearsError::NotImplemented(
330 "Named column selection not yet implemented".to_string(),
331 ))
332 }
333 ColumnSelector::DataType(_dtype) => {
334 Err(SklearsError::InvalidInput(
336 "DataType column selection requires training data. Use resolve_columns_with_data.".to_string(),
337 ))
338 }
339 ColumnSelector::Remainder => {
340 Ok(Vec::new())
342 }
343 }
344 }
345
346 fn resolve_columns_with_data(
348 &self,
349 selector: &ColumnSelector,
350 data: &Array2<Float>,
351 ) -> Result<Vec<usize>> {
352 let (_, n_features) = data.dim();
353
354 match selector {
355 ColumnSelector::Indices(indices) => {
356 for &idx in indices {
358 if idx >= n_features {
359 return Err(SklearsError::InvalidInput(format!(
360 "Column index {} is out of bounds for {} features",
361 idx, n_features
362 )));
363 }
364 }
365 Ok(indices.clone())
366 }
367 ColumnSelector::Names(_names) => {
368 Err(SklearsError::NotImplemented(
370 "Named column selection not yet implemented".to_string(),
371 ))
372 }
373 ColumnSelector::DataType(dtype) => self.infer_columns_by_dtype_with_data(dtype, data),
374 ColumnSelector::Remainder => {
375 Ok(Vec::new())
377 }
378 }
379 }
380
381 fn infer_columns_by_dtype(&self, _dtype: &DataType, _n_features: usize) -> Result<Vec<usize>> {
383 Err(SklearsError::InvalidInput(
385 "Data type column selection requires training data context. \
386 Use resolve_columns_with_data during fit."
387 .to_string(),
388 ))
389 }
390
391 fn infer_columns_by_dtype_with_data(
393 &self,
394 dtype: &DataType,
395 data: &Array2<Float>,
396 ) -> Result<Vec<usize>> {
397 let (_n_samples, n_features) = data.dim();
398 let mut matching_columns = Vec::new();
399
400 for col_idx in 0..n_features {
401 let column = data.column(col_idx);
402 let column_type = self.infer_column_type(&column);
403
404 if column_type == *dtype {
405 matching_columns.push(col_idx);
406 }
407 }
408
409 Ok(matching_columns)
410 }
411
412 fn infer_column_type(&self, column: &scirs2_core::ndarray::ArrayView1<Float>) -> DataType {
414 let unique_values: std::collections::HashSet<_> =
415 column.iter().map(|&x| OrderedFloat::from(x)).collect();
416
417 let n_unique = unique_values.len();
418 let n_total = column.len();
419
420 if n_unique <= 2 {
422 let zero_bits = OrderedFloat::from(0.0);
423 let one_bits = OrderedFloat::from(1.0);
424 if unique_values
425 .iter()
426 .all(|&x| x == zero_bits || x == one_bits)
427 {
428 return DataType::Boolean;
429 }
430 }
431
432 let unique_ratio = n_unique as f64 / n_total as f64;
435
436 if (unique_ratio < 0.6 && n_unique <= 5) || unique_ratio < 0.2 {
439 DataType::Categorical
440 } else {
441 DataType::Numeric
442 }
443 }
444
445 fn get_remainder_indices(&self, data: &Array2<Float>) -> Result<Vec<usize>> {
447 let (_, n_features) = data.dim();
448 let mut used_indices = std::collections::HashSet::new();
449
450 for step in &self.transformers {
452 let indices = match &step.columns {
453 ColumnSelector::DataType(_) => {
454 self.resolve_columns_with_data(&step.columns, data)?
455 }
456 _ => self.resolve_columns(&step.columns, n_features)?,
457 };
458 for idx in indices {
459 used_indices.insert(idx);
460 }
461 }
462
463 Ok((0..n_features)
465 .filter(|i| !used_indices.contains(i))
466 .collect())
467 }
468}
469
470impl Default for ColumnTransformer<Untrained> {
471 fn default() -> Self {
472 Self::new()
473 }
474}
475
476impl Estimator for ColumnTransformer<Untrained> {
477 type Config = ColumnTransformerConfig;
478 type Error = SklearsError;
479 type Float = Float;
480
481 fn config(&self) -> &Self::Config {
482 &self.config
483 }
484}
485
486impl Estimator for ColumnTransformer<Trained> {
487 type Config = ColumnTransformerConfig;
488 type Error = SklearsError;
489 type Float = Float;
490
491 fn config(&self) -> &Self::Config {
492 &self.config
493 }
494}
495
496impl Fit<Array2<Float>, ()> for ColumnTransformer<Untrained> {
497 type Fitted = ColumnTransformer<Trained>;
498
499 fn fit(self, x: &Array2<Float>, _y: &()) -> Result<Self::Fitted> {
500 let (n_samples, n_features) = x.dim();
501
502 if n_samples == 0 {
503 return Err(SklearsError::InvalidInput(
504 "Cannot fit transformer on empty dataset".to_string(),
505 ));
506 }
507
508 let remainder_indices = self.get_remainder_indices(x)?;
510
511 let mut transformer_tasks: Vec<(TransformerStep, Vec<usize>)> = Vec::new();
513
514 for step in &self.transformers {
515 let indices = match &step.columns {
517 ColumnSelector::DataType(_) => self.resolve_columns_with_data(&step.columns, x)?,
518 _ => self.resolve_columns(&step.columns, n_features)?,
519 };
520
521 if !indices.is_empty() {
522 transformer_tasks.push((step.clone(), indices));
523 }
524 }
525
526 let transform_results: Vec<ColumnTransformResult> = if self.config.parallel_execution
528 && transformer_tasks.len() > 1
529 {
530 #[cfg(feature = "parallel")]
531 {
532 transformer_tasks
533 .into_par_iter()
534 .map(|(step, indices)| {
535 let subset = x.select(Axis(1), &indices);
536 self.apply_transformer_with_error_handling(
537 &step, x, &subset, true, &indices,
538 )
539 })
540 .collect()
541 }
542 #[cfg(not(feature = "parallel"))]
543 {
544 transformer_tasks
546 .into_iter()
547 .map(|(step, indices)| {
548 let subset = x.select(Axis(1), &indices);
549 self.apply_transformer_with_error_handling(
550 &step, x, &subset, true, &indices,
551 )
552 })
553 .collect()
554 }
555 } else {
556 transformer_tasks
558 .into_iter()
559 .map(|(step, indices)| {
560 let subset = x.select(Axis(1), &indices);
561 self.apply_transformer_with_error_handling(&step, x, &subset, true, &indices)
562 })
563 .collect()
564 };
565
566 let mut fitted_transformers = Vec::new();
568 let mut output_indices = HashMap::new();
569
570 for transform_result in transform_results {
571 match transform_result.result {
572 Ok(transformed) => {
573 if transformed.ncols() > 0 {
574 let output_cols = (0..transformed.ncols()).collect();
577 output_indices
578 .insert(transform_result.transformer_name.clone(), output_cols);
579
580 let transformer_name = transform_result.transformer_name.clone();
582 fitted_transformers.push(TransformerStep {
583 name: transformer_name.clone(),
584 columns: ColumnSelector::Indices(transform_result.original_indices),
585 transformer: self
586 .transformers
587 .iter()
588 .find(|s| s.name == transformer_name)
589 .unwrap()
590 .transformer
591 .clone_box(),
592 });
593 }
594 }
595 Err(e) => {
596 return Err(SklearsError::TransformError(format!(
598 "Transformer '{}' failed: {}",
599 transform_result.transformer_name, e
600 )));
601 }
602 }
603 }
604
605 Ok(ColumnTransformer {
606 config: self.config,
607 transformers: self.transformers,
608 state: PhantomData,
609 fitted_transformers_: Some(fitted_transformers),
610 feature_names_in_: None,
611 n_features_in_: Some(n_features),
612 output_indices_: Some(output_indices),
613 remainder_indices_: Some(remainder_indices),
614 })
615 }
616}
617
618impl Transform<Array2<Float>, Array2<Float>> for ColumnTransformer<Trained> {
619 fn transform(&self, x: &Array2<Float>) -> Result<Array2<Float>> {
620 let (n_samples, n_features) = x.dim();
621
622 if Some(n_features) != self.n_features_in_ {
623 return Err(SklearsError::FeatureMismatch {
624 expected: self.n_features_in_.unwrap_or(0),
625 actual: n_features,
626 });
627 }
628
629 let fitted_transformers = self.fitted_transformers_.as_ref().unwrap();
630 let remainder_indices = self.remainder_indices_.as_ref().unwrap();
631
632 let transformer_tasks: Vec<&TransformerStep> = fitted_transformers.iter().collect();
634
635 let transform_results: Vec<ColumnTransformResult> =
637 if self.config.parallel_execution && transformer_tasks.len() > 1 {
638 #[cfg(feature = "parallel")]
639 {
640 transformer_tasks
641 .into_par_iter()
642 .filter_map(|step| {
643 if let ColumnSelector::Indices(indices) = &step.columns {
644 if !indices.is_empty() {
645 let subset = x.select(Axis(1), indices);
646 Some(self.apply_transformer_with_error_handling(
647 step, x, &subset, false, indices,
648 ))
649 } else {
650 None
651 }
652 } else {
653 None
654 }
655 })
656 .collect()
657 }
658 #[cfg(not(feature = "parallel"))]
659 {
660 transformer_tasks
662 .into_iter()
663 .filter_map(|step| {
664 if let ColumnSelector::Indices(indices) = &step.columns {
665 if !indices.is_empty() {
666 let subset = x.select(Axis(1), indices);
667 Some(self.apply_transformer_with_error_handling(
668 step, x, &subset, false, indices,
669 ))
670 } else {
671 None
672 }
673 } else {
674 None
675 }
676 })
677 .collect()
678 }
679 } else {
680 transformer_tasks
682 .into_iter()
683 .filter_map(|step| {
684 if let ColumnSelector::Indices(indices) = &step.columns {
685 if !indices.is_empty() {
686 let subset = x.select(Axis(1), indices);
687 Some(self.apply_transformer_with_error_handling(
688 step, x, &subset, false, indices,
689 ))
690 } else {
691 None
692 }
693 } else {
694 None
695 }
696 })
697 .collect()
698 };
699
700 let mut column_outputs: Vec<(usize, Array2<Float>)> = Vec::new();
702
703 for transform_result in transform_results {
704 match transform_result.result {
705 Ok(transformed) => {
706 if transformed.ncols() > 0 {
707 let min_index = *transform_result.original_indices.iter().min().unwrap();
710 column_outputs.push((min_index, transformed));
711 }
712 }
713 Err(e) => {
714 return Err(SklearsError::TransformError(format!(
716 "Transformer '{}' failed: {}",
717 transform_result.transformer_name, e
718 )));
719 }
720 }
721 }
722
723 if !remainder_indices.is_empty() {
725 let remainder_data = x.select(Axis(1), remainder_indices);
726
727 let transformed_remainder = match &self.config.remainder {
728 RemainderStrategy::Drop => {
729 None }
731 RemainderStrategy::Passthrough => Some(remainder_data),
732 RemainderStrategy::Transform(transformer) => {
733 let transformed = transformer.transform_wrapper(&remainder_data)?;
734 Some(transformed)
735 }
736 };
737
738 if let Some(remainder_output) = transformed_remainder {
739 if let Some(&min_remainder_index) = remainder_indices.iter().min() {
741 column_outputs.push((min_remainder_index, remainder_output));
742 }
743 }
744 }
745
746 column_outputs.sort_by_key(|(idx, _)| *idx);
748
749 if column_outputs.is_empty() {
751 return Err(SklearsError::InvalidInput(
752 "No output from any transformer".to_string(),
753 ));
754 }
755
756 let total_cols: usize = column_outputs.iter().map(|(_, arr)| arr.ncols()).sum();
758 let mut result = Array2::zeros((n_samples, total_cols));
759
760 let mut col_offset = 0;
762 for (_, part) in column_outputs {
763 let part_cols = part.ncols();
764 result
765 .slice_mut(s![.., col_offset..col_offset + part_cols])
766 .assign(&part);
767 col_offset += part_cols;
768 }
769
770 Ok(result)
771 }
772}
773
774impl ColumnTransformer<Trained> {
775 pub fn n_features_in(&self) -> usize {
777 self.n_features_in_.unwrap()
778 }
779
780 pub fn output_indices(&self) -> &HashMap<String, Vec<usize>> {
782 self.output_indices_.as_ref().unwrap()
783 }
784
785 pub fn remainder_indices(&self) -> &Vec<usize> {
787 self.remainder_indices_.as_ref().unwrap()
788 }
789}
790
791#[allow(non_snake_case)]
792#[cfg(test)]
793mod tests {
794 use super::*;
795 use scirs2_core::ndarray::array;
796
797 #[derive(Debug, Clone)]
799 struct MockTransformer {
800 scale: Float,
801 }
802
803 impl TransformerWrapper for MockTransformer {
804 fn fit_transform_wrapper(&self, x: &Array2<Float>) -> Result<Array2<Float>> {
805 Ok(x * self.scale)
806 }
807
808 fn transform_wrapper(&self, x: &Array2<Float>) -> Result<Array2<Float>> {
809 Ok(x * self.scale)
810 }
811
812 fn get_n_features_out(&self) -> Option<usize> {
813 None }
815
816 fn clone_box(&self) -> Box<dyn TransformerWrapper> {
817 Box::new(self.clone())
818 }
819 }
820
821 #[test]
822 fn test_column_transformer_basic() {
823 let x = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0],];
824
825 let ct = ColumnTransformer::new()
826 .add_transformer(
827 "scale_first_two",
828 MockTransformer { scale: 2.0 },
829 ColumnSelector::Indices(vec![0, 1]),
830 )
831 .remainder(RemainderStrategy::Passthrough);
832
833 let fitted_ct = ct.fit(&x, &()).unwrap();
834 let result = fitted_ct.transform(&x).unwrap();
835
836 assert_eq!(result.dim(), (3, 3));
838 assert_eq!(result[[0, 0]], 2.0); assert_eq!(result[[0, 1]], 4.0); assert_eq!(result[[0, 2]], 3.0); }
842
843 #[test]
844 fn test_column_transformer_drop_remainder() {
845 let x = array![[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0],];
846
847 let ct = ColumnTransformer::new()
848 .add_transformer(
849 "scale_middle",
850 MockTransformer { scale: 3.0 },
851 ColumnSelector::Indices(vec![1, 2]),
852 )
853 .remainder(RemainderStrategy::Drop);
854
855 let fitted_ct = ct.fit(&x, &()).unwrap();
856 let result = fitted_ct.transform(&x).unwrap();
857
858 assert_eq!(result.dim(), (2, 2));
860 assert_eq!(result[[0, 0]], 6.0); assert_eq!(result[[0, 1]], 9.0); }
863
864 #[test]
865 fn test_column_transformer_multiple_transformers() {
866 let x = array![[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0],];
867
868 let ct = ColumnTransformer::new()
869 .add_transformer(
870 "scale_first",
871 MockTransformer { scale: 2.0 },
872 ColumnSelector::Indices(vec![0]),
873 )
874 .add_transformer(
875 "scale_last",
876 MockTransformer { scale: 0.5 },
877 ColumnSelector::Indices(vec![3]),
878 )
879 .remainder(RemainderStrategy::Passthrough);
880
881 let fitted_ct = ct.fit(&x, &()).unwrap();
882 let result = fitted_ct.transform(&x).unwrap();
883
884 assert_eq!(result.dim(), (2, 4));
886 assert_eq!(result[[0, 0]], 2.0); assert_eq!(result[[0, 1]], 2.0); assert_eq!(result[[0, 2]], 3.0); assert_eq!(result[[0, 3]], 2.0); }
891
892 #[test]
893 fn test_column_transformer_empty_data() {
894 let x_empty: Array2<Float> = Array2::zeros((0, 3));
895
896 let ct = ColumnTransformer::new().add_transformer(
897 "test",
898 MockTransformer { scale: 1.0 },
899 ColumnSelector::Indices(vec![0]),
900 );
901
902 let result = ct.fit(&x_empty, &());
903 assert!(result.is_err());
904 }
905
906 #[test]
907 fn test_column_transformer_invalid_indices() {
908 let x = array![[1.0, 2.0], [3.0, 4.0],];
909
910 let ct = ColumnTransformer::new().add_transformer(
911 "invalid",
912 MockTransformer { scale: 1.0 },
913 ColumnSelector::Indices(vec![0, 5]), );
915
916 let result = ct.fit(&x, &());
917 assert!(result.is_err());
918 }
919
920 #[test]
921 fn test_column_type_inference() {
922 let ct = ColumnTransformer::new();
923
924 let bool_col = scirs2_core::ndarray::array![0.0, 1.0, 0.0, 1.0, 0.0];
926 let bool_type = ct.infer_column_type(&bool_col.view());
927 assert_eq!(bool_type, DataType::Boolean);
928
929 let cat_col = scirs2_core::ndarray::array![1.0, 2.0, 1.0, 3.0, 2.0, 1.0, 2.0, 3.0];
931 let cat_type = ct.infer_column_type(&cat_col.view());
932 assert_eq!(cat_type, DataType::Categorical);
933
934 let num_col =
936 scirs2_core::ndarray::array![1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8, 9.9, 10.0];
937 let num_type = ct.infer_column_type(&num_col.view());
938 assert_eq!(num_type, DataType::Numeric);
939 }
940
941 #[derive(Debug, Clone)]
943 struct FailingTransformer {
944 should_fail: bool,
945 }
946
947 impl TransformerWrapper for FailingTransformer {
948 fn fit_transform_wrapper(&self, x: &Array2<Float>) -> Result<Array2<Float>> {
949 if self.should_fail {
950 Err(SklearsError::InvalidInput(
951 "Intentional failure for testing".to_string(),
952 ))
953 } else {
954 Ok(x * 2.0)
955 }
956 }
957
958 fn transform_wrapper(&self, x: &Array2<Float>) -> Result<Array2<Float>> {
959 if self.should_fail {
960 Err(SklearsError::InvalidInput(
961 "Intentional failure for testing".to_string(),
962 ))
963 } else {
964 Ok(x * 2.0)
965 }
966 }
967
968 fn get_n_features_out(&self) -> Option<usize> {
969 None
970 }
971
972 fn clone_box(&self) -> Box<dyn TransformerWrapper> {
973 Box::new(self.clone())
974 }
975 }
976
977 #[test]
978 fn test_column_transformer_error_handling_stop_on_error() {
979 let x = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
980
981 let ct = ColumnTransformer::new()
982 .add_transformer(
983 "failing",
984 FailingTransformer { should_fail: true },
985 ColumnSelector::Indices(vec![0]),
986 )
987 .error_strategy(ColumnErrorStrategy::StopOnError);
988
989 let result = ct.fit(&x, &());
990 assert!(result.is_err(), "Should fail with StopOnError");
991 }
992
993 #[test]
994 fn test_column_transformer_error_handling_skip_on_error() {
995 let x = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
996
997 let ct = ColumnTransformer::new()
998 .add_transformer(
999 "failing",
1000 FailingTransformer { should_fail: true },
1001 ColumnSelector::Indices(vec![0]),
1002 )
1003 .add_transformer(
1004 "working",
1005 MockTransformer { scale: 2.0 },
1006 ColumnSelector::Indices(vec![1]),
1007 )
1008 .error_strategy(ColumnErrorStrategy::SkipOnError)
1009 .remainder(RemainderStrategy::Passthrough);
1010
1011 let fitted_ct = ct.fit(&x, &()).unwrap();
1012 let result = fitted_ct.transform(&x).unwrap();
1013
1014 assert_eq!(result.dim(), (2, 2));
1016 assert_eq!(result[[0, 0]], 4.0); assert_eq!(result[[0, 1]], 3.0); }
1019
1020 #[test]
1021 fn test_column_transformer_error_handling_replace_with_zeros() {
1022 let x = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
1023
1024 let ct = ColumnTransformer::new()
1025 .add_transformer(
1026 "failing",
1027 FailingTransformer { should_fail: true },
1028 ColumnSelector::Indices(vec![0]),
1029 )
1030 .error_strategy(ColumnErrorStrategy::ReplaceWithZeros)
1031 .remainder(RemainderStrategy::Passthrough);
1032
1033 let fitted_ct = ct.fit(&x, &()).unwrap();
1034 let result = fitted_ct.transform(&x).unwrap();
1035
1036 assert_eq!(result.dim(), (2, 3));
1038 assert_eq!(result[[0, 0]], 0.0); assert_eq!(result[[0, 1]], 2.0); assert_eq!(result[[0, 2]], 3.0); }
1042
1043 #[test]
1044 fn test_column_transformer_error_handling_fallback() {
1045 let x = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
1046
1047 let ct = ColumnTransformer::new()
1048 .add_transformer(
1049 "failing",
1050 FailingTransformer { should_fail: true },
1051 ColumnSelector::Indices(vec![0]),
1052 )
1053 .error_strategy(ColumnErrorStrategy::Fallback)
1054 .fallback_transformer(MockTransformer { scale: 0.5 })
1055 .remainder(RemainderStrategy::Passthrough);
1056
1057 let fitted_ct = ct.fit(&x, &()).unwrap();
1058 let result = fitted_ct.transform(&x).unwrap();
1059
1060 assert_eq!(result.dim(), (2, 3));
1062 assert_eq!(result[[0, 0]], 0.5); assert_eq!(result[[0, 1]], 2.0); assert_eq!(result[[0, 2]], 3.0); }
1066
1067 #[test]
1068 fn test_column_transformer_parallel_execution() {
1069 let x = array![[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0]];
1070
1071 let ct = ColumnTransformer::new()
1072 .add_transformer(
1073 "scale_first",
1074 MockTransformer { scale: 2.0 },
1075 ColumnSelector::Indices(vec![0]),
1076 )
1077 .add_transformer(
1078 "scale_second",
1079 MockTransformer { scale: 3.0 },
1080 ColumnSelector::Indices(vec![1]),
1081 )
1082 .parallel_execution(true)
1083 .remainder(RemainderStrategy::Passthrough);
1084
1085 let fitted_ct = ct.fit(&x, &()).unwrap();
1086 let result = fitted_ct.transform(&x).unwrap();
1087
1088 assert_eq!(result.dim(), (2, 4));
1090 assert_eq!(result[[0, 0]], 2.0); assert_eq!(result[[0, 1]], 6.0); assert_eq!(result[[0, 2]], 3.0); assert_eq!(result[[0, 3]], 4.0); }
1095
1096 #[test]
1097 fn test_column_transformer_dtype_selection() {
1098 let x = array![
1103 [1.23456, 0.0, 1.0],
1104 [2.78901, 1.0, 1.0],
1105 [3.45678, 0.0, 2.0],
1106 [4.98765, 1.0, 1.0],
1107 [5.12345, 0.0, 2.0],
1108 [6.67890, 1.0, 3.0],
1109 [7.11111, 0.0, 1.0],
1110 [8.22222, 1.0, 2.0],
1111 ];
1112
1113 let ct_bool = ColumnTransformer::new().add_transformer(
1115 "bool_transformer",
1116 MockTransformer { scale: 10.0 },
1117 ColumnSelector::DataType(DataType::Boolean),
1118 );
1119
1120 let fitted_ct_bool = ct_bool.fit(&x, &()).unwrap();
1121 let result_bool = fitted_ct_bool.transform(&x).unwrap();
1122
1123 assert_eq!(result_bool.dim(), (8, 1));
1125 assert_eq!(result_bool[[0, 0]], 0.0); assert_eq!(result_bool[[1, 0]], 10.0); let ct_cat = ColumnTransformer::new().add_transformer(
1130 "cat_transformer",
1131 MockTransformer { scale: 0.1 },
1132 ColumnSelector::DataType(DataType::Categorical),
1133 );
1134
1135 let fitted_ct_cat = ct_cat.fit(&x, &()).unwrap();
1136 let result_cat = fitted_ct_cat.transform(&x).unwrap();
1137
1138 assert_eq!(result_cat.dim(), (8, 1));
1140 assert_eq!(result_cat[[0, 0]], 0.1); let ct_num = ColumnTransformer::new().add_transformer(
1144 "num_transformer",
1145 MockTransformer { scale: 2.0 },
1146 ColumnSelector::DataType(DataType::Numeric),
1147 );
1148
1149 let fitted_ct_num = ct_num.fit(&x, &()).unwrap();
1150 let result_num = fitted_ct_num.transform(&x).unwrap();
1151
1152 assert_eq!(result_num.dim(), (8, 1));
1154 let expected_first = 1.23456 * 2.0;
1155 assert!((result_num[[0, 0]] - expected_first).abs() < 1e-10)
1156 }
1157}