sklears_feature_selection/
type_safe.rs

1//! Type-Safe Feature Selection Framework
2//!
3//! This module provides compile-time guarantees for feature selection operations,
4//! using Rust's advanced type system features including phantom types, const generics,
5//! and zero-cost abstractions to ensure correctness and performance.
6
7use scirs2_core::ndarray::{Array2, ArrayView1, ArrayView2};
8use sklears_core::error::{Result as SklResult, SklearsError};
9use std::marker::PhantomData;
10
11type Result<T> = SklResult<T>;
12
13/// Phantom type markers for selection method types
14pub mod selection_types {
15    /// Marker for filter-based selection methods
16    #[derive(Debug, Clone, Copy, Default)]
17    pub struct Filter;
18
19    /// Marker for wrapper-based selection methods
20    #[derive(Debug, Clone, Copy, Default)]
21    pub struct Wrapper;
22
23    /// Marker for embedded selection methods
24    #[derive(Debug, Clone, Copy, Default)]
25    pub struct Embedded;
26
27    /// Marker for univariate selection methods
28    #[derive(Debug, Clone, Copy, Default)]
29    pub struct Univariate;
30
31    /// Marker for multivariate selection methods
32    #[derive(Debug, Clone, Copy, Default)]
33    pub struct Multivariate;
34
35    /// Marker for supervised selection methods
36    #[derive(Debug, Clone, Copy, Default)]
37    pub struct Supervised;
38
39    /// Marker for unsupervised selection methods
40    #[derive(Debug, Clone, Copy, Default)]
41    pub struct Unsupervised;
42
43    /// Marker for deterministic selection methods
44    #[derive(Debug, Clone, Copy, Default)]
45    pub struct Deterministic;
46
47    /// Marker for stochastic selection methods
48    #[derive(Debug, Clone, Copy, Default)]
49    pub struct Stochastic;
50}
51
52/// Phantom type markers for data states
53pub mod data_states {
54    /// Marker for untrained/unfitted state
55    #[derive(Debug, Clone, Copy, Default)]
56    pub struct Untrained;
57
58    /// Marker for trained/fitted state
59    #[derive(Debug, Clone, Copy, Default)]
60    pub struct Trained;
61
62    /// Marker for validated state
63    #[derive(Debug, Clone, Copy, Default)]
64    pub struct Validated;
65
66    /// Marker for optimized state
67    #[derive(Debug, Clone, Copy, Default)]
68    pub struct Optimized;
69}
70
71/// Type-safe feature index with compile-time bounds checking
72#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
73pub struct FeatureIndex<const MAX_FEATURES: usize> {
74    index: usize,
75}
76
77impl<const MAX_FEATURES: usize> FeatureIndex<MAX_FEATURES> {
78    /// Create a new feature index with compile-time bounds checking
79    pub const fn new(index: usize) -> Option<Self> {
80        if index < MAX_FEATURES {
81            Some(Self { index })
82        } else {
83            None
84        }
85    }
86
87    /// Create a new feature index without bounds checking (unsafe)
88    ///
89    /// # Safety
90    /// The caller must ensure that `index < MAX_FEATURES`
91    pub const unsafe fn new_unchecked(index: usize) -> Self {
92        Self { index }
93    }
94
95    /// Get the inner index value
96    pub const fn get(self) -> usize {
97        self.index
98    }
99
100    /// Convert to a runtime feature index
101    pub const fn to_runtime(self) -> RuntimeFeatureIndex {
102        RuntimeFeatureIndex::new(self.index)
103    }
104}
105
106/// Runtime feature index for dynamic bounds checking
107#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
108pub struct RuntimeFeatureIndex {
109    index: usize,
110}
111
112impl RuntimeFeatureIndex {
113    /// Create a new runtime feature index
114    pub const fn new(index: usize) -> Self {
115        Self { index }
116    }
117
118    /// Get the inner index value
119    pub const fn get(self) -> usize {
120        self.index
121    }
122
123    /// Check if this index is valid for a given number of features
124    pub const fn is_valid(self, n_features: usize) -> bool {
125        self.index < n_features
126    }
127}
128
129/// Type-safe feature mask with const generic size
130#[derive(Debug, Clone)]
131pub struct FeatureMask<const N_FEATURES: usize> {
132    mask: [bool; N_FEATURES],
133}
134
135impl<const N_FEATURES: usize> FeatureMask<N_FEATURES> {
136    /// Create a new feature mask with all features selected
137    pub const fn all_selected() -> Self {
138        Self {
139            mask: [true; N_FEATURES],
140        }
141    }
142
143    /// Create a new feature mask with no features selected
144    pub const fn none_selected() -> Self {
145        Self {
146            mask: [false; N_FEATURES],
147        }
148    }
149
150    /// Create a feature mask from a boolean array
151    pub const fn from_array(mask: [bool; N_FEATURES]) -> Self {
152        Self { mask }
153    }
154
155    /// Create a feature mask from selected indices
156    pub fn from_indices(indices: &[FeatureIndex<N_FEATURES>]) -> Self {
157        let mut mask = [false; N_FEATURES];
158        for &index in indices {
159            mask[index.get()] = true;
160        }
161        Self { mask }
162    }
163
164    /// Get the mask as a boolean array
165    pub const fn as_array(&self) -> &[bool; N_FEATURES] {
166        &self.mask
167    }
168
169    /// Check if a feature is selected
170    pub const fn is_selected(&self, index: FeatureIndex<N_FEATURES>) -> bool {
171        self.mask[index.get()]
172    }
173
174    /// Set a feature as selected or unselected
175    pub fn set(&mut self, index: FeatureIndex<N_FEATURES>, selected: bool) {
176        self.mask[index.get()] = selected;
177    }
178
179    /// Get the number of selected features
180    pub fn count_selected(&self) -> usize {
181        self.mask.iter().filter(|&&x| x).count()
182    }
183
184    /// Get indices of selected features
185    pub fn selected_indices(&self) -> Vec<FeatureIndex<N_FEATURES>> {
186        self.mask
187            .iter()
188            .enumerate()
189            .filter_map(|(i, &selected)| {
190                if selected {
191                    // Safety: i is always < N_FEATURES in this context
192                    Some(unsafe { FeatureIndex::new_unchecked(i) })
193                } else {
194                    None
195                }
196            })
197            .collect()
198    }
199
200    /// Combine with another mask using logical AND
201    pub fn and(&self, other: &Self) -> Self {
202        let mut result = [false; N_FEATURES];
203        for i in 0..N_FEATURES {
204            result[i] = self.mask[i] && other.mask[i];
205        }
206        Self::from_array(result)
207    }
208
209    /// Combine with another mask using logical OR
210    pub fn or(&self, other: &Self) -> Self {
211        let mut result = [false; N_FEATURES];
212        for i in 0..N_FEATURES {
213            result[i] = self.mask[i] || other.mask[i];
214        }
215        Self::from_array(result)
216    }
217
218    /// Invert the mask
219    pub fn not(&self) -> Self {
220        let mut result = [false; N_FEATURES];
221        for i in 0..N_FEATURES {
222            result[i] = !self.mask[i];
223        }
224        Self::from_array(result)
225    }
226}
227
228/// Type-safe feature matrix with compile-time feature count validation
229#[derive(Debug, Clone)]
230pub struct FeatureMatrix<T, const N_FEATURES: usize> {
231    data: Array2<T>,
232    _phantom: PhantomData<[T; N_FEATURES]>,
233}
234
235impl<T, const N_FEATURES: usize> FeatureMatrix<T, N_FEATURES>
236where
237    T: Clone + Default,
238{
239    /// Create a new feature matrix with compile-time feature count validation
240    pub fn new(data: Array2<T>) -> Result<Self> {
241        if data.ncols() == N_FEATURES {
242            Ok(Self {
243                data,
244                _phantom: PhantomData,
245            })
246        } else {
247            Err(SklearsError::InvalidInput(format!(
248                "Expected {} features, got {}",
249                N_FEATURES,
250                data.ncols()
251            )))
252        }
253    }
254
255    /// Create a new feature matrix without validation (unsafe)
256    ///
257    /// # Safety
258    /// The caller must ensure that `data.ncols() == N_FEATURES`
259    pub unsafe fn new_unchecked(data: Array2<T>) -> Self {
260        Self {
261            data,
262            _phantom: PhantomData,
263        }
264    }
265
266    /// Get the number of samples
267    pub fn n_samples(&self) -> usize {
268        self.data.nrows()
269    }
270
271    /// Get the number of features (always N_FEATURES)
272    pub const fn n_features(&self) -> usize {
273        N_FEATURES
274    }
275
276    /// Get a view of the underlying data
277    pub fn view(&self) -> ArrayView2<'_, T> {
278        self.data.view()
279    }
280
281    /// Get a feature column by type-safe index
282    pub fn feature(&self, index: FeatureIndex<N_FEATURES>) -> ArrayView1<'_, T> {
283        self.data.column(index.get())
284    }
285
286    /// Select features using a type-safe mask
287    pub fn select_features<const N_SELECTED: usize>(
288        &self,
289        mask: &FeatureMask<N_FEATURES>,
290    ) -> Result<FeatureMatrix<T, N_SELECTED>> {
291        let selected_indices = mask.selected_indices();
292        if selected_indices.len() != N_SELECTED {
293            return Err(SklearsError::InvalidInput(format!(
294                "Expected {} selected features, got {}",
295                N_SELECTED,
296                selected_indices.len()
297            )));
298        }
299
300        let mut selected_data = Array2::default((self.n_samples(), N_SELECTED));
301        for (new_col, &old_index) in selected_indices.iter().enumerate() {
302            for row in 0..self.n_samples() {
303                selected_data[[row, new_col]] = self.data[[row, old_index.get()]].clone();
304            }
305        }
306
307        Ok(FeatureMatrix {
308            data: selected_data,
309            _phantom: PhantomData,
310        })
311    }
312
313    /// Convert to a dynamic feature matrix
314    pub fn to_dynamic(self) -> DynamicFeatureMatrix<T> {
315        DynamicFeatureMatrix::new(self.data)
316    }
317}
318
319/// Dynamic feature matrix for runtime feature count
320#[derive(Debug, Clone)]
321pub struct DynamicFeatureMatrix<T> {
322    data: Array2<T>,
323}
324
325impl<T> DynamicFeatureMatrix<T> {
326    /// Create a new dynamic feature matrix
327    pub fn new(data: Array2<T>) -> Self {
328        Self { data }
329    }
330
331    /// Get the number of samples
332    pub fn n_samples(&self) -> usize {
333        self.data.nrows()
334    }
335
336    /// Get the number of features
337    pub fn n_features(&self) -> usize {
338        self.data.ncols()
339    }
340
341    /// Get a view of the underlying data
342    pub fn view(&self) -> ArrayView2<'_, T> {
343        self.data.view()
344    }
345
346    /// Get a feature column by runtime index
347    pub fn feature(&self, index: RuntimeFeatureIndex) -> Result<ArrayView1<'_, T>> {
348        if index.is_valid(self.n_features()) {
349            Ok(self.data.column(index.get()))
350        } else {
351            Err(SklearsError::InvalidInput(format!(
352                "Feature index {} out of bounds for {} features",
353                index.get(),
354                self.n_features()
355            )))
356        }
357    }
358
359    /// Convert to a compile-time feature matrix if the size matches
360    pub fn to_static<const N_FEATURES: usize>(self) -> Result<FeatureMatrix<T, N_FEATURES>>
361    where
362        T: Clone + Default,
363    {
364        FeatureMatrix::new(self.data)
365    }
366}
367
368/// Zero-cost abstraction for feature selection algorithms
369pub trait TypeSafeSelector<Method, State = data_states::Untrained> {
370    /// The output state after fitting
371    type FittedState;
372
373    /// The type of selection results
374    type SelectionResult;
375
376    /// Fit the selector on training data
377    fn fit_typed<const N_FEATURES: usize>(
378        self,
379        X: &FeatureMatrix<f64, N_FEATURES>,
380        y: ArrayView1<f64>,
381    ) -> Result<TypeSafeSelectorWrapper<Method, Self::FittedState, N_FEATURES>>;
382}
383
384/// Zero-cost wrapper for type-safe selectors
385#[derive(Debug, Clone)]
386pub struct TypeSafeSelectorWrapper<Method, State, const N_FEATURES: usize> {
387    method_params: MethodParameters,
388    selection_result: Option<FeatureMask<N_FEATURES>>,
389    _phantom: PhantomData<(Method, State)>,
390}
391
392impl<Method, State, const N_FEATURES: usize> TypeSafeSelectorWrapper<Method, State, N_FEATURES> {
393    /// Create a new type-safe selector wrapper
394    pub fn new(method_params: MethodParameters) -> Self {
395        Self {
396            method_params,
397            selection_result: None,
398            _phantom: PhantomData,
399        }
400    }
401
402    /// Get the selection result
403    pub fn selection_mask(&self) -> Option<&FeatureMask<N_FEATURES>> {
404        self.selection_result.as_ref()
405    }
406
407    /// Set the selection result
408    pub fn set_selection(&mut self, mask: FeatureMask<N_FEATURES>) {
409        self.selection_result = Some(mask);
410    }
411}
412
413/// Method parameters for different selection algorithms
414#[derive(Debug, Clone)]
415pub enum MethodParameters {
416    /// VarianceThreshold
417    VarianceThreshold {
418        threshold: f64,
419    },
420    /// UnivariateFilter
421    UnivariateFilter {
422        k: usize,
423
424        score_function: String,
425    },
426    /// RecursiveElimination
427    RecursiveElimination {
428        n_features: usize,
429
430        step: f64,
431    },
432    LassoSelection {
433        alpha: f64,
434        max_iter: usize,
435    },
436    TreeBasedSelection {
437        n_estimators: usize,
438        max_depth: Option<usize>,
439    },
440    CorrelationFilter {
441        threshold: f64,
442    },
443    MutualInfoSelection {
444        k: usize,
445        discrete_features: Vec<bool>,
446    },
447}
448
449/// Compile-time variance threshold selector
450#[derive(Debug, Clone)]
451pub struct VarianceThresholdSelector<const N_FEATURES: usize> {
452    threshold: f64,
453    feature_variances: Option<[f64; N_FEATURES]>,
454}
455
456impl<const N_FEATURES: usize> VarianceThresholdSelector<N_FEATURES> {
457    /// Create a new variance threshold selector
458    pub const fn new(threshold: f64) -> Self {
459        Self {
460            threshold,
461            feature_variances: None,
462        }
463    }
464
465    /// Fit the selector on data
466    pub fn fit(&mut self, X: &FeatureMatrix<f64, N_FEATURES>) -> Result<FeatureMask<N_FEATURES>> {
467        let mut variances = [0.0; N_FEATURES];
468
469        for i in 0..N_FEATURES {
470            // Safety: i is always < N_FEATURES
471            let feature_index = unsafe { FeatureIndex::new_unchecked(i) };
472            let feature_data = X.feature(feature_index);
473            variances[i] = feature_data.var(1.0);
474        }
475
476        self.feature_variances = Some(variances);
477
478        let mut mask = [false; N_FEATURES];
479        for i in 0..N_FEATURES {
480            mask[i] = variances[i] > self.threshold;
481        }
482
483        Ok(FeatureMask::from_array(mask))
484    }
485
486    /// Transform data using the fitted selector
487    pub fn transform<const N_SELECTED: usize>(
488        &self,
489        X: &FeatureMatrix<f64, N_FEATURES>,
490        mask: &FeatureMask<N_FEATURES>,
491    ) -> Result<FeatureMatrix<f64, N_SELECTED>> {
492        X.select_features(mask)
493    }
494
495    /// Get feature variances (if fitted)
496    pub const fn feature_variances(&self) -> Option<&[f64; N_FEATURES]> {
497        self.feature_variances.as_ref()
498    }
499}
500
501/// Compile-time univariate feature selector
502#[derive(Debug, Clone)]
503pub struct UnivariateSelector<const N_FEATURES: usize, const K: usize> {
504    score_function: UnivariateScoreFunction,
505    feature_scores: Option<[f64; N_FEATURES]>,
506}
507
508impl<const N_FEATURES: usize, const K: usize> UnivariateSelector<N_FEATURES, K> {
509    /// Create a new univariate selector
510    ///
511    /// # Compile-time checks
512    /// - K must be <= N_FEATURES
513    pub const fn new(score_function: UnivariateScoreFunction) -> Option<Self> {
514        if K <= N_FEATURES {
515            Some(Self {
516                score_function,
517                feature_scores: None,
518            })
519        } else {
520            None
521        }
522    }
523
524    /// Fit the selector on data
525    pub fn fit(
526        &mut self,
527        X: &FeatureMatrix<f64, N_FEATURES>,
528        y: ArrayView1<f64>,
529    ) -> Result<FeatureMask<N_FEATURES>> {
530        let mut scores = [0.0; N_FEATURES];
531
532        for i in 0..N_FEATURES {
533            // Safety: i is always < N_FEATURES
534            let feature_index = unsafe { FeatureIndex::new_unchecked(i) };
535            let feature_data = X.feature(feature_index);
536            scores[i] = match self.score_function {
537                UnivariateScoreFunction::Correlation => self.compute_correlation(feature_data, y),
538                UnivariateScoreFunction::MutualInfo => self.compute_mutual_info(feature_data, y),
539                UnivariateScoreFunction::Chi2 => self.compute_chi2_score(feature_data, y),
540                UnivariateScoreFunction::FStatistic => self.compute_f_statistic(feature_data, y),
541            };
542        }
543
544        self.feature_scores = Some(scores);
545
546        // Select top K features
547        let mut indexed_scores: Vec<(usize, f64)> = Vec::with_capacity(N_FEATURES);
548        for i in 0..N_FEATURES {
549            indexed_scores.push((i, scores[i]));
550        }
551
552        // Sort by score (descending)
553        indexed_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
554
555        let mut mask = [false; N_FEATURES];
556        for i in 0..K {
557            if let Some(&(feature_idx, _)) = indexed_scores.get(i) {
558                mask[feature_idx] = true;
559            }
560        }
561
562        Ok(FeatureMask::from_array(mask))
563    }
564
565    /// Transform data using the fitted selector
566    pub fn transform(
567        &self,
568        X: &FeatureMatrix<f64, N_FEATURES>,
569        mask: &FeatureMask<N_FEATURES>,
570    ) -> Result<FeatureMatrix<f64, K>> {
571        X.select_features(mask)
572    }
573
574    /// Get feature scores (if fitted)
575    pub const fn feature_scores(&self) -> Option<&[f64; N_FEATURES]> {
576        self.feature_scores.as_ref()
577    }
578
579    // Helper methods for score computation
580    fn compute_correlation(&self, x: ArrayView1<f64>, y: ArrayView1<f64>) -> f64 {
581        let n = x.len() as f64;
582        if n < 2.0 {
583            return 0.0;
584        }
585
586        let mean_x = x.mean().unwrap_or(0.0);
587        let mean_y = y.mean().unwrap_or(0.0);
588
589        let mut sum_xy = 0.0;
590        let mut sum_x2 = 0.0;
591        let mut sum_y2 = 0.0;
592
593        for i in 0..x.len() {
594            let dx = x[i] - mean_x;
595            let dy = y[i] - mean_y;
596            sum_xy += dx * dy;
597            sum_x2 += dx * dx;
598            sum_y2 += dy * dy;
599        }
600
601        let denom = (sum_x2 * sum_y2).sqrt();
602        if denom < 1e-10 {
603            0.0
604        } else {
605            (sum_xy / denom).abs()
606        }
607    }
608
609    fn compute_mutual_info(&self, x: ArrayView1<f64>, y: ArrayView1<f64>) -> f64 {
610        // Simplified mutual information computation
611        // In a real implementation, this would use proper MI estimation algorithms
612        self.compute_correlation(x, y)
613    }
614
615    fn compute_chi2_score(&self, x: ArrayView1<f64>, y: ArrayView1<f64>) -> f64 {
616        // Simplified chi-square test
617        // In a real implementation, this would compute proper chi-square statistics
618        self.compute_correlation(x, y)
619    }
620
621    fn compute_f_statistic(&self, x: ArrayView1<f64>, y: ArrayView1<f64>) -> f64 {
622        // Simplified F-statistic computation
623        // In a real implementation, this would compute ANOVA F-statistic
624        self.compute_correlation(x, y)
625    }
626}
627
628/// Score functions for univariate selection
629#[derive(Debug, Clone, Copy)]
630pub enum UnivariateScoreFunction {
631    /// Correlation
632    Correlation,
633    /// MutualInfo
634    MutualInfo,
635    /// Chi2
636    Chi2,
637    /// FStatistic
638    FStatistic,
639}
640
641/// Compile-time correlation-based feature selector
642#[derive(Debug, Clone)]
643pub struct CorrelationSelector<const N_FEATURES: usize> {
644    threshold: f64,
645    correlation_matrix: Option<[[f64; N_FEATURES]; N_FEATURES]>,
646}
647
648impl<const N_FEATURES: usize> CorrelationSelector<N_FEATURES> {
649    /// Create a new correlation-based selector
650    pub const fn new(threshold: f64) -> Self {
651        Self {
652            threshold,
653            correlation_matrix: None,
654        }
655    }
656
657    /// Fit the selector on data
658    pub fn fit(&mut self, X: &FeatureMatrix<f64, N_FEATURES>) -> Result<FeatureMask<N_FEATURES>> {
659        let mut corr_matrix = [[0.0; N_FEATURES]; N_FEATURES];
660
661        // Compute correlation matrix
662        for i in 0..N_FEATURES {
663            for j in 0..N_FEATURES {
664                if i == j {
665                    corr_matrix[i][j] = 1.0;
666                } else {
667                    // Safety: i and j are always < N_FEATURES
668                    let feature_i = unsafe { FeatureIndex::new_unchecked(i) };
669                    let feature_j = unsafe { FeatureIndex::new_unchecked(j) };
670                    let data_i = X.feature(feature_i);
671                    let data_j = X.feature(feature_j);
672                    corr_matrix[i][j] = self.compute_correlation(data_i, data_j);
673                }
674            }
675        }
676
677        self.correlation_matrix = Some(corr_matrix);
678
679        // Remove highly correlated features
680        let mut mask = [true; N_FEATURES];
681        for i in 0..N_FEATURES {
682            for j in (i + 1)..N_FEATURES {
683                if corr_matrix[i][j].abs() > self.threshold && mask[i] && mask[j] {
684                    // Keep the feature with higher variance
685                    // Safety: i and j are always < N_FEATURES
686                    let feature_i = unsafe { FeatureIndex::new_unchecked(i) };
687                    let feature_j = unsafe { FeatureIndex::new_unchecked(j) };
688                    let var_i = X.feature(feature_i).var(1.0);
689                    let var_j = X.feature(feature_j).var(1.0);
690                    if var_i < var_j {
691                        mask[i] = false;
692                    } else {
693                        mask[j] = false;
694                    }
695                }
696            }
697        }
698
699        Ok(FeatureMask::from_array(mask))
700    }
701
702    /// Transform data using the fitted selector
703    pub fn transform<const N_SELECTED: usize>(
704        &self,
705        X: &FeatureMatrix<f64, N_FEATURES>,
706        mask: &FeatureMask<N_FEATURES>,
707    ) -> Result<FeatureMatrix<f64, N_SELECTED>> {
708        X.select_features(mask)
709    }
710
711    /// Get correlation matrix (if fitted)
712    pub const fn correlation_matrix(&self) -> Option<&[[f64; N_FEATURES]; N_FEATURES]> {
713        self.correlation_matrix.as_ref()
714    }
715
716    fn compute_correlation(&self, x: ArrayView1<f64>, y: ArrayView1<f64>) -> f64 {
717        let n = x.len() as f64;
718        if n < 2.0 {
719            return 0.0;
720        }
721
722        let mean_x = x.mean().unwrap_or(0.0);
723        let mean_y = y.mean().unwrap_or(0.0);
724
725        let mut sum_xy = 0.0;
726        let mut sum_x2 = 0.0;
727        let mut sum_y2 = 0.0;
728
729        for i in 0..x.len() {
730            let dx = x[i] - mean_x;
731            let dy = y[i] - mean_y;
732            sum_xy += dx * dy;
733            sum_x2 += dx * dx;
734            sum_y2 += dy * dy;
735        }
736
737        let denom = (sum_x2 * sum_y2).sqrt();
738        if denom < 1e-10 {
739            0.0
740        } else {
741            sum_xy / denom
742        }
743    }
744}
745
746/// Type-safe feature selection pipeline with compile-time guarantees
747#[derive(Debug, Clone)]
748pub struct TypeSafeSelectionPipeline<const N_FEATURES: usize, State = data_states::Untrained> {
749    steps: Vec<PipelineStep>,
750    current_mask: Option<FeatureMask<N_FEATURES>>,
751    _phantom: PhantomData<State>,
752}
753
754impl<const N_FEATURES: usize> Default
755    for TypeSafeSelectionPipeline<N_FEATURES, data_states::Untrained>
756{
757    fn default() -> Self {
758        Self::new()
759    }
760}
761
762impl<const N_FEATURES: usize> TypeSafeSelectionPipeline<N_FEATURES, data_states::Untrained> {
763    /// Create a new type-safe selection pipeline
764    pub const fn new() -> Self {
765        Self {
766            steps: Vec::new(),
767            current_mask: None,
768            _phantom: PhantomData,
769        }
770    }
771
772    /// Add a variance threshold step
773    pub fn add_variance_threshold(mut self, threshold: f64) -> Self {
774        self.steps.push(PipelineStep::VarianceThreshold(threshold));
775        self
776    }
777
778    /// Add a correlation filter step
779    pub fn add_correlation_filter(mut self, threshold: f64) -> Self {
780        self.steps.push(PipelineStep::CorrelationFilter(threshold));
781        self
782    }
783
784    /// Add a univariate selection step
785    pub fn add_univariate_selection<const K: usize>(
786        mut self,
787        score_function: UnivariateScoreFunction,
788    ) -> Self {
789        self.steps.push(PipelineStep::UnivariateSelection {
790            k: K,
791            score_function,
792        });
793        self
794    }
795
796    /// Fit the pipeline on training data
797    pub fn fit(
798        self,
799        X: &FeatureMatrix<f64, N_FEATURES>,
800        y: ArrayView1<f64>,
801    ) -> Result<TypeSafeSelectionPipeline<N_FEATURES, data_states::Trained>> {
802        let mut current_mask = FeatureMask::all_selected();
803
804        for step in &self.steps {
805            let step_mask = match step {
806                PipelineStep::VarianceThreshold(threshold) => {
807                    let mut selector = VarianceThresholdSelector::new(*threshold);
808                    selector.fit(X)?
809                }
810                PipelineStep::CorrelationFilter(threshold) => {
811                    let mut selector = CorrelationSelector::new(*threshold);
812                    selector.fit(X)?
813                }
814                PipelineStep::UnivariateSelection {
815                    k: _,
816                    score_function,
817                } => {
818                    // This is a simplification - in practice we'd need to handle different K values
819                    // For demonstration, we'll use a fixed K
820                    const DEFAULT_K: usize = 10;
821                    if DEFAULT_K <= N_FEATURES {
822                        let mut selector =
823                            UnivariateSelector::<N_FEATURES, DEFAULT_K>::new(*score_function)
824                                .ok_or_else(|| {
825                                    SklearsError::InvalidInput(
826                                        "Invalid K for univariate selection".to_string(),
827                                    )
828                                })?;
829                        selector.fit(X, y)?
830                    } else {
831                        FeatureMask::all_selected()
832                    }
833                }
834            };
835
836            current_mask = current_mask.and(&step_mask);
837        }
838
839        Ok(TypeSafeSelectionPipeline {
840            steps: self.steps,
841            current_mask: Some(current_mask),
842            _phantom: PhantomData,
843        })
844    }
845}
846
847impl<const N_FEATURES: usize> TypeSafeSelectionPipeline<N_FEATURES, data_states::Trained> {
848    /// Transform data using the fitted pipeline
849    pub fn transform<const N_SELECTED: usize>(
850        &self,
851        X: &FeatureMatrix<f64, N_FEATURES>,
852    ) -> Result<FeatureMatrix<f64, N_SELECTED>> {
853        if let Some(ref mask) = self.current_mask {
854            X.select_features(mask)
855        } else {
856            Err(SklearsError::FitError("Pipeline not fitted".to_string()))
857        }
858    }
859
860    /// Get the feature selection mask
861    pub fn selection_mask(&self) -> Option<&FeatureMask<N_FEATURES>> {
862        self.current_mask.as_ref()
863    }
864
865    /// Get the number of selected features
866    pub fn n_selected_features(&self) -> usize {
867        self.current_mask
868            .as_ref()
869            .map(|mask| mask.count_selected())
870            .unwrap_or(0)
871    }
872}
873
874/// Pipeline step enumeration
875#[derive(Debug, Clone)]
876enum PipelineStep {
877    VarianceThreshold(f64),
878    CorrelationFilter(f64),
879    UnivariateSelection {
880        k: usize,
881        score_function: UnivariateScoreFunction,
882    },
883}
884
885/// Zero-cost abstraction for feature transformations
886pub trait ZeroCostTransform<Input, Output> {
887    /// Apply the transformation with zero runtime cost
888    fn transform_zero_cost(input: Input) -> Output;
889}
890
891/// Zero-cost feature index conversion
892impl<const N: usize> ZeroCostTransform<FeatureIndex<N>, usize> for () {
893    fn transform_zero_cost(input: FeatureIndex<N>) -> usize {
894        input.get()
895    }
896}
897
898/// Zero-cost feature mask conversion
899impl<const N: usize> ZeroCostTransform<FeatureMask<N>, Vec<bool>> for () {
900    fn transform_zero_cost(input: FeatureMask<N>) -> Vec<bool> {
901        input.as_array().to_vec()
902    }
903}
904
905/// Compile-time feature count validator
906pub struct FeatureCountValidator<const EXPECTED: usize>;
907
908impl<const EXPECTED: usize> FeatureCountValidator<EXPECTED> {
909    /// Validate feature count at compile time
910    pub const fn validate<const ACTUAL: usize>() -> bool {
911        EXPECTED == ACTUAL
912    }
913
914    /// Validate and convert feature matrix type
915    pub fn validate_matrix<T>(matrix: FeatureMatrix<T, EXPECTED>) -> FeatureMatrix<T, EXPECTED>
916    where
917        T: Clone + Default,
918    {
919        matrix
920    }
921}
922
923/// Type-safe feature selection trait bounds
924pub trait TypeSafeFeatureSelection {
925    /// The feature matrix type
926    type FeatureMatrix;
927
928    /// The selection result type
929    type SelectionResult;
930
931    /// The number of input features (compile-time constant)
932    const INPUT_FEATURES: usize;
933
934    /// Perform feature selection with compile-time guarantees
935    fn select_features_typed(data: Self::FeatureMatrix) -> Result<Self::SelectionResult>;
936}
937
938/// Implementation macro for type-safe selectors
939#[macro_export]
940macro_rules! impl_type_safe_selector {
941    ($selector:ty, $method:ty, $n_features:expr, $n_selected:expr) => {
942        impl TypeSafeFeatureSelection for $selector {
943            type FeatureMatrix = FeatureMatrix<f64, $n_features>;
944            type SelectionResult = FeatureMatrix<f64, $n_selected>;
945            const INPUT_FEATURES: usize = $n_features;
946
947            fn select_features_typed(data: Self::FeatureMatrix) -> Result<Self::SelectionResult> {
948                // Default implementation using variance threshold
949                // This can be overridden by implementing the trait directly
950                use crate::type_safe::VarianceThresholdSelector;
951
952                let mut selector = VarianceThresholdSelector::<$n_features>::new(0.0);
953                let mask = selector.fit(&data)?;
954
955                // Verify we have the expected number of selected features
956                if mask.count_selected() != $n_selected {
957                    return Err(SklearsError::InvalidInput(format!(
958                        "Expected {} selected features, got {}. Consider adjusting selection parameters.",
959                        $n_selected,
960                        mask.count_selected()
961                    )));
962                }
963
964                data.select_features(&mask)
965            }
966        }
967    };
968}
969
970/// Const generic helper for computing binomial coefficients at compile time
971pub const fn binomial_coefficient(n: usize, k: usize) -> usize {
972    if k > n {
973        0
974    } else if k == 0 || k == n {
975        1
976    } else {
977        let k = if k > n - k { n - k } else { k };
978        let mut result = 1;
979        let mut i = 0;
980        while i < k {
981            result = result * (n - i) / (i + 1);
982            i += 1;
983        }
984        result
985    }
986}
987
988/// Compile-time validation that selection count is valid
989pub const fn validate_selection_count<const N_FEATURES: usize, const K: usize>() -> bool {
990    K <= N_FEATURES && K > 0
991}
992
993/// Type-level boolean for compile-time feature validation
994pub trait TypeBool {
995    const VALUE: bool;
996}
997
998pub struct True;
999pub struct False;
1000
1001impl TypeBool for True {
1002    const VALUE: bool = true;
1003}
1004
1005impl TypeBool for False {
1006    const VALUE: bool = false;
1007}
1008
1009// Note: Advanced type-level programming features commented out due to requiring unstable Rust features
1010// These would be enabled once const generics operations and inherent associated types are stabilized
1011
1012// /// Compile-time assertion for feature selection validity
1013// pub type Assert<T> = <T as TypeBool>::Value;
1014
1015// pub trait TypeBoolTrait {
1016//     type Value: TypeBool;
1017// }
1018
1019// /// Feature selection validity checker
1020// pub struct FeatureSelectionValid<const N_FEATURES: usize, const K: usize>;
1021
1022// /// Conditional type selection (requires unstable features)
1023// pub type If<const CONDITION: bool, T, F> = IfImpl<{ CONDITION }, T, F>::Type;
1024
1025#[allow(non_snake_case)]
1026#[cfg(test)]
1027mod tests {
1028    use super::*;
1029    use scirs2_core::ndarray::array;
1030
1031    #[test]
1032    fn test_feature_index() {
1033        const MAX_FEATURES: usize = 10;
1034
1035        // Valid index
1036        let valid_index = FeatureIndex::<MAX_FEATURES>::new(5).unwrap();
1037        assert_eq!(valid_index.get(), 5);
1038
1039        // Invalid index
1040        assert!(FeatureIndex::<MAX_FEATURES>::new(15).is_none());
1041    }
1042
1043    #[test]
1044    fn test_feature_mask() {
1045        const N_FEATURES: usize = 5;
1046
1047        let mask = FeatureMask::<N_FEATURES>::from_array([true, false, true, false, true]);
1048        assert_eq!(mask.count_selected(), 3);
1049
1050        let indices = mask.selected_indices();
1051        assert_eq!(indices.len(), 3);
1052    }
1053
1054    #[test]
1055    fn test_feature_matrix() -> Result<()> {
1056        const N_FEATURES: usize = 3;
1057
1058        let data = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
1059        let matrix = FeatureMatrix::<f64, N_FEATURES>::new(data)?;
1060
1061        assert_eq!(matrix.n_features(), 3);
1062        assert_eq!(matrix.n_samples(), 2);
1063
1064        Ok(())
1065    }
1066
1067    #[test]
1068    fn test_variance_threshold_selector() -> Result<()> {
1069        const N_FEATURES: usize = 3;
1070
1071        let data = array![[1.0, 2.0, 3.0], [1.1, 5.0, 3.1], [0.9, 8.0, 2.9]];
1072        let matrix = FeatureMatrix::<f64, N_FEATURES>::new(data)?;
1073
1074        let mut selector = VarianceThresholdSelector::new(0.1);
1075        let mask = selector.fit(&matrix)?;
1076
1077        // Should select features with variance > 0.1
1078        assert!(mask.count_selected() > 0);
1079
1080        Ok(())
1081    }
1082
1083    #[test]
1084    fn test_compile_time_validation() {
1085        const N_FEATURES: usize = 10;
1086        const K: usize = 5;
1087
1088        // This should compile
1089        assert!(validate_selection_count::<N_FEATURES, K>());
1090
1091        // This should not compile if uncommented:
1092        // assert!(validate_selection_count::<5, 10>());
1093    }
1094
1095    #[test]
1096    fn test_type_safe_pipeline() -> Result<()> {
1097        const N_FEATURES: usize = 4;
1098
1099        let data = array![
1100            [1.0, 2.0, 3.0, 4.0],
1101            [1.1, 5.0, 3.1, 4.1],
1102            [0.9, 8.0, 2.9, 3.9],
1103            [1.2, 2.1, 3.2, 4.2]
1104        ];
1105        let matrix = FeatureMatrix::<f64, N_FEATURES>::new(data)?;
1106        let y = array![0.0, 1.0, 0.0, 1.0];
1107
1108        let pipeline = TypeSafeSelectionPipeline::<N_FEATURES>::new()
1109            .add_variance_threshold(0.01)
1110            .add_correlation_filter(0.9);
1111
1112        let fitted_pipeline = pipeline.fit(&matrix, y.view())?;
1113        assert!(fitted_pipeline.n_selected_features() > 0);
1114
1115        Ok(())
1116    }
1117}