1use crate::base::FeatureSelector;
8use scirs2_core::ndarray::{Array1, Array2, Axis};
9use sklears_core::{
10 error::{Result as SklResult, SklearsError},
11 traits::{Estimator, Fit, Trained, Transform, Untrained},
12 types::Float,
13};
14use std::collections::{HashMap, HashSet, VecDeque};
15use std::marker::PhantomData;
16
17#[derive(Debug, Clone)]
19pub struct HierarchyNode {
20 pub feature_id: usize,
21 pub parent: Option<usize>,
22 pub children: Vec<usize>,
23 pub level: usize,
24 pub group_id: Option<usize>,
25}
26
27#[derive(Debug, Clone)]
29pub struct FeatureHierarchy {
30 nodes: HashMap<usize, HierarchyNode>,
31 root_nodes: Vec<usize>,
32 max_level: usize,
33}
34
35impl FeatureHierarchy {
36 pub fn new() -> Self {
38 Self {
39 nodes: HashMap::new(),
40 root_nodes: Vec::new(),
41 max_level: 0,
42 }
43 }
44
45 pub fn add_node(
47 &mut self,
48 feature_id: usize,
49 parent: Option<usize>,
50 group_id: Option<usize>,
51 ) -> SklResult<()> {
52 let level = if let Some(parent_id) = parent {
53 if let Some(parent_node) = self.nodes.get(&parent_id) {
54 parent_node.level + 1
55 } else {
56 return Err(SklearsError::InvalidInput(format!(
57 "Parent node {} not found",
58 parent_id
59 )));
60 }
61 } else {
62 0
63 };
64
65 let node = HierarchyNode {
66 feature_id,
67 parent,
68 children: Vec::new(),
69 level,
70 group_id,
71 };
72
73 if let Some(parent_id) = parent {
75 if let Some(parent_node) = self.nodes.get_mut(&parent_id) {
76 parent_node.children.push(feature_id);
77 }
78 } else {
79 self.root_nodes.push(feature_id);
80 }
81
82 self.max_level = self.max_level.max(level);
83 self.nodes.insert(feature_id, node);
84 Ok(())
85 }
86
87 pub fn get_descendants(&self, feature_id: usize) -> Vec<usize> {
89 let mut descendants = Vec::new();
90 let mut queue = VecDeque::new();
91
92 if let Some(node) = self.nodes.get(&feature_id) {
93 queue.extend(&node.children);
94 }
95
96 while let Some(child_id) = queue.pop_front() {
97 descendants.push(child_id);
98 if let Some(child_node) = self.nodes.get(&child_id) {
99 queue.extend(&child_node.children);
100 }
101 }
102
103 descendants
104 }
105
106 pub fn get_ancestors(&self, feature_id: usize) -> Vec<usize> {
108 let mut ancestors = Vec::new();
109 let mut current_id = feature_id;
110
111 while let Some(node) = self.nodes.get(¤t_id) {
112 if let Some(parent_id) = node.parent {
113 ancestors.push(parent_id);
114 current_id = parent_id;
115 } else {
116 break;
117 }
118 }
119
120 ancestors
121 }
122
123 pub fn get_features_at_level(&self, level: usize) -> Vec<usize> {
125 let mut features: Vec<usize> = self
126 .nodes
127 .values()
128 .filter(|node| node.level == level)
129 .map(|node| node.feature_id)
130 .collect();
131 features.sort();
132 features
133 }
134
135 pub fn get_features_in_group(&self, group_id: usize) -> Vec<usize> {
137 let mut features: Vec<usize> = self
138 .nodes
139 .values()
140 .filter(|node| node.group_id == Some(group_id))
141 .map(|node| node.feature_id)
142 .collect();
143 features.sort();
144 features
145 }
146
147 pub fn is_leaf(&self, feature_id: usize) -> bool {
149 self.nodes
150 .get(&feature_id)
151 .map(|node| node.children.is_empty())
152 .unwrap_or(false)
153 }
154
155 pub fn get_leaf_nodes(&self) -> Vec<usize> {
157 self.nodes
158 .values()
159 .filter(|node| node.children.is_empty())
160 .map(|node| node.feature_id)
161 .collect()
162 }
163}
164
165impl Default for FeatureHierarchy {
166 fn default() -> Self {
167 Self::new()
168 }
169}
170
171#[derive(Debug, Clone)]
176pub struct HierarchicalFeatureSelector<State = Untrained> {
177 hierarchy: FeatureHierarchy,
178 k: usize,
179 selection_strategy: HierarchicalSelectionStrategy,
180 score_aggregation: ScoreAggregation,
181
182 selected_features_: Option<Vec<usize>>,
184 feature_scores_: Option<HashMap<usize, Float>>,
185
186 state: PhantomData<State>,
187}
188
189#[derive(Debug, Clone)]
191pub enum HierarchicalSelectionStrategy {
192 TopDown,
194 BottomUp,
196 LevelWise,
198 GroupBased,
200}
201
202#[derive(Debug, Clone)]
204pub enum ScoreAggregation {
205 Sum,
207 Max,
209 WeightedAverage,
211 Product,
213}
214
215impl HierarchicalFeatureSelector<Untrained> {
216 pub fn new(hierarchy: FeatureHierarchy, k: usize) -> Self {
218 Self {
219 hierarchy,
220 k,
221 selection_strategy: HierarchicalSelectionStrategy::TopDown,
222 score_aggregation: ScoreAggregation::Sum,
223 selected_features_: None,
224 feature_scores_: None,
225 state: PhantomData,
226 }
227 }
228
229 pub fn selection_strategy(mut self, strategy: HierarchicalSelectionStrategy) -> Self {
231 self.selection_strategy = strategy;
232 self
233 }
234
235 pub fn score_aggregation(mut self, aggregation: ScoreAggregation) -> Self {
237 self.score_aggregation = aggregation;
238 self
239 }
240}
241
242impl Estimator for HierarchicalFeatureSelector<Untrained> {
243 type Config = ();
244 type Error = SklearsError;
245 type Float = f64;
246
247 fn config(&self) -> &Self::Config {
248 &()
249 }
250}
251
252impl Fit<Array2<Float>, Array1<Float>> for HierarchicalFeatureSelector<Untrained> {
253 type Fitted = HierarchicalFeatureSelector<Trained>;
254
255 fn fit(self, x: &Array2<Float>, y: &Array1<Float>) -> SklResult<Self::Fitted> {
256 let (n_samples, n_features) = x.dim();
257 if n_samples == 0 || n_features == 0 {
258 return Err(SklearsError::InvalidInput(
259 "Input data cannot be empty".to_string(),
260 ));
261 }
262
263 if self.k > n_features {
264 return Err(SklearsError::InvalidInput(
265 "k cannot be larger than number of features".to_string(),
266 ));
267 }
268
269 let mut feature_scores = HashMap::new();
271 for feature_idx in 0..n_features {
272 let feature_col = x.column(feature_idx);
273 let score = compute_f_score(&feature_col.to_owned(), y);
274 feature_scores.insert(feature_idx, score);
275 }
276
277 let selected_features = match self.selection_strategy {
279 HierarchicalSelectionStrategy::TopDown => self.select_top_down(&feature_scores)?,
280 HierarchicalSelectionStrategy::BottomUp => self.select_bottom_up(&feature_scores)?,
281 HierarchicalSelectionStrategy::LevelWise => self.select_level_wise(&feature_scores)?,
282 HierarchicalSelectionStrategy::GroupBased => {
283 self.select_group_based(&feature_scores)?
284 }
285 };
286
287 Ok(HierarchicalFeatureSelector {
288 hierarchy: self.hierarchy,
289 k: self.k,
290 selection_strategy: self.selection_strategy,
291 score_aggregation: self.score_aggregation,
292 selected_features_: Some(selected_features),
293 feature_scores_: Some(feature_scores),
294 state: PhantomData,
295 })
296 }
297}
298
299impl HierarchicalFeatureSelector<Untrained> {
300 fn select_top_down(&self, feature_scores: &HashMap<usize, Float>) -> SklResult<Vec<usize>> {
302 let mut selected = HashSet::new();
303 let mut candidates = VecDeque::new();
304
305 candidates.extend(&self.hierarchy.root_nodes);
307
308 while !candidates.is_empty() && selected.len() < self.k {
309 let mut level_scores: Vec<(usize, Float)> = candidates
310 .iter()
311 .filter_map(|&feature_id| {
312 feature_scores
313 .get(&feature_id)
314 .map(|&score| (feature_id, score))
315 })
316 .collect();
317
318 level_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
320
321 let mut next_candidates: VecDeque<usize> = VecDeque::new();
323 for (feature_id, _) in level_scores {
324 if selected.len() >= self.k {
325 break;
326 }
327
328 selected.insert(feature_id);
329 candidates.retain(|&x| x != feature_id);
330
331 if let Some(node) = self.hierarchy.nodes.get(&feature_id) {
333 next_candidates.extend(&node.children);
334 }
335 }
336
337 candidates.extend(next_candidates);
338 }
339
340 Ok(selected.into_iter().collect())
341 }
342
343 fn select_bottom_up(&self, feature_scores: &HashMap<usize, Float>) -> SklResult<Vec<usize>> {
345 let mut aggregated_scores = feature_scores.clone();
346
347 for level in (0..=self.hierarchy.max_level).rev() {
349 let level_features = self.hierarchy.get_features_at_level(level);
350
351 for feature_id in level_features {
352 if let Some(node) = self.hierarchy.nodes.get(&feature_id) {
353 if !node.children.is_empty() {
354 let child_scores: Vec<Float> = node
356 .children
357 .iter()
358 .filter_map(|&child_id| aggregated_scores.get(&child_id))
359 .cloned()
360 .collect();
361
362 if !child_scores.is_empty() {
363 let aggregated = self.aggregate_scores(&child_scores);
364 let current_score =
365 aggregated_scores.get(&feature_id).cloned().unwrap_or(0.0);
366 aggregated_scores.insert(feature_id, current_score + aggregated);
367 }
368 }
369 }
370 }
371 }
372
373 let mut scored_features: Vec<(usize, Float)> = aggregated_scores.into_iter().collect();
375 scored_features.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
376
377 Ok(scored_features
378 .into_iter()
379 .take(self.k)
380 .map(|(feature_id, _)| feature_id)
381 .collect())
382 }
383
384 fn select_level_wise(&self, feature_scores: &HashMap<usize, Float>) -> SklResult<Vec<usize>> {
386 let mut selected = Vec::new();
387 let features_per_level = self.k / (self.hierarchy.max_level + 1);
388 let remaining = self.k % (self.hierarchy.max_level + 1);
389
390 for level in 0..=self.hierarchy.max_level {
391 let level_features = self.hierarchy.get_features_at_level(level);
392 let mut level_scores: Vec<(usize, Float)> = level_features
393 .into_iter()
394 .filter_map(|feature_id| {
395 feature_scores
396 .get(&feature_id)
397 .map(|&score| (feature_id, score))
398 })
399 .collect();
400
401 level_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
402
403 let k_for_level = if level < remaining {
404 features_per_level + 1
405 } else {
406 features_per_level
407 };
408
409 selected.extend(
410 level_scores
411 .into_iter()
412 .take(k_for_level)
413 .map(|(feature_id, _)| feature_id),
414 );
415 }
416
417 Ok(selected)
418 }
419
420 fn select_group_based(&self, feature_scores: &HashMap<usize, Float>) -> SklResult<Vec<usize>> {
422 let mut groups: HashSet<usize> = HashSet::new();
424 for node in self.hierarchy.nodes.values() {
425 if let Some(group_id) = node.group_id {
426 groups.insert(group_id);
427 }
428 }
429
430 if groups.is_empty() {
431 let mut scored_features: Vec<(usize, Float)> = feature_scores
433 .iter()
434 .map(|(&feature_id, &score)| (feature_id, score))
435 .collect();
436 scored_features.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
437
438 return Ok(scored_features
439 .into_iter()
440 .take(self.k)
441 .map(|(feature_id, _)| feature_id)
442 .collect());
443 }
444
445 let features_per_group = self.k / groups.len();
446 let remaining = self.k % groups.len();
447 let mut selected = Vec::new();
448
449 for (group_idx, group_id) in groups.into_iter().enumerate() {
450 let group_features = self.hierarchy.get_features_in_group(group_id);
451 let mut group_scores: Vec<(usize, Float)> = group_features
452 .into_iter()
453 .filter_map(|feature_id| {
454 feature_scores
455 .get(&feature_id)
456 .map(|&score| (feature_id, score))
457 })
458 .collect();
459
460 group_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
461
462 let k_for_group = if group_idx < remaining {
463 features_per_group + 1
464 } else {
465 features_per_group
466 };
467
468 selected.extend(
469 group_scores
470 .into_iter()
471 .take(k_for_group)
472 .map(|(feature_id, _)| feature_id),
473 );
474 }
475
476 Ok(selected)
477 }
478
479 fn aggregate_scores(&self, scores: &[Float]) -> Float {
481 if scores.is_empty() {
482 return 0.0;
483 }
484
485 match self.score_aggregation {
486 ScoreAggregation::Sum => scores.iter().sum(),
487 ScoreAggregation::Max => scores.iter().cloned().fold(f64::NEG_INFINITY, f64::max),
488 ScoreAggregation::WeightedAverage => {
489 let sum: Float = scores.iter().sum();
490 sum / scores.len() as Float
491 }
492 ScoreAggregation::Product => scores.iter().product(),
493 }
494 }
495}
496
497impl FeatureSelector for HierarchicalFeatureSelector<Trained> {
498 fn selected_features(&self) -> &Vec<usize> {
499 match &self.selected_features_ {
500 Some(features) => features,
501 None => {
502 static EMPTY: Vec<usize> = Vec::new();
503 &EMPTY
504 }
505 }
506 }
507}
508
509impl Transform<Array2<Float>, Array2<Float>> for HierarchicalFeatureSelector<Trained> {
510 fn transform(&self, x: &Array2<Float>) -> SklResult<Array2<Float>> {
511 if let Some(selected) = &self.selected_features_ {
512 if selected.is_empty() {
513 return Err(SklearsError::InvalidData {
514 reason: "No features selected".to_string(),
515 });
516 }
517
518 let selected_cols = x.select(Axis(1), selected);
519 Ok(selected_cols)
520 } else {
521 Err(SklearsError::InvalidData {
522 reason: "Selector not fitted yet".to_string(),
523 })
524 }
525 }
526}
527
528#[derive(Debug, Clone)]
532pub struct MultiLevelHierarchicalSelector<State = Untrained> {
533 hierarchy: FeatureHierarchy,
534 k_per_level: HashMap<usize, usize>,
535 level_weights: HashMap<usize, Float>,
536
537 selected_features_: Option<HashMap<usize, Vec<usize>>>,
539 level_scores_: Option<HashMap<usize, HashMap<usize, Float>>>,
540
541 state: PhantomData<State>,
542}
543
544impl MultiLevelHierarchicalSelector<Untrained> {
545 pub fn new(hierarchy: FeatureHierarchy) -> Self {
547 Self {
548 hierarchy,
549 k_per_level: HashMap::new(),
550 level_weights: HashMap::new(),
551 selected_features_: None,
552 level_scores_: None,
553 state: PhantomData,
554 }
555 }
556
557 pub fn k_per_level(mut self, k_per_level: HashMap<usize, usize>) -> Self {
559 self.k_per_level = k_per_level;
560 self
561 }
562
563 pub fn level_weights(mut self, level_weights: HashMap<usize, Float>) -> Self {
565 self.level_weights = level_weights;
566 self
567 }
568}
569
570impl Estimator for MultiLevelHierarchicalSelector<Untrained> {
571 type Config = ();
572 type Error = SklearsError;
573 type Float = f64;
574
575 fn config(&self) -> &Self::Config {
576 &()
577 }
578}
579
580impl Fit<Array2<Float>, Array1<Float>> for MultiLevelHierarchicalSelector<Untrained> {
581 type Fitted = MultiLevelHierarchicalSelector<Trained>;
582
583 fn fit(self, x: &Array2<Float>, y: &Array1<Float>) -> SklResult<Self::Fitted> {
584 let (n_samples, n_features) = x.dim();
585 if n_samples == 0 || n_features == 0 {
586 return Err(SklearsError::InvalidInput(
587 "Input data cannot be empty".to_string(),
588 ));
589 }
590
591 let mut feature_scores = HashMap::new();
593 for feature_idx in 0..n_features {
594 let feature_col = x.column(feature_idx);
595 let score = compute_f_score(&feature_col.to_owned(), y);
596 feature_scores.insert(feature_idx, score);
597 }
598
599 let mut selected_features = HashMap::new();
601 let mut level_scores = HashMap::new();
602
603 for level in 0..=self.hierarchy.max_level {
604 let level_features = self.hierarchy.get_features_at_level(level);
605 let k_for_level = self.k_per_level.get(&level).cloned().unwrap_or(
606 level_features.len().min(5), );
608
609 let mut level_feature_scores: Vec<(usize, Float)> = level_features
610 .into_iter()
611 .filter_map(|feature_id| {
612 feature_scores.get(&feature_id).map(|&score| {
613 let weight = self.level_weights.get(&level).cloned().unwrap_or(1.0);
614 (feature_id, score * weight)
615 })
616 })
617 .collect();
618
619 level_feature_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
620
621 let selected_at_level: Vec<usize> = level_feature_scores
622 .into_iter()
623 .take(k_for_level)
624 .map(|(feature_id, score)| {
625 level_scores
626 .entry(level)
627 .or_insert_with(HashMap::new)
628 .insert(feature_id, score);
629 feature_id
630 })
631 .collect();
632
633 selected_features.insert(level, selected_at_level);
634 }
635
636 Ok(MultiLevelHierarchicalSelector {
637 hierarchy: self.hierarchy,
638 k_per_level: self.k_per_level,
639 level_weights: self.level_weights,
640 selected_features_: Some(selected_features),
641 level_scores_: Some(level_scores),
642 state: PhantomData,
643 })
644 }
645}
646
647impl MultiLevelHierarchicalSelector<Trained> {
648 pub fn selected_features_at_level(&self, level: usize) -> Option<&Vec<usize>> {
650 self.selected_features_.as_ref()?.get(&level)
651 }
652
653 pub fn all_selected_features(&self) -> Vec<usize> {
655 if let Some(selected_features) = &self.selected_features_ {
656 let mut all_features = Vec::new();
657 for features in selected_features.values() {
658 all_features.extend_from_slice(features);
659 }
660 all_features.sort_unstable();
661 all_features.dedup();
662 all_features
663 } else {
664 Vec::new()
665 }
666 }
667}
668
669impl FeatureSelector for MultiLevelHierarchicalSelector<Trained> {
670 fn selected_features(&self) -> &Vec<usize> {
671 static EMPTY: Vec<usize> = Vec::new();
674 &EMPTY
675 }
676}
677
678impl Transform<Array2<Float>, Array2<Float>> for MultiLevelHierarchicalSelector<Trained> {
679 fn transform(&self, x: &Array2<Float>) -> SklResult<Array2<Float>> {
680 let all_selected = self.all_selected_features();
681 if all_selected.is_empty() {
682 return Err(SklearsError::InvalidData {
683 reason: "No features selected".to_string(),
684 });
685 }
686
687 let selected_cols = x.select(Axis(1), &all_selected);
688 Ok(selected_cols)
689 }
690}
691
692fn compute_f_score(feature: &Array1<Float>, target: &Array1<Float>) -> Float {
694 if feature.len() != target.len() || feature.len() < 3 {
695 return 0.0;
696 }
697
698 let n = feature.len() as Float;
699 let feature_mean = feature.mean().unwrap_or(0.0);
700 let target_mean = target.mean().unwrap_or(0.0);
701
702 let mut numerator = 0.0;
704 let mut feature_var = 0.0;
705 let mut target_var = 0.0;
706
707 for i in 0..feature.len() {
708 let feature_dev = feature[i] - feature_mean;
709 let target_dev = target[i] - target_mean;
710 numerator += feature_dev * target_dev;
711 feature_var += feature_dev * feature_dev;
712 target_var += target_dev * target_dev;
713 }
714
715 let r = if feature_var > 0.0 && target_var > 0.0 {
716 numerator / (feature_var * target_var).sqrt()
717 } else {
718 0.0
719 };
720
721 let r_squared = r * r;
723 if (1.0 - r_squared).abs() < 1e-10 {
724 f64::INFINITY
725 } else {
726 r_squared * (n - 2.0) / (1.0 - r_squared)
727 }
728}
729
730#[allow(non_snake_case)]
731#[cfg(test)]
732mod tests {
733 use super::*;
734 use scirs2_core::ndarray::array;
735
736 #[test]
737 fn test_feature_hierarchy_creation() {
738 let mut hierarchy = FeatureHierarchy::new();
739
740 hierarchy.add_node(0, None, Some(0)).unwrap();
742 hierarchy.add_node(1, None, Some(1)).unwrap();
743
744 hierarchy.add_node(2, Some(0), Some(0)).unwrap();
746 hierarchy.add_node(3, Some(0), Some(0)).unwrap();
747 hierarchy.add_node(4, Some(1), Some(1)).unwrap();
748
749 assert_eq!(hierarchy.root_nodes.len(), 2);
750 assert_eq!(hierarchy.max_level, 1);
751
752 let descendants_0 = hierarchy.get_descendants(0);
753 assert_eq!(descendants_0, vec![2, 3]);
754
755 let level_0_features = hierarchy.get_features_at_level(0);
756 assert_eq!(level_0_features, vec![0, 1]);
757
758 let group_0_features = hierarchy.get_features_in_group(0);
759 assert_eq!(group_0_features, vec![0, 2, 3]);
760 }
761
762 #[test]
763 fn test_hierarchical_selector_top_down() {
764 let mut hierarchy = FeatureHierarchy::new();
765 hierarchy.add_node(0, None, None).unwrap();
766 hierarchy.add_node(1, Some(0), None).unwrap();
767 hierarchy.add_node(2, Some(0), None).unwrap();
768 hierarchy.add_node(3, None, None).unwrap();
769
770 let x = array![
771 [1.0, 0.5, 0.8, 2.0],
772 [2.0, 1.0, 1.2, 4.0],
773 [3.0, 1.5, 1.8, 6.0],
774 [4.0, 2.0, 2.4, 8.0],
775 ];
776 let y = array![1.0, 2.0, 3.0, 4.0];
777
778 let selector = HierarchicalFeatureSelector::new(hierarchy, 2)
779 .selection_strategy(HierarchicalSelectionStrategy::TopDown);
780 let fitted = selector.fit(&x, &y).unwrap();
781
782 let selected = fitted.selected_features();
783 assert!(!selected.is_empty());
784 assert!(selected.len() <= 2);
785 }
786
787 #[test]
788 fn test_hierarchical_selector_transform() {
789 let mut hierarchy = FeatureHierarchy::new();
790 hierarchy.add_node(0, None, None).unwrap();
791 hierarchy.add_node(1, None, None).unwrap();
792 hierarchy.add_node(2, None, None).unwrap();
793
794 let x = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0],];
795 let y = array![1.0, 2.0, 3.0];
796
797 let selector = HierarchicalFeatureSelector::new(hierarchy, 2);
798 let fitted = selector.fit(&x, &y).unwrap();
799
800 let test_x = array![[10.0, 11.0, 12.0], [13.0, 14.0, 15.0]];
801 let transformed = fitted.transform(&test_x).unwrap();
802
803 assert_eq!(transformed.nrows(), 2);
804 assert!(transformed.ncols() <= 2);
805 }
806
807 #[test]
808 fn test_multi_level_selector() {
809 let mut hierarchy = FeatureHierarchy::new();
810 hierarchy.add_node(0, None, None).unwrap();
811 hierarchy.add_node(1, Some(0), None).unwrap();
812 hierarchy.add_node(2, Some(0), None).unwrap();
813 hierarchy.add_node(3, None, None).unwrap();
814
815 let x = array![
816 [1.0, 0.5, 0.8, 2.0],
817 [2.0, 1.0, 1.2, 4.0],
818 [3.0, 1.5, 1.8, 6.0],
819 ];
820 let y = array![1.0, 2.0, 3.0];
821
822 let mut k_per_level = HashMap::new();
823 k_per_level.insert(0, 1); k_per_level.insert(1, 1); let selector = MultiLevelHierarchicalSelector::new(hierarchy).k_per_level(k_per_level);
827 let fitted = selector.fit(&x, &y).unwrap();
828
829 let level_0_selected = fitted.selected_features_at_level(0);
830 assert!(level_0_selected.is_some());
831 assert_eq!(level_0_selected.unwrap().len(), 1);
832 }
833}