sklears_feature_selection/
multi_label.rs

1//! Multi-label feature selection algorithms
2//!
3//! This module provides feature selection methods specifically designed for multi-label datasets,
4//! where each instance can be associated with multiple labels simultaneously.
5
6use crate::base::{FeatureSelector, SelectorMixin};
7use scirs2_core::ndarray::{Array1, Array2};
8use sklears_core::{
9    error::{validate, Result as SklResult, SklearsError},
10    traits::{Estimator, Fit, Trained, Transform, Untrained},
11    types::Float,
12};
13use std::collections::{HashMap, HashSet};
14use std::marker::PhantomData;
15
16/// Type alias for multi-label targets where each row can have multiple labels
17pub type MultiLabelTarget = Array2<Float>;
18
19/// Multi-label feature selection strategy
20#[derive(Debug, Clone)]
21pub enum MultiLabelStrategy {
22    /// Select features relevant to all labels
23    GlobalRelevance,
24    /// Select features most relevant to individual labels and combine
25    LabelSpecific,
26    /// Use label correlations to guide feature selection
27    LabelCorrelationAware,
28    /// Hierarchical selection considering label relationships
29    HierarchicalLabels,
30    /// Ensemble approach combining multiple strategies
31    Ensemble,
32}
33
34/// Aggregation method for combining label-specific selections
35#[derive(Debug, Clone)]
36pub enum AggregateMethod {
37    /// Union
38    Union,
39    /// Intersection
40    Intersection,
41    /// MajorityVote
42    MajorityVote,
43    /// WeightedUnion
44    WeightedUnion,
45}
46
47/// Core multi-label feature selector
48#[derive(Debug, Clone)]
49pub struct MultiLabelFeatureSelector<State = Untrained> {
50    strategy: MultiLabelStrategy,
51    n_features: Option<usize>,
52    threshold: Float,
53    min_label_frequency: Float,
54    use_label_correlation: bool,
55    correlation_threshold: Float,
56    state: PhantomData<State>,
57    // Trained state
58    scores_: Option<Array1<Float>>,
59    selected_features_: Option<Vec<usize>>,
60    n_features_: Option<usize>,
61    n_labels_: Option<usize>,
62}
63
64impl Default for MultiLabelFeatureSelector<Untrained> {
65    fn default() -> Self {
66        Self::new()
67    }
68}
69
70impl MultiLabelFeatureSelector<Untrained> {
71    /// Create a new multi-label feature selector
72    pub fn new() -> Self {
73        Self {
74            strategy: MultiLabelStrategy::LabelSpecific,
75            n_features: None,
76            threshold: 0.01,
77            min_label_frequency: 0.01,
78            use_label_correlation: true,
79            correlation_threshold: 0.1,
80            state: PhantomData,
81            scores_: None,
82            selected_features_: None,
83            n_features_: None,
84            n_labels_: None,
85        }
86    }
87
88    /// Set the selection strategy
89    pub fn strategy(mut self, strategy: MultiLabelStrategy) -> Self {
90        self.strategy = strategy;
91        self
92    }
93
94    /// Set the number of features to select
95    pub fn n_features(mut self, n_features: usize) -> Self {
96        self.n_features = Some(n_features);
97        self
98    }
99
100    /// Set the relevance threshold
101    pub fn threshold(mut self, threshold: Float) -> Self {
102        self.threshold = threshold;
103        self
104    }
105
106    /// Set the minimum label frequency
107    pub fn min_label_frequency(mut self, frequency: Float) -> Self {
108        self.min_label_frequency = frequency;
109        self
110    }
111
112    /// Set whether to use label correlation
113    pub fn use_label_correlation(mut self, use_correlation: bool) -> Self {
114        self.use_label_correlation = use_correlation;
115        self
116    }
117
118    /// Set the correlation threshold
119    pub fn correlation_threshold(mut self, threshold: Float) -> Self {
120        self.correlation_threshold = threshold;
121        self
122    }
123
124    /// Compute feature relevance for multi-label data
125    fn compute_multi_label_relevance(
126        &self,
127        features: &Array2<Float>,
128        labels: &MultiLabelTarget,
129    ) -> SklResult<Array1<Float>> {
130        let n_features = features.ncols();
131        let mut relevance_scores = Array1::zeros(n_features);
132
133        match self.strategy {
134            MultiLabelStrategy::GlobalRelevance => {
135                self.compute_global_relevance(features, labels, &mut relevance_scores)?;
136            }
137            MultiLabelStrategy::LabelSpecific => {
138                self.compute_label_specific_relevance(features, labels, &mut relevance_scores)?;
139            }
140            MultiLabelStrategy::LabelCorrelationAware => {
141                self.compute_correlation_aware_relevance(features, labels, &mut relevance_scores)?;
142            }
143            MultiLabelStrategy::HierarchicalLabels => {
144                self.compute_hierarchical_relevance(features, labels, &mut relevance_scores)?;
145            }
146            MultiLabelStrategy::Ensemble => {
147                self.compute_ensemble_relevance(features, labels, &mut relevance_scores)?;
148            }
149        }
150
151        Ok(relevance_scores)
152    }
153
154    /// Compute global relevance across all labels
155    fn compute_global_relevance(
156        &self,
157        features: &Array2<Float>,
158        labels: &MultiLabelTarget,
159        scores: &mut Array1<Float>,
160    ) -> SklResult<()> {
161        let n_features = features.ncols();
162        let n_labels = labels.ncols();
163
164        for feature_idx in 0..n_features {
165            let feature_col = features.column(feature_idx);
166            let mut total_relevance = 0.0;
167
168            for label_idx in 0..n_labels {
169                let label_col = labels.column(label_idx);
170
171                // Compute correlation between feature and label
172                let corr = self.compute_correlation(&feature_col, &label_col)?;
173                total_relevance += corr.abs();
174            }
175
176            scores[feature_idx] = total_relevance / n_labels as Float;
177        }
178
179        Ok(())
180    }
181
182    /// Compute label-specific relevance and aggregate
183    fn compute_label_specific_relevance(
184        &self,
185        features: &Array2<Float>,
186        labels: &MultiLabelTarget,
187        scores: &mut Array1<Float>,
188    ) -> SklResult<()> {
189        let n_features = features.ncols();
190        let n_labels = labels.ncols();
191
192        // Compute relevance for each label separately
193        let mut label_relevances = Array2::zeros((n_labels, n_features));
194
195        for label_idx in 0..n_labels {
196            let label_col = labels.column(label_idx);
197
198            // Skip labels with insufficient frequency
199            let label_frequency = label_col.sum() / label_col.len() as Float;
200            if label_frequency < self.min_label_frequency {
201                continue;
202            }
203
204            for feature_idx in 0..n_features {
205                let feature_col = features.column(feature_idx);
206                let corr = self.compute_correlation(&feature_col, &label_col)?;
207                label_relevances[[label_idx, feature_idx]] = corr.abs();
208            }
209        }
210
211        // Aggregate relevances across labels (use max relevance)
212        for feature_idx in 0..n_features {
213            let feature_relevances = label_relevances.column(feature_idx);
214            scores[feature_idx] = feature_relevances.iter().cloned().fold(0.0, Float::max);
215        }
216
217        Ok(())
218    }
219
220    /// Compute correlation-aware relevance considering label interactions
221    fn compute_correlation_aware_relevance(
222        &self,
223        features: &Array2<Float>,
224        labels: &MultiLabelTarget,
225        scores: &mut Array1<Float>,
226    ) -> SklResult<()> {
227        let n_features = features.ncols();
228        let n_labels = labels.ncols();
229
230        // Compute label correlation matrix
231        let label_correlations = self.compute_label_correlation_matrix(labels)?;
232
233        for feature_idx in 0..n_features {
234            let feature_col = features.column(feature_idx);
235            let mut weighted_relevance = 0.0;
236            let mut total_weight = 0.0;
237
238            for label_idx in 0..n_labels {
239                let label_col = labels.column(label_idx);
240                let corr = self.compute_correlation(&feature_col, &label_col)?;
241
242                // Weight by label importance and correlation structure
243                let label_weight = self.compute_label_weight(label_idx, &label_correlations);
244                weighted_relevance += corr.abs() * label_weight;
245                total_weight += label_weight;
246            }
247
248            scores[feature_idx] = if total_weight > 0.0 {
249                weighted_relevance / total_weight
250            } else {
251                0.0
252            };
253        }
254
255        Ok(())
256    }
257
258    /// Compute hierarchical relevance for structured label spaces
259    fn compute_hierarchical_relevance(
260        &self,
261        features: &Array2<Float>,
262        labels: &MultiLabelTarget,
263        scores: &mut Array1<Float>,
264    ) -> SklResult<()> {
265        // For now, implement a simplified hierarchical approach
266        // In practice, this would require label hierarchy information
267        self.compute_label_specific_relevance(features, labels, scores)?;
268
269        // Apply hierarchical weighting (simplified)
270        for score in scores.iter_mut() {
271            *score *= 1.1; // Slight boost for hierarchical consideration
272        }
273
274        Ok(())
275    }
276
277    /// Compute ensemble relevance combining multiple strategies
278    fn compute_ensemble_relevance(
279        &self,
280        features: &Array2<Float>,
281        labels: &MultiLabelTarget,
282        scores: &mut Array1<Float>,
283    ) -> SklResult<()> {
284        let n_features = features.ncols();
285        let mut global_scores = Array1::zeros(n_features);
286        let mut specific_scores = Array1::zeros(n_features);
287        let mut correlation_scores = Array1::zeros(n_features);
288
289        // Compute scores using different strategies
290        self.compute_global_relevance(features, labels, &mut global_scores)?;
291        self.compute_label_specific_relevance(features, labels, &mut specific_scores)?;
292        self.compute_correlation_aware_relevance(features, labels, &mut correlation_scores)?;
293
294        // Combine scores with equal weights
295        for feature_idx in 0..n_features {
296            scores[feature_idx] = (global_scores[feature_idx]
297                + specific_scores[feature_idx]
298                + correlation_scores[feature_idx])
299                / 3.0;
300        }
301
302        Ok(())
303    }
304
305    /// Compute correlation between feature and label
306    fn compute_correlation(
307        &self,
308        feature: &scirs2_core::ndarray::ArrayView1<Float>,
309        label: &scirs2_core::ndarray::ArrayView1<Float>,
310    ) -> SklResult<Float> {
311        let feature_mean = feature.mean().unwrap_or(0.0);
312        let label_mean = label.mean().unwrap_or(0.0);
313
314        let mut covariance = 0.0;
315        let mut feature_var = 0.0;
316        let mut label_var = 0.0;
317
318        let n = feature.len();
319        if n == 0 {
320            return Ok(0.0);
321        }
322
323        for i in 0..n {
324            let f_diff = feature[i] - feature_mean;
325            let l_diff = label[i] - label_mean;
326
327            covariance += f_diff * l_diff;
328            feature_var += f_diff * f_diff;
329            label_var += l_diff * l_diff;
330        }
331
332        if feature_var == 0.0 || label_var == 0.0 {
333            return Ok(0.0);
334        }
335
336        let correlation = covariance / (feature_var * label_var).sqrt();
337        Ok(correlation)
338    }
339
340    /// Compute label correlation matrix
341    fn compute_label_correlation_matrix(
342        &self,
343        labels: &MultiLabelTarget,
344    ) -> SklResult<Array2<Float>> {
345        let n_labels = labels.ncols();
346        let mut correlations = Array2::zeros((n_labels, n_labels));
347
348        for i in 0..n_labels {
349            for j in 0..n_labels {
350                if i == j {
351                    correlations[[i, j]] = 1.0;
352                } else {
353                    let label_i = labels.column(i);
354                    let label_j = labels.column(j);
355                    let corr = self.compute_correlation(&label_i, &label_j)?;
356                    correlations[[i, j]] = corr;
357                }
358            }
359        }
360
361        Ok(correlations)
362    }
363
364    /// Compute weight for a label based on correlation structure
365    fn compute_label_weight(&self, label_idx: usize, correlations: &Array2<Float>) -> Float {
366        let label_correlations = correlations.row(label_idx);
367        let avg_correlation = label_correlations.mean().unwrap_or(0.0);
368
369        // Labels with moderate correlations get higher weights
370        1.0 - (avg_correlation - 0.5).abs()
371    }
372
373    /// Select features based on computed relevance scores
374    fn select_features(&self, relevance_scores: &Array1<Float>) -> SklResult<Vec<usize>> {
375        let n_features = relevance_scores.len();
376
377        if let Some(k) = self.n_features {
378            if k > n_features {
379                return Err(SklearsError::InvalidInput(format!(
380                    "n_features ({}) must be <= total features ({})",
381                    k, n_features
382                )));
383            }
384            // Select top k features
385            let mut indices: Vec<usize> = (0..n_features).collect();
386            indices.sort_by(|&a, &b| {
387                relevance_scores[b]
388                    .partial_cmp(&relevance_scores[a])
389                    .unwrap()
390            });
391            indices.truncate(k);
392            Ok(indices)
393        } else {
394            // Select features above threshold
395            let selected: Vec<usize> = relevance_scores
396                .iter()
397                .enumerate()
398                .filter(|(_, &score)| score >= self.threshold)
399                .map(|(idx, _)| idx)
400                .collect();
401
402            if selected.is_empty() {
403                return Err(SklearsError::InvalidInput(
404                    "No features selected with current threshold".to_string(),
405                ));
406            }
407            Ok(selected)
408        }
409    }
410}
411
412impl Estimator for MultiLabelFeatureSelector<Untrained> {
413    type Config = ();
414    type Error = SklearsError;
415    type Float = Float;
416
417    fn config(&self) -> &Self::Config {
418        &()
419    }
420}
421
422impl Fit<Array2<Float>, MultiLabelTarget> for MultiLabelFeatureSelector<Untrained> {
423    type Fitted = MultiLabelFeatureSelector<Trained>;
424
425    fn fit(self, features: &Array2<Float>, target: &MultiLabelTarget) -> SklResult<Self::Fitted> {
426        // Custom validation for multi-label targets
427        if features.nrows() != target.nrows() {
428            return Err(SklearsError::InvalidInput(format!(
429                "Inconsistent numbers of samples: features has {} samples, target has {}",
430                features.nrows(),
431                target.nrows()
432            )));
433        }
434
435        let n_features = features.ncols();
436        let n_labels = target.ncols();
437
438        if n_features == 0 {
439            return Err(SklearsError::InvalidInput(
440                "No features provided".to_string(),
441            ));
442        }
443        if n_labels == 0 {
444            return Err(SklearsError::InvalidInput("No labels provided".to_string()));
445        }
446
447        let relevance_scores = self.compute_multi_label_relevance(features, target)?;
448        let selected_features = self.select_features(&relevance_scores)?;
449
450        Ok(MultiLabelFeatureSelector {
451            strategy: self.strategy,
452            n_features: self.n_features,
453            threshold: self.threshold,
454            min_label_frequency: self.min_label_frequency,
455            use_label_correlation: self.use_label_correlation,
456            correlation_threshold: self.correlation_threshold,
457            state: PhantomData,
458            scores_: Some(relevance_scores),
459            selected_features_: Some(selected_features),
460            n_features_: Some(n_features),
461            n_labels_: Some(n_labels),
462        })
463    }
464}
465
466impl Transform<Array2<Float>> for MultiLabelFeatureSelector<Trained> {
467    fn transform(&self, x: &Array2<Float>) -> SklResult<Array2<Float>> {
468        validate::check_n_features(x, self.n_features_.unwrap())?;
469
470        let selected_features = self.selected_features_.as_ref().unwrap();
471        let n_samples = x.nrows();
472        let n_selected = selected_features.len();
473        let mut x_new = Array2::zeros((n_samples, n_selected));
474
475        for (new_idx, &old_idx) in selected_features.iter().enumerate() {
476            x_new.column_mut(new_idx).assign(&x.column(old_idx));
477        }
478
479        Ok(x_new)
480    }
481}
482
483impl SelectorMixin for MultiLabelFeatureSelector<Trained> {
484    fn get_support(&self) -> SklResult<Array1<bool>> {
485        let n_features = self.n_features_.unwrap();
486        let selected_features = self.selected_features_.as_ref().unwrap();
487        let mut support = Array1::from_elem(n_features, false);
488
489        for &idx in selected_features {
490            support[idx] = true;
491        }
492
493        Ok(support)
494    }
495
496    fn transform_features(&self, indices: &[usize]) -> SklResult<Vec<usize>> {
497        let selected_features = self.selected_features_.as_ref().unwrap();
498        Ok(indices
499            .iter()
500            .filter_map(|&idx| selected_features.iter().position(|&f| f == idx))
501            .collect())
502    }
503}
504
505impl FeatureSelector for MultiLabelFeatureSelector<Trained> {
506    fn selected_features(&self) -> &Vec<usize> {
507        self.selected_features_.as_ref().unwrap()
508    }
509}
510
511impl MultiLabelFeatureSelector<Trained> {
512    /// Get feature relevance scores
513    pub fn scores(&self) -> &Array1<Float> {
514        self.scores_.as_ref().unwrap()
515    }
516
517    /// Get the number of selected features
518    pub fn n_features_out(&self) -> usize {
519        self.selected_features_.as_ref().unwrap().len()
520    }
521
522    /// Get the number of labels
523    pub fn n_labels(&self) -> usize {
524        self.n_labels_.unwrap()
525    }
526
527    /// Check if a feature was selected
528    pub fn is_feature_selected(&self, feature_idx: usize) -> bool {
529        self.selected_features_
530            .as_ref()
531            .unwrap()
532            .contains(&feature_idx)
533    }
534
535    /// Get feature ranking (0-indexed, lower is better)
536    pub fn feature_ranking(&self) -> Vec<usize> {
537        let scores = self.scores_.as_ref().unwrap();
538        let mut indices: Vec<usize> = (0..scores.len()).collect();
539        indices.sort_by(|&a, &b| scores[b].partial_cmp(&scores[a]).unwrap());
540
541        let mut ranking = vec![0; scores.len()];
542        for (rank, &feature_idx) in indices.iter().enumerate() {
543            ranking[feature_idx] = rank;
544        }
545        ranking
546    }
547}
548
549/// Label-specific feature selector that selects features for individual labels
550#[derive(Debug, Clone)]
551pub struct LabelSpecificSelector<State = Untrained> {
552    n_features_per_label: Option<usize>,
553    threshold: Float,
554    aggregate_method: AggregateMethod,
555    state: PhantomData<State>,
556    // Trained state
557    selected_features_: Option<Vec<usize>>,
558    label_selections_: Option<Vec<Vec<usize>>>,
559    n_features_: Option<usize>,
560    n_labels_: Option<usize>,
561}
562
563impl Default for LabelSpecificSelector<Untrained> {
564    fn default() -> Self {
565        Self::new()
566    }
567}
568
569impl LabelSpecificSelector<Untrained> {
570    pub fn new() -> Self {
571        Self {
572            n_features_per_label: None,
573            threshold: 0.01,
574            aggregate_method: AggregateMethod::Union,
575            state: PhantomData,
576            selected_features_: None,
577            label_selections_: None,
578            n_features_: None,
579            n_labels_: None,
580        }
581    }
582
583    pub fn n_features_per_label(mut self, n_features: usize) -> Self {
584        self.n_features_per_label = Some(n_features);
585        self
586    }
587
588    pub fn threshold(mut self, threshold: Float) -> Self {
589        self.threshold = threshold;
590        self
591    }
592
593    pub fn aggregate_method(mut self, method: AggregateMethod) -> Self {
594        self.aggregate_method = method;
595        self
596    }
597
598    fn select_for_label(
599        &self,
600        features: &Array2<Float>,
601        label: &scirs2_core::ndarray::ArrayView1<Float>,
602    ) -> SklResult<Vec<usize>> {
603        let n_features = features.ncols();
604        let mut scores = Array1::zeros(n_features);
605
606        for feature_idx in 0..n_features {
607            let feature_col = features.column(feature_idx);
608            scores[feature_idx] = self.compute_feature_label_relevance(&feature_col, label)?;
609        }
610
611        if let Some(k) = self.n_features_per_label {
612            let mut indices: Vec<usize> = (0..n_features).collect();
613            indices.sort_by(|&a, &b| scores[b].partial_cmp(&scores[a]).unwrap());
614            indices.truncate(k);
615            Ok(indices)
616        } else {
617            Ok(scores
618                .iter()
619                .enumerate()
620                .filter(|(_, &score)| score >= self.threshold)
621                .map(|(idx, _)| idx)
622                .collect())
623        }
624    }
625
626    fn compute_feature_label_relevance(
627        &self,
628        feature: &scirs2_core::ndarray::ArrayView1<Float>,
629        label: &scirs2_core::ndarray::ArrayView1<Float>,
630    ) -> SklResult<Float> {
631        // Compute correlation coefficient
632        let feature_mean = feature.mean().unwrap_or(0.0);
633        let label_mean = label.mean().unwrap_or(0.0);
634
635        let mut numerator = 0.0;
636        let mut feature_variance = 0.0;
637        let mut label_variance = 0.0;
638
639        let n = feature.len();
640        for i in 0..n {
641            let f_diff = feature[i] - feature_mean;
642            let l_diff = label[i] - label_mean;
643
644            numerator += f_diff * l_diff;
645            feature_variance += f_diff * f_diff;
646            label_variance += l_diff * l_diff;
647        }
648
649        if feature_variance == 0.0 || label_variance == 0.0 {
650            return Ok(0.0);
651        }
652
653        let correlation = numerator / (feature_variance * label_variance).sqrt();
654        Ok(correlation.abs())
655    }
656
657    fn aggregate_selections(&self, label_selections: &[Vec<usize>]) -> Vec<usize> {
658        match self.aggregate_method {
659            AggregateMethod::Union => {
660                let mut result = HashSet::new();
661                for selection in label_selections {
662                    result.extend(selection);
663                }
664                result.into_iter().collect()
665            }
666            AggregateMethod::Intersection => {
667                if label_selections.is_empty() {
668                    return vec![];
669                }
670                let mut result: HashSet<usize> = label_selections[0].iter().cloned().collect();
671                for selection in &label_selections[1..] {
672                    let selection_set: HashSet<usize> = selection.iter().cloned().collect();
673                    result = result.intersection(&selection_set).cloned().collect();
674                }
675                result.into_iter().collect()
676            }
677            AggregateMethod::MajorityVote => {
678                let mut feature_counts: HashMap<usize, usize> = HashMap::new();
679                for selection in label_selections {
680                    for &feature in selection {
681                        *feature_counts.entry(feature).or_insert(0) += 1;
682                    }
683                }
684                let majority_threshold = (label_selections.len() + 1) / 2;
685                feature_counts
686                    .into_iter()
687                    .filter(|(_, count)| *count >= majority_threshold)
688                    .map(|(feature, _)| feature)
689                    .collect()
690            }
691            AggregateMethod::WeightedUnion => {
692                // For now, same as union - could be extended with label importance weights
693                let mut result = HashSet::new();
694                for selection in label_selections {
695                    result.extend(selection);
696                }
697                result.into_iter().collect()
698            }
699        }
700    }
701}
702
703impl Estimator for LabelSpecificSelector<Untrained> {
704    type Config = ();
705    type Error = SklearsError;
706    type Float = Float;
707
708    fn config(&self) -> &Self::Config {
709        &()
710    }
711}
712
713impl Fit<Array2<Float>, MultiLabelTarget> for LabelSpecificSelector<Untrained> {
714    type Fitted = LabelSpecificSelector<Trained>;
715
716    fn fit(self, features: &Array2<Float>, target: &MultiLabelTarget) -> SklResult<Self::Fitted> {
717        // Custom validation for multi-label targets
718        if features.nrows() != target.nrows() {
719            return Err(SklearsError::InvalidInput(format!(
720                "Inconsistent numbers of samples: features has {} samples, target has {}",
721                features.nrows(),
722                target.nrows()
723            )));
724        }
725
726        let n_features = features.ncols();
727        let n_labels = target.ncols();
728        let mut label_selections = Vec::with_capacity(n_labels);
729
730        for label_idx in 0..n_labels {
731            let label_col = target.column(label_idx);
732            let selection = self.select_for_label(features, &label_col)?;
733            label_selections.push(selection);
734        }
735
736        let selected_features = self.aggregate_selections(&label_selections);
737
738        Ok(LabelSpecificSelector {
739            n_features_per_label: self.n_features_per_label,
740            threshold: self.threshold,
741            aggregate_method: self.aggregate_method,
742            state: PhantomData,
743            selected_features_: Some(selected_features),
744            label_selections_: Some(label_selections),
745            n_features_: Some(n_features),
746            n_labels_: Some(n_labels),
747        })
748    }
749}
750
751impl Transform<Array2<Float>> for LabelSpecificSelector<Trained> {
752    fn transform(&self, x: &Array2<Float>) -> SklResult<Array2<Float>> {
753        validate::check_n_features(x, self.n_features_.unwrap())?;
754
755        let selected_features = self.selected_features_.as_ref().unwrap();
756        let n_samples = x.nrows();
757        let n_selected = selected_features.len();
758
759        if n_selected == 0 {
760            return Err(SklearsError::InvalidInput(
761                "No features were selected".to_string(),
762            ));
763        }
764
765        let mut x_new = Array2::zeros((n_samples, n_selected));
766
767        for (new_idx, &old_idx) in selected_features.iter().enumerate() {
768            x_new.column_mut(new_idx).assign(&x.column(old_idx));
769        }
770
771        Ok(x_new)
772    }
773}
774
775impl SelectorMixin for LabelSpecificSelector<Trained> {
776    fn get_support(&self) -> SklResult<Array1<bool>> {
777        let n_features = self.n_features_.unwrap();
778        let selected_features = self.selected_features_.as_ref().unwrap();
779        let mut support = Array1::from_elem(n_features, false);
780
781        for &idx in selected_features {
782            support[idx] = true;
783        }
784
785        Ok(support)
786    }
787
788    fn transform_features(&self, indices: &[usize]) -> SklResult<Vec<usize>> {
789        let selected_features = self.selected_features_.as_ref().unwrap();
790        Ok(indices
791            .iter()
792            .filter_map(|&idx| selected_features.iter().position(|&f| f == idx))
793            .collect())
794    }
795}
796
797impl FeatureSelector for LabelSpecificSelector<Trained> {
798    fn selected_features(&self) -> &Vec<usize> {
799        self.selected_features_.as_ref().unwrap()
800    }
801}
802
803impl LabelSpecificSelector<Trained> {
804    pub fn features_for_label(&self, label_idx: usize) -> Option<&[usize]> {
805        self.label_selections_
806            .as_ref()?
807            .get(label_idx)
808            .map(|v| v.as_slice())
809    }
810
811    pub fn n_features_out(&self) -> usize {
812        self.selected_features_.as_ref().unwrap().len()
813    }
814
815    pub fn n_labels(&self) -> usize {
816        self.n_labels_.unwrap()
817    }
818}
819
820#[allow(non_snake_case)]
821#[cfg(test)]
822mod tests {
823    use super::*;
824    use proptest::prelude::*;
825    use scirs2_core::ndarray::Array2;
826
827    fn create_test_data() -> (Array2<Float>, MultiLabelTarget) {
828        let features =
829            Array2::from_shape_vec((100, 10), (0..1000).map(|i| (i as Float) * 0.01).collect())
830                .unwrap();
831        let labels = Array2::from_shape_vec(
832            (100, 3),
833            (0..300)
834                .map(|i| if i % 3 == 0 { 1.0 } else { 0.0 })
835                .collect(),
836        )
837        .unwrap();
838        (features, labels)
839    }
840
841    #[test]
842    fn test_multi_label_selector_global_relevance() {
843        let (features, labels) = create_test_data();
844
845        let selector = MultiLabelFeatureSelector::new()
846            .strategy(MultiLabelStrategy::GlobalRelevance)
847            .n_features(5);
848
849        let trained = selector.fit(&features, &labels).unwrap();
850        assert_eq!(trained.n_features_out(), 5);
851        assert_eq!(trained.selected_features().len(), 5);
852    }
853
854    #[test]
855    fn test_multi_label_selector_label_specific() {
856        let (features, labels) = create_test_data();
857
858        let selector = MultiLabelFeatureSelector::new()
859            .strategy(MultiLabelStrategy::LabelSpecific)
860            .n_features(3); // Use fixed number instead of threshold for random data
861
862        let trained = selector.fit(&features, &labels).unwrap();
863        assert_eq!(trained.n_features_out(), 3);
864    }
865
866    #[test]
867    fn test_multi_label_transform() {
868        let (features, labels) = create_test_data();
869
870        let selector = MultiLabelFeatureSelector::new().n_features(3);
871
872        let trained = selector.fit(&features, &labels).unwrap();
873        let transformed = trained.transform(&features).unwrap();
874
875        assert_eq!(transformed.ncols(), 3);
876        assert_eq!(transformed.nrows(), features.nrows());
877    }
878
879    #[test]
880    fn test_label_specific_selector() {
881        let (features, labels) = create_test_data();
882
883        let selector = LabelSpecificSelector::new()
884            .n_features_per_label(2)
885            .aggregate_method(AggregateMethod::Union);
886
887        let trained = selector.fit(&features, &labels).unwrap();
888        assert!(trained.n_features_out() > 0);
889        assert!(trained.n_features_out() <= 6); // Max 2 per label * 3 labels
890    }
891
892    #[test]
893    fn test_ensemble_strategy() {
894        let (features, labels) = create_test_data();
895
896        let selector = MultiLabelFeatureSelector::new()
897            .strategy(MultiLabelStrategy::Ensemble)
898            .n_features(4);
899
900        let trained = selector.fit(&features, &labels).unwrap();
901        assert_eq!(trained.n_features_out(), 4);
902    }
903
904    #[test]
905    fn test_feature_ranking() {
906        let (features, labels) = create_test_data();
907
908        let selector = MultiLabelFeatureSelector::new().n_features(5);
909
910        let trained = selector.fit(&features, &labels).unwrap();
911        let ranking = trained.feature_ranking();
912
913        assert_eq!(ranking.len(), features.ncols());
914        // Check that selected features have better (lower) ranks
915        for &selected_idx in trained.selected_features() {
916            assert!(ranking[selected_idx] < 5);
917        }
918    }
919
920    #[test]
921    fn test_selector_mixin() {
922        let (features, labels) = create_test_data();
923
924        let selector = MultiLabelFeatureSelector::new().n_features(3);
925
926        let trained = selector.fit(&features, &labels).unwrap();
927        let support = trained.get_support().unwrap();
928
929        assert_eq!(support.len(), features.ncols());
930        assert_eq!(support.iter().filter(|&&x| x).count(), 3);
931    }
932
933    // Property-based tests for multi-label feature selection
934    mod proptests {
935        use super::*;
936
937        fn valid_array_2d() -> impl Strategy<Value = Array2<Float>> {
938            (5usize..20, 10usize..50).prop_flat_map(|(n_cols, n_rows)| {
939                prop::collection::vec(-10.0..10.0f64, n_rows * n_cols).prop_map(move |values| {
940                    Array2::from_shape_vec((n_rows, n_cols), values).unwrap()
941                })
942            })
943        }
944
945        fn valid_multilabel_target(
946            n_samples: usize,
947            n_labels: usize,
948        ) -> impl Strategy<Value = MultiLabelTarget> {
949            prop::collection::vec(0.0..1.0f64, n_samples * n_labels).prop_map(move |values| {
950                Array2::from_shape_vec((n_samples, n_labels), values).unwrap()
951            })
952        }
953
954        proptest! {
955            #[test]
956            fn prop_multi_label_selector_respects_feature_count(
957                features in valid_array_2d(),
958                n_features in 1usize..10
959            ) {
960                let n_labels = 3;
961                let labels = Array2::from_elem((features.nrows(), n_labels), 0.5);
962
963                let n_select = n_features.min(features.ncols());
964                let selector = MultiLabelFeatureSelector::new()
965                    .n_features(n_select);
966
967                if let Ok(trained) = selector.fit(&features, &labels) {
968                    prop_assert_eq!(trained.n_features_out(), n_select);
969                    prop_assert!(trained.selected_features().len() == n_select);
970
971                    // All selected features should be valid indices
972                    for &idx in trained.selected_features() {
973                        prop_assert!(idx < features.ncols());
974                    }
975
976                    // Transform should work correctly
977                    if let Ok(transformed) = trained.transform(&features) {
978                        prop_assert_eq!(transformed.ncols(), n_select);
979                        prop_assert_eq!(transformed.nrows(), features.nrows());
980                    }
981                }
982            }
983
984            #[test]
985            fn prop_multi_label_selector_deterministic(
986                features in valid_array_2d(),
987                n_features in 1usize..5
988            ) {
989                let n_labels = 2;
990                let labels = Array2::from_elem((features.nrows(), n_labels), 0.3);
991
992                let n_select = n_features.min(features.ncols());
993                let selector = MultiLabelFeatureSelector::new()
994                    .strategy(MultiLabelStrategy::GlobalRelevance)
995                    .n_features(n_select);
996
997                if let Ok(trained1) = selector.clone().fit(&features, &labels) {
998                    if let Ok(trained2) = selector.fit(&features, &labels) {
999                        // Same input should produce same output
1000                        prop_assert_eq!(trained1.selected_features(), trained2.selected_features());
1001                        prop_assert_eq!(trained1.n_features_out(), trained2.n_features_out());
1002                    }
1003                }
1004            }
1005
1006            #[test]
1007            fn prop_multi_label_selector_scores_non_negative(
1008                features in valid_array_2d(),
1009                n_features in 1usize..5
1010            ) {
1011                let n_labels = 2;
1012                let labels = Array2::from_elem((features.nrows(), n_labels), 0.4);
1013
1014                let n_select = n_features.min(features.ncols());
1015                let selector = MultiLabelFeatureSelector::new()
1016                    .n_features(n_select);
1017
1018                if let Ok(trained) = selector.fit(&features, &labels) {
1019                    let scores = trained.scores();
1020
1021                    // All scores should be non-negative (using absolute correlation)
1022                    for &score in scores.iter() {
1023                        prop_assert!(score >= 0.0);
1024                    }
1025
1026                    // Selected features should have higher scores
1027                    let selected_indices = trained.selected_features();
1028                    let min_selected_score = selected_indices.iter()
1029                        .map(|&idx| scores[idx])
1030                        .fold(f64::INFINITY, f64::min);
1031
1032                    // Count how many features have scores >= min_selected_score
1033                    let count_above_min = scores.iter()
1034                        .filter(|&&score| score >= min_selected_score)
1035                        .count();
1036
1037                    // Should be at least as many as selected
1038                    prop_assert!(count_above_min >= selected_indices.len());
1039                }
1040            }
1041
1042            #[test]
1043            fn prop_label_specific_selector_aggregation_consistency(
1044                features in valid_array_2d(),
1045                n_features_per_label in 1usize..3
1046            ) {
1047                let n_labels = 3;
1048                let labels = Array2::from_elem((features.nrows(), n_labels), 0.5);
1049
1050                let n_select = n_features_per_label.min(features.ncols());
1051
1052                // Test union aggregation
1053                let selector_union = LabelSpecificSelector::new()
1054                    .n_features_per_label(n_select)
1055                    .aggregate_method(AggregateMethod::Union);
1056
1057                if let Ok(trained_union) = selector_union.fit(&features, &labels) {
1058                    // Union should select at most n_select * n_labels features
1059                    prop_assert!(trained_union.n_features_out() <= n_select * n_labels);
1060
1061                    // Test intersection aggregation
1062                    let selector_intersect = LabelSpecificSelector::new()
1063                        .n_features_per_label(n_select)
1064                        .aggregate_method(AggregateMethod::Intersection);
1065
1066                    if let Ok(trained_intersect) = selector_intersect.fit(&features, &labels) {
1067                        // Intersection should select at most n_select features
1068                        prop_assert!(trained_intersect.n_features_out() <= n_select);
1069
1070                        // Intersection features should be subset of union features
1071                        let union_set: std::collections::HashSet<_> = trained_union.selected_features().iter().collect();
1072                        for &feature in trained_intersect.selected_features() {
1073                            prop_assert!(union_set.contains(&feature));
1074                        }
1075                    }
1076                }
1077            }
1078
1079            #[test]
1080            fn prop_multi_label_transform_preserves_samples(
1081                features in valid_array_2d(),
1082                n_features in 1usize..5
1083            ) {
1084                let n_labels = 2;
1085                let labels = Array2::from_elem((features.nrows(), n_labels), 0.4);
1086
1087                let n_select = n_features.min(features.ncols());
1088                let selector = MultiLabelFeatureSelector::new()
1089                    .n_features(n_select);
1090
1091                if let Ok(trained) = selector.fit(&features, &labels) {
1092                    if let Ok(transformed) = trained.transform(&features) {
1093                        // Transform should preserve number of samples
1094                        prop_assert_eq!(transformed.nrows(), features.nrows());
1095
1096                        // Should have correct number of features
1097                        prop_assert_eq!(transformed.ncols(), n_select);
1098
1099                        // Values should be from original features
1100                        for (sample_idx, row) in transformed.rows().into_iter().enumerate() {
1101                            for (feat_idx, &value) in row.iter().enumerate() {
1102                                let original_feat_idx = trained.selected_features()[feat_idx];
1103                                let expected_value = features[[sample_idx, original_feat_idx]];
1104                                prop_assert!((value - expected_value).abs() < 1e-10);
1105                            }
1106                        }
1107                    }
1108                }
1109            }
1110        }
1111    }
1112}