1use scirs2_core::ndarray::{Array1, Array2, Axis};
7use sklears_core::{
8 error::{Result, SklearsError},
9 traits::{Estimator, Fit, Trained, Transform, Untrained},
10 types::Float,
11};
12use std::marker::PhantomData;
13
14pub use crate::column_transformer::TransformerWrapper;
16
17#[derive(Debug, Clone, Copy, PartialEq)]
19pub enum FeatureSelectionStrategy {
20 None,
22 VarianceThreshold(Float),
24 TopK(usize),
26 ImportanceThreshold(Float),
28 TopPercentile(Float),
30}
31
32impl Default for FeatureSelectionStrategy {
33 fn default() -> Self {
34 Self::None
35 }
36}
37
38#[derive(Debug, Clone, Copy, PartialEq)]
40pub enum FeatureImportanceMethod {
41 Variance,
43 AbsoluteMean,
45 L1Norm,
47 L2Norm,
49 PrincipalComponent,
51}
52
53impl Default for FeatureImportanceMethod {
54 fn default() -> Self {
55 Self::Variance
56 }
57}
58
59#[derive(Debug, Clone)]
61pub struct FeatureUnionStep {
62 pub name: String,
64 pub transformer: Box<dyn TransformerWrapper>,
66 pub weight: Option<Float>,
68}
69
70#[derive(Debug, Clone)]
72pub struct FeatureUnionConfig {
73 pub n_jobs: Option<usize>,
75 pub validate_input: bool,
77 pub preserve_order: bool,
79 pub feature_selection: FeatureSelectionStrategy,
81 pub importance_method: FeatureImportanceMethod,
83 pub enable_feature_selection: bool,
85}
86
87impl Default for FeatureUnionConfig {
88 fn default() -> Self {
89 Self {
90 n_jobs: None,
91 validate_input: true,
92 preserve_order: true,
93 feature_selection: FeatureSelectionStrategy::None,
94 importance_method: FeatureImportanceMethod::Variance,
95 enable_feature_selection: false,
96 }
97 }
98}
99
100#[derive(Debug)]
109pub struct FeatureUnion<State = Untrained> {
110 config: FeatureUnionConfig,
111 transformers: Vec<FeatureUnionStep>,
112 state: PhantomData<State>,
113 fitted_transformers_: Option<Vec<FeatureUnionStep>>,
115 n_features_in_: Option<usize>,
116 n_features_out_: Option<usize>,
117 transformer_weights_: Option<Vec<Float>>,
118 selected_features_: Option<Vec<usize>>,
120 feature_importances_: Option<Array1<Float>>,
121 feature_names_: Option<Vec<String>>,
122}
123
124impl FeatureUnion<Untrained> {
125 pub fn new() -> Self {
127 Self {
128 config: FeatureUnionConfig::default(),
129 transformers: Vec::new(),
130 state: PhantomData,
131 fitted_transformers_: None,
132 n_features_in_: None,
133 n_features_out_: None,
134 transformer_weights_: None,
135 selected_features_: None,
136 feature_importances_: None,
137 feature_names_: None,
138 }
139 }
140
141 pub fn add_transformer<T>(mut self, name: &str, transformer: T) -> Self
143 where
144 T: TransformerWrapper + 'static,
145 {
146 self.transformers.push(FeatureUnionStep {
147 name: name.to_string(),
148 transformer: Box::new(transformer),
149 weight: None,
150 });
151 self
152 }
153
154 pub fn add_weighted_transformer<T>(mut self, name: &str, transformer: T, weight: Float) -> Self
156 where
157 T: TransformerWrapper + 'static,
158 {
159 self.transformers.push(FeatureUnionStep {
160 name: name.to_string(),
161 transformer: Box::new(transformer),
162 weight: Some(weight),
163 });
164 self
165 }
166
167 pub fn n_jobs(mut self, n_jobs: Option<usize>) -> Self {
169 self.config.n_jobs = n_jobs;
170 self
171 }
172
173 pub fn validate_input(mut self, validate: bool) -> Self {
175 self.config.validate_input = validate;
176 self
177 }
178
179 pub fn preserve_order(mut self, preserve: bool) -> Self {
181 self.config.preserve_order = preserve;
182 self
183 }
184
185 pub fn feature_selection(mut self, strategy: FeatureSelectionStrategy) -> Self {
187 self.config.feature_selection = strategy;
188 self.config.enable_feature_selection = !matches!(strategy, FeatureSelectionStrategy::None);
189 self
190 }
191
192 pub fn importance_method(mut self, method: FeatureImportanceMethod) -> Self {
194 self.config.importance_method = method;
195 self
196 }
197
198 pub fn enable_feature_selection(mut self, enable: bool) -> Self {
200 self.config.enable_feature_selection = enable;
201 self
202 }
203
204 fn calculate_feature_importance(
206 &self,
207 data: &Array2<Float>,
208 method: FeatureImportanceMethod,
209 ) -> Array1<Float> {
210 let n_features = data.ncols();
211 let mut importances = Array1::zeros(n_features);
212
213 match method {
214 FeatureImportanceMethod::Variance => {
215 for (i, col) in data.columns().into_iter().enumerate() {
216 let mean = col.mean().unwrap_or(0.0);
217 let variance = col.iter().map(|&x| (x - mean).powi(2)).sum::<Float>()
218 / (col.len() as Float);
219 importances[i] = variance;
220 }
221 }
222 FeatureImportanceMethod::AbsoluteMean => {
223 for (i, col) in data.columns().into_iter().enumerate() {
224 let abs_mean =
225 col.iter().map(|&x| x.abs()).sum::<Float>() / (col.len() as Float);
226 importances[i] = abs_mean;
227 }
228 }
229 FeatureImportanceMethod::L1Norm => {
230 for (i, col) in data.columns().into_iter().enumerate() {
231 let l1_norm = col.iter().map(|&x| x.abs()).sum::<Float>();
232 importances[i] = l1_norm;
233 }
234 }
235 FeatureImportanceMethod::L2Norm => {
236 for (i, col) in data.columns().into_iter().enumerate() {
237 let l2_norm = col.iter().map(|&x| x * x).sum::<Float>().sqrt();
238 importances[i] = l2_norm;
239 }
240 }
241 FeatureImportanceMethod::PrincipalComponent => {
242 let means: Vec<Float> = (0..n_features)
244 .map(|i| data.column(i).mean().unwrap_or(0.0))
245 .collect();
246
247 let mut pc1: Array1<Float> = Array1::zeros(data.nrows());
249 for (i, col) in data.columns().into_iter().enumerate() {
250 let mean = means[i];
251 for (row_idx, &val) in col.iter().enumerate() {
252 pc1[row_idx] += (val - mean) / (n_features as Float).sqrt();
253 }
254 }
255
256 for (i, col) in data.columns().into_iter().enumerate() {
258 let mean = means[i];
259 let centered_col: Vec<Float> = col.iter().map(|&x| x - mean).collect();
260 let correlation: Float = centered_col
261 .iter()
262 .zip(pc1.iter())
263 .map(|(&x, &y): (&Float, &Float)| x * y)
264 .sum::<Float>()
265 / ((data.nrows() - 1) as Float);
266 importances[i] = correlation.abs();
267 }
268 }
269 }
270
271 importances
272 }
273
274 fn select_features(
276 &self,
277 importances: &Array1<Float>,
278 strategy: FeatureSelectionStrategy,
279 ) -> Vec<usize> {
280 let n_features = importances.len();
281 let mut feature_indices: Vec<(usize, Float)> = importances
282 .iter()
283 .enumerate()
284 .map(|(i, &score)| (i, score))
285 .collect();
286
287 match strategy {
288 FeatureSelectionStrategy::None => (0..n_features).collect(),
289 FeatureSelectionStrategy::VarianceThreshold(threshold) => feature_indices
290 .into_iter()
291 .filter_map(|(idx, score)| if score >= threshold { Some(idx) } else { None })
292 .collect(),
293 FeatureSelectionStrategy::TopK(k) => {
294 feature_indices
295 .sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
296 feature_indices
297 .into_iter()
298 .take(k.min(n_features))
299 .map(|(idx, _)| idx)
300 .collect()
301 }
302 FeatureSelectionStrategy::ImportanceThreshold(threshold) => feature_indices
303 .into_iter()
304 .filter_map(|(idx, score)| if score >= threshold { Some(idx) } else { None })
305 .collect(),
306 FeatureSelectionStrategy::TopPercentile(percentile) => {
307 if percentile <= 0.0 || percentile > 100.0 {
308 return (0..n_features).collect();
309 }
310 feature_indices
311 .sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
312 let k = ((n_features as Float * percentile / 100.0).ceil() as usize).max(1);
313 feature_indices
314 .into_iter()
315 .take(k)
316 .map(|(idx, _)| idx)
317 .collect()
318 }
319 }
320 }
321}
322
323impl Default for FeatureUnion<Untrained> {
324 fn default() -> Self {
325 Self::new()
326 }
327}
328
329impl Estimator for FeatureUnion<Untrained> {
330 type Config = FeatureUnionConfig;
331 type Error = SklearsError;
332 type Float = Float;
333
334 fn config(&self) -> &Self::Config {
335 &self.config
336 }
337}
338
339impl Estimator for FeatureUnion<Trained> {
340 type Config = FeatureUnionConfig;
341 type Error = SklearsError;
342 type Float = Float;
343
344 fn config(&self) -> &Self::Config {
345 &self.config
346 }
347}
348
349impl Fit<Array2<Float>, ()> for FeatureUnion<Untrained> {
350 type Fitted = FeatureUnion<Trained>;
351
352 fn fit(self, x: &Array2<Float>, _y: &()) -> Result<Self::Fitted> {
353 let (n_samples, n_features) = x.dim();
354
355 if n_samples == 0 {
356 return Err(SklearsError::InvalidInput(
357 "Cannot fit FeatureUnion on empty dataset".to_string(),
358 ));
359 }
360
361 if self.transformers.is_empty() {
362 return Err(SklearsError::InvalidInput(
363 "FeatureUnion requires at least one transformer".to_string(),
364 ));
365 }
366
367 let mut fitted_transformers = Vec::new();
369 let mut transformer_weights = Vec::new();
370 let mut all_transformed_data = Vec::new();
371
372 for step in &self.transformers {
373 let transformed = step.transformer.fit_transform_wrapper(x)?;
375
376 if transformed.nrows() != n_samples {
378 return Err(SklearsError::InvalidInput(format!(
379 "Transformer '{}' returned {} samples, expected {}",
380 step.name,
381 transformed.nrows(),
382 n_samples
383 )));
384 }
385
386 transformer_weights.push(step.weight.unwrap_or(1.0));
388
389 let mut weighted_transformed = transformed;
391 let weight = step.weight.unwrap_or(1.0);
392 if (weight - 1.0).abs() > Float::EPSILON {
393 weighted_transformed *= weight;
394 }
395
396 all_transformed_data.push(weighted_transformed);
397
398 fitted_transformers.push(FeatureUnionStep {
400 name: step.name.clone(),
401 transformer: step.transformer.clone_box(),
402 weight: step.weight,
403 });
404 }
405
406 let concatenated_data = concatenate_features(all_transformed_data)?;
408
409 let (selected_features, feature_importances, total_output_features) = if self
411 .config
412 .enable_feature_selection
413 {
414 let importances = self
415 .calculate_feature_importance(&concatenated_data, self.config.importance_method);
416 let selected = self.select_features(&importances, self.config.feature_selection);
417 let n_selected = selected.len();
418 (Some(selected), Some(importances), n_selected)
419 } else {
420 (None, None, concatenated_data.ncols())
421 };
422
423 Ok(FeatureUnion {
424 config: self.config,
425 transformers: self.transformers,
426 state: PhantomData,
427 fitted_transformers_: Some(fitted_transformers),
428 n_features_in_: Some(n_features),
429 n_features_out_: Some(total_output_features),
430 transformer_weights_: Some(transformer_weights),
431 selected_features_: selected_features,
432 feature_importances_: feature_importances,
433 feature_names_: None, })
435 }
436}
437
438impl Transform<Array2<Float>, Array2<Float>> for FeatureUnion<Trained> {
439 fn transform(&self, x: &Array2<Float>) -> Result<Array2<Float>> {
440 let (n_samples, n_features) = x.dim();
441
442 if Some(n_features) != self.n_features_in_ {
443 return Err(SklearsError::FeatureMismatch {
444 expected: self.n_features_in_.unwrap_or(0),
445 actual: n_features,
446 });
447 }
448
449 let fitted_transformers = self.fitted_transformers_.as_ref().unwrap();
450 let transformer_weights = self.transformer_weights_.as_ref().unwrap();
451
452 if fitted_transformers.is_empty() {
453 return Err(SklearsError::InvalidInput(
454 "No fitted transformers available".to_string(),
455 ));
456 }
457
458 let mut transformed_parts = Vec::new();
460
461 for (i, step) in fitted_transformers.iter().enumerate() {
462 let mut transformed = step.transformer.transform_wrapper(x)?;
464
465 if transformed.nrows() != n_samples {
467 return Err(SklearsError::InvalidInput(format!(
468 "Transformer '{}' returned {} samples, expected {}",
469 step.name,
470 transformed.nrows(),
471 n_samples
472 )));
473 }
474
475 let weight = transformer_weights[i];
477 if (weight - 1.0).abs() > Float::EPSILON {
478 transformed *= weight;
479 }
480
481 transformed_parts.push(transformed);
482 }
483
484 let concatenated = concatenate_features(transformed_parts)?;
486
487 if let Some(ref selected_features) = self.selected_features_ {
489 if selected_features.is_empty() {
490 return Err(SklearsError::InvalidInput(
491 "No features were selected during fitting".to_string(),
492 ));
493 }
494
495 let selected_data = concatenated.select(Axis(1), selected_features);
497 Ok(selected_data)
498 } else {
499 Ok(concatenated)
500 }
501 }
502}
503
504fn concatenate_features(parts: Vec<Array2<Float>>) -> Result<Array2<Float>> {
506 if parts.is_empty() {
507 return Err(SklearsError::InvalidInput(
508 "No arrays to concatenate".to_string(),
509 ));
510 }
511
512 if parts.len() == 1 {
513 return Ok(parts.into_iter().next().unwrap());
514 }
515
516 let total_cols: usize = parts.iter().map(|p| p.ncols()).sum();
518 let n_rows = parts[0].nrows();
519
520 let mut result = Array2::zeros((n_rows, total_cols));
522
523 let mut col_offset = 0;
525 for part in parts {
526 let part_cols = part.ncols();
527 result
528 .slice_mut(scirs2_core::ndarray::s![
529 ..,
530 col_offset..col_offset + part_cols
531 ])
532 .assign(&part);
533 col_offset += part_cols;
534 }
535
536 Ok(result)
537}
538
539impl FeatureUnion<Trained> {
540 pub fn n_features_in(&self) -> usize {
542 self.n_features_in_.unwrap()
543 }
544
545 pub fn n_features_out(&self) -> usize {
547 self.n_features_out_.unwrap()
548 }
549
550 pub fn get_transformers(&self) -> &Vec<FeatureUnionStep> {
552 self.fitted_transformers_.as_ref().unwrap()
553 }
554
555 pub fn get_weights(&self) -> &Vec<Float> {
557 self.transformer_weights_.as_ref().unwrap()
558 }
559
560 pub fn get_selected_features(&self) -> Option<&Vec<usize>> {
562 self.selected_features_.as_ref()
563 }
564
565 pub fn get_feature_importances(&self) -> Option<&Array1<Float>> {
567 self.feature_importances_.as_ref()
568 }
569
570 pub fn n_features_selected(&self) -> usize {
572 self.selected_features_
573 .as_ref()
574 .map(|features| features.len())
575 .unwrap_or_else(|| self.n_features_out())
576 }
577
578 pub fn is_feature_selection_enabled(&self) -> bool {
580 self.selected_features_.is_some()
581 }
582}
583
584#[allow(non_snake_case)]
585#[cfg(test)]
586mod tests {
587 use super::*;
588 use scirs2_core::ndarray::array;
589
590 #[derive(Debug, Clone)]
592 struct MockTransformer {
593 scale: Float,
594 output_features: Option<usize>,
595 }
596
597 impl TransformerWrapper for MockTransformer {
598 fn fit_transform_wrapper(&self, x: &Array2<Float>) -> Result<Array2<Float>> {
599 self.transform_wrapper(x)
600 }
601
602 fn transform_wrapper(&self, x: &Array2<Float>) -> Result<Array2<Float>> {
603 let result = x * self.scale;
604
605 if let Some(out_features) = self.output_features {
607 let n_rows = result.nrows();
608 let mut output = Array2::zeros((n_rows, out_features));
609
610 for i in 0..out_features {
611 let source_col = i % result.ncols();
612 output.column_mut(i).assign(&result.column(source_col));
613 }
614
615 Ok(output)
616 } else {
617 Ok(result)
618 }
619 }
620
621 fn get_n_features_out(&self) -> Option<usize> {
622 self.output_features
623 }
624
625 fn clone_box(&self) -> Box<dyn TransformerWrapper> {
626 Box::new(self.clone())
627 }
628 }
629
630 #[test]
631 fn test_feature_union_basic() {
632 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0],];
633
634 let fu = FeatureUnion::new()
635 .add_transformer(
636 "scale_by_2",
637 MockTransformer {
638 scale: 2.0,
639 output_features: None,
640 },
641 )
642 .add_transformer(
643 "scale_by_3",
644 MockTransformer {
645 scale: 3.0,
646 output_features: None,
647 },
648 );
649
650 let fitted_fu = fu.fit(&x, &()).unwrap();
651 let result = fitted_fu.transform(&x).unwrap();
652
653 assert_eq!(result.dim(), (3, 4));
655
656 assert_eq!(result[[0, 0]], 2.0); assert_eq!(result[[0, 1]], 4.0); assert_eq!(result[[0, 2]], 3.0); assert_eq!(result[[0, 3]], 6.0); }
664
665 #[test]
666 fn test_feature_union_weighted() {
667 let x = array![[1.0, 2.0], [3.0, 4.0],];
668
669 let fu = FeatureUnion::new().add_weighted_transformer(
670 "weighted",
671 MockTransformer {
672 scale: 1.0,
673 output_features: None,
674 },
675 2.0,
676 );
677
678 let fitted_fu = fu.fit(&x, &()).unwrap();
679 let result = fitted_fu.transform(&x).unwrap();
680
681 assert_eq!(result[[0, 0]], 2.0); assert_eq!(result[[0, 1]], 4.0); }
685
686 #[test]
687 fn test_feature_union_different_output_sizes() {
688 let x = array![[1.0, 2.0], [3.0, 4.0],];
689
690 let fu = FeatureUnion::new()
691 .add_transformer(
692 "identity",
693 MockTransformer {
694 scale: 1.0,
695 output_features: None,
696 },
697 ) .add_transformer(
699 "expand",
700 MockTransformer {
701 scale: 1.0,
702 output_features: Some(3),
703 },
704 ); let fitted_fu = fu.fit(&x, &()).unwrap();
707 let result = fitted_fu.transform(&x).unwrap();
708
709 assert_eq!(result.dim(), (2, 5));
711 assert_eq!(fitted_fu.n_features_out(), 5);
712 }
713
714 #[test]
715 fn test_feature_union_empty_transformers() {
716 let x = array![[1.0, 2.0], [3.0, 4.0],];
717
718 let fu = FeatureUnion::new();
719
720 let result = fu.fit(&x, &());
721 assert!(result.is_err());
722 }
723
724 #[test]
725 fn test_feature_union_empty_data() {
726 let x_empty: Array2<Float> = Array2::zeros((0, 2));
727
728 let fu = FeatureUnion::new().add_transformer(
729 "test",
730 MockTransformer {
731 scale: 1.0,
732 output_features: None,
733 },
734 );
735
736 let result = fu.fit(&x_empty, &());
737 assert!(result.is_err());
738 }
739
740 #[test]
741 fn test_feature_union_feature_mismatch() {
742 let x_train = array![[1.0, 2.0], [3.0, 4.0],];
743
744 let x_test = array![
745 [1.0, 2.0, 3.0], [4.0, 5.0, 6.0],
747 ];
748
749 let fu = FeatureUnion::new().add_transformer(
750 "test",
751 MockTransformer {
752 scale: 1.0,
753 output_features: None,
754 },
755 );
756
757 let fitted_fu = fu.fit(&x_train, &()).unwrap();
758 let result = fitted_fu.transform(&x_test);
759
760 assert!(result.is_err());
761 if let Err(SklearsError::FeatureMismatch { expected, actual }) = result {
762 assert_eq!(expected, 2);
763 assert_eq!(actual, 3);
764 } else {
765 panic!("Expected FeatureMismatch error");
766 }
767 }
768
769 #[test]
770 fn test_concatenate_features() {
771 let part1 = array![[1.0, 2.0], [3.0, 4.0],];
772
773 let part2 = array![[5.0], [6.0],];
774
775 let part3 = array![[7.0, 8.0, 9.0], [10.0, 11.0, 12.0],];
776
777 let parts = vec![part1, part2, part3];
778 let result = concatenate_features(parts).unwrap();
779
780 assert_eq!(result.dim(), (2, 6)); assert_eq!(result[[0, 0]], 1.0);
784 assert_eq!(result[[0, 1]], 2.0);
785 assert_eq!(result[[0, 2]], 5.0);
786 assert_eq!(result[[0, 3]], 7.0);
787 assert_eq!(result[[0, 4]], 8.0);
788 assert_eq!(result[[0, 5]], 9.0);
789 }
790
791 #[test]
792 fn test_feature_selection_variance_threshold() {
793 let x = array![
794 [1.0, 1.0, 1.0, 2.0], [1.1, 1.0, 1.0, 4.0],
796 [0.9, 1.0, 1.0, 6.0],
797 [1.0, 1.0, 1.0, 8.0],
798 ];
799
800 let fu = FeatureUnion::new()
801 .add_transformer(
802 "identity",
803 MockTransformer {
804 scale: 1.0,
805 output_features: None,
806 },
807 )
808 .feature_selection(FeatureSelectionStrategy::VarianceThreshold(0.1))
809 .importance_method(FeatureImportanceMethod::Variance);
810
811 let fitted_fu = fu.fit(&x, &()).unwrap();
812 let _result = fitted_fu.transform(&x).unwrap();
813
814 assert!(fitted_fu.is_feature_selection_enabled());
816 assert!(fitted_fu.n_features_selected() <= 4);
817 assert!(fitted_fu.get_feature_importances().is_some());
818 assert!(fitted_fu.get_selected_features().is_some());
819 }
820
821 #[test]
822 fn test_feature_selection_top_k() {
823 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
824
825 let fu = FeatureUnion::new()
826 .add_transformer(
827 "scale_by_2",
828 MockTransformer {
829 scale: 2.0,
830 output_features: None,
831 },
832 )
833 .add_transformer(
834 "scale_by_3",
835 MockTransformer {
836 scale: 3.0,
837 output_features: None,
838 },
839 )
840 .feature_selection(FeatureSelectionStrategy::TopK(2))
841 .importance_method(FeatureImportanceMethod::L2Norm);
842
843 let fitted_fu = fu.fit(&x, &()).unwrap();
844 let result = fitted_fu.transform(&x).unwrap();
845
846 assert_eq!(fitted_fu.n_features_selected(), 2);
848 assert_eq!(result.ncols(), 2);
849 assert!(fitted_fu.is_feature_selection_enabled());
850 }
851
852 #[test]
853 fn test_feature_selection_top_percentile() {
854 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
855
856 let fu = FeatureUnion::new()
857 .add_transformer(
858 "expand",
859 MockTransformer {
860 scale: 1.0,
861 output_features: Some(6),
862 },
863 )
864 .feature_selection(FeatureSelectionStrategy::TopPercentile(50.0))
865 .importance_method(FeatureImportanceMethod::AbsoluteMean);
866
867 let fitted_fu = fu.fit(&x, &()).unwrap();
868 let result = fitted_fu.transform(&x).unwrap();
869
870 assert_eq!(fitted_fu.n_features_selected(), 3);
872 assert_eq!(result.ncols(), 3);
873 assert!(fitted_fu.is_feature_selection_enabled());
874 }
875
876 #[test]
877 fn test_feature_selection_disabled() {
878 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
879
880 let fu = FeatureUnion::new()
881 .add_transformer(
882 "identity",
883 MockTransformer {
884 scale: 1.0,
885 output_features: None,
886 },
887 )
888 .feature_selection(FeatureSelectionStrategy::None);
889
890 let fitted_fu = fu.fit(&x, &()).unwrap();
891 let result = fitted_fu.transform(&x).unwrap();
892
893 assert!(!fitted_fu.is_feature_selection_enabled());
895 assert_eq!(fitted_fu.n_features_selected(), 2);
896 assert_eq!(result.ncols(), 2);
897 assert!(fitted_fu.get_feature_importances().is_none());
898 assert!(fitted_fu.get_selected_features().is_none());
899 }
900
901 #[test]
902 fn test_feature_importance_methods() {
903 let x = array![[1.0, 0.0, 10.0], [2.0, 0.0, 20.0], [3.0, 0.0, 30.0],];
904
905 let fu = FeatureUnion::new()
906 .add_transformer(
907 "identity",
908 MockTransformer {
909 scale: 1.0,
910 output_features: None,
911 },
912 )
913 .enable_feature_selection(true)
914 .importance_method(FeatureImportanceMethod::Variance);
915
916 let fitted_fu = fu.fit(&x, &()).unwrap();
917 let importances = fitted_fu.get_feature_importances().unwrap();
918
919 assert!(importances[2] > importances[0]);
921 assert!(importances[0] > importances[1]); }
923
924 #[test]
925 fn test_get_methods() {
926 let x = array![[1.0, 2.0], [3.0, 4.0],];
927
928 let fu = FeatureUnion::new()
929 .add_weighted_transformer(
930 "test1",
931 MockTransformer {
932 scale: 1.0,
933 output_features: None,
934 },
935 2.0,
936 )
937 .add_transformer(
938 "test2",
939 MockTransformer {
940 scale: 1.0,
941 output_features: Some(3),
942 },
943 );
944
945 let fitted_fu = fu.fit(&x, &()).unwrap();
946
947 assert_eq!(fitted_fu.n_features_in(), 2);
948 assert_eq!(fitted_fu.n_features_out(), 5); assert_eq!(fitted_fu.get_transformers().len(), 2);
950 assert_eq!(fitted_fu.get_weights(), &vec![2.0, 1.0]);
951 }
952}