Skip to main content

sklears_preprocessing/
column_transformer.rs

1//! Column Transformer
2//!
3//! This module provides ColumnTransformer which applies different transformers
4//! to specific columns of a dataset.
5
6use scirs2_core::ndarray::{s, Array2, Axis};
7use sklears_core::{
8    error::{Result, SklearsError},
9    traits::{Estimator, Fit, Trained, Transform, Untrained},
10    types::Float,
11};
12use std::collections::HashMap;
13use std::marker::PhantomData;
14
15#[cfg(feature = "parallel")]
16use rayon::prelude::*;
17
18// For floating point comparison in HashSet
19#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
20struct OrderedFloat(u64);
21
22impl From<Float> for OrderedFloat {
23    fn from(val: Float) -> Self {
24        OrderedFloat(val.to_bits())
25    }
26}
27
28/// Column selector type
29#[derive(Debug, Clone)]
30pub enum ColumnSelector {
31    /// Select columns by indices
32    Indices(Vec<usize>),
33    /// Select columns by name (when working with named columns)
34    Names(Vec<String>),
35    /// Select columns by data type (would require runtime type checking)
36    DataType(DataType),
37    /// Select all remaining columns
38    Remainder,
39}
40
41/// Data type enum for column selection
42#[derive(Debug, Clone, PartialEq)]
43pub enum DataType {
44    Numeric,
45    Categorical,
46    Boolean,
47}
48
49/// Strategy for handling remaining columns
50#[derive(Debug, Clone)]
51pub enum RemainderStrategy {
52    /// Drop remaining columns
53    Drop,
54    /// Pass through remaining columns unchanged
55    Passthrough,
56    /// Apply a specific transformer to remaining columns
57    Transform(Box<dyn TransformerWrapper>),
58}
59
60/// Strategy for handling errors during column transformations
61#[derive(Debug, Clone, Copy, PartialEq)]
62pub enum ColumnErrorStrategy {
63    /// Stop on first error
64    StopOnError,
65    /// Skip failed transformers and continue with others
66    SkipOnError,
67    /// Use fallback transformer for failed columns
68    Fallback,
69    /// Replace failed columns with zeros
70    ReplaceWithZeros,
71    /// Replace failed columns with NaN values
72    ReplaceWithNaN,
73}
74
75impl Default for ColumnErrorStrategy {
76    fn default() -> Self {
77        Self::StopOnError
78    }
79}
80
81impl Default for RemainderStrategy {
82    fn default() -> Self {
83        Self::Drop
84    }
85}
86
87/// Trait for transformer wrappers to enable dynamic dispatch
88pub trait TransformerWrapper: Send + Sync + std::fmt::Debug {
89    fn fit_transform_wrapper(&self, x: &Array2<Float>) -> Result<Array2<Float>>;
90    fn transform_wrapper(&self, x: &Array2<Float>) -> Result<Array2<Float>>;
91    fn get_n_features_out(&self) -> Option<usize>;
92    fn clone_box(&self) -> Box<dyn TransformerWrapper>;
93}
94
95impl Clone for Box<dyn TransformerWrapper> {
96    fn clone(&self) -> Self {
97        self.clone_box()
98    }
99}
100
101/// A transformer step in the column transformer
102#[derive(Debug, Clone)]
103pub struct TransformerStep {
104    /// Name of the transformer step
105    pub name: String,
106    /// The column selector
107    pub columns: ColumnSelector,
108    /// The transformer (boxed for dynamic dispatch)
109    pub transformer: Box<dyn TransformerWrapper>,
110}
111
112/// Configuration for ColumnTransformer
113#[derive(Debug, Clone)]
114pub struct ColumnTransformerConfig {
115    /// Strategy for handling remaining columns
116    pub remainder: RemainderStrategy,
117    /// Whether to preserve column order in output
118    pub preserve_order: bool,
119    /// Whether to use parallel processing
120    pub n_jobs: Option<usize>,
121    /// Whether to validate input
122    pub validate_input: bool,
123    /// Strategy for handling transformation errors
124    pub error_strategy: ColumnErrorStrategy,
125    /// Enable parallel processing for transformers
126    pub parallel_execution: bool,
127    /// Fallback transformer for error handling
128    pub fallback_transformer: Option<Box<dyn TransformerWrapper>>,
129}
130
131impl Default for ColumnTransformerConfig {
132    fn default() -> Self {
133        Self {
134            remainder: RemainderStrategy::Drop,
135            preserve_order: false,
136            n_jobs: None,
137            validate_input: true,
138            error_strategy: ColumnErrorStrategy::StopOnError,
139            parallel_execution: false,
140            fallback_transformer: None,
141        }
142    }
143}
144
145/// ColumnTransformer applies different transformers to different columns
146#[derive(Debug)]
147pub struct ColumnTransformer<State = Untrained> {
148    config: ColumnTransformerConfig,
149    transformers: Vec<TransformerStep>,
150    state: PhantomData<State>,
151    // Fitted parameters
152    fitted_transformers_: Option<Vec<TransformerStep>>,
153    feature_names_in_: Option<Vec<String>>,
154    n_features_in_: Option<usize>,
155    output_indices_: Option<HashMap<String, Vec<usize>>>,
156    remainder_indices_: Option<Vec<usize>>,
157}
158
159/// Result of a column transformation attempt
160#[derive(Debug)]
161struct ColumnTransformResult {
162    transformer_name: String,
163    column_indices: Vec<usize>,
164    result: Result<Array2<Float>>,
165    original_indices: Vec<usize>,
166}
167
168// Shared methods for both Untrained and Trained states
169impl<State> ColumnTransformer<State> {
170    /// Apply transformer with error handling
171    fn apply_transformer_with_error_handling(
172        &self,
173        step: &TransformerStep,
174        _data: &Array2<Float>,
175        subset: &Array2<Float>,
176        is_fit_transform: bool,
177        resolved_indices: &[usize],
178    ) -> ColumnTransformResult {
179        let column_indices = resolved_indices.to_vec();
180
181        let transform_result = if is_fit_transform {
182            step.transformer.fit_transform_wrapper(subset)
183        } else {
184            step.transformer.transform_wrapper(subset)
185        };
186
187        let final_result = match transform_result {
188            Ok(transformed) => Ok(transformed),
189            Err(error) => {
190                // Apply error strategy
191                match self.config.error_strategy {
192                    ColumnErrorStrategy::StopOnError => Err(error),
193                    ColumnErrorStrategy::SkipOnError => {
194                        eprintln!(
195                            "Warning: Transformer '{}' failed on columns {:?}: {}. Skipping...",
196                            step.name, column_indices, error
197                        );
198                        // Return empty result to indicate skipping
199                        Ok(Array2::zeros((subset.nrows(), 0)))
200                    }
201                    ColumnErrorStrategy::Fallback => {
202                        if let Some(ref fallback) = self.config.fallback_transformer {
203                            eprintln!("Warning: Transformer '{}' failed on columns {:?}: {}. Using fallback...", 
204                                    step.name, column_indices, error);
205                            if is_fit_transform {
206                                fallback.fit_transform_wrapper(subset)
207                            } else {
208                                fallback.transform_wrapper(subset)
209                            }
210                        } else {
211                            eprintln!("Warning: Transformer '{}' failed on columns {:?}: {}. No fallback available, passing through...", 
212                                    step.name, column_indices, error);
213                            Ok(subset.clone())
214                        }
215                    }
216                    ColumnErrorStrategy::ReplaceWithZeros => {
217                        eprintln!("Warning: Transformer '{}' failed on columns {:?}: {}. Replacing with zeros...", 
218                                step.name, column_indices, error);
219                        Ok(Array2::zeros(subset.dim()))
220                    }
221                    ColumnErrorStrategy::ReplaceWithNaN => {
222                        eprintln!("Warning: Transformer '{}' failed on columns {:?}: {}. Replacing with NaN...", 
223                                step.name, column_indices, error);
224                        Ok(Array2::from_elem(subset.dim(), Float::NAN))
225                    }
226                }
227            }
228        };
229
230        ColumnTransformResult {
231            transformer_name: step.name.clone(),
232            column_indices: column_indices.clone(),
233            result: final_result,
234            original_indices: column_indices,
235        }
236    }
237}
238
239impl ColumnTransformer<Untrained> {
240    /// Create a new ColumnTransformer
241    pub fn new() -> Self {
242        Self {
243            config: ColumnTransformerConfig::default(),
244            transformers: Vec::new(),
245            state: PhantomData,
246            fitted_transformers_: None,
247            feature_names_in_: None,
248            n_features_in_: None,
249            output_indices_: None,
250            remainder_indices_: None,
251        }
252    }
253
254    /// Add a transformer for specific columns
255    pub fn add_transformer<T>(mut self, name: &str, transformer: T, columns: ColumnSelector) -> Self
256    where
257        T: TransformerWrapper + 'static,
258    {
259        self.transformers.push(TransformerStep {
260            name: name.to_string(),
261            columns,
262            transformer: Box::new(transformer),
263        });
264        self
265    }
266
267    /// Set the remainder strategy
268    pub fn remainder(mut self, strategy: RemainderStrategy) -> Self {
269        self.config.remainder = strategy;
270        self
271    }
272
273    /// Set whether to preserve column order
274    pub fn preserve_order(mut self, preserve: bool) -> Self {
275        self.config.preserve_order = preserve;
276        self
277    }
278
279    /// Set number of parallel jobs
280    pub fn n_jobs(mut self, n_jobs: Option<usize>) -> Self {
281        self.config.n_jobs = n_jobs;
282        self
283    }
284
285    /// Set input validation
286    pub fn validate_input(mut self, validate: bool) -> Self {
287        self.config.validate_input = validate;
288        self
289    }
290
291    /// Set error handling strategy
292    pub fn error_strategy(mut self, strategy: ColumnErrorStrategy) -> Self {
293        self.config.error_strategy = strategy;
294        self
295    }
296
297    /// Enable/disable parallel execution
298    pub fn parallel_execution(mut self, parallel: bool) -> Self {
299        self.config.parallel_execution = parallel;
300        self
301    }
302
303    /// Set fallback transformer for error handling
304    pub fn fallback_transformer<T>(mut self, transformer: T) -> Self
305    where
306        T: TransformerWrapper + 'static,
307    {
308        self.config.fallback_transformer = Some(Box::new(transformer));
309        self
310    }
311
312    /// Resolve column indices from selectors
313    fn resolve_columns(&self, selector: &ColumnSelector, n_features: usize) -> Result<Vec<usize>> {
314        match selector {
315            ColumnSelector::Indices(indices) => {
316                // Validate indices
317                for &idx in indices {
318                    if idx >= n_features {
319                        return Err(SklearsError::InvalidInput(format!(
320                            "Column index {} is out of bounds for {} features",
321                            idx, n_features
322                        )));
323                    }
324                }
325                Ok(indices.clone())
326            }
327            ColumnSelector::Names(_names) => {
328                // For now, return error as we need named column support
329                Err(SklearsError::NotImplemented(
330                    "Named column selection not yet implemented".to_string(),
331                ))
332            }
333            ColumnSelector::DataType(_dtype) => {
334                // DataType selection requires training data, handled in resolve_columns_with_data
335                Err(SklearsError::InvalidInput(
336                    "DataType column selection requires training data. Use resolve_columns_with_data.".to_string(),
337                ))
338            }
339            ColumnSelector::Remainder => {
340                // This should be handled separately
341                Ok(Vec::new())
342            }
343        }
344    }
345
346    /// Resolve column indices from selectors with access to training data
347    fn resolve_columns_with_data(
348        &self,
349        selector: &ColumnSelector,
350        data: &Array2<Float>,
351    ) -> Result<Vec<usize>> {
352        let (_, n_features) = data.dim();
353
354        match selector {
355            ColumnSelector::Indices(indices) => {
356                // Validate indices
357                for &idx in indices {
358                    if idx >= n_features {
359                        return Err(SklearsError::InvalidInput(format!(
360                            "Column index {} is out of bounds for {} features",
361                            idx, n_features
362                        )));
363                    }
364                }
365                Ok(indices.clone())
366            }
367            ColumnSelector::Names(_names) => {
368                // For now, return error as we need named column support
369                Err(SklearsError::NotImplemented(
370                    "Named column selection not yet implemented".to_string(),
371                ))
372            }
373            ColumnSelector::DataType(dtype) => self.infer_columns_by_dtype_with_data(dtype, data),
374            ColumnSelector::Remainder => {
375                // This should be handled separately
376                Ok(Vec::new())
377            }
378        }
379    }
380
381    /// Infer column indices by data type using heuristics (without training data)
382    fn infer_columns_by_dtype(&self, _dtype: &DataType, _n_features: usize) -> Result<Vec<usize>> {
383        // This method cannot work without training data
384        Err(SklearsError::InvalidInput(
385            "Data type column selection requires training data context. \
386             Use resolve_columns_with_data during fit."
387                .to_string(),
388        ))
389    }
390
391    /// Infer column indices by data type using heuristics on training data
392    fn infer_columns_by_dtype_with_data(
393        &self,
394        dtype: &DataType,
395        data: &Array2<Float>,
396    ) -> Result<Vec<usize>> {
397        let (_n_samples, n_features) = data.dim();
398        let mut matching_columns = Vec::new();
399
400        for col_idx in 0..n_features {
401            let column = data.column(col_idx);
402            let column_type = self.infer_column_type(&column);
403
404            if column_type == *dtype {
405                matching_columns.push(col_idx);
406            }
407        }
408
409        Ok(matching_columns)
410    }
411
412    /// Infer the data type of a single column using heuristics
413    fn infer_column_type(&self, column: &scirs2_core::ndarray::ArrayView1<Float>) -> DataType {
414        let unique_values: std::collections::HashSet<_> =
415            column.iter().map(|&x| OrderedFloat::from(x)).collect();
416
417        let n_unique = unique_values.len();
418        let n_total = column.len();
419
420        // Check if column is boolean (only 0.0 and 1.0 values)
421        if n_unique <= 2 {
422            let zero_bits = OrderedFloat::from(0.0);
423            let one_bits = OrderedFloat::from(1.0);
424            if unique_values
425                .iter()
426                .all(|&x| x == zero_bits || x == one_bits)
427            {
428                return DataType::Boolean;
429            }
430        }
431
432        // Heuristic for categorical vs numeric
433        // If the ratio of unique values to total values is low, consider it categorical
434        let unique_ratio = n_unique as f64 / n_total as f64;
435
436        // Use a more balanced approach: categorical if either condition is met
437        // but with more conservative thresholds
438        if (unique_ratio < 0.6 && n_unique <= 5) || unique_ratio < 0.2 {
439            DataType::Categorical
440        } else {
441            DataType::Numeric
442        }
443    }
444
445    /// Get indices of columns that are not selected by any transformer
446    fn get_remainder_indices(&self, data: &Array2<Float>) -> Result<Vec<usize>> {
447        let (_, n_features) = data.dim();
448        let mut used_indices = std::collections::HashSet::new();
449
450        // Collect all used indices
451        for step in &self.transformers {
452            let indices = match &step.columns {
453                ColumnSelector::DataType(_) => {
454                    self.resolve_columns_with_data(&step.columns, data)?
455                }
456                _ => self.resolve_columns(&step.columns, n_features)?,
457            };
458            for idx in indices {
459                used_indices.insert(idx);
460            }
461        }
462
463        // Return unused indices
464        Ok((0..n_features)
465            .filter(|i| !used_indices.contains(i))
466            .collect())
467    }
468}
469
470impl Default for ColumnTransformer<Untrained> {
471    fn default() -> Self {
472        Self::new()
473    }
474}
475
476impl Estimator for ColumnTransformer<Untrained> {
477    type Config = ColumnTransformerConfig;
478    type Error = SklearsError;
479    type Float = Float;
480
481    fn config(&self) -> &Self::Config {
482        &self.config
483    }
484}
485
486impl Estimator for ColumnTransformer<Trained> {
487    type Config = ColumnTransformerConfig;
488    type Error = SklearsError;
489    type Float = Float;
490
491    fn config(&self) -> &Self::Config {
492        &self.config
493    }
494}
495
496impl Fit<Array2<Float>, ()> for ColumnTransformer<Untrained> {
497    type Fitted = ColumnTransformer<Trained>;
498
499    fn fit(self, x: &Array2<Float>, _y: &()) -> Result<Self::Fitted> {
500        let (n_samples, n_features) = x.dim();
501
502        if n_samples == 0 {
503            return Err(SklearsError::InvalidInput(
504                "Cannot fit transformer on empty dataset".to_string(),
505            ));
506        }
507
508        // Get remainder indices
509        let remainder_indices = self.get_remainder_indices(x)?;
510
511        // Prepare transformer steps with resolved indices
512        let mut transformer_tasks: Vec<(TransformerStep, Vec<usize>)> = Vec::new();
513
514        for step in &self.transformers {
515            // Resolve column indices - use data-aware method for DataType selectors
516            let indices = match &step.columns {
517                ColumnSelector::DataType(_) => self.resolve_columns_with_data(&step.columns, x)?,
518                _ => self.resolve_columns(&step.columns, n_features)?,
519            };
520
521            if !indices.is_empty() {
522                transformer_tasks.push((step.clone(), indices));
523            }
524        }
525
526        // Apply transformers with parallel processing and error handling
527        let transform_results: Vec<ColumnTransformResult> = if self.config.parallel_execution
528            && transformer_tasks.len() > 1
529        {
530            #[cfg(feature = "parallel")]
531            {
532                transformer_tasks
533                    .into_par_iter()
534                    .map(|(step, indices)| {
535                        let subset = x.select(Axis(1), &indices);
536                        self.apply_transformer_with_error_handling(
537                            &step, x, &subset, true, &indices,
538                        )
539                    })
540                    .collect()
541            }
542            #[cfg(not(feature = "parallel"))]
543            {
544                // Fallback to sequential processing
545                transformer_tasks
546                    .into_iter()
547                    .map(|(step, indices)| {
548                        let subset = x.select(Axis(1), &indices);
549                        self.apply_transformer_with_error_handling(
550                            &step, x, &subset, true, &indices,
551                        )
552                    })
553                    .collect()
554            }
555        } else {
556            // Sequential processing
557            transformer_tasks
558                .into_iter()
559                .map(|(step, indices)| {
560                    let subset = x.select(Axis(1), &indices);
561                    self.apply_transformer_with_error_handling(&step, x, &subset, true, &indices)
562                })
563                .collect()
564        };
565
566        // Process results and create fitted transformers
567        let mut fitted_transformers = Vec::new();
568        let mut output_indices = HashMap::new();
569
570        for transform_result in transform_results {
571            match transform_result.result {
572                Ok(transformed) => {
573                    if transformed.ncols() > 0 {
574                        // Skip empty results (from SkipOnError)
575                        // Store output indices mapping
576                        let output_cols = (0..transformed.ncols()).collect();
577                        output_indices
578                            .insert(transform_result.transformer_name.clone(), output_cols);
579
580                        // Create fitted transformer step
581                        let transformer_name = transform_result.transformer_name.clone();
582                        fitted_transformers.push(TransformerStep {
583                            name: transformer_name.clone(),
584                            columns: ColumnSelector::Indices(transform_result.original_indices),
585                            transformer: self
586                                .transformers
587                                .iter()
588                                .find(|s| s.name == transformer_name)
589                                .expect("operation should succeed")
590                                .transformer
591                                .clone_box(),
592                        });
593                    }
594                }
595                Err(e) => {
596                    // If we reach here, it means StopOnError was used
597                    return Err(SklearsError::TransformError(format!(
598                        "Transformer '{}' failed: {}",
599                        transform_result.transformer_name, e
600                    )));
601                }
602            }
603        }
604
605        Ok(ColumnTransformer {
606            config: self.config,
607            transformers: self.transformers,
608            state: PhantomData,
609            fitted_transformers_: Some(fitted_transformers),
610            feature_names_in_: None,
611            n_features_in_: Some(n_features),
612            output_indices_: Some(output_indices),
613            remainder_indices_: Some(remainder_indices),
614        })
615    }
616}
617
618impl Transform<Array2<Float>, Array2<Float>> for ColumnTransformer<Trained> {
619    fn transform(&self, x: &Array2<Float>) -> Result<Array2<Float>> {
620        let (n_samples, n_features) = x.dim();
621
622        if Some(n_features) != self.n_features_in_ {
623            return Err(SklearsError::FeatureMismatch {
624                expected: self.n_features_in_.unwrap_or(0),
625                actual: n_features,
626            });
627        }
628
629        let fitted_transformers = self
630            .fitted_transformers_
631            .as_ref()
632            .expect("operation should succeed");
633        let remainder_indices = self
634            .remainder_indices_
635            .as_ref()
636            .expect("operation should succeed");
637
638        // Prepare transformer tasks for parallel processing
639        let transformer_tasks: Vec<&TransformerStep> = fitted_transformers.iter().collect();
640
641        // Apply transformers with parallel processing and error handling
642        let transform_results: Vec<ColumnTransformResult> =
643            if self.config.parallel_execution && transformer_tasks.len() > 1 {
644                #[cfg(feature = "parallel")]
645                {
646                    transformer_tasks
647                        .into_par_iter()
648                        .filter_map(|step| {
649                            if let ColumnSelector::Indices(indices) = &step.columns {
650                                if !indices.is_empty() {
651                                    let subset = x.select(Axis(1), indices);
652                                    Some(self.apply_transformer_with_error_handling(
653                                        step, x, &subset, false, indices,
654                                    ))
655                                } else {
656                                    None
657                                }
658                            } else {
659                                None
660                            }
661                        })
662                        .collect()
663                }
664                #[cfg(not(feature = "parallel"))]
665                {
666                    // Fallback to sequential processing
667                    transformer_tasks
668                        .into_iter()
669                        .filter_map(|step| {
670                            if let ColumnSelector::Indices(indices) = &step.columns {
671                                if !indices.is_empty() {
672                                    let subset = x.select(Axis(1), indices);
673                                    Some(self.apply_transformer_with_error_handling(
674                                        step, x, &subset, false, indices,
675                                    ))
676                                } else {
677                                    None
678                                }
679                            } else {
680                                None
681                            }
682                        })
683                        .collect()
684                }
685            } else {
686                // Sequential processing
687                transformer_tasks
688                    .into_iter()
689                    .filter_map(|step| {
690                        if let ColumnSelector::Indices(indices) = &step.columns {
691                            if !indices.is_empty() {
692                                let subset = x.select(Axis(1), indices);
693                                Some(self.apply_transformer_with_error_handling(
694                                    step, x, &subset, false, indices,
695                                ))
696                            } else {
697                                None
698                            }
699                        } else {
700                            None
701                        }
702                    })
703                    .collect()
704            };
705
706        // Process results and create column outputs
707        let mut column_outputs: Vec<(usize, Array2<Float>)> = Vec::new();
708
709        for transform_result in transform_results {
710            match transform_result.result {
711                Ok(transformed) => {
712                    if transformed.ncols() > 0 {
713                        // Skip empty results (from SkipOnError)
714                        // For each original column index, store its min value to maintain order
715                        let min_index = *transform_result
716                            .original_indices
717                            .iter()
718                            .min()
719                            .expect("collection should not be empty for min/max");
720                        column_outputs.push((min_index, transformed));
721                    }
722                }
723                Err(e) => {
724                    // If we reach here, it means StopOnError was used
725                    return Err(SklearsError::TransformError(format!(
726                        "Transformer '{}' failed: {}",
727                        transform_result.transformer_name, e
728                    )));
729                }
730            }
731        }
732
733        // Handle remainder columns
734        if !remainder_indices.is_empty() {
735            let remainder_data = x.select(Axis(1), remainder_indices);
736
737            let transformed_remainder = match &self.config.remainder {
738                RemainderStrategy::Drop => {
739                    None // remainder is dropped
740                }
741                RemainderStrategy::Passthrough => Some(remainder_data),
742                RemainderStrategy::Transform(transformer) => {
743                    let transformed = transformer.transform_wrapper(&remainder_data)?;
744                    Some(transformed)
745                }
746            };
747
748            if let Some(remainder_output) = transformed_remainder {
749                // Add remainder with the minimum remainder index
750                if let Some(&min_remainder_index) = remainder_indices.iter().min() {
751                    column_outputs.push((min_remainder_index, remainder_output));
752                }
753            }
754        }
755
756        // Sort by original column indices to maintain proper ordering
757        column_outputs.sort_by_key(|(idx, _)| *idx);
758
759        // Concatenate all output parts in the correct order
760        if column_outputs.is_empty() {
761            return Err(SklearsError::InvalidInput(
762                "No output from any transformer".to_string(),
763            ));
764        }
765
766        // Calculate total columns
767        let total_cols: usize = column_outputs.iter().map(|(_, arr)| arr.ncols()).sum();
768        let mut result = Array2::zeros((n_samples, total_cols));
769
770        // Concatenate in order
771        let mut col_offset = 0;
772        for (_, part) in column_outputs {
773            let part_cols = part.ncols();
774            result
775                .slice_mut(s![.., col_offset..col_offset + part_cols])
776                .assign(&part);
777            col_offset += part_cols;
778        }
779
780        Ok(result)
781    }
782}
783
784impl ColumnTransformer<Trained> {
785    /// Get the number of features seen during fitting
786    pub fn n_features_in(&self) -> usize {
787        self.n_features_in_.expect("operation should succeed")
788    }
789
790    /// Get the output indices mapping
791    pub fn output_indices(&self) -> &HashMap<String, Vec<usize>> {
792        self.output_indices_
793            .as_ref()
794            .expect("operation should succeed")
795    }
796
797    /// Get the remainder indices
798    pub fn remainder_indices(&self) -> &Vec<usize> {
799        self.remainder_indices_
800            .as_ref()
801            .expect("operation should succeed")
802    }
803}
804
805#[allow(non_snake_case)]
806#[cfg(test)]
807mod tests {
808    use super::*;
809    use scirs2_core::ndarray::array;
810
811    // Mock transformer for testing
812    #[derive(Debug, Clone)]
813    struct MockTransformer {
814        scale: Float,
815    }
816
817    impl TransformerWrapper for MockTransformer {
818        fn fit_transform_wrapper(&self, x: &Array2<Float>) -> Result<Array2<Float>> {
819            Ok(x * self.scale)
820        }
821
822        fn transform_wrapper(&self, x: &Array2<Float>) -> Result<Array2<Float>> {
823            Ok(x * self.scale)
824        }
825
826        fn get_n_features_out(&self) -> Option<usize> {
827            None // Same as input
828        }
829
830        fn clone_box(&self) -> Box<dyn TransformerWrapper> {
831            Box::new(self.clone())
832        }
833    }
834
835    #[test]
836    fn test_column_transformer_basic() {
837        let x = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0],];
838
839        let ct = ColumnTransformer::new()
840            .add_transformer(
841                "scale_first_two",
842                MockTransformer { scale: 2.0 },
843                ColumnSelector::Indices(vec![0, 1]),
844            )
845            .remainder(RemainderStrategy::Passthrough);
846
847        let fitted_ct = ct.fit(&x, &()).expect("model fitting should succeed");
848        let result = fitted_ct
849            .transform(&x)
850            .expect("transformation should succeed");
851
852        // First two columns should be scaled by 2, last column passed through
853        assert_eq!(result.dim(), (3, 3));
854        assert_eq!(result[[0, 0]], 2.0); // 1.0 * 2.0
855        assert_eq!(result[[0, 1]], 4.0); // 2.0 * 2.0
856        assert_eq!(result[[0, 2]], 3.0); // 3.0 (passthrough)
857    }
858
859    #[test]
860    fn test_column_transformer_drop_remainder() {
861        let x = array![[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0],];
862
863        let ct = ColumnTransformer::new()
864            .add_transformer(
865                "scale_middle",
866                MockTransformer { scale: 3.0 },
867                ColumnSelector::Indices(vec![1, 2]),
868            )
869            .remainder(RemainderStrategy::Drop);
870
871        let fitted_ct = ct.fit(&x, &()).expect("model fitting should succeed");
872        let result = fitted_ct
873            .transform(&x)
874            .expect("transformation should succeed");
875
876        // Only middle two columns should remain (scaled by 3)
877        assert_eq!(result.dim(), (2, 2));
878        assert_eq!(result[[0, 0]], 6.0); // 2.0 * 3.0
879        assert_eq!(result[[0, 1]], 9.0); // 3.0 * 3.0
880    }
881
882    #[test]
883    fn test_column_transformer_multiple_transformers() {
884        let x = array![[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0],];
885
886        let ct = ColumnTransformer::new()
887            .add_transformer(
888                "scale_first",
889                MockTransformer { scale: 2.0 },
890                ColumnSelector::Indices(vec![0]),
891            )
892            .add_transformer(
893                "scale_last",
894                MockTransformer { scale: 0.5 },
895                ColumnSelector::Indices(vec![3]),
896            )
897            .remainder(RemainderStrategy::Passthrough);
898
899        let fitted_ct = ct.fit(&x, &()).expect("model fitting should succeed");
900        let result = fitted_ct
901            .transform(&x)
902            .expect("transformation should succeed");
903
904        // Should have 4 columns: [scaled_first, middle_two_passthrough, scaled_last]
905        assert_eq!(result.dim(), (2, 4));
906        assert_eq!(result[[0, 0]], 2.0); // 1.0 * 2.0 (first transformer)
907        assert_eq!(result[[0, 1]], 2.0); // 2.0 (passthrough)
908        assert_eq!(result[[0, 2]], 3.0); // 3.0 (passthrough)
909        assert_eq!(result[[0, 3]], 2.0); // 4.0 * 0.5 (second transformer)
910    }
911
912    #[test]
913    fn test_column_transformer_empty_data() {
914        let x_empty: Array2<Float> = Array2::zeros((0, 3));
915
916        let ct = ColumnTransformer::new().add_transformer(
917            "test",
918            MockTransformer { scale: 1.0 },
919            ColumnSelector::Indices(vec![0]),
920        );
921
922        let result = ct.fit(&x_empty, &());
923        assert!(result.is_err());
924    }
925
926    #[test]
927    fn test_column_transformer_invalid_indices() {
928        let x = array![[1.0, 2.0], [3.0, 4.0],];
929
930        let ct = ColumnTransformer::new().add_transformer(
931            "invalid",
932            MockTransformer { scale: 1.0 },
933            ColumnSelector::Indices(vec![0, 5]), // Index 5 doesn't exist
934        );
935
936        let result = ct.fit(&x, &());
937        assert!(result.is_err());
938    }
939
940    #[test]
941    fn test_column_type_inference() {
942        let ct = ColumnTransformer::new();
943
944        // Test boolean detection (strict 0.0 and 1.0 only)
945        let bool_col = scirs2_core::ndarray::array![0.0, 1.0, 0.0, 1.0, 0.0];
946        let bool_type = ct.infer_column_type(&bool_col.view());
947        assert_eq!(bool_type, DataType::Boolean);
948
949        // Test categorical detection (few unique values)
950        let cat_col = scirs2_core::ndarray::array![1.0, 2.0, 1.0, 3.0, 2.0, 1.0, 2.0, 3.0];
951        let cat_type = ct.infer_column_type(&cat_col.view());
952        assert_eq!(cat_type, DataType::Categorical);
953
954        // Test numeric detection (many unique values)
955        let num_col =
956            scirs2_core::ndarray::array![1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8, 9.9, 10.0];
957        let num_type = ct.infer_column_type(&num_col.view());
958        assert_eq!(num_type, DataType::Numeric);
959    }
960
961    // Failing transformer for testing error handling
962    #[derive(Debug, Clone)]
963    struct FailingTransformer {
964        should_fail: bool,
965    }
966
967    impl TransformerWrapper for FailingTransformer {
968        fn fit_transform_wrapper(&self, x: &Array2<Float>) -> Result<Array2<Float>> {
969            if self.should_fail {
970                Err(SklearsError::InvalidInput(
971                    "Intentional failure for testing".to_string(),
972                ))
973            } else {
974                Ok(x * 2.0)
975            }
976        }
977
978        fn transform_wrapper(&self, x: &Array2<Float>) -> Result<Array2<Float>> {
979            if self.should_fail {
980                Err(SklearsError::InvalidInput(
981                    "Intentional failure for testing".to_string(),
982                ))
983            } else {
984                Ok(x * 2.0)
985            }
986        }
987
988        fn get_n_features_out(&self) -> Option<usize> {
989            None
990        }
991
992        fn clone_box(&self) -> Box<dyn TransformerWrapper> {
993            Box::new(self.clone())
994        }
995    }
996
997    #[test]
998    fn test_column_transformer_error_handling_stop_on_error() {
999        let x = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
1000
1001        let ct = ColumnTransformer::new()
1002            .add_transformer(
1003                "failing",
1004                FailingTransformer { should_fail: true },
1005                ColumnSelector::Indices(vec![0]),
1006            )
1007            .error_strategy(ColumnErrorStrategy::StopOnError);
1008
1009        let result = ct.fit(&x, &());
1010        assert!(result.is_err(), "Should fail with StopOnError");
1011    }
1012
1013    #[test]
1014    fn test_column_transformer_error_handling_skip_on_error() {
1015        let x = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
1016
1017        let ct = ColumnTransformer::new()
1018            .add_transformer(
1019                "failing",
1020                FailingTransformer { should_fail: true },
1021                ColumnSelector::Indices(vec![0]),
1022            )
1023            .add_transformer(
1024                "working",
1025                MockTransformer { scale: 2.0 },
1026                ColumnSelector::Indices(vec![1]),
1027            )
1028            .error_strategy(ColumnErrorStrategy::SkipOnError)
1029            .remainder(RemainderStrategy::Passthrough);
1030
1031        let fitted_ct = ct.fit(&x, &()).expect("model fitting should succeed");
1032        let result = fitted_ct
1033            .transform(&x)
1034            .expect("transformation should succeed");
1035
1036        // Should have 2 columns: working transformer output + remainder
1037        assert_eq!(result.dim(), (2, 2));
1038        assert_eq!(result[[0, 0]], 4.0); // 2.0 * 2.0 (working transformer)
1039        assert_eq!(result[[0, 1]], 3.0); // 3.0 (remainder passthrough)
1040    }
1041
1042    #[test]
1043    fn test_column_transformer_error_handling_replace_with_zeros() {
1044        let x = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
1045
1046        let ct = ColumnTransformer::new()
1047            .add_transformer(
1048                "failing",
1049                FailingTransformer { should_fail: true },
1050                ColumnSelector::Indices(vec![0]),
1051            )
1052            .error_strategy(ColumnErrorStrategy::ReplaceWithZeros)
1053            .remainder(RemainderStrategy::Passthrough);
1054
1055        let fitted_ct = ct.fit(&x, &()).expect("model fitting should succeed");
1056        let result = fitted_ct
1057            .transform(&x)
1058            .expect("transformation should succeed");
1059
1060        // Should have 3 columns: zeros (replacement) + remainder passthrough
1061        assert_eq!(result.dim(), (2, 3));
1062        assert_eq!(result[[0, 0]], 0.0); // Replaced with zero
1063        assert_eq!(result[[0, 1]], 2.0); // Remainder passthrough
1064        assert_eq!(result[[0, 2]], 3.0); // Remainder passthrough
1065    }
1066
1067    #[test]
1068    fn test_column_transformer_error_handling_fallback() {
1069        let x = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
1070
1071        let ct = ColumnTransformer::new()
1072            .add_transformer(
1073                "failing",
1074                FailingTransformer { should_fail: true },
1075                ColumnSelector::Indices(vec![0]),
1076            )
1077            .error_strategy(ColumnErrorStrategy::Fallback)
1078            .fallback_transformer(MockTransformer { scale: 0.5 })
1079            .remainder(RemainderStrategy::Passthrough);
1080
1081        let fitted_ct = ct.fit(&x, &()).expect("model fitting should succeed");
1082        let result = fitted_ct
1083            .transform(&x)
1084            .expect("transformation should succeed");
1085
1086        // Should have 3 columns: fallback transformer output + remainder
1087        assert_eq!(result.dim(), (2, 3));
1088        assert_eq!(result[[0, 0]], 0.5); // 1.0 * 0.5 (fallback)
1089        assert_eq!(result[[0, 1]], 2.0); // Remainder passthrough
1090        assert_eq!(result[[0, 2]], 3.0); // Remainder passthrough
1091    }
1092
1093    #[test]
1094    fn test_column_transformer_parallel_execution() {
1095        let x = array![[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0]];
1096
1097        let ct = ColumnTransformer::new()
1098            .add_transformer(
1099                "scale_first",
1100                MockTransformer { scale: 2.0 },
1101                ColumnSelector::Indices(vec![0]),
1102            )
1103            .add_transformer(
1104                "scale_second",
1105                MockTransformer { scale: 3.0 },
1106                ColumnSelector::Indices(vec![1]),
1107            )
1108            .parallel_execution(true)
1109            .remainder(RemainderStrategy::Passthrough);
1110
1111        let fitted_ct = ct.fit(&x, &()).expect("model fitting should succeed");
1112        let result = fitted_ct
1113            .transform(&x)
1114            .expect("transformation should succeed");
1115
1116        // Should have 4 columns: 2 transformed + 2 remainder
1117        assert_eq!(result.dim(), (2, 4));
1118        assert_eq!(result[[0, 0]], 2.0); // 1.0 * 2.0
1119        assert_eq!(result[[0, 1]], 6.0); // 2.0 * 3.0
1120        assert_eq!(result[[0, 2]], 3.0); // Remainder
1121        assert_eq!(result[[0, 3]], 4.0); // Remainder
1122    }
1123
1124    #[test]
1125    fn test_column_transformer_dtype_selection() {
1126        // Create data with very clear column types:
1127        // Col 0: Numeric (many unique continuous values)
1128        // Col 1: Boolean (strict 0.0/1.0 only)
1129        // Col 2: Categorical (few repeated values)
1130        let x = array![
1131            [1.23456, 0.0, 1.0],
1132            [2.78901, 1.0, 1.0],
1133            [3.45678, 0.0, 2.0],
1134            [4.98765, 1.0, 1.0],
1135            [5.12345, 0.0, 2.0],
1136            [6.67890, 1.0, 3.0],
1137            [7.11111, 0.0, 1.0],
1138            [8.22222, 1.0, 2.0],
1139        ];
1140
1141        // Test Boolean column selection
1142        let ct_bool = ColumnTransformer::new().add_transformer(
1143            "bool_transformer",
1144            MockTransformer { scale: 10.0 },
1145            ColumnSelector::DataType(DataType::Boolean),
1146        );
1147
1148        let fitted_ct_bool = ct_bool.fit(&x, &()).expect("model fitting should succeed");
1149        let result_bool = fitted_ct_bool
1150            .transform(&x)
1151            .expect("transformation should succeed");
1152
1153        // Should have 1 column (boolean column scaled by 10)
1154        assert_eq!(result_bool.dim(), (8, 1));
1155        assert_eq!(result_bool[[0, 0]], 0.0); // 0.0 * 10.0
1156        assert_eq!(result_bool[[1, 0]], 10.0); // 1.0 * 10.0
1157
1158        // Test Categorical column selection
1159        let ct_cat = ColumnTransformer::new().add_transformer(
1160            "cat_transformer",
1161            MockTransformer { scale: 0.1 },
1162            ColumnSelector::DataType(DataType::Categorical),
1163        );
1164
1165        let fitted_ct_cat = ct_cat.fit(&x, &()).expect("model fitting should succeed");
1166        let result_cat = fitted_ct_cat
1167            .transform(&x)
1168            .expect("transformation should succeed");
1169
1170        // Should have 1 column (categorical column scaled by 0.1)
1171        assert_eq!(result_cat.dim(), (8, 1));
1172        assert_eq!(result_cat[[0, 0]], 0.1); // 1.0 * 0.1
1173
1174        // Test Numeric column selection
1175        let ct_num = ColumnTransformer::new().add_transformer(
1176            "num_transformer",
1177            MockTransformer { scale: 2.0 },
1178            ColumnSelector::DataType(DataType::Numeric),
1179        );
1180
1181        let fitted_ct_num = ct_num.fit(&x, &()).expect("model fitting should succeed");
1182        let result_num = fitted_ct_num
1183            .transform(&x)
1184            .expect("transformation should succeed");
1185
1186        // Should have 1 column (numeric column scaled by 2.0)
1187        assert_eq!(result_num.dim(), (8, 1));
1188        let expected_first = 1.23456 * 2.0;
1189        assert!((result_num[[0, 0]] - expected_first).abs() < 1e-10)
1190    }
1191}