sklears_preprocessing/
feature_union.rs

1//! Feature Union
2//!
3//! This module provides FeatureUnion which combines the output of multiple
4//! transformers by applying them all to the same input data and concatenating results.
5
6use scirs2_core::ndarray::{Array1, Array2, Axis};
7use sklears_core::{
8    error::{Result, SklearsError},
9    traits::{Estimator, Fit, Trained, Transform, Untrained},
10    types::Float,
11};
12use std::marker::PhantomData;
13
14/// Re-export the TransformerWrapper trait from column_transformer
15pub use crate::column_transformer::TransformerWrapper;
16
17/// Feature selection strategy for FeatureUnion
18#[derive(Debug, Clone, Copy, PartialEq)]
19pub enum FeatureSelectionStrategy {
20    /// No feature selection (keep all features)
21    None,
22    /// Select top k features based on variance
23    VarianceThreshold(Float),
24    /// Select top k features by count
25    TopK(usize),
26    /// Select features with importance above threshold
27    ImportanceThreshold(Float),
28    /// Select top percentage of features
29    TopPercentile(Float),
30}
31
32impl Default for FeatureSelectionStrategy {
33    fn default() -> Self {
34        Self::None
35    }
36}
37
38/// Feature importance calculation method
39#[derive(Debug, Clone, Copy, PartialEq)]
40pub enum FeatureImportanceMethod {
41    /// Use variance as importance measure
42    Variance,
43    /// Use absolute mean as importance measure
44    AbsoluteMean,
45    /// Use L1 norm as importance measure
46    L1Norm,
47    /// Use L2 norm as importance measure
48    L2Norm,
49    /// Use correlation with first principal component
50    PrincipalComponent,
51}
52
53impl Default for FeatureImportanceMethod {
54    fn default() -> Self {
55        Self::Variance
56    }
57}
58
59/// A transformer step in the feature union
60#[derive(Debug, Clone)]
61pub struct FeatureUnionStep {
62    /// Name of the transformer step
63    pub name: String,
64    /// The transformer (boxed for dynamic dispatch)
65    pub transformer: Box<dyn TransformerWrapper>,
66    /// Weight for this transformer's output (optional)
67    pub weight: Option<Float>,
68}
69
70/// Configuration for FeatureUnion
71#[derive(Debug, Clone)]
72pub struct FeatureUnionConfig {
73    /// Whether to use parallel processing
74    pub n_jobs: Option<usize>,
75    /// Whether to validate input
76    pub validate_input: bool,
77    /// Whether to preserve transformer order in output
78    pub preserve_order: bool,
79    /// Feature selection strategy
80    pub feature_selection: FeatureSelectionStrategy,
81    /// Feature importance calculation method
82    pub importance_method: FeatureImportanceMethod,
83    /// Whether to enable feature selection
84    pub enable_feature_selection: bool,
85}
86
87impl Default for FeatureUnionConfig {
88    fn default() -> Self {
89        Self {
90            n_jobs: None,
91            validate_input: true,
92            preserve_order: true,
93            feature_selection: FeatureSelectionStrategy::None,
94            importance_method: FeatureImportanceMethod::Variance,
95            enable_feature_selection: false,
96        }
97    }
98}
99
100/// FeatureUnion concatenates the results of multiple transformers
101///
102/// Unlike ColumnTransformer which applies different transformers to different columns,
103/// FeatureUnion applies all transformers to the same input data and concatenates
104/// their outputs column-wise.
105///
106/// This is useful for creating feature combinations, such as applying both
107/// PCA and polynomial features to the same input.
108#[derive(Debug)]
109pub struct FeatureUnion<State = Untrained> {
110    config: FeatureUnionConfig,
111    transformers: Vec<FeatureUnionStep>,
112    state: PhantomData<State>,
113    // Fitted parameters
114    fitted_transformers_: Option<Vec<FeatureUnionStep>>,
115    n_features_in_: Option<usize>,
116    n_features_out_: Option<usize>,
117    transformer_weights_: Option<Vec<Float>>,
118    // Feature selection parameters
119    selected_features_: Option<Vec<usize>>,
120    feature_importances_: Option<Array1<Float>>,
121    feature_names_: Option<Vec<String>>,
122}
123
124impl FeatureUnion<Untrained> {
125    /// Create a new FeatureUnion
126    pub fn new() -> Self {
127        Self {
128            config: FeatureUnionConfig::default(),
129            transformers: Vec::new(),
130            state: PhantomData,
131            fitted_transformers_: None,
132            n_features_in_: None,
133            n_features_out_: None,
134            transformer_weights_: None,
135            selected_features_: None,
136            feature_importances_: None,
137            feature_names_: None,
138        }
139    }
140
141    /// Add a transformer to the union
142    pub fn add_transformer<T>(mut self, name: &str, transformer: T) -> Self
143    where
144        T: TransformerWrapper + 'static,
145    {
146        self.transformers.push(FeatureUnionStep {
147            name: name.to_string(),
148            transformer: Box::new(transformer),
149            weight: None,
150        });
151        self
152    }
153
154    /// Add a weighted transformer to the union
155    pub fn add_weighted_transformer<T>(mut self, name: &str, transformer: T, weight: Float) -> Self
156    where
157        T: TransformerWrapper + 'static,
158    {
159        self.transformers.push(FeatureUnionStep {
160            name: name.to_string(),
161            transformer: Box::new(transformer),
162            weight: Some(weight),
163        });
164        self
165    }
166
167    /// Set number of parallel jobs
168    pub fn n_jobs(mut self, n_jobs: Option<usize>) -> Self {
169        self.config.n_jobs = n_jobs;
170        self
171    }
172
173    /// Set input validation
174    pub fn validate_input(mut self, validate: bool) -> Self {
175        self.config.validate_input = validate;
176        self
177    }
178
179    /// Set whether to preserve transformer order
180    pub fn preserve_order(mut self, preserve: bool) -> Self {
181        self.config.preserve_order = preserve;
182        self
183    }
184
185    /// Set feature selection strategy
186    pub fn feature_selection(mut self, strategy: FeatureSelectionStrategy) -> Self {
187        self.config.feature_selection = strategy;
188        self.config.enable_feature_selection = !matches!(strategy, FeatureSelectionStrategy::None);
189        self
190    }
191
192    /// Set feature importance calculation method
193    pub fn importance_method(mut self, method: FeatureImportanceMethod) -> Self {
194        self.config.importance_method = method;
195        self
196    }
197
198    /// Enable or disable feature selection
199    pub fn enable_feature_selection(mut self, enable: bool) -> Self {
200        self.config.enable_feature_selection = enable;
201        self
202    }
203
204    /// Calculate feature importance scores for the transformed data
205    fn calculate_feature_importance(
206        &self,
207        data: &Array2<Float>,
208        method: FeatureImportanceMethod,
209    ) -> Array1<Float> {
210        let n_features = data.ncols();
211        let mut importances = Array1::zeros(n_features);
212
213        match method {
214            FeatureImportanceMethod::Variance => {
215                for (i, col) in data.columns().into_iter().enumerate() {
216                    let mean = col.mean().unwrap_or(0.0);
217                    let variance = col.iter().map(|&x| (x - mean).powi(2)).sum::<Float>()
218                        / (col.len() as Float);
219                    importances[i] = variance;
220                }
221            }
222            FeatureImportanceMethod::AbsoluteMean => {
223                for (i, col) in data.columns().into_iter().enumerate() {
224                    let abs_mean =
225                        col.iter().map(|&x| x.abs()).sum::<Float>() / (col.len() as Float);
226                    importances[i] = abs_mean;
227                }
228            }
229            FeatureImportanceMethod::L1Norm => {
230                for (i, col) in data.columns().into_iter().enumerate() {
231                    let l1_norm = col.iter().map(|&x| x.abs()).sum::<Float>();
232                    importances[i] = l1_norm;
233                }
234            }
235            FeatureImportanceMethod::L2Norm => {
236                for (i, col) in data.columns().into_iter().enumerate() {
237                    let l2_norm = col.iter().map(|&x| x * x).sum::<Float>().sqrt();
238                    importances[i] = l2_norm;
239                }
240            }
241            FeatureImportanceMethod::PrincipalComponent => {
242                // Simplified correlation with the first principal component (mean-centered sum)
243                let means: Vec<Float> = (0..n_features)
244                    .map(|i| data.column(i).mean().unwrap_or(0.0))
245                    .collect();
246
247                // Calculate first principal component as weighted sum of all features
248                let mut pc1: Array1<Float> = Array1::zeros(data.nrows());
249                for (i, col) in data.columns().into_iter().enumerate() {
250                    let mean = means[i];
251                    for (row_idx, &val) in col.iter().enumerate() {
252                        pc1[row_idx] += (val - mean) / (n_features as Float).sqrt();
253                    }
254                }
255
256                // Calculate correlation of each feature with PC1
257                for (i, col) in data.columns().into_iter().enumerate() {
258                    let mean = means[i];
259                    let centered_col: Vec<Float> = col.iter().map(|&x| x - mean).collect();
260                    let correlation: Float = centered_col
261                        .iter()
262                        .zip(pc1.iter())
263                        .map(|(&x, &y): (&Float, &Float)| x * y)
264                        .sum::<Float>()
265                        / ((data.nrows() - 1) as Float);
266                    importances[i] = correlation.abs();
267                }
268            }
269        }
270
271        importances
272    }
273
274    /// Select features based on the configured strategy
275    fn select_features(
276        &self,
277        importances: &Array1<Float>,
278        strategy: FeatureSelectionStrategy,
279    ) -> Vec<usize> {
280        let n_features = importances.len();
281        let mut feature_indices: Vec<(usize, Float)> = importances
282            .iter()
283            .enumerate()
284            .map(|(i, &score)| (i, score))
285            .collect();
286
287        match strategy {
288            FeatureSelectionStrategy::None => (0..n_features).collect(),
289            FeatureSelectionStrategy::VarianceThreshold(threshold) => feature_indices
290                .into_iter()
291                .filter_map(|(idx, score)| if score >= threshold { Some(idx) } else { None })
292                .collect(),
293            FeatureSelectionStrategy::TopK(k) => {
294                feature_indices
295                    .sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
296                feature_indices
297                    .into_iter()
298                    .take(k.min(n_features))
299                    .map(|(idx, _)| idx)
300                    .collect()
301            }
302            FeatureSelectionStrategy::ImportanceThreshold(threshold) => feature_indices
303                .into_iter()
304                .filter_map(|(idx, score)| if score >= threshold { Some(idx) } else { None })
305                .collect(),
306            FeatureSelectionStrategy::TopPercentile(percentile) => {
307                if percentile <= 0.0 || percentile > 100.0 {
308                    return (0..n_features).collect();
309                }
310                feature_indices
311                    .sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
312                let k = ((n_features as Float * percentile / 100.0).ceil() as usize).max(1);
313                feature_indices
314                    .into_iter()
315                    .take(k)
316                    .map(|(idx, _)| idx)
317                    .collect()
318            }
319        }
320    }
321}
322
323impl Default for FeatureUnion<Untrained> {
324    fn default() -> Self {
325        Self::new()
326    }
327}
328
329impl Estimator for FeatureUnion<Untrained> {
330    type Config = FeatureUnionConfig;
331    type Error = SklearsError;
332    type Float = Float;
333
334    fn config(&self) -> &Self::Config {
335        &self.config
336    }
337}
338
339impl Estimator for FeatureUnion<Trained> {
340    type Config = FeatureUnionConfig;
341    type Error = SklearsError;
342    type Float = Float;
343
344    fn config(&self) -> &Self::Config {
345        &self.config
346    }
347}
348
349impl Fit<Array2<Float>, ()> for FeatureUnion<Untrained> {
350    type Fitted = FeatureUnion<Trained>;
351
352    fn fit(self, x: &Array2<Float>, _y: &()) -> Result<Self::Fitted> {
353        let (n_samples, n_features) = x.dim();
354
355        if n_samples == 0 {
356            return Err(SklearsError::InvalidInput(
357                "Cannot fit FeatureUnion on empty dataset".to_string(),
358            ));
359        }
360
361        if self.transformers.is_empty() {
362            return Err(SklearsError::InvalidInput(
363                "FeatureUnion requires at least one transformer".to_string(),
364            ));
365        }
366
367        // Fit each transformer and calculate output dimensions
368        let mut fitted_transformers = Vec::new();
369        let mut transformer_weights = Vec::new();
370        let mut all_transformed_data = Vec::new();
371
372        for step in &self.transformers {
373            // Fit and transform to get the fitted transformer and output shape
374            let transformed = step.transformer.fit_transform_wrapper(x)?;
375
376            // Validate that all transformers return the same number of samples
377            if transformed.nrows() != n_samples {
378                return Err(SklearsError::InvalidInput(format!(
379                    "Transformer '{}' returned {} samples, expected {}",
380                    step.name,
381                    transformed.nrows(),
382                    n_samples
383                )));
384            }
385
386            // Store the weight (default to 1.0 if not specified)
387            transformer_weights.push(step.weight.unwrap_or(1.0));
388
389            // Apply weight if specified
390            let mut weighted_transformed = transformed;
391            let weight = step.weight.unwrap_or(1.0);
392            if (weight - 1.0).abs() > Float::EPSILON {
393                weighted_transformed *= weight;
394            }
395
396            all_transformed_data.push(weighted_transformed);
397
398            // Create fitted transformer step (clone for now)
399            fitted_transformers.push(FeatureUnionStep {
400                name: step.name.clone(),
401                transformer: step.transformer.clone_box(),
402                weight: step.weight,
403            });
404        }
405
406        // Concatenate all transformed data for feature selection
407        let concatenated_data = concatenate_features(all_transformed_data)?;
408
409        // Perform feature selection if enabled
410        let (selected_features, feature_importances, total_output_features) = if self
411            .config
412            .enable_feature_selection
413        {
414            let importances = self
415                .calculate_feature_importance(&concatenated_data, self.config.importance_method);
416            let selected = self.select_features(&importances, self.config.feature_selection);
417            let n_selected = selected.len();
418            (Some(selected), Some(importances), n_selected)
419        } else {
420            (None, None, concatenated_data.ncols())
421        };
422
423        Ok(FeatureUnion {
424            config: self.config,
425            transformers: self.transformers,
426            state: PhantomData,
427            fitted_transformers_: Some(fitted_transformers),
428            n_features_in_: Some(n_features),
429            n_features_out_: Some(total_output_features),
430            transformer_weights_: Some(transformer_weights),
431            selected_features_: selected_features,
432            feature_importances_: feature_importances,
433            feature_names_: None, // Could be set if feature names are provided
434        })
435    }
436}
437
438impl Transform<Array2<Float>, Array2<Float>> for FeatureUnion<Trained> {
439    fn transform(&self, x: &Array2<Float>) -> Result<Array2<Float>> {
440        let (n_samples, n_features) = x.dim();
441
442        if Some(n_features) != self.n_features_in_ {
443            return Err(SklearsError::FeatureMismatch {
444                expected: self.n_features_in_.unwrap_or(0),
445                actual: n_features,
446            });
447        }
448
449        let fitted_transformers = self.fitted_transformers_.as_ref().unwrap();
450        let transformer_weights = self.transformer_weights_.as_ref().unwrap();
451
452        if fitted_transformers.is_empty() {
453            return Err(SklearsError::InvalidInput(
454                "No fitted transformers available".to_string(),
455            ));
456        }
457
458        // Apply each transformer and collect results
459        let mut transformed_parts = Vec::new();
460
461        for (i, step) in fitted_transformers.iter().enumerate() {
462            // Transform the input
463            let mut transformed = step.transformer.transform_wrapper(x)?;
464
465            // Validate output shape
466            if transformed.nrows() != n_samples {
467                return Err(SklearsError::InvalidInput(format!(
468                    "Transformer '{}' returned {} samples, expected {}",
469                    step.name,
470                    transformed.nrows(),
471                    n_samples
472                )));
473            }
474
475            // Apply weight if specified
476            let weight = transformer_weights[i];
477            if (weight - 1.0).abs() > Float::EPSILON {
478                transformed *= weight;
479            }
480
481            transformed_parts.push(transformed);
482        }
483
484        // Concatenate all transformed parts along the feature axis
485        let concatenated = concatenate_features(transformed_parts)?;
486
487        // Apply feature selection if enabled
488        if let Some(ref selected_features) = self.selected_features_ {
489            if selected_features.is_empty() {
490                return Err(SklearsError::InvalidInput(
491                    "No features were selected during fitting".to_string(),
492                ));
493            }
494
495            // Select only the chosen features
496            let selected_data = concatenated.select(Axis(1), selected_features);
497            Ok(selected_data)
498        } else {
499            Ok(concatenated)
500        }
501    }
502}
503
504/// Helper function to concatenate arrays along the feature (column) axis
505fn concatenate_features(parts: Vec<Array2<Float>>) -> Result<Array2<Float>> {
506    if parts.is_empty() {
507        return Err(SklearsError::InvalidInput(
508            "No arrays to concatenate".to_string(),
509        ));
510    }
511
512    if parts.len() == 1 {
513        return Ok(parts.into_iter().next().unwrap());
514    }
515
516    // Calculate total columns
517    let total_cols: usize = parts.iter().map(|p| p.ncols()).sum();
518    let n_rows = parts[0].nrows();
519
520    // Create result array
521    let mut result = Array2::zeros((n_rows, total_cols));
522
523    // Copy each part into the result
524    let mut col_offset = 0;
525    for part in parts {
526        let part_cols = part.ncols();
527        result
528            .slice_mut(scirs2_core::ndarray::s![
529                ..,
530                col_offset..col_offset + part_cols
531            ])
532            .assign(&part);
533        col_offset += part_cols;
534    }
535
536    Ok(result)
537}
538
539impl FeatureUnion<Trained> {
540    /// Get the number of input features
541    pub fn n_features_in(&self) -> usize {
542        self.n_features_in_.unwrap()
543    }
544
545    /// Get the number of output features
546    pub fn n_features_out(&self) -> usize {
547        self.n_features_out_.unwrap()
548    }
549
550    /// Get the fitted transformers
551    pub fn get_transformers(&self) -> &Vec<FeatureUnionStep> {
552        self.fitted_transformers_.as_ref().unwrap()
553    }
554
555    /// Get the weights used for each transformer
556    pub fn get_weights(&self) -> &Vec<Float> {
557        self.transformer_weights_.as_ref().unwrap()
558    }
559
560    /// Get the selected feature indices (if feature selection is enabled)
561    pub fn get_selected_features(&self) -> Option<&Vec<usize>> {
562        self.selected_features_.as_ref()
563    }
564
565    /// Get the feature importance scores (if feature selection is enabled)
566    pub fn get_feature_importances(&self) -> Option<&Array1<Float>> {
567        self.feature_importances_.as_ref()
568    }
569
570    /// Get the number of features that were selected
571    pub fn n_features_selected(&self) -> usize {
572        self.selected_features_
573            .as_ref()
574            .map(|features| features.len())
575            .unwrap_or_else(|| self.n_features_out())
576    }
577
578    /// Check if feature selection is enabled
579    pub fn is_feature_selection_enabled(&self) -> bool {
580        self.selected_features_.is_some()
581    }
582}
583
584#[allow(non_snake_case)]
585#[cfg(test)]
586mod tests {
587    use super::*;
588    use scirs2_core::ndarray::array;
589
590    // Mock transformer for testing
591    #[derive(Debug, Clone)]
592    struct MockTransformer {
593        scale: Float,
594        output_features: Option<usize>,
595    }
596
597    impl TransformerWrapper for MockTransformer {
598        fn fit_transform_wrapper(&self, x: &Array2<Float>) -> Result<Array2<Float>> {
599            self.transform_wrapper(x)
600        }
601
602        fn transform_wrapper(&self, x: &Array2<Float>) -> Result<Array2<Float>> {
603            let result = x * self.scale;
604
605            // If output_features is specified, duplicate or reduce features
606            if let Some(out_features) = self.output_features {
607                let n_rows = result.nrows();
608                let mut output = Array2::zeros((n_rows, out_features));
609
610                for i in 0..out_features {
611                    let source_col = i % result.ncols();
612                    output.column_mut(i).assign(&result.column(source_col));
613                }
614
615                Ok(output)
616            } else {
617                Ok(result)
618            }
619        }
620
621        fn get_n_features_out(&self) -> Option<usize> {
622            self.output_features
623        }
624
625        fn clone_box(&self) -> Box<dyn TransformerWrapper> {
626            Box::new(self.clone())
627        }
628    }
629
630    #[test]
631    fn test_feature_union_basic() {
632        let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0],];
633
634        let fu = FeatureUnion::new()
635            .add_transformer(
636                "scale_by_2",
637                MockTransformer {
638                    scale: 2.0,
639                    output_features: None,
640                },
641            )
642            .add_transformer(
643                "scale_by_3",
644                MockTransformer {
645                    scale: 3.0,
646                    output_features: None,
647                },
648            );
649
650        let fitted_fu = fu.fit(&x, &()).unwrap();
651        let result = fitted_fu.transform(&x).unwrap();
652
653        // Should have 4 features: [original*2, original*3]
654        assert_eq!(result.dim(), (3, 4));
655
656        // First transformer output (scale by 2)
657        assert_eq!(result[[0, 0]], 2.0); // 1.0 * 2
658        assert_eq!(result[[0, 1]], 4.0); // 2.0 * 2
659
660        // Second transformer output (scale by 3)
661        assert_eq!(result[[0, 2]], 3.0); // 1.0 * 3
662        assert_eq!(result[[0, 3]], 6.0); // 2.0 * 3
663    }
664
665    #[test]
666    fn test_feature_union_weighted() {
667        let x = array![[1.0, 2.0], [3.0, 4.0],];
668
669        let fu = FeatureUnion::new().add_weighted_transformer(
670            "weighted",
671            MockTransformer {
672                scale: 1.0,
673                output_features: None,
674            },
675            2.0,
676        );
677
678        let fitted_fu = fu.fit(&x, &()).unwrap();
679        let result = fitted_fu.transform(&x).unwrap();
680
681        // Features should be scaled by weight (2.0)
682        assert_eq!(result[[0, 0]], 2.0); // 1.0 * 1.0 * 2.0
683        assert_eq!(result[[0, 1]], 4.0); // 2.0 * 1.0 * 2.0
684    }
685
686    #[test]
687    fn test_feature_union_different_output_sizes() {
688        let x = array![[1.0, 2.0], [3.0, 4.0],];
689
690        let fu = FeatureUnion::new()
691            .add_transformer(
692                "identity",
693                MockTransformer {
694                    scale: 1.0,
695                    output_features: None,
696                },
697            ) // 2 features out
698            .add_transformer(
699                "expand",
700                MockTransformer {
701                    scale: 1.0,
702                    output_features: Some(3),
703                },
704            ); // 3 features out
705
706        let fitted_fu = fu.fit(&x, &()).unwrap();
707        let result = fitted_fu.transform(&x).unwrap();
708
709        // Should have 5 features total (2 + 3)
710        assert_eq!(result.dim(), (2, 5));
711        assert_eq!(fitted_fu.n_features_out(), 5);
712    }
713
714    #[test]
715    fn test_feature_union_empty_transformers() {
716        let x = array![[1.0, 2.0], [3.0, 4.0],];
717
718        let fu = FeatureUnion::new();
719
720        let result = fu.fit(&x, &());
721        assert!(result.is_err());
722    }
723
724    #[test]
725    fn test_feature_union_empty_data() {
726        let x_empty: Array2<Float> = Array2::zeros((0, 2));
727
728        let fu = FeatureUnion::new().add_transformer(
729            "test",
730            MockTransformer {
731                scale: 1.0,
732                output_features: None,
733            },
734        );
735
736        let result = fu.fit(&x_empty, &());
737        assert!(result.is_err());
738    }
739
740    #[test]
741    fn test_feature_union_feature_mismatch() {
742        let x_train = array![[1.0, 2.0], [3.0, 4.0],];
743
744        let x_test = array![
745            [1.0, 2.0, 3.0], // Wrong number of features
746            [4.0, 5.0, 6.0],
747        ];
748
749        let fu = FeatureUnion::new().add_transformer(
750            "test",
751            MockTransformer {
752                scale: 1.0,
753                output_features: None,
754            },
755        );
756
757        let fitted_fu = fu.fit(&x_train, &()).unwrap();
758        let result = fitted_fu.transform(&x_test);
759
760        assert!(result.is_err());
761        if let Err(SklearsError::FeatureMismatch { expected, actual }) = result {
762            assert_eq!(expected, 2);
763            assert_eq!(actual, 3);
764        } else {
765            panic!("Expected FeatureMismatch error");
766        }
767    }
768
769    #[test]
770    fn test_concatenate_features() {
771        let part1 = array![[1.0, 2.0], [3.0, 4.0],];
772
773        let part2 = array![[5.0], [6.0],];
774
775        let part3 = array![[7.0, 8.0, 9.0], [10.0, 11.0, 12.0],];
776
777        let parts = vec![part1, part2, part3];
778        let result = concatenate_features(parts).unwrap();
779
780        assert_eq!(result.dim(), (2, 6)); // 2 + 1 + 3 columns
781
782        // Check values
783        assert_eq!(result[[0, 0]], 1.0);
784        assert_eq!(result[[0, 1]], 2.0);
785        assert_eq!(result[[0, 2]], 5.0);
786        assert_eq!(result[[0, 3]], 7.0);
787        assert_eq!(result[[0, 4]], 8.0);
788        assert_eq!(result[[0, 5]], 9.0);
789    }
790
791    #[test]
792    fn test_feature_selection_variance_threshold() {
793        let x = array![
794            [1.0, 1.0, 1.0, 2.0], // Low variance features + high variance
795            [1.1, 1.0, 1.0, 4.0],
796            [0.9, 1.0, 1.0, 6.0],
797            [1.0, 1.0, 1.0, 8.0],
798        ];
799
800        let fu = FeatureUnion::new()
801            .add_transformer(
802                "identity",
803                MockTransformer {
804                    scale: 1.0,
805                    output_features: None,
806                },
807            )
808            .feature_selection(FeatureSelectionStrategy::VarianceThreshold(0.1))
809            .importance_method(FeatureImportanceMethod::Variance);
810
811        let fitted_fu = fu.fit(&x, &()).unwrap();
812        let _result = fitted_fu.transform(&x).unwrap();
813
814        // Should select only features with variance > 0.1 (likely just the last column)
815        assert!(fitted_fu.is_feature_selection_enabled());
816        assert!(fitted_fu.n_features_selected() <= 4);
817        assert!(fitted_fu.get_feature_importances().is_some());
818        assert!(fitted_fu.get_selected_features().is_some());
819    }
820
821    #[test]
822    fn test_feature_selection_top_k() {
823        let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
824
825        let fu = FeatureUnion::new()
826            .add_transformer(
827                "scale_by_2",
828                MockTransformer {
829                    scale: 2.0,
830                    output_features: None,
831                },
832            )
833            .add_transformer(
834                "scale_by_3",
835                MockTransformer {
836                    scale: 3.0,
837                    output_features: None,
838                },
839            )
840            .feature_selection(FeatureSelectionStrategy::TopK(2))
841            .importance_method(FeatureImportanceMethod::L2Norm);
842
843        let fitted_fu = fu.fit(&x, &()).unwrap();
844        let result = fitted_fu.transform(&x).unwrap();
845
846        // Should select exactly 2 features
847        assert_eq!(fitted_fu.n_features_selected(), 2);
848        assert_eq!(result.ncols(), 2);
849        assert!(fitted_fu.is_feature_selection_enabled());
850    }
851
852    #[test]
853    fn test_feature_selection_top_percentile() {
854        let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
855
856        let fu = FeatureUnion::new()
857            .add_transformer(
858                "expand",
859                MockTransformer {
860                    scale: 1.0,
861                    output_features: Some(6),
862                },
863            )
864            .feature_selection(FeatureSelectionStrategy::TopPercentile(50.0))
865            .importance_method(FeatureImportanceMethod::AbsoluteMean);
866
867        let fitted_fu = fu.fit(&x, &()).unwrap();
868        let result = fitted_fu.transform(&x).unwrap();
869
870        // Should select 50% of features (3 out of 6)
871        assert_eq!(fitted_fu.n_features_selected(), 3);
872        assert_eq!(result.ncols(), 3);
873        assert!(fitted_fu.is_feature_selection_enabled());
874    }
875
876    #[test]
877    fn test_feature_selection_disabled() {
878        let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
879
880        let fu = FeatureUnion::new()
881            .add_transformer(
882                "identity",
883                MockTransformer {
884                    scale: 1.0,
885                    output_features: None,
886                },
887            )
888            .feature_selection(FeatureSelectionStrategy::None);
889
890        let fitted_fu = fu.fit(&x, &()).unwrap();
891        let result = fitted_fu.transform(&x).unwrap();
892
893        // Should keep all features
894        assert!(!fitted_fu.is_feature_selection_enabled());
895        assert_eq!(fitted_fu.n_features_selected(), 2);
896        assert_eq!(result.ncols(), 2);
897        assert!(fitted_fu.get_feature_importances().is_none());
898        assert!(fitted_fu.get_selected_features().is_none());
899    }
900
901    #[test]
902    fn test_feature_importance_methods() {
903        let x = array![[1.0, 0.0, 10.0], [2.0, 0.0, 20.0], [3.0, 0.0, 30.0],];
904
905        let fu = FeatureUnion::new()
906            .add_transformer(
907                "identity",
908                MockTransformer {
909                    scale: 1.0,
910                    output_features: None,
911                },
912            )
913            .enable_feature_selection(true)
914            .importance_method(FeatureImportanceMethod::Variance);
915
916        let fitted_fu = fu.fit(&x, &()).unwrap();
917        let importances = fitted_fu.get_feature_importances().unwrap();
918
919        // Third column should have highest variance
920        assert!(importances[2] > importances[0]);
921        assert!(importances[0] > importances[1]); // Second column has zero variance
922    }
923
924    #[test]
925    fn test_get_methods() {
926        let x = array![[1.0, 2.0], [3.0, 4.0],];
927
928        let fu = FeatureUnion::new()
929            .add_weighted_transformer(
930                "test1",
931                MockTransformer {
932                    scale: 1.0,
933                    output_features: None,
934                },
935                2.0,
936            )
937            .add_transformer(
938                "test2",
939                MockTransformer {
940                    scale: 1.0,
941                    output_features: Some(3),
942                },
943            );
944
945        let fitted_fu = fu.fit(&x, &()).unwrap();
946
947        assert_eq!(fitted_fu.n_features_in(), 2);
948        assert_eq!(fitted_fu.n_features_out(), 5); // 2 + 3
949        assert_eq!(fitted_fu.get_transformers().len(), 2);
950        assert_eq!(fitted_fu.get_weights(), &vec![2.0, 1.0]);
951    }
952}