sklears_feature_selection/
hierarchical.rs

1//! Hierarchical feature selection methods
2//!
3//! This module provides algorithms for feature selection that respect hierarchical
4//! structure in the features, such as grouped features, multi-level categorical variables,
5//! or features with natural parent-child relationships.
6
7use crate::base::FeatureSelector;
8use scirs2_core::ndarray::{Array1, Array2, Axis};
9use sklears_core::{
10    error::{Result as SklResult, SklearsError},
11    traits::{Estimator, Fit, Trained, Transform, Untrained},
12    types::Float,
13};
14use std::collections::{HashMap, HashSet, VecDeque};
15use std::marker::PhantomData;
16
17/// Represents a node in the feature hierarchy
18#[derive(Debug, Clone)]
19pub struct HierarchyNode {
20    pub feature_id: usize,
21    pub parent: Option<usize>,
22    pub children: Vec<usize>,
23    pub level: usize,
24    pub group_id: Option<usize>,
25}
26
27/// Feature hierarchy structure
28#[derive(Debug, Clone)]
29pub struct FeatureHierarchy {
30    nodes: HashMap<usize, HierarchyNode>,
31    root_nodes: Vec<usize>,
32    max_level: usize,
33}
34
35impl FeatureHierarchy {
36    /// Create a new feature hierarchy
37    pub fn new() -> Self {
38        Self {
39            nodes: HashMap::new(),
40            root_nodes: Vec::new(),
41            max_level: 0,
42        }
43    }
44
45    /// Add a feature node to the hierarchy
46    pub fn add_node(
47        &mut self,
48        feature_id: usize,
49        parent: Option<usize>,
50        group_id: Option<usize>,
51    ) -> SklResult<()> {
52        let level = if let Some(parent_id) = parent {
53            if let Some(parent_node) = self.nodes.get(&parent_id) {
54                parent_node.level + 1
55            } else {
56                return Err(SklearsError::InvalidInput(format!(
57                    "Parent node {} not found",
58                    parent_id
59                )));
60            }
61        } else {
62            0
63        };
64
65        let node = HierarchyNode {
66            feature_id,
67            parent,
68            children: Vec::new(),
69            level,
70            group_id,
71        };
72
73        // Update parent's children list
74        if let Some(parent_id) = parent {
75            if let Some(parent_node) = self.nodes.get_mut(&parent_id) {
76                parent_node.children.push(feature_id);
77            }
78        } else {
79            self.root_nodes.push(feature_id);
80        }
81
82        self.max_level = self.max_level.max(level);
83        self.nodes.insert(feature_id, node);
84        Ok(())
85    }
86
87    /// Get all descendants of a node
88    pub fn get_descendants(&self, feature_id: usize) -> Vec<usize> {
89        let mut descendants = Vec::new();
90        let mut queue = VecDeque::new();
91
92        if let Some(node) = self.nodes.get(&feature_id) {
93            queue.extend(&node.children);
94        }
95
96        while let Some(child_id) = queue.pop_front() {
97            descendants.push(child_id);
98            if let Some(child_node) = self.nodes.get(&child_id) {
99                queue.extend(&child_node.children);
100            }
101        }
102
103        descendants
104    }
105
106    /// Get all ancestors of a node
107    pub fn get_ancestors(&self, feature_id: usize) -> Vec<usize> {
108        let mut ancestors = Vec::new();
109        let mut current_id = feature_id;
110
111        while let Some(node) = self.nodes.get(&current_id) {
112            if let Some(parent_id) = node.parent {
113                ancestors.push(parent_id);
114                current_id = parent_id;
115            } else {
116                break;
117            }
118        }
119
120        ancestors
121    }
122
123    /// Get features at a specific level
124    pub fn get_features_at_level(&self, level: usize) -> Vec<usize> {
125        let mut features: Vec<usize> = self
126            .nodes
127            .values()
128            .filter(|node| node.level == level)
129            .map(|node| node.feature_id)
130            .collect();
131        features.sort();
132        features
133    }
134
135    /// Get features in a specific group
136    pub fn get_features_in_group(&self, group_id: usize) -> Vec<usize> {
137        let mut features: Vec<usize> = self
138            .nodes
139            .values()
140            .filter(|node| node.group_id == Some(group_id))
141            .map(|node| node.feature_id)
142            .collect();
143        features.sort();
144        features
145    }
146
147    /// Check if a feature is a leaf node (has no children)
148    pub fn is_leaf(&self, feature_id: usize) -> bool {
149        self.nodes
150            .get(&feature_id)
151            .map(|node| node.children.is_empty())
152            .unwrap_or(false)
153    }
154
155    /// Get all leaf nodes
156    pub fn get_leaf_nodes(&self) -> Vec<usize> {
157        self.nodes
158            .values()
159            .filter(|node| node.children.is_empty())
160            .map(|node| node.feature_id)
161            .collect()
162    }
163}
164
165impl Default for FeatureHierarchy {
166    fn default() -> Self {
167        Self::new()
168    }
169}
170
171/// Hierarchical feature selector using top-down selection
172///
173/// Selects features starting from the root level and moving down the hierarchy,
174/// ensuring that if a parent is selected, at least one child is considered.
175#[derive(Debug, Clone)]
176pub struct HierarchicalFeatureSelector<State = Untrained> {
177    hierarchy: FeatureHierarchy,
178    k: usize,
179    selection_strategy: HierarchicalSelectionStrategy,
180    score_aggregation: ScoreAggregation,
181
182    // Fitted state
183    selected_features_: Option<Vec<usize>>,
184    feature_scores_: Option<HashMap<usize, Float>>,
185
186    state: PhantomData<State>,
187}
188
189/// Strategy for hierarchical selection
190#[derive(Debug, Clone)]
191pub enum HierarchicalSelectionStrategy {
192    /// Select from top to bottom, ensuring parent-child consistency
193    TopDown,
194    /// Select from bottom to top, propagating scores upward
195    BottomUp,
196    /// Select at each level independently
197    LevelWise,
198    /// Use group-based selection within hierarchy
199    GroupBased,
200}
201
202/// Method for aggregating scores across hierarchy levels
203#[derive(Debug, Clone)]
204pub enum ScoreAggregation {
205    /// Sum scores across levels
206    Sum,
207    /// Take maximum score across levels
208    Max,
209    /// Take weighted average (higher levels get higher weights)
210    WeightedAverage,
211    /// Use multiplicative combination
212    Product,
213}
214
215impl HierarchicalFeatureSelector<Untrained> {
216    /// Create a new hierarchical feature selector
217    pub fn new(hierarchy: FeatureHierarchy, k: usize) -> Self {
218        Self {
219            hierarchy,
220            k,
221            selection_strategy: HierarchicalSelectionStrategy::TopDown,
222            score_aggregation: ScoreAggregation::Sum,
223            selected_features_: None,
224            feature_scores_: None,
225            state: PhantomData,
226        }
227    }
228
229    /// Set the selection strategy
230    pub fn selection_strategy(mut self, strategy: HierarchicalSelectionStrategy) -> Self {
231        self.selection_strategy = strategy;
232        self
233    }
234
235    /// Set the score aggregation method
236    pub fn score_aggregation(mut self, aggregation: ScoreAggregation) -> Self {
237        self.score_aggregation = aggregation;
238        self
239    }
240}
241
242impl Estimator for HierarchicalFeatureSelector<Untrained> {
243    type Config = ();
244    type Error = SklearsError;
245    type Float = f64;
246
247    fn config(&self) -> &Self::Config {
248        &()
249    }
250}
251
252impl Fit<Array2<Float>, Array1<Float>> for HierarchicalFeatureSelector<Untrained> {
253    type Fitted = HierarchicalFeatureSelector<Trained>;
254
255    fn fit(self, x: &Array2<Float>, y: &Array1<Float>) -> SklResult<Self::Fitted> {
256        let (n_samples, n_features) = x.dim();
257        if n_samples == 0 || n_features == 0 {
258            return Err(SklearsError::InvalidInput(
259                "Input data cannot be empty".to_string(),
260            ));
261        }
262
263        if self.k > n_features {
264            return Err(SklearsError::InvalidInput(
265                "k cannot be larger than number of features".to_string(),
266            ));
267        }
268
269        // Compute base feature scores using F-statistic
270        let mut feature_scores = HashMap::new();
271        for feature_idx in 0..n_features {
272            let feature_col = x.column(feature_idx);
273            let score = compute_f_score(&feature_col.to_owned(), y);
274            feature_scores.insert(feature_idx, score);
275        }
276
277        // Apply hierarchical selection based on strategy
278        let selected_features = match self.selection_strategy {
279            HierarchicalSelectionStrategy::TopDown => self.select_top_down(&feature_scores)?,
280            HierarchicalSelectionStrategy::BottomUp => self.select_bottom_up(&feature_scores)?,
281            HierarchicalSelectionStrategy::LevelWise => self.select_level_wise(&feature_scores)?,
282            HierarchicalSelectionStrategy::GroupBased => {
283                self.select_group_based(&feature_scores)?
284            }
285        };
286
287        Ok(HierarchicalFeatureSelector {
288            hierarchy: self.hierarchy,
289            k: self.k,
290            selection_strategy: self.selection_strategy,
291            score_aggregation: self.score_aggregation,
292            selected_features_: Some(selected_features),
293            feature_scores_: Some(feature_scores),
294            state: PhantomData,
295        })
296    }
297}
298
299impl HierarchicalFeatureSelector<Untrained> {
300    /// Top-down hierarchical selection
301    fn select_top_down(&self, feature_scores: &HashMap<usize, Float>) -> SklResult<Vec<usize>> {
302        let mut selected = HashSet::new();
303        let mut candidates = VecDeque::new();
304
305        // Start with root nodes
306        candidates.extend(&self.hierarchy.root_nodes);
307
308        while !candidates.is_empty() && selected.len() < self.k {
309            let mut level_scores: Vec<(usize, Float)> = candidates
310                .iter()
311                .filter_map(|&feature_id| {
312                    feature_scores
313                        .get(&feature_id)
314                        .map(|&score| (feature_id, score))
315                })
316                .collect();
317
318            // Sort by score (descending)
319            level_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
320
321            // Select best features from current level
322            let mut next_candidates: VecDeque<usize> = VecDeque::new();
323            for (feature_id, _) in level_scores {
324                if selected.len() >= self.k {
325                    break;
326                }
327
328                selected.insert(feature_id);
329                candidates.retain(|&x| x != feature_id);
330
331                // Add children to next level candidates
332                if let Some(node) = self.hierarchy.nodes.get(&feature_id) {
333                    next_candidates.extend(&node.children);
334                }
335            }
336
337            candidates.extend(next_candidates);
338        }
339
340        Ok(selected.into_iter().collect())
341    }
342
343    /// Bottom-up hierarchical selection
344    fn select_bottom_up(&self, feature_scores: &HashMap<usize, Float>) -> SklResult<Vec<usize>> {
345        let mut aggregated_scores = feature_scores.clone();
346
347        // Propagate scores from leaves to roots
348        for level in (0..=self.hierarchy.max_level).rev() {
349            let level_features = self.hierarchy.get_features_at_level(level);
350
351            for feature_id in level_features {
352                if let Some(node) = self.hierarchy.nodes.get(&feature_id) {
353                    if !node.children.is_empty() {
354                        // Aggregate children scores
355                        let child_scores: Vec<Float> = node
356                            .children
357                            .iter()
358                            .filter_map(|&child_id| aggregated_scores.get(&child_id))
359                            .cloned()
360                            .collect();
361
362                        if !child_scores.is_empty() {
363                            let aggregated = self.aggregate_scores(&child_scores);
364                            let current_score =
365                                aggregated_scores.get(&feature_id).cloned().unwrap_or(0.0);
366                            aggregated_scores.insert(feature_id, current_score + aggregated);
367                        }
368                    }
369                }
370            }
371        }
372
373        // Select top k features based on aggregated scores
374        let mut scored_features: Vec<(usize, Float)> = aggregated_scores.into_iter().collect();
375        scored_features.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
376
377        Ok(scored_features
378            .into_iter()
379            .take(self.k)
380            .map(|(feature_id, _)| feature_id)
381            .collect())
382    }
383
384    /// Level-wise hierarchical selection
385    fn select_level_wise(&self, feature_scores: &HashMap<usize, Float>) -> SklResult<Vec<usize>> {
386        let mut selected = Vec::new();
387        let features_per_level = self.k / (self.hierarchy.max_level + 1);
388        let remaining = self.k % (self.hierarchy.max_level + 1);
389
390        for level in 0..=self.hierarchy.max_level {
391            let level_features = self.hierarchy.get_features_at_level(level);
392            let mut level_scores: Vec<(usize, Float)> = level_features
393                .into_iter()
394                .filter_map(|feature_id| {
395                    feature_scores
396                        .get(&feature_id)
397                        .map(|&score| (feature_id, score))
398                })
399                .collect();
400
401            level_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
402
403            let k_for_level = if level < remaining {
404                features_per_level + 1
405            } else {
406                features_per_level
407            };
408
409            selected.extend(
410                level_scores
411                    .into_iter()
412                    .take(k_for_level)
413                    .map(|(feature_id, _)| feature_id),
414            );
415        }
416
417        Ok(selected)
418    }
419
420    /// Group-based hierarchical selection
421    fn select_group_based(&self, feature_scores: &HashMap<usize, Float>) -> SklResult<Vec<usize>> {
422        // Get all unique groups
423        let mut groups: HashSet<usize> = HashSet::new();
424        for node in self.hierarchy.nodes.values() {
425            if let Some(group_id) = node.group_id {
426                groups.insert(group_id);
427            }
428        }
429
430        if groups.is_empty() {
431            // Fallback to regular top-k selection
432            let mut scored_features: Vec<(usize, Float)> = feature_scores
433                .iter()
434                .map(|(&feature_id, &score)| (feature_id, score))
435                .collect();
436            scored_features.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
437
438            return Ok(scored_features
439                .into_iter()
440                .take(self.k)
441                .map(|(feature_id, _)| feature_id)
442                .collect());
443        }
444
445        let features_per_group = self.k / groups.len();
446        let remaining = self.k % groups.len();
447        let mut selected = Vec::new();
448
449        for (group_idx, group_id) in groups.into_iter().enumerate() {
450            let group_features = self.hierarchy.get_features_in_group(group_id);
451            let mut group_scores: Vec<(usize, Float)> = group_features
452                .into_iter()
453                .filter_map(|feature_id| {
454                    feature_scores
455                        .get(&feature_id)
456                        .map(|&score| (feature_id, score))
457                })
458                .collect();
459
460            group_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
461
462            let k_for_group = if group_idx < remaining {
463                features_per_group + 1
464            } else {
465                features_per_group
466            };
467
468            selected.extend(
469                group_scores
470                    .into_iter()
471                    .take(k_for_group)
472                    .map(|(feature_id, _)| feature_id),
473            );
474        }
475
476        Ok(selected)
477    }
478
479    /// Aggregate scores using the specified method
480    fn aggregate_scores(&self, scores: &[Float]) -> Float {
481        if scores.is_empty() {
482            return 0.0;
483        }
484
485        match self.score_aggregation {
486            ScoreAggregation::Sum => scores.iter().sum(),
487            ScoreAggregation::Max => scores.iter().cloned().fold(f64::NEG_INFINITY, f64::max),
488            ScoreAggregation::WeightedAverage => {
489                let sum: Float = scores.iter().sum();
490                sum / scores.len() as Float
491            }
492            ScoreAggregation::Product => scores.iter().product(),
493        }
494    }
495}
496
497impl FeatureSelector for HierarchicalFeatureSelector<Trained> {
498    fn selected_features(&self) -> &Vec<usize> {
499        match &self.selected_features_ {
500            Some(features) => features,
501            None => {
502                static EMPTY: Vec<usize> = Vec::new();
503                &EMPTY
504            }
505        }
506    }
507}
508
509impl Transform<Array2<Float>, Array2<Float>> for HierarchicalFeatureSelector<Trained> {
510    fn transform(&self, x: &Array2<Float>) -> SklResult<Array2<Float>> {
511        if let Some(selected) = &self.selected_features_ {
512            if selected.is_empty() {
513                return Err(SklearsError::InvalidData {
514                    reason: "No features selected".to_string(),
515                });
516            }
517
518            let selected_cols = x.select(Axis(1), selected);
519            Ok(selected_cols)
520        } else {
521            Err(SklearsError::InvalidData {
522                reason: "Selector not fitted yet".to_string(),
523            })
524        }
525    }
526}
527
528/// Multi-level hierarchical feature selector
529///
530/// Performs feature selection at multiple levels of the hierarchy simultaneously
531#[derive(Debug, Clone)]
532pub struct MultiLevelHierarchicalSelector<State = Untrained> {
533    hierarchy: FeatureHierarchy,
534    k_per_level: HashMap<usize, usize>,
535    level_weights: HashMap<usize, Float>,
536
537    // Fitted state
538    selected_features_: Option<HashMap<usize, Vec<usize>>>,
539    level_scores_: Option<HashMap<usize, HashMap<usize, Float>>>,
540
541    state: PhantomData<State>,
542}
543
544impl MultiLevelHierarchicalSelector<Untrained> {
545    /// Create a new multi-level hierarchical selector
546    pub fn new(hierarchy: FeatureHierarchy) -> Self {
547        Self {
548            hierarchy,
549            k_per_level: HashMap::new(),
550            level_weights: HashMap::new(),
551            selected_features_: None,
552            level_scores_: None,
553            state: PhantomData,
554        }
555    }
556
557    /// Set number of features to select at each level
558    pub fn k_per_level(mut self, k_per_level: HashMap<usize, usize>) -> Self {
559        self.k_per_level = k_per_level;
560        self
561    }
562
563    /// Set weights for each level (used in scoring)
564    pub fn level_weights(mut self, level_weights: HashMap<usize, Float>) -> Self {
565        self.level_weights = level_weights;
566        self
567    }
568}
569
570impl Estimator for MultiLevelHierarchicalSelector<Untrained> {
571    type Config = ();
572    type Error = SklearsError;
573    type Float = f64;
574
575    fn config(&self) -> &Self::Config {
576        &()
577    }
578}
579
580impl Fit<Array2<Float>, Array1<Float>> for MultiLevelHierarchicalSelector<Untrained> {
581    type Fitted = MultiLevelHierarchicalSelector<Trained>;
582
583    fn fit(self, x: &Array2<Float>, y: &Array1<Float>) -> SklResult<Self::Fitted> {
584        let (n_samples, n_features) = x.dim();
585        if n_samples == 0 || n_features == 0 {
586            return Err(SklearsError::InvalidInput(
587                "Input data cannot be empty".to_string(),
588            ));
589        }
590
591        // Compute feature scores
592        let mut feature_scores = HashMap::new();
593        for feature_idx in 0..n_features {
594            let feature_col = x.column(feature_idx);
595            let score = compute_f_score(&feature_col.to_owned(), y);
596            feature_scores.insert(feature_idx, score);
597        }
598
599        // Select features at each level
600        let mut selected_features = HashMap::new();
601        let mut level_scores = HashMap::new();
602
603        for level in 0..=self.hierarchy.max_level {
604            let level_features = self.hierarchy.get_features_at_level(level);
605            let k_for_level = self.k_per_level.get(&level).cloned().unwrap_or(
606                level_features.len().min(5), // Default to 5 or all features at level
607            );
608
609            let mut level_feature_scores: Vec<(usize, Float)> = level_features
610                .into_iter()
611                .filter_map(|feature_id| {
612                    feature_scores.get(&feature_id).map(|&score| {
613                        let weight = self.level_weights.get(&level).cloned().unwrap_or(1.0);
614                        (feature_id, score * weight)
615                    })
616                })
617                .collect();
618
619            level_feature_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
620
621            let selected_at_level: Vec<usize> = level_feature_scores
622                .into_iter()
623                .take(k_for_level)
624                .map(|(feature_id, score)| {
625                    level_scores
626                        .entry(level)
627                        .or_insert_with(HashMap::new)
628                        .insert(feature_id, score);
629                    feature_id
630                })
631                .collect();
632
633            selected_features.insert(level, selected_at_level);
634        }
635
636        Ok(MultiLevelHierarchicalSelector {
637            hierarchy: self.hierarchy,
638            k_per_level: self.k_per_level,
639            level_weights: self.level_weights,
640            selected_features_: Some(selected_features),
641            level_scores_: Some(level_scores),
642            state: PhantomData,
643        })
644    }
645}
646
647impl MultiLevelHierarchicalSelector<Trained> {
648    /// Get selected features at a specific level
649    pub fn selected_features_at_level(&self, level: usize) -> Option<&Vec<usize>> {
650        self.selected_features_.as_ref()?.get(&level)
651    }
652
653    /// Get all selected features across all levels
654    pub fn all_selected_features(&self) -> Vec<usize> {
655        if let Some(selected_features) = &self.selected_features_ {
656            let mut all_features = Vec::new();
657            for features in selected_features.values() {
658                all_features.extend_from_slice(features);
659            }
660            all_features.sort_unstable();
661            all_features.dedup();
662            all_features
663        } else {
664            Vec::new()
665        }
666    }
667}
668
669impl FeatureSelector for MultiLevelHierarchicalSelector<Trained> {
670    fn selected_features(&self) -> &Vec<usize> {
671        // This is a bit tricky since we don't store a single Vec<usize>
672        // We'll need a different approach for this trait implementation
673        static EMPTY: Vec<usize> = Vec::new();
674        &EMPTY
675    }
676}
677
678impl Transform<Array2<Float>, Array2<Float>> for MultiLevelHierarchicalSelector<Trained> {
679    fn transform(&self, x: &Array2<Float>) -> SklResult<Array2<Float>> {
680        let all_selected = self.all_selected_features();
681        if all_selected.is_empty() {
682            return Err(SklearsError::InvalidData {
683                reason: "No features selected".to_string(),
684            });
685        }
686
687        let selected_cols = x.select(Axis(1), &all_selected);
688        Ok(selected_cols)
689    }
690}
691
692/// Compute F-score for a feature
693fn compute_f_score(feature: &Array1<Float>, target: &Array1<Float>) -> Float {
694    if feature.len() != target.len() || feature.len() < 3 {
695        return 0.0;
696    }
697
698    let n = feature.len() as Float;
699    let feature_mean = feature.mean().unwrap_or(0.0);
700    let target_mean = target.mean().unwrap_or(0.0);
701
702    // Compute correlation coefficient
703    let mut numerator = 0.0;
704    let mut feature_var = 0.0;
705    let mut target_var = 0.0;
706
707    for i in 0..feature.len() {
708        let feature_dev = feature[i] - feature_mean;
709        let target_dev = target[i] - target_mean;
710        numerator += feature_dev * target_dev;
711        feature_var += feature_dev * feature_dev;
712        target_var += target_dev * target_dev;
713    }
714
715    let r = if feature_var > 0.0 && target_var > 0.0 {
716        numerator / (feature_var * target_var).sqrt()
717    } else {
718        0.0
719    };
720
721    // Convert correlation to F-statistic
722    let r_squared = r * r;
723    if (1.0 - r_squared).abs() < 1e-10 {
724        f64::INFINITY
725    } else {
726        r_squared * (n - 2.0) / (1.0 - r_squared)
727    }
728}
729
730#[allow(non_snake_case)]
731#[cfg(test)]
732mod tests {
733    use super::*;
734    use scirs2_core::ndarray::array;
735
736    #[test]
737    fn test_feature_hierarchy_creation() {
738        let mut hierarchy = FeatureHierarchy::new();
739
740        // Add root features
741        hierarchy.add_node(0, None, Some(0)).unwrap();
742        hierarchy.add_node(1, None, Some(1)).unwrap();
743
744        // Add child features
745        hierarchy.add_node(2, Some(0), Some(0)).unwrap();
746        hierarchy.add_node(3, Some(0), Some(0)).unwrap();
747        hierarchy.add_node(4, Some(1), Some(1)).unwrap();
748
749        assert_eq!(hierarchy.root_nodes.len(), 2);
750        assert_eq!(hierarchy.max_level, 1);
751
752        let descendants_0 = hierarchy.get_descendants(0);
753        assert_eq!(descendants_0, vec![2, 3]);
754
755        let level_0_features = hierarchy.get_features_at_level(0);
756        assert_eq!(level_0_features, vec![0, 1]);
757
758        let group_0_features = hierarchy.get_features_in_group(0);
759        assert_eq!(group_0_features, vec![0, 2, 3]);
760    }
761
762    #[test]
763    fn test_hierarchical_selector_top_down() {
764        let mut hierarchy = FeatureHierarchy::new();
765        hierarchy.add_node(0, None, None).unwrap();
766        hierarchy.add_node(1, Some(0), None).unwrap();
767        hierarchy.add_node(2, Some(0), None).unwrap();
768        hierarchy.add_node(3, None, None).unwrap();
769
770        let x = array![
771            [1.0, 0.5, 0.8, 2.0],
772            [2.0, 1.0, 1.2, 4.0],
773            [3.0, 1.5, 1.8, 6.0],
774            [4.0, 2.0, 2.4, 8.0],
775        ];
776        let y = array![1.0, 2.0, 3.0, 4.0];
777
778        let selector = HierarchicalFeatureSelector::new(hierarchy, 2)
779            .selection_strategy(HierarchicalSelectionStrategy::TopDown);
780        let fitted = selector.fit(&x, &y).unwrap();
781
782        let selected = fitted.selected_features();
783        assert!(!selected.is_empty());
784        assert!(selected.len() <= 2);
785    }
786
787    #[test]
788    fn test_hierarchical_selector_transform() {
789        let mut hierarchy = FeatureHierarchy::new();
790        hierarchy.add_node(0, None, None).unwrap();
791        hierarchy.add_node(1, None, None).unwrap();
792        hierarchy.add_node(2, None, None).unwrap();
793
794        let x = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0],];
795        let y = array![1.0, 2.0, 3.0];
796
797        let selector = HierarchicalFeatureSelector::new(hierarchy, 2);
798        let fitted = selector.fit(&x, &y).unwrap();
799
800        let test_x = array![[10.0, 11.0, 12.0], [13.0, 14.0, 15.0]];
801        let transformed = fitted.transform(&test_x).unwrap();
802
803        assert_eq!(transformed.nrows(), 2);
804        assert!(transformed.ncols() <= 2);
805    }
806
807    #[test]
808    fn test_multi_level_selector() {
809        let mut hierarchy = FeatureHierarchy::new();
810        hierarchy.add_node(0, None, None).unwrap();
811        hierarchy.add_node(1, Some(0), None).unwrap();
812        hierarchy.add_node(2, Some(0), None).unwrap();
813        hierarchy.add_node(3, None, None).unwrap();
814
815        let x = array![
816            [1.0, 0.5, 0.8, 2.0],
817            [2.0, 1.0, 1.2, 4.0],
818            [3.0, 1.5, 1.8, 6.0],
819        ];
820        let y = array![1.0, 2.0, 3.0];
821
822        let mut k_per_level = HashMap::new();
823        k_per_level.insert(0, 1); // Select 1 feature at level 0
824        k_per_level.insert(1, 1); // Select 1 feature at level 1
825
826        let selector = MultiLevelHierarchicalSelector::new(hierarchy).k_per_level(k_per_level);
827        let fitted = selector.fit(&x, &y).unwrap();
828
829        let level_0_selected = fitted.selected_features_at_level(0);
830        assert!(level_0_selected.is_some());
831        assert_eq!(level_0_selected.unwrap().len(), 1);
832    }
833}