sklears_preprocessing/
column_transformer.rs

1//! Column Transformer
2//!
3//! This module provides ColumnTransformer which applies different transformers
4//! to specific columns of a dataset.
5
6use scirs2_core::ndarray::{s, Array2, Axis};
7use sklears_core::{
8    error::{Result, SklearsError},
9    traits::{Estimator, Fit, Trained, Transform, Untrained},
10    types::Float,
11};
12use std::collections::HashMap;
13use std::marker::PhantomData;
14
15#[cfg(feature = "parallel")]
16use rayon::prelude::*;
17
18// For floating point comparison in HashSet
19#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
20struct OrderedFloat(u64);
21
22impl From<Float> for OrderedFloat {
23    fn from(val: Float) -> Self {
24        OrderedFloat(val.to_bits())
25    }
26}
27
28/// Column selector type
29#[derive(Debug, Clone)]
30pub enum ColumnSelector {
31    /// Select columns by indices
32    Indices(Vec<usize>),
33    /// Select columns by name (when working with named columns)
34    Names(Vec<String>),
35    /// Select columns by data type (would require runtime type checking)
36    DataType(DataType),
37    /// Select all remaining columns
38    Remainder,
39}
40
41/// Data type enum for column selection
42#[derive(Debug, Clone, PartialEq)]
43pub enum DataType {
44    Numeric,
45    Categorical,
46    Boolean,
47}
48
49/// Strategy for handling remaining columns
50#[derive(Debug, Clone)]
51pub enum RemainderStrategy {
52    /// Drop remaining columns
53    Drop,
54    /// Pass through remaining columns unchanged
55    Passthrough,
56    /// Apply a specific transformer to remaining columns
57    Transform(Box<dyn TransformerWrapper>),
58}
59
60/// Strategy for handling errors during column transformations
61#[derive(Debug, Clone, Copy, PartialEq)]
62pub enum ColumnErrorStrategy {
63    /// Stop on first error
64    StopOnError,
65    /// Skip failed transformers and continue with others
66    SkipOnError,
67    /// Use fallback transformer for failed columns
68    Fallback,
69    /// Replace failed columns with zeros
70    ReplaceWithZeros,
71    /// Replace failed columns with NaN values
72    ReplaceWithNaN,
73}
74
75impl Default for ColumnErrorStrategy {
76    fn default() -> Self {
77        Self::StopOnError
78    }
79}
80
81impl Default for RemainderStrategy {
82    fn default() -> Self {
83        Self::Drop
84    }
85}
86
87/// Trait for transformer wrappers to enable dynamic dispatch
88pub trait TransformerWrapper: Send + Sync + std::fmt::Debug {
89    fn fit_transform_wrapper(&self, x: &Array2<Float>) -> Result<Array2<Float>>;
90    fn transform_wrapper(&self, x: &Array2<Float>) -> Result<Array2<Float>>;
91    fn get_n_features_out(&self) -> Option<usize>;
92    fn clone_box(&self) -> Box<dyn TransformerWrapper>;
93}
94
95impl Clone for Box<dyn TransformerWrapper> {
96    fn clone(&self) -> Self {
97        self.clone_box()
98    }
99}
100
101/// A transformer step in the column transformer
102#[derive(Debug, Clone)]
103pub struct TransformerStep {
104    /// Name of the transformer step
105    pub name: String,
106    /// The column selector
107    pub columns: ColumnSelector,
108    /// The transformer (boxed for dynamic dispatch)
109    pub transformer: Box<dyn TransformerWrapper>,
110}
111
112/// Configuration for ColumnTransformer
113#[derive(Debug, Clone)]
114pub struct ColumnTransformerConfig {
115    /// Strategy for handling remaining columns
116    pub remainder: RemainderStrategy,
117    /// Whether to preserve column order in output
118    pub preserve_order: bool,
119    /// Whether to use parallel processing
120    pub n_jobs: Option<usize>,
121    /// Whether to validate input
122    pub validate_input: bool,
123    /// Strategy for handling transformation errors
124    pub error_strategy: ColumnErrorStrategy,
125    /// Enable parallel processing for transformers
126    pub parallel_execution: bool,
127    /// Fallback transformer for error handling
128    pub fallback_transformer: Option<Box<dyn TransformerWrapper>>,
129}
130
131impl Default for ColumnTransformerConfig {
132    fn default() -> Self {
133        Self {
134            remainder: RemainderStrategy::Drop,
135            preserve_order: false,
136            n_jobs: None,
137            validate_input: true,
138            error_strategy: ColumnErrorStrategy::StopOnError,
139            parallel_execution: false,
140            fallback_transformer: None,
141        }
142    }
143}
144
145/// ColumnTransformer applies different transformers to different columns
146#[derive(Debug)]
147pub struct ColumnTransformer<State = Untrained> {
148    config: ColumnTransformerConfig,
149    transformers: Vec<TransformerStep>,
150    state: PhantomData<State>,
151    // Fitted parameters
152    fitted_transformers_: Option<Vec<TransformerStep>>,
153    feature_names_in_: Option<Vec<String>>,
154    n_features_in_: Option<usize>,
155    output_indices_: Option<HashMap<String, Vec<usize>>>,
156    remainder_indices_: Option<Vec<usize>>,
157}
158
159/// Result of a column transformation attempt
160#[derive(Debug)]
161struct ColumnTransformResult {
162    transformer_name: String,
163    column_indices: Vec<usize>,
164    result: Result<Array2<Float>>,
165    original_indices: Vec<usize>,
166}
167
168// Shared methods for both Untrained and Trained states
169impl<State> ColumnTransformer<State> {
170    /// Apply transformer with error handling
171    fn apply_transformer_with_error_handling(
172        &self,
173        step: &TransformerStep,
174        _data: &Array2<Float>,
175        subset: &Array2<Float>,
176        is_fit_transform: bool,
177        resolved_indices: &[usize],
178    ) -> ColumnTransformResult {
179        let column_indices = resolved_indices.to_vec();
180
181        let transform_result = if is_fit_transform {
182            step.transformer.fit_transform_wrapper(subset)
183        } else {
184            step.transformer.transform_wrapper(subset)
185        };
186
187        let final_result = match transform_result {
188            Ok(transformed) => Ok(transformed),
189            Err(error) => {
190                // Apply error strategy
191                match self.config.error_strategy {
192                    ColumnErrorStrategy::StopOnError => Err(error),
193                    ColumnErrorStrategy::SkipOnError => {
194                        eprintln!(
195                            "Warning: Transformer '{}' failed on columns {:?}: {}. Skipping...",
196                            step.name, column_indices, error
197                        );
198                        // Return empty result to indicate skipping
199                        Ok(Array2::zeros((subset.nrows(), 0)))
200                    }
201                    ColumnErrorStrategy::Fallback => {
202                        if let Some(ref fallback) = self.config.fallback_transformer {
203                            eprintln!("Warning: Transformer '{}' failed on columns {:?}: {}. Using fallback...", 
204                                    step.name, column_indices, error);
205                            if is_fit_transform {
206                                fallback.fit_transform_wrapper(subset)
207                            } else {
208                                fallback.transform_wrapper(subset)
209                            }
210                        } else {
211                            eprintln!("Warning: Transformer '{}' failed on columns {:?}: {}. No fallback available, passing through...", 
212                                    step.name, column_indices, error);
213                            Ok(subset.clone())
214                        }
215                    }
216                    ColumnErrorStrategy::ReplaceWithZeros => {
217                        eprintln!("Warning: Transformer '{}' failed on columns {:?}: {}. Replacing with zeros...", 
218                                step.name, column_indices, error);
219                        Ok(Array2::zeros(subset.dim()))
220                    }
221                    ColumnErrorStrategy::ReplaceWithNaN => {
222                        eprintln!("Warning: Transformer '{}' failed on columns {:?}: {}. Replacing with NaN...", 
223                                step.name, column_indices, error);
224                        Ok(Array2::from_elem(subset.dim(), Float::NAN))
225                    }
226                }
227            }
228        };
229
230        ColumnTransformResult {
231            transformer_name: step.name.clone(),
232            column_indices: column_indices.clone(),
233            result: final_result,
234            original_indices: column_indices,
235        }
236    }
237}
238
239impl ColumnTransformer<Untrained> {
240    /// Create a new ColumnTransformer
241    pub fn new() -> Self {
242        Self {
243            config: ColumnTransformerConfig::default(),
244            transformers: Vec::new(),
245            state: PhantomData,
246            fitted_transformers_: None,
247            feature_names_in_: None,
248            n_features_in_: None,
249            output_indices_: None,
250            remainder_indices_: None,
251        }
252    }
253
254    /// Add a transformer for specific columns
255    pub fn add_transformer<T>(mut self, name: &str, transformer: T, columns: ColumnSelector) -> Self
256    where
257        T: TransformerWrapper + 'static,
258    {
259        self.transformers.push(TransformerStep {
260            name: name.to_string(),
261            columns,
262            transformer: Box::new(transformer),
263        });
264        self
265    }
266
267    /// Set the remainder strategy
268    pub fn remainder(mut self, strategy: RemainderStrategy) -> Self {
269        self.config.remainder = strategy;
270        self
271    }
272
273    /// Set whether to preserve column order
274    pub fn preserve_order(mut self, preserve: bool) -> Self {
275        self.config.preserve_order = preserve;
276        self
277    }
278
279    /// Set number of parallel jobs
280    pub fn n_jobs(mut self, n_jobs: Option<usize>) -> Self {
281        self.config.n_jobs = n_jobs;
282        self
283    }
284
285    /// Set input validation
286    pub fn validate_input(mut self, validate: bool) -> Self {
287        self.config.validate_input = validate;
288        self
289    }
290
291    /// Set error handling strategy
292    pub fn error_strategy(mut self, strategy: ColumnErrorStrategy) -> Self {
293        self.config.error_strategy = strategy;
294        self
295    }
296
297    /// Enable/disable parallel execution
298    pub fn parallel_execution(mut self, parallel: bool) -> Self {
299        self.config.parallel_execution = parallel;
300        self
301    }
302
303    /// Set fallback transformer for error handling
304    pub fn fallback_transformer<T>(mut self, transformer: T) -> Self
305    where
306        T: TransformerWrapper + 'static,
307    {
308        self.config.fallback_transformer = Some(Box::new(transformer));
309        self
310    }
311
312    /// Resolve column indices from selectors
313    fn resolve_columns(&self, selector: &ColumnSelector, n_features: usize) -> Result<Vec<usize>> {
314        match selector {
315            ColumnSelector::Indices(indices) => {
316                // Validate indices
317                for &idx in indices {
318                    if idx >= n_features {
319                        return Err(SklearsError::InvalidInput(format!(
320                            "Column index {} is out of bounds for {} features",
321                            idx, n_features
322                        )));
323                    }
324                }
325                Ok(indices.clone())
326            }
327            ColumnSelector::Names(_names) => {
328                // For now, return error as we need named column support
329                Err(SklearsError::NotImplemented(
330                    "Named column selection not yet implemented".to_string(),
331                ))
332            }
333            ColumnSelector::DataType(_dtype) => {
334                // DataType selection requires training data, handled in resolve_columns_with_data
335                Err(SklearsError::InvalidInput(
336                    "DataType column selection requires training data. Use resolve_columns_with_data.".to_string(),
337                ))
338            }
339            ColumnSelector::Remainder => {
340                // This should be handled separately
341                Ok(Vec::new())
342            }
343        }
344    }
345
346    /// Resolve column indices from selectors with access to training data
347    fn resolve_columns_with_data(
348        &self,
349        selector: &ColumnSelector,
350        data: &Array2<Float>,
351    ) -> Result<Vec<usize>> {
352        let (_, n_features) = data.dim();
353
354        match selector {
355            ColumnSelector::Indices(indices) => {
356                // Validate indices
357                for &idx in indices {
358                    if idx >= n_features {
359                        return Err(SklearsError::InvalidInput(format!(
360                            "Column index {} is out of bounds for {} features",
361                            idx, n_features
362                        )));
363                    }
364                }
365                Ok(indices.clone())
366            }
367            ColumnSelector::Names(_names) => {
368                // For now, return error as we need named column support
369                Err(SklearsError::NotImplemented(
370                    "Named column selection not yet implemented".to_string(),
371                ))
372            }
373            ColumnSelector::DataType(dtype) => self.infer_columns_by_dtype_with_data(dtype, data),
374            ColumnSelector::Remainder => {
375                // This should be handled separately
376                Ok(Vec::new())
377            }
378        }
379    }
380
381    /// Infer column indices by data type using heuristics (without training data)
382    fn infer_columns_by_dtype(&self, _dtype: &DataType, _n_features: usize) -> Result<Vec<usize>> {
383        // This method cannot work without training data
384        Err(SklearsError::InvalidInput(
385            "Data type column selection requires training data context. \
386             Use resolve_columns_with_data during fit."
387                .to_string(),
388        ))
389    }
390
391    /// Infer column indices by data type using heuristics on training data
392    fn infer_columns_by_dtype_with_data(
393        &self,
394        dtype: &DataType,
395        data: &Array2<Float>,
396    ) -> Result<Vec<usize>> {
397        let (_n_samples, n_features) = data.dim();
398        let mut matching_columns = Vec::new();
399
400        for col_idx in 0..n_features {
401            let column = data.column(col_idx);
402            let column_type = self.infer_column_type(&column);
403
404            if column_type == *dtype {
405                matching_columns.push(col_idx);
406            }
407        }
408
409        Ok(matching_columns)
410    }
411
412    /// Infer the data type of a single column using heuristics
413    fn infer_column_type(&self, column: &scirs2_core::ndarray::ArrayView1<Float>) -> DataType {
414        let unique_values: std::collections::HashSet<_> =
415            column.iter().map(|&x| OrderedFloat::from(x)).collect();
416
417        let n_unique = unique_values.len();
418        let n_total = column.len();
419
420        // Check if column is boolean (only 0.0 and 1.0 values)
421        if n_unique <= 2 {
422            let zero_bits = OrderedFloat::from(0.0);
423            let one_bits = OrderedFloat::from(1.0);
424            if unique_values
425                .iter()
426                .all(|&x| x == zero_bits || x == one_bits)
427            {
428                return DataType::Boolean;
429            }
430        }
431
432        // Heuristic for categorical vs numeric
433        // If the ratio of unique values to total values is low, consider it categorical
434        let unique_ratio = n_unique as f64 / n_total as f64;
435
436        // Use a more balanced approach: categorical if either condition is met
437        // but with more conservative thresholds
438        if (unique_ratio < 0.6 && n_unique <= 5) || unique_ratio < 0.2 {
439            DataType::Categorical
440        } else {
441            DataType::Numeric
442        }
443    }
444
445    /// Get indices of columns that are not selected by any transformer
446    fn get_remainder_indices(&self, data: &Array2<Float>) -> Result<Vec<usize>> {
447        let (_, n_features) = data.dim();
448        let mut used_indices = std::collections::HashSet::new();
449
450        // Collect all used indices
451        for step in &self.transformers {
452            let indices = match &step.columns {
453                ColumnSelector::DataType(_) => {
454                    self.resolve_columns_with_data(&step.columns, data)?
455                }
456                _ => self.resolve_columns(&step.columns, n_features)?,
457            };
458            for idx in indices {
459                used_indices.insert(idx);
460            }
461        }
462
463        // Return unused indices
464        Ok((0..n_features)
465            .filter(|i| !used_indices.contains(i))
466            .collect())
467    }
468}
469
470impl Default for ColumnTransformer<Untrained> {
471    fn default() -> Self {
472        Self::new()
473    }
474}
475
476impl Estimator for ColumnTransformer<Untrained> {
477    type Config = ColumnTransformerConfig;
478    type Error = SklearsError;
479    type Float = Float;
480
481    fn config(&self) -> &Self::Config {
482        &self.config
483    }
484}
485
486impl Estimator for ColumnTransformer<Trained> {
487    type Config = ColumnTransformerConfig;
488    type Error = SklearsError;
489    type Float = Float;
490
491    fn config(&self) -> &Self::Config {
492        &self.config
493    }
494}
495
496impl Fit<Array2<Float>, ()> for ColumnTransformer<Untrained> {
497    type Fitted = ColumnTransformer<Trained>;
498
499    fn fit(self, x: &Array2<Float>, _y: &()) -> Result<Self::Fitted> {
500        let (n_samples, n_features) = x.dim();
501
502        if n_samples == 0 {
503            return Err(SklearsError::InvalidInput(
504                "Cannot fit transformer on empty dataset".to_string(),
505            ));
506        }
507
508        // Get remainder indices
509        let remainder_indices = self.get_remainder_indices(x)?;
510
511        // Prepare transformer steps with resolved indices
512        let mut transformer_tasks: Vec<(TransformerStep, Vec<usize>)> = Vec::new();
513
514        for step in &self.transformers {
515            // Resolve column indices - use data-aware method for DataType selectors
516            let indices = match &step.columns {
517                ColumnSelector::DataType(_) => self.resolve_columns_with_data(&step.columns, x)?,
518                _ => self.resolve_columns(&step.columns, n_features)?,
519            };
520
521            if !indices.is_empty() {
522                transformer_tasks.push((step.clone(), indices));
523            }
524        }
525
526        // Apply transformers with parallel processing and error handling
527        let transform_results: Vec<ColumnTransformResult> = if self.config.parallel_execution
528            && transformer_tasks.len() > 1
529        {
530            #[cfg(feature = "parallel")]
531            {
532                transformer_tasks
533                    .into_par_iter()
534                    .map(|(step, indices)| {
535                        let subset = x.select(Axis(1), &indices);
536                        self.apply_transformer_with_error_handling(
537                            &step, x, &subset, true, &indices,
538                        )
539                    })
540                    .collect()
541            }
542            #[cfg(not(feature = "parallel"))]
543            {
544                // Fallback to sequential processing
545                transformer_tasks
546                    .into_iter()
547                    .map(|(step, indices)| {
548                        let subset = x.select(Axis(1), &indices);
549                        self.apply_transformer_with_error_handling(
550                            &step, x, &subset, true, &indices,
551                        )
552                    })
553                    .collect()
554            }
555        } else {
556            // Sequential processing
557            transformer_tasks
558                .into_iter()
559                .map(|(step, indices)| {
560                    let subset = x.select(Axis(1), &indices);
561                    self.apply_transformer_with_error_handling(&step, x, &subset, true, &indices)
562                })
563                .collect()
564        };
565
566        // Process results and create fitted transformers
567        let mut fitted_transformers = Vec::new();
568        let mut output_indices = HashMap::new();
569
570        for transform_result in transform_results {
571            match transform_result.result {
572                Ok(transformed) => {
573                    if transformed.ncols() > 0 {
574                        // Skip empty results (from SkipOnError)
575                        // Store output indices mapping
576                        let output_cols = (0..transformed.ncols()).collect();
577                        output_indices
578                            .insert(transform_result.transformer_name.clone(), output_cols);
579
580                        // Create fitted transformer step
581                        let transformer_name = transform_result.transformer_name.clone();
582                        fitted_transformers.push(TransformerStep {
583                            name: transformer_name.clone(),
584                            columns: ColumnSelector::Indices(transform_result.original_indices),
585                            transformer: self
586                                .transformers
587                                .iter()
588                                .find(|s| s.name == transformer_name)
589                                .unwrap()
590                                .transformer
591                                .clone_box(),
592                        });
593                    }
594                }
595                Err(e) => {
596                    // If we reach here, it means StopOnError was used
597                    return Err(SklearsError::TransformError(format!(
598                        "Transformer '{}' failed: {}",
599                        transform_result.transformer_name, e
600                    )));
601                }
602            }
603        }
604
605        Ok(ColumnTransformer {
606            config: self.config,
607            transformers: self.transformers,
608            state: PhantomData,
609            fitted_transformers_: Some(fitted_transformers),
610            feature_names_in_: None,
611            n_features_in_: Some(n_features),
612            output_indices_: Some(output_indices),
613            remainder_indices_: Some(remainder_indices),
614        })
615    }
616}
617
618impl Transform<Array2<Float>, Array2<Float>> for ColumnTransformer<Trained> {
619    fn transform(&self, x: &Array2<Float>) -> Result<Array2<Float>> {
620        let (n_samples, n_features) = x.dim();
621
622        if Some(n_features) != self.n_features_in_ {
623            return Err(SklearsError::FeatureMismatch {
624                expected: self.n_features_in_.unwrap_or(0),
625                actual: n_features,
626            });
627        }
628
629        let fitted_transformers = self.fitted_transformers_.as_ref().unwrap();
630        let remainder_indices = self.remainder_indices_.as_ref().unwrap();
631
632        // Prepare transformer tasks for parallel processing
633        let transformer_tasks: Vec<&TransformerStep> = fitted_transformers.iter().collect();
634
635        // Apply transformers with parallel processing and error handling
636        let transform_results: Vec<ColumnTransformResult> =
637            if self.config.parallel_execution && transformer_tasks.len() > 1 {
638                #[cfg(feature = "parallel")]
639                {
640                    transformer_tasks
641                        .into_par_iter()
642                        .filter_map(|step| {
643                            if let ColumnSelector::Indices(indices) = &step.columns {
644                                if !indices.is_empty() {
645                                    let subset = x.select(Axis(1), indices);
646                                    Some(self.apply_transformer_with_error_handling(
647                                        step, x, &subset, false, indices,
648                                    ))
649                                } else {
650                                    None
651                                }
652                            } else {
653                                None
654                            }
655                        })
656                        .collect()
657                }
658                #[cfg(not(feature = "parallel"))]
659                {
660                    // Fallback to sequential processing
661                    transformer_tasks
662                        .into_iter()
663                        .filter_map(|step| {
664                            if let ColumnSelector::Indices(indices) = &step.columns {
665                                if !indices.is_empty() {
666                                    let subset = x.select(Axis(1), indices);
667                                    Some(self.apply_transformer_with_error_handling(
668                                        step, x, &subset, false, indices,
669                                    ))
670                                } else {
671                                    None
672                                }
673                            } else {
674                                None
675                            }
676                        })
677                        .collect()
678                }
679            } else {
680                // Sequential processing
681                transformer_tasks
682                    .into_iter()
683                    .filter_map(|step| {
684                        if let ColumnSelector::Indices(indices) = &step.columns {
685                            if !indices.is_empty() {
686                                let subset = x.select(Axis(1), indices);
687                                Some(self.apply_transformer_with_error_handling(
688                                    step, x, &subset, false, indices,
689                                ))
690                            } else {
691                                None
692                            }
693                        } else {
694                            None
695                        }
696                    })
697                    .collect()
698            };
699
700        // Process results and create column outputs
701        let mut column_outputs: Vec<(usize, Array2<Float>)> = Vec::new();
702
703        for transform_result in transform_results {
704            match transform_result.result {
705                Ok(transformed) => {
706                    if transformed.ncols() > 0 {
707                        // Skip empty results (from SkipOnError)
708                        // For each original column index, store its min value to maintain order
709                        let min_index = *transform_result.original_indices.iter().min().unwrap();
710                        column_outputs.push((min_index, transformed));
711                    }
712                }
713                Err(e) => {
714                    // If we reach here, it means StopOnError was used
715                    return Err(SklearsError::TransformError(format!(
716                        "Transformer '{}' failed: {}",
717                        transform_result.transformer_name, e
718                    )));
719                }
720            }
721        }
722
723        // Handle remainder columns
724        if !remainder_indices.is_empty() {
725            let remainder_data = x.select(Axis(1), remainder_indices);
726
727            let transformed_remainder = match &self.config.remainder {
728                RemainderStrategy::Drop => {
729                    None // remainder is dropped
730                }
731                RemainderStrategy::Passthrough => Some(remainder_data),
732                RemainderStrategy::Transform(transformer) => {
733                    let transformed = transformer.transform_wrapper(&remainder_data)?;
734                    Some(transformed)
735                }
736            };
737
738            if let Some(remainder_output) = transformed_remainder {
739                // Add remainder with the minimum remainder index
740                if let Some(&min_remainder_index) = remainder_indices.iter().min() {
741                    column_outputs.push((min_remainder_index, remainder_output));
742                }
743            }
744        }
745
746        // Sort by original column indices to maintain proper ordering
747        column_outputs.sort_by_key(|(idx, _)| *idx);
748
749        // Concatenate all output parts in the correct order
750        if column_outputs.is_empty() {
751            return Err(SklearsError::InvalidInput(
752                "No output from any transformer".to_string(),
753            ));
754        }
755
756        // Calculate total columns
757        let total_cols: usize = column_outputs.iter().map(|(_, arr)| arr.ncols()).sum();
758        let mut result = Array2::zeros((n_samples, total_cols));
759
760        // Concatenate in order
761        let mut col_offset = 0;
762        for (_, part) in column_outputs {
763            let part_cols = part.ncols();
764            result
765                .slice_mut(s![.., col_offset..col_offset + part_cols])
766                .assign(&part);
767            col_offset += part_cols;
768        }
769
770        Ok(result)
771    }
772}
773
774impl ColumnTransformer<Trained> {
775    /// Get the number of features seen during fitting
776    pub fn n_features_in(&self) -> usize {
777        self.n_features_in_.unwrap()
778    }
779
780    /// Get the output indices mapping
781    pub fn output_indices(&self) -> &HashMap<String, Vec<usize>> {
782        self.output_indices_.as_ref().unwrap()
783    }
784
785    /// Get the remainder indices
786    pub fn remainder_indices(&self) -> &Vec<usize> {
787        self.remainder_indices_.as_ref().unwrap()
788    }
789}
790
791#[allow(non_snake_case)]
792#[cfg(test)]
793mod tests {
794    use super::*;
795    use scirs2_core::ndarray::array;
796
797    // Mock transformer for testing
798    #[derive(Debug, Clone)]
799    struct MockTransformer {
800        scale: Float,
801    }
802
803    impl TransformerWrapper for MockTransformer {
804        fn fit_transform_wrapper(&self, x: &Array2<Float>) -> Result<Array2<Float>> {
805            Ok(x * self.scale)
806        }
807
808        fn transform_wrapper(&self, x: &Array2<Float>) -> Result<Array2<Float>> {
809            Ok(x * self.scale)
810        }
811
812        fn get_n_features_out(&self) -> Option<usize> {
813            None // Same as input
814        }
815
816        fn clone_box(&self) -> Box<dyn TransformerWrapper> {
817            Box::new(self.clone())
818        }
819    }
820
821    #[test]
822    fn test_column_transformer_basic() {
823        let x = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0],];
824
825        let ct = ColumnTransformer::new()
826            .add_transformer(
827                "scale_first_two",
828                MockTransformer { scale: 2.0 },
829                ColumnSelector::Indices(vec![0, 1]),
830            )
831            .remainder(RemainderStrategy::Passthrough);
832
833        let fitted_ct = ct.fit(&x, &()).unwrap();
834        let result = fitted_ct.transform(&x).unwrap();
835
836        // First two columns should be scaled by 2, last column passed through
837        assert_eq!(result.dim(), (3, 3));
838        assert_eq!(result[[0, 0]], 2.0); // 1.0 * 2.0
839        assert_eq!(result[[0, 1]], 4.0); // 2.0 * 2.0
840        assert_eq!(result[[0, 2]], 3.0); // 3.0 (passthrough)
841    }
842
843    #[test]
844    fn test_column_transformer_drop_remainder() {
845        let x = array![[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0],];
846
847        let ct = ColumnTransformer::new()
848            .add_transformer(
849                "scale_middle",
850                MockTransformer { scale: 3.0 },
851                ColumnSelector::Indices(vec![1, 2]),
852            )
853            .remainder(RemainderStrategy::Drop);
854
855        let fitted_ct = ct.fit(&x, &()).unwrap();
856        let result = fitted_ct.transform(&x).unwrap();
857
858        // Only middle two columns should remain (scaled by 3)
859        assert_eq!(result.dim(), (2, 2));
860        assert_eq!(result[[0, 0]], 6.0); // 2.0 * 3.0
861        assert_eq!(result[[0, 1]], 9.0); // 3.0 * 3.0
862    }
863
864    #[test]
865    fn test_column_transformer_multiple_transformers() {
866        let x = array![[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0],];
867
868        let ct = ColumnTransformer::new()
869            .add_transformer(
870                "scale_first",
871                MockTransformer { scale: 2.0 },
872                ColumnSelector::Indices(vec![0]),
873            )
874            .add_transformer(
875                "scale_last",
876                MockTransformer { scale: 0.5 },
877                ColumnSelector::Indices(vec![3]),
878            )
879            .remainder(RemainderStrategy::Passthrough);
880
881        let fitted_ct = ct.fit(&x, &()).unwrap();
882        let result = fitted_ct.transform(&x).unwrap();
883
884        // Should have 4 columns: [scaled_first, middle_two_passthrough, scaled_last]
885        assert_eq!(result.dim(), (2, 4));
886        assert_eq!(result[[0, 0]], 2.0); // 1.0 * 2.0 (first transformer)
887        assert_eq!(result[[0, 1]], 2.0); // 2.0 (passthrough)
888        assert_eq!(result[[0, 2]], 3.0); // 3.0 (passthrough)
889        assert_eq!(result[[0, 3]], 2.0); // 4.0 * 0.5 (second transformer)
890    }
891
892    #[test]
893    fn test_column_transformer_empty_data() {
894        let x_empty: Array2<Float> = Array2::zeros((0, 3));
895
896        let ct = ColumnTransformer::new().add_transformer(
897            "test",
898            MockTransformer { scale: 1.0 },
899            ColumnSelector::Indices(vec![0]),
900        );
901
902        let result = ct.fit(&x_empty, &());
903        assert!(result.is_err());
904    }
905
906    #[test]
907    fn test_column_transformer_invalid_indices() {
908        let x = array![[1.0, 2.0], [3.0, 4.0],];
909
910        let ct = ColumnTransformer::new().add_transformer(
911            "invalid",
912            MockTransformer { scale: 1.0 },
913            ColumnSelector::Indices(vec![0, 5]), // Index 5 doesn't exist
914        );
915
916        let result = ct.fit(&x, &());
917        assert!(result.is_err());
918    }
919
920    #[test]
921    fn test_column_type_inference() {
922        let ct = ColumnTransformer::new();
923
924        // Test boolean detection (strict 0.0 and 1.0 only)
925        let bool_col = scirs2_core::ndarray::array![0.0, 1.0, 0.0, 1.0, 0.0];
926        let bool_type = ct.infer_column_type(&bool_col.view());
927        assert_eq!(bool_type, DataType::Boolean);
928
929        // Test categorical detection (few unique values)
930        let cat_col = scirs2_core::ndarray::array![1.0, 2.0, 1.0, 3.0, 2.0, 1.0, 2.0, 3.0];
931        let cat_type = ct.infer_column_type(&cat_col.view());
932        assert_eq!(cat_type, DataType::Categorical);
933
934        // Test numeric detection (many unique values)
935        let num_col =
936            scirs2_core::ndarray::array![1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8, 9.9, 10.0];
937        let num_type = ct.infer_column_type(&num_col.view());
938        assert_eq!(num_type, DataType::Numeric);
939    }
940
941    // Failing transformer for testing error handling
942    #[derive(Debug, Clone)]
943    struct FailingTransformer {
944        should_fail: bool,
945    }
946
947    impl TransformerWrapper for FailingTransformer {
948        fn fit_transform_wrapper(&self, x: &Array2<Float>) -> Result<Array2<Float>> {
949            if self.should_fail {
950                Err(SklearsError::InvalidInput(
951                    "Intentional failure for testing".to_string(),
952                ))
953            } else {
954                Ok(x * 2.0)
955            }
956        }
957
958        fn transform_wrapper(&self, x: &Array2<Float>) -> Result<Array2<Float>> {
959            if self.should_fail {
960                Err(SklearsError::InvalidInput(
961                    "Intentional failure for testing".to_string(),
962                ))
963            } else {
964                Ok(x * 2.0)
965            }
966        }
967
968        fn get_n_features_out(&self) -> Option<usize> {
969            None
970        }
971
972        fn clone_box(&self) -> Box<dyn TransformerWrapper> {
973            Box::new(self.clone())
974        }
975    }
976
977    #[test]
978    fn test_column_transformer_error_handling_stop_on_error() {
979        let x = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
980
981        let ct = ColumnTransformer::new()
982            .add_transformer(
983                "failing",
984                FailingTransformer { should_fail: true },
985                ColumnSelector::Indices(vec![0]),
986            )
987            .error_strategy(ColumnErrorStrategy::StopOnError);
988
989        let result = ct.fit(&x, &());
990        assert!(result.is_err(), "Should fail with StopOnError");
991    }
992
993    #[test]
994    fn test_column_transformer_error_handling_skip_on_error() {
995        let x = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
996
997        let ct = ColumnTransformer::new()
998            .add_transformer(
999                "failing",
1000                FailingTransformer { should_fail: true },
1001                ColumnSelector::Indices(vec![0]),
1002            )
1003            .add_transformer(
1004                "working",
1005                MockTransformer { scale: 2.0 },
1006                ColumnSelector::Indices(vec![1]),
1007            )
1008            .error_strategy(ColumnErrorStrategy::SkipOnError)
1009            .remainder(RemainderStrategy::Passthrough);
1010
1011        let fitted_ct = ct.fit(&x, &()).unwrap();
1012        let result = fitted_ct.transform(&x).unwrap();
1013
1014        // Should have 2 columns: working transformer output + remainder
1015        assert_eq!(result.dim(), (2, 2));
1016        assert_eq!(result[[0, 0]], 4.0); // 2.0 * 2.0 (working transformer)
1017        assert_eq!(result[[0, 1]], 3.0); // 3.0 (remainder passthrough)
1018    }
1019
1020    #[test]
1021    fn test_column_transformer_error_handling_replace_with_zeros() {
1022        let x = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
1023
1024        let ct = ColumnTransformer::new()
1025            .add_transformer(
1026                "failing",
1027                FailingTransformer { should_fail: true },
1028                ColumnSelector::Indices(vec![0]),
1029            )
1030            .error_strategy(ColumnErrorStrategy::ReplaceWithZeros)
1031            .remainder(RemainderStrategy::Passthrough);
1032
1033        let fitted_ct = ct.fit(&x, &()).unwrap();
1034        let result = fitted_ct.transform(&x).unwrap();
1035
1036        // Should have 3 columns: zeros (replacement) + remainder passthrough
1037        assert_eq!(result.dim(), (2, 3));
1038        assert_eq!(result[[0, 0]], 0.0); // Replaced with zero
1039        assert_eq!(result[[0, 1]], 2.0); // Remainder passthrough
1040        assert_eq!(result[[0, 2]], 3.0); // Remainder passthrough
1041    }
1042
1043    #[test]
1044    fn test_column_transformer_error_handling_fallback() {
1045        let x = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
1046
1047        let ct = ColumnTransformer::new()
1048            .add_transformer(
1049                "failing",
1050                FailingTransformer { should_fail: true },
1051                ColumnSelector::Indices(vec![0]),
1052            )
1053            .error_strategy(ColumnErrorStrategy::Fallback)
1054            .fallback_transformer(MockTransformer { scale: 0.5 })
1055            .remainder(RemainderStrategy::Passthrough);
1056
1057        let fitted_ct = ct.fit(&x, &()).unwrap();
1058        let result = fitted_ct.transform(&x).unwrap();
1059
1060        // Should have 3 columns: fallback transformer output + remainder
1061        assert_eq!(result.dim(), (2, 3));
1062        assert_eq!(result[[0, 0]], 0.5); // 1.0 * 0.5 (fallback)
1063        assert_eq!(result[[0, 1]], 2.0); // Remainder passthrough
1064        assert_eq!(result[[0, 2]], 3.0); // Remainder passthrough
1065    }
1066
1067    #[test]
1068    fn test_column_transformer_parallel_execution() {
1069        let x = array![[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0]];
1070
1071        let ct = ColumnTransformer::new()
1072            .add_transformer(
1073                "scale_first",
1074                MockTransformer { scale: 2.0 },
1075                ColumnSelector::Indices(vec![0]),
1076            )
1077            .add_transformer(
1078                "scale_second",
1079                MockTransformer { scale: 3.0 },
1080                ColumnSelector::Indices(vec![1]),
1081            )
1082            .parallel_execution(true)
1083            .remainder(RemainderStrategy::Passthrough);
1084
1085        let fitted_ct = ct.fit(&x, &()).unwrap();
1086        let result = fitted_ct.transform(&x).unwrap();
1087
1088        // Should have 4 columns: 2 transformed + 2 remainder
1089        assert_eq!(result.dim(), (2, 4));
1090        assert_eq!(result[[0, 0]], 2.0); // 1.0 * 2.0
1091        assert_eq!(result[[0, 1]], 6.0); // 2.0 * 3.0
1092        assert_eq!(result[[0, 2]], 3.0); // Remainder
1093        assert_eq!(result[[0, 3]], 4.0); // Remainder
1094    }
1095
1096    #[test]
1097    fn test_column_transformer_dtype_selection() {
1098        // Create data with very clear column types:
1099        // Col 0: Numeric (many unique continuous values)
1100        // Col 1: Boolean (strict 0.0/1.0 only)
1101        // Col 2: Categorical (few repeated values)
1102        let x = array![
1103            [1.23456, 0.0, 1.0],
1104            [2.78901, 1.0, 1.0],
1105            [3.45678, 0.0, 2.0],
1106            [4.98765, 1.0, 1.0],
1107            [5.12345, 0.0, 2.0],
1108            [6.67890, 1.0, 3.0],
1109            [7.11111, 0.0, 1.0],
1110            [8.22222, 1.0, 2.0],
1111        ];
1112
1113        // Test Boolean column selection
1114        let ct_bool = ColumnTransformer::new().add_transformer(
1115            "bool_transformer",
1116            MockTransformer { scale: 10.0 },
1117            ColumnSelector::DataType(DataType::Boolean),
1118        );
1119
1120        let fitted_ct_bool = ct_bool.fit(&x, &()).unwrap();
1121        let result_bool = fitted_ct_bool.transform(&x).unwrap();
1122
1123        // Should have 1 column (boolean column scaled by 10)
1124        assert_eq!(result_bool.dim(), (8, 1));
1125        assert_eq!(result_bool[[0, 0]], 0.0); // 0.0 * 10.0
1126        assert_eq!(result_bool[[1, 0]], 10.0); // 1.0 * 10.0
1127
1128        // Test Categorical column selection
1129        let ct_cat = ColumnTransformer::new().add_transformer(
1130            "cat_transformer",
1131            MockTransformer { scale: 0.1 },
1132            ColumnSelector::DataType(DataType::Categorical),
1133        );
1134
1135        let fitted_ct_cat = ct_cat.fit(&x, &()).unwrap();
1136        let result_cat = fitted_ct_cat.transform(&x).unwrap();
1137
1138        // Should have 1 column (categorical column scaled by 0.1)
1139        assert_eq!(result_cat.dim(), (8, 1));
1140        assert_eq!(result_cat[[0, 0]], 0.1); // 1.0 * 0.1
1141
1142        // Test Numeric column selection
1143        let ct_num = ColumnTransformer::new().add_transformer(
1144            "num_transformer",
1145            MockTransformer { scale: 2.0 },
1146            ColumnSelector::DataType(DataType::Numeric),
1147        );
1148
1149        let fitted_ct_num = ct_num.fit(&x, &()).unwrap();
1150        let result_num = fitted_ct_num.transform(&x).unwrap();
1151
1152        // Should have 1 column (numeric column scaled by 2.0)
1153        assert_eq!(result_num.dim(), (8, 1));
1154        let expected_first = 1.23456 * 2.0;
1155        assert!((result_num[[0, 0]] - expected_first).abs() < 1e-10)
1156    }
1157}