sklears_compose/
column_transformer.rs

1//! Column Transformer
2//!
3//! Apply different transformers to different subsets of features.
4
5use scirs2_core::ndarray::{s, Array2, ArrayView1, ArrayView2};
6use sklears_core::{
7    error::{Result as SklResult, SklearsError},
8    traits::{Estimator, Fit, Transform, Untrained},
9    types::Float,
10};
11// TODO: Migrate to scirs2-sparse when implementing sparse functionality
12// use scirs2_sparse::{CsMat, TriMat};
13use std::collections::HashMap;
14
15use crate::{MockTransformer, PipelineStep};
16
17/// Column Transformer
18///
19/// Apply different transformers to different subsets of features.
20/// This allows you to apply different preprocessing steps to different
21/// types of features (e.g., numerical vs. categorical).
22///
23/// # Parameters
24///
25/// * `transformers` - List of (name, transformer, columns) tuples
26/// * `remainder` - How to handle remaining columns ('drop', 'passthrough')
27/// * `sparse_threshold` - Threshold for sparse output
28/// * `n_jobs` - Number of parallel jobs
29/// * `transformer_weights` - Weights for each transformer
30///
31/// # Examples
32///
33/// ```ignore
34/// use sklears_compose::ColumnTransformer;
35/// use scirs2_core::ndarray::array;
36///
37/// let data = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]];
38///
39/// let mut ct = ColumnTransformer::new();
40/// ct.add_transformer("numeric".to_string(), vec![0, 1]);
41/// ct.add_transformer("categorical".to_string(), vec![2]);
42/// ```
43#[derive(Debug, Clone)]
44pub struct ColumnTransformer<S = Untrained> {
45    state: S,
46    transformer_names: Vec<String>,
47    transformer_columns: Vec<Vec<usize>>,
48    remainder: String,
49    sparse_threshold: f64,
50    n_jobs: Option<i32>,
51    transformer_weights: Option<HashMap<String, f64>>,
52}
53
54/// Trained state for `ColumnTransformer`
55#[derive(Debug)]
56pub struct ColumnTransformerTrained {
57    fitted_transformers: Vec<(String, Box<dyn PipelineStep>, Vec<usize>)>,
58    output_indices: Vec<Vec<usize>>,
59    n_features_in: usize,
60    feature_names_in: Option<Vec<String>>,
61    sparse_output: bool,
62}
63
64impl ColumnTransformer<Untrained> {
65    /// Create a new `ColumnTransformer` instance
66    #[must_use]
67    pub fn new() -> Self {
68        Self {
69            state: Untrained,
70            transformer_names: Vec::new(),
71            transformer_columns: Vec::new(),
72            remainder: "drop".to_string(),
73            sparse_threshold: 0.3,
74            n_jobs: None,
75            transformer_weights: None,
76        }
77    }
78
79    /// Create a column transformer builder
80    #[must_use]
81    pub fn builder() -> ColumnTransformerBuilder {
82        ColumnTransformerBuilder::new()
83    }
84
85    /// Add a transformer for specific columns (legacy method)
86    pub fn add_transformer(&mut self, name: String, columns: Vec<usize>) {
87        self.transformer_names.push(name);
88        self.transformer_columns.push(columns);
89    }
90
91    /// Add a transformer with actual `PipelineStep` implementation
92    pub fn add_transformer_step(
93        &mut self,
94        name: String,
95        transformer: Box<dyn PipelineStep>,
96        columns: Vec<usize>,
97    ) {
98        self.transformer_names.push(name);
99        self.transformer_columns.push(columns);
100    }
101
102    /// Set what to do with remaining columns
103    #[must_use]
104    pub fn remainder(mut self, remainder: String) -> Self {
105        self.remainder = remainder;
106        self
107    }
108
109    /// Set the sparse threshold
110    #[must_use]
111    pub fn sparse_threshold(mut self, threshold: f64) -> Self {
112        self.sparse_threshold = threshold;
113        self
114    }
115
116    /// Set the number of parallel jobs
117    #[must_use]
118    pub fn n_jobs(mut self, n_jobs: Option<i32>) -> Self {
119        self.n_jobs = n_jobs;
120        self
121    }
122
123    /// Set transformer weights
124    #[must_use]
125    pub fn transformer_weights(mut self, weights: HashMap<String, f64>) -> Self {
126        self.transformer_weights = Some(weights);
127        self
128    }
129
130    /// Extract specified columns from input array
131    fn extract_columns(
132        &self,
133        x: &ArrayView2<'_, Float>,
134        columns: &[usize],
135    ) -> SklResult<Array2<Float>> {
136        if columns.is_empty() {
137            return Ok(Array2::zeros((x.nrows(), 0)));
138        }
139
140        let mut result = Array2::zeros((x.nrows(), columns.len()));
141        for (col_idx, &original_col) in columns.iter().enumerate() {
142            if original_col >= x.ncols() {
143                return Err(SklearsError::InvalidInput(format!(
144                    "Column index {original_col} out of bounds"
145                )));
146            }
147            result.column_mut(col_idx).assign(&x.column(original_col));
148        }
149        Ok(result)
150    }
151
152    /// Determine if output should be sparse based on sparsity and threshold
153    fn should_output_sparse(&self, x: &ArrayView2<'_, Float>) -> bool {
154        let total_elements = x.nrows() * x.ncols();
155        if total_elements == 0 {
156            return false;
157        }
158
159        let zero_count = x.iter().filter(|&&val| val == 0.0).count();
160        let sparsity = zero_count as f64 / total_elements as f64;
161
162        sparsity >= self.sparse_threshold
163    }
164}
165
166impl Default for ColumnTransformer<Untrained> {
167    fn default() -> Self {
168        Self::new()
169    }
170}
171
172/// Configuration for `ColumnTransformer`
173#[derive(Debug, Clone)]
174pub struct ColumnTransformerConfig {
175    pub remainder: String,
176    pub sparse_threshold: f64,
177    pub n_jobs: Option<i32>,
178    pub transformer_weights: Option<HashMap<String, f64>>,
179}
180
181impl Default for ColumnTransformerConfig {
182    fn default() -> Self {
183        Self {
184            remainder: "drop".to_string(),
185            sparse_threshold: 0.3,
186            n_jobs: None,
187            transformer_weights: None,
188        }
189    }
190}
191
192impl Estimator for ColumnTransformer<Untrained> {
193    type Config = ColumnTransformerConfig;
194    type Error = SklearsError;
195    type Float = Float;
196
197    fn config(&self) -> &Self::Config {
198        // For now, create a default config
199        // In a real implementation, this should be stored in the struct
200        static DEFAULT_CONFIG: ColumnTransformerConfig = ColumnTransformerConfig {
201            remainder: String::new(),
202            sparse_threshold: 0.3,
203            n_jobs: None,
204            transformer_weights: None,
205        };
206        &DEFAULT_CONFIG
207    }
208}
209
210impl Fit<ArrayView2<'_, Float>, Option<&ArrayView1<'_, Float>>> for ColumnTransformer<Untrained> {
211    type Fitted = ColumnTransformer<ColumnTransformerTrained>;
212
213    fn fit(
214        self,
215        x: &ArrayView2<'_, Float>,
216        y: &Option<&ArrayView1<'_, Float>>,
217    ) -> SklResult<Self::Fitted> {
218        let n_features_in = x.ncols();
219        let mut fitted_transformers = Vec::new();
220        let mut output_indices = Vec::new();
221        let mut used_columns = vec![false; n_features_in];
222
223        // Fit each transformer on its specified columns
224        for (name, columns) in self
225            .transformer_names
226            .iter()
227            .zip(self.transformer_columns.iter())
228        {
229            // Validate column indices
230            for &col in columns {
231                if col >= n_features_in {
232                    return Err(SklearsError::InvalidInput(format!(
233                        "Column index {col} out of bounds for {n_features_in} features"
234                    )));
235                }
236                used_columns[col] = true;
237            }
238
239            // Extract columns for this transformer
240            let x_subset = self.extract_columns(x, columns)?;
241
242            // Create a mock transformer for now - in real implementation, this would be provided
243            let mut transformer = Box::new(MockTransformer::new()) as Box<dyn PipelineStep>;
244            transformer.fit(&x_subset.view(), y.as_ref().copied())?;
245
246            fitted_transformers.push((name.clone(), transformer, columns.clone()));
247            output_indices.push((0..columns.len()).collect()); // Simplified output mapping
248        }
249
250        // Handle remainder columns
251        let remainder_columns: Vec<usize> =
252            (0..n_features_in).filter(|&i| !used_columns[i]).collect();
253
254        if !remainder_columns.is_empty() && self.remainder == "passthrough" {
255            let x_remainder = self.extract_columns(x, &remainder_columns)?;
256            let mut remainder_transformer =
257                Box::new(MockTransformer::new()) as Box<dyn PipelineStep>;
258            remainder_transformer.fit(&x_remainder.view(), y.as_ref().copied())?;
259            fitted_transformers.push((
260                "remainder".to_string(),
261                remainder_transformer,
262                remainder_columns.clone(),
263            ));
264            output_indices.push((0..remainder_columns.len()).collect());
265        }
266
267        // Determine if output should be sparse
268        let sparse_output = self.should_output_sparse(x);
269
270        Ok(ColumnTransformer {
271            state: ColumnTransformerTrained {
272                fitted_transformers,
273                output_indices,
274                n_features_in,
275                feature_names_in: None,
276                sparse_output,
277            },
278            transformer_names: self.transformer_names,
279            transformer_columns: self.transformer_columns,
280            remainder: self.remainder,
281            sparse_threshold: self.sparse_threshold,
282            n_jobs: self.n_jobs,
283            transformer_weights: self.transformer_weights,
284        })
285    }
286}
287
288/// Transform trait implementation for trained `ColumnTransformer`
289impl Transform<ArrayView2<'_, Float>, Array2<f64>> for ColumnTransformer<ColumnTransformerTrained> {
290    fn transform(&self, x: &ArrayView2<'_, Float>) -> SklResult<Array2<f64>> {
291        if x.ncols() != self.state.n_features_in {
292            return Err(SklearsError::InvalidInput(format!(
293                "Input has {} features, expected {}",
294                x.ncols(),
295                self.state.n_features_in
296            )));
297        }
298
299        if self.state.fitted_transformers.is_empty() {
300            return Ok(x.mapv(|v| v));
301        }
302
303        let mut transformed_results = Vec::new();
304
305        // Transform each subset using fitted transformers
306        for (name, transformer, columns) in &self.state.fitted_transformers {
307            let x_subset = self.extract_columns(x, columns)?;
308            let mut transformed = transformer.transform(&x_subset.view())?;
309
310            // Apply weights if specified
311            if let Some(ref weights) = self.transformer_weights {
312                if let Some(&weight) = weights.get(name) {
313                    transformed.mapv_inplace(|v| v * weight);
314                }
315            }
316
317            transformed_results.push(transformed);
318        }
319
320        if transformed_results.is_empty() {
321            return Ok(Array2::zeros((x.nrows(), 0)));
322        }
323
324        // Concatenate all transformed results
325        if transformed_results.len() == 1 {
326            Ok(transformed_results.into_iter().next().unwrap())
327        } else {
328            self.concatenate_results(transformed_results)
329        }
330    }
331}
332
333/// Sparse output support for `ColumnTransformer`
334#[derive(Debug, Clone)]
335pub enum ColumnTransformerOutput {
336    /// Dense
337    Dense(Array2<f64>),
338    // TODO: Re-enable sparse support with scirs2-sparse
339    // Sparse(CsMat<f64>),
340}
341
342impl ColumnTransformer<ColumnTransformerTrained> {
343    /// Transform data and return appropriate output format (dense or sparse)
344    pub fn transform_output(
345        &self,
346        x: &ArrayView2<'_, Float>,
347    ) -> SklResult<ColumnTransformerOutput> {
348        let dense_result = self.transform(x)?;
349
350        // TODO: Re-enable sparse support with scirs2-sparse
351        // if self.state.sparse_output {
352        //     // Convert to sparse matrix if threshold is met
353        //     let sparse_result = self.dense_to_sparse(&dense_result)?;
354        //     Ok(ColumnTransformerOutput::Sparse(sparse_result))
355        // } else {
356        Ok(ColumnTransformerOutput::Dense(dense_result))
357        // }
358    }
359
360    // TODO: Re-enable sparse support with scirs2-sparse
361    // /// Convert dense matrix to sparse CSR format
362    // fn dense_to_sparse(&self, dense: &Array2<f64>) -> SklResult<CsMat<f64>> {
363    //     let mut triplets = TriMat::new((dense.nrows(), dense.ncols()));
364    //
365    //     for (i, row) in dense.outer_iter().enumerate() {
366    //         for (j, &value) in row.iter().enumerate() {
367    //             if value != 0.0 {
368    //                 triplets.add_triplet(i, j, value);
369    //             }
370    //         }
371    //     }
372    //
373    //     Ok(triplets.to_csr())
374    // }
375
376    /// Concatenate multiple transformed results
377    fn concatenate_results(&self, results: Vec<Array2<f64>>) -> SklResult<Array2<f64>> {
378        let n_samples = results[0].nrows();
379        let total_features: usize = results
380            .iter()
381            .map(scirs2_core::ndarray::ArrayBase::ncols)
382            .sum();
383
384        let mut concatenated = Array2::zeros((n_samples, total_features));
385        let mut col_idx = 0;
386
387        for result in results {
388            if result.nrows() != n_samples {
389                return Err(SklearsError::InvalidInput(
390                    "All transformer outputs must have the same number of samples".to_string(),
391                ));
392            }
393
394            let end_idx = col_idx + result.ncols();
395            concatenated
396                .slice_mut(s![.., col_idx..end_idx])
397                .assign(&result);
398            col_idx = end_idx;
399        }
400
401        Ok(concatenated)
402    }
403
404    /// Extract specified columns from input array (helper for trained transformer)
405    fn extract_columns(
406        &self,
407        x: &ArrayView2<'_, Float>,
408        columns: &[usize],
409    ) -> SklResult<Array2<Float>> {
410        if columns.is_empty() {
411            return Ok(Array2::zeros((x.nrows(), 0)));
412        }
413
414        let mut result = Array2::zeros((x.nrows(), columns.len()));
415        for (col_idx, &original_col) in columns.iter().enumerate() {
416            if original_col >= x.ncols() {
417                return Err(SklearsError::InvalidInput(format!(
418                    "Column index {original_col} out of bounds"
419                )));
420            }
421            result.column_mut(col_idx).assign(&x.column(original_col));
422        }
423        Ok(result)
424    }
425
426    /// Get information about fitted transformers
427    #[must_use]
428    pub fn get_transformer_info(&self) -> Vec<(String, Vec<usize>)> {
429        self.state
430            .fitted_transformers
431            .iter()
432            .map(|(name, _, columns)| (name.clone(), columns.clone()))
433            .collect()
434    }
435
436    /// Get number of output features
437    #[must_use]
438    pub fn n_features_out(&self) -> usize {
439        self.state
440            .output_indices
441            .iter()
442            .map(std::vec::Vec::len)
443            .sum()
444    }
445}
446
447/// Column transformer builder for fluent construction
448#[derive(Debug, Clone)]
449pub struct ColumnTransformerBuilder {
450    transformer_names: Vec<String>,
451    transformer_columns: Vec<Vec<usize>>,
452    remainder: String,
453    sparse_threshold: f64,
454    n_jobs: Option<i32>,
455    transformer_weights: Option<HashMap<String, f64>>,
456}
457
458impl ColumnTransformerBuilder {
459    /// Create a new builder
460    #[must_use]
461    pub fn new() -> Self {
462        Self {
463            transformer_names: Vec::new(),
464            transformer_columns: Vec::new(),
465            remainder: "drop".to_string(),
466            sparse_threshold: 0.3,
467            n_jobs: None,
468            transformer_weights: None,
469        }
470    }
471
472    /// Add a transformer
473    #[must_use]
474    pub fn transformer(mut self, name: String, columns: Vec<usize>) -> Self {
475        self.transformer_names.push(name);
476        self.transformer_columns.push(columns);
477        self
478    }
479
480    /// Set remainder strategy
481    #[must_use]
482    pub fn remainder(mut self, remainder: String) -> Self {
483        self.remainder = remainder;
484        self
485    }
486
487    /// Set sparse threshold
488    #[must_use]
489    pub fn sparse_threshold(mut self, threshold: f64) -> Self {
490        self.sparse_threshold = threshold;
491        self
492    }
493
494    /// Set number of jobs
495    #[must_use]
496    pub fn n_jobs(mut self, n_jobs: Option<i32>) -> Self {
497        self.n_jobs = n_jobs;
498        self
499    }
500
501    /// Set transformer weights
502    #[must_use]
503    pub fn transformer_weights(mut self, weights: HashMap<String, f64>) -> Self {
504        self.transformer_weights = Some(weights);
505        self
506    }
507
508    /// Build the `ColumnTransformer`
509    #[must_use]
510    pub fn build(self) -> ColumnTransformer<Untrained> {
511        /// ColumnTransformer
512        ColumnTransformer {
513            state: Untrained,
514            transformer_names: self.transformer_names,
515            transformer_columns: self.transformer_columns,
516            remainder: self.remainder,
517            sparse_threshold: self.sparse_threshold,
518            n_jobs: self.n_jobs,
519            transformer_weights: self.transformer_weights,
520        }
521    }
522}
523
524impl Default for ColumnTransformerBuilder {
525    fn default() -> Self {
526        Self::new()
527    }
528}