1use scirs2_core::ndarray::{s, Array2, Axis};
7use sklears_core::{
8 error::{Result, SklearsError},
9 traits::{Estimator, Fit, Trained, Transform, Untrained},
10 types::Float,
11};
12use std::collections::HashMap;
13use std::marker::PhantomData;
14
15#[cfg(feature = "parallel")]
16use rayon::prelude::*;
17
18#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
20struct OrderedFloat(u64);
21
22impl From<Float> for OrderedFloat {
23 fn from(val: Float) -> Self {
24 OrderedFloat(val.to_bits())
25 }
26}
27
28#[derive(Debug, Clone)]
30pub enum ColumnSelector {
31 Indices(Vec<usize>),
33 Names(Vec<String>),
35 DataType(DataType),
37 Remainder,
39}
40
41#[derive(Debug, Clone, PartialEq)]
43pub enum DataType {
44 Numeric,
45 Categorical,
46 Boolean,
47}
48
49#[derive(Debug, Clone)]
51pub enum RemainderStrategy {
52 Drop,
54 Passthrough,
56 Transform(Box<dyn TransformerWrapper>),
58}
59
60#[derive(Debug, Clone, Copy, PartialEq)]
62pub enum ColumnErrorStrategy {
63 StopOnError,
65 SkipOnError,
67 Fallback,
69 ReplaceWithZeros,
71 ReplaceWithNaN,
73}
74
75impl Default for ColumnErrorStrategy {
76 fn default() -> Self {
77 Self::StopOnError
78 }
79}
80
81impl Default for RemainderStrategy {
82 fn default() -> Self {
83 Self::Drop
84 }
85}
86
87pub trait TransformerWrapper: Send + Sync + std::fmt::Debug {
89 fn fit_transform_wrapper(&self, x: &Array2<Float>) -> Result<Array2<Float>>;
90 fn transform_wrapper(&self, x: &Array2<Float>) -> Result<Array2<Float>>;
91 fn get_n_features_out(&self) -> Option<usize>;
92 fn clone_box(&self) -> Box<dyn TransformerWrapper>;
93}
94
95impl Clone for Box<dyn TransformerWrapper> {
96 fn clone(&self) -> Self {
97 self.clone_box()
98 }
99}
100
101#[derive(Debug, Clone)]
103pub struct TransformerStep {
104 pub name: String,
106 pub columns: ColumnSelector,
108 pub transformer: Box<dyn TransformerWrapper>,
110}
111
112#[derive(Debug, Clone)]
114pub struct ColumnTransformerConfig {
115 pub remainder: RemainderStrategy,
117 pub preserve_order: bool,
119 pub n_jobs: Option<usize>,
121 pub validate_input: bool,
123 pub error_strategy: ColumnErrorStrategy,
125 pub parallel_execution: bool,
127 pub fallback_transformer: Option<Box<dyn TransformerWrapper>>,
129}
130
131impl Default for ColumnTransformerConfig {
132 fn default() -> Self {
133 Self {
134 remainder: RemainderStrategy::Drop,
135 preserve_order: false,
136 n_jobs: None,
137 validate_input: true,
138 error_strategy: ColumnErrorStrategy::StopOnError,
139 parallel_execution: false,
140 fallback_transformer: None,
141 }
142 }
143}
144
145#[derive(Debug)]
147pub struct ColumnTransformer<State = Untrained> {
148 config: ColumnTransformerConfig,
149 transformers: Vec<TransformerStep>,
150 state: PhantomData<State>,
151 fitted_transformers_: Option<Vec<TransformerStep>>,
153 feature_names_in_: Option<Vec<String>>,
154 n_features_in_: Option<usize>,
155 output_indices_: Option<HashMap<String, Vec<usize>>>,
156 remainder_indices_: Option<Vec<usize>>,
157}
158
159#[derive(Debug)]
161struct ColumnTransformResult {
162 transformer_name: String,
163 column_indices: Vec<usize>,
164 result: Result<Array2<Float>>,
165 original_indices: Vec<usize>,
166}
167
168impl<State> ColumnTransformer<State> {
170 fn apply_transformer_with_error_handling(
172 &self,
173 step: &TransformerStep,
174 _data: &Array2<Float>,
175 subset: &Array2<Float>,
176 is_fit_transform: bool,
177 resolved_indices: &[usize],
178 ) -> ColumnTransformResult {
179 let column_indices = resolved_indices.to_vec();
180
181 let transform_result = if is_fit_transform {
182 step.transformer.fit_transform_wrapper(subset)
183 } else {
184 step.transformer.transform_wrapper(subset)
185 };
186
187 let final_result = match transform_result {
188 Ok(transformed) => Ok(transformed),
189 Err(error) => {
190 match self.config.error_strategy {
192 ColumnErrorStrategy::StopOnError => Err(error),
193 ColumnErrorStrategy::SkipOnError => {
194 eprintln!(
195 "Warning: Transformer '{}' failed on columns {:?}: {}. Skipping...",
196 step.name, column_indices, error
197 );
198 Ok(Array2::zeros((subset.nrows(), 0)))
200 }
201 ColumnErrorStrategy::Fallback => {
202 if let Some(ref fallback) = self.config.fallback_transformer {
203 eprintln!("Warning: Transformer '{}' failed on columns {:?}: {}. Using fallback...",
204 step.name, column_indices, error);
205 if is_fit_transform {
206 fallback.fit_transform_wrapper(subset)
207 } else {
208 fallback.transform_wrapper(subset)
209 }
210 } else {
211 eprintln!("Warning: Transformer '{}' failed on columns {:?}: {}. No fallback available, passing through...",
212 step.name, column_indices, error);
213 Ok(subset.clone())
214 }
215 }
216 ColumnErrorStrategy::ReplaceWithZeros => {
217 eprintln!("Warning: Transformer '{}' failed on columns {:?}: {}. Replacing with zeros...",
218 step.name, column_indices, error);
219 Ok(Array2::zeros(subset.dim()))
220 }
221 ColumnErrorStrategy::ReplaceWithNaN => {
222 eprintln!("Warning: Transformer '{}' failed on columns {:?}: {}. Replacing with NaN...",
223 step.name, column_indices, error);
224 Ok(Array2::from_elem(subset.dim(), Float::NAN))
225 }
226 }
227 }
228 };
229
230 ColumnTransformResult {
231 transformer_name: step.name.clone(),
232 column_indices: column_indices.clone(),
233 result: final_result,
234 original_indices: column_indices,
235 }
236 }
237}
238
239impl ColumnTransformer<Untrained> {
240 pub fn new() -> Self {
242 Self {
243 config: ColumnTransformerConfig::default(),
244 transformers: Vec::new(),
245 state: PhantomData,
246 fitted_transformers_: None,
247 feature_names_in_: None,
248 n_features_in_: None,
249 output_indices_: None,
250 remainder_indices_: None,
251 }
252 }
253
254 pub fn add_transformer<T>(mut self, name: &str, transformer: T, columns: ColumnSelector) -> Self
256 where
257 T: TransformerWrapper + 'static,
258 {
259 self.transformers.push(TransformerStep {
260 name: name.to_string(),
261 columns,
262 transformer: Box::new(transformer),
263 });
264 self
265 }
266
267 pub fn remainder(mut self, strategy: RemainderStrategy) -> Self {
269 self.config.remainder = strategy;
270 self
271 }
272
273 pub fn preserve_order(mut self, preserve: bool) -> Self {
275 self.config.preserve_order = preserve;
276 self
277 }
278
279 pub fn n_jobs(mut self, n_jobs: Option<usize>) -> Self {
281 self.config.n_jobs = n_jobs;
282 self
283 }
284
285 pub fn validate_input(mut self, validate: bool) -> Self {
287 self.config.validate_input = validate;
288 self
289 }
290
291 pub fn error_strategy(mut self, strategy: ColumnErrorStrategy) -> Self {
293 self.config.error_strategy = strategy;
294 self
295 }
296
297 pub fn parallel_execution(mut self, parallel: bool) -> Self {
299 self.config.parallel_execution = parallel;
300 self
301 }
302
303 pub fn fallback_transformer<T>(mut self, transformer: T) -> Self
305 where
306 T: TransformerWrapper + 'static,
307 {
308 self.config.fallback_transformer = Some(Box::new(transformer));
309 self
310 }
311
312 fn resolve_columns(&self, selector: &ColumnSelector, n_features: usize) -> Result<Vec<usize>> {
314 match selector {
315 ColumnSelector::Indices(indices) => {
316 for &idx in indices {
318 if idx >= n_features {
319 return Err(SklearsError::InvalidInput(format!(
320 "Column index {} is out of bounds for {} features",
321 idx, n_features
322 )));
323 }
324 }
325 Ok(indices.clone())
326 }
327 ColumnSelector::Names(_names) => {
328 Err(SklearsError::NotImplemented(
330 "Named column selection not yet implemented".to_string(),
331 ))
332 }
333 ColumnSelector::DataType(_dtype) => {
334 Err(SklearsError::InvalidInput(
336 "DataType column selection requires training data. Use resolve_columns_with_data.".to_string(),
337 ))
338 }
339 ColumnSelector::Remainder => {
340 Ok(Vec::new())
342 }
343 }
344 }
345
346 fn resolve_columns_with_data(
348 &self,
349 selector: &ColumnSelector,
350 data: &Array2<Float>,
351 ) -> Result<Vec<usize>> {
352 let (_, n_features) = data.dim();
353
354 match selector {
355 ColumnSelector::Indices(indices) => {
356 for &idx in indices {
358 if idx >= n_features {
359 return Err(SklearsError::InvalidInput(format!(
360 "Column index {} is out of bounds for {} features",
361 idx, n_features
362 )));
363 }
364 }
365 Ok(indices.clone())
366 }
367 ColumnSelector::Names(_names) => {
368 Err(SklearsError::NotImplemented(
370 "Named column selection not yet implemented".to_string(),
371 ))
372 }
373 ColumnSelector::DataType(dtype) => self.infer_columns_by_dtype_with_data(dtype, data),
374 ColumnSelector::Remainder => {
375 Ok(Vec::new())
377 }
378 }
379 }
380
381 fn infer_columns_by_dtype(&self, _dtype: &DataType, _n_features: usize) -> Result<Vec<usize>> {
383 Err(SklearsError::InvalidInput(
385 "Data type column selection requires training data context. \
386 Use resolve_columns_with_data during fit."
387 .to_string(),
388 ))
389 }
390
391 fn infer_columns_by_dtype_with_data(
393 &self,
394 dtype: &DataType,
395 data: &Array2<Float>,
396 ) -> Result<Vec<usize>> {
397 let (_n_samples, n_features) = data.dim();
398 let mut matching_columns = Vec::new();
399
400 for col_idx in 0..n_features {
401 let column = data.column(col_idx);
402 let column_type = self.infer_column_type(&column);
403
404 if column_type == *dtype {
405 matching_columns.push(col_idx);
406 }
407 }
408
409 Ok(matching_columns)
410 }
411
412 fn infer_column_type(&self, column: &scirs2_core::ndarray::ArrayView1<Float>) -> DataType {
414 let unique_values: std::collections::HashSet<_> =
415 column.iter().map(|&x| OrderedFloat::from(x)).collect();
416
417 let n_unique = unique_values.len();
418 let n_total = column.len();
419
420 if n_unique <= 2 {
422 let zero_bits = OrderedFloat::from(0.0);
423 let one_bits = OrderedFloat::from(1.0);
424 if unique_values
425 .iter()
426 .all(|&x| x == zero_bits || x == one_bits)
427 {
428 return DataType::Boolean;
429 }
430 }
431
432 let unique_ratio = n_unique as f64 / n_total as f64;
435
436 if (unique_ratio < 0.6 && n_unique <= 5) || unique_ratio < 0.2 {
439 DataType::Categorical
440 } else {
441 DataType::Numeric
442 }
443 }
444
445 fn get_remainder_indices(&self, data: &Array2<Float>) -> Result<Vec<usize>> {
447 let (_, n_features) = data.dim();
448 let mut used_indices = std::collections::HashSet::new();
449
450 for step in &self.transformers {
452 let indices = match &step.columns {
453 ColumnSelector::DataType(_) => {
454 self.resolve_columns_with_data(&step.columns, data)?
455 }
456 _ => self.resolve_columns(&step.columns, n_features)?,
457 };
458 for idx in indices {
459 used_indices.insert(idx);
460 }
461 }
462
463 Ok((0..n_features)
465 .filter(|i| !used_indices.contains(i))
466 .collect())
467 }
468}
469
470impl Default for ColumnTransformer<Untrained> {
471 fn default() -> Self {
472 Self::new()
473 }
474}
475
476impl Estimator for ColumnTransformer<Untrained> {
477 type Config = ColumnTransformerConfig;
478 type Error = SklearsError;
479 type Float = Float;
480
481 fn config(&self) -> &Self::Config {
482 &self.config
483 }
484}
485
486impl Estimator for ColumnTransformer<Trained> {
487 type Config = ColumnTransformerConfig;
488 type Error = SklearsError;
489 type Float = Float;
490
491 fn config(&self) -> &Self::Config {
492 &self.config
493 }
494}
495
496impl Fit<Array2<Float>, ()> for ColumnTransformer<Untrained> {
497 type Fitted = ColumnTransformer<Trained>;
498
499 fn fit(self, x: &Array2<Float>, _y: &()) -> Result<Self::Fitted> {
500 let (n_samples, n_features) = x.dim();
501
502 if n_samples == 0 {
503 return Err(SklearsError::InvalidInput(
504 "Cannot fit transformer on empty dataset".to_string(),
505 ));
506 }
507
508 let remainder_indices = self.get_remainder_indices(x)?;
510
511 let mut transformer_tasks: Vec<(TransformerStep, Vec<usize>)> = Vec::new();
513
514 for step in &self.transformers {
515 let indices = match &step.columns {
517 ColumnSelector::DataType(_) => self.resolve_columns_with_data(&step.columns, x)?,
518 _ => self.resolve_columns(&step.columns, n_features)?,
519 };
520
521 if !indices.is_empty() {
522 transformer_tasks.push((step.clone(), indices));
523 }
524 }
525
526 let transform_results: Vec<ColumnTransformResult> = if self.config.parallel_execution
528 && transformer_tasks.len() > 1
529 {
530 #[cfg(feature = "parallel")]
531 {
532 transformer_tasks
533 .into_par_iter()
534 .map(|(step, indices)| {
535 let subset = x.select(Axis(1), &indices);
536 self.apply_transformer_with_error_handling(
537 &step, x, &subset, true, &indices,
538 )
539 })
540 .collect()
541 }
542 #[cfg(not(feature = "parallel"))]
543 {
544 transformer_tasks
546 .into_iter()
547 .map(|(step, indices)| {
548 let subset = x.select(Axis(1), &indices);
549 self.apply_transformer_with_error_handling(
550 &step, x, &subset, true, &indices,
551 )
552 })
553 .collect()
554 }
555 } else {
556 transformer_tasks
558 .into_iter()
559 .map(|(step, indices)| {
560 let subset = x.select(Axis(1), &indices);
561 self.apply_transformer_with_error_handling(&step, x, &subset, true, &indices)
562 })
563 .collect()
564 };
565
566 let mut fitted_transformers = Vec::new();
568 let mut output_indices = HashMap::new();
569
570 for transform_result in transform_results {
571 match transform_result.result {
572 Ok(transformed) => {
573 if transformed.ncols() > 0 {
574 let output_cols = (0..transformed.ncols()).collect();
577 output_indices
578 .insert(transform_result.transformer_name.clone(), output_cols);
579
580 let transformer_name = transform_result.transformer_name.clone();
582 fitted_transformers.push(TransformerStep {
583 name: transformer_name.clone(),
584 columns: ColumnSelector::Indices(transform_result.original_indices),
585 transformer: self
586 .transformers
587 .iter()
588 .find(|s| s.name == transformer_name)
589 .expect("operation should succeed")
590 .transformer
591 .clone_box(),
592 });
593 }
594 }
595 Err(e) => {
596 return Err(SklearsError::TransformError(format!(
598 "Transformer '{}' failed: {}",
599 transform_result.transformer_name, e
600 )));
601 }
602 }
603 }
604
605 Ok(ColumnTransformer {
606 config: self.config,
607 transformers: self.transformers,
608 state: PhantomData,
609 fitted_transformers_: Some(fitted_transformers),
610 feature_names_in_: None,
611 n_features_in_: Some(n_features),
612 output_indices_: Some(output_indices),
613 remainder_indices_: Some(remainder_indices),
614 })
615 }
616}
617
618impl Transform<Array2<Float>, Array2<Float>> for ColumnTransformer<Trained> {
619 fn transform(&self, x: &Array2<Float>) -> Result<Array2<Float>> {
620 let (n_samples, n_features) = x.dim();
621
622 if Some(n_features) != self.n_features_in_ {
623 return Err(SklearsError::FeatureMismatch {
624 expected: self.n_features_in_.unwrap_or(0),
625 actual: n_features,
626 });
627 }
628
629 let fitted_transformers = self
630 .fitted_transformers_
631 .as_ref()
632 .expect("operation should succeed");
633 let remainder_indices = self
634 .remainder_indices_
635 .as_ref()
636 .expect("operation should succeed");
637
638 let transformer_tasks: Vec<&TransformerStep> = fitted_transformers.iter().collect();
640
641 let transform_results: Vec<ColumnTransformResult> =
643 if self.config.parallel_execution && transformer_tasks.len() > 1 {
644 #[cfg(feature = "parallel")]
645 {
646 transformer_tasks
647 .into_par_iter()
648 .filter_map(|step| {
649 if let ColumnSelector::Indices(indices) = &step.columns {
650 if !indices.is_empty() {
651 let subset = x.select(Axis(1), indices);
652 Some(self.apply_transformer_with_error_handling(
653 step, x, &subset, false, indices,
654 ))
655 } else {
656 None
657 }
658 } else {
659 None
660 }
661 })
662 .collect()
663 }
664 #[cfg(not(feature = "parallel"))]
665 {
666 transformer_tasks
668 .into_iter()
669 .filter_map(|step| {
670 if let ColumnSelector::Indices(indices) = &step.columns {
671 if !indices.is_empty() {
672 let subset = x.select(Axis(1), indices);
673 Some(self.apply_transformer_with_error_handling(
674 step, x, &subset, false, indices,
675 ))
676 } else {
677 None
678 }
679 } else {
680 None
681 }
682 })
683 .collect()
684 }
685 } else {
686 transformer_tasks
688 .into_iter()
689 .filter_map(|step| {
690 if let ColumnSelector::Indices(indices) = &step.columns {
691 if !indices.is_empty() {
692 let subset = x.select(Axis(1), indices);
693 Some(self.apply_transformer_with_error_handling(
694 step, x, &subset, false, indices,
695 ))
696 } else {
697 None
698 }
699 } else {
700 None
701 }
702 })
703 .collect()
704 };
705
706 let mut column_outputs: Vec<(usize, Array2<Float>)> = Vec::new();
708
709 for transform_result in transform_results {
710 match transform_result.result {
711 Ok(transformed) => {
712 if transformed.ncols() > 0 {
713 let min_index = *transform_result
716 .original_indices
717 .iter()
718 .min()
719 .expect("collection should not be empty for min/max");
720 column_outputs.push((min_index, transformed));
721 }
722 }
723 Err(e) => {
724 return Err(SklearsError::TransformError(format!(
726 "Transformer '{}' failed: {}",
727 transform_result.transformer_name, e
728 )));
729 }
730 }
731 }
732
733 if !remainder_indices.is_empty() {
735 let remainder_data = x.select(Axis(1), remainder_indices);
736
737 let transformed_remainder = match &self.config.remainder {
738 RemainderStrategy::Drop => {
739 None }
741 RemainderStrategy::Passthrough => Some(remainder_data),
742 RemainderStrategy::Transform(transformer) => {
743 let transformed = transformer.transform_wrapper(&remainder_data)?;
744 Some(transformed)
745 }
746 };
747
748 if let Some(remainder_output) = transformed_remainder {
749 if let Some(&min_remainder_index) = remainder_indices.iter().min() {
751 column_outputs.push((min_remainder_index, remainder_output));
752 }
753 }
754 }
755
756 column_outputs.sort_by_key(|(idx, _)| *idx);
758
759 if column_outputs.is_empty() {
761 return Err(SklearsError::InvalidInput(
762 "No output from any transformer".to_string(),
763 ));
764 }
765
766 let total_cols: usize = column_outputs.iter().map(|(_, arr)| arr.ncols()).sum();
768 let mut result = Array2::zeros((n_samples, total_cols));
769
770 let mut col_offset = 0;
772 for (_, part) in column_outputs {
773 let part_cols = part.ncols();
774 result
775 .slice_mut(s![.., col_offset..col_offset + part_cols])
776 .assign(&part);
777 col_offset += part_cols;
778 }
779
780 Ok(result)
781 }
782}
783
784impl ColumnTransformer<Trained> {
785 pub fn n_features_in(&self) -> usize {
787 self.n_features_in_.expect("operation should succeed")
788 }
789
790 pub fn output_indices(&self) -> &HashMap<String, Vec<usize>> {
792 self.output_indices_
793 .as_ref()
794 .expect("operation should succeed")
795 }
796
797 pub fn remainder_indices(&self) -> &Vec<usize> {
799 self.remainder_indices_
800 .as_ref()
801 .expect("operation should succeed")
802 }
803}
804
805#[allow(non_snake_case)]
806#[cfg(test)]
807mod tests {
808 use super::*;
809 use scirs2_core::ndarray::array;
810
811 #[derive(Debug, Clone)]
813 struct MockTransformer {
814 scale: Float,
815 }
816
817 impl TransformerWrapper for MockTransformer {
818 fn fit_transform_wrapper(&self, x: &Array2<Float>) -> Result<Array2<Float>> {
819 Ok(x * self.scale)
820 }
821
822 fn transform_wrapper(&self, x: &Array2<Float>) -> Result<Array2<Float>> {
823 Ok(x * self.scale)
824 }
825
826 fn get_n_features_out(&self) -> Option<usize> {
827 None }
829
830 fn clone_box(&self) -> Box<dyn TransformerWrapper> {
831 Box::new(self.clone())
832 }
833 }
834
835 #[test]
836 fn test_column_transformer_basic() {
837 let x = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0],];
838
839 let ct = ColumnTransformer::new()
840 .add_transformer(
841 "scale_first_two",
842 MockTransformer { scale: 2.0 },
843 ColumnSelector::Indices(vec![0, 1]),
844 )
845 .remainder(RemainderStrategy::Passthrough);
846
847 let fitted_ct = ct.fit(&x, &()).expect("model fitting should succeed");
848 let result = fitted_ct
849 .transform(&x)
850 .expect("transformation should succeed");
851
852 assert_eq!(result.dim(), (3, 3));
854 assert_eq!(result[[0, 0]], 2.0); assert_eq!(result[[0, 1]], 4.0); assert_eq!(result[[0, 2]], 3.0); }
858
859 #[test]
860 fn test_column_transformer_drop_remainder() {
861 let x = array![[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0],];
862
863 let ct = ColumnTransformer::new()
864 .add_transformer(
865 "scale_middle",
866 MockTransformer { scale: 3.0 },
867 ColumnSelector::Indices(vec![1, 2]),
868 )
869 .remainder(RemainderStrategy::Drop);
870
871 let fitted_ct = ct.fit(&x, &()).expect("model fitting should succeed");
872 let result = fitted_ct
873 .transform(&x)
874 .expect("transformation should succeed");
875
876 assert_eq!(result.dim(), (2, 2));
878 assert_eq!(result[[0, 0]], 6.0); assert_eq!(result[[0, 1]], 9.0); }
881
882 #[test]
883 fn test_column_transformer_multiple_transformers() {
884 let x = array![[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0],];
885
886 let ct = ColumnTransformer::new()
887 .add_transformer(
888 "scale_first",
889 MockTransformer { scale: 2.0 },
890 ColumnSelector::Indices(vec![0]),
891 )
892 .add_transformer(
893 "scale_last",
894 MockTransformer { scale: 0.5 },
895 ColumnSelector::Indices(vec![3]),
896 )
897 .remainder(RemainderStrategy::Passthrough);
898
899 let fitted_ct = ct.fit(&x, &()).expect("model fitting should succeed");
900 let result = fitted_ct
901 .transform(&x)
902 .expect("transformation should succeed");
903
904 assert_eq!(result.dim(), (2, 4));
906 assert_eq!(result[[0, 0]], 2.0); assert_eq!(result[[0, 1]], 2.0); assert_eq!(result[[0, 2]], 3.0); assert_eq!(result[[0, 3]], 2.0); }
911
912 #[test]
913 fn test_column_transformer_empty_data() {
914 let x_empty: Array2<Float> = Array2::zeros((0, 3));
915
916 let ct = ColumnTransformer::new().add_transformer(
917 "test",
918 MockTransformer { scale: 1.0 },
919 ColumnSelector::Indices(vec![0]),
920 );
921
922 let result = ct.fit(&x_empty, &());
923 assert!(result.is_err());
924 }
925
926 #[test]
927 fn test_column_transformer_invalid_indices() {
928 let x = array![[1.0, 2.0], [3.0, 4.0],];
929
930 let ct = ColumnTransformer::new().add_transformer(
931 "invalid",
932 MockTransformer { scale: 1.0 },
933 ColumnSelector::Indices(vec![0, 5]), );
935
936 let result = ct.fit(&x, &());
937 assert!(result.is_err());
938 }
939
940 #[test]
941 fn test_column_type_inference() {
942 let ct = ColumnTransformer::new();
943
944 let bool_col = scirs2_core::ndarray::array![0.0, 1.0, 0.0, 1.0, 0.0];
946 let bool_type = ct.infer_column_type(&bool_col.view());
947 assert_eq!(bool_type, DataType::Boolean);
948
949 let cat_col = scirs2_core::ndarray::array![1.0, 2.0, 1.0, 3.0, 2.0, 1.0, 2.0, 3.0];
951 let cat_type = ct.infer_column_type(&cat_col.view());
952 assert_eq!(cat_type, DataType::Categorical);
953
954 let num_col =
956 scirs2_core::ndarray::array![1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8, 9.9, 10.0];
957 let num_type = ct.infer_column_type(&num_col.view());
958 assert_eq!(num_type, DataType::Numeric);
959 }
960
961 #[derive(Debug, Clone)]
963 struct FailingTransformer {
964 should_fail: bool,
965 }
966
967 impl TransformerWrapper for FailingTransformer {
968 fn fit_transform_wrapper(&self, x: &Array2<Float>) -> Result<Array2<Float>> {
969 if self.should_fail {
970 Err(SklearsError::InvalidInput(
971 "Intentional failure for testing".to_string(),
972 ))
973 } else {
974 Ok(x * 2.0)
975 }
976 }
977
978 fn transform_wrapper(&self, x: &Array2<Float>) -> Result<Array2<Float>> {
979 if self.should_fail {
980 Err(SklearsError::InvalidInput(
981 "Intentional failure for testing".to_string(),
982 ))
983 } else {
984 Ok(x * 2.0)
985 }
986 }
987
988 fn get_n_features_out(&self) -> Option<usize> {
989 None
990 }
991
992 fn clone_box(&self) -> Box<dyn TransformerWrapper> {
993 Box::new(self.clone())
994 }
995 }
996
997 #[test]
998 fn test_column_transformer_error_handling_stop_on_error() {
999 let x = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
1000
1001 let ct = ColumnTransformer::new()
1002 .add_transformer(
1003 "failing",
1004 FailingTransformer { should_fail: true },
1005 ColumnSelector::Indices(vec![0]),
1006 )
1007 .error_strategy(ColumnErrorStrategy::StopOnError);
1008
1009 let result = ct.fit(&x, &());
1010 assert!(result.is_err(), "Should fail with StopOnError");
1011 }
1012
1013 #[test]
1014 fn test_column_transformer_error_handling_skip_on_error() {
1015 let x = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
1016
1017 let ct = ColumnTransformer::new()
1018 .add_transformer(
1019 "failing",
1020 FailingTransformer { should_fail: true },
1021 ColumnSelector::Indices(vec![0]),
1022 )
1023 .add_transformer(
1024 "working",
1025 MockTransformer { scale: 2.0 },
1026 ColumnSelector::Indices(vec![1]),
1027 )
1028 .error_strategy(ColumnErrorStrategy::SkipOnError)
1029 .remainder(RemainderStrategy::Passthrough);
1030
1031 let fitted_ct = ct.fit(&x, &()).expect("model fitting should succeed");
1032 let result = fitted_ct
1033 .transform(&x)
1034 .expect("transformation should succeed");
1035
1036 assert_eq!(result.dim(), (2, 2));
1038 assert_eq!(result[[0, 0]], 4.0); assert_eq!(result[[0, 1]], 3.0); }
1041
1042 #[test]
1043 fn test_column_transformer_error_handling_replace_with_zeros() {
1044 let x = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
1045
1046 let ct = ColumnTransformer::new()
1047 .add_transformer(
1048 "failing",
1049 FailingTransformer { should_fail: true },
1050 ColumnSelector::Indices(vec![0]),
1051 )
1052 .error_strategy(ColumnErrorStrategy::ReplaceWithZeros)
1053 .remainder(RemainderStrategy::Passthrough);
1054
1055 let fitted_ct = ct.fit(&x, &()).expect("model fitting should succeed");
1056 let result = fitted_ct
1057 .transform(&x)
1058 .expect("transformation should succeed");
1059
1060 assert_eq!(result.dim(), (2, 3));
1062 assert_eq!(result[[0, 0]], 0.0); assert_eq!(result[[0, 1]], 2.0); assert_eq!(result[[0, 2]], 3.0); }
1066
1067 #[test]
1068 fn test_column_transformer_error_handling_fallback() {
1069 let x = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
1070
1071 let ct = ColumnTransformer::new()
1072 .add_transformer(
1073 "failing",
1074 FailingTransformer { should_fail: true },
1075 ColumnSelector::Indices(vec![0]),
1076 )
1077 .error_strategy(ColumnErrorStrategy::Fallback)
1078 .fallback_transformer(MockTransformer { scale: 0.5 })
1079 .remainder(RemainderStrategy::Passthrough);
1080
1081 let fitted_ct = ct.fit(&x, &()).expect("model fitting should succeed");
1082 let result = fitted_ct
1083 .transform(&x)
1084 .expect("transformation should succeed");
1085
1086 assert_eq!(result.dim(), (2, 3));
1088 assert_eq!(result[[0, 0]], 0.5); assert_eq!(result[[0, 1]], 2.0); assert_eq!(result[[0, 2]], 3.0); }
1092
1093 #[test]
1094 fn test_column_transformer_parallel_execution() {
1095 let x = array![[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0]];
1096
1097 let ct = ColumnTransformer::new()
1098 .add_transformer(
1099 "scale_first",
1100 MockTransformer { scale: 2.0 },
1101 ColumnSelector::Indices(vec![0]),
1102 )
1103 .add_transformer(
1104 "scale_second",
1105 MockTransformer { scale: 3.0 },
1106 ColumnSelector::Indices(vec![1]),
1107 )
1108 .parallel_execution(true)
1109 .remainder(RemainderStrategy::Passthrough);
1110
1111 let fitted_ct = ct.fit(&x, &()).expect("model fitting should succeed");
1112 let result = fitted_ct
1113 .transform(&x)
1114 .expect("transformation should succeed");
1115
1116 assert_eq!(result.dim(), (2, 4));
1118 assert_eq!(result[[0, 0]], 2.0); assert_eq!(result[[0, 1]], 6.0); assert_eq!(result[[0, 2]], 3.0); assert_eq!(result[[0, 3]], 4.0); }
1123
1124 #[test]
1125 fn test_column_transformer_dtype_selection() {
1126 let x = array![
1131 [1.23456, 0.0, 1.0],
1132 [2.78901, 1.0, 1.0],
1133 [3.45678, 0.0, 2.0],
1134 [4.98765, 1.0, 1.0],
1135 [5.12345, 0.0, 2.0],
1136 [6.67890, 1.0, 3.0],
1137 [7.11111, 0.0, 1.0],
1138 [8.22222, 1.0, 2.0],
1139 ];
1140
1141 let ct_bool = ColumnTransformer::new().add_transformer(
1143 "bool_transformer",
1144 MockTransformer { scale: 10.0 },
1145 ColumnSelector::DataType(DataType::Boolean),
1146 );
1147
1148 let fitted_ct_bool = ct_bool.fit(&x, &()).expect("model fitting should succeed");
1149 let result_bool = fitted_ct_bool
1150 .transform(&x)
1151 .expect("transformation should succeed");
1152
1153 assert_eq!(result_bool.dim(), (8, 1));
1155 assert_eq!(result_bool[[0, 0]], 0.0); assert_eq!(result_bool[[1, 0]], 10.0); let ct_cat = ColumnTransformer::new().add_transformer(
1160 "cat_transformer",
1161 MockTransformer { scale: 0.1 },
1162 ColumnSelector::DataType(DataType::Categorical),
1163 );
1164
1165 let fitted_ct_cat = ct_cat.fit(&x, &()).expect("model fitting should succeed");
1166 let result_cat = fitted_ct_cat
1167 .transform(&x)
1168 .expect("transformation should succeed");
1169
1170 assert_eq!(result_cat.dim(), (8, 1));
1172 assert_eq!(result_cat[[0, 0]], 0.1); let ct_num = ColumnTransformer::new().add_transformer(
1176 "num_transformer",
1177 MockTransformer { scale: 2.0 },
1178 ColumnSelector::DataType(DataType::Numeric),
1179 );
1180
1181 let fitted_ct_num = ct_num.fit(&x, &()).expect("model fitting should succeed");
1182 let result_num = fitted_ct_num
1183 .transform(&x)
1184 .expect("transformation should succeed");
1185
1186 assert_eq!(result_num.dim(), (8, 1));
1188 let expected_first = 1.23456 * 2.0;
1189 assert!((result_num[[0, 0]] - expected_first).abs() < 1e-10)
1190 }
1191}