1use crate::base::{FeatureSelector, SelectorMixin};
7use scirs2_core::ndarray::{Array1, Array2};
8use sklears_core::{
9 error::{validate, Result as SklResult, SklearsError},
10 traits::{Estimator, Fit, Trained, Transform, Untrained},
11 types::Float,
12};
13use std::collections::{HashMap, HashSet};
14use std::marker::PhantomData;
15
16pub type MultiLabelTarget = Array2<Float>;
18
19#[derive(Debug, Clone)]
21pub enum MultiLabelStrategy {
22 GlobalRelevance,
24 LabelSpecific,
26 LabelCorrelationAware,
28 HierarchicalLabels,
30 Ensemble,
32}
33
34#[derive(Debug, Clone)]
36pub enum AggregateMethod {
37 Union,
39 Intersection,
41 MajorityVote,
43 WeightedUnion,
45}
46
47#[derive(Debug, Clone)]
49pub struct MultiLabelFeatureSelector<State = Untrained> {
50 strategy: MultiLabelStrategy,
51 n_features: Option<usize>,
52 threshold: Float,
53 min_label_frequency: Float,
54 use_label_correlation: bool,
55 correlation_threshold: Float,
56 state: PhantomData<State>,
57 scores_: Option<Array1<Float>>,
59 selected_features_: Option<Vec<usize>>,
60 n_features_: Option<usize>,
61 n_labels_: Option<usize>,
62}
63
64impl Default for MultiLabelFeatureSelector<Untrained> {
65 fn default() -> Self {
66 Self::new()
67 }
68}
69
70impl MultiLabelFeatureSelector<Untrained> {
71 pub fn new() -> Self {
73 Self {
74 strategy: MultiLabelStrategy::LabelSpecific,
75 n_features: None,
76 threshold: 0.01,
77 min_label_frequency: 0.01,
78 use_label_correlation: true,
79 correlation_threshold: 0.1,
80 state: PhantomData,
81 scores_: None,
82 selected_features_: None,
83 n_features_: None,
84 n_labels_: None,
85 }
86 }
87
88 pub fn strategy(mut self, strategy: MultiLabelStrategy) -> Self {
90 self.strategy = strategy;
91 self
92 }
93
94 pub fn n_features(mut self, n_features: usize) -> Self {
96 self.n_features = Some(n_features);
97 self
98 }
99
100 pub fn threshold(mut self, threshold: Float) -> Self {
102 self.threshold = threshold;
103 self
104 }
105
106 pub fn min_label_frequency(mut self, frequency: Float) -> Self {
108 self.min_label_frequency = frequency;
109 self
110 }
111
112 pub fn use_label_correlation(mut self, use_correlation: bool) -> Self {
114 self.use_label_correlation = use_correlation;
115 self
116 }
117
118 pub fn correlation_threshold(mut self, threshold: Float) -> Self {
120 self.correlation_threshold = threshold;
121 self
122 }
123
124 fn compute_multi_label_relevance(
126 &self,
127 features: &Array2<Float>,
128 labels: &MultiLabelTarget,
129 ) -> SklResult<Array1<Float>> {
130 let n_features = features.ncols();
131 let mut relevance_scores = Array1::zeros(n_features);
132
133 match self.strategy {
134 MultiLabelStrategy::GlobalRelevance => {
135 self.compute_global_relevance(features, labels, &mut relevance_scores)?;
136 }
137 MultiLabelStrategy::LabelSpecific => {
138 self.compute_label_specific_relevance(features, labels, &mut relevance_scores)?;
139 }
140 MultiLabelStrategy::LabelCorrelationAware => {
141 self.compute_correlation_aware_relevance(features, labels, &mut relevance_scores)?;
142 }
143 MultiLabelStrategy::HierarchicalLabels => {
144 self.compute_hierarchical_relevance(features, labels, &mut relevance_scores)?;
145 }
146 MultiLabelStrategy::Ensemble => {
147 self.compute_ensemble_relevance(features, labels, &mut relevance_scores)?;
148 }
149 }
150
151 Ok(relevance_scores)
152 }
153
154 fn compute_global_relevance(
156 &self,
157 features: &Array2<Float>,
158 labels: &MultiLabelTarget,
159 scores: &mut Array1<Float>,
160 ) -> SklResult<()> {
161 let n_features = features.ncols();
162 let n_labels = labels.ncols();
163
164 for feature_idx in 0..n_features {
165 let feature_col = features.column(feature_idx);
166 let mut total_relevance = 0.0;
167
168 for label_idx in 0..n_labels {
169 let label_col = labels.column(label_idx);
170
171 let corr = self.compute_correlation(&feature_col, &label_col)?;
173 total_relevance += corr.abs();
174 }
175
176 scores[feature_idx] = total_relevance / n_labels as Float;
177 }
178
179 Ok(())
180 }
181
182 fn compute_label_specific_relevance(
184 &self,
185 features: &Array2<Float>,
186 labels: &MultiLabelTarget,
187 scores: &mut Array1<Float>,
188 ) -> SklResult<()> {
189 let n_features = features.ncols();
190 let n_labels = labels.ncols();
191
192 let mut label_relevances = Array2::zeros((n_labels, n_features));
194
195 for label_idx in 0..n_labels {
196 let label_col = labels.column(label_idx);
197
198 let label_frequency = label_col.sum() / label_col.len() as Float;
200 if label_frequency < self.min_label_frequency {
201 continue;
202 }
203
204 for feature_idx in 0..n_features {
205 let feature_col = features.column(feature_idx);
206 let corr = self.compute_correlation(&feature_col, &label_col)?;
207 label_relevances[[label_idx, feature_idx]] = corr.abs();
208 }
209 }
210
211 for feature_idx in 0..n_features {
213 let feature_relevances = label_relevances.column(feature_idx);
214 scores[feature_idx] = feature_relevances.iter().cloned().fold(0.0, Float::max);
215 }
216
217 Ok(())
218 }
219
220 fn compute_correlation_aware_relevance(
222 &self,
223 features: &Array2<Float>,
224 labels: &MultiLabelTarget,
225 scores: &mut Array1<Float>,
226 ) -> SklResult<()> {
227 let n_features = features.ncols();
228 let n_labels = labels.ncols();
229
230 let label_correlations = self.compute_label_correlation_matrix(labels)?;
232
233 for feature_idx in 0..n_features {
234 let feature_col = features.column(feature_idx);
235 let mut weighted_relevance = 0.0;
236 let mut total_weight = 0.0;
237
238 for label_idx in 0..n_labels {
239 let label_col = labels.column(label_idx);
240 let corr = self.compute_correlation(&feature_col, &label_col)?;
241
242 let label_weight = self.compute_label_weight(label_idx, &label_correlations);
244 weighted_relevance += corr.abs() * label_weight;
245 total_weight += label_weight;
246 }
247
248 scores[feature_idx] = if total_weight > 0.0 {
249 weighted_relevance / total_weight
250 } else {
251 0.0
252 };
253 }
254
255 Ok(())
256 }
257
258 fn compute_hierarchical_relevance(
260 &self,
261 features: &Array2<Float>,
262 labels: &MultiLabelTarget,
263 scores: &mut Array1<Float>,
264 ) -> SklResult<()> {
265 self.compute_label_specific_relevance(features, labels, scores)?;
268
269 for score in scores.iter_mut() {
271 *score *= 1.1; }
273
274 Ok(())
275 }
276
277 fn compute_ensemble_relevance(
279 &self,
280 features: &Array2<Float>,
281 labels: &MultiLabelTarget,
282 scores: &mut Array1<Float>,
283 ) -> SklResult<()> {
284 let n_features = features.ncols();
285 let mut global_scores = Array1::zeros(n_features);
286 let mut specific_scores = Array1::zeros(n_features);
287 let mut correlation_scores = Array1::zeros(n_features);
288
289 self.compute_global_relevance(features, labels, &mut global_scores)?;
291 self.compute_label_specific_relevance(features, labels, &mut specific_scores)?;
292 self.compute_correlation_aware_relevance(features, labels, &mut correlation_scores)?;
293
294 for feature_idx in 0..n_features {
296 scores[feature_idx] = (global_scores[feature_idx]
297 + specific_scores[feature_idx]
298 + correlation_scores[feature_idx])
299 / 3.0;
300 }
301
302 Ok(())
303 }
304
305 fn compute_correlation(
307 &self,
308 feature: &scirs2_core::ndarray::ArrayView1<Float>,
309 label: &scirs2_core::ndarray::ArrayView1<Float>,
310 ) -> SklResult<Float> {
311 let feature_mean = feature.mean().unwrap_or(0.0);
312 let label_mean = label.mean().unwrap_or(0.0);
313
314 let mut covariance = 0.0;
315 let mut feature_var = 0.0;
316 let mut label_var = 0.0;
317
318 let n = feature.len();
319 if n == 0 {
320 return Ok(0.0);
321 }
322
323 for i in 0..n {
324 let f_diff = feature[i] - feature_mean;
325 let l_diff = label[i] - label_mean;
326
327 covariance += f_diff * l_diff;
328 feature_var += f_diff * f_diff;
329 label_var += l_diff * l_diff;
330 }
331
332 if feature_var == 0.0 || label_var == 0.0 {
333 return Ok(0.0);
334 }
335
336 let correlation = covariance / (feature_var * label_var).sqrt();
337 Ok(correlation)
338 }
339
340 fn compute_label_correlation_matrix(
342 &self,
343 labels: &MultiLabelTarget,
344 ) -> SklResult<Array2<Float>> {
345 let n_labels = labels.ncols();
346 let mut correlations = Array2::zeros((n_labels, n_labels));
347
348 for i in 0..n_labels {
349 for j in 0..n_labels {
350 if i == j {
351 correlations[[i, j]] = 1.0;
352 } else {
353 let label_i = labels.column(i);
354 let label_j = labels.column(j);
355 let corr = self.compute_correlation(&label_i, &label_j)?;
356 correlations[[i, j]] = corr;
357 }
358 }
359 }
360
361 Ok(correlations)
362 }
363
364 fn compute_label_weight(&self, label_idx: usize, correlations: &Array2<Float>) -> Float {
366 let label_correlations = correlations.row(label_idx);
367 let avg_correlation = label_correlations.mean().unwrap_or(0.0);
368
369 1.0 - (avg_correlation - 0.5).abs()
371 }
372
373 fn select_features(&self, relevance_scores: &Array1<Float>) -> SklResult<Vec<usize>> {
375 let n_features = relevance_scores.len();
376
377 if let Some(k) = self.n_features {
378 if k > n_features {
379 return Err(SklearsError::InvalidInput(format!(
380 "n_features ({}) must be <= total features ({})",
381 k, n_features
382 )));
383 }
384 let mut indices: Vec<usize> = (0..n_features).collect();
386 indices.sort_by(|&a, &b| {
387 relevance_scores[b]
388 .partial_cmp(&relevance_scores[a])
389 .unwrap()
390 });
391 indices.truncate(k);
392 Ok(indices)
393 } else {
394 let selected: Vec<usize> = relevance_scores
396 .iter()
397 .enumerate()
398 .filter(|(_, &score)| score >= self.threshold)
399 .map(|(idx, _)| idx)
400 .collect();
401
402 if selected.is_empty() {
403 return Err(SklearsError::InvalidInput(
404 "No features selected with current threshold".to_string(),
405 ));
406 }
407 Ok(selected)
408 }
409 }
410}
411
412impl Estimator for MultiLabelFeatureSelector<Untrained> {
413 type Config = ();
414 type Error = SklearsError;
415 type Float = Float;
416
417 fn config(&self) -> &Self::Config {
418 &()
419 }
420}
421
422impl Fit<Array2<Float>, MultiLabelTarget> for MultiLabelFeatureSelector<Untrained> {
423 type Fitted = MultiLabelFeatureSelector<Trained>;
424
425 fn fit(self, features: &Array2<Float>, target: &MultiLabelTarget) -> SklResult<Self::Fitted> {
426 if features.nrows() != target.nrows() {
428 return Err(SklearsError::InvalidInput(format!(
429 "Inconsistent numbers of samples: features has {} samples, target has {}",
430 features.nrows(),
431 target.nrows()
432 )));
433 }
434
435 let n_features = features.ncols();
436 let n_labels = target.ncols();
437
438 if n_features == 0 {
439 return Err(SklearsError::InvalidInput(
440 "No features provided".to_string(),
441 ));
442 }
443 if n_labels == 0 {
444 return Err(SklearsError::InvalidInput("No labels provided".to_string()));
445 }
446
447 let relevance_scores = self.compute_multi_label_relevance(features, target)?;
448 let selected_features = self.select_features(&relevance_scores)?;
449
450 Ok(MultiLabelFeatureSelector {
451 strategy: self.strategy,
452 n_features: self.n_features,
453 threshold: self.threshold,
454 min_label_frequency: self.min_label_frequency,
455 use_label_correlation: self.use_label_correlation,
456 correlation_threshold: self.correlation_threshold,
457 state: PhantomData,
458 scores_: Some(relevance_scores),
459 selected_features_: Some(selected_features),
460 n_features_: Some(n_features),
461 n_labels_: Some(n_labels),
462 })
463 }
464}
465
466impl Transform<Array2<Float>> for MultiLabelFeatureSelector<Trained> {
467 fn transform(&self, x: &Array2<Float>) -> SklResult<Array2<Float>> {
468 validate::check_n_features(x, self.n_features_.unwrap())?;
469
470 let selected_features = self.selected_features_.as_ref().unwrap();
471 let n_samples = x.nrows();
472 let n_selected = selected_features.len();
473 let mut x_new = Array2::zeros((n_samples, n_selected));
474
475 for (new_idx, &old_idx) in selected_features.iter().enumerate() {
476 x_new.column_mut(new_idx).assign(&x.column(old_idx));
477 }
478
479 Ok(x_new)
480 }
481}
482
483impl SelectorMixin for MultiLabelFeatureSelector<Trained> {
484 fn get_support(&self) -> SklResult<Array1<bool>> {
485 let n_features = self.n_features_.unwrap();
486 let selected_features = self.selected_features_.as_ref().unwrap();
487 let mut support = Array1::from_elem(n_features, false);
488
489 for &idx in selected_features {
490 support[idx] = true;
491 }
492
493 Ok(support)
494 }
495
496 fn transform_features(&self, indices: &[usize]) -> SklResult<Vec<usize>> {
497 let selected_features = self.selected_features_.as_ref().unwrap();
498 Ok(indices
499 .iter()
500 .filter_map(|&idx| selected_features.iter().position(|&f| f == idx))
501 .collect())
502 }
503}
504
505impl FeatureSelector for MultiLabelFeatureSelector<Trained> {
506 fn selected_features(&self) -> &Vec<usize> {
507 self.selected_features_.as_ref().unwrap()
508 }
509}
510
511impl MultiLabelFeatureSelector<Trained> {
512 pub fn scores(&self) -> &Array1<Float> {
514 self.scores_.as_ref().unwrap()
515 }
516
517 pub fn n_features_out(&self) -> usize {
519 self.selected_features_.as_ref().unwrap().len()
520 }
521
522 pub fn n_labels(&self) -> usize {
524 self.n_labels_.unwrap()
525 }
526
527 pub fn is_feature_selected(&self, feature_idx: usize) -> bool {
529 self.selected_features_
530 .as_ref()
531 .unwrap()
532 .contains(&feature_idx)
533 }
534
535 pub fn feature_ranking(&self) -> Vec<usize> {
537 let scores = self.scores_.as_ref().unwrap();
538 let mut indices: Vec<usize> = (0..scores.len()).collect();
539 indices.sort_by(|&a, &b| scores[b].partial_cmp(&scores[a]).unwrap());
540
541 let mut ranking = vec![0; scores.len()];
542 for (rank, &feature_idx) in indices.iter().enumerate() {
543 ranking[feature_idx] = rank;
544 }
545 ranking
546 }
547}
548
549#[derive(Debug, Clone)]
551pub struct LabelSpecificSelector<State = Untrained> {
552 n_features_per_label: Option<usize>,
553 threshold: Float,
554 aggregate_method: AggregateMethod,
555 state: PhantomData<State>,
556 selected_features_: Option<Vec<usize>>,
558 label_selections_: Option<Vec<Vec<usize>>>,
559 n_features_: Option<usize>,
560 n_labels_: Option<usize>,
561}
562
563impl Default for LabelSpecificSelector<Untrained> {
564 fn default() -> Self {
565 Self::new()
566 }
567}
568
569impl LabelSpecificSelector<Untrained> {
570 pub fn new() -> Self {
571 Self {
572 n_features_per_label: None,
573 threshold: 0.01,
574 aggregate_method: AggregateMethod::Union,
575 state: PhantomData,
576 selected_features_: None,
577 label_selections_: None,
578 n_features_: None,
579 n_labels_: None,
580 }
581 }
582
583 pub fn n_features_per_label(mut self, n_features: usize) -> Self {
584 self.n_features_per_label = Some(n_features);
585 self
586 }
587
588 pub fn threshold(mut self, threshold: Float) -> Self {
589 self.threshold = threshold;
590 self
591 }
592
593 pub fn aggregate_method(mut self, method: AggregateMethod) -> Self {
594 self.aggregate_method = method;
595 self
596 }
597
598 fn select_for_label(
599 &self,
600 features: &Array2<Float>,
601 label: &scirs2_core::ndarray::ArrayView1<Float>,
602 ) -> SklResult<Vec<usize>> {
603 let n_features = features.ncols();
604 let mut scores = Array1::zeros(n_features);
605
606 for feature_idx in 0..n_features {
607 let feature_col = features.column(feature_idx);
608 scores[feature_idx] = self.compute_feature_label_relevance(&feature_col, label)?;
609 }
610
611 if let Some(k) = self.n_features_per_label {
612 let mut indices: Vec<usize> = (0..n_features).collect();
613 indices.sort_by(|&a, &b| scores[b].partial_cmp(&scores[a]).unwrap());
614 indices.truncate(k);
615 Ok(indices)
616 } else {
617 Ok(scores
618 .iter()
619 .enumerate()
620 .filter(|(_, &score)| score >= self.threshold)
621 .map(|(idx, _)| idx)
622 .collect())
623 }
624 }
625
626 fn compute_feature_label_relevance(
627 &self,
628 feature: &scirs2_core::ndarray::ArrayView1<Float>,
629 label: &scirs2_core::ndarray::ArrayView1<Float>,
630 ) -> SklResult<Float> {
631 let feature_mean = feature.mean().unwrap_or(0.0);
633 let label_mean = label.mean().unwrap_or(0.0);
634
635 let mut numerator = 0.0;
636 let mut feature_variance = 0.0;
637 let mut label_variance = 0.0;
638
639 let n = feature.len();
640 for i in 0..n {
641 let f_diff = feature[i] - feature_mean;
642 let l_diff = label[i] - label_mean;
643
644 numerator += f_diff * l_diff;
645 feature_variance += f_diff * f_diff;
646 label_variance += l_diff * l_diff;
647 }
648
649 if feature_variance == 0.0 || label_variance == 0.0 {
650 return Ok(0.0);
651 }
652
653 let correlation = numerator / (feature_variance * label_variance).sqrt();
654 Ok(correlation.abs())
655 }
656
657 fn aggregate_selections(&self, label_selections: &[Vec<usize>]) -> Vec<usize> {
658 match self.aggregate_method {
659 AggregateMethod::Union => {
660 let mut result = HashSet::new();
661 for selection in label_selections {
662 result.extend(selection);
663 }
664 result.into_iter().collect()
665 }
666 AggregateMethod::Intersection => {
667 if label_selections.is_empty() {
668 return vec![];
669 }
670 let mut result: HashSet<usize> = label_selections[0].iter().cloned().collect();
671 for selection in &label_selections[1..] {
672 let selection_set: HashSet<usize> = selection.iter().cloned().collect();
673 result = result.intersection(&selection_set).cloned().collect();
674 }
675 result.into_iter().collect()
676 }
677 AggregateMethod::MajorityVote => {
678 let mut feature_counts: HashMap<usize, usize> = HashMap::new();
679 for selection in label_selections {
680 for &feature in selection {
681 *feature_counts.entry(feature).or_insert(0) += 1;
682 }
683 }
684 let majority_threshold = (label_selections.len() + 1) / 2;
685 feature_counts
686 .into_iter()
687 .filter(|(_, count)| *count >= majority_threshold)
688 .map(|(feature, _)| feature)
689 .collect()
690 }
691 AggregateMethod::WeightedUnion => {
692 let mut result = HashSet::new();
694 for selection in label_selections {
695 result.extend(selection);
696 }
697 result.into_iter().collect()
698 }
699 }
700 }
701}
702
703impl Estimator for LabelSpecificSelector<Untrained> {
704 type Config = ();
705 type Error = SklearsError;
706 type Float = Float;
707
708 fn config(&self) -> &Self::Config {
709 &()
710 }
711}
712
713impl Fit<Array2<Float>, MultiLabelTarget> for LabelSpecificSelector<Untrained> {
714 type Fitted = LabelSpecificSelector<Trained>;
715
716 fn fit(self, features: &Array2<Float>, target: &MultiLabelTarget) -> SklResult<Self::Fitted> {
717 if features.nrows() != target.nrows() {
719 return Err(SklearsError::InvalidInput(format!(
720 "Inconsistent numbers of samples: features has {} samples, target has {}",
721 features.nrows(),
722 target.nrows()
723 )));
724 }
725
726 let n_features = features.ncols();
727 let n_labels = target.ncols();
728 let mut label_selections = Vec::with_capacity(n_labels);
729
730 for label_idx in 0..n_labels {
731 let label_col = target.column(label_idx);
732 let selection = self.select_for_label(features, &label_col)?;
733 label_selections.push(selection);
734 }
735
736 let selected_features = self.aggregate_selections(&label_selections);
737
738 Ok(LabelSpecificSelector {
739 n_features_per_label: self.n_features_per_label,
740 threshold: self.threshold,
741 aggregate_method: self.aggregate_method,
742 state: PhantomData,
743 selected_features_: Some(selected_features),
744 label_selections_: Some(label_selections),
745 n_features_: Some(n_features),
746 n_labels_: Some(n_labels),
747 })
748 }
749}
750
751impl Transform<Array2<Float>> for LabelSpecificSelector<Trained> {
752 fn transform(&self, x: &Array2<Float>) -> SklResult<Array2<Float>> {
753 validate::check_n_features(x, self.n_features_.unwrap())?;
754
755 let selected_features = self.selected_features_.as_ref().unwrap();
756 let n_samples = x.nrows();
757 let n_selected = selected_features.len();
758
759 if n_selected == 0 {
760 return Err(SklearsError::InvalidInput(
761 "No features were selected".to_string(),
762 ));
763 }
764
765 let mut x_new = Array2::zeros((n_samples, n_selected));
766
767 for (new_idx, &old_idx) in selected_features.iter().enumerate() {
768 x_new.column_mut(new_idx).assign(&x.column(old_idx));
769 }
770
771 Ok(x_new)
772 }
773}
774
775impl SelectorMixin for LabelSpecificSelector<Trained> {
776 fn get_support(&self) -> SklResult<Array1<bool>> {
777 let n_features = self.n_features_.unwrap();
778 let selected_features = self.selected_features_.as_ref().unwrap();
779 let mut support = Array1::from_elem(n_features, false);
780
781 for &idx in selected_features {
782 support[idx] = true;
783 }
784
785 Ok(support)
786 }
787
788 fn transform_features(&self, indices: &[usize]) -> SklResult<Vec<usize>> {
789 let selected_features = self.selected_features_.as_ref().unwrap();
790 Ok(indices
791 .iter()
792 .filter_map(|&idx| selected_features.iter().position(|&f| f == idx))
793 .collect())
794 }
795}
796
797impl FeatureSelector for LabelSpecificSelector<Trained> {
798 fn selected_features(&self) -> &Vec<usize> {
799 self.selected_features_.as_ref().unwrap()
800 }
801}
802
803impl LabelSpecificSelector<Trained> {
804 pub fn features_for_label(&self, label_idx: usize) -> Option<&[usize]> {
805 self.label_selections_
806 .as_ref()?
807 .get(label_idx)
808 .map(|v| v.as_slice())
809 }
810
811 pub fn n_features_out(&self) -> usize {
812 self.selected_features_.as_ref().unwrap().len()
813 }
814
815 pub fn n_labels(&self) -> usize {
816 self.n_labels_.unwrap()
817 }
818}
819
820#[allow(non_snake_case)]
821#[cfg(test)]
822mod tests {
823 use super::*;
824 use proptest::prelude::*;
825 use scirs2_core::ndarray::Array2;
826
827 fn create_test_data() -> (Array2<Float>, MultiLabelTarget) {
828 let features =
829 Array2::from_shape_vec((100, 10), (0..1000).map(|i| (i as Float) * 0.01).collect())
830 .unwrap();
831 let labels = Array2::from_shape_vec(
832 (100, 3),
833 (0..300)
834 .map(|i| if i % 3 == 0 { 1.0 } else { 0.0 })
835 .collect(),
836 )
837 .unwrap();
838 (features, labels)
839 }
840
841 #[test]
842 fn test_multi_label_selector_global_relevance() {
843 let (features, labels) = create_test_data();
844
845 let selector = MultiLabelFeatureSelector::new()
846 .strategy(MultiLabelStrategy::GlobalRelevance)
847 .n_features(5);
848
849 let trained = selector.fit(&features, &labels).unwrap();
850 assert_eq!(trained.n_features_out(), 5);
851 assert_eq!(trained.selected_features().len(), 5);
852 }
853
854 #[test]
855 fn test_multi_label_selector_label_specific() {
856 let (features, labels) = create_test_data();
857
858 let selector = MultiLabelFeatureSelector::new()
859 .strategy(MultiLabelStrategy::LabelSpecific)
860 .n_features(3); let trained = selector.fit(&features, &labels).unwrap();
863 assert_eq!(trained.n_features_out(), 3);
864 }
865
866 #[test]
867 fn test_multi_label_transform() {
868 let (features, labels) = create_test_data();
869
870 let selector = MultiLabelFeatureSelector::new().n_features(3);
871
872 let trained = selector.fit(&features, &labels).unwrap();
873 let transformed = trained.transform(&features).unwrap();
874
875 assert_eq!(transformed.ncols(), 3);
876 assert_eq!(transformed.nrows(), features.nrows());
877 }
878
879 #[test]
880 fn test_label_specific_selector() {
881 let (features, labels) = create_test_data();
882
883 let selector = LabelSpecificSelector::new()
884 .n_features_per_label(2)
885 .aggregate_method(AggregateMethod::Union);
886
887 let trained = selector.fit(&features, &labels).unwrap();
888 assert!(trained.n_features_out() > 0);
889 assert!(trained.n_features_out() <= 6); }
891
892 #[test]
893 fn test_ensemble_strategy() {
894 let (features, labels) = create_test_data();
895
896 let selector = MultiLabelFeatureSelector::new()
897 .strategy(MultiLabelStrategy::Ensemble)
898 .n_features(4);
899
900 let trained = selector.fit(&features, &labels).unwrap();
901 assert_eq!(trained.n_features_out(), 4);
902 }
903
904 #[test]
905 fn test_feature_ranking() {
906 let (features, labels) = create_test_data();
907
908 let selector = MultiLabelFeatureSelector::new().n_features(5);
909
910 let trained = selector.fit(&features, &labels).unwrap();
911 let ranking = trained.feature_ranking();
912
913 assert_eq!(ranking.len(), features.ncols());
914 for &selected_idx in trained.selected_features() {
916 assert!(ranking[selected_idx] < 5);
917 }
918 }
919
920 #[test]
921 fn test_selector_mixin() {
922 let (features, labels) = create_test_data();
923
924 let selector = MultiLabelFeatureSelector::new().n_features(3);
925
926 let trained = selector.fit(&features, &labels).unwrap();
927 let support = trained.get_support().unwrap();
928
929 assert_eq!(support.len(), features.ncols());
930 assert_eq!(support.iter().filter(|&&x| x).count(), 3);
931 }
932
933 mod proptests {
935 use super::*;
936
937 fn valid_array_2d() -> impl Strategy<Value = Array2<Float>> {
938 (5usize..20, 10usize..50).prop_flat_map(|(n_cols, n_rows)| {
939 prop::collection::vec(-10.0..10.0f64, n_rows * n_cols).prop_map(move |values| {
940 Array2::from_shape_vec((n_rows, n_cols), values).unwrap()
941 })
942 })
943 }
944
945 fn valid_multilabel_target(
946 n_samples: usize,
947 n_labels: usize,
948 ) -> impl Strategy<Value = MultiLabelTarget> {
949 prop::collection::vec(0.0..1.0f64, n_samples * n_labels).prop_map(move |values| {
950 Array2::from_shape_vec((n_samples, n_labels), values).unwrap()
951 })
952 }
953
954 proptest! {
955 #[test]
956 fn prop_multi_label_selector_respects_feature_count(
957 features in valid_array_2d(),
958 n_features in 1usize..10
959 ) {
960 let n_labels = 3;
961 let labels = Array2::from_elem((features.nrows(), n_labels), 0.5);
962
963 let n_select = n_features.min(features.ncols());
964 let selector = MultiLabelFeatureSelector::new()
965 .n_features(n_select);
966
967 if let Ok(trained) = selector.fit(&features, &labels) {
968 prop_assert_eq!(trained.n_features_out(), n_select);
969 prop_assert!(trained.selected_features().len() == n_select);
970
971 for &idx in trained.selected_features() {
973 prop_assert!(idx < features.ncols());
974 }
975
976 if let Ok(transformed) = trained.transform(&features) {
978 prop_assert_eq!(transformed.ncols(), n_select);
979 prop_assert_eq!(transformed.nrows(), features.nrows());
980 }
981 }
982 }
983
984 #[test]
985 fn prop_multi_label_selector_deterministic(
986 features in valid_array_2d(),
987 n_features in 1usize..5
988 ) {
989 let n_labels = 2;
990 let labels = Array2::from_elem((features.nrows(), n_labels), 0.3);
991
992 let n_select = n_features.min(features.ncols());
993 let selector = MultiLabelFeatureSelector::new()
994 .strategy(MultiLabelStrategy::GlobalRelevance)
995 .n_features(n_select);
996
997 if let Ok(trained1) = selector.clone().fit(&features, &labels) {
998 if let Ok(trained2) = selector.fit(&features, &labels) {
999 prop_assert_eq!(trained1.selected_features(), trained2.selected_features());
1001 prop_assert_eq!(trained1.n_features_out(), trained2.n_features_out());
1002 }
1003 }
1004 }
1005
1006 #[test]
1007 fn prop_multi_label_selector_scores_non_negative(
1008 features in valid_array_2d(),
1009 n_features in 1usize..5
1010 ) {
1011 let n_labels = 2;
1012 let labels = Array2::from_elem((features.nrows(), n_labels), 0.4);
1013
1014 let n_select = n_features.min(features.ncols());
1015 let selector = MultiLabelFeatureSelector::new()
1016 .n_features(n_select);
1017
1018 if let Ok(trained) = selector.fit(&features, &labels) {
1019 let scores = trained.scores();
1020
1021 for &score in scores.iter() {
1023 prop_assert!(score >= 0.0);
1024 }
1025
1026 let selected_indices = trained.selected_features();
1028 let min_selected_score = selected_indices.iter()
1029 .map(|&idx| scores[idx])
1030 .fold(f64::INFINITY, f64::min);
1031
1032 let count_above_min = scores.iter()
1034 .filter(|&&score| score >= min_selected_score)
1035 .count();
1036
1037 prop_assert!(count_above_min >= selected_indices.len());
1039 }
1040 }
1041
1042 #[test]
1043 fn prop_label_specific_selector_aggregation_consistency(
1044 features in valid_array_2d(),
1045 n_features_per_label in 1usize..3
1046 ) {
1047 let n_labels = 3;
1048 let labels = Array2::from_elem((features.nrows(), n_labels), 0.5);
1049
1050 let n_select = n_features_per_label.min(features.ncols());
1051
1052 let selector_union = LabelSpecificSelector::new()
1054 .n_features_per_label(n_select)
1055 .aggregate_method(AggregateMethod::Union);
1056
1057 if let Ok(trained_union) = selector_union.fit(&features, &labels) {
1058 prop_assert!(trained_union.n_features_out() <= n_select * n_labels);
1060
1061 let selector_intersect = LabelSpecificSelector::new()
1063 .n_features_per_label(n_select)
1064 .aggregate_method(AggregateMethod::Intersection);
1065
1066 if let Ok(trained_intersect) = selector_intersect.fit(&features, &labels) {
1067 prop_assert!(trained_intersect.n_features_out() <= n_select);
1069
1070 let union_set: std::collections::HashSet<_> = trained_union.selected_features().iter().collect();
1072 for &feature in trained_intersect.selected_features() {
1073 prop_assert!(union_set.contains(&feature));
1074 }
1075 }
1076 }
1077 }
1078
1079 #[test]
1080 fn prop_multi_label_transform_preserves_samples(
1081 features in valid_array_2d(),
1082 n_features in 1usize..5
1083 ) {
1084 let n_labels = 2;
1085 let labels = Array2::from_elem((features.nrows(), n_labels), 0.4);
1086
1087 let n_select = n_features.min(features.ncols());
1088 let selector = MultiLabelFeatureSelector::new()
1089 .n_features(n_select);
1090
1091 if let Ok(trained) = selector.fit(&features, &labels) {
1092 if let Ok(transformed) = trained.transform(&features) {
1093 prop_assert_eq!(transformed.nrows(), features.nrows());
1095
1096 prop_assert_eq!(transformed.ncols(), n_select);
1098
1099 for (sample_idx, row) in transformed.rows().into_iter().enumerate() {
1101 for (feat_idx, &value) in row.iter().enumerate() {
1102 let original_feat_idx = trained.selected_features()[feat_idx];
1103 let expected_value = features[[sample_idx, original_feat_idx]];
1104 prop_assert!((value - expected_value).abs() < 1e-10);
1105 }
1106 }
1107 }
1108 }
1109 }
1110 }
1111 }
1112}