1use scirs2_core::ndarray::{Array1, Array2, Axis};
7use sklears_core::{
8 error::{Result, SklearsError},
9 traits::{Estimator, Fit, Trained, Transform, Untrained},
10 types::Float,
11};
12use std::marker::PhantomData;
13
14pub use crate::column_transformer::TransformerWrapper;
16
17#[derive(Debug, Clone, Copy, PartialEq)]
19pub enum FeatureSelectionStrategy {
20 None,
22 VarianceThreshold(Float),
24 TopK(usize),
26 ImportanceThreshold(Float),
28 TopPercentile(Float),
30}
31
32impl Default for FeatureSelectionStrategy {
33 fn default() -> Self {
34 Self::None
35 }
36}
37
38#[derive(Debug, Clone, Copy, PartialEq)]
40pub enum FeatureImportanceMethod {
41 Variance,
43 AbsoluteMean,
45 L1Norm,
47 L2Norm,
49 PrincipalComponent,
51}
52
53impl Default for FeatureImportanceMethod {
54 fn default() -> Self {
55 Self::Variance
56 }
57}
58
59#[derive(Debug, Clone)]
61pub struct FeatureUnionStep {
62 pub name: String,
64 pub transformer: Box<dyn TransformerWrapper>,
66 pub weight: Option<Float>,
68}
69
70#[derive(Debug, Clone)]
72pub struct FeatureUnionConfig {
73 pub n_jobs: Option<usize>,
75 pub validate_input: bool,
77 pub preserve_order: bool,
79 pub feature_selection: FeatureSelectionStrategy,
81 pub importance_method: FeatureImportanceMethod,
83 pub enable_feature_selection: bool,
85}
86
87impl Default for FeatureUnionConfig {
88 fn default() -> Self {
89 Self {
90 n_jobs: None,
91 validate_input: true,
92 preserve_order: true,
93 feature_selection: FeatureSelectionStrategy::None,
94 importance_method: FeatureImportanceMethod::Variance,
95 enable_feature_selection: false,
96 }
97 }
98}
99
100#[derive(Debug)]
109pub struct FeatureUnion<State = Untrained> {
110 config: FeatureUnionConfig,
111 transformers: Vec<FeatureUnionStep>,
112 state: PhantomData<State>,
113 fitted_transformers_: Option<Vec<FeatureUnionStep>>,
115 n_features_in_: Option<usize>,
116 n_features_out_: Option<usize>,
117 transformer_weights_: Option<Vec<Float>>,
118 selected_features_: Option<Vec<usize>>,
120 feature_importances_: Option<Array1<Float>>,
121 feature_names_: Option<Vec<String>>,
122}
123
124impl FeatureUnion<Untrained> {
125 pub fn new() -> Self {
127 Self {
128 config: FeatureUnionConfig::default(),
129 transformers: Vec::new(),
130 state: PhantomData,
131 fitted_transformers_: None,
132 n_features_in_: None,
133 n_features_out_: None,
134 transformer_weights_: None,
135 selected_features_: None,
136 feature_importances_: None,
137 feature_names_: None,
138 }
139 }
140
141 pub fn add_transformer<T>(mut self, name: &str, transformer: T) -> Self
143 where
144 T: TransformerWrapper + 'static,
145 {
146 self.transformers.push(FeatureUnionStep {
147 name: name.to_string(),
148 transformer: Box::new(transformer),
149 weight: None,
150 });
151 self
152 }
153
154 pub fn add_weighted_transformer<T>(mut self, name: &str, transformer: T, weight: Float) -> Self
156 where
157 T: TransformerWrapper + 'static,
158 {
159 self.transformers.push(FeatureUnionStep {
160 name: name.to_string(),
161 transformer: Box::new(transformer),
162 weight: Some(weight),
163 });
164 self
165 }
166
167 pub fn n_jobs(mut self, n_jobs: Option<usize>) -> Self {
169 self.config.n_jobs = n_jobs;
170 self
171 }
172
173 pub fn validate_input(mut self, validate: bool) -> Self {
175 self.config.validate_input = validate;
176 self
177 }
178
179 pub fn preserve_order(mut self, preserve: bool) -> Self {
181 self.config.preserve_order = preserve;
182 self
183 }
184
185 pub fn feature_selection(mut self, strategy: FeatureSelectionStrategy) -> Self {
187 self.config.feature_selection = strategy;
188 self.config.enable_feature_selection = !matches!(strategy, FeatureSelectionStrategy::None);
189 self
190 }
191
192 pub fn importance_method(mut self, method: FeatureImportanceMethod) -> Self {
194 self.config.importance_method = method;
195 self
196 }
197
198 pub fn enable_feature_selection(mut self, enable: bool) -> Self {
200 self.config.enable_feature_selection = enable;
201 self
202 }
203
204 fn calculate_feature_importance(
206 &self,
207 data: &Array2<Float>,
208 method: FeatureImportanceMethod,
209 ) -> Array1<Float> {
210 let n_features = data.ncols();
211 let mut importances = Array1::zeros(n_features);
212
213 match method {
214 FeatureImportanceMethod::Variance => {
215 for (i, col) in data.columns().into_iter().enumerate() {
216 let mean = col.mean().unwrap_or(0.0);
217 let variance = col.iter().map(|&x| (x - mean).powi(2)).sum::<Float>()
218 / (col.len() as Float);
219 importances[i] = variance;
220 }
221 }
222 FeatureImportanceMethod::AbsoluteMean => {
223 for (i, col) in data.columns().into_iter().enumerate() {
224 let abs_mean =
225 col.iter().map(|&x| x.abs()).sum::<Float>() / (col.len() as Float);
226 importances[i] = abs_mean;
227 }
228 }
229 FeatureImportanceMethod::L1Norm => {
230 for (i, col) in data.columns().into_iter().enumerate() {
231 let l1_norm = col.iter().map(|&x| x.abs()).sum::<Float>();
232 importances[i] = l1_norm;
233 }
234 }
235 FeatureImportanceMethod::L2Norm => {
236 for (i, col) in data.columns().into_iter().enumerate() {
237 let l2_norm = col.iter().map(|&x| x * x).sum::<Float>().sqrt();
238 importances[i] = l2_norm;
239 }
240 }
241 FeatureImportanceMethod::PrincipalComponent => {
242 let means: Vec<Float> = (0..n_features)
244 .map(|i| data.column(i).mean().unwrap_or(0.0))
245 .collect();
246
247 let mut pc1: Array1<Float> = Array1::zeros(data.nrows());
249 for (i, col) in data.columns().into_iter().enumerate() {
250 let mean = means[i];
251 for (row_idx, &val) in col.iter().enumerate() {
252 pc1[row_idx] += (val - mean) / (n_features as Float).sqrt();
253 }
254 }
255
256 for (i, col) in data.columns().into_iter().enumerate() {
258 let mean = means[i];
259 let centered_col: Vec<Float> = col.iter().map(|&x| x - mean).collect();
260 let correlation: Float = centered_col
261 .iter()
262 .zip(pc1.iter())
263 .map(|(&x, &y): (&Float, &Float)| x * y)
264 .sum::<Float>()
265 / ((data.nrows() - 1) as Float);
266 importances[i] = correlation.abs();
267 }
268 }
269 }
270
271 importances
272 }
273
274 fn select_features(
276 &self,
277 importances: &Array1<Float>,
278 strategy: FeatureSelectionStrategy,
279 ) -> Vec<usize> {
280 let n_features = importances.len();
281 let mut feature_indices: Vec<(usize, Float)> = importances
282 .iter()
283 .enumerate()
284 .map(|(i, &score)| (i, score))
285 .collect();
286
287 match strategy {
288 FeatureSelectionStrategy::None => (0..n_features).collect(),
289 FeatureSelectionStrategy::VarianceThreshold(threshold) => feature_indices
290 .into_iter()
291 .filter_map(|(idx, score)| if score >= threshold { Some(idx) } else { None })
292 .collect(),
293 FeatureSelectionStrategy::TopK(k) => {
294 feature_indices
295 .sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
296 feature_indices
297 .into_iter()
298 .take(k.min(n_features))
299 .map(|(idx, _)| idx)
300 .collect()
301 }
302 FeatureSelectionStrategy::ImportanceThreshold(threshold) => feature_indices
303 .into_iter()
304 .filter_map(|(idx, score)| if score >= threshold { Some(idx) } else { None })
305 .collect(),
306 FeatureSelectionStrategy::TopPercentile(percentile) => {
307 if percentile <= 0.0 || percentile > 100.0 {
308 return (0..n_features).collect();
309 }
310 feature_indices
311 .sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
312 let k = ((n_features as Float * percentile / 100.0).ceil() as usize).max(1);
313 feature_indices
314 .into_iter()
315 .take(k)
316 .map(|(idx, _)| idx)
317 .collect()
318 }
319 }
320 }
321}
322
323impl Default for FeatureUnion<Untrained> {
324 fn default() -> Self {
325 Self::new()
326 }
327}
328
329impl Estimator for FeatureUnion<Untrained> {
330 type Config = FeatureUnionConfig;
331 type Error = SklearsError;
332 type Float = Float;
333
334 fn config(&self) -> &Self::Config {
335 &self.config
336 }
337}
338
339impl Estimator for FeatureUnion<Trained> {
340 type Config = FeatureUnionConfig;
341 type Error = SklearsError;
342 type Float = Float;
343
344 fn config(&self) -> &Self::Config {
345 &self.config
346 }
347}
348
349impl Fit<Array2<Float>, ()> for FeatureUnion<Untrained> {
350 type Fitted = FeatureUnion<Trained>;
351
352 fn fit(self, x: &Array2<Float>, _y: &()) -> Result<Self::Fitted> {
353 let (n_samples, n_features) = x.dim();
354
355 if n_samples == 0 {
356 return Err(SklearsError::InvalidInput(
357 "Cannot fit FeatureUnion on empty dataset".to_string(),
358 ));
359 }
360
361 if self.transformers.is_empty() {
362 return Err(SklearsError::InvalidInput(
363 "FeatureUnion requires at least one transformer".to_string(),
364 ));
365 }
366
367 let mut fitted_transformers = Vec::new();
369 let mut transformer_weights = Vec::new();
370 let mut all_transformed_data = Vec::new();
371
372 for step in &self.transformers {
373 let transformed = step.transformer.fit_transform_wrapper(x)?;
375
376 if transformed.nrows() != n_samples {
378 return Err(SklearsError::InvalidInput(format!(
379 "Transformer '{}' returned {} samples, expected {}",
380 step.name,
381 transformed.nrows(),
382 n_samples
383 )));
384 }
385
386 transformer_weights.push(step.weight.unwrap_or(1.0));
388
389 let mut weighted_transformed = transformed;
391 let weight = step.weight.unwrap_or(1.0);
392 if (weight - 1.0).abs() > Float::EPSILON {
393 weighted_transformed *= weight;
394 }
395
396 all_transformed_data.push(weighted_transformed);
397
398 fitted_transformers.push(FeatureUnionStep {
400 name: step.name.clone(),
401 transformer: step.transformer.clone_box(),
402 weight: step.weight,
403 });
404 }
405
406 let concatenated_data = concatenate_features(all_transformed_data)?;
408
409 let (selected_features, feature_importances, total_output_features) = if self
411 .config
412 .enable_feature_selection
413 {
414 let importances = self
415 .calculate_feature_importance(&concatenated_data, self.config.importance_method);
416 let selected = self.select_features(&importances, self.config.feature_selection);
417 let n_selected = selected.len();
418 (Some(selected), Some(importances), n_selected)
419 } else {
420 (None, None, concatenated_data.ncols())
421 };
422
423 Ok(FeatureUnion {
424 config: self.config,
425 transformers: self.transformers,
426 state: PhantomData,
427 fitted_transformers_: Some(fitted_transformers),
428 n_features_in_: Some(n_features),
429 n_features_out_: Some(total_output_features),
430 transformer_weights_: Some(transformer_weights),
431 selected_features_: selected_features,
432 feature_importances_: feature_importances,
433 feature_names_: None, })
435 }
436}
437
438impl Transform<Array2<Float>, Array2<Float>> for FeatureUnion<Trained> {
439 fn transform(&self, x: &Array2<Float>) -> Result<Array2<Float>> {
440 let (n_samples, n_features) = x.dim();
441
442 if Some(n_features) != self.n_features_in_ {
443 return Err(SklearsError::FeatureMismatch {
444 expected: self.n_features_in_.unwrap_or(0),
445 actual: n_features,
446 });
447 }
448
449 let fitted_transformers = self
450 .fitted_transformers_
451 .as_ref()
452 .expect("operation should succeed");
453 let transformer_weights = self
454 .transformer_weights_
455 .as_ref()
456 .expect("operation should succeed");
457
458 if fitted_transformers.is_empty() {
459 return Err(SklearsError::InvalidInput(
460 "No fitted transformers available".to_string(),
461 ));
462 }
463
464 let mut transformed_parts = Vec::new();
466
467 for (i, step) in fitted_transformers.iter().enumerate() {
468 let mut transformed = step.transformer.transform_wrapper(x)?;
470
471 if transformed.nrows() != n_samples {
473 return Err(SklearsError::InvalidInput(format!(
474 "Transformer '{}' returned {} samples, expected {}",
475 step.name,
476 transformed.nrows(),
477 n_samples
478 )));
479 }
480
481 let weight = transformer_weights[i];
483 if (weight - 1.0).abs() > Float::EPSILON {
484 transformed *= weight;
485 }
486
487 transformed_parts.push(transformed);
488 }
489
490 let concatenated = concatenate_features(transformed_parts)?;
492
493 if let Some(ref selected_features) = self.selected_features_ {
495 if selected_features.is_empty() {
496 return Err(SklearsError::InvalidInput(
497 "No features were selected during fitting".to_string(),
498 ));
499 }
500
501 let selected_data = concatenated.select(Axis(1), selected_features);
503 Ok(selected_data)
504 } else {
505 Ok(concatenated)
506 }
507 }
508}
509
510fn concatenate_features(parts: Vec<Array2<Float>>) -> Result<Array2<Float>> {
512 if parts.is_empty() {
513 return Err(SklearsError::InvalidInput(
514 "No arrays to concatenate".to_string(),
515 ));
516 }
517
518 if parts.len() == 1 {
519 return Ok(parts.into_iter().next().expect("operation should succeed"));
520 }
521
522 let total_cols: usize = parts.iter().map(|p| p.ncols()).sum();
524 let n_rows = parts[0].nrows();
525
526 let mut result = Array2::zeros((n_rows, total_cols));
528
529 let mut col_offset = 0;
531 for part in parts {
532 let part_cols = part.ncols();
533 result
534 .slice_mut(scirs2_core::ndarray::s![
535 ..,
536 col_offset..col_offset + part_cols
537 ])
538 .assign(&part);
539 col_offset += part_cols;
540 }
541
542 Ok(result)
543}
544
545impl FeatureUnion<Trained> {
546 pub fn n_features_in(&self) -> usize {
548 self.n_features_in_.expect("operation should succeed")
549 }
550
551 pub fn n_features_out(&self) -> usize {
553 self.n_features_out_.expect("operation should succeed")
554 }
555
556 pub fn get_transformers(&self) -> &Vec<FeatureUnionStep> {
558 self.fitted_transformers_
559 .as_ref()
560 .expect("operation should succeed")
561 }
562
563 pub fn get_weights(&self) -> &Vec<Float> {
565 self.transformer_weights_
566 .as_ref()
567 .expect("operation should succeed")
568 }
569
570 pub fn get_selected_features(&self) -> Option<&Vec<usize>> {
572 self.selected_features_.as_ref()
573 }
574
575 pub fn get_feature_importances(&self) -> Option<&Array1<Float>> {
577 self.feature_importances_.as_ref()
578 }
579
580 pub fn n_features_selected(&self) -> usize {
582 self.selected_features_
583 .as_ref()
584 .map(|features| features.len())
585 .unwrap_or_else(|| self.n_features_out())
586 }
587
588 pub fn is_feature_selection_enabled(&self) -> bool {
590 self.selected_features_.is_some()
591 }
592}
593
594#[allow(non_snake_case)]
595#[cfg(test)]
596mod tests {
597 use super::*;
598 use scirs2_core::ndarray::array;
599
600 #[derive(Debug, Clone)]
602 struct MockTransformer {
603 scale: Float,
604 output_features: Option<usize>,
605 }
606
607 impl TransformerWrapper for MockTransformer {
608 fn fit_transform_wrapper(&self, x: &Array2<Float>) -> Result<Array2<Float>> {
609 self.transform_wrapper(x)
610 }
611
612 fn transform_wrapper(&self, x: &Array2<Float>) -> Result<Array2<Float>> {
613 let result = x * self.scale;
614
615 if let Some(out_features) = self.output_features {
617 let n_rows = result.nrows();
618 let mut output = Array2::zeros((n_rows, out_features));
619
620 for i in 0..out_features {
621 let source_col = i % result.ncols();
622 output.column_mut(i).assign(&result.column(source_col));
623 }
624
625 Ok(output)
626 } else {
627 Ok(result)
628 }
629 }
630
631 fn get_n_features_out(&self) -> Option<usize> {
632 self.output_features
633 }
634
635 fn clone_box(&self) -> Box<dyn TransformerWrapper> {
636 Box::new(self.clone())
637 }
638 }
639
640 #[test]
641 fn test_feature_union_basic() {
642 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0],];
643
644 let fu = FeatureUnion::new()
645 .add_transformer(
646 "scale_by_2",
647 MockTransformer {
648 scale: 2.0,
649 output_features: None,
650 },
651 )
652 .add_transformer(
653 "scale_by_3",
654 MockTransformer {
655 scale: 3.0,
656 output_features: None,
657 },
658 );
659
660 let fitted_fu = fu.fit(&x, &()).expect("model fitting should succeed");
661 let result = fitted_fu
662 .transform(&x)
663 .expect("transformation should succeed");
664
665 assert_eq!(result.dim(), (3, 4));
667
668 assert_eq!(result[[0, 0]], 2.0); assert_eq!(result[[0, 1]], 4.0); assert_eq!(result[[0, 2]], 3.0); assert_eq!(result[[0, 3]], 6.0); }
676
677 #[test]
678 fn test_feature_union_weighted() {
679 let x = array![[1.0, 2.0], [3.0, 4.0],];
680
681 let fu = FeatureUnion::new().add_weighted_transformer(
682 "weighted",
683 MockTransformer {
684 scale: 1.0,
685 output_features: None,
686 },
687 2.0,
688 );
689
690 let fitted_fu = fu.fit(&x, &()).expect("model fitting should succeed");
691 let result = fitted_fu
692 .transform(&x)
693 .expect("transformation should succeed");
694
695 assert_eq!(result[[0, 0]], 2.0); assert_eq!(result[[0, 1]], 4.0); }
699
700 #[test]
701 fn test_feature_union_different_output_sizes() {
702 let x = array![[1.0, 2.0], [3.0, 4.0],];
703
704 let fu = FeatureUnion::new()
705 .add_transformer(
706 "identity",
707 MockTransformer {
708 scale: 1.0,
709 output_features: None,
710 },
711 ) .add_transformer(
713 "expand",
714 MockTransformer {
715 scale: 1.0,
716 output_features: Some(3),
717 },
718 ); let fitted_fu = fu.fit(&x, &()).expect("model fitting should succeed");
721 let result = fitted_fu
722 .transform(&x)
723 .expect("transformation should succeed");
724
725 assert_eq!(result.dim(), (2, 5));
727 assert_eq!(fitted_fu.n_features_out(), 5);
728 }
729
730 #[test]
731 fn test_feature_union_empty_transformers() {
732 let x = array![[1.0, 2.0], [3.0, 4.0],];
733
734 let fu = FeatureUnion::new();
735
736 let result = fu.fit(&x, &());
737 assert!(result.is_err());
738 }
739
740 #[test]
741 fn test_feature_union_empty_data() {
742 let x_empty: Array2<Float> = Array2::zeros((0, 2));
743
744 let fu = FeatureUnion::new().add_transformer(
745 "test",
746 MockTransformer {
747 scale: 1.0,
748 output_features: None,
749 },
750 );
751
752 let result = fu.fit(&x_empty, &());
753 assert!(result.is_err());
754 }
755
756 #[test]
757 fn test_feature_union_feature_mismatch() {
758 let x_train = array![[1.0, 2.0], [3.0, 4.0],];
759
760 let x_test = array![
761 [1.0, 2.0, 3.0], [4.0, 5.0, 6.0],
763 ];
764
765 let fu = FeatureUnion::new().add_transformer(
766 "test",
767 MockTransformer {
768 scale: 1.0,
769 output_features: None,
770 },
771 );
772
773 let fitted_fu = fu.fit(&x_train, &()).expect("model fitting should succeed");
774 let result = fitted_fu.transform(&x_test);
775
776 assert!(result.is_err());
777 if let Err(SklearsError::FeatureMismatch { expected, actual }) = result {
778 assert_eq!(expected, 2);
779 assert_eq!(actual, 3);
780 } else {
781 panic!("Expected FeatureMismatch error");
782 }
783 }
784
785 #[test]
786 fn test_concatenate_features() {
787 let part1 = array![[1.0, 2.0], [3.0, 4.0],];
788
789 let part2 = array![[5.0], [6.0],];
790
791 let part3 = array![[7.0, 8.0, 9.0], [10.0, 11.0, 12.0],];
792
793 let parts = vec![part1, part2, part3];
794 let result = concatenate_features(parts).expect("operation should succeed");
795
796 assert_eq!(result.dim(), (2, 6)); assert_eq!(result[[0, 0]], 1.0);
800 assert_eq!(result[[0, 1]], 2.0);
801 assert_eq!(result[[0, 2]], 5.0);
802 assert_eq!(result[[0, 3]], 7.0);
803 assert_eq!(result[[0, 4]], 8.0);
804 assert_eq!(result[[0, 5]], 9.0);
805 }
806
807 #[test]
808 fn test_feature_selection_variance_threshold() {
809 let x = array![
810 [1.0, 1.0, 1.0, 2.0], [1.1, 1.0, 1.0, 4.0],
812 [0.9, 1.0, 1.0, 6.0],
813 [1.0, 1.0, 1.0, 8.0],
814 ];
815
816 let fu = FeatureUnion::new()
817 .add_transformer(
818 "identity",
819 MockTransformer {
820 scale: 1.0,
821 output_features: None,
822 },
823 )
824 .feature_selection(FeatureSelectionStrategy::VarianceThreshold(0.1))
825 .importance_method(FeatureImportanceMethod::Variance);
826
827 let fitted_fu = fu.fit(&x, &()).expect("model fitting should succeed");
828 let _result = fitted_fu
829 .transform(&x)
830 .expect("transformation should succeed");
831
832 assert!(fitted_fu.is_feature_selection_enabled());
834 assert!(fitted_fu.n_features_selected() <= 4);
835 assert!(fitted_fu.get_feature_importances().is_some());
836 assert!(fitted_fu.get_selected_features().is_some());
837 }
838
839 #[test]
840 fn test_feature_selection_top_k() {
841 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
842
843 let fu = FeatureUnion::new()
844 .add_transformer(
845 "scale_by_2",
846 MockTransformer {
847 scale: 2.0,
848 output_features: None,
849 },
850 )
851 .add_transformer(
852 "scale_by_3",
853 MockTransformer {
854 scale: 3.0,
855 output_features: None,
856 },
857 )
858 .feature_selection(FeatureSelectionStrategy::TopK(2))
859 .importance_method(FeatureImportanceMethod::L2Norm);
860
861 let fitted_fu = fu.fit(&x, &()).expect("model fitting should succeed");
862 let result = fitted_fu
863 .transform(&x)
864 .expect("transformation should succeed");
865
866 assert_eq!(fitted_fu.n_features_selected(), 2);
868 assert_eq!(result.ncols(), 2);
869 assert!(fitted_fu.is_feature_selection_enabled());
870 }
871
872 #[test]
873 fn test_feature_selection_top_percentile() {
874 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
875
876 let fu = FeatureUnion::new()
877 .add_transformer(
878 "expand",
879 MockTransformer {
880 scale: 1.0,
881 output_features: Some(6),
882 },
883 )
884 .feature_selection(FeatureSelectionStrategy::TopPercentile(50.0))
885 .importance_method(FeatureImportanceMethod::AbsoluteMean);
886
887 let fitted_fu = fu.fit(&x, &()).expect("model fitting should succeed");
888 let result = fitted_fu
889 .transform(&x)
890 .expect("transformation should succeed");
891
892 assert_eq!(fitted_fu.n_features_selected(), 3);
894 assert_eq!(result.ncols(), 3);
895 assert!(fitted_fu.is_feature_selection_enabled());
896 }
897
898 #[test]
899 fn test_feature_selection_disabled() {
900 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
901
902 let fu = FeatureUnion::new()
903 .add_transformer(
904 "identity",
905 MockTransformer {
906 scale: 1.0,
907 output_features: None,
908 },
909 )
910 .feature_selection(FeatureSelectionStrategy::None);
911
912 let fitted_fu = fu.fit(&x, &()).expect("model fitting should succeed");
913 let result = fitted_fu
914 .transform(&x)
915 .expect("transformation should succeed");
916
917 assert!(!fitted_fu.is_feature_selection_enabled());
919 assert_eq!(fitted_fu.n_features_selected(), 2);
920 assert_eq!(result.ncols(), 2);
921 assert!(fitted_fu.get_feature_importances().is_none());
922 assert!(fitted_fu.get_selected_features().is_none());
923 }
924
925 #[test]
926 fn test_feature_importance_methods() {
927 let x = array![[1.0, 0.0, 10.0], [2.0, 0.0, 20.0], [3.0, 0.0, 30.0],];
928
929 let fu = FeatureUnion::new()
930 .add_transformer(
931 "identity",
932 MockTransformer {
933 scale: 1.0,
934 output_features: None,
935 },
936 )
937 .enable_feature_selection(true)
938 .importance_method(FeatureImportanceMethod::Variance);
939
940 let fitted_fu = fu.fit(&x, &()).expect("model fitting should succeed");
941 let importances = fitted_fu
942 .get_feature_importances()
943 .expect("operation should succeed");
944
945 assert!(importances[2] > importances[0]);
947 assert!(importances[0] > importances[1]); }
949
950 #[test]
951 fn test_get_methods() {
952 let x = array![[1.0, 2.0], [3.0, 4.0],];
953
954 let fu = FeatureUnion::new()
955 .add_weighted_transformer(
956 "test1",
957 MockTransformer {
958 scale: 1.0,
959 output_features: None,
960 },
961 2.0,
962 )
963 .add_transformer(
964 "test2",
965 MockTransformer {
966 scale: 1.0,
967 output_features: Some(3),
968 },
969 );
970
971 let fitted_fu = fu.fit(&x, &()).expect("model fitting should succeed");
972
973 assert_eq!(fitted_fu.n_features_in(), 2);
974 assert_eq!(fitted_fu.n_features_out(), 5); assert_eq!(fitted_fu.get_transformers().len(), 2);
976 assert_eq!(fitted_fu.get_weights(), &vec![2.0, 1.0]);
977 }
978}