1use scirs2_core::ndarray::{Array2, ArrayView1, ArrayView2};
8use sklears_core::error::{Result as SklResult, SklearsError};
9use std::marker::PhantomData;
10
11type Result<T> = SklResult<T>;
12
13pub mod selection_types {
15 #[derive(Debug, Clone, Copy, Default)]
17 pub struct Filter;
18
19 #[derive(Debug, Clone, Copy, Default)]
21 pub struct Wrapper;
22
23 #[derive(Debug, Clone, Copy, Default)]
25 pub struct Embedded;
26
27 #[derive(Debug, Clone, Copy, Default)]
29 pub struct Univariate;
30
31 #[derive(Debug, Clone, Copy, Default)]
33 pub struct Multivariate;
34
35 #[derive(Debug, Clone, Copy, Default)]
37 pub struct Supervised;
38
39 #[derive(Debug, Clone, Copy, Default)]
41 pub struct Unsupervised;
42
43 #[derive(Debug, Clone, Copy, Default)]
45 pub struct Deterministic;
46
47 #[derive(Debug, Clone, Copy, Default)]
49 pub struct Stochastic;
50}
51
52pub mod data_states {
54 #[derive(Debug, Clone, Copy, Default)]
56 pub struct Untrained;
57
58 #[derive(Debug, Clone, Copy, Default)]
60 pub struct Trained;
61
62 #[derive(Debug, Clone, Copy, Default)]
64 pub struct Validated;
65
66 #[derive(Debug, Clone, Copy, Default)]
68 pub struct Optimized;
69}
70
71#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
73pub struct FeatureIndex<const MAX_FEATURES: usize> {
74 index: usize,
75}
76
77impl<const MAX_FEATURES: usize> FeatureIndex<MAX_FEATURES> {
78 pub const fn new(index: usize) -> Option<Self> {
80 if index < MAX_FEATURES {
81 Some(Self { index })
82 } else {
83 None
84 }
85 }
86
87 pub const unsafe fn new_unchecked(index: usize) -> Self {
92 Self { index }
93 }
94
95 pub const fn get(self) -> usize {
97 self.index
98 }
99
100 pub const fn to_runtime(self) -> RuntimeFeatureIndex {
102 RuntimeFeatureIndex::new(self.index)
103 }
104}
105
106#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
108pub struct RuntimeFeatureIndex {
109 index: usize,
110}
111
112impl RuntimeFeatureIndex {
113 pub const fn new(index: usize) -> Self {
115 Self { index }
116 }
117
118 pub const fn get(self) -> usize {
120 self.index
121 }
122
123 pub const fn is_valid(self, n_features: usize) -> bool {
125 self.index < n_features
126 }
127}
128
129#[derive(Debug, Clone)]
131pub struct FeatureMask<const N_FEATURES: usize> {
132 mask: [bool; N_FEATURES],
133}
134
135impl<const N_FEATURES: usize> FeatureMask<N_FEATURES> {
136 pub const fn all_selected() -> Self {
138 Self {
139 mask: [true; N_FEATURES],
140 }
141 }
142
143 pub const fn none_selected() -> Self {
145 Self {
146 mask: [false; N_FEATURES],
147 }
148 }
149
150 pub const fn from_array(mask: [bool; N_FEATURES]) -> Self {
152 Self { mask }
153 }
154
155 pub fn from_indices(indices: &[FeatureIndex<N_FEATURES>]) -> Self {
157 let mut mask = [false; N_FEATURES];
158 for &index in indices {
159 mask[index.get()] = true;
160 }
161 Self { mask }
162 }
163
164 pub const fn as_array(&self) -> &[bool; N_FEATURES] {
166 &self.mask
167 }
168
169 pub const fn is_selected(&self, index: FeatureIndex<N_FEATURES>) -> bool {
171 self.mask[index.get()]
172 }
173
174 pub fn set(&mut self, index: FeatureIndex<N_FEATURES>, selected: bool) {
176 self.mask[index.get()] = selected;
177 }
178
179 pub fn count_selected(&self) -> usize {
181 self.mask.iter().filter(|&&x| x).count()
182 }
183
184 pub fn selected_indices(&self) -> Vec<FeatureIndex<N_FEATURES>> {
186 self.mask
187 .iter()
188 .enumerate()
189 .filter_map(|(i, &selected)| {
190 if selected {
191 Some(unsafe { FeatureIndex::new_unchecked(i) })
193 } else {
194 None
195 }
196 })
197 .collect()
198 }
199
200 pub fn and(&self, other: &Self) -> Self {
202 let mut result = [false; N_FEATURES];
203 for i in 0..N_FEATURES {
204 result[i] = self.mask[i] && other.mask[i];
205 }
206 Self::from_array(result)
207 }
208
209 pub fn or(&self, other: &Self) -> Self {
211 let mut result = [false; N_FEATURES];
212 for i in 0..N_FEATURES {
213 result[i] = self.mask[i] || other.mask[i];
214 }
215 Self::from_array(result)
216 }
217
218 pub fn not(&self) -> Self {
220 let mut result = [false; N_FEATURES];
221 for i in 0..N_FEATURES {
222 result[i] = !self.mask[i];
223 }
224 Self::from_array(result)
225 }
226}
227
228#[derive(Debug, Clone)]
230pub struct FeatureMatrix<T, const N_FEATURES: usize> {
231 data: Array2<T>,
232 _phantom: PhantomData<[T; N_FEATURES]>,
233}
234
235impl<T, const N_FEATURES: usize> FeatureMatrix<T, N_FEATURES>
236where
237 T: Clone + Default,
238{
239 pub fn new(data: Array2<T>) -> Result<Self> {
241 if data.ncols() == N_FEATURES {
242 Ok(Self {
243 data,
244 _phantom: PhantomData,
245 })
246 } else {
247 Err(SklearsError::InvalidInput(format!(
248 "Expected {} features, got {}",
249 N_FEATURES,
250 data.ncols()
251 )))
252 }
253 }
254
255 pub unsafe fn new_unchecked(data: Array2<T>) -> Self {
260 Self {
261 data,
262 _phantom: PhantomData,
263 }
264 }
265
266 pub fn n_samples(&self) -> usize {
268 self.data.nrows()
269 }
270
271 pub const fn n_features(&self) -> usize {
273 N_FEATURES
274 }
275
276 pub fn view(&self) -> ArrayView2<'_, T> {
278 self.data.view()
279 }
280
281 pub fn feature(&self, index: FeatureIndex<N_FEATURES>) -> ArrayView1<'_, T> {
283 self.data.column(index.get())
284 }
285
286 pub fn select_features<const N_SELECTED: usize>(
288 &self,
289 mask: &FeatureMask<N_FEATURES>,
290 ) -> Result<FeatureMatrix<T, N_SELECTED>> {
291 let selected_indices = mask.selected_indices();
292 if selected_indices.len() != N_SELECTED {
293 return Err(SklearsError::InvalidInput(format!(
294 "Expected {} selected features, got {}",
295 N_SELECTED,
296 selected_indices.len()
297 )));
298 }
299
300 let mut selected_data = Array2::default((self.n_samples(), N_SELECTED));
301 for (new_col, &old_index) in selected_indices.iter().enumerate() {
302 for row in 0..self.n_samples() {
303 selected_data[[row, new_col]] = self.data[[row, old_index.get()]].clone();
304 }
305 }
306
307 Ok(FeatureMatrix {
308 data: selected_data,
309 _phantom: PhantomData,
310 })
311 }
312
313 pub fn to_dynamic(self) -> DynamicFeatureMatrix<T> {
315 DynamicFeatureMatrix::new(self.data)
316 }
317}
318
319#[derive(Debug, Clone)]
321pub struct DynamicFeatureMatrix<T> {
322 data: Array2<T>,
323}
324
325impl<T> DynamicFeatureMatrix<T> {
326 pub fn new(data: Array2<T>) -> Self {
328 Self { data }
329 }
330
331 pub fn n_samples(&self) -> usize {
333 self.data.nrows()
334 }
335
336 pub fn n_features(&self) -> usize {
338 self.data.ncols()
339 }
340
341 pub fn view(&self) -> ArrayView2<'_, T> {
343 self.data.view()
344 }
345
346 pub fn feature(&self, index: RuntimeFeatureIndex) -> Result<ArrayView1<'_, T>> {
348 if index.is_valid(self.n_features()) {
349 Ok(self.data.column(index.get()))
350 } else {
351 Err(SklearsError::InvalidInput(format!(
352 "Feature index {} out of bounds for {} features",
353 index.get(),
354 self.n_features()
355 )))
356 }
357 }
358
359 pub fn to_static<const N_FEATURES: usize>(self) -> Result<FeatureMatrix<T, N_FEATURES>>
361 where
362 T: Clone + Default,
363 {
364 FeatureMatrix::new(self.data)
365 }
366}
367
368pub trait TypeSafeSelector<Method, State = data_states::Untrained> {
370 type FittedState;
372
373 type SelectionResult;
375
376 fn fit_typed<const N_FEATURES: usize>(
378 self,
379 X: &FeatureMatrix<f64, N_FEATURES>,
380 y: ArrayView1<f64>,
381 ) -> Result<TypeSafeSelectorWrapper<Method, Self::FittedState, N_FEATURES>>;
382}
383
384#[derive(Debug, Clone)]
386pub struct TypeSafeSelectorWrapper<Method, State, const N_FEATURES: usize> {
387 method_params: MethodParameters,
388 selection_result: Option<FeatureMask<N_FEATURES>>,
389 _phantom: PhantomData<(Method, State)>,
390}
391
392impl<Method, State, const N_FEATURES: usize> TypeSafeSelectorWrapper<Method, State, N_FEATURES> {
393 pub fn new(method_params: MethodParameters) -> Self {
395 Self {
396 method_params,
397 selection_result: None,
398 _phantom: PhantomData,
399 }
400 }
401
402 pub fn selection_mask(&self) -> Option<&FeatureMask<N_FEATURES>> {
404 self.selection_result.as_ref()
405 }
406
407 pub fn set_selection(&mut self, mask: FeatureMask<N_FEATURES>) {
409 self.selection_result = Some(mask);
410 }
411}
412
413#[derive(Debug, Clone)]
415pub enum MethodParameters {
416 VarianceThreshold {
418 threshold: f64,
419 },
420 UnivariateFilter {
422 k: usize,
423
424 score_function: String,
425 },
426 RecursiveElimination {
428 n_features: usize,
429
430 step: f64,
431 },
432 LassoSelection {
433 alpha: f64,
434 max_iter: usize,
435 },
436 TreeBasedSelection {
437 n_estimators: usize,
438 max_depth: Option<usize>,
439 },
440 CorrelationFilter {
441 threshold: f64,
442 },
443 MutualInfoSelection {
444 k: usize,
445 discrete_features: Vec<bool>,
446 },
447}
448
449#[derive(Debug, Clone)]
451pub struct VarianceThresholdSelector<const N_FEATURES: usize> {
452 threshold: f64,
453 feature_variances: Option<[f64; N_FEATURES]>,
454}
455
456impl<const N_FEATURES: usize> VarianceThresholdSelector<N_FEATURES> {
457 pub const fn new(threshold: f64) -> Self {
459 Self {
460 threshold,
461 feature_variances: None,
462 }
463 }
464
465 pub fn fit(&mut self, X: &FeatureMatrix<f64, N_FEATURES>) -> Result<FeatureMask<N_FEATURES>> {
467 let mut variances = [0.0; N_FEATURES];
468
469 for i in 0..N_FEATURES {
470 let feature_index = unsafe { FeatureIndex::new_unchecked(i) };
472 let feature_data = X.feature(feature_index);
473 variances[i] = feature_data.var(1.0);
474 }
475
476 self.feature_variances = Some(variances);
477
478 let mut mask = [false; N_FEATURES];
479 for i in 0..N_FEATURES {
480 mask[i] = variances[i] > self.threshold;
481 }
482
483 Ok(FeatureMask::from_array(mask))
484 }
485
486 pub fn transform<const N_SELECTED: usize>(
488 &self,
489 X: &FeatureMatrix<f64, N_FEATURES>,
490 mask: &FeatureMask<N_FEATURES>,
491 ) -> Result<FeatureMatrix<f64, N_SELECTED>> {
492 X.select_features(mask)
493 }
494
495 pub const fn feature_variances(&self) -> Option<&[f64; N_FEATURES]> {
497 self.feature_variances.as_ref()
498 }
499}
500
501#[derive(Debug, Clone)]
503pub struct UnivariateSelector<const N_FEATURES: usize, const K: usize> {
504 score_function: UnivariateScoreFunction,
505 feature_scores: Option<[f64; N_FEATURES]>,
506}
507
508impl<const N_FEATURES: usize, const K: usize> UnivariateSelector<N_FEATURES, K> {
509 pub const fn new(score_function: UnivariateScoreFunction) -> Option<Self> {
514 if K <= N_FEATURES {
515 Some(Self {
516 score_function,
517 feature_scores: None,
518 })
519 } else {
520 None
521 }
522 }
523
524 pub fn fit(
526 &mut self,
527 X: &FeatureMatrix<f64, N_FEATURES>,
528 y: ArrayView1<f64>,
529 ) -> Result<FeatureMask<N_FEATURES>> {
530 let mut scores = [0.0; N_FEATURES];
531
532 for i in 0..N_FEATURES {
533 let feature_index = unsafe { FeatureIndex::new_unchecked(i) };
535 let feature_data = X.feature(feature_index);
536 scores[i] = match self.score_function {
537 UnivariateScoreFunction::Correlation => self.compute_correlation(feature_data, y),
538 UnivariateScoreFunction::MutualInfo => self.compute_mutual_info(feature_data, y),
539 UnivariateScoreFunction::Chi2 => self.compute_chi2_score(feature_data, y),
540 UnivariateScoreFunction::FStatistic => self.compute_f_statistic(feature_data, y),
541 };
542 }
543
544 self.feature_scores = Some(scores);
545
546 let mut indexed_scores: Vec<(usize, f64)> = Vec::with_capacity(N_FEATURES);
548 for i in 0..N_FEATURES {
549 indexed_scores.push((i, scores[i]));
550 }
551
552 indexed_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
554
555 let mut mask = [false; N_FEATURES];
556 for i in 0..K {
557 if let Some(&(feature_idx, _)) = indexed_scores.get(i) {
558 mask[feature_idx] = true;
559 }
560 }
561
562 Ok(FeatureMask::from_array(mask))
563 }
564
565 pub fn transform(
567 &self,
568 X: &FeatureMatrix<f64, N_FEATURES>,
569 mask: &FeatureMask<N_FEATURES>,
570 ) -> Result<FeatureMatrix<f64, K>> {
571 X.select_features(mask)
572 }
573
574 pub const fn feature_scores(&self) -> Option<&[f64; N_FEATURES]> {
576 self.feature_scores.as_ref()
577 }
578
579 fn compute_correlation(&self, x: ArrayView1<f64>, y: ArrayView1<f64>) -> f64 {
581 let n = x.len() as f64;
582 if n < 2.0 {
583 return 0.0;
584 }
585
586 let mean_x = x.mean().unwrap_or(0.0);
587 let mean_y = y.mean().unwrap_or(0.0);
588
589 let mut sum_xy = 0.0;
590 let mut sum_x2 = 0.0;
591 let mut sum_y2 = 0.0;
592
593 for i in 0..x.len() {
594 let dx = x[i] - mean_x;
595 let dy = y[i] - mean_y;
596 sum_xy += dx * dy;
597 sum_x2 += dx * dx;
598 sum_y2 += dy * dy;
599 }
600
601 let denom = (sum_x2 * sum_y2).sqrt();
602 if denom < 1e-10 {
603 0.0
604 } else {
605 (sum_xy / denom).abs()
606 }
607 }
608
609 fn compute_mutual_info(&self, x: ArrayView1<f64>, y: ArrayView1<f64>) -> f64 {
610 self.compute_correlation(x, y)
613 }
614
615 fn compute_chi2_score(&self, x: ArrayView1<f64>, y: ArrayView1<f64>) -> f64 {
616 self.compute_correlation(x, y)
619 }
620
621 fn compute_f_statistic(&self, x: ArrayView1<f64>, y: ArrayView1<f64>) -> f64 {
622 self.compute_correlation(x, y)
625 }
626}
627
628#[derive(Debug, Clone, Copy)]
630pub enum UnivariateScoreFunction {
631 Correlation,
633 MutualInfo,
635 Chi2,
637 FStatistic,
639}
640
641#[derive(Debug, Clone)]
643pub struct CorrelationSelector<const N_FEATURES: usize> {
644 threshold: f64,
645 correlation_matrix: Option<[[f64; N_FEATURES]; N_FEATURES]>,
646}
647
648impl<const N_FEATURES: usize> CorrelationSelector<N_FEATURES> {
649 pub const fn new(threshold: f64) -> Self {
651 Self {
652 threshold,
653 correlation_matrix: None,
654 }
655 }
656
657 pub fn fit(&mut self, X: &FeatureMatrix<f64, N_FEATURES>) -> Result<FeatureMask<N_FEATURES>> {
659 let mut corr_matrix = [[0.0; N_FEATURES]; N_FEATURES];
660
661 for i in 0..N_FEATURES {
663 for j in 0..N_FEATURES {
664 if i == j {
665 corr_matrix[i][j] = 1.0;
666 } else {
667 let feature_i = unsafe { FeatureIndex::new_unchecked(i) };
669 let feature_j = unsafe { FeatureIndex::new_unchecked(j) };
670 let data_i = X.feature(feature_i);
671 let data_j = X.feature(feature_j);
672 corr_matrix[i][j] = self.compute_correlation(data_i, data_j);
673 }
674 }
675 }
676
677 self.correlation_matrix = Some(corr_matrix);
678
679 let mut mask = [true; N_FEATURES];
681 for i in 0..N_FEATURES {
682 for j in (i + 1)..N_FEATURES {
683 if corr_matrix[i][j].abs() > self.threshold && mask[i] && mask[j] {
684 let feature_i = unsafe { FeatureIndex::new_unchecked(i) };
687 let feature_j = unsafe { FeatureIndex::new_unchecked(j) };
688 let var_i = X.feature(feature_i).var(1.0);
689 let var_j = X.feature(feature_j).var(1.0);
690 if var_i < var_j {
691 mask[i] = false;
692 } else {
693 mask[j] = false;
694 }
695 }
696 }
697 }
698
699 Ok(FeatureMask::from_array(mask))
700 }
701
702 pub fn transform<const N_SELECTED: usize>(
704 &self,
705 X: &FeatureMatrix<f64, N_FEATURES>,
706 mask: &FeatureMask<N_FEATURES>,
707 ) -> Result<FeatureMatrix<f64, N_SELECTED>> {
708 X.select_features(mask)
709 }
710
711 pub const fn correlation_matrix(&self) -> Option<&[[f64; N_FEATURES]; N_FEATURES]> {
713 self.correlation_matrix.as_ref()
714 }
715
716 fn compute_correlation(&self, x: ArrayView1<f64>, y: ArrayView1<f64>) -> f64 {
717 let n = x.len() as f64;
718 if n < 2.0 {
719 return 0.0;
720 }
721
722 let mean_x = x.mean().unwrap_or(0.0);
723 let mean_y = y.mean().unwrap_or(0.0);
724
725 let mut sum_xy = 0.0;
726 let mut sum_x2 = 0.0;
727 let mut sum_y2 = 0.0;
728
729 for i in 0..x.len() {
730 let dx = x[i] - mean_x;
731 let dy = y[i] - mean_y;
732 sum_xy += dx * dy;
733 sum_x2 += dx * dx;
734 sum_y2 += dy * dy;
735 }
736
737 let denom = (sum_x2 * sum_y2).sqrt();
738 if denom < 1e-10 {
739 0.0
740 } else {
741 sum_xy / denom
742 }
743 }
744}
745
746#[derive(Debug, Clone)]
748pub struct TypeSafeSelectionPipeline<const N_FEATURES: usize, State = data_states::Untrained> {
749 steps: Vec<PipelineStep>,
750 current_mask: Option<FeatureMask<N_FEATURES>>,
751 _phantom: PhantomData<State>,
752}
753
754impl<const N_FEATURES: usize> Default
755 for TypeSafeSelectionPipeline<N_FEATURES, data_states::Untrained>
756{
757 fn default() -> Self {
758 Self::new()
759 }
760}
761
762impl<const N_FEATURES: usize> TypeSafeSelectionPipeline<N_FEATURES, data_states::Untrained> {
763 pub const fn new() -> Self {
765 Self {
766 steps: Vec::new(),
767 current_mask: None,
768 _phantom: PhantomData,
769 }
770 }
771
772 pub fn add_variance_threshold(mut self, threshold: f64) -> Self {
774 self.steps.push(PipelineStep::VarianceThreshold(threshold));
775 self
776 }
777
778 pub fn add_correlation_filter(mut self, threshold: f64) -> Self {
780 self.steps.push(PipelineStep::CorrelationFilter(threshold));
781 self
782 }
783
784 pub fn add_univariate_selection<const K: usize>(
786 mut self,
787 score_function: UnivariateScoreFunction,
788 ) -> Self {
789 self.steps.push(PipelineStep::UnivariateSelection {
790 k: K,
791 score_function,
792 });
793 self
794 }
795
796 pub fn fit(
798 self,
799 X: &FeatureMatrix<f64, N_FEATURES>,
800 y: ArrayView1<f64>,
801 ) -> Result<TypeSafeSelectionPipeline<N_FEATURES, data_states::Trained>> {
802 let mut current_mask = FeatureMask::all_selected();
803
804 for step in &self.steps {
805 let step_mask = match step {
806 PipelineStep::VarianceThreshold(threshold) => {
807 let mut selector = VarianceThresholdSelector::new(*threshold);
808 selector.fit(X)?
809 }
810 PipelineStep::CorrelationFilter(threshold) => {
811 let mut selector = CorrelationSelector::new(*threshold);
812 selector.fit(X)?
813 }
814 PipelineStep::UnivariateSelection {
815 k: _,
816 score_function,
817 } => {
818 const DEFAULT_K: usize = 10;
821 if DEFAULT_K <= N_FEATURES {
822 let mut selector =
823 UnivariateSelector::<N_FEATURES, DEFAULT_K>::new(*score_function)
824 .ok_or_else(|| {
825 SklearsError::InvalidInput(
826 "Invalid K for univariate selection".to_string(),
827 )
828 })?;
829 selector.fit(X, y)?
830 } else {
831 FeatureMask::all_selected()
832 }
833 }
834 };
835
836 current_mask = current_mask.and(&step_mask);
837 }
838
839 Ok(TypeSafeSelectionPipeline {
840 steps: self.steps,
841 current_mask: Some(current_mask),
842 _phantom: PhantomData,
843 })
844 }
845}
846
847impl<const N_FEATURES: usize> TypeSafeSelectionPipeline<N_FEATURES, data_states::Trained> {
848 pub fn transform<const N_SELECTED: usize>(
850 &self,
851 X: &FeatureMatrix<f64, N_FEATURES>,
852 ) -> Result<FeatureMatrix<f64, N_SELECTED>> {
853 if let Some(ref mask) = self.current_mask {
854 X.select_features(mask)
855 } else {
856 Err(SklearsError::FitError("Pipeline not fitted".to_string()))
857 }
858 }
859
860 pub fn selection_mask(&self) -> Option<&FeatureMask<N_FEATURES>> {
862 self.current_mask.as_ref()
863 }
864
865 pub fn n_selected_features(&self) -> usize {
867 self.current_mask
868 .as_ref()
869 .map(|mask| mask.count_selected())
870 .unwrap_or(0)
871 }
872}
873
874#[derive(Debug, Clone)]
876enum PipelineStep {
877 VarianceThreshold(f64),
878 CorrelationFilter(f64),
879 UnivariateSelection {
880 k: usize,
881 score_function: UnivariateScoreFunction,
882 },
883}
884
885pub trait ZeroCostTransform<Input, Output> {
887 fn transform_zero_cost(input: Input) -> Output;
889}
890
891impl<const N: usize> ZeroCostTransform<FeatureIndex<N>, usize> for () {
893 fn transform_zero_cost(input: FeatureIndex<N>) -> usize {
894 input.get()
895 }
896}
897
898impl<const N: usize> ZeroCostTransform<FeatureMask<N>, Vec<bool>> for () {
900 fn transform_zero_cost(input: FeatureMask<N>) -> Vec<bool> {
901 input.as_array().to_vec()
902 }
903}
904
905pub struct FeatureCountValidator<const EXPECTED: usize>;
907
908impl<const EXPECTED: usize> FeatureCountValidator<EXPECTED> {
909 pub const fn validate<const ACTUAL: usize>() -> bool {
911 EXPECTED == ACTUAL
912 }
913
914 pub fn validate_matrix<T>(matrix: FeatureMatrix<T, EXPECTED>) -> FeatureMatrix<T, EXPECTED>
916 where
917 T: Clone + Default,
918 {
919 matrix
920 }
921}
922
923pub trait TypeSafeFeatureSelection {
925 type FeatureMatrix;
927
928 type SelectionResult;
930
931 const INPUT_FEATURES: usize;
933
934 fn select_features_typed(data: Self::FeatureMatrix) -> Result<Self::SelectionResult>;
936}
937
938#[macro_export]
940macro_rules! impl_type_safe_selector {
941 ($selector:ty, $method:ty, $n_features:expr, $n_selected:expr) => {
942 impl TypeSafeFeatureSelection for $selector {
943 type FeatureMatrix = FeatureMatrix<f64, $n_features>;
944 type SelectionResult = FeatureMatrix<f64, $n_selected>;
945 const INPUT_FEATURES: usize = $n_features;
946
947 fn select_features_typed(data: Self::FeatureMatrix) -> Result<Self::SelectionResult> {
948 use crate::type_safe::VarianceThresholdSelector;
951
952 let mut selector = VarianceThresholdSelector::<$n_features>::new(0.0);
953 let mask = selector.fit(&data)?;
954
955 if mask.count_selected() != $n_selected {
957 return Err(SklearsError::InvalidInput(format!(
958 "Expected {} selected features, got {}. Consider adjusting selection parameters.",
959 $n_selected,
960 mask.count_selected()
961 )));
962 }
963
964 data.select_features(&mask)
965 }
966 }
967 };
968}
969
970pub const fn binomial_coefficient(n: usize, k: usize) -> usize {
972 if k > n {
973 0
974 } else if k == 0 || k == n {
975 1
976 } else {
977 let k = if k > n - k { n - k } else { k };
978 let mut result = 1;
979 let mut i = 0;
980 while i < k {
981 result = result * (n - i) / (i + 1);
982 i += 1;
983 }
984 result
985 }
986}
987
988pub const fn validate_selection_count<const N_FEATURES: usize, const K: usize>() -> bool {
990 K <= N_FEATURES && K > 0
991}
992
993pub trait TypeBool {
995 const VALUE: bool;
996}
997
998pub struct True;
999pub struct False;
1000
1001impl TypeBool for True {
1002 const VALUE: bool = true;
1003}
1004
1005impl TypeBool for False {
1006 const VALUE: bool = false;
1007}
1008
1009#[allow(non_snake_case)]
1026#[cfg(test)]
1027mod tests {
1028 use super::*;
1029 use scirs2_core::ndarray::array;
1030
1031 #[test]
1032 fn test_feature_index() {
1033 const MAX_FEATURES: usize = 10;
1034
1035 let valid_index = FeatureIndex::<MAX_FEATURES>::new(5).unwrap();
1037 assert_eq!(valid_index.get(), 5);
1038
1039 assert!(FeatureIndex::<MAX_FEATURES>::new(15).is_none());
1041 }
1042
1043 #[test]
1044 fn test_feature_mask() {
1045 const N_FEATURES: usize = 5;
1046
1047 let mask = FeatureMask::<N_FEATURES>::from_array([true, false, true, false, true]);
1048 assert_eq!(mask.count_selected(), 3);
1049
1050 let indices = mask.selected_indices();
1051 assert_eq!(indices.len(), 3);
1052 }
1053
1054 #[test]
1055 fn test_feature_matrix() -> Result<()> {
1056 const N_FEATURES: usize = 3;
1057
1058 let data = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
1059 let matrix = FeatureMatrix::<f64, N_FEATURES>::new(data)?;
1060
1061 assert_eq!(matrix.n_features(), 3);
1062 assert_eq!(matrix.n_samples(), 2);
1063
1064 Ok(())
1065 }
1066
1067 #[test]
1068 fn test_variance_threshold_selector() -> Result<()> {
1069 const N_FEATURES: usize = 3;
1070
1071 let data = array![[1.0, 2.0, 3.0], [1.1, 5.0, 3.1], [0.9, 8.0, 2.9]];
1072 let matrix = FeatureMatrix::<f64, N_FEATURES>::new(data)?;
1073
1074 let mut selector = VarianceThresholdSelector::new(0.1);
1075 let mask = selector.fit(&matrix)?;
1076
1077 assert!(mask.count_selected() > 0);
1079
1080 Ok(())
1081 }
1082
1083 #[test]
1084 fn test_compile_time_validation() {
1085 const N_FEATURES: usize = 10;
1086 const K: usize = 5;
1087
1088 assert!(validate_selection_count::<N_FEATURES, K>());
1090
1091 }
1094
1095 #[test]
1096 fn test_type_safe_pipeline() -> Result<()> {
1097 const N_FEATURES: usize = 4;
1098
1099 let data = array![
1100 [1.0, 2.0, 3.0, 4.0],
1101 [1.1, 5.0, 3.1, 4.1],
1102 [0.9, 8.0, 2.9, 3.9],
1103 [1.2, 2.1, 3.2, 4.2]
1104 ];
1105 let matrix = FeatureMatrix::<f64, N_FEATURES>::new(data)?;
1106 let y = array![0.0, 1.0, 0.0, 1.0];
1107
1108 let pipeline = TypeSafeSelectionPipeline::<N_FEATURES>::new()
1109 .add_variance_threshold(0.01)
1110 .add_correlation_filter(0.9);
1111
1112 let fitted_pipeline = pipeline.fit(&matrix, y.view())?;
1113 assert!(fitted_pipeline.n_selected_features() > 0);
1114
1115 Ok(())
1116 }
1117}